""" Utility classes for running TensorRT engines with PyCUDA. """ from __future__ import annotations from dataclasses import dataclass from pathlib import Path from typing import Dict, List import numpy as np import tensorrt as trt try: import pycuda.autoinit # noqa: F401 import pycuda.driver as cuda except ImportError as exc: # pragma: no cover - optional dependency raise ImportError( "PyCUDA is required for TensorRT inference. Install pycuda before " "running the TensorRT demo." ) from exc @dataclass class HostDeviceMem: host: np.ndarray device: "cuda.DeviceAllocation" class TensorRTEngine: """Lightweight wrapper around TensorRT runtime execution.""" def __init__(self, engine_path: Path): self.logger = trt.Logger(trt.Logger.ERROR) self.runtime = trt.Runtime(self.logger) engine_bytes = engine_path.read_bytes() self.engine = self.runtime.deserialize_cuda_engine(engine_bytes) if self.engine is None: raise RuntimeError(f"Failed to load TensorRT engine: {engine_path}") self.context = self.engine.create_execution_context() self.bindings: List[int] = [] self.inputs: Dict[str, HostDeviceMem] = {} self.outputs: Dict[str, HostDeviceMem] = {} self.stream = cuda.Stream() self._allocate_buffers() def _allocate_buffers(self): for binding in self.engine: idx = self.engine.get_binding_index(binding) dims = self.context.get_binding_shape(idx) dtype = trt.nptype(self.engine.get_binding_dtype(binding)) host_mem = np.empty(shape=dims, dtype=dtype) device_mem = cuda.mem_alloc(host_mem.nbytes) self.bindings.append(int(device_mem)) pair = HostDeviceMem(host_mem, device_mem) if self.engine.binding_is_input(binding): self.inputs[binding] = pair else: self.outputs[binding] = pair def infer(self, feed_dict: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: """Run inference. `feed_dict` keys must match engine input names.""" for name, arr in feed_dict.items(): if name not in self.inputs: raise KeyError(f"Unknown input binding: {name}") host_mem = self.inputs[name].host np.copyto(host_mem, arr.astype(host_mem.dtype, copy=False)) cuda.memcpy_htod_async( self.inputs[name].device, host_mem, self.stream ) self.context.execute_async_v2(self.bindings, self.stream.handle, None) results: Dict[str, np.ndarray] = {} for name, pair in self.outputs.items(): cuda.memcpy_dtoh_async(pair.host, pair.device, self.stream) results[name] = pair.host.copy() self.stream.synchronize() return results def set_input_shape(self, name: str, shape): idx = self.engine.get_binding_index(name) self.context.set_binding_shape(idx, shape) def load_engines(superpoint_plan: Path, lightglue_plan: Path): sp_engine = TensorRTEngine(superpoint_plan) lg_engine = TensorRTEngine(lightglue_plan) return sp_engine, lg_engine