| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
- # and is publicly available at https://github.com/dptech-corp/Uni-Fold.
- import os
- from typing import Mapping, Sequence
- import json
- from absl import logging
- from modelscope.models.science.unifold.data import protein
- def get_chain_id_map(
- sequences: Sequence[str],
- descriptions: Sequence[str],
- ):
- """
- Makes a mapping from PDB-format chain ID to sequence and description,
- and parses the order of multi-chains
- """
- unique_seqs = []
- for seq in sequences:
- if seq not in unique_seqs:
- unique_seqs.append(seq)
- chain_id_map = {
- chain_id: {
- 'descriptions': [],
- 'sequence': seq
- }
- for chain_id, seq in zip(protein.PDB_CHAIN_IDS, unique_seqs)
- }
- chain_order = []
- for seq, des in zip(sequences, descriptions):
- chain_id = protein.PDB_CHAIN_IDS[unique_seqs.index(seq)]
- chain_id_map[chain_id]['descriptions'].append(des)
- chain_order.append(chain_id)
- return chain_id_map, chain_order
- def divide_multi_chains(
- fasta_name: str,
- output_dir_base: str,
- sequences: Sequence[str],
- descriptions: Sequence[str],
- ):
- """
- Divides the multi-chains fasta into several single fasta files and
- records multi-chains mapping information.
- """
- if len(sequences) != len(descriptions):
- raise ValueError('sequences and descriptions must have equal length. '
- f'Got {len(sequences)} != {len(descriptions)}.')
- if len(sequences) > protein.PDB_MAX_CHAINS:
- raise ValueError(
- 'Cannot process more chains than the PDB format supports. '
- f'Got {len(sequences)} chains.')
- chain_id_map, chain_order = get_chain_id_map(sequences, descriptions)
- output_dir = os.path.join(output_dir_base, fasta_name)
- if not os.path.exists(output_dir):
- os.makedirs(output_dir)
- chain_id_map_path = os.path.join(output_dir, 'chain_id_map.json')
- with open(chain_id_map_path, 'w') as f:
- json.dump(chain_id_map, f, indent=4, sort_keys=True)
- chain_order_path = os.path.join(output_dir, 'chains.txt')
- with open(chain_order_path, 'w') as f:
- f.write(' '.join(chain_order))
- logging.info('Mapping multi-chains fasta with chain order: %s',
- ' '.join(chain_order))
- temp_names = []
- temp_paths = []
- for chain_id in chain_id_map.keys():
- temp_name = fasta_name + '_{}'.format(chain_id)
- temp_path = os.path.join(output_dir, temp_name + '.fasta')
- des = 'chain_{}'.format(chain_id)
- seq = chain_id_map[chain_id]['sequence']
- with open(temp_path, 'w') as f:
- f.write('>' + des + '\n' + seq)
- temp_names.append(temp_name)
- temp_paths.append(temp_path)
- return temp_names, temp_paths
|