Components & ConfigsΒΆ

GLiNER supports multiple architecture variants, each with its own configuration class. This page documents the configuration parameters for each architecture and provides training examples.

Architecture OverviewΒΆ

Architecture

Config Class

Use Case

UniEncoderSpan

UniEncoderSpanConfig

Standard span-based NER, original GLiNER

UniEncoderToken

UniEncoderTokenConfig

Token-level NER, long-form extraction

BiEncoderSpan

BiEncoderSpanConfig

Span NER with separate label encoder

BiEncoderToken

BiEncoderTokenConfig

Token NER with separate label encoder

UniEncoderSpanDecoder

UniEncoderSpanDecoderConfig

Generative label prediction

UniEncoderSpanRelex

UniEncoderSpanRelexConfig

Joint entity and relation extraction

Base Configuration ParametersΒΆ

All GLiNER architectures share these base configuration parameters from BaseGLiNERConfig:

Core ParametersΒΆ

model_nameΒΆ

str, optional, defaults to "microsoft/deberta-v3-small"

Base encoder model identifier from Hugging Face Hub or local path.


nameΒΆ

str, optional, defaults to "gliner"

Optional display name for this model configuration.


max_widthΒΆ

int, optional, defaults to 12

Maximum span width (in number of tokens) allowed when generating candidate spans. Only applies to span-based architectures.


hidden_sizeΒΆ

int, optional, defaults to 512

Dimensionality of hidden representations in internal layers.


dropoutΒΆ

float, optional, defaults to 0.4

Dropout rate applied to intermediate layers.


fine_tuneΒΆ

bool, optional, defaults to True

Whether to fine-tune the encoder during training.


subtoken_poolingΒΆ

str, optional, defaults to "first"

Currently only first token pooling is supported. More approaches will be added in the future.


span_mode [source]ΒΆ

str, optional, defaults to "markerV0"

Defines the strategy for constructing span representations from encoder outputs. Only applies to span-based architectures.

Available options:

  • "markerV0" β€” Projects the start and end token representations with MLPs, concatenates them, and then applies a final projection. Lightweight and default.

  • "marker" β€” Similar to markerV0 but with deeper two-layer projections; better for complex tasks.

  • "query" β€” Uses learned per-span-width query vectors and dot-product interaction.

  • "mlp" β€” Applies a feedforward MLP and reshapes output into span format; fast but position-agnostic.

  • "cat" β€” Concatenates token features with learned span width embeddings before projection.

  • "conv_conv" β€” Uses multiple 1D convolutions with increasing kernel sizes; captures internal structure.

  • "conv_max" β€” Max pooling over tokens in span; emphasizes the strongest token.

  • "conv_mean" β€” Mean pooling across span tokens.

  • "conv_sum" β€” Sum pooling; raw additive representation.

  • "conv_share" β€” Shared convolution kernel over span widths; parameter-efficient alternative.


post_fusion_schema [source]ΒΆ

str, optional, defaults to ""

Defines the multi-step attention schema used to fuse span and label embeddings. The value is a string with hyphen-separated tokens that determine the sequence of attention operations applied in the CrossFuser module.

Each token in the schema defines one of the following attention types:

  • "l2l" β€” label-to-label self-attention (intra-label interaction)

  • "t2t" β€” token-to-token self-attention (intra-span interaction)

  • "l2t" β€” label-to-token cross-attention (labels attend to span tokens)

  • "t2l" β€” token-to-label cross-attention (tokens attend to labels)

Examples:

  • "l2l-l2t-t2t" β€” apply label self-attention β†’ label-to-token attention β†’ token self-attention

  • "l2t" β€” a single step where labels attend to span tokens

  • "" β€” disables fusion entirely (no interaction is applied)

:::tip The number of fusion layers (num_post_fusion_layers) controls how many times the entire schema is repeated. :::ΒΆ

num_post_fusion_layersΒΆ

int, optional, defaults to 1

Number of layers applied after span-label fusion.


vocab_sizeΒΆ

int, optional, defaults to -1

Vocabulary size override if needed. Automatically set during model initialization.


max_neg_type_ratioΒΆ

int, optional, defaults to 1

Controls the ratio of negative (non-matching) types during training.


max_typesΒΆ

int, optional, defaults to 25

