alphafold.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
  2. # and is publicly available at https://github.com/dptech-corp/Uni-Fold.
  3. import torch
  4. import torch.nn as nn
  5. from unicore.utils import tensor_tree_map
  6. from ..data import residue_constants
  7. from .attentions import gen_msa_attn_mask, gen_tri_attn_mask
  8. from .auxillary_heads import AuxiliaryHeads
  9. from .common import residual
  10. from .embedders import (ExtraMSAEmbedder, InputEmbedder, RecyclingEmbedder,
  11. TemplateAngleEmbedder, TemplatePairEmbedder)
  12. from .evoformer import EvoformerStack, ExtraMSAStack
  13. from .featurization import (atom14_to_atom37, build_extra_msa_feat,
  14. build_template_angle_feat,
  15. build_template_pair_feat,
  16. build_template_pair_feat_v2, pseudo_beta_fn)
  17. from .structure_module import StructureModule
  18. from .template import (TemplatePairStack, TemplatePointwiseAttention,
  19. TemplateProjection)
  20. class AlphaFold(nn.Module):
  21. def __init__(self, config):
  22. super(AlphaFold, self).__init__()
  23. self.globals = config.globals
  24. config = config.model
  25. template_config = config.template
  26. extra_msa_config = config.extra_msa
  27. self.input_embedder = InputEmbedder(
  28. **config['input_embedder'],
  29. use_chain_relative=config.is_multimer,
  30. )
  31. self.recycling_embedder = RecyclingEmbedder(
  32. **config['recycling_embedder'], )
  33. if config.template.enabled:
  34. self.template_angle_embedder = TemplateAngleEmbedder(
  35. **template_config['template_angle_embedder'], )
  36. self.template_pair_embedder = TemplatePairEmbedder(
  37. **template_config['template_pair_embedder'], )
  38. self.template_pair_stack = TemplatePairStack(
  39. **template_config['template_pair_stack'], )
  40. else:
  41. self.template_pair_stack = None
  42. self.enable_template_pointwise_attention = template_config[
  43. 'template_pointwise_attention'].enabled
  44. if self.enable_template_pointwise_attention:
  45. self.template_pointwise_att = TemplatePointwiseAttention(
  46. **template_config['template_pointwise_attention'], )
  47. else:
  48. self.template_proj = TemplateProjection(
  49. **template_config['template_pointwise_attention'], )
  50. self.extra_msa_embedder = ExtraMSAEmbedder(
  51. **extra_msa_config['extra_msa_embedder'], )
  52. self.extra_msa_stack = ExtraMSAStack(
  53. **extra_msa_config['extra_msa_stack'], )
  54. self.evoformer = EvoformerStack(**config['evoformer_stack'], )
  55. self.structure_module = StructureModule(**config['structure_module'], )
  56. self.aux_heads = AuxiliaryHeads(config['heads'], )
  57. self.config = config
  58. self.dtype = torch.float
  59. self.inf = self.globals.inf
  60. if self.globals.alphafold_original_mode:
  61. self.alphafold_original_mode()
  62. def __make_input_float__(self):
  63. self.input_embedder = self.input_embedder.float()
  64. self.recycling_embedder = self.recycling_embedder.float()
  65. def half(self):
  66. super().half()
  67. if (not getattr(self, 'inference', False)):
  68. self.__make_input_float__()
  69. self.dtype = torch.half
  70. return self
  71. def bfloat16(self):
  72. super().bfloat16()
  73. if (not getattr(self, 'inference', False)):
  74. self.__make_input_float__()
  75. self.dtype = torch.bfloat16
  76. return self
  77. def alphafold_original_mode(self):
  78. def set_alphafold_original_mode(module):
  79. if hasattr(module, 'apply_alphafold_original_mode'):
  80. module.apply_alphafold_original_mode()
  81. if hasattr(module, 'act'):
  82. module.act = nn.ReLU()
  83. self.apply(set_alphafold_original_mode)
  84. def inference_mode(self):
  85. def set_inference_mode(module):
  86. setattr(module, 'inference', True)
  87. self.apply(set_inference_mode)
  88. def __convert_input_dtype__(self, batch):
  89. for key in batch:
  90. # only convert features with mask
  91. if batch[key].dtype != self.dtype and 'mask' in key:
  92. batch[key] = batch[key].type(self.dtype)
  93. return batch
  94. def embed_templates_pair_core(self, batch, z, pair_mask,
  95. tri_start_attn_mask, tri_end_attn_mask,
  96. templ_dim, multichain_mask_2d):
  97. if self.config.template.template_pair_embedder.v2_feature:
  98. t = build_template_pair_feat_v2(
  99. batch,
  100. inf=self.config.template.inf,
  101. eps=self.config.template.eps,
  102. multichain_mask_2d=multichain_mask_2d,
  103. **self.config.template.distogram,
  104. )
  105. num_template = t[0].shape[-4]
  106. single_templates = [
  107. self.template_pair_embedder([x[..., ti, :, :, :]
  108. for x in t], z)
  109. for ti in range(num_template)
  110. ]
  111. else:
  112. t = build_template_pair_feat(
  113. batch,
  114. inf=self.config.template.inf,
  115. eps=self.config.template.eps,
  116. **self.config.template.distogram,
  117. )
  118. single_templates = [
  119. self.template_pair_embedder(x, z)
  120. for x in torch.unbind(t, dim=templ_dim)
  121. ]
  122. t = self.template_pair_stack(
  123. single_templates,
  124. pair_mask,
  125. tri_start_attn_mask=tri_start_attn_mask,
  126. tri_end_attn_mask=tri_end_attn_mask,
  127. templ_dim=templ_dim,
  128. chunk_size=self.globals.chunk_size,
  129. block_size=self.globals.block_size,
  130. return_mean=not self.enable_template_pointwise_attention,
  131. )
  132. return t
  133. def embed_templates_pair(self, batch, z, pair_mask, tri_start_attn_mask,
  134. tri_end_attn_mask, templ_dim):
  135. if self.config.template.template_pair_embedder.v2_feature and 'asym_id' in batch:
  136. multichain_mask_2d = (
  137. batch['asym_id'][..., :, None] == batch['asym_id'][...,
  138. None, :])
  139. multichain_mask_2d = multichain_mask_2d.unsqueeze(0)
  140. else:
  141. multichain_mask_2d = None
  142. if self.training or self.enable_template_pointwise_attention:
  143. t = self.embed_templates_pair_core(batch, z, pair_mask,
  144. tri_start_attn_mask,
  145. tri_end_attn_mask, templ_dim,
  146. multichain_mask_2d)
  147. if self.enable_template_pointwise_attention:
  148. t = self.template_pointwise_att(
  149. t,
  150. z,
  151. template_mask=batch['template_mask'],
  152. chunk_size=self.globals.chunk_size,
  153. )
  154. t_mask = torch.sum(
  155. batch['template_mask'], dim=-1, keepdims=True) > 0
  156. t_mask = t_mask[..., None, None].type(t.dtype)
  157. t *= t_mask
  158. else:
  159. t = self.template_proj(t, z)
  160. else:
  161. template_aatype_shape = batch['template_aatype'].shape
  162. # template_aatype is either [n_template, n_res] or [1, n_template_, n_res]
  163. batch_templ_dim = 1 if len(template_aatype_shape) == 3 else 0
  164. n_templ = batch['template_aatype'].shape[batch_templ_dim]
  165. if n_templ <= 0:
  166. t = None
  167. else:
  168. template_batch = {
  169. k: v
  170. for k, v in batch.items() if k.startswith('template_')
  171. }
  172. def embed_one_template(i):
  173. def slice_template_tensor(t):
  174. s = [slice(None) for _ in t.shape]
  175. s[batch_templ_dim] = slice(i, i + 1)
  176. return t[s]
  177. template_feats = tensor_tree_map(
  178. slice_template_tensor,
  179. template_batch,
  180. )
  181. t = self.embed_templates_pair_core(
  182. template_feats, z, pair_mask, tri_start_attn_mask,
  183. tri_end_attn_mask, templ_dim, multichain_mask_2d)
  184. return t
  185. t = embed_one_template(0)
  186. # iterate templates one by one
  187. for i in range(1, n_templ):
  188. t += embed_one_template(i)
  189. t /= n_templ
  190. t = self.template_proj(t, z)
  191. return t
  192. def embed_templates_angle(self, batch):
  193. template_angle_feat, template_angle_mask = build_template_angle_feat(
  194. batch,
  195. v2_feature=self.config.template.template_pair_embedder.v2_feature)
  196. t = self.template_angle_embedder(template_angle_feat)
  197. return t, template_angle_mask
  198. def iteration_evoformer(self, feats, m_1_prev, z_prev, x_prev):
  199. batch_dims = feats['target_feat'].shape[:-2]
  200. n = feats['target_feat'].shape[-2]
  201. seq_mask = feats['seq_mask']
  202. pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
  203. msa_mask = feats['msa_mask']
  204. m, z = self.input_embedder(
  205. feats['target_feat'],
  206. feats['msa_feat'],
  207. )
  208. if m_1_prev is None:
  209. m_1_prev = m.new_zeros(
  210. (*batch_dims, n, self.config.input_embedder.d_msa),
  211. requires_grad=False,
  212. )
  213. if z_prev is None:
  214. z_prev = z.new_zeros(
  215. (*batch_dims, n, n, self.config.input_embedder.d_pair),
  216. requires_grad=False,
  217. )
  218. if x_prev is None:
  219. x_prev = z.new_zeros(
  220. (*batch_dims, n, residue_constants.atom_type_num, 3),
  221. requires_grad=False,
  222. )
  223. x_prev = pseudo_beta_fn(feats['aatype'], x_prev, None)
  224. z += self.recycling_embedder.recyle_pos(x_prev)
  225. m_1_prev_emb, z_prev_emb = self.recycling_embedder(
  226. m_1_prev,
  227. z_prev,
  228. )
  229. m[..., 0, :, :] += m_1_prev_emb
  230. z += z_prev_emb
  231. z += self.input_embedder.relpos_emb(
  232. feats['residue_index'].long(),
  233. feats.get('sym_id', None),
  234. feats.get('asym_id', None),
  235. feats.get('entity_id', None),
  236. feats.get('num_sym', None),
  237. )
  238. m = m.type(self.dtype)
  239. z = z.type(self.dtype)
  240. tri_start_attn_mask, tri_end_attn_mask = gen_tri_attn_mask(
  241. pair_mask, self.inf)
  242. if self.config.template.enabled:
  243. template_mask = feats['template_mask']
  244. if torch.any(template_mask):
  245. z = residual(
  246. z,
  247. self.embed_templates_pair(
  248. feats,
  249. z,
  250. pair_mask,
  251. tri_start_attn_mask,
  252. tri_end_attn_mask,
  253. templ_dim=-4,
  254. ),
  255. self.training,
  256. )
  257. if self.config.extra_msa.enabled:
  258. a = self.extra_msa_embedder(build_extra_msa_feat(feats))
  259. extra_msa_row_mask = gen_msa_attn_mask(
  260. feats['extra_msa_mask'],
  261. inf=self.inf,
  262. gen_col_mask=False,
  263. )
  264. z = self.extra_msa_stack(
  265. a,
  266. z,
  267. msa_mask=feats['extra_msa_mask'],
  268. chunk_size=self.globals.chunk_size,
  269. block_size=self.globals.block_size,
  270. pair_mask=pair_mask,
  271. msa_row_attn_mask=extra_msa_row_mask,
  272. msa_col_attn_mask=None,
  273. tri_start_attn_mask=tri_start_attn_mask,
  274. tri_end_attn_mask=tri_end_attn_mask,
  275. )
  276. if self.config.template.embed_angles:
  277. template_1d_feat, template_1d_mask = self.embed_templates_angle(
  278. feats)
  279. m = torch.cat([m, template_1d_feat], dim=-3)
  280. msa_mask = torch.cat([feats['msa_mask'], template_1d_mask], dim=-2)
  281. msa_row_mask, msa_col_mask = gen_msa_attn_mask(
  282. msa_mask,
  283. inf=self.inf,
  284. )
  285. m, z, s = self.evoformer(
  286. m,
  287. z,
  288. msa_mask=msa_mask,
  289. pair_mask=pair_mask,
  290. msa_row_attn_mask=msa_row_mask,
  291. msa_col_attn_mask=msa_col_mask,
  292. tri_start_attn_mask=tri_start_attn_mask,
  293. tri_end_attn_mask=tri_end_attn_mask,
  294. chunk_size=self.globals.chunk_size,
  295. block_size=self.globals.block_size,
  296. )
  297. return m, z, s, msa_mask, m_1_prev_emb, z_prev_emb
  298. def iteration_evoformer_structure_module(self,
  299. batch,
  300. m_1_prev,
  301. z_prev,
  302. x_prev,
  303. cycle_no,
  304. num_recycling,
  305. num_ensembles=1):
  306. z, s = 0, 0
  307. n_seq = batch['msa_feat'].shape[-3]
  308. assert num_ensembles >= 1
  309. for ensemble_no in range(num_ensembles):
  310. idx = cycle_no * num_ensembles + ensemble_no
  311. # fetch_cur_batch = lambda t: t[min(t.shape[0] - 1, idx), ...]
  312. def fetch_cur_batch(t):
  313. return t[min(t.shape[0] - 1, idx), ...]
  314. feats = tensor_tree_map(fetch_cur_batch, batch)
  315. m, z0, s0, msa_mask, m_1_prev_emb, z_prev_emb = self.iteration_evoformer(
  316. feats, m_1_prev, z_prev, x_prev)
  317. z += z0
  318. s += s0
  319. del z0, s0
  320. if num_ensembles > 1:
  321. z /= float(num_ensembles)
  322. s /= float(num_ensembles)
  323. outputs = {}
  324. outputs['msa'] = m[..., :n_seq, :, :]
  325. outputs['pair'] = z
  326. outputs['single'] = s
  327. # norm loss
  328. if (not getattr(self, 'inference',
  329. False)) and num_recycling == (cycle_no + 1):
  330. delta_msa = m
  331. delta_msa[...,
  332. 0, :, :] = delta_msa[...,
  333. 0, :, :] - m_1_prev_emb.detach()
  334. delta_pair = z - z_prev_emb.detach()
  335. outputs['delta_msa'] = delta_msa
  336. outputs['delta_pair'] = delta_pair
  337. outputs['msa_norm_mask'] = msa_mask
  338. outputs['sm'] = self.structure_module(
  339. s,
  340. z,
  341. feats['aatype'],
  342. mask=feats['seq_mask'],
  343. )
  344. outputs['final_atom_positions'] = atom14_to_atom37(
  345. outputs['sm']['positions'], feats)
  346. outputs['final_atom_mask'] = feats['atom37_atom_exists']
  347. outputs['pred_frame_tensor'] = outputs['sm']['frames'][-1]
  348. # use float32 for numerical stability
  349. if (not getattr(self, 'inference', False)):
  350. m_1_prev = m[..., 0, :, :].float()
  351. z_prev = z.float()
  352. x_prev = outputs['final_atom_positions'].float()
  353. else:
  354. m_1_prev = m[..., 0, :, :]
  355. z_prev = z
  356. x_prev = outputs['final_atom_positions']
  357. return outputs, m_1_prev, z_prev, x_prev
  358. def forward(self, batch):
  359. m_1_prev = batch.get('m_1_prev', None)
  360. z_prev = batch.get('z_prev', None)
  361. x_prev = batch.get('x_prev', None)
  362. is_grad_enabled = torch.is_grad_enabled()
  363. num_iters = int(batch['num_recycling_iters']) + 1
  364. num_ensembles = int(batch['msa_mask'].shape[0]) // num_iters
  365. if self.training:
  366. # don't use ensemble during training
  367. assert num_ensembles == 1
  368. # convert dtypes in batch
  369. batch = self.__convert_input_dtype__(batch)
  370. for cycle_no in range(num_iters):
  371. is_final_iter = cycle_no == (num_iters - 1)
  372. with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
  373. (
  374. outputs,
  375. m_1_prev,
  376. z_prev,
  377. x_prev,
  378. ) = self.iteration_evoformer_structure_module(
  379. batch,
  380. m_1_prev,
  381. z_prev,
  382. x_prev,
  383. cycle_no=cycle_no,
  384. num_recycling=num_iters,
  385. num_ensembles=num_ensembles,
  386. )
  387. if not is_final_iter:
  388. del outputs
  389. if 'asym_id' in batch:
  390. outputs['asym_id'] = batch['asym_id'][0, ...]
  391. outputs.update(self.aux_heads(outputs))
  392. return outputs