import warnings
from typing import List, Union, Literal
from collections import defaultdict
import numpy as np
[docs]
class UndefinedMetricWarning(UserWarning):
pass
def _prf_divide(
numerator: np.ndarray,
denominator: np.ndarray,
metric: Literal["precision", "recall", "f-score"],
modifier: str,
average: str,
warn_for: List[str],
zero_division: Union[str, int] = "warn",
) -> np.ndarray:
"""Performs division and handles divide-by-zero with warnings."""
with np.errstate(divide="ignore", invalid="ignore"):
result = np.true_divide(numerator, denominator)
result[denominator == 0] = 0.0 if zero_division in ["warn", 0] else 1.0
if denominator == 0 and zero_division == "warn" and metric in warn_for:
msg_start = f"{metric.title()}"
if "f-score" in warn_for:
msg_start += " and F-score" if metric in warn_for else "F-score"
msg_start += " are" if "f-score" in warn_for else " is"
_warn_prf(
average=average,
modifier=modifier,
msg_start=msg_start,
result_size=len(result),
)
return result
def _warn_prf(average: str, modifier: str, msg_start: str, result_size: int):
axis0, axis1 = ("label", "sample") if average == "samples" else ("sample", "label")
if result_size == 1:
msg = f"{msg_start} ill-defined and being set to 0.0 due to no {modifier} {axis0}."
else:
msg = f"{msg_start} ill-defined and being set to 0.0 in {axis1}s with no {modifier} {axis0}s."
msg += " Use `zero_division` parameter to control this behavior."
warnings.warn(msg, UndefinedMetricWarning, stacklevel=3)
[docs]
def flatten_for_eval(y_true, y_pred):
all_true = []
all_pred = []
for i, (true, pred) in enumerate(zip(y_true, y_pred)):
all_true.extend([[*t, i] for t in true])
all_pred.extend([[*p, i] for p in pred])
return all_true, all_pred