Maximum number of entity types supported per batch.


max_lenΒΆ

int, optional, defaults to 384

Maximum sequence length accepted by the encoder.


words_splitter_typeΒΆ

str, optional, defaults to "whitespace"

Heuristic used for word-level splitting during inference.
Choices: "whitespace", "spacy", "moses", stanza, universal


num_rnn_layersΒΆ

int, optional, defaults to 1

Number of LSTM layers to apply on top of encoder outputs. Set to 0 to disable LSTM.


fuse_layersΒΆ

bool, optional, defaults to False

If True, combine representations from multiple encoders (labels and main encoder).


embed_ent_tokenΒΆ

bool, optional, defaults to True

If True, <<ENT>> tokens will be pooled for each label. If False, the first token of each label will be pooled as label embedding.


class_token_indexΒΆ

int, optional, defaults to -1

Index of the entity token in the vocabulary. Set automatically during initialization.


encoder_configΒΆ

dict or PretrainedConfig, optional

A nested config dictionary for the encoder model. If a dict is passed, its model_type must be set or inferred.


ent_tokenΒΆ

str, optional, defaults to "<<ENT>>"

Special token used to mark entity type boundaries in the input.


sep_tokenΒΆ

str, optional, defaults to "<<SEP>>"

Token used to separate entity types from input text.


_attn_implementationΒΆ

str, optional

Optional override for attention logic. Can be used to disable Flash Attention if installed.

Example:

model = GLiNER.from_pretrained(
    "urchade/gliner_mediumv2.1", 
    _attn_implementation="eager"  # Disable Flash Attention
)

UniEncoder Span ConfigurationΒΆ

UniEncoderSpanConfig is used for the original GLiNER architecture with span-based prediction.

Architecture-Specific ParametersΒΆ

This architecture uses all base parameters without additional architecture-specific parameters.

Usage ExampleΒΆ

from gliner import GLiNERConfig, GLiNER

# Create config for UniEncoderSpan
config = GLiNERConfig(
    model_name="microsoft/deberta-v3-small",
    max_width=12,
    hidden_size=512,
    span_mode="markerV0",
    # labels_encoder=None  # Makes it UniEncoder
    # labels_decoder=None  # No decoder
    # relations_layer=None  # No relations
)

# Initialize model from config
model = GLiNER.from_config(config)

Training Config ExampleΒΆ

# Model Configuration
model_name: microsoft/deberta-v3-base
labels_encoder: null  # UniEncoder
name: "span level gliner"
max_width: 12
hidden_size: 768
dropout: 0.4
fine_tune: true
subtoken_pooling: first
span_mode: markerV0
post_fusion_schema: ""
num_post_fusion_layers: 1

# Training Parameters
num_steps: 30000
train_batch_size: 8
eval_every: 1000
warmup_ratio: 0.1
scheduler_type: "cosine"

# Loss Configuration
loss_alpha: -1
loss_gamma: 0
label_smoothing: 0
loss_reduction: "sum"

# Learning Rate Configuration
lr_encoder: 1e-5
lr_others: 5e-5
weight_decay_encoder: 0.01
weight_decay_other: 0.01
max_grad_norm: 1.0

# Data Configuration
train_data: "data.json"
prev_path: null  # Training from scratch
save_total_limit: 3

# Advanced Settings
max_types: 25
max_len: 384

UniEncoder Token ConfigurationΒΆ

UniEncoderTokenConfig is used for token-level classification, suitable for long-form entity extraction.

Architecture-Specific ParametersΒΆ

span_modeΒΆ

str, required, fixed to "token-level"

This parameter is automatically set to "token-level" and cannot be changed for this architecture.

Usage ExampleΒΆ

from gliner import GLiNERConfig, GLiNER

# Create config for UniEncoderToken
config = GLiNERConfig(
    model_name="microsoft/deberta-v3-small",
    hidden_size=512,
    span_mode="token-level",  # Automatically set for this architecture
)

model = GLiNER.from_config(config)

Training Config ExampleΒΆ

# Model Configuration
model_name: microsoft/deberta-v3-base
labels_encoder: null
name: "token level gliner"
hidden_size: 768
dropout: 0.4
fine_tune: true
subtoken_pooling: first
span_mode: token-level  # Token-level prediction
num_rnn_layers: 1  # LSTM helps with token sequences

