csanmt_translation_trainer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os.path as osp
  3. import time
  4. from typing import Dict, Optional
  5. import tensorflow as tf
  6. from modelscope.hub.snapshot_download import snapshot_download
  7. from modelscope.models.nlp import CsanmtForTranslation
  8. from modelscope.trainers.base import BaseTrainer
  9. from modelscope.trainers.builder import TRAINERS
  10. from modelscope.utils.constant import ModelFile
  11. from modelscope.utils.logger import get_logger
  12. if tf.__version__ >= '2.0':
  13. tf = tf.compat.v1
  14. tf.disable_eager_execution()
  15. logger = get_logger()
  16. @TRAINERS.register_module(module_name=r'csanmt-translation')
  17. class CsanmtTranslationTrainer(BaseTrainer):
  18. def __init__(self, model: str, cfg_file: str = None, *args, **kwargs):
  19. model = self.get_or_download_model_dir(model)
  20. tf.reset_default_graph()
  21. self.model_dir = model
  22. self.model_path = osp.join(model, ModelFile.TF_CHECKPOINT_FOLDER)
  23. if cfg_file is None:
  24. cfg_file = osp.join(model, ModelFile.CONFIGURATION)
  25. super().__init__(cfg_file)
  26. self.params = {}
  27. self._override_params_from_file()
  28. tf_config = tf.ConfigProto(allow_soft_placement=True)
  29. tf_config.gpu_options.allow_growth = True
  30. self._session = tf.Session(config=tf_config)
  31. self.source_wids = tf.placeholder(
  32. dtype=tf.int64, shape=[None, None], name='source_wids')
  33. self.target_wids = tf.placeholder(
  34. dtype=tf.int64, shape=[None, None], name='target_wids')
  35. self.output = {}
  36. self.global_step = tf.train.create_global_step()
  37. self.model = CsanmtForTranslation(self.model_path, **self.params)
  38. output = self.model(input=self.source_wids, label=self.target_wids)
  39. self.output.update(output)
  40. self.model_saver = tf.train.Saver(
  41. tf.global_variables(),
  42. max_to_keep=self.params['keep_checkpoint_max'])
  43. with self._session.as_default() as sess:
  44. logger.info(f'loading model from {self.model_path}')
  45. pretrained_variables_map = get_pretrained_variables_map(
  46. self.model_path)
  47. tf.train.init_from_checkpoint(self.model_path,
  48. pretrained_variables_map)
  49. sess.run(tf.global_variables_initializer())
  50. def _override_params_from_file(self):
  51. self.params['hidden_size'] = self.cfg['model']['hidden_size']
  52. self.params['filter_size'] = self.cfg['model']['filter_size']
  53. self.params['num_heads'] = self.cfg['model']['num_heads']
  54. self.params['num_encoder_layers'] = self.cfg['model'][
  55. 'num_encoder_layers']
  56. self.params['num_decoder_layers'] = self.cfg['model'][
  57. 'num_decoder_layers']
  58. self.params['layer_preproc'] = self.cfg['model']['layer_preproc']
  59. self.params['layer_postproc'] = self.cfg['model']['layer_postproc']
  60. self.params['shared_embedding_and_softmax_weights'] = self.cfg[
  61. 'model']['shared_embedding_and_softmax_weights']
  62. self.params['shared_source_target_embedding'] = self.cfg['model'][
  63. 'shared_source_target_embedding']
  64. self.params['initializer_scale'] = self.cfg['model'][
  65. 'initializer_scale']
  66. self.params['position_info_type'] = self.cfg['model'][
  67. 'position_info_type']
  68. self.params['max_relative_dis'] = self.cfg['model']['max_relative_dis']
  69. self.params['num_semantic_encoder_layers'] = self.cfg['model'][
  70. 'num_semantic_encoder_layers']
  71. self.params['src_vocab_size'] = self.cfg['model']['src_vocab_size']
  72. self.params['trg_vocab_size'] = self.cfg['model']['trg_vocab_size']
  73. self.params['attention_dropout'] = 0.0
  74. self.params['residual_dropout'] = 0.0
  75. self.params['relu_dropout'] = 0.0
  76. self.params['train_src'] = self.cfg['dataset']['train_src']
  77. self.params['train_trg'] = self.cfg['dataset']['train_trg']
  78. self.params['vocab_src'] = self.cfg['dataset']['src_vocab']['file']
  79. self.params['vocab_trg'] = self.cfg['dataset']['trg_vocab']['file']
  80. self.params['num_gpus'] = self.cfg['train']['num_gpus']
  81. self.params['warmup_steps'] = self.cfg['train']['warmup_steps']
  82. self.params['update_cycle'] = self.cfg['train']['update_cycle']
  83. self.params['keep_checkpoint_max'] = self.cfg['train'][
  84. 'keep_checkpoint_max']
  85. self.params['confidence'] = self.cfg['train']['confidence']
  86. self.params['optimizer'] = self.cfg['train']['optimizer']
  87. self.params['adam_beta1'] = self.cfg['train']['adam_beta1']
  88. self.params['adam_beta2'] = self.cfg['train']['adam_beta2']
  89. self.params['adam_epsilon'] = self.cfg['train']['adam_epsilon']
  90. self.params['gradient_clip_norm'] = self.cfg['train'][
  91. 'gradient_clip_norm']
  92. self.params['learning_rate_decay'] = self.cfg['train'][
  93. 'learning_rate_decay']
  94. self.params['initializer'] = self.cfg['train']['initializer']
  95. self.params['initializer_scale'] = self.cfg['train'][
  96. 'initializer_scale']
  97. self.params['learning_rate'] = self.cfg['train']['learning_rate']
  98. self.params['train_batch_size_words'] = self.cfg['train'][
  99. 'train_batch_size_words']
  100. self.params['scale_l1'] = self.cfg['train']['scale_l1']
  101. self.params['scale_l2'] = self.cfg['train']['scale_l2']
  102. self.params['train_max_len'] = self.cfg['train']['train_max_len']
  103. self.params['num_of_epochs'] = self.cfg['train']['num_of_epochs']
  104. self.params['save_checkpoints_steps'] = self.cfg['train'][
  105. 'save_checkpoints_steps']
  106. self.params['num_of_samples'] = self.cfg['train']['num_of_samples']
  107. self.params['eta'] = self.cfg['train']['eta']
  108. self.params['beam_size'] = self.cfg['evaluation']['beam_size']
  109. self.params['lp_rate'] = self.cfg['evaluation']['lp_rate']
  110. self.params['max_decoded_trg_len'] = self.cfg['evaluation'][
  111. 'max_decoded_trg_len']
  112. self.params['seed'] = self.cfg['model']['seed']
  113. def train(self, *args, **kwargs):
  114. logger.info('Begin csanmt training')
  115. train_src = osp.join(self.model_dir, self.params['train_src'])
  116. train_trg = osp.join(self.model_dir, self.params['train_trg'])
  117. vocab_src = osp.join(self.model_dir, self.params['vocab_src'])
  118. vocab_trg = osp.join(self.model_dir, self.params['vocab_trg'])
  119. epoch = 0
  120. iteration = 0
  121. with self._session.as_default() as tf_session:
  122. while True:
  123. epoch += 1
  124. if epoch >= self.params['num_of_epochs']:
  125. break
  126. tf.logging.info('%s: Epoch %i' % (__name__, epoch))
  127. train_input_fn = input_fn(
  128. train_src,
  129. train_trg,
  130. vocab_src,
  131. vocab_trg,
  132. batch_size_words=self.params['train_batch_size_words'],
  133. max_len=self.params['train_max_len'],
  134. num_gpus=self.params['num_gpus']
  135. if self.params['num_gpus'] > 1 else 1,
  136. is_train=True,
  137. session=tf_session,
  138. epoch=epoch)
  139. features, labels = train_input_fn
  140. try:
  141. while True:
  142. features_batch, labels_batch = tf_session.run(
  143. [features, labels])
  144. iteration += 1
  145. feed_dict = {
  146. self.source_wids: features_batch,
  147. self.target_wids: labels_batch
  148. }
  149. sess_outputs = self._session.run(
  150. self.output, feed_dict=feed_dict)
  151. loss_step = sess_outputs['loss']
  152. logger.info('Iteration: {}, step loss: {:.6f}'.format(
  153. iteration, loss_step))
  154. if iteration % self.params[
  155. 'save_checkpoints_steps'] == 0:
  156. tf.logging.info('%s: Saving model on step: %d.' %
  157. (__name__, iteration))
  158. ck_path = self.model_dir + 'model.ckpt'
  159. self.model_saver.save(
  160. tf_session,
  161. ck_path,
  162. global_step=tf.train.get_global_step())
  163. except tf.errors.OutOfRangeError:
  164. tf.logging.info('epoch %d end!' % (epoch))
  165. tf.logging.info(
  166. '%s: NMT training completed at time: %s.' %
  167. (__name__, time.asctime(time.localtime(time.time()))))
  168. def evaluate(self,
  169. checkpoint_path: Optional[str] = None,
  170. *args,
  171. **kwargs) -> Dict[str, float]:
  172. """evaluate a dataset
  173. evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path`
  174. does not exist, read from the config file.
  175. Args:
  176. checkpoint_path (Optional[str], optional): the model path. Defaults to None.
  177. Returns:
  178. Dict[str, float]: the results about the evaluation
  179. Example:
  180. {"accuracy": 0.5091743119266054, "f1": 0.673780487804878}
  181. """
  182. pass
  183. def input_fn(src_file,
  184. trg_file,
  185. src_vocab_file,
  186. trg_vocab_file,
  187. num_buckets=20,
  188. max_len=100,
  189. batch_size=200,
  190. batch_size_words=4096,
  191. num_gpus=1,
  192. is_train=True,
  193. session=None,
  194. epoch=None):
  195. src_vocab = tf.lookup.StaticVocabularyTable(
  196. tf.lookup.TextFileInitializer(
  197. src_vocab_file,
  198. key_dtype=tf.string,
  199. key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
  200. value_dtype=tf.int64,
  201. value_index=tf.lookup.TextFileIndex.LINE_NUMBER),
  202. num_oov_buckets=1) # NOTE unk-> vocab_size
  203. trg_vocab = tf.lookup.StaticVocabularyTable(
  204. tf.lookup.TextFileInitializer(
  205. trg_vocab_file,
  206. key_dtype=tf.string,
  207. key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
  208. value_dtype=tf.int64,
  209. value_index=tf.lookup.TextFileIndex.LINE_NUMBER),
  210. num_oov_buckets=1) # NOTE unk-> vocab_size
  211. src_dataset = tf.data.TextLineDataset(src_file)
  212. trg_dataset = tf.data.TextLineDataset(trg_file)
  213. src_trg_dataset = tf.data.Dataset.zip((src_dataset, trg_dataset))
  214. src_trg_dataset = src_trg_dataset.map(
  215. lambda src, trg: (tf.string_split([src]), tf.string_split([trg])),
  216. num_parallel_calls=10).prefetch(1000000)
  217. src_trg_dataset = src_trg_dataset.map(
  218. lambda src, trg: (src.values, trg.values),
  219. num_parallel_calls=10).prefetch(1000000)
  220. src_trg_dataset = src_trg_dataset.map(
  221. lambda src, trg: (src_vocab.lookup(src), trg_vocab.lookup(trg)),
  222. num_parallel_calls=10).prefetch(1000000)
  223. if is_train:
  224. def key_func(src_data, trg_data):
  225. bucket_width = (max_len + num_buckets - 1) // num_buckets
  226. bucket_id = tf.maximum(
  227. tf.size(input=src_data) // bucket_width,
  228. tf.size(input=trg_data) // bucket_width)
  229. return tf.cast(tf.minimum(num_buckets, bucket_id), dtype=tf.int64)
  230. def reduce_func(unused_key, windowed_data):
  231. return windowed_data.padded_batch(
  232. batch_size_words, padded_shapes=([None], [None]))
  233. def window_size_func(key):
  234. bucket_width = (max_len + num_buckets - 1) // num_buckets
  235. key += 1
  236. size = (num_gpus * batch_size_words // (key * bucket_width))
  237. return tf.cast(size, dtype=tf.int64)
  238. src_trg_dataset = src_trg_dataset.filter(
  239. lambda src, trg: tf.logical_and(
  240. tf.size(input=src) <= max_len,
  241. tf.size(input=trg) <= max_len))
  242. src_trg_dataset = src_trg_dataset.apply(
  243. tf.data.experimental.group_by_window(
  244. key_func=key_func,
  245. reduce_func=reduce_func,
  246. window_size_func=window_size_func))
  247. else:
  248. src_trg_dataset = src_trg_dataset.padded_batch(
  249. batch_size * num_gpus, padded_shapes=([None], [None]))
  250. iterator = tf.data.make_initializable_iterator(src_trg_dataset)
  251. tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
  252. features, labels = iterator.get_next()
  253. if is_train:
  254. session.run(iterator.initializer)
  255. if epoch == 1:
  256. session.run(tf.tables_initializer())
  257. return features, labels
  258. def get_pretrained_variables_map(checkpoint_file_path, ignore_scope=None):
  259. reader = tf.train.NewCheckpointReader(
  260. tf.train.latest_checkpoint(checkpoint_file_path))
  261. saved_shapes = reader.get_variable_to_shape_map()
  262. if ignore_scope is None:
  263. var_names = sorted([(var.name, var.name.split(':')[0])
  264. for var in tf.global_variables()
  265. if var.name.split(':')[0] in saved_shapes])
  266. else:
  267. var_names = sorted([(var.name, var.name.split(':')[0])
  268. for var in tf.global_variables()
  269. if var.name.split(':')[0] in saved_shapes and all(
  270. scope not in var.name
  271. for scope in ignore_scope)])
  272. restore_vars = []
  273. name2var = dict(
  274. zip(
  275. map(lambda x: x.name.split(':')[0], tf.global_variables()),
  276. tf.global_variables()))
  277. restore_map = {}
  278. with tf.variable_scope('', reuse=True):
  279. for var_name, saved_var_name in var_names:
  280. curr_var = name2var[saved_var_name]
  281. var_shape = curr_var.get_shape().as_list()
  282. if var_shape == saved_shapes[saved_var_name]:
  283. restore_vars.append(curr_var)
  284. restore_map[saved_var_name] = curr_var
  285. return restore_map