canmt_model.py 51 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301
  1. # Part of the implementation is borrowed and modified from FAIRSEQ,
  2. # publicly available at https://github.com/facebookresearch/fairseq
  3. # Copyright 2022-2023 The Alibaba MT Team Authors. All rights reserved.
  4. import math
  5. from typing import Any, Dict, List, Optional, Tuple
  6. import numpy
  7. import torch
  8. import torch.nn as nn
  9. from fairseq import utils
  10. from fairseq.distributed import fsdp_wrap
  11. from fairseq.models import (FairseqEncoder, FairseqEncoderDecoderModel,
  12. FairseqIncrementalDecoder, register_model,
  13. register_model_architecture)
  14. from fairseq.modules import (AdaptiveSoftmax, BaseLayer, FairseqDropout,
  15. LayerDropModuleList, LayerNorm,
  16. PositionalEmbedding,
  17. SinusoidalPositionalEmbedding,
  18. TransformerDecoderLayer, TransformerEncoderLayer)
  19. from fairseq.modules.checkpoint_activations import checkpoint_wrapper
  20. from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
  21. from torch import Tensor
  22. DEFAULT_MAX_SOURCE_POSITIONS = 1024
  23. DEFAULT_MAX_TARGET_POSITIONS = 1024
  24. DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
  25. class CanmtModel(FairseqEncoderDecoderModel):
  26. """
  27. Args:
  28. encoder (TransformerEncoder): the encoder
  29. decoder (TransformerDecoder): the decoder
  30. The CanmtModel provides the following named architectures and
  31. command-line arguments:
  32. .. argparse::
  33. :ref: fairseq.models.transformer_parser
  34. :prog:
  35. """
  36. def __init__(self, args, encoder, decoder, second_decoder):
  37. super().__init__(encoder, decoder)
  38. self.args = args
  39. self.supports_align_args = True
  40. self.encoder = encoder
  41. self.decoder = decoder
  42. self.second_decoder = second_decoder
  43. @staticmethod
  44. def add_args(parser):
  45. """Add model-specific arguments to the parser."""
  46. # fmt: off
  47. parser.add_argument(
  48. '--activation-fn',
  49. choices=utils.get_available_activation_fns(),
  50. help='activation function to use')
  51. parser.add_argument(
  52. '--dropout', type=float, metavar='D', help='dropout probability')
  53. parser.add_argument(
  54. '--attention-dropout',
  55. type=float,
  56. metavar='D',
  57. help='dropout probability for attention weights')
  58. parser.add_argument(
  59. '--activation-dropout',
  60. '--relu-dropout',
  61. type=float,
  62. metavar='D',
  63. help='dropout probability after activation in FFN.')
  64. parser.add_argument(
  65. '--encoder-embed-path',
  66. type=str,
  67. metavar='STR',
  68. help='path to pre-trained encoder embedding')
  69. parser.add_argument(
  70. '--encoder-embed-dim',
  71. type=int,
  72. metavar='N',
  73. help='encoder embedding dimension')
  74. parser.add_argument(
  75. '--encoder-ffn-embed-dim',
  76. type=int,
  77. metavar='N',
  78. help='encoder embedding dimension for FFN')
  79. parser.add_argument(
  80. '--encoder-layers',
  81. type=int,
  82. metavar='N',
  83. help='num encoder layers')
  84. parser.add_argument(
  85. '--encoder-attention-heads',
  86. type=int,
  87. metavar='N',
  88. help='num encoder attention heads')
  89. parser.add_argument(
  90. '--encoder-normalize-before',
  91. action='store_true',
  92. help='apply layernorm before each encoder block')
  93. parser.add_argument(
  94. '--encoder-learned-pos',
  95. action='store_true',
  96. help='use learned positional embeddings in the encoder')
  97. parser.add_argument(
  98. '--decoder-embed-path',
  99. type=str,
  100. metavar='STR',
  101. help='path to pre-trained decoder embedding')
  102. parser.add_argument(
  103. '--decoder-embed-dim',
  104. type=int,
  105. metavar='N',
  106. help='decoder embedding dimension')
  107. parser.add_argument(
  108. '--decoder-ffn-embed-dim',
  109. type=int,
  110. metavar='N',
  111. help='decoder embedding dimension for FFN')
  112. parser.add_argument(
  113. '--decoder-layers',
  114. type=int,
  115. metavar='N',
  116. help='num decoder layers')
  117. parser.add_argument(
  118. '--decoder-attention-heads',
  119. type=int,
  120. metavar='N',
  121. help='num decoder attention heads')
  122. parser.add_argument(
  123. '--decoder-learned-pos',
  124. action='store_true',
  125. help='use learned positional embeddings in the decoder')
  126. parser.add_argument(
  127. '--decoder-normalize-before',
  128. action='store_true',
  129. help='apply layernorm before each decoder block')
  130. parser.add_argument(
  131. '--decoder-output-dim',
  132. type=int,
  133. metavar='N',
  134. help='decoder output dimension (extra linear layer '
  135. 'if different from decoder embed dim')
  136. parser.add_argument(
  137. '--share-decoder-input-output-embed',
  138. action='store_true',
  139. help='share decoder input and output embeddings')
  140. parser.add_argument(
  141. '--share-all-embeddings',
  142. action='store_true',
  143. help='share encoder, decoder and output embeddings'
  144. ' (requires shared dictionary and embed dim)')
  145. parser.add_argument(
  146. '--no-token-positional-embeddings',
  147. default=False,
  148. action='store_true',
  149. help=
  150. 'if set, disables positional embeddings (outside self attention)')
  151. parser.add_argument(
  152. '--adaptive-softmax-cutoff',
  153. metavar='EXPR',
  154. help='comma separated list of adaptive softmax cutoff points. '
  155. 'Must be used with adaptive_loss criterion'),
  156. parser.add_argument(
  157. '--adaptive-softmax-dropout',
  158. type=float,
  159. metavar='D',
  160. help='sets adaptive softmax dropout for the tail projections')
  161. parser.add_argument(
  162. '--layernorm-embedding',
  163. action='store_true',
  164. help='add layernorm to embedding')
  165. parser.add_argument(
  166. '--no-scale-embedding',
  167. action='store_true',
  168. help='if True, dont scale embeddings')
  169. parser.add_argument(
  170. '--checkpoint-activations',
  171. action='store_true',
  172. help='checkpoint activations at each layer, which saves GPU '
  173. 'memory usage at the cost of some additional compute')
  174. parser.add_argument(
  175. '--offload-activations',
  176. action='store_true',
  177. help='checkpoint activations at each layer, then save to gpu.'
  178. 'Sets --checkpoint-activations.')
  179. parser.add_argument(
  180. '--no-cross-attention',
  181. default=False,
  182. action='store_true',
  183. help='do not perform cross-attention')
  184. parser.add_argument(
  185. '--cross-self-attention',
  186. default=False,
  187. action='store_true',
  188. help='perform cross+self-attention')
  189. parser.add_argument(
  190. '--encoder-layerdrop',
  191. type=float,
  192. metavar='D',
  193. default=0,
  194. help='LayerDrop probability for encoder')
  195. parser.add_argument(
  196. '--decoder-layerdrop',
  197. type=float,
  198. metavar='D',
  199. default=0,
  200. help='LayerDrop probability for decoder')
  201. parser.add_argument(
  202. '--encoder-layers-to-keep',
  203. default=None,
  204. help='which layers to *keep* when pruning as a comma-separated list'
  205. )
  206. parser.add_argument(
  207. '--decoder-layers-to-keep',
  208. default=None,
  209. help='which layers to *keep* when pruning as a comma-separated list'
  210. )
  211. parser.add_argument(
  212. '--quant-noise-pq',
  213. type=float,
  214. metavar='D',
  215. default=0,
  216. help='iterative PQ quantization noise at training time')
  217. parser.add_argument(
  218. '--quant-noise-pq-block-size',
  219. type=int,
  220. metavar='D',
  221. default=8,
  222. help='block size of quantization noise at training time')
  223. parser.add_argument(
  224. '--quant-noise-scalar',
  225. type=float,
  226. metavar='D',
  227. default=0,
  228. help=
  229. 'scalar quantization noise and scalar quantization at training time'
  230. )
  231. parser.add_argument(
  232. '--min-params-to-wrap',
  233. type=int,
  234. metavar='D',
  235. default=DEFAULT_MIN_PARAMS_TO_WRAP,
  236. help=
  237. ('minimum number of params for a layer to be wrapped with FSDP() when '
  238. 'training with --ddp-backend=fully_sharded. Smaller values will '
  239. 'improve memory efficiency, but may make torch.distributed '
  240. 'communication less efficient due to smaller input sizes. This option '
  241. 'is set to 0 (i.e., always wrap) when --checkpoint-activations or '
  242. '--offload-activations are passed.'))
  243. # fmt: on
  244. @classmethod
  245. def build_model(cls, args, task):
  246. """Build a new model instance."""
  247. # make sure all arguments are present in older models
  248. base_architecture(args)
  249. if args.encoder_layers_to_keep:
  250. args.encoder_layers = len(args.encoder_layers_to_keep.split(','))
  251. if args.decoder_layers_to_keep:
  252. args.decoder_layers = len(args.decoder_layers_to_keep.split(','))
  253. if getattr(args, 'max_source_positions', None) is None:
  254. args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
  255. if getattr(args, 'max_target_positions', None) is None:
  256. args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
  257. src_dict, tgt_dict = task.vocab_src, task.vocab_tgt
  258. if args.share_all_embeddings:
  259. if src_dict != tgt_dict:
  260. raise ValueError(
  261. '--share-all-embeddings requires a joined dictionary')
  262. if args.encoder_embed_dim != args.decoder_embed_dim:
  263. raise ValueError(
  264. '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim'
  265. )
  266. if args.decoder_embed_path and \
  267. (args.decoder_embed_path != args.encoder_embed_path):
  268. raise ValueError(
  269. '--share-all-embeddings not compatible with --decoder-embed-path'
  270. )
  271. encoder_embed_tokens = cls.build_embedding(args, src_dict,
  272. args.encoder_embed_dim,
  273. args.encoder_embed_path)
  274. decoder_embed_tokens = encoder_embed_tokens
  275. args.share_decoder_input_output_embed = True
  276. else:
  277. encoder_embed_tokens = cls.build_embedding(args, src_dict,
  278. args.encoder_embed_dim,
  279. args.encoder_embed_path)
  280. decoder_embed_tokens = cls.build_embedding(args, tgt_dict,
  281. args.decoder_embed_dim,
  282. args.decoder_embed_path)
  283. if getattr(args, 'offload_activations', False):
  284. args.checkpoint_activations = True # offloading implies checkpointing
  285. encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
  286. decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
  287. second_decoder = cls.build_decoder(args, src_dict,
  288. encoder_embed_tokens)
  289. if not args.share_all_embeddings:
  290. min_params_to_wrap = getattr(args, 'min_params_to_wrap',
  291. DEFAULT_MIN_PARAMS_TO_WRAP)
  292. # fsdp_wrap is a no-op when --ddp-backend != fully_sharded
  293. encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap)
  294. decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap)
  295. return cls(args, encoder, decoder, second_decoder)
  296. @classmethod
  297. def build_embedding(cls, args, dictionary, embed_dim, path=None):
  298. num_embeddings = len(dictionary)
  299. padding_idx = dictionary.pad()
  300. emb = Embedding(num_embeddings, embed_dim, padding_idx)
  301. # if provided, load from preloaded dictionaries
  302. if path:
  303. embed_dict = utils.parse_embedding(path)
  304. utils.load_embedding(embed_dict, dictionary, emb)
  305. return emb
  306. @classmethod
  307. def build_encoder(cls, args, src_dict, embed_tokens):
  308. return TransformerEncoder(args, src_dict, embed_tokens)
  309. @classmethod
  310. def build_decoder(cls, args, tgt_dict, embed_tokens):
  311. return TransformerDecoder(
  312. args,
  313. tgt_dict,
  314. embed_tokens,
  315. no_encoder_attn=getattr(args, 'no_cross_attention', False),
  316. )
  317. def forward(
  318. self,
  319. src_tokens,
  320. src_lengths,
  321. prev_output_tokens,
  322. prev_src_tokens,
  323. return_all_hiddens: bool = True,
  324. features_only: bool = False,
  325. alignment_layer: Optional[int] = None,
  326. alignment_heads: Optional[int] = None,
  327. ):
  328. """
  329. Run the forward pass for an encoder-decoder model.
  330. Copied from the base class, but without ``**kwargs``,
  331. which are not supported by TorchScript.
  332. """
  333. encoder_out = self.encoder(
  334. src_tokens,
  335. src_lengths=src_lengths,
  336. return_all_hiddens=return_all_hiddens)
  337. decoder_out = self.decoder(
  338. prev_output_tokens,
  339. encoder_out=encoder_out,
  340. features_only=features_only,
  341. alignment_layer=alignment_layer,
  342. alignment_heads=alignment_heads,
  343. src_lengths=src_lengths,
  344. return_all_hiddens=return_all_hiddens,
  345. )
  346. decoder_out_re = self.decoder(
  347. prev_output_tokens,
  348. encoder_out=None,
  349. features_only=features_only,
  350. full_context_alignment=True,
  351. alignment_layer=alignment_layer,
  352. alignment_heads=alignment_heads,
  353. src_lengths=src_lengths,
  354. return_all_hiddens=return_all_hiddens,
  355. )
  356. decoder_out_tensor = decoder_out_re[1]['last_layer']
  357. decoder_padding = decoder_out_re[1]['self_attn_padding_mask']
  358. decoder_kvs = {
  359. 'encoder_out': [decoder_out_tensor],
  360. 'encoder_padding_mask': [decoder_padding]
  361. }
  362. src_out = self.second_decoder(
  363. prev_src_tokens,
  364. encoder_out=decoder_kvs,
  365. features_only=features_only,
  366. alignment_layer=alignment_layer,
  367. alignment_heads=alignment_heads,
  368. src_lengths=None,
  369. return_all_hiddens=return_all_hiddens,
  370. )
  371. return decoder_out, src_out, decoder_kvs
  372. @torch.jit.export
  373. def get_normalized_probs(
  374. self,
  375. net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
  376. log_probs: bool,
  377. sample: Optional[Dict[str, Tensor]] = None,
  378. ):
  379. """Get normalized probabilities (or log probs) from a net's output."""
  380. return self.get_normalized_probs_scriptable(net_output, log_probs,
  381. sample)
  382. def forward_decoder(
  383. self,
  384. tokens,
  385. encoder_outs: Dict[str, List[Tensor]],
  386. incremental_states: Dict[str, Dict[str, Optional[Tensor]]],
  387. temperature: float = 1.0,
  388. ):
  389. encoder_out: Optional[Dict[str, List[Tensor]]] = None
  390. encoder_out = encoder_outs
  391. # decode
  392. decoder_out = self.decoder.forward(
  393. tokens,
  394. encoder_out=encoder_out,
  395. incremental_state=incremental_states,
  396. )
  397. attn: Optional[Tensor] = None
  398. decoder_len = len(decoder_out)
  399. if decoder_len > 1 and decoder_out[1] is not None:
  400. if isinstance(decoder_out[1], Tensor):
  401. attn = decoder_out[1]
  402. else:
  403. attn_holder = decoder_out[1]['attn']
  404. if isinstance(attn_holder, Tensor):
  405. attn = attn_holder
  406. elif attn_holder is not None:
  407. attn = attn_holder[0]
  408. if attn is not None:
  409. attn = attn[:, -1, :]
  410. decoder_out_tuple = (
  411. decoder_out[0][:, -1:, :].div_(temperature),
  412. None if decoder_len <= 1 else decoder_out[1],
  413. )
  414. probs = self.get_normalized_probs(
  415. decoder_out_tuple, log_probs=True, sample=None)
  416. probs = probs[:, -1, :]
  417. decoder_out_tensor = decoder_out[1]['last_layer']
  418. return probs, attn, decoder_out_tensor
  419. def forward_decoder_src(
  420. self,
  421. tokens,
  422. encoder_outs: Dict[str, List[Tensor]],
  423. incremental_states: Dict[str, Dict[str, Optional[Tensor]]],
  424. temperature: float = 1.0,
  425. ):
  426. encoder_out: Optional[Dict[str, List[Tensor]]] = None
  427. encoder_out = encoder_outs
  428. # decode each model
  429. decoder_out = self.second_decoder.forward(
  430. tokens, encoder_out=encoder_out)
  431. attn: Optional[Tensor] = None
  432. decoder_len = len(decoder_out)
  433. if decoder_len > 1 and decoder_out[1] is not None:
  434. if isinstance(decoder_out[1], Tensor):
  435. attn = decoder_out[1]
  436. else:
  437. attn_holder = decoder_out[1]['attn']
  438. if isinstance(attn_holder, Tensor):
  439. attn = attn_holder
  440. elif attn_holder is not None:
  441. attn = attn_holder[0]
  442. if attn is not None:
  443. attn = attn[:, -1, :]
  444. decoder_out_tuple = (
  445. decoder_out[0][:, -1:, :].div_(temperature),
  446. None if decoder_len <= 1 else decoder_out[1],
  447. )
  448. probs = self.get_normalized_probs(
  449. decoder_out_tuple, log_probs=True, sample=None)
  450. probs = probs[:, -1, :]
  451. decoder_out_tensor = decoder_out[1]['last_layer']
  452. return probs, attn, decoder_out_tensor, decoder_out
  453. def forward_encoder(self, net_input: Dict[str, Tensor]):
  454. encoder_input = {
  455. k: v
  456. for k, v in net_input.items() if k != 'prev_output_tokens'
  457. and k != 'prev_src_tokens' and k != 'sources'
  458. }
  459. return self.encoder.forward_torchscript(encoder_input)
  460. def reorder_encoder_out(self, encoder_outs: Optional[Dict[str,
  461. List[Tensor]]],
  462. new_order):
  463. """
  464. Reorder encoder output according to *new_order*.
  465. Args:
  466. encoder_out: output from the ``forward()`` method
  467. new_order (LongTensor): desired order
  468. Returns:
  469. *encoder_out* rearranged according to *new_order*
  470. """
  471. assert encoder_outs is not None
  472. return self.encoder.reorder_encoder_out(encoder_outs, new_order)
  473. def reorder_incremental_state(
  474. self,
  475. incremental_states: Dict[str, Dict[str, Optional[Tensor]]],
  476. new_order,
  477. ):
  478. self.decoder.reorder_incremental_state_scripting(
  479. incremental_states, new_order)
  480. class TransformerEncoder(FairseqEncoder):
  481. """
  482. Transformer encoder consisting of *args.encoder_layers* layers. Each layer
  483. is a :class:`TransformerEncoderLayer`.
  484. Args:
  485. args (argparse.Namespace): parsed command-line arguments
  486. dictionary (~fairseq.data.Dictionary): encoding dictionary
  487. embed_tokens (torch.nn.Embedding): input embedding
  488. """
  489. def __init__(self, args, dictionary, embed_tokens):
  490. self.args = args
  491. super().__init__(dictionary)
  492. self.register_buffer('version', torch.Tensor([3]))
  493. self.dropout_module = FairseqDropout(
  494. args.dropout, module_name=self.__class__.__name__)
  495. self.encoder_layerdrop = args.encoder_layerdrop
  496. embed_dim = embed_tokens.embedding_dim
  497. self.padding_idx = embed_tokens.padding_idx
  498. self.max_source_positions = args.max_source_positions
  499. self.embed_tokens = embed_tokens
  500. self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(
  501. embed_dim)
  502. self.embed_positions = (
  503. PositionalEmbedding(
  504. args.max_source_positions,
  505. embed_dim,
  506. self.padding_idx,
  507. learned=args.encoder_learned_pos,
  508. ) if not args.no_token_positional_embeddings else None)
  509. export = getattr(args, 'export', False)
  510. if getattr(args, 'layernorm_embedding', False):
  511. self.layernorm_embedding = LayerNorm(embed_dim, export=export)
  512. else:
  513. self.layernorm_embedding = None
  514. if not args.adaptive_input and args.quant_noise_pq > 0:
  515. self.quant_noise = apply_quant_noise_(
  516. nn.Linear(embed_dim, embed_dim, bias=False),
  517. args.quant_noise_pq,
  518. args.quant_noise_pq_block_size,
  519. )
  520. else:
  521. self.quant_noise = None
  522. if self.encoder_layerdrop > 0.0:
  523. self.layers = LayerDropModuleList(p=self.encoder_layerdrop)
  524. else:
  525. self.layers = nn.ModuleList([])
  526. self.layers.extend([
  527. self.build_encoder_layer(args) for i in range(args.encoder_layers)
  528. ])
  529. self.num_layers = len(self.layers)
  530. if args.encoder_normalize_before:
  531. self.layer_norm = LayerNorm(embed_dim, export=export)
  532. else:
  533. self.layer_norm = None
  534. def build_encoder_layer(self, args):
  535. layer = TransformerEncoderLayer(args)
  536. checkpoint = getattr(args, 'checkpoint_activations', False)
  537. if checkpoint:
  538. offload_to_cpu = getattr(args, 'offload_activations', False)
  539. layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
  540. min_params_to_wrap = (
  541. getattr(args, 'min_params_to_wrap', DEFAULT_MIN_PARAMS_TO_WRAP)
  542. if not checkpoint else 0)
  543. layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
  544. return layer
  545. def forward_embedding(self,
  546. src_tokens,
  547. token_embedding: Optional[torch.Tensor] = None):
  548. # embed tokens and positions
  549. if token_embedding is None:
  550. token_embedding = self.embed_tokens(src_tokens)
  551. x = embed = self.embed_scale * token_embedding
  552. if self.embed_positions is not None:
  553. x = embed + self.embed_positions(src_tokens)
  554. if self.layernorm_embedding is not None:
  555. x = self.layernorm_embedding(x)
  556. x = self.dropout_module(x)
  557. if self.quant_noise is not None:
  558. x = self.quant_noise(x)
  559. return x, embed
  560. def forward(
  561. self,
  562. src_tokens,
  563. src_lengths: Optional[torch.Tensor] = None,
  564. return_all_hiddens: bool = False,
  565. token_embeddings: Optional[torch.Tensor] = None,
  566. ):
  567. """
  568. Args:
  569. src_tokens (LongTensor): tokens in the source language of shape
  570. `(batch, src_len)`
  571. src_lengths (torch.LongTensor): lengths of each source sentence of
  572. shape `(batch)`
  573. return_all_hiddens (bool, optional): also return all of the
  574. intermediate hidden states (default: False).
  575. token_embeddings (torch.Tensor, optional): precomputed embeddings
  576. default `None` will recompute embeddings
  577. Returns:
  578. dict:
  579. - **encoder_out** (Tensor): the last encoder layer's output of
  580. shape `(src_len, batch, embed_dim)`
  581. - **encoder_padding_mask** (ByteTensor): the positions of
  582. padding elements of shape `(batch, src_len)`
  583. - **encoder_embedding** (Tensor): the (scaled) embedding lookup
  584. of shape `(batch, src_len, embed_dim)`
  585. - **encoder_states** (List[Tensor]): all intermediate
  586. hidden states of shape `(src_len, batch, embed_dim)`.
  587. Only populated if *return_all_hiddens* is True.
  588. """
  589. return self.forward_scriptable(src_tokens, src_lengths,
  590. return_all_hiddens, token_embeddings)
  591. def forward_scriptable(
  592. self,
  593. src_tokens,
  594. src_lengths: Optional[torch.Tensor] = None,
  595. return_all_hiddens: bool = False,
  596. token_embeddings: Optional[torch.Tensor] = None,
  597. ):
  598. """
  599. Args:
  600. src_tokens (LongTensor): tokens in the source language of shape
  601. `(batch, src_len)`
  602. src_lengths (torch.LongTensor): lengths of each source sentence of
  603. shape `(batch)`
  604. return_all_hiddens (bool, optional): also return all of the
  605. intermediate hidden states (default: False).
  606. token_embeddings (torch.Tensor, optional): precomputed embeddings
  607. default `None` will recompute embeddings
  608. Returns:
  609. dict:
  610. - **encoder_out** (Tensor): the last encoder layer's output of
  611. shape `(src_len, batch, embed_dim)`
  612. - **encoder_padding_mask** (ByteTensor): the positions of
  613. padding elements of shape `(batch, src_len)`
  614. - **encoder_embedding** (Tensor): the (scaled) embedding lookup
  615. of shape `(batch, src_len, embed_dim)`
  616. - **encoder_states** (List[Tensor]): all intermediate
  617. hidden states of shape `(src_len, batch, embed_dim)`.
  618. Only populated if *return_all_hiddens* is True.
  619. """
  620. # compute padding mask
  621. encoder_padding_mask = src_tokens.eq(self.padding_idx)
  622. has_pads = src_tokens.device.type == 'xla' or encoder_padding_mask.any(
  623. )
  624. x, encoder_embedding = self.forward_embedding(src_tokens,
  625. token_embeddings)
  626. # account for padding while computing the representation
  627. if has_pads:
  628. x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
  629. # B x T x C -> T x B x C
  630. x = x.transpose(0, 1)
  631. encoder_states = []
  632. if return_all_hiddens:
  633. encoder_states.append(x)
  634. # encoder layers
  635. for layer in self.layers:
  636. x = layer(
  637. x,
  638. encoder_padding_mask=encoder_padding_mask
  639. if has_pads else None)
  640. if return_all_hiddens:
  641. assert encoder_states is not None
  642. encoder_states.append(x)
  643. if self.layer_norm is not None:
  644. x = self.layer_norm(x)
  645. return {
  646. 'encoder_out': [x], # T x B x C
  647. 'encoder_padding_mask': [encoder_padding_mask], # B x T
  648. 'encoder_embedding': [encoder_embedding], # B x T x C
  649. 'encoder_states': encoder_states, # List[T x B x C]
  650. 'src_tokens': [],
  651. 'src_lengths': [],
  652. }
  653. @torch.jit.export
  654. def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]],
  655. new_order):
  656. """
  657. Reorder encoder output according to *new_order*.
  658. Args:
  659. encoder_out: output from the ``forward()`` method
  660. new_order (LongTensor): desired order
  661. Returns:
  662. *encoder_out* rearranged according to *new_order*
  663. """
  664. if len(encoder_out['encoder_out']) == 0:
  665. new_encoder_out = []
  666. else:
  667. new_encoder_out = [
  668. encoder_out['encoder_out'][0].index_select(1, new_order)
  669. ]
  670. if len(encoder_out['encoder_padding_mask']) == 0:
  671. new_encoder_padding_mask = []
  672. else:
  673. new_encoder_padding_mask = [
  674. encoder_out['encoder_padding_mask'][0].index_select(
  675. 0, new_order)
  676. ]
  677. if len(encoder_out['encoder_embedding']) == 0:
  678. new_encoder_embedding = []
  679. else:
  680. new_encoder_embedding = [
  681. encoder_out['encoder_embedding'][0].index_select(0, new_order)
  682. ]
  683. if len(encoder_out['src_tokens']) == 0:
  684. src_tokens = []
  685. else:
  686. src_tokens = [
  687. (encoder_out['src_tokens'][0]).index_select(0, new_order)
  688. ]
  689. if len(encoder_out['src_lengths']) == 0:
  690. src_lengths = []
  691. else:
  692. src_lengths = [
  693. (encoder_out['src_lengths'][0]).index_select(0, new_order)
  694. ]
  695. encoder_states = encoder_out['encoder_states']
  696. if len(encoder_states) > 0:
  697. for idx, state in enumerate(encoder_states):
  698. encoder_states[idx] = state.index_select(1, new_order)
  699. return {
  700. 'encoder_out': new_encoder_out, # T x B x C
  701. 'encoder_padding_mask': new_encoder_padding_mask, # B x T
  702. 'encoder_embedding': new_encoder_embedding, # B x T x C
  703. 'encoder_states': encoder_states, # List[T x B x C]
  704. 'src_tokens': src_tokens, # B x T
  705. 'src_lengths': src_lengths, # B x 1
  706. }
  707. def max_positions(self):
  708. """Maximum input length supported by the encoder."""
  709. if self.embed_positions is None:
  710. return self.max_source_positions
  711. return min(self.max_source_positions,
  712. self.embed_positions.max_positions)
  713. def upgrade_state_dict_named(self, state_dict, name):
  714. """Upgrade a (possibly old) state dict for new versions of fairseq."""
  715. if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
  716. weights_key = '{}.embed_positions.weights'.format(name)
  717. if weights_key in state_dict:
  718. print('deleting {0}'.format(weights_key))
  719. del state_dict[weights_key]
  720. state_dict['{}.embed_positions._float_tensor'.format(
  721. name)] = torch.FloatTensor(1)
  722. for i in range(self.num_layers):
  723. # update layer norms
  724. self.layers[i].upgrade_state_dict_named(
  725. state_dict, '{}.layers.{}'.format(name, i))
  726. version_key = '{}.version'.format(name)
  727. if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
  728. # earlier checkpoints did not normalize after the stack of layers
  729. self.layer_norm = None
  730. self.normalize = False
  731. state_dict[version_key] = torch.Tensor([1])
  732. return state_dict
  733. class TransformerDecoder(FairseqIncrementalDecoder):
  734. """
  735. Transformer decoder consisting of *args.decoder_layers* layers. Each layer
  736. is a :class:`TransformerDecoderLayer`.
  737. Args:
  738. args (argparse.Namespace): parsed command-line arguments
  739. dictionary (~fairseq.data.Dictionary): decoding dictionary
  740. embed_tokens (torch.nn.Embedding): output embedding
  741. no_encoder_attn (bool, optional): whether to attend to encoder outputs
  742. (default: False).
  743. """
  744. def __init__(
  745. self,
  746. args,
  747. dictionary,
  748. embed_tokens,
  749. no_encoder_attn=False,
  750. output_projection=None,
  751. ):
  752. self.args = args
  753. super().__init__(dictionary)
  754. self.register_buffer('version', torch.Tensor([3]))
  755. self._future_mask = torch.empty(0)
  756. self.dropout_module = FairseqDropout(
  757. args.dropout, module_name=self.__class__.__name__)
  758. self.decoder_layerdrop = args.decoder_layerdrop
  759. self.share_input_output_embed = args.share_decoder_input_output_embed
  760. input_embed_dim = embed_tokens.embedding_dim
  761. embed_dim = args.decoder_embed_dim
  762. self.embed_dim = embed_dim
  763. self.output_embed_dim = args.decoder_output_dim
  764. self.padding_idx = embed_tokens.padding_idx
  765. self.max_target_positions = args.max_target_positions
  766. self.embed_tokens = embed_tokens
  767. self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(
  768. embed_dim)
  769. if not args.adaptive_input and args.quant_noise_pq > 0:
  770. self.quant_noise = apply_quant_noise_(
  771. nn.Linear(embed_dim, embed_dim, bias=False),
  772. args.quant_noise_pq,
  773. args.quant_noise_pq_block_size,
  774. )
  775. else:
  776. self.quant_noise = None
  777. self.project_in_dim = (
  778. Linear(input_embed_dim, embed_dim, bias=False)
  779. if embed_dim != input_embed_dim else None)
  780. self.embed_positions = (
  781. PositionalEmbedding(
  782. self.max_target_positions,
  783. embed_dim,
  784. self.padding_idx,
  785. learned=args.decoder_learned_pos,
  786. ) if not args.no_token_positional_embeddings else None)
  787. export = getattr(args, 'export', False)
  788. if getattr(args, 'layernorm_embedding', False):
  789. self.layernorm_embedding = LayerNorm(embed_dim, export=export)
  790. else:
  791. self.layernorm_embedding = None
  792. self.cross_self_attention = getattr(args, 'cross_self_attention',
  793. False)
  794. if self.decoder_layerdrop > 0.0:
  795. self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
  796. else:
  797. self.layers = nn.ModuleList([])
  798. self.layers.extend([
  799. self.build_decoder_layer(args, no_encoder_attn)
  800. for _ in range(args.decoder_layers)
  801. ])
  802. self.num_layers = len(self.layers)
  803. if args.decoder_normalize_before and not getattr(
  804. args, 'no_decoder_final_norm', False):
  805. self.layer_norm = LayerNorm(embed_dim, export=export)
  806. else:
  807. self.layer_norm = None
  808. self.project_out_dim = (
  809. Linear(embed_dim, self.output_embed_dim, bias=False)
  810. if embed_dim != self.output_embed_dim
  811. and not args.tie_adaptive_weights else None)
  812. self.adaptive_softmax = None
  813. self.output_projection = output_projection
  814. if self.output_projection is None:
  815. self.build_output_projection(args, dictionary, embed_tokens)
  816. def build_output_projection(self, args, dictionary, embed_tokens):
  817. if args.adaptive_softmax_cutoff is not None:
  818. self.adaptive_softmax = AdaptiveSoftmax(
  819. len(dictionary),
  820. self.output_embed_dim,
  821. utils.eval_str_list(args.adaptive_softmax_cutoff, type=int),
  822. dropout=args.adaptive_softmax_dropout,
  823. adaptive_inputs=embed_tokens
  824. if args.tie_adaptive_weights else None,
  825. factor=args.adaptive_softmax_factor,
  826. tie_proj=args.tie_adaptive_proj,
  827. )
  828. elif self.share_input_output_embed:
  829. self.output_projection = nn.Linear(
  830. self.embed_tokens.weight.shape[1],
  831. self.embed_tokens.weight.shape[0],
  832. bias=False,
  833. )
  834. self.output_projection.weight = self.embed_tokens.weight
  835. else:
  836. self.output_projection = nn.Linear(
  837. self.output_embed_dim, len(dictionary), bias=False)
  838. nn.init.normal_(
  839. self.output_projection.weight,
  840. mean=0,
  841. std=self.output_embed_dim**-0.5)
  842. num_base_layers = getattr(args, 'base_layers', 0)
  843. for i in range(num_base_layers):
  844. self.layers.insert(
  845. ((i + 1) * args.decoder_layers) // (num_base_layers + 1),
  846. BaseLayer(args),
  847. )
  848. def build_decoder_layer(self, args, no_encoder_attn=False):
  849. layer = TransformerDecoderLayer(args, no_encoder_attn)
  850. checkpoint = getattr(args, 'checkpoint_activations', False)
  851. if checkpoint:
  852. offload_to_cpu = getattr(args, 'offload_activations', False)
  853. layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
  854. min_params_to_wrap = (
  855. getattr(args, 'min_params_to_wrap', DEFAULT_MIN_PARAMS_TO_WRAP)
  856. if not checkpoint else 0)
  857. layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
  858. return layer
  859. def forward(
  860. self,
  861. prev_output_tokens,
  862. encoder_out: Optional[Dict[str, List[Tensor]]] = None,
  863. incremental_state: Optional[Dict[str, Dict[str,
  864. Optional[Tensor]]]] = None,
  865. features_only: bool = False,
  866. full_context_alignment: bool = False,
  867. alignment_layer: Optional[int] = None,
  868. alignment_heads: Optional[int] = None,
  869. src_lengths: Optional[Any] = None,
  870. return_all_hiddens: bool = False,
  871. ):
  872. """
  873. Args:
  874. prev_output_tokens (LongTensor): previous decoder outputs of shape
  875. `(batch, tgt_len)`, for teacher forcing
  876. encoder_out (optional): output from the encoder, used for
  877. encoder-side attention, should be of size T x B x C
  878. incremental_state (dict): dictionary used for storing state during
  879. :ref:`Incremental decoding`
  880. features_only (bool, optional): only return features without
  881. applying output layer (default: False).
  882. full_context_alignment (bool, optional): don't apply
  883. auto-regressive mask to self-attention (default: False).
  884. Returns:
  885. tuple:
  886. - the decoder's output of shape `(batch, tgt_len, vocab)`
  887. - a dictionary with any model-specific outputs
  888. """
  889. x, extra = self.extract_features(
  890. prev_output_tokens,
  891. encoder_out=encoder_out,
  892. incremental_state=incremental_state,
  893. full_context_alignment=full_context_alignment,
  894. alignment_layer=alignment_layer,
  895. alignment_heads=alignment_heads,
  896. )
  897. if not features_only:
  898. x = self.output_layer(x)
  899. return x, extra
  900. def extract_features(
  901. self,
  902. prev_output_tokens,
  903. encoder_out: Optional[Dict[str, List[Tensor]]],
  904. incremental_state: Optional[Dict[str, Dict[str,
  905. Optional[Tensor]]]] = None,
  906. full_context_alignment: bool = False,
  907. alignment_layer: Optional[int] = None,
  908. alignment_heads: Optional[int] = None,
  909. ):
  910. return self.extract_features_scriptable(
  911. prev_output_tokens,
  912. encoder_out,
  913. incremental_state,
  914. full_context_alignment,
  915. alignment_layer,
  916. alignment_heads,
  917. )
  918. """
  919. A scriptable subclass of this class has an extract_features method and calls
  920. super().extract_features, but super() is not supported in torchscript. A copy of
  921. this function is made to be used in the subclass instead.
  922. """
  923. def extract_features_scriptable(
  924. self,
  925. prev_output_tokens,
  926. encoder_out: Optional[Dict[str, List[Tensor]]],
  927. incremental_state: Optional[Dict[str, Dict[str,
  928. Optional[Tensor]]]] = None,
  929. full_context_alignment: bool = False,
  930. alignment_layer: Optional[int] = None,
  931. alignment_heads: Optional[int] = None,
  932. ):
  933. """
  934. Similar to *forward* but only return features.
  935. Includes several features from "Jointly Learning to Align and
  936. Translate with Transformer Models" (Garg et al., EMNLP 2019).
  937. Args:
  938. full_context_alignment (bool, optional): don't apply
  939. auto-regressive mask to self-attention (default: False).
  940. alignment_layer (int, optional): return mean alignment over
  941. heads at this layer (default: last layer).
  942. alignment_heads (int, optional): only average alignment over
  943. this many heads (default: all heads).
  944. Returns:
  945. tuple:
  946. - the decoder's features of shape `(batch, tgt_len, embed_dim)`
  947. - a dictionary with any model-specific outputs
  948. """
  949. bs, slen = prev_output_tokens.size()
  950. if alignment_layer is None:
  951. alignment_layer = self.num_layers - 1
  952. enc: Optional[Tensor] = None
  953. padding_mask: Optional[Tensor] = None
  954. if encoder_out is not None and len(encoder_out['encoder_out']) > 0:
  955. enc = encoder_out['encoder_out'][0]
  956. assert (enc.size()[1] == bs
  957. ), f'Expected enc.shape == (t, {bs}, c) got {enc.shape}'
  958. if encoder_out is not None and len(
  959. encoder_out['encoder_padding_mask']) > 0:
  960. padding_mask = encoder_out['encoder_padding_mask'][0]
  961. # embed positions
  962. positions = None
  963. if self.embed_positions is not None:
  964. positions = self.embed_positions(
  965. prev_output_tokens, incremental_state=incremental_state)
  966. if incremental_state is not None:
  967. prev_output_tokens = prev_output_tokens[:, -1:]
  968. if positions is not None:
  969. positions = positions[:, -1:]
  970. # embed tokens and positions
  971. x = self.embed_scale * self.embed_tokens(prev_output_tokens)
  972. if self.quant_noise is not None:
  973. x = self.quant_noise(x)
  974. if self.project_in_dim is not None:
  975. x = self.project_in_dim(x)
  976. if positions is not None:
  977. x += positions
  978. if self.layernorm_embedding is not None:
  979. x = self.layernorm_embedding(x)
  980. x = self.dropout_module(x)
  981. # B x T x C -> T x B x C
  982. x = x.transpose(0, 1)
  983. self_attn_padding_mask: Optional[Tensor] = None
  984. if self.cross_self_attention or prev_output_tokens.eq(
  985. self.padding_idx).any():
  986. self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
  987. # decoder layers
  988. attn: Optional[Tensor] = None
  989. inner_states: List[Optional[Tensor]] = [x]
  990. for idx, layer in enumerate(self.layers):
  991. if incremental_state is None and not full_context_alignment:
  992. self_attn_mask = self.buffered_future_mask(x)
  993. else:
  994. self_attn_mask = None
  995. x, layer_attn, self_attn_hidden = layer(
  996. x,
  997. enc,
  998. padding_mask,
  999. incremental_state,
  1000. self_attn_mask=self_attn_mask,
  1001. self_attn_padding_mask=self_attn_padding_mask,
  1002. need_attn=bool((idx == alignment_layer)),
  1003. need_head_weights=bool((idx == alignment_layer)),
  1004. )
  1005. inner_states.append(x)
  1006. if layer_attn is not None and idx == alignment_layer:
  1007. attn = layer_attn.float().to(x)
  1008. if attn is not None:
  1009. if alignment_heads is not None:
  1010. attn = attn[:alignment_heads]
  1011. attn = attn.mean(dim=0)
  1012. if self.layer_norm is not None:
  1013. x = self.layer_norm(x)
  1014. last_layer = x
  1015. # T x B x C -> B x T x C
  1016. x = x.transpose(0, 1)
  1017. if self.project_out_dim is not None:
  1018. x = self.project_out_dim(x)
  1019. return x, {
  1020. 'attn': [attn],
  1021. 'inner_states': inner_states,
  1022. 'last_layer': last_layer,
  1023. 'self_attn_padding_mask': self_attn_padding_mask
  1024. }
  1025. def output_layer(self, features):
  1026. """Project features to the vocabulary size."""
  1027. if self.adaptive_softmax is None:
  1028. # project back to size of vocabulary
  1029. return self.output_projection(features)
  1030. else:
  1031. return features
  1032. def max_positions(self):
  1033. """Maximum output length supported by the decoder."""
  1034. if self.embed_positions is None:
  1035. return self.max_target_positions
  1036. return min(self.max_target_positions,
  1037. self.embed_positions.max_positions)
  1038. def buffered_future_mask(self, tensor):
  1039. dim = tensor.size(0)
  1040. if (self._future_mask.size(0) == 0
  1041. or (not self._future_mask.device == tensor.device)
  1042. or self._future_mask.size(0) < dim):
  1043. self._future_mask = torch.triu(
  1044. utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1)
  1045. self._future_mask = self._future_mask.to(tensor)
  1046. return self._future_mask[:dim, :dim]
  1047. def upgrade_state_dict_named(self, state_dict, name):
  1048. """Upgrade a (possibly old) state dict for new versions of fairseq."""
  1049. if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
  1050. weights_key = '{}.embed_positions.weights'.format(name)
  1051. if weights_key in state_dict:
  1052. del state_dict[weights_key]
  1053. state_dict['{}.embed_positions._float_tensor'.format(
  1054. name)] = torch.FloatTensor(1)
  1055. if f'{name}.output_projection.weight' not in state_dict:
  1056. if self.share_input_output_embed:
  1057. embed_out_key = f'{name}.embed_tokens.weight'
  1058. else:
  1059. embed_out_key = f'{name}.embed_out'
  1060. if embed_out_key in state_dict:
  1061. state_dict[f'{name}.output_projection.weight'] = state_dict[
  1062. embed_out_key]
  1063. if not self.share_input_output_embed:
  1064. del state_dict[embed_out_key]
  1065. for i in range(self.num_layers):
  1066. # update layer norms
  1067. layer_norm_map = {
  1068. '0': 'self_attn_layer_norm',
  1069. '1': 'encoder_attn_layer_norm',
  1070. '2': 'final_layer_norm',
  1071. }
  1072. for old, new in layer_norm_map.items():
  1073. for m in ('weight', 'bias'):
  1074. k = '{}.layers.{}.layer_norms.{}.{}'.format(
  1075. name, i, old, m)
  1076. if k in state_dict:
  1077. state_dict['{}.layers.{}.{}.{}'.format(
  1078. name, i, new, m)] = state_dict[k]
  1079. del state_dict[k]
  1080. version_key = '{}.version'.format(name)
  1081. if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
  1082. # earlier checkpoints did not normalize after the stack of layers
  1083. self.layer_norm = None
  1084. self.normalize = False
  1085. state_dict[version_key] = torch.Tensor([1])
  1086. return state_dict
  1087. def Embedding(num_embeddings, embedding_dim, padding_idx):
  1088. m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
  1089. nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
  1090. nn.init.constant_(m.weight[padding_idx], 0)
  1091. return m
  1092. def Linear(in_features, out_features, bias=True):
  1093. m = nn.Linear(in_features, out_features, bias)
  1094. nn.init.xavier_uniform_(m.weight)
  1095. if bias:
  1096. nn.init.constant_(m.bias, 0.0)
  1097. return m
  1098. def base_architecture(args):
  1099. args.encoder_embed_path = getattr(args, 'encoder_embed_path', None)
  1100. args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
  1101. args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 2048)
  1102. args.encoder_layers = getattr(args, 'encoder_layers', 6)
  1103. args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
  1104. args.encoder_normalize_before = getattr(args, 'encoder_normalize_before',
  1105. False)
  1106. args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', False)
  1107. args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
  1108. args.decoder_embed_dim = getattr(args, 'decoder_embed_dim',
  1109. args.encoder_embed_dim)
  1110. args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim',
  1111. args.encoder_ffn_embed_dim)
  1112. args.decoder_layers = getattr(args, 'decoder_layers', 6)
  1113. args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
  1114. args.decoder_normalize_before = getattr(args, 'decoder_normalize_before',
  1115. False)
  1116. args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
  1117. args.attention_dropout = getattr(args, 'attention_dropout', 0.0)
  1118. args.activation_dropout = getattr(args, 'activation_dropout', 0.0)
  1119. args.activation_fn = getattr(args, 'activation_fn', 'relu')
  1120. args.dropout = getattr(args, 'dropout', 0.1)
  1121. args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff',
  1122. None)
  1123. args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout',
  1124. 0)
  1125. args.share_decoder_input_output_embed = getattr(
  1126. args, 'share_decoder_input_output_embed', False)
  1127. args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
  1128. args.no_token_positional_embeddings = getattr(
  1129. args, 'no_token_positional_embeddings', False)
  1130. args.adaptive_input = getattr(args, 'adaptive_input', False)
  1131. args.no_cross_attention = getattr(args, 'no_cross_attention', False)
  1132. args.cross_self_attention = getattr(args, 'cross_self_attention', False)
  1133. args.decoder_output_dim = getattr(args, 'decoder_output_dim',
  1134. args.decoder_embed_dim)
  1135. args.decoder_input_dim = getattr(args, 'decoder_input_dim',
  1136. args.decoder_embed_dim)
  1137. args.no_scale_embedding = getattr(args, 'no_scale_embedding', False)
  1138. args.layernorm_embedding = getattr(args, 'layernorm_embedding', False)
  1139. args.tie_adaptive_weights = getattr(args, 'tie_adaptive_weights', False)
  1140. args.checkpoint_activations = getattr(args, 'checkpoint_activations',
  1141. False)
  1142. args.offload_activations = getattr(args, 'offload_activations', False)
  1143. if args.offload_activations:
  1144. args.checkpoint_activations = True
  1145. args.encoder_layers_to_keep = getattr(args, 'encoder_layers_to_keep', None)
  1146. args.decoder_layers_to_keep = getattr(args, 'decoder_layers_to_keep', None)
  1147. args.encoder_layerdrop = getattr(args, 'encoder_layerdrop', 0)
  1148. args.decoder_layerdrop = getattr(args, 'decoder_layerdrop', 0)
  1149. args.quant_noise_pq = getattr(args, 'quant_noise_pq', 0)
  1150. args.quant_noise_pq_block_size = getattr(args, 'quant_noise_pq_block_size',
  1151. 8)
  1152. args.quant_noise_scalar = getattr(args, 'quant_noise_scalar', 0)
  1153. def transformer_deep(args):
  1154. args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
  1155. args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072)
  1156. args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12)
  1157. args.encoder_layers = getattr(args, 'encoder_layers', 24)
  1158. args.encoder_normalize_before = getattr(args, 'encoder_normalize_before',
  1159. True)
  1160. args.decoder_normalize_before = getattr(args, 'decoder_normalize_before',
  1161. True)
  1162. args.decoder_layers = getattr(args, 'decoder_layers', 3)
  1163. args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 768)
  1164. args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 3072)
  1165. args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 12)
  1166. args.attention_dropout = getattr(args, 'attention_dropout', 0.01)
  1167. args.activation_dropout = getattr(args, 'activation_dropout', 0.01)
  1168. args.dropout = getattr(args, 'dropout', 0.01)
  1169. base_architecture(args)