| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- # mypy: ignore-errors
- from __future__ import annotations
- import queue
- import asyncio
- from typing import Any, Union, Callable, AsyncGenerator, cast
- from typing_extensions import TYPE_CHECKING
- from .. import _legacy_response
- from .._extras import numpy as np, sounddevice as sd
- from .._response import StreamedBinaryAPIResponse, AsyncStreamedBinaryAPIResponse
- if TYPE_CHECKING:
- import numpy.typing as npt
- SAMPLE_RATE = 24000
- class LocalAudioPlayer:
- def __init__(
- self,
- should_stop: Union[Callable[[], bool], None] = None,
- ):
- self.channels = 1
- self.dtype = np.float32
- self.should_stop = should_stop
- async def _tts_response_to_buffer(
- self,
- response: Union[
- _legacy_response.HttpxBinaryResponseContent,
- AsyncStreamedBinaryAPIResponse,
- StreamedBinaryAPIResponse,
- ],
- ) -> npt.NDArray[np.float32]:
- chunks: list[bytes] = []
- if isinstance(response, _legacy_response.HttpxBinaryResponseContent) or isinstance(
- response, StreamedBinaryAPIResponse
- ):
- for chunk in response.iter_bytes(chunk_size=1024):
- if chunk:
- chunks.append(chunk)
- else:
- async for chunk in response.iter_bytes(chunk_size=1024):
- if chunk:
- chunks.append(chunk)
- audio_bytes = b"".join(chunks)
- audio_np = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32767.0
- audio_np = audio_np.reshape(-1, 1)
- return audio_np
- async def play(
- self,
- input: Union[
- npt.NDArray[np.int16],
- npt.NDArray[np.float32],
- _legacy_response.HttpxBinaryResponseContent,
- AsyncStreamedBinaryAPIResponse,
- StreamedBinaryAPIResponse,
- ],
- ) -> None:
- audio_content: npt.NDArray[np.float32]
- if isinstance(input, np.ndarray):
- if input.dtype == np.int16 and self.dtype == np.float32:
- audio_content = (input.astype(np.float32) / 32767.0).reshape(-1, self.channels)
- elif input.dtype == np.float32:
- audio_content = cast("npt.NDArray[np.float32]", input)
- else:
- raise ValueError(f"Unsupported dtype: {input.dtype}")
- else:
- audio_content = await self._tts_response_to_buffer(input)
- loop = asyncio.get_event_loop()
- event = asyncio.Event()
- idx = 0
- def callback(
- outdata: npt.NDArray[np.float32],
- frame_count: int,
- _time_info: Any,
- _status: Any,
- ):
- nonlocal idx
- remainder = len(audio_content) - idx
- if remainder == 0 or (callable(self.should_stop) and self.should_stop()):
- loop.call_soon_threadsafe(event.set)
- raise sd.CallbackStop
- valid_frames = frame_count if remainder >= frame_count else remainder
- outdata[:valid_frames] = audio_content[idx : idx + valid_frames]
- outdata[valid_frames:] = 0
- idx += valid_frames
- stream = sd.OutputStream(
- samplerate=SAMPLE_RATE,
- callback=callback,
- dtype=audio_content.dtype,
- channels=audio_content.shape[1],
- )
- with stream:
- await event.wait()
- async def play_stream(
- self,
- buffer_stream: AsyncGenerator[Union[npt.NDArray[np.float32], npt.NDArray[np.int16], None], None],
- ) -> None:
- loop = asyncio.get_event_loop()
- event = asyncio.Event()
- buffer_queue: queue.Queue[Union[npt.NDArray[np.float32], npt.NDArray[np.int16], None]] = queue.Queue(maxsize=50)
- async def buffer_producer():
- async for buffer in buffer_stream:
- if buffer is None:
- break
- await loop.run_in_executor(None, buffer_queue.put, buffer)
- await loop.run_in_executor(None, buffer_queue.put, None) # Signal completion
- def callback(
- outdata: npt.NDArray[np.float32],
- frame_count: int,
- _time_info: Any,
- _status: Any,
- ):
- nonlocal current_buffer, buffer_pos
- frames_written = 0
- while frames_written < frame_count:
- if current_buffer is None or buffer_pos >= len(current_buffer):
- try:
- current_buffer = buffer_queue.get(timeout=0.1)
- if current_buffer is None:
- loop.call_soon_threadsafe(event.set)
- raise sd.CallbackStop
- buffer_pos = 0
- if current_buffer.dtype == np.int16 and self.dtype == np.float32:
- current_buffer = (current_buffer.astype(np.float32) / 32767.0).reshape(-1, self.channels)
- except queue.Empty:
- outdata[frames_written:] = 0
- return
- remaining_frames = len(current_buffer) - buffer_pos
- frames_to_write = min(frame_count - frames_written, remaining_frames)
- outdata[frames_written : frames_written + frames_to_write] = current_buffer[
- buffer_pos : buffer_pos + frames_to_write
- ]
- buffer_pos += frames_to_write
- frames_written += frames_to_write
- current_buffer = None
- buffer_pos = 0
- producer_task = asyncio.create_task(buffer_producer())
- with sd.OutputStream(
- samplerate=SAMPLE_RATE,
- channels=self.channels,
- dtype=self.dtype,
- callback=callback,
- ):
- await event.wait()
- await producer_task
|