Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Whisper language detection #1097

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"@types/node": "^22.10.1",
"@webgpu/types": "^0.1.51",
"catharsis": "github:xenova/catharsis",
"fastest-levenshtein": "^1.0.16",
"jest": "^30.0.0-alpha.6",
"jest-environment-node": "^30.0.0-alpha.6",
"jsdoc-to-markdown": "^9.1.1",
Expand Down
7 changes: 7 additions & 0 deletions src/generation/configuration_utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,13 @@ export class GenerationConfig {
*/
bad_words_ids = null;

/**
* List of token ids that are allowed to be generated.
* @type {number[][]}
* @default null
*/
good_words_ids = null;

/**
* List of token ids that must be generated.
* If given a `number[][]`, this is treated as a simple list of words that must be included, the opposite to `bad_words_ids`.
Expand Down
34 changes: 34 additions & 0 deletions src/generation/logits_process.js
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,40 @@ export class NoBadWordsLogitsProcessor extends LogitsProcessor {
}
}

export class OnlyGoodWordsLogitsProcessor extends LogitsProcessor {
/**
* Create a `OnlyGoodWordsLogitsProcessor`.
* @param {number[][]} good_words_ids List of list of token ids that are allowed to be generated.
* @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
*/
constructor(good_words_ids, eos_token_id) {
super();
this.good_words_ids = good_words_ids;
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
}

/**
* Apply logit processor.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The processed logits.
*/
_call(input_ids, logits) {
const good_ids = this.good_words_ids.flat();
// Iterate over batches of input IDs and logits
for (let i = 0; i < input_ids.length; ++i) {
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
// For every ID, set its logit score to -Infinity unless it's in our list of valid token IDs
for (let j = 0; j < batch_logits_data.length; ++j) {
if (!good_ids.includes(j)) {
batch_logits_data[j] = -Infinity;
}
}
}
return logits
}
}

/**
* [`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension,
* where the first half correspond to the conditional logits (predicted from the input prompt) and the second half
Expand Down
13 changes: 13 additions & 0 deletions src/generation/stopping_criteria.js
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,16 @@ export class InterruptableStoppingCriteria extends StoppingCriteria {
return new Array(input_ids.length).fill(this.interrupted);
}
}

/**
* This class can be used to always stop generation after one pass.
*/
export class AlwaysStopCriteria extends StoppingCriteria {
constructor() {
super();
}

_call(input_ids, scores) {
return new Array(input_ids.length).fill(true);
}
}
101 changes: 78 additions & 23 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ import {
TopKLogitsWarper,
TopPLogitsWarper,
ClassifierFreeGuidanceLogitsProcessor,
OnlyGoodWordsLogitsProcessor,
} from './generation/logits_process.js';

import {
Expand All @@ -112,7 +113,7 @@ import {
import { RawImage } from './utils/image.js';

import { dynamic_time_warping, max, medianFilter } from './utils/maths.js';
import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js';
import { AlwaysStopCriteria, EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js';
import { LogitsSampler } from './generation/logits_sampler.js';
import { apis } from './env.js';

Expand Down Expand Up @@ -1212,6 +1213,10 @@ export class PreTrainedModel extends Callable {
processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id));
}

if (generation_config.good_words_ids !== null) {
processors.push(new OnlyGoodWordsLogitsProcessor(generation_config.good_words_ids, generation_config.eos_token_id));
}

if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) {
processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id));
}
Expand Down Expand Up @@ -3119,10 +3124,48 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
}

