-
Notifications
You must be signed in to change notification settings - Fork 214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
gguf: better type usage #655
Changes from 6 commits
e061017
9a9e771
6d704bc
8767726
c2afbdc
5f547dd
31bac8b
74e8cfd
8c1bce0
2e62e41
a2250d3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import { describe, it } from "vitest"; | ||
import type { gguf } from "./gguf"; | ||
import type { GGUFMetadata, GGUFParseOutput, GGUFType } from "./types"; | ||
|
||
describe("gguf-types", () => { | ||
it("gguf() type can be casted between STRICT and NON_STRICT (at compile time)", async () => { | ||
// eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
const result: Awaited<ReturnType<typeof gguf>> = null as any; | ||
const strictType = result as GGUFParseOutput<GGUFType.STRICT>; | ||
// @ts-expect-error because the key "abc" does not exist | ||
strictType.metadata.abc = 123; | ||
const nonStrictType = result as GGUFParseOutput<GGUFType.NON_STRICT>; | ||
nonStrictType.metadata.abc = 123; // PASS, because it can be anything | ||
// @ts-expect-error because ArrayBuffer is not a MetadataValue | ||
nonStrictType.metadata.fff = ArrayBuffer; | ||
}); | ||
|
||
it("GGUFType.NON_STRICT should be correct (at compile time)", async () => { | ||
// eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
const model: GGUFMetadata<GGUFType.NON_STRICT> = null as any; | ||
model.kv_count = 123n; | ||
model.abc = 456; // PASS, because it can be anything | ||
}); | ||
|
||
it("GGUFType.STRICT should be correct (at compile time)", async () => { | ||
// eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
const model: GGUFMetadata<GGUFType.STRICT> = null as any; | ||
|
||
if (model["general.architecture"] === "whisper") { | ||
model["encoder.whisper.block_count"] = 0; | ||
// @ts-expect-error because it must be a number | ||
model["encoder.whisper.block_count"] = "abc"; | ||
} | ||
|
||
if (model["tokenizer.ggml.model"] === undefined) { | ||
// @ts-expect-error because it's undefined | ||
model["tokenizer.ggml.eos_token_id"] = 1; | ||
} | ||
if (model["tokenizer.ggml.model"] === "gpt2") { | ||
// @ts-expect-error because it must be a number | ||
model["tokenizer.ggml.eos_token_id"] = undefined; | ||
model["tokenizer.ggml.eos_token_id"] = 1; | ||
} | ||
|
||
if (model["general.architecture"] === "mamba") { | ||
model["mamba.ssm.conv_kernel"] = 0; | ||
// @ts-expect-error because it must be a number | ||
model["mamba.ssm.conv_kernel"] = "abc"; | ||
} | ||
if (model["general.architecture"] === "llama") { | ||
// @ts-expect-error llama does not have ssm.* keys | ||
model["mamba.ssm.conv_kernel"] = 0; | ||
} | ||
}); | ||
}); |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,21 +50,32 @@ export enum GGUFValueType { | |
const ARCHITECTURES = [...LLM_ARCHITECTURES, "rwkv", "whisper"] as const; | ||
export type Architecture = (typeof ARCHITECTURES)[number]; | ||
|
||
interface General { | ||
"general.architecture": Architecture; | ||
"general.name": string; | ||
"general.file_type": number; | ||
"general.quantization_version": number; | ||
export interface GGUFGeneralInfo<TArchitecture extends Architecture> { | ||
"general.architecture": TArchitecture; | ||
"general.name"?: string; | ||
"general.file_type"?: number; | ||
"general.quantization_version"?: number; | ||
} | ||
|
||
type ModelMetadata = Whisper | RWKV | TransformerLLM; | ||
interface NoModelMetadata { | ||
"general.architecture"?: undefined; | ||
} | ||
|
||
export type ModelBase< | ||
TArchitecture extends | ||
| Architecture | ||
| `encoder.${Extract<Architecture, "whisper">}` | ||
| `decoder.${Extract<Architecture, "whisper">}`, | ||
> = { [K in `${TArchitecture}.layer_count`]: number } & { [K in `${TArchitecture}.feed_forward_length`]: number } & { | ||
[K in `${TArchitecture}.context_length`]: number; | ||
} & { [K in `${TArchitecture}.embedding_length`]: number } & { [K in `${TArchitecture}.block_count`]: number }; | ||
> = Record< | ||
| `${TArchitecture}.context_length` | ||
| `${TArchitecture}.block_count` | ||
| `${TArchitecture}.embedding_length` | ||
| `${TArchitecture}.feed_forward_length`, | ||
number | ||
>; | ||
|
||
/// Tokenizer | ||
|
||
type TokenizerModel = "no_vocab" | "llama" | "gpt2" | "bert"; | ||
interface Tokenizer { | ||
|
@@ -75,21 +86,43 @@ interface Tokenizer { | |
"tokenizer.ggml.bos_token_id": number; | ||
"tokenizer.ggml.eos_token_id": number; | ||
"tokenizer.ggml.add_bos_token": boolean; | ||
"tokenizer.chat_template": string; | ||
"tokenizer.chat_template"?: string; | ||
} | ||
interface NoTokenizer { | ||
"tokenizer.ggml.model"?: undefined; | ||
} | ||
|
||
/// Models outside of llama.cpp: "rwkv" and "whisper" | ||
|
||
export type RWKV = ModelBase<"rwkv"> & { "rwkv.architecture_version": number }; | ||
export type LLM = TransformerLLM | RWKV; | ||
export type Whisper = ModelBase<"encoder.whisper"> & ModelBase<"decoder.whisper">; | ||
export type Model = (LLM | Whisper) & Partial<Tokenizer>; | ||
export type RWKV = GGUFGeneralInfo<"rwkv"> & | ||
ModelBase<"rwkv"> & { | ||
"rwkv.architecture_version": number; | ||
}; | ||
|
||
export type GGUFMetadata = { | ||
// TODO: whisper.cpp doesn't yet support gguf. This maybe changed in the future. | ||
export type Whisper = GGUFGeneralInfo<"whisper"> & | ||
ModelBase<"encoder.whisper"> & | ||
ModelBase<"decoder.whisper"> & { | ||
"whisper.encoder.mels_count": number; | ||
"whisper.encoder.attention.head_count": number; | ||
"whisper.decoder.attention.head_count": number; | ||
}; | ||
|
||
/// Types for parse output | ||
|
||
export enum GGUFType { | ||
STRICT, | ||
NON_STRICT, | ||
} | ||
|
||
export type GGUFMetadata<TGGUFType extends GGUFType = GGUFType.STRICT> = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably switch to something like https://github.com/sindresorhus/type-fest/blob/main/source/except.d.ts interface GGUFMetadataOptions {
/**
* ...
*
* @default true
*/
strict: boolean;
}
export GGUFMetadata<Options extends GGUFMetadataOptions = { strict: true}> {
... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea 👍 I've never thought about this before. Implemented in 31bac8b |
||
version: Version; | ||
tensor_count: bigint; | ||
kv_count: bigint; | ||
} & Partial<General> & | ||
Partial<Model> & | ||
Record<string, MetadataValue>; | ||
} & GGUFModelKV & | ||
(TGGUFType extends GGUFType.STRICT ? unknown : Record<string, MetadataValue>); | ||
|
||
export type GGUFModelKV = (NoModelMetadata | ModelMetadata) & (NoTokenizer | Tokenizer); | ||
|
||
export interface GGUFTensorInfo { | ||
name: string; | ||
|
@@ -99,7 +132,7 @@ export interface GGUFTensorInfo { | |
offset: bigint; | ||
} | ||
|
||
export interface GGUFParseOutput { | ||
metadata: GGUFMetadata; | ||
export interface GGUFParseOutput<TGGUFType extends GGUFType = GGUFType.STRICT> { | ||
metadata: GGUFMetadata<TGGUFType>; | ||
tensorInfos: GGUFTensorInfo[]; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use
// @ts-expect-error
instead of// @ts-ignore
in general(no need for eslint-disable this way)
Here I think you can change
const metadata: GGUFMetadata
toconst metadata: GGUFMetadata<GGUFType.NON_STRICT>
to remove the error (not sure if it's the best fix)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 31bac8b