modeling_esm.py 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058
  1. # coding=utf-8
  2. # Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
  3. # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """PyTorch ESM model."""
  17. import math
  18. from typing import Callable, Optional, Union
  19. import torch
  20. from torch import nn
  21. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import (
  24. BaseModelOutputWithCrossAttentions,
  25. BaseModelOutputWithPoolingAndCrossAttentions,
  26. MaskedLMOutput,
  27. SequenceClassifierOutput,
  28. TokenClassifierOutput,
  29. )
  30. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  31. from ...processing_utils import Unpack
  32. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  33. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  34. from ...utils.generic import OutputRecorder, check_model_inputs
  35. from .configuration_esm import EsmConfig
  36. logger = logging.get_logger(__name__)
  37. def rotate_half(x):
  38. x1, x2 = x.chunk(2, dim=-1)
  39. return torch.cat((-x2, x1), dim=-1)
  40. def apply_rotary_pos_emb(x, cos, sin):
  41. cos = cos[:, :, : x.shape[-2], :]
  42. sin = sin[:, :, : x.shape[-2], :]
  43. return (x * cos) + (rotate_half(x) * sin)
  44. def gelu(x):
  45. """
  46. This is the gelu implementation from the original ESM repo. Using F.gelu yields subtly wrong results.
  47. """
  48. return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
  49. def symmetrize(x):
  50. "Make layer symmetric in final two dimensions, used for contact prediction."
  51. return x + x.transpose(-1, -2)
  52. def average_product_correct(x):
  53. "Perform average product correct, used for contact prediction."
  54. a1 = x.sum(-1, keepdims=True)
  55. a2 = x.sum(-2, keepdims=True)
  56. a12 = x.sum((-1, -2), keepdims=True)
  57. avg = a1 * a2
  58. avg.div_(a12) # in-place to reduce memory
  59. normalized = x - avg
  60. return normalized
  61. class RotaryEmbedding(torch.nn.Module):
  62. """
  63. Rotary position embeddings based on those in
  64. [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
  65. matrices which depend on their relative positions.
  66. """
  67. inv_freq: torch.Tensor # fix linting for `register_buffer`
  68. def __init__(self, dim: int):
  69. super().__init__()
  70. # Generate and save the inverse frequency buffer (non trainable)
  71. inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
  72. self.register_buffer("inv_freq", inv_freq)
  73. self._seq_len_cached = None
  74. self._cos_cached = None
  75. self._sin_cached = None
  76. def _update_cos_sin_tables(self, x, seq_dimension=2):
  77. seq_len = x.shape[seq_dimension]
  78. # Reset the tables if the sequence length has changed,
  79. # or if we're on a new device (possibly due to tracing for instance)
  80. if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
  81. self._seq_len_cached = seq_len
  82. t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
  83. freqs = torch.outer(t, self.inv_freq)
  84. emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
  85. self._cos_cached = emb.cos()[None, None, :, :]
  86. self._sin_cached = emb.sin()[None, None, :, :]
  87. return self._cos_cached, self._sin_cached
  88. def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  89. self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
  90. return (
  91. apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached).to(dtype=q.dtype),
  92. apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached).to(dtype=k.dtype),
  93. )
  94. class EsmContactPredictionHead(nn.Module):
  95. """Performs symmetrization, apc, and computes a logistic regression on the output features"""
  96. def __init__(
  97. self,
  98. in_features: int,
  99. bias=True,
  100. eos_idx: int = 2,
  101. ):
  102. super().__init__()
  103. self.in_features = in_features
  104. self.eos_idx = eos_idx
  105. self.regression = nn.Linear(in_features, 1, bias)
  106. self.activation = nn.Sigmoid()
  107. def forward(self, tokens, attentions):
  108. # remove eos token attentions
  109. eos_mask = tokens.ne(self.eos_idx).to(attentions)
  110. eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
  111. attentions = attentions * eos_mask[:, None, None, :, :]
  112. attentions = attentions[..., :-1, :-1]
  113. # remove cls token attentions
  114. attentions = attentions[..., 1:, 1:]
  115. batch_size, layers, heads, seqlen, _ = attentions.size()
  116. attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
  117. # features: batch x channels x tokens x tokens (symmetric)
  118. attentions = attentions.to(
  119. self.regression.weight.device
  120. ) # attentions always float32, may need to convert to float16
  121. attentions = average_product_correct(symmetrize(attentions))
  122. attentions = attentions.permute(0, 2, 3, 1)
  123. return self.activation(self.regression(attentions).squeeze(3))
  124. class EsmEmbeddings(nn.Module):
  125. """
  126. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
  127. """
  128. def __init__(self, config):
  129. super().__init__()
  130. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  131. if config.emb_layer_norm_before:
  132. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  133. else:
  134. self.layer_norm = None
  135. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  136. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  137. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  138. self.register_buffer(
  139. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  140. )
  141. self.padding_idx = config.pad_token_id
  142. if self.position_embedding_type == "absolute":
  143. self.position_embeddings = nn.Embedding(
  144. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  145. )
  146. self.token_dropout = config.token_dropout
  147. self.mask_token_id = config.mask_token_id
  148. def forward(
  149. self,
  150. input_ids=None,
  151. attention_mask=None,
  152. position_ids=None,
  153. inputs_embeds=None,
  154. ):
  155. if position_ids is None:
  156. if input_ids is not None:
  157. # Create the position ids from the input token ids. Any padded tokens remain padded.
  158. position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx)
  159. else:
  160. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
  161. if inputs_embeds is None:
  162. inputs_embeds = self.word_embeddings(input_ids)
  163. # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an
  164. # embedding_scale factor here.
  165. embeddings = inputs_embeds
  166. # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout
  167. # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however,
  168. # masked tokens are treated as if they were selected for input dropout and zeroed out.
  169. # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by
  170. # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample).
  171. # This is analogous to the way that dropout layers scale down outputs during evaluation when not
  172. # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training).
  173. if self.token_dropout and input_ids is not None:
  174. embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
  175. mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs
  176. src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1]
  177. mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
  178. embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
  179. embeddings.dtype
  180. )
  181. if self.position_embedding_type == "absolute":
  182. position_embeddings = self.position_embeddings(position_ids)
  183. embeddings = embeddings + position_embeddings
  184. if self.layer_norm is not None:
  185. embeddings = self.layer_norm(embeddings)
  186. if attention_mask is not None:
  187. embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
  188. # Matt: I think this line was copied incorrectly from BERT, disabling it for now.
  189. # embeddings = self.dropout(embeddings)
  190. return embeddings
  191. def create_position_ids_from_inputs_embeds(self, inputs_embeds):
  192. """
  193. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  194. Args:
  195. inputs_embeds: torch.Tensor
  196. Returns: torch.Tensor
  197. """
  198. input_shape = inputs_embeds.size()[:-1]
  199. sequence_length = input_shape[1]
  200. position_ids = torch.arange(
  201. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  202. )
  203. return position_ids.unsqueeze(0).expand(input_shape)
  204. def eager_attention_forward(
  205. module: nn.Module,
  206. query: torch.Tensor,
  207. key: torch.Tensor,
  208. value: torch.Tensor,
  209. attention_mask: Optional[torch.Tensor],
  210. scaling: float,
  211. dropout: float = 0.0,
  212. head_mask: Optional[torch.Tensor] = None,
  213. **kwargs: Unpack[TransformersKwargs],
  214. ):
  215. # ESM applies relative position embeddings and we don't copy from Llama
  216. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  217. if hasattr(module, "position_embedding_type") and module.position_embedding_type in [
  218. "relative_key",
  219. "relative_key_query",
  220. ]:
  221. seq_length = query.shape[2]
  222. position_ids_l = torch.arange(seq_length, dtype=torch.long, device=attn_weights.device).view(-1, 1)
  223. position_ids_r = torch.arange(seq_length, dtype=torch.long, device=attn_weights.device).view(1, -1)
  224. distance = position_ids_l - position_ids_r
  225. positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1)
  226. positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility
  227. if module.position_embedding_type == "relative_key":
  228. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
  229. elif module.position_embedding_type == "relative_key_query":
  230. relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
  231. relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding)
  232. relative_position_scores = relative_position_scores_query + relative_position_scores_key
  233. attn_weights = attn_weights + relative_position_scores
  234. if attention_mask is not None:
  235. causal_mask = attention_mask[:, :, :, : key.shape[-2]]
  236. attn_weights = attn_weights + causal_mask
  237. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  238. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  239. if head_mask is not None:
  240. attn_weights = attn_weights * head_mask
  241. attn_output = torch.matmul(attn_weights, value)
  242. attn_output = attn_output.transpose(1, 2).contiguous()
  243. return attn_output, attn_weights
  244. class EsmSelfAttention(nn.Module):
  245. def __init__(self, config, position_embedding_type=None, layer_idx=None, is_cross_attention=False):
  246. super().__init__()
  247. self.config = config
  248. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  249. raise ValueError(
  250. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  251. f"heads ({config.num_attention_heads})"
  252. )
  253. self.num_attention_heads = config.num_attention_heads
  254. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  255. self.all_head_size = self.num_attention_heads * self.attention_head_size
  256. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  257. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  258. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  259. self.dropout = config.attention_probs_dropout_prob
  260. self.position_embedding_type = position_embedding_type or getattr(
  261. config, "position_embedding_type", "absolute"
  262. )
  263. self.rotary_embeddings = None
  264. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  265. self.max_position_embeddings = config.max_position_embeddings
  266. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  267. elif self.position_embedding_type == "rotary":
  268. self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
  269. self.scaling = 1.0 # For BC we apply scaling before RoPE
  270. self.is_decoder = config.is_decoder
  271. self.layer_idx = layer_idx
  272. self.is_causal = self.is_decoder and not is_cross_attention
  273. def forward(
  274. self,
  275. hidden_states: torch.Tensor,
  276. attention_mask: Optional[torch.FloatTensor] = None,
  277. head_mask: Optional[torch.FloatTensor] = None,
  278. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  279. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  280. **kwargs: Unpack[TransformersKwargs],
  281. ) -> tuple[torch.Tensor]:
  282. batch_size, seq_length = hidden_states.shape[:-1]
  283. hidden_shape = (batch_size, seq_length, -1, self.attention_head_size)
  284. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  285. is_cross_attention = encoder_hidden_states is not None
  286. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  287. attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
  288. key_layer = self.key(current_states).view(hidden_shape).transpose(1, 2)
  289. value_layer = self.value(current_states).view(hidden_shape).transpose(1, 2)
  290. # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
  291. # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
  292. # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
  293. # ESM code and fix rotary embeddings.
  294. query_layer = query_layer * self.attention_head_size**-0.5
  295. if self.position_embedding_type == "rotary":
  296. query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
  297. attention_interface: Callable = eager_attention_forward
  298. if self.config._attn_implementation != "eager":
  299. if self.position_embedding_type in ["relative_key", "relative_key_query"]:
  300. raise ValueError(
  301. f"ESM {self.config._attn_implementation} attention does not support {self.position_embedding_type} embeddings. "
  302. "Set attention explicitly to 'eager' with `model.set_attn_implementation('eager')`"
  303. )
  304. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  305. attn_output, attn_weights = attention_interface(
  306. self,
  307. query_layer,
  308. key_layer,
  309. value_layer,
  310. attention_mask,
  311. dropout=0.0 if not self.training else self.dropout,
  312. scaling=self.scaling,
  313. head_mask=head_mask,
  314. **kwargs,
  315. )
  316. attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
  317. return attn_output, attn_weights
  318. class EsmSelfOutput(nn.Module):
  319. def __init__(self, config):
  320. super().__init__()
  321. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  322. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  323. def forward(self, hidden_states, input_tensor):
  324. hidden_states = self.dense(hidden_states)
  325. hidden_states = self.dropout(hidden_states)
  326. hidden_states = hidden_states + input_tensor
  327. return hidden_states
  328. class EsmAttention(nn.Module):
  329. def __init__(self, config, layer_idx=None, is_cross_attention=False):
  330. super().__init__()
  331. self.self = EsmSelfAttention(config, layer_idx=layer_idx, is_cross_attention=is_cross_attention)
  332. self.output = EsmSelfOutput(config)
  333. self.pruned_heads = set()
  334. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  335. def prune_heads(self, heads):
  336. if len(heads) == 0:
  337. return
  338. heads, index = find_pruneable_heads_and_indices(
  339. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  340. )
  341. # Prune linear layers
  342. self.self.query = prune_linear_layer(self.self.query, index)
  343. self.self.key = prune_linear_layer(self.self.key, index)
  344. self.self.value = prune_linear_layer(self.self.value, index)
  345. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  346. # Update hyper params and store pruned heads
  347. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  348. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  349. self.pruned_heads = self.pruned_heads.union(heads)
  350. def forward(
  351. self,
  352. hidden_states,
  353. attention_mask=None,
  354. head_mask=None,
  355. encoder_hidden_states=None,
  356. encoder_attention_mask=None,
  357. **kwargs: Unpack[TransformersKwargs],
  358. ):
  359. hidden_states_ln = self.LayerNorm(hidden_states)
  360. attn_output, _ = self.self(
  361. hidden_states_ln,
  362. attention_mask=attention_mask,
  363. head_mask=head_mask,
  364. encoder_hidden_states=encoder_hidden_states,
  365. encoder_attention_mask=encoder_attention_mask,
  366. **kwargs,
  367. )
  368. attn_output = self.output(attn_output, hidden_states)
  369. return attn_output
  370. class EsmIntermediate(nn.Module):
  371. def __init__(self, config):
  372. super().__init__()
  373. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  374. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  375. hidden_states = self.dense(hidden_states)
  376. hidden_states = gelu(hidden_states)
  377. return hidden_states
  378. class EsmOutput(nn.Module):
  379. def __init__(self, config):
  380. super().__init__()
  381. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  382. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  383. def forward(self, hidden_states, input_tensor):
  384. hidden_states = self.dense(hidden_states)
  385. hidden_states = self.dropout(hidden_states)
  386. hidden_states = hidden_states + input_tensor
  387. return hidden_states
  388. class EsmLayer(GradientCheckpointingLayer):
  389. def __init__(self, config):
  390. super().__init__()
  391. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  392. self.seq_len_dim = 1
  393. self.attention = EsmAttention(config)
  394. self.is_decoder = config.is_decoder
  395. self.add_cross_attention = config.add_cross_attention
  396. if self.add_cross_attention:
  397. if not self.is_decoder:
  398. raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
  399. self.crossattention = EsmAttention(config, is_cross_attention=True)
  400. self.intermediate = EsmIntermediate(config)
  401. self.output = EsmOutput(config)
  402. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  403. def forward(
  404. self,
  405. hidden_states,
  406. attention_mask=None,
  407. head_mask=None,
  408. encoder_hidden_states=None,
  409. encoder_attention_mask=None,
  410. **kwargs: Unpack[TransformersKwargs],
  411. ):
  412. attention_output = self.attention(
  413. hidden_states,
  414. attention_mask=attention_mask,
  415. head_mask=head_mask,
  416. **kwargs,
  417. )
  418. if self.is_decoder and encoder_hidden_states is not None:
  419. if not hasattr(self, "crossattention"):
  420. raise AttributeError(
  421. f"If `encoder_hidden_states` are passed, {self} has to be instantiated"
  422. " with cross-attention layers by setting `config.add_cross_attention=True`"
  423. )
  424. attention_output = self.crossattention(
  425. attention_output,
  426. attention_mask=attention_mask,
  427. head_mask=head_mask,
  428. encoder_hidden_states=encoder_hidden_states,
  429. encoder_attention_mask=encoder_attention_mask,
  430. **kwargs,
  431. )
  432. layer_output = self.feed_forward_chunk(attention_output)
  433. return layer_output
  434. def feed_forward_chunk(self, attention_output):
  435. attention_output_ln = self.LayerNorm(attention_output)
  436. intermediate_output = self.intermediate(attention_output_ln)
  437. layer_output = self.output(intermediate_output, attention_output)
  438. return layer_output
  439. class EsmEncoder(nn.Module):
  440. def __init__(self, config):
  441. super().__init__()
  442. self.config = config
  443. self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
  444. self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  445. self.gradient_checkpointing = False
  446. @can_return_tuple
  447. def forward(
  448. self,
  449. hidden_states,
  450. attention_mask=None,
  451. head_mask=None,
  452. encoder_hidden_states=None,
  453. encoder_attention_mask=None,
  454. **kwargs: Unpack[TransformersKwargs],
  455. ):
  456. for i, layer_module in enumerate(self.layer):
  457. layer_head_mask = head_mask[i] if head_mask is not None else None
  458. hidden_states = layer_module(
  459. hidden_states,
  460. attention_mask=attention_mask,
  461. head_mask=layer_head_mask,
  462. encoder_hidden_states=encoder_hidden_states,
  463. encoder_attention_mask=encoder_attention_mask,
  464. **kwargs,
  465. )
  466. if self.emb_layer_norm_after:
  467. hidden_states = self.emb_layer_norm_after(hidden_states)
  468. return BaseModelOutputWithCrossAttentions(last_hidden_state=hidden_states)
  469. # Copied from transformers.models.bert.modeling_bert.BertPooler
  470. class EsmPooler(nn.Module):
  471. def __init__(self, config):
  472. super().__init__()
  473. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  474. self.activation = nn.Tanh()
  475. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  476. # We "pool" the model by simply taking the hidden state corresponding
  477. # to the first token.
  478. first_token_tensor = hidden_states[:, 0]
  479. pooled_output = self.dense(first_token_tensor)
  480. pooled_output = self.activation(pooled_output)
  481. return pooled_output
  482. @auto_docstring
  483. class EsmPreTrainedModel(PreTrainedModel):
  484. config: EsmConfig
  485. base_model_prefix = "esm"
  486. supports_gradient_checkpointing = True
  487. accepts_loss_kwargs = False
  488. _no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"]
  489. _keys_to_ignore_on_load_unexpected = ["position_embeddings.weight"]
  490. _supports_flash_attn = True
  491. _supports_sdpa = True
  492. _supports_flex_attn = True
  493. _supports_attention_backend = True
  494. _can_record_outputs = {
  495. "hidden_states": EsmLayer,
  496. "attentions": [OutputRecorder(EsmSelfAttention, index=1, layer_name="attention")],
  497. "cross_attentions": [
  498. OutputRecorder(EsmSelfAttention, index=1, layer_name="crossattention"),
  499. ],
  500. }
  501. # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->EsmLMHead
  502. def _init_weights(self, module):
  503. """Initialize the weights"""
  504. if isinstance(module, nn.Linear):
  505. # Slightly different from the TF version which uses truncated_normal for initialization
  506. # cf https://github.com/pytorch/pytorch/pull/5617
  507. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  508. if module.bias is not None:
  509. module.bias.data.zero_()
  510. elif isinstance(module, nn.Embedding):
  511. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  512. if module.padding_idx is not None:
  513. module.weight.data[module.padding_idx].zero_()
  514. elif isinstance(module, nn.LayerNorm):
  515. module.bias.data.zero_()
  516. module.weight.data.fill_(1.0)
  517. elif isinstance(module, EsmLMHead):
  518. module.bias.data.zero_()
  519. def get_output_embeddings(self):
  520. # NOTE: get_output_embeddings() must return None to prevent accidental weight tying.
  521. # See e.g. https://github.com/huggingface/transformers/pull/39339#discussion_r2219126400
  522. return None
  523. @auto_docstring
  524. class EsmModel(EsmPreTrainedModel):
  525. """
  526. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  527. cross-attention is added between the self-attention layers, following the architecture described in [Attention is
  528. all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
  529. Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
  530. To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
  531. to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
  532. `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
  533. """
  534. def __init__(self, config, add_pooling_layer=True):
  535. r"""
  536. add_pooling_layer (bool, *optional*, defaults to `True`):
  537. Whether to add a pooling layer
  538. """
  539. super().__init__(config)
  540. self.config = config
  541. self.embeddings = EsmEmbeddings(config)
  542. self.encoder = EsmEncoder(config)
  543. self.pooler = EsmPooler(config) if add_pooling_layer else None
  544. self.contact_head = EsmContactPredictionHead(
  545. in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
  546. )
  547. # Initialize weights and apply final processing
  548. self.post_init()
  549. def get_input_embeddings(self):
  550. return self.embeddings.word_embeddings
  551. def set_input_embeddings(self, value):
  552. self.embeddings.word_embeddings = value
  553. def _prune_heads(self, heads_to_prune):
  554. """
  555. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  556. class PreTrainedModel
  557. """
  558. for layer, heads in heads_to_prune.items():
  559. self.encoder.layer[layer].attention.prune_heads(heads)
  560. @check_model_inputs()
  561. @auto_docstring
  562. def forward(
  563. self,
  564. input_ids: Optional[torch.Tensor] = None,
  565. attention_mask: Optional[torch.Tensor] = None,
  566. position_ids: Optional[torch.Tensor] = None,
  567. head_mask: Optional[torch.Tensor] = None,
  568. inputs_embeds: Optional[torch.Tensor] = None,
  569. encoder_hidden_states: Optional[torch.Tensor] = None,
  570. encoder_attention_mask: Optional[torch.Tensor] = None,
  571. **kwargs: Unpack[TransformersKwargs],
  572. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
  573. r"""
  574. input_ids (`torch.LongTensor` of shape `((batch_size, sequence_length))`):
  575. Indices of input sequence tokens in the vocabulary.
  576. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  577. [`PreTrainedTokenizer.__call__`] for details.
  578. [What are input IDs?](../glossary#input-ids)
  579. position_ids (`torch.LongTensor` of shape `((batch_size, sequence_length))`, *optional*):
  580. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  581. config.max_position_embeddings - 1]`.
  582. [What are position IDs?](../glossary#position-ids)
  583. inputs_embeds (`torch.FloatTensor` of shape `((batch_size, sequence_length), hidden_size)`, *optional*):
  584. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  585. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  586. model's internal embedding lookup matrix.
  587. """
  588. if (input_ids is None) ^ (inputs_embeds is not None):
  589. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  590. if inputs_embeds is None:
  591. inputs_embeds = self.embeddings(
  592. input_ids=input_ids,
  593. position_ids=position_ids,
  594. )
  595. if self.config._attn_implementation != "flash_attention_2":
  596. batch_size, seq_length = inputs_embeds.shape[:-1]
  597. if attention_mask is None:
  598. attention_mask = torch.ones(((batch_size, seq_length)), device=inputs_embeds.device)
  599. attention_mask: torch.Tensor = self.get_extended_attention_mask(
  600. attention_mask, input_shape=(batch_size, seq_length)
  601. )
  602. # If a 2D or 3D attention mask is provided for the cross-attention
  603. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  604. if self.config.is_decoder and encoder_hidden_states is not None:
  605. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  606. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  607. if encoder_attention_mask is None:
  608. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
  609. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  610. else:
  611. encoder_extended_attention_mask = None
  612. # Prepare head mask if needed
  613. # 1.0 in head_mask indicate we keep the head
  614. # attention_probs has shape bsz x n_heads x N x N
  615. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  616. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  617. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  618. encoder_outputs = self.encoder(
  619. inputs_embeds,
  620. attention_mask=attention_mask,
  621. head_mask=head_mask,
  622. encoder_hidden_states=encoder_hidden_states,
  623. encoder_attention_mask=encoder_extended_attention_mask,
  624. **kwargs,
  625. )
  626. sequence_output = encoder_outputs[0]
  627. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  628. return BaseModelOutputWithPoolingAndCrossAttentions(
  629. last_hidden_state=sequence_output,
  630. pooler_output=pooled_output,
  631. )
  632. def predict_contacts(self, tokens, attention_mask):
  633. attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions
  634. attns = torch.stack(attns, dim=1) # Matches the original model layout
  635. # In the original model, attentions for padding tokens are completely zeroed out.
  636. # This makes no difference most of the time because the other tokens won't attend to them,
  637. # but it does for the contact prediction task, which takes attentions as input,
  638. # so we have to mimic that here.
  639. attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
  640. attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
  641. return self.contact_head(tokens, attns)
  642. @auto_docstring
  643. class EsmForMaskedLM(EsmPreTrainedModel):
  644. _tied_weights_keys = ["lm_head.decoder.weight"]
  645. def __init__(self, config):
  646. super().__init__(config)
  647. if config.is_decoder:
  648. logger.warning(
  649. "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for "
  650. "bi-directional self-attention."
  651. )
  652. self.esm = EsmModel(config, add_pooling_layer=False)
  653. self.lm_head = EsmLMHead(config)
  654. self.init_weights()
  655. self.post_init()
  656. def get_output_embeddings(self):
  657. return self.lm_head.decoder
  658. def set_output_embeddings(self, new_embeddings):
  659. self.lm_head.decoder = new_embeddings
  660. @can_return_tuple
  661. @auto_docstring
  662. def forward(
  663. self,
  664. input_ids: Optional[torch.LongTensor] = None,
  665. attention_mask: Optional[torch.Tensor] = None,
  666. position_ids: Optional[torch.LongTensor] = None,
  667. head_mask: Optional[torch.Tensor] = None,
  668. inputs_embeds: Optional[torch.FloatTensor] = None,
  669. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  670. encoder_attention_mask: Optional[torch.Tensor] = None,
  671. labels: Optional[torch.LongTensor] = None,
  672. **kwargs: Unpack[TransformersKwargs],
  673. ) -> Union[tuple, MaskedLMOutput]:
  674. r"""
  675. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  676. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  677. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  678. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  679. """
  680. outputs = self.esm(
  681. input_ids,
  682. attention_mask=attention_mask,
  683. position_ids=position_ids,
  684. head_mask=head_mask,
  685. inputs_embeds=inputs_embeds,
  686. encoder_hidden_states=encoder_hidden_states,
  687. encoder_attention_mask=encoder_attention_mask,
  688. **kwargs,
  689. )
  690. sequence_output = outputs[0]
  691. prediction_scores = self.lm_head(sequence_output)
  692. masked_lm_loss = None
  693. if labels is not None:
  694. loss_fct = CrossEntropyLoss()
  695. labels = labels.to(prediction_scores.device)
  696. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  697. return MaskedLMOutput(
  698. loss=masked_lm_loss,
  699. logits=prediction_scores,
  700. hidden_states=outputs.hidden_states,
  701. attentions=outputs.attentions,
  702. )
  703. def predict_contacts(self, tokens, attention_mask):
  704. return self.esm.predict_contacts(tokens, attention_mask=attention_mask)
  705. class EsmLMHead(nn.Module):
  706. """ESM Head for masked language modeling."""
  707. def __init__(self, config):
  708. super().__init__()
  709. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  710. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  711. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  712. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  713. def forward(self, features, **kwargs):
  714. x = self.dense(features)
  715. x = gelu(x)
  716. x = self.layer_norm(x)
  717. # project back to size of vocabulary with bias
  718. x = self.decoder(x) + self.bias
  719. return x
  720. @auto_docstring(
  721. custom_intro="""
  722. ESM Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  723. output) e.g. for GLUE tasks.
  724. """
  725. )
  726. class EsmForSequenceClassification(EsmPreTrainedModel):
  727. def __init__(self, config):
  728. super().__init__(config)
  729. self.num_labels = config.num_labels
  730. self.config = config
  731. self.esm = EsmModel(config, add_pooling_layer=False)
  732. self.classifier = EsmClassificationHead(config)
  733. self.init_weights()
  734. self.post_init()
  735. @can_return_tuple
  736. @auto_docstring
  737. def forward(
  738. self,
  739. input_ids: Optional[torch.LongTensor] = None,
  740. attention_mask: Optional[torch.Tensor] = None,
  741. position_ids: Optional[torch.LongTensor] = None,
  742. head_mask: Optional[torch.Tensor] = None,
  743. inputs_embeds: Optional[torch.FloatTensor] = None,
  744. labels: Optional[torch.LongTensor] = None,
  745. **kwargs: Unpack[TransformersKwargs],
  746. ) -> Union[tuple, SequenceClassifierOutput]:
  747. r"""
  748. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  749. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  750. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  751. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  752. """
  753. outputs = self.esm(
  754. input_ids,
  755. attention_mask=attention_mask,
  756. position_ids=position_ids,
  757. head_mask=head_mask,
  758. inputs_embeds=inputs_embeds,
  759. **kwargs,
  760. )
  761. sequence_output = outputs[0]
  762. logits = self.classifier(sequence_output)
  763. loss = None
  764. if labels is not None:
  765. labels = labels.to(logits.device)
  766. if self.config.problem_type is None:
  767. if self.num_labels == 1:
  768. self.config.problem_type = "regression"
  769. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  770. self.config.problem_type = "single_label_classification"
  771. else:
  772. self.config.problem_type = "multi_label_classification"
  773. if self.config.problem_type == "regression":
  774. loss_fct = MSELoss()
  775. if self.num_labels == 1:
  776. loss = loss_fct(logits.squeeze(), labels.squeeze())
  777. else:
  778. loss = loss_fct(logits, labels)
  779. elif self.config.problem_type == "single_label_classification":
  780. loss_fct = CrossEntropyLoss()
  781. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  782. elif self.config.problem_type == "multi_label_classification":
  783. loss_fct = BCEWithLogitsLoss()
  784. loss = loss_fct(logits, labels)
  785. return SequenceClassifierOutput(
  786. loss=loss,
  787. logits=logits,
  788. hidden_states=outputs.hidden_states,
  789. attentions=outputs.attentions,
  790. )
  791. @auto_docstring
  792. class EsmForTokenClassification(EsmPreTrainedModel):
  793. def __init__(self, config):
  794. super().__init__(config)
  795. self.num_labels = config.num_labels
  796. self.esm = EsmModel(config, add_pooling_layer=False)
  797. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  798. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  799. self.init_weights()
  800. self.post_init()
  801. @can_return_tuple
  802. @auto_docstring
  803. def forward(
  804. self,
  805. input_ids: Optional[torch.LongTensor] = None,
  806. attention_mask: Optional[torch.Tensor] = None,
  807. position_ids: Optional[torch.LongTensor] = None,
  808. head_mask: Optional[torch.Tensor] = None,
  809. inputs_embeds: Optional[torch.FloatTensor] = None,
  810. labels: Optional[torch.LongTensor] = None,
  811. **kwargs: Unpack[TransformersKwargs],
  812. ) -> Union[tuple, TokenClassifierOutput]:
  813. r"""
  814. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  815. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  816. """
  817. outputs = self.esm(
  818. input_ids,
  819. attention_mask=attention_mask,
  820. position_ids=position_ids,
  821. head_mask=head_mask,
  822. inputs_embeds=inputs_embeds,
  823. **kwargs,
  824. )
  825. sequence_output = outputs[0]
  826. sequence_output = self.dropout(sequence_output)
  827. logits = self.classifier(sequence_output)
  828. loss = None
  829. if labels is not None:
  830. loss_fct = CrossEntropyLoss()
  831. labels = labels.to(logits.device)
  832. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  833. return TokenClassifierOutput(
  834. loss=loss,
  835. logits=logits,
  836. hidden_states=outputs.hidden_states,
  837. attentions=outputs.attentions,
  838. )
  839. class EsmClassificationHead(nn.Module):
  840. """Head for sentence-level classification tasks."""
  841. def __init__(self, config):
  842. super().__init__()
  843. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  844. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  845. self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
  846. def forward(self, features, **kwargs):
  847. x = features[:, 0, :] # take <s> token (equiv. to [CLS])
  848. x = self.dropout(x)
  849. x = self.dense(x)
  850. x = torch.tanh(x)
  851. x = self.dropout(x)
  852. x = self.out_proj(x)
  853. return x
  854. def create_position_ids_from_input_ids(input_ids, padding_idx):
  855. """
  856. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  857. are ignored. This is modified from fairseq's `utils.make_positions`.
  858. Args:
  859. x: torch.Tensor x:
  860. Returns: torch.Tensor
  861. """
  862. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  863. mask = input_ids.ne(padding_idx).int()
  864. incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
  865. return incremental_indices.long() + padding_idx
  866. __all__ = [
  867. "EsmForMaskedLM",
  868. "EsmForSequenceClassification",
  869. "EsmForTokenClassification",
  870. "EsmModel",
  871. "EsmPreTrainedModel",
  872. ]