# Training Parameters (same as span)
num_steps: 30000
train_batch_size: 8
eval_every: 1000
warmup_ratio: 0.1
scheduler_type: "cosine"

# Loss Configuration
loss_alpha: -1
loss_gamma: 0
label_smoothing: 0
loss_reduction: "sum"

# Learning Rate Configuration
lr_encoder: 1e-5
lr_others: 5e-5
weight_decay_encoder: 0.01
weight_decay_other: 0.01
max_grad_norm: 1.0

# Data Configuration
train_data: "data.json"
prev_path: null
save_total_limit: 3

# Advanced Settings
max_types: 25
max_len: 384

BiEncoder Span ConfigurationΒΆ

BiEncoderSpanConfig uses separate encoders for text and entity labels, enabling pre-computation of label embeddings.

Architecture-Specific ParametersΒΆ

labels_encoderΒΆ

str, required

Model identifier or path for the label encoder. Typically a sentence transformer model.

Examples:

  • "sentence-transformers/all-MiniLM-L6-v2"

  • "BAAI/bge-small-en-v1.5"


labels_encoder_configΒΆ

dict or PretrainedConfig, optional

Nested configuration for the label encoder model.

Important NotesΒΆ

:::warning Embedding Resizing Not Supported Unlike UniEncoder models, BiEncoder models do not support token embedding resizing. The vocabulary is fixed to the pretrained encoder’s vocabulary. :::

Usage ExampleΒΆ

from gliner import GLiNERConfig, GLiNER

# Create config for BiEncoderSpan
config = GLiNERConfig(
    model_name="microsoft/deberta-v3-base",
    labels_encoder="sentence-transformers/all-MiniLM-L6-v2",  # Bi-encoder
    max_width=12,
    hidden_size=768,
    span_mode="markerV0",
)

model = GLiNER.from_config(config)

# Pre-compute label embeddings for efficiency
labels = ["person", "organization", "location"]
labels_embeddings = model.encode_labels(labels)

# Use pre-computed embeddings for inference
entities = model.batch_predict_with_embeds(
    texts=["Apple Inc. was founded by Steve Jobs."],
    labels_embeddings=labels_embeddings,
    labels=labels
)

Training Config ExampleΒΆ

# Model Configuration
model_name: microsoft/deberta-v3-base
labels_encoder: sentence-transformers/all-MiniLM-L6-v2  # Bi-encoder
name: "bi-encoder span gliner"
max_width: 12
hidden_size: 768
dropout: 0.4
fine_tune: true
subtoken_pooling: first
span_mode: markerV0
post_fusion_schema: "l2t-t2l"  # Cross-attention fusion

# Training Parameters
num_steps: 30000
train_batch_size: 8
eval_every: 1000
warmup_ratio: 0.1
scheduler_type: "cosine"

# Loss Configuration (Focal loss recommended)
loss_alpha: 0.25
loss_gamma: 2.0
label_smoothing: 0
loss_reduction: "sum"

# Learning Rate Configuration
lr_encoder: 1e-5
lr_others: 5e-5
weight_decay_encoder: 0.01
weight_decay_other: 0.01
max_grad_norm: 1.0

# Data Configuration
train_data: "data.json"
prev_path: null
save_total_limit: 3

# Advanced Settings
max_types: 100  # Can handle many more types
max_len: 384

BiEncoder Token ConfigurationΒΆ

BiEncoderTokenConfig combines bi-encoder architecture with token-level prediction.

Architecture-Specific ParametersΒΆ

labels_encoderΒΆ

str, required

Model identifier for the label encoder.

span_modeΒΆ

str, required, fixed to "token-level"

Automatically set to "token-level" for this architecture.

Usage ExampleΒΆ

from gliner import GLiNERConfig, GLiNER

# Create config for BiEncoderToken
config = GLiNERConfig(
    model_name="microsoft/deberta-v3-base",
    labels_encoder="sentence-transformers/all-MiniLM-L6-v2",
    hidden_size=768,
    span_mode="token-level",
)

model = GLiNER.from_config(config)

Training Config ExampleΒΆ

