Shortcuts

collie.metrics.ppl 源代码

from typing import Dict, Optional
from .base import BaseMetric

[文档]class PPLMetric(BaseMetric): """ 计算困惑度 `Perplexity <https://en.wikipedia.org/wiki/Perplexity>`_ 的 Metric。 """ def __init__(self, gather_result: bool = False) -> None: super().__init__(gather_result) self.ppl = 0. self.total = 0 def reset(self): self.ppl = 0. self.total = 0 def get_metric(self) -> Optional[Dict]: return {'ppl': round(self.ppl / (self.total + 1e-12), 6)} def update(self, result: Dict): assert "ppl" in result.keys(), f"ppl not in result!" self.ppl += float(result["ppl"].sum()) self.total += result["ppl"].shape[0]