Shortcuts

collie.callbacks.load_best_model_callback 源代码

import os
from typing import Callable, Union

from collie.log.logger import logger
from collie.driver.io import IODriver
from collie.utils import env
from .has_monitor_callback import HasMonitorCallback

__all__ = ['LoadBestModelCallback']


[文档]class LoadBestModelCallback(HasMonitorCallback): r"""保存 monitor 值最佳的模型,并在训练结束的时候重新加载模型的 ``Callbcak``。 默认会在加载之后删除权重文件。仅在训练正常结束的时候才能加载最好的模型。 :param folder: 保存的文件夹。 :param process_exclusion: -- 是否互斥地执行保存操作;在模型规模较大时该参数可以 节省一定的内存。 :param monitor: 监控的 metric 值。 * 为 ``None`` 时,不设置监控值。 * 为 ``str`` 时, CoLLiE 将尝试直接使用该名称从 ``evaluation`` 的结果中寻找,如果最终在 ``evaluation`` 结果中没有找到完全一致的名称,则将使用最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor``。 * 为 :class:`Callable` 时, 则接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作 为 ``monitor`` 的结果,如果当前结果中没有相关的 ``monitor`` 值则返回 ``None``。 :param larger_better: 该 metric 值是否是越大越好; :param delete_after_train: 在训练结束后是否删掉模型; :param kwargs: 传给 :meth:`.Trainer.save_model` 和 :meth:`.Trainer.\ load_model` 的额外参数。 """ def __init__(self, folder: str, process_exclusion: bool = False, monitor: Union[str, Callable, None] = None, larger_better: bool = True, delete_after_train: bool = True, **kwargs ): super().__init__( monitor=monitor, larger_better=larger_better, must_have_monitor=True) self.save_folder = folder self.delete_after_train = delete_after_train self.meta = {'epoch': -1, 'batch': -1} self.process_exclusion = process_exclusion self.kwargs = kwargs self.real_save_folder = os.path.join(folder, "best") def on_evaluate_end(self, trainer, results): if self.is_better_results(results, keep_if_better=True): self.meta['epoch'] = trainer.epoch_idx self.meta['batch'] = trainer.batch_idx trainer.save_model( self.real_save_folder, self.process_exclusion, **self.kwargs ) def on_train_end(self, trainer): if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。 logger.info(f'Loading best model from {self.real_save_folder} ' f"with '{self._real_monitor}: {self.monitor_value} " f"(achieved in Epoch: {self.meta['epoch']}, Batch in " f"epoch: {self.meta['batch']}) ...") trainer.load_model(self.real_save_folder, self.process_exclusion, **self.kwargs) if self.delete_after_train: self._delete_folder() def _delete_folder(self): if env.rank == 0: protocol = self.kwargs.get("protocol", "file") driver = IODriver.from_protocol(protocol) driver.delete(self.save_folder)