| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076 |
- #!/usr/bin/env python3
- """
- LightGlue demo with asynchronous video streaming.
- This script demonstrates how to decouple frame acquisition from model
- inference by reading camera frames in a background thread. The goal is to
- reduce blocking caused by `cv2.VideoCapture.read()` and keep the downstream
- pipeline busy with the most recent frame available.
- """
- # 强制关闭 Windows 11 的后台限制(效率模式 / EcoQoS),避免游戏全屏时 Python 被系统挂起
- import ctypes
- import sys
- def _disable_win11_efficiency_mode():
- if sys.platform != "win32":
- return
- try:
- handle = ctypes.windll.kernel32.GetCurrentProcess()
- # PROCESS_POWER_THROTTLING_STATE: Version=1, ControlMask=1(Speed), StateMask=0(Disable limit)
- class ProcessPowerThrottlingState(ctypes.Structure):
- _fields_ = [
- ("Version", ctypes.c_ulong),
- ("ControlMask", ctypes.c_ulong),
- ("StateMask", ctypes.c_ulong),
- ]
- state = ProcessPowerThrottlingState(1, 1, 0)
- ProcessPowerThrottling = 4 # ProcessPowerThrottling
- if ctypes.windll.kernel32.SetProcessInformation(
- handle, ProcessPowerThrottling, ctypes.byref(state), ctypes.sizeof(state)
- ):
- print("Win11: 已禁用进程效率模式限制 (EcoQoS)", flush=True)
- else:
- print("Win11: 设置效率模式失败 (非致命)", flush=True)
- except Exception as e: # pylint: disable=broad-except
- print(f"Win11 效率模式设置异常: {e}", flush=True)
- _disable_win11_efficiency_mode()
- import argparse
- import queue
- import threading
- import time
- import socket
- from pathlib import Path
- import cv2
- import numpy as np
- import torch
- from demo_lightglue_camera_position_single_window import (
- AverageTimer,
- VideoStreamer,
- draw_camera_position_on_reference,
- frame2tensor,
- )
- from lightglue import LightGlue, SuperPoint
- # UDP结果发送器(可选)
- try:
- from udp_result_sender import UDPResultSender
- UDP_RESULT_SENDER_AVAILABLE = True
- except ImportError:
- UDPResultSender = None
- UDP_RESULT_SENDER_AVAILABLE = False
- # Import helper functions for SuperPoint post-processing
- try:
- from lightglue.superpoint import simple_nms, top_k_keypoints, sample_descriptors
- # remove_borders is not a function, it's a config parameter in SuperPoint.conf
- HELPER_FUNCTIONS_AVAILABLE = True
- except ImportError as e:
- # These will be imported later if needed
- simple_nms = top_k_keypoints = sample_descriptors = None
- HELPER_FUNCTIONS_AVAILABLE = False
- # TensorRT support
- try:
- import torch_tensorrt
- TENSORRT_AVAILABLE = True
- except ImportError:
- TENSORRT_AVAILABLE = False
- # 屏幕截屏(DXGI Desktop Duplication,用于 Unity 指令 s:Python 截屏作基准图)
- try:
- import dxcam
- try:
- _DXCAM_CAMERA = dxcam.create(output_color="RGB")
- DXCAM_AVAILABLE = True
- except Exception: # pylint: disable=broad-except
- _DXCAM_CAMERA = None
- DXCAM_AVAILABLE = False
- except ImportError:
- dxcam = None # type: ignore[assignment]
- _DXCAM_CAMERA = None
- DXCAM_AVAILABLE = False
- torch.set_grad_enabled(False)
- class AsyncVisualizer:
- """Asynchronous visualizer that handles CPU operations and drawing in a background thread."""
-
- def __init__(self, queue_size: int = 2, min_matches: int = 10, result_sender=None):
- self.input_queue: "queue.Queue" = queue.Queue(maxsize=queue_size)
- self.output_queue: "queue.Queue" = queue.Queue(maxsize=queue_size)
- self._stop_requested = False
- self.min_matches = min_matches
- self.result_sender = result_sender # UDP结果发送器(可选)
- self._thread = threading.Thread(target=self._visualizer_worker, name="AsyncVisualizer", daemon=True)
- self._thread.start()
-
- def _visualizer_worker(self) -> None:
- """Background thread that performs all CPU operations and drawing."""
- while not self._stop_requested:
- try:
- # Get data from queue (with timeout to check stop flag)
- try:
- viz_data = self.input_queue.get(timeout=0.1)
- except queue.Empty:
- continue
-
- if viz_data is None: # Sentinel value to stop
- break
-
- # Unpack data - now includes raw tensors
- (frame, last_frame, last_data, curr_data, matches01,
- camera_center_current, show_fps, fps_display) = viz_data
-
- # All CPU operations happen here (in background thread)
- # 1. Convert tensors to numpy (this was blocking the main thread)
- kpts0 = last_data["keypoints"][0].detach().cpu().numpy()
- kpts1 = curr_data["keypoints"][0].detach().cpu().numpy()
- matches = matches01["matches0"][0].detach().cpu().numpy()
- scores = matches01["matching_scores0"][0].detach().cpu().numpy()
-
- # 2. Process matches and compute homography (this was blocking)
- valid = matches > -1
- mkpts0 = kpts0[valid]
- mkpts1 = kpts1[matches[valid]]
- mconf = scores[valid] if valid.any() else np.array([])
- num_matches = len(mkpts0)
- inliers_ratio = 0.0
- current_H = None
-
- if num_matches >= self.min_matches:
- current_H, mask = cv2.findHomography(mkpts0, mkpts1, cv2.RANSAC, 5.0)
- if current_H is not None and mask is not None:
- inliers_ratio = float(np.sum(mask)) / max(num_matches, 1)
- mean_score = float(mconf.mean()) if len(mconf) else 0.0
- # print(f"[Homography] matches={num_matches} inliers={inliers_ratio:.2%} score={mean_score:.3f}")
- else:
- # print(f"[Homography] failed with matches={num_matches}")
- pass
- else:
- if num_matches > 0:
- mean_score = float(mconf.mean()) if len(mconf) else 0.0
- # print(f"[Matches] insufficient ({num_matches}/{self.min_matches}) score={mean_score:.3f}")
-
- # 3. Prepare display frame (drawing operations)
- display_frame = cv2.cvtColor(frame.copy(), cv2.COLOR_GRAY2BGR)
- center_x, center_y = display_frame.shape[1] // 2, display_frame.shape[0] // 2
- crosshair_size = 20
- cv2.line(display_frame, (center_x - crosshair_size, center_y),
- (center_x + crosshair_size, center_y), (0, 0, 255), 4, cv2.LINE_AA)
- cv2.line(display_frame, (center_x, center_y - crosshair_size),
- (center_x, center_y + crosshair_size), (0, 0, 255), 4, cv2.LINE_AA)
-
- # Don't draw FPS here - it will be drawn in main thread with latest value
- # if show_fps:
- # cv2.putText(display_frame, f"FPS: {fps_display:.1f}", (10, 30),
- # cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
- cv2.putText(display_frame, f"Matches: {num_matches}", (10, 60),
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
-
- # 4. Calculate camera position in reference frame (if homography is valid)
- camera_center_ref = None
- if current_H is not None and num_matches >= self.min_matches:
- try:
- H_inv = np.linalg.inv(current_H)
- camera_center_ref = cv2.perspectiveTransform(
- np.array([[camera_center_current]], dtype=np.float32).reshape(-1, 1, 2),
- H_inv
- )[0, 0]
- except np.linalg.LinAlgError:
- camera_center_ref = None
-
- # 5. Send result to Unity via UDP (if sender is available)
- if self.result_sender is not None:
- try:
- if camera_center_ref is not None:
- self.result_sender.send_result(
- is_valid=True,
- num_matches=num_matches,
- inliers_ratio=inliers_ratio,
- camera_x=float(camera_center_ref[0]),
- camera_y=float(camera_center_ref[1])
- )
- else:
- # Send invalid result
- self.result_sender.send_result(
- is_valid=False,
- num_matches=num_matches,
- inliers_ratio=0.0,
- camera_x=0.0,
- camera_y=0.0
- )
- except Exception as e:
- print(f"[AsyncVisualizer] Failed to send result: {e}")
-
- # 6. Draw reference view
- reference_view = draw_camera_position_on_reference(
- last_frame,
- camera_center_current,
- current_H,
- num_matches,
- self.min_matches,
- inliers_ratio,
- )
-
- # Put drawn frames in output queue (non-blocking, drop if full)
- try:
- self.output_queue.put_nowait((display_frame, reference_view, num_matches, current_H))
- except queue.Full:
- # Drop old frame, keep only latest
- try:
- self.output_queue.get_nowait()
- self.output_queue.put_nowait((display_frame, reference_view, num_matches, current_H))
- except queue.Empty:
- pass
-
- except Exception as e:
- print(f"[AsyncVisualizer] Error: {e}")
- import traceback
- traceback.print_exc()
- continue
-
- def submit(self, frame, last_frame, last_data, curr_data, matches01,
- camera_center_current, show_fps, fps_display):
- """Submit raw data for processing (non-blocking)."""
- viz_data = (frame, last_frame, last_data, curr_data, matches01,
- camera_center_current, show_fps, fps_display)
- try:
- # Non-blocking put, drop if queue is full (keep only latest)
- if self.input_queue.full():
- try:
- self.input_queue.get_nowait()
- except queue.Empty:
- pass
- self.input_queue.put_nowait(viz_data)
- except queue.Full:
- pass # Drop if still full
-
- def get_result(self, timeout: float = 0.01):
- """Get visualization result (non-blocking)."""
- try:
- return self.output_queue.get(timeout=timeout)
- except queue.Empty:
- return None
-
- def stop(self):
- """Stop the visualizer thread."""
- self._stop_requested = True
- self.input_queue.put(None) # Sentinel value
- self._thread.join(timeout=1.0)
- # Close result sender if available
- if self.result_sender is not None:
- try:
- self.result_sender.close()
- except Exception as e:
- print(f"[AsyncVisualizer] Error closing result sender: {e}")
- class AsyncVideoStreamer:
- """Wrapper around VideoStreamer that reads frames on a background thread."""
- def __init__(self, streamer: VideoStreamer, queue_size: int = 1, timeout: float = 1.0):
- self.streamer = streamer
- self.queue: "queue.Queue[np.ndarray]" = queue.Queue(maxsize=max(queue_size, 1))
- self.timeout = timeout
- self._stop_requested = False
- self._has_error = False
- self._thread = threading.Thread(target=self._reader, name="AsyncVideoStreamer", daemon=True)
- self._thread.start()
- def _reader(self) -> None:
- try:
- while not self._stop_requested:
- frame, ret = self.streamer.next_frame()
- if not ret:
- # For UDP mode, no frame doesn't mean end of stream
- # Just wait a bit and retry
- if hasattr(self.streamer, "is_udp_jpeg") and self.streamer.is_udp_jpeg:
- time.sleep(0.01) # Wait 10ms before retry
- continue
- # End of stream or error: signal stop and exit
- self._stop_requested = True
- break
- # Keep only the most recent frame to minimise latency
- if self.queue.full():
- try:
- self.queue.get_nowait()
- except queue.Empty:
- pass
- self.queue.put(frame)
- except Exception as exc: # pylint: disable=broad-except
- self._has_error = True
- print(f"[AsyncVideoStreamer] Reader thread error: {exc}")
- finally:
- self._stop_requested = True
- def read(self):
- """Return the latest frame. Blocks up to `timeout` seconds."""
- if self._has_error:
- return None, False
- try:
- frame = self.queue.get(timeout=self.timeout)
- return frame, True
- except queue.Empty:
- return None, False
- def stop(self):
- self._stop_requested = True
- if self._thread.is_alive():
- self._thread.join(timeout=1.0)
- self.streamer.cleanup()
- def load_reference_frame(opt, device):
- if opt.reference_image is None:
- return None, None, None
- print(f"==> Loading reference image: {opt.reference_image}")
- ref_image = cv2.imread(opt.reference_image, cv2.IMREAD_GRAYSCALE)
- if ref_image is None:
- raise IOError(f"Cannot load reference image: {opt.reference_image}")
- h, w = ref_image.shape[:2]
- if len(opt.resize) == 2:
- ref_image = cv2.resize(ref_image, tuple(opt.resize))
- elif len(opt.resize) == 1 and opt.resize[0] > 0:
- scale = opt.resize[0] / max(h, w)
- ref_image = cv2.resize(ref_image, (int(w * scale), int(h * scale)))
- ref_tensor = frame2tensor(ref_image, device)
- return ref_image, ref_tensor, 0
- def parse_args():
- parser = argparse.ArgumentParser(
- description="LightGlue demo (asynchronous capture)",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
- parser.add_argument("--input", type=str, default="0", help="USB webcam index, IP camera URL, UDP stream (udp://host:port), or video path")
- parser.add_argument("--reference_image", type=str, default=None, help="Optional reference image path")
- parser.add_argument("--output_dir", type=str, default=None, help="Directory to save visualisations")
- parser.add_argument("--image_glob", type=str, nargs="+", default=["*.png", "*.jpg", "*.jpeg"], help="Glob for image sequences")
- parser.add_argument("--skip", type=int, default=0, help="Number of frames to skip between reads")
- parser.add_argument("--max_length", type=int, default=1_000_000, help="Maximum frames")
- parser.add_argument(
- "--resize",
- type=int,
- nargs="+",
- default=[640, 480],
- help="Resize input image. Two numbers = width height, one number = max dimension, -1 = no resize",
- )
- parser.add_argument("--max_keypoints", type=int, default=1024, help="Maximum number of SuperPoint keypoints")
- parser.add_argument("--keypoint_threshold", type=float, default=0.01, help="SuperPoint detection threshold")
- parser.add_argument("--nms_radius", type=int, default=4, help="SuperPoint NMS radius")
- parser.add_argument("--match_threshold", type=float, default=0.2, help="LightGlue match threshold")
- parser.add_argument("--depth_confidence", type=float, default=0.95, help="LightGlue depth confidence")
- parser.add_argument("--width_confidence", type=float, default=0.99, help="LightGlue width confidence")
- parser.add_argument("--use_fp16", action="store_true", help="Enable FP16 half precision inference for faster processing")
- parser.add_argument("--use_tensorrt", action="store_true", help="Use TensorRT optimized models (requires torch-tensorrt)")
- parser.add_argument("--tensorrt_precision", type=str, default="fp16", choices=["fp32", "fp16", "int8"],
- help="TensorRT precision mode (fp16 recommended)")
- parser.add_argument("--tensorrt_calibration_data", type=str, default=None,
- help="Directory containing calibration images for INT8 quantization (optional)")
- parser.add_argument("--tensorrt_calibration_batches", type=int, default=10,
- help="Number of calibration batches for INT8 (default: 10)")
- parser.add_argument("--min_matches", type=int, default=10, help="Minimum matches required to compute homography")
- parser.add_argument("--queue_size", type=int, default=1, help="Frame queue size for async reader")
- parser.add_argument("--read_timeout", type=float, default=1.0, help="Seconds to wait for a frame from async reader")
- parser.add_argument("--flip_horizontal", action="store_true", help="Flip frames horizontally")
- parser.add_argument("--flip_vertical", action="store_true", help="Flip frames vertically")
- parser.add_argument("--rotate", type=int, default=0, choices=[0, 90, 180, 270], help="Rotate frames clockwise")
- parser.add_argument("--show_fps", action="store_true", help="Render FPS overlay")
- parser.add_argument("--force_cpu", action="store_true", help="Run inference on CPU even if CUDA is available")
- parser.add_argument("--no_ip_grab", action="store_true", help="Disable extra grab calls for IP cameras (reduces frame drops but may increase latency)")
- parser.add_argument("--no_display", action="store_true", help="Disable OpenCV window")
- parser.add_argument("--no_ui", action="store_true", help="Suppress console output (UI embedding)")
- parser.add_argument("--result_ip", type=str, default="127.0.0.1", help="Unity IP address for result transmission (default: 127.0.0.1)")
- parser.add_argument("--result_port", type=int, default=12348, help="Unity UDP port for result transmission (default: 12348)")
- parser.add_argument(
- "--control_port",
- type=int,
- default=0,
- help="Optional UDP port for receiving control commands from Unity (e.g., refresh reference frame). 0=disabled.",
- )
- return parser.parse_args()
- def maybe_resize(input_frame, resize_opt):
- if len(resize_opt) == 2:
- return cv2.resize(input_frame, tuple(resize_opt))
- if len(resize_opt) == 1 and resize_opt[0] > 0:
- h, w = input_frame.shape[:2]
- scale = resize_opt[0] / max(h, w)
- return cv2.resize(input_frame, (int(w * scale), int(h * scale)))
- return input_frame
- def apply_orientation(frame, opt):
- if opt.rotate == 90:
- frame = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)
- elif opt.rotate == 180:
- frame = cv2.rotate(frame, cv2.ROTATE_180)
- elif opt.rotate == 270:
- frame = cv2.rotate(frame, cv2.ROTATE_90_COUNTERCLOCKWISE)
- if opt.flip_horizontal:
- frame = cv2.flip(frame, 1)
- if opt.flip_vertical:
- frame = cv2.flip(frame, 0)
- return frame
- def capture_screen_frame(opt):
- """使用 DXGI (dxcam) 截取整屏,返回与 opt.resize 一致的灰度图;无 dxcam 时返回 None。"""
- if not DXCAM_AVAILABLE or _DXCAM_CAMERA is None:
- return None
- try:
- frame = _DXCAM_CAMERA.grab()
- if frame is None:
- return None
- # dxcam 默认返回 RGB,shape=(H, W, 3)
- if len(frame.shape) == 3 and frame.shape[2] == 3:
- img = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
- elif len(frame.shape) == 3 and frame.shape[2] == 4:
- img = cv2.cvtColor(frame, cv2.COLOR_BGRA2GRAY)
- else:
- img = frame if len(frame.shape) == 2 else frame[..., 0]
- img = maybe_resize(img, opt.resize)
- return img
- except Exception: # pylint: disable=broad-except
- return None
- def main():
- opt = parse_args()
- if len(opt.resize) == 2 and opt.resize[1] == -1:
- opt.resize = opt.resize[0:1]
- if len(opt.resize) == 2:
- print(f"Will resize to {opt.resize[0]}x{opt.resize[1]} (WxH)")
- elif len(opt.resize) == 1 and opt.resize[0] > 0:
- print(f"Will resize max dimension to {opt.resize[0]}")
- elif len(opt.resize) == 1:
- print("Will not resize images")
- else:
- raise ValueError("Cannot specify more than two integers for --resize")
- if opt.no_ui:
- import os
- import sys
- sys.stdout = open(os.devnull, "w")
- sys.stderr = open(os.devnull, "w")
- device = "cuda" if torch.cuda.is_available() and not opt.force_cpu else "cpu"
- print(f'Running inference on device "{device}"')
- extractor = SuperPoint(
- max_num_keypoints=opt.max_keypoints,
- detection_threshold=opt.keypoint_threshold,
- nms_radius=opt.nms_radius,
- ).eval().to(device)
- matcher = LightGlue(
- features="superpoint",
- depth_confidence=opt.depth_confidence,
- width_confidence=opt.width_confidence,
- filter_threshold=opt.match_threshold,
- mp=opt.use_fp16, # Enable mixed precision if FP16 is requested
- ).eval().to(device)
- print("Loaded SuperPoint and LightGlue models")
- # TensorRT optimization
- if opt.use_tensorrt and TENSORRT_AVAILABLE and device == "cuda":
- try:
- print("="*60)
- print("Compiling models with TensorRT...")
- print(f"Precision: {opt.tensorrt_precision}")
- print("This may take several minutes on first run...")
- print("="*60)
-
- # Compile SuperPoint with TensorRT
- print("Compiling SuperPoint...")
- example_input = torch.randn(1, 1, opt.resize[1], opt.resize[0]).cuda()
-
- enabled_precisions = {torch.float}
- calibration_cache = None
- if opt.tensorrt_precision == "fp16":
- enabled_precisions.add(torch.half)
- elif opt.tensorrt_precision == "int8":
- enabled_precisions.add(torch.int8)
- print(" Note: INT8 quantization will use default calibration")
- print(" For better accuracy, provide calibration data with --tensorrt_calibration_data")
-
- # Create a hybrid approach: compile only the encoder part (conv layers)
- # Keep dynamic operations (keypoint extraction, NMS) in PyTorch
- print(" Creating encoder-only model for TensorRT compilation...")
-
- class SuperPointEncoder(torch.nn.Module):
- """Only the encoder part of SuperPoint (conv layers + feature extraction)"""
- def __init__(self, superpoint_model):
- super().__init__()
- # Copy encoder layers
- self.conv1a = superpoint_model.conv1a
- self.conv1b = superpoint_model.conv1b
- self.conv2a = superpoint_model.conv2a
- self.conv2b = superpoint_model.conv2b
- self.conv3a = superpoint_model.conv3a
- self.conv3b = superpoint_model.conv3b
- self.conv4a = superpoint_model.conv4a
- self.conv4b = superpoint_model.conv4b
- self.pool = superpoint_model.pool
- self.relu = superpoint_model.relu
- # Feature extraction layers
- self.convPa = superpoint_model.convPa
- self.convPb = superpoint_model.convPb
- self.convDa = superpoint_model.convDa
- self.convDb = superpoint_model.convDb
- # Store config for post-processing
- self.conf = superpoint_model.conf
- self.original_model = superpoint_model
-
- def forward(self, image):
- """Forward pass through encoder only"""
- # Shared Encoder
- x = self.relu(self.conv1a(image))
- x = self.relu(self.conv1b(x))
- x = self.pool(x)
- x = self.relu(self.conv2a(x))
- x = self.relu(self.conv2b(x))
- x = self.pool(x)
- x = self.relu(self.conv3a(x))
- x = self.relu(self.conv3b(x))
- x = self.pool(x)
- x = self.relu(self.conv4a(x))
- x = self.relu(self.conv4b(x))
-
- # Compute the dense keypoint scores
- cPa = self.relu(self.convPa(x))
- scores = self.convPb(cPa)
- scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
- b, _, h, w = scores.shape
- scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
- scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
-
- # Compute the dense descriptors
- cDa = self.relu(self.convDa(x))
- descriptors = self.convDb(cDa)
- descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
-
- return scores, descriptors
-
- encoder_model = SuperPointEncoder(extractor).eval()
-
- # Compile encoder with TensorRT
- extractor_trt = None
- try:
- print(" Compiling encoder with TensorRT...")
- with torch.no_grad():
- # Trace the encoder
- traced_encoder = torch.jit.trace(encoder_model, example_input, strict=False)
- traced_encoder.eval()
-
- # For INT8, torch-tensorrt will automatically handle calibration
- # We just need to provide a single example input
- # The calibration_batches parameter is informational only for now
- if opt.tensorrt_precision == "int8":
- print(f" Note: INT8 calibration will be performed automatically")
- print(f" (Calibration batches setting: {opt.tensorrt_calibration_batches})")
- print(" WARNING: INT8 compilation can take 10-20 minutes, please be patient...")
- import sys
- sys.stdout.flush()
-
- # Compile with TensorRT
- # For INT8, torch-tensorrt automatically generates calibration data
- print(" Starting TensorRT compilation (this may take a while)...")
- import sys
- sys.stdout.flush()
-
- encoder_trt = torch_tensorrt.compile(
- traced_encoder,
- inputs=[example_input],
- enabled_precisions=enabled_precisions,
- workspace_size=1 << 30, # 1GB
- min_block_size=7,
- ir="torchscript",
- truncate_long_and_double=True,
- )
- print(" [OK] Encoder compiled with TensorRT successfully")
- import sys
- sys.stdout.flush()
-
- # Create hybrid wrapper that uses TensorRT encoder + PyTorch post-processing
- # Re-import helper functions to ensure they're available
- try:
- from lightglue.superpoint import simple_nms, top_k_keypoints, sample_descriptors
- except ImportError as import_err:
- print(f" [ERROR] Could not import helper functions: {import_err}")
- print(" Falling back to PyTorch model (TensorRT optimization disabled)")
- extractor_trt = None
- raise ImportError("Required helper functions not available") from import_err
-
- class HybridSuperPoint:
- def __init__(self, trt_encoder, original_model):
- self.trt_encoder = trt_encoder
- self.original_model = original_model
- self.conf = original_model.conf
-
- def __call__(self, inputs):
- if isinstance(inputs, dict):
- image = inputs["image"]
- else:
- image = inputs
-
- # Use TensorRT encoder
- scores, descriptors = self.trt_encoder(image)
-
- # Post-processing in PyTorch (dynamic operations)
- scores = simple_nms(scores, self.conf.nms_radius)
-
- # Discard keypoints near borders
- if self.conf.remove_borders:
- pad = self.conf.remove_borders
- scores[:, :pad] = -1
- scores[:, :, :pad] = -1
- scores[:, -pad:] = -1
- scores[:, :, -pad:] = -1
-
- # Extract keypoints
- best_kp = torch.where(scores > self.conf.detection_threshold)
- scores_vals = scores[best_kp]
- b = image.shape[0]
- keypoints = [
- torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i] for i in range(b)
- ]
- scores_list = [scores_vals[best_kp[0] == i] for i in range(b)]
-
- # Top-k keypoints
- if self.conf.max_num_keypoints is not None:
- keypoints, scores_list = list(zip(*[
- top_k_keypoints(k, s, self.conf.max_num_keypoints)
- for k, s in zip(keypoints, scores_list)
- ]))
-
- # Convert (h, w) to (x, y)
- keypoints = [torch.flip(k, [1]).float() for k in keypoints]
-
- # Extract descriptors
- descriptors_list = [
- sample_descriptors(k[None], d[None], 8)[0]
- for k, d in zip(keypoints, descriptors)
- ]
-
- return {
- "keypoints": torch.stack(keypoints, 0),
- "keypoint_scores": torch.stack(scores_list, 0),
- "descriptors": torch.stack(descriptors_list, 0).transpose(-1, -2).contiguous(),
- }
-
- def eval(self):
- return self
-
- def to(self, device):
- return self
-
- extractor_trt = HybridSuperPoint(encoder_trt, extractor)
-
- except Exception as compile_error:
- print(f" [ERROR] TensorRT compilation failed: {compile_error}")
- print(" Falling back to PyTorch model (TensorRT optimization disabled)")
- import traceback
- print(" Full error traceback:")
- traceback.print_exc()
- extractor_trt = None
-
- # Replace extractor with TensorRT version only if compilation succeeded
- if extractor_trt is None:
- print("="*60)
- print("TensorRT optimization skipped, using PyTorch models")
- print("="*60)
- else:
- extractor = extractor_trt # HybridSuperPoint already implements the interface
- print("[OK] SuperPoint encoder compiled with TensorRT")
- print(" (Keypoint extraction and NMS remain in PyTorch for compatibility)")
-
- # Note: LightGlue compilation is more complex due to multiple inputs
- # For now, we'll keep LightGlue as PyTorch model
- print("Note: LightGlue will use PyTorch (TensorRT compilation for LightGlue is more complex)")
- print("="*60)
- print("[OK] TensorRT optimization completed (hybrid approach)")
- print("="*60)
-
- except Exception as e:
- print(f"[ERROR] Failed to compile with TensorRT: {e}")
- print("Falling back to PyTorch models")
- import traceback
- print("Full error traceback:")
- traceback.print_exc()
- import sys
- sys.stdout.flush()
- elif opt.use_tensorrt:
- if not TENSORRT_AVAILABLE:
- print("Warning: TensorRT requested but torch-tensorrt not installed")
- print("Install with: pip install torch-tensorrt")
- elif device != "cuda":
- print("Warning: TensorRT requires CUDA, but running on CPU")
- ref_frame, ref_tensor, last_image_id = load_reference_frame(opt, device)
- if ref_tensor is not None:
- if opt.use_fp16 and device == "cuda":
- with torch.cuda.amp.autocast():
- last_data = extractor({"image": ref_tensor})
- else:
- last_data = extractor({"image": ref_tensor})
- last_frame = ref_frame
- else:
- last_data = None
- last_frame = None
- last_image_id = 0
- streamer = VideoStreamer(opt.input, opt.resize, opt.skip, opt.image_glob, opt.max_length)
-
- # 处理UDP模式
- if hasattr(streamer, "is_udp_jpeg") and streamer.is_udp_jpeg:
- print("UDP JPEG mode: receiver started in background thread", flush=True)
- # 处理摄像头模式
- elif hasattr(streamer, "cap") and streamer.cap is not None:
- is_local_cam = False
- if isinstance(opt.input, str) and opt.input.isdigit():
- is_local_cam = True
- elif isinstance(opt.input, int):
- is_local_cam = True
- if is_local_cam:
- desired_width, desired_height, desired_fps = 640, 480, 30
- streamer.cap.set(cv2.CAP_PROP_FRAME_WIDTH, desired_width)
- streamer.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, desired_height)
- streamer.cap.set(cv2.CAP_PROP_FPS, desired_fps)
- actual_w = streamer.cap.get(cv2.CAP_PROP_FRAME_WIDTH)
- actual_h = streamer.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
- actual_fps = streamer.cap.get(cv2.CAP_PROP_FPS)
- print(f"Camera props requested -> {desired_width}x{desired_height} @{desired_fps} FPS")
- print(f"Camera props applied -> {actual_w:.0f}x{actual_h:.0f} @{actual_fps:.1f} FPS")
- if opt.no_ip_grab and hasattr(streamer, "is_ip_camera"):
- streamer.is_ip_camera = False
- print("IP camera buffer flush disabled (no extra grab calls).")
-
- async_streamer = AsyncVideoStreamer(streamer, queue_size=opt.queue_size, timeout=opt.read_timeout)
- if last_data is None:
- # For UDP mode, wait a bit for the first frame to arrive
- if hasattr(streamer, "is_udp_jpeg") and streamer.is_udp_jpeg:
- print("Waiting for first UDP frame...")
- max_wait_time = 10.0 # Wait up to 10 seconds
- wait_interval = 0.1
- elapsed = 0.0
- first_frame, ret = None, False
- while elapsed < max_wait_time:
- first_frame, ret = async_streamer.read()
- if ret:
- break
- time.sleep(wait_interval)
- elapsed += wait_interval
- if int(elapsed) % 2 == 0 and int(elapsed - wait_interval) % 2 != 0:
- print(f"Still waiting for UDP frame... ({int(elapsed)}s)")
- else:
- first_frame, ret = async_streamer.read()
-
- if not ret:
- raise RuntimeError(
- "Error when reading the first frame. "
- "For UDP mode, make sure:\n"
- " 1. The sender is running and sending data\n"
- " 2. The port number is correct\n"
- " 3. Firewall allows UDP traffic on this port"
- )
-
- first_frame = apply_orientation(first_frame, opt)
- last_frame = first_frame
- last_tensor = frame2tensor(first_frame, device)
- if opt.use_fp16 and device == "cuda":
- with torch.cuda.amp.autocast():
- last_data = extractor({"image": last_tensor})
- else:
- last_data = extractor({"image": last_tensor})
- print("First frame received and processed")
- if opt.output_dir is not None:
- Path(opt.output_dir).mkdir(exist_ok=True)
- print(f"==> Will write outputs to {opt.output_dir}")
- window_name_ref = "Camera Position in Reference"
- if not opt.no_display:
- try:
- cv2.namedWindow(window_name_ref, cv2.WINDOW_NORMAL)
- cv2.resizeWindow(window_name_ref, 640, 480)
- except cv2.error as e:
- print(f"Warning: Could not create OpenCV windows: {e}")
- print("Continuing without display...")
- opt.no_display = True
- print("==> Keyboard control:\n"
- "\tn: set current frame as reference\n"
- "\tq: quit\n"
- "\tf: toggle FPS overlay\n")
- timer = AverageTimer()
- show_fps = opt.show_fps
- fps_display = 0.0
- last_time = time.time()
- last_fps_print_time = time.time()
- fps_print_interval = 2.0 # Print FPS every 2 seconds
- # Initialize UDP result sender (if available)
- result_sender = None
- if UDP_RESULT_SENDER_AVAILABLE:
- try:
- result_sender = UDPResultSender(unity_ip=opt.result_ip, unity_port=opt.result_port)
- print(f"[UDP] Result sender initialized: {opt.result_ip}:{opt.result_port}")
- except Exception as e:
- print(f"[UDP] Failed to initialize result sender: {e}")
- result_sender = None
- else:
- print("[UDP] UDP result sender not available (udp_result_sender.py not found)")
- # Create async visualizer: 无窗口时也需创建以便通过 result_sender 向 Unity 回传结果;仅 no_display 时不再弹窗
- async_visualizer = AsyncVisualizer(queue_size=2, min_matches=opt.min_matches, result_sender=result_sender) if (not opt.no_display or result_sender is not None) else None
- # 可选:启动来自 Unity 的 UDP 控制监听(用于刷新参考图等简单指令)
- control_sock = None
- control_stop_event = None
- control_thread = None
- control_refresh_event = None
- # 控制指令:n=当前摄像头帧作参考图, s=Python截屏作参考图, r=下一帧来自Unity作为参考图
- screen_capture_event = threading.Event()
- next_frame_is_reference_event = threading.Event()
- if opt.control_port and opt.control_port > 0:
- def _control_listener(sock: socket.socket, stop_event: threading.Event, refresh_event: threading.Event,
- screen_ev: threading.Event, next_ref_ev: threading.Event) -> None:
- sock.settimeout(0.5)
- print(f"[Control] Listening for UDP control commands on 0.0.0.0:{opt.control_port} (n/s/r)")
- while not stop_event.is_set():
- try:
- try:
- data, addr = sock.recvfrom(1024)
- except socket.timeout:
- continue
- if not data:
- continue
- cmd = data[0]
- # n/N: 当前摄像头帧作参考图; s: Python截屏作参考图; r: 下一帧(Unity游戏画面)作参考图
- if cmd in (ord("n"), ord("N"), 1):
- refresh_event.set()
- elif cmd == ord("s"):
- screen_ev.set()
- elif cmd == ord("r"):
- next_ref_ev.set()
- except OSError:
- break
- except Exception as exc: # pylint: disable=broad-except
- print(f"[Control] Listener error: {exc}")
- continue
- try:
- control_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- control_sock.bind(("0.0.0.0", opt.control_port))
- control_stop_event = threading.Event()
- control_refresh_event = threading.Event()
- control_thread = threading.Thread(
- target=_control_listener,
- args=(control_sock, control_stop_event, control_refresh_event,
- screen_capture_event, next_frame_is_reference_event),
- name="ControlListener",
- daemon=True,
- )
- control_thread.start()
- except Exception as exc: # pylint: disable=broad-except
- print(f"[Control] Failed to start UDP control listener on port {opt.control_port}: {exc}")
- control_sock = None
- control_stop_event = None
- control_thread = None
- control_refresh_event = None
- try:
- while True:
- loop_start_time = time.time()
- # 控制指令 s:Python 截屏作为基准图(在读取本帧前处理)
- if screen_capture_event.is_set():
- screen_capture_event.clear()
- sc_frame = capture_screen_frame(opt)
- if sc_frame is not None:
- last_frame = sc_frame
- last_tensor = frame2tensor(last_frame, device)
- if opt.use_fp16 and device == "cuda":
- with torch.cuda.amp.autocast():
- new_ref = extractor({"image": last_tensor})
- else:
- new_ref = extractor({"image": last_tensor})
- # 只有在新参考图中检测到关键点时才更新基准,避免空 keypoints 导致 LightGlue 报错
- if new_ref.get("keypoints", None) is not None and new_ref["keypoints"].shape[1] > 0:
- last_data = new_ref
- last_image_id += 1
- else:
- print("[Control] Screen capture has no keypoints, keep previous reference")
- elif not DXCAM_AVAILABLE:
- print("[Control] Screen capture skipped: install dxcam (pip install dxcam)")
- frame, ret = async_streamer.read()
- if not ret:
- # For UDP mode, timeout doesn't mean end of stream
- # Continue waiting for new frames
- if hasattr(streamer, "is_udp_jpeg") and streamer.is_udp_jpeg:
- continue # Keep waiting for UDP frames
- # For other modes, timeout means end of stream
- print("Stream ended or timeout exceeded.")
- break
- frame = apply_orientation(frame, opt)
- timer.update("data")
- # 控制指令 r:本帧为 Unity 发来的游戏画面,用作基准图后跳过本帧推理
- if next_frame_is_reference_event.is_set():
- next_frame_is_reference_event.clear()
- last_frame = frame
- frame_tensor = frame2tensor(frame, device)
- if opt.use_fp16 and device == "cuda":
- with torch.cuda.amp.autocast():
- new_ref = extractor({"image": frame_tensor})
- else:
- new_ref = extractor({"image": frame_tensor})
- # 同样只在有关键点时更新参考图
- if new_ref.get("keypoints", None) is not None and new_ref["keypoints"].shape[1] > 0:
- last_data = new_ref
- last_image_id += 1
- print("[Control] Reference updated from Unity game view frame")
- else:
- print("[Control] Unity game view frame has no keypoints, keep previous reference")
- continue
- frame_tensor = frame2tensor(frame, device)
-
- # Use FP16 autocast if enabled
- if opt.use_fp16 and device == "cuda":
- with torch.cuda.amp.autocast():
- curr_data = extractor({"image": frame_tensor})
- else:
- curr_data = extractor({"image": frame_tensor})
- # 如果任一图像没有关键点,跳过本帧匹配,避免 LightGlue 报 IndexError
- if last_data is None or last_data.get("keypoints", None) is None or last_data["keypoints"].shape[1] == 0:
- print("[Guard] Reference has no keypoints, skip matching this frame")
- continue
- if curr_data.get("keypoints", None) is None or curr_data["keypoints"].shape[1] == 0:
- print("[Guard] Current frame has no keypoints, skip matching this frame")
- continue
- matches01 = matcher({"image0": last_data, "image1": curr_data})
- # Update timer immediately after inference (all CPU operations are async now)
- timer.update("forward")
-
- # Calculate FPS based on the entire loop time (more accurate)
- loop_end_time = time.time()
- dt = loop_end_time - loop_start_time
- if dt > 0:
- fps_display = 0.9 * fps_display + 0.1 * (1.0 / dt)
-
- # Use loop_end_time for console printing
- current_time = loop_end_time
-
- # Print FPS to console periodically
- if current_time - last_fps_print_time >= fps_print_interval:
- fp16_status = "FP16" if (opt.use_fp16 and device == "cuda") else "FP32"
- print(f"[FPS] {fps_display:.1f} FPS ({fp16_status})")
- last_fps_print_time = current_time
-
- # Submit all CPU operations to background thread (non-blocking)
- # This includes: tensor->numpy, homography calculation, visualization
- # Now fps_display is already updated, so it will show correct FPS on screen
- center_x, center_y = frame.shape[1] // 2, frame.shape[0] // 2
- camera_center_current = (center_x, center_y)
-
- if async_visualizer is not None:
- async_visualizer.submit(
- frame, # Raw frame (will be processed in background)
- last_frame,
- last_data, # Raw tensor (will be converted in background)
- curr_data, # Raw tensor (will be converted in background)
- matches01, # Raw tensor (will be converted in background)
- camera_center_current,
- show_fps,
- fps_display # Now this is the current frame's FPS
- )
- # Try to get visualization result (non-blocking); no_display 时仍取结果以消费队列,但不弹窗
- if async_visualizer is not None:
- viz_result = async_visualizer.get_result(timeout=0.0) # Non-blocking
- if viz_result is not None and not opt.no_display:
- display_frame, reference_view, num_matches, current_H = viz_result
- cv2.imshow(window_name_ref, reference_view)
- # 处理来自 Unity 的“刷新参考图”控制指令(等价于按键 n)
- if control_thread is not None and control_refresh_event is not None:
- if control_refresh_event.is_set():
- control_refresh_event.clear()
- last_data = curr_data
- last_frame = frame
- last_image_id += 1
- print("[Control] Applied refresh-reference command from Unity, updated reference frame")
-
- # Handle keyboard input (reduce frequency to minimize overhead)
- # Only check every frame (waitKey is necessary for window events)
- if not opt.no_display:
- key = cv2.waitKey(1) & 0xFF
- else:
- key = 0
-
- # Update timer after all operations (for accurate total time)
- # timer.print("LightGlue-Async")
- if key == ord("q"):
- print("Exiting via keyboard (q)")
- break
- if key == ord("n"):
- last_data = curr_data
- last_frame = frame
- last_image_id += 1
- print("Updated reference frame")
- elif key == ord("f"):
- show_fps = not show_fps
- finally:
- async_streamer.stop()
- if async_visualizer is not None:
- async_visualizer.stop()
- if result_sender is not None:
- result_sender.close()
- # 关闭控制监听
- if control_stop_event is not None:
- try:
- control_stop_event.set()
- except Exception:
- pass
- if control_sock is not None:
- try:
- control_sock.close()
- except Exception:
- pass
- try:
- cv2.destroyAllWindows()
- except:
- pass # Ignore errors if windows weren't created
- if __name__ == "__main__":
- main()
|