Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding prompt when transcribe with Whisper #462

Merged
merged 16 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/reference/speech-transcription-vosk.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The *VoskLearner* class is a wrapper of libary [[1]](#alphacep/vosk-api/python-g

The [VoskLearner](/src/opendr/perception/speech_transcription/vosk/vosk_learner.py) class has the following public methods:

#### `WhisperLearner` constructor
#### `VoskLearner` constructor

```python
VoskLearner(self, device, sample_rate)
Expand Down
8 changes: 6 additions & 2 deletions docs/reference/speech-transcription-whisper.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,21 @@ Parameters:
#### `WhisperLearner.infer`

```python
WhisperLearner.infer(self, audio)
WhisperLearner.infer(self, audio, initial_prompt)
```

This method runs inference on an audio sample. Please call the load() method before calling this method.
This method runs inference on an audio sample. Please call the load() method before calling this method. `initial_prompt` is a string that can be used to suggest the context of the transcription text. For example: the name of a person that will appear in the transcription.

Return transcription as `WhisperTranscription` that contains transcription text
and other side information.

Parameters:
- **audio**: *Union[Timeseries, np.ndarray, torch.Tensor, str]*\
The audio sample as a `Timeseries`, `torch.Tensor`, or `np.ndarray` or a file path as `str`.
- **initial_prompt**: *Optional[str]*\
Optional text to provide as a prompt for the first window. This can be used to provide, or "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns to
make it more likely to predict those word correctly.



#### `WhisperLearner.load`
Expand Down
9 changes: 8 additions & 1 deletion projects/opendr_ws/src/opendr_perception/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -998,11 +998,17 @@ The node makes use of the toolkit's speech transcription tools:
2. You are then ready to start the speech transcription node

```shell
# Enable log to console.
rosrun opendr_perception speech_transcription_node.py --verbose True
```
```shell
# Use Whisper instead of Vosk and choose tiny.en variant.
rosrun opendr_perception speech_transcription_node.py --backbone whisper --model_name tiny.en --verbose True
```
```shell
# Suggest to Whisper that the speech will contain the name 'Felix'.
rosrun opendr_perception speech_transcription_node.py --backbone whisper --model_name tiny.en --initial_prompt "Felix" --verbose True
```
The following optional arguments are available (More in the source code):
- `-h or --help`: show a help message and exit
- `-i or --input_audio_topic INPUT_AUDIO_TOPIC`: topic name for input audio (default=`/audio/audio`)
Expand All @@ -1012,7 +1018,8 @@ The node makes use of the toolkit's speech transcription tools:
- `--model_name MODEL_NAME`: Specific model name for each backbone. Example: 'tiny', 'tiny.en', 'base', 'base.en' for Whisper, 'vosk-model-small-en-us-0.15' for Vosk (default=`None`)
- `--model_path MODEL_PATH`: Path to downloaded model files (default=`None`)
- `--language LANGUAGE`: Whisper uses the language parameter to avoid language dectection. Vosk uses the langauge paremeter to select a specific model. Example: 'en' for Whisper, 'en-us' for Vosk (default=`en-us`). Check the available language codes for Whisper at [Whipser repository](https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/tokenizer.py#L10). Check the available language code for Vosk from the Vosk model name at [Vosk website](https://alphacephei.com/vosk/models).
- `--verbose VERBOSE`: Display transcription (default=`False`)
- `--initial_prompt INITIAL_PROMPT`: Prompt to provide some context or instruction for the transcription, only for Whisper
- `--verbose VERBOSE`: Display transcription (default=`False`).

3. Default output topics:
- Speech transcription: `/opendr/speech_transcription`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
temperature: Union[float, Tuple[float, ...]]=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
logprob_threshold: Optional[float]=-0.8,
no_speech_threshold: float=0.6,
initial_prompt: Optional[str]=None,
phrase_timeout: float=2,
input_audio_topic: str="/audio/audio",
output_transcription_topic: str="/opendr/speech_transcription",
Expand Down Expand Up @@ -101,6 +102,11 @@ def __init__(
:param no_speech_threshold: Threshold for detecting long silence in Whisper.
:type no_speech_threshold: float.

:param initial_prompt: Optional text to provide as a prompt for the first window. This can be used to
provide, or "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns to
make it more likely to predict those word correctly.
:type initial_prompt: Optional[str].

:param phrase_timeout: The most recent seconds used for detecting long silence in Whisper.
:type phrase_timeout: float.

Expand Down Expand Up @@ -144,6 +150,7 @@ def __init__(
self.temperature = temperature
self.logprob_threshold = logprob_threshold
self.no_speech_threshold = no_speech_threshold
self.initial_prompt = initial_prompt
self.phrase_timeout = phrase_timeout

# Initialize model
Expand Down Expand Up @@ -373,7 +380,7 @@ def _whisper_process_and_publish(self):
"""
audio_array = WhisperLearner.load_audio(self.temp_file)
self.vad = self._whisper_vad(audio_array)
transcription_whisper = self.audio_model.infer(audio_array)
transcription_whisper = self.audio_model.infer(audio_array, initial_prompt=self.initial_prompt)

vosk_transcription = self._postprocess_whisper(
audio_array, transcription_whisper
Expand Down Expand Up @@ -501,6 +508,12 @@ def main():
"--sample_rate", type=int, default=16000, help="Sampling rate for audio data."
"Check your audio source for correct value."
)
parser.add_argument(
"--initial_prompt",
default="",
type=str,
help="Prompt to provide some context or instruction for the transcription, only for Whisper",
)
args = parser.parse_args(rospy.myargv()[1:])

try:
Expand Down Expand Up @@ -560,6 +573,7 @@ def main():
temperature=temperature,
logprob_threshold=args.logprob_threshold,
no_speech_threshold=args.no_speech_threshold,
initial_prompt=args.initial_prompt,
input_audio_topic=args.input_audio_topic,
output_transcription_topic=args.output_transcription_topic,
performance_topic=args.performance_topic,
Expand Down
7 changes: 7 additions & 0 deletions projects/opendr_ws_2/src/opendr_perception/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1006,11 +1006,17 @@ The node makes use of the toolkit's speech transcription tools:
2. You are then ready to start the speech transcription node

```shell
# Enable log to console.
ros2 run opendr_perception speech_transcription --verbose True
```
```shell
# Use Whisper instead of Vosk and choose tiny.en variant.
ros2 run opendr_perception speech_transcription --backbone whisper --model_name tiny.en --verbose True
```
```shell
# Suggest to Whisper that the speech will contain the name 'Felix'.
ros2 run opendr_perception speech_transcription --backbone whisper --model_name tiny.en --initial_prompt "Felix" --verbose True
```
The following optional arguments are available (More in the source code):
- `-h or --help`: show a help message and exit
- `-i or --input_audio_topic INPUT_AUDIO_TOPIC`: topic name for input audio (default=`/audio/audio`)
Expand All @@ -1020,6 +1026,7 @@ The node makes use of the toolkit's speech transcription tools:
- `--model_name MODEL_NAME`: Specific model name for each backbone. Example: 'tiny', 'tiny.en', 'base', 'base.en' for Whisper, 'vosk-model-small-en-us-0.15' for Vosk (default=`None`)
- `--model_path MODEL_PATH`: Path to downloaded model files (default=`None`)
- `--language LANGUAGE`: Whisper uses the language parameter to avoid language dectection. Vosk uses the langauge paremeter to select a specific model. Example: 'en' for Whisper, 'en-us' for Vosk (default=`en-us`). Check the available language codes for Whisper at [Whipser repository](https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/tokenizer.py#L10). Check the available language code for Vosk from the Vosk model name at [Vosk website](https://alphacephei.com/vosk/models).
- `--initial_prompt INITIAL_PROMPT`: Prompt to provide some context or instruction for the transcription, only for Whisper
- `--verbose VERBOSE`: Display transcription (default=`False`)

3. Default output topics:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
temperature: Union[float, Tuple[float, ...]]=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
logprob_threshold: Optional[float]=-0.8,
no_speech_threshold: float=0.6,
initial_prompt: Optional[str]=None,
phrase_timeout: float=2,
input_audio_topic: str="/audio/audio",
output_transcription_topic: str="/opendr/speech_transcription",
Expand Down Expand Up @@ -103,6 +104,11 @@ def __init__(
:param no_speech_threshold: Threshold for detecting long silence in Whisper.
:type no_speech_threshold: float.

:param initial_prompt: Optional text to provide as a prompt for the first window. This can be used to
provide, or "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns to
make it more likely to predict those word correctly.
:type initial_prompt: Optional[str].

:param phrase_timeout: The most recent seconds used for detecting long silence in Whisper.
:type phrase_timeout: float.

Expand Down Expand Up @@ -146,6 +152,7 @@ def __init__(
self.temperature = temperature
self.logprob_threshold = logprob_threshold
self.no_speech_threshold = no_speech_threshold
self.initial_prompt = initial_prompt
self.phrase_timeout = phrase_timeout

# Initialize model
Expand Down Expand Up @@ -377,7 +384,7 @@ def _whisper_process_and_publish(self):
"""
audio_array = WhisperLearner.load_audio(self.temp_file)
self.vad = self._whisper_vad(audio_array)
transcription_whisper = self.audio_model.infer(audio_array)
transcription_whisper = self.audio_model.infer(audio_array, initial_prompt=self.initial_prompt)

vosk_transcription = self._postprocess_whisper(
audio_array, transcription_whisper
Expand Down Expand Up @@ -504,6 +511,12 @@ def main(args=None):
"--sample_rate", type=int, default=16000, help="Sampling rate for audio data."
"Check your audio source for correct value."
)
parser.add_argument(
"--initial_prompt",
default="",
type=str,
help="Prompt to provide some context or instruction for the transcription, only for Whisper",
)
args = parser.parse_args()

try:
Expand Down Expand Up @@ -562,6 +575,7 @@ def main(args=None):
temperature=temperature,
logprob_threshold=args.logprob_threshold,
no_speech_threshold=args.no_speech_threshold,
initial_prompt=args.initial_prompt,
input_audio_topic=args.input_audio_topic,
output_transcription_topic=args.output_transcription_topic,
performance_topic=args.performance_topic,
Expand Down
4 changes: 3 additions & 1 deletion projects/python/perception/speech_transcription/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ The `demo_live.py` is a simple command line tool that continuously record and tr
```
python demo_live.py -d 5 -i 0.25 --backbone whisper --model_name tiny.en --language en --device cuda
```

```
python demo_live.py -d 5 -i 0.25 --backbone whisper --model_name tiny.en --language en --initial_prompt "Vosk" --device cuda
```
```
python demo_live.py -d 5 -i 0.25 --backbone vosk --language en-us
```
Expand Down
8 changes: 4 additions & 4 deletions projects/python/perception/speech_transcription/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ def str2bool(v):
choices=["whisper", "vosk"],
)
parser.add_argument(
"--model-path",
"--model_path",
type=str,
help="path to the model files, if not given, the pretrained model will be downloaded",
default=None,
)
parser.add_argument(
"--model-name",
"--model_name",
type=str,
help="Specific name for Whisper model",
choices="Available models name: ['tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium',"
Expand All @@ -72,12 +72,12 @@ def str2bool(v):
help="Language for the model",
)
parser.add_argument(
"--download-dir",
"--download_dir",
type=str,
help="Path to the directory where the model will be downloaded",
)
parser.add_argument(
"--builtin-transcribe",
"--builtin_transcribe",
type=str2bool,
help="Use the built-in transcribe function of the Whisper model",
)
Expand Down
19 changes: 14 additions & 5 deletions projects/python/perception/speech_transcription/demo_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.


from typing import Callable
from typing import Optional, Callable
import argparse
import time

Expand Down Expand Up @@ -50,8 +50,8 @@ def record_audio(duration: int, sample_rate: int) -> np.ndarray:
return audio_data


def transcribe_audio(audio_data: np.ndarray, transcribe_function: Callable):
output = transcribe_function(audio_data)
def transcribe_audio(audio_data: np.ndarray, initial_prompt: Optional[str], transcribe_function: Callable):
output = transcribe_function(audio=audio_data, initial_prompt=initial_prompt)
output = output.text

print("Transcription: ", output)
Expand All @@ -71,7 +71,7 @@ def wait_for_start_command(learner, sample_rate):


def main(
backbone, duration, interval, model_path, model_name, language, download_dir, device
backbone, duration, interval, model_path, model_name, initial_prompt, language, download_dir, device
):
if backbone == "whisper":
learner = WhisperLearner(language=language, device=device)
Expand All @@ -89,14 +89,16 @@ def main(

# Wait for the user to say "start" before starting the loop
sample_rate = 16000
print("Waiting for 'start' command. Say 'start' to start the transcribe loop.")
print("Say 'stop' to stop the transcribe loop.")
wait_for_start_command(learner, sample_rate)

while True:
# Record the audio
audio_data = record_audio(duration, sample_rate)

# Transcribe the recorded audio and check for the "stop" command
transcription = transcribe_audio(audio_data, learner.infer).lower()
transcription = transcribe_audio(audio_data, initial_prompt, learner.infer).lower()

if "stop" in transcription:
print("Stop command received. Exiting the program.")
Expand Down Expand Up @@ -152,6 +154,12 @@ def main(
"'medium.en', 'medium', 'large-v1', 'large-v2', 'large']",
default=None,
)
parser.add_argument(
"--initial_prompt",
default="",
type=str,
help="Prompt to provide some context or instruction for the transcription, only for Whisper",
)
parser.add_argument(
"--language",
type=str,
Expand All @@ -170,6 +178,7 @@ def main(
interval=args.interval,
model_path=args.model_path,
model_name=args.model_name,
initial_prompt=args.initial_prompt,
language=args.language,
download_dir=args.download_dir,
device=args.device,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def download(
def infer(
self,
audio: Union[Timeseries, np.ndarray, torch.Tensor, str],
initial_prompt: Optional[str] = None,
) -> WhisperTranscription:
"""
Run inference on an audio sample. Please call the load() method before calling this method.
Expand All @@ -295,6 +296,10 @@ def infer(
audio (Union[Timeseries, np.ndarray, torch.Tensor, str]): The audio sample as a Timeseries, torch.Tensor, or
np.ndarray or a string of file path.

initial_prompt (str, optional): Optional text to provide as a prompt for the first window.
This can be used to provide, or "prompt-engineer" a context for transcription, e.g. custom vocabularies or
proper nouns to make it more likely to predict those word correctly.

Returns:
WhisperTranscription: Transcription results with side information.

Expand All @@ -319,6 +324,7 @@ def infer(
logprob_threshold=self.logprob_threshold,
condition_on_previous_text=self.condition_on_previous_text,
word_timestamps=self.word_timestamps,
initial_prompt=initial_prompt,
prepend_punctuations=self.prepend_punctuations,
append_punctuations=self.append_punctuations,
**asdict(self.decode_options),
Expand Down
Loading