From 22954143b029e20586bfaa2050aac6d6323b0022 Mon Sep 17 00:00:00 2001 From: Jon Watte Date: Fri, 29 Nov 2024 12:20:58 -0800 Subject: [PATCH] safetensors-parser: parse safetensors files --- .gitignore | 3 + LICENSE.md | 18 ++ README.md | 220 +++++++++++++++++++ dist/src/index.d.ts | 49 +++++ dist/src/index.d.ts.map | 1 + dist/src/index.js | 456 ++++++++++++++++++++++++++++++++++++++ dist/src/index.js.map | 1 + package-lock.json | 30 +++ package.json | 34 +++ src/index.ts | 470 ++++++++++++++++++++++++++++++++++++++++ test/index-test.ts | 111 ++++++++++ tsconfig.json | 32 +++ 12 files changed, 1425 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE.md create mode 100644 README.md create mode 100644 dist/src/index.d.ts create mode 100644 dist/src/index.d.ts.map create mode 100644 dist/src/index.js create mode 100644 dist/src/index.js.map create mode 100644 package-lock.json create mode 100644 package.json create mode 100644 src/index.ts create mode 100644 test/index-test.ts create mode 100644 tsconfig.json diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5331981 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +node_modules +dist/tsconfig.tsbuildinfo +dist/test diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..1929e42 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,18 @@ +Copyright 2024 Reve AI Authors + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the “Software”), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..dafed32 --- /dev/null +++ b/README.md @@ -0,0 +1,220 @@ +# Safetensors Parser for Javascript + +"safetensors" is the highest-performance file format in wide use within the +pytorch machine learning community. It's a very simple format, and the Python +libraries for dealing with the format are generally of good quality. However, +sometimes you need to send tensors around some web infrastructure that might +include a browser, or a node or deno server, and then what do you do? + +This library includes a utility class to build and save safetensor files, as +well as parse, load, and inspect them. It validates the files and throws an +error when the file is somehow not up to snuff. Some of those validations can +be turned off if you like not knowing things. + +Safetensors-parser attempts to be memory efficient, inasmuchas anything in this +area can be. Each separate tensor from a parsed file references the underlying +byte array. Each tensor is written as a separate chunk if you provide a write +callback to saving tensors. This could allow you to stream tensor data over a +network without having to keep all of them in RAM at once. + +## Usage: Loading + +```typescript +import { parseSafeTensors } from "@reve-ai/safetensors-parser"; + +const stuff: UInt8array = ...; + +const tensorMap = parseSafeTensors(stuff); +const myTensor = tensorMap.getTensor("my-tensor"); +``` + +## Usage: Saving + +```typescript +import { saveSafeTensors, TensorMap } from "@reve-ai/safetensors-parser"; + +const tensorMap = new TensorMap(); +tensorMap.setMetadata("creator", "me"); +tensorMap.addTensor( + "identity", + new UInt8array([1, 0, 0, 0, 1, 0, 0, 0, 1], "UINT8", [3, 3]) +); + +// Use the default writer, which returns the full byte array. +// If you use a custom write callback, nothing will be returned. +const stuff: UInt8array = saveSafeTensors(tensorMap); +``` + +# FAQ + +These questions might have been asked by some person at some point, making them +more frequently asked than questions that nobody has asked. + +## Why does the package.json have no bundler? + +This package is a proper module, provided in a single file (dist/src/index.js) +which you can just import as-is, no bundler required. + +## Why does the package.json have no test runner? + +The tests live in a single file that you can just run with node.js. + +## Can you add convnient dependency wrappers for all my favorite frameworks and image loading libraries? + +This package has no runtime dependencies, and the only development time +dependency is typescript, and it will stay that way. + +## I like the Buffer class better than the UInt8array class. + +That's not a question. Also, `Buffer` is not available in the browser; this +library is intended to be usable both in a browser and on a server. + +# Detailed Usage + +Other than `parseSafeTensors` and `saveSafeTensors`, the rest of the functions +are largely internal but are exported in case you want to use them or test them. +The `TensorMap` class is intended for direct usage, and you could also use +`TensorRef` directly. + +## TensorMap + +```typescript +class TensorMap { + constructor(); + getTensor(name: TensorName): TensorRef | undefined; + addTensor( + name: TensorName, + bytes: Bytes, + format: Format, + shape: Shape + ): TensorRef; + addTensor(name: TensorName, tensor: TensorRef): TensorRef; + setTensor(name: TensorName, tensor: TensorRef): TensorRef; + getOrMakeTensor(name: TensorName, factory: () => TensorRef): TensorRef; + getMetaValue(name: TensorName): string | undefined; + setMetaValue(name: TensorName, value: string): void; + get allMetadata(): Map; + get allTensors(): Map; + setAllMetadata(metadata: Map): void; + removeTensor(name: TensorName | TensorRef): void; +} +``` + +`TensorMap` is a collection of multiple tensors. It can be manually constructed +or loaded from a safetensors file. It provides methods to manage tensors and +metadata. + +- `constructor()`: Creates a new empty TensorMap. +- `getTensor(name)`: Retrieves a tensor by name, or returns undefined if not found. +- `addTensor(name, ...)`: Adds a new tensor to the map. Throws an error if the name already exists. +- `setTensor(name, tensor)`: Sets a tensor, replacing any existing tensor with the same name. +- `setMetaValue(name, value)`: Sets a metadata value. +- `allMetadata`: Getter that returns all metadata as a Map. +- `allTensors`: Getter that returns all tensors as a Map. +- `setAllMetadata(metadata)`: Sets all metadata at once. +- `removeTensor(name)`: Removes a tensor from the map. + +## TensorRef + +```typescript +class TensorRef { + constructor(name: TensorName, bytes: Bytes, format: Format, shape: Shape); + get parent(): TensorMap | undefined; + get name(): string; + set name(val: string); + removeIfParented(): void; + sanityCheck(): void; +} +``` + +`TensorRef` represents a specific tensor within a safetensors archive. It can be +created standalone or obtained from a TensorMap. + +- `constructor(name, bytes, format, shape)`: Creates a new TensorRef with the given properties. +- `parent`: Getter that returns the parent TensorMap, if any. +- `name`: Getter and setter for the tensor name. Setting the name will update it in the parent TensorMap if present. +- `removeIfParented()`: Removes the tensor from its parent TensorMap, if it has one. +- `sanityCheck()`: Performs sanity checks on the tensor, ensuring its size matches the declared shape and format. + +Note: The `bytes`, `format`, and `shape` properties are read-only and can be accessed directly on the TensorRef instance. + +## parseSafeTensors + +```typescript +function parseSafeTensors( + bytes: Uint8Array, + ignoreInvalid?: boolean +): TensorMap; +``` + +Parses a safetensors file (as a Uint8Array) and returns a TensorMap containing +the tensors and metadata stored in that file. If the file is not fully compliant +with the spec, an error will be thrown. You can pass `ignoreInvalid=true` to +attempt parsing anyway. + +## saveSafeTensors + +```typescript +function saveSafeTensors(tensorMap: TensorMap): Uint8Array; +function saveSafeTensors( + tensorMap: TensorMap, + write: (data: Uint8Array) => void +): undefined; +``` + +Generates the contents of a safetensors file from a given TensorMap. It can +either return a Uint8Array containing the file contents or call a provided write +function with chunks of data. + +## sanityCheckTensorsHeader + +```typescript +function sanityCheckTensorsHeader( + ignoreInvalid: boolean, + bytes: Uint8Array, + filesize: number +): number; +``` + +Verifies that the header of a safetensors file seems legitimate. Returns the +size of the JSON chunk that starts at offset 8. + +## unsafeGetHeaderSize + +```typescript +function unsafeGetHeaderSize(bytes: Uint8Array): number; +``` + +Retrieves the header size from the first 4 bytes of a safetensors file. This +function is unsafe as it doesn't perform any checks. + +## sanityCheckTensorsParsed + +```typescript +function sanityCheckTensorsParsed( + ignoreInvalid: boolean, + j: Object, + chunksize: number +): void; +``` + +Performs sanity checks on the parsed JSON content of a safetensors file. + +## sanityCheck + +```typescript +function sanityCheck(tensor: TensorRef): void; +function sanityCheck(bytes: Uint8Array, format: string, shape: number[]): void; +``` + +Performs sanity checks on a tensor, ensuring that its size matches its declared +shape and format. + +## formatLength + +```typescript +function formatLength(format: string): number; +``` + +Returns the byte length of a given tensor format. This may be a fractional value +for formats like 4-bit data. diff --git a/dist/src/index.d.ts b/dist/src/index.d.ts new file mode 100644 index 0000000..f56fe17 --- /dev/null +++ b/dist/src/index.d.ts @@ -0,0 +1,49 @@ +export type Bytes = Uint8Array; +export type Format = string; +export type Integer = number; +export type OffsetPair = [Integer, Integer]; +export type Shape = Integer[]; +export type TensorName = string; +export declare class IgnorableError extends Error { + constructor(message: string); +} +export declare function parseSafeTensors(bytes: Bytes, ignoreInvalid?: boolean): TensorMap; +export declare function saveSafeTensors(tensorMap: TensorMap): Uint8Array; +export declare function saveSafeTensors(tensorMap: TensorMap, write: (data: Bytes) => void): undefined; +export declare class TensorMap { + refs: Map; + metadata: Map; + constructor(); + getTensor(name: TensorName): TensorRef | undefined; + addTensor(name: TensorName, bytes: Bytes, format: Format, shape: Shape): TensorRef; + addTensor(name: TensorName, tensor: TensorRef): TensorRef; + setTensor(name: TensorName, tensor: TensorRef): TensorRef; + getOrMakeTensor(name: TensorName, factory: () => TensorRef): TensorRef; + getMetaValue(name: TensorName): string | undefined; + setMetaValue(name: TensorName, value: string): void; + get allMetadata(): Map; + get allTensors(): Map; + setAllMetadata(metadata: Map): void; + removeTensor(name: TensorName | TensorRef): void; + private _setTensor; +} +export declare class TensorRef { + readonly bytes: Bytes; + readonly format: Format; + readonly shape: Shape; + constructor(name: TensorName, bytes: Bytes, format: Format, shape: Shape); + get parent(): TensorMap | undefined; + private _name; + _parent?: TensorMap; + get name(): string; + set name(val: string); + removeIfParented(): void; + sanityCheck(): void; +} +export declare function sanityCheckTensorsHeader(ignoreInvalid: boolean, bytes: Bytes, filesize: Integer): Integer; +export declare function unsafeGetHeaderSize(bytes: Bytes): Integer; +export declare function sanityCheckTensorsParsed(ignoreInvalid: boolean, j: Object, chunksize: Integer): void; +export declare function sanityCheck(tensor: TensorRef): void; +export declare function sanityCheck(bytes: Bytes, format: Format, shape: Shape): void; +export declare function formatLength(format: Format): Integer; +//# sourceMappingURL=index.d.ts.map \ No newline at end of file diff --git a/dist/src/index.d.ts.map b/dist/src/index.d.ts.map new file mode 100644 index 0000000..fd60dbd --- /dev/null +++ b/dist/src/index.d.ts.map @@ -0,0 +1 @@ +{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../src/index.ts"],"names":[],"mappings":"AACA,MAAM,MAAM,KAAK,GAAG,UAAU,CAAC;AAG/B,MAAM,MAAM,MAAM,GAAG,MAAM,CAAC;AAE5B,MAAM,MAAM,OAAO,GAAG,MAAM,CAAC;AAE7B,MAAM,MAAM,UAAU,GAAG,CAAC,OAAO,EAAE,OAAO,CAAC,CAAC;AAE5C,MAAM,MAAM,KAAK,GAAG,OAAO,EAAE,CAAC;AAE9B,MAAM,MAAM,UAAU,GAAG,MAAM,CAAC;AAEhC,qBAAa,cAAe,SAAQ,KAAK;gBAC5B,OAAO,EAAE,MAAM;CAI3B;AAMD,wBAAgB,gBAAgB,CAAC,KAAK,EAAE,KAAK,EAAE,aAAa,CAAC,EAAE,OAAO,GAAG,SAAS,CAejF;AAUD,wBAAgB,eAAe,CAAC,SAAS,EAAE,SAAS,GAAG,UAAU,CAAA;AACjE,wBAAgB,eAAe,CAAC,SAAS,EAAE,SAAS,EAAE,KAAK,EAAE,CAAC,IAAI,EAAE,KAAK,KAAK,IAAI,GAAG,SAAS,CAAA;AAiG9F,qBAAa,SAAS;IACrB,IAAI,EAAE,GAAG,CAAC,MAAM,EAAE,SAAS,CAAC,CAAa;IACzC,QAAQ,EAAE,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,CAAa;;IAG1C,SAAS,CAAC,IAAI,EAAE,UAAU,GAAG,SAAS,GAAG,SAAS;IAMlD,SAAS,CAAC,IAAI,EAAE,UAAU,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,GAAG,SAAS;IAClF,SAAS,CAAC,IAAI,EAAE,UAAU,EAAE,MAAM,EAAE,SAAS,GAAG,SAAS;IAiBzD,SAAS,CAAC,IAAI,EAAE,UAAU,EAAE,MAAM,EAAE,SAAS,GAAG,SAAS;IAOzD,eAAe,CAAC,IAAI,EAAE,UAAU,EAAE,OAAO,EAAE,MAAM,SAAS,GAAG,SAAS;IAWtE,YAAY,CAAC,IAAI,EAAE,UAAU,GAAG,MAAM,GAAG,SAAS;IAI/C,YAAY,CAAC,IAAI,EAAE,UAAU,EAAE,KAAK,EAAE,MAAM;IAI/C,IAAI,WAAW,IAAI,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,CAErC;IAED,IAAI,UAAU,IAAI,GAAG,CAAC,MAAM,EAAE,SAAS,CAAC,CAEvC;IAEE,cAAc,CAAC,QAAQ,EAAE,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC;IAK/C,YAAY,CAAC,IAAI,EAAE,UAAU,GAAG,SAAS;IAiBzC,OAAO,CAAC,UAAU;CAQlB;AAID,qBAAa,SAAS;IAGS,QAAQ,CAAC,KAAK,EAAE,KAAK;IAAE,QAAQ,CAAC,MAAM,EAAE,MAAM;IAAE,QAAQ,CAAC,KAAK,EAAE,KAAK;gBAAvF,IAAI,EAAE,UAAU,EAAW,KAAK,EAAE,KAAK,EAAW,MAAM,EAAE,MAAM,EAAW,KAAK,EAAE,KAAK;IAInG,IAAI,MAAM,IAAI,SAAS,GAAG,SAAS,CAAyB;IAC5D,OAAO,CAAC,KAAK,CAAa;IAC1B,OAAO,CAAC,EAAE,SAAS,CAAC;IAGpB,IAAI,IAAI,IAAI,MAAM,CAAuB;IACzC,IAAI,IAAI,CAAC,GAAG,EAAE,MAAM,EAWnB;IAED,gBAAgB;IAMhB,WAAW;CAGX;AAOD,wBAAgB,wBAAwB,CAAC,aAAa,EAAE,OAAO,EAAE,KAAK,EAAE,KAAK,EAAE,QAAQ,EAAE,OAAO,GAAG,OAAO,CAYzG;AAED,wBAAgB,mBAAmB,CAAC,KAAK,EAAE,KAAK,GAAG,OAAO,CAEzD;AAeD,wBAAgB,wBAAwB,CAAC,aAAa,EAAE,OAAO,EAAE,CAAC,EAAE,MAAM,EAAE,SAAS,EAAE,OAAO,GAAG,IAAI,CAUpG;AAmGD,wBAAgB,WAAW,CAAC,MAAM,EAAE,SAAS,GAAG,IAAI,CAAA;AACpD,wBAAgB,WAAW,CAAC,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,GAAG,IAAI,CAAA;AA6B7E,wBAAgB,YAAY,CAAC,MAAM,EAAE,MAAM,GAAG,OAAO,CAYpD"} \ No newline at end of file diff --git a/dist/src/index.js b/dist/src/index.js new file mode 100644 index 0000000..276db2c --- /dev/null +++ b/dist/src/index.js @@ -0,0 +1,456 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.TensorRef = exports.TensorMap = exports.IgnorableError = void 0; +exports.parseSafeTensors = parseSafeTensors; +exports.saveSafeTensors = saveSafeTensors; +exports.sanityCheckTensorsHeader = sanityCheckTensorsHeader; +exports.unsafeGetHeaderSize = unsafeGetHeaderSize; +exports.sanityCheckTensorsParsed = sanityCheckTensorsParsed; +exports.sanityCheck = sanityCheck; +exports.formatLength = formatLength; +class IgnorableError extends Error { + constructor(message) { + super(message + " You can ignore this error with ignoreInvalid=true in parseSafeTensors()."); + this.name = "IgnorableError"; + } +} +exports.IgnorableError = IgnorableError; +// Given a safetensors file (as a Uint8Array) return the TensorMap of the +// tensors and metadata stored in that file. If the file is not fully compliant +// with the spec (additional properties, etc) an error will be thrown. If you'd +// like to try using it anyway, you can pass ignoreInvalid=true. +function parseSafeTensors(bytes, ignoreInvalid) { + sanityCheckTensorsHeader(!!ignoreInvalid, bytes, bytes.length); + const hdrsize = unsafeGetHeaderSize(bytes); + const ret = new TensorMap(); + const j = JSON.parse(new TextDecoder().decode(bytes.slice(8, 8 + hdrsize))); + sanityCheckTensorsParsed(!!ignoreInvalid, j, bytes.length - 8 - hdrsize); + for (const [name, val] of Object.entries(j)) { + if (name === "__metadata__") { + ret.setAllMetadata(new Map(Object.entries(val))); + } + else { + const { dtype, shape, data_offsets } = val; + ret.addTensor(name, bytes.slice(8 + hdrsize + data_offsets[0], 8 + hdrsize + data_offsets[1]), dtype, shape); + } + } + return ret; +} +function saveSafeTensors(tensorMap, write) { + const save = new TensorSaver(tensorMap, write); + save.calcOffset(); + save.calcHeader(); + save.writeHeader(); + save.writeTensors(); + return save.ret; +} +const padder = new Uint8Array([32, 32, 32, 32, 32, 32, 32, 32]); +// Simple helper to format the safetensors file with header and data chunks. +class TensorSaver { + tensorMap; + write; + hdr = {}; + constructor(tensorMap, write) { + this.tensorMap = tensorMap; + this.write = write; + this.setMetadata(); + } + tes; + offset; + hdrBuf; + lenbuf; + hblen; + ret; + setMetadata() { + const md = {}; + let toset = false; + for (const [name, val] of this.tensorMap.allMetadata.entries()) { + md[name] = val; + toset = true; + } + if (toset) { + this.hdr["__metadata__"] = md; + } + } + calcOffset() { + let offset = 0; + this.tes = Array.from(this.tensorMap.allTensors.entries()); + for (const [name, tensor] of this.tes) { + this.hdr[name] = { + dtype: tensor.format, + shape: tensor.shape, + data_offsets: [offset, offset + tensor.bytes.length] + }; + offset += (tensor.bytes.length + 7) & ~7; + } + this.offset = offset; + return offset; + } + calcHeader() { + const hdrBuf = (new TextEncoder()).encode(JSON.stringify(this.hdr)); + const hblen = (hdrBuf.length + 7) & ~7; + if (hblen > 100 * 1024 * 1024) { + throw new Error("The metadata is too large to be saved in a safetensors file."); + } + this.lenbuf = new Uint8Array([hblen & 0xff, (hblen >> 8) & 0xff, (hblen >> 16) & 0xff, (hblen >> 24) & 0xff, 0, 0, 0, 0]); + this.hdrBuf = hdrBuf; + this.hblen = hblen; + return hblen; + } + writeHeader() { + if (!this.write) { + this.ret = new Uint8Array(this.offset + this.hblen + 8); + let wrote = 0; + this.write = (data) => { + this.ret.set(data, wrote); + wrote += data.length; + }; + } + this.write(this.lenbuf); + this.write(this.hdrBuf); + if (this.hblen > this.hdrBuf.length) { + this.write(padder.slice(0, this.hblen - this.hdrBuf.length)); + } + } + writeTensors() { + let written = 0; + for (const [name, tensor] of this.tes) { + this.write(tensor.bytes); + written += tensor.bytes.length; + if (tensor.bytes.length & 7) { + this.write(padder.slice(0, 8 - (tensor.bytes.length & 7))); + written += 8 - (tensor.bytes.length & 7); + } + const end = this.hdr[name]["data_offsets"][1]; + if (((end + 7) & ~7) !== written) { + throw new Error(`Internal tensor alignment problem: "${name}": ${end} !== ${written}`); + } + } + } +} +// TensorMap is a collection of possibly multiple tensors. It has either been manually +// constructed, or loaded from a safetensors file. It can be used to return data (in +// little endian format) or to save the tensors to a safetensors formatted file. Note +// that tensor refs are not cloned -- if you mutate the underlying data, you will +// indirectly mutate the TensorMap, too! +class TensorMap { + refs = new Map(); + metadata = new Map(); + constructor() { } + // If the tensor exists in the map, return it, else return undefined. + getTensor(name) { + return this.refs.get(name); + } + addTensor(name, data, format, shape) { + if (this.refs.get(name)) { + throw new Error(`The name "${name}" already exists in the TensorMap.`); + } + if (data instanceof TensorRef) { + return this._setTensor(name, data); + } + if (!format || !shape) { + throw new Error(`You must provide format and shape when adding tensor "${name}".`); + } + else { + return this._setTensor(name, new TensorRef(name, data, format, shape)); + } + } + // Set this name to be this tensor, no matter whether it exists or not. Remove + // any previous tensor of the same name (which may be the same value!) + setTensor(name, tensor) { + this.removeTensor(name); + return this.addTensor(name, tensor); + } + // If the tensor exists, then return it. Otherwise, create it with the factory + // function, add it to the map, and return it. + getOrMakeTensor(name, factory) { + let ten = this.refs.get(name); + if (ten) { + return ten; + } + ten = factory(); + this.refs.set(name, ten); + return ten; + } + // If some metadata value exists, return it, else return undefined. + getMetaValue(name) { + return this.metadata.get(name); + } + setMetaValue(name, value) { + this.metadata.set(name, value); + } + get allMetadata() { + return this.metadata; + } + get allTensors() { + return this.refs; + } + setAllMetadata(metadata) { + this.metadata = metadata; + } + // Remove a tensor from the map. It's OK if it doesn't exist. + removeTensor(name) { + let ten; + if (typeof name === "string") { + ten = this.refs.get(name); + if (!ten) { + return; + } + } + else { + ten = name; + } + if (ten.parent !== this) { + throw new Error("You can only remove a tensor from the TensorMap that it belongs to."); + } + delete ten._parent; + this.refs.delete(ten.name); + } + _setTensor(name, tensor) { + if (tensor.parent && tensor.parent !== this) { + throw new Error("You can only add a tensor to a TensorMap that it doesn't already belong to."); + } + this.refs.set(name, tensor); + tensor._parent = this; + return tensor; + } +} +exports.TensorMap = TensorMap; +; +// TensorRef is a specific tensor within a safetensors archive. You can make one free +// standing, or get one from an archive. +class TensorRef { + bytes; + format; + shape; + // Make a new tensorref referencing the given data. The new tensorref is not parented. + // Note that the shape and format are not sanity checked! Hopefully the bytes are accurate. + constructor(name, bytes, format, shape) { + this.bytes = bytes; + this.format = format; + this.shape = shape; + this._name = name; + } + // A tensorref may have a parent tensor map, or may be freestanding. + get parent() { return this._parent; } + _name; + _parent; + // If you mutate the name of a tensor in a tensor map, it will rename the tensor + // in the map. It's an error to rename to a name that already exists in the map. + get name() { return this._name; } + set name(val) { + if (this._parent) { + if (this._parent.refs.get(val)) { + throw new Error(`The name ${val} already exists in the parent TensorMap.`); + } + this._parent.refs.delete(this._name); + } + this._name = val; + if (this._parent) { + this._parent.refs.set(this._name, this); + } + } + // If the tensor is parented, remove it from the parent. + removeIfParented() { + if (this._parent) { + this._parent.removeTensor(this); + } + } + // Check whether the tensor is sane + sanityCheck() { + sanityCheck(this); + } +} +exports.TensorRef = TensorRef; +; +// Given the bytes of a safetensors file, verify that the header seems legitimate. +// Note that you only need 10 bytes to check this. The value returned is the size +// of the JSON chunk that starts at offset 8, so if the total file size is smaller +// than 8+return, the file is likely truncated. +function sanityCheckTensorsHeader(ignoreInvalid, bytes, filesize) { + if (!arrayCompare(Array.from(bytes.slice(4, 9)), [0, 0, 0, 0, "{".charCodeAt(0)])) { + throw new Error("The file header is not a valid safetensors file."); + } + const val = unsafeGetHeaderSize(bytes); + if (val > 100 * 1024 * 1024 && !ignoreInvalid) { + throw new IgnorableError("The file header is too long to be a valid safetensors file."); + } + if (filesize < 8 + val) { + throw new Error("The safetensors file seems truncated or otherwise not valid."); + } + return val; +} +function unsafeGetHeaderSize(bytes) { + return ((bytes[0] | 0) + ((bytes[1] | 0) * 256) + ((bytes[2] | 0) * 256 * 256) + ((bytes[3] | 0) * 256 * 256 * 256)) | 0; +} +// Return true if each element compares equal in the arrays +function arrayCompare(a, b) { + if (a.length !== b.length) { + return false; + } + for (let i = 0; i < a.length; i++) { + if (a[i] !== b[i]) { + return false; + } + } + return true; +} +function sanityCheckTensorsParsed(ignoreInvalid, j, chunksize) { + const covered = []; + for (const [name, value] of Object.entries(j)) { + if (name === "__metadata__") { + checkMetadata(ignoreInvalid, value); + continue; + } + const tac = new TensorAttributeChecker(chunksize, covered, ignoreInvalid); + tac.checkTensorValue(name, value); + } +} +function checkMetadata(ignoreInvalid, value) { + for (const [key, strval] of Object.entries(value)) { + if (typeof strval !== "string" && !ignoreInvalid) { + throw new IgnorableError(`The metadata value ${key} is not a string (${typeof strval}).`); + } + } +} +class TensorAttributeChecker { + chunksize; + covered; + ignoreInvalid; + has_dtype; + has_shape; + has_data_offsets; + constructor(chunksize, covered, ignoreInvalid = false, has_dtype = false, has_shape = false, has_data_offsets = false) { + this.chunksize = chunksize; + this.covered = covered; + this.ignoreInvalid = ignoreInvalid; + this.has_dtype = has_dtype; + this.has_shape = has_shape; + this.has_data_offsets = has_data_offsets; + } + dtype(key, val) { + if (typeof val !== "string") { + throw new Error(`The dtype for ${key} is not a string (${typeof val}).`); + } + this.has_dtype = true; + } + shape(key, val) { + this.checkTensorShape(key, val); + this.has_shape = true; + } + data_offsets(key, val) { + this.checkDataOffsets(key, val); + this.has_data_offsets = true; + } + attrOk(key) { + return key === "dtype" || key === "shape" || key === "data_offsets"; + } + assertAttributePresence(name) { + if (!this.has_dtype) { + throw new Error(`The tensor "${name}" is missing the dtype key.`); + } + if (!this.has_shape) { + throw new Error(`The tensor "${name}" is missing the shape key.`); + } + if (!this.has_data_offsets) { + throw new Error(`The tensor "${name}" is missing the data_offsets key.`); + } + } + checkTensorShape(key, val) { + if (!Array.isArray(val)) { + throw new Error(`The shape for ${key} is not an array (${typeof val})`); + } + for (const v of val) { + if (typeof v !== "number" && !this.ignoreInvalid) { + throw new IgnorableError(`The shape for ${key} is not an array of numbers (${typeof v})`); + } + } + } + checkDataOffsets(key, obj) { + if (!Array.isArray(obj) || obj.length !== 2) { + throw new Error(`The data_offsets for ${key} is not an array with length 2 (${typeof obj})`); + } + this.checkDataOffsetBasics(key, obj); + for (const [start, end] of this.covered) { + this.checkTensorOverlap(key, start, end, obj); + } + this.covered.push(obj); + } + checkDataOffsetBasics(key, val) { + let err = null; + if (!pairInRange(val, this.chunksize) && !this.ignoreInvalid) { + err = new IgnorableError(`The data_offsets for ${key} is out of range (${val[0]}-${val[1]} versus ${this.chunksize}).`); + } + if (err && !this.ignoreInvalid) { + throw err; + } + } + checkTensorOverlap(key, start, end, val) { + let err = null; + if (val[1] > start && val[1] <= end) { + err = new IgnorableError(`The data_offsets end for ${key} overlaps with another tensor (${val[1]} > ${start} && ${val[1]} <= ${end}).`); + } + if (val[0] >= start && val[0] < end) { + err = new IgnorableError(`The data_offsets start for ${key} overlaps with another tensor (${val[0]} >= ${start} && ${val[0]} < ${end}).`); + } + if (err && !this.ignoreInvalid) { + throw err; + } + } + checkTensorValue(name, value) { + for (const [key, val] of Object.entries(value)) { + if (!this.attrOk(key)) { + if (!this.ignoreInvalid) { + throw new IgnorableError(`The tensor description for "${name}" has an invalid key "${key}".`); + } + continue; + } + this[key](name, val); + } + this.assertAttributePresence(name); + } +} +function pairInRange(pair, chunksize) { + return pair[0] >= 0 && pair[0] <= chunksize && pair[1] >= 0 && pair[1] <= chunksize && pair[0] <= pair[1]; +} +function sanityCheck(bytes, format, shape) { + if (bytes instanceof TensorRef) { + _sanityCheck(bytes.name, bytes.bytes, bytes.format, bytes.shape); + } + else { + _sanityCheck("created", bytes, format, shape); + } +} +function _sanityCheck(name, bytes, format, shape) { + if (!format || !shape) { + throw new Error(`The tensor "${name}" is missing format and shape information.`); + } + // Sanity check the shape. + if (shape.length === 0) { + shape = [1]; // a scalar + } + let num = 1; + for (let i = 0; i < shape.length; i++) { + num *= shape[i]; + } + const fl = formatLength(format); + if (bytes.length !== num * fl) { + throw new IgnorableError(`The tensor "${name}" is the wrong size ${bytes.length} bytes for its shape ${JSON.stringify(shape)} and format "${format}": should be ${num * fl} bytes.`); + } +} +// Get the byte length of a given format constant. This may be a fractional +// value if the format is something like 4-bit data. +function formatLength(format) { + const num = format.match(/\d+/); + if (!num) { + throw new Error(`The element format "${format}" is not a valid format (must contain bit size).`); + } + const parsed = parseInt(num[0], 10) | 0; + // This is an attempt to support sub-byte precision. It will work + // for many cases, but doesn't deal with rounding-up of un-aligned sizes. + if ((parsed | 0) & ((parsed | 0) - 1)) { + throw new Error(`The element format "${format}" is not a valid format (must be power of 2).`); + } + return parsed / 8.0; +} +//# sourceMappingURL=index.js.map \ No newline at end of file diff --git a/dist/src/index.js.map b/dist/src/index.js.map new file mode 100644 index 0000000..b7b5ee1 --- /dev/null +++ b/dist/src/index.js.map @@ -0,0 +1 @@ +{"version":3,"file":"index.js","sourceRoot":"","sources":["../../src/index.ts"],"names":[],"mappings":";;;AAyBA,4CAeC;AAYD,0CAOC;AAoOD,4DAYC;AAED,kDAEC;AAeD,4DAUC;AAqGD,kCAMC;AAsBD,oCAYC;AAvcD,MAAa,cAAe,SAAQ,KAAK;IACxC,YAAY,OAAe;QAC1B,KAAK,CAAC,OAAO,GAAG,2EAA2E,CAAC,CAAC;QAC7F,IAAI,CAAC,IAAI,GAAG,gBAAgB,CAAC;IAC9B,CAAC;CACD;AALD,wCAKC;AAED,yEAAyE;AACzE,+EAA+E;AAC/E,+EAA+E;AAC/E,gEAAgE;AAChE,SAAgB,gBAAgB,CAAC,KAAY,EAAE,aAAuB;IACrE,wBAAwB,CAAC,CAAC,CAAC,aAAa,EAAE,KAAK,EAAE,KAAK,CAAC,MAAM,CAAC,CAAC;IAC5D,MAAM,OAAO,GAAG,mBAAmB,CAAC,KAAK,CAAC,CAAC;IAC3C,MAAM,GAAG,GAAG,IAAI,SAAS,EAAE,CAAC;IAC5B,MAAM,CAAC,GAAW,IAAI,CAAC,KAAK,CAAC,IAAI,WAAW,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,GAAC,OAAO,CAAC,CAAC,CAAC,CAAC;IACrF,wBAAwB,CAAC,CAAC,CAAC,aAAa,EAAE,CAAC,EAAE,KAAK,CAAC,MAAM,GAAG,CAAC,GAAG,OAAO,CAAC,CAAC;IACtE,KAAK,MAAM,CAAC,IAAI,EAAE,GAAG,CAAC,IAAI,MAAM,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC;QAC1C,IAAI,IAAI,KAAK,cAAc,EAAE,CAAC;YAC1B,GAAG,CAAC,cAAc,CAAC,IAAI,GAAG,CAAiB,MAAM,CAAC,OAAO,CAAC,GAAa,CAAC,CAAC,CAAC,CAAC;QAC/E,CAAC;aAAM,CAAC;YACJ,MAAM,EAAC,KAAK,EAAE,KAAK,EAAE,YAAY,EAAC,GAAG,GAAgE,CAAC;YACtG,GAAG,CAAC,SAAS,CAAC,IAAI,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,GAAC,OAAO,GAAC,YAAY,CAAC,CAAC,CAAC,EAAE,CAAC,GAAC,OAAO,GAAC,YAAY,CAAC,CAAC,CAAC,CAAC,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC;QACzG,CAAC;IACL,CAAC;IACD,OAAO,GAAG,CAAC;AACf,CAAC;AAYD,SAAgB,eAAe,CAAC,SAAoB,EAAE,KAA6B;IAClF,MAAM,IAAI,GAAG,IAAI,WAAW,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;IAC/C,IAAI,CAAC,UAAU,EAAE,CAAC;IAClB,IAAI,CAAC,UAAU,EAAE,CAAC;IAClB,IAAI,CAAC,WAAW,EAAE,CAAC;IACnB,IAAI,CAAC,YAAY,EAAE,CAAC;IACpB,OAAO,IAAI,CAAC,GAAG,CAAC;AACjB,CAAC;AAED,MAAM,MAAM,GAAG,IAAI,UAAU,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC;AAEhE,4EAA4E;AAC5E,MAAM,WAAW;IAEK;IAA8B;IAD1C,GAAG,GAA8H,EAAE,CAAC;IAC7I,YAAqB,SAAoB,EAAU,KAA6B;QAA3D,cAAS,GAAT,SAAS,CAAW;QAAU,UAAK,GAAL,KAAK,CAAwB;QAC/E,IAAI,CAAC,WAAW,EAAE,CAAC;IACpB,CAAC;IACD,GAAG,CAA8B;IACjC,MAAM,CAAW;IACjB,MAAM,CAAc;IACpB,MAAM,CAAc;IACpB,KAAK,CAAW;IAChB,GAAG,CAAc;IACjB,WAAW;QACV,MAAM,EAAE,GAAG,EAA6B,CAAC;QACzC,IAAI,KAAK,GAAI,KAAK,CAAC;QACnB,KAAK,MAAM,CAAC,IAAI,EAAE,GAAG,CAAC,IAAI,IAAI,CAAC,SAAS,CAAC,WAAW,CAAC,OAAO,EAAE,EAAE,CAAC;YAChE,EAAE,CAAC,IAAI,CAAC,GAAG,GAAG,CAAC;YACf,KAAK,GAAG,IAAI,CAAC;QACd,CAAC;QACD,IAAI,KAAK,EAAE,CAAC;YACX,IAAI,CAAC,GAAG,CAAC,cAAc,CAAC,GAAG,EAAE,CAAC;QAC/B,CAAC;IACF,CAAC;IACD,UAAU;QACT,IAAI,MAAM,GAAG,CAAC,CAAC;QACf,IAAI,CAAC,GAAG,GAAG,KAAK,CAAC,IAAI,CAAC,IAAI,CAAC,SAAS,CAAC,UAAU,CAAC,OAAO,EAAE,CAAC,CAAC;QAC3D,KAAK,MAAM,CAAC,IAAI,EAAE,MAAM,CAAC,IAAI,IAAI,CAAC,GAAG,EAAE,CAAC;YACvC,IAAI,CAAC,GAAG,CAAC,IAAI,CAAC,GAAG;gBAChB,KAAK,EAAE,MAAM,CAAC,MAAM;gBACpB,KAAK,EAAE,MAAM,CAAC,KAAK;gBACnB,YAAY,EAAE,CAAC,MAAM,EAAE,MAAM,GAAG,MAAM,CAAC,KAAK,CAAC,MAAM,CAAC;aACpD,CAAC;YACF,MAAM,IAAI,CAAC,MAAM,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;QAC1C,CAAC;QACD,IAAI,CAAC,MAAM,GAAG,MAAM,CAAC;QACrB,OAAO,MAAM,CAAC;IACf,CAAC;IACD,UAAU;QACT,MAAM,MAAM,GAAG,CAAC,IAAI,WAAW,EAAE,CAAC,CAAC,MAAM,CAAC,IAAI,CAAC,SAAS,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC;QACpE,MAAM,KAAK,GAAG,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;QACvC,IAAI,KAAK,GAAG,GAAG,GAAC,IAAI,GAAC,IAAI,EAAE,CAAC;YAC3B,MAAM,IAAI,KAAK,CAAC,8DAA8D,CAAC,CAAC;QACjF,CAAC;QACD,IAAI,CAAC,MAAM,GAAG,IAAI,UAAU,CAAC,CAAC,KAAK,GAAG,IAAI,EAAE,CAAC,KAAK,IAAI,CAAC,CAAC,GAAG,IAAI,EAAE,CAAC,KAAK,IAAI,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,KAAK,IAAI,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC1H,IAAI,CAAC,MAAM,GAAG,MAAM,CAAC;QACrB,IAAI,CAAC,KAAK,GAAG,KAAK,CAAC;QACnB,OAAO,KAAK,CAAC;IACd,CAAC;IACD,WAAW;QACV,IAAI,CAAC,IAAI,CAAC,KAAK,EAAE,CAAC;YACjB,IAAI,CAAC,GAAG,GAAG,IAAI,UAAU,CAAC,IAAI,CAAC,MAAO,GAAE,IAAI,CAAE,KAAM,GAAG,CAAC,CAAC,CAAC;YAC1D,IAAI,KAAK,GAAG,CAAC,CAAC;YACd,IAAI,CAAC,KAAK,GAAG,CAAC,IAAW,EAAE,EAAE;gBAC5B,IAAI,CAAC,GAAI,CAAC,GAAG,CAAC,IAAI,EAAE,KAAK,CAAC,CAAC;gBAC3B,KAAK,IAAI,IAAI,CAAC,MAAM,CAAC;YACtB,CAAC,CAAC;QACH,CAAC;QACD,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,MAAO,CAAC,CAAC;QACzB,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,MAAO,CAAC,CAAC;QACzB,IAAI,IAAI,CAAC,KAAM,GAAG,IAAI,CAAC,MAAO,CAAC,MAAM,EAAE,CAAC;YACvC,IAAI,CAAC,KAAK,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,EAAE,IAAI,CAAC,KAAM,GAAG,IAAI,CAAC,MAAO,CAAC,MAAM,CAAC,CAAC,CAAC;QAChE,CAAC;IACF,CAAC;IACD,YAAY;QACX,IAAI,OAAO,GAAG,CAAC,CAAC;QAChB,KAAK,MAAM,CAAC,IAAI,EAAE,MAAM,CAAC,IAAI,IAAI,CAAC,GAAI,EAAE,CAAC;YACxC,IAAI,CAAC,KAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;YAC1B,OAAO,IAAI,MAAM,CAAC,KAAK,CAAC,MAAM,CAAC;YAC/B,IAAI,MAAM,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC;gBAC7B,IAAI,CAAC,KAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,MAAM,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC;gBAC5D,OAAO,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;YAC1C,CAAC;YACD,MAAM,GAAG,GAAa,IAAI,CAAC,GAAG,CAAC,IAAI,CAAgC,CAAC,cAAc,CAAC,CAAC,CAAC,CAAC,CAAC;YACvF,IAAI,CAAC,CAAC,GAAG,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,KAAK,OAAO,EAAE,CAAC;gBAClC,MAAM,IAAI,KAAK,CAAC,uCAAuC,IAAI,MAAM,GAAG,QAAQ,OAAO,EAAE,CAAC,CAAC;YACxF,CAAC;QACF,CAAC;IACF,CAAC;CACD;AAED,sFAAsF;AACtF,oFAAoF;AACpF,qFAAqF;AACrF,iFAAiF;AACjF,wCAAwC;AACxC,MAAa,SAAS;IACrB,IAAI,GAA2B,IAAI,GAAG,EAAE,CAAC;IACzC,QAAQ,GAAwB,IAAI,GAAG,EAAE,CAAC;IAC1C,gBAAe,CAAC;IAChB,qEAAqE;IACrE,SAAS,CAAC,IAAgB;QACzB,OAAO,IAAI,CAAC,IAAI,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;IAC5B,CAAC;IAMD,SAAS,CAAC,IAAgB,EAAE,IAAuB,EAAE,MAAe,EAAE,KAAa;QAClF,IAAI,IAAI,CAAC,IAAI,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC;YACzB,MAAM,IAAI,KAAK,CAAC,aAAa,IAAI,oCAAoC,CAAC,CAAC;QACxE,CAAC;QACD,IAAI,IAAI,YAAY,SAAS,EAAE,CAAC;YAC/B,OAAO,IAAI,CAAC,UAAU,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC;QACpC,CAAC;QACD,IAAI,CAAC,MAAM,IAAI,CAAC,KAAK,EAAE,CAAC;YACvB,MAAM,IAAI,KAAK,CAAC,yDAAyD,IAAI,IAAI,CAAC,CAAC;QACpF,CAAC;aAAM,CAAC;YACP,OAAO,IAAI,CAAC,UAAU,CAAC,IAAI,EAAE,IAAI,SAAS,CAAC,IAAI,EAAE,IAAI,EAAE,MAAM,EAAE,KAAK,CAAC,CAAC,CAAC;QACxE,CAAC;IACF,CAAC;IAED,8EAA8E;IAC9E,sEAAsE;IACtE,SAAS,CAAC,IAAgB,EAAE,MAAiB;QAC5C,IAAI,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;QACxB,OAAO,IAAI,CAAC,SAAS,CAAC,IAAI,EAAE,MAAM,CAAC,CAAC;IACrC,CAAC;IAED,8EAA8E;IAC9E,8CAA8C;IAC9C,eAAe,CAAC,IAAgB,EAAE,OAAwB;QACzD,IAAI,GAAG,GAAG,IAAI,CAAC,IAAI,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;QAC9B,IAAI,GAAG,EAAE,CAAC;YACT,OAAO,GAAG,CAAC;QACZ,CAAC;QACD,GAAG,GAAG,OAAO,EAAE,CAAC;QAChB,IAAI,CAAC,IAAI,CAAC,GAAG,CAAC,IAAI,EAAE,GAAG,CAAC,CAAC;QACzB,OAAO,GAAG,CAAC;IACZ,CAAC;IAED,mEAAmE;IACnE,YAAY,CAAC,IAAgB;QAC5B,OAAO,IAAI,CAAC,QAAQ,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;IAChC,CAAC;IAEE,YAAY,CAAC,IAAgB,EAAE,KAAa;QACxC,IAAI,CAAC,QAAQ,CAAC,GAAG,CAAC,IAAI,EAAE,KAAK,CAAC,CAAC;IACnC,CAAC;IAEJ,IAAI,WAAW;QACd,OAAO,IAAI,CAAC,QAAQ,CAAC;IACtB,CAAC;IAED,IAAI,UAAU;QACb,OAAO,IAAI,CAAC,IAAI,CAAC;IAClB,CAAC;IAEE,cAAc,CAAC,QAA6B;QACxC,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;IAC7B,CAAC;IAEJ,6DAA6D;IAC7D,YAAY,CAAC,IAA4B;QACxC,IAAI,GAA0B,CAAC;QAC/B,IAAI,OAAO,IAAI,KAAK,QAAQ,EAAE,CAAC;YAC9B,GAAG,GAAG,IAAI,CAAC,IAAI,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;YAC1B,IAAI,CAAC,GAAG,EAAE,CAAC;gBACV,OAAO;YACR,CAAC;QACF,CAAC;aAAM,CAAC;YACP,GAAG,GAAG,IAAI,CAAC;QACZ,CAAC;QACD,IAAI,GAAG,CAAC,MAAM,KAAK,IAAI,EAAE,CAAC;YACzB,MAAM,IAAI,KAAK,CAAC,qEAAqE,CAAC,CAAC;QACxF,CAAC;QACD,OAAO,GAAG,CAAC,OAAO,CAAC;QACnB,IAAI,CAAC,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;IAC5B,CAAC;IAEO,UAAU,CAAC,IAAgB,EAAE,MAAiB;QACrD,IAAI,MAAM,CAAC,MAAM,IAAI,MAAM,CAAC,MAAM,KAAK,IAAI,EAAE,CAAC;YAC7C,MAAM,IAAI,KAAK,CAAC,6EAA6E,CAAC,CAAC;QAChG,CAAC;QACD,IAAI,CAAC,IAAI,CAAC,GAAG,CAAC,IAAI,EAAE,MAAM,CAAC,CAAC;QAC5B,MAAM,CAAC,OAAO,GAAG,IAAI,CAAC;QACtB,OAAO,MAAM,CAAC;IACf,CAAC;CACD;AA7FD,8BA6FC;AAAA,CAAC;AAEF,qFAAqF;AACrF,wCAAwC;AACxC,MAAa,SAAS;IAGkB;IAAuB;IAAyB;IAFvF,sFAAsF;IACtF,2FAA2F;IAC3F,YAAY,IAAgB,EAAW,KAAY,EAAW,MAAc,EAAW,KAAY;QAA5D,UAAK,GAAL,KAAK,CAAO;QAAW,WAAM,GAAN,MAAM,CAAQ;QAAW,UAAK,GAAL,KAAK,CAAO;QAClG,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC;IACnB,CAAC;IACD,oEAAoE;IACpE,IAAI,MAAM,KAA4B,OAAO,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC;IACpD,KAAK,CAAa;IAC1B,OAAO,CAAa;IACpB,gFAAgF;IAChF,gFAAgF;IAChF,IAAI,IAAI,KAAa,OAAO,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC;IACzC,IAAI,IAAI,CAAC,GAAW;QACnB,IAAI,IAAI,CAAC,OAAO,EAAE,CAAC;YAClB,IAAI,IAAI,CAAC,OAAO,CAAC,IAAI,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,CAAC;gBAChC,MAAM,IAAI,KAAK,CAAC,YAAY,GAAG,0CAA0C,CAAC,CAAC;YAC5E,CAAC;YACD,IAAI,CAAC,OAAO,CAAC,IAAI,CAAC,MAAM,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;QACtC,CAAC;QACD,IAAI,CAAC,KAAK,GAAG,GAAG,CAAC;QACjB,IAAI,IAAI,CAAC,OAAO,EAAE,CAAC;YAClB,IAAI,CAAC,OAAO,CAAC,IAAI,CAAC,GAAG,CAAC,IAAI,CAAC,KAAK,EAAE,IAAI,CAAC,CAAC;QACzC,CAAC;IACF,CAAC;IACD,wDAAwD;IACxD,gBAAgB;QACf,IAAI,IAAI,CAAC,OAAO,EAAE,CAAC;YAClB,IAAI,CAAC,OAAO,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;QACjC,CAAC;IACF,CAAC;IACD,mCAAmC;IACnC,WAAW;QACV,WAAW,CAAC,IAAI,CAAC,CAAC;IACnB,CAAC;CACD;AAnCD,8BAmCC;AAAA,CAAC;AAGF,kFAAkF;AAClF,iFAAiF;AACjF,kFAAkF;AAClF,+CAA+C;AAC/C,SAAgB,wBAAwB,CAAC,aAAsB,EAAE,KAAY,EAAE,QAAiB;IAC5F,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,GAAG,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC;QAC/E,MAAM,IAAI,KAAK,CAAC,kDAAkD,CAAC,CAAC;IACxE,CAAC;IACD,MAAM,GAAG,GAAG,mBAAmB,CAAC,KAAK,CAAC,CAAC;IACvC,IAAI,GAAG,GAAG,GAAG,GAAC,IAAI,GAAC,IAAI,IAAI,CAAC,aAAa,EAAE,CAAC;QACxC,MAAM,IAAI,cAAc,CAAC,6DAA6D,CAAC,CAAC;IAC5F,CAAC;IACD,IAAI,QAAQ,GAAG,CAAC,GAAC,GAAG,EAAE,CAAC;QACnB,MAAM,IAAI,KAAK,CAAC,8DAA8D,CAAC,CAAC;IACpF,CAAC;IACD,OAAO,GAAG,CAAC;AACf,CAAC;AAED,SAAgB,mBAAmB,CAAC,KAAY;IAC5C,OAAO,CAAC,CAAC,KAAK,CAAC,CAAC,CAAE,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,CAAE,GAAG,CAAC,CAAC,GAAG,GAAG,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,CAAE,GAAG,CAAC,CAAC,GAAG,GAAG,GAAG,GAAG,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,CAAE,GAAG,CAAC,CAAC,GAAG,GAAG,GAAG,GAAG,GAAG,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC;AACjI,CAAC;AAED,2DAA2D;AAC3D,SAAS,YAAY,CAAC,CAAQ,EAAE,CAAQ;IACpC,IAAI,CAAC,CAAC,MAAM,KAAK,CAAC,CAAC,MAAM,EAAE,CAAC;QACxB,OAAO,KAAK,CAAC;IACjB,CAAC;IACD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QAChC,IAAI,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC;YAChB,OAAO,KAAK,CAAC;QACjB,CAAC;IACL,CAAC;IACD,OAAO,IAAI,CAAC;AAChB,CAAC;AAED,SAAgB,wBAAwB,CAAC,aAAsB,EAAE,CAAS,EAAE,SAAkB;IAC1F,MAAM,OAAO,GAAiB,EAAE,CAAC;IACjC,KAAK,MAAM,CAAC,IAAI,EAAE,KAAK,CAAC,IAAI,MAAM,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC;QAC5C,IAAI,IAAI,KAAK,cAAc,EAAE,CAAC;YAC1B,aAAa,CAAC,aAAa,EAAE,KAAK,CAAC,CAAC;YACpC,SAAS;QACb,CAAC;QACP,MAAM,GAAG,GAAG,IAAI,sBAAsB,CAAC,SAAS,EAAE,OAAO,EAAE,aAAa,CAAC,CAAC;QACpE,GAAG,CAAC,gBAAgB,CAAC,IAAI,EAAE,KAAK,CAAC,CAAC;IACtC,CAAC;AACL,CAAC;AAED,SAAS,aAAa,CAAC,aAAsB,EAAE,KAAa;IACxD,KAAK,MAAM,CAAC,GAAG,EAAE,MAAM,CAAC,IAAI,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE,CAAC;QAChD,IAAI,OAAO,MAAM,KAAK,QAAQ,IAAI,CAAC,aAAa,EAAE,CAAC;YAC/C,MAAM,IAAI,cAAc,CAAC,sBAAsB,GAAG,qBAAqB,OAAO,MAAM,IAAI,CAAC,CAAC;QAC9F,CAAC;IACL,CAAC;AACL,CAAC;AAED,MAAM,sBAAsB;IACH;IAA6B;IAAgC;IAAuC;IAAmC;IAAmC;IAA/L,YAAqB,SAAkB,EAAW,OAAqB,EAAW,gBAAyB,KAAK,EAAS,YAAqB,KAAK,EAAS,YAAqB,KAAK,EAAS,mBAA4B,KAAK;QAA3M,cAAS,GAAT,SAAS,CAAS;QAAW,YAAO,GAAP,OAAO,CAAc;QAAW,kBAAa,GAAb,aAAa,CAAiB;QAAS,cAAS,GAAT,SAAS,CAAiB;QAAS,cAAS,GAAT,SAAS,CAAiB;QAAS,qBAAgB,GAAhB,gBAAgB,CAAiB;IAAG,CAAC;IACpO,KAAK,CAAC,GAAW,EAAE,GAAQ;QACvB,IAAI,OAAO,GAAG,KAAK,QAAQ,EAAE,CAAC;YAC1B,MAAM,IAAI,KAAK,CAAC,iBAAiB,GAAG,qBAAqB,OAAO,GAAG,IAAI,CAAC,CAAC;QAC7E,CAAC;QACD,IAAI,CAAC,SAAS,GAAG,IAAI,CAAC;IAC1B,CAAC;IACD,KAAK,CAAC,GAAW,EAAE,GAAQ;QACvB,IAAI,CAAC,gBAAgB,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC;QAChC,IAAI,CAAC,SAAS,GAAG,IAAI,CAAC;IAC1B,CAAC;IACD,YAAY,CAAC,GAAW,EAAE,GAAQ;QAC9B,IAAI,CAAC,gBAAgB,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC;QAChC,IAAI,CAAC,gBAAgB,GAAG,IAAI,CAAC;IACjC,CAAC;IACD,MAAM,CAAC,GAAW;QACd,OAAO,GAAG,KAAK,OAAO,IAAI,GAAG,KAAK,OAAO,IAAI,GAAG,KAAK,cAAc,CAAC;IACxE,CAAC;IACD,uBAAuB,CAAC,IAAgB;QACpC,IAAI,CAAC,IAAI,CAAC,SAAS,EAAE,CAAC;YAClB,MAAM,IAAI,KAAK,CAAC,eAAe,IAAI,6BAA6B,CAAC,CAAC;QACtE,CAAC;QACD,IAAI,CAAC,IAAI,CAAC,SAAS,EAAE,CAAC;YAClB,MAAM,IAAI,KAAK,CAAC,eAAe,IAAI,6BAA6B,CAAC,CAAC;QACtE,CAAC;QACD,IAAI,CAAC,IAAI,CAAC,gBAAgB,EAAE,CAAC;YACzB,MAAM,IAAI,KAAK,CAAC,eAAe,IAAI,oCAAoC,CAAC,CAAC;QAC7E,CAAC;IACL,CAAC;IACO,gBAAgB,CAAC,GAAW,EAAE,GAAW;QAC7C,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,GAAG,CAAC,EAAE,CAAC;YACtB,MAAM,IAAI,KAAK,CAAC,iBAAiB,GAAG,qBAAqB,OAAO,GAAG,GAAG,CAAC,CAAC;QAC5E,CAAC;QACD,KAAK,MAAM,CAAC,IAAI,GAAG,EAAE,CAAC;YAClB,IAAI,OAAO,CAAC,KAAK,QAAQ,IAAI,CAAC,IAAI,CAAC,aAAa,EAAE,CAAC;gBAC/C,MAAM,IAAI,cAAc,CAAC,iBAAiB,GAAG,gCAAgC,OAAO,CAAC,GAAG,CAAC,CAAC;YAC9F,CAAC;QACL,CAAC;IACL,CAAC;IACO,gBAAgB,CAAC,GAAW,EAAE,GAAW;QACnD,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,GAAG,CAAC,IAAI,GAAG,CAAC,MAAM,KAAK,CAAC,EAAE,CAAC;YAC7C,MAAM,IAAI,KAAK,CAAC,wBAAwB,GAAG,mCAAmC,OAAO,GAAG,GAAG,CAAC,CAAC;QAC9F,CAAC;QACD,IAAI,CAAC,qBAAqB,CAAC,GAAG,EAAE,GAAiB,CAAC,CAAC;QAC7C,KAAK,MAAM,CAAC,KAAK,EAAE,GAAG,CAAC,IAAI,IAAI,CAAC,OAAO,EAAE,CAAC;YACtC,IAAI,CAAC,kBAAkB,CAAC,GAAG,EAAE,KAAK,EAAE,GAAG,EAAE,GAAiB,CAAC,CAAC;QAChE,CAAC;QACD,IAAI,CAAC,OAAO,CAAC,IAAI,CAAC,GAAiB,CAAC,CAAC;IACzC,CAAC;IACJ,qBAAqB,CAAC,GAAW,EAAE,GAAe;QACjD,IAAI,GAAG,GAAiB,IAAI,CAAC;QAC7B,IAAI,CAAC,WAAW,CAAC,GAAG,EAAE,IAAI,CAAC,SAAS,CAAC,IAAI,CAAC,IAAI,CAAC,aAAa,EAAE,CAAC;YAC9D,GAAG,GAAG,IAAI,cAAc,CAAC,wBAAwB,GAAG,qBAAqB,GAAG,CAAC,CAAC,CAAC,IAAI,GAAG,CAAC,CAAC,CAAC,WAAW,IAAI,CAAC,SAAS,IAAI,CAAC,CAAC;QACzH,CAAC;QACD,IAAI,GAAG,IAAI,CAAC,IAAI,CAAC,aAAa,EAAE,CAAC;YAChC,MAAM,GAAG,CAAC;QACX,CAAC;IACF,CAAC;IACD,kBAAkB,CAAC,GAAW,EAAE,KAAc,EAAE,GAAY,EAAE,GAAe;QAC5E,IAAI,GAAG,GAAiB,IAAI,CAAC;QAC7B,IAAI,GAAG,CAAC,CAAC,CAAC,GAAG,KAAK,IAAI,GAAG,CAAC,CAAC,CAAC,IAAI,GAAG,EAAE,CAAC;YACrC,GAAG,GAAG,IAAI,cAAc,CAAC,4BAA4B,GAAG,kCAAkC,GAAG,CAAC,CAAC,CAAC,MAAM,KAAK,OAAO,GAAG,CAAC,CAAC,CAAC,OAAO,GAAG,IAAI,CAAC,CAAC;QACzI,CAAC;QACD,IAAI,GAAG,CAAC,CAAC,CAAC,IAAI,KAAK,IAAI,GAAG,CAAC,CAAC,CAAC,GAAG,GAAG,EAAE,CAAC;YACrC,GAAG,GAAG,IAAI,cAAc,CAAC,8BAA8B,GAAG,kCAAkC,GAAG,CAAC,CAAC,CAAC,OAAO,KAAK,OAAO,GAAG,CAAC,CAAC,CAAC,MAAM,GAAG,IAAI,CAAC,CAAC;QAC3I,CAAC;QACD,IAAI,GAAG,IAAI,CAAC,IAAI,CAAC,aAAa,EAAE,CAAC;YAChC,MAAM,GAAG,CAAC;QACX,CAAC;IACF,CAAC;IACD,gBAAgB,CAAC,IAAgB,EAAE,KAAa;QAC/C,KAAK,MAAM,CAAC,GAAG,EAAE,GAAG,CAAC,IAAI,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE,CAAC;YAChD,IAAI,CAAC,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,EAAE,CAAC;gBACvB,IAAI,CAAC,IAAI,CAAC,aAAa,EAAE,CAAC;oBACzB,MAAM,IAAI,cAAc,CAAC,+BAA+B,IAAI,yBAAyB,GAAG,IAAI,CAAC,CAAC;gBAC/F,CAAC;gBACD,SAAS;YACV,CAAC;YACA,IAAI,CAAC,GAAmC,CAAqC,CAAC,IAAI,EAAE,GAAG,CAAC,CAAC;QAC3F,CAAC;QACD,IAAI,CAAC,uBAAuB,CAAC,IAAI,CAAC,CAAC;IACpC,CAAC;CACD;AAED,SAAS,WAAW,CAAC,IAAgB,EAAE,SAAkB;IACxD,OAAO,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,SAAS,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,SAAS,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,IAAI,CAAC,CAAC,CAAC,CAAC;AAC3G,CAAC;AAID,SAAgB,WAAW,CAAC,KAAwB,EAAE,MAAe,EAAE,KAAa;IACnF,IAAI,KAAK,YAAY,SAAS,EAAE,CAAC;QAChC,YAAY,CAAC,KAAK,CAAC,IAAI,EAAE,KAAK,CAAC,KAAK,EAAE,KAAK,CAAC,MAAM,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC;IAClE,CAAC;SAAM,CAAC;QACP,YAAY,CAAC,SAAS,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,CAAC,CAAA;IAC9C,CAAC;AACF,CAAC;AAED,SAAS,YAAY,CAAC,IAAgB,EAAE,KAAY,EAAE,MAAe,EAAE,KAAa;IACnF,IAAI,CAAC,MAAM,IAAI,CAAC,KAAK,EAAE,CAAC;QACvB,MAAM,IAAI,KAAK,CAAC,eAAe,IAAI,4CAA4C,CAAC,CAAC;IAClF,CAAC;IACD,0BAA0B;IAC1B,IAAI,KAAK,CAAC,MAAM,KAAK,CAAC,EAAE,CAAC;QACxB,KAAK,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,WAAW;IACzB,CAAC;IACD,IAAI,GAAG,GAAG,CAAC,CAAC;IACZ,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,KAAK,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QACvC,GAAG,IAAI,KAAK,CAAC,CAAC,CAAE,CAAC;IAClB,CAAC;IACD,MAAM,EAAE,GAAG,YAAY,CAAC,MAAM,CAAC,CAAC;IAChC,IAAI,KAAK,CAAC,MAAM,KAAK,GAAG,GAAG,EAAE,EAAE,CAAC;QAC/B,MAAM,IAAI,cAAc,CAAC,eAAe,IAAI,uBAAuB,KAAK,CAAC,MAAM,wBAAwB,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,gBAAgB,MAAM,gBAAgB,GAAG,GAAG,EAAE,SAAS,CAAC,CAAC;IACtL,CAAC;AACF,CAAC;AAED,2EAA2E;AAC3E,oDAAoD;AACpD,SAAgB,YAAY,CAAC,MAAc;IAC1C,MAAM,GAAG,GAAG,MAAM,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;IAChC,IAAI,CAAC,GAAG,EAAE,CAAC;QACV,MAAM,IAAI,KAAK,CAAC,uBAAuB,MAAM,kDAAkD,CAAC,CAAC;IAClG,CAAC;IACD,MAAM,MAAM,GAAG,QAAQ,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,GAAG,CAAC,CAAC;IACxC,iEAAiE;IACjE,yEAAyE;IACzE,IAAI,CAAC,MAAM,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,MAAM,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC;QACvC,MAAM,IAAI,KAAK,CAAC,uBAAuB,MAAM,+CAA+C,CAAC,CAAC;IAC/F,CAAC;IACD,OAAO,MAAM,GAAC,GAAG,CAAC;AACnB,CAAC"} \ No newline at end of file diff --git a/package-lock.json b/package-lock.json new file mode 100644 index 0000000..3beb9f9 --- /dev/null +++ b/package-lock.json @@ -0,0 +1,30 @@ +{ + "name": "safetensors-parser", + "version": "0.0.1", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "safetensors-parser", + "version": "0.0.1", + "license": "MIT", + "devDependencies": { + "typescript": "^5.0.4" + } + }, + "node_modules/typescript": { + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.7.2.tgz", + "integrity": "sha512-i5t66RHxDvVN40HfDd1PsEThGNnlMCMT3jMUuoh9/0TaqWevNontacunWyN02LA9/fIbEWlcHZcgTKb9QoaLfg==", + "dev": true, + "license": "Apache-2.0", + "bin": { + "tsc": "bin/tsc", + "tsserver": "bin/tsserver" + }, + "engines": { + "node": ">=14.17" + } + } + } +} diff --git a/package.json b/package.json new file mode 100644 index 0000000..5d22c42 --- /dev/null +++ b/package.json @@ -0,0 +1,34 @@ +{ + "name": "@reve-ai/safetensors-parser", + "version": "0.0.1", + "description": "Parse and generate safetensor files for use with pytorch.", + "exports": { + ".": "./dist/src/index.js" + }, + "files": [ + "dist/src", + "package.json", + "README.md", + "LICENSE.md" + ], + "scripts": { + "build": "rm -rf dist; mkdir -p dist; tsc", + "test": "npm run build && node dist/test/index-test.js" + }, + "repository": { + "type": "git", + "url": "github.com/reve-ai/safetensors-parser" + }, + "devDependencies": { + "typescript": "^5.0.4" + }, + "keywords": [ + "huggingface", + "tensor", + "safetensor", + "pytorch", + "typescript" + ], + "author": "Jon Watte", + "license": "MIT" +} diff --git a/src/index.ts b/src/index.ts new file mode 100644 index 0000000..e634619 --- /dev/null +++ b/src/index.ts @@ -0,0 +1,470 @@ +// Bytes is the byte array type (Uint8Array) +export type Bytes = Uint8Array; +// Format is typically BF16 or FP16 or somesuch, but it's not really checked +// or used by the safetensors library itself -- it just deals with byte arrays. +export type Format = string; +// The only numbers we deal with are integers, so document that. +export type Integer = number; +// Offset pair is [start, end) of the array sub-range. +export type OffsetPair = [Integer, Integer]; +// Shape is an n-dimensional array of integers. +export type Shape = Integer[]; +// Tensor names are strings. +export type TensorName = string; + +export class IgnorableError extends Error { + constructor(message: string) { + super(message + " You can ignore this error with ignoreInvalid=true in parseSafeTensors()."); + this.name = "IgnorableError"; + } +} + +// Given a safetensors file (as a Uint8Array) return the TensorMap of the +// tensors and metadata stored in that file. If the file is not fully compliant +// with the spec (additional properties, etc) an error will be thrown. If you'd +// like to try using it anyway, you can pass ignoreInvalid=true. +export function parseSafeTensors(bytes: Bytes, ignoreInvalid?: boolean): TensorMap { + sanityCheckTensorsHeader(!!ignoreInvalid, bytes, bytes.length); + const hdrsize = unsafeGetHeaderSize(bytes); + const ret = new TensorMap(); + const j: Object = JSON.parse(new TextDecoder().decode(bytes.slice(8, 8+hdrsize))); + sanityCheckTensorsParsed(!!ignoreInvalid, j, bytes.length - 8 - hdrsize); + for (const [name, val] of Object.entries(j)) { + if (name === "__metadata__") { + ret.setAllMetadata(new Map(Object.entries(val as Object))); + } else { + const {dtype, shape, data_offsets} = val as { dtype: string, shape: Shape, data_offsets: OffsetPair }; + ret.addTensor(name, bytes.slice(8+hdrsize+data_offsets[0], 8+hdrsize+data_offsets[1]), dtype, shape); + } + } + return ret; +} + +// Given a TensorMap, generate the contents of a safetensors file that contains +// those tensors, and the added metadata. This will throw if you run out of memory, +// OR if you try to save a tensor with metadata that is too large. There's a +// maximum header/metadata limit of 100 MB total in the file format spec. +// You can write by passing in an empty Uint8Array, or you can pass in a function +// that will be called with chunks of data in order of writing. +// This implementation will align the start of each tensor on a multiple of 8 bytes +// if it's less than that. +export function saveSafeTensors(tensorMap: TensorMap): Uint8Array +export function saveSafeTensors(tensorMap: TensorMap, write: (data: Bytes) => void): undefined +export function saveSafeTensors(tensorMap: TensorMap, write?: (data: Bytes) => void): Uint8Array | undefined { + const save = new TensorSaver(tensorMap, write); + save.calcOffset(); + save.calcHeader(); + save.writeHeader(); + save.writeTensors(); + return save.ret; +} + +const padder = new Uint8Array([32, 32, 32, 32, 32, 32, 32, 32]); + +// Simple helper to format the safetensors file with header and data chunks. +class TensorSaver { + readonly hdr: { [key: string]: {dtype: string, shape: Shape, data_offsets: OffsetPair} } & { "__metadata__"?: {[key: string]: string} } = {}; + constructor(readonly tensorMap: TensorMap, private write?: (data: Bytes) => void) { + this.setMetadata(); + } + tes?: Array<[string, TensorRef]>; + offset?: Integer; + hdrBuf?: Uint8Array; + lenbuf?: Uint8Array; + hblen?: Integer; + ret?: Uint8Array; + setMetadata() { + const md = {} as {[key: string]: string}; + let toset = false; + for (const [name, val] of this.tensorMap.allMetadata.entries()) { + md[name] = val; + toset = true; + } + if (toset) { + this.hdr["__metadata__"] = md; + } + } + calcOffset(): Integer { + let offset = 0; + this.tes = Array.from(this.tensorMap.allTensors.entries()); + for (const [name, tensor] of this.tes) { + this.hdr[name] = { + dtype: tensor.format, + shape: tensor.shape, + data_offsets: [offset, offset + tensor.bytes.length] + }; + offset += (tensor.bytes.length + 7) & ~7; + } + this.offset = offset; + return offset; + } + calcHeader(): Integer { + const hdrBuf = (new TextEncoder()).encode(JSON.stringify(this.hdr)); + const hblen = (hdrBuf.length + 7) & ~7; + if (hblen > 100*1024*1024) { + throw new Error("The metadata is too large to be saved in a safetensors file."); + } + this.lenbuf = new Uint8Array([hblen & 0xff, (hblen >> 8) & 0xff, (hblen >> 16) & 0xff, (hblen >> 24) & 0xff, 0, 0, 0, 0]); + this.hdrBuf = hdrBuf; + this.hblen = hblen; + return hblen; + } + writeHeader() { + if (!this.write) { + this.ret = new Uint8Array(this.offset! +this. hblen! + 8); + let wrote = 0; + this.write = (data: Bytes) => { + this.ret!.set(data, wrote); + wrote += data.length; + }; + } + this.write(this.lenbuf!); + this.write(this.hdrBuf!); + if (this.hblen! > this.hdrBuf!.length) { + this.write(padder.slice(0, this.hblen! - this.hdrBuf!.length)); + } + } + writeTensors() { + let written = 0; + for (const [name, tensor] of this.tes!) { + this.write!(tensor.bytes); + written += tensor.bytes.length; + if (tensor.bytes.length & 7) { + this.write!(padder.slice(0, 8 - (tensor.bytes.length & 7))); + written += 8 - (tensor.bytes.length & 7); + } + const end: Integer = (this.hdr[name] as {data_offsets: OffsetPair})["data_offsets"][1]; + if (((end + 7) & ~7) !== written) { + throw new Error(`Internal tensor alignment problem: "${name}": ${end} !== ${written}`); + } + } + } +} + +// TensorMap is a collection of possibly multiple tensors. It has either been manually +// constructed, or loaded from a safetensors file. It can be used to return data (in +// little endian format) or to save the tensors to a safetensors formatted file. Note +// that tensor refs are not cloned -- if you mutate the underlying data, you will +// indirectly mutate the TensorMap, too! +export class TensorMap { + refs: Map = new Map(); + metadata: Map = new Map(); + constructor() {} + // If the tensor exists in the map, return it, else return undefined. + getTensor(name: TensorName): TensorRef | undefined { + return this.refs.get(name); + } + + // Add a new tensor to the map. If another tensor with the same name already exists, + // fail. Use setTensor() if you don't care. + addTensor(name: TensorName, bytes: Bytes, format: Format, shape: Shape): TensorRef + addTensor(name: TensorName, tensor: TensorRef): TensorRef + addTensor(name: TensorName, data: TensorRef | Bytes, format?: Format, shape?: Shape): TensorRef { + if (this.refs.get(name)) { + throw new Error(`The name "${name}" already exists in the TensorMap.`); + } + if (data instanceof TensorRef) { + return this._setTensor(name, data); + } + if (!format || !shape) { + throw new Error(`You must provide format and shape when adding tensor "${name}".`); + } else { + return this._setTensor(name, new TensorRef(name, data, format, shape)); + } + } + + // Set this name to be this tensor, no matter whether it exists or not. Remove + // any previous tensor of the same name (which may be the same value!) + setTensor(name: TensorName, tensor: TensorRef): TensorRef { + this.removeTensor(name); + return this.addTensor(name, tensor); + } + + // If the tensor exists, then return it. Otherwise, create it with the factory + // function, add it to the map, and return it. + getOrMakeTensor(name: TensorName, factory: () => TensorRef): TensorRef { + let ten = this.refs.get(name); + if (ten) { + return ten; + } + ten = factory(); + this.refs.set(name, ten); + return ten; + } + + // If some metadata value exists, return it, else return undefined. + getMetaValue(name: TensorName): string | undefined { + return this.metadata.get(name); + } + + setMetaValue(name: TensorName, value: string) { + this.metadata.set(name, value); + } + + get allMetadata(): Map { + return this.metadata; + } + + get allTensors(): Map { + return this.refs; + } + + setAllMetadata(metadata: Map) { + this.metadata = metadata; + } + + // Remove a tensor from the map. It's OK if it doesn't exist. + removeTensor(name: TensorName | TensorRef) { + let ten: TensorRef | undefined; + if (typeof name === "string") { + ten = this.refs.get(name); + if (!ten) { + return; + } + } else { + ten = name; + } + if (ten.parent !== this) { + throw new Error("You can only remove a tensor from the TensorMap that it belongs to."); + } + delete ten._parent; + this.refs.delete(ten.name); + } + + private _setTensor(name: TensorName, tensor: TensorRef): TensorRef { + if (tensor.parent && tensor.parent !== this) { + throw new Error("You can only add a tensor to a TensorMap that it doesn't already belong to."); + } + this.refs.set(name, tensor); + tensor._parent = this; + return tensor; + } +}; + +// TensorRef is a specific tensor within a safetensors archive. You can make one free +// standing, or get one from an archive. +export class TensorRef { + // Make a new tensorref referencing the given data. The new tensorref is not parented. + // Note that the shape and format are not sanity checked! Hopefully the bytes are accurate. + constructor(name: TensorName, readonly bytes: Bytes, readonly format: Format, readonly shape: Shape) { + this._name = name; + } + // A tensorref may have a parent tensor map, or may be freestanding. + get parent(): TensorMap | undefined { return this._parent; } + private _name: TensorName; + _parent?: TensorMap; + // If you mutate the name of a tensor in a tensor map, it will rename the tensor + // in the map. It's an error to rename to a name that already exists in the map. + get name(): string { return this._name; } + set name(val: string) { + if (this._parent) { + if (this._parent.refs.get(val)) { + throw new Error(`The name ${val} already exists in the parent TensorMap.`); + } + this._parent.refs.delete(this._name); + } + this._name = val; + if (this._parent) { + this._parent.refs.set(this._name, this); + } + } + // If the tensor is parented, remove it from the parent. + removeIfParented() { + if (this._parent) { + this._parent.removeTensor(this); + } + } + // Check whether the tensor is sane + sanityCheck() { + sanityCheck(this); + } +}; + + +// Given the bytes of a safetensors file, verify that the header seems legitimate. +// Note that you only need 10 bytes to check this. The value returned is the size +// of the JSON chunk that starts at offset 8, so if the total file size is smaller +// than 8+return, the file is likely truncated. +export function sanityCheckTensorsHeader(ignoreInvalid: boolean, bytes: Bytes, filesize: Integer): Integer { + if (!arrayCompare(Array.from(bytes.slice(4,9)), [0, 0, 0, 0, "{".charCodeAt(0)])) { + throw new Error("The file header is not a valid safetensors file."); + } + const val = unsafeGetHeaderSize(bytes); + if (val > 100*1024*1024 && !ignoreInvalid) { + throw new IgnorableError("The file header is too long to be a valid safetensors file."); + } + if (filesize < 8+val) { + throw new Error("The safetensors file seems truncated or otherwise not valid."); + } + return val; +} + +export function unsafeGetHeaderSize(bytes: Bytes): Integer { + return ((bytes[0]! | 0) + ((bytes[1]! | 0) * 256) + ((bytes[2]! | 0) * 256 * 256) + ((bytes[3]! | 0) * 256 * 256 * 256)) | 0; +} + +// Return true if each element compares equal in the arrays +function arrayCompare(a: Shape, b: Shape): boolean { + if (a.length !== b.length) { + return false; + } + for (let i = 0; i < a.length; i++) { + if (a[i] !== b[i]) { + return false; + } + } + return true; +} + +export function sanityCheckTensorsParsed(ignoreInvalid: boolean, j: Object, chunksize: Integer): void { + const covered: OffsetPair[] = []; + for (const [name, value] of Object.entries(j)) { + if (name === "__metadata__") { + checkMetadata(ignoreInvalid, value); + continue; + } + const tac = new TensorAttributeChecker(chunksize, covered, ignoreInvalid); + tac.checkTensorValue(name, value); + } +} + +function checkMetadata(ignoreInvalid: boolean, value: Object): void { + for (const [key, strval] of Object.entries(value)) { + if (typeof strval !== "string" && !ignoreInvalid) { + throw new IgnorableError(`The metadata value ${key} is not a string (${typeof strval}).`); + } + } +} + +class TensorAttributeChecker { + constructor(readonly chunksize: Integer, readonly covered: OffsetPair[], readonly ignoreInvalid: boolean = false, public has_dtype: boolean = false, public has_shape: boolean = false, public has_data_offsets: boolean = false) {} + dtype(key: string, val: any) { + if (typeof val !== "string") { + throw new Error(`The dtype for ${key} is not a string (${typeof val}).`); + } + this.has_dtype = true; + } + shape(key: string, val: any) { + this.checkTensorShape(key, val); + this.has_shape = true; + } + data_offsets(key: string, val: any) { + this.checkDataOffsets(key, val); + this.has_data_offsets = true; + } + attrOk(key: string): boolean { + return key === "dtype" || key === "shape" || key === "data_offsets"; + } + assertAttributePresence(name: TensorName) { + if (!this.has_dtype) { + throw new Error(`The tensor "${name}" is missing the dtype key.`); + } + if (!this.has_shape) { + throw new Error(`The tensor "${name}" is missing the shape key.`); + } + if (!this.has_data_offsets) { + throw new Error(`The tensor "${name}" is missing the data_offsets key.`); + } + } + private checkTensorShape(key: string, val: Object): void { + if (!Array.isArray(val)) { + throw new Error(`The shape for ${key} is not an array (${typeof val})`); + } + for (const v of val) { + if (typeof v !== "number" && !this.ignoreInvalid) { + throw new IgnorableError(`The shape for ${key} is not an array of numbers (${typeof v})`); + } + } + } + private checkDataOffsets(key: string, obj: Object): void { + if (!Array.isArray(obj) || obj.length !== 2) { + throw new Error(`The data_offsets for ${key} is not an array with length 2 (${typeof obj})`); + } + this.checkDataOffsetBasics(key, obj as OffsetPair); + for (const [start, end] of this.covered) { + this.checkTensorOverlap(key, start, end, obj as OffsetPair); + } + this.covered.push(obj as OffsetPair); + } + checkDataOffsetBasics(key: string, val: OffsetPair) { + let err: Error | null = null; + if (!pairInRange(val, this.chunksize) && !this.ignoreInvalid) { + err = new IgnorableError(`The data_offsets for ${key} is out of range (${val[0]}-${val[1]} versus ${this.chunksize}).`); + } + if (err && !this.ignoreInvalid) { + throw err; + } + } + checkTensorOverlap(key: string, start: Integer, end: Integer, val: OffsetPair): void { + let err: Error | null = null; + if (val[1] > start && val[1] <= end) { + err = new IgnorableError(`The data_offsets end for ${key} overlaps with another tensor (${val[1]} > ${start} && ${val[1]} <= ${end}).`); + } + if (val[0] >= start && val[0] < end) { + err = new IgnorableError(`The data_offsets start for ${key} overlaps with another tensor (${val[0]} >= ${start} && ${val[0]} < ${end}).`); + } + if (err && !this.ignoreInvalid) { + throw err; + } + } + checkTensorValue(name: TensorName, value: Object): void { + for (const [key, val] of Object.entries(value)) { + if (!this.attrOk(key)) { + if (!this.ignoreInvalid) { + throw new IgnorableError(`The tensor description for "${name}" has an invalid key "${key}".`); + } + continue; + } + (this[key as keyof TensorAttributeChecker] as (s: string, u: unknown) => void)(name, val); + } + this.assertAttributePresence(name); + } +} + +function pairInRange(pair: OffsetPair, chunksize: Integer): boolean { + return pair[0] >= 0 && pair[0] <= chunksize && pair[1] >= 0 && pair[1] <= chunksize && pair[0] <= pair[1]; +} + +export function sanityCheck(tensor: TensorRef): void +export function sanityCheck(bytes: Bytes, format: Format, shape: Shape): void +export function sanityCheck(bytes: TensorRef | Bytes, format?: Format, shape?: Shape): void { + if (bytes instanceof TensorRef) { + _sanityCheck(bytes.name, bytes.bytes, bytes.format, bytes.shape); + } else { + _sanityCheck("created", bytes, format, shape) + } +} + +function _sanityCheck(name: TensorName, bytes: Bytes, format?: Format, shape?: Shape) { + if (!format || !shape) { + throw new Error(`The tensor "${name}" is missing format and shape information.`); + } + // Sanity check the shape. + if (shape.length === 0) { + shape = [1]; // a scalar + } + let num = 1; + for (let i = 0; i < shape.length; i++) { + num *= shape[i]!; + } + const fl = formatLength(format); + if (bytes.length !== num * fl) { + throw new IgnorableError(`The tensor "${name}" is the wrong size ${bytes.length} bytes for its shape ${JSON.stringify(shape)} and format "${format}": should be ${num * fl} bytes.`); + } +} + +// Get the byte length of a given format constant. This may be a fractional +// value if the format is something like 4-bit data. +export function formatLength(format: Format): Integer { + const num = format.match(/\d+/); + if (!num) { + throw new Error(`The element format "${format}" is not a valid format (must contain bit size).`); + } + const parsed = parseInt(num[0], 10) | 0; + // This is an attempt to support sub-byte precision. It will work + // for many cases, but doesn't deal with rounding-up of un-aligned sizes. + if ((parsed | 0) & ((parsed | 0) - 1)) { + throw new Error(`The element format "${format}" is not a valid format (must be power of 2).`); + } + return parsed/8.0; +} diff --git a/test/index-test.ts b/test/index-test.ts new file mode 100644 index 0000000..b2160c8 --- /dev/null +++ b/test/index-test.ts @@ -0,0 +1,111 @@ +import { parseSafeTensors, saveSafeTensors, TensorMap, TensorRef } from "@reve-ai/safetensors-parser"; + +const bytes = new Uint8Array([ + // header length: 64 + 64, 0, 0, 0, 0, 0, 0, 0, + // header (whitespace padded up to 64 bytes) + 0x7b, 0x22, 0x74, 0x65, 0x6e, 0x22, 0x3a, 0x7b, 0x22, 0x64, 0x74, 0x79, + 0x70, 0x65, 0x22, 0x3a, 0x22, 0x42, 0x46, 0x31, 0x36, 0x22, 0x2c, 0x22, + 0x73, 0x68, 0x61, 0x70, 0x65, 0x22, 0x3a, 0x5b, 0x32, 0x2c, 0x33, 0x5d, + 0x2c, 0x22, 0x64, 0x61, 0x74, 0x61, 0x5f, 0x6f, 0x66, 0x66, 0x73, 0x65, + 0x74, 0x73, 0x22, 0x3a, 0x5b, 0x30, 0x2c, 0x31, 0x32, 0x5d, 0x7d, 0x7d, + 0x20, 0x20, 0x20, 0x20, + // tensor data + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, + 0x20, 0x20, 0x20, 0x20, +]); + +const stuff = parseSafeTensors(bytes); +const ten = stuff.getTensor("ten"); +if (!ten) { + throw new Error("No tensor named ten"); +} +if (ten.format !== "BF16") { + throw new Error(`The tensor format is not BF16: ${ten.format}`); +} +if (ten.shape.length !== 2) { + throw new Error(`The tensor shape is not length 2: ${ten.shape.length}`); +} +if (ten.shape[0] !== 2) { + throw new Error(`The tensor shape is not [2, 3]: ${JSON.stringify(ten.shape)}`); +} +if (ten.shape[1] !== 3) { + throw new Error(`The tensor shape is not [2, 3]: ${JSON.stringify(ten.shape)}`); +} +if (ten.bytes.length !== 12) { + throw new Error(`The tensor bytes is not length 12: ${ten.bytes.length}`); +} +if (JSON.stringify(Array.from(ten.bytes)) !== JSON.stringify([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])) { + throw new Error(`The tensor bytes is not [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]: ${JSON.stringify(Array.from(ten.bytes))}`); +} + +const parsed = saveSafeTensors(stuff); + +const a1 = Array.from(bytes); +const a2 = Array.from(parsed); + +if (a1.length !== a2.length) { + throw new Error(`The lengths are different: ${a1.length} !== ${a2.length}`); +} +for (let i = 0; i < a1.length; i++) { + if (a1[i] !== a2[i]) { + // This check is a little brittle, because there's no guarantee + // that the JSON.stringify() will produce the same output as we + // read in the JSON.parse(). + throw new Error(`Byte ${i} is different: ${a1[i]} !== ${a2[i]}`); + } +} + +const tm1 = new TensorMap(); +tm1.addTensor("snark", new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), "UINT8", [1, 3, 4]); +tm1.addTensor("snork", new Uint8Array([0, 1, 1, 0]), "UINT8", [4]); +tm1.addTensor("snurk", new TensorRef("snurk", new Uint8Array([0, 1, 1, 0]), "UINT8", [2, 2])); +tm1.setMetaValue("foo", "bar"); +tm1.allMetadata.set("baz", "quux"); + +const tm2 = parseSafeTensors(saveSafeTensors(tm1)); + +if (tm1.allTensors.size !== tm2.allTensors.size) { + throw new Error(`The tensor maps have different sizes ${tm1.allTensors.size} !== ${tm2.allTensors.size}`); +} +const tm1td = Array.from(tm1.allTensors.entries()).sort(); +const tm2td = Array.from(tm2.allTensors.entries()).sort(); +if (tm1td.length !== tm2td.length) { + throw new Error(`The tensor maps have different lengths ${tm1td.length} !== ${tm2td.length}`); +} + for (let i = 0; i < tm1td.length; i++) { + if (tm1td[i]![0]! !== tm2td[i]![0]!) { + throw new Error(`The tensor maps have different keys: ${tm1td[i]![0]!} !== ${tm2td[i]![0]!}`); + } + if (tm1td[i]![1]!.name !== tm2td[i]![1]!.name) { + throw new Error(`The tensor maps have different names: ${tm1td[i]![1]!.name} !== ${tm2td[i]![1]!.name}`); + } + if (tm1td[i]![1]!.format !== tm2td[i]![1]!.format) { + throw new Error(`The tensor maps have different formats: ${tm1td[i]![1]!.format} !== ${tm2td[i]![1]!.format}`); + } + if (JSON.stringify(tm1td[i]![1]!.shape) !== JSON.stringify(tm2td[i]![1]!.shape)) { + throw new Error(`The tensor maps have different shapes: ${JSON.stringify(tm1td[i]![1]!.shape)} !== ${JSON.stringify(tm2td[i]![1]!.shape)}`); + } + if (JSON.stringify(Array.from(tm1td[i]![1]!.bytes)) !== JSON.stringify(Array.from(tm2td[i]![1]!.bytes))) { + throw new Error(`The tensor maps have different bytes: ${JSON.stringify(Array.from(tm1td[i]![1]!.bytes))} !== ${JSON.stringify(Array.from(tm2td[i]![1]!.bytes))}`); + } +} + +if (tm1.allMetadata.size !== tm2.allMetadata.size) { + throw new Error(`The metadata maps have different sizes ${tm1.allMetadata.size} !== ${tm2.allMetadata.size}`); +} +const tm1md = Array.from(tm1.allMetadata.entries()).sort(); +const tm2md = Array.from(tm2.allMetadata.entries()).sort(); +if (tm1md.length !== tm2md.length) { + throw new Error(`The metadata maps have different lengths ${tm1md.length} !== ${tm2md.length}`); +} +for (let i = 0; i < tm1md.length; i++) { + if (tm1md[i]![0]! !== tm2md[i]![0]!) { + throw new Error(`The metadata maps have different keys: ${tm1md[i]![0]!} !== ${tm2md[i]![0]!}`); + } + if (tm1md[i]![1]! !== tm2md[i]![1]!) { + throw new Error(`The metadata maps have different values: ${tm1md[i]![1]!} !== ${tm2md[i]![1]!}`); + } +} + +console.log("All tests passed."); diff --git a/tsconfig.json b/tsconfig.json new file mode 100644 index 0000000..c6dd37e --- /dev/null +++ b/tsconfig.json @@ -0,0 +1,32 @@ +{ + "compilerOptions": { + "allowImportingTsExtensions": false, + "composite": true, + "declaration": true, + "declarationMap": true, + "esModuleInterop": true, + "isolatedModules": true, + "lib": ["es2023", "WebWorker"], + "module": "nodenext", + "moduleResolution": "nodenext", + "noEmit": false, + "noErrorTruncation": true, + "noFallthroughCasesInSwitch": true, + "noImplicitOverride": true, + "noImplicitReturns": true, + "noUncheckedIndexedAccess": true, + "noUnusedLocals": true, + "outDir": "dist", + "pretty": true, + "resolveJsonModule": false, + "rootDirs": ["src", "test"], + "skipLibCheck": true, + "sourceMap": true, + "strict": true, + "target": "es2023", + "tsBuildInfoFile": "dist/tsconfig.tsbuildinfo", + "types": [], + "verbatimModuleSyntax": false + }, + "include": ["./src", "./test"] +}