predict_base.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import onnxruntime
  2. class PredictBase(object):
  3. def __init__(self):
  4. pass
  5. def get_onnx_session(self, model_dir, use_gpu, gpu_id = 0):
  6. # 使用gpu
  7. if use_gpu:
  8. providers =[('CUDAExecutionProvider',{"cudnn_conv_algo_search": "DEFAULT","device_id": gpu_id}),'CPUExecutionProvider']
  9. else:
  10. providers =['CPUExecutionProvider']
  11. onnx_session = onnxruntime.InferenceSession(model_dir, None,providers=providers)
  12. # print("providers:", onnxruntime.get_device())
  13. return onnx_session
  14. def get_output_name(self, onnx_session):
  15. """
  16. output_name = onnx_session.get_outputs()[0].name
  17. :param onnx_session:
  18. :return:
  19. """
  20. output_name = []
  21. for node in onnx_session.get_outputs():
  22. output_name.append(node.name)
  23. return output_name
  24. def get_input_name(self, onnx_session):
  25. """
  26. input_name = onnx_session.get_inputs()[0].name
  27. :param onnx_session:
  28. :return:
  29. """
  30. input_name = []
  31. for node in onnx_session.get_inputs():
  32. input_name.append(node.name)
  33. return input_name
  34. def get_input_feed(self, input_name, image_numpy):
  35. """
  36. input_feed={self.input_name: image_numpy}
  37. :param input_name:
  38. :param image_numpy:
  39. :return:
  40. """
  41. input_feed = {}
  42. for name in input_name:
  43. input_feed[name] = image_numpy
  44. return input_feed