diff --git a/package-lock.json b/package-lock.json index 13a1d5d97..00c626b96 100644 --- a/package-lock.json +++ b/package-lock.json @@ -19,6 +19,7 @@ "@types/node": "^22.10.1", "@webgpu/types": "^0.1.51", "catharsis": "github:xenova/catharsis", + "fastest-levenshtein": "^1.0.16", "jest": "^30.0.0-alpha.6", "jest-environment-node": "^30.0.0-alpha.6", "jsdoc-to-markdown": "^9.1.1", @@ -4554,6 +4555,7 @@ "resolved": "https://registry.npmjs.org/fastest-levenshtein/-/fastest-levenshtein-1.0.16.tgz", "integrity": "sha512-eRnCtTTtGZFpQCwhJiUOuxPQWRXVKYDn0b2PeHfXL6/Zi53SLAzAHfVhVWK2AryC/WH05kGfxhFIPvTF0SXQzg==", "dev": true, + "license": "MIT", "engines": { "node": ">= 4.9.1" } diff --git a/package.json b/package.json index bc845c0df..3bf1df085 100644 --- a/package.json +++ b/package.json @@ -65,6 +65,7 @@ "@types/node": "^22.10.1", "@webgpu/types": "^0.1.51", "catharsis": "github:xenova/catharsis", + "fastest-levenshtein": "^1.0.16", "jest": "^30.0.0-alpha.6", "jest-environment-node": "^30.0.0-alpha.6", "jsdoc-to-markdown": "^9.1.1", diff --git a/src/generation/configuration_utils.js b/src/generation/configuration_utils.js index 8474057da..639069b5a 100644 --- a/src/generation/configuration_utils.js +++ b/src/generation/configuration_utils.js @@ -197,6 +197,13 @@ export class GenerationConfig { */ bad_words_ids = null; + /** + * List of token ids that are allowed to be generated. + * @type {number[][]} + * @default null + */ + good_words_ids = null; + /** * List of token ids that must be generated. * If given a `number[][]`, this is treated as a simple list of words that must be included, the opposite to `bad_words_ids`. diff --git a/src/generation/logits_process.js b/src/generation/logits_process.js index f82634f75..4949a137d 100644 --- a/src/generation/logits_process.js +++ b/src/generation/logits_process.js @@ -572,6 +572,40 @@ export class NoBadWordsLogitsProcessor extends LogitsProcessor { } } +export class OnlyGoodWordsLogitsProcessor extends LogitsProcessor { + /** + * Create a `OnlyGoodWordsLogitsProcessor`. + * @param {number[][]} good_words_ids List of list of token ids that are allowed to be generated. + * @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + */ + constructor(good_words_ids, eos_token_id) { + super(); + this.good_words_ids = good_words_ids; + this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id]; + } + + /** + * Apply logit processor. + * @param {bigint[][]} input_ids The input IDs. + * @param {Tensor} logits The logits. + * @returns {Object} The processed logits. + */ + _call(input_ids, logits) { + const good_ids = this.good_words_ids.flat(); + // Iterate over batches of input IDs and logits + for (let i = 0; i < input_ids.length; ++i) { + const batch_logits_data = /** @type {Float32Array} */(logits[i].data); + // For every ID, set its logit score to -Infinity unless it's in our list of valid token IDs + for (let j = 0; j < batch_logits_data.length; ++j) { + if (!good_ids.includes(j)) { + batch_logits_data[j] = -Infinity; + } + } + } + return logits + } +} + /** * [`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension, * where the first half correspond to the conditional logits (predicted from the input prompt) and the second half diff --git a/src/generation/stopping_criteria.js b/src/generation/stopping_criteria.js index 08434f2b4..aa3327a66 100644 --- a/src/generation/stopping_criteria.js +++ b/src/generation/stopping_criteria.js @@ -154,3 +154,16 @@ export class InterruptableStoppingCriteria extends StoppingCriteria { return new Array(input_ids.length).fill(this.interrupted); } } + +/** + * This class can be used to always stop generation after one pass. + */ +export class AlwaysStopCriteria extends StoppingCriteria { + constructor() { + super(); + } + + _call(input_ids, scores) { + return new Array(input_ids.length).fill(true); + } +} diff --git a/src/models.js b/src/models.js index 8e53e2fbe..27faf6ebc 100644 --- a/src/models.js +++ b/src/models.js @@ -90,6 +90,7 @@ import { TopKLogitsWarper, TopPLogitsWarper, ClassifierFreeGuidanceLogitsProcessor, + OnlyGoodWordsLogitsProcessor, } from './generation/logits_process.js'; import { @@ -112,7 +113,7 @@ import { import { RawImage } from './utils/image.js'; import { dynamic_time_warping, max, medianFilter } from './utils/maths.js'; -import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js'; +import { AlwaysStopCriteria, EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js'; import { LogitsSampler } from './generation/logits_sampler.js'; import { apis } from './env.js'; @@ -1212,6 +1213,10 @@ export class PreTrainedModel extends Callable { processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)); } + if (generation_config.good_words_ids !== null) { + processors.push(new OnlyGoodWordsLogitsProcessor(generation_config.good_words_ids, generation_config.eos_token_id)); + } + if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) { processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)); } @@ -3119,10 +3124,48 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { } /** + * Detects language by running input through the model and checking for language tokens in the output. * - * @param {WhisperGenerationConfig} generation_config + * @param {import('./models/whisper/generation_whisper.js').WhisperGenerationFunctionParameters} options + * @returns {Promise} A list of language token IDs detected. + */ + async _detect_language(options) { + const inputs = options.inputs + const generation_config = options.generation_config; + const batch_size = inputs?.dims?.[0] + if (!inputs || batch_size <= 0 || inputs.size <= 0) { + throw new Error("Cannot detect language for empty input"); + } + const start_of_transcript = generation_config.decoder_start_token_id; + const decoder_input_ids = full([batch_size, 1], Number(1.0)).mul_(start_of_transcript).tolist(); + const all_lang_ids = Object.values(generation_config.lang_to_id); + if (!all_lang_ids || all_lang_ids.length <= 0) { + throw new Error("Cannot detect language without language code to token ID map for model"); + } + const stopping_criteria = new StoppingCriteriaList(); + stopping_criteria.push(new AlwaysStopCriteria()); + const good_words_ids = [all_lang_ids]; + const output = await this.generate({ + ...options, + generation_config: { + ...generation_config, + good_words_ids, + num_beams: 1, + do_sample: false, + }, + stopping_criteria, + decoder_input_ids, + }); + const sane = Array.from((/**@type {Tensor}**/(output)).data).flatMap(x => Number(x)); + const lang_ids = sane.filter(x => Object.values(generation_config.lang_to_id).includes(x)); + return lang_ids; + } + + /** + * @param {import('./models/whisper/generation_whisper.js').WhisperGenerationFunctionParameters} options */ - _retrieve_init_tokens(generation_config) { + async _retrieve_init_tokens(options) { + const generation_config = options.generation_config // prefix tokens are of the form: // - Multilingual: <|startoftranscript|> <|lang_id|> <|task|> [<|notimestamps|>] // - English-only: <|startoftranscript|> [<|notimestamps|>] @@ -3134,16 +3177,26 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { let language = generation_config.language; const task = generation_config.task; if (generation_config.is_multilingual) { + let lang_id; if (!language) { - // TODO: Implement language detection - console.warn('No language specified - defaulting to English (en).'); - language = 'en'; + try { + const lang_token_ids = await this._detect_language(options); + lang_id = lang_token_ids[0]; + if (!lang_id) { + throw new Error("No language detected"); + } + } catch (err) { + console.warn("No language detected - defaulting to English (en)."); + language = "en"; + } } - - // Add language token - const language_code = whisper_language_to_code(language); - const language_token = `<|${language_code}|>`; - init_tokens.push(generation_config.lang_to_id[language_token]) + if (language) { + // Add language token + const language_code = whisper_language_to_code(language); + const language_token = `<|${language_code}|>`; + lang_id = generation_config.lang_to_id[language_token]; + } + init_tokens.push(lang_id); // Add task token // NOTE: Defaults to 'transcribe' if no task is specified @@ -3180,22 +3233,24 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { * @param {import('./models/whisper/generation_whisper.js').WhisperGenerationFunctionParameters} options * @returns {Promise} The output of the model, which can contain the generated token ids, attentions, and scores. */ - async generate({ - inputs = null, - generation_config = null, - logits_processor = null, - stopping_criteria = null, + async generate(options) { + let { + inputs = null, + generation_config = null, + logits_processor = null, + //stopping_criteria = null, - // Whisper-specific options (passed to kwargs) - // prompt_ids = null, - // language = null, - // task = null, + // Whisper-specific options (passed to kwargs) + // prompt_ids = null, + // language = null, + // task = null, + + ...kwargs + } = options; - ...kwargs - }) { generation_config = this._prepare_generation_config(generation_config, kwargs); - const init_tokens = kwargs.decoder_input_ids ?? this._retrieve_init_tokens(generation_config); + const init_tokens = kwargs.decoder_input_ids ?? await this._retrieve_init_tokens({ ...options, generation_config }); if (generation_config.return_timestamps) { logits_processor ??= new LogitsProcessorList(); diff --git a/tests/models/whisper/test_modeling_whisper.js b/tests/models/whisper/test_modeling_whisper.js index fff3cc1f7..713e33ada 100644 --- a/tests/models/whisper/test_modeling_whisper.js +++ b/tests/models/whisper/test_modeling_whisper.js @@ -53,9 +53,8 @@ export default () => { it( "language unset; task unset", async () => { - // language defaults to 'en' + // language defaults to detect, falling back to 'en' // task defaults to 'transcribe' - const outputs = await model.generate({ input_features, max_new_tokens: 1, @@ -66,6 +65,21 @@ export default () => { MAX_TEST_EXECUTION_TIME, ); + it( + "language unset; task set", + async () => { + // language defaults to detect, falling back to 'en' + const outputs = await model.generate({ + input_features, + max_new_tokens: 1, + task: "translate", + }); + + expect(outputs.tolist()).toEqual([[/* Prefix */ 50258n, 50259n, 50358n, 50363n, /* Generated */ 45084n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + it( "language set; task unset", async () => { diff --git a/tests/pipelines.test.js b/tests/pipelines.test.js index bfdef3872..bd9abab66 100644 --- a/tests/pipelines.test.js +++ b/tests/pipelines.test.js @@ -1,6 +1,6 @@ import { pipeline, cos_sim } from "../src/transformers.js"; import { init, MAX_TEST_EXECUTION_TIME } from "./init.js"; -import { collect_and_execute_pipeline_tests, compare, loadAudio } from "./test_utils.js"; +import { collect_and_execute_pipeline_tests, compare, compareString, loadAudio } from "./test_utils.js"; // Initialise the testing environment init(); @@ -724,7 +724,8 @@ xdescribe("Pipelines (ignored)", () => { // Transcribe English let output = await transcriber(audioData); expect(output.text.length).toBeGreaterThan(50); - // { text: " And so my fellow Americans ask not what your country can do for you, ask what you can do for your country." } + const expected = " And so my fellow Americans ask not what your country can do for you, ask what you can do for your country."; + compareString(expected, output.text); } { @@ -757,20 +758,42 @@ xdescribe("Pipelines (ignored)", () => { // Transcribe French let output = await transcriber(audioData, { language: "french", task: "transcribe" }); expect(output.text.length).toBeGreaterThan(20); - // { text: " J'adore, j'aime, je n'aime pas, je déteste." } + const expected = " J'adore, j'aime, je n'aime pas, je déteste."; + compareString(expected, output.text); } { // Translate French to English. let output = await transcriber(audioData, { language: "french", task: "translate" }); expect(output.text.length).toBeGreaterThan(20); - // { text: " I love, I like, I don't like, I hate." } + const expected = " I love, I like, I don't like, I hate."; + compareString(expected, output.text); } await transcriber.dispose(); }, MAX_TEST_EXECUTION_TIME, ); - + + it( + `${models[1]}-language-detect`, + async () => { + let transcriber = await pipeline("automatic-speech-recognition", models[1]); + let url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/japanese-audio.wav"; + let audioData = await loadAudio(url); + { + // Transcribe Japanese by autodetecting language + // Note: this sample needs to be hard enough for Whisper not to be able to transcribe it properly + // with the fallback 'en' language set! + let output = await transcriber(audioData, { language: null, task: "transcribe" }); + expect(output.text.length).toBeGreaterThan(20); + const expected = "モリナガの美味しい牛乳は濃い青色に牛乳瓶を払ったゼザインのパック牛乳である。"; + compareString(expected, output.text, 0.8); + } + await transcriber.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + it( models[2].join(" + "), async () => { diff --git a/tests/test_utils.js b/tests/test_utils.js index c42c5f201..ca744d3fc 100644 --- a/tests/test_utils.js +++ b/tests/test_utils.js @@ -2,6 +2,8 @@ import fs from "fs"; import path from "path"; import { fileURLToPath } from "url"; +import { distance } from "fastest-levenshtein"; + export async function loadAudio(url) { // NOTE: Since the Web Audio API is not available in Node.js, we will need to use the `wavefile` library to obtain the raw audio data. // For more information, see: https://huggingface.co/docs/transformers.js/guides/node-audio-processing @@ -68,6 +70,22 @@ export function compare(val1, val2, tol = 0.1) { } } +/** + * Compare two strings adding some tolerance for variation between model outputs. + * + * Similarity score is computing using Levenshtein distance (n_diff) between the two strings, as a fraction of the first string's length: + * similarity score = 1 - n_diff / str1.length. + * + * @param {string} str1 The first string + * @param {string} str2 The second string + * @param {number} tol Tolerance score for similarity between strings, from -Infinity to 1.0 (100% match). + */ +export function compareString(str1, str2, tol = 0.9) { + const dist = distance(str1, str2); + const score = 1 - dist / (str1.length ?? 1); + expect(score).toBeGreaterThanOrEqual(tol); +} + const __filename = fileURLToPath(import.meta.url); const __dirname = path.dirname(__filename); const models_dir = path.join(__dirname, "models"); diff --git a/tests/utils/logits_process.test.js b/tests/utils/logits_process.test.js index 5da188ed4..0055af5e0 100644 --- a/tests/utils/logits_process.test.js +++ b/tests/utils/logits_process.test.js @@ -81,6 +81,61 @@ describe("Logits Processors", () => { ); }); + describe("good_words_ids", () => { + it( + "generates nothing given empty good_words_ids", + async () => { + const text_input = "hello"; + const generated_text_target = ""; + const text_target = [{ generated_text: text_input + generated_text_target }]; + const output = await pipe(text_input, { + max_new_tokens: 5, + good_words_ids: [ + [], + ], + }); + compare(output, text_target); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "passes basic test", + async () => { + const text_input = "hello"; + // Default output tokens for this input: 22172,18547,8136,18547,8136 + // Default output text for this input: helloerdingsAndroid Load Между ligger + const generated_text_target = "Android helloAndroid hello hello"; + const text_target = [{ generated_text: text_input + generated_text_target }]; + const output = await pipe(text_input, { + max_new_tokens: 5, + good_words_ids: [ + [22172, 8136], // hello, Android + ], + }); + compare(output, text_target); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "passes test with many good words", + async () => { + const text_input = "hello"; + const generated_text_target = "erdingsAndroidierraég migli"; + const text_target = [{ generated_text: text_input + generated_text_target }]; + const good_words_ids = []; + for (let i = 0; i < 100000; ++i) { + good_words_ids.push([i * 2 + 1]); // allow all odd numbers + } + good_words_ids.push([22172, 8136]); + const output = await pipe(text_input, { max_new_tokens: 5, good_words_ids }); + compare(output, text_target); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + afterAll(async () => { await pipe?.dispose(); }, MAX_MODEL_DISPOSE_TIME);