pyinterfaces.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import os.path as osp
  2. from typing import Tuple
  3. import torch
  4. from torch.cuda.amp import custom_bwd, custom_fwd
  5. from torch.utils.cpp_extension import load
  6. try:
  7. from cudaops_ailut import (ailut_cbackward, ailut_cforward, lut_cbackward,
  8. lut_cforward)
  9. except ImportError:
  10. CUR_DIR = osp.abspath(osp.dirname(__file__))
  11. cudaops_ailut = load(
  12. name='cudaops_ailut',
  13. sources=[
  14. osp.join(CUR_DIR, 'Ailut', 'csrc/ailut_transform.cpp'),
  15. osp.join(CUR_DIR, 'Ailut', 'csrc/ailut_transform_cpu.cpp'),
  16. osp.join(CUR_DIR, 'Ailut', 'csrc/ailut_transform_cuda.cu')
  17. ],
  18. verbose=True)
  19. from cudaops_ailut import (ailut_cbackward, ailut_cforward, lut_cbackward,
  20. lut_cforward)
  21. class LUTTransformFunction(torch.autograd.Function):
  22. @staticmethod
  23. @custom_fwd(cast_inputs=torch.float32)
  24. def forward(ctx, img: torch.Tensor, lut: torch.Tensor) -> torch.Tensor:
  25. img = img.contiguous()
  26. lut = lut.contiguous()
  27. assert img.ndimension() == 4, \
  28. 'only support 2D image with batch and channel dimensions (4D tensor)'
  29. assert lut.ndimension() in [5], \
  30. 'only support 3D lookup table with batch dimension (5D tensor)'
  31. output = img.new_zeros(
  32. (img.size(0), lut.size(1), img.size(2), img.size(3)))
  33. lut_cforward(img, lut, output)
  34. ctx.save_for_backward(img, lut)
  35. return output
  36. @staticmethod
  37. @custom_bwd
  38. def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor]:
  39. grad_output = grad_output.contiguous()
  40. img, lut = ctx.saved_tensors
  41. grad_img = torch.zeros_like(img)
  42. grad_lut = torch.zeros_like(lut)
  43. lut_cbackward(grad_output, img, lut, grad_img, grad_lut)
  44. return grad_img, grad_lut
  45. class AiLUTTransformFunction(torch.autograd.Function):
  46. @staticmethod
  47. @custom_fwd(cast_inputs=torch.float32)
  48. def forward(ctx, img: torch.Tensor, lut: torch.Tensor,
  49. vertices: torch.tensor) -> torch.Tensor:
  50. img = img.contiguous()
  51. lut = lut.contiguous()
  52. vertices = vertices.contiguous()
  53. assert img.ndimension() == 4, \
  54. 'only support 2D image with batch and channel dimensions (4D tensor)'
  55. assert lut.ndimension() in [5], \
  56. 'only support 3D lookup table with batch dimension (5D tensor)'
  57. assert vertices.ndimension() == 3, \
  58. 'only support 1D vertices list with batch and channel dimensions (3D tensor)'
  59. output = img.new_zeros(
  60. (img.size(0), lut.size(1), img.size(2), img.size(3)))
  61. ailut_cforward(img, lut, vertices, output)
  62. ctx.save_for_backward(img, lut, vertices)
  63. return output
  64. @staticmethod
  65. @custom_bwd
  66. def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor]:
  67. grad_output = grad_output.contiguous()
  68. img, lut, vertices = ctx.saved_tensors
  69. grad_img = torch.zeros_like(img)
  70. grad_lut = torch.zeros_like(lut)
  71. grad_ver = torch.zeros_like(vertices)
  72. ailut_cbackward(grad_output, img, lut, vertices, grad_img, grad_lut,
  73. grad_ver)
  74. return grad_img, grad_lut, grad_ver
  75. def ailut_transform(img: torch.Tensor, lut: torch.Tensor,
  76. vertices: torch.Tensor) -> torch.Tensor:
  77. r"""Adaptive Interval 3D Lookup Table Transform (AiLUT-Transform).
  78. Args:
  79. img (torch.Tensor): input image of shape (b, 3, h, w).
  80. lut (torch.Tensor): output values of the 3D LUT, shape (b, 3, d, d, d).
  81. vertices (torch.Tensor): sampling coordinates along each dimension of
  82. the 3D LUT, shape (b, 3, d).
  83. Returns:
  84. torch.Tensor: transformed image of shape (b, 3, h, w).
  85. """
  86. return AiLUTTransformFunction.apply(img, lut, vertices)
  87. def lut_transform(img: torch.Tensor, lut: torch.Tensor) -> torch.Tensor:
  88. r"""Standard 3D Lookup Table Transform.
  89. Args:
  90. img (torch.Tensor): input image of shape (b, 3, h, w).
  91. lut (torch.Tensor): output values of the 3D LUT, shape (b, 3, d, d, d).
  92. Returns:
  93. torch.Tensor: transformed image of shape (b, 3, h, w).
  94. """
  95. return LUTTransformFunction.apply(img, lut)