from __future__ import annotations
import torch
from torch import nn, fft
from torch.nn import functional as F
from ..layers import create_projection_layer
def _split_complex(x):
"""Assume last dim = 2k and split into real / imag parts."""
return torch.chunk(x, 2, dim=-1)
def _split_quaternion(x):
"""Assume last dim = 4k and split into (1,i,j,k) parts."""
return torch.chunk(x, 4, dim=-1)
def _norm_clamp(x, max_norm):
return x.clamp(max=max_norm) if max_norm is not None else x
[docs]
class NormBasedInteraction(nn.Module):
[docs]
def __init__(
self,
dim: int,
p: int = 2,
power: float = 1.0,
clamp_norm: float | None = 10.0,
use_scorer: bool = False,
dropout: float = 0.3,
):
"""
Base class for norm-based KGE interactions.
Args:
dim: Embedding dimension
p: ℓ_p norm (e.g. 1 or 2) used in ‖·‖_p
power: Raise norm to this power before negating
clamp_norm: Optional upper bound for numerical stability
use_scorer: If True, use learned projection instead of norm
dropout: Dropout rate for scorer (if used)
"""
super().__init__()
self.p = p
self.power = power
self.clamp = clamp_norm
# Optional learned scorer instead of norm
if use_scorer:
self.scorer = create_projection_layer(dim, dropout, 1)
else:
self.scorer = None
def _score(self, x):
"""
Score residual vector.
Args:
x: (..., D) residual vector
Returns:
scores: (...) scalar scores (higher = better)
"""
if self.scorer is not None:
# Use learned projection
return self.scorer(x).squeeze(-1)
# Use norm-based scoring
d = torch.linalg.norm(x, ord=self.p, dim=-1)
d = d.pow(self.power)
d = _norm_clamp(d, self.clamp)
return -d # Negative distance (higher = better)
[docs]
class UMInteraction(NormBasedInteraction):
"""Unstructured model ‖h - t‖."""
[docs]
def __init__(self, dim: int = 768, **kwargs):
super().__init__(dim=dim, **kwargs)
[docs]
def forward(self, h, r, t):
return self._score(h - t)
[docs]
class SEInteraction(NormBasedInteraction):
"""Structure Embedding (SE).
Uses head / tail specific diagonal matrices built from relation.
h' = diag(r) · h , t' = diag(r) · t
"""
[docs]
def __init__(self, dim: int = 768, **kwargs):
super().__init__(dim=dim, **kwargs)
[docs]
def forward(self, h, r, t):
diag = torch.diag_embed(r) # (..., D, D)
h_ = torch.matmul(diag, h.unsqueeze(-1)).squeeze(-1)
t_ = torch.matmul(diag, t.unsqueeze(-1)).squeeze(-1)
return self._score(h_ - t_)
[docs]
class TransEInteraction(NormBasedInteraction):
"""TransE ‖h + r − t‖."""
[docs]
def __init__(self, dim: int = 768, p: int = 1, **kwargs):
super().__init__(dim=dim, p=p, **kwargs)
[docs]
def forward(self, h, r, t):
return self._score(h + r - t)
[docs]
class TransHInteraction(NormBasedInteraction):
"""
TransH – project entities to a relation-specific hyperplane.
Learn mappings from base relation r to:
r_tr = W_tr * r + b_tr (translation)
w = W_w * r + b_w (hyperplane normal)
"""
[docs]
def __init__(self, dim: int, p: int = 2, power: float = 1.0, **kwargs):
super().__init__(dim=dim, p=p, power=power, **kwargs)
self.r_to_rtr = nn.Linear(dim, dim)
self.r_to_w = nn.Linear(dim, dim)
[docs]
def forward(self, h, r, t):
# Map base relation vector -> translation & normal
r_tr = self.r_to_rtr(r) # (..., D)
w = self.r_to_w(r) # (..., D)
w = F.normalize(w, dim=-1)
def proj(x):
# Project: x_proj = x - (x·w) w
dot = (x * w).sum(dim=-1, keepdim=True)
return x - dot * w
h_proj = proj(h)
t_proj = proj(t)
return self._score(h_proj + r_tr - t_proj)
[docs]
class TransFInteraction(NormBasedInteraction):
"""
TransF – element-wise relation-specific scaling before translation.
Learn mappings from base relation r to:
r_vec = W_r * r + b_r
alpha = W_alpha * r + b_alpha
beta = W_beta * r + b_beta
Score is ‖(alpha ∘ h) + r_vec − (beta ∘ t)‖_p
"""
[docs]
def __init__(self, dim: int, p: int = 2, power: float = 1.0, **kwargs):
super().__init__(dim=dim, p=p, power=power, **kwargs)
self.r_to_rvec = nn.Linear(dim, dim)
self.r_to_alpha = nn.Linear(dim, dim)
self.r_to_beta = nn.Linear(dim, dim)
# Initialize to start close to plain TransE
# r_vec ≈ r (identity), alpha ≈ 1, beta ≈ 1
if dim == self.r_to_rvec.weight.shape[1]:
nn.init.eye_(self.r_to_rvec.weight)
else:
nn.init.xavier_uniform_(self.r_to_rvec.weight)
nn.init.zeros_(self.r_to_rvec.bias)
nn.init.zeros_(self.r_to_alpha.weight)
nn.init.ones_(self.r_to_alpha.bias)
nn.init.zeros_(self.r_to_beta.weight)
nn.init.ones_(self.r_to_beta.bias)
[docs]
def forward(self, h, r, t):
r_vec = self.r_to_rvec(r) # (..., D)
alpha = self.r_to_alpha(r) # (..., D)
beta = self.r_to_beta(r) # (..., D)
h_ = alpha * h
t_ = beta * t
return self._score(h_ + r_vec - t_)
[docs]
class PairREInteraction(NormBasedInteraction):
"""
PairRE – per-relation element-wise scaling of h & t.
Learn mappings from base relation r to:
alpha = W_alpha * r + b_alpha
beta = W_beta * r + b_beta
"""
[docs]
def __init__(self, dim: int, p: int = 2, power: float = 1.0, **kwargs):
super().__init__(dim=dim, p=p, power=power, **kwargs)
self.r_to_alpha = nn.Linear(dim, dim)
self.r_to_beta = nn.Linear(dim, dim)
[docs]
def forward(self, h, r, t):
alpha = self.r_to_alpha(r) # (..., D)
beta = self.r_to_beta(r) # (..., D)
return self._score(alpha * h - beta * t)
[docs]
class TripleREInteraction(NormBasedInteraction):
"""
TripleRE – LineaRE + scalar γ per relation.
Learn mappings from base relation r to:
alpha = W_alpha * r + b_alpha
beta = W_beta * r + b_beta
delta = W_delta * r + b_delta
gamma = w_gamma^T * r + b_gamma (scalar)
"""
[docs]
def __init__(self, dim: int, p: int = 2, power: float = 1.0, **kwargs):
super().__init__(dim=dim, p=p, power=power, **kwargs)
self.r_to_alpha = nn.Linear(dim, dim)
self.r_to_beta = nn.Linear(dim, dim)
self.r_to_delta = nn.Linear(dim, dim)
self.r_to_gamma = nn.Linear(dim, 1)
[docs]
def forward(self, h, r, t):
alpha = self.r_to_alpha(r) # (..., D)
beta = self.r_to_beta(r) # (..., D)
delta = self.r_to_delta(r) # (..., D)
gamma = self.r_to_gamma(r) # (..., 1)
base_score = self._score(alpha * h + delta - beta * t) # (...)
return gamma.squeeze(-1) * base_score
[docs]
class DistMultInteraction(nn.Module):
"""DistMult – Σ_d h_d r_d t_d."""
[docs]
def forward(self, h, r, t):
return (h * r * t).sum(dim=-1)
[docs]
class SimplEInteraction(nn.Module):
"""SimplE – split every embedding into (forward, backward) halves.
score = ½( ⟨h_f, r_f, t_b⟩ + ⟨t_f, r_b, h_b⟩ )
Requires even dimension.
"""
[docs]
def __init__(self, dim: int = 768):
super().__init__()
if dim % 2 != 0:
raise ValueError(f"SimplE requires even dimension, got {dim}")
[docs]
def forward(self, h, r, t):
h_f, h_b = _split_complex(h)
t_f, t_b = _split_complex(t)
r_f, r_b = _split_complex(r)
s1 = (h_f * r_f * t_b).sum(dim=-1)
s2 = (t_f * r_b * h_b).sum(dim=-1)
return 0.5 * (s1 + s2)
[docs]
class TuckERInteraction(nn.Module):
"""
TuckER – global core tensor W (D_r × D_e × D_e).
"""
[docs]
def __init__(self, d_e: int, d_r: int, dropout: float = 0.2):
super().__init__()
self.d_e = d_e
self.d_r = d_r
self.W = nn.Parameter(torch.empty(d_r, d_e, d_e))
nn.init.xavier_uniform_(self.W.data)
self.bn0 = nn.BatchNorm1d(d_e)
self.bn1 = nn.BatchNorm1d(d_e)
self.dropout = nn.Dropout(dropout)
self.input_dropout = nn.Dropout(dropout)
[docs]
def forward(self, h, r, t):
# Store original shape for reshaping later
orig_shape = h.shape[:-1]
# Flatten to 2D for BatchNorm: (batch_size, d_e)
h_2d = h.reshape(-1, self.d_e)
t_2d = t.reshape(-1, self.d_e)
# Apply BatchNorm and dropout
h_bn = self.bn0(h_2d)
t_bn = self.bn1(t_2d)
# Reshape back to original shape (except last dim)
h_bn = h_bn.view(*orig_shape, self.d_e)
t_bn = t_bn.view(*orig_shape, self.d_e)
# Apply input dropout
h_bn = self.input_dropout(h_bn)
t_bn = self.input_dropout(t_bn)
# Reshape r for matrix multiplication
r_shape = r.shape[:-1]
r_2d = r.reshape(-1, self.d_r)
# Core interaction: r x W → (batch, d_e, d_e)
W_mat = torch.matmul(r_2d, self.W) # (batch, d_e, d_e)
W_mat = W_mat.view(*r_shape, self.d_e, self.d_e)
# Apply dropout on core tensor output
W_mat = self.dropout(W_mat)
# Compute h x W_mat
hr = torch.matmul(h_bn.unsqueeze(-2), W_mat).squeeze(-2)
# Final score
scores = (hr * t_bn).sum(dim=-1)
return scores
[docs]
class DistMAInteraction(nn.Module):
"""DistMA – sum of pairwise dot products."""
[docs]
def forward(self, h, r, t):
return (h * r).sum(dim=-1) + (h * t).sum(dim=-1) + (r * t).sum(dim=-1)
[docs]
class ComplExInteraction(nn.Module):
"""ComplEx – Re(⟨h, r, conj(t)⟩) with complex embeddings.
Requires even dimension.
"""
[docs]
def __init__(self, dim: int = 768):
super().__init__()
if dim % 2 != 0:
raise ValueError(f"ComplEx requires even dimension, got {dim}")
[docs]
def forward(self, h, r, t):
h_re, h_im = _split_complex(h)
r_re, r_im = _split_complex(r)
t_re, t_im = _split_complex(t)
return (h_re * r_re * t_re + h_re * r_im * t_im + h_im * r_re * t_im - h_im * r_im * t_re).sum(dim=-1)
[docs]
class QuatEInteraction(nn.Module):
"""QuatE – use Hamilton product (a,b,c,d)⨂(e,f,g,h).
Requires dimension divisible by 4.
"""
[docs]
def __init__(self, dim: int = 768):
super().__init__()
if dim % 4 != 0:
raise ValueError(f"QuatE requires dimension divisible by 4, got {dim}")
[docs]
def forward(self, h, r, t):
h0, h1, h2, h3 = _split_quaternion(h)
r0, r1, r2, r3 = _split_quaternion(r)
t0, t1, t2, t3 = _split_quaternion(t)
# Hamilton product h ⨂ r
A0 = h0 * r0 - h1 * r1 - h2 * r2 - h3 * r3
A1 = h0 * r1 + h1 * r0 + h2 * r3 - h3 * r2
A2 = h0 * r2 - h1 * r3 + h2 * r0 + h3 * r1
A3 = h0 * r3 + h1 * r2 - h2 * r1 + h3 * r0
return (A0 * t0 + A1 * t1 + A2 * t2 + A3 * t3).sum(dim=-1)
[docs]
class HolEInteraction(nn.Module):
"""HolE – circular correlation ϕ(h, t) · r."""
[docs]
def forward(self, h, r, t):
# Convert to float32 for FFT stability
h = h.to(torch.float32)
r = r.to(torch.float32)
t = t.to(torch.float32)
# FFT-based circular correlation
fft_h = fft.rfft(h, dim=-1)
fft_t = fft.rfft(t, dim=-1)
corr = fft.irfft(fft_h.conj() * fft_t, n=h.shape[-1], dim=-1)
return (corr * r).sum(dim=-1)
[docs]
class ERMLPInteraction(nn.Module):
"""ER-MLP: 2-layer perceptron on concatenated [h, r, t]."""
[docs]
def __init__(self, dim: int, hidden: int = 2048):
super().__init__()
self.mlp = nn.Sequential(nn.Linear(3 * dim, hidden), nn.ReLU(), nn.Linear(hidden, 1))
[docs]
def forward(self, h, r, t):
x = torch.cat([h, r, t], dim=-1)
return self.mlp(x).squeeze(-1)
[docs]
class ConvKBInteraction(nn.Module):
"""ConvKB: Convolutional Knowledge Base interaction (Conv1d version)."""
[docs]
def __init__(self, dim: int, n_filters: int = 32, dropout: float = 0.3, use_bias: bool = True):
super().__init__()
self.dim = dim
self.n_filters = n_filters
# Dropout layer
self.dropout = nn.Dropout(dropout)
# Conv1d over the 3-embedding dimension
self.conv = nn.Conv1d(
in_channels=3, # [h, r, t]
out_channels=n_filters,
kernel_size=1,
stride=1,
padding=0,
bias=use_bias,
)
# Project concatenated feature maps to score
self.fc = nn.Linear(n_filters * dim, 1)
[docs]
def forward(self, h, r, t):
"""
Score triples (h, r, t).
Args:
h: Head entities (..., D)
r: Relations (..., D)
t: Tail entities (..., D)
Returns:
scores: Triple scores (...)
"""
# Store original shape
orig_shape = h.shape[:-1]
# Flatten to batch dimension
batch_size = h.reshape(-1, self.dim).shape[0]
h_flat = h.reshape(batch_size, self.dim)
r_flat = r.reshape(batch_size, self.dim)
t_flat = t.reshape(batch_size, self.dim)
# Stack [h, r, t]: (B, 3, k)
stacked = torch.stack([h_flat, r_flat, t_flat], dim=1) # (B, 3, k)
# Apply 1D convolution: (B, n_filters, k)
x = self.conv(stacked)
x = F.relu(x)
# Flatten: (B, n_filters * k)
x = x.view(batch_size, -1)
# Apply dropout
x = self.dropout(x)
# Project to score
scores = self.fc(x).squeeze(-1) # (B,)
# Reshape back to original shape
scores = scores.view(*orig_shape)
return scores
[docs]
class ConvEInteraction(nn.Module):
"""ConvE: Convolutional interaction matching reference implementation.
Stacks head and relation embeddings vertically and applies 2D convolution.
"""
[docs]
def __init__(
self,
dim: int,
emb_dim1: int,
n_filters: int = 32,
kernel_size: int = 3,
input_drop: float = 0.2,
hidden_drop: float = 0.3,
feat_drop: float = 0.2,
use_bias: bool = True,
):
super().__init__()
self.dim = dim
self.emb_dim1 = emb_dim1
self.emb_dim2 = dim // emb_dim1
if dim % emb_dim1 != 0:
raise ValueError(f"Embedding dim {dim} must be divisible by emb_dim1 {emb_dim1}")
# Dropout layers
self.inp_drop = nn.Dropout(input_drop)
self.hidden_drop = nn.Dropout(hidden_drop)
self.feature_map_drop = nn.Dropout2d(feat_drop)
# Convolutional layer
self.conv1 = nn.Conv2d(1, n_filters, (kernel_size, kernel_size), stride=1, padding=0, bias=use_bias)
conv_out_h = 2 * emb_dim1 - kernel_size + 1
conv_out_w = self.emb_dim2 - kernel_size + 1
hidden_size = n_filters * conv_out_h * conv_out_w
# Fully connected layer to project back to embedding dimension
self.fc = nn.Linear(hidden_size, dim)
[docs]
def forward(self, h, r, t):
"""
Score triples (h, r, t).
Args:
h: Head entities (..., D)
r: Relations (..., D)
t: Tail entities (..., D)
Returns:
scores: Triple scores (...)
"""
# Store original shape for later
orig_shape = h.shape[:-1]
# Flatten to batch dimension
batch_size = h.reshape(-1, self.dim).shape[0]
h_flat = h.reshape(batch_size, self.dim)
r_flat = r.reshape(batch_size, self.dim)
t_flat = t.reshape(batch_size, self.dim)
# Reshape embeddings into 2D "images"
h_img = h_flat.view(batch_size, 1, self.emb_dim1, self.emb_dim2)
r_img = r_flat.view(batch_size, 1, self.emb_dim1, self.emb_dim2)
# Stack head and relation vertically (along height dimension)
stacked = torch.cat([h_img, r_img], dim=2) # (B, 1, 2*emb_dim1, emb_dim2)
# Input dropout
x = self.inp_drop(stacked)
# Convolution
x = self.conv1(x) # (B, n_filters, conv_out_h, conv_out_w)
x = F.relu(x)
# Feature map dropout
x = self.feature_map_drop(x)
# Flatten feature maps
x = x.view(batch_size, -1)
# Fully connected projection
x = self.fc(x) # (B, dim)
x = self.hidden_drop(x)
x = F.relu(x)
# Score against tail entities (dot product)
scores = (x * t_flat).sum(dim=-1)
# Reshape back to original shape
scores = scores.view(*orig_shape)
return scores
[docs]
class TriplesScoreLayer(nn.Module):
"""Wrapper for knowledge graph triple scoring interactions.
Optimized for relation extraction in entity recognition models.
Args:
interaction_mode: The type of interaction to use. Available modes:
- Translational: UM, SE, TransE, TransH, TransF, PairRE, TripleRE
- Semantic: DistMult, SimplE, ComplEx, QuatE, HolE, DistMA
- Neural: TuckER, ERMLP, ConvE, ConvKB
dim: Embedding dimension (required for most interactions).
**kwargs: Extra parameters for specific interactions:
- TuckER: requires d_e, d_r, optional dropout
- ERMLP: optional hidden (default 2048)
- ConvE: requires emb_dim1, optional n_filters, kernel_size, input_drop,
hidden_drop, feat_drop, use_bias
- ConvKB: optional n_filters, dropout, use_bias
- Norm-based (TransE, TransH, etc.): optional p, power, clamp_norm,
use_scorer, dropout
"""
# Define dimension requirements for each interaction
DIMENSION_REQUIREMENTS = {
"ComplEx": lambda d: d % 2 == 0,
"SimplE": lambda d: d % 2 == 0,
"QuatE": lambda d: d % 4 == 0,
}
[docs]
def __init__(self, interaction_mode: str, dim: int = 768, **kwargs):
super().__init__()
self.mode = interaction_mode
self.dim = dim
# Validate dimension requirements
self.validate_dimensions(dim)
# Create the appropriate interaction
if interaction_mode == "UM":
self.interaction = UMInteraction(dim=dim, **kwargs)
elif interaction_mode == "SE":
self.interaction = SEInteraction(dim=dim, **kwargs)
elif interaction_mode == "TransE":
self.interaction = TransEInteraction(dim=dim, **kwargs)
elif interaction_mode == "TransH":
self.interaction = TransHInteraction(dim=dim, **kwargs)
elif interaction_mode == "TransF":
self.interaction = TransFInteraction(dim=dim, **kwargs)
elif interaction_mode == "PairRE":
self.interaction = PairREInteraction(dim=dim, **kwargs)
elif interaction_mode == "TripleRE":
self.interaction = TripleREInteraction(dim=dim, **kwargs)
elif interaction_mode == "DistMult":
self.interaction = DistMultInteraction()
elif interaction_mode == "SimplE":
self.interaction = SimplEInteraction(dim=dim)
elif interaction_mode == "DistMA":
self.interaction = DistMAInteraction()
elif interaction_mode == "ComplEx":
self.interaction = ComplExInteraction(dim=dim)
elif interaction_mode == "QuatE":
self.interaction = QuatEInteraction(dim=dim)
elif interaction_mode == "HolE":
self.interaction = HolEInteraction()
elif interaction_mode == "TuckER":
d_e = kwargs.get("d_e", dim)
d_r = kwargs.get("d_r", dim)
dropout = kwargs.get("dropout", 0.2)
self.interaction = TuckERInteraction(d_e, d_r, dropout)
elif interaction_mode == "ERMLP":
hidden = kwargs.get("hidden", 2048)
self.interaction = ERMLPInteraction(dim, hidden)
elif interaction_mode == "ConvE":
emb_dim1 = kwargs.get("emb_dim1")
if emb_dim1 is None:
raise ValueError("ConvE requires `emb_dim1` argument (height of reshaped embedding).")
n_filters = kwargs.get("n_filters", 9)
kernel_size = kwargs.get("kernel_size", 3)
input_drop = kwargs.get("input_drop", 0.2)
hidden_drop = kwargs.get("hidden_drop", 0.3)
feat_drop = kwargs.get("feat_drop", 0.2)
use_bias = kwargs.get("use_bias", True)
self.interaction = ConvEInteraction(
dim, emb_dim1, n_filters, kernel_size, input_drop, hidden_drop, feat_drop, use_bias
)
elif interaction_mode == "ConvKB":
n_filters = kwargs.get("n_filters", 32)
dropout = kwargs.get("dropout", 0.3)
use_bias = kwargs.get("use_bias", True)
self.interaction = ConvKBInteraction(dim, n_filters, dropout, use_bias)
else:
raise ValueError(f"Unknown interaction mode '{interaction_mode}'.")
[docs]
def validate_dimensions(self, dim: int):
"""
Validate that the embedding dimension meets requirements for this interaction.
Args:
dim: The embedding dimension to validate
Raises:
ValueError: If dimension requirements are not met
"""
if self.mode in self.DIMENSION_REQUIREMENTS:
check = self.DIMENSION_REQUIREMENTS[self.mode]
if not check(dim):
if self.mode in ["ComplEx", "SimplE"]:
msg = f"{self.mode} requires even embedding dimension. Got {dim}."
elif self.mode == "QuatE":
msg = f"{self.mode} requires embedding dimension divisible by 4. Got {dim}."
else:
msg = f"{self.mode} has dimension requirements not satisfied by {dim}."
raise ValueError(msg)
[docs]
def forward(self, h, r, t):
"""
Score triples (h, r, t).
Args:
h: Head entities (..., D)
r: Relations (..., D)
t: Tail entities (..., D)
Returns:
scores: Triple scores (...)
"""
return self.interaction(h, r, t)
[docs]
def forward_batched_relations(self, h, t, rel_embeddings):
"""
Efficiently score entity pairs against all relation types.
Args:
h: Head entities (B, N, D)
t: Tail entities (B, N, D)
rel_embeddings: Relation type embeddings (B, C, D) or (C, D)
Returns:
scores: (B, N, C) scores for each pair against each relation type
"""
B, N, D = h.shape
# Handle both batched and unbatched relation embeddings
if rel_embeddings.dim() == 2:
C, _ = rel_embeddings.shape
rel_embeddings = rel_embeddings.unsqueeze(0).expand(B, C, D)
else:
C = rel_embeddings.shape[1]
# Expand dimensions for broadcasting
h_exp = h.unsqueeze(2).expand(B, N, C, D)
t_exp = t.unsqueeze(2).expand(B, N, C, D)
r_exp = rel_embeddings.unsqueeze(1).expand(B, N, C, D)
# Reshape to (B*N*C, D) for efficient batch processing
h_flat = h_exp.reshape(B * N * C, D)
r_flat = r_exp.reshape(B * N * C, D)
t_flat = t_exp.reshape(B * N * C, D)
# Score all triples at once
scores_flat = self.interaction(h_flat, r_flat, t_flat)
# Reshape back to (B, N, C)
scores = scores_flat.view(B, N, C)
return scores
[docs]
def forward_single_relation(self, h, t, r):
"""
Score entity pairs with a single relation type.
Args:
h: Head entities (B, N, D)
t: Tail entities (B, N, D)
r: Single relation embedding (B, D) or (D,)
Returns:
scores: (B, N) scores for each pair with the given relation
"""
B, N, D = h.shape
# Expand relation to match batch and pair dimensions
if r.dim() == 1:
r = r.unsqueeze(0).unsqueeze(0).expand(B, N, D)
elif r.dim() == 2:
r = r.unsqueeze(1).expand(B, N, D)
# Flatten for scoring
h_flat = h.reshape(B * N, D)
t_flat = t.reshape(B * N, D)
r_flat = r.reshape(B * N, D)
# Score
scores_flat = self.interaction(h_flat, r_flat, t_flat)
# Reshape to (B, N)
scores = scores_flat.view(B, N)
return scores