Source code for gliner.modeling.utils

from typing import Tuple

import torch


[docs] def extract_word_embeddings( token_embeds: torch.Tensor, words_mask: torch.Tensor, attention_mask: torch.Tensor, batch_size: int, max_text_length: int, embed_dim: int, text_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Extract word-level embeddings from subword token embeddings. Maps subword token embeddings back to word-level embeddings using a word mask that indicates which subword token corresponds to which word. Only the first subword token of each word is typically used for the word representation. This is essential for span-based NER where predictions are made at the word level but the transformer operates on subword tokens. Args: token_embeds: Subword token embeddings from transformer. Shape: (batch_size, seq_len, embed_dim) words_mask: Mask mapping subword positions to word indices. Non-zero values indicate the word index (1-indexed). Zero values are special tokens or continuation subwords to ignore. Shape: (batch_size, seq_len) attention_mask: Standard attention mask from tokenizer. Shape: (batch_size, seq_len) batch_size: Size of the batch. max_text_length: Maximum number of words across all examples in batch. embed_dim: Embedding dimension size. text_lengths: Number of words in each example. Shape: (batch_size, 1) or (batch_size,) Returns: Tuple containing: - words_embedding: Word-level embeddings extracted from token embeddings. Shape: (batch_size, max_text_length, embed_dim) - mask: Boolean mask indicating valid word positions (True) vs padding (False). Shape: (batch_size, max_text_length) """ words_embedding = torch.zeros( batch_size, max_text_length, embed_dim, dtype=token_embeds.dtype, device=token_embeds.device ) # Find positions where words_mask > 0 (actual word positions) batch_indices, word_idx = torch.where(words_mask > 0) # Convert 1-indexed word mask to 0-indexed positions target_word_idx = words_mask[batch_indices, word_idx] - 1 # Copy token embeddings to word positions words_embedding[batch_indices, target_word_idx] = token_embeds[batch_indices, word_idx] # Create mask for valid word positions aranged_word_idx = torch.arange(max_text_length, dtype=attention_mask.dtype, device=token_embeds.device).expand( batch_size, -1 ) mask = aranged_word_idx < text_lengths return words_embedding, mask
[docs] def extract_prompt_features( class_token_index: int, token_embeds: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor, batch_size: int, embed_dim: int, embed_ent_token: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """Extract prompt/entity type embeddings from special class tokens. Extracts embeddings for entity types or other prompt elements that are marked with special class tokens (e.g., [ENT] tokens). These embeddings represent the entity types that the model should extract. In prompt-based NER, the input is typically: [ENT] Person [ENT] Organization [SEP] John works at Google This function extracts the embeddings corresponding to the [ENT] tokens (or the tokens immediately after them if embed_ent_token=False). Args: class_token_index: Token ID of the special class token to extract (e.g., token ID for [ENT]). token_embeds: Token embeddings from transformer. Shape: (batch_size, seq_len, embed_dim) input_ids: Token IDs from tokenizer. Shape: (batch_size, seq_len) attention_mask: Standard attention mask from tokenizer. Shape: (batch_size, seq_len) batch_size: Size of the batch. embed_dim: Embedding dimension size. embed_ent_token: If True, use the [ENT] token embedding itself. If False, use the embedding of the token immediately after [ENT] (i.e., the entity type name token). Default: True. Returns: Tuple containing: - prompts_embedding: Embeddings for each prompt/entity type. Shape: (batch_size, max_num_types, embed_dim) where max_num_types is the maximum number of entity types across examples in the batch. - prompts_embedding_mask: Mask indicating valid prompt positions (True) vs padding (False). Shape: (batch_size, max_num_types) """ # Find all positions with the class token class_token_mask = input_ids == class_token_index num_class_tokens = torch.sum(class_token_mask, dim=-1, keepdim=True) # Maximum number of class tokens across batch max_embed_dim = num_class_tokens.max() aranged_class_idx = torch.arange(max_embed_dim, dtype=attention_mask.dtype, device=token_embeds.device).expand( batch_size, -1 ) # Find valid positions (not padding) batch_indices, target_class_idx = torch.where(aranged_class_idx < num_class_tokens) _, class_indices = torch.where(class_token_mask) # Optionally shift to token after [ENT] (the entity type name) if not embed_ent_token: class_indices += 1 # Initialize prompt embeddings tensor prompts_embedding = torch.zeros( batch_size, max_embed_dim, embed_dim, dtype=token_embeds.dtype, device=token_embeds.device ) # Create mask for valid (non-padded) positions prompts_embedding_mask = (aranged_class_idx < num_class_tokens).to(attention_mask.dtype) # Extract embeddings at class token positions prompts_embedding[batch_indices, target_class_idx] = token_embeds[batch_indices, class_indices] return prompts_embedding, prompts_embedding_mask
[docs] def extract_prompt_features_and_word_embeddings( class_token_index: int, token_embeds: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor, text_lengths: torch.Tensor, words_mask: torch.Tensor, embed_ent_token: bool = True, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Extract both prompt embeddings and word embeddings in one call. Convenience function that combines extract_prompt_features and extract_word_embeddings to get both prompt/entity type embeddings and word-level text embeddings from a single set of token embeddings. This is the typical use case for prompt-based NER where you need both: 1. Entity type embeddings (from prompt tokens like [ENT]) 2. Word-level text embeddings (from the actual text tokens) Args: class_token_index: Token ID of the special class token (e.g., [ENT]). token_embeds: Token embeddings from transformer. Shape: (batch_size, seq_len, embed_dim) input_ids: Token IDs from tokenizer. Shape: (batch_size, seq_len) attention_mask: Standard attention mask from tokenizer. Shape: (batch_size, seq_len) text_lengths: Number of words in each example. Shape: (batch_size, 1) or (batch_size,) words_mask: Mask mapping subword positions to word indices. Shape: (batch_size, seq_len) embed_ent_token: If True, use [ENT] token embedding. If False, use the token after [ENT] (the entity type name). Default: True. **kwargs: Additional keyword arguments passed to extract_prompt_features. Returns: Tuple containing: - prompts_embedding: Entity type embeddings. Shape: (batch_size, max_num_types, embed_dim) - prompts_embedding_mask: Mask for valid entity type positions. Shape: (batch_size, max_num_types) - words_embedding: Word-level text embeddings. Shape: (batch_size, max_text_length, embed_dim) - mask: Mask for valid word positions. Shape: (batch_size, max_text_length) """ batch_size, _, embed_dim = token_embeds.shape max_text_length = text_lengths.max() # Extract prompt/entity type embeddings prompts_embedding, prompts_embedding_mask = extract_prompt_features( class_token_index, token_embeds, input_ids, attention_mask, batch_size, embed_dim, embed_ent_token, **kwargs ) # Extract word-level embeddings words_embedding, mask = extract_word_embeddings( token_embeds, words_mask, attention_mask, batch_size, max_text_length, embed_dim, text_lengths ) return prompts_embedding, prompts_embedding_mask, words_embedding, mask
[docs] def build_entity_pairs( adj: torch.Tensor, span_rep: torch.Tensor, threshold: float = 0.5, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Build entity pairs for relation extraction based on adjacency scores. Extracts entity pairs (head, tail) where the adjacency score exceeds a threshold, and retrieves their corresponding embeddings. This is used in relation extraction to select which entity pairs should be classified for relation types. The function considers ALL directed pairs (i,j) where i≠j, not just the upper triangle, since relation direction matters (e.g., "founded" vs "founded_by" have opposite directions). Args: adj: Adjacency matrix with scores or probabilities for entity pairs. Shape: (batch_size, num_entities, num_entities) The diagonal (self-pairs) is ignored. Values > threshold indicate potential relations. span_rep: Entity/span embeddings for each entity in the batch. Shape: (batch_size, num_entities, embed_dim) threshold: Minimum adjacency score to consider a pair as a potential relation. Pairs with adj[i,j] > threshold are kept. Default: 0.5. Returns: Tuple containing: - pair_idx: Indices of (head, tail) entity pairs. Shape: (batch_size, max_pairs, 2) Values are entity indices, or -1 for padding positions. - pair_mask: Boolean mask indicating valid pairs (True) vs padding (False). Shape: (batch_size, max_pairs) - head_rep: Embeddings of head entities for each pair. Shape: (batch_size, max_pairs, embed_dim) - tail_rep: Embeddings of tail entities for each pair. Shape: (batch_size, max_pairs, embed_dim) """ B, E, _ = adj.shape device = adj.device D = span_rep.shape[-1] # Generate all possible (i, j) pairs where i != j all_rows = [] all_cols = [] for i in range(E): for j in range(E): if i != j: all_rows.append(i) all_cols.append(j) rows = torch.tensor(all_rows, device=device, dtype=torch.long) cols = torch.tensor(all_cols, device=device, dtype=torch.long) # For each example in batch, find pairs exceeding threshold batch_pair_lists: list[torch.Tensor] = [] for b in range(B): sel = adj[b, rows, cols] > threshold # Boolean mask for valid pairs pairs = torch.stack([rows[sel], cols[sel]], dim=-1) # (num_valid_pairs, 2) batch_pair_lists.append(pairs) # Find maximum number of pairs across batch (for padding) N = max(p.shape[0] for p in batch_pair_lists) if batch_pair_lists else 0 # Handle case where no pairs exceed threshold if N == 0: pair_idx = torch.full((B, 1, 2), -1, dtype=torch.long, device=device) pair_mask = torch.zeros((B, 1), dtype=torch.bool, device=device) head_rep = tail_rep = torch.zeros((B, 1, D), dtype=span_rep.dtype, device=device) return pair_idx, pair_mask, head_rep, tail_rep # Initialize padded tensors pair_idx = torch.full((B, N, 2), -1, dtype=torch.long, device=device) pair_mask = torch.zeros((B, N), dtype=torch.bool, device=device) # Fill in valid pairs for each example for b, pairs in enumerate(batch_pair_lists): m = pairs.shape[0] pair_idx[b, :m] = pairs pair_mask[b, :m] = True # Extract head and tail embeddings using advanced indexing batch_idx = torch.arange(B, device=device).unsqueeze(1) # (B, 1) head_rep = span_rep[batch_idx, pair_idx[..., 0].clamp_min(0)] # (B, N, D) tail_rep = span_rep[batch_idx, pair_idx[..., 1].clamp_min(0)] # (B, N, D) return pair_idx, pair_mask, head_rep, tail_rep