demo_tensorrt_async.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. #!/usr/bin/env python3
  2. """
  3. TensorRT-based LightGlue demo with camera position visualization.
  4. This script expects TensorRT engines produced by `build_tensorrt.py` and runs
  5. SuperPoint + LightGlue fully on TensorRT while keeping the visualization logic
  6. close to the original PyTorch demo.
  7. """
  8. from __future__ import annotations
  9. import argparse
  10. import queue
  11. import threading
  12. import time
  13. from pathlib import Path
  14. from typing import Optional, Tuple
  15. import cv2
  16. import numpy as np
  17. from trt_engine import load_engines
  18. class AverageTimer:
  19. def __init__(self, smoothing: float = 0.3, newline: bool = False):
  20. self.smoothing = smoothing
  21. self.newline = newline
  22. self.times = {}
  23. self.will_print = {}
  24. self.reset()
  25. def reset(self):
  26. now = time.time()
  27. self.start = now
  28. self.last_time = now
  29. for name in self.will_print:
  30. self.will_print[name] = False
  31. def update(self, name: str = "default"):
  32. now = time.time()
  33. dt = now - self.last_time
  34. if name in self.times:
  35. dt = self.smoothing * dt + (1 - self.smoothing) * self.times[name]
  36. self.times[name] = dt
  37. self.will_print[name] = True
  38. self.last_time = now
  39. def print(self, text: str = "Timer"):
  40. total = 0.0
  41. print(f"[{text}]", end=" ")
  42. for key in self.times:
  43. val = self.times[key]
  44. if self.will_print[key]:
  45. print(f"{key}={val:.3f}", end=" ")
  46. total += val
  47. fps = 1.0 / total if total > 0 else 0.0
  48. print(f"total={total:.3f} sec {{{fps:.1f} FPS}}", end=" ")
  49. print(flush=True) if self.newline else print(end="\r", flush=True)
  50. self.reset()
  51. class VideoStreamer:
  52. def __init__(self, source, resize, skip, image_glob, max_length=1_000_000):
  53. self.source = source
  54. self.skip = skip
  55. self.max_length = max_length
  56. self.resize = resize
  57. self.i = 0
  58. self.cap = None
  59. self.is_ip_camera = False
  60. if Path(source).is_dir():
  61. self.listing = []
  62. for ext in image_glob:
  63. self.listing.extend(list(Path(source).glob(ext)))
  64. self.listing = sorted(self.listing)[: self.max_length]
  65. self.max_length = len(self.listing)
  66. if self.max_length == 0:
  67. raise IOError(f"No images found in directory: {source}")
  68. print(f"Found {self.max_length} images in {source}")
  69. elif Path(source).exists():
  70. self.cap = cv2.VideoCapture(source)
  71. else:
  72. is_digit = isinstance(source, int) or (isinstance(source, str) and source.isdigit())
  73. if not is_digit and not Path(str(source)).exists():
  74. self.is_ip_camera = True
  75. self.cap = cv2.VideoCapture(source, cv2.CAP_FFMPEG)
  76. else:
  77. self.cap = cv2.VideoCapture(int(source) if is_digit else source)
  78. if self.is_ip_camera:
  79. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
  80. self.cap.set(cv2.CAP_PROP_FPS, 30)
  81. self.cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"MJPG"))
  82. def next_frame(self):
  83. if self.cap is not None:
  84. if self.is_ip_camera:
  85. for _ in range(3):
  86. ret = self.cap.grab()
  87. if not ret:
  88. break
  89. ret, frame = self.cap.read()
  90. if not ret:
  91. return None, False
  92. if len(frame.shape) == 3:
  93. frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
  94. else:
  95. if self.i >= self.max_length:
  96. return None, False
  97. image_file = self.listing[self.i]
  98. frame = cv2.imread(str(image_file), cv2.IMREAD_GRAYSCALE)
  99. if frame is None:
  100. print(f"Failed to load image: {image_file}")
  101. return None, False
  102. self.i += 1
  103. if len(self.resize) == 2:
  104. frame = cv2.resize(frame, tuple(self.resize))
  105. elif len(self.resize) == 1 and self.resize[0] > 0:
  106. h, w = frame.shape[:2]
  107. scale = self.resize[0] / max(h, w)
  108. frame = cv2.resize(frame, (int(w * scale), int(h * scale)))
  109. if self.cap is not None:
  110. for _ in range(self.skip):
  111. ret, _ = self.cap.read()
  112. if not ret:
  113. return frame, True
  114. return frame, True
  115. def cleanup(self):
  116. if self.cap is not None:
  117. self.cap.release()
  118. def draw_camera_position_on_reference(
  119. reference_frame: np.ndarray,
  120. camera_center_current,
  121. H: Optional[np.ndarray],
  122. num_matches: int = 0,
  123. min_matches: int = 10,
  124. inliers_ratio: float = 0.0,
  125. ):
  126. h_ref, w_ref = reference_frame.shape[:2]
  127. ref_colored = cv2.cvtColor(reference_frame.copy(), cv2.COLOR_GRAY2BGR)
  128. center_ref_int = (w_ref // 2, h_ref // 2)
  129. cv2.circle(ref_colored, center_ref_int, 15, (0, 255, 0), 2)
  130. cv2.line(ref_colored, (center_ref_int[0] - 20, center_ref_int[1]), (center_ref_int[0] + 20, center_ref_int[1]), (0, 255, 0), 3)
  131. cv2.line(ref_colored, (center_ref_int[0], center_ref_int[1] - 20), (center_ref_int[0], center_ref_int[1] + 20), (0, 255, 0), 3)
  132. if H is None or num_matches < min_matches:
  133. status_text = f"Matches: {num_matches}/{min_matches}"
  134. cv2.putText(ref_colored, status_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
  135. cv2.putText(ref_colored, "Camera position not available", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
  136. return ref_colored
  137. try:
  138. H_inv = np.linalg.inv(H)
  139. camera_center_ref = cv2.perspectiveTransform(
  140. np.array([[camera_center_current]], dtype=np.float32).reshape(-1, 1, 2),
  141. H_inv,
  142. )[0, 0]
  143. camera_center_ref = np.clip(camera_center_ref, [0, 0], [w_ref - 1, h_ref - 1])
  144. camera_pos_int = (int(camera_center_ref[0]), int(camera_center_ref[1]))
  145. cv2.circle(ref_colored, camera_pos_int, 12, (0, 0, 255), 2)
  146. cv2.line(ref_colored, (camera_pos_int[0] - 15, camera_pos_int[1]), (camera_pos_int[0] + 15, camera_pos_int[1]), (0, 0, 255), 3)
  147. cv2.line(ref_colored, (camera_pos_int[0], camera_pos_int[1] - 15), (camera_pos_int[0], camera_pos_int[1] + 15), (0, 0, 255), 3)
  148. cv2.line(ref_colored, center_ref_int, camera_pos_int, (255, 0, 255), 2)
  149. cv2.putText(ref_colored, f"Matches: {num_matches}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
  150. cv2.putText(ref_colored, f"Inliers: {inliers_ratio:.1%}", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
  151. except np.linalg.LinAlgError:
  152. cv2.putText(ref_colored, "Homography not invertible", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
  153. return ref_colored
  154. MAX_KEYPOINTS = 128
  155. IMAGE_WIDTH = 640
  156. IMAGE_HEIGHT = 480
  157. class AsyncVideoStreamer:
  158. """Background frame grabber identical to the PyTorch demo."""
  159. def __init__(self, streamer: VideoStreamer, queue_size: int = 1, timeout: float = 1.0):
  160. self.streamer = streamer
  161. self.queue: "queue.Queue[np.ndarray]" = queue.Queue(maxsize=max(queue_size, 1))
  162. self.timeout = timeout
  163. self._stop_requested = False
  164. self._has_error = False
  165. self._thread = threading.Thread(target=self._reader, daemon=True)
  166. self._thread.start()
  167. def _reader(self):
  168. try:
  169. while not self._stop_requested:
  170. frame, ret = self.streamer.next_frame()
  171. if not ret:
  172. self._stop_requested = True
  173. break
  174. if self.queue.full():
  175. try:
  176. self.queue.get_nowait()
  177. except queue.Empty:
  178. pass
  179. self.queue.put(frame)
  180. except Exception as exc: # pragma: no cover
  181. self._has_error = True
  182. print(f"[AsyncVideoStreamer] error: {exc}")
  183. finally:
  184. self._stop_requested = True
  185. def read(self) -> Tuple[Optional[np.ndarray], bool]:
  186. if self._has_error:
  187. return None, False
  188. try:
  189. frame = self.queue.get(timeout=self.timeout)
  190. return frame, True
  191. except queue.Empty:
  192. return None, False
  193. def stop(self):
  194. self._stop_requested = True
  195. if self._thread.is_alive():
  196. self._thread.join(timeout=1.0)
  197. self.streamer.cleanup()
  198. def preprocess_frame(frame: np.ndarray) -> np.ndarray:
  199. """Convert frame to float32 normalized tensor (1,1,H,W)."""
  200. if frame.ndim == 2:
  201. gray = frame
  202. else:
  203. gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
  204. gray = gray.astype(np.float32) / 255.0
  205. return np.expand_dims(np.expand_dims(gray, axis=0), axis=0)
  206. def load_reference_features(engine, frame: np.ndarray):
  207. inputs = {"input": preprocess_frame(frame)}
  208. outputs = engine.infer(inputs)
  209. return (
  210. outputs["keypoints"],
  211. outputs["scores"],
  212. outputs["descriptors"],
  213. outputs["valid_counts"],
  214. )
  215. def parse_args():
  216. parser = argparse.ArgumentParser(description="TensorRT LightGlue demo")
  217. parser.add_argument("--input", type=str, default="0", help="Camera index or URL")
  218. parser.add_argument("--sp-engine", type=Path, default=Path("models/superpoint.plan"), help="SuperPoint TensorRT engine")
  219. parser.add_argument("--lg-engine", type=Path, default=Path("models/lightglue.plan"), help="LightGlue TensorRT engine")
  220. parser.add_argument("--queue_size", type=int, default=1, help="Async frame queue size")
  221. parser.add_argument("--read_timeout", type=float, default=1.0, help="Frame read timeout")
  222. parser.add_argument("--skip", type=int, default=0, help="Frames to skip")
  223. parser.add_argument("--resize", type=int, nargs="+", default=[640, 480], help="Resize dimensions (WxH)")
  224. parser.add_argument("--min_matches", type=int, default=10, help="Minimum matches for homography")
  225. parser.add_argument("--max_keypoints", type=int, default=MAX_KEYPOINTS, help="Max keypoints (must match engine)")
  226. parser.add_argument("--no_display", action="store_true", help="Disable OpenCV windows")
  227. parser.add_argument("--show_fps", action="store_true", help="Overlay FPS text")
  228. parser.add_argument("--no_ip_grab", action="store_true", help="Disable IP camera buffer flushing")
  229. return parser.parse_args()
  230. def main():
  231. opt = parse_args()
  232. sp_engine, lg_engine = load_engines(opt.sp_engine, opt.lg_engine)
  233. streamer = VideoStreamer(opt.input, opt.resize, opt.skip, ["*.png", "*.jpg"], 1_000_000)
  234. if opt.no_ip_grab and hasattr(streamer, "is_ip_camera"):
  235. streamer.is_ip_camera = False
  236. async_streamer = AsyncVideoStreamer(streamer, queue_size=opt.queue_size, timeout=opt.read_timeout)
  237. frame0, ret = async_streamer.read()
  238. assert ret, "Failed to grab initial frame"
  239. keypoints_ref, scores_ref, desc_ref, counts_ref = load_reference_features(sp_engine, frame0)
  240. last_frame = frame0
  241. keypoints_ref = keypoints_ref.astype(np.float32)
  242. scores_ref = scores_ref.astype(np.float32)
  243. desc_ref = desc_ref.astype(np.float32)
  244. valid_ref = int(counts_ref.reshape(-1)[0])
  245. if not opt.no_display:
  246. cv2.namedWindow("Async Camera View", cv2.WINDOW_NORMAL)
  247. cv2.namedWindow("Camera Position in Reference", cv2.WINDOW_NORMAL)
  248. cv2.resizeWindow("Async Camera View", 640, 480)
  249. cv2.resizeWindow("Camera Position in Reference", 640, 480)
  250. timer = AverageTimer()
  251. fps_display = 0.0
  252. last_time = time.time()
  253. try:
  254. while True:
  255. frame, ret = async_streamer.read()
  256. if not ret:
  257. print("Stream ended.")
  258. break
  259. timer.update("data")
  260. kp_cur, sc_cur, desc_cur, counts_cur = load_reference_features(sp_engine, frame)
  261. kp_cur = kp_cur.astype(np.float32)
  262. sc_cur = sc_cur.astype(np.float32)
  263. desc_cur = desc_cur.astype(np.float32)
  264. valid_cur = int(counts_cur.reshape(-1)[0])
  265. lg_inputs = {
  266. "input_0": keypoints_ref,
  267. "input_1": scores_ref,
  268. "input_2": desc_ref,
  269. "input_3": kp_cur,
  270. "input_4": sc_cur,
  271. "input_5": desc_cur,
  272. }
  273. lg_outputs = lg_engine.infer(lg_inputs)
  274. matches0 = lg_outputs["matches0"][0].astype(np.int32)
  275. matches1 = lg_outputs["matches1"][0].astype(np.int32)
  276. mconf0 = lg_outputs["scores0"][0]
  277. timer.update("forward")
  278. current_time = time.time()
  279. dt = current_time - last_time
  280. if dt > 0:
  281. fps_display = 0.9 * fps_display + 0.1 * (1.0 / dt)
  282. last_time = current_time
  283. frame_gray = frame if frame.ndim == 2 else cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
  284. display = cv2.cvtColor(frame_gray, cv2.COLOR_GRAY2BGR)
  285. h, w = display.shape[:2]
  286. center = (w // 2, h // 2)
  287. cv2.line(display, (center[0] - 20, center[1]), (center[0] + 20, center[1]), (0, 0, 255), 4, cv2.LINE_AA)
  288. cv2.line(display, (center[0], center[1] - 20), (center[0], center[1] + 20), (0, 0, 255), 4, cv2.LINE_AA)
  289. idx_range = np.arange(keypoints_ref.shape[1])
  290. valid_mask_ref = idx_range < valid_ref
  291. good = (matches0 > -1) & valid_mask_ref
  292. match_idx_cur = matches0[good]
  293. good &= match_idx_cur < valid_cur
  294. mkpts0 = keypoints_ref[0, good]
  295. mkpts1 = kp_cur[0, matches0[good]]
  296. H = None
  297. inliers_ratio = 0.0
  298. if mkpts0.shape[0] >= opt.min_matches:
  299. H, mask = cv2.findHomography(mkpts0, mkpts1, cv2.RANSAC, 5.0)
  300. if H is not None and mask is not None:
  301. inliers_ratio = float(np.sum(mask)) / float(mask.size)
  302. ref_view = draw_camera_position_on_reference(
  303. last_frame if last_frame.ndim == 2 else cv2.cvtColor(last_frame, cv2.COLOR_BGR2GRAY),
  304. center,
  305. H,
  306. mkpts0.shape[0],
  307. opt.min_matches,
  308. inliers_ratio,
  309. )
  310. if not opt.no_display:
  311. if opt.show_fps:
  312. cv2.putText(display, f"FPS: {fps_display:.1f}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
  313. cv2.imshow("Async Camera View", display)
  314. cv2.imshow("Camera Position in Reference", ref_view)
  315. key = cv2.waitKey(1) & 0xFF
  316. else:
  317. key = 0
  318. if key == ord("q"):
  319. print("Exiting via keyboard.")
  320. break
  321. if key == ord("n"):
  322. keypoints_ref = kp_cur
  323. scores_ref = sc_cur
  324. desc_ref = desc_cur
  325. valid_ref = valid_cur
  326. last_frame = frame
  327. print("Reference frame updated.")
  328. timer.update("viz")
  329. timer.print("LightGlue-TensorRT")
  330. finally:
  331. async_streamer.stop()
  332. cv2.destroyAllWindows()
  333. if __name__ == "__main__":
  334. main()