basemodel.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. from utils.general import CUDA, DEVICE
  2. from models.yolov5.yolo import Model
  3. import torch
  4. import cv2
  5. import numpy as np
  6. from models.yolov5.yolo import load_yolov5_ckpt
  7. from utils.yolov5_utils import fuse_conv_and_bn
  8. import glob
  9. import torch.nn as nn
  10. from utils.weight_init import init_weights
  11. from models.yolov5.common import C3, Conv
  12. from torchsummary import summary
  13. import torch.nn.functional as F
  14. import copy
  15. TEXTDET_MASK = 0
  16. TEXTDET_DET = 1
  17. TEXTDET_INFERENCE = 2
  18. class double_conv_up_c3(nn.Module):
  19. def __init__(self, in_ch, mid_ch, out_ch, act=True):
  20. super(double_conv_up_c3, self).__init__()
  21. self.conv = nn.Sequential(
  22. C3(in_ch+mid_ch, mid_ch, act=act),
  23. nn.ConvTranspose2d(mid_ch, out_ch, kernel_size=4, stride = 2, padding=1, bias=False),
  24. nn.BatchNorm2d(out_ch),
  25. nn.ReLU(inplace=True),
  26. )
  27. def forward(self, x):
  28. return self.conv(x)
  29. class double_conv_c3(nn.Module):
  30. def __init__(self, in_ch, out_ch, stride=1, act=True):
  31. super(double_conv_c3, self).__init__()
  32. if stride > 1 :
  33. self.down = nn.AvgPool2d(2,stride=2) if stride > 1 else None
  34. self.conv = C3(in_ch, out_ch, act=act)
  35. def forward(self, x):
  36. if self.down is not None :
  37. x = self.down(x)
  38. x = self.conv(x)
  39. return x
  40. class UnetHead(nn.Module):
  41. def __init__(self, act=True) -> None:
  42. super(UnetHead, self).__init__()
  43. self.down_conv1 = double_conv_c3(512, 512, 2, act=act)
  44. self.upconv0 = double_conv_up_c3(0, 512, 256, act=act)
  45. self.upconv2 = double_conv_up_c3(256, 512, 256, act=act)
  46. self.upconv3 = double_conv_up_c3(0, 512, 256, act=act)
  47. self.upconv4 = double_conv_up_c3(128, 256, 128, act=act)
  48. self.upconv5 = double_conv_up_c3(64, 128, 64, act=act)
  49. self.upconv6 = nn.Sequential(
  50. nn.ConvTranspose2d(64, 1, kernel_size=4, stride = 2, padding=1, bias=False),
  51. nn.Sigmoid()
  52. )
  53. def forward(self, f160, f80, f40, f20, f3, forward_mode=TEXTDET_MASK):
  54. # input: 640@3
  55. d10 = self.down_conv1(f3) # 512@10
  56. u20 = self.upconv0(d10) # 256@10
  57. u40 = self.upconv2(torch.cat([f20, u20], dim = 1)) # 256@40
  58. if forward_mode == TEXTDET_DET:
  59. return f80, f40, u40
  60. else:
  61. u80 = self.upconv3(torch.cat([f40, u40], dim = 1)) # 256@80
  62. u160 = self.upconv4(torch.cat([f80, u80], dim = 1)) # 128@160
  63. u320 = self.upconv5(torch.cat([f160, u160], dim = 1)) # 64@320
  64. mask = self.upconv6(u320)
  65. if forward_mode == TEXTDET_MASK:
  66. return mask
  67. else:
  68. return mask, [f80, f40, u40]
  69. def init_weight(self, init_func):
  70. self.apply(init_func)
  71. class DBHead(nn.Module):
  72. def __init__(self, in_channels, k = 50, shrink_with_sigmoid=True, act=True):
  73. super().__init__()
  74. self.k = k
  75. self.shrink_with_sigmoid = shrink_with_sigmoid
  76. self.upconv3 = double_conv_up_c3(0, 512, 256, act=act)
  77. self.upconv4 = double_conv_up_c3(128, 256, 128, act=act)
  78. self.conv = nn.Sequential(
  79. nn.Conv2d(128, in_channels, 1),
  80. nn.BatchNorm2d(in_channels),
  81. nn.ReLU(inplace=True)
  82. )
  83. self.binarize = nn.Sequential(
  84. nn.Conv2d(in_channels, in_channels // 4, 3, padding=1),
  85. nn.BatchNorm2d(in_channels // 4),
  86. nn.ReLU(inplace=True),
  87. nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
  88. nn.BatchNorm2d(in_channels // 4),
  89. nn.ReLU(inplace=True),
  90. nn.ConvTranspose2d(in_channels // 4, 1, 2, 2)
  91. )
  92. self.thresh = self._init_thresh(in_channels)
  93. def forward(self, f80, f40, u40, shrink_with_sigmoid=True, step_eval=False):
  94. shrink_with_sigmoid = self.shrink_with_sigmoid
  95. u80 = self.upconv3(torch.cat([f40, u40], dim = 1)) # 256@80
  96. x = self.upconv4(torch.cat([f80, u80], dim = 1)) # 128@160
  97. x = self.conv(x)
  98. threshold_maps = self.thresh(x)
  99. x = self.binarize(x)
  100. shrink_maps = torch.sigmoid(x)
  101. if self.training:
  102. binary_maps = self.step_function(shrink_maps, threshold_maps)
  103. if shrink_with_sigmoid:
  104. return torch.cat((shrink_maps, threshold_maps, binary_maps), dim=1)
  105. else:
  106. return torch.cat((shrink_maps, threshold_maps, binary_maps, x), dim=1)
  107. else:
  108. if step_eval:
  109. return self.step_function(shrink_maps, threshold_maps)
  110. else:
  111. return torch.cat((shrink_maps, threshold_maps), dim=1)
  112. def init_weight(self, init_func):
  113. self.apply(init_func)
  114. def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False):
  115. in_channels = inner_channels
  116. if serial:
  117. in_channels += 1
  118. self.thresh = nn.Sequential(
  119. nn.Conv2d(in_channels, inner_channels // 4, 3, padding=1, bias=bias),
  120. nn.BatchNorm2d(inner_channels // 4),
  121. nn.ReLU(inplace=True),
  122. self._init_upsample(inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias),
  123. nn.BatchNorm2d(inner_channels // 4),
  124. nn.ReLU(inplace=True),
  125. self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),
  126. nn.Sigmoid())
  127. return self.thresh
  128. def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):
  129. if smooth:
  130. inter_out_channels = out_channels
  131. if out_channels == 1:
  132. inter_out_channels = in_channels
  133. module_list = [
  134. nn.Upsample(scale_factor=2, mode='nearest'),
  135. nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)]
  136. if out_channels == 1:
  137. module_list.append(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=1, bias=True))
  138. return nn.Sequential(module_list)
  139. else:
  140. return nn.ConvTranspose2d(in_channels, out_channels, 2, 2)
  141. def step_function(self, x, y):
  142. return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
  143. class TextDetector(nn.Module):
  144. def __init__(self, weights, map_location='cpu', forward_mode=TEXTDET_MASK, act=True):
  145. super(TextDetector, self).__init__()
  146. yolov5s_backbone = load_yolov5_ckpt(weights=weights, map_location=map_location)
  147. yolov5s_backbone.eval()
  148. out_indices = [1, 3, 5, 7, 9]
  149. yolov5s_backbone.out_indices = out_indices
  150. yolov5s_backbone.model = yolov5s_backbone.model[:max(out_indices)+1]
  151. self.act = act
  152. self.seg_net = UnetHead(act=act)
  153. self.backbone = yolov5s_backbone
  154. self.dbnet = None
  155. self.forward_mode = forward_mode
  156. def train_mask(self):
  157. self.forward_mode = TEXTDET_MASK
  158. self.backbone.eval()
  159. self.seg_net.train()
  160. def initialize_db(self, unet_weights):
  161. self.dbnet = DBHead(64, act=self.act)
  162. self.seg_net.load_state_dict(torch.load(unet_weights, map_location='cpu')['weights'])
  163. self.dbnet.init_weight(init_weights)
  164. self.dbnet.upconv3 = copy.deepcopy(self.seg_net.upconv3)
  165. self.dbnet.upconv4 = copy.deepcopy(self.seg_net.upconv4)
  166. del self.seg_net.upconv3
  167. del self.seg_net.upconv4
  168. del self.seg_net.upconv5
  169. del self.seg_net.upconv6
  170. # del self.seg_net.conv_mask
  171. def train_db(self):
  172. self.forward_mode = TEXTDET_DET
  173. self.backbone.eval()
  174. self.seg_net.eval()
  175. self.dbnet.train()
  176. def forward(self, x):
  177. forward_mode = self.forward_mode
  178. with torch.no_grad():
  179. outs = self.backbone(x)
  180. if forward_mode == TEXTDET_MASK:
  181. return self.seg_net(*outs, forward_mode=forward_mode)
  182. elif forward_mode == TEXTDET_DET:
  183. with torch.no_grad():
  184. outs = self.seg_net(*outs, forward_mode=forward_mode)
  185. return self.dbnet(*outs)
  186. def get_base_det_models(model_path, device='cpu', half=False, act='leaky'):
  187. textdetector_dict = torch.load(model_path, map_location=device)
  188. blk_det = load_yolov5_ckpt(textdetector_dict['blk_det'], map_location=device)
  189. text_seg = UnetHead(act=act)
  190. text_seg.load_state_dict(textdetector_dict['text_seg'])
  191. text_det = DBHead(64, act=act)
  192. text_det.load_state_dict(textdetector_dict['text_det'])
  193. if half:
  194. return blk_det.eval().half(), text_seg.eval().half(), text_det.eval().half()
  195. return blk_det.eval().to(device), text_seg.eval().to(device), text_det.eval().to(device)
  196. class TextDetBase(nn.Module):
  197. def __init__(self, model_path, device='cpu', half=False, fuse=False, act='leaky'):
  198. super(TextDetBase, self).__init__()
  199. self.blk_det, self.text_seg, self.text_det = get_base_det_models(model_path, device, half, act=act)
  200. if fuse:
  201. self.fuse()
  202. def fuse(self):
  203. def _fuse(model):
  204. for m in model.modules():
  205. if isinstance(m, (Conv)) and hasattr(m, 'bn'):
  206. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  207. delattr(m, 'bn') # remove batchnorm
  208. m.forward = m.forward_fuse # update forward
  209. return model
  210. self.text_seg = _fuse(self.text_seg)
  211. self.text_det = _fuse(self.text_det)
  212. def forward(self, features):
  213. blks, features = self.blk_det(features, detect=True)
  214. mask, features = self.text_seg(*features, forward_mode=TEXTDET_INFERENCE)
  215. lines = self.text_det(*features, step_eval=False)
  216. return blks[0], mask, lines
  217. class TextDetBaseDNN:
  218. def __init__(self, input_size, model_path):
  219. self.input_size = input_size
  220. self.model = cv2.dnn.readNetFromONNX(model_path)
  221. self.uoln = self.model.getUnconnectedOutLayersNames()
  222. def __call__(self, im_in):
  223. blob = cv2.dnn.blobFromImage(im_in, scalefactor=1 / 255.0, size=(self.input_size, self.input_size))
  224. self.model.setInput(blob)
  225. blks, mask, lines_map = self.model.forward(self.uoln)
  226. return blks, mask, lines_map
  227. if __name__ == '__main__':
  228. device = 'cuda'
  229. weights = r'data/yolov5sblk.ckpt'
  230. # yolov5s_backbone = load_yolov5_ckpt(weights=weights, map_location='cpu')
  231. model = TextDetector(weights, map_location=DEVICE)
  232. model.to(DEVICE)
  233. model.train_mask()
  234. summary(model, (3, 640, 640), device=DEVICE)
  235. # model.initialize_db(unet_weights='data/unet_head.pt')
  236. # model.train_db()
  237. # summary(model, (3, 640, 640), device=DEVICE)