base_result.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. import random
  16. import time
  17. import weakref
  18. from collections import UserList
  19. from pathlib import Path
  20. import numpy as np
  21. from ....utils import logging
  22. from .mixin import JsonMixin, StrMixin
  23. class CopyableWeakMethod(weakref.WeakMethod):
  24. """
  25. A weak method that can be deep copied.
  26. """
  27. def __copy__(self):
  28. return self
  29. def __deepcopy__(self, memo):
  30. return self.__copy__()
  31. class AutoWeakList(UserList):
  32. """
  33. A list that automatically removes weak references to items.
  34. """
  35. def append(self, item):
  36. """
  37. Append item to list.
  38. If item is a bound method, append a weak reference to the method.
  39. Otherwise, append the item itself.
  40. """
  41. if inspect.ismethod(item):
  42. super().append(CopyableWeakMethod(item))
  43. else:
  44. super().append(item)
  45. def __iter__(self):
  46. """Iterate over items in the list."""
  47. for item in self.data:
  48. if isinstance(item, CopyableWeakMethod):
  49. func = item()
  50. if func is not None:
  51. yield func
  52. else:
  53. yield item
  54. def __getitem__(self, index):
  55. """Get item at index."""
  56. item = super().__getitem__(index)
  57. if isinstance(item, CopyableWeakMethod):
  58. func = item()
  59. return func
  60. return item
  61. class BaseResult(dict, JsonMixin, StrMixin):
  62. """Base class for result objects that can save themselves.
  63. This class inherits from dict and provides properties and methods for handling result.
  64. """
  65. def __init__(self, data: dict) -> None:
  66. """Initializes the BaseResult with the given data.
  67. Args:
  68. data (dict): The initial data.
  69. """
  70. super().__init__(data)
  71. self._save_funcs = AutoWeakList()
  72. StrMixin.__init__(self)
  73. JsonMixin.__init__(self)
  74. np.set_printoptions(threshold=1, edgeitems=1)
  75. self._rand_fn = None
  76. def save_all(self, save_path: str) -> None:
  77. """Calls all registered save methods with the given save path.
  78. Args:
  79. save_path (str): The path to save the result to.
  80. """
  81. for func in self._save_funcs:
  82. signature = inspect.signature(func)
  83. if "save_path" in signature.parameters:
  84. func(save_path=save_path)
  85. else:
  86. func()
  87. def _get_input_fn(self):
  88. if self.get("input_path", None) is None:
  89. if self._rand_fn:
  90. return self._rand_fn
  91. timestamp = int(time.time())
  92. random_number = random.randint(1000, 9999)
  93. fp = f"{timestamp}_{random_number}"
  94. logging.warning(
  95. f"There is not input file name as reference for name of saved result file. So the saved result file would be named with timestamp and random number: `{fp}`."
  96. )
  97. self._rand_fn = Path(fp).name
  98. return self._rand_fn
  99. if isinstance(self["input_path"], list):
  100. input_path = self["input_path"][0]
  101. else:
  102. input_path = self["input_path"]
  103. fp = input_path
  104. return Path(fp).name