Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDK] Fix trainer error: Update the version of base image and add "num_labels" for downloading pretrained models #2230

Merged
merged 5 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/python/kubeflow/storage_initializer/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class HuggingFaceModelParams:
model_uri: str
transformer_type: TRANSFORMER_TYPES
access_token: str = None
num_labels: Optional[int] = None
helenxie-bit marked this conversation as resolved.
Show resolved Hide resolved

def __post_init__(self):
# Custom checks or validations can be added here
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/kubeflow/trainer/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Use an official Pytorch runtime as a parent image
FROM nvcr.io/nvidia/pytorch:23.10-py3
FROM nvcr.io/nvidia/pytorch:24.06-py3

# Set the working directory in the container
WORKDIR /app
Expand Down
35 changes: 25 additions & 10 deletions sdk/python/kubeflow/trainer/hf_llm_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,26 @@
logger.setLevel(logging.INFO)


def setup_model_and_tokenizer(model_uri, transformer_type, model_dir):
def setup_model_and_tokenizer(model_uri, transformer_type, model_dir, num_labels):
# Set up the model and tokenizer
parsed_uri = urlparse(model_uri)
model_name = parsed_uri.netloc + parsed_uri.path

model = transformer_type.from_pretrained(
pretrained_model_name_or_path=model_name,
cache_dir=model_dir,
local_files_only=True,
trust_remote_code=True,
)
if num_labels > 0:
model = transformer_type.from_pretrained(
pretrained_model_name_or_path=model_name,
cache_dir=model_dir,
local_files_only=True,
trust_remote_code=True,
num_labels=num_labels,
)
else:
model = transformer_type.from_pretrained(
pretrained_model_name_or_path=model_name,
cache_dir=model_dir,
local_files_only=True,
trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=model_name,
Expand Down Expand Up @@ -151,6 +160,7 @@ def parse_arguments():

parser.add_argument("--model_uri", help="model uri")
parser.add_argument("--transformer_type", help="model transformer type")
parser.add_argument("--num_labels", help="number of classes")
parser.add_argument("--model_dir", help="directory containing model")
parser.add_argument("--dataset_dir", help="directory containing dataset")
parser.add_argument("--lora_config", help="lora_config")
Expand All @@ -177,9 +187,14 @@ def parse_arguments():
transformer_type = getattr(transformers, args.transformer_type)

logger.info("Setup model and tokenizer")
model, tokenizer = setup_model_and_tokenizer(
args.model_uri, transformer_type, args.model_dir
)
if args.num_labels == "None":
model, tokenizer = setup_model_and_tokenizer(
args.model_uri, transformer_type, args.model_dir, 0
)
else:
model, tokenizer = setup_model_and_tokenizer(
args.model_uri, transformer_type, args.model_dir, int(args.num_labels)
)
helenxie-bit marked this conversation as resolved.
Show resolved Hide resolved

logger.info("Preprocess dataset")
train_data, eval_data = load_and_preprocess_data(
Expand Down
2 changes: 2 additions & 0 deletions sdk/python/kubeflow/training/api/training_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ def train(
model_provider_parameters.model_uri,
"--transformer_type",
model_provider_parameters.transformer_type.__name__,
"--num_labels",
str(model_provider_parameters.num_labels),
"--model_dir",
VOLUME_PATH_MODEL,
"--dataset_dir",
Expand Down
Loading