Source code for gliner.training.trainer

"""Custom Trainer implementation with enhanced loss functions and optimizer configuration.

This module extends the Hugging Face Transformers Trainer class to support
custom loss functions (focal loss, label smoothing), flexible learning rates
for different parameter groups, and robust error handling during training.
"""

import logging
from typing import Any, Dict, List, Tuple, Union, Optional
from dataclasses import field, dataclass

import torch
import transformers
from torch import nn
from transformers.trainer import (
    get_parameter_names,
    is_sagemaker_mp_enabled,
)
from transformers.trainer_utils import set_seed

if transformers.utils.is_apex_available():
    from apex import amp

if is_sagemaker_mp_enabled():
    from transformers.trainer_pt_utils import smp_forward_backward
from torch.utils.data import Dataset, DataLoader

ALL_LAYERNORM_LAYERS = [nn.LayerNorm]

logger = logging.getLogger(__name__)


[docs] def seed_worker(_): """Set worker seed during DataLoader initialization. Helper function to ensure reproducibility by seeding each DataLoader worker process with a unique but deterministic seed based on PyTorch's initial seed. Args: _: Worker ID (unused, but required by DataLoader worker_init_fn signature). """ worker_seed = torch.initial_seed() % 2**32 set_seed(worker_seed)
[docs] @dataclass class TrainingArguments(transformers.TrainingArguments): """Extended training arguments with custom loss and optimization parameters. Extends the standard Hugging Face TrainingArguments with additional parameters for focal loss, label smoothing, differential learning rates, and custom negative sampling strategies. Attributes: cache_dir: Directory to cache downloaded models and datasets. optim: Optimizer to use. Defaults to "adamw_torch". others_lr: Optional separate learning rate for non-encoder parameters (e.g., classification heads). If None, uses the main learning rate. others_weight_decay: Weight decay for non-encoder parameters when using others_lr. Defaults to 0.0. focal_loss_alpha: Alpha parameter for focal loss. Values < 0 disable focal loss weighting. Defaults to -1. focal_loss_gamma: Gamma (focusing parameter) for focal loss. Higher values increase focus on hard examples. Defaults to 0. focal_loss_prob_margin: Probability margin for focal loss computation. Defaults to 0. label_smoothing: Label smoothing factor. 0.0 means no smoothing. Defaults to 0. loss_reduction: Reduction method for loss ('sum', 'mean', or 'none'). Defaults to 'sum'. negatives: Ratio of negative samples to use. Defaults to 1.0. masking: Masking strategy for training ('global' or other strategies). Defaults to 'global'. """ cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") others_lr: Optional[float] = None others_weight_decay: Optional[float] = 0.0 focal_loss_alpha: Optional[float] = -1 focal_loss_gamma: Optional[float] = 0 focal_loss_prob_margin: Optional[float] = 0 label_smoothing: Optional[float] = 0 loss_reduction: Optional[str] = "sum" negatives: Optional[float] = 1.0 masking: Optional[str] = "global"
[docs] class Trainer(transformers.Trainer): """Custom Trainer with enhanced loss functions and error handling. Extends the Hugging Face Trainer to support: - Custom loss functions (focal loss, label smoothing) - Differential learning rates for encoder vs. other parameters - Robust error handling with automatic recovery from failed batches - Custom negative sampling and masking strategies - Persistent worker support for data loading The trainer automatically handles CUDA out-of-memory errors and other exceptions during training by skipping problematic batches and continuing. """
[docs] def training_step(self, model, inputs, *args, **kwargs) -> torch.Tensor: """Perform a training step on a batch of inputs. Executes forward pass, loss computation, and backward pass for a single training batch. Includes automatic error handling to skip problematic batches without crashing the training run. Args: model: The model to train. inputs: Dictionary of input tensors and targets for the model. The dictionary will be unpacked before being fed to the model. Most models expect targets under the 'labels' key. *args: Additional positional arguments (unused, for compatibility). **kwargs: Additional keyword arguments (unused, for compatibility). Returns: Training loss tensor for this batch, scaled by gradient accumulation steps. Returns a zero tensor with requires_grad=True if an error occurs. Note: If an exception occurs during the training step, the method prints the error, zeros gradients, clears CUDA cache, and returns a zero loss to allow training to continue. """ model.train() try: inputs = self._prepare_inputs(inputs) if is_sagemaker_mp_enabled(): loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) return loss_mb.reduce_mean().detach().to(self.args.device) with self.compute_loss_context_manager(): loss = self.compute_loss(model, inputs) del inputs torch.cuda.empty_cache() kwargs = {} if self.args.n_gpu > 1: loss = loss.mean() # Average on multi-gpu training if self.use_apex: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: self.accelerator.backward(loss, **kwargs) return loss.detach() / self.args.gradient_accumulation_steps except Exception as e: logger.info("Skipping iteration due to error: %s", e) model.zero_grad(set_to_none=True) torch.cuda.empty_cache() # Safely get device for DataParallel or normal model _model = getattr(model, "module", model) device = next(_model.parameters()).device return torch.tensor(0.0, requires_grad=True, device=device)
[docs] def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): """Save the trained model to a directory. Args: output_dir: Directory path where the model should be saved. If None, uses the default output directory from training arguments. _internal_call: Whether this is an internal call from the Trainer. Used for compatibility with the parent class. """ self.model.save_pretrained(output_dir)
[docs] def compute_loss(self, model, inputs): """Compute loss using custom loss functions. Performs forward pass with custom loss parameters including focal loss, label smoothing, and negative sampling configurations from training arguments. Args: model: The model to compute loss for. inputs: Dictionary of input tensors including features and labels. Returns: Computed loss tensor. Note: The loss function parameters (alpha, gamma, label_smoothing, etc.) are passed to the model's forward method, so the model must support these keyword arguments. """ # Forward pass outputs = model( alpha=self.args.focal_loss_alpha, gamma=self.args.focal_loss_gamma, prob_margin=self.args.focal_loss_prob_margin, label_smoothing=self.args.label_smoothing, reduction=self.args.loss_reduction, negatives=self.args.negatives, masking=self.args.masking, **inputs, ) loss = outputs.loss return loss
[docs] def create_optimizer(self): """Create and configure the optimizer with parameter groups. Sets up the optimizer with support for: - Separate learning rates for encoder and non-encoder parameters - Weight decay only for non-bias and non-LayerNorm parameters - Custom weight decay values for different parameter groups Returns: Configured optimizer instance. Note: If self.args.others_lr is set, creates four parameter groups: 1. Non-encoder parameters with weight decay 2. Non-encoder parameters without weight decay 3. Encoder parameters with weight decay 4. Encoder parameters without weight decay Otherwise, creates two standard parameter groups with and without weight decay. """ if is_sagemaker_mp_enabled(): return super().create_optimizer() opt_model = self.model if self.optimizer is None: decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) decay_parameters = [name for name in decay_parameters if "bias" not in name] if self.args.others_lr is not None: encoder_parameters = [name for name, _ in opt_model.named_parameters() if "token_rep_layer" in name] optimizer_grouped_parameters = [ { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in encoder_parameters and p.requires_grad) ], "weight_decay": self.args.others_weight_decay, "lr": self.args.others_lr, }, { "params": [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in encoder_parameters and p.requires_grad) ], "weight_decay": 0.0, "lr": self.args.others_lr, }, { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in encoder_parameters and p.requires_grad) ], "weight_decay": self.args.weight_decay, }, { "params": [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in encoder_parameters and p.requires_grad) ], "weight_decay": 0.0, }, ] else: optimizer_grouped_parameters = [ { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) ], "weight_decay": self.args.weight_decay, }, { "params": [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) ], "weight_decay": 0.0, }, ] optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) return self.optimizer
[docs] def prediction_step( self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: """Perform an evaluation step on the model using inputs. Executes a single forward pass for evaluation without computing gradients. Args: model: The model to evaluate. inputs: Dictionary of input tensors and targets for the model. The dictionary will be unpacked before being fed to the model. Most models expect targets under the 'labels' key. prediction_loss_only: If True, only returns the loss and ignores logits and labels. ignore_keys: Optional list of keys in the model output dictionary that should be ignored when gathering predictions. Currently unused. Returns: A tuple of (loss, logits, labels): - loss: Loss tensor if computed, None otherwise - logits: Model predictions if prediction_loss_only is False, None otherwise - labels: Ground truth labels if prediction_loss_only is False, None otherwise """ with torch.no_grad(): loss = None with self.compute_loss_context_manager(): outputs = model(**inputs) loss = outputs.loss logits = outputs.logits labels = inputs["labels"] if prediction_loss_only: return (loss, None, None) return (loss, logits, labels)
[docs] def get_train_dataloader(self) -> DataLoader: """Create and return the training DataLoader. Constructs a DataLoader with appropriate sampler, collation function, and worker configuration for the training dataset. Includes seeded worker initialization for reproducibility. Returns: Configured and accelerator-prepared training DataLoader. Raises: ValueError: If train_dataset is None. Note: For IterableDataset, sampler and drop_last are not set. For regular datasets, uses the sampler from _get_train_sampler() and applies worker seeding via seed_worker function. """ if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") train_dataset = self.train_dataset data_collator = self.data_collator dataloader_params = { "batch_size": self._train_batch_size, "collate_fn": data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, "persistent_workers": self.args.dataloader_persistent_workers, } if not isinstance(train_dataset, torch.utils.data.IterableDataset): dataloader_params["sampler"] = self._get_train_sampler() dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["worker_init_fn"] = seed_worker dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
[docs] def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader: """Create and return the evaluation DataLoader. Constructs a DataLoader for evaluation with support for persistent workers and multiple evaluation datasets. Caches DataLoaders when persistent workers are enabled to avoid recreation overhead. Args: eval_dataset: Evaluation dataset to use. Can be: - None: Uses self.eval_dataset - str: Uses self.eval_dataset[eval_dataset] (for named eval sets) - Dataset: Overrides self.eval_dataset directly Returns: Configured and accelerator-prepared evaluation DataLoader. Raises: ValueError: If both eval_dataset and self.eval_dataset are None. Note: When persistent_workers is True, DataLoaders are cached in self._eval_dataloaders to avoid worker process recreation between evaluation calls. The cache key is the dataset name (if string) or "eval" for the default dataset. """ if eval_dataset is None and self.eval_dataset is None: raise ValueError("Trainer: evaluation requires an eval_dataset.") dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval" if ( hasattr(self, "_eval_dataloaders") and dataloader_key in self._eval_dataloaders and self.args.dataloader_persistent_workers ): return self.accelerator.prepare(self._eval_dataloaders[dataloader_key]) eval_dataset = ( self.eval_dataset[eval_dataset] if isinstance(eval_dataset, str) else eval_dataset if eval_dataset is not None else self.eval_dataset ) data_collator = self.data_collator dataloader_params = { "batch_size": self.args.eval_batch_size, "collate_fn": data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, "persistent_workers": self.args.dataloader_persistent_workers, } if not isinstance(eval_dataset, torch.utils.data.IterableDataset): dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor # accelerator.free_memory() will destroy the references, so # we need to store the non-prepared version eval_dataloader = DataLoader(eval_dataset, **dataloader_params) if self.args.dataloader_persistent_workers: if hasattr(self, "_eval_dataloaders"): self._eval_dataloaders[dataloader_key] = eval_dataloader else: self._eval_dataloaders = {dataloader_key: eval_dataloader} return self.accelerator.prepare(eval_dataloader)