export.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from cv2 import imshow
  2. from matplotlib import lines
  3. import numpy as np
  4. import onnxruntime
  5. import cv2
  6. import torch
  7. import onnx
  8. from basemodel import TextDetBase
  9. import onnxsim
  10. from models.yolov5.common import Conv
  11. from models.yolov5.yolo import Detect
  12. import torch.nn as nn
  13. import time
  14. from seg_dataset import letterbox
  15. from utils.yolov5_utils import fuse_conv_and_bn
  16. class SiLU(nn.Module): # export-friendly version of nn.SiLU()
  17. @staticmethod
  18. def forward(x):
  19. return x * torch.sigmoid(x)
  20. def concate_models(blk_weights, seg_weights, det_weights, save_path):
  21. textdetector_dict = dict()
  22. textdetector_dict['blk_det'] = torch.load(blk_weights, map_location='cpu')
  23. textdetector_dict['text_seg'] = torch.load(seg_weights, map_location='cpu')['weights']
  24. textdetector_dict['text_det'] = torch.load(det_weights, map_location='cpu')['weights']
  25. torch.save(textdetector_dict, save_path)
  26. def export_onnx(model, im, file, opset, train=False, simplify=True, dynamic=False, inplace=False):
  27. # YOLOv5 ONNX export
  28. f = file + '.onnx'
  29. for k, m in model.named_modules():
  30. if isinstance(m, Conv): # assign export-friendly activations
  31. if isinstance(m.act, nn.SiLU):
  32. m.act = SiLU()
  33. elif isinstance(m, Detect):
  34. m.inplace = inplace
  35. m.onnx_dynamic = False
  36. torch.onnx.export(model, im, f, verbose=False, opset_version=opset,
  37. training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
  38. do_constant_folding=not train,
  39. input_names=['images'],
  40. output_names=['blk', 'seg', 'det'],
  41. dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
  42. 'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
  43. } if dynamic else None)
  44. # Checks
  45. model_onnx = onnx.load(f) # load onnx model
  46. onnx.checker.check_model(model_onnx) # check onnx model
  47. model_onnx, check = onnxsim.simplify(
  48. model_onnx,
  49. dynamic_input_shape=dynamic,
  50. input_shapes={'images': list(im.shape)} if dynamic else None)
  51. assert check, 'assert check failed'
  52. onnx.save(model_onnx, f)