| 123456789101112131415161718192021222324252627282930313233343536 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- from typing import Dict
- from modelscope.metainfo import Metrics
- from modelscope.utils.registry import default_group
- from .base import Metric
- from .builder import METRICS
- @METRICS.register_module(
- group_key=default_group, module_name=Metrics.prediction_saving_wrapper)
- class PredictionSavingWrapper(Metric):
- """The wrapper to save predictions to file.
- Args:
- saving_fn: The saving_fn used to save predictions to files.
- """
- def __init__(self, saving_fn, **kwargs):
- super().__init__(**kwargs)
- self.saving_fn = saving_fn
- def add(self, outputs: Dict, inputs: Dict):
- self.saving_fn(inputs, outputs)
- def evaluate(self):
- return {}
- def merge(self, other: 'PredictionSavingWrapper'):
- pass
- def __getstate__(self):
- pass
- def __setstate__(self, state):
- pass
|