Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error while running live_demo.ipynb, __len__() should return >= 0 #178

Open
Po-Ting-lin opened this issue Jun 14, 2024 · 3 comments
Open

Comments

@Po-Ting-lin
Copy link

Hi Guys,

I tried to run this repo on Jetson nano, and I did the same steps in the instruction.
Somehow it failed to convert model to TRT.
It seems the dimension is not correct?

>>> import json
>>> import trt_pose.coco
>>> 
>>> with open('human_pose.json', 'r') as f:
...     human_pose = json.load(f)
... 
>>> topology = trt_pose.coco.coco_category_to_topology(human_pose)
>>> import trt_pose.models
>>> 
>>> num_parts = len(human_pose['keypoints'])
>>> num_links = len(human_pose['skeleton'])
>>> 
>>> model = trt_pose.models.resnet18_baseline_att(num_parts, 2 * num_links).cuda().eval()
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/bt/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|████████████████████████████████████████████████████████████████████████████████| 44.7M/44.7M [00:17<00:00, 2.72MB/s]
>>> import torch
>>> 
>>> MODEL_WEIGHTS = 'resnet18_baseline_att_224x224_A_epoch_249.pth'
>>> 
>>> model.load_state_dict(torch.load(MODEL_WEIGHTS))
<All keys matched successfully>
>>> WIDTH = 224
>>> HEIGHT = 224
>>> 
>>> data = torch.zeros((1, 3, HEIGHT, WIDTH)).cuda()
>>> import torch2trt
>>> 
>>> model_trt = torch2trt.torch2trt(model, [data], fp16_mode=True, max_workspace_size=1<<25)
[06/14/2024-10:04:27] [TRT] [E] 3: 1.cmap_up.0:0:DECONVOLUTION:GPU:kernel weights has count 2097152 but 4194304 was expected
[06/14/2024-10:04:27] [TRT] [E] 4: 1.cmap_up.0:0:DECONVOLUTION:GPU: count of 2097152 weights in kernel, but kernel dimensions (4,4) with 512 input channels, 512 output channels and 1 groups were specified. Expected Weights count is 512 * 4*4 * 512 / 1 = 4194304
[06/14/2024-10:04:28] [TRT] [E] 3: 1.cmap_up.0:0:DECONVOLUTION:GPU:kernel weights has count 2097152 but 4194304 was expected
[06/14/2024-10:04:28] [TRT] [E] 4: 1.cmap_up.0:0:DECONVOLUTION:GPU: count of 2097152 weights in kernel, but kernel dimensions (4,4) with 512 input channels, 512 output channels and 1 groups were specified. Expected Weights count is 512 * 4*4 * 512 / 1 = 4194304
[06/14/2024-10:04:28] [TRT] [E] 3: 1.cmap_up.0:0:DECONVOLUTION:GPU:kernel weights has count 2097152 but 4194304 was expected
[06/14/2024-10:04:28] [TRT] [E] 4: 1.cmap_up.0:0:DECONVOLUTION:GPU: count of 2097152 weights in kernel, but kernel dimensions (4,4) with 512 input channels, 512 output channels and 1 groups were specified. Expected Weights count is 512 * 4*4 * 512 / 1 = 4194304
[06/14/2024-10:04:28] [TRT] [E] 3: 1.cmap_up.0:0:DECONVOLUTION:GPU:kernel weights has count 2097152 but 4194304 was expected
[06/14/2024-10:04:28] [TRT] [E] 4: 1.cmap_up.0:0:DECONVOLUTION:GPU: count of 2097152 weights in kernel, but kernel dimensions (4,4) with 512 input channels, 512 output channels and 1 groups were specified. Expected Weights count is 512 * 4*4 * 512 / 1 = 4194304
[06/14/2024-10:04:28] [TRT] [E] 3: 1.cmap_up.0:0:DECONVOLUTION:GPU:kernel weights has count 2097152 but 4194304 was expected
[06/14/2024-10:04:28] [TRT] [E] 4: 1.cmap_up.0:0:DECONVOLUTION:GPU: count of 2097152 weights in kernel, but kernel dimensions (4,4) with 512 input channels, 512 output channels and 1 groups were specified. Expected Weights count is 512 * 4*4 * 512 / 1 = 4194304
[06/14/2024-10:04:29] [TRT] [E] 3: 1.cmap_up.0:0:DECONVOLUTION:GPU:kernel weights has count 2097152 but 4194304 was expected
[06/14/2024-10:04:29] [TRT] [E] 4: 1.cmap_up.0:0:DECONVOLUTION:GPU: count of 2097152 weights in kernel, but kernel dimensions (4,4) with 512 input channels, 512 output channels and 1 groups were specified. Expected Weights count is 512 * 4*4 * 512 / 1 = 4194304
[06/14/2024-10:04:29] [TRT] [E] 3: 1.cmap_up.0:0:DECONVOLUTION:GPU:kernel weights has count 2097152 but 4194304 was expected
[06/14/2024-10:04:29] [TRT] [E] 4: 1.cmap_up.0:0:DECONVOLUTION:GPU: count of 2097152 weights in kernel, but kernel dimensions (4,4) with 512 input channels, 512 output channels and 1 groups were specified. Expected Weights count is 512 * 4*4 * 512 / 1 = 4194304
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.6/dist-packages/torch2trt-0.5.0-py3.6-linux-aarch64.egg/torch2trt/torch2trt.py", line 643, in torch2trt
    outputs = module(*inputs)
  File "/home/bt/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1120, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/bt/.local/lib/python3.6/site-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/home/bt/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1120, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/trt_pose-0.0.1-py3.6-linux-aarch64.egg/trt_pose/models/common.py", line 76, in forward
    return self.cmap_conv(xc * ac), self.paf_conv(xp * ap)
  File "/usr/local/lib/python3.6/dist-packages/torch2trt-0.5.0-py3.6-linux-aarch64.egg/torch2trt/torch2trt.py", line 262, in wrapper
    converter["converter"](ctx)
  File "/usr/local/lib/python3.6/dist-packages/torch2trt-0.5.0-py3.6-linux-aarch64.egg/torch2trt/converters/native_converters.py", line 1496, in convert_mul
    input_a_trt, input_b_trt = broadcast_trt_tensors(ctx.network, [input_a_trt, input_b_trt], len(output.shape))
  File "/usr/local/lib/python3.6/dist-packages/torch2trt-0.5.0-py3.6-linux-aarch64.egg/torch2trt/torch2trt.py", line 146, in broadcast_trt_tensors
    if len(t.shape) < broadcast_ndim:
