diff --git a/.changeset/wild-kings-nail.md b/.changeset/wild-kings-nail.md
new file mode 100644
index 0000000..a2b9846
--- /dev/null
+++ b/.changeset/wild-kings-nail.md
@@ -0,0 +1,5 @@
+---
+"tfjs-image-node": minor
+---
+
+add option to use specified metadata
diff --git a/README.md b/README.md
index 3599cec..98a892f 100644
--- a/README.md
+++ b/README.md
@@ -86,11 +86,22 @@ const image = "https://www.stgeorges.nhs.uk/wp-content/uploads/2014/03/hand-2.jp
PLATFORM
- "node" or "classic"
+ "node" or "classic" (optional)
|
Choose the platform to use for the computation of the prediction. If you want to use the tfjs-node platform, use "node" as the parameter, otherwise use "classic".
|
+
+
+ METADATA
+ |
+
+ metadata.json (optional)
+ |
+
+ If you want to specify a set of metadata for the model.
+ |
+
diff --git a/src/index.ts b/src/index.ts
index 8c3e082..5403e1f 100644
--- a/src/index.ts
+++ b/src/index.ts
@@ -3,8 +3,16 @@ const tfNode = require("@tensorflow/tfjs-node");
const tfJs = require("@tensorflow/tfjs");
let tf: any;
-interface IMetadata extends JSON {
+interface IMetadata {
+ tfjsVersion: string;
+ tmVersion: string;
+ packageVersion: string;
+ packageName: string;
+ timeStamp: string;
+ userMetadata: {};
+ modelName: string;
labels: string[];
+ imageSize: number;
}
type ResultType = {
@@ -15,7 +23,8 @@ type ResultType = {
type ClassifyImageType = (
MODEL_DIR_PATH: string,
IMAGE_FILE_PATH: string,
- PLATFORM?: "node" | "classic"
+ PLATFORM?: "node" | "classic",
+ METADATA?: IMetadata
) => Promise;
const filterInputPath = (inputPath: string) => {
@@ -28,7 +37,8 @@ const filterInputPath = (inputPath: string) => {
const classifyImage: ClassifyImageType = async (
MODEL_DIR_PATH,
IMAGE_FILE_PATH,
- PLATFORM = "node"
+ PLATFORM = "node",
+ METADATA
) => {
PLATFORM === "node" ? (tf = tfNode) : (tf = tfJs);
@@ -36,21 +46,20 @@ const classifyImage: ClassifyImageType = async (
return new Error("MISSING_PARAMETER");
}
- MODEL_DIR_PATH = filterInputPath(MODEL_DIR_PATH);
-
- const res = await fetch(`${MODEL_DIR_PATH}/metadata.json`);
- if (res.status !== 200) {
- return new Error("METADATA_NOT_FOUND");
+ let labels: string[];
+
+ if (!METADATA) {
+ const res = await fetch(`${MODEL_DIR_PATH}/metadata.json`);
+ if (res.status !== 200) {
+ return new Error("METADATA_NOT_FOUND" + res);
+ } else {
+ const json = await res.json();
+ labels = json["labels"];
+ }
+ } else {
+ labels = METADATA["labels"];
}
- const METADATA: IMetadata = await res.json();
-
- if (METADATA["labels"].length === 0 || METADATA["labels"]! instanceof Array) {
- return new Error("NO_METADATA_LABELS");
- }
-
- let labels: string[] = METADATA["labels"];
-
const model = await tf.loadLayersModel(`${MODEL_DIR_PATH}/model.json`);
const image = await Jimp.read(IMAGE_FILE_PATH);
diff --git a/test/classifyImage.test.ts b/test/classifyImage.test.ts
index a7fa4d9..ac00aca 100644
--- a/test/classifyImage.test.ts
+++ b/test/classifyImage.test.ts
@@ -1,5 +1,6 @@
import { describe, it, expect } from "vitest";
import classifyImage from "../src";
+import * as metadata from "./testFiles/metadata.json";
const model = "https://teachablemachine.withgoogle.com/models/jAIOHvmge";
const imageHand = "https://www.stgeorges.nhs.uk/wp-content/uploads/2014/03/hand-2.jpeg";
@@ -46,7 +47,17 @@ describe("classifyImage function - Node", async () => {
expect(result[0].label).toBe("Hand");
}
});
+ it("works with specified metadata", async () => {
+ const result = await classifyImage(model, imageNoHand, undefined, metadata);
+ if (result instanceof Error) {
+ return new Error();
+ } else {
+ expect(result[0].probability).not.toBe(null);
+ }
+ });
});
+
+ /* ERROR BOUNDRIES */
describe("Error boundries", async () => {
it("returns an error when missing a parameter", async () => {
//@ts-expect-error
@@ -56,6 +67,7 @@ describe("classifyImage function - Node", async () => {
});
});
+ /* IMAGE TYPES */
describe("Image types", async () => {
it("returns a result on url image-input", async () => {
const result = await classifyImage(model, imageHand);
diff --git a/test/classifyImageClassic.test.ts b/test/classifyImageClassic.test.ts
index 8a37ddb..7c96963 100644
--- a/test/classifyImageClassic.test.ts
+++ b/test/classifyImageClassic.test.ts
@@ -1,5 +1,6 @@
import { describe, it, expect } from "vitest";
import classifyImage from "../src";
+import * as metadata from "./testFiles/metadata.json";
const model = "https://teachablemachine.withgoogle.com/models/jAIOHvmge";
const imageHand = "https://www.stgeorges.nhs.uk/wp-content/uploads/2014/03/hand-2.jpeg";
@@ -38,6 +39,15 @@ describe("classifyImage function - Classic", async () => {
}
});
+ it("works with specified metadata", async () => {
+ const result = await classifyImage(model, imageNoHand, undefined, metadata);
+ if (result instanceof Error) {
+ return new Error();
+ } else {
+ expect(result[0].probability).not.toBe(null);
+ }
+ });
+
it("returns when MODEL_DIR_PATH ends with slash", async () => {
const result = await classifyImage(model + "/", imageHand);
if (result instanceof Error) {
diff --git a/test/testFiles/metadata.json b/test/testFiles/metadata.json
new file mode 100644
index 0000000..e9a6292
--- /dev/null
+++ b/test/testFiles/metadata.json
@@ -0,0 +1 @@
+{"tfjsVersion":"1.3.1","tmVersion":"2.4.7","packageVersion":"0.8.4-alpha2","packageName":"@teachablemachine/image","timeStamp":"2023-10-21T07:45:22.875Z","userMetadata":{},"modelName":"tm-my-image-model","labels":["Hand","No hand"],"imageSize":224}
\ No newline at end of file
diff --git a/tsconfig.json b/tsconfig.json
index 2df40f8..a6d2bbc 100644
--- a/tsconfig.json
+++ b/tsconfig.json
@@ -7,6 +7,7 @@
"esModuleInterop": true,
"forceConsistentCasingInFileNames": true,
"strict": true,
- "skipLibCheck": true
+ "skipLibCheck": true,
+ "resolveJsonModule": true
}
}