@@ -347,6 +347,21 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
347347 if (cfg.partition_info .enabled ) {
348348 return CompileGraphWithFallback (mod, cfg);
349349 }
350+ auto device_spec = cfg.convert_info .engine_settings .device ;
351+
352+ // GPU default WS size : 1 GB
353+ // Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
354+ auto workspace_size = cfg.convert_info .engine_settings .workspace_size ;
355+ cudaDeviceProp device_prop;
356+ cudaGetDeviceProperties (&device_prop, device_spec.gpu_id );
357+ if (workspace_size == 0 ) {
358+ if (device_prop.major < 6 ) {
359+ cfg.convert_info .engine_settings .workspace_size = 256 * (1 << 20 );
360+ } else {
361+ cfg.convert_info .engine_settings .workspace_size = 1 << 30 ;
362+ }
363+ }
364+
350365 // TODO: Should be doing a functional transform but need PR #31978
351366 // [jit] More robust mangling
352367 // torch::jit::script::Module new_mod = mod.clone();
@@ -357,7 +372,6 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
357372 if (method.name ().compare (" forward" ) == 0 ) {
358373 auto engine = ConvertGraphToTRTEngine (mod, method.name (), cfg);
359374 auto new_g = std::make_shared<torch::jit::Graph>();
360- auto device_spec = cfg.convert_info .engine_settings .device ;
361375 auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
362376 AddEngineToGraph (new_mod, new_g, engine, cuda_device);
363377 auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (method.name (), new_g);
0 commit comments