prediction_saving_wrapper.py 907 B

123456789101112131415161718192021222324252627282930313233343536
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Dict
  3. from modelscope.metainfo import Metrics
  4. from modelscope.utils.registry import default_group
  5. from .base import Metric
  6. from .builder import METRICS
  7. @METRICS.register_module(
  8. group_key=default_group, module_name=Metrics.prediction_saving_wrapper)
  9. class PredictionSavingWrapper(Metric):
  10. """The wrapper to save predictions to file.
  11. Args:
  12. saving_fn: The saving_fn used to save predictions to files.
  13. """
  14. def __init__(self, saving_fn, **kwargs):
  15. super().__init__(**kwargs)
  16. self.saving_fn = saving_fn
  17. def add(self, outputs: Dict, inputs: Dict):
  18. self.saving_fn(inputs, outputs)
  19. def evaluate(self):
  20. return {}
  21. def merge(self, other: 'PredictionSavingWrapper'):
  22. pass
  23. def __getstate__(self):
  24. pass
  25. def __setstate__(self, state):
  26. pass