diff --git a/.changeset/wild-kings-nail.md b/.changeset/wild-kings-nail.md new file mode 100644 index 0000000..a2b9846 --- /dev/null +++ b/.changeset/wild-kings-nail.md @@ -0,0 +1,5 @@ +--- +"tfjs-image-node": minor +--- + +add option to use specified metadata diff --git a/README.md b/README.md index 3599cec..98a892f 100644 --- a/README.md +++ b/README.md @@ -86,11 +86,22 @@ const image = "https://www.stgeorges.nhs.uk/wp-content/uploads/2014/03/hand-2.jp PLATFORM - "node" or "classic" + "node" or "classic" (optional) Choose the platform to use for the computation of the prediction. If you want to use the tfjs-node platform, use "node" as the parameter, otherwise use "classic". + + + METADATA + + + metadata.json (optional) + + + If you want to specify a set of metadata for the model. + + diff --git a/src/index.ts b/src/index.ts index 8c3e082..5403e1f 100644 --- a/src/index.ts +++ b/src/index.ts @@ -3,8 +3,16 @@ const tfNode = require("@tensorflow/tfjs-node"); const tfJs = require("@tensorflow/tfjs"); let tf: any; -interface IMetadata extends JSON { +interface IMetadata { + tfjsVersion: string; + tmVersion: string; + packageVersion: string; + packageName: string; + timeStamp: string; + userMetadata: {}; + modelName: string; labels: string[]; + imageSize: number; } type ResultType = { @@ -15,7 +23,8 @@ type ResultType = { type ClassifyImageType = ( MODEL_DIR_PATH: string, IMAGE_FILE_PATH: string, - PLATFORM?: "node" | "classic" + PLATFORM?: "node" | "classic", + METADATA?: IMetadata ) => Promise; const filterInputPath = (inputPath: string) => { @@ -28,7 +37,8 @@ const filterInputPath = (inputPath: string) => { const classifyImage: ClassifyImageType = async ( MODEL_DIR_PATH, IMAGE_FILE_PATH, - PLATFORM = "node" + PLATFORM = "node", + METADATA ) => { PLATFORM === "node" ? (tf = tfNode) : (tf = tfJs); @@ -36,21 +46,20 @@ const classifyImage: ClassifyImageType = async ( return new Error("MISSING_PARAMETER"); } - MODEL_DIR_PATH = filterInputPath(MODEL_DIR_PATH); - - const res = await fetch(`${MODEL_DIR_PATH}/metadata.json`); - if (res.status !== 200) { - return new Error("METADATA_NOT_FOUND"); + let labels: string[]; + + if (!METADATA) { + const res = await fetch(`${MODEL_DIR_PATH}/metadata.json`); + if (res.status !== 200) { + return new Error("METADATA_NOT_FOUND" + res); + } else { + const json = await res.json(); + labels = json["labels"]; + } + } else { + labels = METADATA["labels"]; } - const METADATA: IMetadata = await res.json(); - - if (METADATA["labels"].length === 0 || METADATA["labels"]! instanceof Array) { - return new Error("NO_METADATA_LABELS"); - } - - let labels: string[] = METADATA["labels"]; - const model = await tf.loadLayersModel(`${MODEL_DIR_PATH}/model.json`); const image = await Jimp.read(IMAGE_FILE_PATH); diff --git a/test/classifyImage.test.ts b/test/classifyImage.test.ts index a7fa4d9..ac00aca 100644 --- a/test/classifyImage.test.ts +++ b/test/classifyImage.test.ts @@ -1,5 +1,6 @@ import { describe, it, expect } from "vitest"; import classifyImage from "../src"; +import * as metadata from "./testFiles/metadata.json"; const model = "https://teachablemachine.withgoogle.com/models/jAIOHvmge"; const imageHand = "https://www.stgeorges.nhs.uk/wp-content/uploads/2014/03/hand-2.jpeg"; @@ -46,7 +47,17 @@ describe("classifyImage function - Node", async () => { expect(result[0].label).toBe("Hand"); } }); + it("works with specified metadata", async () => { + const result = await classifyImage(model, imageNoHand, undefined, metadata); + if (result instanceof Error) { + return new Error(); + } else { + expect(result[0].probability).not.toBe(null); + } + }); }); + + /* ERROR BOUNDRIES */ describe("Error boundries", async () => { it("returns an error when missing a parameter", async () => { //@ts-expect-error @@ -56,6 +67,7 @@ describe("classifyImage function - Node", async () => { }); }); + /* IMAGE TYPES */ describe("Image types", async () => { it("returns a result on url image-input", async () => { const result = await classifyImage(model, imageHand); diff --git a/test/classifyImageClassic.test.ts b/test/classifyImageClassic.test.ts index 8a37ddb..7c96963 100644 --- a/test/classifyImageClassic.test.ts +++ b/test/classifyImageClassic.test.ts @@ -1,5 +1,6 @@ import { describe, it, expect } from "vitest"; import classifyImage from "../src"; +import * as metadata from "./testFiles/metadata.json"; const model = "https://teachablemachine.withgoogle.com/models/jAIOHvmge"; const imageHand = "https://www.stgeorges.nhs.uk/wp-content/uploads/2014/03/hand-2.jpeg"; @@ -38,6 +39,15 @@ describe("classifyImage function - Classic", async () => { } }); + it("works with specified metadata", async () => { + const result = await classifyImage(model, imageNoHand, undefined, metadata); + if (result instanceof Error) { + return new Error(); + } else { + expect(result[0].probability).not.toBe(null); + } + }); + it("returns when MODEL_DIR_PATH ends with slash", async () => { const result = await classifyImage(model + "/", imageHand); if (result instanceof Error) { diff --git a/test/testFiles/metadata.json b/test/testFiles/metadata.json new file mode 100644 index 0000000..e9a6292 --- /dev/null +++ b/test/testFiles/metadata.json @@ -0,0 +1 @@ +{"tfjsVersion":"1.3.1","tmVersion":"2.4.7","packageVersion":"0.8.4-alpha2","packageName":"@teachablemachine/image","timeStamp":"2023-10-21T07:45:22.875Z","userMetadata":{},"modelName":"tm-my-image-model","labels":["Hand","No hand"],"imageSize":224} \ No newline at end of file diff --git a/tsconfig.json b/tsconfig.json index 2df40f8..a6d2bbc 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -7,6 +7,7 @@ "esModuleInterop": true, "forceConsistentCasingInFileNames": true, "strict": true, - "skipLibCheck": true + "skipLibCheck": true, + "resolveJsonModule": true } }