From 242d33cfc69b9669fc3a29e36305682e1b663091 Mon Sep 17 00:00:00 2001 From: ae9is <125031666+ae9is@users.noreply.github.com> Date: Fri, 13 Dec 2024 12:17:39 +0000 Subject: [PATCH 1/7] Add language detection support with Whisper tasks --- src/models.js | 81 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 59 insertions(+), 22 deletions(-) diff --git a/src/models.js b/src/models.js index 8e53e2fbe..92e1761a7 100644 --- a/src/models.js +++ b/src/models.js @@ -3119,10 +3119,35 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { } /** + * Detects language by running input through the model and checking for language tokens in the output. * - * @param {WhisperGenerationConfig} generation_config + * @param {import('./models/whisper/generation_whisper.js').WhisperGenerationFunctionParameters} options + * @returns {Promise} A list of language token IDs detected. + */ + async _detect_language(options) { + const inputs = options.inputs + const generation_config = options.generation_config; + const batch_size = inputs?.dims?.[0] + if (!inputs || batch_size <= 0 || inputs.size <= 0) { + throw new Error("Cannot detect language for empty input"); + } + const start_of_transcript = generation_config.decoder_start_token_id; + const decoder_input_ids = full([batch_size, 1], Number(1.0)).mul_(start_of_transcript).tolist(); + const all_lang_ids = Object.values(generation_config.lang_to_id); + if (!all_lang_ids || all_lang_ids.length <= 0) { + throw new Error("Cannot detect language without language code to token ID map for model"); + } + const output = await this.generate({ ...options, decoder_input_ids }); + const sane = Array.from((/**@type {Tensor}**/(output)).data).flatMap(x => Number(x)); + const lang_ids = sane.filter(x => Object.values(generation_config.lang_to_id).includes(x)); + return lang_ids; + } + + /** + * @param {import('./models/whisper/generation_whisper.js').WhisperGenerationFunctionParameters} options */ - _retrieve_init_tokens(generation_config) { + async _retrieve_init_tokens(options) { + const generation_config = options.generation_config // prefix tokens are of the form: // - Multilingual: <|startoftranscript|> <|lang_id|> <|task|> [<|notimestamps|>] // - English-only: <|startoftranscript|> [<|notimestamps|>] @@ -3134,16 +3159,26 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { let language = generation_config.language; const task = generation_config.task; if (generation_config.is_multilingual) { + let lang_id; if (!language) { - // TODO: Implement language detection - console.warn('No language specified - defaulting to English (en).'); - language = 'en'; + try { + const lang_token_ids = await this._detect_language(options); + lang_id = lang_token_ids[0]; + if (!lang_id) { + throw new Error("No language detected"); + } + } catch (err) { + console.warn("No language detected - defaulting to English (en)."); + language = "en"; + } } - - // Add language token - const language_code = whisper_language_to_code(language); - const language_token = `<|${language_code}|>`; - init_tokens.push(generation_config.lang_to_id[language_token]) + if (language) { + // Add language token + const language_code = whisper_language_to_code(language); + const language_token = `<|${language_code}|>`; + lang_id = generation_config.lang_to_id[language_token]; + } + init_tokens.push(lang_id); // Add task token // NOTE: Defaults to 'transcribe' if no task is specified @@ -3180,22 +3215,24 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { * @param {import('./models/whisper/generation_whisper.js').WhisperGenerationFunctionParameters} options * @returns {Promise} The output of the model, which can contain the generated token ids, attentions, and scores. */ - async generate({ - inputs = null, - generation_config = null, - logits_processor = null, - stopping_criteria = null, + async generate(options) { + let { + inputs = null, + generation_config = null, + logits_processor = null, + //stopping_criteria = null, - // Whisper-specific options (passed to kwargs) - // prompt_ids = null, - // language = null, - // task = null, + // Whisper-specific options (passed to kwargs) + // prompt_ids = null, + // language = null, + // task = null, + + ...kwargs + } = options; - ...kwargs - }) { generation_config = this._prepare_generation_config(generation_config, kwargs); - const init_tokens = kwargs.decoder_input_ids ?? this._retrieve_init_tokens(generation_config); + const init_tokens = kwargs.decoder_input_ids ?? await this._retrieve_init_tokens({ ...options, generation_config }); if (generation_config.return_timestamps) { logits_processor ??= new LogitsProcessorList(); From 467851a879a5d2e3869aaaf11249a9d3d1f8543b Mon Sep 17 00:00:00 2001 From: ae9is <125031666+ae9is@users.noreply.github.com> Date: Fri, 13 Dec 2024 12:21:53 +0000 Subject: [PATCH 2/7] Add tests for Whisper language detection --- tests/models/whisper/test_modeling_whisper.js | 18 +++++++++++-- tests/pipelines.test.js | 26 +++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/tests/models/whisper/test_modeling_whisper.js b/tests/models/whisper/test_modeling_whisper.js index fff3cc1f7..713e33ada 100644 --- a/tests/models/whisper/test_modeling_whisper.js +++ b/tests/models/whisper/test_modeling_whisper.js @@ -53,9 +53,8 @@ export default () => { it( "language unset; task unset", async () => { - // language defaults to 'en' + // language defaults to detect, falling back to 'en' // task defaults to 'transcribe' - const outputs = await model.generate({ input_features, max_new_tokens: 1, @@ -66,6 +65,21 @@ export default () => { MAX_TEST_EXECUTION_TIME, ); + it( + "language unset; task set", + async () => { + // language defaults to detect, falling back to 'en' + const outputs = await model.generate({ + input_features, + max_new_tokens: 1, + task: "translate", + }); + + expect(outputs.tolist()).toEqual([[/* Prefix */ 50258n, 50259n, 50358n, 50363n, /* Generated */ 45084n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + it( "language set; task unset", async () => { diff --git a/tests/pipelines.test.js b/tests/pipelines.test.js index bfdef3872..8736ff2d9 100644 --- a/tests/pipelines.test.js +++ b/tests/pipelines.test.js @@ -770,7 +770,33 @@ xdescribe("Pipelines (ignored)", () => { }, MAX_TEST_EXECUTION_TIME, ); + + it( + models[1], + async () => { + let transcriber = await pipeline("automatic-speech-recognition", models[1]); + + let url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/french-audio.wav"; + let audioData = await loadAudio(url); + + { + // Transcribe French by autodetecting language + let output = await transcriber(audioData, { language: null, task: "transcribe" }); + expect(output.text.length).toBeGreaterThan(20); + // { text: " J'adore, j'aime, je n'aime pas, je déteste." } + } + { + // Translate French to English with language autodetect + let output = await transcriber(audioData, { language: null, task: "translate" }); + expect(output.text.length).toBeGreaterThan(20); + // { text: " I love, I like, I don't like, I hate." } + } + await transcriber.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + it( models[2].join(" + "), async () => { From 88bba08eede6e7825c7f1595c4e6875dd3d6b4f7 Mon Sep 17 00:00:00 2001 From: ae9is <125031666+ae9is@users.noreply.github.com> Date: Fri, 13 Dec 2024 12:31:21 +0000 Subject: [PATCH 3/7] Add test utility to compare string similarity --- package-lock.json | 2 ++ package.json | 1 + tests/test_utils.js | 18 ++++++++++++++++++ 3 files changed, 21 insertions(+) diff --git a/package-lock.json b/package-lock.json index 13a1d5d97..00c626b96 100644 --- a/package-lock.json +++ b/package-lock.json @@ -19,6 +19,7 @@ "@types/node": "^22.10.1", "@webgpu/types": "^0.1.51", "catharsis": "github:xenova/catharsis", + "fastest-levenshtein": "^1.0.16", "jest": "^30.0.0-alpha.6", "jest-environment-node": "^30.0.0-alpha.6", "jsdoc-to-markdown": "^9.1.1", @@ -4554,6 +4555,7 @@ "resolved": "https://registry.npmjs.org/fastest-levenshtein/-/fastest-levenshtein-1.0.16.tgz", "integrity": "sha512-eRnCtTTtGZFpQCwhJiUOuxPQWRXVKYDn0b2PeHfXL6/Zi53SLAzAHfVhVWK2AryC/WH05kGfxhFIPvTF0SXQzg==", "dev": true, + "license": "MIT", "engines": { "node": ">= 4.9.1" } diff --git a/package.json b/package.json index bc845c0df..3bf1df085 100644 --- a/package.json +++ b/package.json @@ -65,6 +65,7 @@ "@types/node": "^22.10.1", "@webgpu/types": "^0.1.51", "catharsis": "github:xenova/catharsis", + "fastest-levenshtein": "^1.0.16", "jest": "^30.0.0-alpha.6", "jest-environment-node": "^30.0.0-alpha.6", "jsdoc-to-markdown": "^9.1.1", diff --git a/tests/test_utils.js b/tests/test_utils.js index c42c5f201..ca744d3fc 100644 --- a/tests/test_utils.js +++ b/tests/test_utils.js @@ -2,6 +2,8 @@ import fs from "fs"; import path from "path"; import { fileURLToPath } from "url"; +import { distance } from "fastest-levenshtein"; + export async function loadAudio(url) { // NOTE: Since the Web Audio API is not available in Node.js, we will need to use the `wavefile` library to obtain the raw audio data. // For more information, see: https://huggingface.co/docs/transformers.js/guides/node-audio-processing @@ -68,6 +70,22 @@ export function compare(val1, val2, tol = 0.1) { } } +/** + * Compare two strings adding some tolerance for variation between model outputs. + * + * Similarity score is computing using Levenshtein distance (n_diff) between the two strings, as a fraction of the first string's length: + * similarity score = 1 - n_diff / str1.length. + * + * @param {string} str1 The first string + * @param {string} str2 The second string + * @param {number} tol Tolerance score for similarity between strings, from -Infinity to 1.0 (100% match). + */ +export function compareString(str1, str2, tol = 0.9) { + const dist = distance(str1, str2); + const score = 1 - dist / (str1.length ?? 1); + expect(score).toBeGreaterThanOrEqual(tol); +} + const __filename = fileURLToPath(import.meta.url); const __dirname = path.dirname(__filename); const models_dir = path.join(__dirname, "models"); From fc56bdc178a7ef2c8e8be72e071a003e55077408 Mon Sep 17 00:00:00 2001 From: ae9is <125031666+ae9is@users.noreply.github.com> Date: Fri, 13 Dec 2024 12:42:23 +0000 Subject: [PATCH 4/7] Quality check output for some Whisper pipeline tests --- tests/pipelines.test.js | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/pipelines.test.js b/tests/pipelines.test.js index 8736ff2d9..67d84c9b4 100644 --- a/tests/pipelines.test.js +++ b/tests/pipelines.test.js @@ -1,6 +1,6 @@ import { pipeline, cos_sim } from "../src/transformers.js"; import { init, MAX_TEST_EXECUTION_TIME } from "./init.js"; -import { collect_and_execute_pipeline_tests, compare, loadAudio } from "./test_utils.js"; +import { collect_and_execute_pipeline_tests, compare, compareString, loadAudio } from "./test_utils.js"; // Initialise the testing environment init(); @@ -724,7 +724,8 @@ xdescribe("Pipelines (ignored)", () => { // Transcribe English let output = await transcriber(audioData); expect(output.text.length).toBeGreaterThan(50); - // { text: " And so my fellow Americans ask not what your country can do for you, ask what you can do for your country." } + const expected = " And so my fellow Americans ask not what your country can do for you, ask what you can do for your country."; + compareString(expected, output.text); } { @@ -757,14 +758,16 @@ xdescribe("Pipelines (ignored)", () => { // Transcribe French let output = await transcriber(audioData, { language: "french", task: "transcribe" }); expect(output.text.length).toBeGreaterThan(20); - // { text: " J'adore, j'aime, je n'aime pas, je déteste." } + const expected = " J'adore, j'aime, je n'aime pas, je déteste."; + compareString(expected, output.text); } { // Translate French to English. let output = await transcriber(audioData, { language: "french", task: "translate" }); expect(output.text.length).toBeGreaterThan(20); - // { text: " I love, I like, I don't like, I hate." } + const expected = " I love, I like, I don't like, I hate."; + compareString(expected, output.text); } await transcriber.dispose(); }, @@ -783,14 +786,16 @@ xdescribe("Pipelines (ignored)", () => { // Transcribe French by autodetecting language let output = await transcriber(audioData, { language: null, task: "transcribe" }); expect(output.text.length).toBeGreaterThan(20); - // { text: " J'adore, j'aime, je n'aime pas, je déteste." } + const expected = " J'adore, j'aime, je n'aime pas, je déteste."; + compareString(expected, output.text); } { // Translate French to English with language autodetect let output = await transcriber(audioData, { language: null, task: "translate" }); expect(output.text.length).toBeGreaterThan(20); - // { text: " I love, I like, I don't like, I hate." } + const expected = " I love, I like, I don't like, I hate."; + compareString(expected, output.text); } await transcriber.dispose(); }, From 7cd642c34e01e0998b75a534d69a9123aa28c47c Mon Sep 17 00:00:00 2001 From: ae9is <125031666+ae9is@users.noreply.github.com> Date: Sun, 15 Dec 2024 00:13:43 +0000 Subject: [PATCH 5/7] Add a new logits processor to only generate allowed token IDs --- src/generation/logits_process.js | 34 ++++++++++++++++++ tests/utils/logits_process.test.js | 55 ++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/src/generation/logits_process.js b/src/generation/logits_process.js index f82634f75..4949a137d 100644 --- a/src/generation/logits_process.js +++ b/src/generation/logits_process.js @@ -572,6 +572,40 @@ export class NoBadWordsLogitsProcessor extends LogitsProcessor { } } +export class OnlyGoodWordsLogitsProcessor extends LogitsProcessor { + /** + * Create a `OnlyGoodWordsLogitsProcessor`. + * @param {number[][]} good_words_ids List of list of token ids that are allowed to be generated. + * @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + */ + constructor(good_words_ids, eos_token_id) { + super(); + this.good_words_ids = good_words_ids; + this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id]; + } + + /** + * Apply logit processor. + * @param {bigint[][]} input_ids The input IDs. + * @param {Tensor} logits The logits. + * @returns {Object} The processed logits. + */ + _call(input_ids, logits) { + const good_ids = this.good_words_ids.flat(); + // Iterate over batches of input IDs and logits + for (let i = 0; i < input_ids.length; ++i) { + const batch_logits_data = /** @type {Float32Array} */(logits[i].data); + // For every ID, set its logit score to -Infinity unless it's in our list of valid token IDs + for (let j = 0; j < batch_logits_data.length; ++j) { + if (!good_ids.includes(j)) { + batch_logits_data[j] = -Infinity; + } + } + } + return logits + } +} + /** * [`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension, * where the first half correspond to the conditional logits (predicted from the input prompt) and the second half diff --git a/tests/utils/logits_process.test.js b/tests/utils/logits_process.test.js index 5da188ed4..0055af5e0 100644 --- a/tests/utils/logits_process.test.js +++ b/tests/utils/logits_process.test.js @@ -81,6 +81,61 @@ describe("Logits Processors", () => { ); }); + describe("good_words_ids", () => { + it( + "generates nothing given empty good_words_ids", + async () => { + const text_input = "hello"; + const generated_text_target = ""; + const text_target = [{ generated_text: text_input + generated_text_target }]; + const output = await pipe(text_input, { + max_new_tokens: 5, + good_words_ids: [ + [], + ], + }); + compare(output, text_target); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "passes basic test", + async () => { + const text_input = "hello"; + // Default output tokens for this input: 22172,18547,8136,18547,8136 + // Default output text for this input: helloerdingsAndroid Load Между ligger + const generated_text_target = "Android helloAndroid hello hello"; + const text_target = [{ generated_text: text_input + generated_text_target }]; + const output = await pipe(text_input, { + max_new_tokens: 5, + good_words_ids: [ + [22172, 8136], // hello, Android + ], + }); + compare(output, text_target); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "passes test with many good words", + async () => { + const text_input = "hello"; + const generated_text_target = "erdingsAndroidierraég migli"; + const text_target = [{ generated_text: text_input + generated_text_target }]; + const good_words_ids = []; + for (let i = 0; i < 100000; ++i) { + good_words_ids.push([i * 2 + 1]); // allow all odd numbers + } + good_words_ids.push([22172, 8136]); + const output = await pipe(text_input, { max_new_tokens: 5, good_words_ids }); + compare(output, text_target); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + afterAll(async () => { await pipe?.dispose(); }, MAX_MODEL_DISPOSE_TIME); From ecdd59871b716d551bd8caf32c61a1339ae527de Mon Sep 17 00:00:00 2001 From: ae9is <125031666+ae9is@users.noreply.github.com> Date: Sun, 15 Dec 2024 00:16:21 +0000 Subject: [PATCH 6/7] Improve Whisper language detection performance --- src/generation/configuration_utils.js | 7 +++++++ src/generation/stopping_criteria.js | 13 +++++++++++++ src/models.js | 22 ++++++++++++++++++++-- 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/generation/configuration_utils.js b/src/generation/configuration_utils.js index 8474057da..639069b5a 100644 --- a/src/generation/configuration_utils.js +++ b/src/generation/configuration_utils.js @@ -197,6 +197,13 @@ export class GenerationConfig { */ bad_words_ids = null; + /** + * List of token ids that are allowed to be generated. + * @type {number[][]} + * @default null + */ + good_words_ids = null; + /** * List of token ids that must be generated. * If given a `number[][]`, this is treated as a simple list of words that must be included, the opposite to `bad_words_ids`. diff --git a/src/generation/stopping_criteria.js b/src/generation/stopping_criteria.js index 08434f2b4..aa3327a66 100644 --- a/src/generation/stopping_criteria.js +++ b/src/generation/stopping_criteria.js @@ -154,3 +154,16 @@ export class InterruptableStoppingCriteria extends StoppingCriteria { return new Array(input_ids.length).fill(this.interrupted); } } + +/** + * This class can be used to always stop generation after one pass. + */ +export class AlwaysStopCriteria extends StoppingCriteria { + constructor() { + super(); + } + + _call(input_ids, scores) { + return new Array(input_ids.length).fill(true); + } +} diff --git a/src/models.js b/src/models.js index 92e1761a7..27faf6ebc 100644 --- a/src/models.js +++ b/src/models.js @@ -90,6 +90,7 @@ import { TopKLogitsWarper, TopPLogitsWarper, ClassifierFreeGuidanceLogitsProcessor, + OnlyGoodWordsLogitsProcessor, } from './generation/logits_process.js'; import { @@ -112,7 +113,7 @@ import { import { RawImage } from './utils/image.js'; import { dynamic_time_warping, max, medianFilter } from './utils/maths.js'; -import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js'; +import { AlwaysStopCriteria, EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js'; import { LogitsSampler } from './generation/logits_sampler.js'; import { apis } from './env.js'; @@ -1212,6 +1213,10 @@ export class PreTrainedModel extends Callable { processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)); } + if (generation_config.good_words_ids !== null) { + processors.push(new OnlyGoodWordsLogitsProcessor(generation_config.good_words_ids, generation_config.eos_token_id)); + } + if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) { processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)); } @@ -3137,7 +3142,20 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { if (!all_lang_ids || all_lang_ids.length <= 0) { throw new Error("Cannot detect language without language code to token ID map for model"); } - const output = await this.generate({ ...options, decoder_input_ids }); + const stopping_criteria = new StoppingCriteriaList(); + stopping_criteria.push(new AlwaysStopCriteria()); + const good_words_ids = [all_lang_ids]; + const output = await this.generate({ + ...options, + generation_config: { + ...generation_config, + good_words_ids, + num_beams: 1, + do_sample: false, + }, + stopping_criteria, + decoder_input_ids, + }); const sane = Array.from((/**@type {Tensor}**/(output)).data).flatMap(x => Number(x)); const lang_ids = sane.filter(x => Object.values(generation_config.lang_to_id).includes(x)); return lang_ids; From db845409a11634a93da94bad699c228116d5c1a1 Mon Sep 17 00:00:00 2001 From: ae9is <125031666+ae9is@users.noreply.github.com> Date: Sun, 15 Dec 2024 00:31:35 +0000 Subject: [PATCH 7/7] Fix Whisper language detection pipeline test --- tests/pipelines.test.js | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/tests/pipelines.test.js b/tests/pipelines.test.js index 67d84c9b4..bd9abab66 100644 --- a/tests/pipelines.test.js +++ b/tests/pipelines.test.js @@ -775,27 +775,19 @@ xdescribe("Pipelines (ignored)", () => { ); it( - models[1], + `${models[1]}-language-detect`, async () => { let transcriber = await pipeline("automatic-speech-recognition", models[1]); - - let url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/french-audio.wav"; + let url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/japanese-audio.wav"; let audioData = await loadAudio(url); - { - // Transcribe French by autodetecting language + // Transcribe Japanese by autodetecting language + // Note: this sample needs to be hard enough for Whisper not to be able to transcribe it properly + // with the fallback 'en' language set! let output = await transcriber(audioData, { language: null, task: "transcribe" }); expect(output.text.length).toBeGreaterThan(20); - const expected = " J'adore, j'aime, je n'aime pas, je déteste."; - compareString(expected, output.text); - } - - { - // Translate French to English with language autodetect - let output = await transcriber(audioData, { language: null, task: "translate" }); - expect(output.text.length).toBeGreaterThan(20); - const expected = " I love, I like, I don't like, I hate."; - compareString(expected, output.text); + const expected = "モリナガの美味しい牛乳は濃い青色に牛乳瓶を払ったゼザインのパック牛乳である。"; + compareString(expected, output.text, 0.8); } await transcriber.dispose(); },