processing_evolla.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. # coding=utf-8
  2. # Copyright 2025 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. Processor class for EVOLLA.
  17. """
  18. import os
  19. from typing import Optional, Union
  20. from ...feature_extraction_utils import BatchFeature
  21. from ...processing_utils import (
  22. ProcessorMixin,
  23. )
  24. from ..auto import AutoTokenizer
  25. PROTEIN_VALID_KEYS = ["aa_seq", "foldseek", "msa"]
  26. class EvollaProcessor(ProcessorMixin):
  27. r"""
  28. Constructs a EVOLLA processor which wraps a LLama tokenizer and SaProt tokenizer (EsmTokenizer) into a single processor.
  29. [`EvollaProcessor`] offers all the functionalities of [`EsmTokenizer`] and [`LlamaTokenizerFast`]. See the
  30. docstring of [`~EvollaProcessor.__call__`] and [`~EvollaProcessor.decode`] for more information.
  31. Args:
  32. protein_tokenizer (`EsmTokenizer`):
  33. An instance of [`EsmTokenizer`]. The protein tokenizer is a required input.
  34. tokenizer (`LlamaTokenizerFast`, *optional*):
  35. An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input.
  36. protein_max_length (`int`, *optional*, defaults to 1024):
  37. The maximum length of the sequence to be generated.
  38. text_max_length (`int`, *optional*, defaults to 512):
  39. The maximum length of the text to be generated.
  40. """
  41. attributes = ["protein_tokenizer", "tokenizer"]
  42. valid_kwargs = ["sequence_max_length"]
  43. # protein_tokenizer_class = "EsmTokenizer"
  44. # tokenizer_class = "LlamaTokenizerFast"
  45. protein_tokenizer_class = "AutoTokenizer"
  46. tokenizer_class = "AutoTokenizer"
  47. protein_tokenizer_dir_name = "protein_tokenizer"
  48. # tokenizer_dir_name = "text_tokenizer"
  49. def __init__(self, protein_tokenizer, tokenizer=None, protein_max_length=1024, text_max_length=512, **kwargs):
  50. if protein_tokenizer is None:
  51. raise ValueError("You need to specify an `protein_tokenizer`.")
  52. if tokenizer is None:
  53. raise ValueError("You need to specify a `tokenizer`.")
  54. super().__init__(protein_tokenizer, tokenizer)
  55. self.tokenizer.pad_token = "<|reserved_special_token_0|>"
  56. self.protein_max_length = protein_max_length
  57. self.text_max_length = text_max_length
  58. def process_proteins(self, proteins, protein_max_length=1024):
  59. sa_sequences = []
  60. for protein in proteins:
  61. aa_seq = protein.get("aa_seq")
  62. foldseek = protein.get("foldseek")
  63. sa_sequence = "".join([s.upper() + f.lower() for s, f in zip(aa_seq, foldseek)])
  64. sa_sequences.append(sa_sequence)
  65. sa_tokens = self.protein_tokenizer.batch_encode_plus(
  66. sa_sequences, return_tensors="pt", truncation=True, max_length=protein_max_length, padding=True
  67. )
  68. return sa_tokens
  69. def process_text(
  70. self,
  71. texts,
  72. text_max_length: int = 512,
  73. ):
  74. prompts = []
  75. for messages in texts:
  76. prompt = self.tokenizer.apply_chat_template(
  77. messages,
  78. tokenize=False,
  79. add_generation_prompt=True,
  80. )
  81. prompts.append(prompt)
  82. prompt_inputs = self.tokenizer(
  83. prompts,
  84. add_special_tokens=False,
  85. return_tensors="pt",
  86. padding="longest",
  87. truncation=True,
  88. max_length=text_max_length,
  89. )
  90. return prompt_inputs
  91. def __call__(
  92. self,
  93. proteins: Optional[Union[list[dict], dict]] = None,
  94. messages_list: Optional[Union[list[list[dict]], list[dict]]] = None,
  95. protein_max_length: Optional[int] = None,
  96. text_max_length: Optional[int] = None,
  97. **kwargs,
  98. ):
  99. r"""This method takes batched or non-batched proteins and messages_list and converts them into format that can be used by
  100. the model.
  101. Args:
  102. proteins (`Union[List[dict], dict]`):
  103. A list of dictionaries or a single dictionary containing the following keys:
  104. - `"aa_seq"` (`str`) -- The amino acid sequence of the protein.
  105. - `"foldseek"` (`str`) -- The foldseek string of the protein.
  106. messages_list (`Union[List[List[dict]], List[dict]]`):
  107. A list of lists of dictionaries or a list of dictionaries containing the following keys:
  108. - `"role"` (`str`) -- The role of the message.
  109. - `"content"` (`str`) -- The content of the message.
  110. protein_max_length (`int`, *optional*, defaults to 1024):
  111. The maximum length of the sequence to be generated.
  112. text_max_length (`int`, *optional*, defaults to 512):
  113. The maximum length of the text.
  114. Return:
  115. a dict with following keys:
  116. - `protein_input_ids` (`torch.Tensor` of shape `(batch_size, sequence_length)`) -- The input IDs for the protein sequence.
  117. - `protein_attention_mask` (`torch.Tensor` of shape `(batch_size, sequence_length)`) -- The attention mask for the protein sequence.
  118. - `text_input_ids` (`torch.Tensor` of shape `(batch_size, sequence_length)`) -- The input IDs for the text sequence.
  119. - `text_attention_mask` (`torch.Tensor` of shape `(batch_size, sequence_length)`) -- The attention mask for the text sequence.
  120. """
  121. # proteins and messages_list should be provided
  122. if proteins is None or messages_list is None:
  123. raise ValueError("You need to specify `messages_list` and `proteins`.")
  124. protein_max_length = protein_max_length if protein_max_length is not None else self.protein_max_length
  125. text_max_length = text_max_length if text_max_length is not None else self.text_max_length
  126. # proteins should be List[dict]
  127. if isinstance(proteins, dict):
  128. proteins = [proteins]
  129. # messages_list should be List[List[dict]]
  130. if isinstance(messages_list, (list, tuple)) and not isinstance(messages_list[0], (list, tuple)):
  131. messages_list = [messages_list]
  132. # Check if batched proteins are in the correct format
  133. if isinstance(proteins, (list, tuple)) and not all(isinstance(p, dict) for p in proteins):
  134. raise ValueError("The proteins should be a list of dictionaries, but not all elements are dictionaries.")
  135. if isinstance(proteins, (list, tuple)) and not all(
  136. all(k in PROTEIN_VALID_KEYS for k in p.keys()) for p in proteins
  137. ):
  138. raise ValueError(
  139. "There should be a list of dictionaries with keys: "
  140. f"{', '.join(PROTEIN_VALID_KEYS)} for each protein."
  141. f"But got: {proteins}"
  142. )
  143. # Check if batched messages_list is in the correct format
  144. if isinstance(messages_list, (list, tuple)):
  145. for messages in messages_list:
  146. if not isinstance(messages, (list, tuple)):
  147. raise ValueError(f"Each messages in messages_list should be a list instead of {type(messages)}.")
  148. if not all(isinstance(m, dict) for m in messages):
  149. raise ValueError(
  150. "Each message in messages_list should be a list of dictionaries, but not all elements are dictionaries."
  151. )
  152. if any(len(m.keys()) != 2 for m in messages) or any(
  153. set(m.keys()) != {"role", "content"} for m in messages
  154. ):
  155. raise ValueError(
  156. "Each message in messages_list should be a list of dictionaries with two keys: 'role' and 'content'."
  157. f"But got: {messages}"
  158. )
  159. else:
  160. raise ValueError(
  161. f"The messages_list should be a list of lists of dictionaries, but it's {type(messages_list)}."
  162. )
  163. sa_tokens = self.process_proteins(proteins, protein_max_length)
  164. text_tokens = self.process_text(messages_list, text_max_length)
  165. return BatchFeature(
  166. data={
  167. "protein_input_ids": sa_tokens["input_ids"],
  168. "protein_attention_mask": sa_tokens["attention_mask"],
  169. "input_ids": text_tokens["input_ids"],
  170. "attention_mask": text_tokens["attention_mask"],
  171. }
  172. )
  173. def batch_decode(self, *args, **kwargs):
  174. return self.tokenizer.batch_decode(*args, **kwargs)
  175. def decode(self, *args, **kwargs):
  176. return self.tokenizer.decode(*args, **kwargs)
  177. def protein_batch_decode(self, *args, **kwargs):
  178. return self.protein_tokenizer.batch_decode(*args, **kwargs)
  179. def protein_decode(self, *args, **kwargs):
  180. return self.protein_tokenizer.decode(*args, **kwargs)
  181. # overwrite to save the protein tokenizer in a separate folder
  182. # Adapted from instructblip.processing_instructblip.py (https://github.com/huggingface/transformers/blob/9b479a245b793cac2a8b2e87c6d8e81bb24e20c4/src/transformers/models/instructblip/processing_instructblip.py#L191-L221)
  183. def save_pretrained(self, save_directory, **kwargs):
  184. # only save the protein tokenizer in sub_dir
  185. self.protein_tokenizer.save_pretrained(os.path.join(save_directory, self.protein_tokenizer_dir_name))
  186. # we modify the attributes so that only the text tokenizer are saved in the main folder
  187. protein_tokenizer_present = "protein_tokenizer" in self.attributes
  188. # find the correct position of it in the attributes list
  189. protein_tokenizer_index = self.attributes.index("protein_tokenizer") if protein_tokenizer_present else None
  190. if protein_tokenizer_present and protein_tokenizer_index is not None:
  191. self.attributes.remove("protein_tokenizer")
  192. outputs = super().save_pretrained(save_directory, **kwargs)
  193. if protein_tokenizer_present and protein_tokenizer_index is not None:
  194. self.attributes.insert(protein_tokenizer_index, "protein_tokenizer")
  195. return outputs
  196. # overwrite to load the protein tokenizer from a separate folder
  197. # Adapted from instructblip.processing_instructblip.py (https://github.com/huggingface/transformers/blob/9b479a245b793cac2a8b2e87c6d8e81bb24e20c4/src/transformers/models/instructblip/processing_instructblip.py#L191-L221)
  198. @classmethod
  199. def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
  200. processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
  201. # if return_unused_kwargs a tuple is returned where the second element is 'unused_kwargs'
  202. if isinstance(processor, tuple):
  203. processor = processor[0]
  204. protein_tokenizer = AutoTokenizer.from_pretrained(
  205. pretrained_model_name_or_path, subfolder=cls.protein_tokenizer_dir_name
  206. )
  207. processor.protein_tokenizer = protein_tokenizer
  208. return processor
  209. __all__ = ["EvollaProcessor"]