ValueError: __len__() should return >= 0
  • The weight was from the instruction resnet18_baseline_att_224x224_A_epoch_249.pth
  • Python==3.6
  • Torch==1.10.0
  • Torchvision==0.11.1
  • torch2trt==0.5.0
  • sudo apt-cache show nvidia-jetpack
Package: nvidia-jetpack
Version: 4.6.4-b39
Architecture: arm64
Maintainer: NVIDIA Corporation
Installed-Size: 194
Depends: nvidia-l4t-jetson-multimedia-api (>> 32.7-0), nvidia-l4t-jetson-multimedia-api (<< 32.8-0), nvidia-cuda (= 4.6.4-b39), nvidia-tensorrt (= 4.6.4-b39), nvidia-nsight-sys (= 4.6.4-b39), nvidia-cudnn8 (= 4.6.4-b39), nvidia-opencv (= 4.6.4-b39), nvidia-container (= 4.6.4-b39), nvidia-visionworks (= 4.6.4-b39), nvidia-vpi (= 4.6.4-b39)
Homepage: http://developer.nvidia.com/jetson
Priority: standard
Section: metapackages
Filename: pool/main/n/nvidia-jetpack/nvidia-jetpack_4.6.4-b39_arm64.deb
Size: 29388
SHA256: adf7a6660f73cdc4f95bc15c48d8588688e3afa5ee18bfd5b3a3caa3a458aa02
SHA1: 5abbe0df74f71579c1a0ee30ab7c2c236e1bcdbb
MD5sum: ec293a56d17f2b2793448d621811330d
Description: NVIDIA Jetpack Meta Package
Description-md5: ad1462289bdbc54909ae109d1d32c0a8

