_stats.py 1014 B

123456789101112131415161718192021222324252627282930
  1. # NOTE! PLEASE KEEP THIS FILE *FREE* OF TORCH DEPS! IT SHOULD BE IMPORTABLE ANYWHERE.
  2. # IF YOU FEEL AN OVERWHELMING URGE TO ADD A TORCH DEP, MAKE A TRAMPOLINE FILE A LA torch._dynamo.utils
  3. # AND SCRUB AWAY TORCH NOTIONS THERE.
  4. import collections
  5. import functools
  6. from collections import OrderedDict
  7. from typing import Callable, TypeVar
  8. from typing_extensions import ParamSpec
  9. simple_call_counter: OrderedDict[str, int] = collections.OrderedDict()
  10. _P = ParamSpec("_P")
  11. _R = TypeVar("_R")
  12. def count_label(label: str) -> None:
  13. prev = simple_call_counter.setdefault(label, 0)
  14. simple_call_counter[label] = prev + 1
  15. def count(fn: Callable[_P, _R]) -> Callable[_P, _R]:
  16. @functools.wraps(fn)
  17. def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
  18. if fn.__qualname__ not in simple_call_counter:
  19. simple_call_counter[fn.__qualname__] = 0
  20. simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1
  21. return fn(*args, **kwargs)
  22. return wrapper