processing_pop2piano.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. # coding=utf-8
  2. # Copyright 2023 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Processor class for Pop2Piano."""
  16. import os
  17. from typing import Optional, Union
  18. import numpy as np
  19. from ...feature_extraction_utils import BatchFeature
  20. from ...processing_utils import ProcessorMixin
  21. from ...tokenization_utils import BatchEncoding, PaddingStrategy, TruncationStrategy
  22. from ...utils import TensorType
  23. from ...utils.import_utils import requires
  24. @requires(backends=("essentia", "librosa", "pretty_midi", "scipy", "torch"))
  25. class Pop2PianoProcessor(ProcessorMixin):
  26. r"""
  27. Constructs an Pop2Piano processor which wraps a Pop2Piano Feature Extractor and Pop2Piano Tokenizer into a single
  28. processor.
  29. [`Pop2PianoProcessor`] offers all the functionalities of [`Pop2PianoFeatureExtractor`] and [`Pop2PianoTokenizer`].
  30. See the docstring of [`~Pop2PianoProcessor.__call__`] and [`~Pop2PianoProcessor.decode`] for more information.
  31. Args:
  32. feature_extractor (`Pop2PianoFeatureExtractor`):
  33. An instance of [`Pop2PianoFeatureExtractor`]. The feature extractor is a required input.
  34. tokenizer (`Pop2PianoTokenizer`):
  35. An instance of ['Pop2PianoTokenizer`]. The tokenizer is a required input.
  36. """
  37. attributes = ["feature_extractor", "tokenizer"]
  38. feature_extractor_class = "Pop2PianoFeatureExtractor"
  39. tokenizer_class = "Pop2PianoTokenizer"
  40. def __init__(self, feature_extractor, tokenizer):
  41. super().__init__(feature_extractor, tokenizer)
  42. def __call__(
  43. self,
  44. audio: Union[np.ndarray, list[float], list[np.ndarray]] = None,
  45. sampling_rate: Optional[Union[int, list[int]]] = None,
  46. steps_per_beat: int = 2,
  47. resample: Optional[bool] = True,
  48. notes: Union[list, TensorType] = None,
  49. padding: Union[bool, str, PaddingStrategy] = False,
  50. truncation: Union[bool, str, TruncationStrategy] = None,
  51. max_length: Optional[int] = None,
  52. pad_to_multiple_of: Optional[int] = None,
  53. verbose: bool = True,
  54. **kwargs,
  55. ) -> Union[BatchFeature, BatchEncoding]:
  56. """
  57. This method uses [`Pop2PianoFeatureExtractor.__call__`] method to prepare log-mel-spectrograms for the model,
  58. and [`Pop2PianoTokenizer.__call__`] to prepare token_ids from notes.
  59. Please refer to the docstring of the above two methods for more information.
  60. """
  61. # Since Feature Extractor needs both audio and sampling_rate and tokenizer needs both token_ids and
  62. # feature_extractor_output, we must check for both.
  63. if (audio is None and sampling_rate is None) and (notes is None):
  64. raise ValueError(
  65. "You have to specify at least audios and sampling_rate in order to use feature extractor or "
  66. "notes to use the tokenizer part."
  67. )
  68. if audio is not None and sampling_rate is not None:
  69. inputs = self.feature_extractor(
  70. audio=audio,
  71. sampling_rate=sampling_rate,
  72. steps_per_beat=steps_per_beat,
  73. resample=resample,
  74. **kwargs,
  75. )
  76. if notes is not None:
  77. encoded_token_ids = self.tokenizer(
  78. notes=notes,
  79. padding=padding,
  80. truncation=truncation,
  81. max_length=max_length,
  82. pad_to_multiple_of=pad_to_multiple_of,
  83. verbose=verbose,
  84. **kwargs,
  85. )
  86. if notes is None:
  87. return inputs
  88. elif audio is None or sampling_rate is None:
  89. return encoded_token_ids
  90. else:
  91. inputs["token_ids"] = encoded_token_ids["token_ids"]
  92. return inputs
  93. def batch_decode(
  94. self,
  95. token_ids,
  96. feature_extractor_output: BatchFeature,
  97. return_midi: bool = True,
  98. ) -> BatchEncoding:
  99. """
  100. This method uses [`Pop2PianoTokenizer.batch_decode`] method to convert model generated token_ids to midi_notes.
  101. Please refer to the docstring of the above two methods for more information.
  102. """
  103. return self.tokenizer.batch_decode(
  104. token_ids=token_ids, feature_extractor_output=feature_extractor_output, return_midi=return_midi
  105. )
  106. def save_pretrained(self, save_directory, **kwargs):
  107. if os.path.isfile(save_directory):
  108. raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
  109. os.makedirs(save_directory, exist_ok=True)
  110. return super().save_pretrained(save_directory, **kwargs)
  111. @classmethod
  112. def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
  113. args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
  114. return cls(*args)
  115. __all__ = ["Pop2PianoProcessor"]