Is anyone know why, Thanks!

@Po-Ting-lin Po-Ting-lin changed the title Error while running live_demo.ipynb Error while running live_demo.ipynb, __len__() should return >= 0 Jun 14, 2024
@CMeiyi
Copy link

CMeiyi commented Jun 20, 2024

Have you found a solution?
I had the same problem.

@jiri-kula
Copy link

@Po-Ting-lin I read your log thoroughly and it seems to me that we have the exact same setup, and also the same type of error. The only difference was that I got stuck with the trt_pose_hand project, which is a derivative of the trt_pose.

What fixe the issue for me was that I went through initial dependencies setup again while putting more attention to download correct and not mix up the version of packages.

This is the procedure that I have followed, after which I was able to run the notebook.

Setup

wget https://nvidia.box.com/shared/static/fjtbno0vpo676a25cgvuqc1wty0fkkg6.whl -O torch-1.10.0-cp36-cp36m-linux_aarch64.whl
sudo apt-get install    \
    python3-pip         \
    libopenblas-base    \
    libopenmpi-dev      \
    libomp-dev
pip3 install 'Cython<3'
pip3 install numpy torch-1.10.0-cp36-cp36m-linux_aarch64.whl
sudo apt-get install    \
    libjpeg-dev         \
    zlib1g-dev          \
    libpython3-dev      \
    libopenblas-dev     \
    libavcodec-dev      \
    libavformat-dev     \
    libswscale-dev
git clone --branch v0.11.1 https://github.com/pytorch/vision torchvision    
cd torchvision
export BUILD_VERSION=v0.11.1  # where 0.x.0 is the torchvision version  
pip install 'pillow<7'  
python3 setup.py install --user

Which fixed the issue for me. At least I can tell that the code is working and maybe give you a little clue that the source of the problem might be in the environment variables or packages that are installed, not the code of the script itself.

@Chang558
Copy link

Chang558 commented Sep 23, 2024

Thanks, @jiri-kula! I encountered the same issue, but I managed to resolve it after seeing your explanation.

The issue occurs because torch2trt does not support PyTorch 2.x versions. Since I’m using PyTorch 2.x, that’s why the problem arose.

Some suggest downgrading to PyTorch 1.x to resolve the issue, but I’m unable to do that since I’m working on another project that requires PyTorch 2.x. So, instead of using torch2trt, I converted the model directly into a TensorRT engine.

I hope this solution can help others who are facing similar issues.

This is the code I used:

print('-----------------------Model Load Part------------------')

WIDTH=224
HEIGHT=224

data=torch.zeros((1,3, HEIGHT, WIDTH)).cuda()

torch.onnx.export(
model,
data,
"trt-pose.onnx",
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}, 'output':{0:'batch_size'}}
)

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(1<<int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))

parser = trt.OnnxParser(network, TRT_LOGGER)

with open("trt-pose.onnx",'rb') as mf:
if not parser.parse(mf.read()):
print("ERROR: Failed to parse the ONNX file.")
for error in range(parser.num_errors):
print(parser.get_error(error))
exit()

config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 35)

if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)

profile = builder.create_optimization_profile()
input_tensor_name ='input'
profile.set_shape(input_tensor_name, min=(1, 3, 224, 224), opt=(8, 3, 224, 224), max=(32, 3, 224, 224))
config.add_optimization_profile(profile)
engine = builder.build_engine(network, config)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants
@Po-Ting-lin @jiri-kula @Chang558 @CMeiyi and others