Shortcuts

CoLLiE.utils

Dist Utils

Env

CoLLiE 的环境变量,可以从中获取各种并行的 world_size 和 rank。

setup_distribution

设置分布式环境。

setup_ds_engine

启动 DeepSpeed 引擎。

set_seed

设置随机数种子。

broadcast_tensor

src 广播 tensor

zero3_load_state_dict

用于加载 ZeRO stage 3 的模型参数。

is_zero3_enabled

判断是否启用了 ZeRO stage 3。

Utils

find_tensors

打印出垃圾回收区的所有张量。

apply_to_collection

递归地对 data 中的元素执行函数 function,且仅在满足元素为 dtype 时执行。

dict_as_params

使用字典作为参数输入的辅助函数

concat_tensor

拼接 tensor_list 中的张量,并且在拼接时将张量转移到 cpu 上来避免显存的增加。

progress

包装了 rich 进度条的类。

Data Provider

BaseProvider

BaseProvider 为异步数据提供器的基类,提供了一些基本的接口

GradioProvider

基于 Gradio 的异步数据提供器,会在本地启动一个 Gradio 服务,将用户输入的文本作为模型的输入

DashProvider

基于 Dash 的异步数据提供器,会在本地启动一个 Dash 服务,将用户输入的文本作为模型的输入

Monitor

BaseMonitor

BaseMonitor是一个基础的监控器类,用于记录模型训练过程中的统计信息

StepTimeMonitor

用来记录每个step的时间

TGSMonitor

用来记录每秒每张 GPU 可训练的 token 数 (token / s / GPU)

MemoryMonitor

用来记录每个step的内存占用

LossMonitor

用来记录每个step的loss

EvalMonitor

用来记录每个step的eval结果,仅支持 intfloat 类型的结果

LRMonitor

用来记录每个step的learning rate

Padder

ColliePadder

CoLLie 中的通用 collate_fn 构造器

其它版本