Skip to content

Commit 250dc13

Browse files
akshaychitneniAkshay Chitneni
andauthored
feat(trainer): KEP-2655: Support provisioning of cache with Kubeflow SDK (#112)
Signed-off-by: Akshay Chitneni <achitneni@apple.com> Co-authored-by: Akshay Chitneni <achitneni@apple.com>
1 parent 80f6b0e commit 250dc13

File tree

7 files changed

+252
-30
lines changed

7 files changed

+252
-30
lines changed

CONTRIBUTING.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Specific programmatically generated files listed in the `exclude` field in [.pre
2929
To check formatting:
3030

3131
```shell
32-
make verify
32+
make verify
3333
```
3434

3535
## Testing
@@ -73,4 +73,4 @@ For any significant features or enhancement for Kubeflow SDK project we follow t
7373
[Kubeflow Enhancement Proposal process](https://github.com/kubeflow/community/tree/master/proposals).
7474

7575
If you want to submit a significant change to the Kubeflow Trainer, please submit a new KEP under
76-
[./docs/proposals](./docs/proposals/) directory.
76+
[./docs/proposals](./docs/proposals/) directory.

docs/proposals/2-trainer-local-execution/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ The proposed local execution mode will allow engineers to quickly test their mod
2727

2828
## Proposal
2929

30-
The local execution mode will allow users to run training jobs in container runtime environment on their local machines, mimicking the larger Kubeflow setup but without requiring Kubernetes.
30+
The local execution mode will allow users to run training jobs in container runtime environment on their local machines, mimicking the larger Kubeflow setup but without requiring Kubernetes.
3131

3232
![Architecture Diagram](high-level-arch.svg)
3333

kubeflow/trainer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from kubeflow.trainer.types.types import (
2828
BuiltinTrainer,
2929
CustomTrainer,
30+
DataCacheInitializer,
3031
DataFormat,
3132
DataType,
3233
HuggingFaceDatasetInitializer,
@@ -44,6 +45,7 @@
4445
__all__ = [
4546
"BuiltinTrainer",
4647
"CustomTrainer",
48+
"DataCacheInitializer",
4749
"DataFormat",
4850
"DATASET_PATH",
4951
"DataType",

kubeflow/trainer/types/types.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from dataclasses import dataclass, field
1717
from datetime import datetime
1818
from enum import Enum
19-
from typing import Callable, Optional
19+
from typing import Callable, Optional, Union
2020

2121
from kubeflow.trainer.constants import constants
2222

@@ -258,10 +258,58 @@ class TrainJob:
258258
# TODO (andreyvelich): Discuss how to keep these configurations is sync with pkg.initializers.types
259259
@dataclass
260260
class HuggingFaceDatasetInitializer:
261+
"""Configuration for downloading datasets from HuggingFace Hub."""
262+
261263
storage_uri: str
262264
access_token: Optional[str] = None
263265

264266

267+
@dataclass
268+
class DataCacheInitializer:
269+
"""Configuration for distributed data caching system for training workloads.
270+
271+
Args:
272+
storage_uri (`str`): The URI for the cached data in the format
273+
'cache://<SCHEMA_NAME>/<TABLE_NAME>'. This specifies the location
274+
where the data cache will be stored and accessed.
275+
metadata_loc (`str`): The metadata file path of an iceberg table.
276+
num_data_nodes (`int`): The number of data nodes in the distributed cache
277+
system. Must be greater than 1.
278+
head_cpu (`Optional[str]`): The CPU resources to allocate for the cache head node.
279+
head_mem (`Optional[str]`): The memory resources to allocate for the cache head node.
280+
worker_cpu (`Optional[str]`): The CPU resources to allocate for each cache worker node.
281+
worker_mem (`Optional[str]`): The memory resources to allocate for each cache worker node.
282+
iam_role (`Optional[str]`): The IAM role to use for accessing metadata_loc file.
283+
"""
284+
285+
storage_uri: str
286+
metadata_loc: str
287+
num_data_nodes: int
288+
head_cpu: Optional[str] = None
289+
head_mem: Optional[str] = None
290+
worker_cpu: Optional[str] = None
291+
worker_mem: Optional[str] = None
292+
iam_role: Optional[str] = None
293+
294+
def __post_init__(self):
295+
"""Validate DataCacheInitializer parameters."""
296+
if self.num_data_nodes <= 1:
297+
raise ValueError(f"num_data_nodes must be greater than 1, got {self.num_data_nodes}")
298+
299+
# Validate storage_uri format
300+
if not self.storage_uri.startswith("cache://"):
301+
raise ValueError(f"storage_uri must start with 'cache://', got {self.storage_uri}")
302+
303+
uri_path = self.storage_uri[len("cache://") :]
304+
parts = uri_path.split("/")
305+
306+
if len(parts) != 2:
307+
raise ValueError(
308+
f"storage_uri must be in format "
309+
f"'cache://<SCHEMA_NAME>/<TABLE_NAME>', got {self.storage_uri}"
310+
)
311+
312+
265313
# Configuration for the HuggingFace model initializer.
266314
@dataclass
267315
class HuggingFaceModelInitializer:
@@ -274,11 +322,11 @@ class Initializer:
274322
"""Initializer defines configurations for dataset and pre-trained model initialization
275323
276324
Args:
277-
dataset (`Optional[HuggingFaceDatasetInitializer]`): The configuration for one of the
278-
supported dataset initializers.
325+
dataset (`Optional[Union[HuggingFaceDatasetInitializer, DataCacheInitializer]]`):
326+
The configuration for one of the supported dataset initializers.
279327
model (`Optional[HuggingFaceModelInitializer]`): The configuration for one of the
280328
supported model initializers.
281329
"""
282330

283-
dataset: Optional[HuggingFaceDatasetInitializer] = None
331+
dataset: Optional[Union[HuggingFaceDatasetInitializer, DataCacheInitializer]] = None
284332
model: Optional[HuggingFaceModelInitializer] = None
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright 2025 The Kubeflow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from kubeflow.trainer.test.common import FAILED, SUCCESS, TestCase
18+
from kubeflow.trainer.types import types
19+
20+
21+
@pytest.mark.parametrize(
22+
"test_case",
23+
[
24+
TestCase(
25+
name="valid datacacheinitializer creation",
26+
expected_status=SUCCESS,
27+
config={
28+
"storage_uri": "cache://test_schema/test_table",
29+
"num_data_nodes": 3,
30+
"metadata_loc": "gs://my-bucket/metadata",
31+
},
32+
expected_output=None,
33+
),
34+
TestCase(
35+
name="invalid num_data_nodes raises ValueError",
36+
expected_status=FAILED,
37+
config={
38+
"storage_uri": "cache://test_schema/test_table",
39+
"num_data_nodes": 1,
40+
"metadata_loc": "gs://my-bucket/metadata",
41+
},
42+
expected_error=ValueError,
43+
),
44+
TestCase(
45+
name="zero num_data_nodes raises ValueError",
46+
expected_status=FAILED,
47+
config={
48+
"storage_uri": "cache://test_schema/test_table",
49+
"num_data_nodes": 0,
50+
"metadata_loc": "gs://my-bucket/metadata",
51+
},
52+
expected_error=ValueError,
53+
),
54+
TestCase(
55+
name="negative num_data_nodes raises ValueError",
56+
expected_status=FAILED,
57+
config={
58+
"storage_uri": "cache://test_schema/test_table",
59+
"num_data_nodes": -1,
60+
"metadata_loc": "gs://my-bucket/metadata",
61+
},
62+
expected_error=ValueError,
63+
),
64+
TestCase(
65+
name="invalid storage_uri without cache:// prefix raises ValueError",
66+
expected_status=FAILED,
67+
config={
68+
"storage_uri": "invalid://test_schema/test_table",
69+
"num_data_nodes": 3,
70+
"metadata_loc": "gs://my-bucket/metadata",
71+
},
72+
expected_error=ValueError,
73+
),
74+
TestCase(
75+
name="invalid storage_uri format raises ValueError",
76+
expected_status=FAILED,
77+
config={
78+
"storage_uri": "cache://test_schema",
79+
"num_data_nodes": 3,
80+
"metadata_loc": "gs://my-bucket/metadata",
81+
},
82+
expected_error=ValueError,
83+
),
84+
TestCase(
85+
name="invalid storage_uri with too many parts raises ValueError",
86+
expected_status=FAILED,
87+
config={
88+
"storage_uri": "cache://test_schema/test_table/extra",
89+
"num_data_nodes": 3,
90+
"metadata_loc": "gs://my-bucket/metadata",
91+
},
92+
expected_error=ValueError,
93+
),
94+
],
95+
)
96+
def test_data_cache_initializer(test_case: TestCase):
97+
"""Test DataCacheInitializer creation and validation."""
98+
print("Executing test:", test_case.name)
99+
100+
try:
101+
initializer = types.DataCacheInitializer(
102+
storage_uri=test_case.config["storage_uri"],
103+
num_data_nodes=test_case.config["num_data_nodes"],
104+
metadata_loc=test_case.config["metadata_loc"],
105+
)
106+
107+
assert test_case.expected_status == SUCCESS
108+
# Only check the fields that were passed in config, not auto-generated ones
109+
for key in test_case.config:
110+
assert getattr(initializer, key) == test_case.config[key]
111+
112+
except Exception as e:
113+
assert test_case.expected_status == FAILED
114+
assert type(e) is test_case.expected_error
115+
print("test execution complete")

kubeflow/trainer/utils/utils.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from dataclasses import fields
1516
import inspect
1617
import os
1718
import textwrap
18-
from typing import Any, Callable, Optional
19+
from typing import Any, Callable, Optional, Union
1920
from urllib.parse import urlparse
2021

2122
from kubeflow_trainer_api import models
@@ -563,34 +564,57 @@ def get_args_from_dataset_preprocess_config(
563564

564565

565566
def get_dataset_initializer(
566-
dataset: Optional[types.HuggingFaceDatasetInitializer] = None,
567+
dataset: Optional[
568+
Union[types.HuggingFaceDatasetInitializer, types.DataCacheInitializer]
569+
] = None,
567570
) -> Optional[models.TrainerV1alpha1DatasetInitializer]:
568571
"""
569572
Get the TrainJob dataset initializer from the given config.
570573
"""
571-
if not isinstance(dataset, types.HuggingFaceDatasetInitializer):
572-
return None
574+
if isinstance(dataset, types.HuggingFaceDatasetInitializer):
575+
dataset_initializer = models.TrainerV1alpha1DatasetInitializer(
576+
storageUri=(
577+
dataset.storage_uri
578+
if dataset.storage_uri.startswith("hf://")
579+
else "hf://" + dataset.storage_uri
580+
),
581+
env=(
582+
[
583+
models.IoK8sApiCoreV1EnvVar(
584+
name=constants.INITIALIZER_ENV_ACCESS_TOKEN,
585+
value=dataset.access_token,
586+
),
587+
]
588+
if dataset.access_token
589+
else None
590+
),
591+
)
592+
return dataset_initializer
593+
elif isinstance(dataset, types.DataCacheInitializer):
594+
# Build env vars from optional model fields
595+
envs = []
596+
597+
# Add CLUSTER_SIZE env var from num_data_nodes required field
598+
envs.append(
599+
models.IoK8sApiCoreV1EnvVar(name="CLUSTER_SIZE", value=str(dataset.num_data_nodes + 1))
600+
)
573601

574-
# TODO (andreyvelich): Support more parameters.
575-
dataset_initializer = models.TrainerV1alpha1DatasetInitializer(
576-
storageUri=(
577-
dataset.storage_uri
578-
if dataset.storage_uri.startswith("hf://")
579-
else "hf://" + dataset.storage_uri
580-
),
581-
env=(
582-
[
583-
models.IoK8sApiCoreV1EnvVar(
584-
name=constants.INITIALIZER_ENV_ACCESS_TOKEN,
585-
value=dataset.access_token,
586-
),
587-
]
588-
if dataset.access_token
589-
else None
590-
),
591-
)
602+
# Add METADATA_LOC env var from metadata_loc required field
603+
envs.append(models.IoK8sApiCoreV1EnvVar(name="METADATA_LOC", value=dataset.metadata_loc))
592604

593-
return dataset_initializer
605+
# Add env vars from optional fields (skip required fields)
606+
required_fields = {"storage_uri", "metadata_loc", "num_data_nodes"}
607+
for f in fields(dataset):
608+
if f.name not in required_fields:
609+
value = getattr(dataset, f.name)
610+
if value is not None:
611+
envs.append(models.IoK8sApiCoreV1EnvVar(name=f.name.upper(), value=value))
612+
613+
return models.TrainerV1alpha1DatasetInitializer(
614+
storageUri=dataset.storage_uri, env=envs if envs else None
615+
)
616+
else:
617+
return None
594618

595619

596620
def get_model_initializer(

kubeflow/trainer/utils/utils_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,3 +255,36 @@ def test_get_command_using_train_func(test_case: TestCase):
255255
except Exception as e:
256256
assert type(e) is test_case.expected_error
257257
print("test execution complete")
258+
259+
260+
def test_get_dataset_initializer():
261+
"""Test get_dataset_initializer uses DataCacheInitializer optional fields as env vars."""
262+
datacache_initializer = types.DataCacheInitializer(
263+
storage_uri="cache://test_schema/test_table",
264+
num_data_nodes=3,
265+
metadata_loc="s3://bucket/metadata",
266+
head_cpu="1",
267+
head_mem="1Gi",
268+
worker_cpu="2",
269+
worker_mem="2Gi",
270+
iam_role="arn:aws:iam::123456789012:role/test-role",
271+
)
272+
273+
dataset_initializer = utils.get_dataset_initializer(datacache_initializer)
274+
275+
assert dataset_initializer is not None
276+
assert dataset_initializer.env is not None
277+
env_dict = {env_var.name: env_var.value for env_var in dataset_initializer.env}
278+
279+
# Check CLUSTER_SIZE is present from num_data_nodes
280+
assert env_dict["CLUSTER_SIZE"] == "4"
281+
282+
# Check METADATA_LOC is present from metadata_loc
283+
assert env_dict["METADATA_LOC"] == "s3://bucket/metadata"
284+
285+
# Check all optional fields are present as uppercase env vars
286+
assert env_dict["HEAD_CPU"] == "1"
287+
assert env_dict["HEAD_MEM"] == "1Gi"
288+
assert env_dict["WORKER_CPU"] == "2"
289+
assert env_dict["WORKER_MEM"] == "2Gi"
290+
assert env_dict["IAM_ROLE"] == "arn:aws:iam::123456789012:role/test-role"

0 commit comments

Comments
 (0)