modular_evolla.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023
  1. # coding=utf-8
  2. # Copyright 2025 Westlake Representational Learning Lab (Fajie Yuan Lab) team and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import warnings
  16. from dataclasses import dataclass
  17. from typing import Optional, Union
  18. import torch
  19. from torch import Tensor, nn
  20. from ...cache_utils import Cache, DynamicCache
  21. from ...generation import GenerationMixin
  22. from ...masking_utils import create_causal_mask
  23. from ...modeling_outputs import (
  24. BaseModelOutputWithPast,
  25. BaseModelOutputWithPoolingAndCrossAttentions,
  26. CausalLMOutputWithPast,
  27. ModelOutput,
  28. )
  29. from ...modeling_utils import ModuleUtilsMixin, PreTrainedModel, get_parameter_dtype
  30. from ...utils import (
  31. auto_docstring,
  32. can_return_tuple,
  33. logging,
  34. )
  35. from ...utils.deprecation import deprecate_kwarg
  36. from ...utils.generic import OutputRecorder, check_model_inputs
  37. from ..esm.modeling_esm import (
  38. EsmAttention,
  39. EsmEmbeddings,
  40. EsmEncoder,
  41. EsmIntermediate,
  42. EsmLayer,
  43. EsmOutput,
  44. EsmPooler,
  45. EsmSelfAttention,
  46. EsmSelfOutput,
  47. )
  48. from ..llama.modeling_llama import (
  49. LlamaAttention,
  50. LlamaDecoderLayer,
  51. LlamaMLP,
  52. LlamaPreTrainedModel,
  53. LlamaRMSNorm,
  54. LlamaRotaryEmbedding,
  55. )
  56. from .configuration_evolla import EvollaConfig, SaProtConfig
  57. logger = logging.get_logger(__name__)
  58. class EvollaSaProtEmbeddings(EsmEmbeddings):
  59. def __init__(self, config):
  60. super().__init__(config)
  61. # remove the position_ids in EsmEmbeddings
  62. self.position_ids = None
  63. def rotate_half_esm(x):
  64. x1, x2 = x.chunk(2, dim=-1)
  65. return torch.cat((-x2, x1), dim=-1)
  66. def apply_rotary_pos_emb_esm(x, cos, sin):
  67. cos = cos[:, :, : x.shape[-2], :]
  68. sin = sin[:, :, : x.shape[-2], :]
  69. return (x * cos) + (rotate_half_esm(x) * sin)
  70. class EvollaSaProtRotaryEmbedding(nn.Module):
  71. """
  72. Rotary position embeddings based on those in
  73. [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
  74. matrices which depend on their relative positions.
  75. """
  76. inv_freq: torch.Tensor # fix linting for `register_buffer`
  77. def __init__(self, dim: int):
  78. super().__init__()
  79. # Generate and save the inverse frequency buffer (non trainable)
  80. inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
  81. self.register_buffer("inv_freq", inv_freq)
  82. self._seq_len_cached = None
  83. self._cos_cached = None
  84. self._sin_cached = None
  85. def _update_cos_sin_tables(self, x, seq_dimension=2):
  86. seq_len = x.shape[seq_dimension]
  87. # Reset the tables if the sequence length has changed,
  88. # or if we're on a new device (possibly due to tracing for instance)
  89. if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
  90. self._seq_len_cached = seq_len
  91. t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
  92. freqs = torch.outer(t, self.inv_freq)
  93. emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
  94. self._cos_cached = emb.cos()[None, None, :, :]
  95. self._sin_cached = emb.sin()[None, None, :, :]
  96. return self._cos_cached, self._sin_cached
  97. def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  98. self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
  99. return (
  100. apply_rotary_pos_emb_esm(q, self._cos_cached, self._sin_cached).to(dtype=q.dtype),
  101. apply_rotary_pos_emb_esm(k, self._cos_cached, self._sin_cached).to(dtype=k.dtype),
  102. )
  103. class EvollaSaProtSelfAttention(EsmSelfAttention):
  104. def __init__(self, config, position_embedding_type=None, layer_idx=None, is_cross_attention=False):
  105. nn.Module.__init__(self)
  106. self.config = config
  107. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  108. raise ValueError(
  109. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  110. f"heads ({config.num_attention_heads})"
  111. )
  112. self.num_attention_heads = config.num_attention_heads
  113. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  114. self.all_head_size = self.num_attention_heads * self.attention_head_size
  115. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  116. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  117. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  118. self.dropout = config.attention_probs_dropout_prob
  119. self.position_embedding_type = position_embedding_type or getattr(
  120. config, "position_embedding_type", "absolute"
  121. )
  122. self.rotary_embeddings = None
  123. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  124. self.max_position_embeddings = config.max_position_embeddings
  125. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  126. elif self.position_embedding_type == "rotary":
  127. self.rotary_embeddings = EvollaSaProtRotaryEmbedding(dim=self.attention_head_size)
  128. self.is_decoder = config.is_decoder
  129. self.layer_idx = layer_idx
  130. self.scaling = 1.0
  131. self.is_causal = self.is_decoder and not is_cross_attention
  132. class EvollaSaProtSelfOutput(EsmSelfOutput):
  133. pass
  134. class EvollaSaProtAttention(EsmAttention):
  135. pass
  136. class EvollaSaProtIntermediate(EsmIntermediate):
  137. pass
  138. class EvollaSaProtOutput(EsmOutput):
  139. pass
  140. class EvollaSaProtLayer(EsmLayer):
  141. pass
  142. class EvollaSaProtEncoder(EsmEncoder):
  143. pass
  144. class EvollaSaProtPooler(EsmPooler):
  145. pass
  146. @auto_docstring
  147. class EvollaSaProtPreTrainedModel(PreTrainedModel):
  148. config: SaProtConfig
  149. _no_split_modules = ["EvollaSaProtLayer"]
  150. _supports_flash_attn = True
  151. _supports_sdpa = True
  152. _supports_attention_backend = True
  153. _can_record_outputs = {
  154. "hidden_states": EvollaSaProtLayer,
  155. "attentions": [OutputRecorder(EvollaSaProtSelfAttention, index=1, layer_name="attention")],
  156. "cross_attentions": [
  157. OutputRecorder(EvollaSaProtSelfAttention, index=1, layer_name="crossattention"),
  158. ],
  159. }
  160. def _init_weights(self, module):
  161. """Initialize the weights"""
  162. std = self.config.initializer_range
  163. if isinstance(module, nn.Linear):
  164. module.weight.data.normal_(mean=0.0, std=std)
  165. if module.bias is not None:
  166. module.bias.data.zero_()
  167. elif isinstance(module, nn.Embedding):
  168. module.weight.data.normal_(mean=0.0, std=std)
  169. if module.padding_idx is not None:
  170. module.weight.data[module.padding_idx].zero_()
  171. elif isinstance(module, nn.LayerNorm):
  172. module.bias.data.zero_()
  173. module.weight.data.fill_(1.0)
  174. class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel):
  175. def __init__(self, config: SaProtConfig):
  176. super().__init__(config)
  177. self.embeddings = EvollaSaProtEmbeddings(config)
  178. self.encoder = EvollaSaProtEncoder(config)
  179. def get_input_embeddings(self):
  180. return self.embeddings.word_embeddings
  181. def set_input_embeddings(self, value):
  182. self.embeddings.word_embeddings = value
  183. def _prune_heads(self, heads_to_prune):
  184. """
  185. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  186. class PreTrainedModel
  187. """
  188. for layer, heads in heads_to_prune.items():
  189. self.encoder.layer[layer].attention.prune_heads(heads)
  190. @check_model_inputs()
  191. def forward(
  192. self,
  193. input_ids: Optional[torch.Tensor],
  194. attention_mask: Optional[torch.Tensor] = None,
  195. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
  196. input_shape = input_ids.size()
  197. batch_size, seq_length = input_shape
  198. device = input_ids.device
  199. if attention_mask is None:
  200. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  201. inputs_embeds = self.embeddings(input_ids=input_ids, attention_mask=attention_mask)
  202. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
  203. encoder_outputs = self.encoder(inputs_embeds, attention_mask=extended_attention_mask)
  204. sequence_output = encoder_outputs[0]
  205. return BaseModelOutputWithPoolingAndCrossAttentions(
  206. last_hidden_state=sequence_output,
  207. hidden_states=encoder_outputs.hidden_states,
  208. attentions=encoder_outputs.attentions,
  209. cross_attentions=encoder_outputs.cross_attentions,
  210. )
  211. def get_extended_attention_mask(
  212. self,
  213. attention_mask: Tensor,
  214. input_shape: tuple[int],
  215. device: Optional[torch.device] = None,
  216. dtype: Optional[torch.dtype] = None,
  217. ) -> Tensor:
  218. """
  219. Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
  220. Arguments:
  221. attention_mask (`torch.Tensor`):
  222. Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
  223. input_shape (`Tuple[int]`):
  224. The shape of the input to the model.
  225. Returns:
  226. `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
  227. """
  228. if dtype is None:
  229. dtype = get_parameter_dtype(self)
  230. if not (attention_mask.dim() == 2 and self.config.is_decoder):
  231. # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
  232. if device is not None:
  233. warnings.warn(
  234. "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
  235. )
  236. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  237. # ourselves in which case we just need to make it broadcastable to all heads.
  238. if attention_mask.dim() == 3:
  239. extended_attention_mask = attention_mask[:, None, :, :]
  240. elif attention_mask.dim() == 2:
  241. # Provided a padding mask of dimensions [batch_size, seq_length]
  242. # - if the model is a decoder, apply a causal mask in addition to the padding mask
  243. # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
  244. if self.config.is_decoder:
  245. extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
  246. input_shape, attention_mask, device
  247. )
  248. else:
  249. extended_attention_mask = attention_mask[:, None, None, :]
  250. else:
  251. raise ValueError(
  252. f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
  253. )
  254. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  255. # masked positions, this operation will create a tensor which is 0.0 for
  256. # positions we want to attend and the dtype's smallest value for masked positions.
  257. # Since we are adding it to the raw scores before the softmax, this is
  258. # effectively the same as removing these entirely.
  259. extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
  260. extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
  261. return extended_attention_mask
  262. class EvollaSequenceCompressorAttention(nn.Module):
  263. def __init__(self, dim, dim_head=64, heads=8):
  264. super().__init__()
  265. self.scale = dim_head**-0.5
  266. self.heads = heads
  267. inner_dim = dim_head * heads
  268. self.norm_media = nn.LayerNorm(dim)
  269. self.norm_latents = nn.LayerNorm(dim)
  270. self.to_q = nn.Linear(dim, inner_dim, bias=False)
  271. self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
  272. self.to_out = nn.Linear(inner_dim, dim, bias=False)
  273. def forward(self, x, latents, mask):
  274. """
  275. Args:
  276. x (torch.Tensor): image features
  277. shape (b, n1, D)
  278. latent (torch.Tensor): latent features
  279. shape (b, n2, D); n2: num of latent tokens
  280. """
  281. x = self.norm_media(x)
  282. latents = self.norm_latents(latents)
  283. h = self.heads
  284. q = self.to_q(latents)
  285. kv_input = torch.cat((x, latents), dim=-2)
  286. k, v = self.to_kv(kv_input).chunk(
  287. 2, dim=-1
  288. ) # each: batch_size, max_protein_length+num_latents, dim_head*num_heads
  289. q = q.view(q.size(0), q.size(1), h, -1).permute(0, 2, 1, 3)
  290. k = k.view(k.size(0), k.size(1), h, -1).permute(0, 2, 1, 3)
  291. v = v.view(v.size(0), v.size(1), h, -1).permute(0, 2, 1, 3)
  292. q = q * self.scale # batch_size, num_heads, num_latents, dim_head
  293. # attention
  294. sim = torch.matmul(q, k.transpose(-1, -2))
  295. sim = sim - sim.amax(dim=-1, keepdim=True).detach()
  296. bs, nh, skd, okd = sim.shape
  297. ones = torch.ones(nh, skd).to(mask.device) # Create a tensor of ones with shape (nh, skd)
  298. mask_exp = mask[:, None, None, :]
  299. ones_exp = ones[None, :, :, None]
  300. mask = mask_exp * ones_exp
  301. sim = sim.masked_fill((1 - mask).bool(), -1e4)
  302. attn = sim.softmax(dim=-1)
  303. out = torch.matmul(attn, v)
  304. out = out.permute(0, 2, 1, 3)
  305. # [batch, seq, head, features] -> [batch, seq, head*features]
  306. out = out.reshape(out.size(0), out.size(1), -1)
  307. return self.to_out(out)
  308. class EvollaFeedForward(nn.Module):
  309. def __init__(self, dim, mult=4):
  310. super().__init__()
  311. inner_dim = int(dim * mult)
  312. self.norm = nn.LayerNorm(dim)
  313. self.fc1 = nn.Linear(dim, inner_dim, bias=False)
  314. self.activation = nn.GELU()
  315. self.fc2 = nn.Linear(inner_dim, dim, bias=False)
  316. def forward(self, x):
  317. return self.fc2(self.activation(self.fc1(self.norm(x))))
  318. class EvollaSequenceCompressorResampler(nn.Module):
  319. def __init__(self, config: EvollaConfig):
  320. super().__init__()
  321. protein_repr_dim = config.protein_encoder_config.hidden_size
  322. self.num_latents = config.resampler_num_latents
  323. self.latents = nn.Parameter(torch.randn(self.num_latents, protein_repr_dim), requires_grad=True)
  324. self.layers = nn.ModuleList([])
  325. for _ in range(config.resampler_depth):
  326. self.layers.append(
  327. nn.ModuleList(
  328. [
  329. EvollaSequenceCompressorAttention(
  330. dim=protein_repr_dim, dim_head=config.resampler_dim_head, heads=config.resampler_heads
  331. ),
  332. EvollaFeedForward(dim=protein_repr_dim, mult=config.resampler_ff_mult),
  333. ]
  334. )
  335. )
  336. self.norm = nn.LayerNorm(config.hidden_size)
  337. self.protein_projector = nn.Linear(protein_repr_dim, config.hidden_size)
  338. def forward(self, embeds, mask):
  339. b = embeds.shape[0]
  340. bs, _ = mask.shape # bs, max_protein_length
  341. latent_mask = torch.ones(bs, self.num_latents).to(mask.device)
  342. mask = torch.cat((mask, latent_mask), dim=1) # bs, max_protein_length + num_latents
  343. # blocks
  344. ones = torch.ones(b).to(self.latents.device)
  345. latents = self.latents[None] * ones.view(-1, 1, 1) # [b,n,d]
  346. latents = latents.to(embeds.dtype)
  347. for attn, ff in self.layers:
  348. latents = attn(embeds, latents, mask) + latents
  349. latents = ff(latents) + latents
  350. transformed_feature = self.protein_projector(latents)
  351. return self.norm(transformed_feature)
  352. @dataclass
  353. @auto_docstring
  354. class EvollaProteinEncoderModelOutput(ModelOutput):
  355. sequence_compressor_output: Optional[torch.FloatTensor] = None
  356. last_hidden_state: Optional[torch.FloatTensor] = None
  357. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  358. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  359. class EvollaProteinEncoder(nn.Module):
  360. def __init__(self, config: EvollaConfig):
  361. super().__init__()
  362. self.model = EvollaSaProtProteinEncoder(config=config.protein_encoder_config)
  363. self.sequence_compressor_resampler = EvollaSequenceCompressorResampler(config=config)
  364. @can_return_tuple
  365. def forward(self, input_ids: torch.LongTensor, attention_mask: torch.FloatTensor, **kwargs):
  366. protein_output = self.model(input_ids=input_ids, attention_mask=attention_mask)
  367. protein_embeds = protein_output.last_hidden_state
  368. sequence_repr = self.sequence_compressor_resampler(protein_embeds, attention_mask)
  369. return EvollaProteinEncoderModelOutput(
  370. sequence_compressor_output=sequence_repr,
  371. last_hidden_state=protein_output.last_hidden_state,
  372. )
  373. class EvollaSequenceAlignerCrossAttention(nn.Module):
  374. def __init__(
  375. self,
  376. config,
  377. protein_encoder_dim: Optional[int] = None,
  378. structure_encoder_dim: Optional[int] = None,
  379. msa_encoder_dim: Optional[int] = None,
  380. ):
  381. super().__init__()
  382. self.hidden_size = config.hidden_size
  383. self.num_attention_heads = config.num_attention_heads
  384. self.scale = self.num_attention_heads**-0.5
  385. self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
  386. self.all_head_size = self.num_attention_heads * self.attention_head_size
  387. attention_probs_dropout_prob = config.aligner_attention_probs_dropout_prob
  388. enable_bias = config.aligner_enable_bias
  389. ffn_mult = config.aligner_ffn_mult
  390. self.query = nn.Linear(self.hidden_size, self.all_head_size)
  391. if protein_encoder_dim is not None:
  392. self.key_protein = nn.Linear(protein_encoder_dim, self.all_head_size)
  393. self.value_protein = nn.Linear(protein_encoder_dim, self.all_head_size)
  394. else:
  395. self.key_protein = None
  396. self.value_protein = None
  397. if structure_encoder_dim is not None:
  398. self.key_structure = nn.Linear(structure_encoder_dim, self.all_head_size)
  399. self.value_structure = nn.Linear(structure_encoder_dim, self.all_head_size)
  400. else:
  401. self.key_structure = None
  402. self.value_structure = None
  403. if msa_encoder_dim is not None:
  404. self.key_msa = nn.Linear(msa_encoder_dim, self.all_head_size)
  405. self.value_msa = nn.Linear(msa_encoder_dim, self.all_head_size)
  406. else:
  407. self.key_msa = None
  408. self.value_msa = None
  409. self.attention_norm = EvollaRMSNorm(self.hidden_size)
  410. self.dropout = nn.Dropout(attention_probs_dropout_prob)
  411. self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=enable_bias)
  412. self.ff = EvollaFeedForward(self.hidden_size, ffn_mult)
  413. self.gate_attention = nn.Parameter(torch.tensor([0.0]))
  414. self.gate_ffw = nn.Parameter(torch.tensor([0.0]))
  415. def cross_attention(
  416. self,
  417. query_states,
  418. protein_key_value_states,
  419. structure_key_value_states,
  420. msa_key_value_states,
  421. query_attn_mask,
  422. protein_kv_attn_mask,
  423. structure_kv_attn_mask,
  424. msa_kv_attn_mask,
  425. ):
  426. """
  427. query_states: text
  428. key_value_states: protein
  429. query_states: [bs, query_seq_len, dim]
  430. key_value_states: [bs, kv_seq_len, dim]
  431. query_attn_mask: [bs, query_seq_len]
  432. kv_attn_mask: [bs, kv_seq_len]
  433. """
  434. # Concatenate protein and structure
  435. kv_attn_mask = [protein_kv_attn_mask, structure_kv_attn_mask, msa_kv_attn_mask]
  436. kv_attn_mask = [_ for _ in kv_attn_mask if _ is not None]
  437. if not kv_attn_mask:
  438. raise ValueError("At least one modality should be provided for cross attention.")
  439. kv_attn_mask = torch.cat(kv_attn_mask, dim=1)
  440. query_layer = self.attention_norm(query_states)
  441. # Warning: This place might cause issues, refers to
  442. # https://discuss.pytorch.org/t/cuda-error-cublas-status-not-supported-when-calling-cublasltmatmul-from-torch-nn-functional-linear/170214/13
  443. # Solution: add `DISABLE_ADDMM_CUDA_LT=1` as environment variable
  444. # Apply linear transformation to input_query, input_key, and input_value
  445. query_layer = self.query(query_layer) # [bs, querylength, dim]
  446. if self.key_protein is not None and self.value_protein is not None:
  447. protein_key_value_states = protein_key_value_states.to(query_states)
  448. key_layer_protein = self.key_protein(protein_key_value_states) # [bs, keylength, dim]
  449. value_layer_protein = self.value_protein(protein_key_value_states) # [bs, keylength, dim]
  450. else:
  451. key_layer_protein = None
  452. value_layer_protein = None
  453. if self.key_structure is not None and self.value_structure is not None:
  454. structure_key_value_states = structure_key_value_states.to(query_states)
  455. key_layer_structure = self.key_structure(structure_key_value_states) # [bs, keylength, dim]
  456. value_layer_structure = self.value_structure(structure_key_value_states) # [bs, keylength, dim]
  457. else:
  458. key_layer_structure = None
  459. value_layer_structure = None
  460. if self.key_msa is not None and self.value_msa is not None:
  461. msa_key_value_states = msa_key_value_states.to(query_states)
  462. key_layer_msa = self.key_msa(msa_key_value_states) # [bs, keylength, dim]
  463. value_layer_msa = self.value_msa(msa_key_value_states) # [bs, keylength, dim]
  464. else:
  465. key_layer_msa = None
  466. value_layer_msa = None
  467. key_layer = [key_layer_protein, key_layer_structure, key_layer_msa]
  468. key_layer = [_ for _ in key_layer if _ is not None]
  469. key_layer = torch.cat(key_layer, dim=1)
  470. value_layer = [value_layer_protein, value_layer_structure, value_layer_msa]
  471. value_layer = [_ for _ in value_layer if _ is not None]
  472. value_layer = torch.cat(value_layer, dim=1)
  473. new_query_layer_shape = query_layer.size()[:-1] + (
  474. self.num_attention_heads,
  475. self.attention_head_size,
  476. )
  477. query_layer = query_layer.view(*new_query_layer_shape).permute(0, 2, 1, 3)
  478. new_key_layer_shape = key_layer.size()[:-1] + (
  479. self.num_attention_heads,
  480. self.attention_head_size,
  481. )
  482. key_layer = key_layer.view(*new_key_layer_shape).permute(0, 2, 1, 3)
  483. new_value_layer_shape = value_layer.size()[:-1] + (
  484. self.num_attention_heads,
  485. self.attention_head_size,
  486. )
  487. value_layer = value_layer.view(*new_value_layer_shape).permute(0, 2, 1, 3)
  488. query_layer = query_layer * self.scale
  489. # attention_mask: [bs, 1, querylength, keylength]
  490. if query_attn_mask is None:
  491. query_attn_mask = torch.ones(query_states.size(0), query_states.size(1)).to(query_states.device)
  492. attention_mask = query_attn_mask[:, None, :, None] * kv_attn_mask[:, None, None, :]
  493. # Compute the scaled dot-product attention scores
  494. attn_weights = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # [bs, numheads, querylength, keylength]
  495. attn_weights = attn_weights - attn_weights.amax(dim=-1, keepdim=True).detach() # To stabilize score
  496. attention_scores = attn_weights.masked_fill(
  497. (1 - attention_mask).bool(), torch.finfo(attn_weights.dtype).min
  498. ) # [bs, numheads, querylength, keylength]
  499. attention_probs = nn.Softmax(dim=-1)(attention_scores)
  500. # attention_probs_dropped = self.dropout(attention_probs)
  501. context_layer = torch.matmul(attention_probs, value_layer) # [bs, numheads, querylength, dim/numheads]
  502. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  503. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  504. context_layer = context_layer.view(*new_context_layer_shape)
  505. context_layer = self.out_proj(context_layer)
  506. return context_layer
  507. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  508. def forward(
  509. self,
  510. query_states,
  511. protein_kv_states,
  512. structure_kv_states,
  513. msa_kv_states,
  514. query_attn_mask,
  515. protein_kv_attn_mask=None,
  516. structure_kv_attn_mask=None,
  517. msa_kv_attn_mask=None,
  518. protein_batch_mask=None,
  519. structure_batch_mask=None,
  520. msa_batch_mask=None,
  521. past_key_values=None,
  522. ):
  523. if protein_kv_states is not None:
  524. bs, protein_kv_seq_len, dim = protein_kv_states.shape
  525. if protein_kv_attn_mask is None:
  526. protein_kv_attn_mask = (
  527. torch.ones(bs, protein_kv_seq_len).to(protein_batch_mask.device)
  528. * protein_batch_mask.expand(size=(protein_kv_seq_len, bs)).T
  529. ).to(protein_kv_states.device)
  530. else:
  531. protein_kv_attn_mask = None
  532. if structure_kv_states is not None:
  533. bs, structure_kv_seq_len, dim = structure_kv_states.shape
  534. if structure_kv_attn_mask is None:
  535. structure_kv_attn_mask = (
  536. torch.ones(bs, structure_kv_seq_len).to(protein_batch_mask.device)
  537. * structure_batch_mask.expand(size=(structure_kv_seq_len, bs)).T
  538. ).to(structure_kv_states.device)
  539. else:
  540. structure_kv_attn_mask = None
  541. if msa_kv_states is not None:
  542. bs, msa_kv_seq_len, dim = msa_kv_states.shape
  543. if msa_kv_attn_mask is None:
  544. msa_kv_attn_mask = (
  545. torch.ones(bs, msa_kv_seq_len).to(protein_batch_mask.device)
  546. * msa_batch_mask.expand(size=(msa_kv_seq_len, bs)).T
  547. ).to(msa_kv_states.device)
  548. else:
  549. msa_kv_attn_mask = None
  550. hidden_states = query_states
  551. # only when there's at least one valid modality, crossattention will be performed
  552. if (
  553. (protein_kv_states is not None and protein_kv_attn_mask.any())
  554. or (structure_kv_states is not None and structure_kv_attn_mask.any())
  555. or (msa_kv_states is not None and msa_kv_attn_mask.any())
  556. ):
  557. residual = hidden_states
  558. hidden_states = self.cross_attention(
  559. query_states=hidden_states,
  560. protein_key_value_states=protein_kv_states,
  561. structure_key_value_states=structure_kv_states,
  562. msa_key_value_states=msa_kv_states,
  563. query_attn_mask=query_attn_mask,
  564. protein_kv_attn_mask=protein_kv_attn_mask,
  565. structure_kv_attn_mask=structure_kv_attn_mask,
  566. msa_kv_attn_mask=msa_kv_attn_mask,
  567. ) # [bs, query_seq_len, dim]
  568. # tanh gate
  569. hidden_states = torch.tanh(self.gate_attention) * hidden_states
  570. hidden_states = residual + hidden_states # input_query
  571. residual = hidden_states
  572. hidden_states = self.ff(hidden_states) * torch.tanh(self.gate_ffw)
  573. hidden_states = residual + hidden_states
  574. return hidden_states
  575. class EvollaRMSNorm(LlamaRMSNorm):
  576. pass
  577. class EvollaRotaryEmbedding(LlamaRotaryEmbedding):
  578. pass
  579. class EvollaMLP(LlamaMLP):
  580. pass
  581. class EvollaAttention(LlamaAttention):
  582. pass
  583. class EvollaDecoderLayer(LlamaDecoderLayer):
  584. def __init__(self, config: EvollaConfig, layer_idx: int):
  585. super().__init__(config, layer_idx)
  586. if (layer_idx + 1) % max(config.num_hidden_layers // config.aligner_num_add_layers, 1) == 0:
  587. self.adapter = EvollaSequenceAlignerCrossAttention(
  588. config,
  589. protein_encoder_dim=config.hidden_size,
  590. )
  591. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  592. def forward(
  593. self,
  594. hidden_states: torch.Tensor,
  595. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  596. attention_mask: Optional[torch.Tensor] = None,
  597. position_ids: Optional[torch.LongTensor] = None,
  598. past_key_values: Optional[Cache] = None,
  599. use_cache: Optional[bool] = False,
  600. cache_position: Optional[torch.LongTensor] = None,
  601. protein_kv_states: Optional[torch.Tensor] = None,
  602. structure_kv_states: Optional[torch.Tensor] = None,
  603. msa_kv_states: Optional[torch.Tensor] = None,
  604. protein_batch_mask: Optional[torch.Tensor] = None,
  605. structure_batch_mask: Optional[torch.Tensor] = None,
  606. msa_batch_mask: Optional[torch.Tensor] = None,
  607. query_attn_mask: Optional[torch.Tensor] = None,
  608. **kwargs,
  609. ):
  610. residual = hidden_states
  611. hidden_states = self.input_layernorm(hidden_states)
  612. # Self Attention
  613. hidden_states, _ = self.self_attn(
  614. hidden_states=hidden_states,
  615. attention_mask=attention_mask,
  616. position_ids=position_ids,
  617. past_key_values=past_key_values,
  618. use_cache=use_cache,
  619. cache_position=cache_position,
  620. position_embeddings=position_embeddings,
  621. **kwargs,
  622. )
  623. hidden_states = residual + hidden_states
  624. # Fully Connected
  625. residual = hidden_states
  626. hidden_states = self.post_attention_layernorm(hidden_states)
  627. hidden_states = self.mlp(hidden_states)
  628. hidden_states = residual + hidden_states
  629. if hasattr(self, "adapter"):
  630. hidden_states = self.adapter(
  631. query_states=hidden_states,
  632. protein_kv_states=protein_kv_states,
  633. structure_kv_states=structure_kv_states,
  634. msa_kv_states=msa_kv_states,
  635. query_attn_mask=query_attn_mask,
  636. protein_batch_mask=protein_batch_mask,
  637. structure_batch_mask=structure_batch_mask,
  638. msa_batch_mask=msa_batch_mask,
  639. )
  640. return hidden_states
  641. class EvollaPreTrainedModel(LlamaPreTrainedModel):
  642. _supports_flash_attn = False # see dependency on `EvollaSaProtProteinEncoder`
  643. _supports_flex_attn = False # see dependency on `EvollaSaProtProteinEncoder`
  644. _supports_attention_backend = False
  645. _no_split_modules = [
  646. "EvollaDecoderLayer",
  647. "EvollaSequenceCompressorResampler",
  648. "EvollaSequenceAlignerCrossAttention",
  649. ]
  650. def _init_weights(self, module):
  651. std = self.config.initializer_range
  652. PreTrainedModel._init_weights(self, module)
  653. if isinstance(module, EvollaSequenceAlignerCrossAttention):
  654. module.gate_attention.zero_()
  655. module.gate_ffw.zero_()
  656. module.attention_norm.weight.data.fill_(1.0)
  657. elif isinstance(module, EvollaSequenceCompressorResampler):
  658. module.latents.data.normal_(mean=0.0, std=std)
  659. class EvollaModel(EvollaPreTrainedModel):
  660. def __init__(self, config: EvollaConfig):
  661. super().__init__(config)
  662. self.padding_idx = config.pad_token_id
  663. self.vocab_size = config.vocab_size
  664. self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, self.padding_idx)
  665. self.protein_encoder = EvollaProteinEncoder(config=config)
  666. self.layers = nn.ModuleList(
  667. [
  668. EvollaDecoderLayer(
  669. config=config,
  670. layer_idx=layer_idx,
  671. )
  672. for layer_idx in range(config.num_hidden_layers)
  673. ]
  674. )
  675. self.norm = EvollaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  676. self.rotary_emb = EvollaRotaryEmbedding(config=config)
  677. self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False)
  678. self.post_init()
  679. def get_input_embeddings(self):
  680. return self.embed_tokens
  681. def set_input_embeddings(self, value):
  682. self.embed_tokens = value
  683. @auto_docstring
  684. @check_model_inputs()
  685. def forward(
  686. self,
  687. input_ids: Optional[torch.LongTensor] = None,
  688. attention_mask: Optional[torch.Tensor] = None,
  689. position_ids: Optional[torch.LongTensor] = None,
  690. past_key_values: Optional[Cache] = None,
  691. inputs_embeds: Optional[torch.FloatTensor] = None,
  692. use_cache: Optional[bool] = None,
  693. cache_position: Optional[torch.LongTensor] = None,
  694. protein_input_ids: Optional[torch.LongTensor] = None,
  695. protein_attention_mask: Optional[torch.Tensor] = None,
  696. structure_feats: Optional[torch.FloatTensor] = None,
  697. msa_feats: Optional[torch.FloatTensor] = None,
  698. structure_batch_mask: Optional[torch.Tensor] = None,
  699. msa_batch_mask: Optional[torch.Tensor] = None,
  700. **kwargs,
  701. ) -> Union[tuple, BaseModelOutputWithPast]:
  702. r"""
  703. protein_input_ids (torch.LongTensor):
  704. The input IDs for the protein sequence in structure-aware tokens. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`.
  705. protein_attention_mask (torch.Tensor):
  706. The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`.
  707. structure_feats (torch.FloatTensor):
  708. The input IDs for purely structure-based features. Should be of shape `(batch_size, structure_seq_length, structure_feat_dim)` and type `torch.FloatTensor`. Dummy input for now.
  709. msa_feats (torch.FloatTensor):
  710. The input IDs for purely MSA-based features. Should be of shape `(batch_size, msa_seq_length, msa_feat_dim)` and type `torch.FloatTensor`. Dummy input for now.
  711. structure_batch_mask (torch.Tensor):
  712. The batch mask to decide which protein sequences are purely structure-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `structure_feats`. Dummpy input for now.
  713. msa_batch_mask (torch.Tensor):
  714. The batch mask to decide which protein sequences are purely MSA-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `msa_feats`. Dummpy input for now.
  715. """
  716. if (input_ids is None) ^ (inputs_embeds is not None):
  717. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  718. if inputs_embeds is None:
  719. inputs_embeds = self.embed_tokens(input_ids)
  720. if use_cache and past_key_values is None:
  721. past_key_values = DynamicCache(config=self.config)
  722. if cache_position is None:
  723. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  724. cache_position = torch.arange(
  725. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  726. )
  727. if position_ids is None:
  728. position_ids = cache_position.unsqueeze(0)
  729. protein_feats = None
  730. protein_batch_mask = None
  731. # If provided, actually compute them
  732. if protein_input_ids is not None and protein_attention_mask is not None:
  733. protein_outputs = self.protein_encoder(
  734. input_ids=protein_input_ids,
  735. attention_mask=protein_attention_mask,
  736. )
  737. protein_feats = protein_outputs.sequence_compressor_output
  738. protein_batch_mask = torch.tensor([True] * protein_input_ids.shape[0], device=protein_input_ids.device)
  739. causal_mask = create_causal_mask(
  740. config=self.config,
  741. input_embeds=inputs_embeds,
  742. attention_mask=attention_mask,
  743. cache_position=cache_position,
  744. past_key_values=past_key_values,
  745. )
  746. hidden_states = inputs_embeds
  747. # create position embeddings to be shared across the decoder layers
  748. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  749. for decoder_layer in self.layers:
  750. hidden_states = decoder_layer(
  751. hidden_states,
  752. attention_mask=causal_mask,
  753. position_ids=position_ids,
  754. past_key_values=past_key_values,
  755. use_cache=use_cache,
  756. cache_position=cache_position,
  757. position_embeddings=position_embeddings,
  758. protein_kv_states=protein_feats,
  759. structure_kv_states=structure_feats,
  760. msa_kv_states=msa_feats,
  761. protein_batch_mask=protein_batch_mask,
  762. structure_batch_mask=structure_batch_mask,
  763. msa_batch_mask=msa_batch_mask,
  764. query_attn_mask=attention_mask,
  765. **kwargs,
  766. )
  767. hidden_states = self.norm(hidden_states)
  768. output = BaseModelOutputWithPast(
  769. last_hidden_state=hidden_states,
  770. past_key_values=past_key_values,
  771. )
  772. return output
  773. class EvollaForProteinText2Text(EvollaPreTrainedModel, GenerationMixin):
  774. def __init__(self, config):
  775. super().__init__(config)
  776. self.model = EvollaModel(config)
  777. self.vocab_size = config.vocab_size
  778. self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=False)
  779. self.post_init()
  780. def get_input_embeddings(self):
  781. return self.model.get_input_embeddings()
  782. def set_input_embeddings(self, value):
  783. return self.model.set_input_embeddings(value)
  784. @can_return_tuple
  785. @auto_docstring
  786. def forward(
  787. self,
  788. input_ids: Optional[torch.LongTensor] = None, # text input ids
  789. attention_mask: Optional[torch.Tensor] = None, # text attention mask
  790. inputs_embeds: Optional[torch.FloatTensor] = None, # text input embeddings
  791. labels: Optional[torch.LongTensor] = None,
  792. protein_input_ids: Optional[torch.LongTensor] = None,
  793. protein_attention_mask: Optional[torch.Tensor] = None,
  794. use_cache: Optional[bool] = None,
  795. **kwargs,
  796. ):
  797. r"""
  798. protein_input_ids (torch.LongTensor):
  799. The input IDs for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`.
  800. protein_attention_mask (torch.Tensor):
  801. The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`.
  802. Example:
  803. ```python
  804. >>> from transformers import EvollaProcessor, EvollaForProteinText2Text
  805. >>> model = EvollaForProteinText2Text.from_pretrained("westlake/Evolla-10B-hf")
  806. >>> processor = EvollaProcessor.from_pretrained("westlake/Evolla-10B-hf")
  807. >>> protein_information = {
  808. "aa_seq": "your amino acid sequence",
  809. "foldseek": "your foldseek sequence",
  810. }
  811. >>> question = "What is the function of this protein?"
  812. >>> message = [
  813. {"role": "system", "content": "You are an AI expert that can answer any questions about protein."},
  814. {"role": "user", "content": question},
  815. ]
  816. >>> inputs = processor(proteins=[protein_information], messages_list=[message], return_tensors="pt", padding="longest")
  817. >>> outputs = model.generate(**inputs)
  818. >>> print(processor.batch_decode(outputs, skip_special_tokens=True))
  819. ```"""
  820. outputs = self.model(
  821. input_ids=input_ids,
  822. attention_mask=attention_mask,
  823. inputs_embeds=inputs_embeds,
  824. protein_input_ids=protein_input_ids,
  825. protein_attention_mask=protein_attention_mask,
  826. use_cache=use_cache,
  827. **kwargs,
  828. )
  829. hidden_states = outputs[0]
  830. logits = self.lm_head(hidden_states)
  831. loss = None
  832. if labels is not None:
  833. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
  834. lm_outputs = CausalLMOutputWithPast(
  835. loss=loss,
  836. logits=logits,
  837. past_key_values=outputs.past_key_values,
  838. hidden_states=outputs.hidden_states,
  839. attentions=outputs.attentions,
  840. )
  841. return lm_outputs
  842. __all__ = ["EvollaForProteinText2Text", "EvollaModel", "EvollaPreTrainedModel"]