utils.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
  2. # and is publicly available at https://github.com/dptech-corp/Uni-Fold.
  3. import os
  4. from typing import Mapping, Sequence
  5. import json
  6. from absl import logging
  7. from modelscope.models.science.unifold.data import protein
  8. def get_chain_id_map(
  9. sequences: Sequence[str],
  10. descriptions: Sequence[str],
  11. ):
  12. """
  13. Makes a mapping from PDB-format chain ID to sequence and description,
  14. and parses the order of multi-chains
  15. """
  16. unique_seqs = []
  17. for seq in sequences:
  18. if seq not in unique_seqs:
  19. unique_seqs.append(seq)
  20. chain_id_map = {
  21. chain_id: {
  22. 'descriptions': [],
  23. 'sequence': seq
  24. }
  25. for chain_id, seq in zip(protein.PDB_CHAIN_IDS, unique_seqs)
  26. }
  27. chain_order = []
  28. for seq, des in zip(sequences, descriptions):
  29. chain_id = protein.PDB_CHAIN_IDS[unique_seqs.index(seq)]
  30. chain_id_map[chain_id]['descriptions'].append(des)
  31. chain_order.append(chain_id)
  32. return chain_id_map, chain_order
  33. def divide_multi_chains(
  34. fasta_name: str,
  35. output_dir_base: str,
  36. sequences: Sequence[str],
  37. descriptions: Sequence[str],
  38. ):
  39. """
  40. Divides the multi-chains fasta into several single fasta files and
  41. records multi-chains mapping information.
  42. """
  43. if len(sequences) != len(descriptions):
  44. raise ValueError('sequences and descriptions must have equal length. '
  45. f'Got {len(sequences)} != {len(descriptions)}.')
  46. if len(sequences) > protein.PDB_MAX_CHAINS:
  47. raise ValueError(
  48. 'Cannot process more chains than the PDB format supports. '
  49. f'Got {len(sequences)} chains.')
  50. chain_id_map, chain_order = get_chain_id_map(sequences, descriptions)
  51. output_dir = os.path.join(output_dir_base, fasta_name)
  52. if not os.path.exists(output_dir):
  53. os.makedirs(output_dir)
  54. chain_id_map_path = os.path.join(output_dir, 'chain_id_map.json')
  55. with open(chain_id_map_path, 'w') as f:
  56. json.dump(chain_id_map, f, indent=4, sort_keys=True)
  57. chain_order_path = os.path.join(output_dir, 'chains.txt')
  58. with open(chain_order_path, 'w') as f:
  59. f.write(' '.join(chain_order))
  60. logging.info('Mapping multi-chains fasta with chain order: %s',
  61. ' '.join(chain_order))
  62. temp_names = []
  63. temp_paths = []
  64. for chain_id in chain_id_map.keys():
  65. temp_name = fasta_name + '_{}'.format(chain_id)
  66. temp_path = os.path.join(output_dir, temp_name + '.fasta')
  67. des = 'chain_{}'.format(chain_id)
  68. seq = chain_id_map[chain_id]['sequence']
  69. with open(temp_path, 'w') as f:
  70. f.write('>' + des + '\n' + seq)
  71. temp_names.append(temp_name)
  72. temp_paths.append(temp_path)
  73. return temp_names, temp_paths