EvaluatorForClassfication¶
- class EvaluatorForClassfication(only_latin=True, max_new_tokens=3, *args, **kwargs)[源代码]¶
专为分类任务提供的检验器 :type only_latin:
bool, default:True:param only_latin: 当使用 helm 风格(即基于生成的)分类检验时,是否去除拉丁字母和数字和空格以外的字符。 :type max_new_tokens:int, default:3:param max_new_tokens: 当使用 helm 风格(即基于生成的)分类检验时,生成的 token 数量。一般为了提高准确率可以多生成几个 token。
- static eval_fn(evaluator, batch)[源代码]¶
一次验证的基本单元
- 参数:
evaluator -- 训练器
batch (
Dict) --一个 batch 的数据,类型为长度为
Dict,格式为:- {
- "input_ids": [
torch.tensor([[1, 100, 100, 2]]), torch.tensor([[1, 100, 100, 2]]), torch.tensor([[1, 100, 100, 2]])
], "target": torch.tensor([[0]])
}
- 返回类型:
Any- 返回:
一次验证的结果,为 Dict 类型,该结果会被传入 metric 的 update 方法中