# Copyright (c) Alibaba, Inc. and its affiliates. import collections.abc import math import os.path as osp from itertools import repeat import numpy as np import torch from mmcls.datasets.base_dataset import BaseDataset def get_trained_checkpoints_name(work_path): import os file_list = os.listdir(work_path) last = 0 model_name = None # find the best model if model_name is None: for f_name in file_list: if 'best_' in f_name and f_name.endswith('.pth'): best_epoch = f_name.replace('.pth', '').split('_')[-1] if best_epoch.isdigit(): last = int(best_epoch) model_name = f_name return model_name # or find the latest model if model_name is None: for f_name in file_list: if 'epoch_' in f_name and f_name.endswith('.pth'): epoch_num = f_name.replace('epoch_', '').replace('.pth', '') if not epoch_num.isdigit(): continue ind = int(epoch_num) if ind > last: last = ind model_name = f_name return model_name def preprocess_transform(cfgs): if cfgs is None: return None for i, cfg in enumerate(cfgs): if cfg.type == 'Resize': if isinstance(cfg.size, list): cfgs[i].size = tuple(cfg.size) return cfgs def get_ms_dataset_root(ms_dataset): if ms_dataset is None or len(ms_dataset) < 1: return None try: data_root = ms_dataset[0]['image:FILE'].split('extracted')[0] path_post = ms_dataset[0]['image:FILE'].split('extracted')[1].split( '/') extracted_data_root = osp.join(data_root, 'extracted', path_post[1], path_post[2]) return extracted_data_root except Exception as e: raise ValueError(f'Dataset Error: {e}') return None def get_classes(classes=None): import mmcv if isinstance(classes, str): # take it as a file path class_names = mmcv.list_from_file(classes) elif isinstance(classes, (tuple, list)): class_names = classes else: raise ValueError(f'Unsupported type {type(classes)} of classes.') return class_names class MmDataset(BaseDataset): def __init__(self, ms_dataset, pipeline, classes=None, test_mode=False, data_prefix=''): self.ms_dataset = ms_dataset if len(self.ms_dataset) < 1: raise ValueError('Dataset Error: dataset is empty') super(MmDataset, self).__init__( data_prefix=data_prefix, pipeline=pipeline, classes=classes, test_mode=test_mode) def load_annotations(self): if self.CLASSES is None: raise ValueError( f'Dataset Error: Not found classesname.txt: {self.CLASSES}') data_infos = [] for data_info in self.ms_dataset: filename = data_info['image:FILE'] gt_label = data_info['category'] info = {'img_prefix': self.data_prefix} info['img_info'] = {'filename': filename} info['gt_label'] = np.array(gt_label, dtype=np.int64) data_infos.append(info) return data_infos def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1. + math.erf(x / math.sqrt(2.))) / 2. if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn( 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' 'The distribution of values may be incorrect.', stacklevel=2) # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values v = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [v, u], then translate to # [2v-1, 2u-1]. tensor.uniform_(2 * v - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): # type: (Tensor, float, float, float, float) -> Tensor with torch.no_grad(): return _trunc_normal_(tensor, mean, std, a, b) # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return x return tuple(repeat(x, n)) return parse to_1tuple = _ntuple(1) to_2tuple = _ntuple(2) to_3tuple = _ntuple(3) to_4tuple = _ntuple(4) to_ntuple = _ntuple