"""Decoder modules for autoregressive text generation with optional constraints.
This module provides decoder architectures built on causal language models, supporting
both standard generation and prefix-constrained decoding using trie structures. It
includes custom generation implementations and numerical stability improvements.
"""
import warnings
from typing import Any, List, Union, Optional
from pathlib import Path
import torch
from torch import nn
from transformers import AutoConfig, LogitsProcessor, LogitsProcessorList, AutoModelForCausalLM
from ..utils import is_module_available
from ..decoding.trie import LabelsTrie
# Check for optional dependencies
IS_PEFT = is_module_available("peft")
if IS_PEFT:
from peft import LoraConfig, get_peft_model
[docs]
class NumericalStabilityProcessor(LogitsProcessor):
"""Logits processor that ensures numerical stability during generation.
This processor handles edge cases in logit values by replacing negative infinity
values with the minimum representable value for the dtype, clamping extreme values,
and adding a small epsilon for stability.
Attributes:
epsilon: Small constant added to logits for numerical stability.
"""
[docs]
def __init__(self, epsilon: float = 1e-6) -> None:
"""Initializes the numerical stability processor.
Args:
epsilon: Small constant to add to logits. Defaults to 1e-6.
"""
self.epsilon = epsilon
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
"""Processes logits to ensure numerical stability.
Replaces negative infinity values, clamps extreme values to prevent
overflow/underflow, and adds epsilon for stability.
Args:
input_ids: Previously generated token IDs of shape (batch_size, seq_len).
scores: Raw logit scores of shape (batch_size, vocab_size).
Returns:
Stabilized logit scores of shape (batch_size, vocab_size).
"""
scores = torch.where(
torch.isneginf(scores), torch.tensor(torch.finfo(scores.dtype).min).to(scores.device), scores
)
scores = torch.clamp(scores, min=-1e9, max=1e9)
return scores + self.epsilon
[docs]
class Decoder(nn.Module):
"""High-level decoder interface for autoregressive generation.
This class provides a unified interface for text generation from embeddings,
supporting both standard generation and constrained decoding using trie structures.
It includes custom generation implementations and integrates with Hugging Face's
generation API.
Attributes:
decoder_layer: The underlying DecoderTransformer instance.
decoder_hidden_size: Hidden dimension size of the decoder model.
"""
[docs]
def __init__(
self, config: Any, from_pretrained: bool = False, cache_dir: Optional[Union[str, Path]] = None
) -> None:
"""Initializes the decoder.
Args:
config: Configuration object containing model hyperparameters including
`labels_decoder` (model name) and decoder-specific settings.
from_pretrained: If True, loads pretrained weights for the decoder.
Defaults to False.
cache_dir: Optional directory for caching downloaded models. Defaults to None.
"""
super().__init__()
self.decoder_layer = DecoderTransformer(config.labels_decoder, config, from_pretrained, cache_dir=cache_dir)
self.decoder_hidden_size = self.decoder_layer.model.config.hidden_size
[docs]
def ids_to_embeds(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
"""Converts token IDs to their corresponding embeddings.
Args:
input_ids: Token IDs of shape (batch_size, seq_len).
Returns:
Token embeddings of shape (batch_size, seq_len, hidden_size).
"""
input_ids = input_ids.to(self.decoder_layer.model.device)
embedding_layer = self.decoder_layer.model.get_input_embeddings()
return embedding_layer(input_ids)
[docs]
@torch.inference_mode()
def generate_from_embeds_custom(
self,
inputs_embeds: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
max_new_tokens: int = 32,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
temperature: float = 1.0,
do_sample: bool = False,
labels_trie: Optional[LabelsTrie] = None,
**kwargs: Any,
) -> torch.LongTensor:
"""Custom generation implementation from embeddings with optional trie constraints.
This method implements token-by-token generation with KV caching and support for
trie-based constrained decoding. Unlike the standard generate method, this
implementation provides more control over the generation process and handles
trie constraints at each step.
Args:
inputs_embeds: Input embeddings of shape (batch_size, prefix_len, hidden_size)
serving as the generation prefix.
attention_mask: Optional attention mask of shape (batch_size, prefix_len).
If None, assumes all prefix tokens are valid. Defaults to None.
max_new_tokens: Maximum number of new tokens to generate. Defaults to 32.
eos_token_id: Token ID marking end of sequence. If None, uses model's
default. Defaults to None.
pad_token_id: Token ID for padding. If None, uses model's default or
eos_token_id. Defaults to None.
temperature: Sampling temperature for controlling randomness. Values < 1
make distribution sharper, > 1 make it more uniform. Defaults to 1.0.
do_sample: If True, uses multinomial sampling. If False, uses greedy
decoding (argmax). Defaults to False.
labels_trie: Optional trie structure for constrained decoding. At each
step, only tokens that follow valid trie paths are allowed.
Defaults to None.
**kwargs: Additional keyword arguments (currently unused).
Returns:
Generated token IDs of shape (batch_size, generated_len) where generated_len
varies per sequence based on when EOS is reached. Sequences are padded to
the same length with pad_token_id.
"""
model = self.decoder_layer.model
device, (B, L0, _) = inputs_embeds.device, inputs_embeds.shape
cfg = model.config
eos_token_id = eos_token_id or cfg.eos_token_id
pad_token_id = pad_token_id or cfg.pad_token_id or eos_token_id
# prefix mask
if attention_mask is None:
attention_mask = torch.ones(B, L0, dtype=torch.long, device=device)
out = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, use_cache=True)
past_key_values = out.past_key_values
next_logits = out.logits[:, -1] # (B, V)
unfinished = torch.ones(B, dtype=torch.bool, device=device)
generated = [[] for _ in range(B)]
for _ in range(max_new_tokens):
if labels_trie is not None:
V = next_logits.shape[1]
mask_tensor = torch.full((B, V), -float("inf"), device=device)
for b in range(B):
if unfinished[b]:
current_seq = generated[b] # Tokens generated so far
allowed_tokens = labels_trie.get(current_seq)
if len(allowed_tokens) == 0:
allowed_tokens = [eos_token_id]
mask_tensor[b, allowed_tokens] = 0
else:
mask_tensor[b, :] = 0
next_logits = next_logits + mask_tensor
if temperature != 1.0:
next_logits = next_logits / temperature
if do_sample:
probs = torch.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
else:
next_token = next_logits.argmax(dim=-1, keepdim=True) # (B, 1)
for b in range(B):
if unfinished[b]:
generated[b].append(next_token[b, 0].item())
eos_hit = next_token.squeeze() == eos_token_id
unfinished = unfinished & ~eos_hit
if not unfinished.any():
break
next_token = next_token.masked_fill(~unfinished.unsqueeze(1), pad_token_id)
attention_mask = torch.cat(
[attention_mask, torch.ones(B, 1, dtype=torch.long, device=device)],
dim=1,
)
out = model(
input_ids=next_token, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=True
)
past_key_values = out.past_key_values
next_logits = out.logits[:, -1]
max_len = max(len(seq) for seq in generated)
out_ids = torch.full((B, max_len), pad_token_id, dtype=torch.long, device=device)
for b, seq in enumerate(generated):
if seq:
out_ids[b, : len(seq)] = torch.tensor(seq, device=device)
return out_ids
[docs]
@torch.inference_mode()
def generate_from_embeds(
self,
inputs_embeds: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
max_new_tokens: int = 32,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
temperature: float = 1.0,
do_sample: bool = False,
num_return_sequences: int = 1,
labels_trie: Optional[LabelsTrie] = None,
**kwargs: Any,
) -> torch.LongTensor:
"""Generation from embeddings using Hugging Face's generate API.
This method wraps the Hugging Face generate() function to support generation
from embeddings with optional trie-based prefix constraints. It provides a
more feature-complete interface than generate_from_embeds_custom but may be
less flexible for custom generation logic.
Args:
inputs_embeds: Input embeddings of shape (batch_size, prefix_len, hidden_size)
serving as the generation prefix.
attention_mask: Optional attention mask of shape (batch_size, prefix_len).
If None, creates a mask of all ones. Defaults to None.
max_new_tokens: Maximum number of new tokens to generate. Defaults to 32.
eos_token_id: Token ID marking end of sequence. If None, uses model's
default. Defaults to None.
pad_token_id: Token ID for padding. If None, uses model's default or
eos_token_id. Defaults to None.
temperature: Sampling temperature for controlling randomness. Defaults to 1.0.
do_sample: If True, uses sampling. If False, uses greedy/beam search.
Defaults to False.
num_return_sequences: Number of sequences to generate per input. Also
sets num_beams when > 1. Defaults to 1.
labels_trie: Optional trie structure for constrained decoding via
prefix_allowed_tokens_fn. Defaults to None.
**kwargs: Additional keyword arguments passed to model.generate().
Returns:
Generated token IDs of shape (batch_size * num_return_sequences, total_len)
where total_len = prefix_len + generated_len. Includes both the input
prefix and newly generated tokens.
"""
model = self.decoder_layer.model
inputs_embeds = inputs_embeds.to(dtype=model.dtype)
if attention_mask is not None:
attention_mask = attention_mask.to(dtype=model.dtype)
device, (B, L0, _) = inputs_embeds.device, inputs_embeds.shape
cfg = model.config
# Set token IDs if not provided
eos_token_id = eos_token_id or cfg.eos_token_id
pad_token_id = pad_token_id or cfg.pad_token_id or eos_token_id
# Create attention mask if not provided
if attention_mask is None:
attention_mask = torch.ones(B, L0, dtype=torch.long, device=device)
# Define prefix-constrained token function if trie is provided
if labels_trie is not None:
def prefix_allowed_tokens(batch_idx: int, input_ids: torch.Tensor) -> List[int]:
"""Callback function for constrained decoding.
Args:
batch_idx: Index of the sequence in the batch.
input_ids: Currently generated token IDs.
Returns:
List of allowed token IDs for the next position.
"""
current_seq = input_ids.tolist()
allowed_tokens = labels_trie.get(current_seq)
if not allowed_tokens: # Empty or None
allowed_tokens = [eos_token_id]
return allowed_tokens
else:
prefix_allowed_tokens = None
# Generate new tokens using transformer's generate method
generated_ids = model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
past_key_values=None,
max_new_tokens=max_new_tokens,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
temperature=temperature,
do_sample=do_sample,
use_cache=True,
num_return_sequences=num_return_sequences,
num_beams=num_return_sequences,
prefix_allowed_tokens_fn=prefix_allowed_tokens,
logits_processor=LogitsProcessorList(
[
NumericalStabilityProcessor(),
]
),
**kwargs,
)
return generated_ids
[docs]
def generate(self, *args: Any, **kwargs: Any) -> torch.LongTensor:
"""Flexible generation method supporting both embeddings and token IDs.
This method routes to the appropriate generation function based on whether
inputs_embeds is provided. If inputs_embeds is in kwargs, uses
generate_from_embeds(). Otherwise, delegates to the model's native
generate() method.
Args:
*args: Variable positional arguments passed to the generation method.
**kwargs: Variable keyword arguments. If 'inputs_embeds' is present,
routes to generate_from_embeds(), otherwise routes to model.generate().
Returns:
Generated token IDs. Shape depends on the specific generation method used.
"""
if "inputs_embeds" in kwargs:
return self.generate_from_embeds(*args, **kwargs)
else:
return self.decoder_layer.model.generate(*args, **kwargs)
[docs]
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
"""Forward pass through the decoder.
Computes logits for the input sequence without generation.
Args:
*args: Variable positional arguments passed to the decoder layer.
**kwargs: Variable keyword arguments passed to the decoder layer.
Returns:
Logits tensor of shape (batch_size, seq_len, vocab_size).
"""
decoded_embeddings = self.decoder_layer(*args, **kwargs)
return decoded_embeddings