ops.py 37 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108
  1. # Part of the implementation is borrowed and modified from SegLink,
  2. # publicly available at https://github.com/bgshih/seglink
  3. import math
  4. import os
  5. import shutil
  6. import sys
  7. import uuid
  8. import absl.flags as absl_flags
  9. import cv2
  10. import numpy as np
  11. import tensorflow as tf
  12. from . import utils
  13. if tf.__version__ >= '2.0':
  14. tf = tf.compat.v1
  15. # test
  16. # skip parse sys.argv in tf, so fix bug:
  17. # absl.flags._exceptions.UnrecognizedFlagError:
  18. # Unknown command line flag 'OCRDetectionPipeline: Unknown command line flag
  19. absl_flags.FLAGS(sys.argv, known_only=True)
  20. FLAGS = tf.app.flags.FLAGS
  21. tf.app.flags.DEFINE_string('weight_init_method', 'xavier',
  22. 'Weight initialization method')
  23. # constants
  24. OFFSET_DIM = 6
  25. RBOX_DIM = 5
  26. N_LOCAL_LINKS = 8
  27. N_CROSS_LINKS = 4
  28. N_SEG_CLASSES = 2
  29. N_LNK_CLASSES = 4
  30. MATCH_STATUS_POS = 1
  31. MATCH_STATUS_NEG = -1
  32. MATCH_STATUS_IGNORE = 0
  33. MUT_LABEL = 3
  34. POS_LABEL = 1
  35. NEG_LABEL = 0
  36. N_DET_LAYERS = 6
  37. def load_oplib(lib_name):
  38. """
  39. Load TensorFlow operator library.
  40. """
  41. # use absolute path so that ops.py can be called from other directory
  42. lib_path = os.path.join(
  43. os.path.dirname(os.path.realpath(__file__)),
  44. 'lib{0}.so'.format(lib_name))
  45. # duplicate library with a random new name so that
  46. # a running program will not be interrupted when the original library is updated
  47. lib_copy_path = '/tmp/lib{0}_{1}.so'.format(
  48. str(uuid.uuid4())[:8], LIB_NAME)
  49. shutil.copyfile(lib_path, lib_copy_path)
  50. oplib = tf.load_op_library(lib_copy_path)
  51. return oplib
  52. def _nn_variable(name, shape, init_method, collection=None, **kwargs):
  53. """
  54. Create or reuse a variable
  55. ARGS
  56. name: variable name
  57. shape: variable shape
  58. init_method: 'zero', 'kaiming', 'xavier', or (mean, std)
  59. collection: if not none, add variable to this collection
  60. kwargs: extra parameters passed to tf.get_variable
  61. RETURN
  62. var: a new or existing variable
  63. """
  64. if init_method == 'zero':
  65. initializer = tf.constant_initializer(0.0)
  66. elif init_method == 'kaiming':
  67. if len(shape) == 4: # convolutional filters
  68. kh, kw, n_in = shape[:3]
  69. init_std = math.sqrt(2.0 / (kh * kw * n_in))
  70. elif len(shape) == 2: # linear weights
  71. n_in, n_out = shape
  72. init_std = math.sqrt(1.0 / n_out)
  73. else:
  74. raise 'Unsupported shape'
  75. initializer = tf.truncated_normal_initializer(0.0, init_std)
  76. elif init_method == 'xavier':
  77. if len(shape) == 4:
  78. initializer = tf.keras.initializers.glorot_normal()
  79. else:
  80. initializer = tf.keras.initializers.glorot_normal()
  81. elif isinstance(init_method, tuple):
  82. assert (len(init_method) == 2)
  83. initializer = tf.truncated_normal_initializer(init_method[0],
  84. init_method[1])
  85. else:
  86. raise 'Unsupported weight initialization method: ' + init_method
  87. var = tf.get_variable(name, shape=shape, initializer=initializer)
  88. if collection is not None:
  89. tf.add_to_collection(collection, var)
  90. return var
  91. def conv2d(x,
  92. n_in,
  93. n_out,
  94. ksize,
  95. stride=1,
  96. padding='SAME',
  97. weight_init=None,
  98. bias=True,
  99. relu=False,
  100. scope=None,
  101. **kwargs):
  102. weight_init = weight_init or FLAGS.weight_init_method
  103. trainable = kwargs.get('trainable', True)
  104. # input_dim = n_in
  105. if (padding == 'SAME'):
  106. in_height = x.get_shape()[1]
  107. in_width = x.get_shape()[2]
  108. if (in_height % stride == 0):
  109. pad_along_height = max(ksize - stride, 0)
  110. else:
  111. pad_along_height = max(ksize - (in_height % stride), 0)
  112. if (in_width % stride == 0):
  113. pad_along_width = max(ksize - stride, 0)
  114. else:
  115. pad_along_width = max(ksize - (in_width % stride), 0)
  116. pad_bottom = pad_along_height // 2
  117. pad_top = pad_along_height - pad_bottom
  118. pad_right = pad_along_width // 2
  119. pad_left = pad_along_width - pad_right
  120. paddings = tf.constant([[0, 0], [pad_top, pad_bottom],
  121. [pad_left, pad_right], [0, 0]])
  122. input_padded = tf.pad(x, paddings, 'CONSTANT')
  123. else:
  124. input_padded = x
  125. with tf.variable_scope(scope or 'conv2d'):
  126. # convolution
  127. kernel = _nn_variable(
  128. 'weight', [ksize, ksize, n_in, n_out],
  129. weight_init,
  130. collection='weights' if trainable else None,
  131. **kwargs)
  132. yc = tf.nn.conv2d(
  133. input_padded, kernel, [1, stride, stride, 1], padding='VALID')
  134. # add bias
  135. if bias is True:
  136. bias = _nn_variable(
  137. 'bias', [n_out],
  138. 'zero',
  139. collection='biases' if trainable else None,
  140. **kwargs)
  141. yb = tf.nn.bias_add(yc, bias)
  142. # apply ReLU
  143. y = yb
  144. if relu is True:
  145. y = tf.nn.relu(yb)
  146. return yb, y
  147. def group_conv2d_relu(x,
  148. n_in,
  149. n_out,
  150. ksize,
  151. stride=1,
  152. group=4,
  153. padding='SAME',
  154. weight_init=None,
  155. bias=True,
  156. relu=False,
  157. name='group_conv2d',
  158. **kwargs):
  159. group_axis = len(x.get_shape()) - 1
  160. splits = tf.split(x, [int(n_in / group)] * group, group_axis)
  161. conv_list = []
  162. for i in range(group):
  163. conv_split, relu_split = conv2d(
  164. splits[i],
  165. n_in / group,
  166. n_out / group,
  167. ksize=ksize,
  168. stride=stride,
  169. padding=padding,
  170. weight_init=weight_init,
  171. bias=bias,
  172. relu=relu,
  173. scope='%s_%d' % (name, i))
  174. conv_list.append(conv_split)
  175. conv = tf.concat(values=conv_list, axis=group_axis, name=name + '_concat')
  176. relu = tf.nn.relu(conv)
  177. return conv, relu
  178. def group_conv2d_bn_relu(x,
  179. n_in,
  180. n_out,
  181. ksize,
  182. stride=1,
  183. group=4,
  184. padding='SAME',
  185. weight_init=None,
  186. bias=True,
  187. relu=False,
  188. name='group_conv2d',
  189. **kwargs):
  190. group_axis = len(x.get_shape()) - 1
  191. splits = tf.split(x, [int(n_in / group)] * group, group_axis)
  192. conv_list = []
  193. for i in range(group):
  194. conv_split, relu_split = conv2d(
  195. splits[i],
  196. n_in / group,
  197. n_out / group,
  198. ksize=ksize,
  199. stride=stride,
  200. padding=padding,
  201. weight_init=weight_init,
  202. bias=bias,
  203. relu=relu,
  204. scope='%s_%d' % (name, i))
  205. conv_list.append(conv_split)
  206. conv = tf.concat(values=conv_list, axis=group_axis, name=name + '_concat')
  207. with tf.variable_scope(name + '_bn'):
  208. bn = tf.layers.batch_normalization(
  209. conv, momentum=0.9, epsilon=1e-5, scale=True, training=True)
  210. relu = tf.nn.relu(bn)
  211. return conv, relu
  212. def next_conv(x,
  213. n_in,
  214. n_out,
  215. ksize,
  216. stride=1,
  217. group=4,
  218. padding='SAME',
  219. weight_init=None,
  220. bias=True,
  221. relu=False,
  222. name='next_conv2d',
  223. **kwargs):
  224. conv_a, relu_a = conv_relu(
  225. x,
  226. n_in,
  227. n_in / 2,
  228. ksize=1,
  229. stride=1,
  230. padding=padding,
  231. weight_init=weight_init,
  232. bias=bias,
  233. relu=relu,
  234. scope=name + '_a',
  235. **kwargs)
  236. conv_b, relu_b = group_conv2d_relu(
  237. relu_a,
  238. n_in / 2,
  239. n_out / 2,
  240. ksize=ksize,
  241. stride=stride,
  242. group=group,
  243. padding=padding,
  244. weight_init=weight_init,
  245. bias=bias,
  246. relu=relu,
  247. name=name + '_b',
  248. **kwargs)
  249. conv_c, relu_c = conv_relu(
  250. relu_b,
  251. n_out / 2,
  252. n_out,
  253. ksize=1,
  254. stride=1,
  255. padding=padding,
  256. weight_init=weight_init,
  257. bias=bias,
  258. relu=relu,
  259. scope=name + '_c',
  260. **kwargs)
  261. return conv_c, relu_c
  262. def next_conv_bn(x,
  263. n_in,
  264. n_out,
  265. ksize,
  266. stride=1,
  267. group=4,
  268. padding='SAME',
  269. weight_init=None,
  270. bias=True,
  271. relu=False,
  272. name='next_conv2d',
  273. **kwargs):
  274. conv_a, relu_a = conv_bn_relu(
  275. x,
  276. n_in,
  277. n_in / 2,
  278. ksize=1,
  279. stride=1,
  280. padding=padding,
  281. weight_init=weight_init,
  282. bias=bias,
  283. relu=relu,
  284. scope=name + '_a',
  285. **kwargs)
  286. conv_b, relu_b = group_conv2d_bn_relu(
  287. relu_a,
  288. n_in / 2,
  289. n_out / 2,
  290. ksize=ksize,
  291. stride=stride,
  292. group=group,
  293. padding=padding,
  294. weight_init=weight_init,
  295. bias=bias,
  296. relu=relu,
  297. name=name + '_b',
  298. **kwargs)
  299. conv_c, relu_c = conv_bn_relu(
  300. relu_b,
  301. n_out / 2,
  302. n_out,
  303. ksize=1,
  304. stride=1,
  305. padding=padding,
  306. weight_init=weight_init,
  307. bias=bias,
  308. relu=relu,
  309. scope=name + '_c',
  310. **kwargs)
  311. return conv_c, relu_c
  312. def conv2d_ori(x,
  313. n_in,
  314. n_out,
  315. ksize,
  316. stride=1,
  317. padding='SAME',
  318. weight_init=None,
  319. bias=True,
  320. relu=False,
  321. scope=None,
  322. **kwargs):
  323. weight_init = weight_init or FLAGS.weight_init_method
  324. trainable = kwargs.get('trainable', True)
  325. with tf.variable_scope(scope or 'conv2d'):
  326. # convolution
  327. kernel = _nn_variable(
  328. 'weight', [ksize, ksize, n_in, n_out],
  329. weight_init,
  330. collection='weights' if trainable else None,
  331. **kwargs)
  332. y = tf.nn.conv2d(x, kernel, [1, stride, stride, 1], padding=padding)
  333. # add bias
  334. if bias is True:
  335. bias = _nn_variable(
  336. 'bias', [n_out],
  337. 'zero',
  338. collection='biases' if trainable else None,
  339. **kwargs)
  340. y = tf.nn.bias_add(y, bias)
  341. # apply ReLU
  342. if relu is True:
  343. y = tf.nn.relu(y)
  344. return y
  345. def conv_relu(*args, **kwargs):
  346. kwargs['relu'] = True
  347. if 'scope' not in kwargs:
  348. kwargs['scope'] = 'conv_relu'
  349. return conv2d(*args, **kwargs)
  350. def conv_bn_relu(*args, **kwargs):
  351. kwargs['relu'] = True
  352. if 'scope' not in kwargs:
  353. kwargs['scope'] = 'conv_relu'
  354. conv, relu = conv2d(*args, **kwargs)
  355. with tf.variable_scope(kwargs['scope'] + '_bn'):
  356. bn = tf.layers.batch_normalization(
  357. conv, momentum=0.9, epsilon=1e-5, scale=True, training=True)
  358. bn_relu = tf.nn.relu(bn)
  359. return bn, bn_relu
  360. def conv_relu_ori(*args, **kwargs):
  361. kwargs['relu'] = True
  362. if 'scope' not in kwargs:
  363. kwargs['scope'] = 'conv_relu'
  364. return conv2d_ori(*args, **kwargs)
  365. def atrous_conv2d(x,
  366. n_in,
  367. n_out,
  368. ksize,
  369. dilation,
  370. padding='SAME',
  371. weight_init=None,
  372. bias=True,
  373. relu=False,
  374. scope=None,
  375. **kwargs):
  376. weight_init = weight_init or FLAGS.weight_init_method
  377. trainable = kwargs.get('trainable', True)
  378. with tf.variable_scope(scope or 'atrous_conv2d'):
  379. # atrous convolution
  380. kernel = _nn_variable(
  381. 'weight', [ksize, ksize, n_in, n_out],
  382. weight_init,
  383. collection='weights' if trainable else None,
  384. **kwargs)
  385. y = tf.nn.atrous_conv2d(x, kernel, dilation, padding=padding)
  386. # add bias
  387. if bias is True:
  388. bias = _nn_variable(
  389. 'bias', [n_out],
  390. 'zero',
  391. collection='biases' if trainable else None,
  392. **kwargs)
  393. y = tf.nn.bias_add(y, bias)
  394. # apply ReLU
  395. if relu is True:
  396. y = tf.nn.relu(y)
  397. return y
  398. def avg_pool(x, ksize, stride, padding='SAME', scope=None):
  399. with tf.variable_scope(scope or 'avg_pool'):
  400. y = tf.nn.avg_pool(x, [1, ksize, ksize, 1], [1, stride, stride, 1],
  401. padding)
  402. return y
  403. def max_pool(x, ksize, stride, padding='SAME', scope=None):
  404. with tf.variable_scope(scope or 'max_pool'):
  405. y = tf.nn.max_pool(x, [1, ksize, ksize, 1], [1, stride, stride, 1],
  406. padding)
  407. return y
  408. def score_loss(gt_labels, match_scores, n_classes):
  409. """
  410. Classification loss
  411. ARGS
  412. gt_labels: int32 [n]
  413. match_scores: [n, n_classes]
  414. RETURN
  415. loss
  416. """
  417. embeddings = tf.one_hot(tf.cast(gt_labels, tf.int64), n_classes, 1.0, 0.0)
  418. losses = tf.nn.softmax_cross_entropy_with_logits(match_scores, embeddings)
  419. return tf.reduce_sum(losses)
  420. def smooth_l1_loss(offsets, gt_offsets, scope=None):
  421. """
  422. Smooth L1 loss between offsets and encoded_gt
  423. ARGS
  424. offsets: [m?, 5], predicted offsets for one example
  425. gt_offsets: [m?, 5], corresponding groundtruth offsets
  426. RETURN
  427. loss: scalar
  428. """
  429. with tf.variable_scope(scope or 'smooth_l1_loss'):
  430. gt_offsets = tf.stop_gradient(gt_offsets)
  431. diff = tf.abs(offsets - gt_offsets)
  432. lesser_mask = tf.cast(tf.less(diff, 1.0), tf.float32)
  433. larger_mask = 1.0 - lesser_mask
  434. losses1 = (0.5 * tf.square(diff)) * lesser_mask
  435. losses2 = (diff - 0.5) * larger_mask
  436. return tf.reduce_sum(losses1 + losses2, 1)
  437. def polygon_to_rboxe(polygon):
  438. x1 = polygon[0]
  439. y1 = polygon[1]
  440. x2 = polygon[2]
  441. y2 = polygon[3]
  442. x3 = polygon[4]
  443. y3 = polygon[5]
  444. x4 = polygon[6]
  445. y4 = polygon[7]
  446. c_x = (x1 + x2 + x3 + x4) / 4
  447. c_y = (y1 + y2 + y3 + y4) / 4
  448. w1 = point_dist(x1, y1, x2, y2)
  449. w2 = point_dist(x3, y3, x4, y4)
  450. h1 = point_line_dist(c_x, c_y, x1, y1, x2, y2)
  451. h2 = point_line_dist(c_x, c_y, x3, y3, x4, y4)
  452. h = h1 + h2
  453. w = (w1 + w2) / 2
  454. theta1 = np.arctan2(y2 - y1, x2 - x1)
  455. theta2 = np.arctan2(y3 - y4, x3 - x4)
  456. theta = (theta1 + theta2) / 2
  457. return np.array([c_x, c_y, w, h, theta])
  458. def point_dist(x1, y1, x2, y2):
  459. return np.sqrt((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1))
  460. def point_line_dist(px, py, x1, y1, x2, y2):
  461. eps = 1e-6
  462. dx = x2 - x1
  463. dy = y2 - y1
  464. div = np.sqrt(dx * dx + dy * dy) + eps
  465. dist = np.abs(px * dy - py * dx + x2 * y1 - y2 * x1) / div
  466. return dist
  467. def get_combined_polygon(rboxes, resize_size):
  468. image_w = resize_size[1]
  469. image_h = resize_size[0]
  470. img = np.zeros((image_h, image_w, 3), np.uint8)
  471. for i in range(rboxes.shape[0]):
  472. segment = np.reshape(
  473. np.array(utils.rboxes_to_polygons(rboxes)[i, :], np.int32),
  474. (-1, 1, 2))
  475. cv2.drawContours(img, [segment], 0, (255, 255, 255), -1)
  476. img2gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  477. ret, thresh = cv2.threshold(img2gray, 127, 255, cv2.THRESH_BINARY)
  478. im2, contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE,
  479. cv2.CHAIN_APPROX_SIMPLE)
  480. if len(contours) > 0:
  481. cnt = contours[0]
  482. max_area = cv2.contourArea(cnt)
  483. # get max_area
  484. for cont in contours:
  485. if cv2.contourArea(cont) > max_area:
  486. cnt = cont
  487. max_area = cv2.contourArea(cont)
  488. rect = cv2.minAreaRect(cnt)
  489. combined_polygon = np.array(cv2.boxPoints(rect)).reshape(-1)
  490. else:
  491. combined_polygon = np.array([0, 0, 0, 0, 0, 0, 0, 0])
  492. return combined_polygon
  493. def combine_segs(segs):
  494. segs = np.asarray(segs)
  495. assert segs.ndim == 2, 'invalid segs ndim'
  496. assert segs.shape[-1] == 6, 'invalid segs shape'
  497. if len(segs) == 1:
  498. cx = segs[0, 0]
  499. cy = segs[0, 1]
  500. w = segs[0, 2]
  501. h = segs[0, 3]
  502. theta_sin = segs[0, 4]
  503. theta_cos = segs[0, 5]
  504. theta = np.arctan2(theta_sin, theta_cos)
  505. return np.array([cx, cy, w, h, theta])
  506. # find the best straight line fitting all center points: y = kx + b
  507. cxs = segs[:, 0]
  508. cys = segs[:, 1]
  509. theta_coss = segs[:, 4]
  510. theta_sins = segs[:, 5]
  511. bar_theta = np.arctan2(theta_sins.sum(), theta_coss.sum())
  512. k = np.tan(bar_theta)
  513. b = np.mean(cys - k * cxs)
  514. proj_xs = (k * cys + cxs - k * b) / (k**2 + 1)
  515. proj_ys = (k * k * cys + k * cxs + b) / (k**2 + 1)
  516. proj_points = np.stack((proj_xs, proj_ys), -1)
  517. # find the max distance
  518. max_dist = -1
  519. idx1 = -1
  520. idx2 = -1
  521. for i in range(len(proj_points)):
  522. point1 = proj_points[i, :]
  523. for j in range(i + 1, len(proj_points)):
  524. point2 = proj_points[j, :]
  525. dist = np.sqrt(np.sum((point1 - point2)**2))
  526. if dist > max_dist:
  527. idx1 = i
  528. idx2 = j
  529. max_dist = dist
  530. assert idx1 >= 0 and idx2 >= 0
  531. # the bbox: bcx, bcy, bw, bh, average_theta
  532. seg1 = segs[idx1, :]
  533. seg2 = segs[idx2, :]
  534. bcx, bcy = (seg1[:2] + seg2[:2]) / 2.0
  535. bh = np.mean(segs[:, 3])
  536. bw = max_dist + (seg1[2] + seg2[2]) / 2.0
  537. return bcx, bcy, bw, bh, bar_theta
  538. def combine_segments_batch(segments_batch, group_indices_batch,
  539. segment_counts_batch):
  540. batch_size = 1
  541. combined_rboxes_batch = []
  542. combined_counts_batch = []
  543. for image_id in range(batch_size):
  544. group_count = segment_counts_batch[image_id]
  545. segments = segments_batch[image_id, :, :]
  546. group_indices = group_indices_batch[image_id, :]
  547. combined_rboxes = []
  548. for i in range(group_count):
  549. segments_group = segments[np.where(group_indices == i)[0], :]
  550. if segments_group.shape[0] > 0:
  551. combined_rbox = combine_segs(segments_group)
  552. combined_rboxes.append(combined_rbox)
  553. combined_rboxes_batch.append(combined_rboxes)
  554. combined_counts_batch.append(len(combined_rboxes))
  555. max_count = np.max(combined_counts_batch)
  556. for image_id in range(batch_size):
  557. if not combined_counts_batch[image_id] == max_count:
  558. combined_rboxes_pad = (max_count - combined_counts_batch[image_id]
  559. ) * [RBOX_DIM * [0.0]]
  560. combined_rboxes_batch[image_id] = np.vstack(
  561. (combined_rboxes_batch[image_id],
  562. np.array(combined_rboxes_pad)))
  563. return np.asarray(combined_rboxes_batch,
  564. np.float32), np.asarray(combined_counts_batch, np.int32)
  565. # combine_segments rewrite in python version
  566. def combine_segments_python(segments, group_indices, segment_counts):
  567. combined_rboxes, combined_counts = tf.py_func(
  568. combine_segments_batch, [segments, group_indices, segment_counts],
  569. [tf.float32, tf.int32])
  570. return combined_rboxes, combined_counts
  571. # decode_segments_links rewrite in python version
  572. def get_coord(offsets, map_size, offsets_defaults):
  573. if offsets < offsets_defaults[1][0]:
  574. l_idx = 0
  575. x = offsets % map_size[0][1]
  576. y = offsets // map_size[0][1]
  577. elif offsets < offsets_defaults[2][0]:
  578. l_idx = 1
  579. x = (offsets - offsets_defaults[1][0]) % map_size[1][1]
  580. y = (offsets - offsets_defaults[1][0]) // map_size[1][1]
  581. elif offsets < offsets_defaults[3][0]:
  582. l_idx = 2
  583. x = (offsets - offsets_defaults[2][0]) % map_size[2][1]
  584. y = (offsets - offsets_defaults[2][0]) // map_size[2][1]
  585. elif offsets < offsets_defaults[4][0]:
  586. l_idx = 3
  587. x = (offsets - offsets_defaults[3][0]) % map_size[3][1]
  588. y = (offsets - offsets_defaults[3][0]) // map_size[3][1]
  589. elif offsets < offsets_defaults[5][0]:
  590. l_idx = 4
  591. x = (offsets - offsets_defaults[4][0]) % map_size[4][1]
  592. y = (offsets - offsets_defaults[4][0]) // map_size[4][1]
  593. else:
  594. l_idx = 5
  595. x = (offsets - offsets_defaults[5][0]) % map_size[5][1]
  596. y = (offsets - offsets_defaults[5][0]) // map_size[5][1]
  597. return l_idx, x, y
  598. def get_coord_link(offsets, map_size, offsets_defaults):
  599. if offsets < offsets_defaults[1][1]:
  600. offsets_node = offsets // N_LOCAL_LINKS
  601. link_idx = offsets % N_LOCAL_LINKS
  602. else:
  603. offsets_node = (offsets - offsets_defaults[1][1]) // (
  604. N_LOCAL_LINKS + N_CROSS_LINKS) + offsets_defaults[1][0]
  605. link_idx = (offsets - offsets_defaults[1][1]) % (
  606. N_LOCAL_LINKS + N_CROSS_LINKS)
  607. l_idx, x, y = get_coord(offsets_node, map_size, offsets_defaults)
  608. return l_idx, x, y, link_idx
  609. def is_valid_coord(l_idx, x, y, map_size):
  610. w = map_size[l_idx][1]
  611. h = map_size[l_idx][0]
  612. return x >= 0 and x < w and y >= 0 and y < h
  613. def get_neighbours(l_idx, x, y, map_size, offsets_defaults):
  614. if l_idx == 0:
  615. coord = [(0, x - 1, y - 1), (0, x, y - 1), (0, x + 1, y - 1),
  616. (0, x - 1, y), (0, x + 1, y), (0, x - 1, y + 1),
  617. (0, x, y + 1), (0, x + 1, y + 1)]
  618. else:
  619. coord = [(l_idx, x - 1, y - 1),
  620. (l_idx, x, y - 1), (l_idx, x + 1, y - 1), (l_idx, x - 1, y),
  621. (l_idx, x + 1, y), (l_idx, x - 1, y + 1), (l_idx, x, y + 1),
  622. (l_idx, x + 1, y + 1), (l_idx - 1, 2 * x, 2 * y),
  623. (l_idx - 1, 2 * x + 1, 2 * y), (l_idx - 1, 2 * x, 2 * y + 1),
  624. (l_idx - 1, 2 * x + 1, 2 * y + 1)]
  625. neighbours_offsets = []
  626. link_idx = 0
  627. for nl_idx, nx, ny in coord:
  628. if is_valid_coord(nl_idx, nx, ny, map_size):
  629. neighbours_offset_node = offsets_defaults[nl_idx][
  630. 0] + map_size[nl_idx][1] * ny + nx
  631. if l_idx == 0:
  632. neighbours_offset_link = offsets_defaults[l_idx][1] + (
  633. map_size[l_idx][1] * y + x) * N_LOCAL_LINKS + link_idx
  634. else:
  635. off_tmp = (map_size[l_idx][1] * y + x) * (
  636. N_LOCAL_LINKS + N_CROSS_LINKS)
  637. neighbours_offset_link = offsets_defaults[l_idx][
  638. 1] + off_tmp + link_idx
  639. neighbours_offsets.append(
  640. [neighbours_offset_node, neighbours_offset_link, link_idx])
  641. link_idx += 1
  642. # [node_offsets, link_offsets, link_idx(0-7/11)]
  643. return neighbours_offsets
  644. def decode_segments_links_python(image_size, all_nodes, all_links, all_reg,
  645. anchor_sizes):
  646. batch_size = 1 # FLAGS.test_batch_size
  647. # offsets = 12285 #768
  648. all_nodes_flat = tf.concat(
  649. [tf.reshape(o, [batch_size, -1, N_SEG_CLASSES]) for o in all_nodes],
  650. axis=1)
  651. all_links_flat = tf.concat(
  652. [tf.reshape(o, [batch_size, -1, N_LNK_CLASSES]) for o in all_links],
  653. axis=1)
  654. all_reg_flat = tf.concat(
  655. [tf.reshape(o, [batch_size, -1, OFFSET_DIM]) for o in all_reg], axis=1)
  656. segments, group_indices, segment_counts, group_indices_all = tf.py_func(
  657. decode_batch, [
  658. all_nodes_flat, all_links_flat, all_reg_flat, image_size,
  659. tf.constant(anchor_sizes)
  660. ], [tf.float32, tf.int32, tf.int32, tf.int32])
  661. return segments, group_indices, segment_counts, group_indices_all
  662. def decode_segments_links_train(image_size, all_nodes, all_links, all_reg,
  663. anchor_sizes):
  664. batch_size = FLAGS.train_batch_size
  665. # offsets = 12285 #768
  666. all_nodes_flat = tf.concat(
  667. [tf.reshape(o, [batch_size, -1, N_SEG_CLASSES]) for o in all_nodes],
  668. axis=1)
  669. all_links_flat = tf.concat(
  670. [tf.reshape(o, [batch_size, -1, N_LNK_CLASSES]) for o in all_links],
  671. axis=1)
  672. all_reg_flat = tf.concat(
  673. [tf.reshape(o, [batch_size, -1, OFFSET_DIM]) for o in all_reg], axis=1)
  674. segments, group_indices, segment_counts, group_indices_all = tf.py_func(
  675. decode_batch, [
  676. all_nodes_flat, all_links_flat, all_reg_flat, image_size,
  677. tf.constant(anchor_sizes)
  678. ], [tf.float32, tf.int32, tf.int32, tf.int32])
  679. return segments, group_indices, segment_counts, group_indices_all
  680. def decode_batch(all_nodes, all_links, all_reg, image_size, anchor_sizes):
  681. batch_size = all_nodes.shape[0]
  682. batch_segments = []
  683. batch_group_indices = []
  684. batch_segments_counts = []
  685. batch_group_indices_all = []
  686. for image_id in range(batch_size):
  687. image_node_scores = all_nodes[image_id, :, :]
  688. image_link_scores = all_links[image_id, :, :]
  689. image_reg = all_reg[image_id, :, :]
  690. image_segments, image_group_indices, image_segments_counts, image_group_indices_all = decode_image(
  691. image_node_scores, image_link_scores, image_reg, image_size,
  692. anchor_sizes)
  693. batch_segments.append(image_segments)
  694. batch_group_indices.append(image_group_indices)
  695. batch_segments_counts.append(image_segments_counts)
  696. batch_group_indices_all.append(image_group_indices_all)
  697. max_count = np.max(batch_segments_counts)
  698. for image_id in range(batch_size):
  699. if not batch_segments_counts[image_id] == max_count:
  700. batch_segments_pad = (max_count - batch_segments_counts[image_id]
  701. ) * [OFFSET_DIM * [0.0]]
  702. batch_segments[image_id] = np.vstack(
  703. (batch_segments[image_id], np.array(batch_segments_pad)))
  704. batch_group_indices[image_id] = np.hstack(
  705. (batch_group_indices[image_id],
  706. np.array(
  707. (max_count - batch_segments_counts[image_id]) * [-1])))
  708. return np.asarray(batch_segments, np.float32), np.asarray(
  709. batch_group_indices,
  710. np.int32), np.asarray(batch_segments_counts,
  711. np.int32), np.asarray(batch_group_indices_all,
  712. np.int32)
  713. def decode_image(image_node_scores, image_link_scores, image_reg, image_size,
  714. anchor_sizes):
  715. map_size = []
  716. offsets_defaults = []
  717. offsets_default_node = 0
  718. offsets_default_link = 0
  719. for i in range(N_DET_LAYERS):
  720. offsets_defaults.append([offsets_default_node, offsets_default_link])
  721. map_size.append(image_size // (2**(2 + i)))
  722. offsets_default_node += map_size[i][0] * map_size[i][1]
  723. if i == 0:
  724. offsets_default_link += map_size[i][0] * map_size[i][
  725. 1] * N_LOCAL_LINKS
  726. else:
  727. offsets_default_link += map_size[i][0] * map_size[i][1] * (
  728. N_LOCAL_LINKS + N_CROSS_LINKS)
  729. image_group_indices_all = decode_image_by_join(image_node_scores,
  730. image_link_scores,
  731. FLAGS.node_threshold,
  732. FLAGS.link_threshold,
  733. map_size, offsets_defaults)
  734. image_group_indices_all -= 1
  735. image_group_indices = image_group_indices_all[np.where(
  736. image_group_indices_all >= 0)[0]]
  737. image_segments_counts = len(image_group_indices)
  738. # convert image_reg to segments with scores(OFFSET_DIM+1)
  739. image_segments = np.zeros((image_segments_counts, OFFSET_DIM),
  740. dtype=np.float32)
  741. for i, offsets in enumerate(np.where(image_group_indices_all >= 0)[0]):
  742. encoded_cx = image_reg[offsets, 0]
  743. encoded_cy = image_reg[offsets, 1]
  744. encoded_width = image_reg[offsets, 2]
  745. encoded_height = image_reg[offsets, 3]
  746. encoded_theta_cos = image_reg[offsets, 4]
  747. encoded_theta_sin = image_reg[offsets, 5]
  748. l_idx, x, y = get_coord(offsets, map_size, offsets_defaults)
  749. rs = anchor_sizes[l_idx]
  750. eps = 1e-6
  751. image_segments[i, 0] = encoded_cx * rs + (2**(2 + l_idx)) * (x + 0.5)
  752. image_segments[i, 1] = encoded_cy * rs + (2**(2 + l_idx)) * (y + 0.5)
  753. image_segments[i, 2] = np.exp(encoded_width) * rs - eps
  754. image_segments[i, 3] = np.exp(encoded_height) * rs - eps
  755. image_segments[i, 4] = encoded_theta_cos
  756. image_segments[i, 5] = encoded_theta_sin
  757. return image_segments, image_group_indices, image_segments_counts, image_group_indices_all
  758. def decode_image_by_join(node_scores, link_scores, node_threshold,
  759. link_threshold, map_size, offsets_defaults):
  760. node_mask = node_scores[:, POS_LABEL] >= node_threshold
  761. link_mask = link_scores[:, POS_LABEL] >= link_threshold
  762. group_mask = np.zeros_like(node_mask, np.int32) - 1
  763. offsets_pos = np.where(node_mask == 1)[0]
  764. def find_parent(point):
  765. return group_mask[point]
  766. def set_parent(point, parent):
  767. group_mask[point] = parent
  768. def is_root(point):
  769. return find_parent(point) == -1
  770. def find_root(point):
  771. root = point
  772. update_parent = False
  773. while not is_root(root):
  774. root = find_parent(root)
  775. update_parent = True
  776. # for acceleration of find_root
  777. if update_parent:
  778. set_parent(point, root)
  779. return root
  780. def join(p1, p2):
  781. root1 = find_root(p1)
  782. root2 = find_root(p2)
  783. if root1 != root2:
  784. set_parent(root1, root2)
  785. def get_all():
  786. root_map = {}
  787. def get_index(root):
  788. if root not in root_map:
  789. root_map[root] = len(root_map) + 1
  790. return root_map[root]
  791. mask = np.zeros_like(node_mask, dtype=np.int32)
  792. for i, point in enumerate(offsets_pos):
  793. point_root = find_root(point)
  794. bbox_idx = get_index(point_root)
  795. mask[point] = bbox_idx
  796. return mask
  797. # join by link
  798. pos_link = 0
  799. for i, offsets in enumerate(offsets_pos):
  800. l_idx, x, y = get_coord(offsets, map_size, offsets_defaults)
  801. neighbours = get_neighbours(l_idx, x, y, map_size, offsets_defaults)
  802. for n_idx, noffsets in enumerate(neighbours):
  803. link_value = link_mask[noffsets[1]]
  804. node_cls = node_mask[noffsets[0]]
  805. if link_value and node_cls:
  806. pos_link += 1
  807. join(offsets, noffsets[0])
  808. # print(pos_link)
  809. mask = get_all()
  810. return mask
  811. def get_link_mask(node_mask, offsets_defaults, link_max):
  812. link_mask = np.zeros_like(link_max)
  813. link_mask[0:offsets_defaults[1][1]] = np.tile(
  814. node_mask[0:offsets_defaults[1][0]],
  815. (N_LOCAL_LINKS, 1)).transpose().reshape(offsets_defaults[1][1])
  816. link_mask[offsets_defaults[1][1]:offsets_defaults[2][1]] = np.tile(
  817. node_mask[offsets_defaults[1][0]:offsets_defaults[2][0]],
  818. (N_LOCAL_LINKS + N_CROSS_LINKS, 1)).transpose().reshape(
  819. (offsets_defaults[2][1] - offsets_defaults[1][1]))
  820. link_mask[offsets_defaults[2][1]:offsets_defaults[3][1]] = np.tile(
  821. node_mask[offsets_defaults[2][0]:offsets_defaults[3][0]],
  822. (N_LOCAL_LINKS + N_CROSS_LINKS, 1)).transpose().reshape(
  823. (offsets_defaults[3][1] - offsets_defaults[2][1]))
  824. link_mask[offsets_defaults[3][1]:offsets_defaults[4][1]] = np.tile(
  825. node_mask[offsets_defaults[3][0]:offsets_defaults[4][0]],
  826. (N_LOCAL_LINKS + N_CROSS_LINKS, 1)).transpose().reshape(
  827. (offsets_defaults[4][1] - offsets_defaults[3][1]))
  828. link_mask[offsets_defaults[4][1]:offsets_defaults[5][1]] = np.tile(
  829. node_mask[offsets_defaults[4][0]:offsets_defaults[5][0]],
  830. (N_LOCAL_LINKS + N_CROSS_LINKS, 1)).transpose().reshape(
  831. (offsets_defaults[5][1] - offsets_defaults[4][1]))
  832. link_mask[offsets_defaults[5][1]:] = np.tile(
  833. node_mask[offsets_defaults[5][0]:],
  834. (N_LOCAL_LINKS + N_CROSS_LINKS, 1)).transpose().reshape(
  835. (len(link_mask) - offsets_defaults[5][1]))
  836. return link_mask
  837. def get_link8(link_scores_raw, map_size):
  838. # link[i-1] -local- start -16- end -cross- link[i]
  839. link8_mask = np.zeros((link_scores_raw.shape[0]))
  840. for i in range(N_DET_LAYERS):
  841. if i == 0:
  842. offsets_start = map_size[i][0] * map_size[i][1] * N_LOCAL_LINKS
  843. offsets_end = map_size[i][0] * map_size[i][1] * (
  844. N_LOCAL_LINKS + 16)
  845. offsets_link = map_size[i][0] * map_size[i][1] * (
  846. N_LOCAL_LINKS + 16)
  847. link8_mask[:offsets_start] = 1
  848. else:
  849. offsets_start = offsets_link + map_size[i][0] * map_size[i][
  850. 1] * N_LOCAL_LINKS
  851. offsets_end = offsets_link + map_size[i][0] * map_size[i][1] * (
  852. N_LOCAL_LINKS + 16)
  853. offsets_link_pre = offsets_link
  854. offsets_link += map_size[i][0] * map_size[i][1] * (
  855. N_LOCAL_LINKS + 16 + N_CROSS_LINKS)
  856. link8_mask[offsets_link_pre:offsets_start] = 1
  857. link8_mask[offsets_end:offsets_link] = 1
  858. return link_scores_raw[np.where(link8_mask > 0)[0], :]
  859. def decode_image_by_mutex(node_scores, link_scores, node_threshold,
  860. link_threshold, map_size, offsets_defaults):
  861. node_mask = node_scores[:, POS_LABEL] >= node_threshold
  862. link_pos = link_scores[:, POS_LABEL]
  863. link_mut = link_scores[:, MUT_LABEL]
  864. link_max = np.max(np.vstack((link_pos, link_mut)), axis=0)
  865. offsets_pos_list = np.where(node_mask == 1)[0].tolist()
  866. link_mask_th = link_max >= link_threshold
  867. link_mask = get_link_mask(node_mask, offsets_defaults, link_max)
  868. offsets_link_max = np.argsort(-(link_max * link_mask * link_mask_th))
  869. offsets_link_max = offsets_link_max[:len(offsets_pos_list) * 8]
  870. group_mask = np.zeros_like(node_mask, dtype=np.int32) - 1
  871. mutex_mask = len(node_mask) * [[]]
  872. def find_parent(point):
  873. return group_mask[point]
  874. def set_parent(point, parent):
  875. group_mask[point] = parent
  876. def set_mutex_constraint(point, mutex_point_list):
  877. mutex_mask[point] = mutex_point_list
  878. def find_mutex_constraint(point):
  879. mutex_point_list = mutex_mask[point]
  880. # update mutex_point_list
  881. mutex_point_list_new = []
  882. if not mutex_point_list == []:
  883. for mutex_point in mutex_point_list:
  884. if not is_root(mutex_point):
  885. mutex_point = find_root(mutex_point)
  886. if mutex_point not in mutex_point_list_new:
  887. mutex_point_list_new.append(mutex_point)
  888. set_mutex_constraint(point, mutex_point_list_new)
  889. return mutex_point_list_new
  890. def combine_mutex_constraint(point, parent):
  891. mutex_point_list = find_mutex_constraint(point)
  892. mutex_parent_list = find_mutex_constraint(parent)
  893. for mutex_point in mutex_point_list:
  894. if not is_root(mutex_point):
  895. mutex_point = find_root(mutex_point)
  896. if mutex_point not in mutex_parent_list:
  897. mutex_parent_list.append(mutex_point)
  898. set_mutex_constraint(parent, mutex_parent_list)
  899. def add_mutex_constraint(p1, p2):
  900. mutex_point_list1 = find_mutex_constraint(p1)
  901. mutex_point_list2 = find_mutex_constraint(p2)
  902. if p1 not in mutex_point_list2:
  903. mutex_point_list2.append(p1)
  904. if p2 not in mutex_point_list1:
  905. mutex_point_list1.append(p2)
  906. set_mutex_constraint(p1, mutex_point_list1)
  907. set_mutex_constraint(p2, mutex_point_list2)
  908. def is_root(point):
  909. return find_parent(point) == -1
  910. def find_root(point):
  911. root = point
  912. update_parent = False
  913. while not is_root(root):
  914. root = find_parent(root)
  915. update_parent = True
  916. # for acceleration of find_root
  917. if update_parent:
  918. set_parent(point, root)
  919. return root
  920. def join(p1, p2):
  921. root1 = find_root(p1)
  922. root2 = find_root(p2)
  923. if root1 != root2 and (root1 not in find_mutex_constraint(root2)):
  924. set_parent(root1, root2)
  925. combine_mutex_constraint(root1, root2)
  926. def disjoin(p1, p2):
  927. root1 = find_root(p1)
  928. root2 = find_root(p2)
  929. if root1 != root2:
  930. add_mutex_constraint(root1, root2)
  931. def get_all():
  932. root_map = {}
  933. def get_index(root):
  934. if root not in root_map:
  935. root_map[root] = len(root_map) + 1
  936. return root_map[root]
  937. mask = np.zeros_like(node_mask, dtype=np.int32)
  938. for _, point in enumerate(offsets_pos_list):
  939. point_root = find_root(point)
  940. bbox_idx = get_index(point_root)
  941. mask[point] = bbox_idx
  942. return mask
  943. # join by link
  944. pos_link = 0
  945. mut_link = 0
  946. for _, offsets_link in enumerate(offsets_link_max):
  947. l_idx, x, y, link_idx = get_coord_link(offsets_link, map_size,
  948. offsets_defaults)
  949. offsets = offsets_defaults[l_idx][0] + map_size[l_idx][1] * y + x
  950. if offsets in offsets_pos_list:
  951. neighbours = get_neighbours(l_idx, x, y, map_size,
  952. offsets_defaults)
  953. if not len(np.where(np.array(neighbours)[:,
  954. 2] == link_idx)[0]) == 0:
  955. noffsets = neighbours[np.where(
  956. np.array(neighbours)[:, 2] == link_idx)[0][0]]
  957. link_pos_value = link_pos[noffsets[1]]
  958. link_mut_value = link_mut[noffsets[1]]
  959. node_cls = node_mask[noffsets[0]]
  960. if node_cls and (link_pos_value > link_mut_value):
  961. pos_link += 1
  962. join(offsets, noffsets[0])
  963. elif node_cls and (link_pos_value < link_mut_value):
  964. mut_link += 1
  965. disjoin(offsets, noffsets[0])
  966. mask = get_all()
  967. return mask