| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import collections.abc
- import math
- import os.path as osp
- from itertools import repeat
- import numpy as np
- import torch
- from mmcls.datasets.base_dataset import BaseDataset
- def get_trained_checkpoints_name(work_path):
- import os
- file_list = os.listdir(work_path)
- last = 0
- model_name = None
- # find the best model
- if model_name is None:
- for f_name in file_list:
- if 'best_' in f_name and f_name.endswith('.pth'):
- best_epoch = f_name.replace('.pth', '').split('_')[-1]
- if best_epoch.isdigit():
- last = int(best_epoch)
- model_name = f_name
- return model_name
- # or find the latest model
- if model_name is None:
- for f_name in file_list:
- if 'epoch_' in f_name and f_name.endswith('.pth'):
- epoch_num = f_name.replace('epoch_', '').replace('.pth', '')
- if not epoch_num.isdigit():
- continue
- ind = int(epoch_num)
- if ind > last:
- last = ind
- model_name = f_name
- return model_name
- def preprocess_transform(cfgs):
- if cfgs is None:
- return None
- for i, cfg in enumerate(cfgs):
- if cfg.type == 'Resize':
- if isinstance(cfg.size, list):
- cfgs[i].size = tuple(cfg.size)
- return cfgs
- def get_ms_dataset_root(ms_dataset):
- if ms_dataset is None or len(ms_dataset) < 1:
- return None
- try:
- data_root = ms_dataset[0]['image:FILE'].split('extracted')[0]
- path_post = ms_dataset[0]['image:FILE'].split('extracted')[1].split(
- '/')
- extracted_data_root = osp.join(data_root, 'extracted', path_post[1],
- path_post[2])
- return extracted_data_root
- except Exception as e:
- raise ValueError(f'Dataset Error: {e}')
- return None
- def get_classes(classes=None):
- import mmcv
- if isinstance(classes, str):
- # take it as a file path
- class_names = mmcv.list_from_file(classes)
- elif isinstance(classes, (tuple, list)):
- class_names = classes
- else:
- raise ValueError(f'Unsupported type {type(classes)} of classes.')
- return class_names
- class MmDataset(BaseDataset):
- def __init__(self,
- ms_dataset,
- pipeline,
- classes=None,
- test_mode=False,
- data_prefix=''):
- self.ms_dataset = ms_dataset
- if len(self.ms_dataset) < 1:
- raise ValueError('Dataset Error: dataset is empty')
- super(MmDataset, self).__init__(
- data_prefix=data_prefix,
- pipeline=pipeline,
- classes=classes,
- test_mode=test_mode)
- def load_annotations(self):
- if self.CLASSES is None:
- raise ValueError(
- f'Dataset Error: Not found classesname.txt: {self.CLASSES}')
- data_infos = []
- for data_info in self.ms_dataset:
- filename = data_info['image:FILE']
- gt_label = data_info['category']
- info = {'img_prefix': self.data_prefix}
- info['img_info'] = {'filename': filename}
- info['gt_label'] = np.array(gt_label, dtype=np.int64)
- data_infos.append(info)
- return data_infos
- def _trunc_normal_(tensor, mean, std, a, b):
- # Cut & paste from PyTorch official master until it's in a few official releases - RW
- # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
- def norm_cdf(x):
- # Computes standard normal cumulative distribution function
- return (1. + math.erf(x / math.sqrt(2.))) / 2.
- if (mean < a - 2 * std) or (mean > b + 2 * std):
- warnings.warn(
- 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
- 'The distribution of values may be incorrect.',
- stacklevel=2)
- # Values are generated by using a truncated uniform distribution and
- # then using the inverse CDF for the normal distribution.
- # Get upper and lower cdf values
- v = norm_cdf((a - mean) / std)
- u = norm_cdf((b - mean) / std)
- # Uniformly fill tensor with values from [v, u], then translate to
- # [2v-1, 2u-1].
- tensor.uniform_(2 * v - 1, 2 * u - 1)
- # Use inverse cdf transform for normal distribution to get truncated
- # standard normal
- tensor.erfinv_()
- # Transform to proper mean, std
- tensor.mul_(std * math.sqrt(2.))
- tensor.add_(mean)
- # Clamp to ensure it's in the proper range
- tensor.clamp_(min=a, max=b)
- return tensor
- def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
- # type: (Tensor, float, float, float, float) -> Tensor
- with torch.no_grad():
- return _trunc_normal_(tensor, mean, std, a, b)
- # From PyTorch internals
- def _ntuple(n):
- def parse(x):
- if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
- return x
- return tuple(repeat(x, n))
- return parse
- to_1tuple = _ntuple(1)
- to_2tuple = _ntuple(2)
- to_3tuple = _ntuple(3)
- to_4tuple = _ntuple(4)
- to_ntuple = _ntuple
|