from typing import Any, Optional
import torch
import torch.nn.functional as F
from torch import nn
[docs]
def compute_degree(A: torch.Tensor) -> torch.Tensor:
"""Compute the degree matrix from an adjacency matrix.
The degree of node i is defined as D_ii = Σ_j A_ij, representing the sum
of edge weights connected to that node.
Args:
A: Adjacency matrix of shape (B, E, E) where B is batch size and E is
the number of entities/nodes.
Returns:
Degree vector of shape (B, E) containing the degree for each node.
Values are clamped to a minimum of 1e-6 to avoid division by zero.
"""
return A.sum(dim=-1).clamp(min=1e-6)
def _apply_pair_mask(A: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
"""Zero out adjacency entries where at least one endpoint is masked.
This ensures that edges to/from padded entities are properly masked out.
An edge (i, j) is kept only if both mask[i] and mask[j] are non-zero.
Args:
A: Adjacency matrix of shape (B, E, E).
mask: Optional boolean/float mask of shape (B, E) where 1 indicates
valid entities and 0 indicates padding. If None, returns A unchanged.
Returns:
Masked adjacency matrix of shape (B, E, E).
"""
if mask is None:
return A
m = mask.float() # (B, E)
return A * m.unsqueeze(2) * m.unsqueeze(1) # (B, E, E)
[docs]
def dot_product_adjacency(
X: torch.Tensor, mask: Optional[torch.Tensor] = None, normalize: bool = False
) -> torch.Tensor:
"""Compute adjacency matrix using dot-product (cosine) similarity.
Computes pairwise similarities between entity embeddings using either
normalized (cosine similarity) or unnormalized dot products, followed
by sigmoid activation.
Args:
X: Entity embeddings of shape (B, E, D) where B is batch size,
E is number of entities, and D is embedding dimension.
mask: Optional mask of shape (B, E) indicating valid entities.
normalize: If True, L2-normalize embeddings before computing similarity
(results in cosine similarity). Defaults to False.
Returns:
Adjacency matrix of shape (B, E, E) with values in (0, 1).
"""
if normalize:
Xn = F.normalize(X, p=2, dim=-1)
else:
Xn = X
A = torch.bmm(Xn, Xn.transpose(1, 2)) # (B, E, E)
A = torch.sigmoid(A)
return _apply_pair_mask(A, mask)
[docs]
class MLPDecoder(nn.Module):
"""MLP-based adjacency decoder using concatenated node pairs.
This decoder concatenates embeddings of node pairs and passes them through
an MLP to predict edge existence. It models pairwise interactions explicitly.
Args:
in_dim: Input embedding dimension.
hidden_dim: Hidden layer dimension for the MLP.
"""
[docs]
def __init__(self, in_dim: int, hidden_dim: int):
"""Initialize the MLP decoder.
Args:
in_dim: Input embedding dimension.
hidden_dim: Hidden layer dimension.
"""
super().__init__()
self.mlp = nn.Sequential(nn.Linear(2 * in_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1))
[docs]
def forward(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Compute adjacency matrix using MLP on concatenated node pairs.
Args:
X: Entity embeddings of shape (B, E, D).
mask: Optional mask of shape (B, E) indicating valid entities.
Returns:
Adjacency matrix of shape (B, E, E) with values in (0, 1).
"""
B, E, D = X.shape
Xi = X.unsqueeze(2).expand(B, E, E, D)
Xj = X.unsqueeze(1).expand(B, E, E, D)
A = torch.sigmoid(self.mlp(torch.cat([Xi, Xj], -1)).squeeze(-1))
return _apply_pair_mask(A, mask)
[docs]
class AttentionAdjacency(nn.Module):
"""Adjacency matrix derived from multi-head attention weights.
Uses PyTorch's multi-head attention mechanism to compute pairwise attention
scores, which are averaged across heads to form the adjacency matrix.
Args:
d_model: Model dimension (embedding size).
nhead: Number of attention heads.
"""
[docs]
def __init__(self, d_model: int, nhead: int):
"""Initialize the attention-based adjacency module.
Args:
d_model: Model dimension for attention.
nhead: Number of attention heads.
"""
super().__init__()
self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
[docs]
def forward(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Compute adjacency matrix from attention weights.
Args:
X: Entity embeddings of shape (B, E, D).
mask: Optional mask of shape (B, E) where 1 indicates valid entities.
Returns:
Adjacency matrix of shape (B, E, E) computed from averaged attention weights.
"""
key_padding = (~mask.bool()) if mask is not None else None
_, w = self.attn(X, X, X, key_padding_mask=key_padding, need_weights=True)
if w.dim() == 4: # (B, h, E, E) - average across heads
w = w.mean(dim=1)
w = _apply_pair_mask(w, mask)
return w
[docs]
class BilinearDecoder(nn.Module):
"""Bilinear decoder for adjacency prediction.
Projects embeddings to a latent space and computes adjacency as the
sigmoid of the bilinear product Z @ Z^T.
Args:
in_dim: Input embedding dimension.
latent_dim: Latent projection dimension.
"""
[docs]
def __init__(self, in_dim: int, latent_dim: int):
"""Initialize the bilinear decoder.
Args:
in_dim: Input embedding dimension.
latent_dim: Dimension of the latent projection space.
"""
super().__init__()
self.proj = nn.Linear(in_dim, latent_dim)
[docs]
def forward(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Compute adjacency using bilinear projection.
Args:
X: Entity embeddings of shape (B, E, D).
mask: Optional mask of shape (B, E) indicating valid entities.
Returns:
Adjacency matrix of shape (B, E, E) with values in (0, 1).
"""
Z = self.proj(X)
A = torch.sigmoid(torch.bmm(Z, Z.transpose(1, 2)))
return _apply_pair_mask(A, mask)
[docs]
class SimpleGCNLayer(nn.Module):
"""Simple Graph Convolutional Network layer with symmetric normalization.
Implements the GCN propagation rule: H = ReLU(D^(-1/2) A D^(-1/2) X W)
where D is the degree matrix, A is the adjacency with self-loops, and W
is a learnable weight matrix.
Args:
in_dim: Input feature dimension.
out_dim: Output feature dimension.
"""
[docs]
def __init__(self, in_dim: int, out_dim: int):
"""Initialize the GCN layer.
Args:
in_dim: Input feature dimension.
out_dim: Output feature dimension.
"""
super().__init__()
self.linear = nn.Linear(in_dim, out_dim)
[docs]
def forward(self, X: torch.Tensor, A: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Apply graph convolution with symmetric normalization.
Args:
X: Node features of shape (B, E, D).
A: Adjacency matrix of shape (B, E, E).
mask: Optional mask of shape (B, E). Self-loops are added only
to valid (non-masked) nodes.
Returns:
Updated node features of shape (B, E, out_dim).
"""
# Keep only valid⇆valid edges & add self-loops on valid nodes
if mask is not None:
A = _apply_pair_mask(A, mask)
A = A + torch.diag_embed(mask.float()) # self-loops only where mask == 1
else:
A = A + torch.eye(A.size(1), device=A.device).unsqueeze(0)
D_inv_sqrt = compute_degree(A).pow(-0.5)
A_norm = torch.diag_embed(D_inv_sqrt) @ A @ torch.diag_embed(D_inv_sqrt)
out = torch.bmm(A_norm, X)
return F.relu(self.linear(out))
[docs]
class GCNDecoder(nn.Module):
"""GCN-based adjacency decoder.
First computes an initial adjacency using dot-product similarity, applies
a GCN layer to update node representations, then predicts the final adjacency
from the updated representations.
Args:
in_dim: Input embedding dimension.
hidden_dim: Hidden dimension for GCN and projection layers.
"""
[docs]
def __init__(self, in_dim: int, hidden_dim: int):
"""Initialize the GCN decoder.
Args:
in_dim: Input embedding dimension.
hidden_dim: Hidden dimension for the GCN layer and projection.
"""
super().__init__()
self.gcn = SimpleGCNLayer(in_dim, hidden_dim)
self.proj = nn.Linear(hidden_dim, hidden_dim)
[docs]
def forward(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Compute adjacency using GCN refinement.
Args:
X: Entity embeddings of shape (B, E, D).
mask: Optional mask of shape (B, E) indicating valid entities.
Returns:
Adjacency matrix of shape (B, E, E) with values in (0, 1).
"""
A0 = dot_product_adjacency(X, mask) # Initial adjacency (already masked)
H = self.gcn(X, A0, mask) # Updated node features
A = torch.sigmoid(torch.bmm(self.proj(H), self.proj(H).transpose(1, 2)))
return _apply_pair_mask(A, mask)
[docs]
class GATDecoder(nn.Module):
"""Graph Attention Network (GAT) based adjacency decoder.
Uses multi-head attention to update node representations, then predicts
adjacency from the transformed features.
Args:
d_model: Model dimension for attention.
nhead: Number of attention heads.
hidden_dim: Hidden dimension for the final projection.
"""
[docs]
def __init__(self, d_model: int, nhead: int, hidden_dim: int):
"""Initialize the GAT decoder.
Args:
d_model: Model dimension for attention mechanism.
nhead: Number of attention heads.
hidden_dim: Hidden dimension for the output projection.
"""
super().__init__()
self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.linear = nn.Linear(d_model, hidden_dim)
[docs]
def forward(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Compute adjacency using GAT refinement.
Args:
X: Entity embeddings of shape (B, E, D).
mask: Optional mask of shape (B, E) indicating valid entities.
Returns:
Adjacency matrix of shape (B, E, E) with values in (0, 1).
"""
key_padding = (~mask.bool()) if mask is not None else None
H, w = self.attn(X, X, X, key_padding_mask=key_padding, need_weights=True)
if w.dim() == 4:
w = w.mean(dim=1) # (B, E, E) - average across heads
Z = self.linear(H)
A = torch.sigmoid(torch.bmm(Z, Z.transpose(1, 2)))
return _apply_pair_mask(A, mask)
[docs]
class RelationsRepLayer(nn.Module):
"""Unified wrapper for different adjacency computation methods.
This layer provides a common interface for various approaches to computing
adjacency matrices from entity embeddings, including:
- 'dot': Dot-product/cosine similarity
- 'mlp': MLP-based pairwise decoder
- 'attention'/'attn': Multi-head attention weights
- 'bilinear': Bilinear projection
- 'gcn': Graph convolutional refinement
- 'gat': Graph attention network
All methods support masked inputs for handling variable-length sequences.
Args:
in_dim: Input embedding dimension.
relation_mode: String specifying the adjacency computation method.
One of: 'dot', 'mlp', 'attention', 'attn', 'bilinear', 'gcn', 'gat'.
**kwargs: Additional arguments passed to specific decoders:
- hidden_dim (int): For 'mlp', 'gcn', 'gat'. Defaults to in_dim.
- nhead (int): For 'attention'/'attn' and 'gat'. Defaults to 8.
- latent_dim (int): For 'bilinear'. Defaults to in_dim.
Raises:
ValueError: If relation_mode is not one of the supported methods.
Example:
>>> layer = RelationsRepLayer(in_dim=128, relation_mode="gcn", hidden_dim=64)
>>> X = torch.randn(4, 10, 128) # (batch=4, entities=10, dim=128)
>>> mask = torch.ones(4, 10) # All entities valid
>>> A = layer(X, mask) # (4, 10, 10) adjacency matrix
"""
[docs]
def __init__(self, in_dim: int, relation_mode: str, **kwargs: Any):
"""Initialize the relations representation layer.
Args:
in_dim: Input embedding dimension.
relation_mode: Adjacency computation method. One of: 'dot', 'mlp',
'attention', 'attn', 'bilinear', 'gcn', 'gat'.
**kwargs: Method-specific arguments (hidden_dim, nhead, latent_dim).
Raises:
ValueError: If relation_mode is not recognized.
"""
super().__init__()
m = relation_mode.lower()
if m == "dot":
class _Dot(nn.Module):
"""Simple wrapper for dot-product adjacency with mask support."""
def forward(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
return dot_product_adjacency(X, mask)
self.relation_rep_layer = _Dot()
elif m == "mlp":
self.relation_rep_layer = MLPDecoder(in_dim, kwargs.get("hidden_dim", in_dim))
elif m in {"attention", "attn"}:
self.relation_rep_layer = AttentionAdjacency(in_dim, kwargs.get("nhead", 8))
elif m == "bilinear":
self.relation_rep_layer = BilinearDecoder(in_dim, kwargs.get("latent_dim", in_dim))
elif m == "gcn":
self.relation_rep_layer = GCNDecoder(in_dim, kwargs.get("hidden_dim", in_dim))
elif m == "gat":
self.relation_rep_layer = GATDecoder(in_dim, kwargs.get("nhead", 8), kwargs.get("hidden_dim", in_dim))
else:
raise ValueError(f"Unknown relation mode: {relation_mode}")
[docs]
def forward(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None, *args: Any, **kwargs: Any) -> torch.Tensor:
"""Compute adjacency matrix from entity embeddings.
Args:
X: Entity/mention embeddings of shape (B, E, D) where B is batch size,
E is number of entities, and D is embedding dimension.
mask: Optional mask of shape (B, E) where 1 indicates valid entities
and 0 indicates padding.
*args: Additional positional arguments (unused, for compatibility).
**kwargs: Additional keyword arguments (unused, for compatibility).
Returns:
Adjacency matrix of shape (B, E, E) with values in [0, 1].
Entries A[b, i, j] represent the predicted edge weight from
entity i to entity j in batch b.
"""
return self.relation_rep_layer(X, *args, mask=mask, **kwargs)