Source code for gliner.data_processing.utils

import random
from typing import Dict, List, Tuple, Optional, Sequence

import torch


[docs] def pad_2d_tensor(key_data): """Pad a list of 2D tensors to uniform dimensions. Takes a list of 2D tensors with potentially different shapes and pads them to match the maximum dimensions across all tensors. All tensors are padded with zeros to create a uniform rectangular shape, then stacked into a single 3D tensor with a batch dimension. Args: key_data: List of 2D tensors to pad. Each tensor can have different dimensions, but all must be 2D. Returns: A 3D tensor of shape (batch_size, max_rows, max_cols) containing all input tensors padded and stacked along the batch dimension. Raises: ValueError: If the input list is empty. Example: >>> tensor1 = torch.tensor([[1, 2], [3, 4]]) # 2x2 >>> tensor2 = torch.tensor([[5, 6, 7]]) # 1x3 >>> result = pad_2d_tensor([tensor1, tensor2]) >>> result.shape torch.Size([2, 2, 3]) """ if not key_data: raise ValueError("The input list 'key_data' should not be empty.") # Determine the maximum size along both dimensions max_rows = max(tensor.shape[0] for tensor in key_data) max_cols = max(tensor.shape[1] for tensor in key_data) tensors = [] for tensor in key_data: rows, cols = tensor.shape row_padding = max_rows - rows col_padding = max_cols - cols # Pad the tensor along both dimensions padded_tensor = torch.nn.functional.pad(tensor, (0, col_padding, 0, row_padding), mode="constant", value=0) tensors.append(padded_tensor) # Stack the tensors into a single tensor along a new batch dimension padded_tensors = torch.stack(tensors) return padded_tensors
[docs] def get_negatives(batch_list: List[Dict], sampled_neg: int = 5, key="ner") -> List[str]: """Sample negative entity or relation types from a batch. Extracts all unique entity/relation types from a batch of examples and randomly samples a subset to use as negative types for contrastive learning. This helps the model learn to distinguish between similar but incorrect types. Args: batch_list: List of example dictionaries. Each dictionary should contain the specified key with annotations in the format where the last element of each annotation tuple is the type label. sampled_neg: Maximum number of negative types to sample (default: 5). If fewer unique types exist, all will be returned. key: Dictionary key to access annotations (default: "ner"). Common values are "ner" for entities or "relations" for relation types. Returns: List of randomly sampled type strings. Length will be min(sampled_neg, number of unique types in batch). Example: >>> batch = [{"ner": [(0, 1, "PERSON"), (2, 3, "ORG")]}, {"ner": [(0, 1, "LOC"), (3, 4, "PERSON")]}] >>> negatives = get_negatives(batch, sampled_neg=2, key="ner") >>> len(negatives) <= 2 True """ element_types = set() for b in batch_list: if b.get(key, False): types = {el[-1] for el in b[key]} element_types.update(types) element_types = list(element_types) selected_elements = random.sample(element_types, k=min(sampled_neg, len(element_types))) return selected_elements
[docs] def prepare_word_mask( texts: Sequence[Sequence[str]], tokenized_inputs, *, skip_first_words: Optional[Sequence[int]] = None, token_level: bool = False, ) -> List[List[int]]: """Create word-level masks for subword tokenized sequences. Maps subword tokens back to their original word positions, enabling span extraction at the word level. Each subword token is assigned an integer indicating which word it belongs to (1-indexed), with special tokens and continuation subwords optionally masked as 0. This is essential for span-based NER where predictions are made at the word level but the model processes subword tokens. The mask allows the model to aggregate subword representations into word-level representations. Args: texts: Original text sequences as lists of words, one sequence per example. tokenized_inputs: Tokenized output from a transformer tokenizer with a word_ids() method (e.g., from HuggingFace tokenizers). skip_first_words: Optional number of words to skip at the beginning of each sequence (e.g., prompt words). Must have the same length as texts if provided. Skipped words are masked as 0. token_level: If True, assign a unique mask value to every token of a word (enabling token-level granularity). If False, only the first subword token of each word gets a mask value; continuation tokens are masked as 0 (default: False). Returns: List of word mask lists, one per input sequence. Each mask list has the same length as the corresponding tokenized sequence. Values are: - 0: Special tokens, skipped words, or continuation subwords - 1, 2, 3, ...: Word indices (1-indexed) after skipping Raises: ValueError: If skip_first_words length doesn't match texts length. Example: >>> texts = [["Hello", "world"]] >>> # Assuming tokenizer splits "Hello" -> ["Hel", "##lo"] >>> # and "world" -> ["world"] >>> mask = prepare_word_mask(texts, tokenized_inputs) >>> # Result might be: [[0, 1, 0, 2, 0]] >>> # [CLS, Hel, ##lo, world, SEP] """ n = len(texts) if skip_first_words is None: skip_first_words = [0] * n elif len(skip_first_words) != n: raise ValueError("skip_first_words must have same length as texts") words_masks: List[List[int]] = [] for i in range(n): mask: List[int] = [] prev_word_id: Optional[int] = None seen_words = 0 # counts distinct word_ids we've traversed in this sequence for wid in tokenized_inputs.word_ids(i): if wid is None: # Special tokens (CLS, SEP, PAD, etc.) mask.append(0) elif wid != prev_word_id or token_level: # If we just moved to a new word, update seen_words if wid != prev_word_id: seen_words += 1 if seen_words <= skip_first_words[i]: # This word is in the skip range (e.g., prompt tokens) mask.append(0) else: # 1-based word index after skipping mask.append(seen_words - skip_first_words[i]) else: # same word continuation and token_level=False -> mask as 0 mask.append(0) prev_word_id = wid words_masks.append(mask) return words_masks
[docs] def make_mapping(types: List[str]) -> Tuple[Dict[str, int], Dict[int, str]]: """Create bidirectional mappings between type labels and integer IDs. Generates forward and reverse dictionaries for converting between string labels (e.g., entity or relation types) and integer IDs used in model training. Duplicate types are removed while preserving the order of first occurrence. IDs start from 1 (reserving 0 for no-label/padding). Args: types: List of type label strings. May contain duplicates, which will be removed while preserving order. Returns: Tuple containing: - Forward mapping (Dict[str, int]): Maps type labels to integer IDs starting from 1 - Reverse mapping (Dict[int, str]): Maps integer IDs back to type labels Example: >>> types = ["PERSON", "ORG", "LOC", "PERSON"] # "PERSON" duplicated >>> fwd, rev = make_mapping(types) >>> fwd {'PERSON': 1, 'ORG': 2, 'LOC': 3} >>> rev {1: 'PERSON', 2: 'ORG', 3: 'LOC'} """ # de-duplicate while preserving order uniq = list(dict.fromkeys(types)) fwd = {k: i for i, k in enumerate(uniq, start=1)} rev = {v: k for k, v in fwd.items()} return fwd, rev
[docs] def prepare_span_idx(num_tokens, max_width): """Generate all possible span indices for a sequence. Creates a list of all possible (start, end) span pairs for a sequence, where each span has a width (end - start) less than max_width. This is used in span-based NER models that enumerate and classify all possible spans. The spans follow these conventions: - Start index is inclusive - End index is inclusive (so span (i, i) is a single token) - Spans are generated in left-to-right order, with shorter spans first for each starting position Args: num_tokens: Length of the sequence (number of tokens). max_width: Maximum span width to generate. A span of width w covers w+1 tokens (e.g., width 0 is a single token). Returns: List of (start, end) tuples representing all valid spans. Each tuple contains: - start: Starting token index (0-indexed, inclusive) - end: Ending token index (0-indexed, inclusive) Example: >>> spans = prepare_span_idx(num_tokens=3, max_width=2) >>> spans [(0, 0), (0, 1), (1, 1), (1, 2), (2, 2), (2, 3)] >>> # For sequence ["The", "cat", "sat"]: >>> # (0, 0) = "The" >>> # (0, 1) = "The cat" >>> # (1, 1) = "cat" >>> # (1, 2) = "cat sat" >>> # (2, 2) = "sat" >>> # (2, 3) would be invalid (beyond sequence length) """ span_idx = [(i, i + j) for i in range(num_tokens) for j in range(max_width)] return span_idx