torchsummary.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import torch
  2. import torch.nn as nn
  3. from torch.autograd import Variable
  4. from collections import OrderedDict
  5. import numpy as np
  6. def summary(model, input_size, batch_size=-1, device="cuda"):
  7. def register_hook(module):
  8. def hook(module, input, output):
  9. class_name = str(module.__class__).split(".")[-1].split("'")[0]
  10. module_idx = len(summary)
  11. m_key = "%s-%i" % (class_name, module_idx + 1)
  12. summary[m_key] = OrderedDict()
  13. summary[m_key]["input_shape"] = list(input[0].size())
  14. summary[m_key]["input_shape"][0] = batch_size
  15. if isinstance(output, (list, tuple)):
  16. summary[m_key]["output_shape"] = [
  17. [-1] + list(o.size())[1:] for o in output
  18. ]
  19. else:
  20. summary[m_key]["output_shape"] = list(output.size())
  21. summary[m_key]["output_shape"][0] = batch_size
  22. params = 0
  23. if hasattr(module, "weight") and hasattr(module.weight, "size"):
  24. params += torch.prod(torch.LongTensor(list(module.weight.size())))
  25. summary[m_key]["trainable"] = module.weight.requires_grad
  26. if hasattr(module, "bias") and hasattr(module.bias, "size"):
  27. params += torch.prod(torch.LongTensor(list(module.bias.size())))
  28. summary[m_key]["nb_params"] = params
  29. if (
  30. not isinstance(module, nn.Sequential)
  31. and not isinstance(module, nn.ModuleList)
  32. and not (module == model)
  33. ):
  34. hooks.append(module.register_forward_hook(hook))
  35. device = device.lower()
  36. assert device in [
  37. "cuda",
  38. "cpu",
  39. ], "Input device is not valid, please specify 'cuda' or 'cpu'"
  40. if device == "cuda" and torch.cuda.is_available():
  41. dtype = torch.cuda.FloatTensor
  42. else:
  43. dtype = torch.FloatTensor
  44. # multiple inputs to the network
  45. if isinstance(input_size, tuple):
  46. input_size = [input_size]
  47. # batch_size of 2 for batchnorm
  48. x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
  49. # print(type(x[0]))
  50. # create properties
  51. summary = OrderedDict()
  52. hooks = []
  53. # register hook
  54. model.apply(register_hook)
  55. # make a forward pass
  56. # print(x.shape)
  57. model(*x)
  58. # remove these hooks
  59. for h in hooks:
  60. h.remove()
  61. print("----------------------------------------------------------------")
  62. line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
  63. print(line_new)
  64. print("================================================================")
  65. total_params = 0
  66. total_output = 0
  67. trainable_params = 0
  68. for layer in summary:
  69. # input_shape, output_shape, trainable, nb_params
  70. line_new = "{:>20} {:>25} {:>15}".format(
  71. layer,
  72. str(summary[layer]["output_shape"]),
  73. "{0:,}".format(summary[layer]["nb_params"]),
  74. )
  75. total_params += summary[layer]["nb_params"]
  76. total_output += np.prod(summary[layer]["output_shape"])
  77. if "trainable" in summary[layer]:
  78. if summary[layer]["trainable"] == True:
  79. trainable_params += summary[layer]["nb_params"]
  80. print(line_new)
  81. # assume 4 bytes/number (float on cuda).
  82. total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
  83. total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients
  84. total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
  85. total_size = total_params_size + total_output_size + total_input_size
  86. print("================================================================")
  87. print("Total params: {0:,}".format(total_params))
  88. print("Trainable params: {0:,}".format(trainable_params))
  89. print("Non-trainable params: {0:,}".format(total_params - trainable_params))
  90. print("----------------------------------------------------------------")
  91. print("Input size (MB): %0.2f" % total_input_size)
  92. print("Forward/backward pass size (MB): %0.2f" % total_output_size)
  93. print("Params size (MB): %0.2f" % total_params_size)
  94. print("Estimated Total Size (MB): %0.2f" % total_size)
  95. print("----------------------------------------------------------------")
  96. # return summary