Lomo¶
- class Lomo(model, lr=0.001, clip_grad_norm=None, clip_grad_value=None, weight_decay=0.0, loss_scale_args={})[源代码]¶
一个自定义的优化器类Lomo,用于在分布式训练中的梯度更新。
该类实现两个梯度更新函数
fuse_update()和fuse_update_zero3(),分别用于非ZeRO和ZeRO模式下的梯度更新。- 参数:
model -- 待优化的模型
lr (default:
0.001) -- 学习率,默认值为1e-3clip_grad_norm (default:
None) --梯度裁剪的范数阈值
备注
clip_grad_norm须为正数
zero3_enabled -- 是否开启了 zero3
clip_grad_value (default:
None) -- 梯度裁剪的值域阈值loss_scale_args (default:
{}) -- 用于初始化DynamicLossScaler的参数