modeling_evolla.py 67 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/evolla/modular_evolla.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_evolla.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 Westlake Representational Learning Lab (Fajie Yuan Lab) team and the HuggingFace Inc. team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import math
  22. import warnings
  23. from dataclasses import dataclass
  24. from typing import Callable, Optional, Union
  25. import torch
  26. from torch import Tensor, nn
  27. from ...activations import ACT2FN
  28. from ...cache_utils import Cache, DynamicCache
  29. from ...generation import GenerationMixin
  30. from ...integrations import use_kernel_forward_from_hub
  31. from ...masking_utils import create_causal_mask
  32. from ...modeling_layers import GradientCheckpointingLayer
  33. from ...modeling_outputs import (
  34. BaseModelOutputWithCrossAttentions,
  35. BaseModelOutputWithPast,
  36. BaseModelOutputWithPoolingAndCrossAttentions,
  37. CausalLMOutputWithPast,
  38. ModelOutput,
  39. )
  40. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  41. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, ModuleUtilsMixin, PreTrainedModel, get_parameter_dtype
  42. from ...processing_utils import Unpack
  43. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  44. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  45. from ...utils.deprecation import deprecate_kwarg
  46. from ...utils.generic import OutputRecorder, check_model_inputs
  47. from .configuration_evolla import EvollaConfig, SaProtConfig
  48. def create_position_ids_from_input_ids(input_ids, padding_idx):
  49. """
  50. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  51. are ignored. This is modified from fairseq's `utils.make_positions`.
  52. Args:
  53. x: torch.Tensor x:
  54. Returns: torch.Tensor
  55. """
  56. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  57. mask = input_ids.ne(padding_idx).int()
  58. incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
  59. return incremental_indices.long() + padding_idx
  60. class EvollaSaProtEmbeddings(nn.Module):
  61. """
  62. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
  63. """
  64. def __init__(self, config):
  65. super().__init__()
  66. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  67. if config.emb_layer_norm_before:
  68. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  69. else:
  70. self.layer_norm = None
  71. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  72. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  73. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  74. self.register_buffer(
  75. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  76. )
  77. self.padding_idx = config.pad_token_id
  78. if self.position_embedding_type == "absolute":
  79. self.position_embeddings = nn.Embedding(
  80. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  81. )
  82. self.token_dropout = config.token_dropout
  83. self.mask_token_id = config.mask_token_id
  84. # remove the position_ids in EsmEmbeddings
  85. self.position_ids = None
  86. def forward(
  87. self,
  88. input_ids=None,
  89. attention_mask=None,
  90. position_ids=None,
  91. inputs_embeds=None,
  92. ):
  93. if position_ids is None:
  94. if input_ids is not None:
  95. # Create the position ids from the input token ids. Any padded tokens remain padded.
  96. position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx)
  97. else:
  98. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
  99. if inputs_embeds is None:
  100. inputs_embeds = self.word_embeddings(input_ids)
  101. # Note that if we want to support EVOLLA_SA_PROT-1 (not 1b!) in future then we need to support an
  102. # embedding_scale factor here.
  103. embeddings = inputs_embeds
  104. # Matt: EVOLLA_SA_PROT has the option to handle masking in MLM in a slightly unusual way. If the token_dropout
  105. # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however,
  106. # masked tokens are treated as if they were selected for input dropout and zeroed out.
  107. # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by
  108. # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample).
  109. # This is analogous to the way that dropout layers scale down outputs during evaluation when not
  110. # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training).
  111. if self.token_dropout and input_ids is not None:
  112. embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
  113. mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all EVOLLA_SA_PROT model training runs
  114. src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1]
  115. mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
  116. embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
  117. embeddings.dtype
  118. )
  119. if self.position_embedding_type == "absolute":
  120. position_embeddings = self.position_embeddings(position_ids)
  121. embeddings = embeddings + position_embeddings
  122. if self.layer_norm is not None:
  123. embeddings = self.layer_norm(embeddings)
  124. if attention_mask is not None:
  125. embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
  126. # Matt: I think this line was copied incorrectly from BERT, disabling it for now.
  127. # embeddings = self.dropout(embeddings)
  128. return embeddings
  129. def create_position_ids_from_inputs_embeds(self, inputs_embeds):
  130. """
  131. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  132. Args:
  133. inputs_embeds: torch.Tensor
  134. Returns: torch.Tensor
  135. """
  136. input_shape = inputs_embeds.size()[:-1]
  137. sequence_length = input_shape[1]
  138. position_ids = torch.arange(
  139. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  140. )
  141. return position_ids.unsqueeze(0).expand(input_shape)
  142. def rotate_half_esm(x):
  143. x1, x2 = x.chunk(2, dim=-1)
  144. return torch.cat((-x2, x1), dim=-1)
  145. def apply_rotary_pos_emb_esm(x, cos, sin):
  146. cos = cos[:, :, : x.shape[-2], :]
  147. sin = sin[:, :, : x.shape[-2], :]
  148. return (x * cos) + (rotate_half_esm(x) * sin)
  149. class EvollaSaProtRotaryEmbedding(nn.Module):
  150. """
  151. Rotary position embeddings based on those in
  152. [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
  153. matrices which depend on their relative positions.
  154. """
  155. inv_freq: torch.Tensor # fix linting for `register_buffer`
  156. def __init__(self, dim: int):
  157. super().__init__()
  158. # Generate and save the inverse frequency buffer (non trainable)
  159. inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
  160. self.register_buffer("inv_freq", inv_freq)
  161. self._seq_len_cached = None
  162. self._cos_cached = None
  163. self._sin_cached = None
  164. def _update_cos_sin_tables(self, x, seq_dimension=2):
  165. seq_len = x.shape[seq_dimension]
  166. # Reset the tables if the sequence length has changed,
  167. # or if we're on a new device (possibly due to tracing for instance)
  168. if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
  169. self._seq_len_cached = seq_len
  170. t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
  171. freqs = torch.outer(t, self.inv_freq)
  172. emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
  173. self._cos_cached = emb.cos()[None, None, :, :]
  174. self._sin_cached = emb.sin()[None, None, :, :]
  175. return self._cos_cached, self._sin_cached
  176. def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  177. self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
  178. return (
  179. apply_rotary_pos_emb_esm(q, self._cos_cached, self._sin_cached).to(dtype=q.dtype),
  180. apply_rotary_pos_emb_esm(k, self._cos_cached, self._sin_cached).to(dtype=k.dtype),
  181. )
  182. def eager_attention_forward(
  183. module: nn.Module,
  184. query: torch.Tensor,
  185. key: torch.Tensor,
  186. value: torch.Tensor,
  187. attention_mask: Optional[torch.Tensor],
  188. scaling: float,
  189. dropout: float = 0.0,
  190. head_mask: Optional[torch.Tensor] = None,
  191. **kwargs: Unpack[TransformersKwargs],
  192. ):
  193. # EVOLLA_SA_PROT applies relative position embeddings and we don't copy from Llama
  194. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  195. if hasattr(module, "position_embedding_type") and module.position_embedding_type in [
  196. "relative_key",
  197. "relative_key_query",
  198. ]:
  199. seq_length = query.shape[2]
  200. position_ids_l = torch.arange(seq_length, dtype=torch.long, device=attn_weights.device).view(-1, 1)
  201. position_ids_r = torch.arange(seq_length, dtype=torch.long, device=attn_weights.device).view(1, -1)
  202. distance = position_ids_l - position_ids_r
  203. positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1)
  204. positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility
  205. if module.position_embedding_type == "relative_key":
  206. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
  207. elif module.position_embedding_type == "relative_key_query":
  208. relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
  209. relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding)
  210. relative_position_scores = relative_position_scores_query + relative_position_scores_key
  211. attn_weights = attn_weights + relative_position_scores
  212. if attention_mask is not None:
  213. causal_mask = attention_mask[:, :, :, : key.shape[-2]]
  214. attn_weights = attn_weights + causal_mask
  215. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  216. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  217. if head_mask is not None:
  218. attn_weights = attn_weights * head_mask
  219. attn_output = torch.matmul(attn_weights, value)
  220. attn_output = attn_output.transpose(1, 2).contiguous()
  221. return attn_output, attn_weights
  222. class EvollaSaProtSelfAttention(nn.Module):
  223. def __init__(self, config, position_embedding_type=None, layer_idx=None, is_cross_attention=False):
  224. super().__init__()
  225. self.config = config
  226. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  227. raise ValueError(
  228. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  229. f"heads ({config.num_attention_heads})"
  230. )
  231. self.num_attention_heads = config.num_attention_heads
  232. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  233. self.all_head_size = self.num_attention_heads * self.attention_head_size
  234. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  235. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  236. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  237. self.dropout = config.attention_probs_dropout_prob
  238. self.position_embedding_type = position_embedding_type or getattr(
  239. config, "position_embedding_type", "absolute"
  240. )
  241. self.rotary_embeddings = None
  242. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  243. self.max_position_embeddings = config.max_position_embeddings
  244. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  245. elif self.position_embedding_type == "rotary":
  246. self.rotary_embeddings = EvollaSaProtRotaryEmbedding(dim=self.attention_head_size)
  247. self.is_decoder = config.is_decoder
  248. self.layer_idx = layer_idx
  249. self.scaling = 1.0
  250. self.is_causal = self.is_decoder and not is_cross_attention
  251. def forward(
  252. self,
  253. hidden_states: torch.Tensor,
  254. attention_mask: Optional[torch.FloatTensor] = None,
  255. head_mask: Optional[torch.FloatTensor] = None,
  256. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  257. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  258. **kwargs: Unpack[TransformersKwargs],
  259. ) -> tuple[torch.Tensor]:
  260. batch_size, seq_length = hidden_states.shape[:-1]
  261. hidden_shape = (batch_size, seq_length, -1, self.attention_head_size)
  262. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  263. is_cross_attention = encoder_hidden_states is not None
  264. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  265. attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
  266. key_layer = self.key(current_states).view(hidden_shape).transpose(1, 2)
  267. value_layer = self.value(current_states).view(hidden_shape).transpose(1, 2)
  268. # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
  269. # EVOLLA_SA_PROT scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
  270. # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
  271. # EVOLLA_SA_PROT code and fix rotary embeddings.
  272. query_layer = query_layer * self.attention_head_size**-0.5
  273. if self.position_embedding_type == "rotary":
  274. query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
  275. attention_interface: Callable = eager_attention_forward
  276. if self.config._attn_implementation != "eager":
  277. if self.position_embedding_type in ["relative_key", "relative_key_query"]:
  278. raise ValueError(
  279. f"ESM {self.config._attn_implementation} attention does not support {self.position_embedding_type} embeddings. "
  280. "Set attention explicitly to 'eager' with `model.set_attn_implementation('eager')`"
  281. )
  282. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  283. attn_output, attn_weights = attention_interface(
  284. self,
  285. query_layer,
  286. key_layer,
  287. value_layer,
  288. attention_mask,
  289. dropout=0.0 if not self.training else self.dropout,
  290. scaling=self.scaling,
  291. head_mask=head_mask,
  292. **kwargs,
  293. )
  294. attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
  295. return attn_output, attn_weights
  296. class EvollaSaProtSelfOutput(nn.Module):
  297. def __init__(self, config):
  298. super().__init__()
  299. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  300. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  301. def forward(self, hidden_states, input_tensor):
  302. hidden_states = self.dense(hidden_states)
  303. hidden_states = self.dropout(hidden_states)
  304. hidden_states = hidden_states + input_tensor
  305. return hidden_states
  306. class EvollaSaProtAttention(nn.Module):
  307. def __init__(self, config, layer_idx=None, is_cross_attention=False):
  308. super().__init__()
  309. self.self = EvollaSaProtSelfAttention(config, layer_idx=layer_idx, is_cross_attention=is_cross_attention)
  310. self.output = EvollaSaProtSelfOutput(config)
  311. self.pruned_heads = set()
  312. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  313. def prune_heads(self, heads):
  314. if len(heads) == 0:
  315. return
  316. heads, index = find_pruneable_heads_and_indices(
  317. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  318. )
  319. # Prune linear layers
  320. self.self.query = prune_linear_layer(self.self.query, index)
  321. self.self.key = prune_linear_layer(self.self.key, index)
  322. self.self.value = prune_linear_layer(self.self.value, index)
  323. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  324. # Update hyper params and store pruned heads
  325. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  326. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  327. self.pruned_heads = self.pruned_heads.union(heads)
  328. def forward(
  329. self,
  330. hidden_states,
  331. attention_mask=None,
  332. head_mask=None,
  333. encoder_hidden_states=None,
  334. encoder_attention_mask=None,
  335. **kwargs: Unpack[TransformersKwargs],
  336. ):
  337. hidden_states_ln = self.LayerNorm(hidden_states)
  338. attn_output, _ = self.self(
  339. hidden_states_ln,
  340. attention_mask=attention_mask,
  341. head_mask=head_mask,
  342. encoder_hidden_states=encoder_hidden_states,
  343. encoder_attention_mask=encoder_attention_mask,
  344. **kwargs,
  345. )
  346. attn_output = self.output(attn_output, hidden_states)
  347. return attn_output
  348. def gelu(x):
  349. """
  350. This is the gelu implementation from the original EVOLLA_SA_PROT repo. Using F.gelu yields subtly wrong results.
  351. """
  352. return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
  353. class EvollaSaProtIntermediate(nn.Module):
  354. def __init__(self, config):
  355. super().__init__()
  356. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  357. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  358. hidden_states = self.dense(hidden_states)
  359. hidden_states = gelu(hidden_states)
  360. return hidden_states
  361. class EvollaSaProtOutput(nn.Module):
  362. def __init__(self, config):
  363. super().__init__()
  364. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  365. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  366. def forward(self, hidden_states, input_tensor):
  367. hidden_states = self.dense(hidden_states)
  368. hidden_states = self.dropout(hidden_states)
  369. hidden_states = hidden_states + input_tensor
  370. return hidden_states
  371. class EvollaSaProtLayer(GradientCheckpointingLayer):
  372. def __init__(self, config):
  373. super().__init__()
  374. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  375. self.seq_len_dim = 1
  376. self.attention = EvollaSaProtAttention(config)
  377. self.is_decoder = config.is_decoder
  378. self.add_cross_attention = config.add_cross_attention
  379. if self.add_cross_attention:
  380. if not self.is_decoder:
  381. raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
  382. self.crossattention = EvollaSaProtAttention(config, is_cross_attention=True)
  383. self.intermediate = EvollaSaProtIntermediate(config)
  384. self.output = EvollaSaProtOutput(config)
  385. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  386. def forward(
  387. self,
  388. hidden_states,
  389. attention_mask=None,
  390. head_mask=None,
  391. encoder_hidden_states=None,
  392. encoder_attention_mask=None,
  393. **kwargs: Unpack[TransformersKwargs],
  394. ):
  395. attention_output = self.attention(
  396. hidden_states,
  397. attention_mask=attention_mask,
  398. head_mask=head_mask,
  399. **kwargs,
  400. )
  401. if self.is_decoder and encoder_hidden_states is not None:
  402. if not hasattr(self, "crossattention"):
  403. raise AttributeError(
  404. f"If `encoder_hidden_states` are passed, {self} has to be instantiated"
  405. " with cross-attention layers by setting `config.add_cross_attention=True`"
  406. )
  407. attention_output = self.crossattention(
  408. attention_output,
  409. attention_mask=attention_mask,
  410. head_mask=head_mask,
  411. encoder_hidden_states=encoder_hidden_states,
  412. encoder_attention_mask=encoder_attention_mask,
  413. **kwargs,
  414. )
  415. layer_output = self.feed_forward_chunk(attention_output)
  416. return layer_output
  417. def feed_forward_chunk(self, attention_output):
  418. attention_output_ln = self.LayerNorm(attention_output)
  419. intermediate_output = self.intermediate(attention_output_ln)
  420. layer_output = self.output(intermediate_output, attention_output)
  421. return layer_output
  422. class EvollaSaProtEncoder(nn.Module):
  423. def __init__(self, config):
  424. super().__init__()
  425. self.config = config
  426. self.layer = nn.ModuleList([EvollaSaProtLayer(config) for _ in range(config.num_hidden_layers)])
  427. self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  428. self.gradient_checkpointing = False
  429. @can_return_tuple
  430. def forward(
  431. self,
  432. hidden_states,
  433. attention_mask=None,
  434. head_mask=None,
  435. encoder_hidden_states=None,
  436. encoder_attention_mask=None,
  437. **kwargs: Unpack[TransformersKwargs],
  438. ):
  439. for i, layer_module in enumerate(self.layer):
  440. layer_head_mask = head_mask[i] if head_mask is not None else None
  441. hidden_states = layer_module(
  442. hidden_states,
  443. attention_mask=attention_mask,
  444. head_mask=layer_head_mask,
  445. encoder_hidden_states=encoder_hidden_states,
  446. encoder_attention_mask=encoder_attention_mask,
  447. **kwargs,
  448. )
  449. if self.emb_layer_norm_after:
  450. hidden_states = self.emb_layer_norm_after(hidden_states)
  451. return BaseModelOutputWithCrossAttentions(last_hidden_state=hidden_states)
  452. class EvollaSaProtPooler(nn.Module):
  453. def __init__(self, config):
  454. super().__init__()
  455. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  456. self.activation = nn.Tanh()
  457. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  458. # We "pool" the model by simply taking the hidden state corresponding
  459. # to the first token.
  460. first_token_tensor = hidden_states[:, 0]
  461. pooled_output = self.dense(first_token_tensor)
  462. pooled_output = self.activation(pooled_output)
  463. return pooled_output
  464. @auto_docstring
  465. class EvollaSaProtPreTrainedModel(PreTrainedModel):
  466. config: SaProtConfig
  467. _no_split_modules = ["EvollaSaProtLayer"]
  468. _supports_flash_attn = True
  469. _supports_sdpa = True
  470. _supports_attention_backend = True
  471. _can_record_outputs = {
  472. "hidden_states": EvollaSaProtLayer,
  473. "attentions": [OutputRecorder(EvollaSaProtSelfAttention, index=1, layer_name="attention")],
  474. "cross_attentions": [
  475. OutputRecorder(EvollaSaProtSelfAttention, index=1, layer_name="crossattention"),
  476. ],
  477. }
  478. def _init_weights(self, module):
  479. """Initialize the weights"""
  480. std = self.config.initializer_range
  481. if isinstance(module, nn.Linear):
  482. module.weight.data.normal_(mean=0.0, std=std)
  483. if module.bias is not None:
  484. module.bias.data.zero_()
  485. elif isinstance(module, nn.Embedding):
  486. module.weight.data.normal_(mean=0.0, std=std)
  487. if module.padding_idx is not None:
  488. module.weight.data[module.padding_idx].zero_()
  489. elif isinstance(module, nn.LayerNorm):
  490. module.bias.data.zero_()
  491. module.weight.data.fill_(1.0)
  492. class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel):
  493. def __init__(self, config: SaProtConfig):
  494. super().__init__(config)
  495. self.embeddings = EvollaSaProtEmbeddings(config)
  496. self.encoder = EvollaSaProtEncoder(config)
  497. def get_input_embeddings(self):
  498. return self.embeddings.word_embeddings
  499. def set_input_embeddings(self, value):
  500. self.embeddings.word_embeddings = value
  501. def _prune_heads(self, heads_to_prune):
  502. """
  503. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  504. class PreTrainedModel
  505. """
  506. for layer, heads in heads_to_prune.items():
  507. self.encoder.layer[layer].attention.prune_heads(heads)
  508. @check_model_inputs()
  509. def forward(
  510. self,
  511. input_ids: Optional[torch.Tensor],
  512. attention_mask: Optional[torch.Tensor] = None,
  513. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
  514. input_shape = input_ids.size()
  515. batch_size, seq_length = input_shape
  516. device = input_ids.device
  517. if attention_mask is None:
  518. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  519. inputs_embeds = self.embeddings(input_ids=input_ids, attention_mask=attention_mask)
  520. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
  521. encoder_outputs = self.encoder(inputs_embeds, attention_mask=extended_attention_mask)
  522. sequence_output = encoder_outputs[0]
  523. return BaseModelOutputWithPoolingAndCrossAttentions(
  524. last_hidden_state=sequence_output,
  525. hidden_states=encoder_outputs.hidden_states,
  526. attentions=encoder_outputs.attentions,
  527. cross_attentions=encoder_outputs.cross_attentions,
  528. )
  529. def get_extended_attention_mask(
  530. self,
  531. attention_mask: Tensor,
  532. input_shape: tuple[int],
  533. device: Optional[torch.device] = None,
  534. dtype: Optional[torch.dtype] = None,
  535. ) -> Tensor:
  536. """
  537. Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
  538. Arguments:
  539. attention_mask (`torch.Tensor`):
  540. Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
  541. input_shape (`Tuple[int]`):
  542. The shape of the input to the model.
  543. Returns:
  544. `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
  545. """
  546. if dtype is None:
  547. dtype = get_parameter_dtype(self)
  548. if not (attention_mask.dim() == 2 and self.config.is_decoder):
  549. # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
  550. if device is not None:
  551. warnings.warn(
  552. "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
  553. )
  554. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  555. # ourselves in which case we just need to make it broadcastable to all heads.
  556. if attention_mask.dim() == 3:
  557. extended_attention_mask = attention_mask[:, None, :, :]
  558. elif attention_mask.dim() == 2:
  559. # Provided a padding mask of dimensions [batch_size, seq_length]
  560. # - if the model is a decoder, apply a causal mask in addition to the padding mask
  561. # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
  562. if self.config.is_decoder:
  563. extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
  564. input_shape, attention_mask, device
  565. )
  566. else:
  567. extended_attention_mask = attention_mask[:, None, None, :]
  568. else:
  569. raise ValueError(
  570. f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
  571. )
  572. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  573. # masked positions, this operation will create a tensor which is 0.0 for
  574. # positions we want to attend and the dtype's smallest value for masked positions.
  575. # Since we are adding it to the raw scores before the softmax, this is
  576. # effectively the same as removing these entirely.
  577. extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
  578. extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
  579. return extended_attention_mask
  580. class EvollaSequenceCompressorAttention(nn.Module):
  581. def __init__(self, dim, dim_head=64, heads=8):
  582. super().__init__()
  583. self.scale = dim_head**-0.5
  584. self.heads = heads
  585. inner_dim = dim_head * heads
  586. self.norm_media = nn.LayerNorm(dim)
  587. self.norm_latents = nn.LayerNorm(dim)
  588. self.to_q = nn.Linear(dim, inner_dim, bias=False)
  589. self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
  590. self.to_out = nn.Linear(inner_dim, dim, bias=False)
  591. def forward(self, x, latents, mask):
  592. """
  593. Args:
  594. x (torch.Tensor): image features
  595. shape (b, n1, D)
  596. latent (torch.Tensor): latent features
  597. shape (b, n2, D); n2: num of latent tokens
  598. """
  599. x = self.norm_media(x)
  600. latents = self.norm_latents(latents)
  601. h = self.heads
  602. q = self.to_q(latents)
  603. kv_input = torch.cat((x, latents), dim=-2)
  604. k, v = self.to_kv(kv_input).chunk(
  605. 2, dim=-1
  606. ) # each: batch_size, max_protein_length+num_latents, dim_head*num_heads
  607. q = q.view(q.size(0), q.size(1), h, -1).permute(0, 2, 1, 3)
  608. k = k.view(k.size(0), k.size(1), h, -1).permute(0, 2, 1, 3)
  609. v = v.view(v.size(0), v.size(1), h, -1).permute(0, 2, 1, 3)
  610. q = q * self.scale # batch_size, num_heads, num_latents, dim_head
  611. # attention
  612. sim = torch.matmul(q, k.transpose(-1, -2))
  613. sim = sim - sim.amax(dim=-1, keepdim=True).detach()
  614. bs, nh, skd, okd = sim.shape
  615. ones = torch.ones(nh, skd).to(mask.device) # Create a tensor of ones with shape (nh, skd)
  616. mask_exp = mask[:, None, None, :]
  617. ones_exp = ones[None, :, :, None]
  618. mask = mask_exp * ones_exp
  619. sim = sim.masked_fill((1 - mask).bool(), -1e4)
  620. attn = sim.softmax(dim=-1)
  621. out = torch.matmul(attn, v)
  622. out = out.permute(0, 2, 1, 3)
  623. # [batch, seq, head, features] -> [batch, seq, head*features]
  624. out = out.reshape(out.size(0), out.size(1), -1)
  625. return self.to_out(out)
  626. class EvollaFeedForward(nn.Module):
  627. def __init__(self, dim, mult=4):
  628. super().__init__()
  629. inner_dim = int(dim * mult)
  630. self.norm = nn.LayerNorm(dim)
  631. self.fc1 = nn.Linear(dim, inner_dim, bias=False)
  632. self.activation = nn.GELU()
  633. self.fc2 = nn.Linear(inner_dim, dim, bias=False)
  634. def forward(self, x):
  635. return self.fc2(self.activation(self.fc1(self.norm(x))))
  636. class EvollaSequenceCompressorResampler(nn.Module):
  637. def __init__(self, config: EvollaConfig):
  638. super().__init__()
  639. protein_repr_dim = config.protein_encoder_config.hidden_size
  640. self.num_latents = config.resampler_num_latents
  641. self.latents = nn.Parameter(torch.randn(self.num_latents, protein_repr_dim), requires_grad=True)
  642. self.layers = nn.ModuleList([])
  643. for _ in range(config.resampler_depth):
  644. self.layers.append(
  645. nn.ModuleList(
  646. [
  647. EvollaSequenceCompressorAttention(
  648. dim=protein_repr_dim, dim_head=config.resampler_dim_head, heads=config.resampler_heads
  649. ),
  650. EvollaFeedForward(dim=protein_repr_dim, mult=config.resampler_ff_mult),
  651. ]
  652. )
  653. )
  654. self.norm = nn.LayerNorm(config.hidden_size)
  655. self.protein_projector = nn.Linear(protein_repr_dim, config.hidden_size)
  656. def forward(self, embeds, mask):
  657. b = embeds.shape[0]
  658. bs, _ = mask.shape # bs, max_protein_length
  659. latent_mask = torch.ones(bs, self.num_latents).to(mask.device)
  660. mask = torch.cat((mask, latent_mask), dim=1) # bs, max_protein_length + num_latents
  661. # blocks
  662. ones = torch.ones(b).to(self.latents.device)
  663. latents = self.latents[None] * ones.view(-1, 1, 1) # [b,n,d]
  664. latents = latents.to(embeds.dtype)
  665. for attn, ff in self.layers:
  666. latents = attn(embeds, latents, mask) + latents
  667. latents = ff(latents) + latents
  668. transformed_feature = self.protein_projector(latents)
  669. return self.norm(transformed_feature)
  670. @dataclass
  671. @auto_docstring
  672. class EvollaProteinEncoderModelOutput(ModelOutput):
  673. sequence_compressor_output: Optional[torch.FloatTensor] = None
  674. last_hidden_state: Optional[torch.FloatTensor] = None
  675. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  676. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  677. class EvollaProteinEncoder(nn.Module):
  678. def __init__(self, config: EvollaConfig):
  679. super().__init__()
  680. self.model = EvollaSaProtProteinEncoder(config=config.protein_encoder_config)
  681. self.sequence_compressor_resampler = EvollaSequenceCompressorResampler(config=config)
  682. @can_return_tuple
  683. def forward(self, input_ids: torch.LongTensor, attention_mask: torch.FloatTensor, **kwargs):
  684. protein_output = self.model(input_ids=input_ids, attention_mask=attention_mask)
  685. protein_embeds = protein_output.last_hidden_state
  686. sequence_repr = self.sequence_compressor_resampler(protein_embeds, attention_mask)
  687. return EvollaProteinEncoderModelOutput(
  688. sequence_compressor_output=sequence_repr,
  689. last_hidden_state=protein_output.last_hidden_state,
  690. )
  691. class EvollaSequenceAlignerCrossAttention(nn.Module):
  692. def __init__(
  693. self,
  694. config,
  695. protein_encoder_dim: Optional[int] = None,
  696. structure_encoder_dim: Optional[int] = None,
  697. msa_encoder_dim: Optional[int] = None,
  698. ):
  699. super().__init__()
  700. self.hidden_size = config.hidden_size
  701. self.num_attention_heads = config.num_attention_heads
  702. self.scale = self.num_attention_heads**-0.5
  703. self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
  704. self.all_head_size = self.num_attention_heads * self.attention_head_size
  705. attention_probs_dropout_prob = config.aligner_attention_probs_dropout_prob
  706. enable_bias = config.aligner_enable_bias
  707. ffn_mult = config.aligner_ffn_mult
  708. self.query = nn.Linear(self.hidden_size, self.all_head_size)
  709. if protein_encoder_dim is not None:
  710. self.key_protein = nn.Linear(protein_encoder_dim, self.all_head_size)
  711. self.value_protein = nn.Linear(protein_encoder_dim, self.all_head_size)
  712. else:
  713. self.key_protein = None
  714. self.value_protein = None
  715. if structure_encoder_dim is not None:
  716. self.key_structure = nn.Linear(structure_encoder_dim, self.all_head_size)
  717. self.value_structure = nn.Linear(structure_encoder_dim, self.all_head_size)
  718. else:
  719. self.key_structure = None
  720. self.value_structure = None
  721. if msa_encoder_dim is not None:
  722. self.key_msa = nn.Linear(msa_encoder_dim, self.all_head_size)
  723. self.value_msa = nn.Linear(msa_encoder_dim, self.all_head_size)
  724. else:
  725. self.key_msa = None
  726. self.value_msa = None
  727. self.attention_norm = EvollaRMSNorm(self.hidden_size)
  728. self.dropout = nn.Dropout(attention_probs_dropout_prob)
  729. self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=enable_bias)
  730. self.ff = EvollaFeedForward(self.hidden_size, ffn_mult)
  731. self.gate_attention = nn.Parameter(torch.tensor([0.0]))
  732. self.gate_ffw = nn.Parameter(torch.tensor([0.0]))
  733. def cross_attention(
  734. self,
  735. query_states,
  736. protein_key_value_states,
  737. structure_key_value_states,
  738. msa_key_value_states,
  739. query_attn_mask,
  740. protein_kv_attn_mask,
  741. structure_kv_attn_mask,
  742. msa_kv_attn_mask,
  743. ):
  744. """
  745. query_states: text
  746. key_value_states: protein
  747. query_states: [bs, query_seq_len, dim]
  748. key_value_states: [bs, kv_seq_len, dim]
  749. query_attn_mask: [bs, query_seq_len]
  750. kv_attn_mask: [bs, kv_seq_len]
  751. """
  752. # Concatenate protein and structure
  753. kv_attn_mask = [protein_kv_attn_mask, structure_kv_attn_mask, msa_kv_attn_mask]
  754. kv_attn_mask = [_ for _ in kv_attn_mask if _ is not None]
  755. if not kv_attn_mask:
  756. raise ValueError("At least one modality should be provided for cross attention.")
  757. kv_attn_mask = torch.cat(kv_attn_mask, dim=1)
  758. query_layer = self.attention_norm(query_states)
  759. # Warning: This place might cause issues, refers to
  760. # https://discuss.pytorch.org/t/cuda-error-cublas-status-not-supported-when-calling-cublasltmatmul-from-torch-nn-functional-linear/170214/13
  761. # Solution: add `DISABLE_ADDMM_CUDA_LT=1` as environment variable
  762. # Apply linear transformation to input_query, input_key, and input_value
  763. query_layer = self.query(query_layer) # [bs, querylength, dim]
  764. if self.key_protein is not None and self.value_protein is not None:
  765. protein_key_value_states = protein_key_value_states.to(query_states)
  766. key_layer_protein = self.key_protein(protein_key_value_states) # [bs, keylength, dim]
  767. value_layer_protein = self.value_protein(protein_key_value_states) # [bs, keylength, dim]
  768. else:
  769. key_layer_protein = None
  770. value_layer_protein = None
  771. if self.key_structure is not None and self.value_structure is not None:
  772. structure_key_value_states = structure_key_value_states.to(query_states)
  773. key_layer_structure = self.key_structure(structure_key_value_states) # [bs, keylength, dim]
  774. value_layer_structure = self.value_structure(structure_key_value_states) # [bs, keylength, dim]
  775. else:
  776. key_layer_structure = None
  777. value_layer_structure = None
  778. if self.key_msa is not None and self.value_msa is not None:
  779. msa_key_value_states = msa_key_value_states.to(query_states)
  780. key_layer_msa = self.key_msa(msa_key_value_states) # [bs, keylength, dim]
  781. value_layer_msa = self.value_msa(msa_key_value_states) # [bs, keylength, dim]
  782. else:
  783. key_layer_msa = None
  784. value_layer_msa = None
  785. key_layer = [key_layer_protein, key_layer_structure, key_layer_msa]
  786. key_layer = [_ for _ in key_layer if _ is not None]
  787. key_layer = torch.cat(key_layer, dim=1)
  788. value_layer = [value_layer_protein, value_layer_structure, value_layer_msa]
  789. value_layer = [_ for _ in value_layer if _ is not None]
  790. value_layer = torch.cat(value_layer, dim=1)
  791. new_query_layer_shape = query_layer.size()[:-1] + (
  792. self.num_attention_heads,
  793. self.attention_head_size,
  794. )
  795. query_layer = query_layer.view(*new_query_layer_shape).permute(0, 2, 1, 3)
  796. new_key_layer_shape = key_layer.size()[:-1] + (
  797. self.num_attention_heads,
  798. self.attention_head_size,
  799. )
  800. key_layer = key_layer.view(*new_key_layer_shape).permute(0, 2, 1, 3)
  801. new_value_layer_shape = value_layer.size()[:-1] + (
  802. self.num_attention_heads,
  803. self.attention_head_size,
  804. )
  805. value_layer = value_layer.view(*new_value_layer_shape).permute(0, 2, 1, 3)
  806. query_layer = query_layer * self.scale
  807. # attention_mask: [bs, 1, querylength, keylength]
  808. if query_attn_mask is None:
  809. query_attn_mask = torch.ones(query_states.size(0), query_states.size(1)).to(query_states.device)
  810. attention_mask = query_attn_mask[:, None, :, None] * kv_attn_mask[:, None, None, :]
  811. # Compute the scaled dot-product attention scores
  812. attn_weights = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # [bs, numheads, querylength, keylength]
  813. attn_weights = attn_weights - attn_weights.amax(dim=-1, keepdim=True).detach() # To stabilize score
  814. attention_scores = attn_weights.masked_fill(
  815. (1 - attention_mask).bool(), torch.finfo(attn_weights.dtype).min
  816. ) # [bs, numheads, querylength, keylength]
  817. attention_probs = nn.Softmax(dim=-1)(attention_scores)
  818. # attention_probs_dropped = self.dropout(attention_probs)
  819. context_layer = torch.matmul(attention_probs, value_layer) # [bs, numheads, querylength, dim/numheads]
  820. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  821. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  822. context_layer = context_layer.view(*new_context_layer_shape)
  823. context_layer = self.out_proj(context_layer)
  824. return context_layer
  825. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  826. def forward(
  827. self,
  828. query_states,
  829. protein_kv_states,
  830. structure_kv_states,
  831. msa_kv_states,
  832. query_attn_mask,
  833. protein_kv_attn_mask=None,
  834. structure_kv_attn_mask=None,
  835. msa_kv_attn_mask=None,
  836. protein_batch_mask=None,
  837. structure_batch_mask=None,
  838. msa_batch_mask=None,
  839. past_key_values=None,
  840. ):
  841. if protein_kv_states is not None:
  842. bs, protein_kv_seq_len, dim = protein_kv_states.shape
  843. if protein_kv_attn_mask is None:
  844. protein_kv_attn_mask = (
  845. torch.ones(bs, protein_kv_seq_len).to(protein_batch_mask.device)
  846. * protein_batch_mask.expand(size=(protein_kv_seq_len, bs)).T
  847. ).to(protein_kv_states.device)
  848. else:
  849. protein_kv_attn_mask = None
  850. if structure_kv_states is not None:
  851. bs, structure_kv_seq_len, dim = structure_kv_states.shape
  852. if structure_kv_attn_mask is None:
  853. structure_kv_attn_mask = (
  854. torch.ones(bs, structure_kv_seq_len).to(protein_batch_mask.device)
  855. * structure_batch_mask.expand(size=(structure_kv_seq_len, bs)).T
  856. ).to(structure_kv_states.device)
  857. else:
  858. structure_kv_attn_mask = None
  859. if msa_kv_states is not None:
  860. bs, msa_kv_seq_len, dim = msa_kv_states.shape
  861. if msa_kv_attn_mask is None:
  862. msa_kv_attn_mask = (
  863. torch.ones(bs, msa_kv_seq_len).to(protein_batch_mask.device)
  864. * msa_batch_mask.expand(size=(msa_kv_seq_len, bs)).T
  865. ).to(msa_kv_states.device)
  866. else:
  867. msa_kv_attn_mask = None
  868. hidden_states = query_states
  869. # only when there's at least one valid modality, crossattention will be performed
  870. if (
  871. (protein_kv_states is not None and protein_kv_attn_mask.any())
  872. or (structure_kv_states is not None and structure_kv_attn_mask.any())
  873. or (msa_kv_states is not None and msa_kv_attn_mask.any())
  874. ):
  875. residual = hidden_states
  876. hidden_states = self.cross_attention(
  877. query_states=hidden_states,
  878. protein_key_value_states=protein_kv_states,
  879. structure_key_value_states=structure_kv_states,
  880. msa_key_value_states=msa_kv_states,
  881. query_attn_mask=query_attn_mask,
  882. protein_kv_attn_mask=protein_kv_attn_mask,
  883. structure_kv_attn_mask=structure_kv_attn_mask,
  884. msa_kv_attn_mask=msa_kv_attn_mask,
  885. ) # [bs, query_seq_len, dim]
  886. # tanh gate
  887. hidden_states = torch.tanh(self.gate_attention) * hidden_states
  888. hidden_states = residual + hidden_states # input_query
  889. residual = hidden_states
  890. hidden_states = self.ff(hidden_states) * torch.tanh(self.gate_ffw)
  891. hidden_states = residual + hidden_states
  892. return hidden_states
  893. @use_kernel_forward_from_hub("RMSNorm")
  894. class EvollaRMSNorm(nn.Module):
  895. def __init__(self, hidden_size, eps=1e-6):
  896. """
  897. EvollaRMSNorm is equivalent to T5LayerNorm
  898. """
  899. super().__init__()
  900. self.weight = nn.Parameter(torch.ones(hidden_size))
  901. self.variance_epsilon = eps
  902. def forward(self, hidden_states):
  903. input_dtype = hidden_states.dtype
  904. hidden_states = hidden_states.to(torch.float32)
  905. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  906. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  907. return self.weight * hidden_states.to(input_dtype)
  908. def extra_repr(self):
  909. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  910. class EvollaRotaryEmbedding(nn.Module):
  911. inv_freq: torch.Tensor # fix linting for `register_buffer`
  912. def __init__(self, config: EvollaConfig, device=None):
  913. super().__init__()
  914. # BC: "rope_type" was originally "type"
  915. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  916. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  917. else:
  918. self.rope_type = "default"
  919. self.max_seq_len_cached = config.max_position_embeddings
  920. self.original_max_seq_len = config.max_position_embeddings
  921. self.config = config
  922. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  923. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  924. self.register_buffer("inv_freq", inv_freq, persistent=False)
  925. self.original_inv_freq = self.inv_freq
  926. @torch.no_grad()
  927. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  928. def forward(self, x, position_ids):
  929. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  930. position_ids_expanded = position_ids[:, None, :].float()
  931. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  932. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  933. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  934. emb = torch.cat((freqs, freqs), dim=-1)
  935. cos = emb.cos() * self.attention_scaling
  936. sin = emb.sin() * self.attention_scaling
  937. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  938. class EvollaMLP(nn.Module):
  939. def __init__(self, config):
  940. super().__init__()
  941. self.config = config
  942. self.hidden_size = config.hidden_size
  943. self.intermediate_size = config.intermediate_size
  944. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  945. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  946. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  947. self.act_fn = ACT2FN[config.hidden_act]
  948. def forward(self, x):
  949. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  950. return down_proj
  951. def rotate_half(x):
  952. """Rotates half the hidden dims of the input."""
  953. x1 = x[..., : x.shape[-1] // 2]
  954. x2 = x[..., x.shape[-1] // 2 :]
  955. return torch.cat((-x2, x1), dim=-1)
  956. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  957. """Applies Rotary Position Embedding to the query and key tensors.
  958. Args:
  959. q (`torch.Tensor`): The query tensor.
  960. k (`torch.Tensor`): The key tensor.
  961. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  962. sin (`torch.Tensor`): The sine part of the rotary embedding.
  963. position_ids (`torch.Tensor`, *optional*):
  964. Deprecated and unused.
  965. unsqueeze_dim (`int`, *optional*, defaults to 1):
  966. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  967. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  968. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  969. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  970. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  971. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  972. Returns:
  973. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  974. """
  975. cos = cos.unsqueeze(unsqueeze_dim)
  976. sin = sin.unsqueeze(unsqueeze_dim)
  977. q_embed = (q * cos) + (rotate_half(q) * sin)
  978. k_embed = (k * cos) + (rotate_half(k) * sin)
  979. return q_embed, k_embed
  980. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  981. """
  982. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  983. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  984. """
  985. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  986. if n_rep == 1:
  987. return hidden_states
  988. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  989. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  990. class EvollaAttention(nn.Module):
  991. """Multi-headed attention from 'Attention Is All You Need' paper"""
  992. def __init__(self, config: EvollaConfig, layer_idx: int):
  993. super().__init__()
  994. self.config = config
  995. self.layer_idx = layer_idx
  996. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  997. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  998. self.scaling = self.head_dim**-0.5
  999. self.attention_dropout = config.attention_dropout
  1000. self.is_causal = True
  1001. self.q_proj = nn.Linear(
  1002. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  1003. )
  1004. self.k_proj = nn.Linear(
  1005. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  1006. )
  1007. self.v_proj = nn.Linear(
  1008. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  1009. )
  1010. self.o_proj = nn.Linear(
  1011. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  1012. )
  1013. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  1014. def forward(
  1015. self,
  1016. hidden_states: torch.Tensor,
  1017. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  1018. attention_mask: Optional[torch.Tensor],
  1019. past_key_values: Optional[Cache] = None,
  1020. cache_position: Optional[torch.LongTensor] = None,
  1021. **kwargs: Unpack[TransformersKwargs],
  1022. ) -> tuple[torch.Tensor, torch.Tensor]:
  1023. input_shape = hidden_states.shape[:-1]
  1024. hidden_shape = (*input_shape, -1, self.head_dim)
  1025. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  1026. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  1027. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  1028. cos, sin = position_embeddings
  1029. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  1030. if past_key_values is not None:
  1031. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  1032. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  1033. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  1034. attention_interface: Callable = eager_attention_forward
  1035. if self.config._attn_implementation != "eager":
  1036. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  1037. attn_output, attn_weights = attention_interface(
  1038. self,
  1039. query_states,
  1040. key_states,
  1041. value_states,
  1042. attention_mask,
  1043. dropout=0.0 if not self.training else self.attention_dropout,
  1044. scaling=self.scaling,
  1045. **kwargs,
  1046. )
  1047. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  1048. attn_output = self.o_proj(attn_output)
  1049. return attn_output, attn_weights
  1050. class EvollaDecoderLayer(GradientCheckpointingLayer):
  1051. def __init__(self, config: EvollaConfig, layer_idx: int):
  1052. super().__init__()
  1053. self.hidden_size = config.hidden_size
  1054. self.self_attn = EvollaAttention(config=config, layer_idx=layer_idx)
  1055. self.mlp = EvollaMLP(config)
  1056. self.input_layernorm = EvollaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  1057. self.post_attention_layernorm = EvollaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  1058. if (layer_idx + 1) % max(config.num_hidden_layers // config.aligner_num_add_layers, 1) == 0:
  1059. self.adapter = EvollaSequenceAlignerCrossAttention(
  1060. config,
  1061. protein_encoder_dim=config.hidden_size,
  1062. )
  1063. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  1064. def forward(
  1065. self,
  1066. hidden_states: torch.Tensor,
  1067. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  1068. attention_mask: Optional[torch.Tensor] = None,
  1069. position_ids: Optional[torch.LongTensor] = None,
  1070. past_key_values: Optional[Cache] = None,
  1071. use_cache: Optional[bool] = False,
  1072. cache_position: Optional[torch.LongTensor] = None,
  1073. protein_kv_states: Optional[torch.Tensor] = None,
  1074. structure_kv_states: Optional[torch.Tensor] = None,
  1075. msa_kv_states: Optional[torch.Tensor] = None,
  1076. protein_batch_mask: Optional[torch.Tensor] = None,
  1077. structure_batch_mask: Optional[torch.Tensor] = None,
  1078. msa_batch_mask: Optional[torch.Tensor] = None,
  1079. query_attn_mask: Optional[torch.Tensor] = None,
  1080. **kwargs,
  1081. ) -> torch.Tensor:
  1082. residual = hidden_states
  1083. hidden_states = self.input_layernorm(hidden_states)
  1084. # Self Attention
  1085. hidden_states, _ = self.self_attn(
  1086. hidden_states=hidden_states,
  1087. attention_mask=attention_mask,
  1088. position_ids=position_ids,
  1089. past_key_values=past_key_values,
  1090. use_cache=use_cache,
  1091. cache_position=cache_position,
  1092. position_embeddings=position_embeddings,
  1093. **kwargs,
  1094. )
  1095. hidden_states = residual + hidden_states
  1096. # Fully Connected
  1097. residual = hidden_states
  1098. hidden_states = self.post_attention_layernorm(hidden_states)
  1099. hidden_states = self.mlp(hidden_states)
  1100. hidden_states = residual + hidden_states
  1101. if hasattr(self, "adapter"):
  1102. hidden_states = self.adapter(
  1103. query_states=hidden_states,
  1104. protein_kv_states=protein_kv_states,
  1105. structure_kv_states=structure_kv_states,
  1106. msa_kv_states=msa_kv_states,
  1107. query_attn_mask=query_attn_mask,
  1108. protein_batch_mask=protein_batch_mask,
  1109. structure_batch_mask=structure_batch_mask,
  1110. msa_batch_mask=msa_batch_mask,
  1111. )
  1112. return hidden_states
  1113. @auto_docstring
  1114. class EvollaPreTrainedModel(PreTrainedModel):
  1115. config: EvollaConfig
  1116. base_model_prefix = "model"
  1117. supports_gradient_checkpointing = True
  1118. _no_split_modules = [
  1119. "EvollaDecoderLayer",
  1120. "EvollaSequenceCompressorResampler",
  1121. "EvollaSequenceAlignerCrossAttention",
  1122. ]
  1123. _skip_keys_device_placement = ["past_key_values"]
  1124. _supports_flash_attn = False # see dependency on `EvollaSaProtProteinEncoder`
  1125. _supports_sdpa = True
  1126. _supports_flex_attn = False # see dependency on `EvollaSaProtProteinEncoder`
  1127. _can_compile_fullgraph = True
  1128. _supports_attention_backend = False
  1129. _can_record_outputs = {
  1130. "hidden_states": EvollaDecoderLayer,
  1131. "attentions": EvollaAttention,
  1132. }
  1133. def _init_weights(self, module):
  1134. std = self.config.initializer_range
  1135. super()._init_weights(module)
  1136. if isinstance(module, EvollaSequenceAlignerCrossAttention):
  1137. module.gate_attention.zero_()
  1138. module.gate_ffw.zero_()
  1139. module.attention_norm.weight.data.fill_(1.0)
  1140. elif isinstance(module, EvollaSequenceCompressorResampler):
  1141. module.latents.data.normal_(mean=0.0, std=std)
  1142. class EvollaModel(EvollaPreTrainedModel):
  1143. def __init__(self, config: EvollaConfig):
  1144. super().__init__(config)
  1145. self.padding_idx = config.pad_token_id
  1146. self.vocab_size = config.vocab_size
  1147. self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, self.padding_idx)
  1148. self.protein_encoder = EvollaProteinEncoder(config=config)
  1149. self.layers = nn.ModuleList(
  1150. [
  1151. EvollaDecoderLayer(
  1152. config=config,
  1153. layer_idx=layer_idx,
  1154. )
  1155. for layer_idx in range(config.num_hidden_layers)
  1156. ]
  1157. )
  1158. self.norm = EvollaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  1159. self.rotary_emb = EvollaRotaryEmbedding(config=config)
  1160. self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False)
  1161. self.post_init()
  1162. def get_input_embeddings(self):
  1163. return self.embed_tokens
  1164. def set_input_embeddings(self, value):
  1165. self.embed_tokens = value
  1166. @auto_docstring
  1167. @check_model_inputs()
  1168. def forward(
  1169. self,
  1170. input_ids: Optional[torch.LongTensor] = None,
  1171. attention_mask: Optional[torch.Tensor] = None,
  1172. position_ids: Optional[torch.LongTensor] = None,
  1173. past_key_values: Optional[Cache] = None,
  1174. inputs_embeds: Optional[torch.FloatTensor] = None,
  1175. use_cache: Optional[bool] = None,
  1176. cache_position: Optional[torch.LongTensor] = None,
  1177. protein_input_ids: Optional[torch.LongTensor] = None,
  1178. protein_attention_mask: Optional[torch.Tensor] = None,
  1179. structure_feats: Optional[torch.FloatTensor] = None,
  1180. msa_feats: Optional[torch.FloatTensor] = None,
  1181. structure_batch_mask: Optional[torch.Tensor] = None,
  1182. msa_batch_mask: Optional[torch.Tensor] = None,
  1183. **kwargs,
  1184. ) -> Union[tuple, BaseModelOutputWithPast]:
  1185. r"""
  1186. protein_input_ids (torch.LongTensor):
  1187. The input IDs for the protein sequence in structure-aware tokens. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`.
  1188. protein_attention_mask (torch.Tensor):
  1189. The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`.
  1190. structure_feats (torch.FloatTensor):
  1191. The input IDs for purely structure-based features. Should be of shape `(batch_size, structure_seq_length, structure_feat_dim)` and type `torch.FloatTensor`. Dummy input for now.
  1192. msa_feats (torch.FloatTensor):
  1193. The input IDs for purely MSA-based features. Should be of shape `(batch_size, msa_seq_length, msa_feat_dim)` and type `torch.FloatTensor`. Dummy input for now.
  1194. structure_batch_mask (torch.Tensor):
  1195. The batch mask to decide which protein sequences are purely structure-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `structure_feats`. Dummpy input for now.
  1196. msa_batch_mask (torch.Tensor):
  1197. The batch mask to decide which protein sequences are purely MSA-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `msa_feats`. Dummpy input for now.
  1198. """
  1199. if (input_ids is None) ^ (inputs_embeds is not None):
  1200. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  1201. if inputs_embeds is None:
  1202. inputs_embeds = self.embed_tokens(input_ids)
  1203. if use_cache and past_key_values is None:
  1204. past_key_values = DynamicCache(config=self.config)
  1205. if cache_position is None:
  1206. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1207. cache_position = torch.arange(
  1208. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  1209. )
  1210. if position_ids is None:
  1211. position_ids = cache_position.unsqueeze(0)
  1212. protein_feats = None
  1213. protein_batch_mask = None
  1214. # If provided, actually compute them
  1215. if protein_input_ids is not None and protein_attention_mask is not None:
  1216. protein_outputs = self.protein_encoder(
  1217. input_ids=protein_input_ids,
  1218. attention_mask=protein_attention_mask,
  1219. )
  1220. protein_feats = protein_outputs.sequence_compressor_output
  1221. protein_batch_mask = torch.tensor([True] * protein_input_ids.shape[0], device=protein_input_ids.device)
  1222. causal_mask = create_causal_mask(
  1223. config=self.config,
  1224. input_embeds=inputs_embeds,
  1225. attention_mask=attention_mask,
  1226. cache_position=cache_position,
  1227. past_key_values=past_key_values,
  1228. )
  1229. hidden_states = inputs_embeds
  1230. # create position embeddings to be shared across the decoder layers
  1231. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  1232. for decoder_layer in self.layers:
  1233. hidden_states = decoder_layer(
  1234. hidden_states,
  1235. attention_mask=causal_mask,
  1236. position_ids=position_ids,
  1237. past_key_values=past_key_values,
  1238. use_cache=use_cache,
  1239. cache_position=cache_position,
  1240. position_embeddings=position_embeddings,
  1241. protein_kv_states=protein_feats,
  1242. structure_kv_states=structure_feats,
  1243. msa_kv_states=msa_feats,
  1244. protein_batch_mask=protein_batch_mask,
  1245. structure_batch_mask=structure_batch_mask,
  1246. msa_batch_mask=msa_batch_mask,
  1247. query_attn_mask=attention_mask,
  1248. **kwargs,
  1249. )
  1250. hidden_states = self.norm(hidden_states)
  1251. output = BaseModelOutputWithPast(
  1252. last_hidden_state=hidden_states,
  1253. past_key_values=past_key_values,
  1254. )
  1255. return output
  1256. class EvollaForProteinText2Text(EvollaPreTrainedModel, GenerationMixin):
  1257. def __init__(self, config):
  1258. super().__init__(config)
  1259. self.model = EvollaModel(config)
  1260. self.vocab_size = config.vocab_size
  1261. self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=False)
  1262. self.post_init()
  1263. def get_input_embeddings(self):
  1264. return self.model.get_input_embeddings()
  1265. def set_input_embeddings(self, value):
  1266. return self.model.set_input_embeddings(value)
  1267. @can_return_tuple
  1268. @auto_docstring
  1269. def forward(
  1270. self,
  1271. input_ids: Optional[torch.LongTensor] = None, # text input ids
  1272. attention_mask: Optional[torch.Tensor] = None, # text attention mask
  1273. inputs_embeds: Optional[torch.FloatTensor] = None, # text input embeddings
  1274. labels: Optional[torch.LongTensor] = None,
  1275. protein_input_ids: Optional[torch.LongTensor] = None,
  1276. protein_attention_mask: Optional[torch.Tensor] = None,
  1277. use_cache: Optional[bool] = None,
  1278. **kwargs,
  1279. ):
  1280. r"""
  1281. protein_input_ids (torch.LongTensor):
  1282. The input IDs for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`.
  1283. protein_attention_mask (torch.Tensor):
  1284. The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`.
  1285. Example:
  1286. ```python
  1287. >>> from transformers import EvollaProcessor, EvollaForProteinText2Text
  1288. >>> model = EvollaForProteinText2Text.from_pretrained("westlake/Evolla-10B-hf")
  1289. >>> processor = EvollaProcessor.from_pretrained("westlake/Evolla-10B-hf")
  1290. >>> protein_information = {
  1291. "aa_seq": "your amino acid sequence",
  1292. "foldseek": "your foldseek sequence",
  1293. }
  1294. >>> question = "What is the function of this protein?"
  1295. >>> message = [
  1296. {"role": "system", "content": "You are an AI expert that can answer any questions about protein."},
  1297. {"role": "user", "content": question},
  1298. ]
  1299. >>> inputs = processor(proteins=[protein_information], messages_list=[message], return_tensors="pt", padding="longest")
  1300. >>> outputs = model.generate(**inputs)
  1301. >>> print(processor.batch_decode(outputs, skip_special_tokens=True))
  1302. ```"""
  1303. outputs = self.model(
  1304. input_ids=input_ids,
  1305. attention_mask=attention_mask,
  1306. inputs_embeds=inputs_embeds,
  1307. protein_input_ids=protein_input_ids,
  1308. protein_attention_mask=protein_attention_mask,
  1309. use_cache=use_cache,
  1310. **kwargs,
  1311. )
  1312. hidden_states = outputs[0]
  1313. logits = self.lm_head(hidden_states)
  1314. loss = None
  1315. if labels is not None:
  1316. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
  1317. lm_outputs = CausalLMOutputWithPast(
  1318. loss=loss,
  1319. logits=logits,
  1320. past_key_values=outputs.past_key_values,
  1321. hidden_states=outputs.hidden_states,
  1322. attentions=outputs.attentions,
  1323. )
  1324. return lm_outputs
  1325. __all__ = ["EvollaForProteinText2Text", "EvollaModel", "EvollaPreTrainedModel"]