# Model Configuration
model_name: microsoft/deberta-v3-base
labels_encoder: sentence-transformers/all-MiniLM-L6-v2
name: "bi-encoder token gliner"
hidden_size: 768
dropout: 0.4
fine_tune: true
subtoken_pooling: first
span_mode: token-level
num_rnn_layers: 1

# Training Parameters
num_steps: 30000
train_batch_size: 8
eval_every: 1000
warmup_ratio: 0.1
scheduler_type: "cosine"

# Loss Configuration
loss_alpha: 0.25
loss_gamma: 2.0
label_smoothing: 0
loss_reduction: "sum"

# Learning Rate Configuration
lr_encoder: 1e-5
lr_others: 5e-5
weight_decay_encoder: 0.01
weight_decay_other: 0.01
max_grad_norm: 1.0

# Data Configuration
train_data: "data.json"
prev_path: null
save_total_limit: 3

# Advanced Settings
max_types: 100
max_len: 384

UniEncoder Span Decoder ConfigurationΒΆ

UniEncoderSpanDecoderConfig extends span-based NER with a generative decoder for label generation.

Architecture-Specific ParametersΒΆ

labels_decoderΒΆ

str, required

Model identifier for the generative decoder (e.g., GPT-2).

Examples:

  • "gpt2"

  • "distilgpt2"

  • "EleutherAI/gpt-neo-125M"


decoder_modeΒΆ

str, optional

Defines how decoder inputs are constructed.

Choices:

  • "prompt" β€” Use entity type embeddings as decoder context

  • "span" β€” Use span token representations as decoder context


full_decoder_contextΒΆ

bool, optional, defaults to True

Whether to provide full context to the decoder (all tokens in span) or just boundary markers.


blank_entity_probΒΆ

float, optional, defaults to 0.1

Probability of using a generic β€œentity” label during training for improved generalization.


labels_decoder_configΒΆ

dict or PretrainedConfig, optional

Nested configuration for the decoder model.


decoder_loss_coefΒΆ

float, optional, defaults to 0.5

Weight for the decoder generation loss in the total loss.


span_loss_coefΒΆ

float, optional, defaults to 0.5

Weight for the span classification loss in the total loss.

Usage ExampleΒΆ

from gliner import GLiNERConfig, GLiNER

# Create config for UniEncoderSpanDecoder
config = GLiNERConfig(
    model_name="microsoft/deberta-v3-base",
    labels_decoder="gpt2",  # Add decoder
    decoder_mode="span",
    full_decoder_context=True,
    blank_entity_prob=0.1,
    decoder_loss_coef=0.5,
    span_loss_coef=0.5,
)

model = GLiNER.from_config(config)

Training Config ExampleΒΆ

# Model Configuration
model_name: microsoft/deberta-v3-base
labels_decoder: gpt2  # Generative decoder
decoder_mode: span
full_decoder_context: true
blank_entity_prob: 0.1
name: "span decoder gliner"
max_width: 12
hidden_size: 768
dropout: 0.4
fine_tune: true
span_mode: markerV0

# Loss Configuration
decoder_loss_coef: 0.5
span_loss_coef: 0.5

# Training Parameters
num_steps: 30000
train_batch_size: 4  # Smaller due to decoder
eval_every: 1000
warmup_ratio: 0.1
scheduler_type: "cosine"

# Loss Configuration
loss_alpha: -1
loss_gamma: 0
label_smoothing: 0.1  # Helps with generation
loss_reduction: "sum"

# Learning Rate Configuration
lr_encoder: 1e-5
lr_others: 5e-5
weight_decay_encoder: 0.01
weight_decay_other: 0.01
max_grad_norm: 1.0

# Data Configuration
train_data: "data.json"
prev_path: null
save_total_limit: 3

# Advanced Settings
max_types: 25
max_len: 384

UniEncoder Span Relex ConfigurationΒΆ

UniEncoderSpanRelexConfig extends span-based NER with relation extraction capabilities.

Architecture-Specific ParametersΒΆ

relations_layerΒΆ

str, required

Type of relation representation layer to use.

Choices:

  • "dot" β€” Dot product between entity representations

  • "gcn" β€” Graph convolutional network for modeling interactions between entities

  • "gat" β€” Graph attention network for modeling interactions between entities


triples_layerΒΆ

str, optional

Type of triple scoring layer for (head, relation, tail) scoring.

