cartoon_translation_trainer.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import os.path as osp
  4. from typing import Dict, Optional
  5. import numpy as np
  6. import tensorflow as tf
  7. from packaging import version
  8. from tqdm import tqdm
  9. from modelscope.models.cv.cartoon import (CartoonModel, all_file,
  10. simple_superpixel, tf_data_loader,
  11. write_batch_image)
  12. from modelscope.trainers.base import BaseTrainer
  13. from modelscope.trainers.builder import TRAINERS
  14. from modelscope.utils.constant import ModelFile
  15. from modelscope.utils.logger import get_logger
  16. logger = get_logger()
  17. if version.parse(tf.__version__) < version.parse('2'):
  18. pass
  19. else:
  20. logger.info(
  21. f'TensorFlow version {_tf_version} found, TF2.x is not supported by CartoonTranslationTrainer.'
  22. )
  23. @TRAINERS.register_module(module_name=r'cartoon-translation')
  24. class CartoonTranslationTrainer(BaseTrainer):
  25. def __init__(self,
  26. model: str,
  27. cfg_file: str = None,
  28. work_dir=None,
  29. photo=None,
  30. cartoon=None,
  31. max_steps=None,
  32. *args,
  33. **kwargs):
  34. """
  35. Args:
  36. model: the model_id of trained model
  37. cfg_file: the path of configuration file
  38. work_dir: the path to save training results
  39. photo: the path of photo images for training
  40. cartoon: the path of cartoon images for training
  41. max_steps: the number of total iteration for training
  42. Returns:
  43. initialized trainer: object of CartoonTranslationTrainer
  44. """
  45. model = self.get_or_download_model_dir(model)
  46. tf.reset_default_graph()
  47. self.model_dir = model
  48. self.model_path = osp.join(model, ModelFile.TF_CHECKPOINT_FOLDER)
  49. if cfg_file is None:
  50. cfg_file = osp.join(model, ModelFile.CONFIGURATION)
  51. super().__init__(cfg_file)
  52. self.params = {}
  53. self._override_params_from_file()
  54. if work_dir is not None:
  55. self.params['work_dir'] = work_dir
  56. if photo is not None:
  57. self.params['photo'] = photo
  58. if cartoon is not None:
  59. self.params['cartoon'] = cartoon
  60. if max_steps is not None:
  61. self.params['max_steps'] = max_steps
  62. if not os.path.exists(self.params['work_dir']):
  63. os.makedirs(self.params['work_dir'])
  64. self.face_photo_list = all_file(self.params['photo'])
  65. self.face_cartoon_list = all_file(self.params['cartoon'])
  66. tf_config = tf.ConfigProto(allow_soft_placement=True)
  67. tf_config.gpu_options.allow_growth = True
  68. self._session = tf.Session(config=tf_config)
  69. self.input_photo = tf.placeholder(tf.float32, [
  70. self.params['batch_size'], self.params['patch_size'],
  71. self.params['patch_size'], 3
  72. ])
  73. self.input_superpixel = tf.placeholder(tf.float32, [
  74. self.params['batch_size'], self.params['patch_size'],
  75. self.params['patch_size'], 3
  76. ])
  77. self.input_cartoon = tf.placeholder(tf.float32, [
  78. self.params['batch_size'], self.params['patch_size'],
  79. self.params['patch_size'], 3
  80. ])
  81. self.model = CartoonModel(self.model_dir)
  82. output = self.model(self.input_photo, self.input_cartoon,
  83. self.input_superpixel)
  84. self.output_cartoon = output['output_cartoon']
  85. self.g_loss = output['g_loss']
  86. self.d_loss = output['d_loss']
  87. tf.summary.scalar('g_loss', self.g_loss)
  88. tf.summary.scalar('d_loss', self.d_loss)
  89. self.train_writer = tf.summary.FileWriter(self.params['work_dir']
  90. + '/train_log')
  91. self.summary_op = tf.summary.merge_all()
  92. all_vars = tf.trainable_variables()
  93. gene_vars = [var for var in all_vars if 'gene' in var.name]
  94. disc_vars = [var for var in all_vars if 'disc' in var.name]
  95. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  96. with tf.control_dependencies(update_ops):
  97. self.g_optim = tf.train.AdamOptimizer(self.params['adv_train_lr'], beta1=0.5, beta2=0.99) \
  98. .minimize(self.g_loss, var_list=gene_vars)
  99. self.d_optim = tf.train.AdamOptimizer(self.params['adv_train_lr'], beta1=0.5, beta2=0.99) \
  100. .minimize(self.d_loss, var_list=disc_vars)
  101. self.saver = tf.train.Saver(max_to_keep=1000)
  102. with self._session.as_default() as sess:
  103. sess.run(tf.global_variables_initializer())
  104. if self.params['resume_epoch'] != 0:
  105. logger.info(f'loading model from {self.model_path}')
  106. self.saver.restore(
  107. sess,
  108. osp.join(self.model_path,
  109. 'model-' + str(self.params['resume_epoch'])))
  110. def _override_params_from_file(self):
  111. self.params['photo'] = self.cfg['train']['photo']
  112. self.params['cartoon'] = self.cfg['train']['cartoon']
  113. self.params['patch_size'] = self.cfg['train']['patch_size']
  114. self.params['work_dir'] = self.cfg['train']['work_dir']
  115. self.params['batch_size'] = self.cfg['train']['batch_size']
  116. self.params['adv_train_lr'] = self.cfg['train']['adv_train_lr']
  117. self.params['max_steps'] = self.cfg['train']['max_steps']
  118. self.params['logging_interval'] = self.cfg['train']['logging_interval']
  119. self.params['ckpt_period_interval'] = self.cfg['train'][
  120. 'ckpt_period_interval']
  121. self.params['resume_epoch'] = self.cfg['train']['resume_epoch']
  122. self.params['num_gpus'] = self.cfg['train']['num_gpus']
  123. def train(self, *args, **kwargs):
  124. logger.info('Begin local cartoon translator training')
  125. photo_ds = tf_data_loader(self.face_photo_list,
  126. self.params['batch_size'])
  127. cartoon_ds = tf_data_loader(self.face_cartoon_list,
  128. self.params['batch_size'])
  129. photo_iterator = photo_ds.make_initializable_iterator()
  130. cartoon_iterator = cartoon_ds.make_initializable_iterator()
  131. photo_next = photo_iterator.get_next()
  132. cartoon_next = cartoon_iterator.get_next()
  133. device = 'gpu:0' if tf.test.is_gpu_available else 'cpu:0'
  134. with tf.device(device):
  135. for max_steps in tqdm(range(self.params['max_steps'])):
  136. self._session.run(photo_iterator.initializer)
  137. self._session.run(cartoon_iterator.initializer)
  138. photo_batch, cartoon_batch = self._session.run(
  139. [photo_next, cartoon_next])
  140. transfer_res = self._session.run(
  141. self.output_cartoon,
  142. feed_dict={self.input_photo: photo_batch})
  143. input_superpixel = simple_superpixel(transfer_res, seg_num=200)
  144. g_loss, _ = self._session.run(
  145. [self.g_loss, self.g_optim],
  146. feed_dict={
  147. self.input_photo: photo_batch,
  148. self.input_superpixel: input_superpixel,
  149. self.input_cartoon: cartoon_batch
  150. })
  151. d_loss, _, train_info = self._session.run(
  152. [self.d_loss, self.d_optim, self.summary_op],
  153. feed_dict={
  154. self.input_photo: photo_batch,
  155. self.input_superpixel: input_superpixel,
  156. self.input_cartoon: cartoon_batch
  157. })
  158. self.train_writer.add_summary(train_info, max_steps)
  159. if np.mod(max_steps + 1, self.params['logging_interval']
  160. ) == 0 or max_steps == 0:
  161. logger.info(
  162. f'Iter: {max_steps}, d_loss: {d_loss}, g_loss: {g_loss}'
  163. )
  164. if np.mod(max_steps + 1,
  165. self.params['ckpt_period_interval']
  166. ) == 0 or max_steps == 0:
  167. self.saver.save(
  168. self._session,
  169. self.params['work_dir'] + '/saved_models/model',
  170. write_meta_graph=False,
  171. global_step=max_steps)
  172. result_face = self._session.run(
  173. self.output_cartoon,
  174. feed_dict={
  175. self.input_photo: photo_batch,
  176. self.input_superpixel: photo_batch,
  177. self.input_cartoon: cartoon_batch
  178. })
  179. write_batch_image(
  180. result_face, self.params['work_dir'] + '/images',
  181. str('%8d' % max_steps) + '_face_result.jpg', 4)
  182. write_batch_image(
  183. photo_batch, self.params['work_dir'] + '/images',
  184. str('%8d' % max_steps) + '_face_photo.jpg', 4)
  185. def evaluate(self,
  186. checkpoint_path: Optional[str] = None,
  187. *args,
  188. **kwargs) -> Dict[str, float]:
  189. """evaluate a dataset
  190. evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path`
  191. does not exist, read from the config file.
  192. Args:
  193. checkpoint_path (Optional[str], optional): the model path. Defaults to None.
  194. Returns:
  195. Dict[str, float]: the results about the evaluation
  196. Example:
  197. {"accuracy": 0.5091743119266054, "f1": 0.673780487804878}
  198. """
  199. raise NotImplementedError(
  200. 'evaluate is not supported by CartoonTranslationTrainer')