Source code for gliner.modeling.outputs

from typing import Optional
from dataclasses import dataclass

import torch
from transformers.utils import ModelOutput


[docs] @dataclass class GLiNERBaseOutput(ModelOutput): """Base output class for GLiNER models. This class contains the fundamental outputs produced by GLiNER models, including loss, logits, and embeddings for both prompts (entity types) and input words/tokens. Attributes: loss (Optional[torch.FloatTensor]): The total loss for training. Shape: scalar tensor. logits (Optional[torch.FloatTensor]): The prediction scores for entity spans or other outputs. Shape varies depending on the model configuration, typically [batch_size, num_spans, num_classes] or similar. prompts_embedding (Optional[torch.FloatTensor]): Embeddings for the entity type prompts/labels. Shape: [batch_size, num_classes, hidden_size]. prompts_embedding_mask (Optional[torch.LongTensor]): Attention mask for prompt embeddings. Shape: [batch_size, num_classes]. words_embedding (Optional[torch.FloatTensor]): Embeddings for input words/tokens. Shape: [batch_size, seq_len, hidden_size]. mask (Optional[torch.LongTensor]): Attention mask for input tokens. Shape: [batch_size, seq_len]. """ loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None prompts_embedding: Optional[torch.FloatTensor] = None prompts_embedding_mask: Optional[torch.LongTensor] = None words_embedding: Optional[torch.FloatTensor] = None mask: Optional[torch.LongTensor] = None
[docs] @dataclass class GLiNERDecoderOutput(GLiNERBaseOutput): """Output class for GLiNER models with decoder components. Extends GLiNERBaseOutput with additional decoder-specific outputs, including decoder embeddings and span indices. This is typically used in GLiNER variants that include an explicit decoder module. Attributes: loss (Optional[torch.FloatTensor]): Inherited from GLiNERBaseOutput. logits (Optional[torch.FloatTensor]): Inherited from GLiNERBaseOutput. prompts_embedding (Optional[torch.FloatTensor]): Inherited from GLiNERBaseOutput. prompts_embedding_mask (Optional[torch.LongTensor]): Inherited from GLiNERBaseOutput. words_embedding (Optional[torch.FloatTensor]): Inherited from GLiNERBaseOutput. mask (Optional[torch.LongTensor]): Inherited from GLiNERBaseOutput. decoder_loss (Optional[torch.FloatTensor]): Loss specific to the decoder component. Shape: scalar tensor. decoder_embedding (Optional[torch.FloatTensor]): Output embeddings from the decoder. Shape: [batch_size, num_decoder_tokens, hidden_size]. decoder_embedding_mask (Optional[torch.LongTensor]): Attention mask for decoder embeddings. Shape: [batch_size, num_decoder_tokens]. decoder_span_idx (Optional[torch.LongTensor]): Indices of spans processed by the decoder. Shape: [batch_size, num_spans, 2], where the last dimension contains [start_idx, end_idx]. """ decoder_loss: Optional[torch.FloatTensor] = None decoder_embedding: Optional[torch.FloatTensor] = None decoder_embedding_mask: Optional[torch.LongTensor] = None decoder_span_idx: Optional[torch.LongTensor] = None
[docs] @dataclass class GLiNERRelexOutput(GLiNERBaseOutput): """Output class for GLiNER models with relation extraction. Extends GLiNERBaseOutput with relation-specific outputs for models that perform both entity recognition and relation extraction (Relex). This enables joint modeling of entities and their relationships. Attributes: loss (Optional[torch.FloatTensor]): Inherited from GLiNERBaseOutput. logits (Optional[torch.FloatTensor]): Inherited from GLiNERBaseOutput. prompts_embedding (Optional[torch.FloatTensor]): Inherited from GLiNERBaseOutput. prompts_embedding_mask (Optional[torch.LongTensor]): Inherited from GLiNERBaseOutput. words_embedding (Optional[torch.FloatTensor]): Inherited from GLiNERBaseOutput. mask (Optional[torch.LongTensor]): Inherited from GLiNERBaseOutput. rel_idx (Optional[torch.LongTensor]): Indices of entity pairs for which relations are predicted. Shape: [batch_size, num_relations, 2], where the last dimension contains indices of the two entities. rel_logits (Optional[torch.FloatTensor]): Prediction scores for relations between entity pairs. Shape: [batch_size, num_relations, num_relation_types]. rel_mask (Optional[torch.FloatTensor]): Mask indicating valid relation predictions. Shape: [batch_size, num_relations]. """ rel_idx: Optional[torch.LongTensor] = None rel_logits: Optional[torch.FloatTensor] = None rel_mask: Optional[torch.FloatTensor] = None