Shortcuts

collie.metrics.classify_f1_pre_rec_metric 源代码

from typing import List, Dict
from collections import Counter

import torch

from collie.log import logger
from collie.metrics.base import BaseMetric
from collie.utils.seq_len_to_mask import seq_len_to_mask

__all__ = ['ClassifyFPreRecMetric']


def _compute_f_pre_rec(beta_square, tp, fn, fp):
    r"""

    :param tp: int, true positive
    :param fn: int, false negative
    :param fp: int, false positive
    :return: (f, pre, rec)
    """
    pre = tp / (fp + tp + 1e-13)
    rec = tp / (fn + tp + 1e-13)
    f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13)

    return f, pre, rec


[文档]class ClassifyFPreRecMetric(BaseMetric): """计算分类结果 **F值** 的 ``Metric``。 :param gather_result: 在计算 metric 的时候是否自动将各个进程上的输入进行聚合后再输入到 update 之中。 :param tag_vocab: 标签的 vocabulary(Dict类型)。默认值为``None``。 若为 ``None`` 则使用数字来作为标签内容,否则使用 vocab 来作为标签内容。 :param only_gross: 是否只计算总的 ``f1``, ``precision``, ``recall`` 的值; 如果为 ``False``,不仅返回总的 ``f1``, ``pre``, ``rec``, 还会返回每个 label 的 ``f1``, ``pre``, ``rec``。 :param f_type: `micro` 或 `macro`。 * `micro` : 通过先计算总体的 TP,FN 和 FP 的数量,再计算 f, precision, recall; * `macro` : 分布计算每个类别的 f, precision, recall,然后做平均(各类别 f 的权重相同)。 :param beta: **f_beta** 分数中的 ``beta`` 值。常用为 ``beta=0.5, 1, 2`` 若 为 0.5 则 **精确率** 的权重高于 **召回率**;若为1,则两者平等;若为2,则 **召回率** 权重高于 **精确率**。**f_beta** 分数的计算公式为: .. math:: f_{beta} = \\frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)} """ def __init__(self, gather_result: bool = False, tag_vocab=None, only_gross: bool = True, f_type='micro', beta=1) -> None: super().__init__(gather_result) self.f_type = f_type self.beta = beta self.beta_square = self.beta**2 self.only_gross = only_gross self.tag_vocab = tag_vocab if self.tag_vocab is not None: self.tag_vocab_revert = {v: k for k,v in self.tag_vocab} self._tp: Counter = Counter() self._fp: Counter = Counter() self._fn: Counter = Counter()
[文档] def reset(self): """重置 ``tp``, ``fp``, ``fn`` 的值。""" self._tp.clear() self._fp.clear() self._fn.clear()
[文档] def get_metric(self) -> Dict: r""" :meth:`get_metric` 函数将根据 :meth:`update` 函数累计的评价指标统计量来计 算最终的评价结果。 :return: 包含以下内容的字典:``{"f1": float, "pre": float, "rec": float}`` """ evaluate_result = {} if not self.only_gross or self.f_type == 'macro': tags = set(self._fn.keys()) tags.update(set(self._fp.keys())) tags.update(set(self._tp.keys())) f_sum = 0 pre_sum = 0 rec_sum = 0 for tag in tags: if self.tag_vocab is not None: tag_name = self.tag_vocab_revert[tag] else: tag_name = int(tag) tp = self._tp[tag] fn = self._fn[tag] fp = self._fp[tag] f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) f_sum += f pre_sum += pre rec_sum += rec if not self.only_gross and tag != '': # tag!=''防止无tag的情况 f_key = 'f-{}'.format(tag_name) pre_key = 'pre-{}'.format(tag_name) rec_key = 'rec-{}'.format(tag_name) evaluate_result[f_key] = f evaluate_result[pre_key] = pre evaluate_result[rec_key] = rec if self.f_type == 'macro': evaluate_result['f'] = f_sum / len(tags) evaluate_result['pre'] = pre_sum / len(tags) evaluate_result['rec'] = rec_sum / len(tags) if self.f_type == 'micro': f, pre, rec = _compute_f_pre_rec(self.beta_square, sum(self._tp.values()), sum(self._fn.values()), sum(self._fp.values())) evaluate_result['f'] = f evaluate_result['pre'] = pre evaluate_result['rec'] = rec for key, value in evaluate_result.items(): evaluate_result[key] = round(value, 6) return evaluate_result
[文档] def update(self, result: Dict): r""" :meth:`update` 函数将针对一个批次的预测结果做评价指标的累计。 :param result: 类型为 Dict 且 keys 至少包含["pred", "target"] * pred - 预测的 tensor, tensor 的形状可以是 ``torch.Size([B,])`` 、``torch.Size([B, n_classes])`` 、 ``torch.Size([B, max_len])`` 或 ``torch.Size([B, max_len, n_classes])`` * target - 真实值的 tensor, tensor 的形状可以是 ``torch.Size([B,])`` 、``torch.Size([B, max_len])`` 或 ``torch.Size([B, max_len])`` * seq_len - 序列长度标记, 标记的形状可以是 ``None``, 或者 ``torch.Size([B])`` 。 如果 mask 也被传进来的话 ``seq_len`` 会被忽略 """ assert "pred" in result and "target" in result, "pred and target must in result, but they not." pred = result['pred'] target = result['target'] # ddp 时候需要手动 gahter 所有数据。 默认输入的类型都是tensor if isinstance(pred, List): pred = torch.stack(pred, dim=0) if isinstance(target, List): target = torch.stack(target, dim=0) seq_len = None if "seq_len" in result: seq_len = result['seq_len'] if seq_len is not None and target.ndim > 1: max_len = target.shape[-1] masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) else: masks = torch.ones_like(target) masks = masks.eq(1) if pred.dim() == target.dim(): if len(pred.flatten()) != len(target.flatten()): raise RuntimeError( 'when pred have same dimensions with target, they should ' 'have same element numbers. while target have element ' f'numbers:{len(pred.flatten())}, pred have element ' f'numbers: {len(target.flatten())}') elif pred.dim() == target.dim()+1: pred = pred.argmax(axis=-1) if seq_len is None and target.dim() > 1: logger.warning_once( 'You are not passing `seq_len` to exclude pad when ' 'calculate accuracy.') else: raise RuntimeError( f'when pred have ' f'size:{pred.shape}, target should have size: {pred.shape} or ' f'{pred.shape[:-1]}, got {target.shape}.') target = target.masked_select(masks) pred = pred.masked_select(masks) target_idxes = set(target.reshape(-1).tolist()) for target_idx in target_idxes: self._tp[target_idx] += torch.sum((pred == target_idx).long().masked_fill(target != target_idx, 0)).item() self._fp[target_idx] += torch.sum((pred == target_idx).long().masked_fill(target == target_idx, 0)).item() self._fn[target_idx] += torch.sum((pred != target_idx).long().masked_fill(target != target_idx, 0)).item()