collie.utils.padder 源代码
""" **CoLLie** 中的通用 ``collate_fn`` 构造器
"""
from typing import Any, Dict, List, Sequence, Tuple
import numpy as np
import torch
import torch.nn.functional as F
__all__ = ["ColliePadder"]
[文档]class ColliePadder:
"""**CoLLie** 中的通用 ``collate_fn`` 构造器
:param padding_token: 用于填充模型输入数据 (input_ids) 的 token,为一个 ``Dict`` 决定不同的字段使用不同 id
:param labels_padding_token: 用于填充模型标签数据 (labels) 的 token
:param padding_left: 是否在左侧填充
"""
def __init__(
self,
padding_token_id: dict = {"attention_mask": 0, "labels": -100},
padding_left: bool = False,
) -> None:
self.padding_token_id = padding_token_id
self.padding_left = padding_left
self.key = "input_ids"
[文档] def collate_fn(self, batch: Sequence[Any]) -> torch.Tensor:
"""用于填充的 ``collate_fn``
:param batch: 一个 batch 的数据
:return: 填充后的 batch
"""
padding_token_id = self.padding_token_id.get(self.key, 0)
batch = list(batch)
if isinstance(batch[0], torch.Tensor):
pass
elif isinstance(batch[0], (int, float)):
batch = [torch.tensor(x).cuda() for x in batch]
elif isinstance(batch[0], np.ndarray):
batch = [torch.from_numpy(x).cuda() for x in batch]
elif isinstance(batch[0], list):
batch = [torch.tensor(x).cuda() for x in batch]
else:
raise TypeError(f"Unsupported type: {type(batch[0])}")
for i in range(len(batch)):
sample = batch[i]
shape = []
for s in sample.shape:
if s > 1:
shape.append(s)
if not shape:
shape.append(1)
sample = sample.view(*shape)
batch[i] = sample
max_shape = max([x.shape for x in batch])
for i in range(len(batch)):
shape = (
(torch.tensor(max_shape) - torch.tensor(batch[i].shape)).cpu().tolist()
)
if self.padding_left:
batch[i] = F.pad(
batch[i],
[shape.pop() if d % 2 == 0 else 0 for d in range(len(shape) * 2)],
value=padding_token_id,
)
else:
batch[i] = F.pad(
batch[i],
[
shape.pop() if (d + 1) % 2 == 0 else 0
for d in range(len(shape) * 2)
],
value=padding_token_id,
)
return torch.stack(batch, dim=0).cuda()
def __call__(self, batch: List[Any]) -> Any:
padded_batch = None
if isinstance(batch[0], (torch.Tensor, np.ndarray, list, int, float)):
padded_batch = self.collate_fn([x for x in batch])
elif isinstance(batch[0], tuple):
padded_batch = tuple(
[self.collate_fn([x[j] for x in batch]) for j in range(len(batch[0]))]
)
elif isinstance(batch[0], Dict):
padded_dict = {}
for key in batch[0].keys():
self.key = key
if isinstance(
batch[0][key], (torch.Tensor, np.ndarray, list, int, float)
):
padded_dict[key] = self.collate_fn([x[key] for x in batch])
elif isinstance(batch[0][key], tuple) and isinstance(
batch[0][key][0], (torch.Tensor, np.ndarray, list)
):
padded_dict[key] = [
self.collate_fn([x[key][j] for x in batch])
for j in range(len(batch[0][key]))
]
elif isinstance(batch[0][key], tuple) and isinstance(
batch[0][key][0], str
):
padded_dict[key] = [x[key] for x in batch]
else:
raise TypeError(f"Unsupported type: {type(batch[0][key])}")
padded_batch = padded_dict
else:
raise TypeError(f"Unsupported type: {type(batch[0])}")
return padded_batch