AccuracyMetric¶
- class AccuracyMetric(gather_result=False)[源代码]¶
计算准确率的 metric
- 参数:
gather_result (
bool, default:False) -- 在计算 metric 的时候是否自动将各个进程上的输入进行聚合后再输入到 update 之中。
- get_metric()[源代码]¶
get_metric()函数将根据update()函数累计的评价指标统计量来计算最终的评价结果。- 返回类型:
Dict- 返回:
字典形式的评测结果,例如:
{"acc": float, 'total': float, 'correct': float}
- update(result)[源代码]¶
update()函数将针对一个批次的预测结果做评价指标的累计。- 参数:
result (
Dict) --类型为 Dict 且 keys 至少包含["pred", "target"]
pred - 预测的 tensor, tensor 的形状可以是
torch.Size([B,])、torch.Size([B, n_classes])、torch.Size([B, max_len])或torch.Size([B, max_len, n_classes])target - 真实值的 tensor, tensor 的形状可以是
torch.Size([B,])、torch.Size([B, max_len])或torch.Size([B, max_len])seq_len - 序列长度标记, 标记的形状可以是
None, 或者torch.Size([B])。 如果 mask 也被传进来的话seq_len会被忽略