/**
* Detects language by running input through the model and checking for language tokens in the output.
*
* @param {WhisperGenerationConfig} generation_config
* @param {import('./models/whisper/generation_whisper.js').WhisperGenerationFunctionParameters} options
* @returns {Promise<number[]>} A list of language token IDs detected.
*/
async _detect_language(options) {
const inputs = options.inputs
const generation_config = options.generation_config;
const batch_size = inputs?.dims?.[0]
if (!inputs || batch_size <= 0 || inputs.size <= 0) {
throw new Error("Cannot detect language for empty input");
}
const start_of_transcript = generation_config.decoder_start_token_id;
const decoder_input_ids = full([batch_size, 1], Number(1.0)).mul_(start_of_transcript).tolist();
const all_lang_ids = Object.values(generation_config.lang_to_id);
if (!all_lang_ids || all_lang_ids.length <= 0) {
throw new Error("Cannot detect language without language code to token ID map for model");
}
const stopping_criteria = new StoppingCriteriaList();
stopping_criteria.push(new AlwaysStopCriteria());
const good_words_ids = [all_lang_ids];
const output = await this.generate({
...options,
generation_config: {
...generation_config,
good_words_ids,
num_beams: 1,
do_sample: false,
},
stopping_criteria,
decoder_input_ids,
});
const sane = Array.from((/**@type {Tensor}**/(output)).data).flatMap(x => Number(x));
const lang_ids = sane.filter(x => Object.values(generation_config.lang_to_id).includes(x));
return lang_ids;
}

