Shortcuts

BaseMetric

class BaseMetric(gather_result=False)[源代码]

Metric 的基类。

参数:

gather_result (bool, default: False) -- 在计算 metric 的时候是否自动将各个进程上的输入进行聚合后再输入到 update 之中。

construct(trainer)[源代码]

将 trainer 传入到 metric 中以便于 gather 时候使用

reset()[源代码]

用来重置 init 中定义的值。在调用 get_metric 方法后会自动调用一次该方法

abstract update(result)[源代码]
参数:

result (Dict) --

经过 gather 后的输入。一般为如下格式的字典:

{
    'logits': [logit1, logit2, ..., logit_dp_size],
    'labels': [label1, label2, ..., label_dp_size]
}

其中 dp_size 为 并行的卡数量

gather(result)[源代码]

将不同进程上的 result 数据聚合在一起,使用了 DDP 情况。

参数:

result (Dict[str, Tensor]) --

:class Trainer 中 eval_fn 返回的结果。类型为 Dict[str, torch.Tensor]。 例如:

result = {'logits': logit, 'labels': label}

返回类型:

Dict[str, List]

返回:

经过 gather 后的结果。类型为 Dict[str, torch.Tensor]。

dp_size 不为 1 时 (即开启了数据并行的情况下), 会把不同 dp 进程的 result 按照第一个维度进行拼接。