debugging.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import logging
  16. import paddle
  17. from paddle.base.log_helper import get_logger
  18. _logger = get_logger(
  19. __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
  20. )
  21. class OperatorStatsUnit:
  22. def __init__(self):
  23. self.op_type = None
  24. self.fp32_calls = 0
  25. self.fp16_calls = 0
  26. self.bf16_calls = 0
  27. self.other_calls = 0
  28. def update(self, dtype):
  29. if dtype is None:
  30. self.other_calls = self.other_calls + 1
  31. else:
  32. if dtype == paddle.float32:
  33. self.fp32_calls = self.fp32_calls + 1
  34. elif dtype == paddle.float16:
  35. self.fp16_calls = self.fp16_calls + 1
  36. elif dtype == paddle.bfloat16:
  37. self.bf16_calls = self.bf16_calls + 1
  38. else:
  39. self.other_calls = self.other_calls + 1
  40. def addto(self, another):
  41. self.fp32_calls += another.fp32_calls
  42. self.fp16_calls += another.fp16_calls
  43. self.bf16_calls += another.bf16_calls
  44. self.other_calls += another.other_calls
  45. def convert_to_list(self):
  46. return [
  47. self.fp16_calls,
  48. self.bf16_calls,
  49. self.fp32_calls,
  50. self.other_calls,
  51. ]
  52. def _is_floating_point(dtype):
  53. if dtype in [
  54. paddle.base.core.VarDesc.VarType.FP64,
  55. paddle.base.core.VarDesc.VarType.FP32,
  56. paddle.base.core.VarDesc.VarType.FP16,
  57. paddle.base.core.VarDesc.VarType.BF16,
  58. ]:
  59. return True
  60. else:
  61. return False
  62. def _get_var_dtype_from_block(block, op, arg_name, is_input):
  63. var_names = op.input(arg_name) if is_input else op.output(arg_name)
  64. assert isinstance(var_names, list)
  65. if len(var_names) == 0:
  66. return None
  67. var_name = var_names[0]
  68. try:
  69. var = block._var_recursive(var_name)
  70. return var.dtype
  71. except:
  72. _logger.warning(
  73. "Operator < {} > gets {} < {} : {} > error!".format(
  74. op.type, "input" if is_input else "output", arg_name, var_name
  75. )
  76. )
  77. return None
  78. def _extract_compute_dtype(op, block):
  79. var_name = None
  80. compute_dtype = None
  81. for in_name in op.input_names:
  82. var_dtype = _get_var_dtype_from_block(block, op, in_name, True)
  83. if var_dtype is None:
  84. continue
  85. if compute_dtype is None:
  86. compute_dtype = var_dtype
  87. else:
  88. if compute_dtype != var_dtype:
  89. if _is_floating_point(compute_dtype) and _is_floating_point(
  90. var_dtype
  91. ):
  92. _logger.warning(
  93. f"Operator < {op.type} > has different input data types, input_names = {op.input_names}, output_names = {op.output_names}."
  94. )
  95. elif _is_floating_point(var_dtype):
  96. # When there are multiple inputs, such as embedding
  97. # (ids is integer, w is floating-point), the kernel
  98. # dtype is normally decided by the input of floating-point.
  99. compute_dtype = var_dtype
  100. for out_name in op.output_names:
  101. var_dtype = _get_var_dtype_from_block(block, op, out_name, False)
  102. if var_dtype is None:
  103. continue
  104. if compute_dtype is None:
  105. # Kernel dtype is mostly decided by the input's dtype.
  106. # When the operator has no input, it mightly has a attr
  107. # such as dtype to specify the output's dtype.
  108. compute_dtype = var_dtype
  109. else:
  110. if compute_dtype != var_dtype:
  111. if _is_floating_point(compute_dtype) and _is_floating_point(
  112. var_dtype
  113. ):
  114. _logger.warning(
  115. f"Operator < {op.type} > has different input / output data types, input_names = {op.input_names}, output_names = {op.output_names}."
  116. )
  117. return compute_dtype
  118. def _merge_op_stats(op_stats_list):
  119. merged_op_stats_dict = {}
  120. for each_op_stats_dict in op_stats_list:
  121. for op_type, unit in each_op_stats_dict.items():
  122. if merged_op_stats_dict.get(op_type, None) is None:
  123. merged_op_stats_dict[op_type] = copy.copy(unit)
  124. else:
  125. merged_op_stats_dict[op_type].addto(unit)
  126. return merged_op_stats_dict
  127. def _get_op_stats_list(program):
  128. def _is_special_ops_with_input_x(op_type):
  129. # operators have input X and have inputs different dtypes.
  130. special_op_list = ['cast', 'batch_norm', 'instance_norm', 'layer_norm']
  131. if op_type in special_op_list:
  132. return True
  133. if op_type.replace("_grad", "") in special_op_list:
  134. return True
  135. return False
  136. op_stats_list = []
  137. for block in program.blocks:
  138. block_op_stats_dict = {}
  139. for op in block.ops:
  140. if block_op_stats_dict.get(op.type, None) is None:
  141. unit = OperatorStatsUnit()
  142. block_op_stats_dict[op.type] = unit
  143. else:
  144. unit = block_op_stats_dict[op.type]
  145. if op.type in [
  146. 'create_py_reader',
  147. 'read',
  148. 'create_double_buffer_reader',
  149. ]:
  150. compute_dtype = None
  151. elif _is_special_ops_with_input_x(op.type):
  152. # Not check the input and output dtype difference for this operators.
  153. compute_dtype = _get_var_dtype_from_block(block, op, 'X', True)
  154. elif "Param" in op.input_names:
  155. # Specify compute_dtype for optimizers.
  156. compute_dtype = _get_var_dtype_from_block(
  157. block, op, 'Param', True
  158. )
  159. else:
  160. compute_dtype = _extract_compute_dtype(op, block)
  161. unit.update(dtype=compute_dtype)
  162. op_stats_list.append(block_op_stats_dict)
  163. return op_stats_list
  164. def collect_operator_stats(program=None, print_subblocks=False):
  165. """
  166. Collect the number of operators for different data types through parsing
  167. the program. The statistical data are categorized according to four data
  168. types, namely float32, float16, bfloat16 and others.
  169. Args:
  170. program(Program, optional): The program to parse. Default None, and the default main_program will be parsed.
  171. print_subblocks(bool, optional): Whether to print the operator stats for each subblock. Default False.
  172. Examples:
  173. .. code-block:: python
  174. >>> import paddle
  175. >>> paddle.enable_static()
  176. >>> class SimpleConvNet(paddle.nn.Layer):
  177. ... def __init__(self):
  178. ... super().__init__()
  179. ... self.conv = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3)
  180. ... self.linear = paddle.nn.Linear(in_features=26, out_features=10)
  181. ...
  182. ... def forward(self, x):
  183. ... out = self.conv(x)
  184. ... out = paddle.nn.functional.relu(out)
  185. ... out = self.linear(out)
  186. ... out = paddle.nn.functional.softmax(out)
  187. ... return out
  188. >>> main_program = paddle.static.Program()
  189. >>> startup_program = paddle.static.Program()
  190. >>> with paddle.utils.unique_name.guard():
  191. ... with paddle.static.program_guard(main_program, startup_program):
  192. ... model = SimpleConvNet()
  193. ... x = paddle.static.data(
  194. ... name='input', shape=[None, 1, 28, 28], dtype='float32'
  195. ... )
  196. ... out = model(x)
  197. ... loss = paddle.mean(out)
  198. ... optimizer = paddle.optimizer.AdamW()
  199. ... optimizer = paddle.static.amp.decorate(optimizer)
  200. ... optimizer.minimize(loss)
  201. >>> paddle.static.amp.debugging.collect_operator_stats(main_program)
  202. <------------------------------------------------ op list of all blocks ------------------------------------------------->
  203. <------------------------------------------------------- op list -------------------------------------------------------->
  204. <--------------- Op Name ---------------- | -- FP16 Calls --- | -- BF16 Calls --- | --- FP32 Calls--- | -- Other Calls -->
  205. adamw | 0 | 0 | 4 | 0
  206. cast | 5 | 0 | 6 | 0
  207. check_finite_and_unscale | 0 | 0 | 1 | 0
  208. conv2d | 1 | 0 | 0 | 0
  209. conv2d_grad | 1 | 0 | 0 | 0
  210. elementwise_add | 2 | 0 | 0 | 0
  211. elementwise_add_grad | 2 | 0 | 0 | 0
  212. elementwise_mul | 0 | 0 | 1 | 0
  213. elementwise_mul_grad | 0 | 0 | 1 | 0
  214. fill_constant | 0 | 0 | 1 | 0
  215. matmul_v2 | 1 | 0 | 0 | 0
  216. matmul_v2_grad | 1 | 0 | 0 | 0
  217. memcpy | 0 | 0 | 0 | 1
  218. reduce_mean | 0 | 0 | 1 | 0
  219. reduce_mean_grad | 0 | 0 | 1 | 0
  220. relu | 1 | 0 | 0 | 0
  221. relu_grad | 1 | 0 | 0 | 0
  222. reshape2 | 0 | 0 | 1 | 0
  223. reshape2_grad | 0 | 0 | 1 | 0
  224. softmax | 0 | 0 | 1 | 0
  225. softmax_grad | 0 | 0 | 1 | 0
  226. update_loss_scaling | 0 | 0 | 1 | 0
  227. <----------------------------------------------------- op count: 22 ----------------------------------------------------->
  228. """
  229. def _convert_to_list(op_stats_unit_dict):
  230. for key, value in op_stats_unit_dict.items():
  231. op_stats_unit_dict[key] = value.convert_to_list()
  232. return op_stats_unit_dict
  233. if program is None:
  234. program = paddle.static.default_main_program()
  235. op_stats_list = _get_op_stats_list(program)
  236. merged_op_stats = _merge_op_stats(op_stats_list)
  237. if print_subblocks and len(op_stats_list) > 1:
  238. for i in range(len(op_stats_list)):
  239. print("<{:-^120}>".format(" op list of block " + str(i) + " "))
  240. paddle.amp.debugging._print_operator_stats(
  241. _convert_to_list(op_stats_list[i])
  242. )
  243. print("<{:-^120}>".format(" op list of all blocks "))
  244. paddle.amp.debugging._print_operator_stats(
  245. _convert_to_list(merged_op_stats)
  246. )