model_summary.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numbers
  15. import warnings
  16. from collections import OrderedDict
  17. import numpy as np
  18. import paddle
  19. from paddle import nn
  20. from paddle.autograd import no_grad
  21. from paddle.static import InputSpec
  22. __all__ = []
  23. def summary(net, input_size=None, dtypes=None, input=None):
  24. """Prints a string summary of the network.
  25. Args:
  26. net (Layer): The network which must be a subinstance of Layer.
  27. input_size (tuple|InputSpec|list[tuple|InputSpec], optional): Size of input tensor. if model only
  28. have one input, input_size can be tuple or InputSpec. if model
  29. have multiple input, input_size must be a list which contain
  30. every input's shape. Note that input_size only dim of
  31. batch_size can be None or -1. Default: None. Note that
  32. input_size and input cannot be None at the same time.
  33. dtypes (str, optional): If dtypes is None, 'float32' will be used, Default: None.
  34. input (Tensor, optional): If input is given, input_size and dtype will be ignored, Default: None.
  35. Returns:
  36. Dict: A summary of the network including total params and total trainable params.
  37. Examples:
  38. .. code-block:: python
  39. :name: code-example-1
  40. >>> # example 1: Single Input Demo
  41. >>> import paddle
  42. >>> import paddle.nn as nn
  43. >>> # Define Network
  44. >>> class LeNet(nn.Layer):
  45. ... def __init__(self, num_classes=10):
  46. ... super().__init__()
  47. ... self.num_classes = num_classes
  48. ... self.features = nn.Sequential(
  49. ... nn.Conv2D(1, 6, 3, stride=1, padding=1),
  50. ... nn.ReLU(),
  51. ... nn.MaxPool2D(2, 2),
  52. ... nn.Conv2D(6, 16, 5, stride=1, padding=0),
  53. ... nn.ReLU(),
  54. ... nn.MaxPool2D(2, 2))
  55. ...
  56. ... if num_classes > 0:
  57. ... self.fc = nn.Sequential(
  58. ... nn.Linear(400, 120),
  59. ... nn.Linear(120, 84),
  60. ... nn.Linear(84, 10))
  61. ...
  62. ... def forward(self, inputs):
  63. ... x = self.features(inputs)
  64. ...
  65. ... if self.num_classes > 0:
  66. ... x = paddle.flatten(x, 1)
  67. ... x = self.fc(x)
  68. ... return x
  69. ...
  70. >>> lenet = LeNet()
  71. >>> params_info = paddle.summary(lenet, (1, 1, 28, 28)) # doctest: +NORMALIZE_WHITESPACE
  72. ---------------------------------------------------------------------------
  73. Layer (type) Input Shape Output Shape Param #
  74. ===========================================================================
  75. Conv2D-1 [[1, 1, 28, 28]] [1, 6, 28, 28] 60
  76. ReLU-1 [[1, 6, 28, 28]] [1, 6, 28, 28] 0
  77. MaxPool2D-1 [[1, 6, 28, 28]] [1, 6, 14, 14] 0
  78. Conv2D-2 [[1, 6, 14, 14]] [1, 16, 10, 10] 2,416
  79. ReLU-2 [[1, 16, 10, 10]] [1, 16, 10, 10] 0
  80. MaxPool2D-2 [[1, 16, 10, 10]] [1, 16, 5, 5] 0
  81. Linear-1 [[1, 400]] [1, 120] 48,120
  82. Linear-2 [[1, 120]] [1, 84] 10,164
  83. Linear-3 [[1, 84]] [1, 10] 850
  84. ===========================================================================
  85. Total params: 61,610
  86. Trainable params: 61,610
  87. Non-trainable params: 0
  88. ---------------------------------------------------------------------------
  89. Input size (MB): 0.00
  90. Forward/backward pass size (MB): 0.11
  91. Params size (MB): 0.24
  92. Estimated Total Size (MB): 0.35
  93. ---------------------------------------------------------------------------
  94. <BLANKLINE>
  95. >>> print(params_info)
  96. {'total_params': 61610, 'trainable_params': 61610}
  97. .. code-block:: python
  98. :name: code-example-2
  99. >>> # example 2: multi input demo
  100. >>> import paddle
  101. >>> import paddle.nn as nn
  102. >>> class LeNetMultiInput(nn.Layer):
  103. ... def __init__(self, num_classes=10):
  104. ... super().__init__()
  105. ... self.num_classes = num_classes
  106. ... self.features = nn.Sequential(
  107. ... nn.Conv2D(1, 6, 3, stride=1, padding=1),
  108. ... nn.ReLU(),
  109. ... nn.MaxPool2D(2, 2),
  110. ... nn.Conv2D(6, 16, 5, stride=1, padding=0),
  111. ... nn.ReLU(),
  112. ... nn.MaxPool2D(2, 2))
  113. ...
  114. ... if num_classes > 0:
  115. ... self.fc = nn.Sequential(
  116. ... nn.Linear(400, 120),
  117. ... nn.Linear(120, 84),
  118. ... nn.Linear(84, 10))
  119. ...
  120. ... def forward(self, inputs, y):
  121. ... x = self.features(inputs)
  122. ...
  123. ... if self.num_classes > 0:
  124. ... x = paddle.flatten(x, 1)
  125. ... x = self.fc(x + y)
  126. ... return x
  127. ...
  128. >>> lenet_multi_input = LeNetMultiInput()
  129. >>> params_info = paddle.summary(lenet_multi_input,
  130. ... [(1, 1, 28, 28), (1, 400)],
  131. ... dtypes=['float32', 'float32']) # doctest: +NORMALIZE_WHITESPACE
  132. ---------------------------------------------------------------------------
  133. Layer (type) Input Shape Output Shape Param #
  134. ===========================================================================
  135. Conv2D-1 [[1, 1, 28, 28]] [1, 6, 28, 28] 60
  136. ReLU-1 [[1, 6, 28, 28]] [1, 6, 28, 28] 0
  137. MaxPool2D-1 [[1, 6, 28, 28]] [1, 6, 14, 14] 0
  138. Conv2D-2 [[1, 6, 14, 14]] [1, 16, 10, 10] 2,416
  139. ReLU-2 [[1, 16, 10, 10]] [1, 16, 10, 10] 0
  140. MaxPool2D-2 [[1, 16, 10, 10]] [1, 16, 5, 5] 0
  141. Linear-1 [[1, 400]] [1, 120] 48,120
  142. Linear-2 [[1, 120]] [1, 84] 10,164
  143. Linear-3 [[1, 84]] [1, 10] 850
  144. ===========================================================================
  145. Total params: 61,610
  146. Trainable params: 61,610
  147. Non-trainable params: 0
  148. ---------------------------------------------------------------------------
  149. Input size (MB): 0.00
  150. Forward/backward pass size (MB): 0.11
  151. Params size (MB): 0.24
  152. Estimated Total Size (MB): 0.35
  153. ---------------------------------------------------------------------------
  154. <BLANKLINE>
  155. >>> print(params_info)
  156. {'total_params': 61610, 'trainable_params': 61610}
  157. .. code-block:: python
  158. :name: code-example-3
  159. >>> # example 3: List/Dict Input Demo
  160. >>> import paddle
  161. >>> import paddle.nn as nn
  162. >>> # list input demo
  163. >>> class LeNetListInput(nn.Layer):
  164. ... def __init__(self, num_classes=10):
  165. ... super().__init__()
  166. ... self.num_classes = num_classes
  167. ... self.features = nn.Sequential(
  168. ... nn.Conv2D(1, 6, 3, stride=1, padding=1),
  169. ... nn.ReLU(),
  170. ... nn.MaxPool2D(2, 2),
  171. ... nn.Conv2D(6, 16, 5, stride=1, padding=0),
  172. ... nn.ReLU(),
  173. ... nn.MaxPool2D(2, 2))
  174. ...
  175. ... if num_classes > 0:
  176. ... self.fc = nn.Sequential(
  177. ... nn.Linear(400, 120),
  178. ... nn.Linear(120, 84),
  179. ... nn.Linear(84, 10))
  180. ...
  181. ... def forward(self, inputs):
  182. ... x = self.features(inputs[0])
  183. ...
  184. ... if self.num_classes > 0:
  185. ... x = paddle.flatten(x, 1)
  186. ... x = self.fc(x + inputs[1])
  187. ... return x
  188. ...
  189. >>> lenet_list_input = LeNetListInput()
  190. >>> input_data = [paddle.rand([1, 1, 28, 28]), paddle.rand([1, 400])]
  191. >>> params_info = paddle.summary(lenet_list_input, input=input_data) # doctest: +NORMALIZE_WHITESPACE
  192. ---------------------------------------------------------------------------
  193. Layer (type) Input Shape Output Shape Param #
  194. ===========================================================================
  195. Conv2D-1 [[1, 1, 28, 28]] [1, 6, 28, 28] 60
  196. ReLU-1 [[1, 6, 28, 28]] [1, 6, 28, 28] 0
  197. MaxPool2D-1 [[1, 6, 28, 28]] [1, 6, 14, 14] 0
  198. Conv2D-2 [[1, 6, 14, 14]] [1, 16, 10, 10] 2,416
  199. ReLU-2 [[1, 16, 10, 10]] [1, 16, 10, 10] 0
  200. MaxPool2D-2 [[1, 16, 10, 10]] [1, 16, 5, 5] 0
  201. Linear-1 [[1, 400]] [1, 120] 48,120
  202. Linear-2 [[1, 120]] [1, 84] 10,164
  203. Linear-3 [[1, 84]] [1, 10] 850
  204. ===========================================================================
  205. Total params: 61,610
  206. Trainable params: 61,610
  207. Non-trainable params: 0
  208. ---------------------------------------------------------------------------
  209. Input size (MB): 0.00
  210. Forward/backward pass size (MB): 0.11
  211. Params size (MB): 0.24
  212. Estimated Total Size (MB): 0.35
  213. ---------------------------------------------------------------------------
  214. <BLANKLINE>
  215. >>> print(params_info)
  216. {'total_params': 61610, 'trainable_params': 61610}
  217. >>> # dict input demo
  218. >>> class LeNetDictInput(nn.Layer):
  219. ... def __init__(self, num_classes=10):
  220. ... super().__init__()
  221. ... self.num_classes = num_classes
  222. ... self.features = nn.Sequential(
  223. ... nn.Conv2D(1, 6, 3, stride=1, padding=1),
  224. ... nn.ReLU(),
  225. ... nn.MaxPool2D(2, 2),
  226. ... nn.Conv2D(6, 16, 5, stride=1, padding=0),
  227. ... nn.ReLU(),
  228. ... nn.MaxPool2D(2, 2))
  229. ...
  230. ... if num_classes > 0:
  231. ... self.fc = nn.Sequential(
  232. ... nn.Linear(400, 120),
  233. ... nn.Linear(120, 84),
  234. ... nn.Linear(84, 10))
  235. ...
  236. ... def forward(self, inputs):
  237. ... x = self.features(inputs['x1'])
  238. ...
  239. ... if self.num_classes > 0:
  240. ... x = paddle.flatten(x, 1)
  241. ... x = self.fc(x + inputs['x2'])
  242. ... return x
  243. ...
  244. >>> lenet_dict_input = LeNetDictInput()
  245. >>> input_data = {'x1': paddle.rand([1, 1, 28, 28]),
  246. ... 'x2': paddle.rand([1, 400])}
  247. >>> # The module suffix number indicates its sequence in modules of the same type, used for differentiation identification
  248. >>> params_info = paddle.summary(lenet_dict_input, input=input_data) # doctest: +NORMALIZE_WHITESPACE
  249. ---------------------------------------------------------------------------
  250. Layer (type) Input Shape Output Shape Param #
  251. ===========================================================================
  252. Conv2D-3 [[1, 1, 28, 28]] [1, 6, 28, 28] 60
  253. ReLU-3 [[1, 6, 28, 28]] [1, 6, 28, 28] 0
  254. MaxPool2D-3 [[1, 6, 28, 28]] [1, 6, 14, 14] 0
  255. Conv2D-4 [[1, 6, 14, 14]] [1, 16, 10, 10] 2,416
  256. ReLU-4 [[1, 16, 10, 10]] [1, 16, 10, 10] 0
  257. MaxPool2D-4 [[1, 16, 10, 10]] [1, 16, 5, 5] 0
  258. Linear-4 [[1, 400]] [1, 120] 48,120
  259. Linear-5 [[1, 120]] [1, 84] 10,164
  260. Linear-6 [[1, 84]] [1, 10] 850
  261. ===========================================================================
  262. Total params: 61,610
  263. Trainable params: 61,610
  264. Non-trainable params: 0
  265. ---------------------------------------------------------------------------
  266. Input size (MB): 0.00
  267. Forward/backward pass size (MB): 0.11
  268. Params size (MB): 0.24
  269. Estimated Total Size (MB): 0.35
  270. ---------------------------------------------------------------------------
  271. <BLANKLINE>
  272. >>> print(params_info)
  273. {'total_params': 61610, 'trainable_params': 61610}
  274. """
  275. if input_size is None and input is None:
  276. raise ValueError("input_size and input cannot be None at the same time")
  277. if input_size is None and input is not None:
  278. if paddle.is_tensor(input):
  279. input_size = tuple(input.shape)
  280. elif isinstance(input, (list, tuple)):
  281. input_size = []
  282. for x in input:
  283. input_size.append(tuple(x.shape))
  284. elif isinstance(input, dict):
  285. input_size = []
  286. for key in input.keys():
  287. input_size.append(tuple(input[key].shape))
  288. elif isinstance(input, paddle.base.framework.Variable):
  289. input_size = tuple(input.shape)
  290. else:
  291. raise ValueError(
  292. "Input is not tensor, list, tuple and dict, unable to determine input_size, please input input_size."
  293. )
  294. if isinstance(input_size, InputSpec):
  295. _input_size = tuple(input_size.shape)
  296. elif isinstance(input_size, list):
  297. _input_size = []
  298. for item in input_size:
  299. if isinstance(item, int):
  300. item = (item,)
  301. assert isinstance(
  302. item, (tuple, InputSpec)
  303. ), f'When input_size is list, \
  304. expect item in input_size is a tuple or InputSpec, but got {type(item)}'
  305. if isinstance(item, InputSpec):
  306. _input_size.append(tuple(item.shape))
  307. else:
  308. _input_size.append(item)
  309. elif isinstance(input_size, int):
  310. _input_size = (input_size,)
  311. else:
  312. _input_size = input_size
  313. if not paddle.in_dynamic_mode():
  314. warnings.warn(
  315. "Your model was created in static graph mode, this may not get correct summary information!"
  316. )
  317. in_train_mode = False
  318. else:
  319. in_train_mode = net.training
  320. if in_train_mode:
  321. net.eval()
  322. def _is_shape(shape):
  323. for item in shape:
  324. if isinstance(item, (list, tuple)):
  325. return False
  326. return True
  327. def _check_shape(shape):
  328. num_unknown = 0
  329. new_shape = []
  330. for i in range(len(shape)):
  331. item = shape[i]
  332. if item is None or item == -1:
  333. num_unknown += 1
  334. if num_unknown > 1:
  335. raise ValueError(
  336. 'Option input_size only the dim of batch_size can be None or -1.'
  337. )
  338. item = 1
  339. elif isinstance(item, numbers.Number):
  340. if item <= 0:
  341. raise ValueError(
  342. f"Expected element in input size greater than zero, but got {item}"
  343. )
  344. new_shape.append(item)
  345. return tuple(new_shape)
  346. def _check_input(input_size):
  347. if isinstance(input_size, (list, tuple)) and _is_shape(input_size):
  348. return _check_shape(input_size)
  349. else:
  350. return [_check_input(i) for i in input_size]
  351. _input_size = _check_input(_input_size)
  352. result, params_info = summary_string(net, _input_size, dtypes, input)
  353. print(result)
  354. if in_train_mode:
  355. net.train()
  356. return params_info
  357. @no_grad()
  358. def summary_string(model, input_size=None, dtypes=None, input=None):
  359. def _all_is_numper(items):
  360. for item in items:
  361. if not isinstance(item, numbers.Number):
  362. return False
  363. return True
  364. def _build_dtypes(input_size, dtype):
  365. if dtype is None:
  366. dtype = 'float32'
  367. if isinstance(input_size, (list, tuple)) and _all_is_numper(input_size):
  368. return [dtype]
  369. else:
  370. return [_build_dtypes(i, dtype) for i in input_size]
  371. if not isinstance(dtypes, (list, tuple)):
  372. dtypes = _build_dtypes(input_size, dtypes)
  373. batch_size = 1
  374. summary_str = ''
  375. depth = len(list(model.sublayers()))
  376. def _get_shape_from_tensor(x):
  377. if isinstance(x, (paddle.base.Variable, paddle.base.core.eager.Tensor)):
  378. return list(x.shape)
  379. elif isinstance(x, (list, tuple)):
  380. return [_get_shape_from_tensor(xx) for xx in x]
  381. def _get_output_shape(output):
  382. if isinstance(output, (list, tuple)):
  383. output_shape = [_get_output_shape(o) for o in output]
  384. elif hasattr(output, 'shape'):
  385. output_shape = list(output.shape)
  386. else:
  387. output_shape = []
  388. return output_shape
  389. def register_hook(layer):
  390. def hook(layer, input, output):
  391. class_name = str(layer.__class__).split(".")[-1].split("'")[0]
  392. try:
  393. layer_idx = int(layer._full_name.split('_')[-1])
  394. except:
  395. layer_idx = len(summary)
  396. m_key = "%s-%i" % (class_name, layer_idx + 1)
  397. summary[m_key] = OrderedDict()
  398. try:
  399. summary[m_key]["input_shape"] = _get_shape_from_tensor(input)
  400. except:
  401. warnings.warn('Get layer {} input shape failed!')
  402. summary[m_key]["input_shape"] = []
  403. try:
  404. summary[m_key]["output_shape"] = _get_output_shape(output)
  405. except:
  406. warnings.warn('Get layer {} output shape failed!')
  407. summary[m_key]["output_shape"]
  408. params = 0
  409. if paddle.in_dynamic_mode():
  410. layer_state_dict = layer._parameters
  411. else:
  412. layer_state_dict = layer.state_dict()
  413. summary[m_key]["trainable_params"] = 0
  414. trainable_flag = False
  415. for k, v in layer_state_dict.items():
  416. params += np.prod(v.shape)
  417. try:
  418. if (getattr(layer, k).trainable) and (
  419. not getattr(layer, k).stop_gradient
  420. ):
  421. summary[m_key]["trainable_params"] += np.prod(v.shape)
  422. summary[m_key]["trainable"] = True
  423. trainable_flag = True
  424. elif not trainable_flag:
  425. summary[m_key]["trainable"] = False
  426. except:
  427. summary[m_key]["trainable"] = True
  428. summary[m_key]["nb_params"] = params
  429. if (
  430. not isinstance(layer, nn.Sequential)
  431. and not isinstance(layer, nn.LayerList)
  432. and (not (layer == model) or depth < 1)
  433. ):
  434. hooks.append(layer.register_forward_post_hook(hook))
  435. # For rnn, gru and lstm layer
  436. elif hasattr(layer, 'could_use_cudnn') and layer.could_use_cudnn:
  437. hooks.append(layer.register_forward_post_hook(hook))
  438. if isinstance(input_size, tuple):
  439. input_size = [input_size]
  440. def build_input(input_size, dtypes):
  441. if isinstance(input_size, (list, tuple)) and _all_is_numper(input_size):
  442. if isinstance(dtypes, (list, tuple)):
  443. dtype = dtypes[0]
  444. else:
  445. dtype = dtypes
  446. return paddle.cast(paddle.rand(list(input_size)), dtype)
  447. else:
  448. return [
  449. build_input(i, dtype) for i, dtype in zip(input_size, dtypes)
  450. ]
  451. # create properties
  452. summary = OrderedDict()
  453. hooks = []
  454. # register hook
  455. model.apply(register_hook)
  456. if input is not None:
  457. x = input
  458. model(x)
  459. else:
  460. x = build_input(input_size, dtypes)
  461. # make a forward pass
  462. model(*x)
  463. # remove these hooks
  464. for h in hooks:
  465. h.remove()
  466. def _get_str_length(summary):
  467. head_length = {
  468. 'layer_width': 15,
  469. 'input_shape_width': 20,
  470. 'output_shape_width': 20,
  471. 'params_width': 15,
  472. 'table_width': 75,
  473. }
  474. for layer in summary:
  475. if head_length['output_shape_width'] < len(
  476. str(summary[layer]["output_shape"])
  477. ):
  478. head_length['output_shape_width'] = len(
  479. str(summary[layer]["output_shape"])
  480. )
  481. if head_length['input_shape_width'] < len(
  482. str(summary[layer]["input_shape"])
  483. ):
  484. head_length['input_shape_width'] = len(
  485. str(summary[layer]["input_shape"])
  486. )
  487. if head_length['layer_width'] < len(str(layer)):
  488. head_length['layer_width'] = len(str(layer))
  489. if head_length['params_width'] < len(
  490. str(summary[layer]["nb_params"])
  491. ):
  492. head_length['params_width'] = len(
  493. str(summary[layer]["nb_params"])
  494. )
  495. _temp_width = 0
  496. for k, v in head_length.items():
  497. if k != 'table_width':
  498. _temp_width += v
  499. if head_length['table_width'] < _temp_width + 5:
  500. head_length['table_width'] = _temp_width + 5
  501. return head_length
  502. table_width = _get_str_length(summary)
  503. summary_str += "-" * table_width['table_width'] + "\n"
  504. line_new = "{:^{}} {:^{}} {:^{}} {:^{}}".format(
  505. "Layer (type)",
  506. table_width['layer_width'],
  507. "Input Shape",
  508. table_width['input_shape_width'],
  509. "Output Shape",
  510. table_width['output_shape_width'],
  511. "Param #",
  512. table_width['params_width'],
  513. )
  514. summary_str += line_new + "\n"
  515. summary_str += "=" * table_width['table_width'] + "\n"
  516. total_params = 0
  517. total_output = 0
  518. trainable_params = 0
  519. max_length = 0
  520. for layer in summary:
  521. # input_shape, output_shape, trainable, nb_params
  522. line_new = "{:^{}} {:^{}} {:^{}} {:^{}}".format(
  523. layer,
  524. table_width['layer_width'],
  525. str(summary[layer]["input_shape"]),
  526. table_width['input_shape_width'],
  527. str(summary[layer]["output_shape"]),
  528. table_width['output_shape_width'],
  529. "{:,}".format(summary[layer]["nb_params"]),
  530. table_width['params_width'],
  531. )
  532. total_params += summary[layer]["nb_params"]
  533. try:
  534. total_output += np.sum(
  535. np.prod(summary[layer]["output_shape"], axis=-1)
  536. )
  537. except:
  538. for output_shape in summary[layer]["output_shape"]:
  539. total_output += np.sum(np.prod(output_shape, axis=-1))
  540. if "trainable" in summary[layer]:
  541. if summary[layer]["trainable"]:
  542. trainable_params += summary[layer]["trainable_params"]
  543. summary_str += line_new + "\n"
  544. def _get_input_size(input_size, size):
  545. if isinstance(input_size, (list, tuple)) and _all_is_numper(input_size):
  546. size = abs(np.prod(input_size) * 4.0 / (1024**2.0))
  547. else:
  548. size = sum([_get_input_size(i, size) for i in input_size])
  549. return size
  550. total_input_size = _get_input_size(input_size, 0)
  551. total_output_size = abs(
  552. 2.0 * total_output * 4.0 / (1024**2.0)
  553. ) # x2 for gradients
  554. total_params_size = abs(total_params * 4.0 / (1024**2.0))
  555. total_size = total_params_size + total_output_size + total_input_size
  556. summary_str += "=" * table_width['table_width'] + "\n"
  557. summary_str += f"Total params: {total_params:,}" + "\n"
  558. summary_str += f"Trainable params: {trainable_params:,}" + "\n"
  559. summary_str += (
  560. f"Non-trainable params: {total_params - trainable_params:,}" + "\n"
  561. )
  562. summary_str += "-" * table_width['table_width'] + "\n"
  563. summary_str += "Input size (MB): %0.2f" % total_input_size + "\n"
  564. summary_str += (
  565. "Forward/backward pass size (MB): %0.2f" % total_output_size + "\n"
  566. )
  567. summary_str += "Params size (MB): %0.2f" % total_params_size + "\n"
  568. summary_str += "Estimated Total Size (MB): %0.2f" % total_size + "\n"
  569. summary_str += "-" * table_width['table_width'] + "\n"
  570. # return summary
  571. return summary_str, {
  572. 'total_params': total_params,
  573. 'trainable_params': trainable_params,
  574. }