Shortcuts

Lomo

class Lomo(model, lr=0.001, clip_grad_norm=None, clip_grad_value=None, weight_decay=0.0, loss_scale_args={})[源代码]

一个自定义的优化器类Lomo,用于在分布式训练中的梯度更新。

该类实现两个梯度更新函数 fuse_update()fuse_update_zero3(),分别用于非ZeRO和ZeRO模式下的梯度更新。

参数:
  • model -- 待优化的模型

  • lr (default: 0.001) -- 学习率,默认值为1e-3

  • clip_grad_norm (default: None) --

    梯度裁剪的范数阈值

    备注

    clip_grad_norm须为正数

  • zero3_enabled -- 是否开启了 zero3

  • clip_grad_value (default: None) -- 梯度裁剪的值域阈值

  • loss_scale_args (default: {}) -- 用于初始化 DynamicLossScaler 的参数

fuse_update()[源代码]

在非ZeRO模式下更新模型参数的梯度。

返回:

func,一个闭包函数,用于更新模型参数的梯度

fuse_update_zero3()[源代码]

在ZeRO模式下更新模型参数的梯度。

返回:

func,一个闭包函数,用于更新模型参数的梯度。

fused_backward(loss, lr)[源代码]

执行一步反向传播并更新模型的梯度。

参数:
  • loss -- 模型的loss值

  • lr -- 学习率

grad_norm(loss)[源代码]

计算梯度的范数。

参数:

loss -- 模型的loss值