demo_lightglue_camera_position_async.py 49 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076
  1. #!/usr/bin/env python3
  2. """
  3. LightGlue demo with asynchronous video streaming.
  4. This script demonstrates how to decouple frame acquisition from model
  5. inference by reading camera frames in a background thread. The goal is to
  6. reduce blocking caused by `cv2.VideoCapture.read()` and keep the downstream
  7. pipeline busy with the most recent frame available.
  8. """
  9. # 强制关闭 Windows 11 的后台限制(效率模式 / EcoQoS),避免游戏全屏时 Python 被系统挂起
  10. import ctypes
  11. import sys
  12. def _disable_win11_efficiency_mode():
  13. if sys.platform != "win32":
  14. return
  15. try:
  16. handle = ctypes.windll.kernel32.GetCurrentProcess()
  17. # PROCESS_POWER_THROTTLING_STATE: Version=1, ControlMask=1(Speed), StateMask=0(Disable limit)
  18. class ProcessPowerThrottlingState(ctypes.Structure):
  19. _fields_ = [
  20. ("Version", ctypes.c_ulong),
  21. ("ControlMask", ctypes.c_ulong),
  22. ("StateMask", ctypes.c_ulong),
  23. ]
  24. state = ProcessPowerThrottlingState(1, 1, 0)
  25. ProcessPowerThrottling = 4 # ProcessPowerThrottling
  26. if ctypes.windll.kernel32.SetProcessInformation(
  27. handle, ProcessPowerThrottling, ctypes.byref(state), ctypes.sizeof(state)
  28. ):
  29. print("Win11: 已禁用进程效率模式限制 (EcoQoS)", flush=True)
  30. else:
  31. print("Win11: 设置效率模式失败 (非致命)", flush=True)
  32. except Exception as e: # pylint: disable=broad-except
  33. print(f"Win11 效率模式设置异常: {e}", flush=True)
  34. _disable_win11_efficiency_mode()
  35. import argparse
  36. import queue
  37. import threading
  38. import time
  39. import socket
  40. from pathlib import Path
  41. import cv2
  42. import numpy as np
  43. import torch
  44. from demo_lightglue_camera_position_single_window import (
  45. AverageTimer,
  46. VideoStreamer,
  47. draw_camera_position_on_reference,
  48. frame2tensor,
  49. )
  50. from lightglue import LightGlue, SuperPoint
  51. # UDP结果发送器(可选)
  52. try:
  53. from udp_result_sender import UDPResultSender
  54. UDP_RESULT_SENDER_AVAILABLE = True
  55. except ImportError:
  56. UDPResultSender = None
  57. UDP_RESULT_SENDER_AVAILABLE = False
  58. # Import helper functions for SuperPoint post-processing
  59. try:
  60. from lightglue.superpoint import simple_nms, top_k_keypoints, sample_descriptors
  61. # remove_borders is not a function, it's a config parameter in SuperPoint.conf
  62. HELPER_FUNCTIONS_AVAILABLE = True
  63. except ImportError as e:
  64. # These will be imported later if needed
  65. simple_nms = top_k_keypoints = sample_descriptors = None
  66. HELPER_FUNCTIONS_AVAILABLE = False
  67. # TensorRT support
  68. try:
  69. import torch_tensorrt
  70. TENSORRT_AVAILABLE = True
  71. except ImportError:
  72. TENSORRT_AVAILABLE = False
  73. # 屏幕截屏(DXGI Desktop Duplication,用于 Unity 指令 s:Python 截屏作基准图)
  74. try:
  75. import dxcam
  76. try:
  77. _DXCAM_CAMERA = dxcam.create(output_color="RGB")
  78. DXCAM_AVAILABLE = True
  79. except Exception: # pylint: disable=broad-except
  80. _DXCAM_CAMERA = None
  81. DXCAM_AVAILABLE = False
  82. except ImportError:
  83. dxcam = None # type: ignore[assignment]
  84. _DXCAM_CAMERA = None
  85. DXCAM_AVAILABLE = False
  86. torch.set_grad_enabled(False)
  87. class AsyncVisualizer:
  88. """Asynchronous visualizer that handles CPU operations and drawing in a background thread."""
  89. def __init__(self, queue_size: int = 2, min_matches: int = 10, result_sender=None):
  90. self.input_queue: "queue.Queue" = queue.Queue(maxsize=queue_size)
  91. self.output_queue: "queue.Queue" = queue.Queue(maxsize=queue_size)
  92. self._stop_requested = False
  93. self.min_matches = min_matches
  94. self.result_sender = result_sender # UDP结果发送器(可选)
  95. self._thread = threading.Thread(target=self._visualizer_worker, name="AsyncVisualizer", daemon=True)
  96. self._thread.start()
  97. def _visualizer_worker(self) -> None:
  98. """Background thread that performs all CPU operations and drawing."""
  99. while not self._stop_requested:
  100. try:
  101. # Get data from queue (with timeout to check stop flag)
  102. try:
  103. viz_data = self.input_queue.get(timeout=0.1)
  104. except queue.Empty:
  105. continue
  106. if viz_data is None: # Sentinel value to stop
  107. break
  108. # Unpack data - now includes raw tensors
  109. (frame, last_frame, last_data, curr_data, matches01,
  110. camera_center_current, show_fps, fps_display) = viz_data
  111. # All CPU operations happen here (in background thread)
  112. # 1. Convert tensors to numpy (this was blocking the main thread)
  113. kpts0 = last_data["keypoints"][0].detach().cpu().numpy()
  114. kpts1 = curr_data["keypoints"][0].detach().cpu().numpy()
  115. matches = matches01["matches0"][0].detach().cpu().numpy()
  116. scores = matches01["matching_scores0"][0].detach().cpu().numpy()
  117. # 2. Process matches and compute homography (this was blocking)
  118. valid = matches > -1
  119. mkpts0 = kpts0[valid]
  120. mkpts1 = kpts1[matches[valid]]
  121. mconf = scores[valid] if valid.any() else np.array([])
  122. num_matches = len(mkpts0)
  123. inliers_ratio = 0.0
  124. current_H = None
  125. if num_matches >= self.min_matches:
  126. current_H, mask = cv2.findHomography(mkpts0, mkpts1, cv2.RANSAC, 5.0)
  127. if current_H is not None and mask is not None:
  128. inliers_ratio = float(np.sum(mask)) / max(num_matches, 1)
  129. mean_score = float(mconf.mean()) if len(mconf) else 0.0
  130. # print(f"[Homography] matches={num_matches} inliers={inliers_ratio:.2%} score={mean_score:.3f}")
  131. else:
  132. # print(f"[Homography] failed with matches={num_matches}")
  133. pass
  134. else:
  135. if num_matches > 0:
  136. mean_score = float(mconf.mean()) if len(mconf) else 0.0
  137. # print(f"[Matches] insufficient ({num_matches}/{self.min_matches}) score={mean_score:.3f}")
  138. # 3. Prepare display frame (drawing operations)
  139. display_frame = cv2.cvtColor(frame.copy(), cv2.COLOR_GRAY2BGR)
  140. center_x, center_y = display_frame.shape[1] // 2, display_frame.shape[0] // 2
  141. crosshair_size = 20
  142. cv2.line(display_frame, (center_x - crosshair_size, center_y),
  143. (center_x + crosshair_size, center_y), (0, 0, 255), 4, cv2.LINE_AA)
  144. cv2.line(display_frame, (center_x, center_y - crosshair_size),
  145. (center_x, center_y + crosshair_size), (0, 0, 255), 4, cv2.LINE_AA)
  146. # Don't draw FPS here - it will be drawn in main thread with latest value
  147. # if show_fps:
  148. # cv2.putText(display_frame, f"FPS: {fps_display:.1f}", (10, 30),
  149. # cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
  150. cv2.putText(display_frame, f"Matches: {num_matches}", (10, 60),
  151. cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
  152. # 4. Calculate camera position in reference frame (if homography is valid)
  153. camera_center_ref = None
  154. if current_H is not None and num_matches >= self.min_matches:
  155. try:
  156. H_inv = np.linalg.inv(current_H)
  157. camera_center_ref = cv2.perspectiveTransform(
  158. np.array([[camera_center_current]], dtype=np.float32).reshape(-1, 1, 2),
  159. H_inv
  160. )[0, 0]
  161. except np.linalg.LinAlgError:
  162. camera_center_ref = None
  163. # 5. Send result to Unity via UDP (if sender is available)
  164. if self.result_sender is not None:
  165. try:
  166. if camera_center_ref is not None:
  167. self.result_sender.send_result(
  168. is_valid=True,
  169. num_matches=num_matches,
  170. inliers_ratio=inliers_ratio,
  171. camera_x=float(camera_center_ref[0]),
  172. camera_y=float(camera_center_ref[1])
  173. )
  174. else:
  175. # Send invalid result
  176. self.result_sender.send_result(
  177. is_valid=False,
  178. num_matches=num_matches,
  179. inliers_ratio=0.0,
  180. camera_x=0.0,
  181. camera_y=0.0
  182. )
  183. except Exception as e:
  184. print(f"[AsyncVisualizer] Failed to send result: {e}")
  185. # 6. Draw reference view
  186. reference_view = draw_camera_position_on_reference(
  187. last_frame,
  188. camera_center_current,
  189. current_H,
  190. num_matches,
  191. self.min_matches,
  192. inliers_ratio,
  193. )
  194. # Put drawn frames in output queue (non-blocking, drop if full)
  195. try:
  196. self.output_queue.put_nowait((display_frame, reference_view, num_matches, current_H))
  197. except queue.Full:
  198. # Drop old frame, keep only latest
  199. try:
  200. self.output_queue.get_nowait()
  201. self.output_queue.put_nowait((display_frame, reference_view, num_matches, current_H))
  202. except queue.Empty:
  203. pass
  204. except Exception as e:
  205. print(f"[AsyncVisualizer] Error: {e}")
  206. import traceback
  207. traceback.print_exc()
  208. continue
  209. def submit(self, frame, last_frame, last_data, curr_data, matches01,
  210. camera_center_current, show_fps, fps_display):
  211. """Submit raw data for processing (non-blocking)."""
  212. viz_data = (frame, last_frame, last_data, curr_data, matches01,
  213. camera_center_current, show_fps, fps_display)
  214. try:
  215. # Non-blocking put, drop if queue is full (keep only latest)
  216. if self.input_queue.full():
  217. try:
  218. self.input_queue.get_nowait()
  219. except queue.Empty:
  220. pass
  221. self.input_queue.put_nowait(viz_data)
  222. except queue.Full:
  223. pass # Drop if still full
  224. def get_result(self, timeout: float = 0.01):
  225. """Get visualization result (non-blocking)."""
  226. try:
  227. return self.output_queue.get(timeout=timeout)
  228. except queue.Empty:
  229. return None
  230. def stop(self):
  231. """Stop the visualizer thread."""
  232. self._stop_requested = True
  233. self.input_queue.put(None) # Sentinel value
  234. self._thread.join(timeout=1.0)
  235. # Close result sender if available
  236. if self.result_sender is not None:
  237. try:
  238. self.result_sender.close()
  239. except Exception as e:
  240. print(f"[AsyncVisualizer] Error closing result sender: {e}")
  241. class AsyncVideoStreamer:
  242. """Wrapper around VideoStreamer that reads frames on a background thread."""
  243. def __init__(self, streamer: VideoStreamer, queue_size: int = 1, timeout: float = 1.0):
  244. self.streamer = streamer
  245. self.queue: "queue.Queue[np.ndarray]" = queue.Queue(maxsize=max(queue_size, 1))
  246. self.timeout = timeout
  247. self._stop_requested = False
  248. self._has_error = False
  249. self._thread = threading.Thread(target=self._reader, name="AsyncVideoStreamer", daemon=True)
  250. self._thread.start()
  251. def _reader(self) -> None:
  252. try:
  253. while not self._stop_requested:
  254. frame, ret = self.streamer.next_frame()
  255. if not ret:
  256. # For UDP mode, no frame doesn't mean end of stream
  257. # Just wait a bit and retry
  258. if hasattr(self.streamer, "is_udp_jpeg") and self.streamer.is_udp_jpeg:
  259. time.sleep(0.01) # Wait 10ms before retry
  260. continue
  261. # End of stream or error: signal stop and exit
  262. self._stop_requested = True
  263. break
  264. # Keep only the most recent frame to minimise latency
  265. if self.queue.full():
  266. try:
  267. self.queue.get_nowait()
  268. except queue.Empty:
  269. pass
  270. self.queue.put(frame)
  271. except Exception as exc: # pylint: disable=broad-except
  272. self._has_error = True
  273. print(f"[AsyncVideoStreamer] Reader thread error: {exc}")
  274. finally:
  275. self._stop_requested = True
  276. def read(self):
  277. """Return the latest frame. Blocks up to `timeout` seconds."""
  278. if self._has_error:
  279. return None, False
  280. try:
  281. frame = self.queue.get(timeout=self.timeout)
  282. return frame, True
  283. except queue.Empty:
  284. return None, False
  285. def stop(self):
  286. self._stop_requested = True
  287. if self._thread.is_alive():
  288. self._thread.join(timeout=1.0)
  289. self.streamer.cleanup()
  290. def load_reference_frame(opt, device):
  291. if opt.reference_image is None:
  292. return None, None, None
  293. print(f"==> Loading reference image: {opt.reference_image}")
  294. ref_image = cv2.imread(opt.reference_image, cv2.IMREAD_GRAYSCALE)
  295. if ref_image is None:
  296. raise IOError(f"Cannot load reference image: {opt.reference_image}")
  297. h, w = ref_image.shape[:2]
  298. if len(opt.resize) == 2:
  299. ref_image = cv2.resize(ref_image, tuple(opt.resize))
  300. elif len(opt.resize) == 1 and opt.resize[0] > 0:
  301. scale = opt.resize[0] / max(h, w)
  302. ref_image = cv2.resize(ref_image, (int(w * scale), int(h * scale)))
  303. ref_tensor = frame2tensor(ref_image, device)
  304. return ref_image, ref_tensor, 0
  305. def parse_args():
  306. parser = argparse.ArgumentParser(
  307. description="LightGlue demo (asynchronous capture)",
  308. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  309. )
  310. parser.add_argument("--input", type=str, default="0", help="USB webcam index, IP camera URL, UDP stream (udp://host:port), or video path")
  311. parser.add_argument("--reference_image", type=str, default=None, help="Optional reference image path")
  312. parser.add_argument("--output_dir", type=str, default=None, help="Directory to save visualisations")
  313. parser.add_argument("--image_glob", type=str, nargs="+", default=["*.png", "*.jpg", "*.jpeg"], help="Glob for image sequences")
  314. parser.add_argument("--skip", type=int, default=0, help="Number of frames to skip between reads")
  315. parser.add_argument("--max_length", type=int, default=1_000_000, help="Maximum frames")
  316. parser.add_argument(
  317. "--resize",
  318. type=int,
  319. nargs="+",
  320. default=[640, 480],
  321. help="Resize input image. Two numbers = width height, one number = max dimension, -1 = no resize",
  322. )
  323. parser.add_argument("--max_keypoints", type=int, default=1024, help="Maximum number of SuperPoint keypoints")
  324. parser.add_argument("--keypoint_threshold", type=float, default=0.01, help="SuperPoint detection threshold")
  325. parser.add_argument("--nms_radius", type=int, default=4, help="SuperPoint NMS radius")
  326. parser.add_argument("--match_threshold", type=float, default=0.2, help="LightGlue match threshold")
  327. parser.add_argument("--depth_confidence", type=float, default=0.95, help="LightGlue depth confidence")
  328. parser.add_argument("--width_confidence", type=float, default=0.99, help="LightGlue width confidence")
  329. parser.add_argument("--use_fp16", action="store_true", help="Enable FP16 half precision inference for faster processing")
  330. parser.add_argument("--use_tensorrt", action="store_true", help="Use TensorRT optimized models (requires torch-tensorrt)")
  331. parser.add_argument("--tensorrt_precision", type=str, default="fp16", choices=["fp32", "fp16", "int8"],
  332. help="TensorRT precision mode (fp16 recommended)")
  333. parser.add_argument("--tensorrt_calibration_data", type=str, default=None,
  334. help="Directory containing calibration images for INT8 quantization (optional)")
  335. parser.add_argument("--tensorrt_calibration_batches", type=int, default=10,
  336. help="Number of calibration batches for INT8 (default: 10)")
  337. parser.add_argument("--min_matches", type=int, default=10, help="Minimum matches required to compute homography")
  338. parser.add_argument("--queue_size", type=int, default=1, help="Frame queue size for async reader")
  339. parser.add_argument("--read_timeout", type=float, default=1.0, help="Seconds to wait for a frame from async reader")
  340. parser.add_argument("--flip_horizontal", action="store_true", help="Flip frames horizontally")
  341. parser.add_argument("--flip_vertical", action="store_true", help="Flip frames vertically")
  342. parser.add_argument("--rotate", type=int, default=0, choices=[0, 90, 180, 270], help="Rotate frames clockwise")
  343. parser.add_argument("--show_fps", action="store_true", help="Render FPS overlay")
  344. parser.add_argument("--force_cpu", action="store_true", help="Run inference on CPU even if CUDA is available")
  345. parser.add_argument("--no_ip_grab", action="store_true", help="Disable extra grab calls for IP cameras (reduces frame drops but may increase latency)")
  346. parser.add_argument("--no_display", action="store_true", help="Disable OpenCV window")
  347. parser.add_argument("--no_ui", action="store_true", help="Suppress console output (UI embedding)")
  348. parser.add_argument("--result_ip", type=str, default="127.0.0.1", help="Unity IP address for result transmission (default: 127.0.0.1)")
  349. parser.add_argument("--result_port", type=int, default=12348, help="Unity UDP port for result transmission (default: 12348)")
  350. parser.add_argument(
  351. "--control_port",
  352. type=int,
  353. default=0,
  354. help="Optional UDP port for receiving control commands from Unity (e.g., refresh reference frame). 0=disabled.",
  355. )
  356. return parser.parse_args()
  357. def maybe_resize(input_frame, resize_opt):
  358. if len(resize_opt) == 2:
  359. return cv2.resize(input_frame, tuple(resize_opt))
  360. if len(resize_opt) == 1 and resize_opt[0] > 0:
  361. h, w = input_frame.shape[:2]
  362. scale = resize_opt[0] / max(h, w)
  363. return cv2.resize(input_frame, (int(w * scale), int(h * scale)))
  364. return input_frame
  365. def apply_orientation(frame, opt):
  366. if opt.rotate == 90:
  367. frame = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)
  368. elif opt.rotate == 180:
  369. frame = cv2.rotate(frame, cv2.ROTATE_180)
  370. elif opt.rotate == 270:
  371. frame = cv2.rotate(frame, cv2.ROTATE_90_COUNTERCLOCKWISE)
  372. if opt.flip_horizontal:
  373. frame = cv2.flip(frame, 1)
  374. if opt.flip_vertical:
  375. frame = cv2.flip(frame, 0)
  376. return frame
  377. def capture_screen_frame(opt):
  378. """使用 DXGI (dxcam) 截取整屏,返回与 opt.resize 一致的灰度图;无 dxcam 时返回 None。"""
  379. if not DXCAM_AVAILABLE or _DXCAM_CAMERA is None:
  380. return None
  381. try:
  382. frame = _DXCAM_CAMERA.grab()
  383. if frame is None:
  384. return None
  385. # dxcam 默认返回 RGB,shape=(H, W, 3)
  386. if len(frame.shape) == 3 and frame.shape[2] == 3:
  387. img = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
  388. elif len(frame.shape) == 3 and frame.shape[2] == 4:
  389. img = cv2.cvtColor(frame, cv2.COLOR_BGRA2GRAY)
  390. else:
  391. img = frame if len(frame.shape) == 2 else frame[..., 0]
  392. img = maybe_resize(img, opt.resize)
  393. return img
  394. except Exception: # pylint: disable=broad-except
  395. return None
  396. def main():
  397. opt = parse_args()
  398. if len(opt.resize) == 2 and opt.resize[1] == -1:
  399. opt.resize = opt.resize[0:1]
  400. if len(opt.resize) == 2:
  401. print(f"Will resize to {opt.resize[0]}x{opt.resize[1]} (WxH)")
  402. elif len(opt.resize) == 1 and opt.resize[0] > 0:
  403. print(f"Will resize max dimension to {opt.resize[0]}")
  404. elif len(opt.resize) == 1:
  405. print("Will not resize images")
  406. else:
  407. raise ValueError("Cannot specify more than two integers for --resize")
  408. if opt.no_ui:
  409. import os
  410. import sys
  411. sys.stdout = open(os.devnull, "w")
  412. sys.stderr = open(os.devnull, "w")
  413. device = "cuda" if torch.cuda.is_available() and not opt.force_cpu else "cpu"
  414. print(f'Running inference on device "{device}"')
  415. extractor = SuperPoint(
  416. max_num_keypoints=opt.max_keypoints,
  417. detection_threshold=opt.keypoint_threshold,
  418. nms_radius=opt.nms_radius,
  419. ).eval().to(device)
  420. matcher = LightGlue(
  421. features="superpoint",
  422. depth_confidence=opt.depth_confidence,
  423. width_confidence=opt.width_confidence,
  424. filter_threshold=opt.match_threshold,
  425. mp=opt.use_fp16, # Enable mixed precision if FP16 is requested
  426. ).eval().to(device)
  427. print("Loaded SuperPoint and LightGlue models")
  428. # TensorRT optimization
  429. if opt.use_tensorrt and TENSORRT_AVAILABLE and device == "cuda":
  430. try:
  431. print("="*60)
  432. print("Compiling models with TensorRT...")
  433. print(f"Precision: {opt.tensorrt_precision}")
  434. print("This may take several minutes on first run...")
  435. print("="*60)
  436. # Compile SuperPoint with TensorRT
  437. print("Compiling SuperPoint...")
  438. example_input = torch.randn(1, 1, opt.resize[1], opt.resize[0]).cuda()
  439. enabled_precisions = {torch.float}
  440. calibration_cache = None
  441. if opt.tensorrt_precision == "fp16":
  442. enabled_precisions.add(torch.half)
  443. elif opt.tensorrt_precision == "int8":
  444. enabled_precisions.add(torch.int8)
  445. print(" Note: INT8 quantization will use default calibration")
  446. print(" For better accuracy, provide calibration data with --tensorrt_calibration_data")
  447. # Create a hybrid approach: compile only the encoder part (conv layers)
  448. # Keep dynamic operations (keypoint extraction, NMS) in PyTorch
  449. print(" Creating encoder-only model for TensorRT compilation...")
  450. class SuperPointEncoder(torch.nn.Module):
  451. """Only the encoder part of SuperPoint (conv layers + feature extraction)"""
  452. def __init__(self, superpoint_model):
  453. super().__init__()
  454. # Copy encoder layers
  455. self.conv1a = superpoint_model.conv1a
  456. self.conv1b = superpoint_model.conv1b
  457. self.conv2a = superpoint_model.conv2a
  458. self.conv2b = superpoint_model.conv2b
  459. self.conv3a = superpoint_model.conv3a
  460. self.conv3b = superpoint_model.conv3b
  461. self.conv4a = superpoint_model.conv4a
  462. self.conv4b = superpoint_model.conv4b
  463. self.pool = superpoint_model.pool
  464. self.relu = superpoint_model.relu
  465. # Feature extraction layers
  466. self.convPa = superpoint_model.convPa
  467. self.convPb = superpoint_model.convPb
  468. self.convDa = superpoint_model.convDa
  469. self.convDb = superpoint_model.convDb
  470. # Store config for post-processing
  471. self.conf = superpoint_model.conf
  472. self.original_model = superpoint_model
  473. def forward(self, image):
  474. """Forward pass through encoder only"""
  475. # Shared Encoder
  476. x = self.relu(self.conv1a(image))
  477. x = self.relu(self.conv1b(x))
  478. x = self.pool(x)
  479. x = self.relu(self.conv2a(x))
  480. x = self.relu(self.conv2b(x))
  481. x = self.pool(x)
  482. x = self.relu(self.conv3a(x))
  483. x = self.relu(self.conv3b(x))
  484. x = self.pool(x)
  485. x = self.relu(self.conv4a(x))
  486. x = self.relu(self.conv4b(x))
  487. # Compute the dense keypoint scores
  488. cPa = self.relu(self.convPa(x))
  489. scores = self.convPb(cPa)
  490. scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
  491. b, _, h, w = scores.shape
  492. scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
  493. scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
  494. # Compute the dense descriptors
  495. cDa = self.relu(self.convDa(x))
  496. descriptors = self.convDb(cDa)
  497. descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
  498. return scores, descriptors
  499. encoder_model = SuperPointEncoder(extractor).eval()
  500. # Compile encoder with TensorRT
  501. extractor_trt = None
  502. try:
  503. print(" Compiling encoder with TensorRT...")
  504. with torch.no_grad():
  505. # Trace the encoder
  506. traced_encoder = torch.jit.trace(encoder_model, example_input, strict=False)
  507. traced_encoder.eval()
  508. # For INT8, torch-tensorrt will automatically handle calibration
  509. # We just need to provide a single example input
  510. # The calibration_batches parameter is informational only for now
  511. if opt.tensorrt_precision == "int8":
  512. print(f" Note: INT8 calibration will be performed automatically")
  513. print(f" (Calibration batches setting: {opt.tensorrt_calibration_batches})")
  514. print(" WARNING: INT8 compilation can take 10-20 minutes, please be patient...")
  515. import sys
  516. sys.stdout.flush()
  517. # Compile with TensorRT
  518. # For INT8, torch-tensorrt automatically generates calibration data
  519. print(" Starting TensorRT compilation (this may take a while)...")
  520. import sys
  521. sys.stdout.flush()
  522. encoder_trt = torch_tensorrt.compile(
  523. traced_encoder,
  524. inputs=[example_input],
  525. enabled_precisions=enabled_precisions,
  526. workspace_size=1 << 30, # 1GB
  527. min_block_size=7,
  528. ir="torchscript",
  529. truncate_long_and_double=True,
  530. )
  531. print(" [OK] Encoder compiled with TensorRT successfully")
  532. import sys
  533. sys.stdout.flush()
  534. # Create hybrid wrapper that uses TensorRT encoder + PyTorch post-processing
  535. # Re-import helper functions to ensure they're available
  536. try:
  537. from lightglue.superpoint import simple_nms, top_k_keypoints, sample_descriptors
  538. except ImportError as import_err:
  539. print(f" [ERROR] Could not import helper functions: {import_err}")
  540. print(" Falling back to PyTorch model (TensorRT optimization disabled)")
  541. extractor_trt = None
  542. raise ImportError("Required helper functions not available") from import_err
  543. class HybridSuperPoint:
  544. def __init__(self, trt_encoder, original_model):
  545. self.trt_encoder = trt_encoder
  546. self.original_model = original_model
  547. self.conf = original_model.conf
  548. def __call__(self, inputs):
  549. if isinstance(inputs, dict):
  550. image = inputs["image"]
  551. else:
  552. image = inputs
  553. # Use TensorRT encoder
  554. scores, descriptors = self.trt_encoder(image)
  555. # Post-processing in PyTorch (dynamic operations)
  556. scores = simple_nms(scores, self.conf.nms_radius)
  557. # Discard keypoints near borders
  558. if self.conf.remove_borders:
  559. pad = self.conf.remove_borders
  560. scores[:, :pad] = -1
  561. scores[:, :, :pad] = -1
  562. scores[:, -pad:] = -1
  563. scores[:, :, -pad:] = -1
  564. # Extract keypoints
  565. best_kp = torch.where(scores > self.conf.detection_threshold)
  566. scores_vals = scores[best_kp]
  567. b = image.shape[0]
  568. keypoints = [
  569. torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i] for i in range(b)
  570. ]
  571. scores_list = [scores_vals[best_kp[0] == i] for i in range(b)]
  572. # Top-k keypoints
  573. if self.conf.max_num_keypoints is not None:
  574. keypoints, scores_list = list(zip(*[
  575. top_k_keypoints(k, s, self.conf.max_num_keypoints)
  576. for k, s in zip(keypoints, scores_list)
  577. ]))
  578. # Convert (h, w) to (x, y)
  579. keypoints = [torch.flip(k, [1]).float() for k in keypoints]
  580. # Extract descriptors
  581. descriptors_list = [
  582. sample_descriptors(k[None], d[None], 8)[0]
  583. for k, d in zip(keypoints, descriptors)
  584. ]
  585. return {
  586. "keypoints": torch.stack(keypoints, 0),
  587. "keypoint_scores": torch.stack(scores_list, 0),
  588. "descriptors": torch.stack(descriptors_list, 0).transpose(-1, -2).contiguous(),
  589. }
  590. def eval(self):
  591. return self
  592. def to(self, device):
  593. return self
  594. extractor_trt = HybridSuperPoint(encoder_trt, extractor)
  595. except Exception as compile_error:
  596. print(f" [ERROR] TensorRT compilation failed: {compile_error}")
  597. print(" Falling back to PyTorch model (TensorRT optimization disabled)")
  598. import traceback
  599. print(" Full error traceback:")
  600. traceback.print_exc()
  601. extractor_trt = None
  602. # Replace extractor with TensorRT version only if compilation succeeded
  603. if extractor_trt is None:
  604. print("="*60)
  605. print("TensorRT optimization skipped, using PyTorch models")
  606. print("="*60)
  607. else:
  608. extractor = extractor_trt # HybridSuperPoint already implements the interface
  609. print("[OK] SuperPoint encoder compiled with TensorRT")
  610. print(" (Keypoint extraction and NMS remain in PyTorch for compatibility)")
  611. # Note: LightGlue compilation is more complex due to multiple inputs
  612. # For now, we'll keep LightGlue as PyTorch model
  613. print("Note: LightGlue will use PyTorch (TensorRT compilation for LightGlue is more complex)")
  614. print("="*60)
  615. print("[OK] TensorRT optimization completed (hybrid approach)")
  616. print("="*60)
  617. except Exception as e:
  618. print(f"[ERROR] Failed to compile with TensorRT: {e}")
  619. print("Falling back to PyTorch models")
  620. import traceback
  621. print("Full error traceback:")
  622. traceback.print_exc()
  623. import sys
  624. sys.stdout.flush()
  625. elif opt.use_tensorrt:
  626. if not TENSORRT_AVAILABLE:
  627. print("Warning: TensorRT requested but torch-tensorrt not installed")
  628. print("Install with: pip install torch-tensorrt")
  629. elif device != "cuda":
  630. print("Warning: TensorRT requires CUDA, but running on CPU")
  631. ref_frame, ref_tensor, last_image_id = load_reference_frame(opt, device)
  632. if ref_tensor is not None:
  633. if opt.use_fp16 and device == "cuda":
  634. with torch.cuda.amp.autocast():
  635. last_data = extractor({"image": ref_tensor})
  636. else:
  637. last_data = extractor({"image": ref_tensor})
  638. last_frame = ref_frame
  639. else:
  640. last_data = None
  641. last_frame = None
  642. last_image_id = 0
  643. streamer = VideoStreamer(opt.input, opt.resize, opt.skip, opt.image_glob, opt.max_length)
  644. # 处理UDP模式
  645. if hasattr(streamer, "is_udp_jpeg") and streamer.is_udp_jpeg:
  646. print("UDP JPEG mode: receiver started in background thread", flush=True)
  647. # 处理摄像头模式
  648. elif hasattr(streamer, "cap") and streamer.cap is not None:
  649. is_local_cam = False
  650. if isinstance(opt.input, str) and opt.input.isdigit():
  651. is_local_cam = True
  652. elif isinstance(opt.input, int):
  653. is_local_cam = True
  654. if is_local_cam:
  655. desired_width, desired_height, desired_fps = 640, 480, 30
  656. streamer.cap.set(cv2.CAP_PROP_FRAME_WIDTH, desired_width)
  657. streamer.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, desired_height)
  658. streamer.cap.set(cv2.CAP_PROP_FPS, desired_fps)
  659. actual_w = streamer.cap.get(cv2.CAP_PROP_FRAME_WIDTH)
  660. actual_h = streamer.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
  661. actual_fps = streamer.cap.get(cv2.CAP_PROP_FPS)
  662. print(f"Camera props requested -> {desired_width}x{desired_height} @{desired_fps} FPS")
  663. print(f"Camera props applied -> {actual_w:.0f}x{actual_h:.0f} @{actual_fps:.1f} FPS")
  664. if opt.no_ip_grab and hasattr(streamer, "is_ip_camera"):
  665. streamer.is_ip_camera = False
  666. print("IP camera buffer flush disabled (no extra grab calls).")
  667. async_streamer = AsyncVideoStreamer(streamer, queue_size=opt.queue_size, timeout=opt.read_timeout)
  668. if last_data is None:
  669. # For UDP mode, wait a bit for the first frame to arrive
  670. if hasattr(streamer, "is_udp_jpeg") and streamer.is_udp_jpeg:
  671. print("Waiting for first UDP frame...")
  672. max_wait_time = 10.0 # Wait up to 10 seconds
  673. wait_interval = 0.1
  674. elapsed = 0.0
  675. first_frame, ret = None, False
  676. while elapsed < max_wait_time:
  677. first_frame, ret = async_streamer.read()
  678. if ret:
  679. break
  680. time.sleep(wait_interval)
  681. elapsed += wait_interval
  682. if int(elapsed) % 2 == 0 and int(elapsed - wait_interval) % 2 != 0:
  683. print(f"Still waiting for UDP frame... ({int(elapsed)}s)")
  684. else:
  685. first_frame, ret = async_streamer.read()
  686. if not ret:
  687. raise RuntimeError(
  688. "Error when reading the first frame. "
  689. "For UDP mode, make sure:\n"
  690. " 1. The sender is running and sending data\n"
  691. " 2. The port number is correct\n"
  692. " 3. Firewall allows UDP traffic on this port"
  693. )
  694. first_frame = apply_orientation(first_frame, opt)
  695. last_frame = first_frame
  696. last_tensor = frame2tensor(first_frame, device)
  697. if opt.use_fp16 and device == "cuda":
  698. with torch.cuda.amp.autocast():
  699. last_data = extractor({"image": last_tensor})
  700. else:
  701. last_data = extractor({"image": last_tensor})
  702. print("First frame received and processed")
  703. if opt.output_dir is not None:
  704. Path(opt.output_dir).mkdir(exist_ok=True)
  705. print(f"==> Will write outputs to {opt.output_dir}")
  706. window_name_ref = "Camera Position in Reference"
  707. if not opt.no_display:
  708. try:
  709. cv2.namedWindow(window_name_ref, cv2.WINDOW_NORMAL)
  710. cv2.resizeWindow(window_name_ref, 640, 480)
  711. except cv2.error as e:
  712. print(f"Warning: Could not create OpenCV windows: {e}")
  713. print("Continuing without display...")
  714. opt.no_display = True
  715. print("==> Keyboard control:\n"
  716. "\tn: set current frame as reference\n"
  717. "\tq: quit\n"
  718. "\tf: toggle FPS overlay\n")
  719. timer = AverageTimer()
  720. show_fps = opt.show_fps
  721. fps_display = 0.0
  722. last_time = time.time()
  723. last_fps_print_time = time.time()
  724. fps_print_interval = 2.0 # Print FPS every 2 seconds
  725. # Initialize UDP result sender (if available)
  726. result_sender = None
  727. if UDP_RESULT_SENDER_AVAILABLE:
  728. try:
  729. result_sender = UDPResultSender(unity_ip=opt.result_ip, unity_port=opt.result_port)
  730. print(f"[UDP] Result sender initialized: {opt.result_ip}:{opt.result_port}")
  731. except Exception as e:
  732. print(f"[UDP] Failed to initialize result sender: {e}")
  733. result_sender = None
  734. else:
  735. print("[UDP] UDP result sender not available (udp_result_sender.py not found)")
  736. # Create async visualizer: 无窗口时也需创建以便通过 result_sender 向 Unity 回传结果;仅 no_display 时不再弹窗
  737. async_visualizer = AsyncVisualizer(queue_size=2, min_matches=opt.min_matches, result_sender=result_sender) if (not opt.no_display or result_sender is not None) else None
  738. # 可选:启动来自 Unity 的 UDP 控制监听(用于刷新参考图等简单指令)
  739. control_sock = None
  740. control_stop_event = None
  741. control_thread = None
  742. control_refresh_event = None
  743. # 控制指令:n=当前摄像头帧作参考图, s=Python截屏作参考图, r=下一帧来自Unity作为参考图
  744. screen_capture_event = threading.Event()
  745. next_frame_is_reference_event = threading.Event()
  746. if opt.control_port and opt.control_port > 0:
  747. def _control_listener(sock: socket.socket, stop_event: threading.Event, refresh_event: threading.Event,
  748. screen_ev: threading.Event, next_ref_ev: threading.Event) -> None:
  749. sock.settimeout(0.5)
  750. print(f"[Control] Listening for UDP control commands on 0.0.0.0:{opt.control_port} (n/s/r)")
  751. while not stop_event.is_set():
  752. try:
  753. try:
  754. data, addr = sock.recvfrom(1024)
  755. except socket.timeout:
  756. continue
  757. if not data:
  758. continue
  759. cmd = data[0]
  760. # n/N: 当前摄像头帧作参考图; s: Python截屏作参考图; r: 下一帧(Unity游戏画面)作参考图
  761. if cmd in (ord("n"), ord("N"), 1):
  762. refresh_event.set()
  763. elif cmd == ord("s"):
  764. screen_ev.set()
  765. elif cmd == ord("r"):
  766. next_ref_ev.set()
  767. except OSError:
  768. break
  769. except Exception as exc: # pylint: disable=broad-except
  770. print(f"[Control] Listener error: {exc}")
  771. continue
  772. try:
  773. control_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  774. control_sock.bind(("0.0.0.0", opt.control_port))
  775. control_stop_event = threading.Event()
  776. control_refresh_event = threading.Event()
  777. control_thread = threading.Thread(
  778. target=_control_listener,
  779. args=(control_sock, control_stop_event, control_refresh_event,
  780. screen_capture_event, next_frame_is_reference_event),
  781. name="ControlListener",
  782. daemon=True,
  783. )
  784. control_thread.start()
  785. except Exception as exc: # pylint: disable=broad-except
  786. print(f"[Control] Failed to start UDP control listener on port {opt.control_port}: {exc}")
  787. control_sock = None
  788. control_stop_event = None
  789. control_thread = None
  790. control_refresh_event = None
  791. try:
  792. while True:
  793. loop_start_time = time.time()
  794. # 控制指令 s:Python 截屏作为基准图(在读取本帧前处理)
  795. if screen_capture_event.is_set():
  796. screen_capture_event.clear()
  797. sc_frame = capture_screen_frame(opt)
  798. if sc_frame is not None:
  799. last_frame = sc_frame
  800. last_tensor = frame2tensor(last_frame, device)
  801. if opt.use_fp16 and device == "cuda":
  802. with torch.cuda.amp.autocast():
  803. new_ref = extractor({"image": last_tensor})
  804. else:
  805. new_ref = extractor({"image": last_tensor})
  806. # 只有在新参考图中检测到关键点时才更新基准,避免空 keypoints 导致 LightGlue 报错
  807. if new_ref.get("keypoints", None) is not None and new_ref["keypoints"].shape[1] > 0:
  808. last_data = new_ref
  809. last_image_id += 1
  810. else:
  811. print("[Control] Screen capture has no keypoints, keep previous reference")
  812. elif not DXCAM_AVAILABLE:
  813. print("[Control] Screen capture skipped: install dxcam (pip install dxcam)")
  814. frame, ret = async_streamer.read()
  815. if not ret:
  816. # For UDP mode, timeout doesn't mean end of stream
  817. # Continue waiting for new frames
  818. if hasattr(streamer, "is_udp_jpeg") and streamer.is_udp_jpeg:
  819. continue # Keep waiting for UDP frames
  820. # For other modes, timeout means end of stream
  821. print("Stream ended or timeout exceeded.")
  822. break
  823. frame = apply_orientation(frame, opt)
  824. timer.update("data")
  825. # 控制指令 r:本帧为 Unity 发来的游戏画面,用作基准图后跳过本帧推理
  826. if next_frame_is_reference_event.is_set():
  827. next_frame_is_reference_event.clear()
  828. last_frame = frame
  829. frame_tensor = frame2tensor(frame, device)
  830. if opt.use_fp16 and device == "cuda":
  831. with torch.cuda.amp.autocast():
  832. new_ref = extractor({"image": frame_tensor})
  833. else:
  834. new_ref = extractor({"image": frame_tensor})
  835. # 同样只在有关键点时更新参考图
  836. if new_ref.get("keypoints", None) is not None and new_ref["keypoints"].shape[1] > 0:
  837. last_data = new_ref
  838. last_image_id += 1
  839. print("[Control] Reference updated from Unity game view frame")
  840. else:
  841. print("[Control] Unity game view frame has no keypoints, keep previous reference")
  842. continue
  843. frame_tensor = frame2tensor(frame, device)
  844. # Use FP16 autocast if enabled
  845. if opt.use_fp16 and device == "cuda":
  846. with torch.cuda.amp.autocast():
  847. curr_data = extractor({"image": frame_tensor})
  848. else:
  849. curr_data = extractor({"image": frame_tensor})
  850. # 如果任一图像没有关键点,跳过本帧匹配,避免 LightGlue 报 IndexError
  851. if last_data is None or last_data.get("keypoints", None) is None or last_data["keypoints"].shape[1] == 0:
  852. print("[Guard] Reference has no keypoints, skip matching this frame")
  853. continue
  854. if curr_data.get("keypoints", None) is None or curr_data["keypoints"].shape[1] == 0:
  855. print("[Guard] Current frame has no keypoints, skip matching this frame")
  856. continue
  857. matches01 = matcher({"image0": last_data, "image1": curr_data})
  858. # Update timer immediately after inference (all CPU operations are async now)
  859. timer.update("forward")
  860. # Calculate FPS based on the entire loop time (more accurate)
  861. loop_end_time = time.time()
  862. dt = loop_end_time - loop_start_time
  863. if dt > 0:
  864. fps_display = 0.9 * fps_display + 0.1 * (1.0 / dt)
  865. # Use loop_end_time for console printing
  866. current_time = loop_end_time
  867. # Print FPS to console periodically
  868. if current_time - last_fps_print_time >= fps_print_interval:
  869. fp16_status = "FP16" if (opt.use_fp16 and device == "cuda") else "FP32"
  870. print(f"[FPS] {fps_display:.1f} FPS ({fp16_status})")
  871. last_fps_print_time = current_time
  872. # Submit all CPU operations to background thread (non-blocking)
  873. # This includes: tensor->numpy, homography calculation, visualization
  874. # Now fps_display is already updated, so it will show correct FPS on screen
  875. center_x, center_y = frame.shape[1] // 2, frame.shape[0] // 2
  876. camera_center_current = (center_x, center_y)
  877. if async_visualizer is not None:
  878. async_visualizer.submit(
  879. frame, # Raw frame (will be processed in background)
  880. last_frame,
  881. last_data, # Raw tensor (will be converted in background)
  882. curr_data, # Raw tensor (will be converted in background)
  883. matches01, # Raw tensor (will be converted in background)
  884. camera_center_current,
  885. show_fps,
  886. fps_display # Now this is the current frame's FPS
  887. )
  888. # Try to get visualization result (non-blocking); no_display 时仍取结果以消费队列,但不弹窗
  889. if async_visualizer is not None:
  890. viz_result = async_visualizer.get_result(timeout=0.0) # Non-blocking
  891. if viz_result is not None and not opt.no_display:
  892. display_frame, reference_view, num_matches, current_H = viz_result
  893. cv2.imshow(window_name_ref, reference_view)
  894. # 处理来自 Unity 的“刷新参考图”控制指令(等价于按键 n)
  895. if control_thread is not None and control_refresh_event is not None:
  896. if control_refresh_event.is_set():
  897. control_refresh_event.clear()
  898. last_data = curr_data
  899. last_frame = frame
  900. last_image_id += 1
  901. print("[Control] Applied refresh-reference command from Unity, updated reference frame")
  902. # Handle keyboard input (reduce frequency to minimize overhead)
  903. # Only check every frame (waitKey is necessary for window events)
  904. if not opt.no_display:
  905. key = cv2.waitKey(1) & 0xFF
  906. else:
  907. key = 0
  908. # Update timer after all operations (for accurate total time)
  909. # timer.print("LightGlue-Async")
  910. if key == ord("q"):
  911. print("Exiting via keyboard (q)")
  912. break
  913. if key == ord("n"):
  914. last_data = curr_data
  915. last_frame = frame
  916. last_image_id += 1
  917. print("Updated reference frame")
  918. elif key == ord("f"):
  919. show_fps = not show_fps
  920. finally:
  921. async_streamer.stop()
  922. if async_visualizer is not None:
  923. async_visualizer.stop()
  924. if result_sender is not None:
  925. result_sender.close()
  926. # 关闭控制监听
  927. if control_stop_event is not None:
  928. try:
  929. control_stop_event.set()
  930. except Exception:
  931. pass
  932. if control_sock is not None:
  933. try:
  934. control_sock.close()
  935. except Exception:
  936. pass
  937. try:
  938. cv2.destroyAllWindows()
  939. except:
  940. pass # Ignore errors if windows weren't created
  941. if __name__ == "__main__":
  942. main()