Shortcuts

EvaluatorForClassfication

class EvaluatorForClassfication(only_latin=True, max_new_tokens=3, *args, **kwargs)[源代码]

专为分类任务提供的检验器 :type only_latin: bool, default: True :param only_latin: 当使用 helm 风格(即基于生成的)分类检验时,是否去除拉丁字母和数字和空格以外的字符。 :type max_new_tokens: int, default: 3 :param max_new_tokens: 当使用 helm 风格(即基于生成的)分类检验时,生成的 token 数量。

一般为了提高准确率可以多生成几个 token。

static eval_fn(evaluator, batch)[源代码]

一次验证的基本单元

参数:
  • evaluator -- 训练器

  • batch (Dict) --

    一个 batch 的数据,类型为长度为 Dict,格式为:

    
    
    {
    "input_ids": [

    torch.tensor([[1, 100, 100, 2]]), torch.tensor([[1, 100, 100, 2]]), torch.tensor([[1, 100, 100, 2]])

    ], "target": torch.tensor([[0]])

    }

返回类型:

Any

返回:

一次验证的结果,为 Dict 类型,该结果会被传入 metricupdate 方法中