Shortcuts

collie.metrics.base 源代码

from abc import ABC, abstractmethod
import torch.distributed as dist
from typing import Any, Dict, List, Optional, Sequence
from functools import reduce

import torch

from collie.utils import apply_to_collection

[文档]class BaseMetric(ABC): """ **Metric** 的基类。 :param gather_result: 在计算 metric 的时候是否自动将各个进程上的输入进行聚合后再输入到 update 之中。 """ def __init__(self, gather_result: bool=False) -> None: self.gather_result = gather_result
[文档] def construct(self, trainer): """ 将 trainer 传入到 metric 中以便于 gather 时候使用 """ self.trainer = trainer
[文档] def reset(self): r""" 用来重置 init 中定义的值。在调用 get_metric 方法后会自动调用一次该方法 """ pass
@abstractmethod def get_metric(self) -> Optional[Dict]: raise NotImplementedError()
[文档] @abstractmethod def update(self, result: Dict): r""" :param result: 经过 gather 后的输入。一般为如下格式的字典:: { 'logits': [logit1, logit2, ..., logit_dp_size], 'labels': [label1, label2, ..., label_dp_size] } 其中 ``dp_size`` 为 并行的卡数量 """ raise NotImplementedError
[文档] def gather(self, result: Dict[str, torch.Tensor]) -> Dict[str, List]: r""" 将不同进程上的 result 数据聚合在一起,使用了 DDP 情况。 :param result: :class `Trainer` 中 eval_fn 返回的结果。类型为 Dict[str, torch.Tensor]。 例如:: result = {'logits': logit, 'labels': label} :return: 经过 gather 后的结果。类型为 Dict[str, torch.Tensor]。 当 ``dp_size`` 不为 1 时 (即开启了数据并行的情况下), 会把不同 dp 进程的 ``result`` 按照第一个维度进行拼接。 """ if self.trainer.config.dp_size > 1: group = self.trainer.engine.mpu.get_data_parallel_group() for key in result.keys(): if isinstance(result[key], torch.Tensor): gather_list = [torch.zeros_like(result[key]).to(result[key].dtype).to(result[key].device) for _ in range(self.trainer.config.dp_size)] dist.all_gather(gather_list, result[key], group=group) result[key] = torch.cat(gather_list, dim=0) elif isinstance(result[key], Sequence): gather_list = [None for _ in range(self.trainer.config.dp_size)] dist.all_gather_object(gather_list, result[key], group=group) result[key] = reduce(lambda x, y: list(x) + list(y), gather_list) return result