from typing import List, Optional
[docs]
class Node:
__slots__ = ("_children", "_key", "_permanent")
[docs]
def __init__(self, key: int, permanent: bool):
self._key = key
self._permanent = permanent
self._children: dict[int, Node] = {}
[docs]
def get_key(self) -> int:
return self._key
[docs]
def is_permanent(self) -> bool:
return self._permanent
[docs]
def add_child(self, child: "Node") -> None:
self._children[child.get_key()] = child
[docs]
def get_child(self, child_key: int) -> Optional["Node"]:
return self._children.get(child_key)
[docs]
def get_children(self) -> List["Node"]:
# Preserve insertion order like iterating a vector
return list(self._children.values())
[docs]
def has_children(self) -> bool:
return bool(self._children)
[docs]
def delete_child(self, child_key: int) -> None:
self._children.pop(child_key, None)
[docs]
class Trie:
[docs]
def __init__(self, init_value: Optional[List[List[int]]] = None):
# Root has key=0 and is permanent (matches the C++ code)
self.root = Node(0, True)
if init_value:
self.add_batch(init_value, permanent=True)
[docs]
def add_batch(self, entities: List[List[int]], permanent: bool) -> None:
for entity in entities:
self.add(entity, permanent)
[docs]
def add(self, entity: List[int], permanent: bool) -> None:
current = self.root
for token_id in entity:
nxt = current.get_child(token_id)
if nxt is None:
nxt = Node(token_id, permanent)
current.add_child(nxt)
current = nxt
[docs]
def get_possible_next_keys(self, entity: List[int]) -> List[int]:
tmp = self.root
for token_id in entity:
nxt = tmp.get_child(token_id)
if nxt is None:
return []
tmp = nxt
return [child.get_key() for child in tmp.get_children()]
[docs]
def get_branch(self, entity: List[int]) -> List[Node]:
# Includes root at position 0 when the full path exists.
branch = [self.root]
tmp = self.root
for token_id in entity:
nxt = tmp.get_child(token_id)
if nxt is None:
return []
tmp = nxt
branch.append(tmp)
return branch
[docs]
def remove_batch(self, entities: List[List[int]]) -> None:
for entity in entities:
self.remove_entity(entity)
[docs]
def remove_entity(self, entity: List[int]) -> None:
branch = self.get_branch(entity)
# If not found or only root, nothing to remove
if len(branch) <= 1:
return
for child, parent in zip(reversed(branch[1:]), reversed(branch[:-1])):
if child.has_children() or child.is_permanent():
break
parent.delete_child(child.get_key())
[docs]
class LabelsTrie:
[docs]
def __init__(self, entities: Optional[List[List[int]]] = None):
"""Initialize the trie.
Args:
entities: Optional initial list of token sequences to add to the trie.
If None or empty, creates an empty trie.
"""
if not entities:
self.trie = Trie()
else:
self.trie = Trie(entities)
[docs]
def add_batch(self, entities: List[List[int]]):
"""Add multiple token sequences to the trie.
Args:
entities: List of token sequences to add.
"""
self.trie.add_batch(entities, permanent=False)
[docs]
def add(self, tokens: List[int]):
"""Add a single token sequence to the trie.
Args:
tokens: Token sequence to add.
"""
self.trie.add(tokens, permanent=False)
[docs]
def get(self, prefix: List[int]) -> List[int]:
"""Get possible next tokens after a given prefix.
Args:
prefix: The token sequence to search for.
Returns:
List of possible next token IDs.
"""
return self.trie.get_possible_next_keys(prefix)
[docs]
def remove_batch(self, entities: List[List[int]]):
"""Remove multiple token sequences from the trie.
Args:
entities: List of token sequences to remove.
"""
self.trie.remove_batch(entities)
[docs]
def remove_entity(self, tokens: List[int]):
"""Remove a single token sequence from the trie.
Args:
tokens: Token sequence to remove.
"""
self.trie.remove_entity(tokens)