Skip to content

Commit 8cc56d9

Browse files
committed
add integration test, rename datasoter interface, change split to targets.
1 parent 8923424 commit 8cc56d9

File tree

10 files changed

+93
-31
lines changed

10 files changed

+93
-31
lines changed

apix/v1alpha2/inferencemodelrewrite_types.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ type InferenceModelRewriteRule struct {
9292
// +optional
9393
// +kubebuilder:validation:MinItems=1
9494
//
95-
Targets []TargetModel `json:"split,omitempty"`
95+
Targets []TargetModel `json:"targets,omitempty"`
9696
}
9797

9898
// TargetModel defines a weighted model destination for traffic distribution.

client-go/applyconfiguration/apix/v1alpha2/inferencemodelrewriterule.go

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

config/crd/bases/inference.networking.x-k8s.io_inferencemodelrewrites.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ spec:
108108
- model
109109
type: object
110110
type: array
111-
split:
111+
targets:
112112
items:
113113
description: TargetModel defines a weighted model destination
114114
for traffic distribution.

pkg/epp/controller/inferencemodelrewrite_reconciler.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,20 @@ func (c *InferenceModelRewriteReconciler) Reconcile(ctx context.Context, req ctr
5454
notFound = true
5555
}
5656

57-
if notFound || !infModelRewrite.DeletionTimestamp.IsZero() || infModelRewrite.Spec.PoolRef == nil ||
57+
isDeleted := !infModelRewrite.DeletionTimestamp.IsZero()
58+
isPooRefUnmatch := infModelRewrite.Spec.PoolRef == nil ||
5859
infModelRewrite.Spec.PoolRef.Name != v1alpha2.ObjectName(c.PoolGKNN.Name) ||
59-
(infModelRewrite.Spec.PoolRef.Group != v1alpha2.Group(c.PoolGKNN.Group) && infModelRewrite.Spec.PoolRef.Group != "inference.networking.x-k8s.io") {
60+
infModelRewrite.Spec.PoolRef.Group != v1alpha2.Group(c.PoolGKNN.Group)
61+
62+
if notFound || isDeleted || isPooRefUnmatch {
6063
// InferenceModelRewrite object got deleted or changed the referenced pool.
61-
c.Datastore.RewriteDelete(req.NamespacedName)
64+
c.Datastore.ModelRewriteDelete(req.NamespacedName)
6265
return ctrl.Result{}, nil
6366
}
6467

6568
// Add or update if the InferenceModelRewrite instance has a creation timestamp older than the existing entry of the model.
6669
logger = logger.WithValues("poolRef", infModelRewrite.Spec.PoolRef)
67-
c.Datastore.RewriteSet(infModelRewrite)
70+
c.Datastore.ModelRewriteSet(infModelRewrite)
6871
logger.Info("Added/Updated InferenceModelRewrite")
6972

7073
return ctrl.Result{}, nil

pkg/epp/controller/inferencemodelrewrite_reconciler_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ func TestInferenceModelRewriteReconciler(t *testing.T) {
185185
pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
186186
ds := datastore.NewDatastore(t.Context(), pmf, 0)
187187
for _, r := range test.rewritesInStore {
188-
ds.RewriteSet(r)
188+
ds.ModelRewriteSet(r)
189189
}
190190
endpointPool := poolutil.InferencePoolToEndpointPool(poolForRewrite)
191191
_ = ds.PoolSet(context.Background(), fakeClient, endpointPool)
@@ -210,8 +210,8 @@ func TestInferenceModelRewriteReconciler(t *testing.T) {
210210
t.Errorf("Unexpected result diff (+got/-want): %s", diff)
211211
}
212212

213-
if len(test.wantRewrites) != len(ds.RewriteGetAll()) {
214-
t.Errorf("Unexpected number of rewrites; want: %d, got:%d", len(test.wantRewrites), len(ds.RewriteGetAll()))
213+
if len(test.wantRewrites) != len(ds.ModelRewriteGetAll()) {
214+
t.Errorf("Unexpected number of rewrites; want: %d, got:%d", len(test.wantRewrites), len(ds.ModelRewriteGetAll()))
215215
}
216216

217217
if diff := diffStoreRewrites(ds, test.wantRewrites); diff != "" {
@@ -226,7 +226,7 @@ func diffStoreRewrites(ds datastore.Datastore, wantRewrites []*v1alpha2.Inferenc
226226
wantRewrites = []*v1alpha2.InferenceModelRewrite{}
227227
}
228228

229-
gotRewrites := ds.RewriteGetAll()
229+
gotRewrites := ds.ModelRewriteGetAll()
230230
if diff := cmp.Diff(wantRewrites, gotRewrites); diff != "" {
231231
return "rewrites:" + diff
232232
}

pkg/epp/datastore/datastore.go

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,10 @@ type Datastore interface {
6060
ObjectiveGetAll() []*v1alpha2.InferenceObjective
6161

6262
// InferenceModelRewrite operations
63-
RewriteSet(infModelRewrite *v1alpha2.InferenceModelRewrite)
64-
RewriteDelete(namespacedName types.NamespacedName)
65-
RewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule
66-
RewriteGetAll() []*v1alpha2.InferenceModelRewrite
63+
ModelRewriteSet(infModelRewrite *v1alpha2.InferenceModelRewrite)
64+
ModelRewriteDelete(namespacedName types.NamespacedName)
65+
ModelRewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule
66+
ModelRewriteGetAll() []*v1alpha2.InferenceModelRewrite
6767

6868
// PodList lists pods matching the given predicate.
6969
PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics
@@ -81,7 +81,7 @@ func NewDatastore(parentCtx context.Context, epFactory datalayer.EndpointFactory
8181
pool: nil,
8282
mu: sync.RWMutex{},
8383
objectives: make(map[string]*v1alpha2.InferenceObjective),
84-
rewrites: newModelRewriteStore(),
84+
modelRewrites: newModelRewriteStore(),
8585
pods: &sync.Map{},
8686
modelServerMetricsPort: modelServerMetricsPort,
8787
epf: epFactory,
@@ -103,8 +103,8 @@ type datastore struct {
103103
pool *datalayer.EndpointPool
104104
// key: InferenceObjective name, value: *InferenceObjective
105105
objectives map[string]*v1alpha2.InferenceObjective
106-
// rewrites store for InferenceModelRewrite objects.
107-
rewrites *modelRewriteStore
106+
// modelRewrites store for InferenceModelRewrite objects.
107+
modelRewrites *modelRewriteStore
108108
// key: types.NamespacedName, value: backendmetrics.PodMetrics
109109
pods *sync.Map
110110
// modelServerMetricsPort metrics port from EPP command line argument
@@ -118,7 +118,7 @@ func (ds *datastore) Clear() {
118118
defer ds.mu.Unlock()
119119
ds.pool = nil
120120
ds.objectives = make(map[string]*v1alpha2.InferenceObjective)
121-
ds.rewrites = newModelRewriteStore()
121+
ds.modelRewrites = newModelRewriteStore()
122122
// stop all pods go routines before clearing the pods map.
123123
ds.pods.Range(func(_, v any) bool {
124124
ds.epf.ReleaseEndpoint(v.(backendmetrics.PodMetrics))
@@ -210,28 +210,28 @@ func (ds *datastore) ObjectiveGetAll() []*v1alpha2.InferenceObjective {
210210
return res
211211
}
212212

213-
func (ds *datastore) RewriteSet(infModelRewrite *v1alpha2.InferenceModelRewrite) {
213+
func (ds *datastore) ModelRewriteSet(infModelRewrite *v1alpha2.InferenceModelRewrite) {
214214
ds.mu.Lock()
215215
defer ds.mu.Unlock()
216-
ds.rewrites.set(infModelRewrite)
216+
ds.modelRewrites.set(infModelRewrite)
217217
}
218218

219-
func (ds *datastore) RewriteDelete(namespacedName types.NamespacedName) {
219+
func (ds *datastore) ModelRewriteDelete(namespacedName types.NamespacedName) {
220220
ds.mu.Lock()
221221
defer ds.mu.Unlock()
222-
ds.rewrites.delete(namespacedName)
222+
ds.modelRewrites.delete(namespacedName)
223223
}
224224

225-
func (ds *datastore) RewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule {
225+
func (ds *datastore) ModelRewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule {
226226
ds.mu.RLock()
227227
defer ds.mu.RUnlock()
228-
return ds.rewrites.getRule(modelName)
228+
return ds.modelRewrites.getRule(modelName)
229229
}
230230

231-
func (ds *datastore) RewriteGetAll() []*v1alpha2.InferenceModelRewrite {
231+
func (ds *datastore) ModelRewriteGetAll() []*v1alpha2.InferenceModelRewrite {
232232
ds.mu.RLock()
233233
defer ds.mu.RUnlock()
234-
return ds.rewrites.getAll()
234+
return ds.modelRewrites.getAll()
235235
}
236236

237237
// /// Pods/endpoints APIs ///

pkg/epp/requestcontrol/director.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ type Datastore interface {
5050
PoolGet() (*datalayer.EndpointPool, error)
5151
ObjectiveGet(objectiveName string) *v1alpha2.InferenceObjective
5252
PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics
53-
RewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule
53+
ModelRewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule
5454
}
5555

5656
// Scheduler defines the interface required by the Director for scheduling.
@@ -194,7 +194,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
194194
}
195195

196196
func (d *Director) applyWeightedModelRewrite(reqCtx *handlers.RequestContext) {
197-
rewriteRule := d.datastore.RewriteGet(reqCtx.IncomingModelName)
197+
rewriteRule := d.datastore.ModelRewriteGet(reqCtx.IncomingModelName)
198198
if rewriteRule == nil {
199199
return
200200
}

pkg/epp/requestcontrol/director_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ func (m mockProducedDataType) Clone() datalayer.Cloneable {
169169
return mockProducedDataType{value: m.value}
170170
}
171171

172-
func (ds *mockDatastore) RewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule {
172+
func (ds *mockDatastore) ModelRewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule {
173173
// This mock implementation simulates the precedence logic for simplicity.
174174
// It finds the oldest rewrite that has a rule matching the modelName.
175175
var matchingRewrites []*v1alpha2.InferenceModelRewrite
@@ -268,7 +268,7 @@ func TestDirector_HandleRequest(t *testing.T) {
268268
ds.ObjectiveSet(ioFoodReview)
269269
ds.ObjectiveSet(ioFoodReviewResolve)
270270
ds.ObjectiveSet(ioFoodReviewSheddable)
271-
ds.RewriteSet(rewrite)
271+
ds.ModelRewriteSet(rewrite)
272272

273273
scheme := runtime.NewScheme()
274274
_ = clientgoscheme.AddToScheme(scheme)

test/integration/epp/hermetic_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ const (
9090
// Model Names
9191
modelMyModel = "my-model"
9292
modelMyModelTarget = "my-model-12345"
93+
modelToBeWritten = "model-to-be-rewritten"
94+
modelAfterRewrite = "rewritten-model"
9395
modelSQLLora = "sql-lora"
9496
modelSQLLoraTarget = "sql-lora-1fdg2"
9597
modelSheddable = "sql-lora-sheddable"
@@ -981,6 +983,42 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) {
981983
},
982984
},
983985
},
986+
{
987+
name: "rewrite request model",
988+
requests: integrationutils.GenerateStreamedRequestSet(logger, "test-rewrite", modelToBeWritten, modelToBeWritten, nil),
989+
// Pod 0 will be picked.
990+
// Expected flow:
991+
// 1. Request asks for "model-to-be-rewritte"
992+
// 2. Rewrite rule transforms "model-to-be-rewritten" -> "rewritten-model"
993+
// 3. EPP sends request to backend with model "rewritten-model"
994+
pods: newPodStates(
995+
podState{index: 0, queueSize: 0, kvCacheUsage: 0.1, activeModels: []string{"foo", "rewritten-model"}},
996+
),
997+
wantMetrics: map[string]string{
998+
"inference_objective_request_total": inferenceObjectiveRequestTotal([]label{
999+
{"model_name", modelToBeWritten},
1000+
{"target_model_name", modelAfterRewrite},
1001+
}),
1002+
},
1003+
wantErr: false,
1004+
wantResponses: integrationutils.NewRequestBufferedResponse(
1005+
"192.168.1.1:8000",
1006+
// Note: The prompt remains "test-rewrite", but the model in the JSON body is updated to the *rewritten target* model.
1007+
fmt.Sprintf(`{"max_tokens":100,"model":%q,"prompt":"test-rewrite","temperature":0}`, modelAfterRewrite),
1008+
&configPb.HeaderValueOption{
1009+
Header: &configPb.HeaderValue{
1010+
Key: "hi",
1011+
RawValue: []byte("mom"),
1012+
},
1013+
},
1014+
&configPb.HeaderValueOption{
1015+
Header: &configPb.HeaderValue{
1016+
Key: requtil.RequestIdHeaderKey,
1017+
RawValue: []byte("test-request-id"),
1018+
},
1019+
},
1020+
),
1021+
},
9841022
}
9851023

9861024
for _, test := range tests {
@@ -1247,6 +1285,7 @@ func BeforeSuite() func() {
12471285
_ = testEnv.Stop()
12481286
_ = k8sClient.DeleteAllOf(context.Background(), &v1.InferencePool{})
12491287
_ = k8sClient.DeleteAllOf(context.Background(), &v1alpha2.InferenceObjective{})
1288+
_ = k8sClient.DeleteAllOf(context.Background(), &v1alpha2.InferenceModelRewrite{})
12501289
}
12511290
}
12521291

@@ -1299,6 +1338,11 @@ func managerTestOptions(namespace, name string, metricsServerOptions metricsserv
12991338
namespace: {},
13001339
},
13011340
},
1341+
&v1alpha2.InferenceModelRewrite{}: {
1342+
Namespaces: map[string]cache.Config{
1343+
namespace: {},
1344+
},
1345+
},
13021346
},
13031347
},
13041348
Controller: crconfig.Controller{

test/testdata/inferencepool-with-model-hermetic.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,18 @@ spec:
6262
priority: 2
6363
poolRef:
6464
name: vllm-llama3-8b-instruct-pool
65+
---
66+
apiVersion: inference.networking.x-k8s.io/v1alpha2
67+
kind: InferenceModelRewrite
68+
metadata:
69+
name: rewrite-test
70+
namespace: default
71+
spec:
72+
poolRef:
73+
name: vllm-llama3-8b-instruct-pool
74+
rules:
75+
- matches:
76+
- model:
77+
value: model-to-be-rewritten
78+
targets:
79+
- modelRewrite: rewritten-model

0 commit comments

Comments
 (0)