Skip to content

Commit

Permalink
feat: user defined inference device for CUDA and OpenVINO (#2212)
Browse files Browse the repository at this point in the history
user defined inference device

configuration via main_gpu parameter
  • Loading branch information
fakezeta authored May 2, 2024
1 parent 6a7a799 commit 4690b53
Showing 1 changed file with 19 additions and 11 deletions.
30 changes: 19 additions & 11 deletions backend/python/transformers/transformers_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def LoadModel(self, request, context):
quantization = None

if self.CUDA:
if request.Device:
device_map=request.Device
if request.MainGPU:
device_map=request.MainGPU
else:
device_map="cuda:0"
if request.Quantization == "bnb_4bit":
Expand Down Expand Up @@ -143,28 +143,36 @@ def LoadModel(self, request, context):
from optimum.intel.openvino import OVModelForCausalLM
from openvino.runtime import Core

if "GPU" in Core().available_devices:
device_map="GPU"
if request.MainGPU:
device_map=request.MainGPU
else:
device_map="CPU"
device_map="AUTO"
devices = Core().available_devices
if "GPU" in " ".join(devices):
device_map="AUTO:GPU"

self.model = OVModelForCausalLM.from_pretrained(model_name,
compile=True,
trust_remote_code=request.TrustRemoteCode,
ov_config={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"},
ov_config={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"},
device=device_map)
self.OV = True
elif request.Type == "OVModelForFeatureExtraction":
from optimum.intel.openvino import OVModelForFeatureExtraction
from openvino.runtime import Core

if "GPU" in Core().available_devices:
device_map="GPU"
if request.MainGPU:
device_map=request.MainGPU
else:
device_map="CPU"
device_map="AUTO"
devices = Core().available_devices
if "GPU" in " ".join(devices):
device_map="AUTO:GPU"

self.model = OVModelForFeatureExtraction.from_pretrained(model_name,
compile=True,
trust_remote_code=request.TrustRemoteCode,
ov_config={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"},
ov_config={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT", "GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"},
export=True,
device=device_map)
self.OV = True
Expand Down Expand Up @@ -371,4 +379,4 @@ async def serve(address):
)
args = parser.parse_args()

asyncio.run(serve(args.addr))
asyncio.run(serve(args.addr))

0 comments on commit 4690b53

Please sign in to comment.