trt_engine.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. """
  2. Utility classes for running TensorRT engines with PyCUDA.
  3. """
  4. from __future__ import annotations
  5. from dataclasses import dataclass
  6. from pathlib import Path
  7. from typing import Dict, List
  8. import numpy as np
  9. import tensorrt as trt
  10. try:
  11. import pycuda.autoinit # noqa: F401
  12. import pycuda.driver as cuda
  13. except ImportError as exc: # pragma: no cover - optional dependency
  14. raise ImportError(
  15. "PyCUDA is required for TensorRT inference. Install pycuda before "
  16. "running the TensorRT demo."
  17. ) from exc
  18. @dataclass
  19. class HostDeviceMem:
  20. host: np.ndarray
  21. device: "cuda.DeviceAllocation"
  22. class TensorRTEngine:
  23. """Lightweight wrapper around TensorRT runtime execution."""
  24. def __init__(self, engine_path: Path):
  25. self.logger = trt.Logger(trt.Logger.ERROR)
  26. self.runtime = trt.Runtime(self.logger)
  27. engine_bytes = engine_path.read_bytes()
  28. self.engine = self.runtime.deserialize_cuda_engine(engine_bytes)
  29. if self.engine is None:
  30. raise RuntimeError(f"Failed to load TensorRT engine: {engine_path}")
  31. self.context = self.engine.create_execution_context()
  32. self.bindings: List[int] = []
  33. self.inputs: Dict[str, HostDeviceMem] = {}
  34. self.outputs: Dict[str, HostDeviceMem] = {}
  35. self.stream = cuda.Stream()
  36. self._allocate_buffers()
  37. def _allocate_buffers(self):
  38. for binding in self.engine:
  39. idx = self.engine.get_binding_index(binding)
  40. dims = self.context.get_binding_shape(idx)
  41. dtype = trt.nptype(self.engine.get_binding_dtype(binding))
  42. host_mem = np.empty(shape=dims, dtype=dtype)
  43. device_mem = cuda.mem_alloc(host_mem.nbytes)
  44. self.bindings.append(int(device_mem))
  45. pair = HostDeviceMem(host_mem, device_mem)
  46. if self.engine.binding_is_input(binding):
  47. self.inputs[binding] = pair
  48. else:
  49. self.outputs[binding] = pair
  50. def infer(self, feed_dict: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
  51. """Run inference. `feed_dict` keys must match engine input names."""
  52. for name, arr in feed_dict.items():
  53. if name not in self.inputs:
  54. raise KeyError(f"Unknown input binding: {name}")
  55. host_mem = self.inputs[name].host
  56. np.copyto(host_mem, arr.astype(host_mem.dtype, copy=False))
  57. cuda.memcpy_htod_async(
  58. self.inputs[name].device, host_mem, self.stream
  59. )
  60. self.context.execute_async_v2(self.bindings, self.stream.handle, None)
  61. results: Dict[str, np.ndarray] = {}
  62. for name, pair in self.outputs.items():
  63. cuda.memcpy_dtoh_async(pair.host, pair.device, self.stream)
  64. results[name] = pair.host.copy()
  65. self.stream.synchronize()
  66. return results
  67. def set_input_shape(self, name: str, shape):
  68. idx = self.engine.get_binding_index(name)
  69. self.context.set_binding_shape(idx, shape)
  70. def load_engines(superpoint_plan: Path, lightglue_plan: Path):
  71. sp_engine = TensorRTEngine(superpoint_plan)
  72. lg_engine = TensorRTEngine(lightglue_plan)
  73. return sp_engine, lg_engine