PipelineGenerationMixin¶
- class PipelineGenerationMixin[源代码]¶
重写
transformers提供的GenerationMixin以支持 CoLLie 中的流水线 模型。- 参数:
engine -- DeepSpeedEngine 实例,可由
setup_ds_engine()函数生成
- generate_forward(input_ids=None, attention_mask=None, inputs_embeds=None, position_ids=None, use_cache=True, past_key_values=None, **kwargs)[源代码]¶
进行迭代的流水线模型的前向传播(生成)
- 返回类型:
Tensor
- train_forward(labels, input_ids=None, attention_mask=None, inputs_embeds=None, position_ids=None, past_key_values=None, **kwargs)[源代码]¶
进行一次流水线模型的正反向传播
- 返回类型:
Tensor