Source code for gliner.modeling.span_rep

import torch
import torch.nn.functional as F
from torch import nn

from .layers import create_projection_layer


[docs] class SpanQuery(nn.Module): """Span representation using learned query vectors. This layer learns a set of query vectors, one for each span width, and projects token representations onto these queries to produce span representations. Attributes: query_seg (nn.Parameter): Learnable query matrix of shape [hidden_size, max_width]. project (nn.Sequential): MLP projection layer with ReLU activation. """
[docs] def __init__(self, hidden_size, max_width, trainable=True): """Initialize the SpanQuery layer. Args: hidden_size (int): Dimension of the hidden representations. max_width (int): Maximum span width to represent. trainable (bool, optional): Whether query parameters are trainable. Defaults to True. """ super().__init__() self.query_seg = nn.Parameter(torch.randn(hidden_size, max_width)) nn.init.uniform_(self.query_seg, a=-1, b=1) if not trainable: self.query_seg.requires_grad = False self.project = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.ReLU())
[docs] def forward(self, h, *args): """Compute span representations using query projection. Args: h (torch.Tensor): Token representations of shape [B, L, D]. *args: Additional arguments (unused). Returns: torch.Tensor: Span representations of shape [B, L, max_width, D]. """ # h of shape [B, L, D] # query_seg of shape [D, max_width] span_rep = torch.einsum("bld, ds->blsd", h, self.query_seg) return self.project(span_rep)
[docs] class SpanMLP(nn.Module): """Span representation using a simple MLP. This layer applies a linear transformation to produce multiple span representations per position. Attributes: mlp (nn.Linear): Linear layer that expands hidden_size to hidden_size * max_width. """
[docs] def __init__(self, hidden_size, max_width): """Initialize the SpanMLP layer. Args: hidden_size (int): Dimension of the hidden representations. max_width (int): Maximum span width to represent. """ super().__init__() self.mlp = nn.Linear(hidden_size, hidden_size * max_width)
[docs] def forward(self, h, *args): """Compute span representations using MLP projection. Args: h (torch.Tensor): Token representations of shape [B, L, D]. *args: Additional arguments (unused). Returns: torch.Tensor: Span representations of shape [B, L, max_width, D] with ReLU activation applied. """ # h of shape [B, L, D] # query_seg of shape [D, max_width] B, L, D = h.size() span_rep = self.mlp(h) span_rep = span_rep.view(B, L, -1, D) return span_rep.relu()
[docs] class SpanCAT(nn.Module): """Span representation using concatenation with learned queries. This layer concatenates token representations with learnable query vectors and projects them to produce span representations. Attributes: max_width (int): Maximum span width to represent. query_seg (nn.Parameter): Learnable query matrix of shape [128, max_width]. project (nn.Sequential): MLP projection layer with ReLU activation. """
[docs] def __init__(self, hidden_size, max_width): """Initialize the SpanCAT layer. Args: hidden_size (int): Dimension of the hidden representations. max_width (int): Maximum span width to represent. """ super().__init__() self.max_width = max_width self.query_seg = nn.Parameter(torch.randn(128, max_width)) self.project = nn.Sequential(nn.Linear(hidden_size + 128, hidden_size), nn.ReLU())
[docs] def forward(self, h, *args): """Compute span representations by concatenating with queries. Args: h (torch.Tensor): Token representations of shape [B, L, D]. *args: Additional arguments (unused). Returns: torch.Tensor: Span representations of shape [B, L, max_width, D]. """ # h of shape [B, L, D] # query_seg of shape [D, max_width] B, L, D = h.size() h = h.view(B, L, 1, D).repeat(1, 1, self.max_width, 1) q = self.query_seg.view(1, 1, self.max_width, -1).repeat(B, L, 1, 1) span_rep = torch.cat([h, q], dim=-1) span_rep = self.project(span_rep) return span_rep
[docs] class SpanConvBlock(nn.Module): """A single convolutional block for span representation. This block applies either convolution or pooling operations with a specific kernel size to capture span information. Attributes: conv (nn.Module): Convolution or pooling layer. span_mode (str): Type of operation ('conv_conv', 'conv_max', 'conv_mean', 'conv_sum'). pad (int): Padding size for the operation. """
[docs] def __init__(self, hidden_size, kernel_size, span_mode="conv_normal"): """Initialize the SpanConvBlock. Args: hidden_size (int): Dimension of the hidden representations. kernel_size (int): Size of the convolution/pooling kernel. span_mode (str, optional): Type of operation to use. Options are: 'conv_conv', 'conv_max', 'conv_mean', 'conv_sum'. Defaults to 'conv_normal'. """ super().__init__() if span_mode == "conv_conv": self.conv = nn.Conv1d(hidden_size, hidden_size, kernel_size=kernel_size) # initialize the weights nn.init.kaiming_uniform_(self.conv.weight, nonlinearity="relu") elif span_mode == "conv_max": self.conv = nn.MaxPool1d(kernel_size=kernel_size, stride=1) elif span_mode in {"conv_mean", "conv_sum"}: self.conv = nn.AvgPool1d(kernel_size=kernel_size, stride=1) self.span_mode = span_mode self.pad = kernel_size - 1
[docs] def forward(self, x): """Apply the convolutional block. Args: x (torch.Tensor): Input tensor of shape [B, L, D]. Returns: torch.Tensor: Output tensor of shape [B, L, D]. """ x = torch.einsum("bld->bdl", x) if self.pad > 0: x = F.pad(x, (0, self.pad), "constant", 0) x = self.conv(x) if self.span_mode == "conv_sum": x = x * (self.pad + 1) return torch.einsum("bdl->bld", x)
[docs] class SpanConv(nn.Module): """Span representation using multiple convolutional layers. This layer uses convolutions with different kernel sizes to capture spans of different widths. Attributes: convs (nn.ModuleList): List of convolutional blocks with varying kernel sizes. project (nn.Sequential): MLP projection layer with ReLU activation. """
[docs] def __init__(self, hidden_size, max_width, span_mode): """Initialize the SpanConv layer. Args: hidden_size (int): Dimension of the hidden representations. max_width (int): Maximum span width to represent. span_mode (str): Type of convolution operation to use. """ super().__init__() kernels = [i + 2 for i in range(max_width - 1)] self.convs = nn.ModuleList() for kernel in kernels: self.convs.append(SpanConvBlock(hidden_size, kernel, span_mode)) self.project = nn.Sequential(nn.ReLU(), nn.Linear(hidden_size, hidden_size))
[docs] def forward(self, x, *args): """Compute span representations using multiple convolutions. Args: x (torch.Tensor): Input tensor of shape [B, L, D]. *args: Additional arguments (unused). Returns: torch.Tensor: Span representations of shape [B, L, max_width, D]. """ span_reps = [x] for conv in self.convs: h = conv(x) span_reps.append(h) span_reps = torch.stack(span_reps, dim=-2) return self.project(span_reps)
[docs] class SpanEndpointsBlock(nn.Module): """Extract start and end token representations for spans. This block extracts the first and last token of each span. Attributes: kernel_size (int): The span width (kernel size). """
[docs] def __init__(self, kernel_size): """Initialize the SpanEndpointsBlock. Args: kernel_size (int): The span width to extract endpoints for. """ super().__init__() self.kernel_size = kernel_size
[docs] def forward(self, x): """Extract start and end representations for all spans. Args: x (torch.Tensor): Input tensor of shape [B, L, D]. Returns: torch.Tensor: Start and end representations of shape [B, L, 2, D]. """ B, L, D = x.size() span_idx = torch.LongTensor([[i, i + self.kernel_size - 1] for i in range(L)]).to(x.device) x = F.pad(x, (0, 0, 0, self.kernel_size - 1), "constant", 0) # endrep start_end_rep = torch.index_select(x, dim=1, index=span_idx.view(-1)) start_end_rep = start_end_rep.view(B, L, 2, D) return start_end_rep
[docs] class ConvShare(nn.Module): """Span representation using shared convolution weights. This layer uses a single set of convolution weights shared across different span widths. Attributes: max_width (int): Maximum span width to represent. conv_weigth (nn.Parameter): Shared convolution weights of shape [hidden_size, hidden_size, max_width]. project (nn.Sequential): MLP projection layer with ReLU activation. """
[docs] def __init__(self, hidden_size, max_width): """Initialize the ConvShare layer. Args: hidden_size (int): Dimension of the hidden representations. max_width (int): Maximum span width to represent. """ super().__init__() self.max_width = max_width self.conv_weigth = nn.Parameter(torch.randn(hidden_size, hidden_size, max_width)) nn.init.kaiming_uniform_(self.conv_weigth, nonlinearity="relu") self.project = nn.Sequential(nn.ReLU(), nn.Linear(hidden_size, hidden_size))
[docs] def forward(self, x, *args): """Compute span representations using shared convolutions. Args: x (torch.Tensor): Input tensor of shape [B, L, D]. *args: Additional arguments (unused). Returns: torch.Tensor: Span representations of shape [B, L, max_width, D]. """ span_reps = [] x = torch.einsum("bld->bdl", x) for i in range(self.max_width): pad = i x_i = F.pad(x, (0, pad), "constant", 0) conv_w = self.conv_weigth[:, :, : i + 1] out_i = F.conv1d(x_i, conv_w) span_reps.append(out_i.transpose(-1, -2)) out = torch.stack(span_reps, dim=-2) return self.project(out)
[docs] def extract_elements(sequence, indices): """Extract elements from a sequence using provided indices. Args: sequence (torch.Tensor): Input sequence of shape [B, L, D]. indices (torch.Tensor): Indices to extract, shape [B, K]. Returns: torch.Tensor: Extracted elements of shape [B, K, D]. """ D = sequence.size(-1) # Expand indices to [B, K, D] expanded_indices = indices.unsqueeze(2).expand(-1, -1, D) # Gather the elements extracted_elements = torch.gather(sequence, 1, expanded_indices) return extracted_elements
[docs] class SpanMarker(nn.Module): """Span representation using marker-based approach. This layer projects start and end positions separately and combines them to form span representations. Attributes: max_width (int): Maximum span width to represent. project_start (nn.Sequential): MLP for projecting start positions. project_end (nn.Sequential): MLP for projecting end positions. out_project (nn.Linear): Final projection layer. """
[docs] def __init__(self, hidden_size, max_width, dropout=0.4): """Initialize the SpanMarker layer. Args: hidden_size (int): Dimension of the hidden representations. max_width (int): Maximum span width to represent. dropout (float, optional): Dropout rate. Defaults to 0.4. """ super().__init__() self.max_width = max_width self.project_start = nn.Sequential( nn.Linear(hidden_size, hidden_size * 2, bias=True), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_size * 2, hidden_size, bias=True), ) self.project_end = nn.Sequential( nn.Linear(hidden_size, hidden_size * 2, bias=True), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_size * 2, hidden_size, bias=True), ) self.out_project = nn.Linear(hidden_size * 2, hidden_size, bias=True)
[docs] def forward(self, h, span_idx): """Compute span representations using start and end markers. Args: h (torch.Tensor): Token representations of shape [B, L, D]. span_idx (torch.Tensor): Span indices of shape [B, *, 2] where span_idx[..., 0] are start indices and span_idx[..., 1] are end indices. Returns: torch.Tensor: Span representations of shape [B, L, max_width, D]. """ # h of shape [B, L, D] # query_seg of shape [D, max_width] B, L, D = h.size() # project start and end start_rep = self.project_start(h) end_rep = self.project_end(h) start_span_rep = extract_elements(start_rep, span_idx[:, :, 0]) end_span_rep = extract_elements(end_rep, span_idx[:, :, 1]) # concat start and end cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu() # project cat = self.out_project(cat) # reshape return cat.view(B, L, self.max_width, D)
[docs] class SpanMarkerV0(nn.Module): """Marks and projects span endpoints using an MLP. A cleaner version of SpanMarker using the create_projection_layer utility. Attributes: max_width (int): Maximum span width to represent. project_start (nn.Module): MLP for projecting start positions. project_end (nn.Module): MLP for projecting end positions. out_project (nn.Module): Final projection layer. """
[docs] def __init__(self, hidden_size: int, max_width: int, dropout: float = 0.4): """Initialize the SpanMarkerV0 layer. Args: hidden_size (int): Dimension of the hidden representations. max_width (int): Maximum span width to represent. dropout (float, optional): Dropout rate. Defaults to 0.4. """ super().__init__() self.max_width = max_width self.project_start = create_projection_layer(hidden_size, dropout) self.project_end = create_projection_layer(hidden_size, dropout) self.out_project = create_projection_layer(hidden_size * 2, dropout, hidden_size)
[docs] def forward(self, h: torch.Tensor, span_idx: torch.Tensor) -> torch.Tensor: """Compute span representations using start and end markers. Args: h (torch.Tensor): Token representations of shape [B, L, D]. span_idx (torch.Tensor): Span indices of shape [B, *, 2]. Returns: torch.Tensor: Span representations of shape [B, L, max_width, D]. """ B, L, D = h.size() start_rep = self.project_start(h) end_rep = self.project_end(h) start_span_rep = extract_elements(start_rep, span_idx[:, :, 0]) end_span_rep = extract_elements(end_rep, span_idx[:, :, 1]) cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu() return self.out_project(cat).view(B, L, self.max_width, D)
[docs] class SpanMarkerV1(nn.Module): """Marks span endpoints and augments them with the first-token embedding. For each candidate span we build [ start_proj ‖ end_proj ‖ first_token_proj ] → MLP → span_rep and finally reshape to [B, L, max_width, D]. Attributes: max_width (int): Maximum span width to represent. project_start (nn.Module): MLP for projecting start positions. project_end (nn.Module): MLP for projecting end positions. project_first (nn.Module): MLP for projecting the average token. out_project (nn.Module): Final projection layer. """
[docs] def __init__(self, hidden_size: int, max_width: int, dropout: float = 0.4): """Initialize the SpanMarkerV1 layer. Args: hidden_size (int): Dimension of the hidden representations. max_width (int): Maximum span width to represent. dropout (float, optional): Dropout rate. Defaults to 0.4. """ super().__init__() self.max_width = max_width # Independent projections for the three ingredients self.project_start = create_projection_layer(hidden_size, dropout) self.project_end = create_projection_layer(hidden_size, dropout) self.project_first = create_projection_layer(hidden_size, dropout) # 3 x hidden_size (start + end + first) → hidden_size self.out_project = create_projection_layer(hidden_size * 3, dropout, hidden_size)
[docs] def forward(self, h: torch.Tensor, span_idx: torch.Tensor) -> torch.Tensor: """Compute span representations with average token augmentation. For each span, concatenates start marker, end marker, and average token embedding, then projects to produce the final representation. Args: h (torch.Tensor): Token representations, shape [B, L, D]. span_idx (torch.Tensor): Indices of candidate spans, shape [B, *, 2] (* can be L x max_width or any flattened span dimension). Returns: torch.Tensor: Span representations, shape [B, L, max_width, D]. """ B, L, D = h.size() # Pre-compute per-token projections start_rep = self.project_start(h) # [B, L, D] end_rep = self.project_end(h) # [B, L, D] # Project the first-token embedding once average_token_proj = torch.mean(h, dim=1) # Gather start/end representations for each span start_span_rep = extract_elements(start_rep, span_idx[..., 0]) # [B, S, D] end_span_rep = extract_elements(end_rep, span_idx[..., 1]) # [B, S, D] # Broadcast first-token embedding to every span first_span_rep = average_token_proj.unsqueeze(1).expand_as(start_span_rep) # [B, S, D] # Concatenate and project span_feat = torch.cat((start_span_rep, end_span_rep, first_span_rep), dim=-1).relu() # [B, S, 3D] out = self.out_project(span_feat) # [B, S, D] # Reshape back to [B, L, max_width, D] (S = L x max_width) return out.view(B, L, self.max_width, D)
[docs] class ConvShareV2(nn.Module): """Span representation using shared convolution weights (version 2). Similar to ConvShare but uses Xavier initialization and no projection layer. Attributes: max_width (int): Maximum span width to represent. conv_weigth (nn.Parameter): Shared convolution weights of shape [hidden_size, hidden_size, max_width]. """
[docs] def __init__(self, hidden_size, max_width): """Initialize the ConvShareV2 layer. Args: hidden_size (int): Dimension of the hidden representations. max_width (int): Maximum span width to represent. """ super().__init__() self.max_width = max_width self.conv_weigth = nn.Parameter(torch.randn(hidden_size, hidden_size, max_width)) nn.init.xavier_normal_(self.conv_weigth)
[docs] def forward(self, x, *args): """Compute span representations using shared convolutions. Args: x (torch.Tensor): Input tensor of shape [B, L, D]. *args: Additional arguments (unused). Returns: torch.Tensor: Span representations of shape [B, L, max_width, D]. """ span_reps = [] x = torch.einsum("bld->bdl", x) for i in range(self.max_width): pad = i x_i = F.pad(x, (0, pad), "constant", 0) conv_w = self.conv_weigth[:, :, : i + 1] out_i = F.conv1d(x_i, conv_w) span_reps.append(out_i.transpose(-1, -2)) out = torch.stack(span_reps, dim=-2) return out
[docs] class SpanRepLayer(nn.Module): """Factory class for various span representation approaches. This class provides a unified interface to instantiate different span representation methods based on the specified mode. Attributes: span_rep_layer (nn.Module): The underlying span representation layer. """
[docs] def __init__(self, hidden_size, max_width, span_mode, **kwargs): """Initialize the SpanRepLayer with the specified mode. Args: hidden_size (int): Dimension of the hidden representations. max_width (int): Maximum span width to represent. span_mode (str): Type of span representation to use. Options: - 'marker': SpanMarker - 'markerV0': SpanMarkerV0 - 'markerV1': SpanMarkerV1 - 'query': SpanQuery - 'mlp': SpanMLP - 'cat': SpanCAT - 'conv_conv': SpanConv with convolution - 'conv_max': SpanConv with max pooling - 'conv_mean': SpanConv with mean pooling - 'conv_sum': SpanConv with sum pooling - 'conv_share': ConvShare **kwargs: Additional arguments passed to the span representation layer. Raises: ValueError: If an unknown span_mode is provided. """ super().__init__() if span_mode == "marker": self.span_rep_layer = SpanMarker(hidden_size, max_width, **kwargs) elif span_mode == "markerV0": self.span_rep_layer = SpanMarkerV0(hidden_size, max_width, **kwargs) elif span_mode == "markerV1": self.span_rep_layer = SpanMarkerV1(hidden_size, max_width, **kwargs) elif span_mode == "query": self.span_rep_layer = SpanQuery(hidden_size, max_width, trainable=True) elif span_mode == "mlp": self.span_rep_layer = SpanMLP(hidden_size, max_width) elif span_mode == "cat": self.span_rep_layer = SpanCAT(hidden_size, max_width) elif span_mode == "conv_conv": self.span_rep_layer = SpanConv(hidden_size, max_width, span_mode="conv_conv") elif span_mode == "conv_max": self.span_rep_layer = SpanConv(hidden_size, max_width, span_mode="conv_max") elif span_mode == "conv_mean": self.span_rep_layer = SpanConv(hidden_size, max_width, span_mode="conv_mean") elif span_mode == "conv_sum": self.span_rep_layer = SpanConv(hidden_size, max_width, span_mode="conv_sum") elif span_mode == "conv_share": self.span_rep_layer = ConvShare(hidden_size, max_width) else: raise ValueError(f"Unknown span mode {span_mode}")
[docs] def forward(self, x, *args): """Forward pass through the selected span representation layer. Args: x (torch.Tensor): Input tensor, typically of shape [B, L, D]. *args: Additional arguments passed to the underlying layer. Returns: torch.Tensor: Span representations, typically of shape [B, L, max_width, D]. """ return self.span_rep_layer(x, *args)