/**
* @param {import('./models/whisper/generation_whisper.js').WhisperGenerationFunctionParameters} options
*/
_retrieve_init_tokens(generation_config) {
async _retrieve_init_tokens(options) {
const generation_config = options.generation_config
// prefix tokens are of the form:
// - Multilingual: <|startoftranscript|> <|lang_id|> <|task|> [<|notimestamps|>]
// - English-only: <|startoftranscript|> [<|notimestamps|>]
Expand All @@ -3134,16 +3177,26 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
let language = generation_config.language;
const task = generation_config.task;
if (generation_config.is_multilingual) {
let lang_id;
if (!language) {
// TODO: Implement language detection
console.warn('No language specified - defaulting to English (en).');
language = 'en';
try {
const lang_token_ids = await this._detect_language(options);
lang_id = lang_token_ids[0];
if (!lang_id) {
throw new Error("No language detected");
}
} catch (err) {
console.warn("No language detected - defaulting to English (en).");
language = "en";
}
}

// Add language token
const language_code = whisper_language_to_code(language);
const language_token = `<|${language_code}|>`;
init_tokens.push(generation_config.lang_to_id[language_token])
if (language) {
// Add language token
const language_code = whisper_language_to_code(language);
const language_token = `<|${language_code}|>`;
lang_id = generation_config.lang_to_id[language_token];
}
init_tokens.push(lang_id);

// Add task token
// NOTE: Defaults to 'transcribe' if no task is specified
Expand Down Expand Up @@ -3180,22 +3233,24 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
* @param {import('./models/whisper/generation_whisper.js').WhisperGenerationFunctionParameters} options
* @returns {Promise<ModelOutput|Tensor>} The output of the model, which can contain the generated token ids, attentions, and scores.
*/
async generate({
inputs = null,
generation_config = null,
logits_processor = null,
stopping_criteria = null,
async generate(options) {
let {
inputs = null,
generation_config = null,
logits_processor = null,
//stopping_criteria = null,

// Whisper-specific options (passed to kwargs)
// prompt_ids = null,
// language = null,
// task = null,
// Whisper-specific options (passed to kwargs)
// prompt_ids = null,
// language = null,
// task = null,

...kwargs
} = options;

...kwargs
}) {
generation_config = this._prepare_generation_config(generation_config, kwargs);

const init_tokens = kwargs.decoder_input_ids ?? this._retrieve_init_tokens(generation_config);
const init_tokens = kwargs.decoder_input_ids ?? await this._retrieve_init_tokens({ ...options, generation_config });

if (generation_config.return_timestamps) {
logits_processor ??= new LogitsProcessorList();
Expand Down
18 changes: 16 additions & 2 deletions tests/models/whisper/test_modeling_whisper.js
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,8 @@ export default () => {
it(
"language unset; task unset",
async () => {
// language defaults to 'en'
// language defaults to detect, falling back to 'en'
// task defaults to 'transcribe'

const outputs = await model.generate({
input_features,
max_new_tokens: 1,
Expand All @@ -66,6 +65,21 @@ export default () => {
MAX_TEST_EXECUTION_TIME,
);

it(
"language unset; task set",
async () => {
// language defaults to detect, falling back to 'en'
const outputs = await model.generate({
input_features,
max_new_tokens: 1,
task: "translate",
});

expect(outputs.tolist()).toEqual([[/* Prefix */ 50258n, 50259n, 50358n, 50363n, /* Generated */ 45084n]]);
},
MAX_TEST_EXECUTION_TIME,
);

it(
"language set; task unset",
async () => {
Expand Down
33 changes: 28 additions & 5 deletions tests/pipelines.test.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { pipeline, cos_sim } from "../src/transformers.js";
import { init, MAX_TEST_EXECUTION_TIME } from "./init.js";
import { collect_and_execute_pipeline_tests, compare, loadAudio } from "./test_utils.js";
import { collect_and_execute_pipeline_tests, compare, compareString, loadAudio } from "./test_utils.js";

// Initialise the testing environment
init();
Expand Down Expand Up @@ -724,7 +724,8 @@ xdescribe("Pipelines (ignored)", () => {
// Transcribe English
let output = await transcriber(audioData);
expect(output.text.length).toBeGreaterThan(50);
// { text: " And so my fellow Americans ask not what your country can do for you, ask what you can do for your country." }
const expected = " And so my fellow Americans ask not what your country can do for you, ask what you can do for your country.";
compareString(expected, output.text);
}

{
Expand Down Expand Up @@ -757,20 +758,42 @@ xdescribe("Pipelines (ignored)", () => {
// Transcribe French
let output = await transcriber(audioData, { language: "french", task: "transcribe" });
expect(output.text.length).toBeGreaterThan(20);
// { text: " J'adore, j'aime, je n'aime pas, je déteste." }
const expected = " J'adore, j'aime, je n'aime pas, je déteste.";
compareString(expected, output.text);
}

{
// Translate French to English.
let output = await transcriber(audioData, { language: "french", task: "translate" });
expect(output.text.length).toBeGreaterThan(20);
// { text: " I love, I like, I don't like, I hate." }
const expected = " I love, I like, I don't like, I hate.";
compareString(expected, output.text);
}
await transcriber.dispose();
},
MAX_TEST_EXECUTION_TIME,
);


it(
`${models[1]}-language-detect`,
async () => {
let transcriber = await pipeline("automatic-speech-recognition", models[1]);
let url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/japanese-audio.wav";
let audioData = await loadAudio(url);
{
// Transcribe Japanese by autodetecting language
// Note: this sample needs to be hard enough for Whisper not to be able to transcribe it properly
// with the fallback 'en' language set!
let output = await transcriber(audioData, { language: null, task: "transcribe" });
expect(output.text.length).toBeGreaterThan(20);
const expected = "モリナガの美味しい牛乳は濃い青色に牛乳瓶を払ったゼザインのパック牛乳である。";
compareString(expected, output.text, 0.8);
}
await transcriber.dispose();
},
MAX_TEST_EXECUTION_TIME,
);

it(
models[2].join(" + "),
async () => {
Expand Down
18 changes: 18 additions & 0 deletions tests/test_utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import fs from "fs";
import path from "path";
import { fileURLToPath } from "url";

import { distance } from "fastest-levenshtein";

export async function loadAudio(url) {
// NOTE: Since the Web Audio API is not available in Node.js, we will need to use the `wavefile` library to obtain the raw audio data.
// For more information, see: https://huggingface.co/docs/transformers.js/guides/node-audio-processing
Expand Down Expand Up @@ -68,6 +70,22 @@ export function compare(val1, val2, tol = 0.1) {
}
}

/**
* Compare two strings adding some tolerance for variation between model outputs.
*
* Similarity score is computing using Levenshtein distance (n_diff) between the two strings, as a fraction of the first string's length:
* similarity score = 1 - n_diff / str1.length.
*
* @param {string} str1 The first string
* @param {string} str2 The second string
* @param {number} tol Tolerance score for similarity between strings, from -Infinity to 1.0 (100% match).
*/
export function compareString(str1, str2, tol = 0.9) {
const dist = distance(str1, str2);
const score = 1 - dist / (str1.length ?? 1);
expect(score).toBeGreaterThanOrEqual(tol);
}

const __filename = fileURLToPath(import.meta.url);
const __dirname = path.dirname(__filename);
const models_dir = path.join(__dirname, "models");
Expand Down
Loading