diff --git a/src/turnkeyml/llm/tools/ort_genai/oga.py b/src/turnkeyml/llm/tools/ort_genai/oga.py index de5a14a..633708c 100644 --- a/src/turnkeyml/llm/tools/ort_genai/oga.py +++ b/src/turnkeyml/llm/tools/ort_genai/oga.py @@ -35,7 +35,7 @@ oga_model_builder_cache_path = "model_builder" # Mapping from processor to executiion provider, used in pathnames and by model_builder -execution_providers = {"cpu": "cpu", "npu": "npu", "igpu": "dml"} +execution_providers = {"cpu": "cpu", "npu": "npu", "igpu": "dml", "cuda": "cuda"} class OrtGenaiTokenizer(TokenizerAdapter): @@ -248,7 +248,7 @@ def parser(add_help: bool = True) -> argparse.ArgumentParser: parser.add_argument( "-d", "--device", - choices=["igpu", "npu", "cpu"], + choices=["igpu", "npu", "cpu", "cuda"], default="igpu", help="Which device to load the model on to (default: igpu)", )