regress_test_utils.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import contextlib
  3. import hashlib
  4. import os
  5. import pickle
  6. import random
  7. import re
  8. import shutil
  9. import tempfile
  10. from collections import OrderedDict
  11. from collections.abc import Mapping
  12. from pathlib import Path
  13. from types import FunctionType
  14. from typing import Any, Dict, Union
  15. import json
  16. import numpy as np
  17. import torch
  18. import torch.optim
  19. from torch import nn
  20. from .test_utils import compare_arguments_nested
  21. class RegressTool:
  22. """This class is used to stop inference/training results from changing by some unaware affections by unittests.
  23. Firstly, run a baseline test to create a result file, then changes can be observed between
  24. the latest version and the baseline file.
  25. """
  26. def __init__(self,
  27. baseline: bool = None,
  28. store_func: FunctionType = None,
  29. load_func: FunctionType = None):
  30. """A func to store the baseline file and a func to load the baseline file.
  31. """
  32. self.baseline = baseline
  33. self.store_func = store_func
  34. self.load_func = load_func
  35. print(f'Current working dir is: {Path.cwd()}')
  36. def store(self, local, remote):
  37. if self.store_func is not None:
  38. self.store_func(local, remote)
  39. else:
  40. path = os.path.abspath(
  41. os.path.join(Path.cwd(), 'data', 'test', 'regression'))
  42. os.makedirs(path, exist_ok=True)
  43. shutil.copy(local, os.path.join(path, remote))
  44. def load(self, local, remote):
  45. if self.load_func is not None:
  46. self.load_func(local, remote)
  47. else:
  48. path = os.path.abspath(
  49. os.path.join(Path.cwd(), 'data', 'test', 'regression'))
  50. baseline = os.path.join(path, remote)
  51. if not os.path.exists(baseline):
  52. raise ValueError(f'base line file {baseline} not exist')
  53. print(
  54. f'local file found:{baseline}, md5:{hashlib.md5(open(baseline,"rb").read()).hexdigest()}'
  55. )
  56. if os.path.exists(local):
  57. os.remove(local)
  58. os.symlink(baseline, local, target_is_directory=False)
  59. @contextlib.contextmanager
  60. def monitor_module_single_forward(self,
  61. module: nn.Module,
  62. file_name: str,
  63. compare_fn=None,
  64. compare_model_output=True,
  65. **kwargs):
  66. """Monitor a pytorch module in a single forward.
  67. Args:
  68. module: A torch module
  69. file_name: The file_name to store or load file
  70. compare_fn: A custom fn used to compare the results manually.
  71. compare_model_output: Only compare the input module's output, skip all other tensors
  72. >>> def compare_fn(v1, v2, key, type):
  73. >>> return None
  74. v1 is the baseline value
  75. v2 is the value of current version
  76. key is the key of submodules
  77. type is in one of 'input', 'output'
  78. kwargs:
  79. atol: The absolute gap between two np arrays.
  80. rtol: The relative gap between two np arrays.
  81. """
  82. baseline = os.getenv('REGRESSION_BASELINE')
  83. if baseline is None or self.baseline is None:
  84. yield
  85. return
  86. baseline = self.baseline
  87. io_json = {}
  88. absolute_path = f'./{file_name}.bin'
  89. if not isinstance(module, nn.Module):
  90. assert hasattr(module, 'model')
  91. module = module.model
  92. hack_forward(module, file_name, io_json)
  93. intercept_module(module, io_json)
  94. yield
  95. hack_forward(module, None, None, restore=True)
  96. intercept_module(module, None, restore=True)
  97. if baseline:
  98. with open(absolute_path, 'wb') as f:
  99. pickle.dump(io_json, f)
  100. self.store(absolute_path, f'{file_name}.bin')
  101. os.remove(absolute_path)
  102. else:
  103. name = os.path.basename(absolute_path)
  104. baseline = os.path.join(tempfile.gettempdir(), name)
  105. self.load(baseline, name)
  106. with open(baseline, 'rb') as f:
  107. base = pickle.load(f)
  108. class SafeNumpyEncoder(json.JSONEncoder):
  109. def parse_default(self, obj):
  110. if isinstance(obj, np.ndarray):
  111. return obj.tolist()
  112. if isinstance(obj, np.floating):
  113. return float(obj)
  114. if isinstance(obj, np.integer):
  115. return int(obj)
  116. return json.JSONEncoder.default(self, obj)
  117. def default(self, obj):
  118. try:
  119. return self.default(obj)
  120. except Exception:
  121. print(
  122. f'Type {obj.__class__} cannot be serialized and printed'
  123. )
  124. return None
  125. if compare_model_output:
  126. print(
  127. 'Ignore inner modules, only the output of the model will be verified.'
  128. )
  129. base = {
  130. key: value
  131. for key, value in base.items() if key == file_name
  132. }
  133. for key, value in base.items():
  134. value['input'] = {'args': None, 'kwargs': None}
  135. io_json = {
  136. key: value
  137. for key, value in io_json.items() if key == file_name
  138. }
  139. for key, value in io_json.items():
  140. value['input'] = {'args': None, 'kwargs': None}
  141. print(f'baseline: {json.dumps(base, cls=SafeNumpyEncoder)}')
  142. print(f'latest : {json.dumps(io_json, cls=SafeNumpyEncoder)}')
  143. if not compare_io_and_print(base, io_json, compare_fn, **kwargs):
  144. raise ValueError('Result not match!')
  145. @contextlib.contextmanager
  146. def monitor_module_train(self,
  147. trainer: Union[Dict, Any],
  148. file_name,
  149. level='config',
  150. compare_fn=None,
  151. ignore_keys=None,
  152. compare_random=True,
  153. reset_dropout=True,
  154. lazy_stop_callback=None,
  155. **kwargs):
  156. """Monitor a pytorch module's backward data and cfg data within a step of the optimizer.
  157. This is usually useful when you try to change some dangerous code
  158. which has the risk of affecting the training loop.
  159. Args:
  160. trainer: A dict or an object contains the model/optimizer/lr_scheduler
  161. file_name: The file_name to store or load file
  162. level: The regression level.
  163. 'strict' for matching every single tensor.
  164. Please make sure the parameters of head are fixed
  165. and the drop-out rate is zero.
  166. 'config' for matching the initial config, like cfg file, optimizer param_groups,
  167. lr_scheduler params and the random seed.
  168. 'metric' for compare the best metrics in the evaluation loop.
  169. compare_fn: A custom fn used to compare the results manually.
  170. ignore_keys: The keys to ignore of the named_parameters.
  171. compare_random: If to compare random setttings, default True.
  172. reset_dropout: Reset all dropout modules to 0.0.
  173. lazy_stop_callback: A callback passed in, when the moniting is over, this callback will be called.
  174. kwargs:
  175. atol: The absolute gap between two np arrays.
  176. rtol: The relative gap between two np arrays.
  177. >>> def compare_fn(v1, v2, key, type):
  178. >>> return None
  179. v1 is the baseline value
  180. v2 is the value of current version
  181. key is the key of modules/parameters
  182. type is in one of 'input', 'output', 'backward', 'optimizer', 'lr_scheduler', 'cfg', 'state'
  183. """
  184. baseline = os.getenv('REGRESSION_BASELINE')
  185. if baseline is None or self.baseline is None:
  186. yield
  187. return
  188. baseline = self.baseline
  189. io_json = {}
  190. bw_json = {}
  191. absolute_path = f'./{file_name}.bin'
  192. if level == 'strict':
  193. print(
  194. "[Important] The level of regression is 'strict', please make sure your model's parameters are "
  195. 'fixed and all drop-out rates have been set to zero.')
  196. assert hasattr(
  197. trainer, 'model') or 'model' in trainer, 'model must be in trainer'
  198. module = trainer['model'] if isinstance(trainer,
  199. dict) else trainer.model
  200. if not isinstance(module, nn.Module):
  201. assert hasattr(module, 'model')
  202. module = module.model
  203. assert hasattr(
  204. trainer, 'optimizer'
  205. ) or 'optimizer' in trainer, 'optimizer must be in trainer'
  206. assert hasattr(
  207. trainer, 'lr_scheduler'
  208. ) or 'lr_scheduler' in trainer, 'lr_scheduler must be in trainer'
  209. optimizer: torch.optim.Optimizer = trainer['optimizer'] if isinstance(
  210. trainer, dict) else trainer.optimizer
  211. lr_scheduler: torch.optim.lr_scheduler._LRScheduler = trainer['lr_scheduler'] if isinstance(trainer, dict) \
  212. else trainer.lr_scheduler
  213. torch_state = numpify_tensor_nested(torch.get_rng_state())
  214. np_state = np.random.get_state()
  215. random_seed = random.getstate()
  216. seed = trainer._seed if hasattr(
  217. trainer,
  218. '_seed') else trainer.seed if hasattr(trainer, 'seed') else None
  219. if reset_dropout:
  220. with torch.no_grad():
  221. def reinit_dropout(_module):
  222. for name, submodule in _module.named_children():
  223. if isinstance(submodule, torch.nn.Dropout):
  224. setattr(_module, name, torch.nn.Dropout(0.))
  225. else:
  226. reinit_dropout(submodule)
  227. reinit_dropout(module)
  228. if level == 'strict':
  229. hack_forward(module, file_name, io_json)
  230. intercept_module(module, io_json)
  231. hack_backward(
  232. module, optimizer, bw_json, lazy_stop_callback=lazy_stop_callback)
  233. yield
  234. hack_backward(module, optimizer, None, restore=True)
  235. if level == 'strict':
  236. hack_forward(module, None, None, restore=True)
  237. intercept_module(module, None, restore=True)
  238. optimizer_dict = optimizer.state_dict()
  239. optimizer_dict.pop('state', None)
  240. summary = {
  241. 'forward': io_json,
  242. 'backward': bw_json,
  243. 'optimizer': {
  244. 'type': optimizer.__class__.__name__,
  245. 'defaults': optimizer.defaults,
  246. 'state_dict': optimizer_dict
  247. },
  248. 'lr_scheduler': {
  249. 'type': lr_scheduler.__class__.__name__,
  250. 'state_dict': lr_scheduler.state_dict()
  251. },
  252. 'cfg': trainer.cfg.to_dict() if hasattr(trainer, 'cfg') else None,
  253. 'state': {
  254. 'torch_state': torch_state,
  255. 'np_state': np_state,
  256. 'random_seed': random_seed,
  257. 'seed': seed,
  258. }
  259. }
  260. if baseline:
  261. with open(absolute_path, 'wb') as f:
  262. pickle.dump(summary, f)
  263. self.store(absolute_path, f'{file_name}.bin')
  264. os.remove(absolute_path)
  265. else:
  266. name = os.path.basename(absolute_path)
  267. baseline = os.path.join(tempfile.gettempdir(), name)
  268. self.load(baseline, name)
  269. with open(baseline, 'rb') as f:
  270. baseline_json = pickle.load(f)
  271. if level == 'strict' and not compare_io_and_print(
  272. baseline_json['forward'], io_json, compare_fn, **kwargs):
  273. raise RuntimeError('Forward not match!')
  274. if not compare_backward_and_print(
  275. baseline_json['backward'],
  276. bw_json,
  277. compare_fn=compare_fn,
  278. ignore_keys=ignore_keys,
  279. level=level,
  280. **kwargs):
  281. raise RuntimeError('Backward not match!')
  282. cfg_opt1 = {
  283. 'optimizer': baseline_json['optimizer'],
  284. 'lr_scheduler': baseline_json['lr_scheduler'],
  285. 'cfg': baseline_json['cfg'],
  286. 'state': None if not compare_random else baseline_json['state']
  287. }
  288. cfg_opt2 = {
  289. 'optimizer': summary['optimizer'],
  290. 'lr_scheduler': summary['lr_scheduler'],
  291. 'cfg': summary['cfg'],
  292. 'state': None if not compare_random else summary['state']
  293. }
  294. if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn,
  295. **kwargs):
  296. raise RuntimeError('Cfg or optimizers not match!')
  297. class MsRegressTool(RegressTool):
  298. class EarlyStopError(Exception):
  299. pass
  300. @contextlib.contextmanager
  301. def monitor_ms_train(self,
  302. trainer,
  303. file_name,
  304. level='config',
  305. compare_fn=None,
  306. ignore_keys=None,
  307. compare_random=True,
  308. lazy_stop_callback=None,
  309. **kwargs):
  310. if lazy_stop_callback is None:
  311. def lazy_stop_callback():
  312. class EarlyStopHook:
  313. PRIORITY = 90
  314. def before_run(self, trainer):
  315. pass
  316. def after_run(self, trainer):
  317. pass
  318. def before_epoch(self, trainer):
  319. pass
  320. def after_epoch(self, trainer):
  321. pass
  322. def before_iter(self, trainer):
  323. pass
  324. def before_train_epoch(self, trainer):
  325. self.before_epoch(trainer)
  326. def before_val_epoch(self, trainer):
  327. self.before_epoch(trainer)
  328. def after_train_epoch(self, trainer):
  329. self.after_epoch(trainer)
  330. def after_val_epoch(self, trainer):
  331. self.after_epoch(trainer)
  332. def before_train_iter(self, trainer):
  333. self.before_iter(trainer)
  334. def before_val_iter(self, trainer):
  335. self.before_iter(trainer)
  336. def after_train_iter(self, trainer):
  337. self.after_iter(trainer)
  338. def after_val_iter(self, trainer):
  339. self.after_iter(trainer)
  340. def every_n_epochs(self, trainer, n):
  341. return (trainer.epoch + 1) % n == 0 if n > 0 else False
  342. def every_n_inner_iters(self, runner, n):
  343. return (runner.inner_iter
  344. + 1) % n == 0 if n > 0 else False
  345. def every_n_iters(self, trainer, n):
  346. return (trainer.iter + 1) % n == 0 if n > 0 else False
  347. def end_of_epoch(self, trainer):
  348. return trainer.inner_iter + 1 == trainer.iters_per_epoch
  349. def is_last_epoch(self, trainer):
  350. return trainer.epoch + 1 == trainer.max_epochs
  351. def is_last_iter(self, trainer):
  352. return trainer.iter + 1 == trainer.max_iters
  353. def get_triggered_stages(self):
  354. return []
  355. def state_dict(self):
  356. return {}
  357. def load_state_dict(self, state_dict):
  358. pass
  359. def after_iter(self, trainer):
  360. raise MsRegressTool.EarlyStopError('Test finished.')
  361. trainer.register_hook(EarlyStopHook())
  362. def _train_loop(trainer, *args_train, **kwargs_train):
  363. with self.monitor_module_train(
  364. trainer,
  365. file_name,
  366. level,
  367. compare_fn=compare_fn,
  368. ignore_keys=ignore_keys,
  369. compare_random=compare_random,
  370. lazy_stop_callback=lazy_stop_callback,
  371. **kwargs):
  372. try:
  373. return trainer.train_loop_origin(*args_train,
  374. **kwargs_train)
  375. except MsRegressTool.EarlyStopError:
  376. pass
  377. trainer.train_loop_origin, trainer.train_loop = \
  378. trainer.train_loop, type(trainer.train_loop)(_train_loop, trainer)
  379. yield
  380. def compare_module(module1: nn.Module, module2: nn.Module):
  381. for p1, p2 in zip(module1.parameters(), module2.parameters()):
  382. if p1.data.ne(p2.data).sum() > 0:
  383. return False
  384. return True
  385. def numpify_tensor_nested(tensors, reduction=None, clip_value=10000):
  386. try:
  387. from modelscope.outputs import ModelOutputBase
  388. except ImportError:
  389. ModelOutputBase = dict
  390. "Numpify `tensors` (even if it's a nested list/tuple of tensors)."
  391. if isinstance(tensors, (Mapping, ModelOutputBase)):
  392. return OrderedDict({
  393. k: numpify_tensor_nested(t, reduction, clip_value)
  394. for k, t in tensors.items()
  395. })
  396. if isinstance(tensors, list):
  397. return list(
  398. numpify_tensor_nested(t, reduction, clip_value) for t in tensors)
  399. if isinstance(tensors, tuple):
  400. return tuple(
  401. numpify_tensor_nested(t, reduction, clip_value) for t in tensors)
  402. if isinstance(tensors, torch.Tensor):
  403. t: np.ndarray = tensors.cpu().numpy()
  404. if clip_value is not None:
  405. t = np.where(t > clip_value, clip_value, t)
  406. t = np.where(t < -clip_value, -clip_value, t)
  407. if reduction == 'sum':
  408. return t.sum(dtype=float)
  409. elif reduction == 'mean':
  410. return t.mean(dtype=float)
  411. return t
  412. return tensors
  413. def detach_tensor_nested(tensors):
  414. try:
  415. from modelscope.outputs import ModelOutputBase
  416. except ImportError:
  417. ModelOutputBase = dict
  418. "Detach `tensors` (even if it's a nested list/tuple of tensors)."
  419. if isinstance(tensors, (Mapping, ModelOutputBase)):
  420. return OrderedDict(
  421. {k: detach_tensor_nested(t)
  422. for k, t in tensors.items()})
  423. if isinstance(tensors, list):
  424. return list(detach_tensor_nested(t) for t in tensors)
  425. if isinstance(tensors, tuple):
  426. return tuple(detach_tensor_nested(t) for t in tensors)
  427. if isinstance(tensors, torch.Tensor):
  428. return tensors.detach()
  429. return tensors
  430. def hack_forward(module: nn.Module,
  431. name,
  432. io_json,
  433. restore=False,
  434. keep_tensors=False):
  435. def _forward(self, *args, **kwargs):
  436. ret = self.forward_origin(*args, **kwargs)
  437. if keep_tensors:
  438. args = numpify_tensor_nested(detach_tensor_nested(args))
  439. kwargs = numpify_tensor_nested(detach_tensor_nested(kwargs))
  440. output = numpify_tensor_nested(detach_tensor_nested(ret))
  441. else:
  442. args = {
  443. 'sum':
  444. numpify_tensor_nested(
  445. detach_tensor_nested(args), reduction='sum'),
  446. 'mean':
  447. numpify_tensor_nested(
  448. detach_tensor_nested(args), reduction='mean'),
  449. }
  450. kwargs = {
  451. 'sum':
  452. numpify_tensor_nested(
  453. detach_tensor_nested(kwargs), reduction='sum'),
  454. 'mean':
  455. numpify_tensor_nested(
  456. detach_tensor_nested(kwargs), reduction='mean'),
  457. }
  458. output = {
  459. 'sum':
  460. numpify_tensor_nested(
  461. detach_tensor_nested(ret), reduction='sum'),
  462. 'mean':
  463. numpify_tensor_nested(
  464. detach_tensor_nested(ret), reduction='mean'),
  465. }
  466. io_json[name] = {
  467. 'input': {
  468. 'args': args,
  469. 'kwargs': kwargs,
  470. },
  471. 'output': output,
  472. }
  473. return ret
  474. if not restore and not hasattr(module, 'forward_origin'):
  475. module.forward_origin, module.forward = module.forward, type(
  476. module.forward)(_forward, module)
  477. if restore and hasattr(module, 'forward_origin'):
  478. module.forward = module.forward_origin
  479. del module.forward_origin
  480. def hack_backward(module: nn.Module,
  481. optimizer,
  482. io_json,
  483. restore=False,
  484. lazy_stop_callback=None):
  485. def _step(self, *args, **kwargs):
  486. for name, param in module.named_parameters():
  487. io_json[name] = {
  488. 'data': {
  489. 'sum':
  490. numpify_tensor_nested(
  491. detach_tensor_nested(param.data), reduction='sum'),
  492. 'mean':
  493. numpify_tensor_nested(
  494. detach_tensor_nested(param.data), reduction='mean'),
  495. },
  496. 'grad': {
  497. 'sum':
  498. numpify_tensor_nested(
  499. detach_tensor_nested(param.grad), reduction='sum'),
  500. 'mean':
  501. numpify_tensor_nested(
  502. detach_tensor_nested(param.grad), reduction='mean'),
  503. }
  504. }
  505. ret = self.step_origin(*args, **kwargs)
  506. for name, param in module.named_parameters():
  507. io_json[name]['data_after'] = {
  508. 'sum':
  509. numpify_tensor_nested(
  510. detach_tensor_nested(param.data), reduction='sum'),
  511. 'mean':
  512. numpify_tensor_nested(
  513. detach_tensor_nested(param.data), reduction='mean'),
  514. }
  515. if lazy_stop_callback is not None:
  516. lazy_stop_callback()
  517. return ret
  518. if not restore and not hasattr(optimizer, 'step_origin'):
  519. optimizer.step_origin, optimizer.step = optimizer.step, type(
  520. optimizer.state_dict)(_step, optimizer)
  521. if restore and hasattr(optimizer, 'step_origin'):
  522. optimizer.step = optimizer.step_origin
  523. del optimizer.step_origin
  524. def intercept_module(module: nn.Module,
  525. io_json,
  526. parent_name=None,
  527. restore=False):
  528. for name, module in module.named_children():
  529. full_name = parent_name + '.' + name if parent_name is not None else name
  530. hack_forward(module, full_name, io_json, restore)
  531. intercept_module(module, io_json, full_name, restore)
  532. def compare_io_and_print(baseline_json, io_json, compare_fn=None, **kwargs):
  533. if compare_fn is None:
  534. def compare_fn(*args, **kwargs):
  535. return None
  536. keys1 = set(baseline_json.keys())
  537. keys2 = set(io_json.keys())
  538. added = keys1 - keys2
  539. removed = keys2 - keys1
  540. print(f'unmatched keys: {added}, {removed}')
  541. shared_keys = keys1.intersection(keys2)
  542. match = True
  543. for key in shared_keys:
  544. v1 = baseline_json[key]
  545. v2 = io_json[key]
  546. v1input = numpify_tensor_nested(v1['input'])
  547. v2input = numpify_tensor_nested(v2['input'])
  548. res = compare_fn(v1input, v2input, key, 'input')
  549. if res is not None:
  550. print(
  551. f'input of {key} compared with user compare_fn with result:{res}\n'
  552. )
  553. match = match and res
  554. else:
  555. match = compare_arguments_nested(
  556. f'unmatched module {key} input args', v1input['args'],
  557. v2input['args'], **kwargs) and match
  558. match = compare_arguments_nested(
  559. f'unmatched module {key} input kwargs', v1input['kwargs'],
  560. v2input['kwargs'], **kwargs) and match
  561. v1output = numpify_tensor_nested(v1['output'])
  562. v2output = numpify_tensor_nested(v2['output'])
  563. res = compare_fn(v1output, v2output, key, 'output')
  564. if res is not None:
  565. print(
  566. f'output of {key} compared with user compare_fn with result:{res}\n'
  567. )
  568. match = match and res
  569. else:
  570. match = compare_arguments_nested(
  571. f'unmatched module {key} outputs',
  572. arg1=v1output,
  573. arg2=v2output,
  574. **kwargs) and match
  575. return match
  576. def compare_backward_and_print(baseline_json,
  577. bw_json,
  578. level,
  579. ignore_keys=None,
  580. compare_fn=None,
  581. **kwargs):
  582. if compare_fn is None:
  583. def compare_fn(*args, **kwargs):
  584. return None
  585. keys1 = set(baseline_json.keys())
  586. keys2 = set(bw_json.keys())
  587. added = keys1 - keys2
  588. removed = keys2 - keys1
  589. print(f'unmatched backward keys: {added}, {removed}')
  590. shared_keys = keys1.intersection(keys2)
  591. match = True
  592. for key in shared_keys:
  593. if ignore_keys is not None and key in ignore_keys:
  594. continue
  595. res = compare_fn(baseline_json[key], bw_json[key], key, 'backward')
  596. if res is not None:
  597. print(f'backward data of {key} compared with '
  598. f'user compare_fn with result:{res}\n')
  599. match = match and res
  600. else:
  601. data1, grad1, data_after1 = baseline_json[key][
  602. 'data'], baseline_json[key]['grad'], baseline_json[key][
  603. 'data_after']
  604. data2, grad2, data_after2 = bw_json[key]['data'], bw_json[key][
  605. 'grad'], bw_json[key]['data_after']
  606. match = compare_arguments_nested(
  607. f'unmatched module {key} tensor data',
  608. arg1=data1,
  609. arg2=data2,
  610. **kwargs) and match
  611. if level == 'strict':
  612. match = compare_arguments_nested(
  613. f'unmatched module {key} grad data',
  614. arg1=grad1,
  615. arg2=grad2,
  616. **kwargs) and match
  617. match = compare_arguments_nested(
  618. f'unmatched module {key} data after step', data_after1,
  619. data_after2, **kwargs) and match
  620. return match
  621. def compare_cfg_and_optimizers(baseline_json,
  622. cfg_json,
  623. compare_fn=None,
  624. **kwargs):
  625. if compare_fn is None:
  626. def compare_fn(*args, **kwargs):
  627. return None
  628. optimizer1, lr_scheduler1, cfg1, state1 = baseline_json[
  629. 'optimizer'], baseline_json['lr_scheduler'], baseline_json[
  630. 'cfg'], baseline_json['state']
  631. optimizer2, lr_scheduler2, cfg2, state2 = cfg_json['optimizer'], cfg_json[
  632. 'lr_scheduler'], cfg_json['cfg'], baseline_json['state']
  633. match = True
  634. res = compare_fn(optimizer1, optimizer2, None, 'optimizer')
  635. if res is not None:
  636. print(f'optimizer compared with user compare_fn with result:{res}\n')
  637. match = match and res
  638. else:
  639. if optimizer1['type'] != optimizer2['type']:
  640. print(
  641. f"Optimizer type not equal:{optimizer1['type']} and {optimizer2['type']}"
  642. )
  643. match = compare_arguments_nested(
  644. 'unmatched optimizer defaults', optimizer1['defaults'],
  645. optimizer2['defaults'], **kwargs) and match
  646. match = compare_arguments_nested(
  647. 'unmatched optimizer state_dict', optimizer1['state_dict'],
  648. optimizer2['state_dict'], **kwargs) and match
  649. res = compare_fn(lr_scheduler1, lr_scheduler2, None, 'lr_scheduler')
  650. if res is not None:
  651. print(
  652. f'lr_scheduler compared with user compare_fn with result:{res}\n')
  653. match = match and res
  654. else:
  655. if lr_scheduler1['type'] != lr_scheduler2['type']:
  656. print(
  657. f"Optimizer type not equal:{lr_scheduler1['type']} and {lr_scheduler2['type']}"
  658. )
  659. match = compare_arguments_nested(
  660. 'unmatched lr_scheduler state_dict', lr_scheduler1['state_dict'],
  661. lr_scheduler2['state_dict'], **kwargs) and match
  662. res = compare_fn(cfg1, cfg2, None, 'cfg')
  663. if res is not None:
  664. print(f'cfg compared with user compare_fn with result:{res}\n')
  665. match = match and res
  666. else:
  667. match = compare_arguments_nested(
  668. 'unmatched cfg', arg1=cfg1, arg2=cfg2, **kwargs) and match
  669. res = compare_fn(state1, state2, None, 'state')
  670. if res is not None:
  671. print(
  672. f'random state compared with user compare_fn with result:{res}\n')
  673. match = match and res
  674. else:
  675. match = compare_arguments_nested('unmatched random state', state1,
  676. state2, **kwargs) and match
  677. return match
  678. class IgnoreKeyFn:
  679. def __init__(self, keys):
  680. if isinstance(keys, str):
  681. keys = [keys]
  682. self.keys = keys if isinstance(keys, list) else []
  683. def __call__(self, v1output, v2output, key, type):
  684. for _key in self.keys:
  685. pattern = re.compile(_key)
  686. if key is not None and pattern.fullmatch(key):
  687. return True
  688. return None