local_audio_player.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. # mypy: ignore-errors
  2. from __future__ import annotations
  3. import queue
  4. import asyncio
  5. from typing import Any, Union, Callable, AsyncGenerator, cast
  6. from typing_extensions import TYPE_CHECKING
  7. from .. import _legacy_response
  8. from .._extras import numpy as np, sounddevice as sd
  9. from .._response import StreamedBinaryAPIResponse, AsyncStreamedBinaryAPIResponse
  10. if TYPE_CHECKING:
  11. import numpy.typing as npt
  12. SAMPLE_RATE = 24000
  13. class LocalAudioPlayer:
  14. def __init__(
  15. self,
  16. should_stop: Union[Callable[[], bool], None] = None,
  17. ):
  18. self.channels = 1
  19. self.dtype = np.float32
  20. self.should_stop = should_stop
  21. async def _tts_response_to_buffer(
  22. self,
  23. response: Union[
  24. _legacy_response.HttpxBinaryResponseContent,
  25. AsyncStreamedBinaryAPIResponse,
  26. StreamedBinaryAPIResponse,
  27. ],
  28. ) -> npt.NDArray[np.float32]:
  29. chunks: list[bytes] = []
  30. if isinstance(response, _legacy_response.HttpxBinaryResponseContent) or isinstance(
  31. response, StreamedBinaryAPIResponse
  32. ):
  33. for chunk in response.iter_bytes(chunk_size=1024):
  34. if chunk:
  35. chunks.append(chunk)
  36. else:
  37. async for chunk in response.iter_bytes(chunk_size=1024):
  38. if chunk:
  39. chunks.append(chunk)
  40. audio_bytes = b"".join(chunks)
  41. audio_np = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32767.0
  42. audio_np = audio_np.reshape(-1, 1)
  43. return audio_np
  44. async def play(
  45. self,
  46. input: Union[
  47. npt.NDArray[np.int16],
  48. npt.NDArray[np.float32],
  49. _legacy_response.HttpxBinaryResponseContent,
  50. AsyncStreamedBinaryAPIResponse,
  51. StreamedBinaryAPIResponse,
  52. ],
  53. ) -> None:
  54. audio_content: npt.NDArray[np.float32]
  55. if isinstance(input, np.ndarray):
  56. if input.dtype == np.int16 and self.dtype == np.float32:
  57. audio_content = (input.astype(np.float32) / 32767.0).reshape(-1, self.channels)
  58. elif input.dtype == np.float32:
  59. audio_content = cast("npt.NDArray[np.float32]", input)
  60. else:
  61. raise ValueError(f"Unsupported dtype: {input.dtype}")
  62. else:
  63. audio_content = await self._tts_response_to_buffer(input)
  64. loop = asyncio.get_event_loop()
  65. event = asyncio.Event()
  66. idx = 0
  67. def callback(
  68. outdata: npt.NDArray[np.float32],
  69. frame_count: int,
  70. _time_info: Any,
  71. _status: Any,
  72. ):
  73. nonlocal idx
  74. remainder = len(audio_content) - idx
  75. if remainder == 0 or (callable(self.should_stop) and self.should_stop()):
  76. loop.call_soon_threadsafe(event.set)
  77. raise sd.CallbackStop
  78. valid_frames = frame_count if remainder >= frame_count else remainder
  79. outdata[:valid_frames] = audio_content[idx : idx + valid_frames]
  80. outdata[valid_frames:] = 0
  81. idx += valid_frames
  82. stream = sd.OutputStream(
  83. samplerate=SAMPLE_RATE,
  84. callback=callback,
  85. dtype=audio_content.dtype,
  86. channels=audio_content.shape[1],
  87. )
  88. with stream:
  89. await event.wait()
  90. async def play_stream(
  91. self,
  92. buffer_stream: AsyncGenerator[Union[npt.NDArray[np.float32], npt.NDArray[np.int16], None], None],
  93. ) -> None:
  94. loop = asyncio.get_event_loop()
  95. event = asyncio.Event()
  96. buffer_queue: queue.Queue[Union[npt.NDArray[np.float32], npt.NDArray[np.int16], None]] = queue.Queue(maxsize=50)
  97. async def buffer_producer():
  98. async for buffer in buffer_stream:
  99. if buffer is None:
  100. break
  101. await loop.run_in_executor(None, buffer_queue.put, buffer)
  102. await loop.run_in_executor(None, buffer_queue.put, None) # Signal completion
  103. def callback(
  104. outdata: npt.NDArray[np.float32],
  105. frame_count: int,
  106. _time_info: Any,
  107. _status: Any,
  108. ):
  109. nonlocal current_buffer, buffer_pos
  110. frames_written = 0
  111. while frames_written < frame_count:
  112. if current_buffer is None or buffer_pos >= len(current_buffer):
  113. try:
  114. current_buffer = buffer_queue.get(timeout=0.1)
  115. if current_buffer is None:
  116. loop.call_soon_threadsafe(event.set)
  117. raise sd.CallbackStop
  118. buffer_pos = 0
  119. if current_buffer.dtype == np.int16 and self.dtype == np.float32:
  120. current_buffer = (current_buffer.astype(np.float32) / 32767.0).reshape(-1, self.channels)
  121. except queue.Empty:
  122. outdata[frames_written:] = 0
  123. return
  124. remaining_frames = len(current_buffer) - buffer_pos
  125. frames_to_write = min(frame_count - frames_written, remaining_frames)
  126. outdata[frames_written : frames_written + frames_to_write] = current_buffer[
  127. buffer_pos : buffer_pos + frames_to_write
  128. ]
  129. buffer_pos += frames_to_write
  130. frames_written += frames_to_write
  131. current_buffer = None
  132. buffer_pos = 0
  133. producer_task = asyncio.create_task(buffer_producer())
  134. with sd.OutputStream(
  135. samplerate=SAMPLE_RATE,
  136. channels=self.channels,
  137. dtype=self.dtype,
  138. callback=callback,
  139. ):
  140. await event.wait()
  141. await producer_task