format.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from enum import Enum
  2. from typing import Union
  3. import torch
  4. class Format(str, Enum):
  5. NCHW = 'NCHW'
  6. NHWC = 'NHWC'
  7. NCL = 'NCL'
  8. NLC = 'NLC'
  9. FormatT = Union[str, Format]
  10. def get_spatial_dim(fmt: FormatT):
  11. fmt = Format(fmt)
  12. if fmt is Format.NLC:
  13. dim = (1,)
  14. elif fmt is Format.NCL:
  15. dim = (2,)
  16. elif fmt is Format.NHWC:
  17. dim = (1, 2)
  18. else:
  19. dim = (2, 3)
  20. return dim
  21. def get_channel_dim(fmt: FormatT):
  22. fmt = Format(fmt)
  23. if fmt is Format.NHWC:
  24. dim = 3
  25. elif fmt is Format.NLC:
  26. dim = 2
  27. else:
  28. dim = 1
  29. return dim
  30. def nchw_to(x: torch.Tensor, fmt: Format):
  31. if fmt == Format.NHWC:
  32. x = x.permute(0, 2, 3, 1)
  33. elif fmt == Format.NLC:
  34. x = x.flatten(2).transpose(1, 2)
  35. elif fmt == Format.NCL:
  36. x = x.flatten(2)
  37. return x
  38. def nhwc_to(x: torch.Tensor, fmt: Format):
  39. if fmt == Format.NCHW:
  40. x = x.permute(0, 3, 1, 2)
  41. elif fmt == Format.NLC:
  42. x = x.flatten(1, 2)
  43. elif fmt == Format.NCL:
  44. x = x.flatten(1, 2).transpose(1, 2)
  45. return x