Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support hostpath models #247

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions api/core/v1alpha1/model_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,10 @@ type ModelSource struct {
// ModelHub represents the model registry for model downloads.
// +optional
ModelHub *ModelHub `json:"modelHub,omitempty"`
// URI represents a various kinds of model sources following the uri protocol, e.g.:
// - OSS: oss://<bucket>.<endpoint>/<path-to-your-model>
// URI represents a various kinds of model sources following the uri protocol, protocol://<address>, e.g.
// - oss://<bucket>.<endpoint>/<path-to-your-model>
// - ollama://llama3.3
// - host://<path-to-your-model>
//
// +optional
URI *URIProtocol `json:"uri,omitempty"`
Expand Down
6 changes: 4 additions & 2 deletions config/crd/bases/llmaz.io_openmodels.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,10 @@ spec:
type: object
uri:
description: |-
URI represents a various kinds of model sources following the uri protocol, e.g.:
- OSS: oss://<bucket>.<endpoint>/<path-to-your-model>
URI represents a various kinds of model sources following the uri protocol, protocol://<address>, e.g.
- oss://<bucket>.<endpoint>/<path-to-your-model>
- ollama://llama3.3
- host://<path-to-your-model>
type: string
type: object
required:
Expand Down
5 changes: 5 additions & 0 deletions docs/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ We provide a set of examples to help you serve large language models, by default
- [Deploy models via ollama](#ollama)
- [Speculative Decoding with vLLM](#speculative-decoding-with-vllm)
- [Deploy multi-host inference](#multi-host-inference)
- [Deploy host models](#deploy-host-models)

### Deploy models from Huggingface

Expand Down Expand Up @@ -59,3 +60,7 @@ By default, we use [vLLM](https://github.com/vllm-project/vllm) as the inference
### Multi-Host Inference

Model size is growing bigger and bigger, Llama 3.1 405B FP16 LLM requires more than 750 GB GPU for weights only, leaving kv cache unconsidered, even with 8 x H100 Nvidia GPUs, 80 GB size of HBM each, can not fit in a single host, requires a multi-host deployment, see [example](./multi-nodes/) here.

### Deploy Host Models

Models could be loaded in prior to the hosts, especially those extremely big models, see [example](./hostpath/) to serve local models.
13 changes: 13 additions & 0 deletions docs/examples/hostpath/model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
apiVersion: llmaz.io/v1alpha1
kind: OpenModel
metadata:
name: qwen2-0--5b-instruct
spec:
familyName: qwen2
source:
uri: host:///workspace/Qwen2-0.5B-Instruct
inferenceConfig:
flavors:
- name: t4 # GPU type
requests:
nvidia.com/gpu: 1
8 changes: 8 additions & 0 deletions docs/examples/hostpath/playground.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
apiVersion: inference.llmaz.io/v1alpha1
kind: Playground
metadata:
name: qwen2-0--5b-instruct
spec:
replicas: 1
modelClaim:
modelName: qwen2-0--5b-instruct
11 changes: 7 additions & 4 deletions pkg/controller_helper/model_source/modelsource.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,16 @@ func NewModelSourceProvider(model *coreapi.OpenModel) ModelSourceProvider {

if model.Spec.Source.URI != nil {
// We'll validate the format in the webhook, so generally no error should happen here.
protocol, address, _ := util.ParseURI(string(*model.Spec.Source.URI))
provider := &URIProvider{modelName: model.Name, protocol: protocol, modelAddress: address}
protocol, value, _ := util.ParseURI(string(*model.Spec.Source.URI))
provider := &URIProvider{modelName: model.Name, protocol: protocol}

switch protocol {
case OSS:
provider.endpoint, provider.bucket, provider.modelPath, _ = util.ParseOSS(address)
case OLLAMA:
provider.endpoint, provider.bucket, provider.modelPath, _ = util.ParseOSS(value)
case HostPath:
provider.modelPath = value
case Ollama:
provider.modelPath = value
default:
// This should be validated at webhooks.
panic("protocol not supported")
Expand Down
59 changes: 47 additions & 12 deletions pkg/controller_helper/model_source/uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,24 @@ import (
var _ ModelSourceProvider = &URIProvider{}

const (
OSS = "OSS"
OLLAMA = "OLLAMA"
OSS = "OSS"
Ollama = "OLLAMA"
HostPath = "HOST"
)

type URIProvider struct {
modelName string
protocol string
bucket string
endpoint string
modelPath string
modelAddress string
modelName string
protocol string
bucket string
endpoint string
modelPath string
}

func (p *URIProvider) ModelName() string {
if p.protocol == OLLAMA {
return p.modelAddress
if p.protocol == Ollama {
// model path stores the ollama model name,
// the model name is the name of model CRD.
return p.modelPath
}
return p.modelName
}
Expand All @@ -54,18 +56,51 @@ func (p *URIProvider) ModelName() string {
// - uri: bucket.endpoint/modelPath/model.gguf
// modelPath: /workspace/models/model.gguf
func (p *URIProvider) ModelPath() string {
if p.protocol == HostPath {
return p.modelPath
}

// protocol is oss.

splits := strings.Split(p.modelPath, "/")

if strings.Contains(p.modelPath, ".") {
if strings.Contains(p.modelPath, ".gguf") {
return CONTAINER_MODEL_PATH + splits[len(splits)-1]
}
return CONTAINER_MODEL_PATH + "models--" + splits[len(splits)-1]
}

func (p *URIProvider) InjectModelLoader(template *corev1.PodTemplateSpec, index int) {
if p.protocol == OLLAMA {
// We don't have additional operations for Ollama, just load in runtime.
if p.protocol == Ollama {
return
}

if p.protocol == HostPath {
template.Spec.Volumes = append(template.Spec.Volumes, corev1.Volume{
Name: MODEL_VOLUME_NAME,
VolumeSource: corev1.VolumeSource{
HostPath: &corev1.HostPathVolumeSource{
Path: p.modelPath,
},
},
})

for i, container := range template.Spec.Containers {
// We only consider this container.
if container.Name == MODEL_RUNNER_CONTAINER_NAME {
template.Spec.Containers[i].VolumeMounts = append(template.Spec.Containers[i].VolumeMounts, corev1.VolumeMount{
Name: MODEL_VOLUME_NAME,
MountPath: p.modelPath,
ReadOnly: true,
})
}
}
return
}

// Other protocols.

initContainerName := MODEL_LOADER_CONTAINER_NAME
if index != 0 {
initContainerName += "-" + strconv.Itoa(index)
Expand Down
5 changes: 3 additions & 2 deletions pkg/webhook/openmodel_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ func SetupOpenModelWebhook(mgr ctrl.Manager) error {
var _ webhook.CustomDefaulter = &OpenModelWebhook{}

var SUPPORTED_OBJ_STORES = map[string]struct{}{
modelSource.OSS: {},
modelSource.OLLAMA: {},
modelSource.OSS: {},
modelSource.Ollama: {},
modelSource.HostPath: {},
}

// Default implements webhook.Defaulter so a webhook will be registered for the type
Expand Down
6 changes: 6 additions & 0 deletions test/integration/webhook/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ var _ = ginkgo.Describe("model default and validation", func() {
},
failed: false,
}),
ginkgo.Entry("model creation with host protocol", &testValidatingCase{
model: func() *coreapi.OpenModel {
return wrapper.MakeModel("llama3-8b").FamilyName("llama3").ModelSourceWithURI("host:///models/meta-llama-3-8B").Obj()
},
failed: false,
}),
ginkgo.Entry("model creation with protocol unknown URI", &testValidatingCase{
model: func() *coreapi.OpenModel {
return wrapper.MakeModel("llama3-8b").FamilyName("llama3").ModelSourceWithURI("unknown://bucket.endpoint/models/meta-llama-3-8B").Obj()
Expand Down
2 changes: 1 addition & 1 deletion test/util/wrapper/playground.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func (w *PlaygroundWrapper) BackendRuntimeLimit(r, v string) *PlaygroundWrapper
return w
}

func (w *PlaygroundWrapper) ElasticConfig(maxReplicas, minReplicas int32) *PlaygroundWrapper {
func (w *PlaygroundWrapper) ElasticConfig(minReplicas, maxReplicas int32) *PlaygroundWrapper {
w.Spec.ElasticConfig = &inferenceapi.ElasticConfig{
MaxReplicas: ptr.To[int32](maxReplicas),
MinReplicas: ptr.To[int32](minReplicas),
Expand Down
Loading