loss.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. '''
  2. CVPR 2020 submission, Paper ID 6791
  3. Source code for 'Learning to Cartoonize Using White-Box Cartoon Representations'
  4. '''
  5. import os.path as osp
  6. import numpy as np
  7. import scipy.stats as st
  8. import tensorflow as tf
  9. from joblib import Parallel, delayed
  10. from skimage import color, segmentation
  11. from .network import disc_sn
  12. VGG_MEAN = [103.939, 116.779, 123.68]
  13. class Vgg19:
  14. def __init__(self, vgg19_npy_path=None):
  15. self.data_dict = np.load(
  16. vgg19_npy_path, encoding='latin1', allow_pickle=True).item()
  17. print('Finished loading vgg19.npy')
  18. def build_conv4_4(self, rgb, include_fc=False):
  19. rgb_scaled = (rgb + 1) * 127.5
  20. blue, green, red = tf.split(
  21. axis=3, num_or_size_splits=3, value=rgb_scaled)
  22. bgr = tf.concat(
  23. axis=3,
  24. values=[
  25. blue - VGG_MEAN[0], green - VGG_MEAN[1], red - VGG_MEAN[2]
  26. ])
  27. self.conv1_1 = self.conv_layer(bgr, 'conv1_1')
  28. self.relu1_1 = tf.nn.relu(self.conv1_1)
  29. self.conv1_2 = self.conv_layer(self.relu1_1, 'conv1_2')
  30. self.relu1_2 = tf.nn.relu(self.conv1_2)
  31. self.pool1 = self.max_pool(self.relu1_2, 'pool1')
  32. self.conv2_1 = self.conv_layer(self.pool1, 'conv2_1')
  33. self.relu2_1 = tf.nn.relu(self.conv2_1)
  34. self.conv2_2 = self.conv_layer(self.relu2_1, 'conv2_2')
  35. self.relu2_2 = tf.nn.relu(self.conv2_2)
  36. self.pool2 = self.max_pool(self.relu2_2, 'pool2')
  37. self.conv3_1 = self.conv_layer(self.pool2, 'conv3_1')
  38. self.relu3_1 = tf.nn.relu(self.conv3_1)
  39. self.conv3_2 = self.conv_layer(self.relu3_1, 'conv3_2')
  40. self.relu3_2 = tf.nn.relu(self.conv3_2)
  41. self.conv3_3 = self.conv_layer(self.relu3_2, 'conv3_3')
  42. self.relu3_3 = tf.nn.relu(self.conv3_3)
  43. self.conv3_4 = self.conv_layer(self.relu3_3, 'conv3_4')
  44. self.relu3_4 = tf.nn.relu(self.conv3_4)
  45. self.pool3 = self.max_pool(self.relu3_4, 'pool3')
  46. self.conv4_1 = self.conv_layer(self.pool3, 'conv4_1')
  47. self.relu4_1 = tf.nn.relu(self.conv4_1)
  48. self.conv4_2 = self.conv_layer(self.relu4_1, 'conv4_2')
  49. self.relu4_2 = tf.nn.relu(self.conv4_2)
  50. self.conv4_3 = self.conv_layer(self.relu4_2, 'conv4_3')
  51. self.relu4_3 = tf.nn.relu(self.conv4_3)
  52. self.conv4_4 = self.conv_layer(self.relu4_3, 'conv4_4')
  53. self.relu4_4 = tf.nn.relu(self.conv4_4)
  54. self.pool4 = self.max_pool(self.relu4_4, 'pool4')
  55. return self.conv4_4
  56. def max_pool(self, bottom, name):
  57. return tf.nn.max_pool(
  58. bottom,
  59. ksize=[1, 2, 2, 1],
  60. strides=[1, 2, 2, 1],
  61. padding='SAME',
  62. name=name)
  63. def conv_layer(self, bottom, name):
  64. with tf.variable_scope(name):
  65. filt = self.get_conv_filter(name)
  66. conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')
  67. conv_biases = self.get_bias(name)
  68. bias = tf.nn.bias_add(conv, conv_biases)
  69. return bias
  70. def fc_layer(self, bottom, name):
  71. with tf.variable_scope(name):
  72. shape = bottom.get_shape().as_list()
  73. dim = 1
  74. for d in shape[1:]:
  75. dim *= d
  76. x = tf.reshape(bottom, [-1, dim])
  77. weights = self.get_fc_weight(name)
  78. biases = self.get_bias(name)
  79. fc = tf.nn.bias_add(tf.matmul(x, weights), biases)
  80. return fc
  81. def get_conv_filter(self, name):
  82. return tf.constant(self.data_dict[name][0], name='filter')
  83. def get_bias(self, name):
  84. return tf.constant(self.data_dict[name][1], name='biases')
  85. def get_fc_weight(self, name):
  86. return tf.constant(self.data_dict[name][0], name='weights')
  87. def content_loss(model_dir, input_photo, transfer_res, input_superpixel):
  88. vgg_model = Vgg19(osp.join(model_dir, 'vgg19.npy'))
  89. vgg_photo = vgg_model.build_conv4_4(input_photo)
  90. vgg_output = vgg_model.build_conv4_4(transfer_res)
  91. vgg_superpixel = vgg_model.build_conv4_4(input_superpixel)
  92. h, w, c = vgg_photo.get_shape().as_list()[1:]
  93. abs_photo = tf.losses.absolute_difference(vgg_photo, vgg_output)
  94. photo_loss = tf.reduce_mean(abs_photo) / (h * w * c)
  95. abs_superpixel = tf.losses.absolute_difference(vgg_superpixel, vgg_output)
  96. superpixel_loss = tf.reduce_mean(abs_superpixel) / (h * w * c)
  97. loss = photo_loss + superpixel_loss
  98. return loss
  99. def style_loss(input_cartoon, output_cartoon):
  100. blur_fake = guided_filter(output_cartoon, output_cartoon, r=5, eps=2e-1)
  101. blur_cartoon = guided_filter(input_cartoon, input_cartoon, r=5, eps=2e-1)
  102. gray_fake, gray_cartoon = color_shift(output_cartoon, input_cartoon)
  103. d_loss_gray, g_loss_gray = lsgan_loss(
  104. disc_sn,
  105. gray_cartoon,
  106. gray_fake,
  107. scale=1,
  108. patch=True,
  109. name='disc_gray')
  110. d_loss_blur, g_loss_blur = lsgan_loss(
  111. disc_sn,
  112. blur_cartoon,
  113. blur_fake,
  114. scale=1,
  115. patch=True,
  116. name='disc_blur')
  117. sty_g_loss = (g_loss_blur) + g_loss_gray
  118. sty_d_loss = d_loss_blur + d_loss_gray
  119. return sty_g_loss, sty_d_loss
  120. def gan_loss(discriminator,
  121. real,
  122. fake,
  123. scale=1,
  124. channel=32,
  125. patch=False,
  126. name='discriminator'):
  127. real_logit = discriminator(
  128. real, scale, channel, name=name, patch=patch, reuse=False)
  129. fake_logit = discriminator(
  130. fake, scale, channel, name=name, patch=patch, reuse=True)
  131. real_logit = tf.nn.sigmoid(real_logit)
  132. fake_logit = tf.nn.sigmoid(fake_logit)
  133. g_loss_blur = -tf.reduce_mean(tf.log(fake_logit))
  134. d_loss_blur = -tf.reduce_mean(tf.log(real_logit) + tf.log(1. - fake_logit))
  135. return d_loss_blur, g_loss_blur
  136. def lsgan_loss(discriminator,
  137. real,
  138. fake,
  139. scale=1,
  140. channel=32,
  141. patch=False,
  142. name='discriminator'):
  143. real_logit = discriminator(
  144. real, scale, channel, name=name, patch=patch, reuse=False)
  145. fake_logit = discriminator(
  146. fake, scale, channel, name=name, patch=patch, reuse=True)
  147. g_loss = tf.reduce_mean((fake_logit - 1)**2)
  148. d_loss = 0.5 * (
  149. tf.reduce_mean((real_logit - 1)**2) + tf.reduce_mean(fake_logit**2))
  150. return d_loss, g_loss
  151. def total_variation_loss(image, k_size=1):
  152. h, w = image.get_shape().as_list()[1:3]
  153. tv_h = tf.reduce_mean(
  154. (image[:, k_size:, :, :] - image[:, :h - k_size, :, :])**2)
  155. tv_w = tf.reduce_mean(
  156. (image[:, :, k_size:, :] - image[:, :, :w - k_size, :])**2)
  157. tv_loss = (tv_h + tv_w) / (3 * h * w)
  158. return tv_loss
  159. def guided_filter(x, y, r, eps=1e-2):
  160. x_shape = tf.shape(x)
  161. N = tf_box_filter(
  162. tf.ones((1, x_shape[1], x_shape[2], 1), dtype=x.dtype), r)
  163. mean_x = tf_box_filter(x, r) / N
  164. mean_y = tf_box_filter(y, r) / N
  165. cov_xy = tf_box_filter(x * y, r) / N - mean_x * mean_y
  166. var_x = tf_box_filter(x * x, r) / N - mean_x * mean_x
  167. A = cov_xy / (var_x + eps)
  168. b = mean_y - A * mean_x
  169. mean_A = tf_box_filter(A, r) / N
  170. mean_b = tf_box_filter(b, r) / N
  171. output = mean_A * x + mean_b
  172. return output
  173. def color_shift(image1, image2, mode='uniform'):
  174. b1, g1, r1 = tf.split(image1, num_or_size_splits=3, axis=3)
  175. b2, g2, r2 = tf.split(image2, num_or_size_splits=3, axis=3)
  176. if mode == 'normal':
  177. b_weight = tf.random.normal(shape=[1], mean=0.114, stddev=0.1)
  178. g_weight = np.random.normal(shape=[1], mean=0.587, stddev=0.1)
  179. r_weight = np.random.normal(shape=[1], mean=0.299, stddev=0.1)
  180. elif mode == 'uniform':
  181. b_weight = tf.random.uniform(shape=[1], minval=0.014, maxval=0.214)
  182. g_weight = tf.random.uniform(shape=[1], minval=0.487, maxval=0.687)
  183. r_weight = tf.random.uniform(shape=[1], minval=0.199, maxval=0.399)
  184. output1 = (b_weight * b1 + g_weight * g1 + r_weight * r1) / (
  185. b_weight + g_weight + r_weight)
  186. output2 = (b_weight * b2 + g_weight * g2 + r_weight * r2) / (
  187. b_weight + g_weight + r_weight)
  188. return output1, output2
  189. def simple_superpixel(batch_image, seg_num=200):
  190. def process_slic(image):
  191. seg_label = segmentation.slic(
  192. image,
  193. n_segments=seg_num,
  194. sigma=1,
  195. compactness=10,
  196. convert2lab=True,
  197. start_label=1)
  198. image = color.label2rgb(seg_label, image, kind='avg', bg_label=0)
  199. return image
  200. num_job = np.shape(batch_image)[0]
  201. batch_out = Parallel(n_jobs=num_job)(
  202. delayed(process_slic)(image) for image in batch_image)
  203. return np.array(batch_out)
  204. def tf_box_filter(x, r):
  205. ch = x.get_shape().as_list()[-1]
  206. weight = 1 / ((2 * r + 1)**2)
  207. box_kernel = weight * np.ones((2 * r + 1, 2 * r + 1, ch, 1))
  208. box_kernel = np.array(box_kernel).astype(np.float32)
  209. output = tf.nn.depthwise_conv2d(x, box_kernel, [1, 1, 1, 1], 'SAME')
  210. return output
  211. if __name__ == '__main__':
  212. pass