diff --git a/.prettierrc b/.prettierrc index 0967ef424..57d5ce89a 100644 --- a/.prettierrc +++ b/.prettierrc @@ -1 +1,10 @@ -{} +{ + "overrides": [ + { + "files": ["tests/**/*.js"], + "options": { + "printWidth": 10000000 + } + } + ] +} diff --git a/tests/configs.test.js b/tests/configs.test.js index fb16879b6..f66a8a887 100644 --- a/tests/configs.test.js +++ b/tests/configs.test.js @@ -1,27 +1,23 @@ - - -import { AutoConfig, env } from '../src/transformers.js'; -import { getFile } from '../src/utils/hub.js'; +import { AutoConfig, env } from "../src/transformers.js"; +import { getFile } from "../src/utils/hub.js"; // Initialise the testing environment env.allowLocalModels = false; env.useFSCache = false; const TEST_DATA = { - 'Xenova/bert-base-uncased': { - model_type: 'bert', - }, -} - -describe('Configs', () => { - - for (const [model_id, minimal_config] of Object.entries(TEST_DATA)) { + "Xenova/bert-base-uncased": { + model_type: "bert", + }, +}; - it(model_id, async () => { - const config = await AutoConfig.from_pretrained(model_id); - for(const [key, value] of Object.entries(minimal_config)) { - expect(config[key]).toEqual(value); - } - }); - } +describe("Configs", () => { + for (const [model_id, minimal_config] of Object.entries(TEST_DATA)) { + it(model_id, async () => { + const config = await AutoConfig.from_pretrained(model_id); + for (const [key, value] of Object.entries(minimal_config)) { + expect(config[key]).toEqual(value); + } + }); + } }); diff --git a/tests/data_structures.test.js b/tests/data_structures.test.js index c1a95307a..2c744f755 100644 --- a/tests/data_structures.test.js +++ b/tests/data_structures.test.js @@ -1,36 +1,33 @@ +import { PriorityQueue } from "../src/utils/data-structures.js"; +describe("Priority queue", () => { + const EXAMPLE_ARRAY = [2, 5, 3, 1, 4]; + it("default (max heap)", () => { + const queue = new PriorityQueue(); + queue.extend(EXAMPLE_ARRAY); + expect(queue.pop()).toBe(5); + }); -import { PriorityQueue } from '../src/utils/data-structures.js'; + it("min heap", () => { + const queue = new PriorityQueue((a, b) => a < b); + queue.extend(EXAMPLE_ARRAY); + expect(queue.pop()).toBe(1); + }); + it("heap w/ max size", () => { + const queue = new PriorityQueue((a, b) => a > b, 3); + queue.extend([1, 2, 3, 4, 5, 4, 3, 2, 1]); + expect(queue.pop()).toBe(5); -describe('Priority queue', () => { - const EXAMPLE_ARRAY = [2, 5, 3, 1, 4]; - it('default (max heap)', () => { - const queue = new PriorityQueue(); - queue.extend(EXAMPLE_ARRAY); - expect(queue.pop()).toBe(5); - }); - - it('min heap', () => { - const queue = new PriorityQueue((a, b) => a < b); - queue.extend(EXAMPLE_ARRAY); - expect(queue.pop()).toBe(1); - }); - - it('heap w/ max size', () => { - const queue = new PriorityQueue((a, b) => a > b, 3); - queue.extend([1, 2, 3, 4, 5, 4, 3, 2, 1]); - expect(queue.pop()).toBe(5); - - // Test with random sizes - const sizes = [1, 3, 4, 5, 8, 9, 15, 16, 31, 32, 127, 128]; - const arr = Array.from({ length: 100 }, _ => Math.random()); - const max = Math.max(...arr); - for (const size of sizes) { - const queue = new PriorityQueue((a, b) => a > b, size); - queue.extend(arr); - expect(queue.pop()).toBe(max); - expect(queue.size).toBeLessThanOrEqual(size); - } - }); + // Test with random sizes + const sizes = [1, 3, 4, 5, 8, 9, 15, 16, 31, 32, 127, 128]; + const arr = Array.from({ length: 100 }, (_) => Math.random()); + const max = Math.max(...arr); + for (const size of sizes) { + const queue = new PriorityQueue((a, b) => a > b, size); + queue.extend(arr); + expect(queue.pop()).toBe(max); + expect(queue.size).toBeLessThanOrEqual(size); + } + }); }); diff --git a/tests/init.js b/tests/init.js index e81a1f310..4a9d49177 100644 --- a/tests/init.js +++ b/tests/init.js @@ -7,71 +7,54 @@ import * as types from "node:util/types"; import { onnxruntimeBackend } from "onnxruntime-node/dist/backend"; import * as ONNX_COMMON from "onnxruntime-common"; - /** * A workaround to define a new backend for onnxruntime, which * will not throw an error when running tests with jest. * For more information, see: https://github.com/jestjs/jest/issues/11864#issuecomment-1261468011 */ export function init() { - // In rare cases (specifically when running unit tests with GitHub actions), possibly due to - // a large number of concurrent executions, onnxruntime might fallback to use the WASM backend. - // In this case, we set the number of threads to 1 to avoid errors like: - // - `TypeError: The worker script or module filename must be an absolute path or a relative path starting with './' or '../'. Received "blob:nodedata:..."` - ONNX_COMMON.env.wasm.numThreads = 1; - - let registerBackend = ONNX_COMMON.registerBackend; - - // Define the constructors to monkey-patch - const TYPED_ARRAYS_CONSTRUCTOR_NAMES = [ - "Int8Array", - "Int16Array", - "Int32Array", - "BigInt64Array", - "Uint8Array", - "Uint8ClampedArray", - "Uint16Array", - "Uint32Array", - "BigUint64Array", - "Float32Array", - "Float64Array", - ]; - - // Keep a reference to the original initialization method - const originalMethod = onnxruntimeBackend.init; - - // Monkey-patch the initialization function - onnxruntimeBackend.init = function (...args) { - // There is probably a better way to do this - Array.isArray = x => - typeof x === "object" && - x !== null && - typeof x.length === "number" && - x?.constructor.toString() === Array.toString(); - - // For each typed array constructor - for (const ctorName of TYPED_ARRAYS_CONSTRUCTOR_NAMES) { - // Get the constructor from the current context - const ctor = globalThis[ctorName]; - - // Get the corresponding test function from the `util` module - const value = types[`is${ctorName}`].bind(types); - - // Monkey-patch the constructor so "x instanceof ctor" returns "types[`is${ctorName}`](x)" - Object.defineProperty(ctor, Symbol.hasInstance, { - value, - writable: true, // writable=true is necessary to overwrite the default implementation (and allow subsequent overwrites) - configurable: false, - enumerable: false, - }); - } - - // Call the original method - return originalMethod.apply(this, args); - }; - - // Register the backend with the highest priority, so it is used instead of the default one - registerBackend("test", onnxruntimeBackend, Number.POSITIVE_INFINITY); + // In rare cases (specifically when running unit tests with GitHub actions), possibly due to + // a large number of concurrent executions, onnxruntime might fallback to use the WASM backend. + // In this case, we set the number of threads to 1 to avoid errors like: + // - `TypeError: The worker script or module filename must be an absolute path or a relative path starting with './' or '../'. Received "blob:nodedata:..."` + ONNX_COMMON.env.wasm.numThreads = 1; + + let registerBackend = ONNX_COMMON.registerBackend; + + // Define the constructors to monkey-patch + const TYPED_ARRAYS_CONSTRUCTOR_NAMES = ["Int8Array", "Int16Array", "Int32Array", "BigInt64Array", "Uint8Array", "Uint8ClampedArray", "Uint16Array", "Uint32Array", "BigUint64Array", "Float32Array", "Float64Array"]; + + // Keep a reference to the original initialization method + const originalMethod = onnxruntimeBackend.init; + + // Monkey-patch the initialization function + onnxruntimeBackend.init = function (...args) { + // There is probably a better way to do this + Array.isArray = (x) => typeof x === "object" && x !== null && typeof x.length === "number" && x?.constructor.toString() === Array.toString(); + + // For each typed array constructor + for (const ctorName of TYPED_ARRAYS_CONSTRUCTOR_NAMES) { + // Get the constructor from the current context + const ctor = globalThis[ctorName]; + + // Get the corresponding test function from the `util` module + const value = types[`is${ctorName}`].bind(types); + + // Monkey-patch the constructor so "x instanceof ctor" returns "types[`is${ctorName}`](x)" + Object.defineProperty(ctor, Symbol.hasInstance, { + value, + writable: true, // writable=true is necessary to overwrite the default implementation (and allow subsequent overwrites) + configurable: false, + enumerable: false, + }); + } + + // Call the original method + return originalMethod.apply(this, args); + }; + + // Register the backend with the highest priority, so it is used instead of the default one + registerBackend("test", onnxruntimeBackend, Number.POSITIVE_INFINITY); } export const MAX_MODEL_LOAD_TIME = 10_000; // 10 seconds diff --git a/tests/models.test.js b/tests/models.test.js index d7de4f381..f1bc7961c 100644 --- a/tests/models.test.js +++ b/tests/models.test.js @@ -2,148 +2,129 @@ * Test that models loaded outside of the `pipeline` function work correctly (e.g., `AutoModel.from_pretrained(...)`); */ -import { - AutoTokenizer, - AutoModel, - AutoProcessor, +import { AutoTokenizer, AutoModel, AutoProcessor, BertModel, GPT2Model, T5ForConditionalGeneration, CLIPTextModelWithProjection, CLIPVisionModelWithProjection, BertTokenizer, GPT2Tokenizer, T5Tokenizer, RawImage } from "../src/transformers.js"; - BertModel, - GPT2Model, - T5ForConditionalGeneration, - CLIPTextModelWithProjection, - CLIPVisionModelWithProjection, +import { init, MAX_TEST_EXECUTION_TIME } from "./init.js"; - BertTokenizer, - GPT2Tokenizer, - T5Tokenizer, - - RawImage, -} from '../src/transformers.js'; - -import { init, MAX_TEST_EXECUTION_TIME } from './init.js'; - -import { compare } from './test_utils.js'; +import { compare } from "./test_utils.js"; // Initialise the testing environment init(); -describe('Models', () => { - - describe('Loading different architecture types', () => { - - // List all models which will be tested - const models_to_test = [ - // [name, modelClass, tokenizerClass] - ['hf-internal-testing/tiny-random-BertForMaskedLM', BertModel, BertTokenizer], // Encoder-only - ['hf-internal-testing/tiny-random-GPT2LMHeadModel', GPT2Model, GPT2Tokenizer], // Decoder-only - ['hf-internal-testing/tiny-random-T5ForConditionalGeneration', T5ForConditionalGeneration, T5Tokenizer], // Encoder-decoder - ]; - - const texts = [ - 'Once upon a time', - 'I like to eat apples', - ]; - - for (const [model_id, modelClass, tokenizerClass] of models_to_test) { - - // Test that both the auto model and the specific model work - const tokenizers = [AutoTokenizer, tokenizerClass]; - const models = [AutoModel, modelClass]; - - for (let i = 0; i < tokenizers.length; ++i) { - const tokenizerClassToTest = tokenizers[i]; - const modelClassToTest = models[i]; - - it(`${model_id} (${modelClassToTest.name})`, async () => { - - // Load model and tokenizer - const tokenizer = await tokenizerClassToTest.from_pretrained(model_id); - const model = await modelClassToTest.from_pretrained(model_id); - - const tests = [ - texts[0], // single - texts, // batched - ] - for (const test of tests) { - const inputs = await tokenizer(test, { truncation: true, padding: true }); - if (model.config.is_encoder_decoder) { - inputs.decoder_input_ids = inputs.input_ids; - } - const output = await model(inputs); - - if (output.logits) { - // Ensure correct shapes - const expected_shape = [...inputs.input_ids.dims, model.config.vocab_size]; - const actual_shape = output.logits.dims; - compare(expected_shape, actual_shape); - } else if (output.last_hidden_state) { - const expected_shape = [...inputs.input_ids.dims, model.config.d_model]; - const actual_shape = output.last_hidden_state.dims; - compare(expected_shape, actual_shape); - } else { - console.warn('Unexpected output', output); - throw new Error('Unexpected output'); - } - } - - await model.dispose(); - - }, MAX_TEST_EXECUTION_TIME); - +describe("Models", () => { + describe("Loading different architecture types", () => { + // List all models which will be tested + const models_to_test = [ + // [name, modelClass, tokenizerClass] + ["hf-internal-testing/tiny-random-BertForMaskedLM", BertModel, BertTokenizer], // Encoder-only + ["hf-internal-testing/tiny-random-GPT2LMHeadModel", GPT2Model, GPT2Tokenizer], // Decoder-only + ["hf-internal-testing/tiny-random-T5ForConditionalGeneration", T5ForConditionalGeneration, T5Tokenizer], // Encoder-decoder + ]; + + const texts = ["Once upon a time", "I like to eat apples"]; + + for (const [model_id, modelClass, tokenizerClass] of models_to_test) { + // Test that both the auto model and the specific model work + const tokenizers = [AutoTokenizer, tokenizerClass]; + const models = [AutoModel, modelClass]; + + for (let i = 0; i < tokenizers.length; ++i) { + const tokenizerClassToTest = tokenizers[i]; + const modelClassToTest = models[i]; + + it( + `${model_id} (${modelClassToTest.name})`, + async () => { + // Load model and tokenizer + const tokenizer = await tokenizerClassToTest.from_pretrained(model_id); + const model = await modelClassToTest.from_pretrained(model_id); + + const tests = [ + texts[0], // single + texts, // batched + ]; + for (const test of tests) { + const inputs = await tokenizer(test, { truncation: true, padding: true }); + if (model.config.is_encoder_decoder) { + inputs.decoder_input_ids = inputs.input_ids; + } + const output = await model(inputs); + + if (output.logits) { + // Ensure correct shapes + const expected_shape = [...inputs.input_ids.dims, model.config.vocab_size]; + const actual_shape = output.logits.dims; + compare(expected_shape, actual_shape); + } else if (output.last_hidden_state) { + const expected_shape = [...inputs.input_ids.dims, model.config.d_model]; + const actual_shape = output.last_hidden_state.dims; + compare(expected_shape, actual_shape); + } else { + console.warn("Unexpected output", output); + throw new Error("Unexpected output"); + } } - } - - }); - - describe('Running specific models', () => { - const models_to_test = [ - 'hf-internal-testing/tiny-random-CLIPModel', - ]; - it(`CLIP (text)`, async () => { - const model_id = models_to_test[0]; - - // Load tokenizer and text model - const tokenizer = await AutoTokenizer.from_pretrained(model_id); - const text_model = await CLIPTextModelWithProjection.from_pretrained(model_id, { revision: 'refs/pr/5' }); - - // Run tokenization - const texts = ['a photo of a car', 'a photo of a football match']; - const text_inputs = tokenizer(texts, { padding: true, truncation: true }); - - // Compute embeddings - const { text_embeds } = await text_model(text_inputs); - - // Ensure correct shapes - const expected_shape = [texts.length, text_model.config.projection_dim]; - const actual_shape = text_embeds.dims; - compare(expected_shape, actual_shape); - - await text_model.dispose(); - - }, MAX_TEST_EXECUTION_TIME); - - it(`CLIP (vision)`, async () => { - const model_id = models_to_test[0]; - - // Load processor and vision model - const processor = await AutoProcessor.from_pretrained(model_id); - const vision_model = await CLIPVisionModelWithProjection.from_pretrained(model_id, { revision: 'refs/pr/5' }); - - // Read image and run processor - const image = await RawImage.read('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg'); - const image_inputs = await processor(image); - - // Compute embeddings - const { image_embeds } = await vision_model(image_inputs); - - // Ensure correct shapes - const expected_shape = [1, vision_model.config.projection_dim]; - const actual_shape = image_embeds.dims; - compare(expected_shape, actual_shape); - - await vision_model.dispose(); - - }, MAX_TEST_EXECUTION_TIME); - }); + await model.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + } + } + }); + + describe("Running specific models", () => { + const models_to_test = ["hf-internal-testing/tiny-random-CLIPModel"]; + it( + `CLIP (text)`, + async () => { + const model_id = models_to_test[0]; + + // Load tokenizer and text model + const tokenizer = await AutoTokenizer.from_pretrained(model_id); + const text_model = await CLIPTextModelWithProjection.from_pretrained(model_id, { revision: "refs/pr/5" }); + + // Run tokenization + const texts = ["a photo of a car", "a photo of a football match"]; + const text_inputs = tokenizer(texts, { padding: true, truncation: true }); + + // Compute embeddings + const { text_embeds } = await text_model(text_inputs); + + // Ensure correct shapes + const expected_shape = [texts.length, text_model.config.projection_dim]; + const actual_shape = text_embeds.dims; + compare(expected_shape, actual_shape); + + await text_model.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + `CLIP (vision)`, + async () => { + const model_id = models_to_test[0]; + + // Load processor and vision model + const processor = await AutoProcessor.from_pretrained(model_id); + const vision_model = await CLIPVisionModelWithProjection.from_pretrained(model_id, { revision: "refs/pr/5" }); + + // Read image and run processor + const image = await RawImage.read("https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg"); + const image_inputs = await processor(image); + + // Compute embeddings + const { image_embeds } = await vision_model(image_inputs); + + // Ensure correct shapes + const expected_shape = [1, vision_model.config.projection_dim]; + const actual_shape = image_embeds.dims; + compare(expected_shape, actual_shape); + + await vision_model.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); }); diff --git a/tests/pipelines.test.js b/tests/pipelines.test.js index 1bca594da..6bef83297 100644 --- a/tests/pipelines.test.js +++ b/tests/pipelines.test.js @@ -1,7 +1,6 @@ - -import { pipeline, cos_sim } from '../src/transformers.js'; -import { init, MAX_TEST_EXECUTION_TIME } from './init.js'; -import { compare, loadAudio } from './test_utils.js'; +import { pipeline, cos_sim } from "../src/transformers.js"; +import { init, MAX_TEST_EXECUTION_TIME } from "./init.js"; +import { compare, loadAudio } from "./test_utils.js"; // Initialise the testing environment init(); @@ -11,573 +10,550 @@ init(); // This is due to how model construction and destruction occurs, in `beforeAll` and `afterAll`, respectively. // As a result, each test is responsible for exactly one model, but we run multiple inputs through it. // By encapsulating model construction and destruction in a single `it` block, we avoid these memory issues. -xdescribe('Pipelines', () => { - - describe('Text classification', () => { - - // List all models which will be tested - const models = [ - 'Xenova/distilbert-base-uncased-finetuned-sst-2-english', - 'Xenova/toxic-bert', +xdescribe("Pipelines", () => { + describe("Text classification", () => { + // List all models which will be tested + const models = ["Xenova/distilbert-base-uncased-finetuned-sst-2-english", "Xenova/toxic-bert"]; + + // single_label_classification + it( + models[0], + async () => { + let classifier = await pipeline("text-classification", models[0]); + let texts = ["This was a masterpiece. Not completely faithful to the books, but enthralling from beginning to end. Might be my favorite of the three.", "I hated the movie"]; + + // single + { + let outputs = await classifier("I hated the movie"); + let expected = [{ label: "NEGATIVE", score: 0.9996212720870972 }]; + compare(outputs, expected); + } + + // single + topk + { + let outputs = await classifier("I hated the movie", { + topk: 2, + }); + let expected = [ + { label: "NEGATIVE", score: 0.9996212720870972 }, + { label: "POSITIVE", score: 0.0003787268069572747 }, + ]; + compare(outputs, expected); + } + + // batched + { + let outputs = await classifier(texts); + + let expected = [ + { label: "POSITIVE", score: 0.9993746876716614 }, + { label: "NEGATIVE", score: 0.9996694326400757 }, + ]; + + compare(outputs, expected); + } + + // batched + topk + { + let outputs = await classifier(texts, { + topk: 2, + }); + + let expected = [ + [ + { label: "POSITIVE", score: 0.9993746876716614 }, + { label: "NEGATIVE", score: 0.0006253048195503652 }, + ], + [ + { label: "NEGATIVE", score: 0.9996694326400757 }, + { label: "POSITIVE", score: 0.00033057318069040775 }, + ], + ]; + + compare(outputs, expected); + } + + await classifier.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + + // multi_label_classification + it( + models[1], + async () => { + let classifier = await pipeline("text-classification", models[1]); + let texts = [ + "I like you. I love you", // low scores + "I hate you.", // high scores ]; - // single_label_classification - it(models[0], async () => { - let classifier = await pipeline('text-classification', models[0]); - let texts = [ - "This was a masterpiece. Not completely faithful to the books, but enthralling from beginning to end. Might be my favorite of the three.", - "I hated the movie" - ]; - - // single + // single + { + let outputs = await classifier(texts); + let expected = [ + { label: "toxic", score: 0.0007729064091108739 }, + { label: "toxic", score: 0.9475088119506836 }, + ]; + compare(outputs, expected); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Token classification", () => { + // List all models which will be tested + const models = ["Xenova/bert-base-multilingual-cased-ner-hrl"]; + + it( + models[0], + async () => { + let classifier = await pipeline("token-classification", models[0]); + let texts = ["The Golden State Warriors are an American professional basketball team based in San Francisco.", "My name is Sarah and I live in London."]; + + // single + { + let outputs = await classifier(texts[0]); + + let expected = [ + { entity: "B-ORG", score: 0.9998535513877869, index: 2, word: "Golden", start: null, end: null }, + { entity: "I-ORG", score: 0.9998612999916077, index: 3, word: "State", start: null, end: null }, + { entity: "I-ORG", score: 0.999866247177124, index: 4, word: "Warriors", start: null, end: null }, + { entity: "B-LOC", score: 0.9997050166130066, index: 13, word: "San", start: null, end: null }, + { entity: "I-LOC", score: 0.9987282156944275, index: 14, word: "Francisco", start: null, end: null }, + ]; + + compare(outputs, expected, 0.05); + } + + // batched + { + let outputs = await classifier(texts); + + let expected = [ + [ + { entity: "B-ORG", score: 0.9998375773429871, index: 2, word: "Golden", start: null, end: null }, + { entity: "I-ORG", score: 0.9998642206192017, index: 3, word: "State", start: null, end: null }, + { entity: "I-ORG", score: 0.9998642802238464, index: 4, word: "Warriors", start: null, end: null }, + { entity: "B-LOC", score: 0.9996914863586426, index: 13, word: "San", start: null, end: null }, + { entity: "I-LOC", score: 0.9989780783653259, index: 14, word: "Francisco", start: null, end: null }, + ], + [ + { entity: "B-PER", score: 0.997977614402771, index: 4, word: "Sarah", start: null, end: null }, + { entity: "B-LOC", score: 0.9996902346611023, index: 9, word: "London", start: null, end: null }, + ], + ]; + + compare(outputs, expected, 0.05); + } + + await classifier.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Zero-shot classification", () => { + // List all models which will be tested + const models = ["Xenova/bart-large-mnli"]; + + it( + models[0], + async () => { + let classifier = await pipeline("zero-shot-classification", models[0]); + + let sequences_to_classify = ["one day I will see the world", "I love making pizza"]; + let candidate_labels = ["travel", "cooking", "dancing"]; + + // single + { + let outputs = await classifier(sequences_to_classify[0], candidate_labels); + let expected = { + sequence: "one day I will see the world", + labels: ["travel", "dancing", "cooking"], + scores: [0.4261703487477968, 0.2903585771517135, 0.28347107410048983], + }; + + compare(outputs, expected, 0.2); + } + + // batched + { + let outputs = await classifier(sequences_to_classify, candidate_labels); + let expected = [ { - let outputs = await classifier("I hated the movie"); - let expected = [ - { "label": "NEGATIVE", "score": 0.9996212720870972 } - ]; - compare(outputs, expected); - } - - // single + topk + sequence: "one day I will see the world", + labels: ["travel", "dancing", "cooking"], + scores: [0.4261703487477968, 0.2903585771517135, 0.28347107410048983], + }, { - let outputs = await classifier("I hated the movie", { - topk: 2 - }); - let expected = [ - { "label": "NEGATIVE", "score": 0.9996212720870972 }, - { "label": "POSITIVE", "score": 0.0003787268069572747 } - ]; - compare(outputs, expected); - } - - // batched + sequence: "I love making pizza", + labels: ["cooking", "travel", "dancing"], + scores: [0.4660367922118968, 0.2756005926506238, 0.2583626151374795], + }, + ]; + + compare(outputs, expected, 0.2); + } + + // batched + multilabel + { + let outputs = await classifier(sequences_to_classify, candidate_labels, { + multi_label: true, + }); + let expected = [ { - let outputs = await classifier(texts); - - let expected = [ - { "label": "POSITIVE", "score": 0.9993746876716614 }, - { "label": "NEGATIVE", "score": 0.9996694326400757 } - ]; - - compare(outputs, expected); - } - - - // batched + topk + sequence: "one day I will see the world", + labels: ["travel", "dancing", "cooking"], + scores: [0.7108286792234982, 0.5763787804099745, 0.44303326070949994], + }, { - let outputs = await classifier(texts, { - topk: 2 - }); - - let expected = [[ - { "label": "POSITIVE", "score": 0.9993746876716614 }, - { "label": "NEGATIVE", "score": 0.0006253048195503652 } - ], [ - { "label": "NEGATIVE", "score": 0.9996694326400757 }, - { "label": "POSITIVE", "score": 0.00033057318069040775 } - ]]; - - compare(outputs, expected); - } - - - await classifier.dispose(); - - }, MAX_TEST_EXECUTION_TIME); - - // multi_label_classification - it(models[1], async () => { - let classifier = await pipeline('text-classification', models[1]); - let texts = [ - "I like you. I love you", // low scores - "I hate you." // high scores - ]; - - // single + sequence: "I love making pizza", + labels: ["cooking", "travel", "dancing"], + scores: [0.8527619536354446, 0.7899589317978243, 0.5838912691496106], + }, + ]; + + compare(outputs, expected); + } + + await classifier.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Masked language modelling", () => { + // List all models which will be tested + const models = ["Xenova/bert-base-uncased"]; + + it( + models[0], + async () => { + let unmasker = await pipeline("fill-mask", models[0]); + let texts = ["Once upon a [MASK].", "[MASK] is the capital of England."]; + + // single + { + let outputs = await unmasker(texts[0]); + let expected = [ { - let outputs = await classifier(texts); - let expected = [ - { label: 'toxic', score: 0.0007729064091108739 }, - { label: 'toxic', score: 0.9475088119506836 } - ] - compare(outputs, expected); - } - }, MAX_TEST_EXECUTION_TIME); - - - }); - - describe('Token classification', () => { - - // List all models which will be tested - const models = [ - 'Xenova/bert-base-multilingual-cased-ner-hrl', - ]; - - it(models[0], async () => { - let classifier = await pipeline('token-classification', models[0]); - let texts = [ - "The Golden State Warriors are an American professional basketball team based in San Francisco.", - "My name is Sarah and I live in London." - ]; - - // single + score: 0.9405396580696106, + token: 2051, + token_str: "time", + sequence: "once upon a time.", + }, { - let outputs = await classifier(texts[0]); - - let expected = [ - { entity: "B-ORG", score: 0.9998535513877869, index: 2, word: "Golden", start: null, end: null }, - { entity: "I-ORG", score: 0.9998612999916077, index: 3, word: "State", start: null, end: null }, - { entity: "I-ORG", score: 0.999866247177124, index: 4, word: "Warriors", start: null, end: null }, - { entity: "B-LOC", score: 0.9997050166130066, index: 13, word: "San", start: null, end: null }, - { entity: "I-LOC", score: 0.9987282156944275, index: 14, word: "Francisco", start: null, end: null } - ]; - - compare(outputs, expected, 0.05); - - } - - // batched + score: 0.01182964164763689, + token: 13342, + token_str: "mattress", + sequence: "once upon a mattress.", + }, { - let outputs = await classifier(texts); - - let expected = [ - [ - { entity: "B-ORG", score: 0.9998375773429871, index: 2, word: "Golden", start: null, end: null }, - { entity: "I-ORG", score: 0.9998642206192017, index: 3, word: "State", start: null, end: null }, - { entity: "I-ORG", score: 0.9998642802238464, index: 4, word: "Warriors", start: null, end: null }, - { entity: "B-LOC", score: 0.9996914863586426, index: 13, word: "San", start: null, end: null }, - { entity: "I-LOC", score: 0.9989780783653259, index: 14, word: "Francisco", start: null, end: null } - ], [ - { entity: "B-PER", score: 0.997977614402771, index: 4, word: "Sarah", start: null, end: null }, - { entity: "B-LOC", score: 0.9996902346611023, index: 9, word: "London", start: null, end: null } - ] - ]; - - compare(outputs, expected, 0.05); - } - - await classifier.dispose(); - - }, MAX_TEST_EXECUTION_TIME); - }); - - describe('Zero-shot classification', () => { - - // List all models which will be tested - const models = [ - 'Xenova/bart-large-mnli', - ]; - - it(models[0], async () => { - let classifier = await pipeline('zero-shot-classification', models[0]); - - let sequences_to_classify = ['one day I will see the world', 'I love making pizza']; - let candidate_labels = ['travel', 'cooking', 'dancing']; - - // single + score: 0.0017291896510869265, + token: 6480, + token_str: "lifetime", + sequence: "once upon a lifetime.", + }, { - let outputs = await classifier(sequences_to_classify[0], candidate_labels); - let expected = { - sequence: "one day I will see the world", - labels: ["travel", "dancing", "cooking"], - scores: [0.4261703487477968, 0.2903585771517135, 0.28347107410048983] - } - - compare(outputs, expected, 0.2); - - } - - // batched + score: 0.0010079898638650775, + token: 2504, + token_str: "level", + sequence: "once upon a level.", + }, { - let outputs = await classifier(sequences_to_classify, candidate_labels); - let expected = [{ - sequence: "one day I will see the world", - labels: ["travel", "dancing", "cooking"], - scores: [0.4261703487477968, 0.2903585771517135, 0.28347107410048983] - }, { - sequence: "I love making pizza", - labels: ["cooking", "travel", "dancing"], - scores: [0.4660367922118968, 0.2756005926506238, 0.2583626151374795] - }]; - - compare(outputs, expected, 0.2); - - } - - - // batched + multilabel - { - let outputs = await classifier(sequences_to_classify, candidate_labels, { - multi_label: true - }) - let expected = [{ - sequence: "one day I will see the world", - labels: ["travel", "dancing", "cooking"], - scores: [0.7108286792234982, 0.5763787804099745, 0.44303326070949994] - }, { - sequence: "I love making pizza", - labels: ["cooking", "travel", "dancing"], - scores: [0.8527619536354446, 0.7899589317978243, 0.5838912691496106] - }]; - - compare(outputs, expected); - - } - - await classifier.dispose(); - }, MAX_TEST_EXECUTION_TIME); - }); - - describe('Masked language modelling', () => { - - // List all models which will be tested - const models = [ - 'Xenova/bert-base-uncased', - ]; - - it(models[0], async () => { - let unmasker = await pipeline('fill-mask', models[0]); - let texts = [ - "Once upon a [MASK].", - "[MASK] is the capital of England." - ]; - - // single - { - let outputs = await unmasker(texts[0]); - let expected = [ - { - score: 0.9405396580696106, - token: 2051, - token_str: 'time', - sequence: 'once upon a time.' - }, - { - score: 0.01182964164763689, - token: 13342, - token_str: 'mattress', - sequence: 'once upon a mattress.' - }, - { - score: 0.0017291896510869265, - token: 6480, - token_str: 'lifetime', - sequence: 'once upon a lifetime.' - }, - { - score: 0.0010079898638650775, - token: 2504, - token_str: 'level', - sequence: 'once upon a level.' - }, - { - score: 0.0009655007743276656, - token: 2154, - token_str: 'day', - sequence: 'once upon a day.' - } - ]; - compare(outputs, expected); - - } - - - // batched - { - let outputs = await unmasker(texts); - - let expected = [[ - { - score: 0.9900539517402649, - token: 2051, - token_str: 'time', - sequence: 'once upon a time.' - }, - { - score: 0.0012258145725354552, - token: 13342, - token_str: 'mattress', - sequence: 'once upon a mattress.' - }, - { - score: 0.0002977887343149632, - token: 2096, - token_str: 'while', - sequence: 'once upon a while.' - }, - { - score: 0.0001899998023873195, - token: 6480, - token_str: 'lifetime', - sequence: 'once upon a lifetime.' - }, - { - score: 0.00017618606216274202, - token: 2558, - token_str: 'period', - sequence: 'once upon a period.' - } - ], - [ - { - score: 0.2863538861274719, - token: 2414, - token_str: 'london', - sequence: 'london is the capital of england.' - }, - { - score: 0.0607745461165905, - token: 2009, - token_str: 'it', - sequence: 'it is the capital of england.' - }, - { - score: 0.037455108016729355, - token: 6484, - token_str: 'birmingham', - sequence: 'birmingham is the capital of england.' - }, - { - score: 0.029375044628977776, - token: 5087, - token_str: 'manchester', - sequence: 'manchester is the capital of england.' - }, - { - score: 0.0292277242988348, - token: 7067, - token_str: 'bristol', - sequence: 'bristol is the capital of england.' - } - ]]; - - compare(outputs, expected); - - } - - await unmasker.dispose(); - }, MAX_TEST_EXECUTION_TIME); - }); - - describe('Question answering', () => { - let question = 'Who was Jim Henson?' - let context = 'Jim Henson was a nice puppet.' - - - // List all models which will be tested - const models = [ - 'Xenova/distilbert-base-uncased-distilled-squad', - ]; - - it(models[0], async () => { - let answerer = await pipeline('question-answering', models[0]); - - // single - { - let outputs = await answerer(question, context); - let expected = { answer: 'a nice puppet', score: 0.5664517526948352 }; - - compare(outputs, expected, 0.2); - } - - // single + topk - { - let outputs = await answerer(question, context, { - topk: 3, - }); - let expected = [ - { answer: 'a nice puppet', score: 0.5664517526948352 }, - { answer: 'nice puppet', score: 0.1698902336448853 }, - { answer: 'puppet', score: 0.14046057793125577 } - ]; - - compare(outputs, expected, 0.2); - - } - await answerer.dispose(); - }, MAX_TEST_EXECUTION_TIME); - }); - - describe('Summarization', () => { - - // List all models which will be tested - const models = [ - 'Xenova/distilbart-cnn-6-6', - 'Xenova/bart-large-cnn', - ]; - - let texts = [ - `The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct.`, - `The Amazon rainforest (Portuguese: Floresta Amazônica or Amazônia; Spanish: Selva Amazónica, Amazonía or usually Amazonia; French: Forêt amazonienne; Dutch: Amazoneregenwoud), also known in English as Amazonia or the Amazon Jungle, is a moist broadleaf forest that covers most of the Amazon basin of South America. This basin encompasses 7,000,000 square kilometres (2,700,000 sq mi), of which 5,500,000 square kilometres (2,100,000 sq mi) are covered by the rainforest. This region includes territory belonging to nine nations. The majority of the forest is contained within Brazil, with 60% of the rainforest, followed by Peru with 13%, Colombia with 10%, and with minor amounts in Venezuela, Ecuador, Bolivia, Guyana, Suriname and French Guiana. States or departments in four nations contain "Amazonas" in their names. The Amazon represents over half of the planet's remaining rainforests, and comprises the largest and most biodiverse tract of tropical rainforest in the world, with an estimated 390 billion individual trees divided into 16,000 species.` - ]; - - it(models[0], async () => { - let summarizer = await pipeline('summarization', models[0]); - - // batched - { - let summary = await summarizer(texts, { - top_k: 0, - do_sample: false, - }); - expect(summary).toHaveLength(2); - expect(summary[0].summary_text.length).toBeGreaterThan(50); - expect(summary[1].summary_text.length).toBeGreaterThan(50); - } - await summarizer.dispose(); - - }, MAX_TEST_EXECUTION_TIME); - - - it(models[1], async () => { - let summarizer = await pipeline('summarization', models[1]); - - // batched + `forced_bos_token_id` - { - let summary = await summarizer(texts[0], { - top_k: 0, - do_sample: false, - }); - expect(summary).toHaveLength(1); - expect(summary[0].summary_text.length).toBeGreaterThan(50); - } - - await summarizer.dispose(); - - }, MAX_TEST_EXECUTION_TIME); - }); - - describe('Translation', () => { - - // List all models which will be tested - const models = [ - 'Xenova/t5-small', - - // Multilingual model - 'Xenova/nllb-200-distilled-600M', - ]; - - it(models[0], async () => { - let translator = await pipeline('translation_en_to_de', models[0]); - let texts = [ - 'Hello, how are you?', - 'My name is Maria.', - ] - - // single - { - let translation = await translator(texts[0], { - top_k: 0, - do_sample: false - }); - - let expected = [ - { "translation_text": "Hallo, wie sind Sie?" } - ]; - - compare(translation, expected); - } - - // batched - { - let output = await translator(texts, { - top_k: 0, - do_sample: false - }); - - let expected = [ - { 'translation_text': 'Hallo, wie sind Sie?' }, - { 'translation_text': 'Mein Name ist Maria.' } - ]; - - compare(output, expected); - - } - - await translator.dispose(); - }, MAX_TEST_EXECUTION_TIME); - - - it(models[1], async () => { - let translator = await pipeline('translation', models[1]); - let texts = [ - 'Hello world!', - 'I like to walk my dog.', - ] - - // single - { - let translation = await translator(texts[0], { - src_lang: 'eng_Latn', - tgt_lang: 'arb_Arab' - }); - - let expected = [ - { 'translation_text': 'مرحباً، يا عالم!' } - ]; - - compare(translation, expected); - }; - - // single + back-translation - { - let translation1 = await translator(texts[1], { - // src_lang: 'eng_Latn', - tgt_lang: 'ell_Grek' - }); - let translation2 = await translator(translation1[0].translation_text, { - src_lang: 'ell_Grek', - tgt_lang: 'eng_Latn' - }); - - let expected = [ - { translation_text: 'Μου αρέσει να περπατάω το σκυλί μου.' } - ] - - compare(translation1, expected); - - let expectedBack = [ - { translation_text: texts[1] } - ] - compare(translation2, expectedBack); - - } - - await translator.dispose(); - }, MAX_TEST_EXECUTION_TIME); - }); - - describe('Text-to-text generation', () => { - - // List all models which will be tested - const models = [ - 'Xenova/flan-t5-small', - 'Xenova/flan-t5-base', - ]; - - it(models[0], async () => { - let generator = await pipeline('text2text-generation', models[0]); - let text = "Premise: At my age you will probably have learnt one lesson. " + - "Hypothesis: It's not certain how many lessons you'll learn by your thirties. " + - "Does the premise entail the hypothesis?"; - - { - let outputs = await generator(text, { - top_k: 0, - do_sample: false - }); - expect(outputs).toHaveLength(1); - expect(outputs[0].generated_text.length).toBeGreaterThan(1); - } - - await generator.dispose(); - - }, MAX_TEST_EXECUTION_TIME); - - it(models[1], async () => { - let generator = await pipeline('text2text-generation', models[1]); - let text = ` + score: 0.0009655007743276656, + token: 2154, + token_str: "day", + sequence: "once upon a day.", + }, + ]; + compare(outputs, expected); + } + + // batched + { + let outputs = await unmasker(texts); + + let expected = [ + [ + { + score: 0.9900539517402649, + token: 2051, + token_str: "time", + sequence: "once upon a time.", + }, + { + score: 0.0012258145725354552, + token: 13342, + token_str: "mattress", + sequence: "once upon a mattress.", + }, + { + score: 0.0002977887343149632, + token: 2096, + token_str: "while", + sequence: "once upon a while.", + }, + { + score: 0.0001899998023873195, + token: 6480, + token_str: "lifetime", + sequence: "once upon a lifetime.", + }, + { + score: 0.00017618606216274202, + token: 2558, + token_str: "period", + sequence: "once upon a period.", + }, + ], + [ + { + score: 0.2863538861274719, + token: 2414, + token_str: "london", + sequence: "london is the capital of england.", + }, + { + score: 0.0607745461165905, + token: 2009, + token_str: "it", + sequence: "it is the capital of england.", + }, + { + score: 0.037455108016729355, + token: 6484, + token_str: "birmingham", + sequence: "birmingham is the capital of england.", + }, + { + score: 0.029375044628977776, + token: 5087, + token_str: "manchester", + sequence: "manchester is the capital of england.", + }, + { + score: 0.0292277242988348, + token: 7067, + token_str: "bristol", + sequence: "bristol is the capital of england.", + }, + ], + ]; + + compare(outputs, expected); + } + + await unmasker.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Question answering", () => { + let question = "Who was Jim Henson?"; + let context = "Jim Henson was a nice puppet."; + + // List all models which will be tested + const models = ["Xenova/distilbert-base-uncased-distilled-squad"]; + + it( + models[0], + async () => { + let answerer = await pipeline("question-answering", models[0]); + + // single + { + let outputs = await answerer(question, context); + let expected = { answer: "a nice puppet", score: 0.5664517526948352 }; + + compare(outputs, expected, 0.2); + } + + // single + topk + { + let outputs = await answerer(question, context, { + topk: 3, + }); + let expected = [ + { answer: "a nice puppet", score: 0.5664517526948352 }, + { answer: "nice puppet", score: 0.1698902336448853 }, + { answer: "puppet", score: 0.14046057793125577 }, + ]; + + compare(outputs, expected, 0.2); + } + await answerer.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Summarization", () => { + // List all models which will be tested + const models = ["Xenova/distilbart-cnn-6-6", "Xenova/bart-large-cnn"]; + + let texts = [`The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct.`, `The Amazon rainforest (Portuguese: Floresta Amazônica or Amazônia; Spanish: Selva Amazónica, Amazonía or usually Amazonia; French: Forêt amazonienne; Dutch: Amazoneregenwoud), also known in English as Amazonia or the Amazon Jungle, is a moist broadleaf forest that covers most of the Amazon basin of South America. This basin encompasses 7,000,000 square kilometres (2,700,000 sq mi), of which 5,500,000 square kilometres (2,100,000 sq mi) are covered by the rainforest. This region includes territory belonging to nine nations. The majority of the forest is contained within Brazil, with 60% of the rainforest, followed by Peru with 13%, Colombia with 10%, and with minor amounts in Venezuela, Ecuador, Bolivia, Guyana, Suriname and French Guiana. States or departments in four nations contain "Amazonas" in their names. The Amazon represents over half of the planet's remaining rainforests, and comprises the largest and most biodiverse tract of tropical rainforest in the world, with an estimated 390 billion individual trees divided into 16,000 species.`]; + + it( + models[0], + async () => { + let summarizer = await pipeline("summarization", models[0]); + + // batched + { + let summary = await summarizer(texts, { + top_k: 0, + do_sample: false, + }); + expect(summary).toHaveLength(2); + expect(summary[0].summary_text.length).toBeGreaterThan(50); + expect(summary[1].summary_text.length).toBeGreaterThan(50); + } + await summarizer.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + models[1], + async () => { + let summarizer = await pipeline("summarization", models[1]); + + // batched + `forced_bos_token_id` + { + let summary = await summarizer(texts[0], { + top_k: 0, + do_sample: false, + }); + expect(summary).toHaveLength(1); + expect(summary[0].summary_text.length).toBeGreaterThan(50); + } + + await summarizer.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Translation", () => { + // List all models which will be tested + const models = [ + "Xenova/t5-small", + + // Multilingual model + "Xenova/nllb-200-distilled-600M", + ]; + + it( + models[0], + async () => { + let translator = await pipeline("translation_en_to_de", models[0]); + let texts = ["Hello, how are you?", "My name is Maria."]; + + // single + { + let translation = await translator(texts[0], { + top_k: 0, + do_sample: false, + }); + + let expected = [{ translation_text: "Hallo, wie sind Sie?" }]; + + compare(translation, expected); + } + + // batched + { + let output = await translator(texts, { + top_k: 0, + do_sample: false, + }); + + let expected = [{ translation_text: "Hallo, wie sind Sie?" }, { translation_text: "Mein Name ist Maria." }]; + + compare(output, expected); + } + + await translator.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + models[1], + async () => { + let translator = await pipeline("translation", models[1]); + let texts = ["Hello world!", "I like to walk my dog."]; + + // single + { + let translation = await translator(texts[0], { + src_lang: "eng_Latn", + tgt_lang: "arb_Arab", + }); + + let expected = [{ translation_text: "مرحباً، يا عالم!" }]; + + compare(translation, expected); + } + + // single + back-translation + { + let translation1 = await translator(texts[1], { + // src_lang: 'eng_Latn', + tgt_lang: "ell_Grek", + }); + let translation2 = await translator(translation1[0].translation_text, { + src_lang: "ell_Grek", + tgt_lang: "eng_Latn", + }); + + let expected = [{ translation_text: "Μου αρέσει να περπατάω το σκυλί μου." }]; + + compare(translation1, expected); + + let expectedBack = [{ translation_text: texts[1] }]; + compare(translation2, expectedBack); + } + + await translator.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Text-to-text generation", () => { + // List all models which will be tested + const models = ["Xenova/flan-t5-small", "Xenova/flan-t5-base"]; + + it( + models[0], + async () => { + let generator = await pipeline("text2text-generation", models[0]); + let text = "Premise: At my age you will probably have learnt one lesson. " + "Hypothesis: It's not certain how many lessons you'll learn by your thirties. " + "Does the premise entail the hypothesis?"; + + { + let outputs = await generator(text, { + top_k: 0, + do_sample: false, + }); + expect(outputs).toHaveLength(1); + expect(outputs[0].generated_text.length).toBeGreaterThan(1); + } + + await generator.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + models[1], + async () => { + let generator = await pipeline("text2text-generation", models[1]); + let text = ` Q: Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now? A: Roger started with 5 balls. 2 cans of 3 tennis balls each is 6 tennis balls. @@ -586,1029 +562,1023 @@ xdescribe('Pipelines', () => { Q: A juggler can juggle 16 balls. Half of the balls are golf balls, and half of the golf balls are blue. How many blue golf balls are there?`; - // single - { - let outputs = await generator(text, { - top_k: 0, - do_sample: false - }); - expect(outputs).toHaveLength(1); - expect(outputs[0].generated_text.length).toBeGreaterThan(10); - } - await generator.dispose(); - }, MAX_TEST_EXECUTION_TIME); - }); - - describe('Text generation', () => { - - // List all models which will be tested - const models = [ - 'Xenova/distilgpt2', - - 'Xenova/codegen-350M-mono', - ]; - - it(models[0], async () => { - let generator = await pipeline('text-generation', models[0]); - let texts = [ - 'Once upon a time, there was a', - 'I enjoy walking with my cute dog', - ]; - - // single - { - let output = await generator(texts[0], { - max_new_tokens: 10, - top_k: 0, - do_sample: false - }) - expect(output).toHaveLength(1); - expect(output[0].generated_text.length).toBeGreaterThan(texts[0].length); - } - - // single + `num_beams` + `num_return_sequences` - { - let output = await generator(texts[0], { - max_new_tokens: 10, - num_beams: 2, - num_return_sequences: 2, - top_k: 0, - do_sample: false - }) - expect(output).toHaveLength(2); - expect(output[0].generated_text.length).toBeGreaterThan(texts[0].length); - expect(output[1].generated_text.length).toBeGreaterThan(texts[0].length); - - } - - // batched + `num_beams` + `num_return_sequences` - { - let output = await generator(texts, { - max_new_tokens: 10, - num_beams: 2, - num_return_sequences: 2, - top_k: 0, - do_sample: false - }); - expect(output).toHaveLength(2); - expect(output[0]).toHaveLength(2); - expect(output[0][0].generated_text.length).toBeGreaterThan(texts[0].length); - expect(output[0][1].generated_text.length).toBeGreaterThan(texts[0].length); - expect(output[1]).toHaveLength(2); - expect(output[1][0].generated_text.length).toBeGreaterThan(texts[1].length); - expect(output[1][1].generated_text.length).toBeGreaterThan(texts[1].length); - - } - - await generator.dispose(); - - }, MAX_TEST_EXECUTION_TIME); - - - it(models[1], async () => { - let generator = await pipeline('text-generation', models[1]); - let code = 'def fib(n):'; - - // single + `added_tokens` - { - let output = await generator(code, { - max_new_tokens: 45, - top_k: 0, - do_sample: false - }) - expect(output).toHaveLength(1); - expect(output[0].generated_text.length).toBeGreaterThan(code.length); - } - await generator.dispose(); - - }, MAX_TEST_EXECUTION_TIME); - }); - - describe('Feature extraction', () => { - - // List all models which will be tested - const models = [ - 'Xenova/all-MiniLM-L6-v2', - ]; - - it(models[0], async () => { - let extractor = await pipeline('feature-extraction', models[0]); - - // Provide sentences - let sentences = [ - 'This framework generates embeddings for each input sentence', - 'Sentences are passed as a list of string.', - 'The quick brown fox jumps over the lazy dog.' - ] - - // Without pooling or normalization - { - - let output = await extractor(sentences); - expect(output.dims).toHaveLength(3); - } - - // With pooling and normalization + compare features - { - let output = await extractor(sentences, { pooling: 'mean', normalize: true }); - expect(output.dims).toHaveLength(2); - - // Convert Tensor to JS list - output = output.tolist(); - - let pairwiseScores = [[output[0], output[1]], [output[0], output[2]], [output[1], output[2]]].map(x => cos_sim(...x)) - - let expected = [0.502872309810269, 0.11088411026413121, 0.09602621986931259] - compare(pairwiseScores, expected); - } - await extractor.dispose(); - - }, MAX_TEST_EXECUTION_TIME); - }); - - describe('Speech-to-text generation', () => { - - // List all models which will be tested - const models = [ - // whisper - 'Xenova/whisper-tiny.en', // English-only - 'Xenova/whisper-small', // Multilingual - ['Xenova/whisper-tiny.en', 'output_attentions'], // English-only + `output_attentions` - ['Xenova/whisper-small', 'output_attentions'], // Multilingual + `output_attentions` - - // wav2vec2 - 'jonatasgrosman/wav2vec2-large-xlsr-53-english', - ]; - - it(models[0], async () => { - let transcriber = await pipeline('automatic-speech-recognition', models[0]); - - let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav'; - let audioData = await loadAudio(url); - - { // Transcribe English - let output = await transcriber(audioData); - expect(output.text.length).toBeGreaterThan(50); - // { text: " And so my fellow Americans ask not what your country can do for you, ask what you can do for your country." } - } - - { // Transcribe English w/ timestamps. - let output = await transcriber(audioData, { return_timestamps: true }); - expect(output.text.length).toBeGreaterThan(50); - expect(output.chunks.length).toBeGreaterThan(0); - // { - // text: " And so my fellow Americans ask not what your country can do for you, ask what you can do for your country." - // chunks: [ - // { timestamp: [0, 8], text: " And so my fellow Americans ask not what your country can do for you" } - // { timestamp: [8, 11], text: " ask what you can do for your country." } - // ] - // } - } - await transcriber.dispose(); - - }, MAX_TEST_EXECUTION_TIME); - - it(models[1], async () => { - let transcriber = await pipeline('automatic-speech-recognition', models[1]); - - let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/french-audio.wav'; - let audioData = await loadAudio(url); - - { // Transcribe French - let output = await transcriber(audioData, { language: 'french', task: 'transcribe' }); - expect(output.text.length).toBeGreaterThan(20); - // { text: " J'adore, j'aime, je n'aime pas, je déteste." } + // single + { + let outputs = await generator(text, { + top_k: 0, + do_sample: false, + }); + expect(outputs).toHaveLength(1); + expect(outputs[0].generated_text.length).toBeGreaterThan(10); + } + await generator.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Text generation", () => { + // List all models which will be tested + const models = ["Xenova/distilgpt2", "Xenova/codegen-350M-mono"]; + + it( + models[0], + async () => { + let generator = await pipeline("text-generation", models[0]); + let texts = ["Once upon a time, there was a", "I enjoy walking with my cute dog"]; + + // single + { + let output = await generator(texts[0], { + max_new_tokens: 10, + top_k: 0, + do_sample: false, + }); + expect(output).toHaveLength(1); + expect(output[0].generated_text.length).toBeGreaterThan(texts[0].length); + } + + // single + `num_beams` + `num_return_sequences` + { + let output = await generator(texts[0], { + max_new_tokens: 10, + num_beams: 2, + num_return_sequences: 2, + top_k: 0, + do_sample: false, + }); + expect(output).toHaveLength(2); + expect(output[0].generated_text.length).toBeGreaterThan(texts[0].length); + expect(output[1].generated_text.length).toBeGreaterThan(texts[0].length); + } + + // batched + `num_beams` + `num_return_sequences` + { + let output = await generator(texts, { + max_new_tokens: 10, + num_beams: 2, + num_return_sequences: 2, + top_k: 0, + do_sample: false, + }); + expect(output).toHaveLength(2); + expect(output[0]).toHaveLength(2); + expect(output[0][0].generated_text.length).toBeGreaterThan(texts[0].length); + expect(output[0][1].generated_text.length).toBeGreaterThan(texts[0].length); + expect(output[1]).toHaveLength(2); + expect(output[1][0].generated_text.length).toBeGreaterThan(texts[1].length); + expect(output[1][1].generated_text.length).toBeGreaterThan(texts[1].length); + } + + await generator.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + models[1], + async () => { + let generator = await pipeline("text-generation", models[1]); + let code = "def fib(n):"; + + // single + `added_tokens` + { + let output = await generator(code, { + max_new_tokens: 45, + top_k: 0, + do_sample: false, + }); + expect(output).toHaveLength(1); + expect(output[0].generated_text.length).toBeGreaterThan(code.length); + } + await generator.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Feature extraction", () => { + // List all models which will be tested + const models = ["Xenova/all-MiniLM-L6-v2"]; + + it( + models[0], + async () => { + let extractor = await pipeline("feature-extraction", models[0]); + + // Provide sentences + let sentences = ["This framework generates embeddings for each input sentence", "Sentences are passed as a list of string.", "The quick brown fox jumps over the lazy dog."]; + + // Without pooling or normalization + { + let output = await extractor(sentences); + expect(output.dims).toHaveLength(3); + } + + // With pooling and normalization + compare features + { + let output = await extractor(sentences, { pooling: "mean", normalize: true }); + expect(output.dims).toHaveLength(2); + + // Convert Tensor to JS list + output = output.tolist(); + + let pairwiseScores = [ + [output[0], output[1]], + [output[0], output[2]], + [output[1], output[2]], + ].map((x) => cos_sim(...x)); + + let expected = [0.502872309810269, 0.11088411026413121, 0.09602621986931259]; + compare(pairwiseScores, expected); + } + await extractor.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Speech-to-text generation", () => { + // List all models which will be tested + const models = [ + // whisper + "Xenova/whisper-tiny.en", // English-only + "Xenova/whisper-small", // Multilingual + ["Xenova/whisper-tiny.en", "output_attentions"], // English-only + `output_attentions` + ["Xenova/whisper-small", "output_attentions"], // Multilingual + `output_attentions` + + // wav2vec2 + "jonatasgrosman/wav2vec2-large-xlsr-53-english", + ]; + + it( + models[0], + async () => { + let transcriber = await pipeline("automatic-speech-recognition", models[0]); + + let url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav"; + let audioData = await loadAudio(url); + + { + // Transcribe English + let output = await transcriber(audioData); + expect(output.text.length).toBeGreaterThan(50); + // { text: " And so my fellow Americans ask not what your country can do for you, ask what you can do for your country." } + } + + { + // Transcribe English w/ timestamps. + let output = await transcriber(audioData, { return_timestamps: true }); + expect(output.text.length).toBeGreaterThan(50); + expect(output.chunks.length).toBeGreaterThan(0); + // { + // text: " And so my fellow Americans ask not what your country can do for you, ask what you can do for your country." + // chunks: [ + // { timestamp: [0, 8], text: " And so my fellow Americans ask not what your country can do for you" } + // { timestamp: [8, 11], text: " ask what you can do for your country." } + // ] + // } + } + await transcriber.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + models[1], + async () => { + let transcriber = await pipeline("automatic-speech-recognition", models[1]); + + let url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/french-audio.wav"; + let audioData = await loadAudio(url); + + { + // Transcribe French + let output = await transcriber(audioData, { language: "french", task: "transcribe" }); + expect(output.text.length).toBeGreaterThan(20); + // { text: " J'adore, j'aime, je n'aime pas, je déteste." } + } + + { + // Translate French to English. + let output = await transcriber(audioData, { language: "french", task: "translate" }); + expect(output.text.length).toBeGreaterThan(20); + // { text: " I love, I like, I don't like, I hate." } + } + await transcriber.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + models[2].join(" + "), + async () => { + let transcriber = await pipeline("automatic-speech-recognition", m(models[2][0]), { + revision: models[2][1], + quantized: false, + }); + + let url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav"; + let audioData = await loadAudio(url); + + { + // Transcribe English w/ word-level timestamps. + let output = await transcriber(audioData, { return_timestamps: "word" }); + const target = { + text: " And so my fellow Americans ask not what your country can do for you ask what you can do for your country.", + chunks: [ + { text: " And", timestamp: [0, 0.78] }, + { text: " so", timestamp: [0.78, 1.06] }, + { text: " my", timestamp: [1.06, 1.46] }, + { text: " fellow", timestamp: [1.46, 1.76] }, + { text: " Americans", timestamp: [1.76, 2.22] }, + { text: " ask", timestamp: [2.22, 3.88] }, + { text: " not", timestamp: [3.88, 4.52] }, + { text: " what", timestamp: [4.52, 5.68] }, + { text: " your", timestamp: [5.68, 6] }, + { text: " country", timestamp: [6, 6.36] }, + { text: " can", timestamp: [6.36, 6.76] }, + { text: " do", timestamp: [6.76, 7.02] }, + { text: " for", timestamp: [7.02, 7.24] }, + { text: " you", timestamp: [7.24, 8.02] }, + { text: " ask", timestamp: [8.28, 8.66] }, + { text: " what", timestamp: [8.66, 8.94] }, + { text: " you", timestamp: [8.94, 9.28] }, + { text: " can", timestamp: [9.28, 9.5] }, + { text: " do", timestamp: [9.5, 9.72] }, + { text: " for", timestamp: [9.72, 9.92] }, + { text: " your", timestamp: [9.92, 10.22] }, + { text: " country.", timestamp: [10.22, 13.36] }, + ], + }; + + compare(output, target); + } + + await transcriber.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + models[3].join(" + "), + async () => { + let transcriber = await pipeline("automatic-speech-recognition", m(models[3][0]), { + revision: models[3][1], + }); + + let url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/japanese-audio.wav"; + let audioData = await loadAudio(url); + + { + // Transcribe Japanese w/ word-level timestamps. + let output = await transcriber(audioData, { return_timestamps: "word", language: "japanese", task: "transcribe" }); + const target = { + text: "モリナガの美味しい牛乳は濃い青色に牛乳瓶を払ったゼザインのパック牛乳である。", + chunks: [ + { text: "モ", timestamp: [0, 0.56] }, + { text: "リ", timestamp: [0.56, 0.64] }, + { text: "ナ", timestamp: [0.64, 0.8] }, + { text: "ガ", timestamp: [0.8, 0.88] }, + { text: "の", timestamp: [0.88, 1.04] }, + { text: "美味", timestamp: [1.04, 1.22] }, + { text: "しい", timestamp: [1.22, 1.46] }, + { text: "牛", timestamp: [1.46, 1.76] }, + { text: "乳", timestamp: [1.76, 1.94] }, + { text: "は", timestamp: [1.94, 2.14] }, + { text: "濃", timestamp: [2.14, 2.34] }, + { text: "い", timestamp: [2.34, 2.48] }, + { text: "青", timestamp: [2.48, 2.62] }, + { text: "色", timestamp: [2.62, 2.84] }, + { text: "に", timestamp: [2.84, 3] }, + { text: "牛", timestamp: [3, 3.22] }, + { text: "乳", timestamp: [3.22, 3.42] }, + { text: "瓶", timestamp: [3.42, 3.58] }, + { text: "を", timestamp: [3.58, 3.82] }, + { text: "払", timestamp: [3.82, 4] }, + { text: "った", timestamp: [4, 4.32] }, + { text: "ゼ", timestamp: [4.32, 4.56] }, + { text: "ザ", timestamp: [4.56, 4.6] }, + { text: "イ", timestamp: [4.6, 4.74] }, + { text: "ン", timestamp: [4.74, 4.8] }, + { text: "の", timestamp: [4.8, 4.94] }, + { text: "パ", timestamp: [4.94, 5.12] }, + { text: "ック", timestamp: [5.12, 5.26] }, + { text: "牛", timestamp: [5.26, 5.52] }, + { text: "乳", timestamp: [5.52, 5.72] }, + { text: "で", timestamp: [5.72, 5.86] }, + { text: "ある。", timestamp: [5.86, 6.62] }, + ], + }; + + compare(output, target); + } + + await transcriber.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + models[4], + async () => { + let transcriber = await pipeline("automatic-speech-recognition", m(models[4])); + + let url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav"; + let audioData = await loadAudio(url); + + { + // Transcribe + let output = await transcriber(audioData); + expect(output.text.length).toBeGreaterThan(50); + // { text: "and so my fellow america ask not what your country can do for you ask what you can do for your country" } + } + + await transcriber.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Text-to-speech generation", () => { + // List all models which will be tested + const models = ["Xenova/speecht5_tts", "Xenova/mms-tts-fra"]; + + it( + models[0], + async () => { + let synthesizer = await pipeline("text-to-speech", models[0], { + // NOTE: Although the quantized version produces incoherent results, + // it it is okay to use for testing. + // quantized: false, + }); + + let speaker_embeddings = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/speaker_embeddings.bin"; + + { + // Generate English speech + let output = await synthesizer("Hello, my dog is cute", { speaker_embeddings }); + expect(output.audio.length).toBeGreaterThan(0); + expect(output.sampling_rate).toEqual(16000); + } + + await synthesizer.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + models[1], + async () => { + let synthesizer = await pipeline("text-to-speech", models[1]); + + { + // Generate French speech + let output = await synthesizer("Bonjour"); + expect(output.audio.length).toBeGreaterThan(0); + expect(output.sampling_rate).toEqual(16000); + } + + await synthesizer.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Audio classification", () => { + // List all models which will be tested + const models = ["Xenova/wav2vec2-large-xlsr-53-gender-recognition-librispeech"]; + + it( + models[0], + async () => { + let classifier = await pipeline("audio-classification", models[0]); + + let url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav"; + let audioData = await loadAudio(url); + + { + // Classify audio + let outputs = await classifier(audioData); + + let expected = [ + { score: 0.997512936592102, label: "male" }, + { score: 0.0024870133493095636, label: "female" }, + ]; + compare(outputs, expected); + } + + await classifier.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Image-to-text", () => { + // List all models which will be tested + const models = ["Xenova/vit-gpt2-image-captioning"]; + + it( + models[0], + async () => { + let captioner = await pipeline("image-to-text", models[0]); + + let url = "https://huggingface.co/datasets/mishig/sample_images/resolve/main/savanna.jpg"; + let urls = ["https://huggingface.co/datasets/mishig/sample_images/resolve/main/football-match.jpg", "https://huggingface.co/datasets/mishig/sample_images/resolve/main/airport.jpg"]; + + // single + { + let output = await captioner(url, { + top_k: 0, + do_sample: false, + }); + // let expected = [ + // { "generated_text": "a herd of giraffes and zebras grazing in a field" } + // ] + + expect(output).toHaveLength(1); + expect(output[0].generated_text.length).toBeGreaterThan(10); + } + + // single + generation options + { + let output = await captioner(url, { + max_new_tokens: 20, + num_beams: 2, + num_return_sequences: 2, + top_k: 0, + do_sample: false, + }); + // let expected = [ + // { "generated_text": "a herd of giraffes and zebras grazing in a field" }, + // { "generated_text": "a herd of giraffes and zebras in a grassy field" } + // ] + + expect(output).toHaveLength(2); + expect(output[0].generated_text.length).toBeGreaterThan(10); + expect(output[1].generated_text.length).toBeGreaterThan(10); + } + + // batched + { + let output = await captioner(urls, { + top_k: 0, + do_sample: false, + }); + // let expected = [ + // [{ "generated_text": "two men are kicking a soccer ball in a soccer game" }], + // [{ "generated_text": "a plane on the tarmac with a passenger bus" }] + // ] + + expect(output).toHaveLength(2); + expect(output[0]).toHaveLength(1); + expect(output[0][0].generated_text.length).toBeGreaterThan(10); + expect(output[1]).toHaveLength(1); + expect(output[1][0].generated_text.length).toBeGreaterThan(10); + } + + // batched + generation options + { + let output = await captioner(urls, { + max_new_tokens: 20, + num_beams: 2, + num_return_sequences: 2, + top_k: 0, + do_sample: false, + }); + // let expected = [ + // [ + // { "generated_text": "two men are kicking a soccer ball on a field" }, + // { "generated_text": "two men are kicking a soccer ball in a soccer game" } + // ], [ + // { "generated_text": "a plane on a tarmac with a group of buses" }, + // { "generated_text": "a plane on a tarmac with a group of people on the ground" } + // ] + // ]; + + expect(output).toHaveLength(2); + expect(output[0]).toHaveLength(2); + expect(output[0][0].generated_text.length).toBeGreaterThan(10); + expect(output[0][1].generated_text.length).toBeGreaterThan(10); + expect(output[1]).toHaveLength(2); + expect(output[1][0].generated_text.length).toBeGreaterThan(10); + expect(output[1][1].generated_text.length).toBeGreaterThan(10); + } + await captioner.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Image classification", () => { + // List all models which will be tested + const models = ["Xenova/vit-base-patch16-224"]; + + it( + models[0], + async () => { + let classifier = await pipeline("image-classification", models[0]); + + let url = "https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg"; + let urls = ["https://huggingface.co/datasets/mishig/sample_images/resolve/main/palace.jpg", "https://huggingface.co/datasets/mishig/sample_images/resolve/main/teapot.jpg"]; + + // single + { + let outputs = await classifier(url); + + let expected = [{ label: "tiger, Panthera tigris", score: 0.607988178730011 }]; + + compare(outputs, expected, 0.2); + } + + // single + topk + { + let outputs = await classifier(url, { + topk: 2, + }); + + let expected = [ + { label: "tiger, Panthera tigris", score: 0.607988178730011 }, + { label: "tiger cat", score: 0.3877776563167572 }, + ]; + + compare(outputs, expected, 0.2); + } + + // batched + { + let outputs = await classifier(urls); + + let expected = [ + { label: "palace", score: 0.9986862540245056 }, + { label: "teapot", score: 0.987880527973175 }, + ]; + + compare(outputs, expected); + } + + // batched + topk + { + let outputs = await classifier(urls, { + topk: 2, + }); + + let expected = [ + [ + { label: "palace", score: 0.9986862540245056 }, + { label: "castle", score: 0.00037879671435803175 }, + ], + [ + { label: "teapot", score: 0.987880527973175 }, + { label: "coffeepot", score: 0.006591461598873138 }, + ], + ]; + + compare(outputs, expected); + } + + await classifier.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Image segmentation", () => { + // List all models which will be tested + const models = ["Xenova/detr-resnet-50-panoptic", "Xenova/segformer_b2_clothes"]; + + it( + models[0], + async () => { + let segmenter = await pipeline("image-segmentation", models[0], { + // Quantized version of model produces incorrect results + quantized: false, + }); + let img = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png"; + + // single + { + let outputs = await segmenter(img); + + let expected = [ + { score: 0.9916538596153259, label: "cat", mask: 58998 }, + { score: 0.9987397789955139, label: "remote", mask: 4164 }, + { score: 0.9994599223136902, label: "remote", mask: 2275 }, + { score: 0.9730215072631836, label: "couch", mask: 176980 }, + { score: 0.9993911385536194, label: "cat", mask: 52670 }, + ]; + + let outputLabels = outputs.map((x) => x.label); + let expectedLabels = expected.map((x) => x.label); + + expect(outputLabels).toHaveLength(expectedLabels.length); + expect(outputLabels.sort()).toEqual(expectedLabels.sort()); + } + + await segmenter.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + models[1], + async () => { + let segmenter = await pipeline("image-segmentation", models[1]); + let img = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/young-man-standing-and-leaning-on-car.jpg"; + + // single + { + let outputs = await segmenter(img); + + let expected = [{ label: "Background" }, { label: "Hair" }, { label: "Upper-clothes" }, { label: "Pants" }, { label: "Left-shoe" }, { label: "Right-shoe" }, { label: "Face" }, { label: "Left-leg" }, { label: "Right-leg" }, { label: "Left-arm" }, { label: "Right-arm" }]; + + let outputLabels = outputs.map((x) => x.label); + let expectedLabels = expected.map((x) => x.label); + + expect(outputLabels).toHaveLength(expectedLabels.length); + expect(outputLabels.sort()).toEqual(expectedLabels.sort()); + + // check that all scores are null, and masks have correct dimensions + for (let output of outputs) { + expect(output.score).toBeNull(); + expect(output.mask.width).toEqual(970); + expect(output.mask.height).toEqual(1455); + expect(output.mask.channels).toEqual(1); + } + } + + await segmenter.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Zero-shot image classification", () => { + // List all models which will be tested + const models = ["Xenova/clip-vit-base-patch32"]; + + it( + models[0], + async () => { + let classifier = await pipeline("zero-shot-image-classification", models[0]); + + let url = "https://huggingface.co/datasets/mishig/sample_images/resolve/main/football-match.jpg"; + let urls = ["https://huggingface.co/datasets/mishig/sample_images/resolve/main/football-match.jpg", "https://huggingface.co/datasets/mishig/sample_images/resolve/main/airport.jpg", "https://huggingface.co/datasets/mishig/sample_images/resolve/main/savanna.jpg"]; + + let classes = ["football", "airport", "animals"]; + + // single + { + let output = await classifier(url, classes); + + let expected = [ + { score: 0.9719080924987793, label: "football" }, + { score: 0.022564826533198357, label: "animals" }, + { score: 0.005527070723474026, label: "airport" }, + ]; + compare(output, expected, 0.1); + } + + // batched + { + let output = await classifier(urls, classes); + + let expected = [ + [ + { score: 0.9712504148483276, label: "football" }, + { score: 0.022469401359558105, label: "animals" }, + { score: 0.006280169822275639, label: "airport" }, + ], + [ + { score: 0.997433602809906, label: "airport" }, + { score: 0.0016500800848007202, label: "animals" }, + { score: 0.0009163151844404638, label: "football" }, + ], + [ + { score: 0.9851226806640625, label: "animals" }, + { score: 0.007516484707593918, label: "football" }, + { score: 0.007360846735537052, label: "airport" }, + ], + ]; + compare(output, expected, 0.1); + } + await classifier.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Object detection", () => { + // List all models which will be tested + const models = ["Xenova/detr-resnet-50"]; + + it( + models[0], + async () => { + let detector = await pipeline("object-detection", models[0]); + + // TODO add batched test cases when supported + let url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cats.jpg"; + let urls = ["https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/savanna.jpg"]; + + // single + threshold + { + let output = await detector(url, { + threshold: 0.9, + }); + + // let expected = [ + // { + // "score": 0.9977124929428101, + // "label": "remote", + // "box": { "xmin": 41, "ymin": 70, "xmax": 176, "ymax": 118 } + // }, + // { + // "score": 0.9984639883041382, + // "label": "remote", + // "box": { "xmin": 332, "ymin": 73, "xmax": 369, "ymax": 188 } + // }, + // { + // "score": 0.9964856505393982, + // "label": "couch", + // "box": { "xmin": 0, "ymin": 1, "xmax": 639, "ymax": 474 } + // }, + // { + // "score": 0.9988334774971008, + // "label": "cat", + // "box": { "xmin": 11, "ymin": 51, "xmax": 314, "ymax": 472 } + // }, + // { + // "score": 0.9982513785362244, + // "label": "cat", + // "box": { "xmin": 345, "ymin": 22, "xmax": 640, "ymax": 371 } + // } + // ] + + expect(output.length).toBeGreaterThan(0); + for (let cls of output) { + expect(typeof cls.score).toBe("number"); + expect(typeof cls.label).toBe("string"); + for (let key of ["xmin", "ymin", "xmax", "ymax"]) { + expect(typeof cls.box[key]).toBe("number"); } - - { // Translate French to English. - let output = await transcriber(audioData, { language: 'french', task: 'translate' }); - expect(output.text.length).toBeGreaterThan(20); - // { text: " I love, I like, I don't like, I hate." } - } - await transcriber.dispose(); - - }, MAX_TEST_EXECUTION_TIME); - - it(models[2].join(' + '), async () => { - let transcriber = await pipeline('automatic-speech-recognition', m(models[2][0]), { - revision: models[2][1], - quantized: false, - }); - - - let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav'; - let audioData = await loadAudio(url); - - { // Transcribe English w/ word-level timestamps. - let output = await transcriber(audioData, { return_timestamps: 'word' }); - const target = { - "text": " And so my fellow Americans ask not what your country can do for you ask what you can do for your country.", - "chunks": [ - { "text": " And", "timestamp": [0, 0.78] }, - { "text": " so", "timestamp": [0.78, 1.06] }, - { "text": " my", "timestamp": [1.06, 1.46] }, - { "text": " fellow", "timestamp": [1.46, 1.76] }, - { "text": " Americans", "timestamp": [1.76, 2.22] }, - { "text": " ask", "timestamp": [2.22, 3.88] }, - { "text": " not", "timestamp": [3.88, 4.52] }, - { "text": " what", "timestamp": [4.52, 5.68] }, - { "text": " your", "timestamp": [5.68, 6] }, - { "text": " country", "timestamp": [6, 6.36] }, - { "text": " can", "timestamp": [6.36, 6.76] }, - { "text": " do", "timestamp": [6.76, 7.02] }, - { "text": " for", "timestamp": [7.02, 7.24] }, - { "text": " you", "timestamp": [7.24, 8.02] }, - { "text": " ask", "timestamp": [8.28, 8.66] }, - { "text": " what", "timestamp": [8.66, 8.94] }, - { "text": " you", "timestamp": [8.94, 9.28] }, - { "text": " can", "timestamp": [9.28, 9.5] }, - { "text": " do", "timestamp": [9.5, 9.72] }, - { "text": " for", "timestamp": [9.72, 9.92] }, - { "text": " your", "timestamp": [9.92, 10.22] }, - { "text": " country.", "timestamp": [10.22, 13.36] } - ] - } - - compare(output, target); + } + } + + // batched + threshold + percentage + { + let output = await detector(urls, { + threshold: 0.9, + percentage: true, + }); + // let expected = [[ + // { + // score: 0.9991137385368347, + // label: 'zebra', + // box: { xmin: 0.65165576338768, ymin: 0.685152679681778, xmax: 0.723189502954483, ymax: 0.8801506459712982 } + // }, + // { + // score: 0.998811662197113, + // label: 'zebra', + // box: { xmin: 0.20797613263130188, ymin: 0.6543092578649521, xmax: 0.4147692620754242, ymax: 0.9040975719690323 } + // }, + // { + // score: 0.9707837104797363, + // label: 'giraffe', + // box: { xmin: 0.02498096227645874, ymin: 0.40549489855766296, xmax: 0.38669759035110474, ymax: 0.7895723879337311 } + // }, + // { + // score: 0.9984336495399475, + // label: 'zebra', + // box: { xmin: 0.3540637195110321, ymin: 0.6370827257633209, xmax: 0.5765090882778168, ymax: 0.8480959832668304 } + // }, + // { + // score: 0.9986463785171509, + // label: 'giraffe', + // box: { xmin: 0.6763969212770462, ymin: 0.25748637318611145, xmax: 0.974339172244072, ymax: 0.8684568107128143 } + // } + // ]] + + expect(output).toHaveLength(urls.length); // Same number of inputs as outputs + + for (let i = 0; i < output.length; ++i) { + expect(output[i].length).toBeGreaterThan(0); + for (let cls of output[i]) { + expect(typeof cls.score).toBe("number"); + expect(typeof cls.label).toBe("string"); + for (let key of ["xmin", "ymin", "xmax", "ymax"]) { + expect(typeof cls.box[key]).toBe("number"); + } } - - await transcriber.dispose(); - - }, MAX_TEST_EXECUTION_TIME); - - it(models[3].join(' + '), async () => { - let transcriber = await pipeline('automatic-speech-recognition', m(models[3][0]), { - revision: models[3][1], - }); - - let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/japanese-audio.wav'; - let audioData = await loadAudio(url); - - { // Transcribe Japanese w/ word-level timestamps. - let output = await transcriber(audioData, { return_timestamps: 'word', language: 'japanese', task: 'transcribe' }); - const target = { - "text": "モリナガの美味しい牛乳は濃い青色に牛乳瓶を払ったゼザインのパック牛乳である。", - "chunks": [ - { "text": "モ", "timestamp": [0, 0.56] }, - { "text": "リ", "timestamp": [0.56, 0.64] }, - { "text": "ナ", "timestamp": [0.64, 0.8] }, - { "text": "ガ", "timestamp": [0.8, 0.88] }, - { "text": "の", "timestamp": [0.88, 1.04] }, - { "text": "美味", "timestamp": [1.04, 1.22] }, - { "text": "しい", "timestamp": [1.22, 1.46] }, - { "text": "牛", "timestamp": [1.46, 1.76] }, - { "text": "乳", "timestamp": [1.76, 1.94] }, - { "text": "は", "timestamp": [1.94, 2.14] }, - { "text": "濃", "timestamp": [2.14, 2.34] }, - { "text": "い", "timestamp": [2.34, 2.48] }, - { "text": "青", "timestamp": [2.48, 2.62] }, - { "text": "色", "timestamp": [2.62, 2.84] }, - { "text": "に", "timestamp": [2.84, 3] }, - { "text": "牛", "timestamp": [3, 3.22] }, - { "text": "乳", "timestamp": [3.22, 3.42] }, - { "text": "瓶", "timestamp": [3.42, 3.58] }, - { "text": "を", "timestamp": [3.58, 3.82] }, - { "text": "払", "timestamp": [3.82, 4] }, - { "text": "った", "timestamp": [4, 4.32] }, - { "text": "ゼ", "timestamp": [4.32, 4.56] }, - { "text": "ザ", "timestamp": [4.56, 4.6] }, - { "text": "イ", "timestamp": [4.6, 4.74] }, - { "text": "ン", "timestamp": [4.74, 4.8] }, - { "text": "の", "timestamp": [4.8, 4.94] }, - { "text": "パ", "timestamp": [4.94, 5.12] }, - { "text": "ック", "timestamp": [5.12, 5.26] }, - { "text": "牛", "timestamp": [5.26, 5.52] }, - { "text": "乳", "timestamp": [5.52, 5.72] }, - { "text": "で", "timestamp": [5.72, 5.86] }, - { "text": "ある。", "timestamp": [5.86, 6.62] } - ] - } - - compare(output, target); + } + } + + await detector.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Zero-shot object detection", () => { + // List all models which will be tested + const models = ["Xenova/owlvit-base-patch32"]; + + it( + models[0], + async () => { + let detector = await pipeline("zero-shot-object-detection", models[0]); + + // single (default) + { + let url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/astronaut.png"; + let candidate_labels = ["human face", "rocket", "helmet", "american flag"]; + + let output = await detector(url, candidate_labels); + + // let expected = [ + // { + // score: 0.24392342567443848, + // label: 'human face', + // box: { xmin: 180, ymin: 67, xmax: 274, ymax: 175 } + // }, + // { + // score: 0.15129457414150238, + // label: 'american flag', + // box: { xmin: 0, ymin: 4, xmax: 106, ymax: 513 } + // }, + // { + // score: 0.13649864494800568, + // label: 'helmet', + // box: { xmin: 277, ymin: 337, xmax: 511, ymax: 511 } + // }, + // { + // score: 0.10262022167444229, + // label: 'rocket', + // box: { xmin: 352, ymin: -1, xmax: 463, ymax: 287 } + // } + // ] + + expect(output.length).toBeGreaterThan(0); + for (let cls of output) { + expect(typeof cls.score).toBe("number"); + expect(typeof cls.label).toBe("string"); + for (let key of ["xmin", "ymin", "xmax", "ymax"]) { + expect(typeof cls.box[key]).toBe("number"); } - - await transcriber.dispose(); - - }, MAX_TEST_EXECUTION_TIME); - - - it(models[4], async () => { - let transcriber = await pipeline('automatic-speech-recognition', m(models[4])); - - let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav'; - let audioData = await loadAudio(url); - - { // Transcribe - let output = await transcriber(audioData); - expect(output.text.length).toBeGreaterThan(50); - // { text: "and so my fellow america ask not what your country can do for you ask what you can do for your country" } + } + } + + // topk + threshold + percentage + { + let url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/beach.png"; + let candidate_labels = ["hat", "book", "sunglasses", "camera"]; + + let output = await detector(url, candidate_labels, { + topk: 4, + threshold: 0.05, + percentage: true, + }); + + // let expected = [ + // { + // score: 0.1606510728597641, + // label: 'sunglasses', + // box: { xmin: 347, ymin: 229, xmax: 429, ymax: 264 } + // }, + // { + // score: 0.08935828506946564, + // label: 'hat', + // box: { xmin: 38, ymin: 174, xmax: 258, ymax: 364 } + // }, + // { + // score: 0.08530698716640472, + // label: 'camera', + // box: { xmin: 187, ymin: 350, xmax: 260, ymax: 411 } + // }, + // { + // score: 0.08349756896495819, + // label: 'book', + // box: { xmin: 261, ymin: 280, xmax: 494, ymax: 425 } + // } + // ] + + expect(output.length).toBeGreaterThan(0); + for (let cls of output) { + expect(typeof cls.score).toBe("number"); + expect(typeof cls.label).toBe("string"); + for (let key of ["xmin", "ymin", "xmax", "ymax"]) { + expect(typeof cls.box[key]).toBe("number"); } - - await transcriber.dispose(); - - }, MAX_TEST_EXECUTION_TIME); - }); - - describe('Text-to-speech generation', () => { - - // List all models which will be tested - const models = [ - 'Xenova/speecht5_tts', - 'Xenova/mms-tts-fra', - ]; - - it(models[0], async () => { - let synthesizer = await pipeline('text-to-speech', models[0], { - // NOTE: Although the quantized version produces incoherent results, - // it it is okay to use for testing. - // quantized: false, - }); - - let speaker_embeddings = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/speaker_embeddings.bin'; - - { // Generate English speech - let output = await synthesizer('Hello, my dog is cute', { speaker_embeddings }); - expect(output.audio.length).toBeGreaterThan(0); - expect(output.sampling_rate).toEqual(16000); - } - - await synthesizer.dispose(); - - }, MAX_TEST_EXECUTION_TIME); - - it(models[1], async () => { - let synthesizer = await pipeline('text-to-speech', models[1]); - - { // Generate French speech - let output = await synthesizer('Bonjour'); - expect(output.audio.length).toBeGreaterThan(0); - expect(output.sampling_rate).toEqual(16000); - } - - await synthesizer.dispose(); - - }, MAX_TEST_EXECUTION_TIME); - - }); - - describe('Audio classification', () => { - - // List all models which will be tested - const models = [ - 'Xenova/wav2vec2-large-xlsr-53-gender-recognition-librispeech', - ]; - - it(models[0], async () => { - let classifier = await pipeline('audio-classification', models[0]); - - let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav'; - let audioData = await loadAudio(url); - - { // Classify audio - let outputs = await classifier(audioData); - - let expected = [ - { 'score': 0.997512936592102, 'label': 'male' }, - { 'score': 0.0024870133493095636, 'label': 'female' } - ]; - compare(outputs, expected); - } - - await classifier.dispose(); - - }, MAX_TEST_EXECUTION_TIME); - - }); - - describe('Image-to-text', () => { - - // List all models which will be tested - const models = [ - 'Xenova/vit-gpt2-image-captioning', - ]; - - it(models[0], async () => { - let captioner = await pipeline('image-to-text', models[0]); - - let url = 'https://huggingface.co/datasets/mishig/sample_images/resolve/main/savanna.jpg'; - let urls = [ - 'https://huggingface.co/datasets/mishig/sample_images/resolve/main/football-match.jpg', - 'https://huggingface.co/datasets/mishig/sample_images/resolve/main/airport.jpg' - ] - - // single - { - let output = await captioner(url, { - top_k: 0, - do_sample: false - }) - // let expected = [ - // { "generated_text": "a herd of giraffes and zebras grazing in a field" } - // ] - - expect(output).toHaveLength(1); - expect(output[0].generated_text.length).toBeGreaterThan(10); - } - - // single + generation options - { - let output = await captioner(url, { - max_new_tokens: 20, - num_beams: 2, - num_return_sequences: 2, - top_k: 0, - do_sample: false - }) - // let expected = [ - // { "generated_text": "a herd of giraffes and zebras grazing in a field" }, - // { "generated_text": "a herd of giraffes and zebras in a grassy field" } - // ] - - expect(output).toHaveLength(2); - expect(output[0].generated_text.length).toBeGreaterThan(10); - expect(output[1].generated_text.length).toBeGreaterThan(10); - - } - - // batched - { - let output = await captioner(urls, { - top_k: 0, - do_sample: false - }) - // let expected = [ - // [{ "generated_text": "two men are kicking a soccer ball in a soccer game" }], - // [{ "generated_text": "a plane on the tarmac with a passenger bus" }] - // ] - - expect(output).toHaveLength(2); - expect(output[0]).toHaveLength(1); - expect(output[0][0].generated_text.length).toBeGreaterThan(10); - expect(output[1]).toHaveLength(1); - expect(output[1][0].generated_text.length).toBeGreaterThan(10); - } - - // batched + generation options - { - let output = await captioner(urls, { - max_new_tokens: 20, - num_beams: 2, - num_return_sequences: 2, - top_k: 0, - do_sample: false - }) - // let expected = [ - // [ - // { "generated_text": "two men are kicking a soccer ball on a field" }, - // { "generated_text": "two men are kicking a soccer ball in a soccer game" } - // ], [ - // { "generated_text": "a plane on a tarmac with a group of buses" }, - // { "generated_text": "a plane on a tarmac with a group of people on the ground" } - // ] - // ]; - - expect(output).toHaveLength(2); - expect(output[0]).toHaveLength(2); - expect(output[0][0].generated_text.length).toBeGreaterThan(10); - expect(output[0][1].generated_text.length).toBeGreaterThan(10); - expect(output[1]).toHaveLength(2); - expect(output[1][0].generated_text.length).toBeGreaterThan(10); - expect(output[1][1].generated_text.length).toBeGreaterThan(10); - - } - await captioner.dispose(); - }, MAX_TEST_EXECUTION_TIME); - }); - - describe('Image classification', () => { - - // List all models which will be tested - const models = [ - 'Xenova/vit-base-patch16-224', - ]; - - it(models[0], async () => { - let classifier = await pipeline('image-classification', models[0]); - - let url = 'https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg'; - let urls = [ - 'https://huggingface.co/datasets/mishig/sample_images/resolve/main/palace.jpg', - 'https://huggingface.co/datasets/mishig/sample_images/resolve/main/teapot.jpg' - ] - - // single - { - let outputs = await classifier(url); - - let expected = [ - { "label": "tiger, Panthera tigris", "score": 0.607988178730011 } - ]; - - compare(outputs, expected, 0.2); - - } - - // single + topk - { - let outputs = await classifier(url, { - topk: 2 - }); - - let expected = [ - { "label": "tiger, Panthera tigris", "score": 0.607988178730011 }, - { "label": "tiger cat", "score": 0.3877776563167572 } - ]; - - compare(outputs, expected, 0.2); - } - - - // batched - { - let outputs = await classifier(urls); - - let expected = [ - { "label": "palace", "score": 0.9986862540245056 }, - { "label": "teapot", "score": 0.987880527973175 } - ]; - - compare(outputs, expected); - } - - // batched + topk - { - let outputs = await classifier(urls, { - topk: 2 - }); - - let expected = [ - [ - { "label": "palace", "score": 0.9986862540245056 }, - { "label": "castle", "score": 0.00037879671435803175 } - ], - [ - { "label": "teapot", "score": 0.987880527973175 }, - { "label": "coffeepot", "score": 0.006591461598873138 } - ] - ]; - - compare(outputs, expected); - } - - await classifier.dispose(); - }, MAX_TEST_EXECUTION_TIME); - }); - - describe('Image segmentation', () => { - - // List all models which will be tested - const models = [ - 'Xenova/detr-resnet-50-panoptic', - 'Xenova/segformer_b2_clothes', - ]; - - it(models[0], async () => { - let segmenter = await pipeline('image-segmentation', models[0], { - // Quantized version of model produces incorrect results - quantized: false, - }) - let img = 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png'; - - // single - { - let outputs = await segmenter(img); - - let expected = [ - { score: 0.9916538596153259, label: 'cat', mask: 58998 }, - { score: 0.9987397789955139, label: 'remote', mask: 4164 }, - { score: 0.9994599223136902, label: 'remote', mask: 2275 }, - { score: 0.9730215072631836, label: 'couch', mask: 176980 }, - { score: 0.9993911385536194, label: 'cat', mask: 52670 } - ]; - - let outputLabels = outputs.map(x => x.label); - let expectedLabels = expected.map(x => x.label); - - expect(outputLabels).toHaveLength(expectedLabels.length); - expect(outputLabels.sort()).toEqual(expectedLabels.sort()) - } - - await segmenter.dispose(); - - }, MAX_TEST_EXECUTION_TIME); - - it(models[1], async () => { - let segmenter = await pipeline('image-segmentation', models[1]); - let img = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/young-man-standing-and-leaning-on-car.jpg'; - - // single - { - let outputs = await segmenter(img); - - let expected = [ - { label: 'Background' }, - { label: 'Hair' }, - { label: 'Upper-clothes' }, - { label: 'Pants' }, - { label: 'Left-shoe' }, - { label: 'Right-shoe' }, - { label: 'Face' }, - { label: 'Left-leg' }, - { label: 'Right-leg' }, - { label: 'Left-arm' }, - { label: 'Right-arm' }, - ]; - - let outputLabels = outputs.map(x => x.label); - let expectedLabels = expected.map(x => x.label); - - expect(outputLabels).toHaveLength(expectedLabels.length); - expect(outputLabels.sort()).toEqual(expectedLabels.sort()) - - // check that all scores are null, and masks have correct dimensions - for (let output of outputs) { - expect(output.score).toBeNull(); - expect(output.mask.width).toEqual(970); - expect(output.mask.height).toEqual(1455); - expect(output.mask.channels).toEqual(1); - } - } - - await segmenter.dispose(); - - }, MAX_TEST_EXECUTION_TIME); - }); - - describe('Zero-shot image classification', () => { - - // List all models which will be tested - const models = [ - 'Xenova/clip-vit-base-patch32', - ]; - - it(models[0], async () => { - let classifier = await pipeline('zero-shot-image-classification', models[0]); - - let url = 'https://huggingface.co/datasets/mishig/sample_images/resolve/main/football-match.jpg'; - let urls = [ - 'https://huggingface.co/datasets/mishig/sample_images/resolve/main/football-match.jpg', - 'https://huggingface.co/datasets/mishig/sample_images/resolve/main/airport.jpg', - 'https://huggingface.co/datasets/mishig/sample_images/resolve/main/savanna.jpg', - ] - - let classes = ['football', 'airport', 'animals']; - - // single - { - let output = await classifier(url, classes); - - let expected = [ - { score: 0.9719080924987793, label: 'football' }, - { score: 0.022564826533198357, label: 'animals' }, - { score: 0.005527070723474026, label: 'airport' } - ] - compare(output, expected, 0.1); - - } - - - // batched - { - let output = await classifier(urls, classes); - - let expected = [ - [ - { score: 0.9712504148483276, label: 'football' }, - { score: 0.022469401359558105, label: 'animals' }, - { score: 0.006280169822275639, label: 'airport' } - ], [ - { score: 0.997433602809906, label: 'airport' }, - { score: 0.0016500800848007202, label: 'animals' }, - { score: 0.0009163151844404638, label: 'football' } - ], [ - { score: 0.9851226806640625, label: 'animals' }, - { score: 0.007516484707593918, label: 'football' }, - { score: 0.007360846735537052, label: 'airport' } - ] - ]; - compare(output, expected, 0.1); - - } - await classifier.dispose(); - }, MAX_TEST_EXECUTION_TIME); - }); - - describe('Object detection', () => { - - // List all models which will be tested - const models = [ - 'Xenova/detr-resnet-50', - ]; - - it(models[0], async () => { - let detector = await pipeline('object-detection', models[0]); - - // TODO add batched test cases when supported - let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cats.jpg'; - let urls = ['https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/savanna.jpg'] - - // single + threshold - { - let output = await detector(url, { - threshold: 0.9, - }); - - // let expected = [ - // { - // "score": 0.9977124929428101, - // "label": "remote", - // "box": { "xmin": 41, "ymin": 70, "xmax": 176, "ymax": 118 } - // }, - // { - // "score": 0.9984639883041382, - // "label": "remote", - // "box": { "xmin": 332, "ymin": 73, "xmax": 369, "ymax": 188 } - // }, - // { - // "score": 0.9964856505393982, - // "label": "couch", - // "box": { "xmin": 0, "ymin": 1, "xmax": 639, "ymax": 474 } - // }, - // { - // "score": 0.9988334774971008, - // "label": "cat", - // "box": { "xmin": 11, "ymin": 51, "xmax": 314, "ymax": 472 } - // }, - // { - // "score": 0.9982513785362244, - // "label": "cat", - // "box": { "xmin": 345, "ymin": 22, "xmax": 640, "ymax": 371 } - // } - // ] - - expect(output.length).toBeGreaterThan(0); - for (let cls of output) { - expect(typeof cls.score).toBe('number'); - expect(typeof cls.label).toBe('string'); - for (let key of ['xmin', 'ymin', 'xmax', 'ymax']) { - expect(typeof cls.box[key]).toBe('number'); - } - } - } - - // batched + threshold + percentage - { - let output = await detector(urls, { - threshold: 0.9, - percentage: true - }); - // let expected = [[ - // { - // score: 0.9991137385368347, - // label: 'zebra', - // box: { xmin: 0.65165576338768, ymin: 0.685152679681778, xmax: 0.723189502954483, ymax: 0.8801506459712982 } - // }, - // { - // score: 0.998811662197113, - // label: 'zebra', - // box: { xmin: 0.20797613263130188, ymin: 0.6543092578649521, xmax: 0.4147692620754242, ymax: 0.9040975719690323 } - // }, - // { - // score: 0.9707837104797363, - // label: 'giraffe', - // box: { xmin: 0.02498096227645874, ymin: 0.40549489855766296, xmax: 0.38669759035110474, ymax: 0.7895723879337311 } - // }, - // { - // score: 0.9984336495399475, - // label: 'zebra', - // box: { xmin: 0.3540637195110321, ymin: 0.6370827257633209, xmax: 0.5765090882778168, ymax: 0.8480959832668304 } - // }, - // { - // score: 0.9986463785171509, - // label: 'giraffe', - // box: { xmin: 0.6763969212770462, ymin: 0.25748637318611145, xmax: 0.974339172244072, ymax: 0.8684568107128143 } - // } - // ]] - - expect(output).toHaveLength(urls.length); // Same number of inputs as outputs - - for (let i = 0; i < output.length; ++i) { - expect(output[i].length).toBeGreaterThan(0); - for (let cls of output[i]) { - expect(typeof cls.score).toBe('number'); - expect(typeof cls.label).toBe('string'); - for (let key of ['xmin', 'ymin', 'xmax', 'ymax']) { - expect(typeof cls.box[key]).toBe('number'); - } - } - } - - - } - - await detector.dispose(); - }, MAX_TEST_EXECUTION_TIME); - }); - - describe('Zero-shot object detection', () => { - - // List all models which will be tested - const models = [ - 'Xenova/owlvit-base-patch32', - ]; - - it(models[0], async () => { - let detector = await pipeline('zero-shot-object-detection', models[0]); - - - // single (default) - { - let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/astronaut.png'; - let candidate_labels = ['human face', 'rocket', 'helmet', 'american flag']; - - let output = await detector(url, candidate_labels); - - // let expected = [ - // { - // score: 0.24392342567443848, - // label: 'human face', - // box: { xmin: 180, ymin: 67, xmax: 274, ymax: 175 } - // }, - // { - // score: 0.15129457414150238, - // label: 'american flag', - // box: { xmin: 0, ymin: 4, xmax: 106, ymax: 513 } - // }, - // { - // score: 0.13649864494800568, - // label: 'helmet', - // box: { xmin: 277, ymin: 337, xmax: 511, ymax: 511 } - // }, - // { - // score: 0.10262022167444229, - // label: 'rocket', - // box: { xmin: 352, ymin: -1, xmax: 463, ymax: 287 } - // } - // ] - - expect(output.length).toBeGreaterThan(0); - for (let cls of output) { - expect(typeof cls.score).toBe('number'); - expect(typeof cls.label).toBe('string'); - for (let key of ['xmin', 'ymin', 'xmax', 'ymax']) { - expect(typeof cls.box[key]).toBe('number'); - } - } - } - - // topk + threshold + percentage - { - let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/beach.png'; - let candidate_labels = ['hat', 'book', 'sunglasses', 'camera']; - - let output = await detector(url, candidate_labels, { - topk: 4, - threshold: 0.05, - percentage: true, - }); - - // let expected = [ - // { - // score: 0.1606510728597641, - // label: 'sunglasses', - // box: { xmin: 347, ymin: 229, xmax: 429, ymax: 264 } - // }, - // { - // score: 0.08935828506946564, - // label: 'hat', - // box: { xmin: 38, ymin: 174, xmax: 258, ymax: 364 } - // }, - // { - // score: 0.08530698716640472, - // label: 'camera', - // box: { xmin: 187, ymin: 350, xmax: 260, ymax: 411 } - // }, - // { - // score: 0.08349756896495819, - // label: 'book', - // box: { xmin: 261, ymin: 280, xmax: 494, ymax: 425 } - // } - // ] - - expect(output.length).toBeGreaterThan(0); - for (let cls of output) { - expect(typeof cls.score).toBe('number'); - expect(typeof cls.label).toBe('string'); - for (let key of ['xmin', 'ymin', 'xmax', 'ymax']) { - expect(typeof cls.box[key]).toBe('number'); - } - } - } - - await detector.dispose(); - }, MAX_TEST_EXECUTION_TIME); - }); - - describe('Image-to-image', () => { - - // List all models which will be tested - const models = [ - 'Xenova/swin2SR-classical-sr-x2-64', - ]; - - it(models[0], async () => { - let upscaler = await pipeline('image-to-image', models[0]); - - // Input is 3x3 => padded to 8x8 => upscaled to 16x16 - let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/pattern_3x3.png'; - - // single - { - let outputs = await upscaler(url); - expect(outputs.width).toEqual(16); - expect(outputs.height).toEqual(16); - expect(outputs.channels).toEqual(3); - expect(outputs.data).toHaveLength(768); - } - - // batched - { - let outputs = await upscaler([url, url]); - expect(outputs).toHaveLength(2); - for (let output of outputs) { - expect(output.width).toEqual(16); - expect(output.height).toEqual(16); - expect(output.channels).toEqual(3); - expect(output.data).toHaveLength(768); - } - } - - await upscaler.dispose(); - }, MAX_TEST_EXECUTION_TIME); - }); - - describe('Depth estimation', () => { - - // List all models which will be tested - const models = [ - 'Xenova/dpt-hybrid-midas', - ]; - - it(models[0], async () => { - let depth_estimator = await pipeline('depth-estimation', models[0]); - - let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cats.jpg'; - - // single - { - let { predicted_depth, depth } = await depth_estimator(url); - compare(predicted_depth.dims, [384, 384]); - expect(depth.width).toEqual(640); - expect(depth.height).toEqual(480); - expect(depth.channels).toEqual(1); - expect(depth.data).toHaveLength(307200); - } - - // batched - { - let outputs = await depth_estimator([url, url]); - expect(outputs).toHaveLength(2); - for (let output of outputs) { - let { predicted_depth, depth } = output; - compare(predicted_depth.dims, [384, 384]); - expect(depth.width).toEqual(640); - expect(depth.height).toEqual(480); - expect(depth.channels).toEqual(1); - expect(depth.data).toHaveLength(307200); - } - } - - await depth_estimator.dispose(); - }, MAX_TEST_EXECUTION_TIME); - }); - - describe('Document question answering', () => { - - // List all models which will be tested - const models = [ - 'Xenova/donut-base-finetuned-docvqa', - ]; - const image = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/invoice.png'; - const question = 'What is the invoice number?'; - - it(models[0], async () => { - let qa_pipeline = await pipeline('document-question-answering', models[0]); - - // basic - { - let output = await qa_pipeline(image, question); - let expected = [{ answer: 'us-001' }]; - compare(output, expected); - } - - await qa_pipeline.dispose(); - - }, MAX_TEST_EXECUTION_TIME); - }); - + } + } + + await detector.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Image-to-image", () => { + // List all models which will be tested + const models = ["Xenova/swin2SR-classical-sr-x2-64"]; + + it( + models[0], + async () => { + let upscaler = await pipeline("image-to-image", models[0]); + + // Input is 3x3 => padded to 8x8 => upscaled to 16x16 + let url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/pattern_3x3.png"; + + // single + { + let outputs = await upscaler(url); + expect(outputs.width).toEqual(16); + expect(outputs.height).toEqual(16); + expect(outputs.channels).toEqual(3); + expect(outputs.data).toHaveLength(768); + } + + // batched + { + let outputs = await upscaler([url, url]); + expect(outputs).toHaveLength(2); + for (let output of outputs) { + expect(output.width).toEqual(16); + expect(output.height).toEqual(16); + expect(output.channels).toEqual(3); + expect(output.data).toHaveLength(768); + } + } + + await upscaler.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Depth estimation", () => { + // List all models which will be tested + const models = ["Xenova/dpt-hybrid-midas"]; + + it( + models[0], + async () => { + let depth_estimator = await pipeline("depth-estimation", models[0]); + + let url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cats.jpg"; + + // single + { + let { predicted_depth, depth } = await depth_estimator(url); + compare(predicted_depth.dims, [384, 384]); + expect(depth.width).toEqual(640); + expect(depth.height).toEqual(480); + expect(depth.channels).toEqual(1); + expect(depth.data).toHaveLength(307200); + } + + // batched + { + let outputs = await depth_estimator([url, url]); + expect(outputs).toHaveLength(2); + for (let output of outputs) { + let { predicted_depth, depth } = output; + compare(predicted_depth.dims, [384, 384]); + expect(depth.width).toEqual(640); + expect(depth.height).toEqual(480); + expect(depth.channels).toEqual(1); + expect(depth.data).toHaveLength(307200); + } + } + + await depth_estimator.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Document question answering", () => { + // List all models which will be tested + const models = ["Xenova/donut-base-finetuned-docvqa"]; + const image = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/invoice.png"; + const question = "What is the invoice number?"; + + it( + models[0], + async () => { + let qa_pipeline = await pipeline("document-question-answering", models[0]); + + // basic + { + let output = await qa_pipeline(image, question); + let expected = [{ answer: "us-001" }]; + compare(output, expected); + } + + await qa_pipeline.dispose(); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); }); diff --git a/tests/processors.test.js b/tests/processors.test.js index ae6b11311..caf1ddf86 100644 --- a/tests/processors.test.js +++ b/tests/processors.test.js @@ -1,912 +1,1018 @@ - -import { env, AutoProcessor, RawImage } from '../src/transformers.js'; -import { init, MAX_TEST_EXECUTION_TIME } from './init.js'; -import { compare } from './test_utils.js'; +import { env, AutoProcessor, RawImage } from "../src/transformers.js"; +import { init, MAX_TEST_EXECUTION_TIME } from "./init.js"; +import { compare } from "./test_utils.js"; // Initialise the testing environment init(); env.allowLocalModels = false; env.useFSCache = false; -const sum = array => Number(array.reduce((a, b) => a + b, array instanceof BigInt64Array ? 0n : 0)); -const avg = array => sum(array) / array.length; +const sum = (array) => Number(array.reduce((a, b) => a + b, array instanceof BigInt64Array ? 0n : 0)); +const avg = (array) => sum(array) / array.length; const IMAGE_CACHE = new Map(); const load_image = async (url) => { - const cached = IMAGE_CACHE.get(url); - if (cached) { - return cached; - } - const image = await RawImage.fromURL(url); - IMAGE_CACHE.set(url, image); - return image; -} + const cached = IMAGE_CACHE.get(url); + if (cached) { + return cached; + } + const image = await RawImage.fromURL(url); + IMAGE_CACHE.set(url, image); + return image; +}; const MODELS = { - swin2sr: 'Xenova/swin2SR-classical-sr-x2-64', - sam: 'Xenova/sam-vit-base', - 'donut-swin': 'Xenova/donut-base-finetuned-cord-v2', - resnet: 'Xenova/resnet-50', - vit: 'Xenova/vit-base-patch16-224', - mobilevit: 'Xenova/mobilevit-small', - mobilevit_2: 'Xenova/quickdraw-mobilevit-small', - mobilevit_3: 'Xenova/mobilevitv2-1.0-imagenet1k-256', - deit: 'Xenova/deit-tiny-distilled-patch16-224', - beit: 'Xenova/beit-base-patch16-224-pt22k-ft22k', - detr: 'Xenova/detr-resnet-50', - yolos: 'Xenova/yolos-small-300', - dpt: 'Xenova/dpt-hybrid-midas', - dpt_2: 'Xenova/depth-anything-small-hf', - glpn: 'Xenova/glpn-kitti', - nougat: 'Xenova/nougat-small', - owlvit: 'Xenova/owlvit-base-patch32', - clip: 'Xenova/clip-vit-base-patch16', - vitmatte: 'Xenova/vitmatte-small-distinctions-646', - dinov2: 'Xenova/dinov2-small-imagenet1k-1-layer', - // efficientnet: 'Xenova/efficientnet-b0', - florence2: 'Xenova/tiny-random-Florence2ForConditionalGeneration', -} + swin2sr: "Xenova/swin2SR-classical-sr-x2-64", + sam: "Xenova/sam-vit-base", + "donut-swin": "Xenova/donut-base-finetuned-cord-v2", + resnet: "Xenova/resnet-50", + vit: "Xenova/vit-base-patch16-224", + mobilevit: "Xenova/mobilevit-small", + mobilevit_2: "Xenova/quickdraw-mobilevit-small", + mobilevit_3: "Xenova/mobilevitv2-1.0-imagenet1k-256", + deit: "Xenova/deit-tiny-distilled-patch16-224", + beit: "Xenova/beit-base-patch16-224-pt22k-ft22k", + detr: "Xenova/detr-resnet-50", + yolos: "Xenova/yolos-small-300", + dpt: "Xenova/dpt-hybrid-midas", + dpt_2: "Xenova/depth-anything-small-hf", + glpn: "Xenova/glpn-kitti", + nougat: "Xenova/nougat-small", + owlvit: "Xenova/owlvit-base-patch32", + clip: "Xenova/clip-vit-base-patch16", + vitmatte: "Xenova/vitmatte-small-distinctions-646", + dinov2: "Xenova/dinov2-small-imagenet1k-1-layer", + // efficientnet: 'Xenova/efficientnet-b0', + florence2: "Xenova/tiny-random-Florence2ForConditionalGeneration", +}; const TEST_IMAGES = { - pattern_3x3: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/pattern_3x3.png', - pattern_3x5: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/pattern_3x5.png', - checkerboard_8x8: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/checkerboard_8x8.png', - checkerboard_64x32: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/checkerboard_64x32.png', - receipt: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/receipt.png', - tiger: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg', - paper: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/nougat_paper.png', - cats: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cats.jpg', - - // grayscale image - skateboard: 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/ml-web-games/skateboard.png', - - vitmatte_image: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/vitmatte_image.png', - vitmatte_trimap: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/vitmatte_trimap.png', - - beetle: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/beetle.png', - book_cover: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/book-cover.png', -} - -describe('Processors', () => { - - describe('Image processors', () => { - - // Swin2SRImageProcessor - // - tests when padding is a number (do_pad=true, pad_size=8) - it(MODELS.swin2sr, async () => { - const processor = await AutoProcessor.from_pretrained(MODELS.swin2sr); - - { // Pad to multiple of 8 (3x3 -> 8x8) - const image = await load_image(TEST_IMAGES.pattern_3x3); - const { pixel_values } = await processor(image); - - compare(pixel_values.dims, [1, 3, 8, 8]); - compare(avg(pixel_values.data), 0.5458333368102709); - } - - { // Do not pad if already a multiple of 8 (8x8 -> 8x8) - const image = await load_image(TEST_IMAGES.checkerboard_8x8); - const { pixel_values } = await processor(image); - compare(pixel_values.dims, [1, 3, 8, 8]); - compare(avg(pixel_values.data), 0.5); - } - }, MAX_TEST_EXECUTION_TIME); - - // SamProcessor/SamImageProcessor - // - tests normal padding (do_pad=true, pad_size={"height":1024,"width":1024}) - // - In addition to the image, pass in a list of points - it(MODELS.sam, async () => { - const processor = await AutoProcessor.from_pretrained(MODELS.sam) - - { // without input points - const image = await load_image(TEST_IMAGES.pattern_3x3); - const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); - compare(pixel_values.dims, [1, 3, 1024, 1024]); - compare(avg(pixel_values.data), -0.4505715670146813); - - compare(original_sizes, [[3, 3]]); - compare(reshaped_input_sizes, [[1024, 1024]]); - } - - { // with input points - const image = await load_image(TEST_IMAGES.pattern_3x3); - const { original_sizes, reshaped_input_sizes, input_points } = await processor(image, { - input_points: [[[1, 2]]], - }); - - compare(original_sizes, [[3, 3]]); - compare(reshaped_input_sizes, [[1024, 1024]]); - compare(input_points.tolist(), [[[[341.3333, 682.6667]]]]); - } - - { // multiple points with labels - const image = await load_image(TEST_IMAGES.pattern_3x3); - const { original_sizes, reshaped_input_sizes, input_points, input_labels } = await processor(image, { - input_points: [[[1, 2], [2, 1]]], - input_labels: [[1, 0]], - }); - - compare(original_sizes, [[3, 3]]); - compare(reshaped_input_sizes, [[1024, 1024]]); - compare(input_points.tolist(), [[[[341.3333, 682.6667], [682.6667, 341.3333]]]]); - compare(input_labels.tolist(), [[[1n, 0n]]]); - } - - { // with input boxes - const image = await load_image(TEST_IMAGES.pattern_3x3); - const { original_sizes, reshaped_input_sizes, input_boxes } = await processor(image, { - input_boxes: [[[0, 1, 2, 2]]], - }); - - compare(original_sizes, [[3, 3]]); - compare(reshaped_input_sizes, [[1024, 1024]]); - compare(input_boxes.tolist(), [[[0, 341.3333, 682.6667, 682.6667]]]); - } - }, MAX_TEST_EXECUTION_TIME); - - // DonutProcessor/DonutFeatureExtractor - // - tests thumbnail resizing (do_thumbnail=true, size=[960, 1280]) - // - tests padding after normalization (image_mean=image_std=0.5) - it(MODELS['donut-swin'], async () => { - const processor = await AutoProcessor.from_pretrained(MODELS['donut-swin']) - - { - const image = await load_image(TEST_IMAGES.receipt); - const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); - - compare(pixel_values.dims, [1, 3, 1280, 960]); - compare(avg(pixel_values.data), 0.1229388610053704); - - compare(original_sizes, [[864, 576]]); - compare(reshaped_input_sizes, [[1280, 853]]); - } - }, MAX_TEST_EXECUTION_TIME); - - // ConvNextFeatureExtractor - it(MODELS.resnet, async () => { - const processor = await AutoProcessor.from_pretrained(MODELS.resnet) - - { - const image = await load_image(TEST_IMAGES.tiger); - const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); - - compare(pixel_values.dims, [1, 3, 224, 224]); - compare(avg(pixel_values.data), 0.06262318789958954); - - compare(original_sizes, [[408, 612]]); - compare(reshaped_input_sizes, [[224, 224]]); - } - }, MAX_TEST_EXECUTION_TIME); - - // ViTFeatureExtractor - it(MODELS.vit, async () => { - const processor = await AutoProcessor.from_pretrained(MODELS.vit) - - { - const image = await load_image(TEST_IMAGES.tiger); - const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); - - compare(pixel_values.dims, [1, 3, 224, 224]); - compare(avg(pixel_values.data), -0.22706867939852762); - - compare(original_sizes, [[408, 612]]); - compare(reshaped_input_sizes, [[224, 224]]); - } - }, MAX_TEST_EXECUTION_TIME); - - // MobileViTFeatureExtractor - it(MODELS.mobilevit, async () => { - const processor = await AutoProcessor.from_pretrained(MODELS.mobilevit) - - { - const image = await load_image(TEST_IMAGES.tiger); - const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); - - compare(pixel_values.dims, [1, 3, 256, 256]); - compare(avg(pixel_values.data), 0.4599160496887033); - - compare(original_sizes, [[408, 612]]); - compare(reshaped_input_sizes, [[256, 256]]); - } - }, MAX_TEST_EXECUTION_TIME); - - // MobileViTFeatureExtractor - // - tests not converting to rgb (do_convert_rgb=false) - it(MODELS.mobilevit_2, async () => { - const processor = await AutoProcessor.from_pretrained(MODELS.mobilevit_2) - - { // Tests grayscale image - const image = await load_image(TEST_IMAGES.skateboard); - const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); - - compare(pixel_values.dims, [1, 1, 28, 28]); - compare(avg(pixel_values.data), 0.08558923671585128); - - compare(original_sizes, [[28, 28]]); - compare(reshaped_input_sizes, [[28, 28]]); - } - }, MAX_TEST_EXECUTION_TIME); - - // MobileViTImageProcessor - // - tests converting RGB to BGR (do_flip_channel_order=true) - it(MODELS.mobilevit_3, async () => { - const processor = await AutoProcessor.from_pretrained(MODELS.mobilevit_3) - - { - const image = await load_image(TEST_IMAGES.cats); - const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); - - compare(pixel_values.dims, [1, 3, 256, 256]); - compare(avg(pixel_values.data), 0.5215385556221008); - - compare(original_sizes, [[480, 640]]); - compare(reshaped_input_sizes, [[256, 256]]); - - // Ensure RGB to BGR conversion - compare(pixel_values.data.slice(0, 3), [0.24313725531101227, 0.250980406999588, 0.364705890417099]); - } - }, MAX_TEST_EXECUTION_TIME); - - // DeiTFeatureExtractor - it(MODELS.deit, async () => { - const processor = await AutoProcessor.from_pretrained(MODELS.deit) - - { - const image = await load_image(TEST_IMAGES.tiger); - const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); - - compare(pixel_values.dims, [1, 3, 224, 224]); - compare(avg(pixel_values.data), -0.2760336682859463); - - compare(original_sizes, [[408, 612]]); - compare(reshaped_input_sizes, [[224, 224]]); - } - }, MAX_TEST_EXECUTION_TIME); - - // BeitFeatureExtractor - it(MODELS.beit, async () => { - const processor = await AutoProcessor.from_pretrained(MODELS.beit) - - { - const image = await load_image(TEST_IMAGES.tiger); - const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); - - compare(pixel_values.dims, [1, 3, 224, 224]); - compare(avg(pixel_values.data), -0.22706867939852762); + pattern_3x3: "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/pattern_3x3.png", + pattern_3x5: "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/pattern_3x5.png", + checkerboard_8x8: "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/checkerboard_8x8.png", + checkerboard_64x32: "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/checkerboard_64x32.png", + receipt: "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/receipt.png", + tiger: "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg", + paper: "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/nougat_paper.png", + cats: "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cats.jpg", + + // grayscale image + skateboard: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/ml-web-games/skateboard.png", + + vitmatte_image: "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/vitmatte_image.png", + vitmatte_trimap: "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/vitmatte_trimap.png", + + beetle: "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/beetle.png", + book_cover: "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/book-cover.png", +}; + +describe("Processors", () => { + describe("Image processors", () => { + // Swin2SRImageProcessor + // - tests when padding is a number (do_pad=true, pad_size=8) + it( + MODELS.swin2sr, + async () => { + const processor = await AutoProcessor.from_pretrained(MODELS.swin2sr); + + { + // Pad to multiple of 8 (3x3 -> 8x8) + const image = await load_image(TEST_IMAGES.pattern_3x3); + const { pixel_values } = await processor(image); + + compare(pixel_values.dims, [1, 3, 8, 8]); + compare(avg(pixel_values.data), 0.5458333368102709); + } + + { + // Do not pad if already a multiple of 8 (8x8 -> 8x8) + const image = await load_image(TEST_IMAGES.checkerboard_8x8); + const { pixel_values } = await processor(image); + compare(pixel_values.dims, [1, 3, 8, 8]); + compare(avg(pixel_values.data), 0.5); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + // SamProcessor/SamImageProcessor + // - tests normal padding (do_pad=true, pad_size={"height":1024,"width":1024}) + // - In addition to the image, pass in a list of points + it( + MODELS.sam, + async () => { + const processor = await AutoProcessor.from_pretrained(MODELS.sam); + + { + // without input points + const image = await load_image(TEST_IMAGES.pattern_3x3); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + compare(pixel_values.dims, [1, 3, 1024, 1024]); + compare(avg(pixel_values.data), -0.4505715670146813); + + compare(original_sizes, [[3, 3]]); + compare(reshaped_input_sizes, [[1024, 1024]]); + } + + { + // with input points + const image = await load_image(TEST_IMAGES.pattern_3x3); + const { original_sizes, reshaped_input_sizes, input_points } = await processor(image, { + input_points: [[[1, 2]]], + }); + + compare(original_sizes, [[3, 3]]); + compare(reshaped_input_sizes, [[1024, 1024]]); + compare(input_points.tolist(), [[[[341.3333, 682.6667]]]]); + } + + { + // multiple points with labels + const image = await load_image(TEST_IMAGES.pattern_3x3); + const { original_sizes, reshaped_input_sizes, input_points, input_labels } = await processor(image, { + input_points: [ + [ + [1, 2], + [2, 1], + ], + ], + input_labels: [[1, 0]], + }); + + compare(original_sizes, [[3, 3]]); + compare(reshaped_input_sizes, [[1024, 1024]]); + compare(input_points.tolist(), [ + [ + [ + [341.3333, 682.6667], + [682.6667, 341.3333], + ], + ], + ]); + compare(input_labels.tolist(), [[[1n, 0n]]]); + } + + { + // with input boxes + const image = await load_image(TEST_IMAGES.pattern_3x3); + const { original_sizes, reshaped_input_sizes, input_boxes } = await processor(image, { + input_boxes: [[[0, 1, 2, 2]]], + }); + + compare(original_sizes, [[3, 3]]); + compare(reshaped_input_sizes, [[1024, 1024]]); + compare(input_boxes.tolist(), [[[0, 341.3333, 682.6667, 682.6667]]]); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + // DonutProcessor/DonutFeatureExtractor + // - tests thumbnail resizing (do_thumbnail=true, size=[960, 1280]) + // - tests padding after normalization (image_mean=image_std=0.5) + it( + MODELS["donut-swin"], + async () => { + const processor = await AutoProcessor.from_pretrained(MODELS["donut-swin"]); + + { + const image = await load_image(TEST_IMAGES.receipt); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 1280, 960]); + compare(avg(pixel_values.data), 0.1229388610053704); + + compare(original_sizes, [[864, 576]]); + compare(reshaped_input_sizes, [[1280, 853]]); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + // ConvNextFeatureExtractor + it( + MODELS.resnet, + async () => { + const processor = await AutoProcessor.from_pretrained(MODELS.resnet); + + { + const image = await load_image(TEST_IMAGES.tiger); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 224, 224]); + compare(avg(pixel_values.data), 0.06262318789958954); + + compare(original_sizes, [[408, 612]]); + compare(reshaped_input_sizes, [[224, 224]]); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + // ViTFeatureExtractor + it( + MODELS.vit, + async () => { + const processor = await AutoProcessor.from_pretrained(MODELS.vit); + + { + const image = await load_image(TEST_IMAGES.tiger); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 224, 224]); + compare(avg(pixel_values.data), -0.22706867939852762); + + compare(original_sizes, [[408, 612]]); + compare(reshaped_input_sizes, [[224, 224]]); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + // MobileViTFeatureExtractor + it( + MODELS.mobilevit, + async () => { + const processor = await AutoProcessor.from_pretrained(MODELS.mobilevit); + + { + const image = await load_image(TEST_IMAGES.tiger); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 256, 256]); + compare(avg(pixel_values.data), 0.4599160496887033); + + compare(original_sizes, [[408, 612]]); + compare(reshaped_input_sizes, [[256, 256]]); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + // MobileViTFeatureExtractor + // - tests not converting to rgb (do_convert_rgb=false) + it( + MODELS.mobilevit_2, + async () => { + const processor = await AutoProcessor.from_pretrained(MODELS.mobilevit_2); + + { + // Tests grayscale image + const image = await load_image(TEST_IMAGES.skateboard); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 1, 28, 28]); + compare(avg(pixel_values.data), 0.08558923671585128); + + compare(original_sizes, [[28, 28]]); + compare(reshaped_input_sizes, [[28, 28]]); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + // MobileViTImageProcessor + // - tests converting RGB to BGR (do_flip_channel_order=true) + it( + MODELS.mobilevit_3, + async () => { + const processor = await AutoProcessor.from_pretrained(MODELS.mobilevit_3); + + { + const image = await load_image(TEST_IMAGES.cats); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 256, 256]); + compare(avg(pixel_values.data), 0.5215385556221008); + + compare(original_sizes, [[480, 640]]); + compare(reshaped_input_sizes, [[256, 256]]); + + // Ensure RGB to BGR conversion + compare(pixel_values.data.slice(0, 3), [0.24313725531101227, 0.250980406999588, 0.364705890417099]); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + // DeiTFeatureExtractor + it( + MODELS.deit, + async () => { + const processor = await AutoProcessor.from_pretrained(MODELS.deit); + + { + const image = await load_image(TEST_IMAGES.tiger); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 224, 224]); + compare(avg(pixel_values.data), -0.2760336682859463); + + compare(original_sizes, [[408, 612]]); + compare(reshaped_input_sizes, [[224, 224]]); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + // BeitFeatureExtractor + it( + MODELS.beit, + async () => { + const processor = await AutoProcessor.from_pretrained(MODELS.beit); + + { + const image = await load_image(TEST_IMAGES.tiger); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 224, 224]); + compare(avg(pixel_values.data), -0.22706867939852762); + + compare(original_sizes, [[408, 612]]); + compare(reshaped_input_sizes, [[224, 224]]); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + // DetrFeatureExtractor + it( + MODELS.detr, + async () => { + const processor = await AutoProcessor.from_pretrained(MODELS.detr); + + { + const image = await load_image(TEST_IMAGES.tiger); + const { pixel_values, original_sizes, reshaped_input_sizes, pixel_mask } = await processor(image); + + compare(pixel_values.dims, [1, 3, 888, 1333]); + compare(avg(pixel_values.data), -0.27840224131001773); + + compare(original_sizes, [[408, 612]]); + compare(reshaped_input_sizes, [[888, 1333]]); + + compare(pixel_mask.dims, [1, 64, 64]); + compare(avg(pixel_mask.data), 1); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + // YolosFeatureExtractor + it( + MODELS.yolos, + async () => { + const processor = await AutoProcessor.from_pretrained(MODELS.yolos); + + { + const image = await load_image(TEST_IMAGES.tiger); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 888, 1333]); + compare(avg(pixel_values.data), -0.27840224131001773); + + compare(original_sizes, [[408, 612]]); + compare(reshaped_input_sizes, [[888, 1333]]); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + // DPTFeatureExtractor + it( + MODELS.dpt, + async () => { + const processor = await AutoProcessor.from_pretrained(MODELS.dpt); + + { + // Tests grayscale image + const image = await load_image(TEST_IMAGES.cats); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 384, 384]); + compare(avg(pixel_values.data), 0.0372855559389454); + + compare(original_sizes, [[480, 640]]); + compare(reshaped_input_sizes, [[384, 384]]); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + // GLPNForDepthEstimation + // - tests `size_divisor` and no size (size_divisor=32) + it( + MODELS.glpn, + async () => { + const processor = await AutoProcessor.from_pretrained(MODELS.glpn); + + { + const image = await load_image(TEST_IMAGES.cats); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + compare(pixel_values.dims, [1, 3, 480, 640]); + compare(avg(pixel_values.data), 0.5186172404123327); + + compare(original_sizes, [[480, 640]]); + compare(reshaped_input_sizes, [[480, 640]]); + } + + { + // Tests input which is not a multiple of 32 ([408, 612] -> [384, 608]) + const image = await load_image(TEST_IMAGES.tiger); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 384, 608]); + compare(avg(pixel_values.data), 0.38628831535989555); + + compare(original_sizes, [[408, 612]]); + compare(reshaped_input_sizes, [[384, 608]]); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + // NougatImageProcessor + // - tests padding after normalization (image_mean != 0.5, image_std != 0.5) + it( + MODELS.nougat, + async () => { + const processor = await AutoProcessor.from_pretrained(MODELS.nougat); + + { + const image = await load_image(TEST_IMAGES.paper); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 896, 672]); + compare(avg(pixel_values.data), 1.8447155005897355); + + compare(original_sizes, [[850, 685]]); + compare(reshaped_input_sizes, [[833, 672]]); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + // OwlViTFeatureExtractor + it(MODELS.owlvit, async () => { + const processor = await AutoProcessor.from_pretrained(MODELS.owlvit); + { + const image = await load_image(TEST_IMAGES.cats); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 768, 768]); + compare(avg(pixel_values.data), 0.250620447910435); + + compare(original_sizes, [[480, 640]]); + compare(reshaped_input_sizes, [[768, 768]]); + } + }); - compare(original_sizes, [[408, 612]]); - compare(reshaped_input_sizes, [[224, 224]]); - } - }, MAX_TEST_EXECUTION_TIME); + // CLIPFeatureExtractor + // - tests center crop (do_center_crop=true, crop_size=224) + it( + MODELS.clip, + async () => { + const processor = await AutoProcessor.from_pretrained(MODELS.clip); + + { + const image = await load_image(TEST_IMAGES.tiger); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 224, 224]); + compare(avg(pixel_values.data), -0.06678297738282096); + + compare(original_sizes, [[408, 612]]); + compare(reshaped_input_sizes, [[224, 224]]); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + // VitMatteImageProcessor + // - tests custom overrides + // - tests multiple inputs + // - tests `size_divisibility` and no size (size_divisibility=32) + // - tests do_pad and `size_divisibility` + it( + MODELS.vitmatte, + async () => { + const processor = await AutoProcessor.from_pretrained(MODELS.vitmatte); + + { + const image = await load_image(TEST_IMAGES.vitmatte_image); + const image2 = await load_image(TEST_IMAGES.vitmatte_trimap); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image, image2); + + compare(pixel_values.dims, [1, 4, 640, 960]); + expect(avg(pixel_values.data)).toBeCloseTo(-0.4028555154800415); + expect(pixel_values.data[0]).toBeCloseTo(-0.9921568632125854); + expect(pixel_values.data[1]).toBeCloseTo(-0.9921568632125854); + expect(pixel_values.data[5]).toBeCloseTo(-1.0); + expect(pixel_values.data[640]).toBeCloseTo(-0.6784313917160034); + expect(pixel_values.data[641]).toBeCloseTo(-0.6705882549285889); + expect(pixel_values.data[640 * 960]).toBeCloseTo(-1.0); + expect(pixel_values.data[640 * 960 + 1]).toBeCloseTo(-1.0); + expect(pixel_values.data.at(-1)).toBeCloseTo(0.0); + + compare(original_sizes, [[640, 960]]); + compare(reshaped_input_sizes, [[640, 960]]); + } + + { + const image = await load_image(TEST_IMAGES.pattern_3x5); + const image2 = await load_image(TEST_IMAGES.pattern_3x5); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image, image2); + + compare(pixel_values.dims, [1, 4, 32, 32]); + expect(avg(pixel_values.data)).toBeCloseTo(-0.00867417361587286); + expect(pixel_values.data[0]).toBeCloseTo(-0.9921568632125854); + expect(pixel_values.data[1]).toBeCloseTo(-0.9686274528503418); + expect(pixel_values.data[5]).toBeCloseTo(0.0); + expect(pixel_values.data[32]).toBeCloseTo(-0.9215686321258545); + expect(pixel_values.data[33]).toBeCloseTo(-0.8980392217636108); + expect(pixel_values.data.at(-1)).toBeCloseTo(0.0); + + compare(original_sizes, [[5, 3]]); + compare(reshaped_input_sizes, [[5, 3]]); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + // BitImageProcessor + it( + MODELS.dinov2, + async () => { + const processor = await AutoProcessor.from_pretrained(MODELS.dinov2); + + { + const image = await load_image(TEST_IMAGES.tiger); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 224, 224]); + compare(avg(pixel_values.data), 0.06262318789958954); + + compare(original_sizes, [[408, 612]]); + compare(reshaped_input_sizes, [[224, 224]]); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + // DPTImageProcessor + // - tests ensure_multiple_of + // - tests keep_aspect_ratio + // - tests bankers rounding + it( + MODELS.dpt_2, + async () => { + const processor = await AutoProcessor.from_pretrained(MODELS.dpt_2); + + { + const image = await load_image(TEST_IMAGES.cats); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 518, 686]); + compare(avg(pixel_values.data), 0.30337387323379517); + + compare(original_sizes, [[480, 640]]); + compare(reshaped_input_sizes, [[518, 686]]); + } + + { + const image = await load_image(TEST_IMAGES.checkerboard_64x32); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + // NOTE: without bankers rounding, this would be [1, 3, 266, 518] + compare(pixel_values.dims, [1, 3, 252, 518]); + compare(avg(pixel_values.data), 0.2267402559518814); + + compare(original_sizes, [[32, 64]]); + compare(reshaped_input_sizes, [[252, 518]]); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + // TODO: Add back + // // EfficientNetImageProcessor + // // - tests include_top + // it(MODELS.efficientnet, async () => { + // const processor = await AutoProcessor.from_pretrained(MODELS.efficientnet) + + // { + // const image = await load_image(TEST_IMAGES.cats); + // const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + // compare(pixel_values.dims, [1, 3, 224, 224]); + // compare(avg(pixel_values.data), 0.3015307230282871); + + // compare(original_sizes, [[480, 640]]); + // compare(reshaped_input_sizes, [[224, 224]]); + // } + // }, MAX_TEST_EXECUTION_TIME); + }); + + describe("Audio processors", () => { + const audioPromise = new Promise(async (resolve) => { + const url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/mlk.npy"; + const buffer = await (await fetch(url)).arrayBuffer(); + const audio = Float32Array.from(new Float64Array(buffer)); + resolve(audio); + }); + it( + "WhisperFeatureExtractor", + async () => { + const audio = await audioPromise; + const processor = await AutoProcessor.from_pretrained("Xenova/whisper-tiny.en"); + const { input_features } = await processor(audio); + compare(input_features.dims, [1, 80, 3000]); + expect(avg(input_features.data)).toBeCloseTo(-0.2813588131551941); + expect(input_features.data[0]).toBeCloseTo(0.33168578147888184); + expect(input_features.data[1]).toBeCloseTo(0.30986475944519043); + expect(input_features.data[81]).toBeCloseTo(0.10727232694625854); + expect(input_features.data[3001]).toBeCloseTo(0.2555035352706909); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "ASTFeatureExtractor", + async () => { + const audio = await audioPromise; + const processor = await AutoProcessor.from_pretrained("Xenova/ast-finetuned-audioset-10-10-0.4593"); + { + // truncation + const { input_values } = await processor(audio); + compare(input_values.dims, [1, 1024, 128]); + + expect(avg(input_values.data)).toBeCloseTo(-0.04054912979309085); + expect(input_values.data[0]).toBeCloseTo(-0.5662586092948914); + expect(input_values.data[1]).toBeCloseTo(-1.0300861597061157); + expect(input_values.data[129]).toBeCloseTo(-1.084834098815918); + expect(input_values.data[1025]).toBeCloseTo(-1.1204065084457397); + } + { + // padding + const { input_values } = await processor(audio.slice(0, 1000)); + compare(input_values.dims, [1, 1024, 128]); // [1, 4, 128] -> (padded to) -> [1, 1024, 128] + + expect(avg(input_values.data)).toBeCloseTo(0.4647964835166931); + expect(input_values.data[0]).toBeCloseTo(-0.5662586092948914); + expect(input_values.data[1]).toBeCloseTo(-1.0300861597061157); + expect(input_values.data[129]).toBeCloseTo(-1.084834098815918); + + // padded values + expect(input_values.data[1025]).toBeCloseTo(0.46703237295150757); + expect(input_values.data[2049]).toBeCloseTo(0.46703237295150757); + expect(input_values.data[10000]).toBeCloseTo(0.46703237295150757); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "SeamlessM4TFeatureExtractor", + async () => { + const audio = await audioPromise; + const processor = await AutoProcessor.from_pretrained("Xenova/wav2vec2-bert-CV16-en"); + { + // normal + const { input_features, attention_mask } = await processor(audio); + compare(input_features.dims, [1, 649, 160]); + compare(attention_mask.dims, [1, 649]); + + expect(avg(input_features.data)).toBeCloseTo(-2.938903875815413e-8); + expect(input_features.data[0]).toBeCloseTo(1.1939343214035034); + expect(input_features.data[1]).toBeCloseTo(0.7874255180358887); + expect(input_features.data[160]).toBeCloseTo(-0.712975025177002); + expect(input_features.data[161]).toBeCloseTo(0.045802414417266846); + expect(input_features.data.at(-1)).toBeCloseTo(-1.3328346014022827); + + expect(sum(attention_mask.data)).toEqual(649); + } + { + // padding (pad_to_multiple_of=2) + const { input_features, attention_mask } = await processor(audio.slice(0, 10000)); + + // [1, 61, 80] -> [1, 62, 80] -> [1, 31, 160] + compare(input_features.dims, [1, 31, 160]); + compare(attention_mask.dims, [1, 31]); + + expect(avg(input_features.data)).toBeCloseTo(0.01612919569015503); + expect(input_features.data[0]).toBeCloseTo(0.9657132029533386); + expect(input_features.data[1]).toBeCloseTo(0.12912897765636444); + expect(input_features.data[160]).toBeCloseTo(-1.2364212274551392); + expect(input_features.data[161]).toBeCloseTo(-0.9703778028488159); + expect(input_features.data.at(-1)).toBeCloseTo(1); // padding value + + expect(sum(attention_mask.data)).toEqual(30); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "ClapFeatureExtractor", + async () => { + const audio = await audioPromise; + const processor = await AutoProcessor.from_pretrained("Xenova/clap-htsat-unfused"); + { + // truncation + // Since truncation uses a random strategy, we override + // Math.random to ensure that the test is deterministic + const originalRandom = Math.random; + Math.random = () => 0.5; + + let long_audio = new Float32Array(500000); + long_audio.set(audio); + long_audio.set(audio, long_audio.length - audio.length); + + const { input_features } = await processor(long_audio); + compare(input_features.dims, [1, 1, 1001, 64]); + + expect(avg(input_features.data)).toBeCloseTo(-37.94569396972656); + expect(input_features.data[0]).toBeCloseTo(-53.32647705078125); + expect(input_features.data[1]).toBeCloseTo(-47.76755142211914); + expect(input_features.data[65]).toBeCloseTo(-36.32261276245117); + expect(input_features.data[1002]).toBeCloseTo(-28.0314884185791); + expect(input_features.data[10000]).toBeCloseTo(-21.905902862548828); + expect(input_features.data[60000]).toBeCloseTo(-14.877863883972168); + expect(input_features.data[64062]).toBeCloseTo(-37.9784049987793); + expect(input_features.data[64063]).toBeCloseTo(-37.73963928222656); + + // Reset Math.random + Math.random = originalRandom; + } + { + // padding + const { input_features } = await processor(audio); + compare(input_features.dims, [1, 1, 1001, 64]); + + expect(avg(input_features.data)).toBeCloseTo(-34.99049377441406); + expect(input_features.data[0]).toBeCloseTo(-21.32573890686035); + expect(input_features.data[1]).toBeCloseTo(-26.168411254882812); + expect(input_features.data[65]).toBeCloseTo(-29.716018676757812); + expect(input_features.data[1002]).toBeCloseTo(-32.16273498535156); + expect(input_features.data[10000]).toBeCloseTo(-19.9283390045166); + + // padded values + expect(input_features.data[60000]).toBeCloseTo(-100.0); + expect(input_features.data[64062]).toBeCloseTo(-100.0); + expect(input_features.data[64063]).toBeCloseTo(-100.0); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "WeSpeakerFeatureExtractor", + async () => { + const processor = await AutoProcessor.from_pretrained("onnx-community/wespeaker-voxceleb-resnet34-LM"); + { + // default + const audio = new Float32Array(16000).map((_, i) => Math.sin(i / 100)); + const { input_features } = await processor(audio); + compare(input_features.dims, [1, 98, 80]); + + expect(avg(input_features.data)).toBeCloseTo(5.461731689138105e-8); + expect(input_features.data[0]).toBeCloseTo(-0.19300270080566406); + expect(input_features.data[1]).toBeCloseTo(-0.05825042724609375); + expect(input_features.data[78]).toBeCloseTo(0.2683420181274414); + expect(input_features.data[79]).toBeCloseTo(0.26250171661376953); + expect(input_features.data[80]).toBeCloseTo(0.19062232971191406); + expect(input_features.data.at(-2)).toBeCloseTo(-0.43694400787353516); + expect(input_features.data.at(-1)).toBeCloseTo(-0.4266204833984375); + } + + { + // pad to `min_num_frames` + const audio = new Float32Array(3).map((_, i) => Math.sin(i / 100)); + const { input_features } = await processor(audio); + compare(input_features.dims, [1, 9, 80]); + + expect(avg(input_features.data)).toBeCloseTo(-0.0000010093053181966146); + expect(input_features.data[0]).toBeCloseTo(20.761859893798828); + expect(input_features.data[1]).toBeCloseTo(21.02924346923828); + expect(input_features.data[78]).toBeCloseTo(19.083993911743164); + expect(input_features.data[79]).toBeCloseTo(18.003454208374023); + expect(input_features.data[80]).toBeCloseTo(-2.595233917236328); + expect(input_features.data.at(-2)).toBeCloseTo(-2.385499954223633); + expect(input_features.data.at(-1)).toBeCloseTo(-2.2504329681396484); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Other processors", () => { + describe( + "FlorenceProcessor", + () => { + /** @type {import('../src/processors.js').Florence2Processor} */ + let processor; + let images = {}; + + beforeAll(async () => { + processor = await AutoProcessor.from_pretrained(MODELS.florence2); + images = { + beetle: await load_image(TEST_IMAGES.beetle), + book_cover: await load_image(TEST_IMAGES.book_cover), + }; + }); - // DetrFeatureExtractor - it(MODELS.detr, async () => { - const processor = await AutoProcessor.from_pretrained(MODELS.detr) + describe("Prompt construction", () => { + it("Construct prompt", async () => { + const text = ""; + const prompts = processor.construct_prompts(text); + const target = ["Locate the objects with category name in the image."]; + compare(prompts, target); + }); + + it("Construct prompts", async () => { + const texts = ["", "Locate the objects with category name in the image.", "cat"]; + const prompts = processor.construct_prompts(texts); + const target = ["Describe with a paragraph what is shown in the image.", "Locate the objects with category name in the image.", "Locate cat in the image."]; + compare(prompts, target); + }); + }); + describe("Post-process generation", () => { + const TESTS = [ { - const image = await load_image(TEST_IMAGES.tiger); - const { pixel_values, original_sizes, reshaped_input_sizes, pixel_mask } = await processor(image); - - compare(pixel_values.dims, [1, 3, 888, 1333]); - compare(avg(pixel_values.data), -0.27840224131001773); - - compare(original_sizes, [[408, 612]]); - compare(reshaped_input_sizes, [[888, 1333]]); - - compare(pixel_mask.dims, [1, 64, 64]); - compare(avg(pixel_mask.data), 1); - - } - }, MAX_TEST_EXECUTION_TIME); - - - // YolosFeatureExtractor - it(MODELS.yolos, async () => { - const processor = await AutoProcessor.from_pretrained(MODELS.yolos) - + task: "", + generated_text: "A green car parked in front of a yellow building.", + target: { "": "A green car parked in front of a yellow building." }, + image: "beetle", + }, { - const image = await load_image(TEST_IMAGES.tiger); - const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); - - compare(pixel_values.dims, [1, 3, 888, 1333]); - compare(avg(pixel_values.data), -0.27840224131001773); - - compare(original_sizes, [[408, 612]]); - compare(reshaped_input_sizes, [[888, 1333]]); - } - }, MAX_TEST_EXECUTION_TIME); - - // DPTFeatureExtractor - it(MODELS.dpt, async () => { - const processor = await AutoProcessor.from_pretrained(MODELS.dpt) - - { // Tests grayscale image - const image = await load_image(TEST_IMAGES.cats); - const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); - - compare(pixel_values.dims, [1, 3, 384, 384]); - compare(avg(pixel_values.data), 0.0372855559389454); - - compare(original_sizes, [[480, 640]]); - compare(reshaped_input_sizes, [[384, 384]]); - } - }, MAX_TEST_EXECUTION_TIME); - - // GLPNForDepthEstimation - // - tests `size_divisor` and no size (size_divisor=32) - it(MODELS.glpn, async () => { - const processor = await AutoProcessor.from_pretrained(MODELS.glpn) - + task: "", + generated_text: "The image shows a green Volkswagen Beetle parked in front of a yellow building with two brown doors. The sky is a mix of blue and white, and there are a few green trees in the background.", + target: { "": "The image shows a green Volkswagen Beetle parked in front of a yellow building with two brown doors. The sky is a mix of blue and white, and there are a few green trees in the background." }, + image: "beetle", + }, { - const image = await load_image(TEST_IMAGES.cats); - const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); - compare(pixel_values.dims, [1, 3, 480, 640]); - compare(avg(pixel_values.data), 0.5186172404123327); - - compare(original_sizes, [[480, 640]]); - compare(reshaped_input_sizes, [[480, 640]]); - } - - { // Tests input which is not a multiple of 32 ([408, 612] -> [384, 608]) - const image = await load_image(TEST_IMAGES.tiger); - const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); - - compare(pixel_values.dims, [1, 3, 384, 608]); - compare(avg(pixel_values.data), 0.38628831535989555); - - compare(original_sizes, [[408, 612]]); - compare(reshaped_input_sizes, [[384, 608]]); - } - }, MAX_TEST_EXECUTION_TIME); - - // NougatImageProcessor - // - tests padding after normalization (image_mean != 0.5, image_std != 0.5) - it(MODELS.nougat, async () => { - const processor = await AutoProcessor.from_pretrained(MODELS.nougat) - + task: "", + generated_text: "The image shows a vintage Volkswagen Beetle car parked on a cobblestone street in front of a yellow building with two wooden doors. The car is painted in a bright turquoise color and has a white stripe running along the side. It has two doors on either side of the car, one on top of the other, and a small window on the front. The building appears to be old and dilapidated, with peeling paint and crumbling walls. The sky is blue and there are trees in the background.", + target: { "": "The image shows a vintage Volkswagen Beetle car parked on a cobblestone street in front of a yellow building with two wooden doors. The car is painted in a bright turquoise color and has a white stripe running along the side. It has two doors on either side of the car, one on top of the other, and a small window on the front. The building appears to be old and dilapidated, with peeling paint and crumbling walls. The sky is blue and there are trees in the background." }, + image: "beetle", + }, { - const image = await load_image(TEST_IMAGES.paper); - const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); - - compare(pixel_values.dims, [1, 3, 896, 672]); - compare(avg(pixel_values.data), 1.8447155005897355); - - compare(original_sizes, [[850, 685]]); - compare(reshaped_input_sizes, [[833, 672]]); - } - }, MAX_TEST_EXECUTION_TIME); - - // OwlViTFeatureExtractor - it(MODELS.owlvit, async () => { - const processor = await AutoProcessor.from_pretrained(MODELS.owlvit) + task: "", + generated_text: "cardoorwheel", + target: { + "": { + bboxes: [ + [34.24, 160.08, 597.44, 371.76], + [456.0, 97.68, 580.16, 261.84], + [450.88, 276.72, 554.56, 370.8], + [95.68, 280.56, 198.72, 371.28], + ], + labels: ["car", "door", "wheel", "wheel"], + }, + }, + image: "beetle", + }, { - const image = await load_image(TEST_IMAGES.cats); - const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); - - compare(pixel_values.dims, [1, 3, 768, 768]); - compare(avg(pixel_values.data), 0.250620447910435); - - compare(original_sizes, [[480, 640]]); - compare(reshaped_input_sizes, [[768, 768]]); - } - }); - - // CLIPFeatureExtractor - // - tests center crop (do_center_crop=true, crop_size=224) - it(MODELS.clip, async () => { - const processor = await AutoProcessor.from_pretrained(MODELS.clip) - + task: "", + generated_text: "turquoise Volkswagen Beetlewheel", + target: { + "": { + bboxes: [ + [33.6, 160.08, 596.8, 371.76], + [450.88, 276.72, 553.28, 370.8], + [95.04, 280.56, 197.44, 371.28], + ], + labels: ["turquoise Volkswagen Beetle", "wheel", "wheel"], + }, + }, + image: "beetle", + }, { - const image = await load_image(TEST_IMAGES.tiger); - const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); - - compare(pixel_values.dims, [1, 3, 224, 224]); - compare(avg(pixel_values.data), -0.06678297738282096); - - compare(original_sizes, [[408, 612]]); - compare(reshaped_input_sizes, [[224, 224]]); - } - }, MAX_TEST_EXECUTION_TIME); - - // VitMatteImageProcessor - // - tests custom overrides - // - tests multiple inputs - // - tests `size_divisibility` and no size (size_divisibility=32) - // - tests do_pad and `size_divisibility` - it(MODELS.vitmatte, async () => { - const processor = await AutoProcessor.from_pretrained(MODELS.vitmatte) - + task: "", + generated_text: "", + target: { + "": { + bboxes: [ + [33.6, 160.08, 596.8, 371.76], + [455.36, 97.68, 579.52, 261.84], + [450.88, 276.72, 553.28, 370.8], + [95.04, 280.56, 198.08, 371.28], + [226.88, 88.56, 332.48, 164.4], + [65.6, 266.64, 86.72, 295.92], + [271.68, 241.68, 302.4, 246.96], + [408.0, 308.4, 413.76, 320.88], + ], + labels: ["", "", "", "", "", "", "", ""], + }, + }, + image: "beetle", + }, { - const image = await load_image(TEST_IMAGES.vitmatte_image); - const image2 = await load_image(TEST_IMAGES.vitmatte_trimap); - const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image, image2); - - compare(pixel_values.dims, [1, 4, 640, 960]); - expect(avg(pixel_values.data)).toBeCloseTo(-0.4028555154800415); - expect(pixel_values.data[0]).toBeCloseTo(-0.9921568632125854); - expect(pixel_values.data[1]).toBeCloseTo(-0.9921568632125854); - expect(pixel_values.data[5]).toBeCloseTo(-1.0); - expect(pixel_values.data[640]).toBeCloseTo(-0.6784313917160034); - expect(pixel_values.data[641]).toBeCloseTo(-0.6705882549285889); - expect(pixel_values.data[640 * 960]).toBeCloseTo(-1.0); - expect(pixel_values.data[640 * 960 + 1]).toBeCloseTo(-1.0); - expect(pixel_values.data.at(-1)).toBeCloseTo(0.0); - - compare(original_sizes, [[640, 960]]); - compare(reshaped_input_sizes, [[640, 960]]); - } - - + task: "", + text_input: "A green car parked in front of a yellow building.", + generated_text: "A green cara yellow building", + target: { + "": { + bboxes: [ + [34.88, 158.64, 583.36, 374.64], + [0.32, 4.08, 639.04, 305.04], + ], + labels: ["A green car", "a yellow building"], + }, + }, + image: "beetle", + }, + // { + // task: "", + // text_input: "a green car", + // generated_text: "", + // target: { + // '': { + // polygons: [[[[178.88, 181.68, 180.8, 180.72, 182.72, 180.72, 187.84, 177.84, 189.76, 177.84, 192.96, 175.92, 194.88, 175.92, 198.08, 174, 200.64, 173.04, 203.84, 172.08, 207.04, 170.64, 209.6, 169.68, 214.08, 168.72, 217.92, 167.76, 221.76, 166.8, 226.24, 165.84, 230.72, 164.88, 237.12, 163.92, 244.16, 162.96, 253.12, 162, 265.28, 161.04, 311.36, 161.04, 329.28, 162, 338.24, 162.96, 345.28, 163.92, 350.4, 164.88, 354.24, 165.84, 358.72, 166.8, 362.56, 167.76, 366.4, 168.72, 370.24, 169.68, 373.44, 170.64, 375.36, 172.08, 377.28, 174, 379.2, 176.88, 380.48, 179.76, 382.4, 181.68, 384.32, 185.04, 386.24, 187.92, 387.52, 190.8, 389.44, 192.72, 390.08, 196.08, 392, 198.96, 394.56, 201.84, 396.48, 204.72, 398.4, 208.08, 403.52, 212.88, 406.08, 213.84, 409.28, 216.72, 412.48, 220.08, 431.68, 220.08, 432.32, 221.04, 442.56, 222, 456.64, 222, 465.6, 222.96, 472.64, 223.92, 478.4, 224.88, 484.8, 225.84, 489.92, 226.8, 493.76, 227.76, 497.6, 228.72, 501.44, 229.68, 504.64, 231.12, 507.84, 232.08, 510.4, 233.04, 513.6, 234, 516.8, 235.92, 518.72, 235.92, 523.84, 238.8, 525.76, 238.8, 527.68, 239.76, 529.6, 241.68, 532.8, 242.64, 536, 245.04, 538.56, 247.92, 541.76, 249.84, 545.6, 251.76, 548.8, 252.72, 550.72, 252.72, 553.92, 253.68, 556.48, 255.6, 558.4, 255.6, 564.8, 258.96, 566.72, 260.88, 568.64, 260.88, 570.56, 261.84, 572.48, 263.76, 573.76, 265.68, 574.4, 268.56, 574.4, 271.92, 573.76, 272.88, 572.48, 275.76, 572.48, 279.6, 573.76, 285.84, 574.4, 286.8, 575.68, 289.68, 576.32, 292.56, 577.6, 298.8, 577.6, 301.68, 576.32, 302.64, 575.68, 310.8, 575.68, 312.72, 576.32, 313.68, 577.6, 316.56, 577.6, 320.88, 574.4, 321.84, 568.64, 322.8, 559.68, 322.8, 553.92, 323.76, 552.64, 332.88, 552, 336.72, 550.72, 339.6, 550.08, 342.96, 548.8, 344.88, 546.88, 346.8, 545.6, 349.68, 543.68, 352.56, 541.76, 355.92, 534.72, 362.64, 531.52, 364.56, 525.76, 367.92, 522.56, 368.88, 518.72, 369.84, 495.68, 369.84, 489.92, 368.88, 486.72, 367.92, 483.52, 366.96, 479.68, 364.56, 476.48, 362.64, 472.64, 359.76, 465.6, 352.56, 463.68, 349.68, 461.76, 346.8, 460.48, 344.88, 460.48, 342.96, 458.56, 339.6, 457.92, 336.72, 457.92, 334.8, 456.64, 332.88, 454.72, 330.96, 452.8, 331.92, 448.32, 336.72, 446.4, 337.68, 426.56, 336.72, 424.64, 336.72, 423.36, 337.68, 420.8, 338.64, 414.4, 339.6, 412.48, 339.6, 411.2, 338.64, 380.48, 337.68, 217.28, 337.68, 216, 338.64, 210.88, 339.6, 207.04, 339.6, 203.84, 338.64, 201.92, 337.68, 200, 335.76, 198.08, 334.8, 194.88, 334.8, 192.96, 336.72, 191.68, 338.64, 191.68, 340.56, 191.04, 342.96, 189.12, 344.88, 187.84, 347.76, 185.92, 349.68, 184.64, 352.56, 182.72, 355.92, 176.96, 361.68, 173.76, 363.6, 170.56, 365.52, 166.72, 367.92, 163.52, 368.88, 160.96, 369.84, 153.92, 370.8, 131.52, 370.8, 127.68, 369.84, 124.48, 368.88, 118.72, 365.52, 115.52, 363.6, 111.68, 360.72, 106.56, 355.92, 104.64, 352.56, 103.36, 349.68, 101.44, 347.76, 100.8, 345.84, 99.52, 342.96, 99.52, 339.6, 98.88, 337.68, 95.68, 334.8, 93.76, 333.84, 86.72, 333.84, 80.32, 334.8, 79.68, 335.76, 74.56, 336.72, 66.24, 336.72, 63.68, 334.8, 53.44, 334.8, 50.24, 333.84, 48.32, 331.92, 48.32, 328.56, 50.24, 326.64, 51.52, 324.72, 51.52, 322.8, 44.48, 321.84, 40.64, 320.88, 38.72, 319.92, 37.44, 317.52, 36.16, 313.68, 36.16, 306.96, 38.72, 304.56, 42.56, 303.6, 46.4, 302.64, 55.36, 301.68, 65.6, 301.68, 67.52, 300.72, 69.44, 298.8, 70.72, 296.88, 70.72, 292.56, 69.44, 291.6, 68.8, 288.72, 67.52, 284.88, 67.52, 276.72, 68.8, 273.84, 69.44, 271.92, 72.64, 268.56, 74.56, 267.6, 77.76, 266.64, 79.68, 266.64, 81.6, 264.72, 80.32, 260.88, 81.6, 258.96, 83.52, 256.56, 88.64, 256.56, 90.56, 255.6, 92.48, 253.68, 92.48, 252.72, 97.6, 246.96, 114.88, 229.68, 117.44, 226.8, 122.56, 222.96, 125.76, 221.04, 126.4, 221.04, 129.6, 219.12, 133.44, 215.76, 138.56, 211.92, 143.68, 208.08, 149.44, 201.84, 153.92, 198.96, 154.56, 198.96, 157.76, 197.04, 162.88, 192.72, 168.64, 186.96, 171.84, 185.04, 176.96, 183.12, 178.88, 180.72]]]], + // labels: [''], + // } + // }, + // image: 'beetle', + // }, + // { + // task: "", + // text_input: "", + // generated_text: "", + // target: { + // '': { + // polygons: [[[[470.08, 288.24, 473.92, 285.36, 477.12, 283.44, 479.04, 282.48, 480.96, 282.48, 484.16, 280.56, 486.72, 279.6, 489.92, 278.64, 495.04, 277.68, 512.32, 277.68, 514.88, 278.64, 518.08, 279.6, 521.28, 281.52, 523.2, 281.52, 525.12, 283.44, 528.32, 284.4, 530.88, 286.32, 534.08, 288.24, 543.04, 297.36, 544.96, 300.24, 546.88, 303.12, 550.08, 309.36, 551.36, 312.24, 552, 315.12, 553.28, 319.44, 553.28, 332.4, 552, 337.2, 551.36, 340.08, 550.08, 343.44, 548.16, 347.28, 546.24, 350.16, 544.32, 353.04, 541.12, 357.36, 537.28, 361.2, 532.16, 365.04, 528.96, 366.96, 527.04, 367.92, 523.84, 368.88, 521.28, 369.84, 516.16, 371.28, 500.8, 371.28, 491.84, 369.84, 488, 368.88, 484.8, 367.92, 479.04, 365.04, 475.84, 363.12, 472, 360.24, 464.96, 353.04, 463.04, 350.16, 461.12, 347.28, 459.84, 345.36, 459.84, 343.44, 457.92, 340.08, 456.64, 337.2, 456, 334.32, 454.72, 330.48, 454.72, 316.08, 456, 311.28, 456.64, 307.44, 457.92, 304.08, 459.84, 301.2, 459.84, 299.28, 461.12, 297.36, 463.04, 294.48]]]], + // labels: [''], + // } + // }, + // image: 'beetle', + // }, + // { + // task: "", + // text_input: "a green car", + // generated_text: "a green car", + // target: { + // '': { + // bboxes: [[34.24, 158.64, 582.72, 374.16]], + // bboxes_labels: ['a green car'], + // polygons: [], + // polygons_labels: [], + // } + // }, + // image: 'beetle', + // }, { - const image = await load_image(TEST_IMAGES.pattern_3x5); - const image2 = await load_image(TEST_IMAGES.pattern_3x5); - const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image, image2); - - compare(pixel_values.dims, [1, 4, 32, 32]); - expect(avg(pixel_values.data)).toBeCloseTo(-0.00867417361587286); - expect(pixel_values.data[0]).toBeCloseTo(-0.9921568632125854); - expect(pixel_values.data[1]).toBeCloseTo(-0.9686274528503418); - expect(pixel_values.data[5]).toBeCloseTo(0.0); - expect(pixel_values.data[32]).toBeCloseTo(-0.9215686321258545); - expect(pixel_values.data[33]).toBeCloseTo(-0.8980392217636108); - expect(pixel_values.data.at(-1)).toBeCloseTo(0.0); - - compare(original_sizes, [[5, 3]]); - compare(reshaped_input_sizes, [[5, 3]]); - } - }, MAX_TEST_EXECUTION_TIME); - - // BitImageProcessor - it(MODELS.dinov2, async () => { - const processor = await AutoProcessor.from_pretrained(MODELS.dinov2) - + task: "", + text_input: "", + generated_text: "car", + target: { "": "car" }, + image: "beetle", + }, { - const image = await load_image(TEST_IMAGES.tiger); - const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); - - compare(pixel_values.dims, [1, 3, 224, 224]); - compare(avg(pixel_values.data), 0.06262318789958954); - - compare(original_sizes, [[408, 612]]); - compare(reshaped_input_sizes, [[224, 224]]); - } - }, MAX_TEST_EXECUTION_TIME); - - // DPTImageProcessor - // - tests ensure_multiple_of - // - tests keep_aspect_ratio - // - tests bankers rounding - it(MODELS.dpt_2, async () => { - const processor = await AutoProcessor.from_pretrained(MODELS.dpt_2) - + task: "", + text_input: "", + generated_text: "turquoise Volkswagen Beetle", + target: { "": "turquoise Volkswagen Beetle" }, + image: "beetle", + }, { - const image = await load_image(TEST_IMAGES.cats); - const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); - - compare(pixel_values.dims, [1, 3, 518, 686]); - compare(avg(pixel_values.data), 0.30337387323379517); - - compare(original_sizes, [[480, 640]]); - compare(reshaped_input_sizes, [[518, 686]]); - } - + task: "", + generated_text: "CUDAFOR ENGINEERSAn Introduction to High-PerformanceParallel ComputingDUANE STORTIMETE YURTOGLU", + target: { "": "CUDAFOR ENGINEERSAn Introduction to High-PerformanceParallel ComputingDUANE STORTIMETE YURTOGLU" }, + image: "book_cover", + }, { - const image = await load_image(TEST_IMAGES.checkerboard_64x32); - const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); - - // NOTE: without bankers rounding, this would be [1, 3, 266, 518] - compare(pixel_values.dims, [1, 3, 252, 518]); - compare(avg(pixel_values.data), 0.2267402559518814); - - compare(original_sizes, [[32, 64]]); - compare(reshaped_input_sizes, [[252, 518]]); - } - }, MAX_TEST_EXECUTION_TIME); - - // TODO: Add back - // // EfficientNetImageProcessor - // // - tests include_top - // it(MODELS.efficientnet, async () => { - // const processor = await AutoProcessor.from_pretrained(MODELS.efficientnet) - - // { - // const image = await load_image(TEST_IMAGES.cats); - // const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); - - // compare(pixel_values.dims, [1, 3, 224, 224]); - // compare(avg(pixel_values.data), 0.3015307230282871); - - // compare(original_sizes, [[480, 640]]); - // compare(reshaped_input_sizes, [[224, 224]]); - // } - // }, MAX_TEST_EXECUTION_TIME); - }); - - describe('Audio processors', () => { - const audioPromise = new Promise(async (resolve) => { - const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/mlk.npy'; - const buffer = await (await fetch(url)).arrayBuffer(); - const audio = Float32Array.from(new Float64Array(buffer)); - resolve(audio); - }); - - it('WhisperFeatureExtractor', async () => { - const audio = await audioPromise; - const processor = await AutoProcessor.from_pretrained('Xenova/whisper-tiny.en'); - const { input_features } = await processor(audio); - compare(input_features.dims, [1, 80, 3000]); - expect(avg(input_features.data)).toBeCloseTo(-0.2813588131551941); - expect(input_features.data[0]).toBeCloseTo(0.33168578147888184); - expect(input_features.data[1]).toBeCloseTo(0.30986475944519043); - expect(input_features.data[81]).toBeCloseTo(0.10727232694625854); - expect(input_features.data[3001]).toBeCloseTo(0.2555035352706909); - }, MAX_TEST_EXECUTION_TIME); - - it('ASTFeatureExtractor', async () => { - const audio = await audioPromise; - const processor = await AutoProcessor.from_pretrained('Xenova/ast-finetuned-audioset-10-10-0.4593'); - { // truncation - const { input_values } = await processor(audio); - compare(input_values.dims, [1, 1024, 128]); - - expect(avg(input_values.data)).toBeCloseTo(-0.04054912979309085); - expect(input_values.data[0]).toBeCloseTo(-0.5662586092948914); - expect(input_values.data[1]).toBeCloseTo(-1.0300861597061157); - expect(input_values.data[129]).toBeCloseTo(-1.084834098815918); - expect(input_values.data[1025]).toBeCloseTo(-1.1204065084457397); - } - { // padding - const { input_values } = await processor(audio.slice(0, 1000)); - compare(input_values.dims, [1, 1024, 128]); // [1, 4, 128] -> (padded to) -> [1, 1024, 128] - - expect(avg(input_values.data)).toBeCloseTo(0.4647964835166931); - expect(input_values.data[0]).toBeCloseTo(-0.5662586092948914); - expect(input_values.data[1]).toBeCloseTo(-1.0300861597061157); - expect(input_values.data[129]).toBeCloseTo(-1.084834098815918); - - // padded values - expect(input_values.data[1025]).toBeCloseTo(0.46703237295150757); - expect(input_values.data[2049]).toBeCloseTo(0.46703237295150757); - expect(input_values.data[10000]).toBeCloseTo(0.46703237295150757); - } - }, MAX_TEST_EXECUTION_TIME); - - it('SeamlessM4TFeatureExtractor', async () => { - const audio = await audioPromise; - const processor = await AutoProcessor.from_pretrained('Xenova/wav2vec2-bert-CV16-en'); - { // normal - const { input_features, attention_mask } = await processor(audio); - compare(input_features.dims, [1, 649, 160]); - compare(attention_mask.dims, [1, 649]); - - expect(avg(input_features.data)).toBeCloseTo(-2.938903875815413e-08); - expect(input_features.data[0]).toBeCloseTo(1.1939343214035034); - expect(input_features.data[1]).toBeCloseTo(0.7874255180358887); - expect(input_features.data[160]).toBeCloseTo(-0.712975025177002); - expect(input_features.data[161]).toBeCloseTo(0.045802414417266846); - expect(input_features.data.at(-1)).toBeCloseTo(-1.3328346014022827); - - expect(sum(attention_mask.data)).toEqual(649); - } - { // padding (pad_to_multiple_of=2) - const { input_features, attention_mask } = await processor(audio.slice(0, 10000)); - - // [1, 61, 80] -> [1, 62, 80] -> [1, 31, 160] - compare(input_features.dims, [1, 31, 160]); - compare(attention_mask.dims, [1, 31]); - - expect(avg(input_features.data)).toBeCloseTo(0.01612919569015503); - expect(input_features.data[0]).toBeCloseTo(0.9657132029533386); - expect(input_features.data[1]).toBeCloseTo(0.12912897765636444); - expect(input_features.data[160]).toBeCloseTo(-1.2364212274551392); - expect(input_features.data[161]).toBeCloseTo(-0.9703778028488159); - expect(input_features.data.at(-1)).toBeCloseTo(1); // padding value - - expect(sum(attention_mask.data)).toEqual(30); - } - }, MAX_TEST_EXECUTION_TIME); - - it('ClapFeatureExtractor', async () => { - const audio = await audioPromise; - const processor = await AutoProcessor.from_pretrained('Xenova/clap-htsat-unfused'); - { // truncation - // Since truncation uses a random strategy, we override - // Math.random to ensure that the test is deterministic - const originalRandom = Math.random; - Math.random = () => 0.5; - - let long_audio = new Float32Array(500000); - long_audio.set(audio); - long_audio.set(audio, long_audio.length - audio.length); - - const { input_features } = await processor(long_audio); - compare(input_features.dims, [1, 1, 1001, 64]); - - expect(avg(input_features.data)).toBeCloseTo(-37.94569396972656); - expect(input_features.data[0]).toBeCloseTo(-53.32647705078125); - expect(input_features.data[1]).toBeCloseTo(-47.76755142211914); - expect(input_features.data[65]).toBeCloseTo(-36.32261276245117); - expect(input_features.data[1002]).toBeCloseTo(-28.0314884185791); - expect(input_features.data[10000]).toBeCloseTo(-21.905902862548828); - expect(input_features.data[60000]).toBeCloseTo(-14.877863883972168); - expect(input_features.data[64062]).toBeCloseTo(-37.9784049987793); - expect(input_features.data[64063]).toBeCloseTo(-37.73963928222656); - - // Reset Math.random - Math.random = originalRandom; - } - { // padding - const { input_features } = await processor(audio); - compare(input_features.dims, [1, 1, 1001, 64]); - - expect(avg(input_features.data)).toBeCloseTo(-34.99049377441406); - expect(input_features.data[0]).toBeCloseTo(-21.32573890686035); - expect(input_features.data[1]).toBeCloseTo(-26.168411254882812); - expect(input_features.data[65]).toBeCloseTo(-29.716018676757812); - expect(input_features.data[1002]).toBeCloseTo(-32.16273498535156); - expect(input_features.data[10000]).toBeCloseTo(-19.9283390045166); - - // padded values - expect(input_features.data[60000]).toBeCloseTo(-100.0); - expect(input_features.data[64062]).toBeCloseTo(-100.0); - expect(input_features.data[64063]).toBeCloseTo(-100.0); - } - - - }, MAX_TEST_EXECUTION_TIME); - - it('WeSpeakerFeatureExtractor', async () => { - - const processor = await AutoProcessor.from_pretrained('onnx-community/wespeaker-voxceleb-resnet34-LM'); - { // default - const audio = new Float32Array(16000).map((_, i) => Math.sin(i / 100)); - const { input_features } = await processor(audio); - compare(input_features.dims, [1, 98, 80]); - - expect(avg(input_features.data)).toBeCloseTo(5.461731689138105e-08); - expect(input_features.data[0]).toBeCloseTo(-0.19300270080566406); - expect(input_features.data[1]).toBeCloseTo(-0.05825042724609375); - expect(input_features.data[78]).toBeCloseTo(0.2683420181274414); - expect(input_features.data[79]).toBeCloseTo(0.26250171661376953); - expect(input_features.data[80]).toBeCloseTo(0.19062232971191406); - expect(input_features.data.at(-2)).toBeCloseTo(-0.43694400787353516); - expect(input_features.data.at(-1)).toBeCloseTo(-0.4266204833984375); - } - - { // pad to `min_num_frames` - const audio = new Float32Array(3).map((_, i) => Math.sin(i / 100)); - const { input_features } = await processor(audio); - compare(input_features.dims, [1, 9, 80]); - - expect(avg(input_features.data)).toBeCloseTo(-0.0000010093053181966146); - expect(input_features.data[0]).toBeCloseTo(20.761859893798828); - expect(input_features.data[1]).toBeCloseTo(21.02924346923828); - expect(input_features.data[78]).toBeCloseTo(19.083993911743164); - expect(input_features.data[79]).toBeCloseTo(18.003454208374023); - expect(input_features.data[80]).toBeCloseTo(-2.595233917236328); - expect(input_features.data.at(-2)).toBeCloseTo(-2.385499954223633); - expect(input_features.data.at(-1)).toBeCloseTo(-2.2504329681396484); - } - - }, MAX_TEST_EXECUTION_TIME); - }); - - describe('Other processors', () => { - - describe('FlorenceProcessor', () => { - /** @type {import('../src/processors.js').Florence2Processor} */ - let processor; - let images = {}; - - beforeAll(async () => { - processor = await AutoProcessor.from_pretrained(MODELS.florence2); - images = { - beetle: await load_image(TEST_IMAGES.beetle), - book_cover: await load_image(TEST_IMAGES.book_cover), - }; + task: "", + generated_text: "CUDAFOR ENGINEERSAn Introduction to High-PerformanceParallel ComputingDUANE STORTIMETE YURTOGLU", + target: { + "": { + quad_boxes: [ + [167.0435028076172, 50.25, 375.7974853515625, 50.25, 375.7974853515625, 114.75, 167.0435028076172, 114.75], + [144.8784942626953, 120.75, 375.7974853515625, 120.75, 375.7974853515625, 149.25, 144.8784942626953, 149.25], + [115.86249542236328, 165.25, 376.6034851074219, 166.25, 376.6034851074219, 184.25, 115.86249542236328, 183.25], + [239.9864959716797, 184.25, 376.6034851074219, 186.25, 376.6034851074219, 204.25, 239.9864959716797, 202.25], + [266.1814880371094, 441.25, 376.6034851074219, 441.25, 376.6034851074219, 456.25, 266.1814880371094, 456.25], + [252.0764923095703, 460.25, 376.6034851074219, 460.25, 376.6034851074219, 475.25, 252.0764923095703, 475.25], + ], + + // NOTE: Python version has a bug here, it should be "CUDA" instead of "CUDA" + labels: [/* 'CUDA' */ "CUDA", "FOR ENGINEERS", "An Introduction to High-Performance", "Parallel Computing", "DUANE STORTI", "METE YURTOGLU"], + }, + }, + image: "book_cover", + }, + ]; + + for (const { task, generated_text, target, image } of TESTS) { + it(task, () => { + const result = processor.post_process_generation(generated_text, task, images[image].size); + compare(result, target); }); - - describe('Prompt construction', () => { - it('Construct prompt', async () => { - const text = ""; - const prompts = processor.construct_prompts(text); - const target = [ - 'Locate the objects with category name in the image.' - ] - compare(prompts, target); - }); - - it('Construct prompts', async () => { - const texts = [ - "", - "Locate the objects with category name in the image.", - "cat" - ]; - const prompts = processor.construct_prompts(texts); - const target = [ - 'Describe with a paragraph what is shown in the image.', - 'Locate the objects with category name in the image.', - 'Locate cat in the image.' - ] - compare(prompts, target); - }); - }); - - describe('Post-process generation', () => { - const TESTS = [ - { - task: "", - generated_text: "A green car parked in front of a yellow building.", - target: { '': 'A green car parked in front of a yellow building.' }, - image: 'beetle', - }, - { - task: "", - generated_text: "The image shows a green Volkswagen Beetle parked in front of a yellow building with two brown doors. The sky is a mix of blue and white, and there are a few green trees in the background.", - target: { '': 'The image shows a green Volkswagen Beetle parked in front of a yellow building with two brown doors. The sky is a mix of blue and white, and there are a few green trees in the background.' }, - image: 'beetle', - }, - { - task: "", - generated_text: "The image shows a vintage Volkswagen Beetle car parked on a cobblestone street in front of a yellow building with two wooden doors. The car is painted in a bright turquoise color and has a white stripe running along the side. It has two doors on either side of the car, one on top of the other, and a small window on the front. The building appears to be old and dilapidated, with peeling paint and crumbling walls. The sky is blue and there are trees in the background.", - target: { '': 'The image shows a vintage Volkswagen Beetle car parked on a cobblestone street in front of a yellow building with two wooden doors. The car is painted in a bright turquoise color and has a white stripe running along the side. It has two doors on either side of the car, one on top of the other, and a small window on the front. The building appears to be old and dilapidated, with peeling paint and crumbling walls. The sky is blue and there are trees in the background.' }, - image: 'beetle', - }, - { - task: "", - generated_text: "cardoorwheel", - target: { - '': { - bboxes: [ - [34.24, 160.08, 597.44, 371.76], - [456.00, 97.68, 580.16, 261.84], - [450.88, 276.72, 554.56, 370.80], - [95.68, 280.56, 198.72, 371.28], - ], - labels: ['car', 'door', 'wheel', 'wheel'], - } - }, - image: 'beetle', - }, - { - task: "", - generated_text: "turquoise Volkswagen Beetlewheel", - target: { - '': { - bboxes: [ - [33.60, 160.08, 596.80, 371.76], - [450.88, 276.72, 553.28, 370.80], - [95.04, 280.56, 197.44, 371.28], - ], - labels: ['turquoise Volkswagen Beetle', 'wheel', 'wheel'] - } - }, - image: 'beetle', - }, - { - task: "", - generated_text: "", - target: { - '': { - bboxes: [ - [33.60, 160.08, 596.80, 371.76], - [455.36, 97.68, 579.52, 261.84], - [450.88, 276.72, 553.28, 370.80], - [95.04, 280.56, 198.08, 371.28], - [226.88, 88.56, 332.48, 164.40], - [65.60, 266.64, 86.72, 295.92], - [271.68, 241.68, 302.40, 246.96], - [408.00, 308.40, 413.76, 320.88] - ], - labels: ['', '', '', '', '', '', '', ''] - } - }, - image: 'beetle', - }, - { - task: "", - text_input: "A green car parked in front of a yellow building.", - generated_text: "A green cara yellow building", - target: { - '': { - bboxes: [ - [34.88, 158.64, 583.36, 374.64], - [0.32, 4.08, 639.04, 305.04], - ], - labels: ['A green car', 'a yellow building'] - } - }, - image: 'beetle', - }, - // { - // task: "", - // text_input: "a green car", - // generated_text: "", - // target: { - // '': { - // polygons: [[[[178.88, 181.68, 180.8, 180.72, 182.72, 180.72, 187.84, 177.84, 189.76, 177.84, 192.96, 175.92, 194.88, 175.92, 198.08, 174, 200.64, 173.04, 203.84, 172.08, 207.04, 170.64, 209.6, 169.68, 214.08, 168.72, 217.92, 167.76, 221.76, 166.8, 226.24, 165.84, 230.72, 164.88, 237.12, 163.92, 244.16, 162.96, 253.12, 162, 265.28, 161.04, 311.36, 161.04, 329.28, 162, 338.24, 162.96, 345.28, 163.92, 350.4, 164.88, 354.24, 165.84, 358.72, 166.8, 362.56, 167.76, 366.4, 168.72, 370.24, 169.68, 373.44, 170.64, 375.36, 172.08, 377.28, 174, 379.2, 176.88, 380.48, 179.76, 382.4, 181.68, 384.32, 185.04, 386.24, 187.92, 387.52, 190.8, 389.44, 192.72, 390.08, 196.08, 392, 198.96, 394.56, 201.84, 396.48, 204.72, 398.4, 208.08, 403.52, 212.88, 406.08, 213.84, 409.28, 216.72, 412.48, 220.08, 431.68, 220.08, 432.32, 221.04, 442.56, 222, 456.64, 222, 465.6, 222.96, 472.64, 223.92, 478.4, 224.88, 484.8, 225.84, 489.92, 226.8, 493.76, 227.76, 497.6, 228.72, 501.44, 229.68, 504.64, 231.12, 507.84, 232.08, 510.4, 233.04, 513.6, 234, 516.8, 235.92, 518.72, 235.92, 523.84, 238.8, 525.76, 238.8, 527.68, 239.76, 529.6, 241.68, 532.8, 242.64, 536, 245.04, 538.56, 247.92, 541.76, 249.84, 545.6, 251.76, 548.8, 252.72, 550.72, 252.72, 553.92, 253.68, 556.48, 255.6, 558.4, 255.6, 564.8, 258.96, 566.72, 260.88, 568.64, 260.88, 570.56, 261.84, 572.48, 263.76, 573.76, 265.68, 574.4, 268.56, 574.4, 271.92, 573.76, 272.88, 572.48, 275.76, 572.48, 279.6, 573.76, 285.84, 574.4, 286.8, 575.68, 289.68, 576.32, 292.56, 577.6, 298.8, 577.6, 301.68, 576.32, 302.64, 575.68, 310.8, 575.68, 312.72, 576.32, 313.68, 577.6, 316.56, 577.6, 320.88, 574.4, 321.84, 568.64, 322.8, 559.68, 322.8, 553.92, 323.76, 552.64, 332.88, 552, 336.72, 550.72, 339.6, 550.08, 342.96, 548.8, 344.88, 546.88, 346.8, 545.6, 349.68, 543.68, 352.56, 541.76, 355.92, 534.72, 362.64, 531.52, 364.56, 525.76, 367.92, 522.56, 368.88, 518.72, 369.84, 495.68, 369.84, 489.92, 368.88, 486.72, 367.92, 483.52, 366.96, 479.68, 364.56, 476.48, 362.64, 472.64, 359.76, 465.6, 352.56, 463.68, 349.68, 461.76, 346.8, 460.48, 344.88, 460.48, 342.96, 458.56, 339.6, 457.92, 336.72, 457.92, 334.8, 456.64, 332.88, 454.72, 330.96, 452.8, 331.92, 448.32, 336.72, 446.4, 337.68, 426.56, 336.72, 424.64, 336.72, 423.36, 337.68, 420.8, 338.64, 414.4, 339.6, 412.48, 339.6, 411.2, 338.64, 380.48, 337.68, 217.28, 337.68, 216, 338.64, 210.88, 339.6, 207.04, 339.6, 203.84, 338.64, 201.92, 337.68, 200, 335.76, 198.08, 334.8, 194.88, 334.8, 192.96, 336.72, 191.68, 338.64, 191.68, 340.56, 191.04, 342.96, 189.12, 344.88, 187.84, 347.76, 185.92, 349.68, 184.64, 352.56, 182.72, 355.92, 176.96, 361.68, 173.76, 363.6, 170.56, 365.52, 166.72, 367.92, 163.52, 368.88, 160.96, 369.84, 153.92, 370.8, 131.52, 370.8, 127.68, 369.84, 124.48, 368.88, 118.72, 365.52, 115.52, 363.6, 111.68, 360.72, 106.56, 355.92, 104.64, 352.56, 103.36, 349.68, 101.44, 347.76, 100.8, 345.84, 99.52, 342.96, 99.52, 339.6, 98.88, 337.68, 95.68, 334.8, 93.76, 333.84, 86.72, 333.84, 80.32, 334.8, 79.68, 335.76, 74.56, 336.72, 66.24, 336.72, 63.68, 334.8, 53.44, 334.8, 50.24, 333.84, 48.32, 331.92, 48.32, 328.56, 50.24, 326.64, 51.52, 324.72, 51.52, 322.8, 44.48, 321.84, 40.64, 320.88, 38.72, 319.92, 37.44, 317.52, 36.16, 313.68, 36.16, 306.96, 38.72, 304.56, 42.56, 303.6, 46.4, 302.64, 55.36, 301.68, 65.6, 301.68, 67.52, 300.72, 69.44, 298.8, 70.72, 296.88, 70.72, 292.56, 69.44, 291.6, 68.8, 288.72, 67.52, 284.88, 67.52, 276.72, 68.8, 273.84, 69.44, 271.92, 72.64, 268.56, 74.56, 267.6, 77.76, 266.64, 79.68, 266.64, 81.6, 264.72, 80.32, 260.88, 81.6, 258.96, 83.52, 256.56, 88.64, 256.56, 90.56, 255.6, 92.48, 253.68, 92.48, 252.72, 97.6, 246.96, 114.88, 229.68, 117.44, 226.8, 122.56, 222.96, 125.76, 221.04, 126.4, 221.04, 129.6, 219.12, 133.44, 215.76, 138.56, 211.92, 143.68, 208.08, 149.44, 201.84, 153.92, 198.96, 154.56, 198.96, 157.76, 197.04, 162.88, 192.72, 168.64, 186.96, 171.84, 185.04, 176.96, 183.12, 178.88, 180.72]]]], - // labels: [''], - // } - // }, - // image: 'beetle', - // }, - // { - // task: "", - // text_input: "", - // generated_text: "", - // target: { - // '': { - // polygons: [[[[470.08, 288.24, 473.92, 285.36, 477.12, 283.44, 479.04, 282.48, 480.96, 282.48, 484.16, 280.56, 486.72, 279.6, 489.92, 278.64, 495.04, 277.68, 512.32, 277.68, 514.88, 278.64, 518.08, 279.6, 521.28, 281.52, 523.2, 281.52, 525.12, 283.44, 528.32, 284.4, 530.88, 286.32, 534.08, 288.24, 543.04, 297.36, 544.96, 300.24, 546.88, 303.12, 550.08, 309.36, 551.36, 312.24, 552, 315.12, 553.28, 319.44, 553.28, 332.4, 552, 337.2, 551.36, 340.08, 550.08, 343.44, 548.16, 347.28, 546.24, 350.16, 544.32, 353.04, 541.12, 357.36, 537.28, 361.2, 532.16, 365.04, 528.96, 366.96, 527.04, 367.92, 523.84, 368.88, 521.28, 369.84, 516.16, 371.28, 500.8, 371.28, 491.84, 369.84, 488, 368.88, 484.8, 367.92, 479.04, 365.04, 475.84, 363.12, 472, 360.24, 464.96, 353.04, 463.04, 350.16, 461.12, 347.28, 459.84, 345.36, 459.84, 343.44, 457.92, 340.08, 456.64, 337.2, 456, 334.32, 454.72, 330.48, 454.72, 316.08, 456, 311.28, 456.64, 307.44, 457.92, 304.08, 459.84, 301.2, 459.84, 299.28, 461.12, 297.36, 463.04, 294.48]]]], - // labels: [''], - // } - // }, - // image: 'beetle', - // }, - // { - // task: "", - // text_input: "a green car", - // generated_text: "a green car", - // target: { - // '': { - // bboxes: [[34.24, 158.64, 582.72, 374.16]], - // bboxes_labels: ['a green car'], - // polygons: [], - // polygons_labels: [], - // } - // }, - // image: 'beetle', - // }, - { - task: "", - text_input: "", - generated_text: "car", - target: { '': 'car' }, - image: 'beetle', - }, - { - task: "", - text_input: "", - generated_text: "turquoise Volkswagen Beetle", - target: { '': 'turquoise Volkswagen Beetle' }, - image: 'beetle', - }, - { - task: "", - generated_text: "CUDAFOR ENGINEERSAn Introduction to High-PerformanceParallel ComputingDUANE STORTIMETE YURTOGLU", - target: { '': 'CUDAFOR ENGINEERSAn Introduction to High-PerformanceParallel ComputingDUANE STORTIMETE YURTOGLU' }, - image: 'book_cover', - }, - { - task: "", - generated_text: "CUDAFOR ENGINEERSAn Introduction to High-PerformanceParallel ComputingDUANE STORTIMETE YURTOGLU", - target: { - '': { - quad_boxes: - [ - [167.0435028076172, 50.25, 375.7974853515625, 50.25, 375.7974853515625, 114.75, 167.0435028076172, 114.75], - [144.8784942626953, 120.75, 375.7974853515625, 120.75, 375.7974853515625, 149.25, 144.8784942626953, 149.25], - [115.86249542236328, 165.25, 376.6034851074219, 166.25, 376.6034851074219, 184.25, 115.86249542236328, 183.25], - [239.9864959716797, 184.25, 376.6034851074219, 186.25, 376.6034851074219, 204.25, 239.9864959716797, 202.25], - [266.1814880371094, 441.25, 376.6034851074219, 441.25, 376.6034851074219, 456.25, 266.1814880371094, 456.25], - [252.0764923095703, 460.25, 376.6034851074219, 460.25, 376.6034851074219, 475.25, 252.0764923095703, 475.25], - ], - - // NOTE: Python version has a bug here, it should be "CUDA" instead of "CUDA" - labels: [/* 'CUDA' */ 'CUDA', 'FOR ENGINEERS', 'An Introduction to High-Performance', 'Parallel Computing', 'DUANE STORTI', 'METE YURTOGLU'], - } - }, - image: 'book_cover', - }, - ] - - for (const { task, generated_text, target, image } of TESTS) { - it(task, () => { - const result = processor.post_process_generation(generated_text, task, images[image].size); - compare(result, target); - }); - } - }); - }, MAX_TEST_EXECUTION_TIME); - }); + } + }); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); }); diff --git a/tests/tensor.test.js b/tests/tensor.test.js index a339b22b7..ca5367175 100644 --- a/tests/tensor.test.js +++ b/tests/tensor.test.js @@ -1,202 +1,166 @@ +import { Tensor, cat, mean, stack, layer_norm } from "../src/transformers.js"; +import { compare } from "./test_utils.js"; + +describe("Tensor operations", () => { + describe("cat", () => { + it("should concatenate on dim=0", async () => { + const t1 = new Tensor("float32", [1, 2, 3], [1, 3]); + const t2 = new Tensor("float32", [4, 5, 6, 7, 8, 9], [2, 3]); + const t3 = new Tensor("float32", [10, 11, 12], [1, 3]); + + const target1 = new Tensor("float32", [1, 2, 3, 4, 5, 6, 7, 8, 9], [3, 3]); + const target2 = new Tensor("float32", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [4, 3]); + + // 2 tensors + const concatenated1 = cat([t1, t2], 0); + compare(concatenated1, target1, 1e-3); + + // 3 tensors + const concatenated2 = cat([t1, t2, t3], 0); + compare(concatenated2, target2, 1e-3); + }); -import { Tensor, cat, mean, stack, layer_norm } from '../src/transformers.js'; -import { compare } from './test_utils.js'; - -describe('Tensor operations', () => { - - describe('cat', () => { - - it('should concatenate on dim=0', async () => { - const t1 = new Tensor('float32', [1, 2, 3], [1, 3]); - const t2 = new Tensor('float32', [4, 5, 6, 7, 8, 9], [2, 3]); - const t3 = new Tensor('float32', [10, 11, 12], [1, 3]); - - const target1 = new Tensor('float32', [1, 2, 3, 4, 5, 6, 7, 8, 9], [3, 3]); - const target2 = new Tensor('float32', [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [4, 3]); - - // 2 tensors - const concatenated1 = cat([t1, t2], 0); - compare(concatenated1, target1, 1e-3); - - // 3 tensors - const concatenated2 = cat([t1, t2, t3], 0); - compare(concatenated2, target2, 1e-3); - }); - - - it('should concatenate on dim=1', async () => { - const t1 = new Tensor('float32', [1, 2, 3, -1, -2, -3], [2, 3, 1]); - const t2 = new Tensor('float32', [4, -4], [2, 1, 1]); - const t3 = new Tensor('float32', [5, 6, -5, -6], [2, 2, 1]); - - const target1 = new Tensor('float32', [1, 2, 3, 4, -1, -2, -3, -4], [2, 4, 1]); - const target2 = new Tensor('float32', [1, 2, 3, 4, 5, 6, -1, -2, -3, -4, -5, -6], [2, 6, 1]); - - // 2 tensors - const concatenated1 = cat([t1, t2], 1); - compare(concatenated1, target1, 1e-3); - - // 3 tensors - const concatenated2 = cat([t1, t2, t3], 1); - compare(concatenated2, target2, 1e-3); - }); - + it("should concatenate on dim=1", async () => { + const t1 = new Tensor("float32", [1, 2, 3, -1, -2, -3], [2, 3, 1]); + const t2 = new Tensor("float32", [4, -4], [2, 1, 1]); + const t3 = new Tensor("float32", [5, 6, -5, -6], [2, 2, 1]); - it('should concatenate on dim=-2', async () => { + const target1 = new Tensor("float32", [1, 2, 3, 4, -1, -2, -3, -4], [2, 4, 1]); + const target2 = new Tensor("float32", [1, 2, 3, 4, 5, 6, -1, -2, -3, -4, -5, -6], [2, 6, 1]); - const t1 = new Tensor('float32', [1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16], [2, 1, 3, 2]); - const t2 = new Tensor('float32', [7, 8, 9, 10, 17, 18, 19, 20], [2, 1, 2, 2]); + // 2 tensors + const concatenated1 = cat([t1, t2], 1); + compare(concatenated1, target1, 1e-3); - const target = new Tensor('float32', [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [2, 1, 5, 2]); + // 3 tensors + const concatenated2 = cat([t1, t2, t3], 1); + compare(concatenated2, target2, 1e-3); + }); - const concatenated = cat([t1, t2], -2); + it("should concatenate on dim=-2", async () => { + const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16], [2, 1, 3, 2]); + const t2 = new Tensor("float32", [7, 8, 9, 10, 17, 18, 19, 20], [2, 1, 2, 2]); - compare(concatenated, target, 1e-3); + const target = new Tensor("float32", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [2, 1, 5, 2]); - }); + const concatenated = cat([t1, t2], -2); - // TODO add tests for errors + compare(concatenated, target, 1e-3); }); - describe('stack', () => { + // TODO add tests for errors + }); - const t1 = new Tensor('float32', [0, 1, 2, 3, 4, 5], [1, 3, 2]); + describe("stack", () => { + const t1 = new Tensor("float32", [0, 1, 2, 3, 4, 5], [1, 3, 2]); - it('should stack on dim=0', async () => { - const target1 = new Tensor('float32', [0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5], [2, 1, 3, 2]); - const target2 = new Tensor('float32', [0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5], [3, 1, 3, 2]); + it("should stack on dim=0", async () => { + const target1 = new Tensor("float32", [0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5], [2, 1, 3, 2]); + const target2 = new Tensor("float32", [0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5], [3, 1, 3, 2]); - // 2 tensors - const stacked1 = stack([t1, t1], 0); - compare(stacked1, target1, 1e-3); + // 2 tensors + const stacked1 = stack([t1, t1], 0); + compare(stacked1, target1, 1e-3); - // 3 tensors - const stacked2 = stack([t1, t1, t1], 0); - compare(stacked2, target2, 1e-3); - }); + // 3 tensors + const stacked2 = stack([t1, t1, t1], 0); + compare(stacked2, target2, 1e-3); + }); - it('should stack on dim=1', async () => { - const target1 = new Tensor('float32', [0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5], [1, 2, 3, 2]); - const target2 = new Tensor('float32', [0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5], [1, 3, 3, 2]); + it("should stack on dim=1", async () => { + const target1 = new Tensor("float32", [0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5], [1, 2, 3, 2]); + const target2 = new Tensor("float32", [0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5], [1, 3, 3, 2]); - // 2 tensors - const stacked1 = stack([t1, t1], 1); - compare(stacked1, target1, 1e-3); + // 2 tensors + const stacked1 = stack([t1, t1], 1); + compare(stacked1, target1, 1e-3); - // 3 tensors - const stacked2 = stack([t1, t1, t1], 1); - compare(stacked2, target2, 1e-3); - }); + // 3 tensors + const stacked2 = stack([t1, t1, t1], 1); + compare(stacked2, target2, 1e-3); + }); - it('should stack on dim=-1', async () => { - const target1 = new Tensor('float32', [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5], [1, 3, 2, 2]); - const target2 = new Tensor('float32', [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5], [1, 3, 2, 3]); + it("should stack on dim=-1", async () => { + const target1 = new Tensor("float32", [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5], [1, 3, 2, 2]); + const target2 = new Tensor("float32", [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5], [1, 3, 2, 3]); - // 2 tensors - const stacked1 = stack([t1, t1], -1); - compare(stacked1, target1, 1e-3); + // 2 tensors + const stacked1 = stack([t1, t1], -1); + compare(stacked1, target1, 1e-3); - // 3 tensors - const stacked2 = stack([t1, t1, t1], -1); - compare(stacked2, target2, 1e-3); - }); + // 3 tensors + const stacked2 = stack([t1, t1, t1], -1); + compare(stacked2, target2, 1e-3); }); - - describe('permute', () => { - it('should permute', async () => { - const x = new Tensor( - 'float32', - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], - [2, 3, 4], - ); - // Permute axes to (0, 1, 2) - No change - const permuted_1 = x.permute(0, 1, 2); - const target_1 = x; - compare(permuted_1, target_1, 1e-3); - - // Permute axes to (0, 2, 1) - const permuted_2 = x.permute(0, 2, 1); - const target_2 = new Tensor( - 'float32', - [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11, 12, 16, 20, 13, 17, 21, 14, 18, 22, 15, 19, 23], - [2, 4, 3], - ); - compare(permuted_2, target_2, 1e-3); - - // Permute axes to (1, 0, 2) - const permuted_3 = x.permute(1, 0, 2); - const target_3 = new Tensor( - 'float32', - [0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23], - [3, 2, 4], - ); - compare(permuted_3, target_3, 1e-3); - - // Permute axes to (1, 2, 0) - const permuted_4 = x.permute(1, 2, 0); - const target_4 = new Tensor( - 'float32', - [0, 12, 1, 13, 2, 14, 3, 15, 4, 16, 5, 17, 6, 18, 7, 19, 8, 20, 9, 21, 10, 22, 11, 23], - [3, 4, 2], - ); - compare(permuted_4, target_4, 1e-3); - - // Permute axes to (2, 0, 1) - const permuted_5 = x.permute(2, 0, 1); - const target_5 = new Tensor( - 'float32', - [0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23], - [4, 2, 3], - ); - compare(permuted_5, target_5, 1e-3); - - // Permute axes to (2, 1, 0) - const permuted_6 = x.permute(2, 1, 0); - const target_6 = new Tensor( - 'float32', - [0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21, 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23], - [4, 3, 2], - ); - compare(permuted_6, target_6, 1e-3); - }); + }); + + describe("permute", () => { + it("should permute", async () => { + const x = new Tensor("float32", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], [2, 3, 4]); + // Permute axes to (0, 1, 2) - No change + const permuted_1 = x.permute(0, 1, 2); + const target_1 = x; + compare(permuted_1, target_1, 1e-3); + + // Permute axes to (0, 2, 1) + const permuted_2 = x.permute(0, 2, 1); + const target_2 = new Tensor("float32", [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11, 12, 16, 20, 13, 17, 21, 14, 18, 22, 15, 19, 23], [2, 4, 3]); + compare(permuted_2, target_2, 1e-3); + + // Permute axes to (1, 0, 2) + const permuted_3 = x.permute(1, 0, 2); + const target_3 = new Tensor("float32", [0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23], [3, 2, 4]); + compare(permuted_3, target_3, 1e-3); + + // Permute axes to (1, 2, 0) + const permuted_4 = x.permute(1, 2, 0); + const target_4 = new Tensor("float32", [0, 12, 1, 13, 2, 14, 3, 15, 4, 16, 5, 17, 6, 18, 7, 19, 8, 20, 9, 21, 10, 22, 11, 23], [3, 4, 2]); + compare(permuted_4, target_4, 1e-3); + + // Permute axes to (2, 0, 1) + const permuted_5 = x.permute(2, 0, 1); + const target_5 = new Tensor("float32", [0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23], [4, 2, 3]); + compare(permuted_5, target_5, 1e-3); + + // Permute axes to (2, 1, 0) + const permuted_6 = x.permute(2, 1, 0); + const target_6 = new Tensor("float32", [0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21, 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23], [4, 3, 2]); + compare(permuted_6, target_6, 1e-3); }); + }); - describe('mean', () => { - it('should calculate mean', async () => { - const t1 = new Tensor('float32', [1, 2, 3, 4, 5, 6], [2, 3, 1]); - - const target = new Tensor('float32', [3.5], []); + describe("mean", () => { + it("should calculate mean", async () => { + const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [2, 3, 1]); - const target0 = new Tensor('float32', [2.5, 3.5, 4.5], [3, 1]); - const target1 = new Tensor('float32', [2, 5], [2, 1]); - const target2 = new Tensor('float32', [1, 2, 3, 4, 5, 6], [2, 3]); + const target = new Tensor("float32", [3.5], []); - let avg = mean(t1); - compare(avg, target, 1e-3); + const target0 = new Tensor("float32", [2.5, 3.5, 4.5], [3, 1]); + const target1 = new Tensor("float32", [2, 5], [2, 1]); + const target2 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [2, 3]); - let avg0 = mean(t1, 0); - compare(avg0, target0, 1e-3); + let avg = mean(t1); + compare(avg, target, 1e-3); - let avg1 = mean(t1, 1); - compare(avg1, target1, 1e-3); + let avg0 = mean(t1, 0); + compare(avg0, target0, 1e-3); - let avg2 = mean(t1, 2); - compare(avg2, target2, 1e-3); + let avg1 = mean(t1, 1); + compare(avg1, target1, 1e-3); - }) + let avg2 = mean(t1, 2); + compare(avg2, target2, 1e-3); }); + }); - describe('layer_norm', () => { - it('should calculate layer norm', async () => { - const t1 = new Tensor('float32', [1, 2, 3, 4, 5, 6], [2, 3]); + describe("layer_norm", () => { + it("should calculate layer norm", async () => { + const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [2, 3]); - const target = new Tensor('float32', [ - -1.2247356176376343, 0.0, 1.2247356176376343, - -1.2247357368469238, -1.1920928955078125e-07, 1.2247354984283447, - ], [2, 3]); + const target = new Tensor("float32", [-1.2247356176376343, 0.0, 1.2247356176376343, -1.2247357368469238, -1.1920928955078125e-7, 1.2247354984283447], [2, 3]); - const norm = layer_norm(t1, [t1.dims.at(-1)]); - compare(norm, target, 1e-3); - }); + const norm = layer_norm(t1, [t1.dims.at(-1)]); + compare(norm, target, 1e-3); }); + }); }); diff --git a/tests/tensor_ops.test.js b/tests/tensor_ops.test.js index 67d111074..ed9200388 100644 --- a/tests/tensor_ops.test.js +++ b/tests/tensor_ops.test.js @@ -1,184 +1,191 @@ -import { Tensor, interpolate_4d, matmul, rfft } from '../src/transformers.js'; -import { init } from './init.js'; +import { Tensor, interpolate_4d, matmul, rfft } from "../src/transformers.js"; +import { init } from "./init.js"; // Initialise the testing environment init(); function expectToBeCloseToArray(actual, expected) { - expect(actual.length).toEqual(expected.length) - actual.forEach((x, i) => expect(x).toBeCloseTo(expected[i])) + expect(actual.length).toEqual(expected.length); + actual.forEach((x, i) => expect(x).toBeCloseTo(expected[i])); } function range(start, stop = undefined, step = 1) { - if (stop === undefined) { - stop = start; - start = 0; - } - - const result = []; - for (let i = start; i < stop; i += step) { - result.push(i); - } - return result; + if (stop === undefined) { + stop = start; + start = 0; + } + + const result = []; + for (let i = start; i < stop; i += step) { + result.push(i); + } + return result; } -describe('Tensor operations', () => { - - describe('interpolate', () => { - const input = new Tensor('float32', new Float32Array(2 * 3 * 4 * 5).map((_, i) => i), [2, 3, 4, 5]); - - const size = [2, 3, 3, 2]; - it('bilinear', async () => { - const resized = await interpolate_4d( - input, - { mode: 'bilinear', size }, - ); - const target = new Float32Array([ - [ - [ - [1.5833335, 4.0833335], - [8.25, 10.75], - [14.916668, 17.416668] - ], - [ - [21.583332, 24.083334], - [28.25, 30.75], - [34.916668, 37.416668] - ], - [ - [41.583332, 44.083332], - [48.25, 50.75], - [54.916668, 57.416668] - ] - ], - [ - [ - [61.583332, 64.083336], - [68.25, 70.75], - [74.916664, 77.41667] - ], - [ - [81.58333, 84.083336], - [88.25, 90.75], - [94.91667, 97.41667] - ], - [ - [101.583336, 104.08333], - [108.25, 110.75], - [114.916664, 117.416664] - ] - ] - ].flat(Infinity)); - - expectToBeCloseToArray(target, resized.data); - }); - - it('bicubic', async () => { - const resized = await interpolate_4d( - input, - { mode: 'bicubic', size }, - ); - - const target = new Float32Array([ - [ - [ - [1.2987545, 3.9628172], - [8.167969, 10.832031], - [15.037184, 17.701244] - ], - [ - [21.298756, 23.962818], - [28.167969, 30.832031], - [35.037186, 37.701252] - ], - [ - [41.298756, 43.96282], - [48.16797, 50.83203], - [55.037193, 57.701256] - ] - ], - [ - [ - [61.29875, 63.96282], - [68.16797, 70.83203], - [75.03719, 77.701256] - ], - [ - [81.29875, 83.96282], - [88.16797, 90.83203], - [95.03721, 97.70126] - ], - [ - [101.29875, 103.962814], - [108.16797, 110.83203], - [115.03721, 117.70127] - ] - ] - ].flat(Infinity)); - - expectToBeCloseToArray(target, resized.data); - }); +describe("Tensor operations", () => { + describe("interpolate", () => { + const input = new Tensor( + "float32", + new Float32Array(2 * 3 * 4 * 5).map((_, i) => i), + [2, 3, 4, 5], + ); + + const size = [2, 3, 3, 2]; + it("bilinear", async () => { + const resized = await interpolate_4d(input, { mode: "bilinear", size }); + const target = new Float32Array( + [ + [ + [ + [1.5833335, 4.0833335], + [8.25, 10.75], + [14.916668, 17.416668], + ], + [ + [21.583332, 24.083334], + [28.25, 30.75], + [34.916668, 37.416668], + ], + [ + [41.583332, 44.083332], + [48.25, 50.75], + [54.916668, 57.416668], + ], + ], + [ + [ + [61.583332, 64.083336], + [68.25, 70.75], + [74.916664, 77.41667], + ], + [ + [81.58333, 84.083336], + [88.25, 90.75], + [94.91667, 97.41667], + ], + [ + [101.583336, 104.08333], + [108.25, 110.75], + [114.916664, 117.416664], + ], + ], + ].flat(Infinity), + ); + + expectToBeCloseToArray(target, resized.data); }); - describe('matmul', () => { - it('(2, 5) @ (5, 4) -> (2, 4)', async () => { - const a = new Tensor('float32', range(10), [2, 5]); - const b = new Tensor('float32', range(20), [5, 4]); + it("bicubic", async () => { + const resized = await interpolate_4d(input, { mode: "bicubic", size }); + + const target = new Float32Array( + [ + [ + [ + [1.2987545, 3.9628172], + [8.167969, 10.832031], + [15.037184, 17.701244], + ], + [ + [21.298756, 23.962818], + [28.167969, 30.832031], + [35.037186, 37.701252], + ], + [ + [41.298756, 43.96282], + [48.16797, 50.83203], + [55.037193, 57.701256], + ], + ], + [ + [ + [61.29875, 63.96282], + [68.16797, 70.83203], + [75.03719, 77.701256], + ], + [ + [81.29875, 83.96282], + [88.16797, 90.83203], + [95.03721, 97.70126], + ], + [ + [101.29875, 103.962814], + [108.16797, 110.83203], + [115.03721, 117.70127], + ], + ], + ].flat(Infinity), + ); + + expectToBeCloseToArray(target, resized.data); + }); + }); + + describe("matmul", () => { + it("(2, 5) @ (5, 4) -> (2, 4)", async () => { + const a = new Tensor("float32", range(10), [2, 5]); + const b = new Tensor("float32", range(20), [5, 4]); - const result = await matmul(a, b); + const result = await matmul(a, b); - const target = new Float32Array([ - [120.0, 130.0, 140.0, 150.0], - [320.0, 355.0, 390.0, 425.0], - ].flat()); + const target = new Float32Array( + [ + [120.0, 130.0, 140.0, 150.0], + [320.0, 355.0, 390.0, 425.0], + ].flat(), + ); - expectToBeCloseToArray(target, result.data); - }); + expectToBeCloseToArray(target, result.data); + }); + }); + + describe("rfft", () => { + it("non-power of 2", async () => { + const rows = 2; + const cols = 3; + const input = new Tensor("float32", range(rows * cols), [rows, cols]); + const dim = new Tensor("int64", [-1n], []); + const result = await rfft(input, dim); + + const target = new Float32Array( + [ + [ + [3, 0], + [-1.5, 0.8660262823104858], + ], + [ + [12, 0], + [-1.5, 0.866027295589447], + ], + ].flat(Infinity), + ); + + expectToBeCloseToArray(target, result.data); }); - describe('rfft', () => { - it('non-power of 2', async () => { - const rows = 2; - const cols = 3; - const input = new Tensor('float32', range(rows * cols), [rows, cols]); - const dim = new Tensor('int64', [-1n], []); - const result = await rfft(input, dim); - - const target = new Float32Array([ - [ - [3, 0], - [-1.5, 0.8660262823104858], - ], - [ - [12, 0], - [-1.5, 0.866027295589447], - ], - ].flat(Infinity)); - - expectToBeCloseToArray(target, result.data); - }); - - it('power of 2', async () => { - const rows = 2; - const cols = 4; - const input = new Tensor('float32', range(rows * cols), [rows, cols]); - const dim = new Tensor('int64', [-1n], []); - - const result = await rfft(input, dim); - const target = new Float32Array([ - [ - [6, 0], - [-2, 2], - [-2, 0], - ], - [ - [22, 0], - [-2, 2], - [-2, 0], - ], - ].flat(Infinity)); - - expectToBeCloseToArray(target, result.data); - }); + it("power of 2", async () => { + const rows = 2; + const cols = 4; + const input = new Tensor("float32", range(rows * cols), [rows, cols]); + const dim = new Tensor("int64", [-1n], []); + + const result = await rfft(input, dim); + const target = new Float32Array( + [ + [ + [6, 0], + [-2, 2], + [-2, 0], + ], + [ + [22, 0], + [-2, 2], + [-2, 0], + ], + ].flat(Infinity), + ); + + expectToBeCloseToArray(target, result.data); }); + }); }); diff --git a/tests/test_utils.js b/tests/test_utils.js index 2a05c657f..9928bf75b 100644 --- a/tests/test_utils.js +++ b/tests/test_utils.js @@ -1,32 +1,30 @@ - - export async function loadAudio(url) { - // NOTE: Since the Web Audio API is not available in Node.js, we will need to use the `wavefile` library to obtain the raw audio data. - // For more information, see: https://huggingface.co/docs/transformers.js/guides/node-audio-processing - let wavefile = (await import('wavefile')).default; - - // Load audio data - let buffer = Buffer.from(await fetch(url).then(x => x.arrayBuffer())) + // NOTE: Since the Web Audio API is not available in Node.js, we will need to use the `wavefile` library to obtain the raw audio data. + // For more information, see: https://huggingface.co/docs/transformers.js/guides/node-audio-processing + let wavefile = (await import("wavefile")).default; - // Read .wav file and convert it to required format - let wav = new wavefile.WaveFile(buffer); - wav.toBitDepth('32f'); // Pipeline expects input as a Float32Array - wav.toSampleRate(16000); // Whisper expects audio with a sampling rate of 16000 - let audioData = wav.getSamples(); - if (Array.isArray(audioData)) { - if (audioData.length > 1) { - const SCALING_FACTOR = Math.sqrt(2); + // Load audio data + let buffer = Buffer.from(await fetch(url).then((x) => x.arrayBuffer())); - // Merge channels (into first channel to save memory) - for (let i = 0; i < audioData[0].length; ++i) { - audioData[0][i] = SCALING_FACTOR * (audioData[0][i] + audioData[1][i]) / 2; - } - } + // Read .wav file and convert it to required format + let wav = new wavefile.WaveFile(buffer); + wav.toBitDepth("32f"); // Pipeline expects input as a Float32Array + wav.toSampleRate(16000); // Whisper expects audio with a sampling rate of 16000 + let audioData = wav.getSamples(); + if (Array.isArray(audioData)) { + if (audioData.length > 1) { + const SCALING_FACTOR = Math.sqrt(2); - // Select first channel - audioData = audioData[0]; + // Merge channels (into first channel to save memory) + for (let i = 0; i < audioData[0].length; ++i) { + audioData[0][i] = (SCALING_FACTOR * (audioData[0][i] + audioData[1][i])) / 2; + } } - return audioData; + + // Select first channel + audioData = audioData[0]; + } + return audioData; } /** * Deep equality test (for arrays and objects) with tolerance for floating point numbers @@ -35,38 +33,33 @@ export async function loadAudio(url) { * @param {number} tol Tolerance for floating point numbers */ export function compare(val1, val2, tol = 0.1) { - if ( - (val1 !== null && val2 !== null) && - (typeof val1 === 'object' && typeof val2 === 'object') - ) { - // Both are non-null objects - - if (Array.isArray(val1) && Array.isArray(val2)) { - expect(val1).toHaveLength(val2.length); + if (val1 !== null && val2 !== null && typeof val1 === "object" && typeof val2 === "object") { + // Both are non-null objects - for (let i = 0; i < val1.length; ++i) { - compare(val1[i], val2[i], tol); - } + if (Array.isArray(val1) && Array.isArray(val2)) { + expect(val1).toHaveLength(val2.length); - } else { - expect(Object.keys(val1)).toHaveLength(Object.keys(val2).length); + for (let i = 0; i < val1.length; ++i) { + compare(val1[i], val2[i], tol); + } + } else { + expect(Object.keys(val1)).toHaveLength(Object.keys(val2).length); - for (let key in val1) { - compare(val1[key], val2[key], tol); - } - } + for (let key in val1) { + compare(val1[key], val2[key], tol); + } + } + } else { + // At least one of them is not an object + // First check that both have the same type + expect(typeof val1).toEqual(typeof val2); + if (typeof val1 === "number" && (!Number.isInteger(val1) || !Number.isInteger(val2))) { + // If both are numbers and at least one of them is not an integer + expect(val1).toBeCloseTo(val2, -Math.log10(tol)); } else { - // At least one of them is not an object - // First check that both have the same type - expect(typeof val1).toEqual(typeof val2); - - if (typeof val1 === 'number' && (!Number.isInteger(val1) || !Number.isInteger(val2))) { - // If both are numbers and at least one of them is not an integer - expect(val1).toBeCloseTo(val2, -Math.log10(tol)); - } else { - // Perform equality test - expect(val1).toEqual(val2); - } + // Perform equality test + expect(val1).toEqual(val2); } -} \ No newline at end of file + } +} diff --git a/tests/tiny_random.test.js b/tests/tiny_random.test.js index 82f79ac04..5dd255f46 100644 --- a/tests/tiny_random.test.js +++ b/tests/tiny_random.test.js @@ -1,2299 +1,2433 @@ - - import { - // Tokenizers - CodeGenTokenizer, - LlamaTokenizer, - CohereTokenizer, - GemmaTokenizer, - GPT2Tokenizer, - GPTNeoXTokenizer, - BloomTokenizer, - BertTokenizer, - T5Tokenizer, - WhisperTokenizer, - BartTokenizer, - MarianTokenizer, - PreTrainedTokenizer, - AutoTokenizer, - - // Processors - CLIPImageProcessor, - AutoProcessor, - Processor, - Florence2Processor, - - // Models - LlamaForCausalLM, - CohereModel, - CohereForCausalLM, - GemmaForCausalLM, - Gemma2ForCausalLM, - OPTForCausalLM, - GPTNeoXForCausalLM, - GPTJForCausalLM, - BloomForCausalLM, - GPTBigCodeForCausalLM, - GPT2LMHeadModel, - JAISLMHeadModel, - MptForCausalLM, - CodeGenForCausalLM, - MistralForCausalLM, - GPTNeoForCausalLM, - BertForMaskedLM, - BertForSequenceClassification, - T5ForConditionalGeneration, - T5Model, - BertModel, - BertForTokenClassification, - BertForQuestionAnswering, - MusicgenForConditionalGeneration, - LlavaForConditionalGeneration, - WhisperForConditionalGeneration, - VisionEncoderDecoderModel, - Florence2ForConditionalGeneration, - MarianMTModel, - - // Pipelines - pipeline, - FillMaskPipeline, - TextClassificationPipeline, - TextGenerationPipeline, - ImageClassificationPipeline, - ZeroShotImageClassificationPipeline, - TokenClassificationPipeline, - QuestionAnsweringPipeline, - - // Other - full, - RawImage, -} from '../src/transformers.js'; - -import { init, MAX_MODEL_LOAD_TIME, MAX_TEST_EXECUTION_TIME, MAX_MODEL_DISPOSE_TIME } from './init.js'; -import { compare } from './test_utils.js'; + // Tokenizers + CodeGenTokenizer, + LlamaTokenizer, + CohereTokenizer, + GemmaTokenizer, + GPT2Tokenizer, + GPTNeoXTokenizer, + BloomTokenizer, + BertTokenizer, + T5Tokenizer, + WhisperTokenizer, + BartTokenizer, + MarianTokenizer, + PreTrainedTokenizer, + AutoTokenizer, + + // Processors + CLIPImageProcessor, + AutoProcessor, + Processor, + Florence2Processor, + + // Models + LlamaForCausalLM, + CohereModel, + CohereForCausalLM, + GemmaForCausalLM, + Gemma2ForCausalLM, + OPTForCausalLM, + GPTNeoXForCausalLM, + GPTJForCausalLM, + BloomForCausalLM, + GPTBigCodeForCausalLM, + GPT2LMHeadModel, + JAISLMHeadModel, + MptForCausalLM, + CodeGenForCausalLM, + MistralForCausalLM, + GPTNeoForCausalLM, + BertForMaskedLM, + BertForSequenceClassification, + T5ForConditionalGeneration, + T5Model, + BertModel, + BertForTokenClassification, + BertForQuestionAnswering, + MusicgenForConditionalGeneration, + LlavaForConditionalGeneration, + WhisperForConditionalGeneration, + VisionEncoderDecoderModel, + Florence2ForConditionalGeneration, + MarianMTModel, + + // Pipelines + pipeline, + FillMaskPipeline, + TextClassificationPipeline, + TextGenerationPipeline, + ImageClassificationPipeline, + ZeroShotImageClassificationPipeline, + TokenClassificationPipeline, + QuestionAnsweringPipeline, + + // Other + full, + RawImage, +} from "../src/transformers.js"; + +import { init, MAX_MODEL_LOAD_TIME, MAX_TEST_EXECUTION_TIME, MAX_MODEL_DISPOSE_TIME } from "./init.js"; +import { compare } from "./test_utils.js"; init(); const DEFAULT_MODEL_OPTIONS = { - dtype: 'fp32', -} -describe('Tiny random models', () => { - - describe('bert', () => { - describe('BertModel', () => { - const model_id = 'hf-internal-testing/tiny-random-BertModel'; - - /** @type {BertModel} */ - let model; - /** @type {BertTokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await BertModel.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await BertTokenizer.from_pretrained(model_id); - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer('hello'); - const { last_hidden_state } = await model(inputs); - expect(last_hidden_state.dims).toEqual([1, 7, 32]); - expect(last_hidden_state.mean().item()).toBeCloseTo(0.0, 5); - - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(['hello', 'hello world'], { padding: true }); - const { last_hidden_state } = await model(inputs); - expect(last_hidden_state.dims).toEqual([2, 12, 32]); - expect(last_hidden_state.mean().item()).toBeCloseTo(1.4901161193847656e-08, 5); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); - }); - - describe('BertForMaskedLM', () => { - const model_id = 'hf-internal-testing/tiny-random-BertForMaskedLM'; - - const texts = [ - 'The goal of life is [MASK].', - 'Paris is the [MASK] of France.', - ]; - - /** @type {BertForMaskedLM} */ - let model; - /** @type {BertTokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await BertForMaskedLM.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await BertTokenizer.from_pretrained(model_id); - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer(texts[0]); - const { logits } = await model(inputs); - expect(logits.dims).toEqual([1, 19, 1124]); - expect(logits.mean().item()).toBeCloseTo(0.0016587056452408433, 5); - - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(texts, { padding: true }); - const { logits } = await model(inputs); - expect(logits.dims).toEqual([2, 22, 1124]); - expect(logits.mean().item()).toBeCloseTo(0.0017160633578896523, 5); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + dtype: "fp32", +}; +describe("Tiny random models", () => { + describe("bert", () => { + describe("BertModel", () => { + const model_id = "hf-internal-testing/tiny-random-BertModel"; + + /** @type {BertModel} */ + let model; + /** @type {BertTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await BertModel.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); + tokenizer = await BertTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const { last_hidden_state } = await model(inputs); + expect(last_hidden_state.dims).toEqual([1, 7, 32]); + expect(last_hidden_state.mean().item()).toBeCloseTo(0.0, 5); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const { last_hidden_state } = await model(inputs); + expect(last_hidden_state.dims).toEqual([2, 12, 32]); + expect(last_hidden_state.mean().item()).toBeCloseTo(1.4901161193847656e-8, 5); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); - describe('BertForSequenceClassification', () => { - const model_id = 'hf-internal-testing/tiny-random-BertForSequenceClassification'; - - /** @type {BertForSequenceClassification} */ - let model; - /** @type {BertTokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await BertForSequenceClassification.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await BertTokenizer.from_pretrained(model_id); - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer('hello'); - const { logits } = await model(inputs); - const target = [ - [0.00043986947275698185, -0.030218850821256638], - ].flat(); - expect(logits.dims).toEqual([1, 2]); - logits.tolist().flat().forEach((item, i) => { - expect(item).toBeCloseTo(target[i], 5); - }); - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(['hello', 'hello world'], { padding: true }); - const { logits } = await model(inputs); - const target = [ - [0.00043986947275698185, -0.030218850821256638], - [0.0003853091038763523, -0.03022204339504242] - ].flat(); - expect(logits.dims).toEqual([2, 2]); - logits.tolist().flat().forEach((item, i) => { - expect(item).toBeCloseTo(target[i], 5); - }); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); - }); + describe("BertForMaskedLM", () => { + const model_id = "hf-internal-testing/tiny-random-BertForMaskedLM"; - describe('BertForTokenClassification', () => { - const model_id = 'hf-internal-testing/tiny-random-BertForTokenClassification'; - - /** @type {BertForTokenClassification} */ - let model; - /** @type {BertTokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await BertForTokenClassification.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await BertTokenizer.from_pretrained(model_id); - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer('hello'); - const { logits } = await model(inputs); - expect(logits.dims).toEqual([1, 7, 2]); - expect(logits.mean().item()).toBeCloseTo(0.07089076191186905, 5); - - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(['hello', 'hello world'], { padding: true }); - const { logits } = await model(inputs); - expect(logits.dims).toEqual([2, 12, 2]); - expect(logits.mean().item()).toBeCloseTo(0.04702216014266014, 5); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); - }); + const texts = ["The goal of life is [MASK].", "Paris is the [MASK] of France."]; - describe('BertForQuestionAnswering', () => { - const model_id = 'hf-internal-testing/tiny-random-BertForQuestionAnswering'; - - /** @type {BertForQuestionAnswering} */ - let model; - /** @type {BertTokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await BertForQuestionAnswering.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await BertTokenizer.from_pretrained(model_id); - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer('hello'); - const { start_logits, end_logits } = await model(inputs); - expect(start_logits.dims).toEqual([1, 7]); - expect(start_logits.mean().item()).toBeCloseTo(0.12772157788276672, 5); - expect(end_logits.dims).toEqual([1, 7]); - expect(end_logits.mean().item()).toBeCloseTo(0.11811424791812897, 5); - - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(['hello', 'hello world'], { padding: true }); - const { start_logits, end_logits } = await model(inputs); - expect(start_logits.dims).toEqual([2, 12]); - expect(start_logits.mean().item()).toBeCloseTo(0.12843115627765656, 5); - expect(end_logits.dims).toEqual([2, 12]); - expect(end_logits.mean().item()).toBeCloseTo(0.11745202541351318, 5); - - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + /** @type {BertForMaskedLM} */ + let model; + /** @type {BertTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await BertForMaskedLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); + tokenizer = await BertTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer(texts[0]); + const { logits } = await model(inputs); + expect(logits.dims).toEqual([1, 19, 1124]); + expect(logits.mean().item()).toBeCloseTo(0.0016587056452408433, 5); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(texts, { padding: true }); + const { logits } = await model(inputs); + expect(logits.dims).toEqual([2, 22, 1124]); + expect(logits.mean().item()).toBeCloseTo(0.0017160633578896523, 5); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); }); - describe('t5', () => { - - describe('T5Model', () => { - const model_id = 'hf-internal-testing/tiny-random-T5Model'; - - /** @type {T5Model} */ - let model; - /** @type {T5Tokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await T5Model.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await T5Tokenizer.from_pretrained(model_id); - }, MAX_MODEL_LOAD_TIME); - - it('forward', async () => { - // Example adapted from https://huggingface.co/google-t5/t5-small#how-to-get-started-with-the-model - const inputs = tokenizer( - "Studies have been shown that owning a dog is good for you", - ); - const { input_ids: decoder_input_ids } = tokenizer( - "Studies show that", - ); - - const { last_hidden_state } = await model({ ...inputs, decoder_input_ids }); - expect(last_hidden_state.dims).toEqual([1, 4, 32]); - expect(last_hidden_state.mean().item()).toBeCloseTo(7.492632721550763e-05, 8); - }); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + describe("BertForSequenceClassification", () => { + const model_id = "hf-internal-testing/tiny-random-BertForSequenceClassification"; + + /** @type {BertForSequenceClassification} */ + let model; + /** @type {BertTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await BertForSequenceClassification.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); - describe('T5ForConditionalGeneration', () => { - const model_id = 'hf-internal-testing/tiny-random-T5ForConditionalGeneration'; - - /** @type {T5ForConditionalGeneration} */ - let model; - /** @type {T5Tokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await T5ForConditionalGeneration.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await T5Tokenizer.from_pretrained(model_id); - }, MAX_MODEL_LOAD_TIME); - - it('forward', async () => { - // Example adapted from https://huggingface.co/google-t5/t5-small#how-to-get-started-with-the-model - const inputs = tokenizer( - "Studies have been shown that owning a dog is good for you", - ); - const { input_ids: decoder_input_ids } = tokenizer( - "Studies show that", - ); - - const model = await T5ForConditionalGeneration.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS); - const outputs = await model({ ...inputs, decoder_input_ids }); - expect(outputs.logits.dims).toEqual([1, 4, 32100]); - expect(outputs.logits.mean().item()).toBeCloseTo(8.867568901393952e-09, 12); + tokenizer = await BertTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const { logits } = await model(inputs); + const target = [[0.00043986947275698185, -0.030218850821256638]].flat(); + expect(logits.dims).toEqual([1, 2]); + logits + .tolist() + .flat() + .forEach((item, i) => { + expect(item).toBeCloseTo(target[i], 5); + }); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const { logits } = await model(inputs); + const target = [ + [0.00043986947275698185, -0.030218850821256638], + [0.0003853091038763523, -0.03022204339504242], + ].flat(); + expect(logits.dims).toEqual([2, 2]); + logits + .tolist() + .flat() + .forEach((item, i) => { + expect(item).toBeCloseTo(target[i], 5); }); + }, + MAX_TEST_EXECUTION_TIME, + ); - it('batch_size=1', async () => { - const inputs = tokenizer('hello'); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(['hello', 'hello world'], { padding: true }); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n], - [0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); - }); + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); }); - describe('marian', () => { - - describe('MarianMTModel', () => { - const model_id = 'onnx-community/tiny-random-MarianMTModel'; - - /** @type {MarianMTModel} */ - let model; - /** @type {MarianTokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await MarianMTModel.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await MarianTokenizer.from_pretrained(model_id); - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer('hello'); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [3n, 40672n, 8358n, 32810n, 32810n, 32810n, 32810n, 35687n, 33073n, 6870n], - ]); - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(['hello', 'hello world'], { padding: true }); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [3n, 40672n, 8358n, 32810n, 32810n, 32810n, 32810n, 35687n, 33073n, 6870n], - [3n, 40672n, 8358n, 32810n, 32810n, 32810n, 32810n, 35687n, 33073n, 6870n], - ]); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + describe("BertForTokenClassification", () => { + const model_id = "hf-internal-testing/tiny-random-BertForTokenClassification"; + + /** @type {BertForTokenClassification} */ + let model; + /** @type {BertTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await BertForTokenClassification.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); + tokenizer = await BertTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const { logits } = await model(inputs); + expect(logits.dims).toEqual([1, 7, 2]); + expect(logits.mean().item()).toBeCloseTo(0.07089076191186905, 5); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const { logits } = await model(inputs); + expect(logits.dims).toEqual([2, 12, 2]); + expect(logits.mean().item()).toBeCloseTo(0.04702216014266014, 5); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); }); - describe('musicgen', () => { - describe('MusicgenForConditionalGeneration', () => { - const model_id = 'hf-internal-testing/tiny-random-MusicgenForConditionalGeneration'; - - // Example adapted from https://huggingface.co/docs/transformers/model_doc/musicgen#text-conditional-generation - const texts = [ - "80s pop track with bassy drums and synth", - "90s rock song with loud guitars and heavy drums", - ]; - - /** @type {MusicgenForConditionalGeneration} */ - let model; - /** @type {T5Tokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await MusicgenForConditionalGeneration.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await T5Tokenizer.from_pretrained(model_id); - }, MAX_MODEL_LOAD_TIME); - - it('forward', async () => { - // Example from https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenForConditionalGeneration.forward.example - const inputs = tokenizer(texts, { padding: true }); - const pad_token_id = BigInt(model.generation_config.pad_token_id); - const decoder_input_ids = full( - [inputs.input_ids.dims[0] * model.config.decoder.num_codebooks, 1], - pad_token_id, - ); - const { logits } = await model({ ...inputs, decoder_input_ids }); - expect(logits.dims).toEqual([8, 1, 99]); - expect(logits.mean().item()).toBeCloseTo(-0.0018370470497757196, 5); - }); - - it('batch_size=1', async () => { - const inputs = tokenizer(texts[0]); - const audio_values = await model.generate({ ...inputs, max_length: 10 }); - expect(audio_values.dims).toEqual([1, 1, 1920]); - expect(audio_values.mean().item()).toBeCloseTo(0.16644205152988434, 5); - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(texts, { padding: true }); - const audio_values = await model.generate({ ...inputs, max_length: 10 }); - expect(audio_values.dims).toEqual([2, 1, 1920]); - expect(audio_values.mean().item()).toBeCloseTo(0.16644206643104553, 5); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + describe("BertForQuestionAnswering", () => { + const model_id = "hf-internal-testing/tiny-random-BertForQuestionAnswering"; + + /** @type {BertForQuestionAnswering} */ + let model; + /** @type {BertTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await BertForQuestionAnswering.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); + tokenizer = await BertTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const { start_logits, end_logits } = await model(inputs); + expect(start_logits.dims).toEqual([1, 7]); + expect(start_logits.mean().item()).toBeCloseTo(0.12772157788276672, 5); + expect(end_logits.dims).toEqual([1, 7]); + expect(end_logits.mean().item()).toBeCloseTo(0.11811424791812897, 5); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const { start_logits, end_logits } = await model(inputs); + expect(start_logits.dims).toEqual([2, 12]); + expect(start_logits.mean().item()).toBeCloseTo(0.12843115627765656, 5); + expect(end_logits.dims).toEqual([2, 12]); + expect(end_logits.mean().item()).toBeCloseTo(0.11745202541351318, 5); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); }); - - describe('whisper', () => { - - describe('WhisperForConditionalGeneration', () => { - const model_id = 'Xenova/tiny-random-WhisperForConditionalGeneration'; - - /** @type {WhisperForConditionalGeneration} */ - let model; - /** @type {WhisperTokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await WhisperForConditionalGeneration.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await WhisperTokenizer.from_pretrained(model_id); - }, MAX_MODEL_LOAD_TIME); - - describe('prefix tokens', () => { - const input_features = full([1, 80, 3000], 0.0); - - describe('English-only', () => { - it('default', async () => { - const outputs = await model.generate({ - input_features, - is_multilingual: false, - max_new_tokens: 1, - }); - - expect(outputs.tolist()).toEqual([[ - /* Prefix */ 50258n, 50363n, /* Generated */ 45084n, - ]]); - }); - it('return_timestamps=true', async () => { - const outputs = await model.generate({ - input_features, - is_multilingual: false, - max_new_tokens: 1, - return_timestamps: true, - }); - - expect(outputs.tolist()).toEqual([[ - /* Prefix */ 50258n, /* Generated */ 50366n, - ]]); - }); - }); - - describe('multilingual', () => { - it('language unset; task unset', async () => { - // language defaults to 'en' - // task defaults to 'transcribe' - - const outputs = await model.generate({ - input_features, - max_new_tokens: 1, - }); - - expect(outputs.tolist()).toEqual([[ - /* Prefix */ 50258n, 50259n, 50359n, 50363n, /* Generated */ 45084n, - ]]); - }); - - it('language set; task unset', async () => { - // task defaults to 'transcribe' - const outputs = await model.generate({ - input_features, - max_new_tokens: 1, - language: 'af', - }); - - expect(outputs.tolist()).toEqual([[ - /* Prefix */ 50258n, 50327n, 50359n, 50363n, /* Generated */ 45084n, - ]]); - }); - - it('language set; task set', async () => { - const outputs = await model.generate({ - input_features, - max_new_tokens: 1, - language: 'zh', - task: 'translate', - }); - - expect(outputs.tolist()).toEqual([[ - /* Prefix */ 50258n, 50260n, 50358n, 50363n, /* Generated */ 45084n, - ]]); - }); - - it('return_timestamps=true', async () => { - const outputs = await model.generate({ - input_features, - max_new_tokens: 1, - language: 'en', - task: 'transcribe', - return_timestamps: true, - }); - - expect(outputs.tolist()).toEqual([[ - /* Prefix */ 50258n, 50259n, 50359n, /* Generated */ 50400n, - ]]); - }); - }); + }); + + describe("t5", () => { + describe("T5Model", () => { + const model_id = "hf-internal-testing/tiny-random-T5Model"; + + /** @type {T5Model} */ + let model; + /** @type {T5Tokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await T5Model.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await T5Tokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it("forward", async () => { + // Example adapted from https://huggingface.co/google-t5/t5-small#how-to-get-started-with-the-model + const inputs = tokenizer("Studies have been shown that owning a dog is good for you"); + const { input_ids: decoder_input_ids } = tokenizer("Studies show that"); + + const { last_hidden_state } = await model({ ...inputs, decoder_input_ids }); + expect(last_hidden_state.dims).toEqual([1, 4, 32]); + expect(last_hidden_state.mean().item()).toBeCloseTo(7.492632721550763e-5, 8); + }); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + describe("T5ForConditionalGeneration", () => { + const model_id = "hf-internal-testing/tiny-random-T5ForConditionalGeneration"; + + /** @type {T5ForConditionalGeneration} */ + let model; + /** @type {T5Tokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await T5ForConditionalGeneration.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await T5Tokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it("forward", async () => { + // Example adapted from https://huggingface.co/google-t5/t5-small#how-to-get-started-with-the-model + const inputs = tokenizer("Studies have been shown that owning a dog is good for you"); + const { input_ids: decoder_input_ids } = tokenizer("Studies show that"); + + const model = await T5ForConditionalGeneration.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS); + const outputs = await model({ ...inputs, decoder_input_ids }); + expect(outputs.logits.dims).toEqual([1, 4, 32100]); + expect(outputs.logits.mean().item()).toBeCloseTo(8.867568901393952e-9, 12); + }); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([[0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n], + [0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n], + ]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); + + describe("marian", () => { + describe("MarianMTModel", () => { + const model_id = "onnx-community/tiny-random-MarianMTModel"; + + /** @type {MarianMTModel} */ + let model; + /** @type {MarianTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await MarianMTModel.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await MarianTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([[3n, 40672n, 8358n, 32810n, 32810n, 32810n, 32810n, 35687n, 33073n, 6870n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [3n, 40672n, 8358n, 32810n, 32810n, 32810n, 32810n, 35687n, 33073n, 6870n], + [3n, 40672n, 8358n, 32810n, 32810n, 32810n, 32810n, 35687n, 33073n, 6870n], + ]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); + + describe("musicgen", () => { + describe("MusicgenForConditionalGeneration", () => { + const model_id = "hf-internal-testing/tiny-random-MusicgenForConditionalGeneration"; + + // Example adapted from https://huggingface.co/docs/transformers/model_doc/musicgen#text-conditional-generation + const texts = ["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"]; + + /** @type {MusicgenForConditionalGeneration} */ + let model; + /** @type {T5Tokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await MusicgenForConditionalGeneration.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await T5Tokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it("forward", async () => { + // Example from https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenForConditionalGeneration.forward.example + const inputs = tokenizer(texts, { padding: true }); + const pad_token_id = BigInt(model.generation_config.pad_token_id); + const decoder_input_ids = full([inputs.input_ids.dims[0] * model.config.decoder.num_codebooks, 1], pad_token_id); + const { logits } = await model({ ...inputs, decoder_input_ids }); + expect(logits.dims).toEqual([8, 1, 99]); + expect(logits.mean().item()).toBeCloseTo(-0.0018370470497757196, 5); + }); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer(texts[0]); + const audio_values = await model.generate({ ...inputs, max_length: 10 }); + expect(audio_values.dims).toEqual([1, 1, 1920]); + expect(audio_values.mean().item()).toBeCloseTo(0.16644205152988434, 5); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(texts, { padding: true }); + const audio_values = await model.generate({ ...inputs, max_length: 10 }); + expect(audio_values.dims).toEqual([2, 1, 1920]); + expect(audio_values.mean().item()).toBeCloseTo(0.16644206643104553, 5); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); + + describe("whisper", () => { + describe("WhisperForConditionalGeneration", () => { + const model_id = "Xenova/tiny-random-WhisperForConditionalGeneration"; + + /** @type {WhisperForConditionalGeneration} */ + let model; + /** @type {WhisperTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await WhisperForConditionalGeneration.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await WhisperTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + describe("prefix tokens", () => { + const input_features = full([1, 80, 3000], 0.0); + + describe("English-only", () => { + it("default", async () => { + const outputs = await model.generate({ + input_features, + is_multilingual: false, + max_new_tokens: 1, }); - describe('decoder_start_ids', () => { - const input_features = full([1, 80, 3000], 0.0); - - it('broadcast inputs', async () => { - const { decoder_start_token_id, lang_to_id, task_to_id, no_timestamps_token_id } = model.generation_config; - - const outputs = await model.generate({ - input_features, // batch size 1 - max_new_tokens: 1, - decoder_input_ids: [ // batch size 2 - // <|startoftranscript|> <|lang_id|> <|task|> [<|notimestamps|>] - [decoder_start_token_id, lang_to_id['<|en|>'], task_to_id['translate'], no_timestamps_token_id], - [decoder_start_token_id, lang_to_id['<|fr|>'], task_to_id['transcribe'], no_timestamps_token_id], - ], - }); - expect(outputs.tolist()).toEqual([ - [/* Prefix */ 50258n, 50259n, 50358n, 50363n, /* Generated */ 45084n], - [/* Prefix */ 50258n, 50265n, 50359n, 50363n, /* Generated */ 45084n], - ]); - }); + expect(outputs.tolist()).toEqual([[/* Prefix */ 50258n, 50363n, /* Generated */ 45084n]]); + }); + it("return_timestamps=true", async () => { + const outputs = await model.generate({ + input_features, + is_multilingual: false, + max_new_tokens: 1, + return_timestamps: true, }); - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + expect(outputs.tolist()).toEqual([[/* Prefix */ 50258n, /* Generated */ 50366n]]); + }); }); - }); - - describe('llava', () => { - const prompts = [ - // Example adapted from https://huggingface.co/docs/transformers/model_doc/llava#transformers.LlavaForConditionalGeneration.forward.example - "\nUSER: What's the content of the image?\nASSISTANT:", - "Hi", - ] + describe("multilingual", () => { + it("language unset; task unset", async () => { + // language defaults to 'en' + // task defaults to 'transcribe' - // Empty white image - const dims = [224, 224, 3]; - const image = new RawImage(new Uint8ClampedArray(dims[0] * dims[1] * dims[2]).fill(255), ...dims); - - describe('LlavaForConditionalGeneration', () => { - const model_id = 'Xenova/tiny-random-LlavaForConditionalGeneration'; - - /** @type {LlavaForConditionalGeneration} */ - let model; - /** @type {LlamaTokenizer} */ - let tokenizer; - /** @type {CLIPImageProcessor} */ - let processor; - beforeAll(async () => { - model = await LlavaForConditionalGeneration.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await LlamaTokenizer.from_pretrained(model_id); - processor = await AutoProcessor.from_pretrained(model_id); - }, MAX_MODEL_LOAD_TIME); - - it('forward', async () => { - const text_inputs = tokenizer(prompts[0]); - const vision_inputs = await processor(image); - const inputs = { ...text_inputs, ...vision_inputs }; - - const { logits } = await model(inputs); - expect(logits.dims).toEqual([1, 244, 32002]); - expect(logits.mean().item()).toBeCloseTo(-0.0005755752790719271, 8); + const outputs = await model.generate({ + input_features, + max_new_tokens: 1, }); - it('batch_size=1', async () => { - const text_inputs = tokenizer(prompts[0]); - const vision_inputs = await processor(image); - const inputs = { ...text_inputs, ...vision_inputs }; - - const generate_ids = await model.generate({ ...inputs, max_new_tokens: 10 }); - expect(generate_ids.tolist()).toEqual([ - [1n, 32000n, 29871n, 13n, 11889n, 29901n, 1724n, 29915n, 29879n, 278n, 2793n, 310n, 278n, 1967n, 29973n, 13n, 22933n, 9047n, 13566n, 29901n, 21557n, 16781n, 27238n, 8279n, 20454n, 11927n, 12462n, 12306n, 2414n, 7561n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const text_inputs = tokenizer(prompts, { padding: true }); - const vision_inputs = await processor([image, image]); - const inputs = { ...text_inputs, ...vision_inputs }; - - const generate_ids = await model.generate({ ...inputs, max_new_tokens: 10 }); - expect(generate_ids.tolist()).toEqual([ - [1n, 32000n, 29871n, 13n, 11889n, 29901n, 1724n, 29915n, 29879n, 278n, 2793n, 310n, 278n, 1967n, 29973n, 13n, 22933n, 9047n, 13566n, 29901n, 21557n, 16781n, 27238n, 8279n, 20454n, 11927n, 12462n, 12306n, 2414n, 7561n], - [0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 1n, 32000n, 6324n, 1217n, 22958n, 22913n, 10381n, 148n, 31410n, 31736n, 7358n, 9150n, 28635n] - ]); - - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); - }); - }); + expect(outputs.tolist()).toEqual([[/* Prefix */ 50258n, 50259n, 50359n, 50363n, /* Generated */ 45084n]]); + }); + it("language set; task unset", async () => { + // task defaults to 'transcribe' + const outputs = await model.generate({ + input_features, + max_new_tokens: 1, + language: "af", + }); - describe('florence2', () => { + expect(outputs.tolist()).toEqual([[/* Prefix */ 50258n, 50327n, 50359n, 50363n, /* Generated */ 45084n]]); + }); - const texts = [ - 'Describe with a paragraph what is shown in the image.', - 'Locate the objects with category name in the image.', - ] + it("language set; task set", async () => { + const outputs = await model.generate({ + input_features, + max_new_tokens: 1, + language: "zh", + task: "translate", + }); - // Empty white image - const dims = [224, 224, 3]; - const image = new RawImage(new Uint8ClampedArray(dims[0] * dims[1] * dims[2]).fill(255), ...dims); + expect(outputs.tolist()).toEqual([[/* Prefix */ 50258n, 50260n, 50358n, 50363n, /* Generated */ 45084n]]); + }); - describe('Florence2ForConditionalGeneration', () => { - const model_id = 'Xenova/tiny-random-Florence2ForConditionalGeneration'; - - /** @type {Florence2ForConditionalGeneration} */ - let model; - /** @type {BartTokenizer} */ - let tokenizer; - /** @type {Florence2Processor} */ - let processor; - beforeAll(async () => { - model = await Florence2ForConditionalGeneration.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await BartTokenizer.from_pretrained(model_id); - processor = await AutoProcessor.from_pretrained(model_id); - }, MAX_MODEL_LOAD_TIME); - - it('forward', async () => { - const text_inputs = tokenizer(texts[0]); - const vision_inputs = await processor(image); - const inputs = { - ...text_inputs, - ...vision_inputs, - decoder_input_ids: full([1, 1], 2n), - }; - - const { logits } = await model(inputs); - expect(logits.dims).toEqual([1, 1, 51289]); + it("return_timestamps=true", async () => { + const outputs = await model.generate({ + input_features, + max_new_tokens: 1, + language: "en", + task: "transcribe", + return_timestamps: true, }); - it('batch_size=1', async () => { - const text_inputs = tokenizer(texts[0]); - { - const generate_ids = await model.generate({ ...text_inputs, max_new_tokens: 10 }); - expect(generate_ids.tolist()).toEqual([ - [2n, 0n, 0n, 0n, 1n, 0n, 0n, 2n] - ]); - } - { - const vision_inputs = await processor(image); - const inputs = { ...text_inputs, ...vision_inputs }; - - const generate_ids = await model.generate({ ...inputs, max_new_tokens: 10 }); - expect(generate_ids.tolist()).toEqual([ - [2n, 0n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 2n] - ]); - } - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const text_inputs = tokenizer(texts, { padding: true }); - { - const generate_ids = await model.generate({ ...text_inputs, max_new_tokens: 10 }); - expect(generate_ids.tolist()).toEqual([ - [2n, 0n, 0n, 0n, 1n, 0n, 0n, 2n], - [2n, 0n, 0n, 0n, 1n, 0n, 0n, 2n] - ]); - } - { - const vision_inputs = await processor([image, image]); - const inputs = { ...text_inputs, ...vision_inputs }; - - const generate_ids = await model.generate({ ...inputs, max_new_tokens: 10 }); - expect(generate_ids.tolist()).toEqual([ - [2n, 0n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 2n], - [2n, 0n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 2n] - ]); - } - - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + expect(outputs.tolist()).toEqual([[/* Prefix */ 50258n, 50259n, 50359n, /* Generated */ 50400n]]); + }); }); - }); + }); + + describe("decoder_start_ids", () => { + const input_features = full([1, 80, 3000], 0.0); + + it("broadcast inputs", async () => { + const { decoder_start_token_id, lang_to_id, task_to_id, no_timestamps_token_id } = model.generation_config; + + const outputs = await model.generate({ + input_features, // batch size 1 + max_new_tokens: 1, + decoder_input_ids: [ + // batch size 2 + // <|startoftranscript|> <|lang_id|> <|task|> [<|notimestamps|>] + [decoder_start_token_id, lang_to_id["<|en|>"], task_to_id["translate"], no_timestamps_token_id], + [decoder_start_token_id, lang_to_id["<|fr|>"], task_to_id["transcribe"], no_timestamps_token_id], + ], + }); + expect(outputs.tolist()).toEqual([ + [/* Prefix */ 50258n, 50259n, 50358n, 50363n, /* Generated */ 45084n], + [/* Prefix */ 50258n, 50265n, 50359n, 50363n, /* Generated */ 45084n], + ]); + }); + }); - describe('vision-encoder-decoder', () => { - - describe('VisionEncoderDecoderModel', () => { - const model_id = 'hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2'; - - /** @type {VisionEncoderDecoderModel} */ - let model; - /** @type {GPT2Tokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await VisionEncoderDecoderModel.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await GPT2Tokenizer.from_pretrained(model_id); - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const outputs = await model.generate({ - pixel_values: full([1, 3, 30, 30], -1.0), - max_length: 5, - }); - expect(outputs.tolist()).toEqual([ - [0n, 400n, 400n, 400n, 400n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - // TODO: Add back - // it('batch_size>1', async () => { - // const outputs = await model.generate({ - // pixel_values: cat([ - // full([1, 3, 30, 30], -1.0), - // full([1, 3, 30, 30], 0.0), - // ]), - // max_length: 5, - // }); - // expect(outputs.tolist()).toEqual([ - // // Generation continues - // [0n, 400n, 400n, 400n, 400n], - - // // Finishes early. 1023 is the padding token - // [0n, 0n, 1023n, 1023n, 1023n], - // ]); - // }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); + + describe("llava", () => { + const prompts = [ + // Example adapted from https://huggingface.co/docs/transformers/model_doc/llava#transformers.LlavaForConditionalGeneration.forward.example + "\nUSER: What's the content of the image?\nASSISTANT:", + "Hi", + ]; + + // Empty white image + const dims = [224, 224, 3]; + const image = new RawImage(new Uint8ClampedArray(dims[0] * dims[1] * dims[2]).fill(255), ...dims); + + describe("LlavaForConditionalGeneration", () => { + const model_id = "Xenova/tiny-random-LlavaForConditionalGeneration"; + + /** @type {LlavaForConditionalGeneration} */ + let model; + /** @type {LlamaTokenizer} */ + let tokenizer; + /** @type {CLIPImageProcessor} */ + let processor; + beforeAll(async () => { + model = await LlavaForConditionalGeneration.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); + tokenizer = await LlamaTokenizer.from_pretrained(model_id); + processor = await AutoProcessor.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it("forward", async () => { + const text_inputs = tokenizer(prompts[0]); + const vision_inputs = await processor(image); + const inputs = { ...text_inputs, ...vision_inputs }; + + const { logits } = await model(inputs); + expect(logits.dims).toEqual([1, 244, 32002]); + expect(logits.mean().item()).toBeCloseTo(-0.0005755752790719271, 8); + }); + + it( + "batch_size=1", + async () => { + const text_inputs = tokenizer(prompts[0]); + const vision_inputs = await processor(image); + const inputs = { ...text_inputs, ...vision_inputs }; + + const generate_ids = await model.generate({ ...inputs, max_new_tokens: 10 }); + expect(generate_ids.tolist()).toEqual([[1n, 32000n, 29871n, 13n, 11889n, 29901n, 1724n, 29915n, 29879n, 278n, 2793n, 310n, 278n, 1967n, 29973n, 13n, 22933n, 9047n, 13566n, 29901n, 21557n, 16781n, 27238n, 8279n, 20454n, 11927n, 12462n, 12306n, 2414n, 7561n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const text_inputs = tokenizer(prompts, { padding: true }); + const vision_inputs = await processor([image, image]); + const inputs = { ...text_inputs, ...vision_inputs }; + + const generate_ids = await model.generate({ ...inputs, max_new_tokens: 10 }); + expect(generate_ids.tolist()).toEqual([ + [1n, 32000n, 29871n, 13n, 11889n, 29901n, 1724n, 29915n, 29879n, 278n, 2793n, 310n, 278n, 1967n, 29973n, 13n, 22933n, 9047n, 13566n, 29901n, 21557n, 16781n, 27238n, 8279n, 20454n, 11927n, 12462n, 12306n, 2414n, 7561n], + [0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 1n, 32000n, 6324n, 1217n, 22958n, 22913n, 10381n, 148n, 31410n, 31736n, 7358n, 9150n, 28635n], + ]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); }); - describe('opt', () => { - describe('OPTForCausalLM', () => { - const model_id = 'hf-internal-testing/tiny-random-OPTForCausalLM'; - /** @type {OPTForCausalLM} */ - let model; - /** @type {GPT2Tokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await OPTForCausalLM.from_pretrained(model_id, { - // TODO move to config - revision: 'refs/pr/2', - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await GPT2Tokenizer.from_pretrained(model_id, { - // TODO update this - revision: 'refs/pr/3', - }); - tokenizer.padding_side = 'left'; - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer('hello'); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [2n, 42891n, 39144n, 39144n, 39144n, 39144n, 39144n, 39144n, 39144n, 39144n], - ]); - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(['hello', 'hello world'], { padding: true }); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [1n, 2n, 42891n, 39144n, 39144n, 39144n, 39144n, 39144n, 39144n, 39144n], - [2n, 42891n, 232n, 24680n, 24680n, 24680n, 24680n, 24680n, 24680n, 24680n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("florence2", () => { + const texts = ["Describe with a paragraph what is shown in the image.", "Locate the objects with category name in the image."]; + + // Empty white image + const dims = [224, 224, 3]; + const image = new RawImage(new Uint8ClampedArray(dims[0] * dims[1] * dims[2]).fill(255), ...dims); + + describe("Florence2ForConditionalGeneration", () => { + const model_id = "Xenova/tiny-random-Florence2ForConditionalGeneration"; + + /** @type {Florence2ForConditionalGeneration} */ + let model; + /** @type {BartTokenizer} */ + let tokenizer; + /** @type {Florence2Processor} */ + let processor; + beforeAll(async () => { + model = await Florence2ForConditionalGeneration.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); + tokenizer = await BartTokenizer.from_pretrained(model_id); + processor = await AutoProcessor.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it("forward", async () => { + const text_inputs = tokenizer(texts[0]); + const vision_inputs = await processor(image); + const inputs = { + ...text_inputs, + ...vision_inputs, + decoder_input_ids: full([1, 1], 2n), + }; + + const { logits } = await model(inputs); + expect(logits.dims).toEqual([1, 1, 51289]); + }); + + it( + "batch_size=1", + async () => { + const text_inputs = tokenizer(texts[0]); + { + const generate_ids = await model.generate({ ...text_inputs, max_new_tokens: 10 }); + expect(generate_ids.tolist()).toEqual([[2n, 0n, 0n, 0n, 1n, 0n, 0n, 2n]]); + } + { + const vision_inputs = await processor(image); + const inputs = { ...text_inputs, ...vision_inputs }; + + const generate_ids = await model.generate({ ...inputs, max_new_tokens: 10 }); + expect(generate_ids.tolist()).toEqual([[2n, 0n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 2n]]); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const text_inputs = tokenizer(texts, { padding: true }); + { + const generate_ids = await model.generate({ ...text_inputs, max_new_tokens: 10 }); + expect(generate_ids.tolist()).toEqual([ + [2n, 0n, 0n, 0n, 1n, 0n, 0n, 2n], + [2n, 0n, 0n, 0n, 1n, 0n, 0n, 2n], + ]); + } + { + const vision_inputs = await processor([image, image]); + const inputs = { ...text_inputs, ...vision_inputs }; + + const generate_ids = await model.generate({ ...inputs, max_new_tokens: 10 }); + expect(generate_ids.tolist()).toEqual([ + [2n, 0n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 2n], + [2n, 0n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 2n], + ]); + } + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); }); - - describe('llama', () => { - describe('LlamaForCausalLM', () => { - const model_id = 'hf-internal-testing/tiny-random-LlamaForCausalLM'; - /** @type {LlamaForCausalLM} */ - let model; - /** @type {LlamaTokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await LlamaForCausalLM.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await LlamaTokenizer.from_pretrained(model_id); - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer('hello'); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [1n, 22172n, 18547n, 8143n, 22202n, 9456n, 17213n, 15330n, 26591n, 15721n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(['hello', 'hello world'], { padding: true }); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [0n, 1n, 22172n, 18547n, 8143n, 22202n, 9456n, 17213n, 15330n, 26591n], - [1n, 22172n, 3186n, 24786n, 19169n, 20222n, 29993n, 27146n, 27426n, 24562n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("vision-encoder-decoder", () => { + describe("VisionEncoderDecoderModel", () => { + const model_id = "hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2"; + + /** @type {VisionEncoderDecoderModel} */ + let model; + /** @type {GPT2Tokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await VisionEncoderDecoderModel.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); + tokenizer = await GPT2Tokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const outputs = await model.generate({ + pixel_values: full([1, 3, 30, 30], -1.0), + max_length: 5, + }); + expect(outputs.tolist()).toEqual([[0n, 400n, 400n, 400n, 400n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + // TODO: Add back + // it('batch_size>1', async () => { + // const outputs = await model.generate({ + // pixel_values: cat([ + // full([1, 3, 30, 30], -1.0), + // full([1, 3, 30, 30], 0.0), + // ]), + // max_length: 5, + // }); + // expect(outputs.tolist()).toEqual([ + // // Generation continues + // [0n, 400n, 400n, 400n, 400n], + + // // Finishes early. 1023 is the padding token + // [0n, 0n, 1023n, 1023n, 1023n], + // ]); + // }, MAX_TEST_EXECUTION_TIME); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); }); - - describe('cohere', () => { - describe('CohereModel', () => { - const model_id = 'hf-internal-testing/tiny-random-CohereModel'; - /** @type {CohereModel} */ - let model; - /** @type {CohereTokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await CohereModel.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await CohereTokenizer.from_pretrained(model_id); - tokenizer.padding_side = 'left'; - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer('hello'); - const { last_hidden_state } = await model(inputs); - expect(last_hidden_state.dims).toEqual([1, 4, 32]); - expect(last_hidden_state.mean().item()).toBeCloseTo(0.0, 5); - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(['hello', 'hello world'], { padding: true }); - const { last_hidden_state } = await model(inputs); - expect(last_hidden_state.dims).toEqual([2, 6, 32]); - expect(last_hidden_state.mean().item()).toBeCloseTo(9.934107758624577e-09, 5); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + }); + describe("opt", () => { + describe("OPTForCausalLM", () => { + const model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"; + /** @type {OPTForCausalLM} */ + let model; + /** @type {GPT2Tokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await OPTForCausalLM.from_pretrained(model_id, { + // TODO move to config + revision: "refs/pr/2", + ...DEFAULT_MODEL_OPTIONS, }); - - describe('CohereForCausalLM', () => { - const model_id = 'hf-internal-testing/tiny-random-CohereForCausalLM'; - /** @type {CohereForCausalLM} */ - let model; - /** @type {CohereTokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await CohereForCausalLM.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await CohereTokenizer.from_pretrained(model_id); - tokenizer.padding_side = 'left'; - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer('hello'); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [5n, 203n, 790n, 87n, 87n, 87n, 87n, 87n, 87n, 87n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(['hello', 'hello world'], { padding: true }); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [0n, 0n, 5n, 203n, 790n, 87n, 87n, 87n, 87n, 87n], - [5n, 203n, 790n, 87n, 214n, 741n, 741n, 741n, 741n, 741n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + tokenizer = await GPT2Tokenizer.from_pretrained(model_id, { + // TODO update this + revision: "refs/pr/3", }); + tokenizer.padding_side = "left"; + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([[2n, 42891n, 39144n, 39144n, 39144n, 39144n, 39144n, 39144n, 39144n, 39144n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [1n, 2n, 42891n, 39144n, 39144n, 39144n, 39144n, 39144n, 39144n, 39144n], + [2n, 42891n, 232n, 24680n, 24680n, 24680n, 24680n, 24680n, 24680n, 24680n], + ]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); }); - - - describe('gemma', () => { - describe('GemmaForCausalLM', () => { - const model_id = 'Xenova/tiny-random-GemmaForCausalLM'; - /** @type {GemmaForCausalLM} */ - let model; - /** @type {GemmaTokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await GemmaForCausalLM.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await GemmaTokenizer.from_pretrained(model_id); - tokenizer.padding_side = 'left'; - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer('hello'); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [2n, 17534n, 254059n, 254059n, 254059n, 254059n, 254059n, 254059n, 254059n, 254059n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(['hello', 'hello world'], { padding: true }); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [0n, 2n, 17534n, 254059n, 254059n, 254059n, 254059n, 254059n, 254059n, 254059n], - [2n, 17534n, 2134n, 71055n, 71055n, 71055n, 71055n, 71055n, 71055n, 71055n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("llama", () => { + describe("LlamaForCausalLM", () => { + const model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM"; + /** @type {LlamaForCausalLM} */ + let model; + /** @type {LlamaTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await LlamaForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); + tokenizer = await LlamaTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([[1n, 22172n, 18547n, 8143n, 22202n, 9456n, 17213n, 15330n, 26591n, 15721n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 1n, 22172n, 18547n, 8143n, 22202n, 9456n, 17213n, 15330n, 26591n], + [1n, 22172n, 3186n, 24786n, 19169n, 20222n, 29993n, 27146n, 27426n, 24562n], + ]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); }); - - describe('gemma', () => { - describe('Gemma2ForCausalLM', () => { - const model_id = 'hf-internal-testing/tiny-random-Gemma2ForCausalLM'; - /** @type {Gemma2ForCausalLM} */ - let model; - /** @type {GemmaTokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await Gemma2ForCausalLM.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await GemmaTokenizer.from_pretrained(model_id); - tokenizer.padding_side = 'left'; - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer('hello'); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [2n, 17534n, 127534n, 160055n, 160055n, 160055n, 160055n, 160055n, 160055n, 160055n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(['hello', 'hello world'], { padding: true }); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [0n, 2n, 17534n, 127534n, 127534n, 215341n, 215341n, 215341n, 215341n, 215341n], - [2n, 17534n, 2134n, 107508n, 160055n, 160055n, 160055n, 160055n, 160055n, 160055n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("cohere", () => { + describe("CohereModel", () => { + const model_id = "hf-internal-testing/tiny-random-CohereModel"; + /** @type {CohereModel} */ + let model; + /** @type {CohereTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await CohereModel.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); + tokenizer = await CohereTokenizer.from_pretrained(model_id); + tokenizer.padding_side = "left"; + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const { last_hidden_state } = await model(inputs); + expect(last_hidden_state.dims).toEqual([1, 4, 32]); + expect(last_hidden_state.mean().item()).toBeCloseTo(0.0, 5); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const { last_hidden_state } = await model(inputs); + expect(last_hidden_state.dims).toEqual([2, 6, 32]); + expect(last_hidden_state.mean().item()).toBeCloseTo(9.934107758624577e-9, 5); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); }); - describe('gpt_neo', () => { - describe('GPTNeoForCausalLM', () => { - const model_id = 'hf-internal-testing/tiny-random-GPTNeoForCausalLM'; - /** @type {GPTNeoForCausalLM} */ - let model; - /** @type {GPT2Tokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await GPTNeoForCausalLM.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await GPT2Tokenizer.from_pretrained(model_id); - tokenizer.padding_side = 'left'; - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer('hello'); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [258n, 863n, 79n, 79n, 79n, 949n, 949n, 949n, 949n, 949n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(['hello', 'hello world'], { padding: true }); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [0n, 0n, 258n, 863n, 79n, 79n, 79n, 949n, 949n, 949n], - [258n, 863n, 79n, 269n, 813n, 849n, 849n, 849n, 849n, 849n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + describe("CohereForCausalLM", () => { + const model_id = "hf-internal-testing/tiny-random-CohereForCausalLM"; + /** @type {CohereForCausalLM} */ + let model; + /** @type {CohereTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await CohereForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); + tokenizer = await CohereTokenizer.from_pretrained(model_id); + tokenizer.padding_side = "left"; + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([[5n, 203n, 790n, 87n, 87n, 87n, 87n, 87n, 87n, 87n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 0n, 5n, 203n, 790n, 87n, 87n, 87n, 87n, 87n], + [5n, 203n, 790n, 87n, 214n, 741n, 741n, 741n, 741n, 741n], + ]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); }); - - describe('gpt_neox', () => { - describe('GPTNeoXForCausalLM', () => { - const model_id = 'hf-internal-testing/tiny-random-GPTNeoXForCausalLM'; - /** @type {GPTNeoXForCausalLM} */ - let model; - /** @type {GPTNeoXTokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await GPTNeoXForCausalLM.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await GPTNeoXTokenizer.from_pretrained(model_id); - tokenizer.padding_side = 'left'; - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer('hello'); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [259n, 864n, 80n, 881n, 502n, 895n, 938n, 668n, 502n, 895n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(['hello', 'hello world'], { padding: true }); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [0n, 0n, 259n, 864n, 80n, 881n, 502n, 895n, 938n, 668n], - [259n, 864n, 80n, 270n, 814n, 522n, 112n, 268n, 503n, 468n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("gemma", () => { + describe("GemmaForCausalLM", () => { + const model_id = "Xenova/tiny-random-GemmaForCausalLM"; + /** @type {GemmaForCausalLM} */ + let model; + /** @type {GemmaTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await GemmaForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); + tokenizer = await GemmaTokenizer.from_pretrained(model_id); + tokenizer.padding_side = "left"; + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([[2n, 17534n, 254059n, 254059n, 254059n, 254059n, 254059n, 254059n, 254059n, 254059n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 2n, 17534n, 254059n, 254059n, 254059n, 254059n, 254059n, 254059n, 254059n], + [2n, 17534n, 2134n, 71055n, 71055n, 71055n, 71055n, 71055n, 71055n, 71055n], + ]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); }); - - describe('gptj', () => { - describe('GPTJForCausalLM', () => { - const model_id = 'hf-internal-testing/tiny-random-GPTJForCausalLM'; - /** @type {GPTJForCausalLM} */ - let model; - /** @type {GPTNeoXTokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await GPTJForCausalLM.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await GPTNeoXTokenizer.from_pretrained(model_id); - tokenizer.padding_side = 'left'; - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer('hello'); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [258n, 863n, 79n, 102n, 401n, 773n, 889n, 159n, 957n, 869n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(['hello', 'hello world'], { padding: true }); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [0n, 0n, 258n, 863n, 79n, 102n, 401n, 773n, 889n, 159n], - [258n, 863n, 79n, 269n, 813n, 879n, 175n, 39n, 141n, 1000n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("gemma", () => { + describe("Gemma2ForCausalLM", () => { + const model_id = "hf-internal-testing/tiny-random-Gemma2ForCausalLM"; + /** @type {Gemma2ForCausalLM} */ + let model; + /** @type {GemmaTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await Gemma2ForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); + tokenizer = await GemmaTokenizer.from_pretrained(model_id); + tokenizer.padding_side = "left"; + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([[2n, 17534n, 127534n, 160055n, 160055n, 160055n, 160055n, 160055n, 160055n, 160055n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 2n, 17534n, 127534n, 127534n, 215341n, 215341n, 215341n, 215341n, 215341n], + [2n, 17534n, 2134n, 107508n, 160055n, 160055n, 160055n, 160055n, 160055n, 160055n], + ]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); }); - - describe('bloom', () => { - describe('BloomForCausalLM', () => { - const model_id = 'hf-internal-testing/tiny-random-BloomForCausalLM'; - /** @type {BloomForCausalLM} */ - let model; - /** @type {BloomTokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await BloomForCausalLM.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await BloomTokenizer.from_pretrained(model_id); - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer('hello'); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [198n, 803n, 82n, 82n, 82n, 82n, 82n, 82n, 82n, 82n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(['hello', 'hello world'], { padding: true }); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [3n, 3n, 198n, 803n, 82n, 82n, 82n, 82n, 82n, 82n], - [198n, 803n, 82n, 209n, 753n, 753n, 753n, 753n, 753n, 753n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("gpt_neo", () => { + describe("GPTNeoForCausalLM", () => { + const model_id = "hf-internal-testing/tiny-random-GPTNeoForCausalLM"; + /** @type {GPTNeoForCausalLM} */ + let model; + /** @type {GPT2Tokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await GPTNeoForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); + tokenizer = await GPT2Tokenizer.from_pretrained(model_id); + tokenizer.padding_side = "left"; + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([[258n, 863n, 79n, 79n, 79n, 949n, 949n, 949n, 949n, 949n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 0n, 258n, 863n, 79n, 79n, 79n, 949n, 949n, 949n], + [258n, 863n, 79n, 269n, 813n, 849n, 849n, 849n, 849n, 849n], + ]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); }); - - describe('gpt_bigcode', () => { - describe('GPTBigCodeForCausalLM', () => { - const model_id = 'hf-internal-testing/tiny-random-GPTBigCodeForCausalLM'; - /** @type {GPTBigCodeForCausalLM} */ - let model; - /** @type {GPT2Tokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await GPTBigCodeForCausalLM.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await GPT2Tokenizer.from_pretrained(model_id); - tokenizer.padding_side = 'left'; - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer('hello'); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [258n, 863n, 79n, 79n, 79n, 79n, 79n, 79n, 79n, 79n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(['hello', 'hello world'], { padding: true }); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [0n, 0n, 258n, 863n, 79n, 79n, 79n, 79n, 79n, 79n], - [258n, 863n, 79n, 269n, 813n, 832n, 93n, 93n, 93n, 93n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("gpt_neox", () => { + describe("GPTNeoXForCausalLM", () => { + const model_id = "hf-internal-testing/tiny-random-GPTNeoXForCausalLM"; + /** @type {GPTNeoXForCausalLM} */ + let model; + /** @type {GPTNeoXTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await GPTNeoXForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); + tokenizer = await GPTNeoXTokenizer.from_pretrained(model_id); + tokenizer.padding_side = "left"; + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([[259n, 864n, 80n, 881n, 502n, 895n, 938n, 668n, 502n, 895n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 0n, 259n, 864n, 80n, 881n, 502n, 895n, 938n, 668n], + [259n, 864n, 80n, 270n, 814n, 522n, 112n, 268n, 503n, 468n], + ]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); }); - - describe('gpt2', () => { - describe('GPT2LMHeadModel', () => { - const model_id = 'hf-internal-testing/tiny-random-GPT2LMHeadModel'; - /** @type {GPT2LMHeadModel} */ - let model; - /** @type {GPT2Tokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await GPT2LMHeadModel.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await GPT2Tokenizer.from_pretrained(model_id); - tokenizer.padding_side = 'left'; - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer('hello'); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [258n, 863n, 79n, 79n, 79n, 79n, 79n, 79n, 79n, 243n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(['hello', 'hello world'], { padding: true }); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [0n, 0n, 258n, 863n, 79n, 79n, 79n, 79n, 79n, 79n], - [258n, 863n, 79n, 269n, 813n, 813n, 813n, 813n, 813n, 813n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("gptj", () => { + describe("GPTJForCausalLM", () => { + const model_id = "hf-internal-testing/tiny-random-GPTJForCausalLM"; + /** @type {GPTJForCausalLM} */ + let model; + /** @type {GPTNeoXTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await GPTJForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); + tokenizer = await GPTNeoXTokenizer.from_pretrained(model_id); + tokenizer.padding_side = "left"; + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([[258n, 863n, 79n, 102n, 401n, 773n, 889n, 159n, 957n, 869n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 0n, 258n, 863n, 79n, 102n, 401n, 773n, 889n, 159n], + [258n, 863n, 79n, 269n, 813n, 879n, 175n, 39n, 141n, 1000n], + ]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); }); - - describe('jais', () => { - describe('JAISLMHeadModel', () => { - const model_id = 'onnx-community/tiny-random-jais'; - /** @type {JAISLMHeadModel} */ - let model; - /** @type {PreTrainedTokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await JAISLMHeadModel.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await PreTrainedTokenizer.from_pretrained(model_id); - tokenizer.padding_side = 'left'; - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer('hello'); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [55422n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n], - ]); - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(['hello', 'hello world'], { padding: true }); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [0n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n], - [55422n, 2838n, 2838n, 2838n, 2838n, 2838n, 2838n, 2838n, 2838n, 2838n], - ]); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("bloom", () => { + describe("BloomForCausalLM", () => { + const model_id = "hf-internal-testing/tiny-random-BloomForCausalLM"; + /** @type {BloomForCausalLM} */ + let model; + /** @type {BloomTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await BloomForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); + tokenizer = await BloomTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([[198n, 803n, 82n, 82n, 82n, 82n, 82n, 82n, 82n, 82n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [3n, 3n, 198n, 803n, 82n, 82n, 82n, 82n, 82n, 82n], + [198n, 803n, 82n, 209n, 753n, 753n, 753n, 753n, 753n, 753n], + ]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); }); - - describe('mpt', () => { - describe('MptForCausalLM', () => { - const model_id = 'hf-internal-testing/tiny-random-MptForCausalLM'; - /** @type {MptForCausalLM} */ - let model; - /** @type {GPTNeoXTokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await MptForCausalLM.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await GPTNeoXTokenizer.from_pretrained(model_id); - tokenizer.padding_side = 'left'; - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer('hello'); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [259n, 864n, 80n, 80n, 80n, 80n, 80n, 80n, 80n, 80n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(['hello', 'hello world'], { padding: true }); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [0n, 0n, 259n, 864n, 80n, 80n, 80n, 80n, 80n, 80n], - [259n, 864n, 80n, 270n, 814n, 293n, 293n, 293n, 293n, 293n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("gpt_bigcode", () => { + describe("GPTBigCodeForCausalLM", () => { + const model_id = "hf-internal-testing/tiny-random-GPTBigCodeForCausalLM"; + /** @type {GPTBigCodeForCausalLM} */ + let model; + /** @type {GPT2Tokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await GPTBigCodeForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); + tokenizer = await GPT2Tokenizer.from_pretrained(model_id); + tokenizer.padding_side = "left"; + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([[258n, 863n, 79n, 79n, 79n, 79n, 79n, 79n, 79n, 79n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 0n, 258n, 863n, 79n, 79n, 79n, 79n, 79n, 79n], + [258n, 863n, 79n, 269n, 813n, 832n, 93n, 93n, 93n, 93n], + ]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); }); - - describe('codegen', () => { - describe('CodeGenForCausalLM', () => { - const model_id = 'hf-internal-testing/tiny-random-CodeGenForCausalLM'; - /** @type {CodeGenForCausalLM} */ - let model; - /** @type {CodeGenTokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await CodeGenForCausalLM.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await CodeGenTokenizer.from_pretrained(model_id); - tokenizer.padding_side = 'left'; - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer('hello'); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [258n, 863n, 79n, 437n, 334n, 450n, 294n, 621n, 375n, 385n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(['hello', 'hello world'], { padding: true }); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [0n, 0n, 258n, 863n, 79n, 437n, 334n, 450n, 294n, 621n], - [258n, 863n, 79n, 269n, 813n, 759n, 113n, 295n, 574n, 987n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("gpt2", () => { + describe("GPT2LMHeadModel", () => { + const model_id = "hf-internal-testing/tiny-random-GPT2LMHeadModel"; + /** @type {GPT2LMHeadModel} */ + let model; + /** @type {GPT2Tokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await GPT2LMHeadModel.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); + tokenizer = await GPT2Tokenizer.from_pretrained(model_id); + tokenizer.padding_side = "left"; + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([[258n, 863n, 79n, 79n, 79n, 79n, 79n, 79n, 79n, 243n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 0n, 258n, 863n, 79n, 79n, 79n, 79n, 79n, 79n], + [258n, 863n, 79n, 269n, 813n, 813n, 813n, 813n, 813n, 813n], + ]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); }); - - describe('mistral', () => { - describe('MistralForCausalLM', () => { - const model_id = 'hf-internal-testing/tiny-random-MistralForCausalLM'; - /** @type {MistralForCausalLM} */ - let model; - /** @type {LlamaTokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await MistralForCausalLM.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await LlamaTokenizer.from_pretrained(model_id); - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer('hello'); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [1n, 6312n, 28709n, 24704n, 8732n, 1310n, 9808n, 13771n, 27309n, 4779n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - it('batch_size>1', async () => { - const inputs = tokenizer(['hello', 'hello world'], { padding: true }); - const outputs = await model.generate({ - ...inputs, - max_length: 10, - }); - expect(outputs.tolist()).toEqual([ - [2n, 1n, 6312n, 28709n, 24704n, 8732n, 1310n, 9808n, 13771n, 27309n], - [1n, 6312n, 28709n, 1526n, 8687n, 5690n, 1770n, 30811n, 12501n, 3325n] - ]); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("jais", () => { + describe("JAISLMHeadModel", () => { + const model_id = "onnx-community/tiny-random-jais"; + /** @type {JAISLMHeadModel} */ + let model; + /** @type {PreTrainedTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await JAISLMHeadModel.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); + tokenizer = await PreTrainedTokenizer.from_pretrained(model_id); + tokenizer.padding_side = "left"; + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([[55422n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n, 55422n], + [55422n, 2838n, 2838n, 2838n, 2838n, 2838n, 2838n, 2838n, 2838n, 2838n], + ]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); }); -}); - -describe('Tiny random pipelines', () => { - describe('fill-mask', () => { - const model_id = 'hf-internal-testing/tiny-random-BertForMaskedLM'; - - /** @type {FillMaskPipeline} */ - let pipe; - beforeAll(async () => { - pipe = await pipeline('fill-mask', model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - }, MAX_MODEL_LOAD_TIME); - - describe('batch_size=1', () => { - it('default (top_k=5)', async () => { - const output = await pipe('a [MASK] c'); - const target = [ - { score: 0.0013377574505284429, token: 854, token_str: '##ο', sequence: 'aο c' }, - { score: 0.001248967950232327, token: 962, token_str: '##ち', sequence: 'aち c' }, - { score: 0.0012304208939895034, token: 933, token_str: '##ع', sequence: 'aع c' }, - { score: 0.0012301815440878272, token: 313, token_str: 'ფ', sequence: 'a ფ c' }, - { score: 0.001222139224410057, token: 624, token_str: '未', sequence: 'a 未 c' }, - ] - compare(output, target, 1e-5); - }); - it('custom (top_k=2)', async () => { - const output = await pipe('a [MASK] c', { top_k: 2 }); - const target = [ - { score: 0.0013377574505284429, token: 854, token_str: '##ο', sequence: 'aο c' }, - { score: 0.001248967950232327, token: 962, token_str: '##ち', sequence: 'aち c' }, - ] - compare(output, target, 1e-5); - }); + }); + + describe("mpt", () => { + describe("MptForCausalLM", () => { + const model_id = "hf-internal-testing/tiny-random-MptForCausalLM"; + /** @type {MptForCausalLM} */ + let model; + /** @type {GPTNeoXTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await MptForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); - - describe('batch_size>1', () => { - it('default (top_k=5)', async () => { - const output = await pipe([ - 'a [MASK] c', - 'a b [MASK] c', - ]); - const target = [ - [ - { score: 0.0013377574505284429, token: 854, token_str: '##ο', sequence: 'aο c' }, - { score: 0.001248967950232327, token: 962, token_str: '##ち', sequence: 'aち c' }, - { score: 0.0012304208939895034, token: 933, token_str: '##ع', sequence: 'aع c' }, - { score: 0.0012301815440878272, token: 313, token_str: 'ფ', sequence: 'a ფ c' }, - { score: 0.001222139224410057, token: 624, token_str: '未', sequence: 'a 未 c' } - ], - [ - { score: 0.0013287801994010806, token: 962, token_str: '##ち', sequence: 'a bち c' }, - { score: 0.0012486606137827039, token: 823, token_str: '##ن', sequence: 'a bن c' }, - { score: 0.0012320734094828367, token: 1032, token_str: '##ც', sequence: 'a bც c' }, - { score: 0.0012295148335397243, token: 854, token_str: '##ο', sequence: 'a bο c' }, - { score: 0.0012277684872969985, token: 624, token_str: '未', sequence: 'a b 未 c' } - ] - ] - compare(output, target, 1e-5); - }); - it('custom (top_k=2)', async () => { - const output = await pipe([ - 'a [MASK] c', - 'a b [MASK] c', - ], { top_k: 2 }); - const target = [ - [ - { score: 0.0013377574505284429, token: 854, token_str: '##ο', sequence: 'aο c' }, - { score: 0.001248967950232327, token: 962, token_str: '##ち', sequence: 'aち c' } - ], - [ - { score: 0.0013287801994010806, token: 962, token_str: '##ち', sequence: 'a bち c' }, - { score: 0.0012486606137827039, token: 823, token_str: '##ن', sequence: 'a bن c' }, - ] - ] - compare(output, target, 1e-5); - }); - }); - - afterAll(async () => { - await pipe?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + tokenizer = await GPTNeoXTokenizer.from_pretrained(model_id); + tokenizer.padding_side = "left"; + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([[259n, 864n, 80n, 80n, 80n, 80n, 80n, 80n, 80n, 80n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 0n, 259n, 864n, 80n, 80n, 80n, 80n, 80n, 80n], + [259n, 864n, 80n, 270n, 814n, 293n, 293n, 293n, 293n, 293n], + ]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); }); - - describe('text-classification', () => { - const model_id = 'hf-internal-testing/tiny-random-BertForSequenceClassification'; - - /** @type {TextClassificationPipeline} */ - let pipe; - beforeAll(async () => { - pipe = await pipeline('text-classification', model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - }, MAX_MODEL_LOAD_TIME); - - describe('batch_size=1', () => { - it('default (top_k=1)', async () => { - const output = await pipe('a'); - const target = [ - { label: 'LABEL_0', score: 0.5076976418495178 } - ] - compare(output, target, 1e-5); - }); - it('custom (top_k=2)', async () => { - const output = await pipe('a', { top_k: 2 }); - const target = [ - { label: 'LABEL_0', score: 0.5076976418495178 }, - { label: 'LABEL_1', score: 0.49230238795280457 } - ] - compare(output, target, 1e-5); - }); + }); + + describe("codegen", () => { + describe("CodeGenForCausalLM", () => { + const model_id = "hf-internal-testing/tiny-random-CodeGenForCausalLM"; + /** @type {CodeGenForCausalLM} */ + let model; + /** @type {CodeGenTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await CodeGenForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); - - describe('batch_size>1', () => { - it('default (top_k=1)', async () => { - const output = await pipe(['a', 'b c']); - const target = [ - { label: 'LABEL_0', score: 0.5076976418495178 }, - { label: 'LABEL_0', score: 0.5077522993087769 }, - ] - compare(output, target, 1e-5); - }); - it('custom (top_k=2)', async () => { - const output = await pipe(['a', 'b c'], { top_k: 2 }); - const target = [ - [ - { label: 'LABEL_0', score: 0.5076976418495178 }, - { label: 'LABEL_1', score: 0.49230238795280457 } - ], - [ - { label: 'LABEL_0', score: 0.5077522993087769 }, - { label: 'LABEL_1', score: 0.49224773049354553 } - ] - ]; - compare(output, target, 1e-5); - }); - - it('multi_label_classification', async () => { - - const problem_type = pipe.model.config.problem_type; - pipe.model.config.problem_type = 'multi_label_classification'; - - const output = await pipe(['a', 'b c'], { top_k: 2 }); - const target = [ - [ - { label: 'LABEL_0', score: 0.5001373887062073 }, - { label: 'LABEL_1', score: 0.49243971705436707 } - ], - [ - { label: 'LABEL_0', score: 0.5001326203346252 }, - { label: 'LABEL_1', score: 0.492380291223526 } - ] - ]; - compare(output, target, 1e-5); - - // Reset problem type - pipe.model.config.problem_type = problem_type; - }); + tokenizer = await CodeGenTokenizer.from_pretrained(model_id); + tokenizer.padding_side = "left"; + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([[258n, 863n, 79n, 437n, 334n, 450n, 294n, 621n, 375n, 385n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 0n, 258n, 863n, 79n, 437n, 334n, 450n, 294n, 621n], + [258n, 863n, 79n, 269n, 813n, 759n, 113n, 295n, 574n, 987n], + ]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); + + describe("mistral", () => { + describe("MistralForCausalLM", () => { + const model_id = "hf-internal-testing/tiny-random-MistralForCausalLM"; + /** @type {MistralForCausalLM} */ + let model; + /** @type {LlamaTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await MistralForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, }); + tokenizer = await LlamaTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([[1n, 6312n, 28709n, 24704n, 8732n, 1310n, 9808n, 13771n, 27309n, 4779n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [2n, 1n, 6312n, 28709n, 24704n, 8732n, 1310n, 9808n, 13771n, 27309n], + [1n, 6312n, 28709n, 1526n, 8687n, 5690n, 1770n, 30811n, 12501n, 3325n], + ]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); +}); - afterAll(async () => { - await pipe?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); +describe("Tiny random pipelines", () => { + describe("fill-mask", () => { + const model_id = "hf-internal-testing/tiny-random-BertForMaskedLM"; + + /** @type {FillMaskPipeline} */ + let pipe; + beforeAll(async () => { + pipe = await pipeline("fill-mask", model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + }, MAX_MODEL_LOAD_TIME); + + describe("batch_size=1", () => { + it("default (top_k=5)", async () => { + const output = await pipe("a [MASK] c"); + const target = [ + { score: 0.0013377574505284429, token: 854, token_str: "##ο", sequence: "aο c" }, + { score: 0.001248967950232327, token: 962, token_str: "##ち", sequence: "aち c" }, + { score: 0.0012304208939895034, token: 933, token_str: "##ع", sequence: "aع c" }, + { score: 0.0012301815440878272, token: 313, token_str: "ფ", sequence: "a ფ c" }, + { score: 0.001222139224410057, token: 624, token_str: "未", sequence: "a 未 c" }, + ]; + compare(output, target, 1e-5); + }); + it("custom (top_k=2)", async () => { + const output = await pipe("a [MASK] c", { top_k: 2 }); + const target = [ + { score: 0.0013377574505284429, token: 854, token_str: "##ο", sequence: "aο c" }, + { score: 0.001248967950232327, token: 962, token_str: "##ち", sequence: "aち c" }, + ]; + compare(output, target, 1e-5); + }); }); - describe('token-classification', () => { - const model_id = 'hf-internal-testing/tiny-random-BertForTokenClassification'; + describe("batch_size>1", () => { + it("default (top_k=5)", async () => { + const output = await pipe(["a [MASK] c", "a b [MASK] c"]); + const target = [ + [ + { score: 0.0013377574505284429, token: 854, token_str: "##ο", sequence: "aο c" }, + { score: 0.001248967950232327, token: 962, token_str: "##ち", sequence: "aち c" }, + { score: 0.0012304208939895034, token: 933, token_str: "##ع", sequence: "aع c" }, + { score: 0.0012301815440878272, token: 313, token_str: "ფ", sequence: "a ფ c" }, + { score: 0.001222139224410057, token: 624, token_str: "未", sequence: "a 未 c" }, + ], + [ + { score: 0.0013287801994010806, token: 962, token_str: "##ち", sequence: "a bち c" }, + { score: 0.0012486606137827039, token: 823, token_str: "##ن", sequence: "a bن c" }, + { score: 0.0012320734094828367, token: 1032, token_str: "##ც", sequence: "a bც c" }, + { score: 0.0012295148335397243, token: 854, token_str: "##ο", sequence: "a bο c" }, + { score: 0.0012277684872969985, token: 624, token_str: "未", sequence: "a b 未 c" }, + ], + ]; + compare(output, target, 1e-5); + }); + it("custom (top_k=2)", async () => { + const output = await pipe(["a [MASK] c", "a b [MASK] c"], { top_k: 2 }); + const target = [ + [ + { score: 0.0013377574505284429, token: 854, token_str: "##ο", sequence: "aο c" }, + { score: 0.001248967950232327, token: 962, token_str: "##ち", sequence: "aち c" }, + ], + [ + { score: 0.0013287801994010806, token: 962, token_str: "##ち", sequence: "a bち c" }, + { score: 0.0012486606137827039, token: 823, token_str: "##ن", sequence: "a bن c" }, + ], + ]; + compare(output, target, 1e-5); + }); + }); - /** @type {TokenClassificationPipeline} */ - let pipe; - beforeAll(async () => { - pipe = await pipeline('token-classification', model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - }, MAX_MODEL_LOAD_TIME); - - describe('batch_size=1', () => { - it('default', async () => { - const output = await pipe('1 2 3'); - - // TODO: Add start/end to target - const target = [ - { - entity: 'LABEL_0', score: 0.5292708, index: 1, word: '1', - // 'start': 0, 'end': 1 - }, - { - entity: 'LABEL_0', score: 0.5353687, index: 2, word: '2', - // 'start': 2, 'end': 3 - }, - { - entity: 'LABEL_1', score: 0.51381934, index: 3, word: '3', - // 'start': 4, 'end': 5 - } - ] - compare(output, target, 1e-5); - }); - it('custom (ignore_labels set)', async () => { - const output = await pipe('1 2 3', { ignore_labels: ['LABEL_0'] }); - const target = [ - { - entity: 'LABEL_1', score: 0.51381934, index: 3, word: '3', - // 'start': 4, 'end': 5 - } - ] - compare(output, target, 1e-5); - }); - }); + afterAll(async () => { + await pipe?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("text-classification", () => { + const model_id = "hf-internal-testing/tiny-random-BertForSequenceClassification"; + + /** @type {TextClassificationPipeline} */ + let pipe; + beforeAll(async () => { + pipe = await pipeline("text-classification", model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + }, MAX_MODEL_LOAD_TIME); + + describe("batch_size=1", () => { + it("default (top_k=1)", async () => { + const output = await pipe("a"); + const target = [{ label: "LABEL_0", score: 0.5076976418495178 }]; + compare(output, target, 1e-5); + }); + it("custom (top_k=2)", async () => { + const output = await pipe("a", { top_k: 2 }); + const target = [ + { label: "LABEL_0", score: 0.5076976418495178 }, + { label: "LABEL_1", score: 0.49230238795280457 }, + ]; + compare(output, target, 1e-5); + }); + }); - describe('batch_size>1', () => { - it('default', async () => { - const output = await pipe(['1 2 3', '4 5']); - const target = [ - [ - { - entity: 'LABEL_0', score: 0.5292708, index: 1, word: '1', - // 'start': 0, 'end': 1 - }, - { - entity: 'LABEL_0', score: 0.5353687, index: 2, word: '2', - // 'start': 2, 'end': 3 - }, - { - entity: 'LABEL_1', score: 0.51381934, index: 3, word: '3', - // 'start': 4, 'end': 5 - } - ], - [ - { - entity: 'LABEL_0', score: 0.5432807, index: 1, word: '4', - // 'start': 0, 'end': 1 - }, - { - entity: 'LABEL_1', score: 0.5007693, index: 2, word: '5', - // 'start': 2, 'end': 3 - } - ] - ] - compare(output, target, 1e-5); - }); - it('custom (ignore_labels set)', async () => { - const output = await pipe(['1 2 3', '4 5'], { ignore_labels: ['LABEL_0'] }); - const target = [ - [ - { - entity: 'LABEL_1', score: 0.51381934, index: 3, word: '3', - // 'start': 4, 'end': 5 - } - ], - [ - { - entity: 'LABEL_1', score: 0.5007693, index: 2, word: '5', - // 'start': 2, 'end': 3 - } - ] - ] - compare(output, target, 1e-5); - }); - }); + describe("batch_size>1", () => { + it("default (top_k=1)", async () => { + const output = await pipe(["a", "b c"]); + const target = [ + { label: "LABEL_0", score: 0.5076976418495178 }, + { label: "LABEL_0", score: 0.5077522993087769 }, + ]; + compare(output, target, 1e-5); + }); + it("custom (top_k=2)", async () => { + const output = await pipe(["a", "b c"], { top_k: 2 }); + const target = [ + [ + { label: "LABEL_0", score: 0.5076976418495178 }, + { label: "LABEL_1", score: 0.49230238795280457 }, + ], + [ + { label: "LABEL_0", score: 0.5077522993087769 }, + { label: "LABEL_1", score: 0.49224773049354553 }, + ], + ]; + compare(output, target, 1e-5); + }); + + it("multi_label_classification", async () => { + const problem_type = pipe.model.config.problem_type; + pipe.model.config.problem_type = "multi_label_classification"; + + const output = await pipe(["a", "b c"], { top_k: 2 }); + const target = [ + [ + { label: "LABEL_0", score: 0.5001373887062073 }, + { label: "LABEL_1", score: 0.49243971705436707 }, + ], + [ + { label: "LABEL_0", score: 0.5001326203346252 }, + { label: "LABEL_1", score: 0.492380291223526 }, + ], + ]; + compare(output, target, 1e-5); - afterAll(async () => { - await pipe?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + // Reset problem type + pipe.model.config.problem_type = problem_type; + }); }); - describe('question-answering', () => { - const model_id = 'hf-internal-testing/tiny-random-BertForQuestionAnswering'; - - /** @type {QuestionAnsweringPipeline} */ - let pipe; - beforeAll(async () => { - pipe = await pipeline('question-answering', model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - }, MAX_MODEL_LOAD_TIME); + afterAll(async () => { + await pipe?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("token-classification", () => { + const model_id = "hf-internal-testing/tiny-random-BertForTokenClassification"; + + /** @type {TokenClassificationPipeline} */ + let pipe; + beforeAll(async () => { + pipe = await pipeline("token-classification", model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + }, MAX_MODEL_LOAD_TIME); + + describe("batch_size=1", () => { + it("default", async () => { + const output = await pipe("1 2 3"); + + // TODO: Add start/end to target + const target = [ + { + entity: "LABEL_0", + score: 0.5292708, + index: 1, + word: "1", + // 'start': 0, 'end': 1 + }, + { + entity: "LABEL_0", + score: 0.5353687, + index: 2, + word: "2", + // 'start': 2, 'end': 3 + }, + { + entity: "LABEL_1", + score: 0.51381934, + index: 3, + word: "3", + // 'start': 4, 'end': 5 + }, + ]; + compare(output, target, 1e-5); + }); + it("custom (ignore_labels set)", async () => { + const output = await pipe("1 2 3", { ignore_labels: ["LABEL_0"] }); + const target = [ + { + entity: "LABEL_1", + score: 0.51381934, + index: 3, + word: "3", + // 'start': 4, 'end': 5 + }, + ]; + compare(output, target, 1e-5); + }); + }); - describe('batch_size=1', () => { - it('default (top_k=1)', async () => { - const output = await pipe('a', 'b c'); - const target = { score: 0.11395696550607681, /* start: 0, end: 1, */ answer: 'b' }; - compare(output, target, 1e-5); - }); - it('custom (top_k=3)', async () => { - const output = await pipe('a', 'b c', { top_k: 3 }); - const target = [ - { score: 0.11395696550607681, /* start: 0, end: 1, */ answer: 'b' }, - { score: 0.11300431191921234, /* start: 2, end: 3, */ answer: 'c' }, - { score: 0.10732574015855789, /* start: 0, end: 3, */ answer: 'b c' } - ] - compare(output, target, 1e-5); - }); - }); + describe("batch_size>1", () => { + it("default", async () => { + const output = await pipe(["1 2 3", "4 5"]); + const target = [ + [ + { + entity: "LABEL_0", + score: 0.5292708, + index: 1, + word: "1", + // 'start': 0, 'end': 1 + }, + { + entity: "LABEL_0", + score: 0.5353687, + index: 2, + word: "2", + // 'start': 2, 'end': 3 + }, + { + entity: "LABEL_1", + score: 0.51381934, + index: 3, + word: "3", + // 'start': 4, 'end': 5 + }, + ], + [ + { + entity: "LABEL_0", + score: 0.5432807, + index: 1, + word: "4", + // 'start': 0, 'end': 1 + }, + { + entity: "LABEL_1", + score: 0.5007693, + index: 2, + word: "5", + // 'start': 2, 'end': 3 + }, + ], + ]; + compare(output, target, 1e-5); + }); + it("custom (ignore_labels set)", async () => { + const output = await pipe(["1 2 3", "4 5"], { ignore_labels: ["LABEL_0"] }); + const target = [ + [ + { + entity: "LABEL_1", + score: 0.51381934, + index: 3, + word: "3", + // 'start': 4, 'end': 5 + }, + ], + [ + { + entity: "LABEL_1", + score: 0.5007693, + index: 2, + word: "5", + // 'start': 2, 'end': 3 + }, + ], + ]; + compare(output, target, 1e-5); + }); + }); - afterAll(async () => { - await pipe?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + afterAll(async () => { + await pipe?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("question-answering", () => { + const model_id = "hf-internal-testing/tiny-random-BertForQuestionAnswering"; + + /** @type {QuestionAnsweringPipeline} */ + let pipe; + beforeAll(async () => { + pipe = await pipeline("question-answering", model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + }, MAX_MODEL_LOAD_TIME); + + describe("batch_size=1", () => { + it("default (top_k=1)", async () => { + const output = await pipe("a", "b c"); + const target = { score: 0.11395696550607681, /* start: 0, end: 1, */ answer: "b" }; + compare(output, target, 1e-5); + }); + it("custom (top_k=3)", async () => { + const output = await pipe("a", "b c", { top_k: 3 }); + const target = [ + { score: 0.11395696550607681, /* start: 0, end: 1, */ answer: "b" }, + { score: 0.11300431191921234, /* start: 2, end: 3, */ answer: "c" }, + { score: 0.10732574015855789, /* start: 0, end: 3, */ answer: "b c" }, + ]; + compare(output, target, 1e-5); + }); }); - describe('image-classification', () => { - const model_id = 'hf-internal-testing/tiny-random-vit'; - const urls = [ - 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/white-image.png', - 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/blue-image.png', + afterAll(async () => { + await pipe?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("image-classification", () => { + const model_id = "hf-internal-testing/tiny-random-vit"; + const urls = ["https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/white-image.png", "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/blue-image.png"]; + + /** @type {ImageClassificationPipeline} */ + let pipe; + beforeAll(async () => { + pipe = await pipeline("image-classification", model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + }, MAX_MODEL_LOAD_TIME); + + describe("batch_size=1", () => { + it("default (top_k=5)", async () => { + const output = await pipe(urls[0]); + const target = [ + { label: "LABEL_1", score: 0.5020533800125122 }, + { label: "LABEL_0", score: 0.4979466497898102 }, ]; + compare(output, target, 1e-5); + }); + it("custom (top_k=1)", async () => { + const output = await pipe(urls[0], { top_k: 1 }); + const target = [{ label: "LABEL_1", score: 0.5020533800125122 }]; + compare(output, target, 1e-5); + }); + }); - /** @type {ImageClassificationPipeline} */ - let pipe; - beforeAll(async () => { - pipe = await pipeline('image-classification', model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - }, MAX_MODEL_LOAD_TIME); + describe("batch_size>1", () => { + it("default (top_k=5)", async () => { + const output = await pipe(urls); + const target = [ + [ + { label: "LABEL_1", score: 0.5020533800125122 }, + { label: "LABEL_0", score: 0.4979466497898102 }, + ], + [ + { label: "LABEL_1", score: 0.519227921962738 }, + { label: "LABEL_0", score: 0.4807720482349396 }, + ], + ]; + compare(output, target, 1e-5); + }); + it("custom (top_k=1)", async () => { + const output = await pipe(urls, { top_k: 1 }); + const target = [[{ label: "LABEL_1", score: 0.5020533800125122 }], [{ label: "LABEL_1", score: 0.519227921962738 }]]; + compare(output, target, 1e-5); + }); + }); - describe('batch_size=1', () => { + afterAll(async () => { + await pipe?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("zero-shot-image-classification", () => { + const model_id = "hf-internal-testing/tiny-random-GroupViTModel"; + + // Example adapted from https://huggingface.co/docs/transformers/en/model_doc/groupvit + const urls = ["https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/white-image.png", "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/blue-image.png"]; + const labels = ["cat", "dog"]; + const hypothesis_template = "a photo of a {}"; + + /** @type {ZeroShotImageClassificationPipeline} */ + let pipe; + beforeAll(async () => { + pipe = await pipeline("zero-shot-image-classification", model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + }, MAX_MODEL_LOAD_TIME); + + describe("batch_size=1", () => { + it("default", async () => { + const output = await pipe(urls[0], labels); + const target = [ + { score: 0.5990662574768066, label: "cat" }, + { score: 0.40093377232551575, label: "dog" }, + ]; + compare(output, target, 1e-5); + }); + it("custom (w/ hypothesis_template)", async () => { + const output = await pipe(urls[0], labels, { hypothesis_template }); + const target = [ + { score: 0.5527022480964661, label: "cat" }, + { score: 0.44729775190353394, label: "dog" }, + ]; + compare(output, target, 1e-5); + }); + }); - it('default (top_k=5)', async () => { - const output = await pipe(urls[0]); - const target = [ - { label: 'LABEL_1', score: 0.5020533800125122 }, - { label: 'LABEL_0', score: 0.4979466497898102 } - ] - compare(output, target, 1e-5); - }); - it('custom (top_k=1)', async () => { - const output = await pipe(urls[0], { top_k: 1 }); - const target = [{ label: 'LABEL_1', score: 0.5020533800125122 }] - compare(output, target, 1e-5); - }); - }); + describe("batch_size>1", () => { + it("default", async () => { + const output = await pipe(urls, labels); + const target = [ + [ + { score: 0.5990662574768066, label: "cat" }, + { score: 0.40093377232551575, label: "dog" }, + ], + [ + { score: 0.5006340146064758, label: "dog" }, + { score: 0.49936598539352417, label: "cat" }, + ], + ]; + compare(output, target, 1e-5); + }); + it("custom (w/ hypothesis_template)", async () => { + const output = await pipe(urls, labels, { hypothesis_template }); + const target = [ + [ + { score: 0.5527022480964661, label: "cat" }, + { score: 0.44729775190353394, label: "dog" }, + ], + [ + { score: 0.5395973324775696, label: "cat" }, + { score: 0.46040263772010803, label: "dog" }, + ], + ]; + compare(output, target, 1e-5); + }); + }); - describe('batch_size>1', () => { - it('default (top_k=5)', async () => { - const output = await pipe(urls); - const target = [ - [ - { label: 'LABEL_1', score: 0.5020533800125122 }, - { label: 'LABEL_0', score: 0.4979466497898102 } - ], - [ - { label: 'LABEL_1', score: 0.519227921962738 }, - { label: 'LABEL_0', score: 0.4807720482349396 } - ] - ] - compare(output, target, 1e-5); - }); - it('custom (top_k=1)', async () => { - const output = await pipe(urls, { top_k: 1 }); - const target = [ - [{ label: 'LABEL_1', score: 0.5020533800125122 }], - [{ label: 'LABEL_1', score: 0.519227921962738 }] - ] - compare(output, target, 1e-5); - }); - }); + afterAll(async () => { + await pipe?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("audio-classification", () => { + const model_id = "hf-internal-testing/tiny-random-unispeech"; + const audios = [new Float32Array(16000).fill(0), Float32Array.from({ length: 16000 }, (_, i) => i)]; + + /** @type {ImageClassificationPipeline} */ + let pipe; + beforeAll(async () => { + pipe = await pipeline("audio-classification", model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + }, MAX_MODEL_LOAD_TIME); + + describe("batch_size=1", () => { + it("default (top_k=5)", async () => { + const output = await pipe(audios[0]); + const target = [ + { score: 0.5043687224388123, label: "LABEL_0" }, + { score: 0.4956313371658325, label: "LABEL_1" }, + ]; + compare(output, target, 1e-5); + }); + it("custom (top_k=1)", async () => { + const output = await pipe(audios[0], { top_k: 1 }); + const target = [{ score: 0.5043687224388123, label: "LABEL_0" }]; + compare(output, target, 1e-5); + }); + }); - afterAll(async () => { - await pipe?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + describe("batch_size>1", () => { + it("default (top_k=5)", async () => { + const output = await pipe(audios); + const target = [ + [ + { score: 0.5043687224388123, label: "LABEL_0" }, + { score: 0.4956313371658325, label: "LABEL_1" }, + ], + [ + { score: 0.5187293887138367, label: "LABEL_0" }, + { score: 0.4812707006931305, label: "LABEL_1" }, + ], + ]; + compare(output, target, 1e-5); + }); + it("custom (top_k=1)", async () => { + const output = await pipe(audios, { top_k: 1 }); + const target = [[{ score: 0.5043687224388123, label: "LABEL_0" }], [{ score: 0.5187293887138367, label: "LABEL_0" }]]; + compare(output, target, 1e-5); + }); }); - describe('zero-shot-image-classification', () => { - const model_id = 'hf-internal-testing/tiny-random-GroupViTModel'; + afterAll(async () => { + await pipe?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("text-generation", () => { + const model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM"; + + /** @type {TextGenerationPipeline} */ + let pipe; + beforeAll(async () => { + pipe = await pipeline("text-generation", model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + }, MAX_MODEL_LOAD_TIME); + + describe("batch_size=1", () => { + const text_input = "hello"; + const generated_text_target = "erdingsAndroid Load"; + const text_target = [{ generated_text: text_input + generated_text_target }]; + const new_text_target = [{ generated_text: generated_text_target }]; + + const chat_input = [ + { role: "system", content: "a" }, + { role: "user", content: "b" }, + ]; + const chat_target = [ + { + generated_text: [ + { role: "system", content: "a" }, + { role: "user", content: "b" }, + { role: "assistant", content: " Southern abund Load" }, + ], + }, + ]; + + it("text input (single)", async () => { + const output = await pipe(text_input, { max_new_tokens: 3 }); + compare(output, text_target); + }); + it("text input (list)", async () => { + const output = await pipe([text_input], { max_new_tokens: 3 }); + compare(output, [text_target]); + }); + + it("text input (single) - return_full_text=false", async () => { + const output = await pipe(text_input, { max_new_tokens: 3, return_full_text: false }); + compare(output, new_text_target); + }); + it("text input (list) - return_full_text=false", async () => { + const output = await pipe([text_input], { max_new_tokens: 3, return_full_text: false }); + compare(output, [new_text_target]); + }); + + it("chat input (single)", async () => { + const output = await pipe(chat_input, { max_new_tokens: 3 }); + compare(output, chat_target); + }); + it("chat input (list)", async () => { + const output = await pipe([chat_input], { max_new_tokens: 3 }); + compare(output, [chat_target]); + }); + }); - // Example adapted from https://huggingface.co/docs/transformers/en/model_doc/groupvit - const urls = [ - 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/white-image.png', - 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/blue-image.png', + // TODO: Fix batch_size>1 + // describe('batch_size>1', () => { + // it('default', async () => { + // const output = await pipe(['hello', 'hello world']); + // const target = [ + // [{generated_text: 'helloerdingsAndroid Load'}], + // [{generated_text: 'hello world zerosMillнал'}], + // ]; + // compare(output, target); + // }); + // }); + + afterAll(async () => { + await pipe?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("object-detection", () => { + const model_id = "hf-internal-testing/tiny-random-DetrForObjectDetection"; + const urls = ["https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/white-image.png", "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/blue-image.png"]; + + /** @type {ImageClassificationPipeline} */ + let pipe; + beforeAll(async () => { + pipe = await pipeline("object-detection", model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + }, MAX_MODEL_LOAD_TIME); + + describe("batch_size=1", () => { + it("default (threshold unset)", async () => { + const output = await pipe(urls[0]); + const target = []; + compare(output, target, 1e-5); + }); + it("default (threshold=0)", async () => { + const output = await pipe(urls[0], { threshold: 0 }); + const target = [ + { score: 0.020360443741083145, label: "LABEL_31", box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, + { score: 0.020360419526696205, label: "LABEL_31", box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, + { score: 0.02036038413643837, label: "LABEL_31", box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, + { score: 0.020360447466373444, label: "LABEL_31", box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, + { score: 0.020360389724373817, label: "LABEL_31", box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, + { score: 0.020360423251986504, label: "LABEL_31", box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, + { score: 0.02036040835082531, label: "LABEL_31", box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, + { score: 0.020360363647341728, label: "LABEL_31", box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, + { score: 0.020360389724373817, label: "LABEL_31", box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, + { score: 0.020360389724373817, label: "LABEL_31", box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, + { score: 0.020360343158245087, label: "LABEL_31", box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, + { score: 0.020360423251986504, label: "LABEL_31", box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, ]; - const labels = ['cat', 'dog']; - const hypothesis_template = 'a photo of a {}'; + compare(output, target, 1e-5); + }); + }); + // TODO: Add batched support to object detection pipeline + // describe('batch_size>1', () => { + // it('default (threshold unset)', async () => { + // const output = await pipe(urls); + // console.log(output); + // const target = []; + // compare(output, target, 1e-5); + // }); + // it('default (threshold=0)', async () => { + // const output = await pipe(urls, { threshold: 0 }); + // console.log(output); + // const target = []; + // compare(output, target, 1e-5); + // }); + // }); + + afterAll(async () => { + await pipe?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); +}); - /** @type {ZeroShotImageClassificationPipeline} */ - let pipe; - beforeAll(async () => { - pipe = await pipeline('zero-shot-image-classification', model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - }, MAX_MODEL_LOAD_TIME); - - describe('batch_size=1', () => { - it('default', async () => { - const output = await pipe(urls[0], labels); - const target = [ - { score: 0.5990662574768066, label: 'cat' }, - { score: 0.40093377232551575, label: 'dog' } - ] - compare(output, target, 1e-5); - }); - it('custom (w/ hypothesis_template)', async () => { - const output = await pipe(urls[0], labels, { hypothesis_template }); - const target = [ - { score: 0.5527022480964661, label: 'cat' }, - { score: 0.44729775190353394, label: 'dog' } - ] - compare(output, target, 1e-5); - }); +describe("PKV caching", () => { + describe("LlamaForCausalLM", () => { + const model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM"; + /** @type {LlamaForCausalLM} */ + let model; + /** @type {LlamaTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await LlamaForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await LlamaTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("1"); + + // Generate first sequence w/o PKV + // NOTE: `return_dict_in_generate=true` is required to get PKV + const { past_key_values, sequences } = await model.generate({ + ...inputs, + max_new_tokens: 5, + do_sample: false, + return_dict_in_generate: true, }); - describe('batch_size>1', () => { - it('default', async () => { - const output = await pipe(urls, labels); - const target = [ - [ - { score: 0.5990662574768066, label: 'cat' }, - { score: 0.40093377232551575, label: 'dog' } - ], - [ - { score: 0.5006340146064758, label: 'dog' }, - { score: 0.49936598539352417, label: 'cat' } - ] - ] - compare(output, target, 1e-5); - }); - it('custom (w/ hypothesis_template)', async () => { - const output = await pipe(urls, labels, { hypothesis_template }); - const target = [ - [ - { score: 0.5527022480964661, label: 'cat' }, - { score: 0.44729775190353394, label: 'dog' } - ], - [ - { score: 0.5395973324775696, label: 'cat' }, - { score: 0.46040263772010803, label: 'dog' } - ] - ] - compare(output, target, 1e-5); - }); + // Update output with new text + const decoded = tokenizer.batch_decode(sequences, { + skip_special_tokens: false, + })[0]; + const new_inputs = tokenizer(decoded + "2", { + add_special_tokens: false, }); - afterAll(async () => { - await pipe?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); - }); - - describe('audio-classification', () => { - const model_id = 'hf-internal-testing/tiny-random-unispeech'; - const audios = [ - new Float32Array(16000).fill(0), - Float32Array.from({ length: 16000 }, (_, i) => i), - ] - - /** @type {ImageClassificationPipeline} */ - let pipe; - beforeAll(async () => { - pipe = await pipeline('audio-classification', model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - }, MAX_MODEL_LOAD_TIME); - - describe('batch_size=1', () => { - - it('default (top_k=5)', async () => { - const output = await pipe(audios[0]); - const target = [ - { score: 0.5043687224388123, label: 'LABEL_0' }, - { score: 0.4956313371658325, label: 'LABEL_1' } - ] - compare(output, target, 1e-5); - }); - it('custom (top_k=1)', async () => { - const output = await pipe(audios[0], { top_k: 1 }); - const target = [{ score: 0.5043687224388123, label: 'LABEL_0' }] - compare(output, target, 1e-5); - }); + // Run w/o PKV + const generated_ids = await model.generate({ + ...new_inputs, + max_new_tokens: 3, + do_sample: false, }); - describe('batch_size>1', () => { - it('default (top_k=5)', async () => { - const output = await pipe(audios); - const target = [ - [ - { score: 0.5043687224388123, label: 'LABEL_0' }, - { score: 0.4956313371658325, label: 'LABEL_1' } - ], - [ - { score: 0.5187293887138367, label: 'LABEL_0' }, - { score: 0.4812707006931305, label: 'LABEL_1' } - ] - ] - compare(output, target, 1e-5); - }); - it('custom (top_k=1)', async () => { - const output = await pipe(audios, { top_k: 1 }); - const target = [ - [{ score: 0.5043687224388123, label: 'LABEL_0' }], - [{ score: 0.5187293887138367, label: 'LABEL_0' }] - ] - compare(output, target, 1e-5); - }); + // Run w/ PKV + const generated_ids_pkv = await model.generate({ + ...new_inputs, + past_key_values, + max_new_tokens: 3, + do_sample: false, }); - afterAll(async () => { - await pipe?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); - }); - - describe('text-generation', () => { - const model_id = 'hf-internal-testing/tiny-random-LlamaForCausalLM'; - - /** @type {TextGenerationPipeline} */ - let pipe; - beforeAll(async () => { - pipe = await pipeline('text-generation', model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - }, MAX_MODEL_LOAD_TIME); - - describe('batch_size=1', () => { - const text_input = 'hello'; - const generated_text_target = 'erdingsAndroid Load'; - const text_target = [{ generated_text: text_input + generated_text_target }] - const new_text_target = [{ generated_text: generated_text_target }] - - const chat_input = [ - { role: 'system', content: 'a' }, - { role: 'user', content: 'b' }, - ] - const chat_target = [{ - generated_text: [ - { role: 'system', 'content': 'a' }, - { role: 'user', 'content': 'b' }, - { role: 'assistant', 'content': ' Southern abund Load' }, - ], - }] - - it('text input (single)', async () => { - const output = await pipe(text_input, { max_new_tokens: 3 }); - compare(output, text_target); - }); - it('text input (list)', async () => { - const output = await pipe([text_input], { max_new_tokens: 3 }); - compare(output, [text_target]); - }); - - it('text input (single) - return_full_text=false', async () => { - const output = await pipe(text_input, { max_new_tokens: 3, return_full_text: false }); - compare(output, new_text_target); - }); - it('text input (list) - return_full_text=false', async () => { - const output = await pipe([text_input], { max_new_tokens: 3, return_full_text: false }); - compare(output, [new_text_target]); - }); + const target = [[1n, 259n, 29896n, 24959n, 22063n, 17192n, 12189n, 22468n, 29906n, 3399n, 24823n, 26470n]]; + + expect(generated_ids.tolist()).toEqual(target); + expect(generated_ids_pkv.tolist()).toEqual(target); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + + describe("LlavaForConditionalGeneration", () => { + const model_id = "Xenova/tiny-random-LlavaForConditionalGeneration"; + /** @type {LlavaForConditionalGeneration} */ + let model; + /** @type {PreTrainedTokenizer} */ + let tokenizer; + /** @type {Processor} */ + let processor; + beforeAll(async () => { + model = await LlavaForConditionalGeneration.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await AutoTokenizer.from_pretrained(model_id); + processor = await AutoProcessor.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const text_inputs = tokenizer("hello"); - it('chat input (single)', async () => { - const output = await pipe(chat_input, { max_new_tokens: 3 }); - compare(output, chat_target); - }); - it('chat input (list)', async () => { - const output = await pipe([chat_input], { max_new_tokens: 3 }); - compare(output, [chat_target]); - }); + // Empty white image + const dims = [224, 224, 3]; + const image = new RawImage(new Uint8ClampedArray(dims[0] * dims[1] * dims[2]).fill(255), ...dims); + const vision_inputs = await processor(image); + + // Generate first sequence w/o PKV + // NOTE: `return_dict_in_generate=true` is required to get PKV + const { past_key_values, sequences } = await model.generate({ + ...text_inputs, + ...vision_inputs, + max_new_tokens: 5, + do_sample: false, + return_dict_in_generate: true, }); - // TODO: Fix batch_size>1 - // describe('batch_size>1', () => { - // it('default', async () => { - // const output = await pipe(['hello', 'hello world']); - // const target = [ - // [{generated_text: 'helloerdingsAndroid Load'}], - // [{generated_text: 'hello world zerosMillнал'}], - // ]; - // compare(output, target); - // }); - // }); - - afterAll(async () => { - await pipe?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); - }); - - describe('object-detection', () => { - const model_id = 'hf-internal-testing/tiny-random-DetrForObjectDetection'; - const urls = [ - 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/white-image.png', - 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/blue-image.png', - ]; - - /** @type {ImageClassificationPipeline} */ - let pipe; - beforeAll(async () => { - pipe = await pipeline('object-detection', model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - }, MAX_MODEL_LOAD_TIME); - - describe('batch_size=1', () => { - - it('default (threshold unset)', async () => { - const output = await pipe(urls[0]); - const target = []; - compare(output, target, 1e-5); - }); - it('default (threshold=0)', async () => { - const output = await pipe(urls[0], { threshold: 0 }); - const target = [ - { score: 0.020360443741083145, label: 'LABEL_31', box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, - { score: 0.020360419526696205, label: 'LABEL_31', box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, - { score: 0.02036038413643837, label: 'LABEL_31', box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, - { score: 0.020360447466373444, label: 'LABEL_31', box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, - { score: 0.020360389724373817, label: 'LABEL_31', box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, - { score: 0.020360423251986504, label: 'LABEL_31', box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, - { score: 0.02036040835082531, label: 'LABEL_31', box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, - { score: 0.020360363647341728, label: 'LABEL_31', box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, - { score: 0.020360389724373817, label: 'LABEL_31', box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, - { score: 0.020360389724373817, label: 'LABEL_31', box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, - { score: 0.020360343158245087, label: 'LABEL_31', box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } }, - { score: 0.020360423251986504, label: 'LABEL_31', box: { xmin: 56, ymin: 55, xmax: 169, ymax: 167 } } - ]; - compare(output, target, 1e-5); - }); + // Update output with new text + const decoded = tokenizer.batch_decode(sequences).map((x) => x + "new"); + const new_inputs = tokenizer(decoded, { + add_special_tokens: false, }); - // TODO: Add batched support to object detection pipeline - // describe('batch_size>1', () => { - // it('default (threshold unset)', async () => { - // const output = await pipe(urls); - // console.log(output); - // const target = []; - // compare(output, target, 1e-5); - // }); - // it('default (threshold=0)', async () => { - // const output = await pipe(urls, { threshold: 0 }); - // console.log(output); - // const target = []; - // compare(output, target, 1e-5); - // }); - // }); - - afterAll(async () => { - await pipe?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); - }); -}); - - -describe('PKV caching', () => { - describe('LlamaForCausalLM', () => { - const model_id = 'hf-internal-testing/tiny-random-LlamaForCausalLM'; - /** @type {LlamaForCausalLM} */ - let model; - /** @type {LlamaTokenizer} */ - let tokenizer; - beforeAll(async () => { - model = await LlamaForCausalLM.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await LlamaTokenizer.from_pretrained(model_id); - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const inputs = tokenizer('1'); - - // Generate first sequence w/o PKV - // NOTE: `return_dict_in_generate=true` is required to get PKV - const { past_key_values, sequences } = await model.generate({ - ...inputs, - max_new_tokens: 5, - do_sample: false, - return_dict_in_generate: true, - }); - - // Update output with new text - const decoded = tokenizer.batch_decode(sequences, { - skip_special_tokens: false - })[0]; - const new_inputs = tokenizer(decoded + '2', { - add_special_tokens: false, - }); - - // Run w/o PKV - const generated_ids = await model.generate({ - ...new_inputs, - max_new_tokens: 3, - do_sample: false, - }); - - // Run w/ PKV - const generated_ids_pkv = await model.generate({ - ...new_inputs, - past_key_values, - max_new_tokens: 3, - do_sample: false, - }); - - const target = [[1n, 259n, 29896n, 24959n, 22063n, 17192n, 12189n, 22468n, 29906n, 3399n, 24823n, 26470n]]; - - expect(generated_ids.tolist()).toEqual(target); - expect(generated_ids_pkv.tolist()).toEqual(target); - - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); - }); - - describe('LlavaForConditionalGeneration', () => { - const model_id = 'Xenova/tiny-random-LlavaForConditionalGeneration'; - /** @type {LlavaForConditionalGeneration} */ - let model; - /** @type {PreTrainedTokenizer} */ - let tokenizer; - /** @type {Processor} */ - let processor; - beforeAll(async () => { - model = await LlavaForConditionalGeneration.from_pretrained(model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - tokenizer = await AutoTokenizer.from_pretrained(model_id); - processor = await AutoProcessor.from_pretrained(model_id); - }, MAX_MODEL_LOAD_TIME); - - it('batch_size=1', async () => { - const text_inputs = tokenizer('hello'); - - // Empty white image - const dims = [224, 224, 3]; - const image = new RawImage(new Uint8ClampedArray(dims[0] * dims[1] * dims[2]).fill(255), ...dims); - const vision_inputs = await processor(image); - - // Generate first sequence w/o PKV - // NOTE: `return_dict_in_generate=true` is required to get PKV - const { past_key_values, sequences } = await model.generate({ - ...text_inputs, - ...vision_inputs, - max_new_tokens: 5, - do_sample: false, - return_dict_in_generate: true, - }); - - // Update output with new text - const decoded = tokenizer.batch_decode(sequences).map(x => x + 'new'); - const new_inputs = tokenizer(decoded, { - add_special_tokens: false, - }); - - // Run w/o PKV - const generated_ids = await model.generate({ - ...new_inputs, - ...vision_inputs, - max_new_tokens: 3, - do_sample: false, - }); - - // Run w/ PKV - const generated_ids_pkv = await model.generate({ - ...new_inputs, - past_key_values, - max_new_tokens: 3, - do_sample: false, - }); - - const target = [[1n, 32000n, 29871n, 23927n, 359n, 1519n, 568n, 5769n, 1330n, 21544n, 11568n, 1482n, 7258n, 1250n, 16117n]]; - expect(generated_ids.tolist()).toEqual(target); - expect(generated_ids_pkv.tolist()).toEqual(target); + // Run w/o PKV + const generated_ids = await model.generate({ + ...new_inputs, + ...vision_inputs, + max_new_tokens: 3, + do_sample: false, + }); - }, MAX_TEST_EXECUTION_TIME); + // Run w/ PKV + const generated_ids_pkv = await model.generate({ + ...new_inputs, + past_key_values, + max_new_tokens: 3, + do_sample: false, + }); - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); - }); + const target = [[1n, 32000n, 29871n, 23927n, 359n, 1519n, 568n, 5769n, 1330n, 21544n, 11568n, 1482n, 7258n, 1250n, 16117n]]; + expect(generated_ids.tolist()).toEqual(target); + expect(generated_ids_pkv.tolist()).toEqual(target); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); }); diff --git a/tests/tokenizers.test.js b/tests/tokenizers.test.js index d2154b58a..c8c47961a 100644 --- a/tests/tokenizers.test.js +++ b/tests/tokenizers.test.js @@ -1,745 +1,801 @@ +import { AutoTokenizer, WhisperTokenizer } from "../src/tokenizers.js"; +import * as TOKENIZER_TESTS from "./models/all_tokenization_tests.js"; -import { AutoTokenizer, WhisperTokenizer } from '../src/tokenizers.js'; -import * as TOKENIZER_TESTS from './models/all_tokenization_tests.js'; - -import { compare } from './test_utils.js'; +import { compare } from "./test_utils.js"; const MAX_LOAD_TIME = 10_000; const MAX_EXECUTION_TIME = 10_000; -describe('Tokenizers (model-specific)', () => { - for (const [tokenizer_name, { TOKENIZER_CLASS, TEST_CONFIG }] of Object.entries(TOKENIZER_TESTS)) { - describe(tokenizer_name, () => { - for (const model_id in TEST_CONFIG) { - describe(model_id, () => { - let tokenizer; - beforeAll(async () => { - tokenizer = await TOKENIZER_CLASS.from_pretrained(model_id); - }, MAX_LOAD_TIME); - - for (const [test_name, test_case] of Object.entries(TEST_CONFIG[model_id])) { - test(test_name, () => { - const ids = tokenizer.encode(test_case.text); - expect(ids).toEqual(test_case.ids); - const tokens = tokenizer.tokenize(test_case.text); - expect(tokens).toEqual(test_case.tokens); - const decoded = tokenizer.decode(test_case.ids); - expect(decoded).toEqual(test_case.decoded); - }); - } - }); - } +describe("Tokenizers (model-specific)", () => { + for (const [tokenizer_name, { TOKENIZER_CLASS, TEST_CONFIG }] of Object.entries(TOKENIZER_TESTS)) { + describe(tokenizer_name, () => { + for (const model_id in TEST_CONFIG) { + describe(model_id, () => { + let tokenizer; + beforeAll(async () => { + tokenizer = await TOKENIZER_CLASS.from_pretrained(model_id); + }, MAX_LOAD_TIME); + + for (const [test_name, test_case] of Object.entries(TEST_CONFIG[model_id])) { + test(test_name, () => { + const ids = tokenizer.encode(test_case.text); + expect(ids).toEqual(test_case.ids); + const tokens = tokenizer.tokenize(test_case.text); + expect(tokens).toEqual(test_case.tokens); + const decoded = tokenizer.decode(test_case.ids); + expect(decoded).toEqual(test_case.decoded); + }); + } }); - } + } + }); + } }); // Tests to ensure that no matter what, the correct tokenization is returned. // This is necessary since there are sometimes bugs in the transformers library. -describe('Tokenizers (hard-coded)', () => { - const TESTS = { - 'Xenova/llama-tokenizer': [ // Test legacy compatibility - { - // legacy unset => legacy=true - // NOTE: While incorrect, it is necessary to match legacy behaviour - data: { - "\n": [1, 29871, 13], - }, - legacy: null, - }, - { - // override legacy=true (same results as above) - data: { - "\n": [1, 29871, 13], - }, - legacy: true, - }, - { - // override legacy=false (fixed results) - data: { - "\n": [1, 13], - }, - legacy: false, +describe("Tokenizers (hard-coded)", () => { + const TESTS = { + "Xenova/llama-tokenizer": [ + // Test legacy compatibility + { + // legacy unset => legacy=true + // NOTE: While incorrect, it is necessary to match legacy behaviour + data: { + "\n": [1, 29871, 13], + }, + legacy: null, + }, + { + // override legacy=true (same results as above) + data: { + "\n": [1, 29871, 13], + }, + legacy: true, + }, + { + // override legacy=false (fixed results) + data: { + "\n": [1, 13], + }, + legacy: false, + }, + ], + + "Xenova/llama-tokenizer_new": [ + // legacy=false + { + data: { + " 1 2 3 4 ": [259, 2, 29871, 29896, 259, 29906, 1678, 29941, 268, 29946, 1678], + "\n": [1, 13], + "test": [2, 1688, 2], + " test ": [259, 2, 1243, 29871, 2, 29871], + "A\n'll": [319, 13, 29915, 645], + "Hey . how are you": [18637, 29871, 2, 29889, 920, 526, 366], + " Hi Hello ": [259, 6324, 29871, 15043, 259], + }, + reversible: true, + legacy: null, + }, + { + // override legacy=true (incorrect results, but necessary to match legacy behaviour) + data: { + "\n": [1, 29871, 13], + }, + legacy: true, + }, + ], + + // legacy=false + "Xenova/t5-tokenizer-new": [ + { + data: { + // https://github.com/huggingface/transformers/pull/26678 + // ['▁Hey', '▁', '', '.', '▁how', '▁are', '▁you'] + "Hey . how are you": [9459, 3, 1, 5, 149, 33, 25], + }, + reversible: true, + legacy: null, + }, + { + data: { + "\n": [1, 3], + "A\n'll": [71, 3, 31, 195], + }, + reversible: false, + legacy: null, + }, + ], + }; + + // Re-use the same tests for the llama2 tokenizer + TESTS["Xenova/llama2-tokenizer"] = TESTS["Xenova/llama-tokenizer_new"]; + + for (const [tokenizerName, test_data] of Object.entries(TESTS)) { + it( + tokenizerName, + async () => { + for (const { data, reversible, legacy } of test_data) { + const tokenizer = await AutoTokenizer.from_pretrained(tokenizerName, { legacy }); + + for (const [text, expected] of Object.entries(data)) { + const token_ids = tokenizer.encode(text, { add_special_tokens: false }); + expect(token_ids).toEqual(expected); + + // If reversible, test that decoding produces the original text + if (reversible) { + const decoded = tokenizer.decode(token_ids); + expect(decoded).toEqual(text); } - ], + } + } + }, + MAX_EXECUTION_TIME, + ); + } +}); - 'Xenova/llama-tokenizer_new': [ // legacy=false - { - data: { - " 1 2 3 4 ": [259, 2, 29871, 29896, 259, 29906, 1678, 29941, 268, 29946, 1678], - "\n": [1, 13], - "test": [2, 1688, 2], - " test ": [259, 2, 1243, 29871, 2, 29871], - "A\n'll": [319, 13, 29915, 645], - "Hey . how are you": [18637, 29871, 2, 29889, 920, 526, 366], - " Hi Hello ": [259, 6324, 29871, 15043, 259], - }, - reversible: true, - legacy: null, - }, - { // override legacy=true (incorrect results, but necessary to match legacy behaviour) - data: { - "\n": [1, 29871, 13], - }, - legacy: true, - }, +describe("Tokenizer padding/truncation", () => { + const inputs = ["a", "b c"]; + const text_pair = ["d e", "f g h"]; + + it("should create a jagged array", async () => { + const tokenizer = await AutoTokenizer.from_pretrained("Xenova/bert-base-uncased"); + + { + // support jagged array if `return_tensor=false` + const output = tokenizer(inputs, { + return_tensor: false, + }); + const expected = { + input_ids: [ + [101, 1037, 102], + [101, 1038, 1039, 102], ], - - // legacy=false - 'Xenova/t5-tokenizer-new': [ - { - data: { - // https://github.com/huggingface/transformers/pull/26678 - // ['▁Hey', '▁', '', '.', '▁how', '▁are', '▁you'] - "Hey . how are you": [9459, 3, 1, 5, 149, 33, 25], - }, - reversible: true, - legacy: null, - }, - { - data: { - "\n": [1, 3], - "A\n'll": [71, 3, 31, 195], - }, - reversible: false, - legacy: null, - } + attention_mask: [ + [1, 1, 1], + [1, 1, 1, 1], ], + token_type_ids: [ + [0, 0, 0], + [0, 0, 0, 0], + ], + }; + compare(output, expected); } - // Re-use the same tests for the llama2 tokenizer - TESTS['Xenova/llama2-tokenizer'] = TESTS['Xenova/llama-tokenizer_new']; - - for (const [tokenizerName, test_data] of Object.entries(TESTS)) { - - it(tokenizerName, async () => { - for (const { data, reversible, legacy } of test_data) { - const tokenizer = await AutoTokenizer.from_pretrained(tokenizerName, { legacy }); - - for (const [text, expected] of Object.entries(data)) { - const token_ids = tokenizer.encode(text, { add_special_tokens: false }); - expect(token_ids).toEqual(expected); - - // If reversible, test that decoding produces the original text - if (reversible) { - const decoded = tokenizer.decode(token_ids); - expect(decoded).toEqual(text); - } - } - } - }, MAX_EXECUTION_TIME); + { + const output = tokenizer(inputs, { + return_tensor: false, + truncation: true, + add_special_tokens: false, + }); + const expected = { + input_ids: [[1037], [1038, 1039]], + attention_mask: [[1], [1, 1]], + token_type_ids: [[0], [0, 0]], + }; + compare(output, expected); } -}); - -describe('Tokenizer padding/truncation', () => { - const inputs = ['a', 'b c']; - const text_pair = ['d e', 'f g h']; - - it('should create a jagged array', async () => { - const tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); - - { // support jagged array if `return_tensor=false` - const output = tokenizer(inputs, { - return_tensor: false, - }) - const expected = { - input_ids: [[101, 1037, 102], [101, 1038, 1039, 102]], - attention_mask: [[1, 1, 1], [1, 1, 1, 1]], - token_type_ids: [[0, 0, 0], [0, 0, 0, 0]] - } - compare(output, expected); - } - - { - const output = tokenizer(inputs, { - return_tensor: false, - truncation: true, - add_special_tokens: false, - }) - const expected = { - input_ids: [[1037], [1038, 1039]], - attention_mask: [[1], [1, 1]], - token_type_ids: [[0], [0, 0]] - } - compare(output, expected); - } - }) - - it('should create a tensor', async () => { - const tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); - - { // Expected to throw error if jagged array - expect(() => tokenizer(inputs)).toThrowError('Unable to create tensor'); - } - - { // Truncation - const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, { - truncation: true, - max_length: 1, - add_special_tokens: false, - }) + }); + + it( + "should create a tensor", + async () => { + const tokenizer = await AutoTokenizer.from_pretrained("Xenova/bert-base-uncased"); + + { + // Expected to throw error if jagged array + expect(() => tokenizer(inputs)).toThrowError("Unable to create tensor"); + } + + { + // Truncation + const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, { + truncation: true, + max_length: 1, + add_special_tokens: false, + }); - expect(input_ids.tolist()).toEqual([[1037n], [1038n]]) - expect(attention_mask.tolist()).toEqual([[1n], [1n]]) - expect(token_type_ids.tolist()).toEqual([[0n], [0n]]) - } - { // Truncation w/ text pair - // TODO - } + expect(input_ids.tolist()).toEqual([[1037n], [1038n]]); + expect(attention_mask.tolist()).toEqual([[1n], [1n]]); + expect(token_type_ids.tolist()).toEqual([[0n], [0n]]); + } + { + // Truncation w/ text pair + // TODO + } + + { + // Padding + const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, { + padding: true, + add_special_tokens: false, + }); - { // Padding - const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, { - padding: true, - add_special_tokens: false, - }) + expect(input_ids.tolist()).toEqual([ + [1037n, 0n], + [1038n, 1039n], + ]); + expect(attention_mask.tolist()).toEqual([ + [1n, 0n], + [1n, 1n], + ]); + expect(token_type_ids.tolist()).toEqual([ + [0n, 0n], + [0n, 0n], + ]); + } + { + // Padding w/ text pair + const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, { + text_pair, + padding: true, + add_special_tokens: false, + }); - expect(input_ids.tolist()).toEqual([[1037n, 0n], [1038n, 1039n]]) - expect(attention_mask.tolist()).toEqual([[1n, 0n], [1n, 1n]]) - expect(token_type_ids.tolist()).toEqual([[0n, 0n], [0n, 0n]]) - } - { // Padding w/ text pair - const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, { - text_pair, - padding: true, - add_special_tokens: false, - }) - - expect(input_ids.tolist()).toEqual([ - [1037n, 1040n, 1041n, 0n, 0n], - [1038n, 1039n, 1042n, 1043n, 1044n], - ]); - expect(attention_mask.tolist()).toEqual([ - [1n, 1n, 1n, 0n, 0n], - [1n, 1n, 1n, 1n, 1n], - ]); - expect(token_type_ids.tolist()).toEqual([ - [0n, 1n, 1n, 0n, 0n], - [0n, 0n, 1n, 1n, 1n], - ]); - } + expect(input_ids.tolist()).toEqual([ + [1037n, 1040n, 1041n, 0n, 0n], + [1038n, 1039n, 1042n, 1043n, 1044n], + ]); + expect(attention_mask.tolist()).toEqual([ + [1n, 1n, 1n, 0n, 0n], + [1n, 1n, 1n, 1n, 1n], + ]); + expect(token_type_ids.tolist()).toEqual([ + [0n, 1n, 1n, 0n, 0n], + [0n, 0n, 1n, 1n, 1n], + ]); + } + + { + // Truncation + padding + const { input_ids, attention_mask, token_type_ids } = tokenizer(["a", "b c", "d e f"], { + padding: true, + truncation: true, + add_special_tokens: false, + max_length: 2, + }); - { // Truncation + padding - const { input_ids, attention_mask, token_type_ids } = tokenizer(['a', 'b c', 'd e f'], { - padding: true, - truncation: true, - add_special_tokens: false, - max_length: 2, - }) - - expect(input_ids.tolist()).toEqual([[1037n, 0n], [1038n, 1039n], [1040n, 1041n]]) - expect(attention_mask.tolist()).toEqual([[1n, 0n], [1n, 1n], [1n, 1n]]) - expect(token_type_ids.tolist()).toEqual([[0n, 0n], [0n, 0n], [0n, 0n]]) - } - }, MAX_EXECUTION_TIME); + expect(input_ids.tolist()).toEqual([ + [1037n, 0n], + [1038n, 1039n], + [1040n, 1041n], + ]); + expect(attention_mask.tolist()).toEqual([ + [1n, 0n], + [1n, 1n], + [1n, 1n], + ]); + expect(token_type_ids.tolist()).toEqual([ + [0n, 0n], + [0n, 0n], + [0n, 0n], + ]); + } + }, + MAX_EXECUTION_TIME, + ); }); -describe('Token type ids', () => { - it('should correctly add token type ids', async () => { - const tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); - - const model_inputs = tokenizer( - ['a b c', 'd'], - { - text_pair: ['e f', 'g h'], - padding: true, - truncation: true, - return_tensor: false, - } - ); - +describe("Token type ids", () => { + it( + "should correctly add token type ids", + async () => { + const tokenizer = await AutoTokenizer.from_pretrained("Xenova/bert-base-uncased"); + + const model_inputs = tokenizer(["a b c", "d"], { + text_pair: ["e f", "g h"], + padding: true, + truncation: true, + return_tensor: false, + }); + + const expected = { + input_ids: [ + [101, 1037, 1038, 1039, 102, 1041, 1042, 102], + [101, 1040, 102, 1043, 1044, 102, 0, 0], + ], + token_type_ids: [ + [0, 0, 0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1, 0, 0], + ], + attention_mask: [ + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 0, 0], + ], + }; + + compare(model_inputs, expected); + }, + MAX_EXECUTION_TIME, + ); + + it( + "should add token type ids if user requests them", + async () => { + const tokenizer = await AutoTokenizer.from_pretrained("Xenova/llama3-tokenizer-new"); + + { + // Without text pair + const model_inputs = tokenizer("hello", { + return_tensor: false, + return_token_type_ids: true, + }); const expected = { - input_ids: [ - [101, 1037, 1038, 1039, 102, 1041, 1042, 102], - [101, 1040, 102, 1043, 1044, 102, 0, 0], - ], - token_type_ids: [ - [0, 0, 0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1, 0, 0], - ], - attention_mask: [ - [1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 0, 0], - ], - } - + input_ids: [128000, 15339], + attention_mask: [1, 1], + token_type_ids: [0, 0], + }; compare(model_inputs, expected); - - }, MAX_EXECUTION_TIME); - - it('should add token type ids if user requests them', async () => { - const tokenizer = await AutoTokenizer.from_pretrained('Xenova/llama3-tokenizer-new'); - - { // Without text pair - const model_inputs = tokenizer( - 'hello', - { - return_tensor: false, - return_token_type_ids: true, - } - ); - const expected = { - input_ids: [128000, 15339], - attention_mask: [1, 1], - token_type_ids: [0, 0] - } - compare(model_inputs, expected); - } - - { // With text pair - const model_inputs = tokenizer( - 'hello', - { - text_pair: 'world', - return_tensor: false, - return_token_type_ids: true, - } - ); - const expected = { - input_ids: [128000, 15339, 128000, 14957], - attention_mask: [1, 1, 1, 1], - token_type_ids: [0, 0, 1, 1] - } - compare(model_inputs, expected); - } - - }, MAX_EXECUTION_TIME); + } + + { + // With text pair + const model_inputs = tokenizer("hello", { + text_pair: "world", + return_tensor: false, + return_token_type_ids: true, + }); + const expected = { + input_ids: [128000, 15339, 128000, 14957], + attention_mask: [1, 1, 1, 1], + token_type_ids: [0, 0, 1, 1], + }; + compare(model_inputs, expected); + } + }, + MAX_EXECUTION_TIME, + ); }); -describe('Edge cases', () => { - it('should not crash when encoding a very long string', async () => { - let tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small'); - - let text = String.prototype.repeat.call('Hello world! ', 50000); - let encoded = tokenizer(text); - expect(encoded.input_ids.data.length).toBeGreaterThan(100000); - }, MAX_EXECUTION_TIME); - - it('should not take too long', async () => { - let tokenizer = await AutoTokenizer.from_pretrained('Xenova/all-MiniLM-L6-v2'); - - let text = String.prototype.repeat.call('a', 50000); - let token_ids = tokenizer.encode(text); - compare(token_ids, [101, 100, 102]) - }, 5000); // NOTE: 5 seconds - - it('Special/added tokens with earlier partial matches', async () => { - let tokenizer = await AutoTokenizer.from_pretrained('Xenova/gemini-nano'); - { - let token_ids = tokenizer.encode('\n', { add_special_tokens: false }); - compare(token_ids, [108]) - } - { - let token_ids = tokenizer.encode('\n\n', { add_special_tokens: false }); - compare(token_ids, [109]) // Should not be [108, 108] - } - }, MAX_EXECUTION_TIME); +describe("Edge cases", () => { + it( + "should not crash when encoding a very long string", + async () => { + let tokenizer = await AutoTokenizer.from_pretrained("Xenova/t5-small"); + + let text = String.prototype.repeat.call("Hello world! ", 50000); + let encoded = tokenizer(text); + expect(encoded.input_ids.data.length).toBeGreaterThan(100000); + }, + MAX_EXECUTION_TIME, + ); + + it("should not take too long", async () => { + let tokenizer = await AutoTokenizer.from_pretrained("Xenova/all-MiniLM-L6-v2"); + + let text = String.prototype.repeat.call("a", 50000); + let token_ids = tokenizer.encode(text); + compare(token_ids, [101, 100, 102]); + }, 5000); // NOTE: 5 seconds + + it( + "Special/added tokens with earlier partial matches", + async () => { + let tokenizer = await AutoTokenizer.from_pretrained("Xenova/gemini-nano"); + { + let token_ids = tokenizer.encode("\n", { add_special_tokens: false }); + compare(token_ids, [108]); + } + { + let token_ids = tokenizer.encode("\n\n", { add_special_tokens: false }); + compare(token_ids, [109]); // Should not be [108, 108] + } + }, + MAX_EXECUTION_TIME, + ); }); -describe('Extra decoding tests', () => { - it('should be able to decode the output of encode', async () => { - let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); - - let text = 'hello world!'; - - // Ensure all the following outputs are the same: - // 1. Tensor of ids: allow decoding of 1D or 2D tensors. - let encodedTensor = tokenizer(text); - let decoded1 = tokenizer.decode(encodedTensor.input_ids, { skip_special_tokens: true }); - let decoded2 = tokenizer.batch_decode(encodedTensor.input_ids, { skip_special_tokens: true })[0]; - expect(decoded1).toEqual(text); - expect(decoded2).toEqual(text); - - // 2. List of ids - let encodedList = tokenizer(text, { return_tensor: false }); - let decoded3 = tokenizer.decode(encodedList.input_ids, { skip_special_tokens: true }); - let decoded4 = tokenizer.batch_decode([encodedList.input_ids], { skip_special_tokens: true })[0]; - expect(decoded3).toEqual(text); - expect(decoded4).toEqual(text); - - }, MAX_EXECUTION_TIME); +describe("Extra decoding tests", () => { + it( + "should be able to decode the output of encode", + async () => { + let tokenizer = await AutoTokenizer.from_pretrained("Xenova/bert-base-uncased"); + + let text = "hello world!"; + + // Ensure all the following outputs are the same: + // 1. Tensor of ids: allow decoding of 1D or 2D tensors. + let encodedTensor = tokenizer(text); + let decoded1 = tokenizer.decode(encodedTensor.input_ids, { skip_special_tokens: true }); + let decoded2 = tokenizer.batch_decode(encodedTensor.input_ids, { skip_special_tokens: true })[0]; + expect(decoded1).toEqual(text); + expect(decoded2).toEqual(text); + + // 2. List of ids + let encodedList = tokenizer(text, { return_tensor: false }); + let decoded3 = tokenizer.decode(encodedList.input_ids, { skip_special_tokens: true }); + let decoded4 = tokenizer.batch_decode([encodedList.input_ids], { skip_special_tokens: true })[0]; + expect(decoded3).toEqual(text); + expect(decoded4).toEqual(text); + }, + MAX_EXECUTION_TIME, + ); }); -describe('Chat templates', () => { - it('should generate a chat template', async () => { - const tokenizer = await AutoTokenizer.from_pretrained("Xenova/mistral-tokenizer-v1"); - - const chat = [ - { "role": "user", "content": "Hello, how are you?" }, - { "role": "assistant", "content": "I'm doing great. How can I help you today?" }, - { "role": "user", "content": "I'd like to show off how chat templating works!" }, - ] +describe("Chat templates", () => { + it("should generate a chat template", async () => { + const tokenizer = await AutoTokenizer.from_pretrained("Xenova/mistral-tokenizer-v1"); - const text = tokenizer.apply_chat_template(chat, { tokenize: false }); + const chat = [ + { role: "user", content: "Hello, how are you?" }, + { role: "assistant", content: "I'm doing great. How can I help you today?" }, + { role: "user", content: "I'd like to show off how chat templating works!" }, + ]; - expect(text).toEqual("[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]"); - - const input_ids = tokenizer.apply_chat_template(chat, { tokenize: true, return_tensor: false }); - compare(input_ids, [1, 733, 16289, 28793, 22557, 28725, 910, 460, 368, 28804, 733, 28748, 16289, 28793, 28737, 28742, 28719, 2548, 1598, 28723, 1602, 541, 315, 1316, 368, 3154, 28804, 2, 28705, 733, 16289, 28793, 315, 28742, 28715, 737, 298, 1347, 805, 910, 10706, 5752, 1077, 3791, 28808, 733, 28748, 16289, 28793]) - }); + const text = tokenizer.apply_chat_template(chat, { tokenize: false }); - it('should support multiple chat templates', async () => { - - const tokenizer = await AutoTokenizer.from_pretrained("Xenova/c4ai-command-r-v01-tokenizer") - - // define conversation input: - const conversation = [ - { role: "user", content: "Whats the biggest penguin in the world?" } - ] - // define documents to ground on: - const documents = [ - { title: "Tall penguins", text: "Emperor penguins are the tallest growing up to 122 cm in height." }, - { title: "Penguin habitats", text: "Emperor penguins only live in Antarctica." } - ] - - // render the RAG prompt as a string: - const grounded_generation_prompt = tokenizer.apply_chat_template( - conversation, - { - chat_template: "rag", - tokenize: false, - add_generation_prompt: true, - - documents, - citation_mode: "accurate", // or "fast" - } - ) - expect(grounded_generation_prompt).toEqual( - "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># Safety Preamble\nThe instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.\n\n" + - "# System Preamble\n## Basic Rules\nYou are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.\n\n" + - "# User Preamble\n## Task and Context\nYou help people answer their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You should focus on serving the user's needs as best you can, which will be wide-ranging.\n\n## Style Guide\nUnless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling.<|END_OF_TURN_TOKEN|>" + - "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>Whats the biggest penguin in the world?<|END_OF_TURN_TOKEN|>" + - "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>\nDocument: 0\ntitle: Tall penguins\ntext: Emperor penguins are the tallest growing up to 122 cm in height.\n\nDocument: 1\ntitle: Penguin habitats\ntext: Emperor penguins only live in Antarctica.\n<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Carefully perform the following instructions, in order, starting each with a new line.\nFirstly, Decide which of the retrieved documents are relevant to the user's last input by writing 'Relevant Documents:' followed by comma-separated list of document numbers. If none are relevant, you should instead write 'None'.\nSecondly, Decide which of the retrieved documents contain facts that should be cited in a good answer to the user's last input by writing 'Cited Documents:' followed a comma-separated list of document numbers. If you dont want to cite any of them, you should instead write 'None'.\nThirdly, Write 'Answer:' followed by a response to the user's last input in high quality natural english. Use the retrieved documents to help you. Do not insert any citations or grounding markup.\nFinally, Write 'Grounded answer:' followed by a response to the user's last input in high quality natural english. Use the symbols and to indicate when a fact comes from a document in the search result, e.g my fact for a fact from document 0.<|END_OF_TURN_TOKEN|>" + - "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" - ); - }); + expect(text).toEqual("[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]"); - it('should support user-defined chat template', async () => { - const tokenizer = await AutoTokenizer.from_pretrained("Xenova/llama-tokenizer"); - - const chat = [ - { role: 'user', content: 'Hello, how are you?' }, - { role: 'assistant', content: "I'm doing great. How can I help you today?" }, - { role: 'user', content: "I'd like to show off how chat templating works!" }, - ] - - // https://discuss.huggingface.co/t/issue-with-llama-2-chat-template-and-out-of-date-documentation/61645/3 - const chat_template = ( - "{% if messages[0]['role'] == 'system' %}" + - "{% set loop_messages = messages[1:] %}" + // Extract system message if it's present - "{% set system_message = messages[0]['content'] %}" + - "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}" + - "{% set loop_messages = messages %}" + // Or use the default system message if the flag is set - "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}" + - "{% else %}" + - "{% set loop_messages = messages %}" + - "{% set system_message = false %}" + - "{% endif %}" + - "{% if loop_messages|length == 0 and system_message %}" + // Special handling when only sys message present - "{{ bos_token + '[INST] <>\\n' + system_message + '\\n<>\\n\\n [/INST]' }}" + - "{% endif %}" + - "{% for message in loop_messages %}" + // Loop over all non-system messages - "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + - "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + - "{% endif %}" + - "{% if loop.index0 == 0 and system_message != false %}" + // Embed system message in first message - "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}" + - "{% else %}" + - "{% set content = message['content'] %}" + - "{% endif %}" + - "{% if message['role'] == 'user' %}" + // After all of that, handle messages/roles in a fairly normal way - "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" + - "{% elif message['role'] == 'system' %}" + - "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}" + - "{% elif message['role'] == 'assistant' %}" + - "{{ ' ' + content.strip() + ' ' + eos_token }}" + - "{% endif %}" + - "{% endfor %}" - ) - .replaceAll('USE_DEFAULT_PROMPT', true) - .replaceAll('DEFAULT_SYSTEM_MESSAGE', 'You are a helpful, respectful and honest assistant.'); - - const text = tokenizer.apply_chat_template(chat, { tokenize: false, return_tensor: false, chat_template }); - - expect(text).toEqual("[INST] <>\nYou are a helpful, respectful and honest assistant.\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]"); - - // TODO: Add test for token_ids once bug in transformers is fixed. - }); + const input_ids = tokenizer.apply_chat_template(chat, { tokenize: true, return_tensor: false }); + compare(input_ids, [1, 733, 16289, 28793, 22557, 28725, 910, 460, 368, 28804, 733, 28748, 16289, 28793, 28737, 28742, 28719, 2548, 1598, 28723, 1602, 541, 315, 1316, 368, 3154, 28804, 2, 28705, 733, 16289, 28793, 315, 28742, 28715, 737, 298, 1347, 805, 910, 10706, 5752, 1077, 3791, 28808, 733, 28748, 16289, 28793]); + }); - it('should support default parameters', async () => { - const tokenizer = await AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer"); + it("should support multiple chat templates", async () => { + const tokenizer = await AutoTokenizer.from_pretrained("Xenova/c4ai-command-r-v01-tokenizer"); - // Example adapted from https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct#tool-use-with-transformers - const chat = [ - { "role": "system", "content": "You are a bot that responds to weather queries." }, - { "role": "user", "content": "Hey, what's the temperature in Paris right now?" } - ] - const tools = [ - { 'type': 'function', 'function': { 'name': 'get_current_temperature', 'description': 'Get the current temperature at a location.', 'parameters': { 'type': 'object', 'properties': { 'location': { 'type': 'string', 'description': 'The location to get the temperature for, in the format "City, Country"' } }, 'required': ['location'] }, 'return': { 'type': 'number', 'description': 'The current temperature at the specified location in the specified units, as a float.' } } }, - ] + // define conversation input: + const conversation = [{ role: "user", content: "Whats the biggest penguin in the world?" }]; + // define documents to ground on: + const documents = [ + { title: "Tall penguins", text: "Emperor penguins are the tallest growing up to 122 cm in height." }, + { title: "Penguin habitats", text: "Emperor penguins only live in Antarctica." }, + ]; - { // `tools` unset (will default to `null`) - const text = tokenizer.apply_chat_template(chat, { tokenize: false }); - expect(text).toEqual("<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYou are a bot that responds to weather queries.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHey, what's the temperature in Paris right now?<|eot_id|>"); + // render the RAG prompt as a string: + const grounded_generation_prompt = tokenizer.apply_chat_template(conversation, { + chat_template: "rag", + tokenize: false, + add_generation_prompt: true, - const input_ids = tokenizer.apply_chat_template(chat, { tokenize: true, return_tensor: false }); - compare(input_ids, [128000, 128006, 9125, 128007, 271, 38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220, 1627, 10263, 220, 2366, 19, 271, 2675, 527, 264, 11164, 430, 31680, 311, 9282, 20126, 13, 128009, 128006, 882, 128007, 271, 19182, 11, 1148, 596, 279, 9499, 304, 12366, 1314, 1457, 30, 128009]) - } + documents, + citation_mode: "accurate", // or "fast" + }); + expect(grounded_generation_prompt).toEqual("<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># Safety Preamble\nThe instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.\n\n" + "# System Preamble\n## Basic Rules\nYou are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.\n\n" + "# User Preamble\n## Task and Context\nYou help people answer their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You should focus on serving the user's needs as best you can, which will be wide-ranging.\n\n## Style Guide\nUnless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling.<|END_OF_TURN_TOKEN|>" + "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>Whats the biggest penguin in the world?<|END_OF_TURN_TOKEN|>" + "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>\nDocument: 0\ntitle: Tall penguins\ntext: Emperor penguins are the tallest growing up to 122 cm in height.\n\nDocument: 1\ntitle: Penguin habitats\ntext: Emperor penguins only live in Antarctica.\n<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Carefully perform the following instructions, in order, starting each with a new line.\nFirstly, Decide which of the retrieved documents are relevant to the user's last input by writing 'Relevant Documents:' followed by comma-separated list of document numbers. If none are relevant, you should instead write 'None'.\nSecondly, Decide which of the retrieved documents contain facts that should be cited in a good answer to the user's last input by writing 'Cited Documents:' followed a comma-separated list of document numbers. If you dont want to cite any of them, you should instead write 'None'.\nThirdly, Write 'Answer:' followed by a response to the user's last input in high quality natural english. Use the retrieved documents to help you. Do not insert any citations or grounding markup.\nFinally, Write 'Grounded answer:' followed by a response to the user's last input in high quality natural english. Use the symbols and to indicate when a fact comes from a document in the search result, e.g my fact for a fact from document 0.<|END_OF_TURN_TOKEN|>" + "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"); + }); + + it("should support user-defined chat template", async () => { + const tokenizer = await AutoTokenizer.from_pretrained("Xenova/llama-tokenizer"); + + const chat = [ + { role: "user", content: "Hello, how are you?" }, + { role: "assistant", content: "I'm doing great. How can I help you today?" }, + { role: "user", content: "I'd like to show off how chat templating works!" }, + ]; + + // https://discuss.huggingface.co/t/issue-with-llama-2-chat-template-and-out-of-date-documentation/61645/3 + const chat_template = ( + "{% if messages[0]['role'] == 'system' %}" + + "{% set loop_messages = messages[1:] %}" + // Extract system message if it's present + "{% set system_message = messages[0]['content'] %}" + + "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}" + + "{% set loop_messages = messages %}" + // Or use the default system message if the flag is set + "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}" + + "{% else %}" + + "{% set loop_messages = messages %}" + + "{% set system_message = false %}" + + "{% endif %}" + + "{% if loop_messages|length == 0 and system_message %}" + // Special handling when only sys message present + "{{ bos_token + '[INST] <>\\n' + system_message + '\\n<>\\n\\n [/INST]' }}" + + "{% endif %}" + + "{% for message in loop_messages %}" + // Loop over all non-system messages + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + + "{% endif %}" + + "{% if loop.index0 == 0 and system_message != false %}" + // Embed system message in first message + "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}" + + "{% else %}" + + "{% set content = message['content'] %}" + + "{% endif %}" + + "{% if message['role'] == 'user' %}" + // After all of that, handle messages/roles in a fairly normal way + "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" + + "{% elif message['role'] == 'system' %}" + + "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}" + + "{% elif message['role'] == 'assistant' %}" + + "{{ ' ' + content.strip() + ' ' + eos_token }}" + + "{% endif %}" + + "{% endfor %}" + ) + .replaceAll("USE_DEFAULT_PROMPT", true) + .replaceAll("DEFAULT_SYSTEM_MESSAGE", "You are a helpful, respectful and honest assistant."); + + const text = tokenizer.apply_chat_template(chat, { tokenize: false, return_tensor: false, chat_template }); + + expect(text).toEqual("[INST] <>\nYou are a helpful, respectful and honest assistant.\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]"); + + // TODO: Add test for token_ids once bug in transformers is fixed. + }); + + it("should support default parameters", async () => { + const tokenizer = await AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer"); + + // Example adapted from https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct#tool-use-with-transformers + const chat = [ + { role: "system", content: "You are a bot that responds to weather queries." }, + { role: "user", content: "Hey, what's the temperature in Paris right now?" }, + ]; + const tools = [{ type: "function", function: { name: "get_current_temperature", description: "Get the current temperature at a location.", parameters: { type: "object", properties: { location: { type: "string", description: 'The location to get the temperature for, in the format "City, Country"' } }, required: ["location"] }, return: { type: "number", description: "The current temperature at the specified location in the specified units, as a float." } } }]; + + { + // `tools` unset (will default to `null`) + const text = tokenizer.apply_chat_template(chat, { tokenize: false }); + expect(text).toEqual("<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYou are a bot that responds to weather queries.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHey, what's the temperature in Paris right now?<|eot_id|>"); + + const input_ids = tokenizer.apply_chat_template(chat, { tokenize: true, return_tensor: false }); + compare(input_ids, [128000, 128006, 9125, 128007, 271, 38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220, 1627, 10263, 220, 2366, 19, 271, 2675, 527, 264, 11164, 430, 31680, 311, 9282, 20126, 13, 128009, 128006, 882, 128007, 271, 19182, 11, 1148, 596, 279, 9499, 304, 12366, 1314, 1457, 30, 128009]); + } - { // `tools` set - const text = tokenizer.apply_chat_template(chat, { tools, tokenize: false }); - expect(text).toEqual("<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYou are a bot that responds to weather queries.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"type\": \"function\",\n \"function\": {\n \"name\": \"get_current_temperature\",\n \"description\": \"Get the current temperature at a location.\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"location\": {\n \"type\": \"string\",\n \"description\": \"The location to get the temperature for, in the format \\\"City, Country\\\"\"\n }\n },\n \"required\": [\n \"location\"\n ]\n },\n \"return\": {\n \"type\": \"number\",\n \"description\": \"The current temperature at the specified location in the specified units, as a float.\"\n }\n }\n}\n\nHey, what's the temperature in Paris right now?<|eot_id|>"); + { + // `tools` set + const text = tokenizer.apply_chat_template(chat, { tools, tokenize: false }); + expect(text).toEqual('<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYou are a bot that responds to weather queries.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {"name": function name, "parameters": dictionary of argument name and its value}.Do not use variables.\n\n{\n "type": "function",\n "function": {\n "name": "get_current_temperature",\n "description": "Get the current temperature at a location.",\n "parameters": {\n "type": "object",\n "properties": {\n "location": {\n "type": "string",\n "description": "The location to get the temperature for, in the format \\"City, Country\\""\n }\n },\n "required": [\n "location"\n ]\n },\n "return": {\n "type": "number",\n "description": "The current temperature at the specified location in the specified units, as a float."\n }\n }\n}\n\nHey, what\'s the temperature in Paris right now?<|eot_id|>'); - const input_ids = tokenizer.apply_chat_template(chat, { tools, tokenize: true, return_tensor: false }); - compare(input_ids, [128000, 128006, 9125, 128007, 271, 13013, 25, 6125, 27993, 198, 38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220, 1627, 10263, 220, 2366, 19, 271, 2675, 527, 264, 11164, 430, 31680, 311, 9282, 20126, 13, 128009, 128006, 882, 128007, 271, 22818, 279, 2768, 5865, 11, 4587, 6013, 449, 264, 4823, 369, 264, 734, 1650, 449, 1202, 6300, 6105, 430, 1888, 11503, 279, 2728, 10137, 382, 66454, 304, 279, 3645, 5324, 609, 794, 734, 836, 11, 330, 14105, 794, 11240, 315, 5811, 836, 323, 1202, 907, 7966, 5519, 539, 1005, 7482, 382, 517, 262, 330, 1337, 794, 330, 1723, 761, 262, 330, 1723, 794, 341, 286, 330, 609, 794, 330, 456, 11327, 54625, 761, 286, 330, 4789, 794, 330, 1991, 279, 1510, 9499, 520, 264, 3813, 10560, 286, 330, 14105, 794, 341, 310, 330, 1337, 794, 330, 1735, 761, 310, 330, 13495, 794, 341, 394, 330, 2588, 794, 341, 504, 330, 1337, 794, 330, 928, 761, 504, 330, 4789, 794, 330, 791, 3813, 311, 636, 279, 9499, 369, 11, 304, 279, 3645, 7393, 13020, 11, 14438, 2153, 702, 394, 457, 310, 1173, 310, 330, 6413, 794, 2330, 394, 330, 2588, 702, 310, 5243, 286, 1173, 286, 330, 693, 794, 341, 310, 330, 1337, 794, 330, 4174, 761, 310, 330, 4789, 794, 330, 791, 1510, 9499, 520, 279, 5300, 3813, 304, 279, 5300, 8316, 11, 439, 264, 2273, 10246, 286, 457, 262, 457, 633, 19182, 11, 1148, 596, 279, 9499, 304, 12366, 1314, 1457, 30, 128009]) - } - }); + const input_ids = tokenizer.apply_chat_template(chat, { tools, tokenize: true, return_tensor: false }); + compare(input_ids, [128000, 128006, 9125, 128007, 271, 13013, 25, 6125, 27993, 198, 38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220, 1627, 10263, 220, 2366, 19, 271, 2675, 527, 264, 11164, 430, 31680, 311, 9282, 20126, 13, 128009, 128006, 882, 128007, 271, 22818, 279, 2768, 5865, 11, 4587, 6013, 449, 264, 4823, 369, 264, 734, 1650, 449, 1202, 6300, 6105, 430, 1888, 11503, 279, 2728, 10137, 382, 66454, 304, 279, 3645, 5324, 609, 794, 734, 836, 11, 330, 14105, 794, 11240, 315, 5811, 836, 323, 1202, 907, 7966, 5519, 539, 1005, 7482, 382, 517, 262, 330, 1337, 794, 330, 1723, 761, 262, 330, 1723, 794, 341, 286, 330, 609, 794, 330, 456, 11327, 54625, 761, 286, 330, 4789, 794, 330, 1991, 279, 1510, 9499, 520, 264, 3813, 10560, 286, 330, 14105, 794, 341, 310, 330, 1337, 794, 330, 1735, 761, 310, 330, 13495, 794, 341, 394, 330, 2588, 794, 341, 504, 330, 1337, 794, 330, 928, 761, 504, 330, 4789, 794, 330, 791, 3813, 311, 636, 279, 9499, 369, 11, 304, 279, 3645, 7393, 13020, 11, 14438, 2153, 702, 394, 457, 310, 1173, 310, 330, 6413, 794, 2330, 394, 330, 2588, 702, 310, 5243, 286, 1173, 286, 330, 693, 794, 341, 310, 330, 1337, 794, 330, 4174, 761, 310, 330, 4789, 794, 330, 791, 1510, 9499, 520, 279, 5300, 3813, 304, 279, 5300, 8316, 11, 439, 264, 2273, 10246, 286, 457, 262, 457, 633, 19182, 11, 1148, 596, 279, 9499, 304, 12366, 1314, 1457, 30, 128009]); + } + }); }); -describe('Decode ASR', () => { - it('should decode ASR outputs', async () => { - const tokenizer = await WhisperTokenizer.from_pretrained('onnx-community/whisper-tiny.en_timestamped'); - - const model_outputs = [ - { - stride: [30, 0, 5], - tokens: [50257n, 50362n, 8410n, 7283n, 0n, 2329n, 8410n, 7283n, 0n, 2094n, 470n, 1309n, 534n, 10625n, 307n, 10625n, 13n, 34668n, 345n, 531n, 9439n, 11n, 523n, 655n, 8410n, 7283n, 0n, 39134n, 16592n, 10625n, 0n, 9440n, 36n, 26751n, 0n, 25848n, 8410n, 7283n, 0n, 2773n, 661n, 4320n, 1943n, 981n, 345n, 821n, 8066n, 7765n, 510n, 290n, 670n, 1327n, 379n, 340n, 13n, 10528n, 318n, 5340n, 0n, 50256n], - token_timestamps: [0, 0, 0, 3.78, 4.22, 5.34, 6.04, 6.56, 7, 7.92, 8.58, 8.58, 8.88, 9.14, 9.54, 9.94, 10.58, 11.38, 11.88, 12.42, 12.62, 13, 13.36, 13.64, 14.26, 14.76, 15.12, 15.4, 15.74, 16.12, 16.66, 17.14, 17.24, 17.24, 17.72, 18.38, 18.6, 19.38, 19.92, 22.66, 22.9, 23.24, 23.5, 24.14, 24.56, 24.7, 24.72, 24.94, 25.18, 25.54, 25.72, 26.02, 26.34, 26.44, 26.84, 27.04, 27.16, 27.54, 28.06, 29.92] - }, - { - stride: [30, 5, 5], - tokens: [50257n, 50362n, 2773n, 661n, 4320n, 1943n, 981n, 345n, 821n, 8066n, 7765n, 510n, 290n, 670n, 1327n, 379n, 340n, 13n, 10528n, 318n, 5340n, 13n, 921n, 815n, 651n, 284n, 262n, 966n, 810n, 2687n, 2073n, 561n, 11238n, 290n, 345n, 821n, 407n, 8066n, 2245n, 612n, 13n, 1400n, 11n, 644n, 389n, 345n, 4953n, 329n, 30n, 2141n, 340n, 0n, 2329n, 466n, 340n, 0n, 3363n, 345n, 460n, 0n, 2329n, 466n, 340n, 0n, 50256n], - token_timestamps: [0, 0, 0, 2.92, 3.24, 3.48, 4.14, 4.56, 4.7, 4.74, 4.92, 5.18, 5.54, 5.72, 6.04, 6.34, 6.46, 6.84, 7.04, 7.16, 7.54, 8.12, 10.16, 10.7, 10.9, 11.12, 11.24, 11.48, 11.84, 12.44, 12.82, 13.2, 13.46, 13.72, 14.06, 14.28, 14.34, 14.56, 14.8, 15.16, 15.9, 16.42, 16.82, 16.86, 17.02, 17.1, 17.22, 17.56, 18.06, 19.28, 19.62, 20.26, 21.96, 22.64, 24.28, 24.76, 25.18, 25.56, 25.78, 26.28, 27.12, 27.54, 27.82, 28.22, 29.48] - }, - { - stride: [23.7728125, 5, 0], - tokens: [50257n, 50362n, 2329n, 466n, 340n, 0n, 3363n, 345n, 460n, 0n, 2329n, 466n, 340n, 0n, 1002n, 345n, 821n, 10032n, 286n, 3599n, 625n, 11n, 2245n, 3501n, 510n, 13n, 50256n], - token_timestamps: [0, 0, 0, 2.44, 4.3, 5.04, 5.06, 5.56, 5.8, 6.32, 7.12, 7.56, 7.8, 8.9, 10.92, 12.96, 13.28, 13.28, 13.44, 13.72, 13.96, 14.84, 15.5, 16.06, 16.86, 17.88, 20.92] - } - ]; - - const target = [ - " DO IT! Just DO IT! Don't let your dreams be dreams. Yesterday you said tomorrow, so just DO IT! MAKE YOUR dreams! COME TRUE! JUST DO IT! Some people dream success while you're gonna wake up and work hard at it. Nothing is impossible. You should get to the point where anyone else would quit and you're not gonna stop there. No, what are you waiting for? Do it! Just do it! Yes you can! Just do it! If you're tired of starting over, stop giving up.", - { - chunks: [ - { text: ' DO', timestamp: [0.0, 3.78] }, - { text: ' IT!', timestamp: [3.78, 4.24 /* 5.34 */] }, - { text: ' Just', timestamp: [5.34, 6.04] }, - { text: ' DO', timestamp: [6.04, 6.56] }, - { text: ' IT!', timestamp: [6.56, 7.02 /* 7.92 */] }, - { text: " Don't", timestamp: [7.92, 8.58] }, - { text: ' let', timestamp: [8.58, 8.88] }, - { text: ' your', timestamp: [8.88, 9.14] }, - { text: ' dreams', timestamp: [9.14, 9.54] }, - { text: ' be', timestamp: [9.54, 9.94] }, - { text: ' dreams.', timestamp: [9.94, 10.6 /* 11.38 */] }, - { text: ' Yesterday', timestamp: [11.38, 11.88] }, - { text: ' you', timestamp: [11.88, 12.42] }, - { text: ' said', timestamp: [12.42, 12.62] }, - { text: ' tomorrow,', timestamp: [12.62, 13.02 /* 13.36 */] }, - { text: ' so', timestamp: [13.36, 13.64] }, - { text: ' just', timestamp: [13.64, 14.26] }, - { text: ' DO', timestamp: [14.26, 14.76] }, - { text: ' IT!', timestamp: [14.76, 15.14 /* 15.4 */] }, - { text: ' MAKE', timestamp: [15.4, 15.74] }, - { text: ' YOUR', timestamp: [15.74, 16.12] }, - { text: ' dreams!', timestamp: [16.12, 16.68 /* 17.14 */] }, - { text: ' COME', timestamp: [17.14, 17.24] }, - { text: ' TRUE!', timestamp: [17.24, 17.74 /* 18.38 */] }, - { text: ' JUST', timestamp: [18.38, 18.6] }, - { text: ' DO', timestamp: [18.6, 19.38] }, - { text: ' IT!', timestamp: [19.38, 19.94 /* 22.66 */] }, - { text: ' Some', timestamp: [22.66, 22.9] }, - { text: ' people', timestamp: [22.9, 23.24] }, - { text: ' dream', timestamp: [23.24, 23.5] }, - { text: ' success', timestamp: [23.5, 24.14] }, - { text: ' while', timestamp: [24.14, 24.56] }, - { text: " you're", timestamp: [24.56, 24.72] }, - { text: ' gonna', timestamp: [24.72, 24.94] }, - { text: ' wake', timestamp: [24.94, 25.18] }, - { text: ' up', timestamp: [25.18, 25.54] }, - { text: ' and', timestamp: [25.54, 25.72] }, - { text: ' work', timestamp: [25.72, 26.04] }, - { text: ' hard', timestamp: [26.04, 26.34] }, - { text: ' at', timestamp: [26.34, 26.46] }, - { text: ' it.', timestamp: [26.46, 26.86 /* 27.04 */] }, - { text: ' Nothing', timestamp: [27.04, 27.16] }, - { text: ' is', timestamp: [27.16, 27.54] }, - { text: ' impossible.', timestamp: [27.54, 28.14 /* 30.16 */] }, - { text: ' You', timestamp: [30.16, 30.7] }, - { text: ' should', timestamp: [30.7, 30.9] }, - { text: ' get', timestamp: [30.9, 31.12] }, - { text: ' to', timestamp: [31.12, 31.24] }, - { text: ' the', timestamp: [31.24, 31.48] }, - { text: ' point', timestamp: [31.48, 31.84] }, - { text: ' where', timestamp: [31.84, 32.44] }, - { text: ' anyone', timestamp: [32.44, 32.82] }, - { text: ' else', timestamp: [32.82, 33.2] }, - { text: ' would', timestamp: [33.2, 33.46] }, - { text: ' quit', timestamp: [33.46, 33.72] }, - { text: ' and', timestamp: [33.72, 34.06] }, - { text: " you're", timestamp: [34.06, 34.34] }, - { text: ' not', timestamp: [34.34, 34.56] }, - { text: ' gonna', timestamp: [34.56, 34.8] }, - { text: ' stop', timestamp: [34.8, 35.16] }, - { text: ' there.', timestamp: [35.16, 35.92 /* 36.42 */] }, - { text: ' No,', timestamp: [36.42, 36.84 /* 36.86 */] }, - { text: ' what', timestamp: [36.86, 37.02] }, - { text: ' are', timestamp: [37.02, 37.1] }, - { text: ' you', timestamp: [37.1, 37.22] }, - { text: ' waiting', timestamp: [37.22, 37.56] }, - { text: ' for?', timestamp: [37.56, 38.08 /* 39.28 */] }, - { text: ' Do', timestamp: [39.28, 39.62] }, - { text: ' it!', timestamp: [39.62, 40.28 /* 41.96 */] }, - { text: ' Just', timestamp: [41.96, 42.64] }, - { text: ' do', timestamp: [42.64, 44.28] }, - { text: ' it!', timestamp: [44.28, 44.78 /* 45.18 */] }, - { text: ' Yes', timestamp: [45.18, 45.56] }, - { text: ' you', timestamp: [45.56, 45.78] }, - { text: ' can!', timestamp: [45.8, 46.34 /* 47.12 */] }, - { text: ' Just', timestamp: [47.12, 47.56] }, - { text: ' do', timestamp: [47.56, 47.8] }, - { text: ' it!', timestamp: [47.8, 48.92 /* 50.92 */] }, - { text: ' If', timestamp: [50.92, 52.96] }, - { text: " you're", timestamp: [52.96, 53.28] }, - { text: ' tired', timestamp: [53.28, 53.44] }, - { text: ' of', timestamp: [53.44, 53.72] }, - { text: ' starting', timestamp: [53.72, 53.96] }, - { text: ' over,', timestamp: [53.96, 54.86 /* 55.5 */] }, - { text: ' stop', timestamp: [55.5, 56.06] }, - { text: ' giving', timestamp: [56.06, 56.86] }, - { text: ' up.', timestamp: [56.86, 57.9 /* 60.92 */] } - ] - } - ] - - compare(tokenizer._decode_asr(model_outputs, { - return_timestamps: 'word', - time_precision: 0.02, - force_full_sequences: false, - }), target, 1e-2); - - }, MAX_EXECUTION_TIME); - - it('should handle overlapping edge case', async () => { - const tokenizer = await WhisperTokenizer.from_pretrained('onnx-community/whisper-tiny.en_timestamped'); - - const model_outputs = [ - { - stride: [30, 0, 5], - tokens: [50257n, 50362n, 8410n, 7283n, 0n, 2329n, 8410n, 7283n, 0n, 2094n, 470n, 1309n, 534n, 10625n, 307n, 10625n, 13n, 34668n, 11n, 345n, 531n, 9439n, 11n, 523n, 655n, 8410n, 7283n, 0n, 39134n, 16592n, 10560n, 3955n, 50n, 0n, 7102n, 5446n, 46n, 0n, 25848n, 8410n, 7283n, 0n, 2773n, 661n, 4320n, 1943n, 981n, 345n, 821n, 8066n, 7765n, 510n, 290n, 670n, 1327n, 379n, 340n, 13n, 10528n, 318n, 5340n, 13n, 50256n], - token_timestamps: [0, 0, 0, 3.78, 4.22, 5.26, 6.04, 6.54, 7, 7.94, 8.58, 8.58, 8.88, 9.16, 9.54, 9.94, 10.6, 11.38, 11.88, 12.38, 12.44, 12.62, 13, 13.36, 13.64, 14.24, 14.74, 15.12, 15.4, 15.74, 16.1, 16.54, 16.54, 16.78, 17.08, 17.2, 17.36, 17.56, 18.08, 18.58, 19.38, 19.88, 22.54, 22.9, 23.24, 23.5, 24.14, 24.56, 24.7, 24.94, 24.94, 25.18, 25.54, 25.72, 26.04, 26.34, 26.46, 26.84, 27.04, 27.14, 27.54, 28.06, 29.92] - }, - { - stride: [30, 5, 5], - tokens: [50257n, 50362n, 2773n, 661n, 4320n, 1943n, 981n, 345n, 821n, 8066n, 7765n, 510n, 290n, 670n, 1327n, 379n, 340n, 13n, 10528n, 318n, 5340n, 13n, 921n, 815n, 651n, 284n, 262n, 966n, 810n, 2687n, 2073n, 561n, 11238n, 290n, 345n, 821n, 407n, 8066n, 2245n, 612n, 13n, 1400n, 11n, 644n, 389n, 345n, 4953n, 329n, 30n, 2141n, 340n, 0n, 2329n, 466n, 340n, 0n, 3363n, 11n, 345n, 460n, 0n, 2329n, 466n, 340n, 0n, 50256n], - token_timestamps: [0, 0, 0, 2.92, 3.24, 3.5, 4.14, 4.56, 4.7, 4.74, 4.92, 5.18, 5.54, 5.74, 6.04, 6.34, 6.46, 6.84, 7.04, 7.18, 7.56, 8.12, 9.68, 10.7, 10.88, 11.1, 11.24, 11.48, 11.82, 12.46, 12.82, 13.2, 13.46, 13.72, 14.08, 14.28, 14.34, 14.56, 14.82, 15.16, 15.72, 16.42, 16.82, 16.86, 17, 17.1, 17.2, 17.56, 18.06, 19.28, 19.6, 20.28, 21.96, 22.64, 24.28, 24.76, 25.18, 25.56, 25.56, 25.84, 26.36, 27.12, 27.54, 27.82, 28.16, 29.48] - }, - { - stride: [23.7728125, 5, 0], - tokens: [50257n, 50362n, 2329n, 466n, 340n, 0n, 3363n, 345n, 460n, 0n, 2329n, 466n, 340n, 0n, 1002n, 534n, 15867n, 318n, 3599n, 625n, 11n, 2245n, 3501n, 510n, 13n, 50256n], - token_timestamps: [0, 0, 0, 2.44, 4.3, 5.04, 5.06, 5.56, 5.8, 6.32, 7.12, 7.56, 7.8, 8.72, 10.04, 12.96, 13.3, 13.44, 13.72, 13.98, 14.86, 15.5, 16, 16.88, 17.76, 20.9] - } - ]; - - const target = [ - " DO IT! Just DO IT! Don't let your dreams be dreams. Yesterday, you said tomorrow, so just DO IT! MAKE YOUR DRIMS! CONTRO! JUST DO IT! Some people dream success while you're gonna wake up and work hard at it. Nothing is impossible. You should get to the point where anyone else would quit and you're not gonna stop there. No, what are you waiting for? Do it! Just do it! Yes, you can! Just do it! If your tire is starting over, stop giving up.", - { - chunks: [ - { text: " DO", timestamp: [0, 3.78] }, - { text: " IT!", timestamp: [3.78, 4.24] }, - { text: " Just", timestamp: [5.26, 6.04] }, - { text: " DO", timestamp: [6.04, 6.54] }, - { text: " IT!", timestamp: [6.54, 7.02] }, - { text: " Don't", timestamp: [7.94, 8.58] }, - { text: " let", timestamp: [8.58, 8.88] }, - { text: " your", timestamp: [8.88, 9.16] }, - { text: " dreams", timestamp: [9.16, 9.54] }, - { text: " be", timestamp: [9.54, 9.94] }, - { text: " dreams.", timestamp: [9.94, 10.62] }, - { text: " Yesterday,", timestamp: [11.38, 11.9] }, - { text: " you", timestamp: [12.38, 12.44] }, - { text: " said", timestamp: [12.44, 12.62] }, - { text: " tomorrow,", timestamp: [12.62, 13.02] }, - { text: " so", timestamp: [13.36, 13.64] }, - { text: " just", timestamp: [13.64, 14.24] }, - { text: " DO", timestamp: [14.24, 14.74] }, - { text: " IT!", timestamp: [14.74, 15.14] }, - { text: " MAKE", timestamp: [15.4, 15.74] }, - { text: " YOUR", timestamp: [15.74, 16.1] }, - { text: " DRIMS!", timestamp: [16.1, 16.8] }, - { text: " CONTRO!", timestamp: [17.08, 17.58] }, - { text: " JUST", timestamp: [18.08, 18.58] }, - { text: " DO", timestamp: [18.58, 19.38] }, - { text: " IT!", timestamp: [19.38, 19.9] }, - { text: " Some", timestamp: [22.54, 22.9] }, - { text: " people", timestamp: [22.9, 23.24] }, - { text: " dream", timestamp: [23.24, 23.5] }, - { text: " success", timestamp: [23.5, 24.14] }, - { text: " while", timestamp: [24.14, 24.56] }, - { text: " you're", timestamp: [24.56, 24.94] }, - { text: " gonna", timestamp: [24.94, 24.94] }, - { text: " wake", timestamp: [24.94, 25.18] }, - { text: " up", timestamp: [25.18, 25.54] }, - { text: " and", timestamp: [25.54, 25.74] }, - { text: " work", timestamp: [25.74, 26.04] }, - { text: " hard", timestamp: [26.04, 26.34] }, - { text: " at", timestamp: [26.34, 26.46] }, - { text: " it.", timestamp: [26.46, 26.86] }, - { text: " Nothing", timestamp: [27.04, 27.18] }, - { text: " is", timestamp: [27.18, 27.56] }, - { text: " impossible.", timestamp: [27.56, 28.14] }, - { text: " You", timestamp: [29.68, 30.7] }, - { text: " should", timestamp: [30.7, 30.88] }, - { text: " get", timestamp: [30.88, 31.1] }, - { text: " to", timestamp: [31.1, 31.24] }, - { text: " the", timestamp: [31.24, 31.48] }, - { text: " point", timestamp: [31.48, 31.82] }, - { text: " where", timestamp: [31.82, 32.46] }, - { text: " anyone", timestamp: [32.46, 32.82] }, - { text: " else", timestamp: [32.82, 33.2] }, - { text: " would", timestamp: [33.2, 33.46] }, - { text: " quit", timestamp: [33.46, 33.72] }, - { text: " and", timestamp: [33.72, 34.08] }, - { text: " you're", timestamp: [34.08, 34.34] }, - { text: " not", timestamp: [34.34, 34.56] }, - { text: " gonna", timestamp: [34.56, 34.82] }, - { text: " stop", timestamp: [34.82, 35.16] }, - { text: " there.", timestamp: [35.16, 35.74] }, - { text: " No,", timestamp: [36.42, 36.84] }, - { text: " what", timestamp: [36.86, 37] }, - { text: " are", timestamp: [37, 37.1] }, - { text: " you", timestamp: [37.1, 37.2] }, - { text: " waiting", timestamp: [37.2, 37.56] }, - { text: " for?", timestamp: [37.56, 38.08] }, - { text: " Do", timestamp: [39.28, 39.6] }, - { text: " it!", timestamp: [39.6, 40.3] }, - { text: " Just", timestamp: [41.96, 42.64] }, - { text: " do", timestamp: [42.64, 44.28] }, - { text: " it!", timestamp: [44.28, 44.78] }, - { text: " Yes,", timestamp: [45.18, 45.56] }, - { text: " you", timestamp: [45.56, 45.84] }, - { text: " can!", timestamp: [45.8, 46.34] }, - { text: " Just", timestamp: [47.12, 47.56] }, - { text: " do", timestamp: [47.56, 47.8] }, - { text: " it!", timestamp: [47.8, 48.74] }, - { text: " If", timestamp: [50.04, 52.96] }, - { text: " your", timestamp: [52.96, 53.3] }, - { text: " tire", timestamp: [53.3, 53.44] }, - { text: " is", timestamp: [53.44, 53.72] }, - { text: " starting", timestamp: [53.72, 53.98] }, - { text: " over,", timestamp: [53.98, 54.88] }, - { text: " stop", timestamp: [55.5, 56] }, - { text: " giving", timestamp: [56, 56.88] }, - { text: " up.", timestamp: [56.88, 57.78] }, - ] - } - ] - - compare(tokenizer._decode_asr(model_outputs, { - return_timestamps: 'word', - time_precision: 0.02, - force_full_sequences: false, - }), target, 1e-2); +describe("Decode ASR", () => { + it( + "should decode ASR outputs", + async () => { + const tokenizer = await WhisperTokenizer.from_pretrained("onnx-community/whisper-tiny.en_timestamped"); - }, MAX_EXECUTION_TIME); + const model_outputs = [ + { + stride: [30, 0, 5], + tokens: [50257n, 50362n, 8410n, 7283n, 0n, 2329n, 8410n, 7283n, 0n, 2094n, 470n, 1309n, 534n, 10625n, 307n, 10625n, 13n, 34668n, 345n, 531n, 9439n, 11n, 523n, 655n, 8410n, 7283n, 0n, 39134n, 16592n, 10625n, 0n, 9440n, 36n, 26751n, 0n, 25848n, 8410n, 7283n, 0n, 2773n, 661n, 4320n, 1943n, 981n, 345n, 821n, 8066n, 7765n, 510n, 290n, 670n, 1327n, 379n, 340n, 13n, 10528n, 318n, 5340n, 0n, 50256n], + token_timestamps: [0, 0, 0, 3.78, 4.22, 5.34, 6.04, 6.56, 7, 7.92, 8.58, 8.58, 8.88, 9.14, 9.54, 9.94, 10.58, 11.38, 11.88, 12.42, 12.62, 13, 13.36, 13.64, 14.26, 14.76, 15.12, 15.4, 15.74, 16.12, 16.66, 17.14, 17.24, 17.24, 17.72, 18.38, 18.6, 19.38, 19.92, 22.66, 22.9, 23.24, 23.5, 24.14, 24.56, 24.7, 24.72, 24.94, 25.18, 25.54, 25.72, 26.02, 26.34, 26.44, 26.84, 27.04, 27.16, 27.54, 28.06, 29.92], + }, + { + stride: [30, 5, 5], + tokens: [50257n, 50362n, 2773n, 661n, 4320n, 1943n, 981n, 345n, 821n, 8066n, 7765n, 510n, 290n, 670n, 1327n, 379n, 340n, 13n, 10528n, 318n, 5340n, 13n, 921n, 815n, 651n, 284n, 262n, 966n, 810n, 2687n, 2073n, 561n, 11238n, 290n, 345n, 821n, 407n, 8066n, 2245n, 612n, 13n, 1400n, 11n, 644n, 389n, 345n, 4953n, 329n, 30n, 2141n, 340n, 0n, 2329n, 466n, 340n, 0n, 3363n, 345n, 460n, 0n, 2329n, 466n, 340n, 0n, 50256n], + token_timestamps: [0, 0, 0, 2.92, 3.24, 3.48, 4.14, 4.56, 4.7, 4.74, 4.92, 5.18, 5.54, 5.72, 6.04, 6.34, 6.46, 6.84, 7.04, 7.16, 7.54, 8.12, 10.16, 10.7, 10.9, 11.12, 11.24, 11.48, 11.84, 12.44, 12.82, 13.2, 13.46, 13.72, 14.06, 14.28, 14.34, 14.56, 14.8, 15.16, 15.9, 16.42, 16.82, 16.86, 17.02, 17.1, 17.22, 17.56, 18.06, 19.28, 19.62, 20.26, 21.96, 22.64, 24.28, 24.76, 25.18, 25.56, 25.78, 26.28, 27.12, 27.54, 27.82, 28.22, 29.48], + }, + { + stride: [23.7728125, 5, 0], + tokens: [50257n, 50362n, 2329n, 466n, 340n, 0n, 3363n, 345n, 460n, 0n, 2329n, 466n, 340n, 0n, 1002n, 345n, 821n, 10032n, 286n, 3599n, 625n, 11n, 2245n, 3501n, 510n, 13n, 50256n], + token_timestamps: [0, 0, 0, 2.44, 4.3, 5.04, 5.06, 5.56, 5.8, 6.32, 7.12, 7.56, 7.8, 8.9, 10.92, 12.96, 13.28, 13.28, 13.44, 13.72, 13.96, 14.84, 15.5, 16.06, 16.86, 17.88, 20.92], + }, + ]; + + const target = [ + " DO IT! Just DO IT! Don't let your dreams be dreams. Yesterday you said tomorrow, so just DO IT! MAKE YOUR dreams! COME TRUE! JUST DO IT! Some people dream success while you're gonna wake up and work hard at it. Nothing is impossible. You should get to the point where anyone else would quit and you're not gonna stop there. No, what are you waiting for? Do it! Just do it! Yes you can! Just do it! If you're tired of starting over, stop giving up.", + { + chunks: [ + { text: " DO", timestamp: [0.0, 3.78] }, + { text: " IT!", timestamp: [3.78, 4.24 /* 5.34 */] }, + { text: " Just", timestamp: [5.34, 6.04] }, + { text: " DO", timestamp: [6.04, 6.56] }, + { text: " IT!", timestamp: [6.56, 7.02 /* 7.92 */] }, + { text: " Don't", timestamp: [7.92, 8.58] }, + { text: " let", timestamp: [8.58, 8.88] }, + { text: " your", timestamp: [8.88, 9.14] }, + { text: " dreams", timestamp: [9.14, 9.54] }, + { text: " be", timestamp: [9.54, 9.94] }, + { text: " dreams.", timestamp: [9.94, 10.6 /* 11.38 */] }, + { text: " Yesterday", timestamp: [11.38, 11.88] }, + { text: " you", timestamp: [11.88, 12.42] }, + { text: " said", timestamp: [12.42, 12.62] }, + { text: " tomorrow,", timestamp: [12.62, 13.02 /* 13.36 */] }, + { text: " so", timestamp: [13.36, 13.64] }, + { text: " just", timestamp: [13.64, 14.26] }, + { text: " DO", timestamp: [14.26, 14.76] }, + { text: " IT!", timestamp: [14.76, 15.14 /* 15.4 */] }, + { text: " MAKE", timestamp: [15.4, 15.74] }, + { text: " YOUR", timestamp: [15.74, 16.12] }, + { text: " dreams!", timestamp: [16.12, 16.68 /* 17.14 */] }, + { text: " COME", timestamp: [17.14, 17.24] }, + { text: " TRUE!", timestamp: [17.24, 17.74 /* 18.38 */] }, + { text: " JUST", timestamp: [18.38, 18.6] }, + { text: " DO", timestamp: [18.6, 19.38] }, + { text: " IT!", timestamp: [19.38, 19.94 /* 22.66 */] }, + { text: " Some", timestamp: [22.66, 22.9] }, + { text: " people", timestamp: [22.9, 23.24] }, + { text: " dream", timestamp: [23.24, 23.5] }, + { text: " success", timestamp: [23.5, 24.14] }, + { text: " while", timestamp: [24.14, 24.56] }, + { text: " you're", timestamp: [24.56, 24.72] }, + { text: " gonna", timestamp: [24.72, 24.94] }, + { text: " wake", timestamp: [24.94, 25.18] }, + { text: " up", timestamp: [25.18, 25.54] }, + { text: " and", timestamp: [25.54, 25.72] }, + { text: " work", timestamp: [25.72, 26.04] }, + { text: " hard", timestamp: [26.04, 26.34] }, + { text: " at", timestamp: [26.34, 26.46] }, + { text: " it.", timestamp: [26.46, 26.86 /* 27.04 */] }, + { text: " Nothing", timestamp: [27.04, 27.16] }, + { text: " is", timestamp: [27.16, 27.54] }, + { text: " impossible.", timestamp: [27.54, 28.14 /* 30.16 */] }, + { text: " You", timestamp: [30.16, 30.7] }, + { text: " should", timestamp: [30.7, 30.9] }, + { text: " get", timestamp: [30.9, 31.12] }, + { text: " to", timestamp: [31.12, 31.24] }, + { text: " the", timestamp: [31.24, 31.48] }, + { text: " point", timestamp: [31.48, 31.84] }, + { text: " where", timestamp: [31.84, 32.44] }, + { text: " anyone", timestamp: [32.44, 32.82] }, + { text: " else", timestamp: [32.82, 33.2] }, + { text: " would", timestamp: [33.2, 33.46] }, + { text: " quit", timestamp: [33.46, 33.72] }, + { text: " and", timestamp: [33.72, 34.06] }, + { text: " you're", timestamp: [34.06, 34.34] }, + { text: " not", timestamp: [34.34, 34.56] }, + { text: " gonna", timestamp: [34.56, 34.8] }, + { text: " stop", timestamp: [34.8, 35.16] }, + { text: " there.", timestamp: [35.16, 35.92 /* 36.42 */] }, + { text: " No,", timestamp: [36.42, 36.84 /* 36.86 */] }, + { text: " what", timestamp: [36.86, 37.02] }, + { text: " are", timestamp: [37.02, 37.1] }, + { text: " you", timestamp: [37.1, 37.22] }, + { text: " waiting", timestamp: [37.22, 37.56] }, + { text: " for?", timestamp: [37.56, 38.08 /* 39.28 */] }, + { text: " Do", timestamp: [39.28, 39.62] }, + { text: " it!", timestamp: [39.62, 40.28 /* 41.96 */] }, + { text: " Just", timestamp: [41.96, 42.64] }, + { text: " do", timestamp: [42.64, 44.28] }, + { text: " it!", timestamp: [44.28, 44.78 /* 45.18 */] }, + { text: " Yes", timestamp: [45.18, 45.56] }, + { text: " you", timestamp: [45.56, 45.78] }, + { text: " can!", timestamp: [45.8, 46.34 /* 47.12 */] }, + { text: " Just", timestamp: [47.12, 47.56] }, + { text: " do", timestamp: [47.56, 47.8] }, + { text: " it!", timestamp: [47.8, 48.92 /* 50.92 */] }, + { text: " If", timestamp: [50.92, 52.96] }, + { text: " you're", timestamp: [52.96, 53.28] }, + { text: " tired", timestamp: [53.28, 53.44] }, + { text: " of", timestamp: [53.44, 53.72] }, + { text: " starting", timestamp: [53.72, 53.96] }, + { text: " over,", timestamp: [53.96, 54.86 /* 55.5 */] }, + { text: " stop", timestamp: [55.5, 56.06] }, + { text: " giving", timestamp: [56.06, 56.86] }, + { text: " up.", timestamp: [56.86, 57.9 /* 60.92 */] }, + ], + }, + ]; + + compare( + tokenizer._decode_asr(model_outputs, { + return_timestamps: "word", + time_precision: 0.02, + force_full_sequences: false, + }), + target, + 1e-2, + ); + }, + MAX_EXECUTION_TIME, + ); + + it( + "should handle overlapping edge case", + async () => { + const tokenizer = await WhisperTokenizer.from_pretrained("onnx-community/whisper-tiny.en_timestamped"); + + const model_outputs = [ + { + stride: [30, 0, 5], + tokens: [50257n, 50362n, 8410n, 7283n, 0n, 2329n, 8410n, 7283n, 0n, 2094n, 470n, 1309n, 534n, 10625n, 307n, 10625n, 13n, 34668n, 11n, 345n, 531n, 9439n, 11n, 523n, 655n, 8410n, 7283n, 0n, 39134n, 16592n, 10560n, 3955n, 50n, 0n, 7102n, 5446n, 46n, 0n, 25848n, 8410n, 7283n, 0n, 2773n, 661n, 4320n, 1943n, 981n, 345n, 821n, 8066n, 7765n, 510n, 290n, 670n, 1327n, 379n, 340n, 13n, 10528n, 318n, 5340n, 13n, 50256n], + token_timestamps: [0, 0, 0, 3.78, 4.22, 5.26, 6.04, 6.54, 7, 7.94, 8.58, 8.58, 8.88, 9.16, 9.54, 9.94, 10.6, 11.38, 11.88, 12.38, 12.44, 12.62, 13, 13.36, 13.64, 14.24, 14.74, 15.12, 15.4, 15.74, 16.1, 16.54, 16.54, 16.78, 17.08, 17.2, 17.36, 17.56, 18.08, 18.58, 19.38, 19.88, 22.54, 22.9, 23.24, 23.5, 24.14, 24.56, 24.7, 24.94, 24.94, 25.18, 25.54, 25.72, 26.04, 26.34, 26.46, 26.84, 27.04, 27.14, 27.54, 28.06, 29.92], + }, + { + stride: [30, 5, 5], + tokens: [50257n, 50362n, 2773n, 661n, 4320n, 1943n, 981n, 345n, 821n, 8066n, 7765n, 510n, 290n, 670n, 1327n, 379n, 340n, 13n, 10528n, 318n, 5340n, 13n, 921n, 815n, 651n, 284n, 262n, 966n, 810n, 2687n, 2073n, 561n, 11238n, 290n, 345n, 821n, 407n, 8066n, 2245n, 612n, 13n, 1400n, 11n, 644n, 389n, 345n, 4953n, 329n, 30n, 2141n, 340n, 0n, 2329n, 466n, 340n, 0n, 3363n, 11n, 345n, 460n, 0n, 2329n, 466n, 340n, 0n, 50256n], + token_timestamps: [0, 0, 0, 2.92, 3.24, 3.5, 4.14, 4.56, 4.7, 4.74, 4.92, 5.18, 5.54, 5.74, 6.04, 6.34, 6.46, 6.84, 7.04, 7.18, 7.56, 8.12, 9.68, 10.7, 10.88, 11.1, 11.24, 11.48, 11.82, 12.46, 12.82, 13.2, 13.46, 13.72, 14.08, 14.28, 14.34, 14.56, 14.82, 15.16, 15.72, 16.42, 16.82, 16.86, 17, 17.1, 17.2, 17.56, 18.06, 19.28, 19.6, 20.28, 21.96, 22.64, 24.28, 24.76, 25.18, 25.56, 25.56, 25.84, 26.36, 27.12, 27.54, 27.82, 28.16, 29.48], + }, + { + stride: [23.7728125, 5, 0], + tokens: [50257n, 50362n, 2329n, 466n, 340n, 0n, 3363n, 345n, 460n, 0n, 2329n, 466n, 340n, 0n, 1002n, 534n, 15867n, 318n, 3599n, 625n, 11n, 2245n, 3501n, 510n, 13n, 50256n], + token_timestamps: [0, 0, 0, 2.44, 4.3, 5.04, 5.06, 5.56, 5.8, 6.32, 7.12, 7.56, 7.8, 8.72, 10.04, 12.96, 13.3, 13.44, 13.72, 13.98, 14.86, 15.5, 16, 16.88, 17.76, 20.9], + }, + ]; + + const target = [ + " DO IT! Just DO IT! Don't let your dreams be dreams. Yesterday, you said tomorrow, so just DO IT! MAKE YOUR DRIMS! CONTRO! JUST DO IT! Some people dream success while you're gonna wake up and work hard at it. Nothing is impossible. You should get to the point where anyone else would quit and you're not gonna stop there. No, what are you waiting for? Do it! Just do it! Yes, you can! Just do it! If your tire is starting over, stop giving up.", + { + chunks: [ + { text: " DO", timestamp: [0, 3.78] }, + { text: " IT!", timestamp: [3.78, 4.24] }, + { text: " Just", timestamp: [5.26, 6.04] }, + { text: " DO", timestamp: [6.04, 6.54] }, + { text: " IT!", timestamp: [6.54, 7.02] }, + { text: " Don't", timestamp: [7.94, 8.58] }, + { text: " let", timestamp: [8.58, 8.88] }, + { text: " your", timestamp: [8.88, 9.16] }, + { text: " dreams", timestamp: [9.16, 9.54] }, + { text: " be", timestamp: [9.54, 9.94] }, + { text: " dreams.", timestamp: [9.94, 10.62] }, + { text: " Yesterday,", timestamp: [11.38, 11.9] }, + { text: " you", timestamp: [12.38, 12.44] }, + { text: " said", timestamp: [12.44, 12.62] }, + { text: " tomorrow,", timestamp: [12.62, 13.02] }, + { text: " so", timestamp: [13.36, 13.64] }, + { text: " just", timestamp: [13.64, 14.24] }, + { text: " DO", timestamp: [14.24, 14.74] }, + { text: " IT!", timestamp: [14.74, 15.14] }, + { text: " MAKE", timestamp: [15.4, 15.74] }, + { text: " YOUR", timestamp: [15.74, 16.1] }, + { text: " DRIMS!", timestamp: [16.1, 16.8] }, + { text: " CONTRO!", timestamp: [17.08, 17.58] }, + { text: " JUST", timestamp: [18.08, 18.58] }, + { text: " DO", timestamp: [18.58, 19.38] }, + { text: " IT!", timestamp: [19.38, 19.9] }, + { text: " Some", timestamp: [22.54, 22.9] }, + { text: " people", timestamp: [22.9, 23.24] }, + { text: " dream", timestamp: [23.24, 23.5] }, + { text: " success", timestamp: [23.5, 24.14] }, + { text: " while", timestamp: [24.14, 24.56] }, + { text: " you're", timestamp: [24.56, 24.94] }, + { text: " gonna", timestamp: [24.94, 24.94] }, + { text: " wake", timestamp: [24.94, 25.18] }, + { text: " up", timestamp: [25.18, 25.54] }, + { text: " and", timestamp: [25.54, 25.74] }, + { text: " work", timestamp: [25.74, 26.04] }, + { text: " hard", timestamp: [26.04, 26.34] }, + { text: " at", timestamp: [26.34, 26.46] }, + { text: " it.", timestamp: [26.46, 26.86] }, + { text: " Nothing", timestamp: [27.04, 27.18] }, + { text: " is", timestamp: [27.18, 27.56] }, + { text: " impossible.", timestamp: [27.56, 28.14] }, + { text: " You", timestamp: [29.68, 30.7] }, + { text: " should", timestamp: [30.7, 30.88] }, + { text: " get", timestamp: [30.88, 31.1] }, + { text: " to", timestamp: [31.1, 31.24] }, + { text: " the", timestamp: [31.24, 31.48] }, + { text: " point", timestamp: [31.48, 31.82] }, + { text: " where", timestamp: [31.82, 32.46] }, + { text: " anyone", timestamp: [32.46, 32.82] }, + { text: " else", timestamp: [32.82, 33.2] }, + { text: " would", timestamp: [33.2, 33.46] }, + { text: " quit", timestamp: [33.46, 33.72] }, + { text: " and", timestamp: [33.72, 34.08] }, + { text: " you're", timestamp: [34.08, 34.34] }, + { text: " not", timestamp: [34.34, 34.56] }, + { text: " gonna", timestamp: [34.56, 34.82] }, + { text: " stop", timestamp: [34.82, 35.16] }, + { text: " there.", timestamp: [35.16, 35.74] }, + { text: " No,", timestamp: [36.42, 36.84] }, + { text: " what", timestamp: [36.86, 37] }, + { text: " are", timestamp: [37, 37.1] }, + { text: " you", timestamp: [37.1, 37.2] }, + { text: " waiting", timestamp: [37.2, 37.56] }, + { text: " for?", timestamp: [37.56, 38.08] }, + { text: " Do", timestamp: [39.28, 39.6] }, + { text: " it!", timestamp: [39.6, 40.3] }, + { text: " Just", timestamp: [41.96, 42.64] }, + { text: " do", timestamp: [42.64, 44.28] }, + { text: " it!", timestamp: [44.28, 44.78] }, + { text: " Yes,", timestamp: [45.18, 45.56] }, + { text: " you", timestamp: [45.56, 45.84] }, + { text: " can!", timestamp: [45.8, 46.34] }, + { text: " Just", timestamp: [47.12, 47.56] }, + { text: " do", timestamp: [47.56, 47.8] }, + { text: " it!", timestamp: [47.8, 48.74] }, + { text: " If", timestamp: [50.04, 52.96] }, + { text: " your", timestamp: [52.96, 53.3] }, + { text: " tire", timestamp: [53.3, 53.44] }, + { text: " is", timestamp: [53.44, 53.72] }, + { text: " starting", timestamp: [53.72, 53.98] }, + { text: " over,", timestamp: [53.98, 54.88] }, + { text: " stop", timestamp: [55.5, 56] }, + { text: " giving", timestamp: [56, 56.88] }, + { text: " up.", timestamp: [56.88, 57.78] }, + ], + }, + ]; + + compare( + tokenizer._decode_asr(model_outputs, { + return_timestamps: "word", + time_precision: 0.02, + force_full_sequences: false, + }), + target, + 1e-2, + ); + }, + MAX_EXECUTION_TIME, + ); }); - diff --git a/tests/utils.test.js b/tests/utils.test.js index b98c8fa14..19df6e94b 100644 --- a/tests/utils.test.js +++ b/tests/utils.test.js @@ -1,72 +1,62 @@ - -import { AutoProcessor, hamming, hanning, mel_filter_bank } from '../src/transformers.js'; -import { getFile } from '../src/utils/hub.js'; - -import { MAX_TEST_EXECUTION_TIME } from './init.js'; -import { compare } from './test_utils.js'; - -describe('Utilities', () => { - - describe('Audio utilities', () => { - - it('should calculate MEL filters', async () => { - - // NOTE: Uses official HF implementation as reference: - const processor = await AutoProcessor.from_pretrained('openai/whisper-tiny.en'); - const config = processor.feature_extractor.config; - - // True MEL filters - const original_mel_filters = config.mel_filters; - - // Calculated MEL filters - const calculated_mel_filters = mel_filter_bank( - Math.floor(1 + config.n_fft / 2), // num_frequency_bins - config.feature_size, // num_mel_filters - 0.0, // min_frequency - 8000.0, // max_frequency - config.sampling_rate, // sampling_rate - "slaney", // norm - "slaney", // mel_scale - ); - - const original = original_mel_filters.flat(); - const calculated = calculated_mel_filters.flat(); - - // Compute max difference - const maxdiff = original.reduce((maxdiff, _, i) => { - const diff = Math.abs(original[i] - calculated[i]); - return Math.max(maxdiff, diff); - }, -Infinity); - expect(maxdiff).toBeGreaterThanOrEqual(0); - expect(maxdiff).toBeLessThan(1e-6); - - }, MAX_TEST_EXECUTION_TIME); - - it('should calculate window', async () => { - compare( - hanning(10), - new Float64Array( - [0.0, 0.11697777844051105, 0.41317591116653485, 0.75, 0.9698463103929542, 0.9698463103929542, 0.75, 0.41317591116653485, 0.11697777844051105, 0.0] - ) - ); - compare( - hamming(10), - new Float64Array( - [0.08000000000000002, 0.1876195561652702, 0.46012183827321207, 0.7700000000000001, 0.9722586055615179, 0.9722586055615179, 0.7700000000000001, 0.46012183827321207, 0.1876195561652702, 0.08000000000000002], - ) - ); - - }, MAX_TEST_EXECUTION_TIME); - }); - - describe('Hub utilities', () => { - - it('Read data from blob', async () => { - const blob = new Blob(['Hello, world!'], { type: 'text/plain' }); - const blobUrl = URL.createObjectURL(blob); - const data = await getFile(blobUrl); - expect(await data.text()).toBe('Hello, world!'); - }); - +import { AutoProcessor, hamming, hanning, mel_filter_bank } from "../src/transformers.js"; +import { getFile } from "../src/utils/hub.js"; + +import { MAX_TEST_EXECUTION_TIME } from "./init.js"; +import { compare } from "./test_utils.js"; + +describe("Utilities", () => { + describe("Audio utilities", () => { + it( + "should calculate MEL filters", + async () => { + // NOTE: Uses official HF implementation as reference: + const processor = await AutoProcessor.from_pretrained("openai/whisper-tiny.en"); + const config = processor.feature_extractor.config; + + // True MEL filters + const original_mel_filters = config.mel_filters; + + // Calculated MEL filters + const calculated_mel_filters = mel_filter_bank( + Math.floor(1 + config.n_fft / 2), // num_frequency_bins + config.feature_size, // num_mel_filters + 0.0, // min_frequency + 8000.0, // max_frequency + config.sampling_rate, // sampling_rate + "slaney", // norm + "slaney", // mel_scale + ); + + const original = original_mel_filters.flat(); + const calculated = calculated_mel_filters.flat(); + + // Compute max difference + const maxdiff = original.reduce((maxdiff, _, i) => { + const diff = Math.abs(original[i] - calculated[i]); + return Math.max(maxdiff, diff); + }, -Infinity); + expect(maxdiff).toBeGreaterThanOrEqual(0); + expect(maxdiff).toBeLessThan(1e-6); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "should calculate window", + async () => { + compare(hanning(10), new Float64Array([0.0, 0.11697777844051105, 0.41317591116653485, 0.75, 0.9698463103929542, 0.9698463103929542, 0.75, 0.41317591116653485, 0.11697777844051105, 0.0])); + compare(hamming(10), new Float64Array([0.08000000000000002, 0.1876195561652702, 0.46012183827321207, 0.7700000000000001, 0.9722586055615179, 0.9722586055615179, 0.7700000000000001, 0.46012183827321207, 0.1876195561652702, 0.08000000000000002])); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + + describe("Hub utilities", () => { + it("Read data from blob", async () => { + const blob = new Blob(["Hello, world!"], { type: "text/plain" }); + const blobUrl = URL.createObjectURL(blob); + const data = await getFile(blobUrl); + expect(await data.text()).toBe("Hello, world!"); }); + }); }); diff --git a/tests/utils/generation.test.js b/tests/utils/generation.test.js index aa9edc3f9..27b6a28d5 100644 --- a/tests/utils/generation.test.js +++ b/tests/utils/generation.test.js @@ -1,144 +1,166 @@ - -import { AutoTokenizer } from '../../src/tokenizers.js'; -import { AutoModelForSeq2SeqLM, AutoModelForCausalLM } from '../../src/models.js'; -import { init, MAX_TEST_EXECUTION_TIME, MAX_MODEL_LOAD_TIME, MAX_MODEL_DISPOSE_TIME } from '../init.js'; +import { AutoTokenizer } from "../../src/tokenizers.js"; +import { AutoModelForSeq2SeqLM, AutoModelForCausalLM } from "../../src/models.js"; +import { init, MAX_TEST_EXECUTION_TIME, MAX_MODEL_LOAD_TIME, MAX_MODEL_DISPOSE_TIME } from "../init.js"; // Initialise the testing environment init(); // Helper function to generate text const generate = async (model, tokenizer, text, options) => { - const inputs = tokenizer(text); - return await model.generate({ - ...inputs, - ...options, - }); -} - -describe('Generation parameters', () => { - - // List all models which will be tested - const models = [ - 'hf-internal-testing/tiny-random-T5ForConditionalGeneration', // - 'hf-internal-testing/tiny-random-LlamaForCausalLM', // decoder-only - ]; - const DUMMY_TEXT = 'hello'; - - describe(`encoder-decoder (${models[0]})`, () => { - const model_id = models[0]; - - let model; - let tokenizer; - beforeAll(async () => { - model = await AutoModelForSeq2SeqLM.from_pretrained(model_id); - tokenizer = await AutoTokenizer.from_pretrained(model_id); - }, MAX_MODEL_LOAD_TIME); - - // NOTE: Since `max_length` defaults to 20, this case also tests that. - it('default', async () => { - const outputs = await generate(model, tokenizer, DUMMY_TEXT, {}); - expect(outputs.dims.at(-1)).toEqual(20); - }, MAX_TEST_EXECUTION_TIME); - - it('max_new_tokens', async () => { - const MAX_NEW_TOKENS = 5; - const outputs = await generate(model, tokenizer, DUMMY_TEXT, { - max_new_tokens: MAX_NEW_TOKENS, - }); - expect(outputs.dims.at(-1)).toEqual(MAX_NEW_TOKENS + 1); // + 1 due to forced BOS token - }, MAX_TEST_EXECUTION_TIME); - - it('min_length', async () => { - const MIN_LENGTH = 3; - const MAX_LENGTH = 5; - const outputs = await generate(model, tokenizer, DUMMY_TEXT, { - eos_token_id: 0, - min_length: MIN_LENGTH, - max_length: MAX_LENGTH, - }); - expect(outputs.tolist()).toEqual([ - [0n, 11924n, 11924n, 11924n, 11924n], - ]); - expect(outputs.dims.at(-1)).toBeGreaterThanOrEqual(MIN_LENGTH); - }, MAX_TEST_EXECUTION_TIME); - - it('min_new_tokens', async () => { - const MIN_NEW_TOKENS = 2; - const MAX_LENGTH = 5; - const outputs = await generate(model, tokenizer, DUMMY_TEXT, { - eos_token_id: 0, - min_new_tokens: MIN_NEW_TOKENS, - max_length: MAX_LENGTH, - }); - expect(outputs.tolist()).toEqual([ - [0n, 11924n, 11924n, 11924n, 11924n], - ]); - expect(outputs.dims.at(-1)).toBeGreaterThanOrEqual(MIN_NEW_TOKENS); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); - }); - - describe(`decoder-only (${models[1]})`, () => { - const model_id = models[1]; - - let model; - let tokenizer; - beforeAll(async () => { - model = await AutoModelForCausalLM.from_pretrained(model_id); - tokenizer = await AutoTokenizer.from_pretrained(model_id); - }, MAX_MODEL_LOAD_TIME); - - // NOTE: Since `max_length` defaults to 20, this case also tests that. - it('default', async () => { - const outputs = await generate(model, tokenizer, DUMMY_TEXT, {}); - expect(outputs.dims.at(-1)).toEqual(20); - }, MAX_TEST_EXECUTION_TIME); - - it('max_new_tokens', async () => { - const MAX_NEW_TOKENS = 5; - const PROMPT_LENGTH = 2; // BOS + DUMMY_TEXT - const outputs = await generate(model, tokenizer, DUMMY_TEXT, { - max_new_tokens: MAX_NEW_TOKENS, - }); - const expected_length = PROMPT_LENGTH + MAX_NEW_TOKENS; - expect(outputs.dims.at(-1)).toEqual(expected_length); - }, MAX_TEST_EXECUTION_TIME); - - it('min_length', async () => { - const MIN_LENGTH = 4; - const outputs = await generate(model, tokenizer, DUMMY_TEXT, { - eos_token_id: [ - 18547, // min_length will suppress this token (generated by default) - 16012, // stop at this token - ], - min_length: MIN_LENGTH, - }); - expect(outputs.tolist()).toEqual([ - [1n, 22172n, 31583n, 18824n, 16621n, 8136n, 16012n], - ]); - expect(outputs.dims.at(-1)).toBeGreaterThanOrEqual(MIN_LENGTH); - }, MAX_TEST_EXECUTION_TIME); - - it('min_new_tokens', async () => { - const MIN_NEW_TOKENS = 2; - const outputs = await generate(model, tokenizer, DUMMY_TEXT, { - eos_token_id: [ - 18547, // min_new_tokens will suppress this token (generated by default) - 16012, // stop at this token - ], - min_new_tokens: MIN_NEW_TOKENS, - }); - expect(outputs.tolist()).toEqual([ - [1n, 22172n, 31583n, 18824n, 16621n, 8136n, 16012n], - ]); - expect(outputs.dims.at(-1)).toBeGreaterThanOrEqual(MIN_NEW_TOKENS); - }, MAX_TEST_EXECUTION_TIME); - - afterAll(async () => { - await model?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); - }); + const inputs = tokenizer(text); + return await model.generate({ + ...inputs, + ...options, + }); +}; + +describe("Generation parameters", () => { + // List all models which will be tested + const models = [ + "hf-internal-testing/tiny-random-T5ForConditionalGeneration", // + "hf-internal-testing/tiny-random-LlamaForCausalLM", // decoder-only + ]; + const DUMMY_TEXT = "hello"; + + describe(`encoder-decoder (${models[0]})`, () => { + const model_id = models[0]; + + let model; + let tokenizer; + beforeAll(async () => { + model = await AutoModelForSeq2SeqLM.from_pretrained(model_id); + tokenizer = await AutoTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + // NOTE: Since `max_length` defaults to 20, this case also tests that. + it( + "default", + async () => { + const outputs = await generate(model, tokenizer, DUMMY_TEXT, {}); + expect(outputs.dims.at(-1)).toEqual(20); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "max_new_tokens", + async () => { + const MAX_NEW_TOKENS = 5; + const outputs = await generate(model, tokenizer, DUMMY_TEXT, { + max_new_tokens: MAX_NEW_TOKENS, + }); + expect(outputs.dims.at(-1)).toEqual(MAX_NEW_TOKENS + 1); // + 1 due to forced BOS token + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "min_length", + async () => { + const MIN_LENGTH = 3; + const MAX_LENGTH = 5; + const outputs = await generate(model, tokenizer, DUMMY_TEXT, { + eos_token_id: 0, + min_length: MIN_LENGTH, + max_length: MAX_LENGTH, + }); + expect(outputs.tolist()).toEqual([[0n, 11924n, 11924n, 11924n, 11924n]]); + expect(outputs.dims.at(-1)).toBeGreaterThanOrEqual(MIN_LENGTH); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "min_new_tokens", + async () => { + const MIN_NEW_TOKENS = 2; + const MAX_LENGTH = 5; + const outputs = await generate(model, tokenizer, DUMMY_TEXT, { + eos_token_id: 0, + min_new_tokens: MIN_NEW_TOKENS, + max_length: MAX_LENGTH, + }); + expect(outputs.tolist()).toEqual([[0n, 11924n, 11924n, 11924n, 11924n]]); + expect(outputs.dims.at(-1)).toBeGreaterThanOrEqual(MIN_NEW_TOKENS); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + + describe(`decoder-only (${models[1]})`, () => { + const model_id = models[1]; + + let model; + let tokenizer; + beforeAll(async () => { + model = await AutoModelForCausalLM.from_pretrained(model_id); + tokenizer = await AutoTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + // NOTE: Since `max_length` defaults to 20, this case also tests that. + it( + "default", + async () => { + const outputs = await generate(model, tokenizer, DUMMY_TEXT, {}); + expect(outputs.dims.at(-1)).toEqual(20); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "max_new_tokens", + async () => { + const MAX_NEW_TOKENS = 5; + const PROMPT_LENGTH = 2; // BOS + DUMMY_TEXT + const outputs = await generate(model, tokenizer, DUMMY_TEXT, { + max_new_tokens: MAX_NEW_TOKENS, + }); + const expected_length = PROMPT_LENGTH + MAX_NEW_TOKENS; + expect(outputs.dims.at(-1)).toEqual(expected_length); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "min_length", + async () => { + const MIN_LENGTH = 4; + const outputs = await generate(model, tokenizer, DUMMY_TEXT, { + eos_token_id: [ + 18547, // min_length will suppress this token (generated by default) + 16012, // stop at this token + ], + min_length: MIN_LENGTH, + }); + expect(outputs.tolist()).toEqual([[1n, 22172n, 31583n, 18824n, 16621n, 8136n, 16012n]]); + expect(outputs.dims.at(-1)).toBeGreaterThanOrEqual(MIN_LENGTH); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "min_new_tokens", + async () => { + const MIN_NEW_TOKENS = 2; + const outputs = await generate(model, tokenizer, DUMMY_TEXT, { + eos_token_id: [ + 18547, // min_new_tokens will suppress this token (generated by default) + 16012, // stop at this token + ], + min_new_tokens: MIN_NEW_TOKENS, + }); + expect(outputs.tolist()).toEqual([[1n, 22172n, 31583n, 18824n, 16621n, 8136n, 16012n]]); + expect(outputs.dims.at(-1)).toBeGreaterThanOrEqual(MIN_NEW_TOKENS); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); }); diff --git a/tests/utils/hub.test.js b/tests/utils/hub.test.js index 8b57b5965..19077f009 100644 --- a/tests/utils/hub.test.js +++ b/tests/utils/hub.test.js @@ -1,36 +1,40 @@ +import { AutoModel, PreTrainedModel } from "../../src/models.js"; - -import { AutoModel, PreTrainedModel } from '../../src/models.js'; - -import { MAX_TEST_EXECUTION_TIME } from '../init.js'; +import { MAX_TEST_EXECUTION_TIME } from "../init.js"; // TODO: Set cache folder to a temp directory -describe('Hub', () => { - - describe('Loading models', () => { - - it('should load a model from the local cache', async () => { - // 1. Local model exists (doesn't matter about status of remote file since local is tried first) - const model = await AutoModel.from_pretrained('hf-internal-testing/tiny-random-T5ForConditionalGeneration'); - expect(model).toBeInstanceOf(PreTrainedModel); - }, MAX_TEST_EXECUTION_TIME); - - it('should load a model from the remote cache', async () => { - // 2. Local model doesn't exist, remote file exists - // This tests that fallback functionality is working - const model = await AutoModel.from_pretrained('hf-internal-testing/tiny-random-T5ForConditionalGeneration'); - expect(model).toBeInstanceOf(PreTrainedModel); - }, MAX_TEST_EXECUTION_TIME); - - it('should fail to load a model', async () => { - // 3. Local model doesn't exist, remote file doesn't exist - // This tests that error handling is working. - await expect( - AutoModel.from_pretrained('hf-internal-testing/this-model-does-not-exist') - ).rejects - .toBeInstanceOf(Error); - }, MAX_TEST_EXECUTION_TIME); - }); - +describe("Hub", () => { + describe("Loading models", () => { + it( + "should load a model from the local cache", + async () => { + // 1. Local model exists (doesn't matter about status of remote file since local is tried first) + const model = await AutoModel.from_pretrained("hf-internal-testing/tiny-random-T5ForConditionalGeneration"); + expect(model).toBeInstanceOf(PreTrainedModel); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "should load a model from the remote cache", + async () => { + // 2. Local model doesn't exist, remote file exists + // This tests that fallback functionality is working + const model = await AutoModel.from_pretrained("hf-internal-testing/tiny-random-T5ForConditionalGeneration"); + expect(model).toBeInstanceOf(PreTrainedModel); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "should fail to load a model", + async () => { + // 3. Local model doesn't exist, remote file doesn't exist + // This tests that error handling is working. + await expect(AutoModel.from_pretrained("hf-internal-testing/this-model-does-not-exist")).rejects.toBeInstanceOf(Error); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); }); diff --git a/tests/utils/logits_process.test.js b/tests/utils/logits_process.test.js index d5c9ce97f..5da188ed4 100644 --- a/tests/utils/logits_process.test.js +++ b/tests/utils/logits_process.test.js @@ -1,12 +1,11 @@ - import { - // Pipelines - pipeline, - TextGenerationPipeline, -} from '../../src/transformers.js'; + // Pipelines + pipeline, + TextGenerationPipeline, +} from "../../src/transformers.js"; -import { init } from '../init.js'; -import { compare } from '../test_utils.js'; +import { init } from "../init.js"; +import { compare } from "../test_utils.js"; init(); const MAX_MODEL_LOAD_TIME = 10_000; // 10 seconds @@ -14,68 +13,76 @@ const MAX_TEST_EXECUTION_TIME = 10_000; // 10 seconds const MAX_MODEL_DISPOSE_TIME = 1_000; // 1 second const DEFAULT_MODEL_OPTIONS = { - dtype: 'fp32', -} - -describe('Logits Processors', () => { - - describe('text-generation', () => { - const model_id = 'hf-internal-testing/tiny-random-LlamaForCausalLM'; - - /** @type {TextGenerationPipeline} */ - let pipe; - beforeAll(async () => { - pipe = await pipeline('text-generation', model_id, { - // TODO move to config - ...DEFAULT_MODEL_OPTIONS, - }); - }, MAX_MODEL_LOAD_TIME); - - describe('bad_word_ids', () => { - it('basic', async () => { - const text_input = 'hello'; - - const generated_text_target = ' Bert explicit wed digasset'; - const text_target = [{ generated_text: text_input + generated_text_target }] - - const output = await pipe(text_input, { - max_new_tokens: 5, bad_words_ids: [ - // default: [22172n, 18547n, 8136n, 16012n, 28064n, 11361n] - [18547], - - // block #1: [22172n, 16662n, 6261n, 18916n, 29109n, 799n] - [6261, 18916], - ] - }); - compare(output, text_target); - }, MAX_TEST_EXECUTION_TIME); - - it('many bad words', async () => { - const text_input = 'hello'; - - const generated_text_target = 'erdingsdeletearus)?nor'; - const text_target = [{ generated_text: text_input + generated_text_target }] - - // Construct long list of bad words - const bad_words_ids = []; - // default: [22172n, 18547n, 8136n, 16012n, 28064n, 11361n] - for (let i = 0; i < 100000; ++i) { - bad_words_ids.push([i * 2]); // block all even numbers - } - // block #1: [22172n, 18547n, 8143n, 30327n, 20061n, 18193n] - bad_words_ids.push([8143, 30327]); - - // block #2: [22172n, 18547n, 8143n, 29485n, 3799n, 29331n] - bad_words_ids.push([18547, 8143, 29485]); - - // block #3: [22172n, 18547n, 8143n, 26465n, 6877n, 15459n] - const output = await pipe(text_input, { max_new_tokens: 5, bad_words_ids }); - compare(output, text_target); - }, MAX_TEST_EXECUTION_TIME); - }); - - afterAll(async () => { - await pipe?.dispose(); - }, MAX_MODEL_DISPOSE_TIME); + dtype: "fp32", +}; + +describe("Logits Processors", () => { + describe("text-generation", () => { + const model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM"; + + /** @type {TextGenerationPipeline} */ + let pipe; + beforeAll(async () => { + pipe = await pipeline("text-generation", model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + }, MAX_MODEL_LOAD_TIME); + + describe("bad_word_ids", () => { + it( + "basic", + async () => { + const text_input = "hello"; + + const generated_text_target = " Bert explicit wed digasset"; + const text_target = [{ generated_text: text_input + generated_text_target }]; + + const output = await pipe(text_input, { + max_new_tokens: 5, + bad_words_ids: [ + // default: [22172n, 18547n, 8136n, 16012n, 28064n, 11361n] + [18547], + + // block #1: [22172n, 16662n, 6261n, 18916n, 29109n, 799n] + [6261, 18916], + ], + }); + compare(output, text_target); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "many bad words", + async () => { + const text_input = "hello"; + + const generated_text_target = "erdingsdeletearus)?nor"; + const text_target = [{ generated_text: text_input + generated_text_target }]; + + // Construct long list of bad words + const bad_words_ids = []; + // default: [22172n, 18547n, 8136n, 16012n, 28064n, 11361n] + for (let i = 0; i < 100000; ++i) { + bad_words_ids.push([i * 2]); // block all even numbers + } + // block #1: [22172n, 18547n, 8143n, 30327n, 20061n, 18193n] + bad_words_ids.push([8143, 30327]); + + // block #2: [22172n, 18547n, 8143n, 29485n, 3799n, 29331n] + bad_words_ids.push([18547, 8143, 29485]); + + // block #3: [22172n, 18547n, 8143n, 26465n, 6877n, 15459n] + const output = await pipe(text_input, { max_new_tokens: 5, bad_words_ids }); + compare(output, text_target); + }, + MAX_TEST_EXECUTION_TIME, + ); }); + + afterAll(async () => { + await pipe?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); });