#!/usr/bin/env python3 """ Export SuperPoint and LightGlue models to ONNX for TensorRT conversion. This script targets a fixed 640x480 (WxH) grayscale input and limits the number of detected keypoints to make the exported graph TensorRT-friendly. """ from __future__ import annotations import argparse from pathlib import Path import torch # Ensure we can import the local LightGlue package that was copied into the # project root. from lightglue.superpoint import SuperPoint # type: ignore from lightglue.lightglue import LightGlue # type: ignore MAX_KEYPOINTS = 128 IMAGE_WIDTH = 640 IMAGE_HEIGHT = 480 class SuperPointWrapper(torch.nn.Module): """Wrap the official SuperPoint module so the outputs have static shapes.""" def __init__(self, max_keypoints: int = MAX_KEYPOINTS): super().__init__() self.model = SuperPoint(max_num_keypoints=max_keypoints) self.max_keypoints = max_keypoints def forward(self, image: torch.Tensor): """ Args: image: (B, 1, H, W) float32 tensor in [0, 1] Returns: keypoints: (B, MAX_K, 2) scores: (B, MAX_K, 1) descriptors: (B, MAX_K, 256) valid_counts: (B, 1) number of real keypoints before padding """ out = self.model({"image": image}) keypoints = out["keypoints"] # (B, N, 2) scores = out["keypoint_scores"].unsqueeze(-1) # (B, N, 1) descriptors = out["descriptors"] # (B, N, 256) batch_size, num_kp, _ = keypoints.shape max_k = self.max_keypoints # Clamp to max_k and record valid counts (for downstream masking). clamped_num = torch.clamp(torch.tensor([num_kp], device=image.device), max=max_k) valid_counts = clamped_num.expand(batch_size, 1).to(torch.int32) keypoints = keypoints[:, :max_k, :] scores = scores[:, :max_k, :] descriptors = descriptors[:, :max_k, :] pad_k = max_k - keypoints.shape[1] if pad_k > 0: pad_shape_kp = (0, 0, 0, pad_k) keypoints = torch.nn.functional.pad(keypoints, pad_shape_kp) scores = torch.nn.functional.pad(scores, (0, 0, 0, pad_k)) descriptors = torch.nn.functional.pad(descriptors, (0, 0, 0, pad_k)) return keypoints, scores, descriptors, valid_counts class LightGlueWrapper(torch.nn.Module): """Wrap LightGlue so it consumes SuperPoint outputs with static shapes.""" def __init__(self, max_keypoints: int = MAX_KEYPOINTS): super().__init__() self.model = LightGlue(features="superpoint") self.max_keypoints = max_keypoints def forward( self, keypoints0, scores0, descriptors0, keypoints1, scores1, descriptors1, ): """ Args: keypoints{0,1}: (B, MAX_K, 2) scores{0,1}: (B, MAX_K, 1) descriptors{0,1}: (B, MAX_K, 256) Returns: matches0: (B, MAX_K) indices of matches in image1 (or -1) matches1: (B, MAX_K) indices of matches in image0 (or -1) mscores0: (B, MAX_K) mscores1: (B, MAX_K) """ batch = { "image0": { "keypoints": keypoints0, "keypoint_scores": scores0.squeeze(-1), "descriptors": descriptors0.transpose(-1, -2), }, "image1": { "keypoints": keypoints1, "keypoint_scores": scores1.squeeze(-1), "descriptors": descriptors1.transpose(-1, -2), }, } out = self.model(batch) matches0 = out["matches0"] # (B, MAX_K) matches1 = out["matches1"] mscores0 = out["matching_scores0"] mscores1 = out["matching_scores1"] return matches0, matches1, mscores0, mscores1 def export_model(module: torch.nn.Module, inputs, output_path: Path, output_names): output_path.parent.mkdir(parents=True, exist_ok=True) module.eval() with torch.no_grad(): torch.onnx.export( module, inputs, output_path.as_posix(), export_params=True, opset_version=17, do_constant_folding=True, input_names=[f"input_{i}" for i in range(len(inputs))] if isinstance(inputs, (tuple, list)) else ["input"], output_names=output_names, dynamic_axes=None, ) print(f"[OK] Exported ONNX: {output_path}") def parse_args(): parser = argparse.ArgumentParser(description="Export LightGlue pipeline to ONNX.") parser.add_argument( "--output-dir", type=Path, default=Path("models"), help="Where to store the ONNX files.", ) parser.add_argument( "--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Torch device for the export dummy run.", ) parser.add_argument( "--max-keypoints", type=int, default=MAX_KEYPOINTS, help="Maximum number of keypoints to keep (must match inference).", ) return parser.parse_args() def main(): args = parse_args() device = torch.device(args.device if torch.cuda.is_available() else "cpu") dummy = torch.rand(1, 1, IMAGE_HEIGHT, IMAGE_WIDTH, device=device) sp = SuperPointWrapper(max_keypoints=args.max_keypoints).to(device) export_model( sp, dummy, args.output_dir / "superpoint.onnx", ["keypoints", "scores", "descriptors", "valid_counts"], ) # Prepare dummy inputs for LightGlue (batch=1) keypoints = torch.zeros(1, args.max_keypoints, 2, device=device) scores = torch.zeros(1, args.max_keypoints, 1, device=device) descriptors = torch.zeros(1, args.max_keypoints, 256, device=device) lg = LightGlueWrapper(max_keypoints=args.max_keypoints).to(device) export_model( lg, ( keypoints, scores, descriptors, keypoints, scores, descriptors, ), args.output_dir / "lightglue.onnx", ["matches0", "matches1", "scores0", "scores1"], ) if __name__ == "__main__": main()