"""
Whisper model implementation for speech-to-text.
This module provides the WhisperTranscriber class which uses OpenAI's Whisper
model for transcription.
"""
import os
import tempfile
from typing import Any, Dict, List, Optional, Union, Generator
import numpy as np
from fmus_vox.core.audio import Audio
from fmus_vox.core.errors import TranscriptionError, ModelError
from fmus_vox.core.utils import get_logger, download_file, ensure_path_exists
from fmus_vox.stt.transcriber import Transcriber, TranscriptionResult
[docs]
class WhisperTranscriber(Transcriber):
"""
Transcriber using OpenAI's Whisper model.
Whisper is a general-purpose speech recognition model that can transcribe
speech in multiple languages and translate it to English.
Args:
model: Whisper model size/variant (tiny, base, small, medium, large)
device: Computation device (cpu, cuda, auto)
download_root: Directory to download and store models
**kwargs: Additional model-specific parameters
"""
# Available Whisper models
_available_models = [
"whisper-tiny", "whisper-base", "whisper-small",
"whisper-medium", "whisper-large", "whisper-large-v2", "whisper-large-v3"
]
# Model size mapping (without whisper- prefix)
_model_size_map = {
"tiny": "tiny",
"base": "base",
"small": "small",
"medium": "medium",
"large": "large",
"large-v2": "large-v2",
"large-v3": "large-v3"
}
[docs]
def __init__(self, model: str = "whisper-base",
device: Optional[str] = None,
download_root: Optional[str] = None,
**kwargs):
"""
Initialize the Whisper transcriber.
Args:
model: Whisper model size/variant (tiny, base, small, medium, large)
device: Computation device (cpu, cuda, auto)
download_root: Directory to download and store models
**kwargs: Additional model-specific parameters
"""
# Map model name if needed
if not model.startswith("whisper-"):
model = f"whisper-{model}"
# Get model size
self.model_size = self._get_model_size(model)
# Initialize base class
super().__init__(model=model, device=device, **kwargs)
# Set download root
self.download_root = download_root or self.config.get("models_dir")
# Additional parameters
self.beam_size = kwargs.get("beam_size", 5)
self.temperature = kwargs.get("temperature", 0.0)
self.language_detection = kwargs.get("language_detection", True)
self.logger.debug(f"Initialized WhisperTranscriber with model_size={self.model_size}")
def _get_model_size(self, model: str) -> str:
"""Get the Whisper model size from the model name."""
# Remove whisper- prefix if present
if model.startswith("whisper-"):
model = model[8:]
# Check if model size is valid
if model not in self._model_size_map:
raise ModelError(f"Invalid Whisper model size: {model}. "
f"Available sizes: {list(self._model_size_map.keys())}")
return self._model_size_map[model]
def _load_model(self) -> Any:
"""
Load the Whisper model.
Returns:
Loaded Whisper model
Raises:
ModelError: If model loading fails
"""
try:
import torch
import whisper
# Set device
device = self.device
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
self.logger.info(f"Loading Whisper model {self.model_size} on {device}")
# Load the model
model = whisper.load_model(self.model_size, device=device, download_root=self.download_root)
return model
except ImportError:
raise ModelError("Failed to import whisper. Please install it with 'pip install openai-whisper'")
except Exception as e:
raise ModelError(f"Failed to load Whisper model: {str(e)}")
[docs]
def stream(self, audio_stream: Generator[Audio, None, None],
language: Optional[str] = None) -> Generator[TranscriptionResult, None, None]:
"""
Stream transcription for incoming audio chunks.
Args:
audio_stream: Generator yielding Audio objects
language: Language code (if None, auto-detect)
Yields:
TranscriptionResult for each processed chunk
Raises:
TranscriptionError: If transcription fails
"""
try:
# Ensure model is loaded
model = self._model.get()
# We'll accumulate audio until we have enough for a meaningful transcription
buffer = None
buffer_duration = 0.0
target_duration = 5.0 # Process in 5-second chunks
for audio_chunk in audio_stream:
# Resample if needed
if audio_chunk.sample_rate != 16000:
audio_chunk = audio_chunk.resample(target_sr=16000)
# Add to buffer
if buffer is None:
buffer = audio_chunk.data
else:
buffer = np.concatenate([buffer, audio_chunk.data])
buffer_duration += audio_chunk.duration
# Process when we have enough audio
if buffer_duration >= target_duration:
# Create Audio object from buffer
buffer_audio = Audio(buffer, 16000)
# Transcribe the buffer
result = self.transcribe_with_metadata(buffer_audio, language)
# Reset buffer (keep a small overlap to avoid cutting words)
overlap_samples = int(0.5 * 16000) # 0.5 seconds overlap
if len(buffer) > overlap_samples:
buffer = buffer[-overlap_samples:]
buffer_duration = 0.5
else:
buffer = None
buffer_duration = 0.0
yield result
# Process any remaining audio in the buffer
if buffer is not None and len(buffer) > 0:
buffer_audio = Audio(buffer, 16000)
result = self.transcribe_with_metadata(buffer_audio, language)
yield result
except Exception as e:
raise TranscriptionError(f"Whisper streaming transcription failed: {str(e)}")