| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203 |
- #!/usr/bin/env python3
- """
- Export SuperPoint and LightGlue models to ONNX for TensorRT conversion.
- This script targets a fixed 640x480 (WxH) grayscale input and limits the
- number of detected keypoints to make the exported graph TensorRT-friendly.
- """
- from __future__ import annotations
- import argparse
- from pathlib import Path
- import torch
- # Ensure we can import the local LightGlue package that was copied into the
- # project root.
- from lightglue.superpoint import SuperPoint # type: ignore
- from lightglue.lightglue import LightGlue # type: ignore
- MAX_KEYPOINTS = 128
- IMAGE_WIDTH = 640
- IMAGE_HEIGHT = 480
- class SuperPointWrapper(torch.nn.Module):
- """Wrap the official SuperPoint module so the outputs have static shapes."""
- def __init__(self, max_keypoints: int = MAX_KEYPOINTS):
- super().__init__()
- self.model = SuperPoint(max_num_keypoints=max_keypoints)
- self.max_keypoints = max_keypoints
- def forward(self, image: torch.Tensor):
- """
- Args:
- image: (B, 1, H, W) float32 tensor in [0, 1]
- Returns:
- keypoints: (B, MAX_K, 2)
- scores: (B, MAX_K, 1)
- descriptors: (B, MAX_K, 256)
- valid_counts: (B, 1) number of real keypoints before padding
- """
- out = self.model({"image": image})
- keypoints = out["keypoints"] # (B, N, 2)
- scores = out["keypoint_scores"].unsqueeze(-1) # (B, N, 1)
- descriptors = out["descriptors"] # (B, N, 256)
- batch_size, num_kp, _ = keypoints.shape
- max_k = self.max_keypoints
- # Clamp to max_k and record valid counts (for downstream masking).
- clamped_num = torch.clamp(torch.tensor([num_kp], device=image.device), max=max_k)
- valid_counts = clamped_num.expand(batch_size, 1).to(torch.int32)
- keypoints = keypoints[:, :max_k, :]
- scores = scores[:, :max_k, :]
- descriptors = descriptors[:, :max_k, :]
- pad_k = max_k - keypoints.shape[1]
- if pad_k > 0:
- pad_shape_kp = (0, 0, 0, pad_k)
- keypoints = torch.nn.functional.pad(keypoints, pad_shape_kp)
- scores = torch.nn.functional.pad(scores, (0, 0, 0, pad_k))
- descriptors = torch.nn.functional.pad(descriptors, (0, 0, 0, pad_k))
- return keypoints, scores, descriptors, valid_counts
- class LightGlueWrapper(torch.nn.Module):
- """Wrap LightGlue so it consumes SuperPoint outputs with static shapes."""
- def __init__(self, max_keypoints: int = MAX_KEYPOINTS):
- super().__init__()
- self.model = LightGlue(features="superpoint")
- self.max_keypoints = max_keypoints
- def forward(
- self,
- keypoints0,
- scores0,
- descriptors0,
- keypoints1,
- scores1,
- descriptors1,
- ):
- """
- Args:
- keypoints{0,1}: (B, MAX_K, 2)
- scores{0,1}: (B, MAX_K, 1)
- descriptors{0,1}: (B, MAX_K, 256)
- Returns:
- matches0: (B, MAX_K) indices of matches in image1 (or -1)
- matches1: (B, MAX_K) indices of matches in image0 (or -1)
- mscores0: (B, MAX_K)
- mscores1: (B, MAX_K)
- """
- batch = {
- "image0": {
- "keypoints": keypoints0,
- "keypoint_scores": scores0.squeeze(-1),
- "descriptors": descriptors0.transpose(-1, -2),
- },
- "image1": {
- "keypoints": keypoints1,
- "keypoint_scores": scores1.squeeze(-1),
- "descriptors": descriptors1.transpose(-1, -2),
- },
- }
- out = self.model(batch)
- matches0 = out["matches0"] # (B, MAX_K)
- matches1 = out["matches1"]
- mscores0 = out["matching_scores0"]
- mscores1 = out["matching_scores1"]
- return matches0, matches1, mscores0, mscores1
- def export_model(module: torch.nn.Module, inputs, output_path: Path, output_names):
- output_path.parent.mkdir(parents=True, exist_ok=True)
- module.eval()
- with torch.no_grad():
- torch.onnx.export(
- module,
- inputs,
- output_path.as_posix(),
- export_params=True,
- opset_version=17,
- do_constant_folding=True,
- input_names=[f"input_{i}" for i in range(len(inputs))]
- if isinstance(inputs, (tuple, list))
- else ["input"],
- output_names=output_names,
- dynamic_axes=None,
- )
- print(f"[OK] Exported ONNX: {output_path}")
- def parse_args():
- parser = argparse.ArgumentParser(description="Export LightGlue pipeline to ONNX.")
- parser.add_argument(
- "--output-dir",
- type=Path,
- default=Path("models"),
- help="Where to store the ONNX files.",
- )
- parser.add_argument(
- "--device",
- type=str,
- default="cuda",
- choices=["cuda", "cpu"],
- help="Torch device for the export dummy run.",
- )
- parser.add_argument(
- "--max-keypoints",
- type=int,
- default=MAX_KEYPOINTS,
- help="Maximum number of keypoints to keep (must match inference).",
- )
- return parser.parse_args()
- def main():
- args = parse_args()
- device = torch.device(args.device if torch.cuda.is_available() else "cpu")
- dummy = torch.rand(1, 1, IMAGE_HEIGHT, IMAGE_WIDTH, device=device)
- sp = SuperPointWrapper(max_keypoints=args.max_keypoints).to(device)
- export_model(
- sp,
- dummy,
- args.output_dir / "superpoint.onnx",
- ["keypoints", "scores", "descriptors", "valid_counts"],
- )
- # Prepare dummy inputs for LightGlue (batch=1)
- keypoints = torch.zeros(1, args.max_keypoints, 2, device=device)
- scores = torch.zeros(1, args.max_keypoints, 1, device=device)
- descriptors = torch.zeros(1, args.max_keypoints, 256, device=device)
- lg = LightGlueWrapper(max_keypoints=args.max_keypoints).to(device)
- export_model(
- lg,
- (
- keypoints,
- scores,
- descriptors,
- keypoints,
- scores,
- descriptors,
- ),
- args.output_dir / "lightglue.onnx",
- ["matches0", "matches1", "scores0", "scores1"],
- )
- if __name__ == "__main__":
- main()
|