utils.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import collections.abc
  3. import math
  4. import os.path as osp
  5. from itertools import repeat
  6. import numpy as np
  7. import torch
  8. from mmcls.datasets.base_dataset import BaseDataset
  9. def get_trained_checkpoints_name(work_path):
  10. import os
  11. file_list = os.listdir(work_path)
  12. last = 0
  13. model_name = None
  14. # find the best model
  15. if model_name is None:
  16. for f_name in file_list:
  17. if 'best_' in f_name and f_name.endswith('.pth'):
  18. best_epoch = f_name.replace('.pth', '').split('_')[-1]
  19. if best_epoch.isdigit():
  20. last = int(best_epoch)
  21. model_name = f_name
  22. return model_name
  23. # or find the latest model
  24. if model_name is None:
  25. for f_name in file_list:
  26. if 'epoch_' in f_name and f_name.endswith('.pth'):
  27. epoch_num = f_name.replace('epoch_', '').replace('.pth', '')
  28. if not epoch_num.isdigit():
  29. continue
  30. ind = int(epoch_num)
  31. if ind > last:
  32. last = ind
  33. model_name = f_name
  34. return model_name
  35. def preprocess_transform(cfgs):
  36. if cfgs is None:
  37. return None
  38. for i, cfg in enumerate(cfgs):
  39. if cfg.type == 'Resize':
  40. if isinstance(cfg.size, list):
  41. cfgs[i].size = tuple(cfg.size)
  42. return cfgs
  43. def get_ms_dataset_root(ms_dataset):
  44. if ms_dataset is None or len(ms_dataset) < 1:
  45. return None
  46. try:
  47. data_root = ms_dataset[0]['image:FILE'].split('extracted')[0]
  48. path_post = ms_dataset[0]['image:FILE'].split('extracted')[1].split(
  49. '/')
  50. extracted_data_root = osp.join(data_root, 'extracted', path_post[1],
  51. path_post[2])
  52. return extracted_data_root
  53. except Exception as e:
  54. raise ValueError(f'Dataset Error: {e}')
  55. return None
  56. def get_classes(classes=None):
  57. import mmcv
  58. if isinstance(classes, str):
  59. # take it as a file path
  60. class_names = mmcv.list_from_file(classes)
  61. elif isinstance(classes, (tuple, list)):
  62. class_names = classes
  63. else:
  64. raise ValueError(f'Unsupported type {type(classes)} of classes.')
  65. return class_names
  66. class MmDataset(BaseDataset):
  67. def __init__(self,
  68. ms_dataset,
  69. pipeline,
  70. classes=None,
  71. test_mode=False,
  72. data_prefix=''):
  73. self.ms_dataset = ms_dataset
  74. if len(self.ms_dataset) < 1:
  75. raise ValueError('Dataset Error: dataset is empty')
  76. super(MmDataset, self).__init__(
  77. data_prefix=data_prefix,
  78. pipeline=pipeline,
  79. classes=classes,
  80. test_mode=test_mode)
  81. def load_annotations(self):
  82. if self.CLASSES is None:
  83. raise ValueError(
  84. f'Dataset Error: Not found classesname.txt: {self.CLASSES}')
  85. data_infos = []
  86. for data_info in self.ms_dataset:
  87. filename = data_info['image:FILE']
  88. gt_label = data_info['category']
  89. info = {'img_prefix': self.data_prefix}
  90. info['img_info'] = {'filename': filename}
  91. info['gt_label'] = np.array(gt_label, dtype=np.int64)
  92. data_infos.append(info)
  93. return data_infos
  94. def _trunc_normal_(tensor, mean, std, a, b):
  95. # Cut & paste from PyTorch official master until it's in a few official releases - RW
  96. # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
  97. def norm_cdf(x):
  98. # Computes standard normal cumulative distribution function
  99. return (1. + math.erf(x / math.sqrt(2.))) / 2.
  100. if (mean < a - 2 * std) or (mean > b + 2 * std):
  101. warnings.warn(
  102. 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
  103. 'The distribution of values may be incorrect.',
  104. stacklevel=2)
  105. # Values are generated by using a truncated uniform distribution and
  106. # then using the inverse CDF for the normal distribution.
  107. # Get upper and lower cdf values
  108. v = norm_cdf((a - mean) / std)
  109. u = norm_cdf((b - mean) / std)
  110. # Uniformly fill tensor with values from [v, u], then translate to
  111. # [2v-1, 2u-1].
  112. tensor.uniform_(2 * v - 1, 2 * u - 1)
  113. # Use inverse cdf transform for normal distribution to get truncated
  114. # standard normal
  115. tensor.erfinv_()
  116. # Transform to proper mean, std
  117. tensor.mul_(std * math.sqrt(2.))
  118. tensor.add_(mean)
  119. # Clamp to ensure it's in the proper range
  120. tensor.clamp_(min=a, max=b)
  121. return tensor
  122. def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
  123. # type: (Tensor, float, float, float, float) -> Tensor
  124. with torch.no_grad():
  125. return _trunc_normal_(tensor, mean, std, a, b)
  126. # From PyTorch internals
  127. def _ntuple(n):
  128. def parse(x):
  129. if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
  130. return x
  131. return tuple(repeat(x, n))
  132. return parse
  133. to_1tuple = _ntuple(1)
  134. to_2tuple = _ntuple(2)
  135. to_3tuple = _ntuple(3)
  136. to_4tuple = _ntuple(4)
  137. to_ntuple = _ntuple