Choices:

  • "distmult" β€” DistMult scoring function

  • "complex" β€” ComplEx scoring function

  • "transe" β€” TransE scoring function


embed_rel_tokenΒΆ

bool, optional, defaults to True

Whether to embed relation type tokens similar to entity tokens.


rel_token_indexΒΆ

int, optional, defaults to -1

Index of the relation token in vocabulary. Set automatically during initialization.


rel_tokenΒΆ

str, optional, defaults to "<<REL>>"

Special token used to mark relation types in the input.


span_loss_coefΒΆ

float, optional, defaults to 1.0

Weight for entity span classification loss.


adjacency_loss_coefΒΆ

float, optional, defaults to 1.0

Weight for entity pair adjacency prediction loss.


relation_loss_coefΒΆ

float, optional, defaults to 1.0

Weight for relation type classification loss.

Usage ExampleΒΆ

from gliner import GLiNERConfig, GLiNER

# Create config for UniEncoderSpanRelex
config = GLiNERConfig(
    model_name="microsoft/deberta-v3-base",
    relations_layer="biaffine",  # Enable relations
    triples_layer="distmult",
    rel_token="<<REL>>",
    span_loss_coef=1.0,
    adjacency_loss_coef=1.0,
    relation_loss_coef=1.0,
)

model = GLiNER.from_config(config)

Training Config ExampleΒΆ

# Model Configuration
model_name: microsoft/deberta-v3-base
relations_layer: biaffine  # Enable relation extraction
triples_layer: distmult
rel_token: "<<REL>>"
embed_rel_token: true
name: "span relex gliner"
max_width: 12
hidden_size: 768
dropout: 0.4
fine_tune: true
span_mode: markerV0

# Loss Configuration
span_loss_coef: 1.0
adjacency_loss_coef: 1.0
relation_loss_coef: 1.0

# Training Parameters
num_steps: 30000
train_batch_size: 6  # Smaller due to relation computation
eval_every: 1000
warmup_ratio: 0.1
scheduler_type: "cosine"

# Loss Configuration
loss_alpha: -1
loss_gamma: 0
label_smoothing: 0
loss_reduction: "sum"

# Learning Rate Configuration
lr_encoder: 1e-5
lr_others: 5e-5
weight_decay_encoder: 0.01
weight_decay_other: 0.01
max_grad_norm: 1.0

# Data Configuration
train_data: "data_with_relations.json"  # Must include relation annotations
prev_path: null
save_total_limit: 3

# Advanced Settings
max_types: 25
max_len: 384

Data Format for Relation ExtractionΒΆ

train_data = [
    {
        "tokenized_text": ["John", "works", "at", "Microsoft"],
        "ner": [[0, 0, "person"], [3, 3, "organization"]],
        "relations": [[0, 1, "works_at"]]  # (head_entity_idx, tail_entity_idx, relation_type)
    }
]

TrainingArgumentsΒΆ

Custom extension of transformers.TrainingArguments with additional parameters for GLiNER models.

GLiNER-Specific ParametersΒΆ

others_lrΒΆ

float, optional
Learning rate for non-encoder parameters (e.g., span layers, label encoder). If not specified, uses main learning_rate.


others_weight_decayΒΆ

float, optional, defaults to 0.0
Weight decay for non-encoder parameters.


focal_loss_alphaΒΆ

float, optional, defaults to -1
Alpha parameter for focal loss. If β‰₯ 0, focal loss is activated.

Focal loss formula:
FL(p_t) = -Ξ± Γ— (1 - p_t)^Ξ³ Γ— log(p_t)


focal_loss_gammaΒΆ

float, optional, defaults to 0
Gamma parameter for focal loss. Higher values increase focus on hard examples.


focal_loss_prob_marginΒΆ

float, optional, defaults to 0.0
Probability margin for focal loss adjustment.


label_smoothingΒΆ

float, optional, defaults to 0.0
Label smoothing factor Ξ΅ for regularization.


loss_reductionΒΆ

str, optional, defaults to "sum"
How to aggregate loss across samples.
Choices: "sum", "mean"


negativesΒΆ

float, optional, defaults to 1.0
Ratio of negative to positive spans during training.


maskingΒΆ

str, optional, defaults to "none"
Masking strategy for negative sampling.
Choices: "none", "global", "label", "span"