From 37c40f7476383e83ec50aec1891b0f4d5e8898c1 Mon Sep 17 00:00:00 2001 From: kerthcet Date: Tue, 21 Jan 2025 15:05:37 +0800 Subject: [PATCH] Support hostpath models Signed-off-by: kerthcet --- api/core/v1alpha1/model_types.go | 6 +- config/crd/bases/llmaz.io_openmodels.yaml | 6 +- docs/examples/README.md | 5 ++ docs/examples/hostpath/model.yaml | 13 ++++ docs/examples/hostpath/playground.yaml | 8 +++ .../model_source/modelsource.go | 11 ++-- pkg/controller_helper/model_source/uri.go | 59 +++++++++++++++---- pkg/webhook/openmodel_webhook.go | 5 +- test/integration/webhook/model_test.go | 6 ++ test/util/wrapper/playground.go | 2 +- 10 files changed, 98 insertions(+), 23 deletions(-) create mode 100644 docs/examples/hostpath/model.yaml create mode 100644 docs/examples/hostpath/playground.yaml diff --git a/api/core/v1alpha1/model_types.go b/api/core/v1alpha1/model_types.go index 017edbf..b82f721 100644 --- a/api/core/v1alpha1/model_types.go +++ b/api/core/v1alpha1/model_types.go @@ -82,8 +82,10 @@ type ModelSource struct { // ModelHub represents the model registry for model downloads. // +optional ModelHub *ModelHub `json:"modelHub,omitempty"` - // URI represents a various kinds of model sources following the uri protocol, e.g.: - // - OSS: oss://./ + // URI represents a various kinds of model sources following the uri protocol, protocol://
, e.g. + // - oss://./ + // - ollama://llama3.3 + // - host:// // // +optional URI *URIProtocol `json:"uri,omitempty"` diff --git a/config/crd/bases/llmaz.io_openmodels.yaml b/config/crd/bases/llmaz.io_openmodels.yaml index 6dac63c..61c561f 100644 --- a/config/crd/bases/llmaz.io_openmodels.yaml +++ b/config/crd/bases/llmaz.io_openmodels.yaml @@ -155,8 +155,10 @@ spec: type: object uri: description: |- - URI represents a various kinds of model sources following the uri protocol, e.g.: - - OSS: oss://./ + URI represents a various kinds of model sources following the uri protocol, protocol://
, e.g. + - oss://./ + - ollama://llama3.3 + - host:// type: string type: object required: diff --git a/docs/examples/README.md b/docs/examples/README.md index ce7ea8e..05fb7ee 100644 --- a/docs/examples/README.md +++ b/docs/examples/README.md @@ -13,6 +13,7 @@ We provide a set of examples to help you serve large language models, by default - [Deploy models via ollama](#ollama) - [Speculative Decoding with vLLM](#speculative-decoding-with-vllm) - [Deploy multi-host inference](#multi-host-inference) +- [Deploy host models](#deploy-host-models) ### Deploy models from Huggingface @@ -59,3 +60,7 @@ By default, we use [vLLM](https://github.com/vllm-project/vllm) as the inference ### Multi-Host Inference Model size is growing bigger and bigger, Llama 3.1 405B FP16 LLM requires more than 750 GB GPU for weights only, leaving kv cache unconsidered, even with 8 x H100 Nvidia GPUs, 80 GB size of HBM each, can not fit in a single host, requires a multi-host deployment, see [example](./multi-nodes/) here. + +### Deploy Host Models + +Models could be loaded in prior to the hosts, especially those extremely big models, see [example](./hostpath/) to serve local models. diff --git a/docs/examples/hostpath/model.yaml b/docs/examples/hostpath/model.yaml new file mode 100644 index 0000000..f5d6f54 --- /dev/null +++ b/docs/examples/hostpath/model.yaml @@ -0,0 +1,13 @@ +apiVersion: llmaz.io/v1alpha1 +kind: OpenModel +metadata: + name: qwen2-0--5b-instruct +spec: + familyName: qwen2 + source: + uri: host:///workspace/Qwen2-0.5B-Instruct + inferenceConfig: + flavors: + - name: t4 # GPU type + requests: + nvidia.com/gpu: 1 diff --git a/docs/examples/hostpath/playground.yaml b/docs/examples/hostpath/playground.yaml new file mode 100644 index 0000000..c7a3381 --- /dev/null +++ b/docs/examples/hostpath/playground.yaml @@ -0,0 +1,8 @@ +apiVersion: inference.llmaz.io/v1alpha1 +kind: Playground +metadata: + name: qwen2-0--5b-instruct +spec: + replicas: 1 + modelClaim: + modelName: qwen2-0--5b-instruct diff --git a/pkg/controller_helper/model_source/modelsource.go b/pkg/controller_helper/model_source/modelsource.go index 18bd640..352733a 100644 --- a/pkg/controller_helper/model_source/modelsource.go +++ b/pkg/controller_helper/model_source/modelsource.go @@ -71,13 +71,16 @@ func NewModelSourceProvider(model *coreapi.OpenModel) ModelSourceProvider { if model.Spec.Source.URI != nil { // We'll validate the format in the webhook, so generally no error should happen here. - protocol, address, _ := util.ParseURI(string(*model.Spec.Source.URI)) - provider := &URIProvider{modelName: model.Name, protocol: protocol, modelAddress: address} + protocol, value, _ := util.ParseURI(string(*model.Spec.Source.URI)) + provider := &URIProvider{modelName: model.Name, protocol: protocol} switch protocol { case OSS: - provider.endpoint, provider.bucket, provider.modelPath, _ = util.ParseOSS(address) - case OLLAMA: + provider.endpoint, provider.bucket, provider.modelPath, _ = util.ParseOSS(value) + case HostPath: + provider.modelPath = value + case Ollama: + provider.modelPath = value default: // This should be validated at webhooks. panic("protocol not supported") diff --git a/pkg/controller_helper/model_source/uri.go b/pkg/controller_helper/model_source/uri.go index fb8a049..fb2a6f9 100644 --- a/pkg/controller_helper/model_source/uri.go +++ b/pkg/controller_helper/model_source/uri.go @@ -26,22 +26,24 @@ import ( var _ ModelSourceProvider = &URIProvider{} const ( - OSS = "OSS" - OLLAMA = "OLLAMA" + OSS = "OSS" + Ollama = "OLLAMA" + HostPath = "HOST" ) type URIProvider struct { - modelName string - protocol string - bucket string - endpoint string - modelPath string - modelAddress string + modelName string + protocol string + bucket string + endpoint string + modelPath string } func (p *URIProvider) ModelName() string { - if p.protocol == OLLAMA { - return p.modelAddress + if p.protocol == Ollama { + // model path stores the ollama model name, + // the model name is the name of model CRD. + return p.modelPath } return p.modelName } @@ -54,18 +56,51 @@ func (p *URIProvider) ModelName() string { // - uri: bucket.endpoint/modelPath/model.gguf // modelPath: /workspace/models/model.gguf func (p *URIProvider) ModelPath() string { + if p.protocol == HostPath { + return p.modelPath + } + + // protocol is oss. + splits := strings.Split(p.modelPath, "/") - if strings.Contains(p.modelPath, ".") { + if strings.Contains(p.modelPath, ".gguf") { return CONTAINER_MODEL_PATH + splits[len(splits)-1] } return CONTAINER_MODEL_PATH + "models--" + splits[len(splits)-1] } func (p *URIProvider) InjectModelLoader(template *corev1.PodTemplateSpec, index int) { - if p.protocol == OLLAMA { + // We don't have additional operations for Ollama, just load in runtime. + if p.protocol == Ollama { return } + + if p.protocol == HostPath { + template.Spec.Volumes = append(template.Spec.Volumes, corev1.Volume{ + Name: MODEL_VOLUME_NAME, + VolumeSource: corev1.VolumeSource{ + HostPath: &corev1.HostPathVolumeSource{ + Path: p.modelPath, + }, + }, + }) + + for i, container := range template.Spec.Containers { + // We only consider this container. + if container.Name == MODEL_RUNNER_CONTAINER_NAME { + template.Spec.Containers[i].VolumeMounts = append(template.Spec.Containers[i].VolumeMounts, corev1.VolumeMount{ + Name: MODEL_VOLUME_NAME, + MountPath: p.modelPath, + ReadOnly: true, + }) + } + } + return + } + + // Other protocols. + initContainerName := MODEL_LOADER_CONTAINER_NAME if index != 0 { initContainerName += "-" + strconv.Itoa(index) diff --git a/pkg/webhook/openmodel_webhook.go b/pkg/webhook/openmodel_webhook.go index 6b188c5..9ede0ce 100644 --- a/pkg/webhook/openmodel_webhook.go +++ b/pkg/webhook/openmodel_webhook.go @@ -47,8 +47,9 @@ func SetupOpenModelWebhook(mgr ctrl.Manager) error { var _ webhook.CustomDefaulter = &OpenModelWebhook{} var SUPPORTED_OBJ_STORES = map[string]struct{}{ - modelSource.OSS: {}, - modelSource.OLLAMA: {}, + modelSource.OSS: {}, + modelSource.Ollama: {}, + modelSource.HostPath: {}, } // Default implements webhook.Defaulter so a webhook will be registered for the type diff --git a/test/integration/webhook/model_test.go b/test/integration/webhook/model_test.go index 193fe6e..fe46fd3 100644 --- a/test/integration/webhook/model_test.go +++ b/test/integration/webhook/model_test.go @@ -105,6 +105,12 @@ var _ = ginkgo.Describe("model default and validation", func() { }, failed: false, }), + ginkgo.Entry("model creation with host protocol", &testValidatingCase{ + model: func() *coreapi.OpenModel { + return wrapper.MakeModel("llama3-8b").FamilyName("llama3").ModelSourceWithURI("host:///models/meta-llama-3-8B").Obj() + }, + failed: false, + }), ginkgo.Entry("model creation with protocol unknown URI", &testValidatingCase{ model: func() *coreapi.OpenModel { return wrapper.MakeModel("llama3-8b").FamilyName("llama3").ModelSourceWithURI("unknown://bucket.endpoint/models/meta-llama-3-8B").Obj() diff --git a/test/util/wrapper/playground.go b/test/util/wrapper/playground.go index 816897c..15541d3 100644 --- a/test/util/wrapper/playground.go +++ b/test/util/wrapper/playground.go @@ -160,7 +160,7 @@ func (w *PlaygroundWrapper) BackendRuntimeLimit(r, v string) *PlaygroundWrapper return w } -func (w *PlaygroundWrapper) ElasticConfig(maxReplicas, minReplicas int32) *PlaygroundWrapper { +func (w *PlaygroundWrapper) ElasticConfig(minReplicas, maxReplicas int32) *PlaygroundWrapper { w.Spec.ElasticConfig = &inferenceapi.ElasticConfig{ MaxReplicas: ptr.To[int32](maxReplicas), MinReplicas: ptr.To[int32](minReplicas),