Shortcuts

collie.data.batch_sampler 源代码

[文档]class CollieBatchSampler: """ Batch Sampler。在最后一个 batch 样本数目不足一个 ``batch size`` 时可以选择 不处理(normal)、丢弃(drop)或从头补齐(fill)。 :param sampler: :param batch_size: :param last_batch: 当最后一个 batch 样本数不足一个 ``batch_size`` 时的处理方式 * ``'normal'`` - 不进行任何特殊处理。 * ``'drop'`` - 丢弃最后一个 batch。 * ``'fill'`` - 将最后一个 batch 补齐到 ``batch_size`` 大小。 """ def __init__(self, sampler, batch_size, last_batch="normal"): assert last_batch in ["normal", "drop", "fill"] self.sampler = sampler self.batch_size = batch_size self.last_batch = last_batch def __iter__(self): # torch BatchSampler.__iter__ if self.last_batch == "drop": sampler_iter = iter(self.sampler) while True: try: batch = [next(sampler_iter) for _ in range(self.batch_size)] yield batch except StopIteration: break else: batch = [0] * self.batch_size idx_in_batch = 0 for idx in self.sampler: batch[idx_in_batch] = idx idx_in_batch += 1 if idx_in_batch == self.batch_size: yield batch idx_in_batch = 0 batch = [0] * self.batch_size if idx_in_batch > 0: if self.last_batch == "normal": yield batch[:idx_in_batch] elif self.last_batch == "fill": sampler_iter = iter(self.sampler) while idx_in_batch < self.batch_size: try: batch[idx_in_batch] = next(sampler_iter) idx_in_batch += 1 except StopIteration: sampler_iter = iter(self.sampler) yield batch def __len__(self) -> int: if self.last_batch == "drop": return len(self.sampler) // self.batch_size # type: ignore[arg-type] else: return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type] def set_epoch(self, epoch_idx): self.sampler.set_epoch(epoch_idx)