pir_utils.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from functools import wraps
  15. import paddle
  16. from paddle.framework.dtype import bind_datatype, bind_vartype
  17. def _switch_to_pir_():
  18. paddle.base.framework.global_var._use_pir_api_ = True
  19. paddle.framework.set_flags({"FLAGS_enable_pir_in_executor": True})
  20. paddle.pir.register_paddle_dialect()
  21. # TODO find a better place to init the registion of dist dialect.
  22. paddle.pir.register_dist_dialect()
  23. paddle.base.Program = paddle.pir.Program
  24. paddle.base.program_guard = paddle.pir.core.program_guard
  25. paddle.base.default_main_program = paddle.pir.core.default_main_program
  26. paddle.base.default_startup_program = (
  27. paddle.pir.core.default_startup_program
  28. )
  29. paddle.static.Program = paddle.pir.Program
  30. paddle.static.program_guard = paddle.pir.core.program_guard
  31. paddle.static.default_main_program = paddle.pir.core.default_main_program
  32. paddle.static.default_startup_program = (
  33. paddle.pir.core.default_startup_program
  34. )
  35. def _switch_to_old_ir_():
  36. paddle.base.framework.global_var._use_pir_api_ = False
  37. paddle.framework.set_flags({"FLAGS_enable_pir_in_executor": False})
  38. paddle.base.Program = paddle.base.framework.Program
  39. paddle.base.program_guard = paddle.base.framework.program_guard
  40. paddle.base.default_main_program = (
  41. paddle.base.framework.default_main_program
  42. )
  43. paddle.base.default_startup_program = (
  44. paddle.base.framework.default_startup_program
  45. )
  46. paddle.static.Program = paddle.base.framework.Program
  47. paddle.static.program_guard = paddle.base.framework.program_guard
  48. paddle.static.default_main_program = (
  49. paddle.base.framework.default_main_program
  50. )
  51. paddle.static.default_startup_program = (
  52. paddle.base.framework.default_startup_program
  53. )
  54. class IrGuard:
  55. def __enter__(self):
  56. self.in_dygraph_outside = paddle.base.framework.in_dygraph_mode()
  57. self.old_flag = paddle.base.framework.get_flags("FLAGS_enable_pir_api")[
  58. "FLAGS_enable_pir_api"
  59. ]
  60. if self.in_dygraph_outside:
  61. paddle.enable_static()
  62. if not self.old_flag:
  63. paddle.framework.set_flags({"FLAGS_enable_pir_api": True})
  64. paddle.base.framework.global_var._use_pir_api_ = True
  65. bind_datatype()
  66. self._switch_to_pir()
  67. def __exit__(self, exc_type, exc_val, exc_tb):
  68. if self.in_dygraph_outside:
  69. paddle.disable_static()
  70. if not self.old_flag:
  71. paddle.framework.set_flags({"FLAGS_enable_pir_api": False})
  72. paddle.base.framework.global_var._use_pir_api_ = False
  73. bind_vartype()
  74. self._switch_to_old_ir()
  75. def _switch_to_pir(self):
  76. if paddle.base.framework.get_flags("FLAGS_enable_pir_api")[
  77. "FLAGS_enable_pir_api"
  78. ]:
  79. _switch_to_pir_()
  80. def _switch_to_old_ir(self):
  81. if not paddle.base.framework.get_flags("FLAGS_enable_pir_api")[
  82. "FLAGS_enable_pir_api"
  83. ]:
  84. _switch_to_old_ir_()
  85. else:
  86. raise RuntimeError(
  87. "IrGuard._switch_to_old_ir only work when paddle.framework.in_pir_mode() is false, \
  88. please set FLAGS_enable_pir_api = false"
  89. )
  90. class OldIrGuard:
  91. def __enter__(self):
  92. self.in_dygraph_outside = paddle.base.framework.in_dygraph_mode()
  93. self.old_flag = paddle.base.framework.get_flags("FLAGS_enable_pir_api")[
  94. "FLAGS_enable_pir_api"
  95. ]
  96. if self.in_dygraph_outside:
  97. paddle.enable_static()
  98. if self.old_flag:
  99. paddle.framework.set_flags({"FLAGS_enable_pir_api": False})
  100. paddle.base.framework.global_var._use_pir_api_ = False
  101. bind_vartype()
  102. _switch_to_old_ir_()
  103. def __exit__(self, exc_type, exc_val, exc_tb):
  104. if self.in_dygraph_outside:
  105. paddle.disable_static()
  106. if self.old_flag:
  107. paddle.framework.set_flags({"FLAGS_enable_pir_api": True})
  108. paddle.base.framework.global_var._use_pir_api_ = True
  109. bind_datatype()
  110. _switch_to_pir_()
  111. class DygraphPirGuard:
  112. def __enter__(self):
  113. self.old_flag = paddle.base.framework.get_flags("FLAGS_enable_pir_api")[
  114. "FLAGS_enable_pir_api"
  115. ]
  116. if not self.old_flag:
  117. paddle.framework.set_flags({"FLAGS_enable_pir_api": True})
  118. paddle.base.framework.global_var._use_pir_api_ = True
  119. bind_datatype()
  120. self._switch_to_pir()
  121. def __exit__(self, exc_type, exc_val, exc_tb):
  122. if not self.old_flag:
  123. paddle.framework.set_flags({"FLAGS_enable_pir_api": False})
  124. paddle.base.framework.global_var._use_pir_api_ = False
  125. bind_vartype()
  126. self._switch_to_old_ir()
  127. def _switch_to_pir(self):
  128. if paddle.base.framework.get_flags("FLAGS_enable_pir_api")[
  129. "FLAGS_enable_pir_api"
  130. ]:
  131. _switch_to_pir_()
  132. def _switch_to_old_ir(self):
  133. if not paddle.base.framework.get_flags("FLAGS_enable_pir_api")[
  134. "FLAGS_enable_pir_api"
  135. ]:
  136. _switch_to_old_ir_()
  137. else:
  138. raise RuntimeError(
  139. "IrGuard._switch_to_old_ir only work when paddle.framework.in_pir_mode() is false, \
  140. please set FLAGS_enable_pir_api = false"
  141. )
  142. class DygraphOldIrGuard:
  143. def __enter__(self):
  144. self.old_flag = paddle.base.framework.get_flags("FLAGS_enable_pir_api")[
  145. "FLAGS_enable_pir_api"
  146. ]
  147. if self.old_flag:
  148. paddle.framework.set_flags({"FLAGS_enable_pir_api": False})
  149. paddle.base.framework.global_var._use_pir_api_ = False
  150. bind_vartype()
  151. _switch_to_old_ir_()
  152. def __exit__(self, exc_type, exc_val, exc_tb):
  153. if self.old_flag:
  154. paddle.framework.set_flags({"FLAGS_enable_pir_api": True})
  155. paddle.base.framework.global_var._use_pir_api_ = True
  156. bind_datatype()
  157. _switch_to_pir_()
  158. def test_with_pir_api(func):
  159. @wraps(func)
  160. def impl(*args, **kwargs):
  161. with OldIrGuard():
  162. func(*args, **kwargs)
  163. with IrGuard():
  164. func(*args, **kwargs)
  165. return impl
  166. def test_with_old_ir_only(func):
  167. @wraps(func)
  168. def impl(*args, **kwargs):
  169. with OldIrGuard():
  170. func(*args, **kwargs)
  171. return impl
  172. def test_with_dygraph_pir(func):
  173. @wraps(func)
  174. def impl(*args, **kwargs):
  175. with DygraphOldIrGuard():
  176. func(*args, **kwargs)
  177. with DygraphPirGuard():
  178. func(*args, **kwargs)
  179. return impl