diff --git a/README.md b/README.md index 052af7a0..e8b44de6 100644 --- a/README.md +++ b/README.md @@ -41,9 +41,9 @@ import { FetchStore } from "@zarrita/storage"; const store = new FetchStore("http://localhost:8080/data.zarr"); // open array from root (note that dtype is unknown) -const arr = await zarr.open.v2(store, { kind: "array" }); // zarr.Array +const arr = await zarr.open.v2(store, { kind: "array" }); // zarr.Array -arr; // zarr.Array +arr; // zarr.Array arr.shape; // [5, 10] arr.chunk_shape; // [2, 5] arr.dtype; // "int32" diff --git a/packages/core/__tests__/hierarchy.test.ts b/packages/core/__tests__/hierarchy.test.ts new file mode 100644 index 00000000..f58e41d1 --- /dev/null +++ b/packages/core/__tests__/hierarchy.test.ts @@ -0,0 +1,114 @@ +import { describe, expect, expectTypeOf, test } from "vitest"; +import { Array, Group } from "../src/hierarchy.js"; +import type { + ArrayMetadata, + BigintDataType, + Bool, + DataType, + Int8, + NumberDataType, + Raw, +} from "../src/metadata.js"; + +const array_metadata = { + zarr_format: 3, + node_type: "array", + data_type: "int8", + shape: [10, 10], + chunk_grid: { + name: "regular", + configuration: { + chunk_shape: [5, 5], + }, + }, + chunk_key_encoding: { + name: "default", + configuration: { + separator: "/", + }, + }, + codecs: [], + fill_value: 0, + attributes: { answer: 42 }, +} satisfies ArrayMetadata; + +describe("Array", () => { + test("constructor", async () => { + let arr = new Array(new Map(), "/", array_metadata); + expect({ + shape: arr.shape, + chunk_shape: arr.chunk_shape, + dtype: arr.dtype, + fill_value: arr.fill_value, + attrs: await arr.attrs(), + path: arr.path, + codec_pipeline: arr.codec_pipeline, + store: arr.store, + }).toMatchInlineSnapshot(` + { + "attrs": { + "answer": 42, + }, + "chunk_shape": [ + 5, + 5, + ], + "codec_pipeline": { + "decode": [Function], + "encode": [Function], + }, + "dtype": "int8", + "fill_value": 0, + "path": "/", + "shape": [ + 10, + 10, + ], + "store": Map {}, + } + `); + }); + + test("Array.is", () => { + let arr = new Array(new Map(), "/", array_metadata as any); + expectTypeOf(arr.dtype).toMatchTypeOf(); + if (arr.is("bigint")) { + expectTypeOf(arr.dtype).toMatchTypeOf(); + } + if (arr.is("number")) { + expectTypeOf(arr.dtype).toMatchTypeOf(); + } + if (arr.is("bool")) { + expectTypeOf(arr.dtype).toMatchTypeOf(); + } + if (arr.is("raw")) { + expectTypeOf(arr.dtype).toMatchTypeOf(); + } + if (arr.is("int8")) { + expectTypeOf(arr.dtype).toMatchTypeOf(); + } + }); +}); + +describe("Group", () => { + test("constructor", async () => { + let grp = new Group(new Map(), "/", { + zarr_format: 3, + node_type: "group", + attributes: { answer: 42 }, + }); + expect({ + attrs: await grp.attrs(), + path: grp.path, + store: grp.store, + }).toMatchInlineSnapshot(` + { + "attrs": { + "answer": 42, + }, + "path": "/", + "store": Map {}, + } + `); + }); +}); diff --git a/packages/core/__tests__/is-dtype.test.ts b/packages/core/__tests__/is-dtype.test.ts deleted file mode 100644 index e6e646bc..00000000 --- a/packages/core/__tests__/is-dtype.test.ts +++ /dev/null @@ -1,115 +0,0 @@ -// @ts-nocheck -import { assert, test } from "vitest"; -import { is_dtype } from "../src/util.js"; - -test.skip("is number", () => { - assert.ok(is_dtype("|i1", "number")); - assert.ok(is_dtype("i2", "number")); - assert.ok(is_dtype("i4", "number")); - - assert.ok(is_dtype("|u1", "number")); - assert.ok(is_dtype("u2", "number")); - assert.ok(is_dtype("u4", "number")); - - assert.ok(is_dtype("f4", "number")); - assert.ok(is_dtype("f8", "number")); -}); - -test.skip("is not number (bigint)", () => { - assert.notOk(is_dtype(">i8", "number")); - assert.notOk(is_dtype("u8", "number")); - assert.notOk(is_dtype(" { - assert.notOk(is_dtype(" { - assert.ok(is_dtype(">i8", "bigint")); - assert.ok(is_dtype("u8", "bigint")); - assert.ok(is_dtype(" { - assert.notOk(is_dtype("|i1", "bigint")); - assert.notOk(is_dtype("i2", "bigint")); - assert.notOk(is_dtype("i4", "bigint")); - - assert.notOk(is_dtype("|u1", "bigint")); - assert.notOk(is_dtype("u2", "bigint")); - assert.notOk(is_dtype("u4", "bigint")); - - assert.notOk(is_dtype("f4", "bigint")); - assert.notOk(is_dtype("f8", "bigint")); -}); - -test.skip("is not bigint (string)", () => { - assert.notOk(is_dtype("U43", "bigint")); - assert.notOk(is_dtype("|S1", "bigint")); - assert.notOk(is_dtype("|S24", "bigint")); -}); - -test.skip("is exact", () => { - assert.ok(is_dtype("f4", ">f4")); - assert.ok(is_dtype("|u1", "|u1")); - assert.ok(is_dtype(" { - assert.notOk(is_dtype("i2")); - assert.notOk(is_dtype("|S22", "|S225")); -}); - -test.skip("is fuzzy", () => { - // number - assert.ok(is_dtype("|i1", "i1")); - assert.ok(is_dtype("i2", "i2")); - assert.ok(is_dtype("i4", "i4")); - assert.ok(is_dtype("|u1", "u1")); - assert.ok(is_dtype("u2", "u2")); - assert.ok(is_dtype("u4", "u4")); - assert.ok(is_dtype("f4", "f4")); - assert.ok(is_dtype("f8", "f8")); - - // bigint - assert.ok(is_dtype(">i8", "i8")); - assert.ok(is_dtype("u8", "u8")); - assert.ok(is_dtype("U8", "U8")); - assert.ok(is_dtype(" { expect(get_strides(shape, order)).toStrictEqual(expected); }); }); + +describe("is_dtype", () => { + test.each<[DataType, boolean]>([ + ["int8", true], + ["int16", true], + ["int32", true], + ["uint8", true], + ["uint16", true], + ["uint32", true], + ["float32", true], + ["float64", true], + ["bool", false], + ["int64", false], + ["uint64", false], + ["r42", false], + ])("is_dtype(%s, 'number') -> %s", (dtype, expected) => { + expect(is_dtype(dtype, "number")).toBe(expected); + }); + + test.each<[DataType, boolean]>([ + ["int8", false], + ["int16", false], + ["int32", false], + ["uint8", false], + ["uint16", false], + ["uint32", false], + ["float32", false], + ["float64", false], + ["bool", false], + ["int64", true], + ["uint64", true], + ["r42", false], + ])("is_dtype(%s, 'bigint') -> %s", (dtype, expected) => { + expect(is_dtype(dtype, "bigint")).toBe(expected); + }); + + test.each<[DataType, boolean]>([ + ["int8", false], + ["int16", false], + ["int32", false], + ["uint8", false], + ["uint16", false], + ["uint32", false], + ["float32", false], + ["float64", false], + ["bool", false], + ["int64", false], + ["uint64", false], + ["r42", true], + ])("is_dtype(%s, 'raw') -> %s", (dtype, expected) => { + expect(is_dtype(dtype, "raw")).toBe(expected); + }); + + test.each([ + "int8", + "int16", + "int32", + "uint8", + "uint16", + "uint32", + "float32", + "float64", + "bool", + "int64", + "uint64", + "r42", + ])("is_dtype(%s, %s) -> true", (dtype) => { + expect(is_dtype(dtype, dtype)).toBe(true); + }); +}); diff --git a/packages/core/src/hierarchy.ts b/packages/core/src/hierarchy.ts index 828d895f..1ebe98ab 100644 --- a/packages/core/src/hierarchy.ts +++ b/packages/core/src/hierarchy.ts @@ -7,7 +7,14 @@ import type { Scalar, } from "./metadata.js"; import { create_codec_pipeline } from "./codecs.js"; -import { encode_chunk_key, json_decode_object, v2_marker } from "./util.js"; +import { + type DataTypeQuery, + encode_chunk_key, + is_dtype, + json_decode_object, + type NarrowDataType, + v2_marker, +} from "./util.js"; import { KeyError } from "./errors.js"; export class Location { @@ -136,4 +143,27 @@ export class Array< } return this.#attributes ?? {}; } + + /** + * A helper method to narrow `zarr.Array` Dtype. + * + * ```typescript + * let arr: zarr.Array = zarr.open(store, { kind: "array" }); + * + * // Option 1: narrow by scalar type (e.g. "bool", "raw", "bigint", "number") + * if (arr.is("bigint")) { + * // zarr.Array<"int64" | "uint64", FetchStore> + * } + * + * // Option 3: exact match + * if (arr.is("float32")) { + * // zarr.Array<"float32", FetchStore, "/"> + * } + * ``` + */ + is( + query: Query, + ): this is Array, Store> { + return is_dtype(this.dtype, query); + } } diff --git a/packages/core/src/util.ts b/packages/core/src/util.ts index d63c92da..c51660be 100644 --- a/packages/core/src/util.ts +++ b/packages/core/src/util.ts @@ -7,10 +7,13 @@ import { import type { ArrayMetadata, ArrayMetadataV2, + BigintDataType, CodecMetadata, DataType, GroupMetadata, GroupMetadataV2, + NumberDataType, + Raw, TypedArrayConstructor, } from "./metadata.js"; @@ -179,3 +182,31 @@ export function v2_to_v3_group_metadata(_meta: GroupMetadataV2): GroupMetadata { attributes: { [v2_marker]: true }, }; } + +export type DataTypeQuery = + | DataType + | "number" + | "bigint" + | "raw"; + +export type NarrowDataType< + Dtype extends DataType, + Query extends DataTypeQuery, +> = Query extends "number" ? NumberDataType + : Query extends "bigint" ? BigintDataType + : Query extends "raw" ? Raw + : Extract; + +export function is_dtype( + dtype: DataType, + query: Query, +): dtype is NarrowDataType { + if (query !== "raw" && query !== "number" && query !== "bigint") { + return dtype === query; + } + const is_raw = dtype.startsWith("r"); + if (query === "raw") return is_raw; + const is_bigint = dtype === "int64" || dtype === "uint64"; + if (query === "bigint") return is_bigint; + return !is_raw && !is_bigint && !(dtype === "bool"); +}