Skip to content

🐛 [Bug] Segmentation Fault when running on Jetson Orin NX #1891

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
janblumenkamp opened this issue May 5, 2023 · 6 comments
Closed

🐛 [Bug] Segmentation Fault when running on Jetson Orin NX #1891

janblumenkamp opened this issue May 5, 2023 · 6 comments
Assignees
Labels
bug Something isn't working No Activity platform: aarch64 Bugs regarding the x86_64 builds of TRTorch

Comments

@janblumenkamp
Copy link

Bug Description

I built TensorRT for the Jetson Orin NX. I followed the instructions here and am building on the Jetson on the pyt2.0 branch, which uses the TensorRT 1.4.0 RC.

To Reproduce

I use the test script from here. When I try to run it, I get the following output:

/usr/local/lib/python3.8/dist-packages/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py:840: UserWarning: Unable to import torchvision related libraries.: No module named 'torchvision'. Please install torchvision lib in order to lower stochastic_depth
  warnings.warn(
WARNING: [Torch-TensorRT TorchScript Conversion Context] - Unable to determine GPU memory usage
WARNING: [Torch-TensorRT TorchScript Conversion Context] - Unable to determine GPU memory usage
WARNING: [Torch-TensorRT TorchScript Conversion Context] - CUDA initialization failure with error: 222. Please check your CUDA installation:  http://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html
Segmentation fault (core dumped)

Expected behavior

The test model is converted successfully.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 1.4.0 (pyt2.0 branch ec06d6f)
  • PyTorch Version (e.g. 1.0): 2.0.0a0+8aa34602.nv23.03 (installed from NVidia as instructed here)
  • CPU Architecture: aarch64
  • OS (e.g., Linux): Linux-5.10.104-tegra-aarch64-with-glibc2.29
  • How you installed PyTorch (conda, pip, libtorch, source): NVidia compiled version, torch works on CUDA.
  • Build command you used (if compiling from source):
    • bazel build //:libtorchtrt --platforms //toolchains:jetpack_5.0
    • python3 setup.py install --use-cxx11-abi --jetpack-version 5.0
  • Are you using local sources or building from archives: Local sources
  • Python version: 3.8.10 (default, Mar 13 2023, 10:26:41) [GCC 9.4.0] (64-bit runtime)
  • CUDA version: Jetson Orin NX
  • GPU models and configuration: Jetson Orin NX
  • Any other relevant information: N/A

Additional context

Output of python3 -m torch.utils.collect_env:

nvidia@tegra-ubuntu:~$ python3 -m torch.utils.collect_env
Collecting environment information...
PyTorch version: 2.0.0a0+8aa34602.nv23.03
Is debug build: False
CUDA used to build PyTorch: 11.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (aarch64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31

Python version: 3.8.10 (default, Mar 13 2023, 10:26:41)  [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.10.104-tegra-aarch64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/aarch64-linux-gnu/libcudnn.so.8.6.0
/usr/lib/aarch64-linux-gnu/libcudnn_adv_infer.so.8.6.0
/usr/lib/aarch64-linux-gnu/libcudnn_adv_train.so.8.6.0
/usr/lib/aarch64-linux-gnu/libcudnn_cnn_infer.so.8.6.0
/usr/lib/aarch64-linux-gnu/libcudnn_cnn_train.so.8.6.0
/usr/lib/aarch64-linux-gnu/libcudnn_ops_infer.so.8.6.0
/usr/lib/aarch64-linux-gnu/libcudnn_ops_train.so.8.6.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: False

CPU:
Architecture:                    aarch64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
CPU(s):                          8
On-line CPU(s) list:             0-7
Thread(s) per core:              1
Core(s) per socket:              4
Socket(s):                       2
Vendor ID:                       ARM
Model:                           1
Model name:                      ARMv8 Processor rev 1 (v8l)
Stepping:                        r0p1
CPU max MHz:                     1984.0000
CPU min MHz:                     115.2000
BogoMIPS:                        62.50
L1d cache:                       512 KiB
L1i cache:                       512 KiB
L2 cache:                        2 MiB
L3 cache:                        4 MiB
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:        Mitigation; __user pointer sanitization
Vulnerability Spectre v2:        Not affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected
Flags:                           fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm lrcpc dcpop asimddp uscat ilrcpc flagm

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.19.4
[pip3] torch==2.0.0a0+8aa34602.nv23.3
[pip3] torch-tensorrt==1.3.0+3d6a1ba5
[conda] Could not collect
@janblumenkamp
Copy link
Author

janblumenkamp commented May 5, 2023

To add some more context, this is the stacktrace when running python in gdb:

nvidia@tegra-ubuntu:~$ gdb python3
GNU gdb (Ubuntu 9.2-0ubuntu1~20.04.1) 9.2
Copyright (C) 2020 Free Software Foundation, Inc.
License GPLv3+: GNU GPL version 3 or later <http://gnu.org/licenses/gpl.html>
This is free software: you are free to change and redistribute it.
There is NO WARRANTY, to the extent permitted by law.
Type "show copying" and "show warranty" for details.
This GDB was configured as "aarch64-linux-gnu".
Type "show configuration" for configuration details.
For bug reporting instructions, please see:
<http://www.gnu.org/software/gdb/bugs/>.
Find the GDB manual and other documentation resources online at:
    <http://www.gnu.org/software/gdb/documentation/>.

For help, type "help".
Type "apropos word" to search for commands related to "word"...
Reading symbols from python3...
(No debugging symbols found in python3)
(gdb) run test_tensorrt.py 
Starting program: /usr/bin/python3 test_tensorrt.py
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/aarch64-linux-gnu/libthread_db.so.1".
[Detaching after fork from child process 139619]
[New Thread 0xffffa42921e0 (LWP 139620)]
[New Thread 0xffffa3a911e0 (LWP 139621)]
[New Thread 0xffffa12901e0 (LWP 139622)]
[New Thread 0xffffa0a8f1e0 (LWP 139623)]
[New Thread 0xffff9e28e1e0 (LWP 139624)]
[New Thread 0xffff9da8d1e0 (LWP 139625)]
[New Thread 0xffff9c28c1e0 (LWP 139626)]
[New Thread 0xffff951f51e0 (LWP 139627)]
[New Thread 0xffff949f41e0 (LWP 139628)]
[New Thread 0xffff921f31e0 (LWP 139629)]
[New Thread 0xffff909f21e0 (LWP 139630)]
[New Thread 0xffff8f1f11e0 (LWP 139631)]
[New Thread 0xffff8d9f01e0 (LWP 139632)]
[New Thread 0xffff8c1ef1e0 (LWP 139633)]
/usr/local/lib/python3.8/dist-packages/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py:840: UserWarning: Unable to import torchvision related libraries.: No module named 'torchvision'. Please install torchvision lib in order to lower stochastic_depth
  warnings.warn(
[New Thread 0xffff6716d1e0 (LWP 139634)]
WARNING: [Torch-TensorRT TorchScript Conversion Context] - Unable to determine GPU memory usage
WARNING: [Torch-TensorRT TorchScript Conversion Context] - Unable to determine GPU memory usage
WARNING: [Torch-TensorRT TorchScript Conversion Context] - CUDA initialization failure with error: 222. Please check your CUDA installation:  http://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html

Thread 1 "python3" received signal SIGSEGV, Segmentation fault.
0x0000ffff6d5a9bc4 in torch_tensorrt::core::conversion::ConversionCtx::ConversionCtx(torch_tensorrt::core::conversion::BuilderSettings) ()
   from /usr/local/lib/python3.8/dist-packages/torch_tensorrt/lib/libtorchtrt.so
(gdb) backtrace
#0  0x0000ffff6d5a9bc4 in torch_tensorrt::core::conversion::ConversionCtx::ConversionCtx(torch_tensorrt::core::conversion::BuilderSettings) ()
   from /usr/local/lib/python3.8/dist-packages/torch_tensorrt/lib/libtorchtrt.so
#1  0x0000ffff6d503658 in torch_tensorrt::core::conversion::ConvertBlockToEngine[abi:cxx11](torch::jit::Block const*, torch_tensorrt::core::conversion::ConversionInfo, std::map<torch::jit::Value*, c10::IValue, std::less<torch::jit::Value*>, std::allocator<std::pair<torch::jit::Value* const, c10::IValue> > >&) () from /usr/local/lib/python3.8/dist-packages/torch_tensorrt/lib/libtorchtrt.so
#2  0x0000ffff6d4c19ac in torch_tensorrt::core::CompileGraph(torch::jit::Module const&, torch_tensorrt::core::CompileSpec) () from /usr/local/lib/python3.8/dist-packages/torch_tensorrt/lib/libtorchtrt.so
#3  0x0000ffff6d6b70c0 in torch_tensorrt::pyapi::CompileGraph (mod=..., info=...) at torch_tensorrt/csrc/torch_tensorrt_py.cpp:155
#4  0x0000ffff6d6e5ae0 in pybind11::detail::argument_loader<torch::jit::Module const&, torch_tensorrt::pyapi::CompileSpec&>::call_impl<torch::jit::Module, torch::jit::Module (*&)(torch::jit::Module const&, torch_tensorrt::pyapi::CompileSpec&), 0ul, 1ul, pybind11::detail::void_type>(torch::jit::Module (*&)(torch::jit::Module const&, torch_tensorrt::pyapi::CompileSpec&), std::integer_sequence<unsigned long, 0ul, 1ul>, pybind11::detail::void_type&&) && (f=<optimized out>, this=0xffffffffdd78) at /usr/local/lib/python3.8/dist-packages/torch/include/pybind11/detail/../detail/type_caster_base.h:978
#5  pybind11::detail::argument_loader<torch::jit::Module const&, torch_tensorrt::pyapi::CompileSpec&>::call<torch::jit::Module, pybind11::detail::void_type, torch::jit::Module (*&)(torch::jit::Module const&, torch_tensorrt::pyapi::CompileSpec&)>(torch::jit::Module (*&)(torch::jit::Module const&, torch_tensorrt::pyapi::CompileSpec&)) && (f=<optimized out>, this=0xffffffffdd78)
    at /usr/local/lib/python3.8/dist-packages/torch/include/pybind11/detail/../cast.h:1408
#6  pybind11::cpp_function::initialize<torch::jit::Module (*&)(torch::jit::Module const&, torch_tensorrt::pyapi::CompileSpec&), torch::jit::Module, torch::jit::Module const&, torch_tensorrt::pyapi::CompileSpec&, pybind11::name, pybind11::scope, pybind11::sibling, char [128]>(torch::jit::Module (*&)(torch::jit::Module const&, torch_tensorrt::pyapi::CompileSpec&), torch::jit::Module (*)(torch::jit::Module const&, torch_tensorrt::pyapi::CompileSpec&), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, char const (&) [128])::{lambda(pybind11::detail::function_call&)#3}::operator()(pybind11::detail::function_call&) const (call=..., this=0x0) at /usr/local/lib/python3.8/dist-packages/torch/include/pybind11/pybind11.h:249
#7  pybind11::cpp_function::initialize<torch::jit::Module (*&)(torch::jit::Module const&, torch_tensorrt::pyapi::CompileSpec&), torch::jit::Module, torch::jit::Module const&, torch_tensorrt::pyapi::CompileSpec&, pybind11::name, pybind11::scope, pybind11::sibling, char [128]>(torch::jit::Module (*&)(torch::jit::Module const&, torch_tensorrt::pyapi::CompileSpec&), torch::jit::Module (*)(torch::jit::Module const&, torch_tensorrt::pyapi::CompileSpec&), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, char const (&) [128])::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) () at /usr/local/lib/python3.8/dist-packages/torch/include/pybind11/pybind11.h:224
#8  0x0000ffff6d6e6828 in pybind11::cpp_function::dispatcher (self=<optimized out>, args_in=0xffff8997ac40, kwargs_in=0x0) at /usr/local/lib/python3.8/dist-packages/torch/include/pybind11/pybind11.h:929
#9  0x0000000000593248 in PyCFunction_Call ()
#10 0x00000000005935e4 in _PyObject_MakeTpCall ()
#11 0x0000000000503a68 in _PyEval_EvalFrameDefault ()
#12 0x00000000004fcf64 in _PyEval_EvalCodeWithName ()
#13 0x0000000000596a44 in _PyFunction_Vectorcall ()
#14 0x0000000000592bb8 in PyObject_Call ()
#15 0x00000000005001c8 in _PyEval_EvalFrameDefault ()
#16 0x00000000004fcf64 in _PyEval_EvalCodeWithName ()
#17 0x0000000000596a44 in _PyFunction_Vectorcall ()
#18 0x0000000000592bb8 in PyObject_Call ()
#19 0x00000000005001c8 in _PyEval_EvalFrameDefault ()
#20 0x0000000000596850 in _PyFunction_Vectorcall ()
#21 0x00000000004fec7c in _PyEval_EvalFrameDefault ()
#22 0x00000000004fcf64 in _PyEval_EvalCodeWithName ()
#23 0x0000000000661740 in PyEval_EvalCode ()
#24 0x000000000064d7f0 in ?? ()
#25 0x000000000064d8bc in ?? ()
#26 0x000000000064d9b4 in ?? ()
#27 0x000000000064ddc8 in PyRun_SimpleFileExFlags ()
#28 0x000000000069d244 in Py_RunMain ()
#29 0x000000000069da6c in Py_BytesMain ()
#30 0x0000fffff7e56e10 in __libc_start_main (main=0x59acfc <_start+56>, argc=2, argv=0xfffffffff098, init=<optimized out>, fini=<optimized out>, rtld_fini=<optimized out>, stack_end=<optimized out>)
    at ../csu/libc-start.c:308
#31 0x000000000059acf8 in _start ()
Backtrace stopped: previous frame identical to this frame (corrupt stack?)

And this is the output when setting debug outputs to torchtrt.logging.set_reportable_log_level(torchtrt.logging.Level.Debug):

/usr/local/lib/python3.8/dist-packages/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py:840: UserWarning: Unable to import torchvision related libraries.: No module named &apos;torchvision&apos;. Please install torchvision lib in order to lower stochastic_depth
  warnings.warn(
INFO: [Torch-TensorRT] - ir was set to default, using TorchScript as ir
DEBUG: [Torch-TensorRT] - TensorRT Compile Spec: {
    &quot;Inputs&quot;: [
Input(shape=(1,3,5,5,), dtype=Unknown data type, format=Contiguous/Linear/NCHW)    ]
    &quot;Enabled Precision&quot;: [Float, ]
    &quot;TF32 Disabled&quot;: 0
    &quot;Sparsity&quot;: 0
    &quot;Refit&quot;: 0
    &quot;Debug&quot;: 0
    &quot;Device&quot;:  {
        &quot;device_type&quot;: GPU
        &quot;allow_gpu_fallback&quot;: False
        &quot;gpu_id&quot;: 0
        &quot;dla_core&quot;: -1
    }

    &quot;Engine Capability&quot;: Default
    &quot;Num Avg Timing Iters&quot;: 1
    &quot;Workspace Size&quot;: 0
    &quot;DLA SRAM Size&quot;: 1048576
    &quot;DLA Local DRAM Size&quot;: 1073741824
    &quot;DLA Global DRAM Size&quot;: 536870912
    &quot;Truncate long and double&quot;: 0
    &quot;Torch Fallback&quot;:  {
        &quot;enabled&quot;: True
        &quot;min_block_size&quot;: 3
        &quot;forced_fallback_operators&quot;: [
        ]
        &quot;forced_fallback_modules&quot;: [
        ]
    }
}
DEBUG: [Torch-TensorRT] - init_compile_spec with input vector
DEBUG: [Torch-TensorRT] - Settings requested for Lowering:
    torch_executed_modules: [
    ]
DEBUG: [Torch-TensorRT] - RemoveNOPs - Note: Removing operators that have no meaning in TRT
INFO: [Torch-TensorRT] - Lowered Graph: graph(%x.1 : Tensor):
  %self.conv.weight.1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=&lt;Tensor&gt;]()
  %self.conv.bias.1 : Float(32, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=&lt;Tensor&gt;]()
  %5 : int = prim::Constant[value=1]() # /usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py:459:45
  %4 : int[] = prim::Constant[value=[1, 1]]()
  %3 : int[] = prim::Constant[value=[0, 0]]()
  %54 : bool = prim::Constant[value=0]()
  %55 : int[] = prim::Constant[value=[0, 0]]()
  %56 : Tensor = aten::_convolution(%x.1, %self.conv.weight.1, %self.conv.bias.1, %4, %3, %4, %54, %55, %5, %54, %54, %54, %54)
  %40 : float = prim::Constant[value=0.044714999999999998]()
  %41 : float = prim::Constant[value=0.79788456080000003]()
  %42 : float = prim::Constant[value=1.]()
  %43 : float = prim::Constant[value=0.5]()
  %44 : int = prim::Constant[value=1]()
  %45 : Tensor = aten::mul(%56, %43)
  %46 : Tensor = aten::mul(%56, %41)
  %47 : Tensor = aten::mul(%56, %40)
  %48 : Tensor = aten::mul(%47, %56)
  %49 : Tensor = aten::add(%48, %42, %44)
  %50 : Tensor = aten::mul(%46, %49)
  %51 : Tensor = aten::tanh(%50)
  %52 : Tensor = aten::add(%51, %42, %44)
  %53 : Tensor = aten::mul(%45, %52)
  return (%53)

DEBUG: [Torch-TensorRT] - Found 1 inputs to graph
DEBUG: [Torch-TensorRT] - Handle input of debug name: x.1
DEBUG: [Torch-TensorRT] - Paring 0: x.1 : Input(shape: [1, 3, 5, 5], dtype: Float32, format: NCHW\Contiguous\Linear)
DEBUG: [Torch-TensorRT] - Found 1 inputs to graph
DEBUG: [Torch-TensorRT] - Handle input of debug name: x.1
DEBUG: [Torch-TensorRT] - In MapInputsAndDetermineDTypes, the g-&gt;inputs() size is 1, CollectionInputSpecMap size is1
INFO: [Torch-TensorRT] - Since input type is not explicitly defined, infering using first tensor calculation
  Inferred input x.1 has type Float
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %self.conv.weight.1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=&lt;Tensor&gt;]() (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %self.conv.bias.1 : Float(32, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=&lt;Tensor&gt;]() (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %5 : int = prim::Constant[value=1]() # /usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py:459:45 (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %4 : int[] = prim::Constant[value=[1, 1]]() (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %3 : int[] = prim::Constant[value=[0, 0]]() (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %54 : bool = prim::Constant[value=0]() (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %55 : int[] = prim::Constant[value=[0, 0]]() (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %40 : float = prim::Constant[value=0.044714999999999998]() (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %41 : float = prim::Constant[value=0.79788456080000003]() (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %42 : float = prim::Constant[value=1.]() (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %43 : float = prim::Constant[value=0.5]() (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %44 : int = prim::Constant[value=1]() (NodeConverterRegistry.Convertable)
INFO: [Torch-TensorRT] - Skipping partitioning since model is fully supported
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %self.conv.weight.1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=&lt;Tensor&gt;]() (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %self.conv.bias.1 : Float(32, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=&lt;Tensor&gt;]() (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %5 : int = prim::Constant[value=1]() # /usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py:459:45 (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %4 : int[] = prim::Constant[value=[1, 1]]() (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %3 : int[] = prim::Constant[value=[0, 0]]() (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %54 : bool = prim::Constant[value=0]() (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %55 : int[] = prim::Constant[value=[0, 0]]() (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %40 : float = prim::Constant[value=0.044714999999999998]() (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %41 : float = prim::Constant[value=0.79788456080000003]() (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %42 : float = prim::Constant[value=1.]() (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %43 : float = prim::Constant[value=0.5]() (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %44 : int = prim::Constant[value=1]() (NodeConverterRegistry.Convertable)
WARNING: [Torch-TensorRT TorchScript Conversion Context] - Unable to determine GPU memory usage
WARNING: [Torch-TensorRT TorchScript Conversion Context] - Unable to determine GPU memory usage
INFO: [Torch-TensorRT TorchScript Conversion Context] - [MemUsageChange] Init CUDA: CPU +0, GPU +0, now: CPU 602, GPU 0 (MiB)
WARNING: [Torch-TensorRT TorchScript Conversion Context] - CUDA initialization failure with error: 222. Please check your CUDA installation:  http://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html
Segmentation fault (core dumped)

@narendasan
Copy link
Collaborator

@bowang007 can you take a look?

@narendasan narendasan added the platform: aarch64 Bugs regarding the x86_64 builds of TRTorch label May 8, 2023
@bowang007
Copy link
Collaborator

@narendasan sure.
Let me test this model locally and then on Orin later.

@janblumenkamp
Copy link
Author

Hi @bowang007, just wanted to check if you had a chance already to verify this?

@github-actions
Copy link

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working No Activity platform: aarch64 Bugs regarding the x86_64 builds of TRTorch
Projects
None yet
Development

No branches or pull requests

3 participants