Skip to content

Commit 2c0dcc5

Browse files
committed
gracefully skip TTFT model training when not enough samples (no undefined vars)
1 parent 95f30f6 commit 2c0dcc5

File tree

1 file changed

+15
-17
lines changed

1 file changed

+15
-17
lines changed

latencypredictor-v1/training_server.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -628,22 +628,22 @@ def train(self):
628628
if len(df_ttft) >= settings.MIN_SAMPLES_FOR_RETRAIN:
629629
# Updated TTFT features to include prefix_cache_score
630630
ttft_feature_cols_tree = [
631-
'kv_cache_percentage','input_token_length','num_request_waiting',
632-
'num_request_running','prefix_cache_score','effective_input_tokens','prefill_score_bucket'
633-
]
634-
ttft_feature_cols_br = [
635-
'kv_cache_percentage','input_token_length','num_request_waiting',
636-
'num_request_running','prefix_cache_score','effective_input_tokens'
637-
]
638-
639-
# Build X_ttft for all model types, then trim for BR
640-
X_ttft = df_ttft[ttft_feature_cols_tree]
641-
if self.model_type == ModelType.BAYESIAN_RIDGE:
642-
X_ttft = X_ttft[ttft_feature_cols_br]
631+
'kv_cache_percentage','input_token_length','num_request_waiting',
632+
'num_request_running','prefix_cache_score','effective_input_tokens','prefill_score_bucket'
633+
]
634+
ttft_feature_cols_br = [
635+
'kv_cache_percentage','input_token_length','num_request_waiting',
636+
'num_request_running','prefix_cache_score','effective_input_tokens'
637+
]
638+
639+
# Build X_ttft for all model types, then trim for BR
640+
X_ttft = df_ttft[ttft_feature_cols_tree]
641+
if self.model_type == ModelType.BAYESIAN_RIDGE:
642+
X_ttft = X_ttft[ttft_feature_cols_br]
643643

644-
y_ttft = raw_ttft['actual_ttft_ms']
644+
y_ttft = raw_ttft['actual_ttft_ms']
645645

646-
try:
646+
try:
647647
# raw_ttft still has the original columns including 'prefix_cache_score'
648648
raw_ttft['_prefix_bucket'] = raw_ttft['prefix_cache_score'].clip(0, 1).apply(
649649
lambda s: min(int(s * self.prefix_buckets), self.prefix_buckets - 1)
@@ -677,8 +677,6 @@ def train(self):
677677
new_ttft_model, new_ttft_scaler, test_records, cols, 'actual_ttft_ms'
678678
)
679679

680-
681-
682680
if ql is not None:
683681
self.ttft_quantile_loss_scores.append(ql)
684682
self.ttft_coverage_scores.append(coverage)
@@ -690,7 +688,7 @@ def train(self):
690688
else:
691689
logging.info(f"TTFT model trained on {len(df_ttft)} samples. Quantile metrics = N/A (insufficient test data)")
692690

693-
except Exception:
691+
except Exception:
694692
logging.error("Error training TTFT model", exc_info=True)
695693

696694

0 commit comments

Comments
 (0)