Source code for gliner.modeling.encoder
import warnings
from typing import Any, Dict, List, Tuple, Union, Optional
from pathlib import Path
import torch
from torch import nn
from transformers import AutoModel, AutoConfig
from transformers.modeling_outputs import BaseModelOutput
from ..utils import MissedPackageException, is_module_available
from .layers import LayersFuser
from ..infer_packing import InferencePackingConfig, unpack_spans, pack_requests
# Check for optional dependencies
IS_LLM2VEC = is_module_available("llm2vec")
IS_PEFT = is_module_available("peft")
IS_TURBOT5 = is_module_available("turbot5")
IS_FLASHDEBERTA = is_module_available("flashdeberta")
if IS_LLM2VEC:
from llm2vec.models import GemmaBiModel, LlamaBiModel, Qwen2BiModel, MistralBiModel
DECODER_MODEL_MAPPING = {
"MistralConfig": MistralBiModel,
"LlamaConfig": LlamaBiModel,
"GemmaConfig": GemmaBiModel,
"Qwen2Config": Qwen2BiModel,
}
else:
DECODER_MODEL_MAPPING = {}
if IS_TURBOT5:
from turbot5.model.modeling import T5EncoderModel
else:
from transformers import T5EncoderModel
if IS_FLASHDEBERTA:
from flashdeberta import FlashDebertaV2Model as DebertaV2Model
else:
from transformers import DebertaV2Model
if IS_PEFT:
from peft import LoraConfig, get_peft_model
[docs]
class Transformer(nn.Module):
"""Flexible transformer wrapper supporting multiple architectures and configurations.
This class provides a unified interface for various transformer models including
encoder-only (BERT, DeBERTa), encoder-decoder (T5), and decoder-only models
(LLaMA, Mistral) with bidirectional adaptations. It handles model initialization,
adapter loading, and specialized forward passes for different architectures.
Attributes:
model: The underlying transformer model instance.
layers_fuser: Optional layer fusion module when config.fuse_layers is True.
config: Configuration object containing model hyperparameters.
"""
[docs]
def __init__(
self,
model_name: str,
config: Any,
from_pretrained: bool = False,
labels_encoder: bool = False,
cache_dir: Optional[Union[str, Path]] = None,
) -> None:
"""Initializes the transformer wrapper.
Args:
model_name: Name or path of the pretrained model to load.
config: Configuration object containing model hyperparameters. Must have
attributes like `encoder_config`, `labels_encoder_config`, `vocab_size`,
`_attn_implementation`, and `fuse_layers`.
from_pretrained: If True, loads pretrained weights. If False, initializes
from config only. Defaults to False.
labels_encoder: If True, initializes as a labels encoder using
`config.labels_encoder_config`. Defaults to False.
cache_dir: Optional directory for caching downloaded models. Defaults to None.
Raises:
MissedPackageException: If required packages (llm2vec, peft) are not installed
when needed for specific model types.
"""
super().__init__()
if labels_encoder:
encoder_config = config.labels_encoder_config
else:
encoder_config = config.encoder_config
if encoder_config is None:
encoder_config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
if config.vocab_size != -1:
encoder_config.vocab_size = config.vocab_size
if config._attn_implementation is not None and not labels_encoder:
encoder_config._attn_implementation = config._attn_implementation
config_name = encoder_config.__class__.__name__
kwargs = {}
if config_name in DECODER_MODEL_MAPPING:
if not IS_LLM2VEC:
raise MissedPackageException(
f"The llm2vec package must be installed to use this decoder model: {config_name}"
)
else:
ModelClass = DECODER_MODEL_MAPPING[config_name]
custom = True
elif config_name in {"T5Config", "MT5Config"}:
custom = True
ModelClass = T5EncoderModel
if IS_TURBOT5:
kwargs = {"attention_type": "flash"}
elif config_name in {"DebertaV2Config"}:
custom = True
ModelClass = DebertaV2Model
else:
custom = False
ModelClass = AutoModel
if from_pretrained:
self.model = ModelClass.from_pretrained(model_name, trust_remote_code=True)
elif not custom:
self.model = ModelClass.from_config(encoder_config, trust_remote_code=True)
else:
self.model = ModelClass(encoder_config, **kwargs)
adapter_config_file = Path(model_name) / "adapter_config.json"
if adapter_config_file.exists():
if not IS_PEFT:
warnings.warn(
"Adapter configs were detected, if you want to apply them you need to install peft package.",
stacklevel=2,
)
else:
adapter_config = LoraConfig.from_pretrained(model_name)
self.model = get_peft_model(self.model, adapter_config)
if config.fuse_layers:
self.layers_fuser = LayersFuser(encoder_config.num_hidden_layers, encoder_config.hidden_size)
if labels_encoder:
config.labels_encoder_config = encoder_config
else:
config.encoder_config = encoder_config
self.config = config
[docs]
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
"""Forward pass through the transformer model.
Handles different attention mask configurations and model architectures,
including support for pair attention masks for packed sequences.
Args:
*args: Variable positional arguments passed to the model.
**kwargs: Variable keyword arguments. Special arguments include:
- pair_attention_mask: Optional pairwise attention mask of shape
(batch_size, seq_len, seq_len) for packed sequences.
- attention_mask: Standard attention mask of shape (batch_size, seq_len).
- input_ids: Input token IDs of shape (batch_size, seq_len).
- Other model-specific arguments.
Returns:
Encoded representations of shape (batch_size, seq_len, hidden_size).
If config.fuse_layers is True, returns fused layer outputs, otherwise
returns the last hidden state.
"""
pair_attention_mask = kwargs.pop("pair_attention_mask", None)
base_attention_mask = kwargs.pop("attention_mask", None)
# Extract input_ids if present
args = list(args)
input_ids = kwargs.pop("input_ids", None)
if input_ids is None and args:
input_ids = args[0]
args = args[1:]
args = tuple(args)
# Set default kwargs
kwargs.setdefault("output_attentions", False)
kwargs.setdefault("return_dict", True)
# Handle output_hidden_states based on fuse_layers config
if self.config.fuse_layers:
kwargs["output_hidden_states"] = True
else:
kwargs.setdefault("output_hidden_states", False)
if pair_attention_mask is not None:
mask_info = self._prepare_pair_attention_masks(
pair_attention_mask,
base_attention_mask,
input_ids,
kwargs.get("inputs_embeds"),
)
model_kwargs = dict(kwargs)
model_name = self.model.__class__.__name__
if model_name in {"DebertaV2Model", "DebertaModel"}:
output = self._forward_deberta(
input_ids=input_ids,
model_kwargs=model_kwargs,
mask_info=mask_info,
)
elif model_name == "ModernBertModel":
output = self._forward_modernbert(
input_ids=input_ids,
model_kwargs=model_kwargs,
mask_info=mask_info,
)
elif model_name in {"T5EncoderModel", "MT5EncoderModel", "T5Model"}:
output = self._forward_t5(
input_ids=input_ids,
model_kwargs=model_kwargs,
mask_info=mask_info,
)
else:
model_kwargs.pop("packing_config", None)
model_kwargs["attention_mask"] = mask_info["extended_mask"]
output = self.model(*args, **model_kwargs)
else:
if base_attention_mask is not None:
kwargs["attention_mask"] = base_attention_mask
output = self.model(input_ids, *args, **kwargs)
# Common logic for both paths
if self.config.fuse_layers:
encoder_layer = self.layers_fuser(output.hidden_states)
else:
encoder_layer = output[0]
return encoder_layer
def _get_model_dtype(self) -> torch.dtype:
"""Gets the data type of the model parameters.
Returns:
The dtype of the model's parameters, or torch.float32 if no parameters exist.
"""
try:
return next(self.model.parameters()).dtype
except StopIteration:
return torch.float32
def _prepare_pair_attention_masks(
self,
pair_attention_mask: torch.Tensor,
attention_mask: Optional[torch.Tensor],
input_ids: Optional[torch.Tensor],
inputs_embeds: Optional[torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""Prepares attention masks for packed sequence processing.
Converts pair attention masks (which specify token-to-token visibility) into
various mask formats required by different transformer architectures. Ensures
diagonal elements are attended to and inactive tokens are properly masked.
Args:
pair_attention_mask: Pairwise attention mask of shape (batch_size, seq_len, seq_len)
where 1 indicates attention is allowed.
attention_mask: Optional standard attention mask of shape (batch_size, seq_len).
input_ids: Optional input token IDs for device detection.
inputs_embeds: Optional input embeddings for device detection.
Returns:
Dictionary containing:
- token_mask: Per-token mask of shape (batch_size, seq_len).
- token_mask_bool: Boolean version of token_mask.
- extended_mask: 4D attention mask of shape (batch_size, 1, seq_len, seq_len)
with -inf for masked positions.
- block_mask: Boolean 3D mask of shape (batch_size, seq_len, seq_len).
"""
device = pair_attention_mask.device
if input_ids is not None:
device = input_ids.device
elif inputs_embeds is not None:
device = inputs_embeds.device
pair_mask_bool = pair_attention_mask.to(device=device, dtype=torch.bool)
token_mask_bool = pair_mask_bool.any(dim=-1)
if attention_mask is not None:
token_mask_bool = token_mask_bool & attention_mask.to(device=device, dtype=torch.bool)
seq_len = pair_mask_bool.size(-1)
if seq_len:
identity = torch.eye(seq_len, device=device, dtype=torch.bool).unsqueeze(0)
token_diag = token_mask_bool.unsqueeze(-1)
pair_mask_bool = pair_mask_bool | (identity & token_diag)
active = token_mask_bool.unsqueeze(-1) & token_mask_bool.unsqueeze(-2)
pair_mask_bool = pair_mask_bool & active
if attention_mask is not None:
token_mask = token_mask_bool.to(attention_mask.dtype)
else:
token_mask = token_mask_bool.to(dtype=torch.float32)
mask_dtype = self._get_model_dtype()
neg_inf = torch.finfo(mask_dtype).min
extended_mask = (
torch.zeros(pair_mask_bool.shape, dtype=mask_dtype, device=device)
.masked_fill(~pair_mask_bool, neg_inf)
.unsqueeze(1)
)
inactive = ~token_mask_bool
if inactive.any():
extended_mask = extended_mask.masked_fill(
inactive.unsqueeze(1).unsqueeze(-1),
torch.tensor(0.0, dtype=mask_dtype, device=device),
)
return {
"token_mask": token_mask,
"token_mask_bool": token_mask_bool,
"extended_mask": extended_mask,
"block_mask": pair_mask_bool,
}
def _forward_deberta(
self,
input_ids: Optional[torch.Tensor],
model_kwargs: Dict[str, Any],
mask_info: Dict[str, torch.Tensor],
) -> BaseModelOutput:
"""Forward pass through DeBERTa models with packed attention support.
Handles the specific requirements of DeBERTa architecture including embeddings,
relative position encodings, and optional enhanced mask tuning (z_steps).
Args:
input_ids: Input token IDs of shape (batch_size, seq_len), or None if
inputs_embeds is provided.
model_kwargs: Dictionary of model-specific keyword arguments including
inputs_embeds, token_type_ids, position_ids, output_attentions,
output_hidden_states, and return_dict.
mask_info: Dictionary containing prepared attention masks from
_prepare_pair_attention_masks.
Returns:
BaseModelOutput containing:
- last_hidden_state: Final layer output of shape (batch_size, seq_len, hidden_size).
- hidden_states: Tuple of all layer outputs if requested.
- attentions: Tuple of attention weights if requested.
Raises:
ValueError: If neither or both input_ids and inputs_embeds are provided.
"""
inputs_embeds = model_kwargs.pop("inputs_embeds", None)
token_type_ids = model_kwargs.pop("token_type_ids", None)
position_ids = model_kwargs.pop("position_ids", None)
output_attentions = model_kwargs.pop("output_attentions")
produce_hidden = model_kwargs.pop("output_hidden_states")
return_dict = model_kwargs.pop("return_dict")
if input_ids is None and inputs_embeds is None:
raise ValueError("Either input_ids or inputs_embeds must be provided for packed attention")
if input_ids is not None and inputs_embeds is not None:
raise ValueError("Cannot supply both input_ids and inputs_embeds")
if token_type_ids is None:
ref = inputs_embeds if inputs_embeds is not None else input_ids
shape = ref.size()[:-1] if inputs_embeds is not None else ref.size()
token_type_ids = torch.zeros(shape, dtype=torch.long, device=ref.device)
embedding_output = self.model.embeddings(
input_ids=input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
mask=mask_info["token_mask"],
inputs_embeds=inputs_embeds,
)
encoder_outputs = self.model.encoder(
embedding_output,
mask_info["block_mask"],
output_hidden_states=True,
output_attentions=output_attentions,
return_dict=True,
)
encoded_layers = list(encoder_outputs.hidden_states)
if getattr(self.model, "z_steps", 0) > 1:
hidden_states = encoded_layers[-2]
layers = [self.model.encoder.layer[-1] for _ in range(self.model.z_steps)]
query_states = encoded_layers[-1]
rel_embeddings = self.model.encoder.get_rel_embedding()
attention_mask = self.model.encoder.get_attention_mask(mask_info["block_mask"])
rel_pos = self.model.encoder.get_rel_pos(embedding_output)
for layer in layers[1:]:
query_states = layer(
hidden_states,
attention_mask,
output_attentions=False,
query_states=query_states,
relative_pos=rel_pos,
rel_embeddings=rel_embeddings,
)
encoded_layers.append(query_states)
sequence_output = encoded_layers[-1]
hidden_states_tuple = tuple(encoded_layers) if produce_hidden else None
attentions = encoder_outputs.attentions if output_attentions else None
if not return_dict:
result = (sequence_output,)
if hidden_states_tuple is not None:
result += (hidden_states_tuple,)
if attentions is not None:
result += (attentions,)
return result
return BaseModelOutput(
last_hidden_state=sequence_output,
hidden_states=hidden_states_tuple,
attentions=attentions,
)
def _forward_modernbert(
self,
input_ids: Optional[torch.Tensor],
model_kwargs: Dict[str, Any],
mask_info: Dict[str, torch.Tensor],
) -> BaseModelOutput:
"""Forward pass through ModernBERT models with packed attention support.
Handles ModernBERT-specific features including global and sliding window
attention patterns, and temporarily switches to eager attention mode
when using packed attention masks.
Args:
input_ids: Input token IDs of shape (batch_size, seq_len), or None if
inputs_embeds is provided.
model_kwargs: Dictionary of model-specific keyword arguments including
inputs_embeds, position_ids, indices, cu_seqlens, max_seqlen,
batch_size, seq_len, output_attentions, output_hidden_states, return_dict.
mask_info: Dictionary containing prepared attention masks from
_prepare_pair_attention_masks.
Returns:
BaseModelOutput containing:
- last_hidden_state: Final layer output of shape (batch_size, seq_len, hidden_size).
- hidden_states: Tuple of all layer outputs if requested.
- attentions: Tuple of attention weights if requested.
Raises:
ValueError: If both or neither input_ids and inputs_embeds are provided.
"""
inputs_embeds = model_kwargs.pop("inputs_embeds", None)
position_ids = model_kwargs.pop("position_ids", None)
cu_seqlens = model_kwargs.pop("cu_seqlens", None)
max_seqlen = model_kwargs.pop("max_seqlen", None)
batch_size = model_kwargs.pop("batch_size", None)
seq_len = model_kwargs.pop("seq_len", None)
output_attentions = model_kwargs.pop("output_attentions")
output_hidden_states = model_kwargs.pop("output_hidden_states")
return_dict = model_kwargs.pop("return_dict")
if (input_ids is None) == (inputs_embeds is None):
raise ValueError("ModernBERT requires exactly one of input_ids or inputs_embeds")
token_mask_bool = mask_info["token_mask_bool"].to(torch.bool)
if batch_size is None or seq_len is None:
ref = inputs_embeds if inputs_embeds is not None else input_ids
batch_size, seq_len = ref.shape[:2]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
base_attention_mask = token_mask_bool
original_impl = self.model.config._attn_implementation
if original_impl == "flash_attention_2":
self.model.config._attn_implementation = "eager"
self.model._maybe_set_compile()
global_attention_mask, sliding_window_mask = self.model._update_attention_mask(
base_attention_mask,
output_attentions=output_attentions,
)
block = mask_info["block_mask"].unsqueeze(1)
neg_inf = torch.finfo(global_attention_mask.dtype).min
global_attention_mask = global_attention_mask.masked_fill(~block, neg_inf)
sliding_window_mask = sliding_window_mask.masked_fill(~block, neg_inf)
hidden_states = self.model.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for encoder_layer in self.model.layers:
if output_hidden_states:
all_hidden_states = (*all_hidden_states, hidden_states)
layer_outputs = encoder_layer(
hidden_states,
attention_mask=global_attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions and len(layer_outputs) > 1:
all_self_attentions = (*all_self_attentions, layer_outputs[1])
if output_hidden_states:
all_hidden_states = (*all_hidden_states, hidden_states)
hidden_states = self.model.final_norm(hidden_states)
if original_impl == "flash_attention_2":
self.model.config._attn_implementation = original_impl
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
def _forward_t5(
self,
input_ids: Optional[torch.Tensor],
model_kwargs: Dict[str, Any],
mask_info: Dict[str, torch.Tensor],
) -> BaseModelOutput:
"""Forward pass through T5 encoder models with packed attention support.
Handles T5/MT5-specific architecture requirements including relative position
bias and proper attention mask formatting for the encoder stack.
Args:
input_ids: Input token IDs of shape (batch_size, seq_len), or None if
inputs_embeds is provided.
model_kwargs: Dictionary of model-specific keyword arguments including
input_ids (can override parameter), inputs_embeds, head_mask,
past_key_values, use_cache, output_attentions, output_hidden_states,
return_dict, cache_position.
mask_info: Dictionary containing prepared attention masks from
_prepare_pair_attention_masks.
Returns:
BaseModelOutput containing:
- last_hidden_state: Final layer output of shape (batch_size, seq_len, hidden_size).
- hidden_states: Tuple of all layer outputs if requested.
- attentions: Tuple of attention weights if requested.
Raises:
ValueError: If neither input_ids nor inputs_embeds is provided, or if
unsupported kwargs are passed.
"""
stack = self.model.encoder
kw_input_ids = model_kwargs.pop("input_ids", None)
if input_ids is None or kw_input_ids is not None:
input_ids = kw_input_ids
inputs_embeds = model_kwargs.pop("inputs_embeds", None)
head_mask = model_kwargs.pop("head_mask", None)
past_key_values = model_kwargs.pop("past_key_values", None)
use_cache = model_kwargs.pop("use_cache", stack.config.use_cache)
output_attentions = model_kwargs.pop("output_attentions")
output_hidden_states = model_kwargs.pop("output_hidden_states")
return_dict = model_kwargs.pop("return_dict")
cache_position = model_kwargs.pop("cache_position", None)
if model_kwargs:
raise ValueError(f"Unsupported kwargs for T5 forward: {list(model_kwargs.keys())}")
if inputs_embeds is None:
if input_ids is None:
raise ValueError("Either input_ids or inputs_embeds must be provided")
inputs_embeds = stack.embed_tokens(input_ids)
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
device = inputs_embeds.device
if cache_position is None:
cache_position = torch.arange(seq_length, device=device)
block_mask = mask_info["block_mask"].to(device=device, dtype=torch.bool)
dtype = inputs_embeds.dtype
neg_inf = torch.finfo(dtype).min
causal_mask = torch.zeros(block_mask.shape, dtype=dtype, device=device)
causal_mask = causal_mask.masked_fill(~block_mask, neg_inf).unsqueeze(1)
head_mask = stack.get_head_mask(head_mask, stack.config.num_layers)
hidden_states = stack.dropout(inputs_embeds)
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
position_bias = None
for idx, layer_module in enumerate(stack.block):
if output_hidden_states:
all_hidden_states = (
*all_hidden_states,
hidden_states,
)
layer_head_mask = head_mask[idx] if head_mask is not None else None
layer_outputs = layer_module(
hidden_states,
attention_mask=causal_mask,
position_bias=position_bias,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=None,
past_key_values=None if not use_cache else past_key_values,
use_cache=False,
output_attentions=output_attentions,
return_dict=True,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
position_bias = layer_outputs[1]
if output_attentions:
all_attentions = (
*all_attentions,
layer_outputs[2],
)
hidden_states = stack.final_layer_norm(hidden_states)
hidden_states = stack.dropout(hidden_states)
if output_hidden_states:
all_hidden_states = (*all_hidden_states, hidden_states)
if not return_dict:
result = (hidden_states,)
if output_hidden_states:
result += (all_hidden_states,)
if output_attentions:
result += (all_attentions,)
return result
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
[docs]
class Encoder(nn.Module):
"""Standard encoder module wrapping a transformer model with optional projection.
This class provides a high-level interface for encoding text sequences, including
support for inference-time packing to improve throughput. It handles embedding
extraction and optional projection to a different hidden size.
Attributes:
bert_layer: The underlying Transformer instance.
projection: Optional linear projection layer when config.hidden_size differs
from the model's native hidden size.
"""
[docs]
def __init__(
self, config: Any, from_pretrained: bool = False, cache_dir: Optional[Union[str, Path]] = None
) -> None:
"""Initializes the encoder.
Args:
config: Configuration object containing model hyperparameters including
`model_name`, `hidden_size`, and transformer-specific settings.
from_pretrained: If True, loads pretrained weights for the transformer.
Defaults to False.
cache_dir: Optional directory for caching downloaded models. Defaults to None.
"""
super().__init__()
self.bert_layer = Transformer(config.model_name, config, from_pretrained, cache_dir=cache_dir)
bert_hidden_size = self.bert_layer.model.config.hidden_size
if config.hidden_size != bert_hidden_size:
self.projection = nn.Linear(bert_hidden_size, config.hidden_size)
[docs]
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
"""Resizes token embeddings to accommodate new vocabulary size.
Args:
new_num_tokens: New vocabulary size.
pad_to_multiple_of: Optional value to pad vocabulary size to a multiple.
Defaults to None.
Returns:
The resized embedding layer.
"""
return self.bert_layer.model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
[docs]
def get_input_embeddings(self) -> nn.Embedding:
"""Gets the input embedding layer.
Returns:
The model's input embedding layer.
"""
return self.bert_layer.model.get_input_embeddings()
[docs]
def encode_text(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, *args: Any, **kwargs: Any
) -> torch.Tensor:
"""Encodes input text sequences into contextualized embeddings.
Supports inference-time packing to batch multiple variable-length sequences
efficiently when packing_config is provided and not in training mode.
Args:
input_ids: Input token IDs of shape (batch_size, seq_len).
attention_mask: Attention mask of shape (batch_size, seq_len) where 1
indicates valid tokens and 0 indicates padding.
*args: Additional positional arguments passed to the transformer.
**kwargs: Additional keyword arguments including:
- packing_config: Optional InferencePackingConfig for efficient batching.
- pair_attention_mask: Optional pairwise attention mask for packed sequences.
Returns:
Token embeddings of shape (batch_size, seq_len, hidden_size).
"""
packing_config: Optional[InferencePackingConfig] = kwargs.pop("packing_config", None)
pair_attention_mask = kwargs.pop("pair_attention_mask", None)
if (
packing_config is not None
and not self.training
and isinstance(input_ids, torch.Tensor)
and isinstance(attention_mask, torch.Tensor)
and input_ids.dim() == 2
):
token_embeddings = self._encode_with_packing(
input_ids,
attention_mask,
packing_config,
pair_attention_mask,
*args,
**kwargs,
)
else:
bert_kwargs = dict(kwargs)
if attention_mask is not None:
bert_kwargs["attention_mask"] = attention_mask
if pair_attention_mask is not None:
bert_kwargs["pair_attention_mask"] = pair_attention_mask
token_embeddings = self.bert_layer(
input_ids=input_ids,
**bert_kwargs,
)
if hasattr(self, "projection"):
token_embeddings = self.projection(token_embeddings)
return token_embeddings
def _encode_with_packing(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
packing_config: InferencePackingConfig,
pair_attention_mask: Optional[torch.Tensor],
*args: Any,
**kwargs: Any,
) -> torch.Tensor:
"""Encodes sequences using inference-time packing for efficiency.
Packs multiple variable-length sequences into fewer, more efficient batches
to maximize GPU utilization during inference. Short sequences are combined
into single packed sequences.
Args:
input_ids: Input token IDs of shape (batch_size, seq_len).
attention_mask: Attention mask of shape (batch_size, seq_len).
packing_config: Configuration for packing behavior.
pair_attention_mask: Optional pairwise attention mask.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
Token embeddings of shape (batch_size, seq_len, hidden_size) with
proper unpacking to restore original batch structure.
"""
lengths = attention_mask.sum(dim=-1, dtype=torch.int64).tolist()
seq_len = int(input_ids.size(1))
if not lengths or all(int(ln) == seq_len for ln in lengths):
bert_kwargs = dict(kwargs)
bert_kwargs["attention_mask"] = attention_mask
if pair_attention_mask is not None:
bert_kwargs["pair_attention_mask"] = pair_attention_mask
return self.bert_layer(input_ids=input_ids, **bert_kwargs)
requests = []
for row, length in zip(input_ids, lengths):
if length <= 0:
requests.append({"input_ids": []})
else:
requests.append({"input_ids": row[:length].tolist()})
pad_token_id = self.bert_layer.model.config.pad_token_id
if pad_token_id is None:
pad_token_id = 0
packed = pack_requests(requests, packing_config, pad_token_id)
device = input_ids.device
packed_ids = packed.input_ids.to(device=device)
packed_mask = packed.pair_attention_mask.to(device=device)
packed_fallback = packed.attention_mask.to(device=device)
attn_to_use = packed_mask if packed_mask.numel() else packed_fallback
bert_kwargs = dict(kwargs)
if packed_mask.numel():
bert_kwargs["attention_mask"] = packed_fallback
bert_kwargs["pair_attention_mask"] = packed_mask
else:
bert_kwargs["attention_mask"] = attn_to_use
token_embeddings = self.bert_layer(
input_ids=packed_ids,
**bert_kwargs,
)
unpacked: List[torch.Tensor] = unpack_spans(token_embeddings, packed)
hidden_size = token_embeddings.size(-1)
batch, seq = input_ids.size()
output = token_embeddings.new_zeros(batch, seq, hidden_size)
for idx, target in enumerate(unpacked):
tgt_len = int(target.size(0))
if tgt_len == 0:
continue
output[idx, :tgt_len] = target
return output
[docs]
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
"""Forward pass through the encoder.
Args:
*args: Positional arguments passed to encode_text.
**kwargs: Keyword arguments passed to encode_text.
Returns:
Token embeddings of shape (batch_size, seq_len, hidden_size).
"""
token_embeddings = self.encode_text(*args, **kwargs)
return token_embeddings
[docs]
class BiEncoder(Encoder):
"""Bi-encoder architecture with separate encoders for text and labels.
This encoder processes text sequences and label sequences through potentially
different transformer models, producing aligned representations for both. The
label representations are mean-pooled to create fixed-size embeddings.
Attributes:
bert_layer: Inherited text encoder from Encoder.
projection: Inherited optional projection from Encoder.
labels_encoder: Separate Transformer instance for encoding labels.
labels_projection: Optional projection for label embeddings when label
encoder hidden size differs from config.hidden_size.
"""
[docs]
def __init__(
self, config: Any, from_pretrained: bool = False, cache_dir: Optional[Union[str, Path]] = None
) -> None:
"""Initializes the bi-encoder.
Args:
config: Configuration object containing model hyperparameters including
`labels_encoder` (model name for label encoder) and `hidden_size`.
from_pretrained: If True, loads pretrained weights for both encoders.
Defaults to False.
cache_dir: Optional directory for caching downloaded models. Defaults to None.
"""
super().__init__(config, from_pretrained)
if config.labels_encoder is not None:
self.labels_encoder = Transformer(config.labels_encoder, config, from_pretrained, True, cache_dir=cache_dir)
le_hidden_size = self.labels_encoder.model.config.hidden_size
if config.hidden_size != le_hidden_size:
self.labels_projection = nn.Linear(le_hidden_size, config.hidden_size)
[docs]
def mean_pooling(self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Applies mean pooling over token embeddings using attention mask.
Computes the average of token embeddings weighted by the attention mask,
ignoring padded positions.
Args:
token_embeddings: Token-level embeddings of shape (batch_size, seq_len, hidden_size).
attention_mask: Binary mask of shape (batch_size, seq_len) where 1 indicates
valid tokens and 0 indicates padding.
Returns:
Pooled embeddings of shape (batch_size, hidden_size).
"""
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
[docs]
def encode_labels(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, *args: Any, **kwargs: Any
) -> torch.Tensor:
"""Encodes label sequences into fixed-size embeddings.
Processes labels through the dedicated labels encoder and applies mean pooling
to produce sentence-level representations.
Args:
input_ids: Label token IDs of shape (batch_size, seq_len).
attention_mask: Attention mask of shape (batch_size, seq_len).
*args: Additional positional arguments.
**kwargs: Additional keyword arguments (packing_config and pair_attention_mask
are removed as they're not supported for labels).
Returns:
Pooled label embeddings of shape (batch_size, hidden_size).
"""
label_kwargs = dict(kwargs)
label_kwargs.pop("packing_config", None)
label_kwargs.pop("pair_attention_mask", None)
label_kwargs["attention_mask"] = attention_mask
labels_embeddings = self.labels_encoder(input_ids, *args, **label_kwargs)
if hasattr(self, "labels_projection"):
labels_embeddings = self.labels_projection(labels_embeddings)
labels_embeddings = self.mean_pooling(labels_embeddings, attention_mask)
return labels_embeddings
[docs]
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels_input_ids: Optional[torch.Tensor] = None,
labels_attention_mask: Optional[torch.Tensor] = None,
*args: Any,
**kwargs: Any,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass through the bi-encoder.
Encodes both text sequences (token-level) and label sequences (pooled) to
produce aligned representations.
Args:
input_ids: Text token IDs of shape (batch_size, seq_len).
attention_mask: Text attention mask of shape (batch_size, seq_len).
labels_input_ids: Label token IDs of shape (batch_size, label_seq_len).
labels_attention_mask: Label attention mask of shape (batch_size, label_seq_len).
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
A tuple containing:
- token_embeddings: Text embeddings of shape (batch_size, seq_len, hidden_size).
- labels_embeddings: Pooled label embeddings of shape (batch_size, hidden_size).
"""
token_embeddings = self.encode_text(input_ids, attention_mask, *args, **kwargs)
labels_embeddings = self.encode_labels(labels_input_ids, labels_attention_mask, *args, **kwargs)
return token_embeddings, labels_embeddings