@@ -11,75 +11,50 @@ namespace core {
1111namespace runtime {
1212
1313// Checks if the context switch requred for device ID
14- bool is_switch_required (const CudaDevice& curr_device, const CudaDevice& conf_device ) {
14+ bool is_switch_required (const CudaDevice& curr_device, const CudaDevice& engine_device ) {
1515 // If SM capability is not the same as configured then switch
16- if ((curr_device.major != conf_device .major ) || (curr_device.minor != conf_device .minor )) {
16+ if ((curr_device.major != engine_device .major ) || (curr_device.minor != engine_device .minor )) {
1717 LOG_WARNING (
18- " Configured SM capability " << conf_device .getSMCapability ()
18+ " Configured SM capability " << engine_device .getSMCapability ()
1919 << " does not match with current device SM capability "
2020 << curr_device.getSMCapability () << " (" << curr_device
2121 << " ). Switching device context" );
2222 return true ;
2323 }
2424
2525 // GPU case
26- if (conf_device .device_type == nvinfer1::DeviceType::kGPU ) {
27- if (curr_device.device_name != conf_device .device_name ) {
26+ if (engine_device .device_type == nvinfer1::DeviceType::kGPU ) {
27+ if (curr_device.device_name != engine_device .device_name ) {
2828 LOG_WARNING (
29- " Program compiled for " << conf_device .device_name << " but current CUDA device is " << curr_device
29+ " Program compiled for " << engine_device .device_name << " but current CUDA device is " << curr_device
3030 << " . Attempting to switch device context for better compatibility" );
3131 return true ;
3232 }
3333 }
3434
35- if (curr_device.id != conf_device .id ) {
35+ if (curr_device.id != engine_device .id ) {
3636 LOG_WARNING (
37- " Configured Device ID: " << conf_device .id << " is different that current device ID: " << curr_device.id
38- << " . Moving input tensors to device: " << conf_device .id );
37+ " Configured Device ID: " << engine_device .id << " is different that current device ID: " << curr_device.id
38+ << " . Moving input tensors to device: " << engine_device .id );
3939 return true ;
4040 }
4141
4242 return false ;
4343}
4444
45- CudaDevice select_cuda_device (const CudaDevice& conf_device) {
46- int64_t device_id = -1 ;
47- auto dla_supported = get_dla_supported_SMs ();
48-
49- auto device_list = get_available_device_list ().get_devices ();
50-
51- CudaDevice new_target_device;
52-
53- for (auto device : device_list) {
54- auto compute_cap = device.second .getSMCapability ();
55- // In case of DLA select the DLA supported device ID
56- if (conf_device.device_type == nvinfer1::DeviceType::kDLA ) {
57- if (dla_supported.find (compute_cap) != dla_supported.end () &&
58- dla_supported[compute_cap] == device.second .device_name ) {
59- device_id = device.second .id ;
60- new_target_device = CudaDevice (device_id, nvinfer1::DeviceType::kDLA );
61- break ;
62- }
63- } else if (conf_device.device_type == nvinfer1::DeviceType::kGPU ) {
64- auto conf_sm = conf_device.getSMCapability ();
65- if (compute_cap == conf_sm && device.second .device_name == conf_device.device_name ) {
66- device_id = device.second .id ;
67- new_target_device = CudaDevice (device_id, nvinfer1::DeviceType::kGPU );
68- break ;
69- }
70- } else {
71- TRTORCH_THROW_ERROR (" Unknown target device type detected from the compiled program (runtime.select_cuda_device)" );
72- break ;
73- }
74- }
45+ CudaDevice select_cuda_device (const CudaDevice& engine_device) {
46+ auto new_target_device_opt = get_most_compatible_device (engine_device);
7547
7648 // REVIEW: THIS DOES NOT LIST DLA PROBABLY, WHICH WE SHOULD
49+ // TODO: I think this logic could be way simpler at execution time since if the tensors arent on the right
50+ // device, its not going to run. We should just set device to engine device and maybe reset and memcpy tensors
51+ // back to orginal device if needed.
7752 TRTORCH_CHECK (
78- device_id >= 0 ,
53+ new_target_device_opt ,
7954 " No compatible device found on system to run program.\n Program targets "
80- << conf_device << " \n Available targets: \n "
55+ << engine_device << " \n Available targets: \n "
8156 << get_available_device_list ().dump_list () << " \n (runtime.select_cuda_device)" );
82- return new_target_device ;
57+ return new_target_device_opt. value () ;
8358}
8459
8560std::vector<at::Tensor> execute_engine (std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
@@ -96,7 +71,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
9671 std::string target_device = " cuda:" + std::to_string (device.id );
9772
9873 for (auto & in : inputs) {
99- in = in.to (at:: kCUDA );
74+ in = in.to (torch::Device (target_device) );
10075 }
10176 }
10277
0 commit comments