ir.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ..base.framework import _apply_pass
  15. from . import core
  16. def get_data_vars(program):
  17. data_vars = []
  18. for var_name, var in program.global_block().vars.items():
  19. if var.is_data:
  20. data_vars.append(var_name)
  21. return data_vars
  22. def _update_grad_persistable(main_program):
  23. grad_merge_attr_name = "grad_merge_cond_name"
  24. op_role_var_attr_name = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
  25. has_grad_merge = False
  26. has_persistable_grad_var = False
  27. grad_vars = []
  28. for block_id in range(main_program.num_blocks):
  29. block = main_program.block(block_id)
  30. for op in block.ops:
  31. if grad_merge_attr_name in op.attr_names:
  32. has_grad_merge = True
  33. if op_role_var_attr_name not in op.attr_names:
  34. continue
  35. p_g = op.attr(op_role_var_attr_name)
  36. for g in p_g[1::2]:
  37. g_var = block._find_var_recursive(g)
  38. if g_var is None:
  39. continue
  40. grad_vars.append(g_var)
  41. if g_var.persistable:
  42. has_persistable_grad_var = True
  43. if has_grad_merge and has_persistable_grad_var:
  44. for g_var in grad_vars:
  45. g_var.persistable = True
  46. def apply_build_strategy(
  47. main_program, startup_program, build_strategy, pass_attrs
  48. ):
  49. def update_attr(attrs, attr_types, name, value, typ=None):
  50. if name not in attrs:
  51. attrs[name] = value
  52. if typ:
  53. attr_types[name] = typ
  54. def apply_pass(name):
  55. attrs = dict(pass_attrs)
  56. attr_types = {}
  57. update_attr(attrs, attr_types, "nranks", 1, "size_t")
  58. update_attr(attrs, attr_types, "use_cuda", False, "bool")
  59. # TODO(zjl): how to skip fetch variables ?
  60. update_attr(
  61. attrs,
  62. attr_types,
  63. "mem_opt_skip_vars",
  64. get_data_vars(main_program),
  65. "list[str]",
  66. )
  67. _apply_pass(main_program, startup_program, name, attrs, attr_types)
  68. _update_grad_persistable(main_program)
  69. use_cuda = pass_attrs.get("use_cuda", False)
  70. build_strategy = build_strategy._copy()
  71. if build_strategy.sync_batch_norm:
  72. apply_pass("sync_batch_norm_pass")
  73. build_strategy.sync_batch_norm = False
  74. if build_strategy.fuse_relu_depthwise_conv and use_cuda:
  75. apply_pass("fuse_relu_depthwise_conv_pass")
  76. build_strategy.fuse_relu_depthwise_conv = False
  77. if build_strategy.fuse_resunit:
  78. apply_pass("fuse_resunit_pass")
  79. build_strategy.fuse_resunit = False
  80. if build_strategy.fuse_bn_act_ops and use_cuda:
  81. apply_pass("fuse_bn_act_pass")
  82. build_strategy.fuse_bn_act_ops = False
  83. if build_strategy.fuse_bn_add_act_ops and use_cuda:
  84. apply_pass("fuse_bn_add_act_pass")
  85. build_strategy.fuse_bn_add_act_ops = False
  86. if build_strategy.enable_auto_fusion and use_cuda:
  87. apply_pass("fusion_group_pass")
  88. build_strategy.enable_auto_fusion = False
  89. if build_strategy.fuse_gemm_epilogue:
  90. apply_pass("fuse_gemm_epilogue_pass")
  91. build_strategy.fuse_gemm_epilogue = False
  92. if build_strategy.fuse_dot_product_attention:
  93. apply_pass("fuse_dot_product_attention_pass")
  94. build_strategy.fuse_dot_product_attention = False
  95. if build_strategy.fuse_elewise_add_act_ops:
  96. apply_pass("fuse_elewise_add_act_pass")
  97. build_strategy.fuse_elewise_add_act_ops = False
  98. if build_strategy.fuse_all_optimizer_ops:
  99. apply_pass(
  100. [
  101. "coalesce_grad_tensor_pass",
  102. "fuse_adam_op_pass",
  103. "fuse_sgd_op_pass",
  104. "fuse_momentum_op_pass",
  105. ]
  106. )
  107. build_strategy.fuse_all_optimizer_ops = False
  108. # TODO(zjl): support fuse all reduce ops
  109. if build_strategy.cache_runtime_context:
  110. apply_pass("runtime_context_cache_pass")
  111. build_strategy.cache_runtime_context = False
  112. if build_strategy.enable_addto and use_cuda:
  113. # NOTE: how to get fetch vars to skip memory optimization?
  114. apply_pass("inplace_addto_op_pass")
  115. build_strategy.enable_addto = False
  116. if build_strategy.enable_inplace:
  117. apply_pass("buffer_shared_inplace_pass")
  118. build_strategy.enable_inplace = False
  119. build_strategy._clear_finalized()
  120. return build_strategy