modeling_esmfold.py 84 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309
  1. # coding=utf-8
  2. # Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import math
  16. import sys
  17. from collections.abc import Sequence
  18. from dataclasses import dataclass
  19. from functools import partial
  20. from typing import Callable, Optional, Union
  21. import numpy as np
  22. import torch
  23. import torch.nn as nn
  24. from torch.nn import LayerNorm
  25. from ...integrations.deepspeed import is_deepspeed_available
  26. from ...modeling_outputs import ModelOutput
  27. from ...utils import (
  28. ContextManagers,
  29. auto_docstring,
  30. is_scipy_available,
  31. logging,
  32. )
  33. from .modeling_esm import EsmModel, EsmPreTrainedModel
  34. from .openfold_utils import (
  35. OFProtein,
  36. Rigid,
  37. Rotation,
  38. atom14_to_atom37,
  39. chunk_layer,
  40. compute_predicted_aligned_error,
  41. compute_tm,
  42. frames_and_literature_positions_to_atom14_pos,
  43. make_atom14_masks,
  44. residue_constants,
  45. to_pdb,
  46. torsion_angles_to_frames,
  47. )
  48. logger = logging.get_logger(__name__)
  49. @dataclass
  50. @auto_docstring(
  51. custom_intro="""
  52. Output type of [`EsmForProteinFoldingOutput`].
  53. """
  54. )
  55. class EsmForProteinFoldingOutput(ModelOutput):
  56. r"""
  57. frames (`torch.FloatTensor`):
  58. Output frames.
  59. sidechain_frames (`torch.FloatTensor`):
  60. Output sidechain frames.
  61. unnormalized_angles (`torch.FloatTensor`):
  62. Predicted unnormalized backbone and side chain torsion angles.
  63. angles (`torch.FloatTensor`):
  64. Predicted backbone and side chain torsion angles.
  65. positions (`torch.FloatTensor`):
  66. Predicted positions of the backbone and side chain atoms.
  67. states (`torch.FloatTensor`):
  68. Hidden states from the protein folding trunk.
  69. s_s (`torch.FloatTensor`):
  70. Per-residue embeddings derived by concatenating the hidden states of each layer of the ESM-2 LM stem.
  71. s_z (`torch.FloatTensor`):
  72. Pairwise residue embeddings.
  73. distogram_logits (`torch.FloatTensor`):
  74. Input logits to the distogram used to compute residue distances.
  75. lm_logits (`torch.FloatTensor`):
  76. Logits output by the ESM-2 protein language model stem.
  77. aatype (`torch.FloatTensor`):
  78. Input amino acids (AlphaFold2 indices).
  79. atom14_atom_exists (`torch.FloatTensor`):
  80. Whether each atom exists in the atom14 representation.
  81. residx_atom14_to_atom37 (`torch.FloatTensor`):
  82. Mapping between atoms in the atom14 and atom37 representations.
  83. residx_atom37_to_atom14 (`torch.FloatTensor`):
  84. Mapping between atoms in the atom37 and atom14 representations.
  85. atom37_atom_exists (`torch.FloatTensor`):
  86. Whether each atom exists in the atom37 representation.
  87. residue_index (`torch.FloatTensor`):
  88. The index of each residue in the protein chain. Unless internal padding tokens are used, this will just be
  89. a sequence of integers from 0 to `sequence_length`.
  90. lddt_head (`torch.FloatTensor`):
  91. Raw outputs from the lddt head used to compute plddt.
  92. plddt (`torch.FloatTensor`):
  93. Per-residue confidence scores. Regions of low confidence may indicate areas where the model's prediction is
  94. uncertain, or where the protein structure is disordered.
  95. ptm_logits (`torch.FloatTensor`):
  96. Raw logits used for computing ptm.
  97. ptm (`torch.FloatTensor`):
  98. TM-score output representing the model's high-level confidence in the overall structure.
  99. aligned_confidence_probs (`torch.FloatTensor`):
  100. Per-residue confidence scores for the aligned structure.
  101. predicted_aligned_error (`torch.FloatTensor`):
  102. Predicted error between the model's prediction and the ground truth.
  103. max_predicted_aligned_error (`torch.FloatTensor`):
  104. Per-sample maximum predicted error.
  105. """
  106. frames: Optional[torch.FloatTensor] = None
  107. sidechain_frames: Optional[torch.FloatTensor] = None
  108. unnormalized_angles: Optional[torch.FloatTensor] = None
  109. angles: Optional[torch.FloatTensor] = None
  110. positions: Optional[torch.FloatTensor] = None
  111. states: Optional[torch.FloatTensor] = None
  112. s_s: Optional[torch.FloatTensor] = None
  113. s_z: Optional[torch.FloatTensor] = None
  114. distogram_logits: Optional[torch.FloatTensor] = None
  115. lm_logits: Optional[torch.FloatTensor] = None
  116. aatype: Optional[torch.FloatTensor] = None
  117. atom14_atom_exists: Optional[torch.FloatTensor] = None
  118. residx_atom14_to_atom37: Optional[torch.FloatTensor] = None
  119. residx_atom37_to_atom14: Optional[torch.FloatTensor] = None
  120. atom37_atom_exists: Optional[torch.FloatTensor] = None
  121. residue_index: Optional[torch.FloatTensor] = None
  122. lddt_head: Optional[torch.FloatTensor] = None
  123. plddt: Optional[torch.FloatTensor] = None
  124. ptm_logits: Optional[torch.FloatTensor] = None
  125. ptm: Optional[torch.FloatTensor] = None
  126. aligned_confidence_probs: Optional[torch.FloatTensor] = None
  127. predicted_aligned_error: Optional[torch.FloatTensor] = None
  128. max_predicted_aligned_error: Optional[torch.FloatTensor] = None
  129. def is_fp16_enabled(device_type):
  130. # Autocast world
  131. autocast_dtype = (
  132. torch.get_autocast_dtype(device_type)
  133. if hasattr(torch, "get_autocast_dtype")
  134. else torch.get_autocast_gpu_dtype()
  135. )
  136. fp16_enabled = autocast_dtype == torch.float16
  137. fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
  138. return fp16_enabled
  139. def is_deepspeed_initialized():
  140. if is_deepspeed_available():
  141. return False
  142. else:
  143. try:
  144. import deepspeed
  145. # This is not available in all DeepSpeed versions.
  146. return deepspeed.utils.is_initialized()
  147. except Exception:
  148. return False
  149. def collate_dense_tensors(samples: list[torch.Tensor], pad_v: float = 0) -> torch.Tensor:
  150. """
  151. Takes a list of tensors with the following dimensions:
  152. [(d_11, ..., d_1K),
  153. (d_21, ..., d_2K), ..., (d_N1, ..., d_NK)]
  154. and stack + pads them into a single tensor of:
  155. (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
  156. """
  157. if len(samples) == 0:
  158. return torch.Tensor()
  159. if len({x.dim() for x in samples}) != 1:
  160. raise RuntimeError(f"Samples has varying dimensions: {[x.dim() for x in samples]}")
  161. (device,) = tuple({x.device for x in samples}) # assumes all on same device
  162. max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
  163. result = torch.empty(len(samples), *max_shape, dtype=samples[0].dtype, device=device)
  164. result.fill_(pad_v)
  165. for i in range(len(samples)):
  166. result_i = result[i]
  167. t = samples[i]
  168. result_i[tuple(slice(0, k) for k in t.shape)] = t
  169. return result
  170. def flatten_final_dims(t: torch.Tensor, no_dims: int):
  171. return t.reshape(t.shape[:-no_dims] + (-1,))
  172. def permute_final_dims(tensor: torch.Tensor, inds: list[int]):
  173. zero_index = -1 * len(inds)
  174. first_inds = list(range(len(tensor.shape[:zero_index])))
  175. return tensor.permute(first_inds + [zero_index + i for i in inds])
  176. def dict_multimap(fn, dicts):
  177. first = dicts[0]
  178. new_dict = {}
  179. for k, v in first.items():
  180. all_v = [d[k] for d in dicts]
  181. if isinstance(v, dict):
  182. new_dict[k] = dict_multimap(fn, all_v)
  183. else:
  184. new_dict[k] = fn(all_v)
  185. return new_dict
  186. def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
  187. shape = weights.shape
  188. scale = scale / max(1, shape[1])
  189. if not is_scipy_available():
  190. logger.warning(
  191. "This init requires scipy, but scipy was not found, default to an approximation that might not be"
  192. " equivalent."
  193. )
  194. std = math.sqrt(scale)
  195. torch.nn.init.normal_(weights, std=std).clamp(min=0.0, max=2.0 * std)
  196. else:
  197. from scipy.stats import truncnorm
  198. std = math.sqrt(scale) / truncnorm.std(a=-2, b=2, loc=0, scale=1)
  199. samples = truncnorm.rvs(a=-2, b=2, loc=0, scale=std, size=weights.numel())
  200. samples = np.reshape(samples, shape)
  201. weights.copy_(torch.tensor(samples, device=weights.device))
  202. def ipa_point_weights_init_(weights):
  203. with torch.no_grad():
  204. softplus_inverse_1 = 0.541324854612918
  205. weights.fill_(softplus_inverse_1)
  206. class EsmFoldLinear(nn.Linear):
  207. """
  208. A Linear layer with built-in nonstandard initializations. Called just like torch.nn.Linear.
  209. Implements the initializers in 1.11.4, plus some additional ones found in the code.
  210. """
  211. def __init__(
  212. self,
  213. in_dim: int,
  214. out_dim: int,
  215. bias: bool = True,
  216. init: str = "default",
  217. init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
  218. ):
  219. """
  220. Args:
  221. in_dim:
  222. The final dimension of inputs to the layer
  223. out_dim:
  224. The final dimension of layer outputs
  225. bias:
  226. Whether to learn an additive bias. True by default
  227. init:
  228. The initializer to use. Choose from:
  229. "default": LeCun fan-in truncated normal initialization "relu": He initialization w/ truncated normal
  230. distribution "glorot": Fan-average Glorot uniform initialization "gating": Weights=0, Bias=1 "normal":
  231. Normal initialization with std=1/sqrt(fan_in) "final": Weights=0, Bias=0
  232. Overridden by init_fn if the latter is not None.
  233. init_fn:
  234. A custom initializer taking weight and bias as inputs. Overrides init if not None.
  235. """
  236. super().__init__(in_dim, out_dim, bias=bias)
  237. if bias:
  238. with torch.no_grad():
  239. self.bias.fill_(0)
  240. self.init = init
  241. self.init_fn = init_fn
  242. if init not in ["default", "relu", "glorot", "gating", "normal", "final"]:
  243. raise ValueError("Invalid init string.")
  244. class EsmFoldLayerNorm(nn.Module):
  245. def __init__(self, c_in, eps=1e-5):
  246. super().__init__()
  247. self.c_in = (c_in,)
  248. self.eps = eps
  249. self.weight = nn.Parameter(torch.ones(c_in))
  250. self.bias = nn.Parameter(torch.zeros(c_in))
  251. def forward(self, x):
  252. d = x.dtype
  253. if d is torch.bfloat16 and not is_deepspeed_initialized():
  254. with torch.autocast(device_type="cuda", enabled=False):
  255. out = nn.functional.layer_norm(x, self.c_in, self.weight.to(dtype=d), self.bias.to(dtype=d), self.eps)
  256. else:
  257. out = nn.functional.layer_norm(x, self.c_in, self.weight, self.bias, self.eps)
  258. return out
  259. @torch.jit.ignore
  260. def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
  261. """
  262. Softmax, but without automatic casting to fp32 when the input is of type bfloat16
  263. """
  264. d = t.dtype
  265. if d is torch.bfloat16 and not is_deepspeed_initialized():
  266. with torch.autocast(device_type="cuda", enabled=False):
  267. s = torch.nn.functional.softmax(t, dim=dim)
  268. else:
  269. s = torch.nn.functional.softmax(t, dim=dim)
  270. return s
  271. class EsmFoldAttention(nn.Module):
  272. """
  273. Standard multi-head attention using AlphaFold's default layer initialization. Allows multiple bias vectors.
  274. """
  275. def __init__(
  276. self,
  277. c_q: int,
  278. c_k: int,
  279. c_v: int,
  280. c_hidden: int,
  281. no_heads: int,
  282. gating: bool = True,
  283. ):
  284. """
  285. Args:
  286. c_q:
  287. Input dimension of query data
  288. c_k:
  289. Input dimension of key data
  290. c_v:
  291. Input dimension of value data
  292. c_hidden:
  293. Per-head hidden dimension
  294. no_heads:
  295. Number of attention heads
  296. gating:
  297. Whether the output should be gated using query data
  298. """
  299. super().__init__()
  300. self.c_q = c_q
  301. self.c_k = c_k
  302. self.c_v = c_v
  303. self.c_hidden = c_hidden
  304. self.no_heads = no_heads
  305. self.gating = gating
  306. # DISCREPANCY: c_hidden is not the per-head channel dimension, as
  307. # stated in the supplement, but the overall channel dimension.
  308. self.linear_q = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot")
  309. self.linear_k = EsmFoldLinear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot")
  310. self.linear_v = EsmFoldLinear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot")
  311. self.linear_o = EsmFoldLinear(self.c_hidden * self.no_heads, self.c_q, init="final")
  312. self.linear_g = None
  313. if self.gating:
  314. self.linear_g = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, init="gating")
  315. self.sigmoid = nn.Sigmoid()
  316. def _prep_qkv(self, q_x: torch.Tensor, kv_x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  317. # [*, Q/K/V, H * C_hidden]
  318. q = self.linear_q(q_x)
  319. k = self.linear_k(kv_x)
  320. v = self.linear_v(kv_x)
  321. # [*, Q/K, H, C_hidden]
  322. q = q.view(q.shape[:-1] + (self.no_heads, -1))
  323. k = k.view(k.shape[:-1] + (self.no_heads, -1))
  324. v = v.view(v.shape[:-1] + (self.no_heads, -1))
  325. # [*, H, Q/K, C_hidden]
  326. q = q.transpose(-2, -3)
  327. k = k.transpose(-2, -3)
  328. v = v.transpose(-2, -3)
  329. q /= math.sqrt(self.c_hidden)
  330. return q, k, v
  331. def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor:
  332. if self.linear_g is not None:
  333. g = self.sigmoid(self.linear_g(q_x))
  334. # [*, Q, H, C_hidden]
  335. g = g.view(g.shape[:-1] + (self.no_heads, -1))
  336. o = o * g
  337. # [*, Q, H * C_hidden]
  338. o = flatten_final_dims(o, 2)
  339. # [*, Q, C_q]
  340. o = self.linear_o(o)
  341. return o
  342. def forward(
  343. self,
  344. q_x: torch.Tensor,
  345. kv_x: torch.Tensor,
  346. biases: Optional[list[torch.Tensor]] = None,
  347. use_memory_efficient_kernel: bool = False,
  348. use_lma: bool = False,
  349. lma_q_chunk_size: int = 1024,
  350. lma_kv_chunk_size: int = 4096,
  351. use_flash: bool = False,
  352. flash_mask: Optional[torch.Tensor] = None,
  353. ) -> torch.Tensor:
  354. """
  355. Args:
  356. q_x:
  357. [*, Q, C_q] query data
  358. kv_x:
  359. [*, K, C_k] key data
  360. biases:
  361. List of biases that broadcast to [*, H, Q, K]
  362. use_memory_efficient_kernel:
  363. Whether to use a custom memory-efficient attention kernel. This should be the default choice for most.
  364. If none of the "use_<...>" flags are True, a stock PyTorch implementation is used instead
  365. use_lma:
  366. Whether to use low-memory attention (Staats & Rabe 2021). If none of the "use_<...>" flags are True, a
  367. stock PyTorch implementation is used instead
  368. lma_q_chunk_size:
  369. Query chunk size (for LMA)
  370. lma_kv_chunk_size:
  371. Key/Value chunk size (for LMA)
  372. Returns
  373. [*, Q, C_q] attention update
  374. """
  375. if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None):
  376. raise ValueError("If use_lma is specified, lma_q_chunk_size and lma_kv_chunk_size must be provided")
  377. if use_flash and biases is not None:
  378. raise ValueError("use_flash is incompatible with the bias option. For masking, use flash_mask instead")
  379. attn_options = [use_memory_efficient_kernel, use_lma, use_flash]
  380. if sum(attn_options) > 1:
  381. raise ValueError("Choose at most one alternative attention algorithm")
  382. if biases is None:
  383. biases = []
  384. # [*, H, Q/K, C_hidden]
  385. query, key, value = self._prep_qkv(q_x, kv_x)
  386. key = permute_final_dims(key, (1, 0))
  387. # [*, H, Q, K]
  388. output = torch.matmul(query, key)
  389. for b in biases:
  390. output += b
  391. output = softmax_no_cast(output, -1)
  392. # [*, H, Q, C_hidden]
  393. output = torch.matmul(output, value)
  394. output = output.transpose(-2, -3)
  395. output = self._wrap_up(output, q_x)
  396. return output
  397. class EsmFoldTriangleAttention(nn.Module):
  398. def __init__(self, c_in, c_hidden, no_heads, starting=True, inf=1e9):
  399. """
  400. Args:
  401. c_in:
  402. Input channel dimension
  403. c_hidden:
  404. Overall hidden channel dimension (not per-head)
  405. no_heads:
  406. Number of attention heads
  407. """
  408. super().__init__()
  409. self.c_in = c_in
  410. self.c_hidden = c_hidden
  411. self.no_heads = no_heads
  412. self.starting = starting
  413. self.inf = inf
  414. self.layer_norm = LayerNorm(self.c_in)
  415. self.linear = EsmFoldLinear(c_in, self.no_heads, bias=False, init="normal")
  416. self.mha = EsmFoldAttention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads)
  417. @torch.jit.ignore
  418. def _chunk(
  419. self,
  420. x: torch.Tensor,
  421. biases: list[torch.Tensor],
  422. chunk_size: int,
  423. use_memory_efficient_kernel: bool = False,
  424. use_lma: bool = False,
  425. inplace_safe: bool = False,
  426. ) -> torch.Tensor:
  427. "triangle! triangle!"
  428. mha_inputs = {
  429. "q_x": x,
  430. "kv_x": x,
  431. "biases": biases,
  432. }
  433. return chunk_layer(
  434. partial(self.mha, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma),
  435. mha_inputs,
  436. chunk_size=chunk_size,
  437. no_batch_dims=len(x.shape[:-2]),
  438. _out=x if inplace_safe else None,
  439. )
  440. def forward(
  441. self,
  442. x: torch.Tensor,
  443. mask: Optional[torch.Tensor] = None,
  444. chunk_size: Optional[int] = None,
  445. use_memory_efficient_kernel: bool = False,
  446. use_lma: bool = False,
  447. inplace_safe: bool = False,
  448. ) -> torch.Tensor:
  449. """
  450. Args:
  451. x:
  452. [*, I, J, C_in] input tensor (e.g. the pair representation)
  453. Returns:
  454. [*, I, J, C_in] output tensor
  455. """
  456. if mask is None:
  457. # [*, I, J]
  458. mask = x.new_ones(
  459. x.shape[:-1],
  460. )
  461. if not self.starting:
  462. x = x.transpose(-2, -3)
  463. mask = mask.transpose(-1, -2)
  464. # [*, I, J, C_in]
  465. x = self.layer_norm(x)
  466. # [*, I, 1, 1, J]
  467. mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
  468. # [*, H, I, J]
  469. triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
  470. # [*, 1, H, I, J]
  471. triangle_bias = triangle_bias.unsqueeze(-4)
  472. biases = [mask_bias, triangle_bias]
  473. if chunk_size is not None:
  474. x = self._chunk(
  475. x,
  476. biases,
  477. chunk_size,
  478. use_memory_efficient_kernel=use_memory_efficient_kernel,
  479. use_lma=use_lma,
  480. inplace_safe=inplace_safe,
  481. )
  482. else:
  483. x = self.mha(
  484. q_x=x, kv_x=x, biases=biases, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma
  485. )
  486. if not self.starting:
  487. x = x.transpose(-2, -3)
  488. return x
  489. class EsmFoldTriangleMultiplicativeUpdate(nn.Module):
  490. """
  491. Implements Algorithms 11 and 12.
  492. """
  493. def __init__(self, config, _outgoing=True):
  494. super().__init__()
  495. c_hidden = config.pairwise_state_dim
  496. self._outgoing = _outgoing
  497. self.linear_a_p = EsmFoldLinear(c_hidden, c_hidden)
  498. self.linear_a_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
  499. self.linear_b_p = EsmFoldLinear(c_hidden, c_hidden)
  500. self.linear_b_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
  501. self.linear_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
  502. self.linear_z = EsmFoldLinear(c_hidden, c_hidden, init="final")
  503. self.layer_norm_in = LayerNorm(c_hidden)
  504. self.layer_norm_out = LayerNorm(c_hidden)
  505. self.sigmoid = nn.Sigmoid()
  506. def _combine_projections(
  507. self, a: torch.Tensor, b: torch.Tensor, _inplace_chunk_size: Optional[int] = None
  508. ) -> torch.Tensor:
  509. if self._outgoing:
  510. a = permute_final_dims(a, (2, 0, 1))
  511. b = permute_final_dims(b, (2, 1, 0))
  512. else:
  513. a = permute_final_dims(a, (2, 1, 0))
  514. b = permute_final_dims(b, (2, 0, 1))
  515. if _inplace_chunk_size is not None:
  516. # To be replaced by torch vmap
  517. for i in range(0, a.shape[-3], _inplace_chunk_size):
  518. a_chunk = a[..., i : i + _inplace_chunk_size, :, :]
  519. b_chunk = b[..., i : i + _inplace_chunk_size, :, :]
  520. a[..., i : i + _inplace_chunk_size, :, :] = torch.matmul(
  521. a_chunk,
  522. b_chunk,
  523. )
  524. p = a
  525. else:
  526. p = torch.matmul(a, b)
  527. return permute_final_dims(p, (1, 2, 0))
  528. def _inference_forward(
  529. self,
  530. z: torch.Tensor,
  531. mask: Optional[torch.Tensor] = None,
  532. inplace_chunk_size: Optional[int] = None,
  533. with_add: bool = True,
  534. ):
  535. """
  536. Args:
  537. z:
  538. A [*, N, N, C_z] pair representation
  539. mask:
  540. A [*, N, N] pair mask
  541. inplace_chunk_size:
  542. Size of chunks used in the main computation. Increase to trade memory for speed.
  543. with_add:
  544. If True, z is overwritten with (z + update). Otherwise, it is overwritten with (update).
  545. Returns:
  546. A reference to the overwritten z
  547. More memory-efficient, inference-only version of the forward function. Uses in-place operations, fusion of the
  548. addition that happens after this module in the Evoformer, a smidge of recomputation, and a cache of overwritten
  549. values to lower peak memory consumption of this module from 5x the size of the input tensor z to 2.5x its size.
  550. Useful for inference on extremely long sequences.
  551. It works as follows. We will make reference to variables used in the default forward implementation below.
  552. Naively, triangle multiplication attention requires the manifestation of 5 tensors the size of z: 1) z, the
  553. "square" input tensor, 2) a, the first projection of z, 3) b, the second projection of b, 4) g, a z-sized mask,
  554. and 5) a z-sized tensor for intermediate computations. For large N, this is prohibitively expensive; for
  555. N=4000, for example, z is more than 8GB alone. To avoid this problem, we compute b, g, and all intermediate
  556. tensors in small chunks, noting that the chunks required to compute a chunk of the output depend only on the
  557. tensor a and corresponding vertical and horizontal chunks of z. This suggests an algorithm that loops over
  558. pairs of chunks of z: hereafter "columns" and "rows" of z, even though each "column" and "row" in fact contains
  559. inplace_chunk_size contiguous true columns and rows of z. Writing output chunks to a new tensor would bring
  560. total memory consumption down to 3x the size of z. However, more memory can be saved by writing output chunks
  561. directly to z in-place. WLOG, we choose to write output chunks vertically, overwriting the ith "column" of z at
  562. the end of the ith iteration of the main loop. Despite this overwriting, the ith column is always one column
  563. ahead of previously overwritten columns and can be recovered directly from z. After the first iteration,
  564. however, the ith row of z is always at least partially overwritten. For this reason, we introduce the z-cache,
  565. a tensor one-half the size of z. The z-cache initially contains the left half (2nd and 3rd quadrants) of z. For
  566. 0 < i < N/2, the missing left part of the ith row of z is recovered from this cache at the beginning of the ith
  567. iteration. Once i exceeds n/2, the cache is "reoriented" to encompass the 3rd and 4th quadrants of z instead.
  568. Though the 3rd quadrant of the original z is entirely overwritten at this point, it can be recovered from the
  569. z-cache itself. Thereafter, the ith row of z can be recovered in its entirety from the reoriented z-cache.
  570. After the final iteration, z has been completely overwritten and contains the triangular multiplicative update.
  571. If with_add is True, it instead contains the sum of z and the triangular multiplicative update. In either case,
  572. peak memory consumption is just 2.5x the size of z, disregarding memory used for chunks and other small
  573. variables.
  574. """
  575. if mask is None:
  576. mask = z.new_ones(z.shape[:-1])
  577. mask = mask.unsqueeze(-1)
  578. def compute_projection_helper(pair, mask, a=True):
  579. if a:
  580. linear_g = self.linear_a_g
  581. linear_p = self.linear_a_p
  582. else:
  583. linear_g = self.linear_b_g
  584. linear_p = self.linear_b_p
  585. pair = self.layer_norm_in(pair)
  586. p = linear_g(pair)
  587. p.sigmoid_()
  588. p *= linear_p(pair)
  589. p *= mask
  590. p = permute_final_dims(p, (2, 0, 1))
  591. return p
  592. def compute_projection(pair, mask, a=True, chunked=True):
  593. need_transpose = self._outgoing ^ a
  594. if not chunked:
  595. p = compute_projection_helper(pair, mask, a)
  596. if need_transpose:
  597. p = p.transpose(-1, -2)
  598. else:
  599. # This computation is chunked so as not to exceed our 2.5x
  600. # budget with a large intermediate tensor
  601. linear_g = self.linear_a_g if a else self.linear_b_g
  602. c = linear_g.bias.shape[-1]
  603. out_shape = pair.shape[:-3] + (c,) + pair.shape[-3:-1]
  604. p = pair.new_zeros(out_shape)
  605. for i in range(0, pair.shape[-3], inplace_chunk_size):
  606. pair_chunk = pair[..., i : i + inplace_chunk_size, :, :]
  607. pair_chunk = compute_projection_helper(
  608. pair[..., i : i + inplace_chunk_size, :, :],
  609. mask[..., i : i + inplace_chunk_size, :, :],
  610. a,
  611. )
  612. if need_transpose:
  613. pair_chunk = pair_chunk.transpose(-1, -2)
  614. p[..., i : i + inplace_chunk_size] = pair_chunk
  615. else:
  616. p[..., i : i + inplace_chunk_size, :] = pair_chunk
  617. del pair_chunk
  618. return p
  619. # We start by fully manifesting a. In addition to the input, this
  620. # brings total memory consumption to 2x z (disregarding size of chunks)
  621. # [*, N, N, c]
  622. a = compute_projection(z, mask, True, chunked=True)
  623. if inplace_chunk_size is not None:
  624. n = a.shape[-1]
  625. half_n = n // 2 + n % 2
  626. row_dim = -3
  627. col_dim = -2
  628. b_chunk_dim = row_dim if self._outgoing else col_dim
  629. def empty_slicer(t):
  630. return [slice(None) for _ in t.shape]
  631. def slice_tensor(t, start, end, dim):
  632. # Slices start:end from the dim dimension of t
  633. s = empty_slicer(t)
  634. s[dim] = slice(start, end)
  635. return t[s]
  636. def flip_z_cache_(z_cache, z):
  637. # "Reorient" the z_cache (see below), filling it with quadrants
  638. # 3---recovered from the z_cache---and 4---recovered from z---
  639. # of the input tensor z.
  640. quadrant_3 = slice_tensor(z_cache, half_n, None, row_dim)
  641. z_cache = z_cache.transpose(row_dim, col_dim)
  642. # If n is odd, we need to shrink the z_cache by one row
  643. z_cache = z_cache[..., : (n // 2), :, :]
  644. # Move the 3rd quadrant of z into the
  645. first_half_slicer = empty_slicer(z_cache)
  646. first_half_slicer[col_dim] = slice(0, half_n)
  647. z_cache[first_half_slicer] = quadrant_3
  648. # Get the fourth quadrant of z
  649. quadrant_4 = slice_tensor(z, half_n, None, row_dim)
  650. quadrant_4 = slice_tensor(quadrant_4, half_n, None, col_dim)
  651. # Insert said quadrant into the rotated z-cache
  652. quadrant_3_slicer = empty_slicer(z_cache)
  653. quadrant_3_slicer[col_dim] = slice(half_n, None)
  654. z_cache[quadrant_3_slicer] = quadrant_4
  655. return z_cache
  656. # Initialize the z cache to the left half of z.
  657. z_cache_shape = list(z.shape)
  658. z_cache_shape[col_dim] = half_n
  659. z_cache = z.new_zeros(z_cache_shape)
  660. z_cache_slicer = empty_slicer(z_cache)
  661. z_cache_slicer[col_dim] = slice(0, half_n)
  662. z_cache.copy_(z[z_cache_slicer])
  663. z_cache_rotated = False
  664. # We need to reorient the z-cache at the halfway point, and we
  665. # don't want a single chunk to straddle that point. We contract one
  666. # of the chunks in the middle to address that problem.
  667. i_range = list(range(0, half_n, inplace_chunk_size))
  668. initial_offsets = [i_2 - i_1 for i_1, i_2 in zip(i_range, i_range[1:] + [half_n])]
  669. after_half = list(range(half_n, n, inplace_chunk_size))
  670. after_half_offsets = [inplace_chunk_size for _ in after_half]
  671. combined_range_with_offsets = zip(i_range + after_half, initial_offsets + after_half_offsets)
  672. for i, offset in combined_range_with_offsets:
  673. if not z_cache_rotated and i >= half_n:
  674. z_cache = flip_z_cache_(z_cache, z)
  675. z_cache_rotated = True
  676. z_chunk_b = slice_tensor(z, i, i + offset, b_chunk_dim)
  677. mask_chunk = slice_tensor(mask, i, i + offset, b_chunk_dim)
  678. z_chunk_b = z_chunk_b.clone()
  679. if b_chunk_dim == col_dim:
  680. z_chunk_b = slice_tensor(z, i, i + offset, col_dim)
  681. else: # b_chunk_dim == row_dim
  682. # In this case, the b-dimension (b_chunk_dim) is partially
  683. # overwritten at the end of each iteration. We need to
  684. # restore the missing component from the z-cache.
  685. if not z_cache_rotated:
  686. z_chunk_slicer = empty_slicer(z_chunk_b)
  687. z_chunk_slicer[col_dim] = slice(0, half_n)
  688. z_chunk_b[z_chunk_slicer] = slice_tensor(z_cache, i, i + offset, row_dim)
  689. else:
  690. z_cache_offset = i - half_n
  691. z_chunk_b = slice_tensor(z_cache, z_cache_offset, z_cache_offset + offset, row_dim)
  692. b_chunk = compute_projection(z_chunk_b, mask_chunk, a=False, chunked=False)
  693. del z_chunk_b
  694. x_chunk = torch.matmul(a, b_chunk)
  695. x_chunk = permute_final_dims(x_chunk, (1, 2, 0))
  696. x_chunk = self.layer_norm_out(x_chunk)
  697. x_chunk = self.linear_z(x_chunk)
  698. # The g dimension (col_dim) is parallel to and ahead of the
  699. # overwrites in z. We can extract the g chunk normally.
  700. z_chunk_g = slice_tensor(z, i, i + offset, col_dim)
  701. g_chunk = self.linear_g(self.layer_norm_in(z_chunk_g))
  702. g_chunk.sigmoid_()
  703. del z_chunk_g
  704. x_chunk *= g_chunk
  705. # Write the columns into z in-place
  706. z_slicer = empty_slicer(z)
  707. z_slicer[col_dim] = slice(i, i + offset)
  708. if with_add:
  709. z[z_slicer] += x_chunk
  710. else:
  711. z[z_slicer] = x_chunk
  712. else:
  713. b = compute_projection(z, mask, False, False)
  714. x = torch.matmul(a, b)
  715. x = self.layer_norm_out(x)
  716. x = self.linear_z(x)
  717. g = self.linear_g(z)
  718. g.sigmoid_()
  719. x *= g
  720. if with_add:
  721. z += x
  722. else:
  723. z = x
  724. return z
  725. def forward(
  726. self,
  727. z: torch.Tensor,
  728. mask: Optional[torch.Tensor] = None,
  729. inplace_safe: bool = False,
  730. _add_with_inplace: bool = False,
  731. _inplace_chunk_size: Optional[int] = 256,
  732. ) -> torch.Tensor:
  733. """
  734. Args:
  735. x:
  736. [*, N_res, N_res, C_z] input tensor
  737. mask:
  738. [*, N_res, N_res] input mask
  739. Returns:
  740. [*, N_res, N_res, C_z] output tensor
  741. """
  742. if inplace_safe:
  743. x = self._inference_forward(
  744. z,
  745. mask,
  746. inplace_chunk_size=_inplace_chunk_size,
  747. with_add=_add_with_inplace,
  748. )
  749. return x
  750. if mask is None:
  751. mask = z.new_ones(z.shape[:-1])
  752. mask = mask.unsqueeze(-1)
  753. z = self.layer_norm_in(z)
  754. a = mask
  755. a = a * self.sigmoid(self.linear_a_g(z))
  756. a = a * self.linear_a_p(z)
  757. b = mask
  758. b = b * self.sigmoid(self.linear_b_g(z))
  759. b = b * self.linear_b_p(z)
  760. device_type = a.device.type if a.device.type != "mps" else "cpu"
  761. if is_fp16_enabled(device_type):
  762. with torch.autocast(device_type=device_type, enabled=False):
  763. x = self._combine_projections(a.float(), b.float())
  764. else:
  765. x = self._combine_projections(a, b)
  766. del a, b
  767. x = self.layer_norm_out(x)
  768. x = self.linear_z(x)
  769. g = self.sigmoid(self.linear_g(z))
  770. x = x * g
  771. return x
  772. class EsmFoldPreTrainedModel(EsmPreTrainedModel):
  773. """
  774. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  775. models.
  776. """
  777. # Subclass `EsMPreTrainedModel` to deal with special init
  778. def _init_weights(self, module):
  779. """Initialize the weights"""
  780. if isinstance(module, EsmFoldLinear):
  781. with torch.no_grad():
  782. if module.init_fn is not None:
  783. module.init_fn(module.weight, module.bias)
  784. elif module.init == "default":
  785. trunc_normal_init_(module.weight, scale=1.0)
  786. elif module.init == "relu":
  787. trunc_normal_init_(module.weight, scale=2.0)
  788. elif module.init == "glorot":
  789. nn.init.xavier_uniform_(module.weight, gain=1)
  790. elif module.init == "gating":
  791. module.weight.fill_(0.0)
  792. if module.bias:
  793. module.bias.fill_(1.0)
  794. elif module.init == "normal":
  795. torch.nn.init.kaiming_normal_(module.weight, nonlinearity="linear")
  796. elif module.init == "final":
  797. module.weight.fill_(0.0)
  798. elif isinstance(module, EsmFoldInvariantPointAttention):
  799. ipa_point_weights_init_(module.head_weights)
  800. elif isinstance(module, EsmFoldTriangularSelfAttentionBlock):
  801. torch.nn.init.zeros_(module.tri_mul_in.linear_z.weight)
  802. torch.nn.init.zeros_(module.tri_mul_in.linear_z.bias)
  803. torch.nn.init.zeros_(module.tri_mul_out.linear_z.weight)
  804. torch.nn.init.zeros_(module.tri_mul_out.linear_z.bias)
  805. torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.weight)
  806. torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.bias)
  807. torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.weight)
  808. torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.bias)
  809. torch.nn.init.zeros_(module.sequence_to_pair.o_proj.weight)
  810. torch.nn.init.zeros_(module.sequence_to_pair.o_proj.bias)
  811. torch.nn.init.zeros_(module.pair_to_sequence.linear.weight)
  812. torch.nn.init.zeros_(module.seq_attention.o_proj.weight)
  813. torch.nn.init.zeros_(module.seq_attention.o_proj.bias)
  814. torch.nn.init.zeros_(module.mlp_seq.mlp[-2].weight)
  815. torch.nn.init.zeros_(module.mlp_seq.mlp[-2].bias)
  816. torch.nn.init.zeros_(module.mlp_pair.mlp[-2].weight)
  817. torch.nn.init.zeros_(module.mlp_pair.mlp[-2].bias)
  818. else:
  819. super()._init_weights(module)
  820. class EsmFoldSelfAttention(nn.Module):
  821. def __init__(self, embed_dim, num_heads, head_width, gated=False):
  822. super().__init__()
  823. assert embed_dim == num_heads * head_width
  824. self.embed_dim = embed_dim
  825. self.num_heads = num_heads
  826. self.head_width = head_width
  827. self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False)
  828. self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True)
  829. self.gated = gated
  830. if gated:
  831. self.g_proj = nn.Linear(embed_dim, embed_dim)
  832. torch.nn.init.zeros_(self.g_proj.weight)
  833. torch.nn.init.ones_(self.g_proj.bias)
  834. self.rescale_factor = self.head_width**-0.5
  835. torch.nn.init.zeros_(self.o_proj.bias)
  836. def forward(self, x, mask=None, bias=None, indices=None):
  837. """
  838. Basic self attention with optional mask and external pairwise bias. To handle sequences of different lengths,
  839. use mask.
  840. Inputs:
  841. x: batch of input sequences (.. x L x C) mask: batch of boolean masks where 1=valid, 0=padding position (..
  842. x L_k) bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads)
  843. Outputs:
  844. sequence projection (B x L x embed_dim), attention maps (B x L x L x num_heads)
  845. """
  846. t = self.proj(x).view(*x.shape[:2], self.num_heads, -1)
  847. t = t.permute(0, 2, 1, 3)
  848. q, k, v = t.chunk(3, dim=-1)
  849. q = self.rescale_factor * q
  850. a = torch.einsum("...qc,...kc->...qk", q, k)
  851. # Add external attention bias.
  852. if bias is not None:
  853. a = a + bias.permute(0, 3, 1, 2)
  854. # Do not attend to padding tokens.
  855. if mask is not None:
  856. mask = mask[:, None, None]
  857. a = a.masked_fill(mask == False, -np.inf) # noqa: E712
  858. a = nn.functional.softmax(a, dim=-1)
  859. y = torch.einsum("...hqk,...hkc->...qhc", a, v)
  860. y = y.reshape(*y.shape[:2], -1)
  861. if self.gated:
  862. y = self.g_proj(x).sigmoid() * y
  863. y = self.o_proj(y)
  864. return y, a.permute(0, 3, 1, 2)
  865. class EsmFoldDropout(nn.Module):
  866. """
  867. Implementation of dropout with the ability to share the dropout mask along a particular dimension.
  868. """
  869. def __init__(self, r: float, batch_dim: Union[int, list[int]]):
  870. super().__init__()
  871. self.r = r
  872. if isinstance(batch_dim, int):
  873. batch_dim = [batch_dim]
  874. self.batch_dim = batch_dim
  875. self.dropout = nn.Dropout(self.r)
  876. def forward(self, x: torch.Tensor) -> torch.Tensor:
  877. shape = list(x.shape)
  878. if self.batch_dim is not None:
  879. for bd in self.batch_dim:
  880. shape[bd] = 1
  881. return x * self.dropout(x.new_ones(shape))
  882. class EsmFoldSequenceToPair(nn.Module):
  883. def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim):
  884. super().__init__()
  885. self.layernorm = nn.LayerNorm(sequence_state_dim)
  886. self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True)
  887. self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True)
  888. torch.nn.init.zeros_(self.proj.bias)
  889. torch.nn.init.zeros_(self.o_proj.bias)
  890. def forward(self, sequence_state):
  891. """
  892. Inputs:
  893. sequence_state: B x L x sequence_state_dim
  894. Output:
  895. pairwise_state: B x L x L x pairwise_state_dim
  896. Intermediate state:
  897. B x L x L x 2*inner_dim
  898. """
  899. assert len(sequence_state.shape) == 3
  900. s = self.layernorm(sequence_state)
  901. s = self.proj(s)
  902. q, k = s.chunk(2, dim=-1)
  903. prod = q[:, None, :, :] * k[:, :, None, :]
  904. diff = q[:, None, :, :] - k[:, :, None, :]
  905. x = torch.cat([prod, diff], dim=-1)
  906. x = self.o_proj(x)
  907. return x
  908. class EsmFoldPairToSequence(nn.Module):
  909. def __init__(self, pairwise_state_dim, num_heads):
  910. super().__init__()
  911. self.layernorm = nn.LayerNorm(pairwise_state_dim)
  912. self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False)
  913. def forward(self, pairwise_state):
  914. """
  915. Inputs:
  916. pairwise_state: B x L x L x pairwise_state_dim
  917. Output:
  918. pairwise_bias: B x L x L x num_heads
  919. """
  920. assert len(pairwise_state.shape) == 4
  921. z = self.layernorm(pairwise_state)
  922. pairwise_bias = self.linear(z)
  923. return pairwise_bias
  924. class EsmFoldResidueMLP(nn.Module):
  925. def __init__(self, embed_dim, inner_dim, dropout=0):
  926. super().__init__()
  927. self.mlp = nn.Sequential(
  928. nn.LayerNorm(embed_dim),
  929. nn.Linear(embed_dim, inner_dim),
  930. nn.ReLU(),
  931. nn.Linear(inner_dim, embed_dim),
  932. nn.Dropout(dropout),
  933. )
  934. def forward(self, x):
  935. return x + self.mlp(x)
  936. class EsmFoldTriangularSelfAttentionBlock(nn.Module):
  937. def __init__(self, config):
  938. super().__init__()
  939. self.config = config
  940. sequence_state_dim = config.sequence_state_dim
  941. pairwise_state_dim = config.pairwise_state_dim
  942. sequence_num_heads = sequence_state_dim // config.sequence_head_width
  943. pairwise_num_heads = pairwise_state_dim // config.pairwise_head_width
  944. self.layernorm_1 = nn.LayerNorm(sequence_state_dim)
  945. self.sequence_to_pair = EsmFoldSequenceToPair(sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim)
  946. self.pair_to_sequence = EsmFoldPairToSequence(pairwise_state_dim, sequence_num_heads)
  947. self.seq_attention = EsmFoldSelfAttention(
  948. sequence_state_dim, sequence_num_heads, config.sequence_head_width, gated=True
  949. )
  950. self.tri_mul_out = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=True)
  951. self.tri_mul_in = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=False)
  952. self.tri_att_start = EsmFoldTriangleAttention(
  953. pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=True
  954. )
  955. self.tri_att_end = EsmFoldTriangleAttention(
  956. pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=False
  957. )
  958. self.mlp_seq = EsmFoldResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=config.dropout)
  959. self.mlp_pair = EsmFoldResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=config.dropout)
  960. self.drop = nn.Dropout(config.dropout)
  961. self.row_drop = EsmFoldDropout(config.dropout * 2, 2)
  962. self.col_drop = EsmFoldDropout(config.dropout * 2, 1)
  963. def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs):
  964. """
  965. Inputs:
  966. sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim mask: B x L boolean
  967. tensor of valid positions
  968. Output:
  969. sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim
  970. """
  971. if len(sequence_state.shape) != 3:
  972. raise ValueError(f"`sequence_state` should be a 3d-tensor, got {len(sequence_state.shape)} dims.")
  973. if len(pairwise_state.shape) != 4:
  974. raise ValueError(f"`pairwise_state` should be a 4d-tensor, got {len(pairwise_state.shape)} dims.")
  975. if mask is not None and len(mask.shape) != 2:
  976. raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.")
  977. batch_dim, seq_dim, sequence_state_dim = sequence_state.shape
  978. pairwise_state_dim = pairwise_state.shape[3]
  979. if sequence_state_dim != self.config.sequence_state_dim:
  980. raise ValueError(
  981. "`sequence_state` last dimension should be equal to `self.sequence_state_dim`. Got "
  982. f"{sequence_state_dim} != {self.config.sequence_state_dim}."
  983. )
  984. if pairwise_state_dim != self.config.pairwise_state_dim:
  985. raise ValueError(
  986. "`pairwise_state` last dimension should be equal to `self.pairwise_state_dim`. Got "
  987. f"{pairwise_state_dim} != {self.config.pairwise_state_dim}."
  988. )
  989. if batch_dim != pairwise_state.shape[0]:
  990. raise ValueError(
  991. f"`sequence_state` and `pairwise_state` have inconsistent batch size: {batch_dim} != "
  992. f"{pairwise_state.shape[0]}."
  993. )
  994. if seq_dim != pairwise_state.shape[1] or seq_dim != pairwise_state.shape[2]:
  995. raise ValueError(
  996. f"`sequence_state` and `pairwise_state` have inconsistent sequence length: {seq_dim} != "
  997. f"{pairwise_state.shape[1]} or {pairwise_state.shape[2]}."
  998. )
  999. # Update sequence state
  1000. bias = self.pair_to_sequence(pairwise_state)
  1001. # Self attention with bias + mlp.
  1002. y = self.layernorm_1(sequence_state)
  1003. y, _ = self.seq_attention(y, mask=mask, bias=bias)
  1004. sequence_state = sequence_state + self.drop(y)
  1005. sequence_state = self.mlp_seq(sequence_state)
  1006. # Update pairwise state
  1007. pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state)
  1008. # Axial attention with triangular bias.
  1009. tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None
  1010. pairwise_state = pairwise_state + self.row_drop(self.tri_mul_out(pairwise_state, mask=tri_mask))
  1011. pairwise_state = pairwise_state + self.col_drop(self.tri_mul_in(pairwise_state, mask=tri_mask))
  1012. pairwise_state = pairwise_state + self.row_drop(
  1013. self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
  1014. )
  1015. pairwise_state = pairwise_state + self.col_drop(
  1016. self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
  1017. )
  1018. # MLP over pairs.
  1019. pairwise_state = self.mlp_pair(pairwise_state)
  1020. return sequence_state, pairwise_state
  1021. class EsmCategoricalMixture:
  1022. def __init__(self, param, bins=50, start=0, end=1):
  1023. # All tensors are of shape ..., bins.
  1024. self.logits = param
  1025. bins = torch.linspace(start, end, bins + 1, device=self.logits.device, dtype=self.logits.dtype)
  1026. self.v_bins = (bins[:-1] + bins[1:]) / 2
  1027. def log_prob(self, true):
  1028. # Shapes are:
  1029. # self.probs: ... x bins
  1030. # true : ...
  1031. true_index = (true.unsqueeze(-1) - self.v_bins[[None] * true.ndim]).abs().argmin(-1)
  1032. nll = self.logits.log_softmax(-1)
  1033. return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1)
  1034. def mean(self):
  1035. return (self.logits.softmax(-1) @ self.v_bins.unsqueeze(1)).squeeze(-1)
  1036. def categorical_lddt(logits, bins=50):
  1037. # Logits are ..., 37, bins.
  1038. return EsmCategoricalMixture(logits, bins=bins).mean()
  1039. def get_axial_mask(mask):
  1040. """
  1041. Helper to convert B x L mask of valid positions to axial mask used in row column attentions.
  1042. Input:
  1043. mask: B x L tensor of booleans
  1044. Output:
  1045. mask: B x L x L tensor of booleans
  1046. """
  1047. if mask is None:
  1048. return None
  1049. if len(mask.shape) != 2:
  1050. raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.")
  1051. batch_dim, seq_dim = mask.shape
  1052. m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim)
  1053. m = m.reshape(batch_dim * seq_dim, seq_dim)
  1054. return m
  1055. class EsmFoldRelativePosition(nn.Module):
  1056. def __init__(self, config):
  1057. super().__init__()
  1058. self.bins = config.position_bins
  1059. # Note an additional offset is used so that the 0th position
  1060. # is reserved for masked pairs.
  1061. self.embedding = torch.nn.Embedding(2 * self.bins + 2, config.pairwise_state_dim)
  1062. def forward(self, residue_index, mask=None):
  1063. """
  1064. Input:
  1065. residue_index: B x L tensor of indices (dtype=torch.long) mask: B x L tensor of booleans
  1066. Output:
  1067. pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings
  1068. """
  1069. if residue_index.dtype != torch.long:
  1070. raise ValueError(f"`residue_index` has dtype {residue_index.dtype}, it should be `torch.long`.")
  1071. if mask is not None and residue_index.shape != mask.shape:
  1072. raise ValueError(
  1073. f"`residue_index` and `mask` have inconsistent shapes: {residue_index.shape} != {mask.shape}."
  1074. )
  1075. diff = residue_index[:, None, :] - residue_index[:, :, None]
  1076. diff = diff.clamp(-self.bins, self.bins)
  1077. diff = diff + self.bins + 1 # Add 1 to adjust for padding index.
  1078. if mask is not None:
  1079. mask = mask[:, None, :] * mask[:, :, None]
  1080. diff[mask == False] = 0 # noqa: E712
  1081. output = self.embedding(diff)
  1082. return output
  1083. class EsmFoldAngleResnetBlock(nn.Module):
  1084. def __init__(self, config):
  1085. super().__init__()
  1086. self.linear_1 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="relu")
  1087. self.linear_2 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="final")
  1088. self.relu = nn.ReLU()
  1089. def forward(self, a: torch.Tensor) -> torch.Tensor:
  1090. s_initial = a
  1091. a = self.relu(a)
  1092. a = self.linear_1(a)
  1093. a = self.relu(a)
  1094. a = self.linear_2(a)
  1095. return a + s_initial
  1096. class EsmFoldAngleResnet(nn.Module):
  1097. """
  1098. Implements Algorithm 20, lines 11-14
  1099. """
  1100. def __init__(self, config):
  1101. super().__init__()
  1102. self.config = config
  1103. self.linear_in = EsmFoldLinear(config.sequence_dim, config.resnet_dim)
  1104. self.linear_initial = EsmFoldLinear(config.sequence_dim, config.resnet_dim)
  1105. self.layers = nn.ModuleList()
  1106. for _ in range(config.num_resnet_blocks):
  1107. layer = EsmFoldAngleResnetBlock(config)
  1108. self.layers.append(layer)
  1109. self.linear_out = EsmFoldLinear(config.resnet_dim, config.num_angles * 2)
  1110. self.relu = nn.ReLU()
  1111. def forward(self, s: torch.Tensor, s_initial: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  1112. """
  1113. Args:
  1114. s:
  1115. [*, C_hidden] single embedding
  1116. s_initial:
  1117. [*, C_hidden] single embedding as of the start of the StructureModule
  1118. Returns:
  1119. [*, no_angles, 2] predicted angles
  1120. """
  1121. # NOTE: The ReLU's applied to the inputs are absent from the supplement
  1122. # pseudocode but present in the source. For maximal compatibility with
  1123. # the pretrained weights, I'm going with the source.
  1124. # [*, C_hidden]
  1125. s_initial = self.relu(s_initial)
  1126. s_initial = self.linear_initial(s_initial)
  1127. s = self.relu(s)
  1128. s = self.linear_in(s)
  1129. s = s + s_initial
  1130. for l in self.layers:
  1131. s = l(s)
  1132. s = self.relu(s)
  1133. # [*, no_angles * 2]
  1134. s = self.linear_out(s)
  1135. # [*, no_angles, 2]
  1136. s = s.view(s.shape[:-1] + (-1, 2))
  1137. unnormalized_s = s
  1138. norm_denom = torch.sqrt(
  1139. torch.clamp(
  1140. torch.sum(s**2, dim=-1, keepdim=True),
  1141. min=self.config.epsilon,
  1142. )
  1143. )
  1144. s = s / norm_denom
  1145. return unnormalized_s, s
  1146. class EsmFoldInvariantPointAttention(nn.Module):
  1147. """
  1148. Implements Algorithm 22.
  1149. """
  1150. def __init__(self, config):
  1151. super().__init__()
  1152. self.config = config
  1153. c_s = config.sequence_dim
  1154. c_z = config.pairwise_dim
  1155. self.hidden_dim = config.ipa_dim
  1156. self.num_heads = config.num_heads_ipa
  1157. self.num_qk_points = config.num_qk_points
  1158. self.num_v_points = config.num_v_points
  1159. # These linear layers differ from their specifications in the
  1160. # supplement. There, they lack bias and use Glorot initialization.
  1161. # Here as in the official source, they have bias and use the default
  1162. # Lecun initialization.
  1163. hc = config.ipa_dim * config.num_heads_ipa
  1164. self.linear_q = EsmFoldLinear(c_s, hc)
  1165. self.linear_kv = EsmFoldLinear(c_s, 2 * hc)
  1166. hpq = config.num_heads_ipa * config.num_qk_points * 3
  1167. self.linear_q_points = EsmFoldLinear(c_s, hpq)
  1168. hpkv = config.num_heads_ipa * (config.num_qk_points + config.num_v_points) * 3
  1169. self.linear_kv_points = EsmFoldLinear(c_s, hpkv)
  1170. self.linear_b = EsmFoldLinear(c_z, config.num_heads_ipa)
  1171. self.head_weights = nn.Parameter(torch.zeros(config.num_heads_ipa))
  1172. concat_out_dim = config.num_heads_ipa * (c_z + config.ipa_dim + config.num_v_points * 4)
  1173. self.linear_out = EsmFoldLinear(concat_out_dim, c_s, init="final")
  1174. self.softmax = nn.Softmax(dim=-1)
  1175. self.softplus = nn.Softplus()
  1176. def forward(
  1177. self,
  1178. s: torch.Tensor,
  1179. z: Optional[torch.Tensor],
  1180. r: Rigid,
  1181. mask: torch.Tensor,
  1182. _offload_inference: bool = False,
  1183. _z_reference_list: Optional[Sequence[torch.Tensor]] = None,
  1184. ) -> torch.Tensor:
  1185. """
  1186. Args:
  1187. s:
  1188. [*, N_res, C_s] single representation
  1189. z:
  1190. [*, N_res, N_res, C_z] pair representation
  1191. r:
  1192. [*, N_res] transformation object
  1193. mask:
  1194. [*, N_res] mask
  1195. Returns:
  1196. [*, N_res, C_s] single representation update
  1197. """
  1198. z = [z]
  1199. #######################################
  1200. # Generate scalar and point activations
  1201. #######################################
  1202. # [*, N_res, H * C_hidden]
  1203. q = self.linear_q(s)
  1204. kv = self.linear_kv(s)
  1205. # [*, N_res, H, C_hidden]
  1206. q = q.view(q.shape[:-1] + (self.num_heads, -1))
  1207. # [*, N_res, H, 2 * C_hidden]
  1208. kv = kv.view(kv.shape[:-1] + (self.num_heads, -1))
  1209. # [*, N_res, H, C_hidden]
  1210. k, v = torch.split(kv, self.hidden_dim, dim=-1)
  1211. # [*, N_res, H * P_q * 3]
  1212. q_pts = self.linear_q_points(s)
  1213. # This is kind of clunky, but it's how the original does it
  1214. # [*, N_res, H * P_q, 3]
  1215. q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
  1216. q_pts = torch.stack(q_pts, dim=-1)
  1217. q_pts = r[..., None].apply(q_pts)
  1218. # [*, N_res, H, P_q, 3]
  1219. q_pts = q_pts.view(q_pts.shape[:-2] + (self.num_heads, self.num_qk_points, 3))
  1220. # [*, N_res, H * (P_q + P_v) * 3]
  1221. kv_pts = self.linear_kv_points(s)
  1222. # [*, N_res, H * (P_q + P_v), 3]
  1223. kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
  1224. kv_pts = torch.stack(kv_pts, dim=-1)
  1225. kv_pts = r[..., None].apply(kv_pts)
  1226. # [*, N_res, H, (P_q + P_v), 3]
  1227. kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, -1, 3))
  1228. # [*, N_res, H, P_q/P_v, 3]
  1229. k_pts, v_pts = torch.split(kv_pts, [self.num_qk_points, self.num_v_points], dim=-2)
  1230. ##########################
  1231. # Compute attention scores
  1232. ##########################
  1233. # [*, N_res, N_res, H]
  1234. b = self.linear_b(z[0])
  1235. if _offload_inference:
  1236. assert sys.getrefcount(z[0]) == 2
  1237. z[0] = z[0].cpu()
  1238. # [*, H, N_res, N_res]
  1239. device_type = q.device.type if q.device.type != "mps" else "cpu"
  1240. if is_fp16_enabled(device_type):
  1241. with torch.autocast(device_type=device_type, enabled=False):
  1242. a = torch.matmul(
  1243. permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
  1244. permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]
  1245. )
  1246. else:
  1247. a = torch.matmul(
  1248. permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
  1249. permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
  1250. )
  1251. a *= math.sqrt(1.0 / (3 * self.hidden_dim))
  1252. a += math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))
  1253. # [*, N_res, N_res, H, P_q, 3]
  1254. pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
  1255. pt_att = pt_att**2
  1256. # [*, N_res, N_res, H, P_q]
  1257. pt_att = sum(torch.unbind(pt_att, dim=-1))
  1258. head_weights = self.softplus(self.head_weights).view(*((1,) * len(pt_att.shape[:-2]) + (-1, 1)))
  1259. head_weights = head_weights * math.sqrt(1.0 / (3 * (self.num_qk_points * 9.0 / 2)))
  1260. pt_att = pt_att * head_weights
  1261. # [*, N_res, N_res, H]
  1262. pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
  1263. # [*, N_res, N_res]
  1264. square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
  1265. square_mask = self.config.inf * (square_mask - 1)
  1266. # [*, H, N_res, N_res]
  1267. pt_att = permute_final_dims(pt_att, (2, 0, 1))
  1268. a = a + pt_att
  1269. a = a + square_mask.unsqueeze(-3)
  1270. a = self.softmax(a)
  1271. ################
  1272. # Compute output
  1273. ################
  1274. # [*, N_res, H, C_hidden]
  1275. o = torch.matmul(a, v.transpose(-2, -3).to(dtype=a.dtype)).transpose(-2, -3)
  1276. # [*, N_res, H * C_hidden]
  1277. o = flatten_final_dims(o, 2)
  1278. # [*, H, 3, N_res, P_v]
  1279. o_pt = torch.sum(
  1280. (a[..., None, :, :, None] * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]),
  1281. dim=-2,
  1282. )
  1283. # [*, N_res, H, P_v, 3]
  1284. o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
  1285. o_pt = r[..., None, None].invert_apply(o_pt)
  1286. # [*, N_res, H * P_v]
  1287. o_pt_norm = flatten_final_dims(torch.sqrt(torch.sum(o_pt**2, dim=-1) + self.config.epsilon), 2)
  1288. # [*, N_res, H * P_v, 3]
  1289. o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
  1290. if _offload_inference:
  1291. z[0] = z[0].to(o_pt.device)
  1292. # [*, N_res, H, C_z]
  1293. o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype))
  1294. # [*, N_res, H * C_z]
  1295. o_pair = flatten_final_dims(o_pair, 2)
  1296. # [*, N_res, C_s]
  1297. s = self.linear_out(
  1298. torch.cat((o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1).to(dtype=z[0].dtype)
  1299. )
  1300. return s
  1301. class EsmFoldBackboneUpdate(nn.Module):
  1302. """
  1303. Implements part of Algorithm 23.
  1304. """
  1305. def __init__(self, config):
  1306. super().__init__()
  1307. self.linear = EsmFoldLinear(config.sequence_dim, 6, init="final")
  1308. def forward(self, s: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  1309. """
  1310. Args:
  1311. [*, N_res, C_s] single representation
  1312. Returns:
  1313. [*, N_res, 6] update vector
  1314. """
  1315. # [*, 6]
  1316. update = self.linear(s)
  1317. return update
  1318. class EsmFoldStructureModuleTransitionLayer(nn.Module):
  1319. def __init__(self, config):
  1320. super().__init__()
  1321. self.linear_1 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu")
  1322. self.linear_2 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu")
  1323. self.linear_3 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="final")
  1324. self.relu = nn.ReLU()
  1325. def forward(self, s):
  1326. s_initial = s
  1327. s = self.linear_1(s)
  1328. s = self.relu(s)
  1329. s = self.linear_2(s)
  1330. s = self.relu(s)
  1331. s = self.linear_3(s)
  1332. s = s + s_initial
  1333. return s
  1334. class EsmFoldStructureModuleTransition(nn.Module):
  1335. def __init__(self, config):
  1336. super().__init__()
  1337. self.config = config
  1338. self.layers = nn.ModuleList()
  1339. for _ in range(config.num_transition_layers):
  1340. l = EsmFoldStructureModuleTransitionLayer(config)
  1341. self.layers.append(l)
  1342. self.dropout = nn.Dropout(config.dropout_rate)
  1343. self.layer_norm = LayerNorm(config.sequence_dim)
  1344. def forward(self, s):
  1345. for l in self.layers:
  1346. s = l(s)
  1347. s = self.dropout(s)
  1348. s = self.layer_norm(s)
  1349. return s
  1350. class EsmFoldStructureModule(nn.Module):
  1351. def __init__(self, config):
  1352. super().__init__()
  1353. self.config = config
  1354. # Buffers to be lazily initialized later
  1355. # self.default_frames
  1356. # self.group_idx
  1357. # self.atom_mask
  1358. # self.lit_positions
  1359. self.layer_norm_s = LayerNorm(config.sequence_dim)
  1360. self.layer_norm_z = LayerNorm(config.pairwise_dim)
  1361. self.linear_in = EsmFoldLinear(config.sequence_dim, config.sequence_dim)
  1362. self.ipa = EsmFoldInvariantPointAttention(config)
  1363. self.ipa_dropout = nn.Dropout(config.dropout_rate)
  1364. self.layer_norm_ipa = LayerNorm(config.sequence_dim)
  1365. self.transition = EsmFoldStructureModuleTransition(config)
  1366. self.bb_update = EsmFoldBackboneUpdate(config)
  1367. self.angle_resnet = EsmFoldAngleResnet(config)
  1368. def forward(
  1369. self,
  1370. evoformer_output_dict,
  1371. aatype,
  1372. mask=None,
  1373. _offload_inference=False,
  1374. ):
  1375. """
  1376. Args:
  1377. evoformer_output_dict:
  1378. Dictionary containing:
  1379. "single":
  1380. [*, N_res, C_s] single representation
  1381. "pair":
  1382. [*, N_res, N_res, C_z] pair representation
  1383. aatype:
  1384. [*, N_res] amino acid indices
  1385. mask:
  1386. Optional [*, N_res] sequence mask
  1387. Returns:
  1388. A dictionary of outputs
  1389. """
  1390. s = evoformer_output_dict["single"]
  1391. if mask is None:
  1392. # [*, N]
  1393. mask = s.new_ones(s.shape[:-1])
  1394. # [*, N, C_s]
  1395. s = self.layer_norm_s(s)
  1396. # [*, N, N, C_z]
  1397. z = self.layer_norm_z(evoformer_output_dict["pair"])
  1398. z_reference_list = None
  1399. if _offload_inference:
  1400. assert sys.getrefcount(evoformer_output_dict["pair"]) == 2
  1401. evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
  1402. z_reference_list = [z]
  1403. z = None
  1404. # [*, N, C_s]
  1405. s_initial = s
  1406. s = self.linear_in(s)
  1407. # [*, N]
  1408. rigids = Rigid.identity(
  1409. s.shape[:-1],
  1410. s.dtype,
  1411. s.device,
  1412. self.training,
  1413. fmt="quat",
  1414. )
  1415. outputs = []
  1416. for i in range(self.config.num_blocks):
  1417. # [*, N, C_s]
  1418. s = s + self.ipa(
  1419. s,
  1420. z,
  1421. rigids,
  1422. mask,
  1423. _offload_inference=_offload_inference,
  1424. _z_reference_list=z_reference_list,
  1425. )
  1426. s = self.ipa_dropout(s)
  1427. s = self.layer_norm_ipa(s)
  1428. s = self.transition(s)
  1429. # [*, N]
  1430. rigids = rigids.compose_q_update_vec(self.bb_update(s))
  1431. # To hew as closely as possible to AlphaFold, we convert our
  1432. # quaternion-based transformations to rotation-matrix ones
  1433. # here
  1434. backb_to_global = Rigid(
  1435. Rotation(rot_mats=rigids.get_rots().get_rot_mats(), quats=None),
  1436. rigids.get_trans(),
  1437. )
  1438. backb_to_global = backb_to_global.scale_translation(self.config.trans_scale_factor)
  1439. # [*, N, 7, 2]
  1440. unnormalized_angles, angles = self.angle_resnet(s, s_initial)
  1441. all_frames_to_global = self.torsion_angles_to_frames(backb_to_global, angles, aatype)
  1442. pred_xyz = self.frames_and_literature_positions_to_atom14_pos(all_frames_to_global, aatype)
  1443. scaled_rigids = rigids.scale_translation(self.config.trans_scale_factor)
  1444. preds = {
  1445. "frames": scaled_rigids.to_tensor_7(),
  1446. "sidechain_frames": all_frames_to_global.to_tensor_4x4(),
  1447. "unnormalized_angles": unnormalized_angles,
  1448. "angles": angles,
  1449. "positions": pred_xyz,
  1450. "states": s,
  1451. }
  1452. outputs.append(preds)
  1453. rigids = rigids.stop_rot_gradient()
  1454. del z, z_reference_list
  1455. if _offload_inference:
  1456. evoformer_output_dict["pair"] = evoformer_output_dict["pair"].to(s.device)
  1457. outputs = dict_multimap(torch.stack, outputs)
  1458. outputs["single"] = s
  1459. return outputs
  1460. def _init_residue_constants(self, float_dtype, device):
  1461. if not hasattr(self, "default_frames"):
  1462. self.register_buffer(
  1463. "default_frames",
  1464. torch.tensor(
  1465. residue_constants.restype_rigid_group_default_frame,
  1466. dtype=float_dtype,
  1467. device=device,
  1468. requires_grad=False,
  1469. ),
  1470. persistent=False,
  1471. )
  1472. if not hasattr(self, "group_idx"):
  1473. self.register_buffer(
  1474. "group_idx",
  1475. torch.tensor(
  1476. residue_constants.restype_atom14_to_rigid_group,
  1477. device=device,
  1478. requires_grad=False,
  1479. ),
  1480. persistent=False,
  1481. )
  1482. if not hasattr(self, "atom_mask"):
  1483. self.register_buffer(
  1484. "atom_mask",
  1485. torch.tensor(
  1486. residue_constants.restype_atom14_mask,
  1487. dtype=float_dtype,
  1488. device=device,
  1489. requires_grad=False,
  1490. ),
  1491. persistent=False,
  1492. )
  1493. if not hasattr(self, "lit_positions"):
  1494. self.register_buffer(
  1495. "lit_positions",
  1496. torch.tensor(
  1497. residue_constants.restype_atom14_rigid_group_positions,
  1498. dtype=float_dtype,
  1499. device=device,
  1500. requires_grad=False,
  1501. ),
  1502. persistent=False,
  1503. )
  1504. def torsion_angles_to_frames(self, r, alpha, f):
  1505. # Lazily initialize the residue constants on the correct device
  1506. self._init_residue_constants(alpha.dtype, alpha.device)
  1507. # Separated purely to make testing less annoying
  1508. return torsion_angles_to_frames(r, alpha, f, self.default_frames)
  1509. def frames_and_literature_positions_to_atom14_pos(self, r, f): # [*, N, 8] # [*, N]
  1510. # Lazily initialize the residue constants on the correct device
  1511. self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
  1512. return frames_and_literature_positions_to_atom14_pos(
  1513. r,
  1514. f,
  1515. self.default_frames,
  1516. self.group_idx,
  1517. self.atom_mask,
  1518. self.lit_positions,
  1519. )
  1520. class EsmFoldingTrunk(nn.Module):
  1521. def __init__(self, config):
  1522. super().__init__()
  1523. self.config = config
  1524. c_s = config.sequence_state_dim
  1525. c_z = config.pairwise_state_dim
  1526. self.pairwise_positional_embedding = EsmFoldRelativePosition(config)
  1527. self.blocks = nn.ModuleList([EsmFoldTriangularSelfAttentionBlock(config) for _ in range(config.num_blocks)])
  1528. self.recycle_bins = 15
  1529. self.recycle_s_norm = nn.LayerNorm(c_s)
  1530. self.recycle_z_norm = nn.LayerNorm(c_z)
  1531. self.recycle_disto = nn.Embedding(self.recycle_bins, c_z)
  1532. self.recycle_disto.weight[0].detach().zero_()
  1533. self.structure_module = EsmFoldStructureModule(config.structure_module)
  1534. self.trunk2sm_s = nn.Linear(c_s, config.structure_module.sequence_dim)
  1535. self.trunk2sm_z = nn.Linear(c_z, config.structure_module.pairwise_dim)
  1536. self.chunk_size = config.chunk_size
  1537. def set_chunk_size(self, chunk_size):
  1538. # This parameter means the axial attention will be computed
  1539. # in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2).
  1540. # It's equivalent to running a for loop over chunks of the dimension we're iterative over,
  1541. # where the chunk_size is the size of the chunks, so 128 would mean to parse 128-length chunks.
  1542. self.chunk_size = chunk_size
  1543. def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles):
  1544. """
  1545. Inputs:
  1546. seq_feats: B x L x C tensor of sequence features pair_feats: B x L x L x C tensor of pair features residx: B
  1547. x L long tensor giving the position in the sequence mask: B x L boolean tensor indicating valid residues
  1548. Output:
  1549. predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object
  1550. """
  1551. device = seq_feats.device
  1552. s_s_0 = seq_feats
  1553. s_z_0 = pair_feats
  1554. if no_recycles is None:
  1555. no_recycles = self.config.max_recycles
  1556. else:
  1557. if no_recycles < 0:
  1558. raise ValueError("Number of recycles must not be negative.")
  1559. no_recycles += 1 # First 'recycle' is just the standard forward pass through the model.
  1560. def trunk_iter(s, z, residx, mask):
  1561. z = z + self.pairwise_positional_embedding(residx, mask=mask)
  1562. for block in self.blocks:
  1563. s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size)
  1564. return s, z
  1565. s_s = s_s_0
  1566. s_z = s_z_0
  1567. recycle_s = torch.zeros_like(s_s)
  1568. recycle_z = torch.zeros_like(s_z)
  1569. recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64)
  1570. for recycle_idx in range(no_recycles):
  1571. with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]):
  1572. # === Recycling ===
  1573. recycle_s = self.recycle_s_norm(recycle_s.detach()).to(device)
  1574. recycle_z = self.recycle_z_norm(recycle_z.detach()).to(device)
  1575. recycle_z += self.recycle_disto(recycle_bins.detach()).to(device)
  1576. s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask)
  1577. # === Structure module ===
  1578. structure = self.structure_module(
  1579. {"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)},
  1580. true_aa,
  1581. mask.float(),
  1582. )
  1583. recycle_s = s_s
  1584. recycle_z = s_z
  1585. # Distogram needs the N, CA, C coordinates, and bin constants same as alphafold.
  1586. recycle_bins = EsmFoldingTrunk.distogram(
  1587. structure["positions"][-1][:, :, :3],
  1588. 3.375,
  1589. 21.375,
  1590. self.recycle_bins,
  1591. )
  1592. structure["s_s"] = s_s
  1593. structure["s_z"] = s_z
  1594. return structure
  1595. @staticmethod
  1596. def distogram(coords, min_bin, max_bin, num_bins):
  1597. # Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates.
  1598. boundaries = torch.linspace(
  1599. min_bin,
  1600. max_bin,
  1601. num_bins - 1,
  1602. device=coords.device,
  1603. )
  1604. boundaries = boundaries**2
  1605. N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)]
  1606. # Infer CB coordinates.
  1607. b = CA - N
  1608. c = C - CA
  1609. a = b.cross(c, dim=-1)
  1610. CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
  1611. dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True)
  1612. bins = torch.sum(dists > boundaries, dim=-1) # [..., L, L]
  1613. return bins
  1614. # TODO Add information to the docstring about any methods that convert to PDB format, or otherwise prepare
  1615. # the outputs for downstream use.
  1616. @auto_docstring(
  1617. custom_intro="""
  1618. ESMForProteinFolding is the HuggingFace port of the original ESMFold model. It consists of an ESM-2 "stem" followed
  1619. by a protein folding "head", although unlike most other output heads, this "head" is similar in size and runtime to
  1620. the rest of the model combined! It outputs a dictionary containing predicted structural information about the input
  1621. protein(s).
  1622. """
  1623. )
  1624. class EsmForProteinFolding(EsmPreTrainedModel):
  1625. _no_split_modules = ["EsmFoldStructureModule", "EsmFoldTriangularSelfAttentionBlock"]
  1626. _supports_flash_attn = False
  1627. _supports_sdpa = False
  1628. _supports_attention_backend = False
  1629. _can_record_outputs = None
  1630. def __init__(self, config):
  1631. super().__init__(config)
  1632. self.config = config
  1633. self.distogram_bins = 64
  1634. self.esm = EsmModel(config, add_pooling_layer=False)
  1635. self.esm.requires_grad_(False)
  1636. if self.config.esmfold_config.fp16_esm:
  1637. self.esm.half()
  1638. self.esm_feats = self.config.hidden_size
  1639. self.esm_attns = self.config.num_hidden_layers * self.config.num_attention_heads
  1640. self.esm_layers = self.config.num_hidden_layers
  1641. self.register_buffer("af2_to_esm", self._af2_to_esm_from_vocab_list(config.vocab_list))
  1642. self.esm_s_combine = nn.Parameter(torch.zeros(self.esm_layers + 1))
  1643. trunk_config = self.config.esmfold_config.trunk
  1644. c_s = trunk_config.sequence_state_dim
  1645. c_z = trunk_config.pairwise_state_dim
  1646. self.esm_s_mlp = nn.Sequential(
  1647. LayerNorm(self.esm_feats),
  1648. nn.Linear(self.esm_feats, c_s),
  1649. nn.ReLU(),
  1650. nn.Linear(c_s, c_s),
  1651. )
  1652. # 0 is padding, N is unknown residues, N + 1 is mask.
  1653. self.n_tokens_embed = residue_constants.restype_num + 3
  1654. self.pad_idx = 0
  1655. self.unk_idx = self.n_tokens_embed - 2
  1656. self.mask_idx = self.n_tokens_embed - 1
  1657. self.esm_dict_cls_idx = self.config.vocab_list.index("<cls>")
  1658. self.esm_dict_mask_idx = self.config.vocab_list.index("<mask>")
  1659. self.esm_dict_eos_idx = self.config.vocab_list.index("<eos>")
  1660. self.esm_dict_padding_idx = self.config.vocab_list.index("<pad>")
  1661. if self.config.esmfold_config.embed_aa:
  1662. self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0)
  1663. self.trunk = EsmFoldingTrunk(trunk_config)
  1664. self.distogram_head = nn.Linear(c_z, self.distogram_bins)
  1665. self.ptm_head = nn.Linear(c_z, self.distogram_bins)
  1666. self.lm_head = nn.Linear(c_s, self.n_tokens_embed)
  1667. self.lddt_bins = 50
  1668. structure_module_config = trunk_config.structure_module
  1669. self.lddt_head = nn.Sequential(
  1670. nn.LayerNorm(structure_module_config.sequence_dim),
  1671. nn.Linear(structure_module_config.sequence_dim, self.config.esmfold_config.lddt_head_hid_dim),
  1672. nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, self.config.esmfold_config.lddt_head_hid_dim),
  1673. nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, 37 * self.lddt_bins),
  1674. )
  1675. @staticmethod
  1676. def _af2_to_esm_from_vocab_list(vocab_list: list[str]) -> torch.Tensor:
  1677. # Remember that t is shifted from residue_constants by 1 (0 is padding).
  1678. esm_reorder = [vocab_list.index("<pad>")] + [vocab_list.index(v) for v in residue_constants.restypes_with_x]
  1679. return torch.tensor(esm_reorder)
  1680. @auto_docstring
  1681. def forward(
  1682. self,
  1683. input_ids: torch.Tensor,
  1684. attention_mask: Optional[torch.Tensor] = None,
  1685. position_ids: Optional[torch.Tensor] = None,
  1686. masking_pattern: Optional[torch.Tensor] = None,
  1687. num_recycles: Optional[int] = None,
  1688. output_hidden_states: Optional[bool] = False,
  1689. ) -> EsmForProteinFoldingOutput:
  1690. r"""
  1691. masking_pattern (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1692. Locations of tokens to mask during training as a form of regularization. Mask values selected in `[0, 1]`.
  1693. num_recycles (`int`, *optional*, defaults to `None`):
  1694. Number of times to recycle the input sequence. If `None`, defaults to `config.num_recycles`. "Recycling"
  1695. consists of passing the output of the folding trunk back in as input to the trunk. During training, the
  1696. number of recycles should vary with each batch, to ensure that the model learns to output valid predictions
  1697. after each recycle. During inference, num_recycles should be set to the highest value that the model was
  1698. trained with for maximum accuracy. Accordingly, when this value is set to `None`, config.max_recycles is
  1699. used.
  1700. Example:
  1701. ```python
  1702. >>> from transformers import AutoTokenizer, EsmForProteinFolding
  1703. >>> model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1")
  1704. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
  1705. >>> inputs = tokenizer(["MLKNVQVQLV"], return_tensors="pt", add_special_tokens=False) # A tiny random peptide
  1706. >>> outputs = model(**inputs)
  1707. >>> folded_positions = outputs.positions
  1708. ```
  1709. """
  1710. cfg = self.config.esmfold_config
  1711. aa = input_ids # B x L
  1712. B = aa.shape[0]
  1713. L = aa.shape[1]
  1714. device = input_ids.device
  1715. if attention_mask is None:
  1716. attention_mask = torch.ones_like(aa, device=device)
  1717. if position_ids is None:
  1718. position_ids = torch.arange(L, device=device).expand_as(input_ids)
  1719. # === ESM ===
  1720. esmaa = self.af2_idx_to_esm_idx(aa, attention_mask)
  1721. if masking_pattern is not None:
  1722. masked_aa, esmaa, mlm_targets = self.bert_mask(aa, esmaa, attention_mask, masking_pattern)
  1723. else:
  1724. masked_aa = aa
  1725. mlm_targets = None
  1726. # We get sequence and pair representations from whatever version of ESM /
  1727. # configuration we are using. The sequence representation esm_s is always
  1728. # present. The pair embedding esm_z may be present depending on the
  1729. # configuration of the model. If esm_z is not used by the model then it
  1730. # is returned as None here.
  1731. esm_s = self.compute_language_model_representations(esmaa)
  1732. # Convert esm_s and esm_z, if present, to the precision used by the trunk and
  1733. # the structure module. These tensors may be a lower precision if, for example,
  1734. # we're running the language model in fp16 precision.
  1735. esm_s = esm_s.to(self.esm_s_combine.dtype)
  1736. if cfg.esm_ablate_sequence:
  1737. esm_s = esm_s * 0
  1738. esm_s = esm_s.detach()
  1739. # === preprocessing ===
  1740. esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2)
  1741. s_s_0 = self.esm_s_mlp(esm_s)
  1742. s_z_0 = s_s_0.new_zeros(B, L, L, cfg.trunk.pairwise_state_dim)
  1743. if self.config.esmfold_config.embed_aa:
  1744. s_s_0 += self.embedding(masked_aa)
  1745. structure: dict = self.trunk(s_s_0, s_z_0, aa, position_ids, attention_mask, no_recycles=num_recycles)
  1746. # Documenting what we expect:
  1747. structure = {
  1748. k: v
  1749. for k, v in structure.items()
  1750. if k
  1751. in [
  1752. "s_z",
  1753. "s_s",
  1754. "frames",
  1755. "sidechain_frames",
  1756. "unnormalized_angles",
  1757. "angles",
  1758. "positions",
  1759. "states",
  1760. ]
  1761. }
  1762. # Add BERT mask for the loss to use, if available.
  1763. if mlm_targets:
  1764. structure["mlm_targets"] = mlm_targets
  1765. disto_logits = self.distogram_head(structure["s_z"])
  1766. disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2
  1767. structure["distogram_logits"] = disto_logits
  1768. lm_logits = self.lm_head(structure["s_s"])
  1769. structure["lm_logits"] = lm_logits
  1770. structure["aatype"] = aa
  1771. make_atom14_masks(structure)
  1772. # Of course, this doesn't respect the true mask because it doesn't know about it...
  1773. # We're not going to properly mask change of index tensors:
  1774. # "residx_atom14_to_atom37",
  1775. # "residx_atom37_to_atom14",
  1776. for k in [
  1777. "atom14_atom_exists",
  1778. "atom37_atom_exists",
  1779. ]:
  1780. structure[k] *= attention_mask.unsqueeze(-1)
  1781. structure["residue_index"] = position_ids
  1782. lddt_head = self.lddt_head(structure["states"]).reshape(structure["states"].shape[0], B, L, -1, self.lddt_bins)
  1783. structure["lddt_head"] = lddt_head
  1784. plddt = categorical_lddt(lddt_head[-1], bins=self.lddt_bins)
  1785. structure["plddt"] = plddt
  1786. ptm_logits = self.ptm_head(structure["s_z"])
  1787. structure["ptm_logits"] = ptm_logits
  1788. structure["ptm"] = compute_tm(ptm_logits, max_bin=31, no_bins=self.distogram_bins)
  1789. structure.update(compute_predicted_aligned_error(ptm_logits, max_bin=31, no_bins=self.distogram_bins))
  1790. return EsmForProteinFoldingOutput(**structure)
  1791. def af2_idx_to_esm_idx(self, aa, mask):
  1792. # avoid indexing on different devices
  1793. if self.af2_to_esm.device != aa.device:
  1794. self.af2_to_esm = self.af2_to_esm.to(aa.device)
  1795. aa = (aa + 1).masked_fill(mask != 1, 0)
  1796. return self.af2_to_esm[aa]
  1797. def compute_language_model_representations(self, esmaa: torch.Tensor) -> torch.Tensor:
  1798. device = next(self.parameters()).device
  1799. B, L = esmaa.shape # B = batch size, L = sequence length.
  1800. if self.config.esmfold_config.bypass_lm:
  1801. esm_s = torch.zeros(B, L, self.esm_s_combine.size[0], -1, self.esm_feats, device=device)
  1802. return esm_s
  1803. bosi, eosi = self.esm_dict_cls_idx, self.esm_dict_eos_idx
  1804. bos = esmaa.new_full((B, 1), bosi)
  1805. eos = esmaa.new_full((B, 1), self.esm_dict_padding_idx)
  1806. esmaa = torch.cat([bos, esmaa, eos], dim=1)
  1807. # Use the first padding index as eos during inference.
  1808. esmaa[range(B), (esmaa != 1).sum(1)] = eosi
  1809. # _, esm_z, esm_s = self.esm(esmaa, return_pairs=self.config.esmfold_config.use_esm_attn_map)
  1810. # Because we do not support use_esm_attn_map in the HF port as it is not used in any public models,
  1811. # esm_z is always None
  1812. esm_hidden_states = self.esm(esmaa, attention_mask=esmaa != 1, output_hidden_states=True)["hidden_states"]
  1813. esm_s = torch.stack(esm_hidden_states, dim=2)
  1814. esm_s = esm_s[:, 1:-1] # B, L, nLayers, C
  1815. return esm_s
  1816. def bert_mask(self, aa, esmaa, mask, pattern):
  1817. new_aa = aa.clone()
  1818. target = aa.clone()
  1819. new_esmaa = esmaa.clone()
  1820. new_aa[pattern == 1] = self.mask_idx
  1821. target[pattern != 1] = 0
  1822. new_esmaa[pattern == 1] = self.esm_dict_mask_idx
  1823. return new_aa, new_esmaa, target
  1824. @torch.no_grad()
  1825. def infer(
  1826. self,
  1827. seqs: Union[str, list[str]],
  1828. position_ids=None,
  1829. ):
  1830. if isinstance(seqs, str):
  1831. lst = [seqs]
  1832. else:
  1833. lst = seqs
  1834. # Returns the raw outputs of the model given an input sequence.
  1835. device = next(self.parameters()).device
  1836. aatype = collate_dense_tensors(
  1837. [
  1838. torch.from_numpy(
  1839. residue_constants.sequence_to_onehot(
  1840. sequence=seq,
  1841. mapping=residue_constants.restype_order_with_x,
  1842. map_unknown_to_x=True,
  1843. )
  1844. )
  1845. .to(device)
  1846. .argmax(dim=1)
  1847. for seq in lst
  1848. ]
  1849. ) # B=1 x L
  1850. mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst])
  1851. position_ids = (
  1852. torch.arange(aatype.shape[1], device=device).expand(len(lst), -1)
  1853. if position_ids is None
  1854. else position_ids.to(device)
  1855. )
  1856. if position_ids.ndim == 1:
  1857. position_ids = position_ids.unsqueeze(0)
  1858. return self.forward(
  1859. aatype,
  1860. mask,
  1861. position_ids=position_ids,
  1862. )
  1863. @staticmethod
  1864. def output_to_pdb(output: dict) -> list[str]:
  1865. """Returns the pbd (file) string from the model given the model output."""
  1866. output = {k: v.to("cpu").numpy() for k, v in output.items()}
  1867. pdbs = []
  1868. final_atom_positions = atom14_to_atom37(output["positions"][-1], output)
  1869. final_atom_mask = output["atom37_atom_exists"]
  1870. for i in range(output["aatype"].shape[0]):
  1871. aa = output["aatype"][i]
  1872. pred_pos = final_atom_positions[i]
  1873. mask = final_atom_mask[i]
  1874. resid = output["residue_index"][i] + 1
  1875. pred = OFProtein(
  1876. aatype=aa,
  1877. atom_positions=pred_pos,
  1878. atom_mask=mask,
  1879. residue_index=resid,
  1880. b_factors=output["plddt"][i],
  1881. )
  1882. pdbs.append(to_pdb(pred))
  1883. return pdbs
  1884. def infer_pdb(self, seqs, *args, **kwargs) -> str:
  1885. """Returns the pdb (file) string from the model given an input sequence."""
  1886. assert isinstance(seqs, str)
  1887. output = self.infer(seqs, *args, **kwargs)
  1888. return self.output_to_pdb(output)[0]
  1889. def infer_pdbs(self, seqs: list[str], *args, **kwargs) -> list[str]:
  1890. """Returns the pdb (file) string from the model given an input sequence."""
  1891. output = self.infer(seqs, *args, **kwargs)
  1892. return self.output_to_pdb(output)
  1893. __all__ = ["EsmForProteinFolding", "EsmFoldPreTrainedModel"]