Source code for gliner.training.trainer

"""Custom Trainer implementation with enhanced loss functions and optimizer configuration.

This module extends the Hugging Face Transformers Trainer class to support
custom loss functions (focal loss, label smoothing), flexible learning rates
for different parameter groups, and robust error handling during training.
"""

import os
import inspect
import logging
from typing import Any, Dict, List, Tuple, Union, Optional
from dataclasses import field, dataclass

import torch
import transformers
from torch import nn
from transformers.trainer import (
    get_parameter_names,
    is_sagemaker_mp_enabled,
)
from transformers.trainer_utils import set_seed

if is_sagemaker_mp_enabled():
    from transformers.trainer_pt_utils import smp_forward_backward
from torch.utils.data import Dataset, DataLoader

ALL_LAYERNORM_LAYERS = [nn.LayerNorm]

logger = logging.getLogger(__name__)


[docs] def seed_worker(_): """Set worker seed during DataLoader initialization. Helper function to ensure reproducibility by seeding each DataLoader worker process with a unique but deterministic seed based on PyTorch's initial seed. Args: _: Worker ID (unused, but required by DataLoader worker_init_fn signature). """ worker_seed = torch.initial_seed() % 2**32 set_seed(worker_seed)
[docs] @dataclass class TrainingArguments(transformers.TrainingArguments): """Extended training arguments with custom loss and optimization parameters. Extends the standard Hugging Face TrainingArguments with additional parameters for focal loss, label smoothing, differential learning rates, and custom negative sampling strategies. Attributes: cache_dir: Directory to cache downloaded models and datasets. optim: Optimizer to use. Defaults to "adamw_torch". others_lr: Optional separate learning rate for non-encoder parameters (e.g., classification heads). If None, uses the main learning rate. others_weight_decay: Weight decay for non-encoder parameters when using others_lr. Defaults to 0.0. focal_loss_alpha: Alpha parameter for focal loss. Values < 0 disable focal loss weighting. Defaults to -1. focal_loss_gamma: Gamma (focusing parameter) for focal loss. Higher values increase focus on hard examples. Defaults to 0. focal_loss_prob_margin: Probability margin for focal loss computation. Defaults to 0. label_smoothing: Label smoothing factor. 0.0 means no smoothing. Defaults to 0. loss_reduction: Reduction method for loss ('sum', 'mean', or 'none'). Defaults to 'sum'. negatives: Ratio of negative samples to use. Defaults to 1.0. masking: Masking strategy for training ('global' or other strategies). Defaults to 'global'. """ cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") others_lr: Optional[float] = None others_weight_decay: Optional[float] = 0.0 focal_loss_alpha: Optional[float] = -1 focal_loss_gamma: Optional[float] = 0 focal_loss_prob_margin: Optional[float] = 0 label_smoothing: Optional[float] = 0 loss_reduction: Optional[str] = "sum" negatives: Optional[float] = 1.0 masking: Optional[str] = "global"
[docs] class Trainer(transformers.Trainer): """ Transformers v4/v5 compatible custom Trainer. - v5-safe method signatures (num_items_in_batch) - no hard dependency on self.use_apex - skips only OOM by default (other exceptions are raised so you don't silently get 0 loss) """ def _save(self, output_dir: str = None, state_dict=None): # called by HF during checkpoint saves if not self.args.should_save: return output_dir = output_dir or self.args.output_dir os.makedirs(output_dir, exist_ok=True) model_to_save = self.accelerator.unwrap_model(self.model) # Prefer safetensors if TrainingArguments says so safe = bool(getattr(self.args, "save_safetensors", False)) sp = getattr(model_to_save, "save_pretrained", None) if sp is None: # last-resort fallback: behave like HF (weights only) torch.save(model_to_save.state_dict(), os.path.join(output_dir, "pytorch_model.bin")) return sp_sig = inspect.signature(sp).parameters kwargs = {} if "safe_serialization" in sp_sig: kwargs["safe_serialization"] = safe if state_dict is not None and "state_dict" in sp_sig: kwargs["state_dict"] = state_dict model_to_save.save_pretrained(output_dir, **kwargs) proc = getattr(self, "processing_class", None) or getattr(self, "tokenizer", None) if proc is not None and hasattr(proc, "save_pretrained"): proc.save_pretrained(output_dir)
[docs] def save_model(self, output_dir: str = None, _internal_call: bool = False): # make final save consistent with checkpoint saving self._save(output_dir)
@property def use_apex(self) -> bool: return bool(getattr(self, "_use_apex", False)) @use_apex.setter def use_apex(self, value: bool) -> None: self._use_apex = bool(value)
[docs] def compute_loss( self, model, inputs, return_outputs: bool = False, num_items_in_batch: Optional[int] = None, ): # Prepare inputs are done in training_step / prediction_step outputs = model( alpha=self.args.focal_loss_alpha, gamma=self.args.focal_loss_gamma, prob_margin=self.args.focal_loss_prob_margin, label_smoothing=self.args.label_smoothing, reduction=self.args.loss_reduction, negatives=self.args.negatives, masking=self.args.masking, **inputs, ) loss = outputs.loss if hasattr(outputs, "loss") else outputs["loss"] return (loss, outputs) if return_outputs else loss
[docs] def training_step( self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None, ) -> torch.Tensor: model.train() inputs = self._prepare_inputs(inputs) # Guardrail: if labels are missing, fail loudly (otherwise you end up with loss=None -> silent 0) if "labels" not in inputs: raise KeyError(f"Batch has no 'labels'. Keys: {list(inputs.keys())}") try: if is_sagemaker_mp_enabled(): loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) return loss_mb.reduce_mean().detach().to(self.args.device) with self.compute_loss_context_manager(): loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) if loss is None: raise RuntimeError("Model returned loss=None (check labels / remove_unused_columns / forward).") # Average on multi-gpu if self.args.n_gpu > 1: loss = loss.mean() # Match upstream Trainer behavior: scale loss for grad accumulation before backward if self.args.gradient_accumulation_steps > 1 and self.deepspeed is None: loss = loss / self.args.gradient_accumulation_steps self.accelerator.backward(loss) return loss.detach() except torch.cuda.OutOfMemoryError as e: logger.warning("Skipping batch due to CUDA OOM: %s", e) model.zero_grad(set_to_none=True) if torch.cuda.is_available(): torch.cuda.empty_cache() return torch.zeros((), device=self.args.device) except RuntimeError as e: # Some OOMs come as RuntimeError("CUDA out of memory...") if "out of memory" in str(e).lower(): logger.warning("Skipping batch due to OOM RuntimeError: %s", e) model.zero_grad(set_to_none=True) if torch.cuda.is_available(): torch.cuda.empty_cache() return torch.zeros((), device=self.args.device) # Anything else: raise, so you don't silently train with zeros again raise
[docs] def create_optimizer(self): if is_sagemaker_mp_enabled(): return super().create_optimizer() opt_model = self.model if self.optimizer is not None: return self.optimizer decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) decay_parameters = [name for name in decay_parameters if "bias" not in name] if self.args.others_lr is not None: encoder_parameters = [name for name, _ in opt_model.named_parameters() if "token_rep_layer" in name] optimizer_grouped_parameters = [ { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in encoder_parameters and p.requires_grad) ], "weight_decay": self.args.others_weight_decay, "lr": self.args.others_lr, }, { "params": [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in encoder_parameters and p.requires_grad) ], "weight_decay": 0.0, "lr": self.args.others_lr, }, { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in encoder_parameters and p.requires_grad) ], "weight_decay": self.args.weight_decay, }, { "params": [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in encoder_parameters and p.requires_grad) ], "weight_decay": 0.0, }, ] else: optimizer_grouped_parameters = [ { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) ], "weight_decay": self.args.weight_decay, }, { "params": [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) ], "weight_decay": 0.0, }, ] # Works across v4/v5 if hasattr(transformers.Trainer, "get_optimizer_cls_and_kwargs"): optimizer_cls, optimizer_kwargs = transformers.Trainer.get_optimizer_cls_and_kwargs(self.args) else: # very old fallback optimizer_cls, optimizer_kwargs = super().get_optimizer_cls_and_kwargs(self.args) self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) return self.optimizer
[docs] def prediction_step( self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: model.eval() inputs = self._prepare_inputs(inputs) with torch.no_grad(): loss, outputs = self.compute_loss(model, inputs, return_outputs=True) logits = getattr(outputs, "logits", None) labels = inputs.get("labels", None) if prediction_loss_only: return (loss, None, None) return (loss, logits, labels)
[docs] def get_train_dataloader(self) -> DataLoader: if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") train_dataset = self.train_dataset data_collator = self.data_collator dataloader_params = { "batch_size": self._train_batch_size, "collate_fn": data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, "persistent_workers": self.args.dataloader_persistent_workers, } if not isinstance(train_dataset, torch.utils.data.IterableDataset): dataloader_params["sampler"] = self._get_train_sampler() dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["worker_init_fn"] = seed_worker dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
[docs] def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader: if eval_dataset is None and self.eval_dataset is None: raise ValueError("Trainer: evaluation requires an eval_dataset.") dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval" if ( hasattr(self, "_eval_dataloaders") and dataloader_key in self._eval_dataloaders and self.args.dataloader_persistent_workers ): return self.accelerator.prepare(self._eval_dataloaders[dataloader_key]) eval_dataset = ( self.eval_dataset[eval_dataset] if isinstance(eval_dataset, str) else eval_dataset if eval_dataset is not None else self.eval_dataset ) dataloader_params = { "batch_size": self.args.eval_batch_size, "collate_fn": self.data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, "persistent_workers": self.args.dataloader_persistent_workers, } if not isinstance(eval_dataset, torch.utils.data.IterableDataset): dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor eval_dataloader = DataLoader(eval_dataset, **dataloader_params) if self.args.dataloader_persistent_workers: if hasattr(self, "_eval_dataloaders"): self._eval_dataloaders[dataloader_key] = eval_dataloader else: self._eval_dataloaders = {dataloader_key: eval_dataloader} return self.accelerator.prepare(eval_dataloader)