| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- """
- Utility classes for running TensorRT engines with PyCUDA.
- """
- from __future__ import annotations
- from dataclasses import dataclass
- from pathlib import Path
- from typing import Dict, List
- import numpy as np
- import tensorrt as trt
- try:
- import pycuda.autoinit # noqa: F401
- import pycuda.driver as cuda
- except ImportError as exc: # pragma: no cover - optional dependency
- raise ImportError(
- "PyCUDA is required for TensorRT inference. Install pycuda before "
- "running the TensorRT demo."
- ) from exc
- @dataclass
- class HostDeviceMem:
- host: np.ndarray
- device: "cuda.DeviceAllocation"
- class TensorRTEngine:
- """Lightweight wrapper around TensorRT runtime execution."""
- def __init__(self, engine_path: Path):
- self.logger = trt.Logger(trt.Logger.ERROR)
- self.runtime = trt.Runtime(self.logger)
- engine_bytes = engine_path.read_bytes()
- self.engine = self.runtime.deserialize_cuda_engine(engine_bytes)
- if self.engine is None:
- raise RuntimeError(f"Failed to load TensorRT engine: {engine_path}")
- self.context = self.engine.create_execution_context()
- self.bindings: List[int] = []
- self.inputs: Dict[str, HostDeviceMem] = {}
- self.outputs: Dict[str, HostDeviceMem] = {}
- self.stream = cuda.Stream()
- self._allocate_buffers()
- def _allocate_buffers(self):
- for binding in self.engine:
- idx = self.engine.get_binding_index(binding)
- dims = self.context.get_binding_shape(idx)
- dtype = trt.nptype(self.engine.get_binding_dtype(binding))
- host_mem = np.empty(shape=dims, dtype=dtype)
- device_mem = cuda.mem_alloc(host_mem.nbytes)
- self.bindings.append(int(device_mem))
- pair = HostDeviceMem(host_mem, device_mem)
- if self.engine.binding_is_input(binding):
- self.inputs[binding] = pair
- else:
- self.outputs[binding] = pair
- def infer(self, feed_dict: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
- """Run inference. `feed_dict` keys must match engine input names."""
- for name, arr in feed_dict.items():
- if name not in self.inputs:
- raise KeyError(f"Unknown input binding: {name}")
- host_mem = self.inputs[name].host
- np.copyto(host_mem, arr.astype(host_mem.dtype, copy=False))
- cuda.memcpy_htod_async(
- self.inputs[name].device, host_mem, self.stream
- )
- self.context.execute_async_v2(self.bindings, self.stream.handle, None)
- results: Dict[str, np.ndarray] = {}
- for name, pair in self.outputs.items():
- cuda.memcpy_dtoh_async(pair.host, pair.device, self.stream)
- results[name] = pair.host.copy()
- self.stream.synchronize()
- return results
- def set_input_shape(self, name: str, shape):
- idx = self.engine.get_binding_index(name)
- self.context.set_binding_shape(idx, shape)
- def load_engines(superpoint_plan: Path, lightglue_plan: Path):
- sp_engine = TensorRTEngine(superpoint_plan)
- lg_engine = TensorRTEngine(lightglue_plan)
- return sp_engine, lg_engine
|