| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365 |
- # coding=utf-8
- # Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """ESM model configuration"""
- from dataclasses import asdict, dataclass
- from typing import Optional
- from ...configuration_utils import PretrainedConfig
- from ...utils import logging
- logger = logging.get_logger(__name__)
- # TODO Update this
- class EsmConfig(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`ESMModel`]. It is used to instantiate a ESM model
- according to the specified arguments, defining the model architecture. Instantiating a configuration with the
- defaults will yield a similar configuration to that of the ESM
- [facebook/esm-1b](https://huggingface.co/facebook/esm-1b) architecture.
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
- Args:
- vocab_size (`int`, *optional*):
- Vocabulary size of the ESM model. Defines the number of different tokens that can be represented by the
- `inputs_ids` passed when calling [`ESMModel`].
- mask_token_id (`int`, *optional*):
- The index of the mask token in the vocabulary. This must be included in the config because of the
- "mask-dropout" scaling trick, which will scale the inputs depending on the number of masked tokens.
- pad_token_id (`int`, *optional*):
- The index of the padding token in the vocabulary. This must be included in the config because certain parts
- of the ESM code use this instead of the attention mask.
- hidden_size (`int`, *optional*, defaults to 768):
- Dimensionality of the encoder layers and the pooler layer.
- num_hidden_layers (`int`, *optional*, defaults to 12):
- Number of hidden layers in the Transformer encoder.
- num_attention_heads (`int`, *optional*, defaults to 12):
- Number of attention heads for each attention layer in the Transformer encoder.
- intermediate_size (`int`, *optional*, defaults to 3072):
- Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
- hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
- The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
- attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
- The dropout ratio for the attention probabilities.
- max_position_embeddings (`int`, *optional*, defaults to 1026):
- The maximum sequence length that this model might ever be used with. Typically set this to something large
- just in case (e.g., 512 or 1024 or 2048).
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- layer_norm_eps (`float`, *optional*, defaults to 1e-12):
- The epsilon used by the layer normalization layers.
- position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
- Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query", "rotary"`.
- For positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
- [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155).
- For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
- with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658).
- is_decoder (`bool`, *optional*, defaults to `False`):
- Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
- use_cache (`bool`, *optional*, defaults to `True`):
- Whether or not the model should return the last key/values attentions (not used by all models). Only
- relevant if `config.is_decoder=True`.
- emb_layer_norm_before (`bool`, *optional*):
- Whether to apply layer normalization after embeddings but before the main stem of the network.
- token_dropout (`bool`, defaults to `False`):
- When this is enabled, masked tokens are treated as if they had been dropped out by input dropout.
- Examples:
- ```python
- >>> from transformers import EsmModel, EsmConfig
- >>> # Initializing a ESM facebook/esm-1b style configuration
- >>> configuration = EsmConfig(vocab_size=33)
- >>> # Initializing a model from the configuration
- >>> model = EsmModel(configuration)
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```"""
- model_type = "esm"
- def __init__(
- self,
- vocab_size=None,
- mask_token_id=None,
- pad_token_id=None,
- hidden_size=768,
- num_hidden_layers=12,
- num_attention_heads=12,
- intermediate_size=3072,
- hidden_dropout_prob=0.1,
- attention_probs_dropout_prob=0.1,
- max_position_embeddings=1026,
- initializer_range=0.02,
- layer_norm_eps=1e-12,
- position_embedding_type="absolute",
- use_cache=True,
- emb_layer_norm_before=None,
- token_dropout=False,
- is_folding_model=False,
- esmfold_config=None,
- vocab_list=None,
- **kwargs,
- ):
- super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs)
- self.vocab_size = vocab_size
- self.hidden_size = hidden_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.intermediate_size = intermediate_size
- self.hidden_dropout_prob = hidden_dropout_prob
- self.attention_probs_dropout_prob = attention_probs_dropout_prob
- self.max_position_embeddings = max_position_embeddings
- self.initializer_range = initializer_range
- self.layer_norm_eps = layer_norm_eps
- self.position_embedding_type = position_embedding_type
- self.use_cache = use_cache
- self.emb_layer_norm_before = emb_layer_norm_before
- self.token_dropout = token_dropout
- self.is_folding_model = is_folding_model
- if is_folding_model:
- if esmfold_config is None:
- logger.info("No esmfold_config supplied for folding model, using default values.")
- esmfold_config = EsmFoldConfig()
- elif isinstance(esmfold_config, dict):
- esmfold_config = EsmFoldConfig(**esmfold_config)
- self.esmfold_config = esmfold_config
- if vocab_list is None:
- logger.warning("No vocab_list supplied for folding model, assuming the ESM-2 vocabulary!")
- self.vocab_list = get_default_vocab_list()
- else:
- self.vocab_list = vocab_list
- else:
- self.esmfold_config = None
- self.vocab_list = None
- if self.esmfold_config is not None and getattr(self.esmfold_config, "use_esm_attn_map", False):
- raise ValueError("The HuggingFace port of ESMFold does not support use_esm_attn_map at this time!")
- def to_dict(self):
- """
- Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
- Returns:
- `dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
- """
- output = super().to_dict()
- if isinstance(self.esmfold_config, EsmFoldConfig):
- output["esmfold_config"] = self.esmfold_config.to_dict()
- return output
- @dataclass
- class EsmFoldConfig:
- esm_type: Optional[str] = None
- fp16_esm: bool = True
- use_esm_attn_map: bool = False
- esm_ablate_pairwise: bool = False
- esm_ablate_sequence: bool = False
- esm_input_dropout: float = 0
- embed_aa: bool = True
- bypass_lm: bool = False
- lddt_head_hid_dim: int = 128
- trunk: "TrunkConfig" = None
- def __post_init__(self):
- if self.trunk is None:
- self.trunk = TrunkConfig()
- elif isinstance(self.trunk, dict):
- self.trunk = TrunkConfig(**self.trunk)
- def to_dict(self):
- """
- Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
- Returns:
- `dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
- """
- output = asdict(self)
- output["trunk"] = self.trunk.to_dict()
- return output
- @dataclass
- class TrunkConfig:
- num_blocks: int = 48
- sequence_state_dim: int = 1024
- pairwise_state_dim: int = 128
- sequence_head_width: int = 32
- pairwise_head_width: int = 32
- position_bins: int = 32
- dropout: float = 0
- layer_drop: float = 0
- cpu_grad_checkpoint: bool = False
- max_recycles: int = 4
- chunk_size: Optional[int] = 128
- structure_module: "StructureModuleConfig" = None
- def __post_init__(self):
- if self.structure_module is None:
- self.structure_module = StructureModuleConfig()
- elif isinstance(self.structure_module, dict):
- self.structure_module = StructureModuleConfig(**self.structure_module)
- if self.max_recycles <= 0:
- raise ValueError(f"`max_recycles` should be positive, got {self.max_recycles}.")
- if self.sequence_state_dim % self.sequence_state_dim != 0:
- raise ValueError(
- "`sequence_state_dim` should be a round multiple of `sequence_state_dim`, got"
- f" {self.sequence_state_dim} and {self.sequence_state_dim}."
- )
- if self.pairwise_state_dim % self.pairwise_state_dim != 0:
- raise ValueError(
- "`pairwise_state_dim` should be a round multiple of `pairwise_state_dim`, got"
- f" {self.pairwise_state_dim} and {self.pairwise_state_dim}."
- )
- sequence_num_heads = self.sequence_state_dim // self.sequence_head_width
- pairwise_num_heads = self.pairwise_state_dim // self.pairwise_head_width
- if self.sequence_state_dim != sequence_num_heads * self.sequence_head_width:
- raise ValueError(
- "`sequence_state_dim` should be equal to `sequence_num_heads * sequence_head_width, got"
- f" {self.sequence_state_dim} != {sequence_num_heads} * {self.sequence_head_width}."
- )
- if self.pairwise_state_dim != pairwise_num_heads * self.pairwise_head_width:
- raise ValueError(
- "`pairwise_state_dim` should be equal to `pairwise_num_heads * pairwise_head_width, got"
- f" {self.pairwise_state_dim} != {pairwise_num_heads} * {self.pairwise_head_width}."
- )
- if self.pairwise_state_dim % 2 != 0:
- raise ValueError(f"`pairwise_state_dim` should be even, got {self.pairwise_state_dim}.")
- if self.dropout >= 0.4:
- raise ValueError(f"`dropout` should not be greater than 0.4, got {self.dropout}.")
- def to_dict(self):
- """
- Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
- Returns:
- `dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
- """
- output = asdict(self)
- output["structure_module"] = self.structure_module.to_dict()
- return output
- @dataclass
- class StructureModuleConfig:
- """
- Args:
- sequence_dim:
- Single representation channel dimension
- pairwise_dim:
- Pair representation channel dimension
- ipa_dim:
- IPA hidden channel dimension
- resnet_dim:
- Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
- num_heads_ipa:
- Number of IPA heads
- num_qk_points:
- Number of query/key points to generate during IPA
- num_v_points:
- Number of value points to generate during IPA
- dropout_rate:
- Dropout rate used throughout the layer
- num_blocks:
- Number of structure module blocks
- num_transition_layers:
- Number of layers in the single representation transition (Alg. 23 lines 8-9)
- num_resnet_blocks:
- Number of blocks in the angle resnet
- num_angles:
- Number of angles to generate in the angle resnet
- trans_scale_factor:
- Scale of single representation transition hidden dimension
- epsilon:
- Small number used in angle resnet normalization
- inf:
- Large number used for attention masking
- """
- sequence_dim: int = 384
- pairwise_dim: int = 128
- ipa_dim: int = 16
- resnet_dim: int = 128
- num_heads_ipa: int = 12
- num_qk_points: int = 4
- num_v_points: int = 8
- dropout_rate: float = 0.1
- num_blocks: int = 8
- num_transition_layers: int = 1
- num_resnet_blocks: int = 2
- num_angles: int = 7
- trans_scale_factor: int = 10
- epsilon: float = 1e-8
- inf: float = 1e5
- def to_dict(self):
- return asdict(self)
- def get_default_vocab_list():
- return (
- "<cls>",
- "<pad>",
- "<eos>",
- "<unk>",
- "L",
- "A",
- "G",
- "V",
- "S",
- "E",
- "R",
- "T",
- "I",
- "D",
- "P",
- "K",
- "Q",
- "N",
- "F",
- "Y",
- "M",
- "H",
- "W",
- "C",
- "X",
- "B",
- "U",
- "Z",
- "O",
- ".",
- "-",
- "<null_1>",
- "<mask>",
- )
- __all__ = ["EsmConfig"]
|