import torch
import torch.nn.functional as F
from torch import nn
from .layers import create_projection_layer
[docs]
class SpanQuery(nn.Module):
"""Span representation using learned query vectors.
This layer learns a set of query vectors, one for each span width, and
projects token representations onto these queries to produce span
representations.
Attributes:
query_seg (nn.Parameter): Learnable query matrix of shape [hidden_size, max_width].
project (nn.Sequential): MLP projection layer with ReLU activation.
"""
[docs]
def __init__(self, hidden_size, max_width, trainable=True):
"""Initialize the SpanQuery layer.
Args:
hidden_size (int): Dimension of the hidden representations.
max_width (int): Maximum span width to represent.
trainable (bool, optional): Whether query parameters are trainable.
Defaults to True.
"""
super().__init__()
self.query_seg = nn.Parameter(torch.randn(hidden_size, max_width))
nn.init.uniform_(self.query_seg, a=-1, b=1)
if not trainable:
self.query_seg.requires_grad = False
self.project = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.ReLU())
[docs]
def forward(self, h, *args):
"""Compute span representations using query projection.
Args:
h (torch.Tensor): Token representations of shape [B, L, D].
*args: Additional arguments (unused).
Returns:
torch.Tensor: Span representations of shape [B, L, max_width, D].
"""
# h of shape [B, L, D]
# query_seg of shape [D, max_width]
span_rep = torch.einsum("bld, ds->blsd", h, self.query_seg)
return self.project(span_rep)
[docs]
class SpanMLP(nn.Module):
"""Span representation using a simple MLP.
This layer applies a linear transformation to produce multiple span
representations per position.
Attributes:
mlp (nn.Linear): Linear layer that expands hidden_size to
hidden_size * max_width.
"""
[docs]
def __init__(self, hidden_size, max_width):
"""Initialize the SpanMLP layer.
Args:
hidden_size (int): Dimension of the hidden representations.
max_width (int): Maximum span width to represent.
"""
super().__init__()
self.mlp = nn.Linear(hidden_size, hidden_size * max_width)
[docs]
def forward(self, h, *args):
"""Compute span representations using MLP projection.
Args:
h (torch.Tensor): Token representations of shape [B, L, D].
*args: Additional arguments (unused).
Returns:
torch.Tensor: Span representations of shape [B, L, max_width, D]
with ReLU activation applied.
"""
# h of shape [B, L, D]
# query_seg of shape [D, max_width]
B, L, D = h.size()
span_rep = self.mlp(h)
span_rep = span_rep.view(B, L, -1, D)
return span_rep.relu()
[docs]
class SpanCAT(nn.Module):
"""Span representation using concatenation with learned queries.
This layer concatenates token representations with learnable query vectors
and projects them to produce span representations.
Attributes:
max_width (int): Maximum span width to represent.
query_seg (nn.Parameter): Learnable query matrix of shape [128, max_width].
project (nn.Sequential): MLP projection layer with ReLU activation.
"""
[docs]
def __init__(self, hidden_size, max_width):
"""Initialize the SpanCAT layer.
Args:
hidden_size (int): Dimension of the hidden representations.
max_width (int): Maximum span width to represent.
"""
super().__init__()
self.max_width = max_width
self.query_seg = nn.Parameter(torch.randn(128, max_width))
self.project = nn.Sequential(nn.Linear(hidden_size + 128, hidden_size), nn.ReLU())
[docs]
def forward(self, h, *args):
"""Compute span representations by concatenating with queries.
Args:
h (torch.Tensor): Token representations of shape [B, L, D].
*args: Additional arguments (unused).
Returns:
torch.Tensor: Span representations of shape [B, L, max_width, D].
"""
# h of shape [B, L, D]
# query_seg of shape [D, max_width]
B, L, D = h.size()
h = h.view(B, L, 1, D).repeat(1, 1, self.max_width, 1)
q = self.query_seg.view(1, 1, self.max_width, -1).repeat(B, L, 1, 1)
span_rep = torch.cat([h, q], dim=-1)
span_rep = self.project(span_rep)
return span_rep
[docs]
class SpanConvBlock(nn.Module):
"""A single convolutional block for span representation.
This block applies either convolution or pooling operations with a specific
kernel size to capture span information.
Attributes:
conv (nn.Module): Convolution or pooling layer.
span_mode (str): Type of operation ('conv_conv', 'conv_max', 'conv_mean', 'conv_sum').
pad (int): Padding size for the operation.
"""
[docs]
def __init__(self, hidden_size, kernel_size, span_mode="conv_normal"):
"""Initialize the SpanConvBlock.
Args:
hidden_size (int): Dimension of the hidden representations.
kernel_size (int): Size of the convolution/pooling kernel.
span_mode (str, optional): Type of operation to use. Options are:
'conv_conv', 'conv_max', 'conv_mean', 'conv_sum'.
Defaults to 'conv_normal'.
"""
super().__init__()
if span_mode == "conv_conv":
self.conv = nn.Conv1d(hidden_size, hidden_size, kernel_size=kernel_size)
# initialize the weights
nn.init.kaiming_uniform_(self.conv.weight, nonlinearity="relu")
elif span_mode == "conv_max":
self.conv = nn.MaxPool1d(kernel_size=kernel_size, stride=1)
elif span_mode in {"conv_mean", "conv_sum"}:
self.conv = nn.AvgPool1d(kernel_size=kernel_size, stride=1)
self.span_mode = span_mode
self.pad = kernel_size - 1
[docs]
def forward(self, x):
"""Apply the convolutional block.
Args:
x (torch.Tensor): Input tensor of shape [B, L, D].
Returns:
torch.Tensor: Output tensor of shape [B, L, D].
"""
x = torch.einsum("bld->bdl", x)
if self.pad > 0:
x = F.pad(x, (0, self.pad), "constant", 0)
x = self.conv(x)
if self.span_mode == "conv_sum":
x = x * (self.pad + 1)
return torch.einsum("bdl->bld", x)
[docs]
class SpanConv(nn.Module):
"""Span representation using multiple convolutional layers.
This layer uses convolutions with different kernel sizes to capture
spans of different widths.
Attributes:
convs (nn.ModuleList): List of convolutional blocks with varying kernel sizes.
project (nn.Sequential): MLP projection layer with ReLU activation.
"""
[docs]
def __init__(self, hidden_size, max_width, span_mode):
"""Initialize the SpanConv layer.
Args:
hidden_size (int): Dimension of the hidden representations.
max_width (int): Maximum span width to represent.
span_mode (str): Type of convolution operation to use.
"""
super().__init__()
kernels = [i + 2 for i in range(max_width - 1)]
self.convs = nn.ModuleList()
for kernel in kernels:
self.convs.append(SpanConvBlock(hidden_size, kernel, span_mode))
self.project = nn.Sequential(nn.ReLU(), nn.Linear(hidden_size, hidden_size))
[docs]
def forward(self, x, *args):
"""Compute span representations using multiple convolutions.
Args:
x (torch.Tensor): Input tensor of shape [B, L, D].
*args: Additional arguments (unused).
Returns:
torch.Tensor: Span representations of shape [B, L, max_width, D].
"""
span_reps = [x]
for conv in self.convs:
h = conv(x)
span_reps.append(h)
span_reps = torch.stack(span_reps, dim=-2)
return self.project(span_reps)
[docs]
class SpanEndpointsBlock(nn.Module):
"""Extract start and end token representations for spans.
This block extracts the first and last token of each span.
Attributes:
kernel_size (int): The span width (kernel size).
"""
[docs]
def __init__(self, kernel_size):
"""Initialize the SpanEndpointsBlock.
Args:
kernel_size (int): The span width to extract endpoints for.
"""
super().__init__()
self.kernel_size = kernel_size
[docs]
def forward(self, x):
"""Extract start and end representations for all spans.
Args:
x (torch.Tensor): Input tensor of shape [B, L, D].
Returns:
torch.Tensor: Start and end representations of shape [B, L, 2, D].
"""
B, L, D = x.size()
span_idx = torch.LongTensor([[i, i + self.kernel_size - 1] for i in range(L)]).to(x.device)
x = F.pad(x, (0, 0, 0, self.kernel_size - 1), "constant", 0)
# endrep
start_end_rep = torch.index_select(x, dim=1, index=span_idx.view(-1))
start_end_rep = start_end_rep.view(B, L, 2, D)
return start_end_rep
[docs]
class ConvShare(nn.Module):
"""Span representation using shared convolution weights.
This layer uses a single set of convolution weights shared across
different span widths.
Attributes:
max_width (int): Maximum span width to represent.
conv_weigth (nn.Parameter): Shared convolution weights of shape
[hidden_size, hidden_size, max_width].
project (nn.Sequential): MLP projection layer with ReLU activation.
"""
[docs]
def __init__(self, hidden_size, max_width):
"""Initialize the ConvShare layer.
Args:
hidden_size (int): Dimension of the hidden representations.
max_width (int): Maximum span width to represent.
"""
super().__init__()
self.max_width = max_width
self.conv_weigth = nn.Parameter(torch.randn(hidden_size, hidden_size, max_width))
nn.init.kaiming_uniform_(self.conv_weigth, nonlinearity="relu")
self.project = nn.Sequential(nn.ReLU(), nn.Linear(hidden_size, hidden_size))
[docs]
def forward(self, x, *args):
"""Compute span representations using shared convolutions.
Args:
x (torch.Tensor): Input tensor of shape [B, L, D].
*args: Additional arguments (unused).
Returns:
torch.Tensor: Span representations of shape [B, L, max_width, D].
"""
span_reps = []
x = torch.einsum("bld->bdl", x)
for i in range(self.max_width):
pad = i
x_i = F.pad(x, (0, pad), "constant", 0)
conv_w = self.conv_weigth[:, :, : i + 1]
out_i = F.conv1d(x_i, conv_w)
span_reps.append(out_i.transpose(-1, -2))
out = torch.stack(span_reps, dim=-2)
return self.project(out)
[docs]
class SpanMarker(nn.Module):
"""Span representation using marker-based approach.
This layer projects start and end positions separately and combines them
to form span representations.
Attributes:
max_width (int): Maximum span width to represent.
project_start (nn.Sequential): MLP for projecting start positions.
project_end (nn.Sequential): MLP for projecting end positions.
out_project (nn.Linear): Final projection layer.
"""
[docs]
def __init__(self, hidden_size, max_width, dropout=0.4):
"""Initialize the SpanMarker layer.
Args:
hidden_size (int): Dimension of the hidden representations.
max_width (int): Maximum span width to represent.
dropout (float, optional): Dropout rate. Defaults to 0.4.
"""
super().__init__()
self.max_width = max_width
self.project_start = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 2, bias=True),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_size * 2, hidden_size, bias=True),
)
self.project_end = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 2, bias=True),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_size * 2, hidden_size, bias=True),
)
self.out_project = nn.Linear(hidden_size * 2, hidden_size, bias=True)
[docs]
def forward(self, h, span_idx):
"""Compute span representations using start and end markers.
Args:
h (torch.Tensor): Token representations of shape [B, L, D].
span_idx (torch.Tensor): Span indices of shape [B, *, 2] where
span_idx[..., 0] are start indices and span_idx[..., 1] are
end indices.
Returns:
torch.Tensor: Span representations of shape [B, L, max_width, D].
"""
# h of shape [B, L, D]
# query_seg of shape [D, max_width]
B, L, D = h.size()
# project start and end
start_rep = self.project_start(h)
end_rep = self.project_end(h)
start_span_rep = extract_elements(start_rep, span_idx[:, :, 0])
end_span_rep = extract_elements(end_rep, span_idx[:, :, 1])
# concat start and end
cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu()
# project
cat = self.out_project(cat)
# reshape
return cat.view(B, L, self.max_width, D)
[docs]
class SpanMarkerV0(nn.Module):
"""Marks and projects span endpoints using an MLP.
A cleaner version of SpanMarker using the create_projection_layer utility.
Attributes:
max_width (int): Maximum span width to represent.
project_start (nn.Module): MLP for projecting start positions.
project_end (nn.Module): MLP for projecting end positions.
out_project (nn.Module): Final projection layer.
"""
[docs]
def __init__(self, hidden_size: int, max_width: int, dropout: float = 0.4):
"""Initialize the SpanMarkerV0 layer.
Args:
hidden_size (int): Dimension of the hidden representations.
max_width (int): Maximum span width to represent.
dropout (float, optional): Dropout rate. Defaults to 0.4.
"""
super().__init__()
self.max_width = max_width
self.project_start = create_projection_layer(hidden_size, dropout)
self.project_end = create_projection_layer(hidden_size, dropout)
self.out_project = create_projection_layer(hidden_size * 2, dropout, hidden_size)
[docs]
def forward(self, h: torch.Tensor, span_idx: torch.Tensor) -> torch.Tensor:
"""Compute span representations using start and end markers.
Args:
h (torch.Tensor): Token representations of shape [B, L, D].
span_idx (torch.Tensor): Span indices of shape [B, *, 2].
Returns:
torch.Tensor: Span representations of shape [B, L, max_width, D].
"""
B, L, D = h.size()
start_rep = self.project_start(h)
end_rep = self.project_end(h)
start_span_rep = extract_elements(start_rep, span_idx[:, :, 0])
end_span_rep = extract_elements(end_rep, span_idx[:, :, 1])
cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu()
return self.out_project(cat).view(B, L, self.max_width, D)
[docs]
class SpanMarkerV1(nn.Module):
"""Marks span endpoints and augments them with the first-token embedding.
For each candidate span we build
[ start_proj ‖ end_proj ‖ first_token_proj ] → MLP → span_rep
and finally reshape to [B, L, max_width, D].
Attributes:
max_width (int): Maximum span width to represent.
project_start (nn.Module): MLP for projecting start positions.
project_end (nn.Module): MLP for projecting end positions.
project_first (nn.Module): MLP for projecting the average token.
out_project (nn.Module): Final projection layer.
"""
[docs]
def __init__(self, hidden_size: int, max_width: int, dropout: float = 0.4):
"""Initialize the SpanMarkerV1 layer.
Args:
hidden_size (int): Dimension of the hidden representations.
max_width (int): Maximum span width to represent.
dropout (float, optional): Dropout rate. Defaults to 0.4.
"""
super().__init__()
self.max_width = max_width
# Independent projections for the three ingredients
self.project_start = create_projection_layer(hidden_size, dropout)
self.project_end = create_projection_layer(hidden_size, dropout)
self.project_first = create_projection_layer(hidden_size, dropout)
# 3 x hidden_size (start + end + first) → hidden_size
self.out_project = create_projection_layer(hidden_size * 3, dropout, hidden_size)
[docs]
def forward(self, h: torch.Tensor, span_idx: torch.Tensor) -> torch.Tensor:
"""Compute span representations with average token augmentation.
For each span, concatenates start marker, end marker, and average
token embedding, then projects to produce the final representation.
Args:
h (torch.Tensor): Token representations, shape [B, L, D].
span_idx (torch.Tensor): Indices of candidate spans, shape [B, *, 2]
(* can be L x max_width or any flattened span dimension).
Returns:
torch.Tensor: Span representations, shape [B, L, max_width, D].
"""
B, L, D = h.size()
# Pre-compute per-token projections
start_rep = self.project_start(h) # [B, L, D]
end_rep = self.project_end(h) # [B, L, D]
# Project the first-token embedding once
average_token_proj = torch.mean(h, dim=1)
# Gather start/end representations for each span
start_span_rep = extract_elements(start_rep, span_idx[..., 0]) # [B, S, D]
end_span_rep = extract_elements(end_rep, span_idx[..., 1]) # [B, S, D]
# Broadcast first-token embedding to every span
first_span_rep = average_token_proj.unsqueeze(1).expand_as(start_span_rep) # [B, S, D]
# Concatenate and project
span_feat = torch.cat((start_span_rep, end_span_rep, first_span_rep), dim=-1).relu() # [B, S, 3D]
out = self.out_project(span_feat) # [B, S, D]
# Reshape back to [B, L, max_width, D] (S = L x max_width)
return out.view(B, L, self.max_width, D)
[docs]
class ConvShareV2(nn.Module):
"""Span representation using shared convolution weights (version 2).
Similar to ConvShare but uses Xavier initialization and no projection layer.
Attributes:
max_width (int): Maximum span width to represent.
conv_weigth (nn.Parameter): Shared convolution weights of shape
[hidden_size, hidden_size, max_width].
"""
[docs]
def __init__(self, hidden_size, max_width):
"""Initialize the ConvShareV2 layer.
Args:
hidden_size (int): Dimension of the hidden representations.
max_width (int): Maximum span width to represent.
"""
super().__init__()
self.max_width = max_width
self.conv_weigth = nn.Parameter(torch.randn(hidden_size, hidden_size, max_width))
nn.init.xavier_normal_(self.conv_weigth)
[docs]
def forward(self, x, *args):
"""Compute span representations using shared convolutions.
Args:
x (torch.Tensor): Input tensor of shape [B, L, D].
*args: Additional arguments (unused).
Returns:
torch.Tensor: Span representations of shape [B, L, max_width, D].
"""
span_reps = []
x = torch.einsum("bld->bdl", x)
for i in range(self.max_width):
pad = i
x_i = F.pad(x, (0, pad), "constant", 0)
conv_w = self.conv_weigth[:, :, : i + 1]
out_i = F.conv1d(x_i, conv_w)
span_reps.append(out_i.transpose(-1, -2))
out = torch.stack(span_reps, dim=-2)
return out
[docs]
class SpanRepLayer(nn.Module):
"""Factory class for various span representation approaches.
This class provides a unified interface to instantiate different span
representation methods based on the specified mode.
Attributes:
span_rep_layer (nn.Module): The underlying span representation layer.
"""
[docs]
def __init__(self, hidden_size, max_width, span_mode, **kwargs):
"""Initialize the SpanRepLayer with the specified mode.
Args:
hidden_size (int): Dimension of the hidden representations.
max_width (int): Maximum span width to represent.
span_mode (str): Type of span representation to use. Options:
- 'marker': SpanMarker
- 'markerV0': SpanMarkerV0
- 'markerV1': SpanMarkerV1
- 'query': SpanQuery
- 'mlp': SpanMLP
- 'cat': SpanCAT
- 'conv_conv': SpanConv with convolution
- 'conv_max': SpanConv with max pooling
- 'conv_mean': SpanConv with mean pooling
- 'conv_sum': SpanConv with sum pooling
- 'conv_share': ConvShare
**kwargs: Additional arguments passed to the span representation layer.
Raises:
ValueError: If an unknown span_mode is provided.
"""
super().__init__()
if span_mode == "marker":
self.span_rep_layer = SpanMarker(hidden_size, max_width, **kwargs)
elif span_mode == "markerV0":
self.span_rep_layer = SpanMarkerV0(hidden_size, max_width, **kwargs)
elif span_mode == "markerV1":
self.span_rep_layer = SpanMarkerV1(hidden_size, max_width, **kwargs)
elif span_mode == "query":
self.span_rep_layer = SpanQuery(hidden_size, max_width, trainable=True)
elif span_mode == "mlp":
self.span_rep_layer = SpanMLP(hidden_size, max_width)
elif span_mode == "cat":
self.span_rep_layer = SpanCAT(hidden_size, max_width)
elif span_mode == "conv_conv":
self.span_rep_layer = SpanConv(hidden_size, max_width, span_mode="conv_conv")
elif span_mode == "conv_max":
self.span_rep_layer = SpanConv(hidden_size, max_width, span_mode="conv_max")
elif span_mode == "conv_mean":
self.span_rep_layer = SpanConv(hidden_size, max_width, span_mode="conv_mean")
elif span_mode == "conv_sum":
self.span_rep_layer = SpanConv(hidden_size, max_width, span_mode="conv_sum")
elif span_mode == "conv_share":
self.span_rep_layer = ConvShare(hidden_size, max_width)
else:
raise ValueError(f"Unknown span mode {span_mode}")
[docs]
def forward(self, x, *args):
"""Forward pass through the selected span representation layer.
Args:
x (torch.Tensor): Input tensor, typically of shape [B, L, D].
*args: Additional arguments passed to the underlying layer.
Returns:
torch.Tensor: Span representations, typically of shape
[B, L, max_width, D].
"""
return self.span_rep_layer(x, *args)