helpers.py 1.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. """ Layer/Module Helpers
  2. Hacked together by / Copyright 2020 Ross Wightman
  3. """
  4. from itertools import repeat
  5. import collections.abc
  6. # From PyTorch internals
  7. def _ntuple(n):
  8. def parse(x):
  9. if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
  10. return tuple(x)
  11. return tuple(repeat(x, n))
  12. return parse
  13. to_1tuple = _ntuple(1)
  14. to_2tuple = _ntuple(2)
  15. to_3tuple = _ntuple(3)
  16. to_4tuple = _ntuple(4)
  17. to_ntuple = _ntuple
  18. def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
  19. min_value = min_value or divisor
  20. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  21. # Make sure that round down does not go down by more than 10%.
  22. if new_v < round_limit * v:
  23. new_v += divisor
  24. return new_v
  25. def extend_tuple(x, n):
  26. # pads a tuple to specified n by padding with last value
  27. if not isinstance(x, (tuple, list)):
  28. x = (x,)
  29. else:
  30. x = tuple(x)
  31. pad_n = n - len(x)
  32. if pad_n <= 0:
  33. return x[:n]
  34. return x + (x[-1],) * pad_n