op_version.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ..base import core
  15. __all__ = []
  16. def Singleton(cls):
  17. _instance = {}
  18. def _singleton(*args, **kargs):
  19. if cls not in _instance:
  20. _instance[cls] = cls(*args, **kargs)
  21. return _instance[cls]
  22. return _singleton
  23. class OpUpdateInfoHelper:
  24. def __init__(self, info):
  25. self._info = info
  26. def verify_key_value(self, name=''):
  27. result = False
  28. key_funcs = {
  29. core.OpAttrInfo: 'name',
  30. core.OpInputOutputInfo: 'name',
  31. }
  32. if name == '':
  33. result = True
  34. elif type(self._info) in key_funcs:
  35. if getattr(self._info, key_funcs[type(self._info)])() == name:
  36. result = True
  37. return result
  38. @Singleton
  39. class OpLastCheckpointChecker:
  40. def __init__(self):
  41. self.raw_version_map = core.get_op_version_map()
  42. self.checkpoints_map = {}
  43. self._construct_map()
  44. def _construct_map(self):
  45. for op_name in self.raw_version_map:
  46. last_checkpoint = self.raw_version_map[op_name].checkpoints()[-1]
  47. infos = last_checkpoint.version_desc().infos()
  48. self.checkpoints_map[op_name] = infos
  49. def filter_updates(self, op_name, type=core.OpUpdateType.kInvalid, key=''):
  50. updates = []
  51. if op_name in self.checkpoints_map:
  52. for update in self.checkpoints_map[op_name]:
  53. if (update.type() == type) or (
  54. type == core.OpUpdateType.kInvalid
  55. ):
  56. if OpUpdateInfoHelper(update.info()).verify_key_value(key):
  57. updates.append(update.info())
  58. return updates