export_to_onnx.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. #!/usr/bin/env python3
  2. """
  3. Export SuperPoint and LightGlue models to ONNX for TensorRT conversion.
  4. This script targets a fixed 640x480 (WxH) grayscale input and limits the
  5. number of detected keypoints to make the exported graph TensorRT-friendly.
  6. """
  7. from __future__ import annotations
  8. import argparse
  9. from pathlib import Path
  10. import torch
  11. # Ensure we can import the local LightGlue package that was copied into the
  12. # project root.
  13. from lightglue.superpoint import SuperPoint # type: ignore
  14. from lightglue.lightglue import LightGlue # type: ignore
  15. MAX_KEYPOINTS = 128
  16. IMAGE_WIDTH = 640
  17. IMAGE_HEIGHT = 480
  18. class SuperPointWrapper(torch.nn.Module):
  19. """Wrap the official SuperPoint module so the outputs have static shapes."""
  20. def __init__(self, max_keypoints: int = MAX_KEYPOINTS):
  21. super().__init__()
  22. self.model = SuperPoint(max_num_keypoints=max_keypoints)
  23. self.max_keypoints = max_keypoints
  24. def forward(self, image: torch.Tensor):
  25. """
  26. Args:
  27. image: (B, 1, H, W) float32 tensor in [0, 1]
  28. Returns:
  29. keypoints: (B, MAX_K, 2)
  30. scores: (B, MAX_K, 1)
  31. descriptors: (B, MAX_K, 256)
  32. valid_counts: (B, 1) number of real keypoints before padding
  33. """
  34. out = self.model({"image": image})
  35. keypoints = out["keypoints"] # (B, N, 2)
  36. scores = out["keypoint_scores"].unsqueeze(-1) # (B, N, 1)
  37. descriptors = out["descriptors"] # (B, N, 256)
  38. batch_size, num_kp, _ = keypoints.shape
  39. max_k = self.max_keypoints
  40. # Clamp to max_k and record valid counts (for downstream masking).
  41. clamped_num = torch.clamp(torch.tensor([num_kp], device=image.device), max=max_k)
  42. valid_counts = clamped_num.expand(batch_size, 1).to(torch.int32)
  43. keypoints = keypoints[:, :max_k, :]
  44. scores = scores[:, :max_k, :]
  45. descriptors = descriptors[:, :max_k, :]
  46. pad_k = max_k - keypoints.shape[1]
  47. if pad_k > 0:
  48. pad_shape_kp = (0, 0, 0, pad_k)
  49. keypoints = torch.nn.functional.pad(keypoints, pad_shape_kp)
  50. scores = torch.nn.functional.pad(scores, (0, 0, 0, pad_k))
  51. descriptors = torch.nn.functional.pad(descriptors, (0, 0, 0, pad_k))
  52. return keypoints, scores, descriptors, valid_counts
  53. class LightGlueWrapper(torch.nn.Module):
  54. """Wrap LightGlue so it consumes SuperPoint outputs with static shapes."""
  55. def __init__(self, max_keypoints: int = MAX_KEYPOINTS):
  56. super().__init__()
  57. self.model = LightGlue(features="superpoint")
  58. self.max_keypoints = max_keypoints
  59. def forward(
  60. self,
  61. keypoints0,
  62. scores0,
  63. descriptors0,
  64. keypoints1,
  65. scores1,
  66. descriptors1,
  67. ):
  68. """
  69. Args:
  70. keypoints{0,1}: (B, MAX_K, 2)
  71. scores{0,1}: (B, MAX_K, 1)
  72. descriptors{0,1}: (B, MAX_K, 256)
  73. Returns:
  74. matches0: (B, MAX_K) indices of matches in image1 (or -1)
  75. matches1: (B, MAX_K) indices of matches in image0 (or -1)
  76. mscores0: (B, MAX_K)
  77. mscores1: (B, MAX_K)
  78. """
  79. batch = {
  80. "image0": {
  81. "keypoints": keypoints0,
  82. "keypoint_scores": scores0.squeeze(-1),
  83. "descriptors": descriptors0.transpose(-1, -2),
  84. },
  85. "image1": {
  86. "keypoints": keypoints1,
  87. "keypoint_scores": scores1.squeeze(-1),
  88. "descriptors": descriptors1.transpose(-1, -2),
  89. },
  90. }
  91. out = self.model(batch)
  92. matches0 = out["matches0"] # (B, MAX_K)
  93. matches1 = out["matches1"]
  94. mscores0 = out["matching_scores0"]
  95. mscores1 = out["matching_scores1"]
  96. return matches0, matches1, mscores0, mscores1
  97. def export_model(module: torch.nn.Module, inputs, output_path: Path, output_names):
  98. output_path.parent.mkdir(parents=True, exist_ok=True)
  99. module.eval()
  100. with torch.no_grad():
  101. torch.onnx.export(
  102. module,
  103. inputs,
  104. output_path.as_posix(),
  105. export_params=True,
  106. opset_version=17,
  107. do_constant_folding=True,
  108. input_names=[f"input_{i}" for i in range(len(inputs))]
  109. if isinstance(inputs, (tuple, list))
  110. else ["input"],
  111. output_names=output_names,
  112. dynamic_axes=None,
  113. )
  114. print(f"[OK] Exported ONNX: {output_path}")
  115. def parse_args():
  116. parser = argparse.ArgumentParser(description="Export LightGlue pipeline to ONNX.")
  117. parser.add_argument(
  118. "--output-dir",
  119. type=Path,
  120. default=Path("models"),
  121. help="Where to store the ONNX files.",
  122. )
  123. parser.add_argument(
  124. "--device",
  125. type=str,
  126. default="cuda",
  127. choices=["cuda", "cpu"],
  128. help="Torch device for the export dummy run.",
  129. )
  130. parser.add_argument(
  131. "--max-keypoints",
  132. type=int,
  133. default=MAX_KEYPOINTS,
  134. help="Maximum number of keypoints to keep (must match inference).",
  135. )
  136. return parser.parse_args()
  137. def main():
  138. args = parse_args()
  139. device = torch.device(args.device if torch.cuda.is_available() else "cpu")
  140. dummy = torch.rand(1, 1, IMAGE_HEIGHT, IMAGE_WIDTH, device=device)
  141. sp = SuperPointWrapper(max_keypoints=args.max_keypoints).to(device)
  142. export_model(
  143. sp,
  144. dummy,
  145. args.output_dir / "superpoint.onnx",
  146. ["keypoints", "scores", "descriptors", "valid_counts"],
  147. )
  148. # Prepare dummy inputs for LightGlue (batch=1)
  149. keypoints = torch.zeros(1, args.max_keypoints, 2, device=device)
  150. scores = torch.zeros(1, args.max_keypoints, 1, device=device)
  151. descriptors = torch.zeros(1, args.max_keypoints, 256, device=device)
  152. lg = LightGlueWrapper(max_keypoints=args.max_keypoints).to(device)
  153. export_model(
  154. lg,
  155. (
  156. keypoints,
  157. scores,
  158. descriptors,
  159. keypoints,
  160. scores,
  161. descriptors,
  162. ),
  163. args.output_dir / "lightglue.onnx",
  164. ["matches0", "matches1", "scores0", "scores1"],
  165. )
  166. if __name__ == "__main__":
  167. main()