utils.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. from typing import Generic, Iterable, Iterator, TypeVar
  16. T = TypeVar("T")
  17. class OrderedSet(Generic[T]):
  18. """
  19. A set that preserves the order of insertion.
  20. """
  21. _data: dict[T, None]
  22. def __init__(self, items: Iterable[T] | None = None):
  23. """
  24. Examples:
  25. >>> s = OrderedSet([1, 2, 3])
  26. >>> s
  27. OrderedSet(1, 2, 3)
  28. >>> s = OrderedSet()
  29. >>> s
  30. OrderedSet()
  31. """
  32. self._data = dict.fromkeys(items) if items is not None else {}
  33. def __iter__(self) -> Iterator[T]:
  34. """
  35. Examples:
  36. >>> s = OrderedSet([1, 2, 3])
  37. >>> for item in s:
  38. ... print(item)
  39. 1
  40. 2
  41. 3
  42. """
  43. return iter(self._data)
  44. def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]:
  45. """
  46. Union two sets.
  47. Args:
  48. other: Another set to be unioned.
  49. Returns:
  50. The union of two sets.
  51. Examples:
  52. >>> s1 = OrderedSet([1, 2, 3])
  53. >>> s2 = OrderedSet([2, 3, 4])
  54. >>> s1 | s2
  55. OrderedSet(1, 2, 3, 4)
  56. """
  57. return OrderedSet(list(self) + list(other))
  58. def __ior__(self, other: OrderedSet[T]):
  59. """
  60. Union two sets in place.
  61. Args:
  62. other: Another set to be unioned.
  63. Examples:
  64. >>> s1 = OrderedSet([1, 2, 3])
  65. >>> s2 = OrderedSet([2, 3, 4])
  66. >>> s1 |= s2
  67. >>> s1
  68. OrderedSet(1, 2, 3, 4)
  69. """
  70. self._data.update(dict.fromkeys(other))
  71. return self
  72. def __and__(self, other: OrderedSet[T]) -> OrderedSet[T]:
  73. """
  74. Intersect two sets.
  75. Args:
  76. other: Another set to be intersected.
  77. Returns:
  78. The intersection of two sets.
  79. Examples:
  80. >>> s1 = OrderedSet([1, 2, 3])
  81. >>> s2 = OrderedSet([2, 3, 4])
  82. >>> s1 & s2
  83. OrderedSet(2, 3)
  84. """
  85. return OrderedSet([item for item in self if item in other])
  86. def __iand__(self, other: OrderedSet[T]):
  87. """
  88. Intersect two sets in place.
  89. Args:
  90. other: Another set to be intersected.
  91. Examples:
  92. >>> s1 = OrderedSet([1, 2, 3])
  93. >>> s2 = OrderedSet([2, 3, 4])
  94. >>> s1 &= s2
  95. >>> s1
  96. OrderedSet(2, 3)
  97. """
  98. self._data = {item: None for item in self if item in other}
  99. return self
  100. def __sub__(self, other: OrderedSet[T]) -> OrderedSet[T]:
  101. """
  102. Subtract two sets.
  103. Args:
  104. other: Another set to be subtracted.
  105. Returns:
  106. The subtraction of two sets.
  107. Examples:
  108. >>> s1 = OrderedSet([1, 2, 3])
  109. >>> s2 = OrderedSet([2, 3, 4])
  110. >>> s1 - s2
  111. OrderedSet(1)
  112. """
  113. return OrderedSet([item for item in self if item not in other])
  114. def __isub__(self, other: OrderedSet[T]):
  115. """
  116. Subtract two sets in place.
  117. Args:
  118. other: Another set to be subtracted.
  119. Examples:
  120. >>> s1 = OrderedSet([1, 2, 3])
  121. >>> s2 = OrderedSet([2, 3, 4])
  122. >>> s1 -= s2
  123. >>> s1
  124. OrderedSet(1)
  125. """
  126. self._data = {item: None for item in self if item not in other}
  127. return self
  128. def __xor__(self, other: OrderedSet[T]) -> OrderedSet[T]:
  129. """
  130. Symmetric difference of two sets.
  131. Args:
  132. other: Another set to be xor'ed.
  133. Returns:
  134. The symmetric difference of two sets.
  135. Examples:
  136. >>> s1 = OrderedSet([1, 2, 3])
  137. >>> s2 = OrderedSet([2, 3, 4])
  138. >>> s1 ^ s2
  139. OrderedSet(1, 4)
  140. """
  141. return OrderedSet(
  142. [item for item in self if item not in other]
  143. ) | OrderedSet([item for item in other if item not in self])
  144. def __ixor__(self, other: OrderedSet[T]):
  145. """
  146. Symmetric difference of two sets in place.
  147. Args:
  148. other: Another set to be xor'ed.
  149. Examples:
  150. >>> s1 = OrderedSet([1, 2, 3])
  151. >>> s2 = OrderedSet([2, 3, 4])
  152. >>> s1 ^= s2
  153. >>> s1
  154. OrderedSet(1, 4)
  155. """
  156. # TODO(Python3.8-cleanup): Use dict union syntax when Python 3.9 is
  157. # minimum supported version.
  158. # self._data = {item: None for item in self if item not in other} | {
  159. # item: None for item in other if item not in self
  160. # }
  161. self._data = {
  162. **{item: None for item in self if item not in other},
  163. **{item: None for item in other if item not in self},
  164. }
  165. return self
  166. def add(self, item: T):
  167. """
  168. Add an item to the set.
  169. Args:
  170. item: The item to be added.
  171. Examples:
  172. >>> s = OrderedSet([1, 2, 3])
  173. >>> s.add(4)
  174. >>> s
  175. OrderedSet(1, 2, 3, 4)
  176. """
  177. self._data.setdefault(item)
  178. def remove(self, item: T):
  179. """
  180. Remove an item from the set.
  181. Args:
  182. item: The item to be removed.
  183. Examples:
  184. >>> s = OrderedSet([1, 2, 3])
  185. >>> s.remove(2)
  186. >>> s
  187. OrderedSet(1, 3)
  188. """
  189. del self._data[item]
  190. def __contains__(self, item: T) -> bool:
  191. """
  192. Examples:
  193. >>> s = OrderedSet([1, 2, 3])
  194. >>> 1 in s
  195. True
  196. >>> 4 in s
  197. False
  198. """
  199. return item in self._data
  200. def __len__(self) -> int:
  201. """
  202. Examples:
  203. >>> s = OrderedSet([1, 2, 3])
  204. >>> len(s)
  205. 3
  206. """
  207. return len(self._data)
  208. def __bool__(self) -> bool:
  209. """
  210. Examples:
  211. >>> s = OrderedSet([1, 2, 3])
  212. >>> bool(s)
  213. True
  214. >>> s = OrderedSet()
  215. >>> bool(s)
  216. False
  217. """
  218. return bool(self._data)
  219. def __eq__(self, other: object) -> bool:
  220. """
  221. Examples:
  222. >>> s1 = OrderedSet([1, 2, 3])
  223. >>> s2 = OrderedSet([1, 2, 3])
  224. >>> s1 == s2
  225. True
  226. >>> s3 = OrderedSet([3, 2, 1])
  227. >>> s1 == s3
  228. False
  229. """
  230. if not isinstance(other, OrderedSet):
  231. return NotImplemented
  232. return list(self) == list(other)
  233. def __repr__(self) -> str:
  234. data_repr = ", ".join(map(repr, self._data))
  235. return f"OrderedSet({data_repr})"