distillation_metric.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import importlib
  15. import copy
  16. from .rec_metric import RecMetric
  17. from .det_metric import DetMetric
  18. from .e2e_metric import E2EMetric
  19. from .cls_metric import ClsMetric
  20. from .vqa_token_ser_metric import VQASerTokenMetric
  21. from .vqa_token_re_metric import VQAReTokenMetric
  22. class DistillationMetric(object):
  23. def __init__(self, key=None, base_metric_name=None, main_indicator=None, **kwargs):
  24. self.main_indicator = main_indicator
  25. self.key = key
  26. self.main_indicator = main_indicator
  27. self.base_metric_name = base_metric_name
  28. self.kwargs = kwargs
  29. self.metrics = None
  30. def _init_metrcis(self, preds):
  31. self.metrics = dict()
  32. mod = importlib.import_module(__name__)
  33. for key in preds:
  34. self.metrics[key] = getattr(mod, self.base_metric_name)(
  35. main_indicator=self.main_indicator, **self.kwargs
  36. )
  37. self.metrics[key].reset()
  38. def __call__(self, preds, batch, **kwargs):
  39. assert isinstance(preds, dict)
  40. if self.metrics is None:
  41. self._init_metrcis(preds)
  42. output = dict()
  43. for key in preds:
  44. self.metrics[key].__call__(preds[key], batch, **kwargs)
  45. def get_metric(self):
  46. """
  47. return metrics {
  48. 'acc': 0,
  49. 'norm_edit_dis': 0,
  50. }
  51. """
  52. output = dict()
  53. for key in self.metrics:
  54. metric = self.metrics[key].get_metric()
  55. # main indicator
  56. if key == self.key:
  57. output.update(metric)
  58. else:
  59. for sub_key in metric:
  60. output["{}_{}".format(key, sub_key)] = metric[sub_key]
  61. return output
  62. def reset(self):
  63. for key in self.metrics:
  64. self.metrics[key].reset()