import os
import functools
import inspect
from inspect import Parameter
from accelerate.utils.modeling import find_tied_parameters
import dataclasses
from types import MethodType
from typing import (Callable, Any, Dict, Union, Mapping, Sequence, Tuple,
Optional, List, AnyStr)
from collections import defaultdict, OrderedDict
from operator import length_hint
from copy import deepcopy
import torch
from collie.log.logger import logger
from .rich_progress import f_rich_progress
__all__ = ["find_tensors", "progress", "dictToObj", "apply_to_collection",
"dict_as_params", "initization_mapping", "is_static_method",
"auto_param_call", "get_keys_to_not_convert", "concat_tensor"]
[文档]def find_tensors():
"""
打印出垃圾回收区的所有张量。
Adopted from https://discuss.pytorch.org/t/how-to-debug-causes-of-gpu-memory-leaks/6741/3
"""
import torch
import gc
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
print(type(obj), obj.size(), obj.dtype, obj.device)
except:
pass
[文档]def concat_tensor(tensor_list, dim=0):
"""
拼接 ``tensor_list`` 中的张量,并且在拼接时将张量转移到 cpu 上来避免显存的增加。
:return: 拼接后位于 cpu 上的张量
"""
tensor_list_cpu = [t.detach().cpu().clone() for t in tensor_list]
tensor_list.clear()
# del tensor_list
ret = torch.cat(tensor_list_cpu, dim=dim)
return ret
def stack_tensor(tensor_list, dim=0):
"""
叠加 ``tensor_list`` 中的张量,并且在叠加时将张量转移到 cpu 上来避免显存的增加。
:return: 叠加后位于 cpu 上的张量
"""
tensor_list_cpu = [t.detach().cpu().clone() for t in tensor_list]
tensor_list.clear()
# del tensor_list
ret = torch.stack(tensor_list_cpu, dim=dim)
return ret
[文档]class progress:
"""包装了 ``rich`` 进度条的类。
.. code-block::
for batch in progress(dataloader):
# do something
.. code-block::
with progress(dataloader) as bar:
for batch in bar:
# do something
bar.set_postfix(Loss=1.0)
.. code-block::
bar = progress(dataloader)
for batch in bar:
# do something
bar.set_postfix(Loss=1.0)
:param sequence: 需要遍历的序列,需要是一个可以迭代的对象。
:param desc: 进度条最左侧的描述语句。
:param total: 遍历对象的总数。如果为 ``None`` 则会自动进行计算。
:param completed: 标识进度条的总进度。
:param upgrade_period: 进度条更新的时间间隔。
:param disable: 调整进度条是否可见。
:param post_desc: 进度条最右侧的补充描述语句。
"""
def __init__(self, sequence, desc="Workin on...", total=None, completed=0,
upgrade_period=0.1, disable=False, post_desc: str = ""):
self.bar = f_rich_progress
self.bar.set_disable(disable)
self.sequence = sequence
self.total = float(length_hint(sequence)) if total is None else total
self.completed = completed
self.task_id = self.bar.add_task(
desc, upgrade_period=upgrade_period, completed=completed,
post_desc=post_desc, visible=not disable, total=self.total
)
def __iter__(self):
yield from self.bar.track(
self.sequence, task_id=self.task_id, total=self.total)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
...
def __del__(self):
self.bar.destroy_task(self.task_id)
[文档] def set_post_desc(self, post_desc: str):
"""
设置进度条最右侧的补充描述语句。
.. code-block::
bar = progress(dataloader)
dataloader.set_post_desc("Loss=1.0")
"""
self.bar.update(self.task_id, post_desc=post_desc, advance=0)
[文档] def set_postfix(self, **kwargs):
"""
设置进度条最右侧的补充描述语句。
对于传入的每一对 key 和 value 将以 ``key1: value1, key2: value2, ..``
的格式进行显示。
.. code-block::
bar = progress(dataloader)
dataloader.set_postfix(Loss=1.0, Batch=1)
"""
post_desc = ", ".join([f"{k}: {v}" for k, v in kwargs.items()])
self.set_post_desc(post_desc)
[文档] def set_description(self, desc):
"""
设置进度条最左侧的描述语句。
"""
self.update(desc=desc)
[文档] def reset(
self, desc: Optional[str] = None, total: Optional[float] = None, completed: int = 0, visible: Optional[bool] = None,
post_desc: Optional[str] = None,
):
"""
重置进度条,可以重置进度条的显示时间。
:param desc: 进度条最左侧的描述语句。
:param total: 遍历对象的总数。如果为 ``None`` 则不会发生改变。
:param completed: 标识进度条的总进度。
:param visible: 调整进度条是否可见。
:param post_desc: 进度条最右侧的补充描述语句。
"""
if post_desc is None:
self.bar.reset(self.task_id, description=desc, total=total,
completed=completed, visible=visible)
else:
self.bar.reset(self.task_id, description=desc, total=total,
completed=completed, visible=visible,
post_desc=post_desc)
[文档] def update(
self, desc: Optional[str] = None, total: Optional[float] = None,
completed: Optional[float] = None, advance: Optional[float] = None,
visible: Optional[bool] = None, refresh: bool = False,
post_desc: Optional[str] = None,
) -> None:
"""
对进度条的内容进行更新,可以更加详细地改变进度条的内容。
:param desc: 进度条最左侧的描述语句。
:param total: 遍历对象的总数。如果为 ``None`` 则不会发生改变。
:param completed: 标识进度条的总进度。
:param advance: 该次进度条更新的进度。
:param visible: 调整进度条是否可见。
:param refresh: 是否强制刷新进度条。
:param post_desc: 进度条最右侧的补充描述语句。
"""
if post_desc is None:
self.bar.update(self.task_id, total=total, completed=completed,
advance=advance, description=desc, visible=visible,
refresh=refresh)
else:
self.bar.update(self.task_id, total=total, completed=completed,
advance=advance, description=desc, visible=visible,
refresh=refresh, post_desc=post_desc)
def _split_past_key_values(past_key_values, micro_batch_size, micro_batch_num):
if micro_batch_num == 1:
return (past_key_values,)
past_kv_split = [() for _ in range(micro_batch_num)]
for layer_past in past_key_values:
if isinstance(layer_past, (tuple, list)):
assert len(layer_past) == 2
else:
# prefix tuning 的 past key values 是个 tensor
assert isinstance(layer_past, torch.Tensor)
assert layer_past.shape[0] == 2
key_split = torch.split(layer_past[0], micro_batch_size)
value_split = torch.split(layer_past[1], micro_batch_size)
assert len(key_split) == micro_batch_num, len(key_split)
assert len(value_split) == micro_batch_num, len(value_split)
for i in range(micro_batch_num):
past_kv_split[i] += ((key_split[i], value_split[i]),)
return tuple(past_kv_split)
def _split_dict(inputs, micro_batch_size, micro_batch_num):
inputs_split = {}
for key in list(inputs.keys()):
dim = 0
if key == "past_key_values":
dim = 2
if isinstance(inputs[key], torch.Tensor):
inputs_split[key] = torch.split(inputs[key], micro_batch_size, dim)
elif isinstance(inputs[key], Sequence):
inputs_split[key] = [torch.split(input_, micro_batch_size, dim) for input_ in inputs[key]]
inputs_split[key] = list(zip(*inputs_split[key]))
inputs_split = [{key: value[i] for key, value in inputs_split.items()} for i in range(micro_batch_num)]
return inputs_split
def _split_batch(batch, micro_batch_size, micro_batch_num):
"""
将 ``batch`` 划分为 ``micro_batch_num`` 个 ``micro_batch_size`` 大小。
仅在流水线情况的训练和验证中用到。
:param batch: tuple from dataloader
:param micro_batch_size:
:param micro_batch_num:
:return: tuple
"""
if isinstance(batch, torch.Tensor):
batch_split = torch.split(batch, micro_batch_size)
elif isinstance(batch, dict):
batch_split = _split_dict(batch, micro_batch_size, micro_batch_num)
else:
raise NotImplementedError(f"Invalid type of batch: {type(batch)}"
"Must be Tensor or dict.")
assert len(batch_split) == micro_batch_num, len(batch_split)
return batch_split
class Dict(dict):
__setattr__ = dict.__setitem__
__getattr__ = dict.__getitem__
def dictToObj(dictObj):
if not isinstance(dictObj, dict):
return dictObj
d = Dict()
for k, v in dictObj.items():
d[k] = dictToObj(v)
return d
def _is_namedtuple(obj: object) -> bool:
# https://github.com/pytorch/pytorch/blob/v1.8.1/torch/nn/parallel/scatter_gather.py#L4-L8
return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
def _is_dataclass_instance(obj: object) -> bool:
# https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions
return dataclasses.is_dataclass(obj) and not isinstance(obj, type)
[文档]def apply_to_collection(
data: Any,
dtype: Union[type, Any, Tuple[Union[type, Any]]],
function: Callable,
*args: Any,
wrong_dtype: Optional[Union[type, Tuple[type]]] = None,
include_none: bool = True,
**kwargs: Any,
) -> Any:
"""
递归地对 ``data`` 中的元素执行函数 ``function``,且仅在满足元素为 ``dtype`` 时执行。
该函数参考了 `pytorch-lightning <https://github.com/PyTorchLightning/pytorch-lightning>`_ 的实现
:param data: 需要进行处理的数据集合或数据;
:param dtype: 数据的类型,函数 ``function`` 只会被应用于 ``data`` 中类型为 ``dtype`` 的数据;
:param function: 对数据进行处理的函数;
:param args: ``function`` 所需要的其它参数;
:param wrong_dtype: ``function`` 一定不会生效的数据类型。
如果数据既是 ``wrong_dtype`` 类型又是 ``dtype`` 类型那么也不会生效;
:param include_none: 是否包含执行结果为 ``None`` 的数据,默认为 ``True``;
:param kwargs: ``function`` 所需要的其它参数;
:return: 经过 ``function`` 处理后的数据集合;
"""
# Breaking condition
if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)):
return function(data, *args, **kwargs)
elem_type = type(data)
# Recursively apply to collection items
if isinstance(data, Mapping):
out = []
for k, v in data.items():
v = apply_to_collection(
v, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs
)
if include_none or v is not None:
out.append((k, v))
if isinstance(data, defaultdict):
return elem_type(data.default_factory, OrderedDict(out))
return elem_type(OrderedDict(out))
is_namedtuple = _is_namedtuple(data)
is_sequence = isinstance(data, Sequence) and not isinstance(data, str)
if is_namedtuple or is_sequence:
out = []
for d in data:
v = apply_to_collection(
d, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs
)
if include_none or v is not None:
out.append(v)
return elem_type(*out) if is_namedtuple else elem_type(out)
if _is_dataclass_instance(data):
# make a deepcopy of the data,
# but do not deepcopy mapped fields since the computation would
# be wasted on values that likely get immediately overwritten
fields = {}
memo = {}
for field in dataclasses.fields(data):
field_value = getattr(data, field.name)
fields[field.name] = (field_value, field.init)
memo[id(field_value)] = field_value
result = deepcopy(data, memo=memo)
# apply function to each field
for field_name, (field_value, field_init) in fields.items():
if field_init:
v = apply_to_collection(
field_value,
dtype,
function,
*args,
wrong_dtype=wrong_dtype,
include_none=include_none,
**kwargs,
)
if not field_init or (not include_none and v is None): # retain old value
v = getattr(data, field_name)
setattr(result, field_name, v)
return result
# data is neither of dtype, nor a collection
return data
def _get_fun_msg(fn, with_fp=True)->str:
"""
获取函数的基本信息,帮助报错::
>>> print(_get_fun_msg(_get_fun_msg))
`_get_fun_msg(fn) -> str`(In file:/Users/hnyan/Desktop/projects/fastNLP/fastNLP/fastNLP/core/utils/utils.py)
:param callable fn:
:param with_fp: 是否包含函数所在的文件信息;
:return:
"""
if isinstance(fn, functools.partial):
return _get_fun_msg(fn.func)
try:
fn_name = fn.__qualname__ + str(inspect.signature(fn))
except:
fn_name = str(fn)
if with_fp:
try:
fp = '(In file:' + os.path.abspath(inspect.getfile(fn)) + ')'
except:
fp = ''
else:
fp = ''
msg = f'`{fn_name}`' + fp
return msg
def _check_valid_parameters_number(fn,
expected_params: List[str],
fn_name=None):
r"""检查一个函数是否需要 expected_params 参数(检测数量是否匹配)。除掉 self (如
果是method),给定默认值的参数等。如果匹配不上,就会进行报错。
:param fn: 需要检测的函数,可以是 method 或者 function 。
:param expected_params: 期待应该支持的参数。
:param fn_name: fn 的名字,当传入的 fn 不是 callable 的时候方便报错。
:return:
"""
if fn_name is not None:
assert callable(
fn), f'`{fn_name}` should be callable, instead of `{type(fn)}`.'
try:
args: List[Any] = []
kwargs: Dict[str, Any] = {}
name = ''
if isinstance(fn, functools.partial) and not hasattr(fn, '__name__'):
name = 'partial:'
f = fn.func
while isinstance(f, functools.partial):
name += 'partial:'
f = f.func
fn.__name__ = name + f.__name__ # type: ignore
inspect.getcallargs(fn, *args, *expected_params, **kwargs)
if name: # 如果一开始没有name的,需要给人家删除掉
delattr(fn, '__name__')
except TypeError as e:
logger.error(
f'The function:{_get_fun_msg(fn)} will be provided with '
f'parameters:{expected_params}. The following exception will '
'happen.')
raise e
[文档]def dict_as_params(input_keys: Union[str, Sequence[str]], output_keys: Union[str, Sequence[str]]):
""" 使用字典作为参数输入的辅助函数
从输入的字典中顺次取出 ``input_keys`` 作为模型的输入,并且将模型的输出以
``output_keys`` 为 key 放入字典中作为输出。在这一过程中多余的 key 并不会被丢
弃。
可以用于 ``nn.LayerNorm`` 这些在流水线并行中一般不需要改变 forward 过程但需要
改变输入输出结构的模型,使用该函数可以避免频繁地重写这些模型,并且可以适应流水线
``LayerSpec`` 的初始化。
.. code-block::
dict_as_params(input_keys="input_ids", output_keys="hidden_states")(nn.Embedding, vocab_size, hidden_size)
LayerSpec(
dict_as_params(input_keys="input_ids", output_keys="hidden_states"),
nn.Embbedding, vocab_size, hidden_size
)
:param input_keys: 该模型输入需要的 key。``dict_as_params`` 会从输入的字典中
依次取出 ``input_keys`` 的内容传入模型。
:param output_keys: 该模型输出对应的 key。``dict_as_params`` 会依次将模型的
输出和 ``output_keys`` 进行对应,并放入字典中作为最终的输出。
.. note::
在使用该函数时,请您注意输入输出顺序和 ``input_keys`` ``output_keys`` 顺序
的对应关系,避免将错误的 key 赋给了对应的张量。
"""
def _inner(cls: type, *args, **kwargs):
obj = cls(*args, **kwargs)
object.__setattr__(obj, "dict_as_params_input_keys", input_keys)
object.__setattr__(obj, "dict_as_params_output_keys", output_keys)
raw_forward = obj.forward
def _forward(self, dict_inputs: dict):
if isinstance(input_keys, str):
inputs = [dict_inputs[input_keys]]
elif isinstance(input_keys, Sequence):
inputs = [dict_inputs[k] for k in input_keys]
else:
raise ValueError(f"input_keys should be str or Sequence[str], but got {type(input_keys)}")
outputs = raw_forward(*inputs)
if isinstance(output_keys, str):
dict_inputs[output_keys] = outputs
for k, v in dict_inputs.items():
if k != output_keys and k != "past_key_values":
dict_inputs[k] = v.detach()
elif isinstance(output_keys, Sequence):
assert isinstance(outputs, Sequence) and len(outputs) == len(output_keys), \
"outputs should be Sequence and have the same length as output_keys"
for k, v in zip(output_keys, outputs):
dict_inputs[k] = v
for k, v in dict_inputs.items():
if k not in output_keys and k != "past_key_values":
dict_inputs[k] = v.detach()
else:
raise ValueError(f"output_keys should be str or Sequence[str], but got {type(output_keys)}")
return dict_inputs
object.__setattr__(obj, "raw_forward", raw_forward)
obj.forward = MethodType(_forward, obj)
return obj
return _inner
def is_static_method(func):
""" 判断一个函数是否是静态方法。
"""
if inspect.isfunction(func):
if inspect.ismethod(func):
# 对于绑定方法,检查其是否由staticmethod装饰器修饰
return isinstance(func.__func__, staticmethod)
else:
# 对于普通函数,检查其是否由staticmethod装饰器修饰
return isinstance(func, staticmethod)
return False
def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None,
mapping: Optional[Dict] = None) -> Any:
r"""
该函数会根据输入函数的形参名从 ``*args`` (均为 **dict** 类型)中找到匹配的值进行调用,如果传入的数据与 ``fn`` 的形参不匹配,可以通过
``mapping`` 参数进行转换。``mapping`` 参数中的一对 ``(key, value)`` 表示在 ``*args`` 中找到 ``key`` 对应的值,并将这个值传递给形参中名为
``value`` 的参数。
1. 该函数用来提供给用户根据字符串匹配从而实现自动调用;
2. 注意 ``mapping`` 默认为 ``None``,如果您希望指定输入和运行函数的参数的对应方式,那么您应当让 ``mapping`` 为一个字典传入进来;
如果 ``mapping`` 不为 ``None``,那么我们一定会先使用 ``mapping`` 将输入的字典的 ``keys`` 修改过来,因此请务必亲自检查 ``mapping`` 的正确性;
3. 如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值;
4. 如果输入的函数是一个 ``partial`` 函数,情况同第三点,即和默认参数的情况相同;
Examples::
>>> # 1
>>> loss_fn = CrossEntropyLoss() # 如果其需要的参数为 def CrossEntropyLoss(y, pred);
>>> batch = {"x": 20, "y": 1}
>>> output = {"pred": 0}
>>> acc = auto_param_call(loss_fn, batch, output)
>>> # 2
>>> def test_fn(x, y, a, b=10):
>>> return x + y + a + b
>>> print(auto_param_call(test_fn, {"x": 10}, {"y": 20, "a": 30})) # res: 70
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240
:param fn: 用来进行实际计算的函数,其参数可以包含有默认值;
:param args: 一系列的位置参数,应当为一系列的字典,我们需要从这些输入中提取 ``fn`` 计算所需要的实际参数;
:param signature_fn: 函数,用来替换 ``fn`` 的函数签名,如果该参数不为 ``None``,那么我们首先会从该函数中提取函数签名,
然后通过该函数签名提取参数值后,再传给 ``fn`` 进行实际的运算;
:param mapping: 一个字典,用来更改其前面的字典的键值;
:return: ``fn`` 运行的结果;
"""
if signature_fn is not None:
if not callable(signature_fn):
raise ValueError(f"Parameter `signature_fn` should be `Callable`.")
_need_params = OrderedDict(inspect.signature(signature_fn).parameters)
else:
_need_params = OrderedDict(inspect.signature(fn).parameters)
_kwargs = None
for _name, _param in _need_params.items():
if _param.kind == Parameter.VAR_POSITIONAL:
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn)
raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn_msg}.")
if _param.kind == Parameter.VAR_KEYWORD:
_kwargs = (_name, _param)
if _kwargs is not None:
_need_params.pop(_kwargs[0])
_default_params = {}
for _name, _param in _need_params.items():
if _param.default != Parameter.empty:
_default_params[_name] = _param.default
if mapping is not None:
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn)
assert isinstance(mapping, Dict), f"Exception happens when calling {fn_msg}. " \
f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}."
_has_params = {}
duplicate_names = []
for arg in args:
if not isinstance(arg, (Dict, dict)):
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn)
raise TypeError(f"Exception happens when calling {fn_msg}. "
f"The input part of function `auto_param_call` must be `Dict` type, instead of {type(arg)}.")
for _name, _value in arg.items():
if mapping is not None and _name in mapping:
_name = mapping[_name]
if _name not in _has_params:
if _kwargs is not None or _name in _need_params:
_has_params[_name] = _value
# 同一参数对象在两个输入的资源中都出现,造成混淆;
elif _name in _need_params and not (_has_params[_name] is _value):
duplicate_names.append(_name)
if duplicate_names:
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn)
raise ValueError(f"The following key present in several inputs:{duplicate_names} when calling {fn_msg}.")
# 将具有默认值但是没有被输入修改过的参数值传进去;
for _name, _value in _default_params.items():
if _name not in _has_params:
_has_params[_name] = _value
if len(_has_params) < len(_need_params):
miss_params = list(set(_need_params.keys()) - set(_has_params.keys()))
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn)
_provided_keys = _get_keys(args)
raise ValueError(f"The parameters:`{miss_params}` needed by function:{fn_msg} "
f"are not found in the input keys({_provided_keys}).")
return fn(**_has_params)
def _get_keys(args:List[Dict]) -> List[List[str]]:
"""
返回每个 dict 的 keys
:param args:
:return:
"""
_provided_keys = []
for arg in args:
_provided_keys.append(list(arg.keys()))
return _provided_keys
def get_keys_to_not_convert(model):
r"""
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
int8.
:param model: Input model
"""
tied_params = find_tied_parameters(model)
# For compatibility with Accelerate < 0.18
if isinstance(tied_params, dict):
tied_keys = list(tied_params.values())
else:
tied_keys = sum([x[1:] for x in tied_params], [])
has_tied_params = len(tied_keys) > 0
# otherwise they have an attached head
list_modules = list(model.named_parameters())
list_last_module = [list_modules[-1][0]]
# add last module together with tied weights
intersection = set(list_last_module) - set(tied_keys)
list_untouched = tied_keys + list(intersection)
# remove ".weight" from the keys
names_to_remove = [".weight", ".bias"]
filtered_module_names = []
for name in list_untouched:
for name_to_remove in names_to_remove:
if name_to_remove in name:
name = name.replace(name_to_remove, "")
filtered_module_names.append(name)
return filtered_module_names