Skip to content

Commit

Permalink
Add cuda support when loading local onnx model
Browse files Browse the repository at this point in the history
Signed-off-by: David Fan <jiafa@microsoft.com>
  • Loading branch information
jiafatom committed Dec 13, 2024
1 parent 63f11db commit c8e1eaa
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
19 changes: 17 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
"invoke>=2.0.0",
"onnx>=1.11.0",
"onnxmltools==1.10.0",
"onnxruntime >=1.10.1;platform_system=='Linux'",
"onnxruntime-directml>=1.19.0;platform_system=='Windows'",
"torch>=1.12.1",
"pyyaml>=5.4",
"typeguard>=2.3.13",
Expand All @@ -49,6 +47,8 @@
],
extras_require={
"llm": [
"onnxruntime >=1.10.1;platform_system=='Linux'",
"onnxruntime-directml>=1.19.0;platform_system=='Windows'",
"tqdm",
"torch>=2.0.0",
"transformers",
Expand All @@ -60,6 +60,8 @@
"uvicorn[standard]",
],
"llm-oga-dml": [
"onnxruntime >=1.10.1;platform_system=='Linux'",
"onnxruntime-directml>=1.19.0;platform_system=='Windows'",
"onnxruntime-genai-directml==0.4.0",
"tqdm",
"torch>=2.0.0,<2.4",
Expand All @@ -71,6 +73,19 @@
"fastapi",
"uvicorn[standard]",
],
"llm-oga-cuda": [
"onnxruntime-gpu>=1.19.0",
"onnxruntime-genai-cuda==0.4.0",
"tqdm",
"torch>=2.0.0,<2.4",
"transformers<4.45.0",
"accelerate",
"py-cpuinfo",
"sentencepiece",
"datasets",
"fastapi",
"uvicorn[standard]",
],
"llm-oga-npu": [
"transformers",
"torch",
Expand Down
5 changes: 3 additions & 2 deletions src/turnkeyml/llm/tools/ort_genai/oga.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
oga_model_builder_cache_path = "model_builder"

# Mapping from processor to executiion provider, used in pathnames and by model_builder
execution_providers = {"cpu": "cpu", "npu": "npu", "igpu": "dml"}
execution_providers = {"cpu": "cpu", "npu": "npu", "igpu": "dml", "cuda": "cuda"}


class OrtGenaiTokenizer(TokenizerAdapter):
Expand Down Expand Up @@ -248,7 +248,7 @@ def parser(add_help: bool = True) -> argparse.ArgumentParser:
parser.add_argument(
"-d",
"--device",
choices=["igpu", "npu", "cpu"],
choices=["igpu", "npu", "cpu", "cuda"],
default="igpu",
help="Which device to load the model on to (default: igpu)",
)
Expand Down Expand Up @@ -312,6 +312,7 @@ def run(
"cpu": {"int4": "*/*", "fp32": "*/*"},
"igpu": {"int4": "*/*", "fp16": "*/*"},
"npu": {"int4": "amd/**-onnx-ryzen-strix"},
"cuda": {"int4": "*/*", "fp16": "*/*"},
}
hf_supported = (
device in hf_supported_models
Expand Down

0 comments on commit c8e1eaa

Please sign in to comment.