| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397 |
- #!/usr/bin/env python3
- """
- TensorRT-based LightGlue demo with camera position visualization.
- This script expects TensorRT engines produced by `build_tensorrt.py` and runs
- SuperPoint + LightGlue fully on TensorRT while keeping the visualization logic
- close to the original PyTorch demo.
- """
- from __future__ import annotations
- import argparse
- import queue
- import threading
- import time
- from pathlib import Path
- from typing import Optional, Tuple
- import cv2
- import numpy as np
- from trt_engine import load_engines
- class AverageTimer:
- def __init__(self, smoothing: float = 0.3, newline: bool = False):
- self.smoothing = smoothing
- self.newline = newline
- self.times = {}
- self.will_print = {}
- self.reset()
- def reset(self):
- now = time.time()
- self.start = now
- self.last_time = now
- for name in self.will_print:
- self.will_print[name] = False
- def update(self, name: str = "default"):
- now = time.time()
- dt = now - self.last_time
- if name in self.times:
- dt = self.smoothing * dt + (1 - self.smoothing) * self.times[name]
- self.times[name] = dt
- self.will_print[name] = True
- self.last_time = now
- def print(self, text: str = "Timer"):
- total = 0.0
- print(f"[{text}]", end=" ")
- for key in self.times:
- val = self.times[key]
- if self.will_print[key]:
- print(f"{key}={val:.3f}", end=" ")
- total += val
- fps = 1.0 / total if total > 0 else 0.0
- print(f"total={total:.3f} sec {{{fps:.1f} FPS}}", end=" ")
- print(flush=True) if self.newline else print(end="\r", flush=True)
- self.reset()
- class VideoStreamer:
- def __init__(self, source, resize, skip, image_glob, max_length=1_000_000):
- self.source = source
- self.skip = skip
- self.max_length = max_length
- self.resize = resize
- self.i = 0
- self.cap = None
- self.is_ip_camera = False
- if Path(source).is_dir():
- self.listing = []
- for ext in image_glob:
- self.listing.extend(list(Path(source).glob(ext)))
- self.listing = sorted(self.listing)[: self.max_length]
- self.max_length = len(self.listing)
- if self.max_length == 0:
- raise IOError(f"No images found in directory: {source}")
- print(f"Found {self.max_length} images in {source}")
- elif Path(source).exists():
- self.cap = cv2.VideoCapture(source)
- else:
- is_digit = isinstance(source, int) or (isinstance(source, str) and source.isdigit())
- if not is_digit and not Path(str(source)).exists():
- self.is_ip_camera = True
- self.cap = cv2.VideoCapture(source, cv2.CAP_FFMPEG)
- else:
- self.cap = cv2.VideoCapture(int(source) if is_digit else source)
- if self.is_ip_camera:
- self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
- self.cap.set(cv2.CAP_PROP_FPS, 30)
- self.cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"MJPG"))
- def next_frame(self):
- if self.cap is not None:
- if self.is_ip_camera:
- for _ in range(3):
- ret = self.cap.grab()
- if not ret:
- break
- ret, frame = self.cap.read()
- if not ret:
- return None, False
- if len(frame.shape) == 3:
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
- else:
- if self.i >= self.max_length:
- return None, False
- image_file = self.listing[self.i]
- frame = cv2.imread(str(image_file), cv2.IMREAD_GRAYSCALE)
- if frame is None:
- print(f"Failed to load image: {image_file}")
- return None, False
- self.i += 1
- if len(self.resize) == 2:
- frame = cv2.resize(frame, tuple(self.resize))
- elif len(self.resize) == 1 and self.resize[0] > 0:
- h, w = frame.shape[:2]
- scale = self.resize[0] / max(h, w)
- frame = cv2.resize(frame, (int(w * scale), int(h * scale)))
- if self.cap is not None:
- for _ in range(self.skip):
- ret, _ = self.cap.read()
- if not ret:
- return frame, True
- return frame, True
- def cleanup(self):
- if self.cap is not None:
- self.cap.release()
- def draw_camera_position_on_reference(
- reference_frame: np.ndarray,
- camera_center_current,
- H: Optional[np.ndarray],
- num_matches: int = 0,
- min_matches: int = 10,
- inliers_ratio: float = 0.0,
- ):
- h_ref, w_ref = reference_frame.shape[:2]
- ref_colored = cv2.cvtColor(reference_frame.copy(), cv2.COLOR_GRAY2BGR)
- center_ref_int = (w_ref // 2, h_ref // 2)
- cv2.circle(ref_colored, center_ref_int, 15, (0, 255, 0), 2)
- cv2.line(ref_colored, (center_ref_int[0] - 20, center_ref_int[1]), (center_ref_int[0] + 20, center_ref_int[1]), (0, 255, 0), 3)
- cv2.line(ref_colored, (center_ref_int[0], center_ref_int[1] - 20), (center_ref_int[0], center_ref_int[1] + 20), (0, 255, 0), 3)
- if H is None or num_matches < min_matches:
- status_text = f"Matches: {num_matches}/{min_matches}"
- cv2.putText(ref_colored, status_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
- cv2.putText(ref_colored, "Camera position not available", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
- return ref_colored
- try:
- H_inv = np.linalg.inv(H)
- camera_center_ref = cv2.perspectiveTransform(
- np.array([[camera_center_current]], dtype=np.float32).reshape(-1, 1, 2),
- H_inv,
- )[0, 0]
- camera_center_ref = np.clip(camera_center_ref, [0, 0], [w_ref - 1, h_ref - 1])
- camera_pos_int = (int(camera_center_ref[0]), int(camera_center_ref[1]))
- cv2.circle(ref_colored, camera_pos_int, 12, (0, 0, 255), 2)
- cv2.line(ref_colored, (camera_pos_int[0] - 15, camera_pos_int[1]), (camera_pos_int[0] + 15, camera_pos_int[1]), (0, 0, 255), 3)
- cv2.line(ref_colored, (camera_pos_int[0], camera_pos_int[1] - 15), (camera_pos_int[0], camera_pos_int[1] + 15), (0, 0, 255), 3)
- cv2.line(ref_colored, center_ref_int, camera_pos_int, (255, 0, 255), 2)
- cv2.putText(ref_colored, f"Matches: {num_matches}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
- cv2.putText(ref_colored, f"Inliers: {inliers_ratio:.1%}", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
- except np.linalg.LinAlgError:
- cv2.putText(ref_colored, "Homography not invertible", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
- return ref_colored
- MAX_KEYPOINTS = 128
- IMAGE_WIDTH = 640
- IMAGE_HEIGHT = 480
- class AsyncVideoStreamer:
- """Background frame grabber identical to the PyTorch demo."""
- def __init__(self, streamer: VideoStreamer, queue_size: int = 1, timeout: float = 1.0):
- self.streamer = streamer
- self.queue: "queue.Queue[np.ndarray]" = queue.Queue(maxsize=max(queue_size, 1))
- self.timeout = timeout
- self._stop_requested = False
- self._has_error = False
- self._thread = threading.Thread(target=self._reader, daemon=True)
- self._thread.start()
- def _reader(self):
- try:
- while not self._stop_requested:
- frame, ret = self.streamer.next_frame()
- if not ret:
- self._stop_requested = True
- break
- if self.queue.full():
- try:
- self.queue.get_nowait()
- except queue.Empty:
- pass
- self.queue.put(frame)
- except Exception as exc: # pragma: no cover
- self._has_error = True
- print(f"[AsyncVideoStreamer] error: {exc}")
- finally:
- self._stop_requested = True
- def read(self) -> Tuple[Optional[np.ndarray], bool]:
- if self._has_error:
- return None, False
- try:
- frame = self.queue.get(timeout=self.timeout)
- return frame, True
- except queue.Empty:
- return None, False
- def stop(self):
- self._stop_requested = True
- if self._thread.is_alive():
- self._thread.join(timeout=1.0)
- self.streamer.cleanup()
- def preprocess_frame(frame: np.ndarray) -> np.ndarray:
- """Convert frame to float32 normalized tensor (1,1,H,W)."""
- if frame.ndim == 2:
- gray = frame
- else:
- gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
- gray = gray.astype(np.float32) / 255.0
- return np.expand_dims(np.expand_dims(gray, axis=0), axis=0)
- def load_reference_features(engine, frame: np.ndarray):
- inputs = {"input": preprocess_frame(frame)}
- outputs = engine.infer(inputs)
- return (
- outputs["keypoints"],
- outputs["scores"],
- outputs["descriptors"],
- outputs["valid_counts"],
- )
- def parse_args():
- parser = argparse.ArgumentParser(description="TensorRT LightGlue demo")
- parser.add_argument("--input", type=str, default="0", help="Camera index or URL")
- parser.add_argument("--sp-engine", type=Path, default=Path("models/superpoint.plan"), help="SuperPoint TensorRT engine")
- parser.add_argument("--lg-engine", type=Path, default=Path("models/lightglue.plan"), help="LightGlue TensorRT engine")
- parser.add_argument("--queue_size", type=int, default=1, help="Async frame queue size")
- parser.add_argument("--read_timeout", type=float, default=1.0, help="Frame read timeout")
- parser.add_argument("--skip", type=int, default=0, help="Frames to skip")
- parser.add_argument("--resize", type=int, nargs="+", default=[640, 480], help="Resize dimensions (WxH)")
- parser.add_argument("--min_matches", type=int, default=10, help="Minimum matches for homography")
- parser.add_argument("--max_keypoints", type=int, default=MAX_KEYPOINTS, help="Max keypoints (must match engine)")
- parser.add_argument("--no_display", action="store_true", help="Disable OpenCV windows")
- parser.add_argument("--show_fps", action="store_true", help="Overlay FPS text")
- parser.add_argument("--no_ip_grab", action="store_true", help="Disable IP camera buffer flushing")
- return parser.parse_args()
- def main():
- opt = parse_args()
- sp_engine, lg_engine = load_engines(opt.sp_engine, opt.lg_engine)
- streamer = VideoStreamer(opt.input, opt.resize, opt.skip, ["*.png", "*.jpg"], 1_000_000)
- if opt.no_ip_grab and hasattr(streamer, "is_ip_camera"):
- streamer.is_ip_camera = False
- async_streamer = AsyncVideoStreamer(streamer, queue_size=opt.queue_size, timeout=opt.read_timeout)
- frame0, ret = async_streamer.read()
- assert ret, "Failed to grab initial frame"
- keypoints_ref, scores_ref, desc_ref, counts_ref = load_reference_features(sp_engine, frame0)
- last_frame = frame0
- keypoints_ref = keypoints_ref.astype(np.float32)
- scores_ref = scores_ref.astype(np.float32)
- desc_ref = desc_ref.astype(np.float32)
- valid_ref = int(counts_ref.reshape(-1)[0])
- if not opt.no_display:
- cv2.namedWindow("Async Camera View", cv2.WINDOW_NORMAL)
- cv2.namedWindow("Camera Position in Reference", cv2.WINDOW_NORMAL)
- cv2.resizeWindow("Async Camera View", 640, 480)
- cv2.resizeWindow("Camera Position in Reference", 640, 480)
- timer = AverageTimer()
- fps_display = 0.0
- last_time = time.time()
- try:
- while True:
- frame, ret = async_streamer.read()
- if not ret:
- print("Stream ended.")
- break
- timer.update("data")
- kp_cur, sc_cur, desc_cur, counts_cur = load_reference_features(sp_engine, frame)
- kp_cur = kp_cur.astype(np.float32)
- sc_cur = sc_cur.astype(np.float32)
- desc_cur = desc_cur.astype(np.float32)
- valid_cur = int(counts_cur.reshape(-1)[0])
- lg_inputs = {
- "input_0": keypoints_ref,
- "input_1": scores_ref,
- "input_2": desc_ref,
- "input_3": kp_cur,
- "input_4": sc_cur,
- "input_5": desc_cur,
- }
- lg_outputs = lg_engine.infer(lg_inputs)
- matches0 = lg_outputs["matches0"][0].astype(np.int32)
- matches1 = lg_outputs["matches1"][0].astype(np.int32)
- mconf0 = lg_outputs["scores0"][0]
- timer.update("forward")
- current_time = time.time()
- dt = current_time - last_time
- if dt > 0:
- fps_display = 0.9 * fps_display + 0.1 * (1.0 / dt)
- last_time = current_time
- frame_gray = frame if frame.ndim == 2 else cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
- display = cv2.cvtColor(frame_gray, cv2.COLOR_GRAY2BGR)
- h, w = display.shape[:2]
- center = (w // 2, h // 2)
- cv2.line(display, (center[0] - 20, center[1]), (center[0] + 20, center[1]), (0, 0, 255), 4, cv2.LINE_AA)
- cv2.line(display, (center[0], center[1] - 20), (center[0], center[1] + 20), (0, 0, 255), 4, cv2.LINE_AA)
- idx_range = np.arange(keypoints_ref.shape[1])
- valid_mask_ref = idx_range < valid_ref
- good = (matches0 > -1) & valid_mask_ref
- match_idx_cur = matches0[good]
- good &= match_idx_cur < valid_cur
- mkpts0 = keypoints_ref[0, good]
- mkpts1 = kp_cur[0, matches0[good]]
- H = None
- inliers_ratio = 0.0
- if mkpts0.shape[0] >= opt.min_matches:
- H, mask = cv2.findHomography(mkpts0, mkpts1, cv2.RANSAC, 5.0)
- if H is not None and mask is not None:
- inliers_ratio = float(np.sum(mask)) / float(mask.size)
- ref_view = draw_camera_position_on_reference(
- last_frame if last_frame.ndim == 2 else cv2.cvtColor(last_frame, cv2.COLOR_BGR2GRAY),
- center,
- H,
- mkpts0.shape[0],
- opt.min_matches,
- inliers_ratio,
- )
- if not opt.no_display:
- if opt.show_fps:
- cv2.putText(display, f"FPS: {fps_display:.1f}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
- cv2.imshow("Async Camera View", display)
- cv2.imshow("Camera Position in Reference", ref_view)
- key = cv2.waitKey(1) & 0xFF
- else:
- key = 0
- if key == ord("q"):
- print("Exiting via keyboard.")
- break
- if key == ord("n"):
- keypoints_ref = kp_cur
- scores_ref = sc_cur
- desc_ref = desc_cur
- valid_ref = valid_cur
- last_frame = frame
- print("Reference frame updated.")
- timer.update("viz")
- timer.print("LightGlue-TensorRT")
- finally:
- async_streamer.stop()
- cv2.destroyAllWindows()
- if __name__ == "__main__":
- main()
|