"""Ray Serve deployment for GLiNER with dynamic batching and memory-aware batch sizing."""
import os
import logging
from typing import Any, Dict, List, Tuple, Union, Optional
import torch
from .config import GLiNERServeConfig
from .memory import GLiNERMemoryEstimator
logger = logging.getLogger(__name__)
[docs]
class GLiNERServer:
"""GLiNER Ray Serve deployment with dynamic batching.
Supports both entity extraction (NER) and relation extraction.
Automatically detects model type and adjusts behavior accordingly.
Uses low-level batch methods (prepare_batch, collate_batch, run_batch,
decode_batch) to avoid DataLoader initialization overhead on each call.
Features:
- Dynamic batching with Ray Serve's @serve.batch
- Memory-aware batch size estimation to prevent CUDA OOM
- Precompilation for power-of-two batch sizes
- Support for both NER and relation extraction models
- FlashDeBERTa support for faster inference
- Sequence packing for improved throughput
"""
[docs]
def __init__(self, config: GLiNERServeConfig):
"""Initialize the GLiNER server deployment.
Args:
config: Server configuration with model and serving parameters.
"""
from gliner import GLiNER, InferencePackingConfig # noqa: PLC0415
self.config = config
env_vars = config.to_env_vars()
for key, value in env_vars.items():
os.environ[key] = value
if config.tokenizer_threads > 0:
os.environ["TOKENIZERS_PARALLELISM"] = "true"
torch.set_num_threads(config.tokenizer_threads)
torch.set_float32_matmul_precision("high")
dtype_map = {
"float32": torch.float32,
"float16": torch.float16,
"fp16": torch.float16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
}
self.torch_dtype = dtype_map.get(config.dtype.lower(), torch.bfloat16)
self.memory_estimator = GLiNERMemoryEstimator(
safety_factor=config.memory_overhead_factor,
target_memory_fraction=config.target_memory_fraction,
calibration_probe_batch_size=config.calibration_probe_batch_size,
)
if torch.cuda.is_available():
self.memory_estimator.measure_cuda_context()
logger.info("Loading model: %s", config.model)
if config.enable_flashdeberta:
logger.info("FlashDeBERTa enabled")
self.model = GLiNER.from_pretrained(
config.model,
max_length=config.max_model_len,
max_width=config.max_span_width,
map_location=config.device,
dtype=self.torch_dtype,
)
self.model.eval()
if config.quantization:
logger.info("Applying quantization: %s", config.quantization)
self.model.quantize(config.quantization)
if torch.cuda.is_available():
self.memory_estimator.measure_model_memory()
self._supports_relations = self._detect_relation_support()
logger.info("Relation extraction support: %s", self._supports_relations)
self.collator = self.model.create_collator()
if config.enable_sequence_packing:
self.packing_config = InferencePackingConfig(
max_length=config.max_model_len,
)
logger.info("Sequence packing enabled")
else:
self.packing_config = None
if config.enable_compilation:
self._precompile()
if torch.cuda.is_available():
self._calibrate_memory()
def _detect_relation_support(self) -> bool:
"""Detect if the model supports relation extraction."""
model_type = getattr(self.model.config, "model_type", "")
return "relex" in model_type.lower()
def _precompile(self) -> None:
"""Precompile model for configured batch sizes."""
logger.info("Precompiling model for batch sizes: %s", self.config.precompiled_batch_sizes)
self.model.compile()
dummy_labels = ["person", "organization", "location"]
dummy_relations = ["works_at", "located_in"] if self._supports_relations else None
for batch_size in self.config.precompiled_batch_sizes:
dummy_texts = [f"Sample text number {i} for precompilation warmup." for i in range(batch_size)]
for _ in range(self.config.warmup_iterations):
if self._supports_relations and dummy_relations:
self._run_batch_internal(
dummy_texts,
dummy_labels,
relations=dummy_relations,
threshold=0.5,
relation_threshold=0.5,
flat_ner=True,
multi_label=False,
)
else:
self._run_batch_internal(
dummy_texts,
dummy_labels,
threshold=0.5,
flat_ner=True,
multi_label=False,
)
logger.info(" Batch size %d: compiled", batch_size)
if torch.cuda.is_available():
torch.cuda.synchronize()
logger.info("Precompilation complete.")
def _calibrate_memory(self) -> None:
"""Build the memory calibration table across power-of-two seq lengths."""
logger.info("Calibrating memory table...")
self.memory_estimator.calibrate(
self._run_batch_internal,
max_seq_len=self.config.max_model_len,
min_seq_len=self.config.calibration_min_seq_len,
)
logger.info("Memory calibration complete.")
[docs]
def batch_size_fn(self, seq_len: Optional[int] = None) -> int:
"""Largest precompiled batch size that fits at ``seq_len``.
With no arguments, returns the worst-case answer (``max_model_len``),
suitable for the deployment's initial ``max_batch_size``. Called again
from ``_infer_batch`` with the observed seq length (text + label +
relation words) to re-size Ray's batcher for the next accumulation.
"""
if not torch.cuda.is_available():
return self.config.precompiled_batch_sizes[-1]
if seq_len is None:
seq_len = self.config.max_model_len
return self.memory_estimator.batch_size_fn(
seq_len=seq_len,
precompiled_sizes=self.config.precompiled_batch_sizes,
)
[docs]
def observed_seq_len(
self,
texts: List[str],
labels: Optional[List[str]] = None,
relations: Optional[List[str]] = None,
) -> int:
"""Total input word count: longest text + all label/relation words.
Labels and relations are concatenated into the input by the model, so
they extend the effective sequence length for every sample in the
batch.
"""
max_text_words = max((len(t.split()) for t in texts if t.strip()), default=0)
prompt_words = 0
if labels:
prompt_words += sum(len(label.split()) for label in labels)
if relations:
prompt_words += sum(len(r.split()) for r in relations)
total = max_text_words + prompt_words
return min(max(total, self.config.calibration_min_seq_len), self.config.max_model_len)
def _filter_labels(self, labels: List[str]) -> List[str]:
"""Filter labels based on max_labels config."""
if self.config.max_labels > 0 and len(labels) > self.config.max_labels:
logger.warning("Truncating labels from %d to %d", len(labels), self.config.max_labels)
return labels[: self.config.max_labels]
return labels
@torch.inference_mode()
def _run_batch_internal(
self,
texts: List[str],
labels: List[str],
relations: Optional[List[str]] = None,
threshold: float = 0.5,
relation_threshold: float = 0.5,
flat_ner: bool = True,
multi_label: bool = False,
) -> Union[List[List[Dict[str, Any]]], Tuple[List[List[Dict[str, Any]]], List[List[Dict[str, Any]]]]]:
"""Run batch inference using low-level methods (no DataLoader).
This is the core inference method that avoids DataLoader initialization
overhead by directly using prepare_batch, collate_batch, run_batch,
decode_batch, and map_entities_to_text.
Args:
texts: List of input texts.
labels: Entity type labels.
relations: Relation type labels (for relex models).
threshold: Entity confidence threshold.
relation_threshold: Relation confidence threshold.
flat_ner: Whether to use flat NER.
multi_label: Whether to allow multiple labels per span.
Returns:
For NER models: List of entity lists.
For relex models: Tuple of (entities, relations) lists.
"""
if self._supports_relations:
return self._run_batch_relex(texts, labels, relations, threshold, relation_threshold, flat_ner, multi_label)
else:
return self._run_batch_ner(texts, labels, threshold, flat_ner, multi_label)
def _run_batch_ner(
self,
texts: List[str],
labels: List[str],
threshold: float,
flat_ner: bool,
multi_label: bool,
) -> List[List[Dict[str, Any]]]:
"""Run NER batch inference using low-level methods."""
prepared = self.model.prepare_batch(texts, labels)
if not prepared["valid_texts"]:
return [[] for _ in range(prepared["num_original"])]
batch = self.model.collate_batch(
prepared["input_x"],
prepared["entity_types"],
self.collator,
)
model_output = self.model.run_batch(
batch,
threshold=threshold,
packing_config=self.packing_config,
move_to_device=True,
)
decoded = self.model.decode_batch(
model_output,
batch,
threshold=threshold,
flat_ner=flat_ner,
multi_label=multi_label,
)
entity_results = self.model.map_entities_to_text(
decoded,
prepared["valid_texts"],
prepared["valid_to_orig_idx"],
prepared["start_token_map"],
prepared["end_token_map"],
prepared["num_original"],
)
return entity_results
def _run_batch_relex(
self,
texts: List[str],
labels: List[str],
relations: Optional[List[str]],
threshold: float,
relation_threshold: float,
flat_ner: bool,
multi_label: bool,
) -> Tuple[List[List[Dict[str, Any]]], List[List[Dict[str, Any]]]]:
"""Run relation extraction batch inference using low-level methods."""
prepared = self.model.prepare_batch(texts, labels, relations=relations)
if not prepared["valid_texts"]:
num_orig = prepared["num_original"]
return [[] for _ in range(num_orig)], [[] for _ in range(num_orig)]
batch = self.model.collate_batch(
prepared["input_x"],
prepared["entity_types"],
self.collator,
relation_types=prepared.get("relation_types", []),
)
model_output = self.model.run_batch(
batch,
threshold=threshold,
packing_config=self.packing_config,
move_to_device=True,
)
decoded_entities, decoded_relations = self.model.decode_batch(
model_output,
batch,
threshold=threshold,
relation_threshold=relation_threshold,
flat_ner=flat_ner,
multi_label=multi_label,
)
entity_results = self.model.map_entities_to_text(
decoded_entities,
prepared["valid_texts"],
prepared["valid_to_orig_idx"],
prepared["start_token_map"],
prepared["end_token_map"],
prepared["num_original"],
)
relation_results = self.model.map_relations_to_text(
decoded_relations,
decoded_entities,
prepared["valid_texts"],
prepared["valid_to_orig_idx"],
prepared["start_token_map"],
prepared["end_token_map"],
prepared["num_original"],
)
return entity_results, relation_results
[docs]
def predict(
self,
texts: Union[str, List[str]],
labels: List[str],
relations: Optional[List[str]] = None,
threshold: Optional[float] = None,
relation_threshold: Optional[float] = None,
flat_ner: bool = True,
multi_label: bool = False,
) -> List[Dict[str, Any]]:
"""Predict entities and optionally relations.
Args:
texts: Input text(s) to process.
labels: Entity type labels to extract.
relations: Relation type labels (only for relex models).
threshold: Confidence threshold for entities.
relation_threshold: Confidence threshold for relations.
flat_ner: Whether to use flat NER (no overlapping entities).
multi_label: Whether to allow multiple labels per span.
Returns:
List of result dicts, one per input text. Each dict contains:
- "entities": List of entity dicts with start, end, text, label, score
- "relations": List of relation dicts (only if model supports relations)
"""
if isinstance(texts, str):
texts = [texts]
if threshold is None:
threshold = self.config.default_threshold
if relation_threshold is None:
relation_threshold = self.config.default_relation_threshold
labels = self._filter_labels(labels)
if self._supports_relations and relations:
entities, rels = self._run_batch_internal(
texts,
labels,
relations=relations,
threshold=threshold,
relation_threshold=relation_threshold,
flat_ner=flat_ner,
multi_label=multi_label,
)
results = [{"entities": ents, "relations": r} for ents, r in zip(entities, rels)]
else:
entities = self._run_batch_internal(
texts,
labels,
threshold=threshold,
flat_ner=flat_ner,
multi_label=multi_label,
)
results = [{"entities": ents} for ents in entities]
return results
def _build_deployment(config: GLiNERServeConfig):
"""Build Ray Serve deployment from config."""
from ray import serve # noqa: PLC0415
batch_wait_s = max(config.batch_wait_timeout_ms, 0.0) / 1000.0
initial_max_batch_size = config.max_batch_size
@serve.deployment(
num_replicas=config.num_replicas,
ray_actor_options={
"num_gpus": config.num_gpus_per_replica,
"num_cpus": config.num_cpus_per_replica,
},
max_ongoing_requests=config.max_ongoing_requests,
)
class GLiNERDeployment:
def __init__(self, serve_config: GLiNERServeConfig):
self.server = GLiNERServer(serve_config)
# Seed Ray's batcher with the pessimistic worst-case size so the
# first batch is safe. ``_infer_batch`` re-calls ``batch_size_fn``
# on every dispatch to re-size the batcher based on observed
# sequence lengths.
self._infer_batch.set_max_batch_size(self.server.batch_size_fn())
logger.info(
"Ray Serve batch size initialized to %d (precompiled: %s)",
self.server.batch_size_fn(),
serve_config.precompiled_batch_sizes,
)
@serve.batch(
max_batch_size=initial_max_batch_size,
batch_wait_timeout_s=batch_wait_s,
)
async def _infer_batch(
self,
texts: List[str],
labels_list: List[List[str]],
relations_list: List[Optional[List[str]]],
thresholds: List[float],
relation_thresholds: List[float],
flat_ner_list: List[bool],
multi_label_list: List[bool],
) -> List[Dict[str, Any]]:
"""Single forward pass over the Ray-accumulated batch.
Before dispatch, re-sizes Ray's batcher via ``set_max_batch_size``
using ``batch_size_fn`` on the observed seq length — so the next
accumulation picks the largest precompiled size that fits.
Assumes batch requests are homogeneous — labels/thresholds/flags
are taken from the first request.
"""
next_max_batch = self.server.batch_size_fn(
seq_len=self.server.observed_seq_len(
texts,
labels=labels_list[0] if labels_list else None,
relations=relations_list[0] if relations_list else None,
)
)
self._infer_batch.set_max_batch_size(next_max_batch)
return self.server.predict(
texts,
labels_list[0],
relations=relations_list[0],
threshold=thresholds[0],
relation_threshold=relation_thresholds[0],
flat_ner=flat_ner_list[0],
multi_label=multi_label_list[0],
)
async def predict(
self,
text: str,
labels: List[str],
relations: Optional[List[str]] = None,
threshold: Optional[float] = None,
relation_threshold: Optional[float] = None,
flat_ner: bool = True,
multi_label: bool = False,
) -> Dict[str, Any]:
"""Single prediction endpoint."""
if threshold is None:
threshold = self.server.config.default_threshold
if relation_threshold is None:
relation_threshold = self.server.config.default_relation_threshold
results = await self._infer_batch(
text,
labels,
relations,
threshold,
relation_threshold,
flat_ner,
multi_label,
)
return results
async def __call__(self, request) -> Dict[str, Any]:
"""Handle HTTP requests."""
payload = await request.json()
return await self.predict(
text=payload["text"],
labels=payload["labels"],
relations=payload.get("relations"),
threshold=payload.get("threshold"),
relation_threshold=payload.get("relation_threshold"),
flat_ner=payload.get("flat_ner", True),
multi_label=payload.get("multi_label", False),
)
return GLiNERDeployment.bind(config)
[docs]
def serve(
config: GLiNERServeConfig,
blocking: bool = False,
) -> Any:
"""Start GLiNER Ray Serve deployment.
Args:
config: Server configuration.
blocking: If True, blocks until the server is shut down.
Returns:
Ray Serve deployment handle for making predictions.
Example:
>>> from gliner.serve import GLiNERServeConfig, serve
>>> config = GLiNERServeConfig(model="urchade/gliner_small-v2.1")
>>> handle = serve(config)
>>> # Make predictions
>>> ref = handle.predict.remote("John works at Google", ["person", "org"])
>>> print(ref.result())
"""
import ray # noqa: PLC0415
from ray import serve as ray_serve # noqa: PLC0415
if not ray.is_initialized():
ray.init(address=config.ray_address, ignore_reinit_error=True)
ray_serve.start(detached=True, http_options={"port": config.http_port})
app = _build_deployment(config)
handle = ray_serve.run(app, name="gliner", route_prefix=config.route_prefix)
logger.info("GLiNER server running at http://localhost:%d%s", config.http_port, config.route_prefix)
if blocking:
import time # noqa: PLC0415
import signal # noqa: PLC0415
shutdown_event = False
def handle_signal(_signum, _frame):
nonlocal shutdown_event
shutdown_event = True
signal.signal(signal.SIGINT, handle_signal)
signal.signal(signal.SIGTERM, handle_signal)
while not shutdown_event:
time.sleep(1)
ray_serve.shutdown()
return handle
[docs]
def shutdown() -> None:
"""Shutdown the GLiNER Ray Serve deployment."""
from ray import serve as ray_serve # noqa: PLC0415
ray_serve.shutdown()
[docs]
class GLiNERFactory:
"""vLLM-style synchronous facade over a GLiNER Ray Serve deployment.
Bundles config → deploy → client into one lifecycle-managed object so
callers never see Ray's ObjectRefs.
Pass a list of texts to ``predict`` to preserve dynamic batching: each
text is dispatched as a separate request so Ray Serve's ``@serve.batch``
can accumulate them into a single forward pass. A Python loop of
single-text calls would serialize and defeat batching.
Example:
>>> from gliner.serve import GLiNERFactory
>>> llm = GLiNERFactory(model="urchade/gliner_small-v2.1")
>>> outputs = llm.predict(
... ["John works at Google", "Paris is in France"],
... labels=["person", "organization", "location"],
... )
>>> llm.shutdown()
Or as a context manager:
>>> with GLiNERFactory(model="urchade/gliner_small-v2.1") as llm:
... out = llm.predict("John works at Google", ["person", "org"])
"""
[docs]
def __init__(
self,
model: Optional[str] = None,
*,
config: Optional[GLiNERServeConfig] = None,
**kwargs,
):
"""Build a config (if not provided) and start the Ray Serve deployment.
Args:
model: Model name or path. Ignored if ``config`` is provided.
config: Prebuilt ``GLiNERServeConfig``. Mutually exclusive with
``model``/``kwargs``.
**kwargs: Forwarded to ``GLiNERServeConfig`` when building one.
"""
if config is not None:
if model is not None or kwargs:
raise ValueError("Pass either `config` or `model`/kwargs, not both.")
else:
if model is None:
raise ValueError("Must provide either `model` or `config`.")
config = GLiNERServeConfig(model=model, **kwargs)
self.config = config
self._handle = serve(config, blocking=False)
self._closed = False
@property
def handle(self):
"""Underlying Ray Serve deployment handle — for async/advanced use."""
return self._handle
[docs]
def predict(
self,
texts: Union[str, List[str]],
labels: List[str],
relations: Optional[List[str]] = None,
threshold: Optional[float] = None,
relation_threshold: Optional[float] = None,
flat_ner: bool = True,
multi_label: bool = False,
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
"""Blocking prediction. Returns a dict for ``str`` input, list for list input."""
single = isinstance(texts, str)
items = [texts] if single else list(texts)
refs = [
self._handle.predict.remote(
t,
labels,
relations,
threshold,
relation_threshold,
flat_ner,
multi_label,
)
for t in items
]
results = [ref.result() for ref in refs]
return results[0] if single else results
[docs]
async def predict_async(
self,
texts: Union[str, List[str]],
labels: List[str],
relations: Optional[List[str]] = None,
threshold: Optional[float] = None,
relation_threshold: Optional[float] = None,
flat_ner: bool = True,
multi_label: bool = False,
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
"""Async prediction. Concurrent calls accumulate into one batch."""
import asyncio # noqa: PLC0415
single = isinstance(texts, str)
items = [texts] if single else list(texts)
refs = [
self._handle.predict.remote(
t,
labels,
relations,
threshold,
relation_threshold,
flat_ner,
multi_label,
)
for t in items
]
results = list(await asyncio.gather(*refs))
return results[0] if single else results
[docs]
def shutdown(self) -> None:
"""Tear down the Ray Serve deployment and the Ray runtime it booted.
Idempotent. Shutting down Ray after Serve avoids leaving the driver
attached to a detached Serve instance — the latter produces noisy
``ServeController ... killed by ray.kill`` retry warnings in the
raylet log when the process exits.
"""
if self._closed:
return
import ray # noqa: PLC0415
from ray import serve as ray_serve # noqa: PLC0415
ray_serve.shutdown()
if ray.is_initialized():
ray.shutdown()
self._closed = True
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.shutdown()
return False
def __del__(self):
try:
self.shutdown()
except Exception:
pass