diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index b13e6e469e..19501cf408 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -327,11 +327,10 @@ def copy_to_worker_0(sess: di.Session, host_array): def get_tvm_model(artifact_path, model, quantization, num_shards, dev): - if num_shards > 1: - model_artifact_path = os.path.join( - artifact_path, f"{model}-{quantization}-presharded-{num_shards}gpu" - ) - else: + model_artifact_path = os.path.join( + artifact_path, f"{model}-{quantization}-presharded-{num_shards}gpu" + ) + if not os.path.exists(model_artifact_path): model_artifact_path = os.path.join(artifact_path, f"{model}-{quantization}") lib_path = os.path.join(model_artifact_path, f"{model}-{quantization}-cuda.so")