Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AutoShape Usage #7128

Closed
1 task done
sezer-muhammed opened this issue Mar 24, 2022 · 13 comments · Fixed by #7560
Closed
1 task done

AutoShape Usage #7128

sezer-muhammed opened this issue Mar 24, 2022 · 13 comments · Fixed by #7560
Assignees
Labels
bug Something isn't working question Further information is requested

Comments

@sezer-muhammed
Copy link

Search before asking

Question

Hi,

I want to use TRT model in my code.
There are some repos about it but these are complicated, on the other side yolov5s's codes are clear and easy to follow.
So I want to implement trt model just like in detect.py

There are multidetect backend class, and autoshape class.

The torch hub uses autoshape and returns detections object which is very usefull.

As far as I see in the codes I may be able to use trt model with autoshape class and get detections class as return.

But I cannot figure it out how to make it.

The question is how to do this?

Additional

No response

@sezer-muhammed sezer-muhammed added the question Further information is requested label Mar 24, 2022
@glenn-jocher
Copy link
Member

glenn-jocher commented Mar 24, 2022

@sezer-muhammed usage examples are shown right after export. Just follow the TRT export code and then do whatever you want with the exported model:

# TensorRT export example
!pip install -U nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com  # install
!python export.py --weights yolov5s.pt --include engine --imgsz 640 640 --device 0  # export
!python detect.py --weights yolov5s.engine --imgsz 640 640 --device 0  # inference

ONNX output example:

Screenshot 2022-03-24 at 13 18 55

@sezer-muhammed
Copy link
Author

sezer-muhammed commented Mar 24, 2022

Hi, I tried this code:

self.model = DetectMultiBackend(model_path, device=torch.device(0), dnn=False, fp16=False, data="yolov5/data/head.yaml") self.model.warmup() self.model = AutoShape(self.model) self.model.conf = 0.45 frame = np.random.randint(0, 255, (640, 640, 3)) self.model(frame).print()

It works as expected with .pt models.
When I try to use .engine model it gives this error:

sezer@sezer:~/Desktop/metu-bitirme-face$ python3 main_code.py --verbose --model ../crowdhuman_yolov5m.engine 
/home/sezer/Desktop/metu-bitirme-face/deep_sort/deep/reid/torchreid/metrics/rank.py:11: UserWarning: Cython evaluation (very fast so highly recommended) is unavailable, now use python evaluation.
  warnings.warn(

Available ReID models for automatic download
['resnet50_market1501', 'resnet50_dukemtmcreid', 'resnet50_msmt17', 'resnet50_fc512_market1501', 'resnet50_fc512_dukemtmcreid', 'resnet50_fc512_msmt17', 'mlfn_market1501', 'mlfn_dukemtmcreid', 'mlfn_msmt17', 'hacnn_market1501', 'hacnn_dukemtmcreid', 'hacnn_msmt17', 'mobilenetv2_x1_0_market1501', 'mobilenetv2_x1_0_dukemtmcreid', 'mobilenetv2_x1_0_msmt17', 'mobilenetv2_x1_4_market1501', 'mobilenetv2_x1_4_dukemtmcreid', 'mobilenetv2_x1_4_msmt17', 'osnet_x1_0_market1501', 'osnet_x1_0_dukemtmcreid', 'osnet_x1_0_msmt17', 'osnet_x0_75_market1501', 'osnet_x0_75_dukemtmcreid', 'osnet_x0_75_msmt17', 'osnet_x0_5_market1501', 'osnet_x0_5_dukemtmcreid', 'osnet_x0_5_msmt17', 'osnet_x0_25_market1501', 'osnet_x0_25_dukemtmcreid', 'osnet_x0_25_msmt17', 'resnet50_MSMT17', 'osnet_x1_0_MSMT17', 'osnet_x0_75_MSMT17', 'osnet_x0_5_MSMT17', 'osnet_x0_25_MSMT17', 'osnet_ibn_x1_0_MSMT17', 'osnet_ain_x1_0_MSMT17']
Loading ../crowdhuman_yolov5m.engine for TensorRT inference...
[03/24/2022-16:16:17] [TRT] [I] [MemUsageChange] Init CUDA: CPU +444, GPU +0, now: CPU 656, GPU 739 (MiB)
[03/24/2022-16:16:17] [TRT] [I] Loaded engine size: 45 MiB
[03/24/2022-16:16:17] [TRT] [I] [MemUsageSnapshot] deserializeCudaEngine begin: CPU 702 MiB, GPU 739 MiB
[03/24/2022-16:16:17] [TRT] [W] TensorRT was linked against cuBLAS/cuBLAS LT 11.6.1 but loaded cuBLAS/cuBLAS LT 11.5.1
[03/24/2022-16:16:17] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +748, GPU +320, now: CPU 1461, GPU 1101 (MiB)
[03/24/2022-16:16:18] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +613, GPU +266, now: CPU 2074, GPU 1367 (MiB)
[03/24/2022-16:16:18] [TRT] [W] TensorRT was linked against cuDNN 8.2.1 but loaded cuDNN 8.2.0
[03/24/2022-16:16:18] [TRT] [I] [MemUsageSnapshot] deserializeCudaEngine end: CPU 2074 MiB, GPU 1349 MiB
[03/24/2022-16:16:20] [TRT] [I] [MemUsageSnapshot] ExecutionContext creation begin: CPU 5375 MiB, GPU 3253 MiB
[03/24/2022-16:16:20] [TRT] [W] TensorRT was linked against cuBLAS/cuBLAS LT 11.6.1 but loaded cuBLAS/cuBLAS LT 11.5.1
[03/24/2022-16:16:20] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +1, GPU +10, now: CPU 5376, GPU 3263 (MiB)
[03/24/2022-16:16:20] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 5376, GPU 3271 (MiB)
[03/24/2022-16:16:20] [TRT] [W] TensorRT was linked against cuDNN 8.2.1 but loaded cuDNN 8.2.0
[03/24/2022-16:16:20] [TRT] [I] [MemUsageSnapshot] ExecutionContext creation end: CPU 5376 MiB, GPU 3335 MiB
Adding AutoShape... 
[03/24/2022-16:16:21] [TRT] [E] 1: [executionContext.cpp::executeInternal::654] Error Code 1: Cuda Runtime (an illegal memory access was encountered)
Traceback (most recent call last):
  File "main_code.py", line 24, in <module>
    manager = ids_info(args.model, args.tracker, args.faces, args.verbose)
  File "/home/sezer/Desktop/metu-bitirme-face/paketler.py", line 68, in __init__
    self.model(frame).print()
  File "/home/sezer/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/sezer/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/home/sezer/Desktop/metu-bitirme-face/yolov5/models/common.py", line 561, in forward
    t.append(time_sync())
  File "/home/sezer/Desktop/metu-bitirme-face/./yolov5/utils/torch_utils.py", line 86, in time_sync
    torch.cuda.synchronize()
  File "/home/sezer/.local/lib/python3.8/site-packages/torch/cuda/__init__.py", line 493, in synchronize
    return torch._C._cuda_synchronize()
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
[03/24/2022-16:16:21] [TRT] [E] 1: [defaultAllocator.cpp::deallocate::35] Error Code 1: Cuda Runtime (an illegal memory access was encountered)
[03/24/2022-16:16:21] [TRT] [E] 1: [defaultAllocator.cpp::deallocate::35] Error Code 1: Cuda Runtime (an illegal memory access was encountered)
[03/24/2022-16:16:21] [TRT] [E] 1: [cudaResources.cpp::~ScopedCudaStream::47] Error Code 1: Cuda Runtime (an illegal memory access was encountered)
[03/24/2022-16:16:21] [TRT] [E] 1: [cudaResources.cpp::~ScopedCudaEvent::24] Error Code 1: Cuda Runtime (an illegal memory access was encountered)
[03/24/2022-16:16:21] [TRT] [E] 1: [cudaResources.cpp::~ScopedCudaEvent::24] Error Code 1: Cuda Runtime (an illegal memory access was encountered)

Uploading model with detectmultibackend works. but autoshape gives error when I try to run it.

Any solution for this?

.engine file can be used with autoshape?

@glenn-jocher
Copy link
Member

@sezer-muhammed AutoShape and DetectMultiBackend classes are used internally and not externally exposed, therefore their usage is undocumented.

@YoungjaeDev
Copy link

@glenn-jocher

Exactly why does connecting the engine with an autoshape forward cause an error below?

File "./yolov5/models/common.py", line 569, in forward
  t.append(time_sync())
File "./yolov5/utils/torch_utils.py", line 88, in time_sync
  torch.cuda.synchronize()
File "/home/nvidia/miniforge3/envs/nbas/lib/python3.6/site-packages/torch/cuda/__init__.py", line 402, in synchronize
  return torch._C._cuda_synchronize()
RuntimeError: CUDA error: an illegal memory access was encountered

@glenn-jocher
Copy link
Member

glenn-jocher commented Apr 23, 2022

@youngjae-avikus yes this is a known YOLOv5 TRT issue with AutoShape. The following code works correctly:

python export.py --weights yolov5s.pt --include engine
python detect.py --weights yolov5s.engine

But if we use the same model for AutoShape inference we get the above CUDA error you mentioned. I have no idea why, I've looked into it several times and can't find the cause. If you have any ideas or discover a solution please let us know!

EDIT: It's extra strange because both methods are essentially the same under the hood. They both use the DetectMultiBackend class here for inference:

yolov5/models/common.py

Lines 279 to 293 in 404b4fe

class DetectMultiBackend(nn.Module):
# YOLOv5 MultiBackend class for python inference on various backends
def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False):
# Usage:
# PyTorch: weights = *.pt
# TorchScript: *.torchscript
# ONNX Runtime: *.onnx
# ONNX OpenCV DNN: *.onnx with --dnn
# OpenVINO: *.xml
# CoreML: *.mlmodel
# TensorRT: *.engine
# TensorFlow SavedModel: *_saved_model
# TensorFlow GraphDef: *.pb
# TensorFlow Lite: *.tflite
# TensorFlow Edge TPU: *_edgetpu.tflite

@YoungjaeDev
Copy link

YoungjaeDev commented Apr 24, 2022

@glenn-jocher

Ah, I understood the cause, and when I modified it, it temporarily worked on the engine file
If you look at the Autoshape forward function below, the exactly device of weights is parsed only in the case of the pt file
When parse engine file, it becomes p.device.type == cpu
so input device is decided cpu

yolov5/models/common.py

Lines 533 to 538 in 7043872

t = [time_sync()]
p = next(self.model.parameters()) if self.pt else torch.zeros(1) # for device and type
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
if isinstance(imgs, torch.Tensor): # torch
with amp.autocast(autocast):
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference

So, temporarily changing it to the code below. That's fine

t = [time_sync()]
p = next(self.model.parameters()) if self.pt else torch.zeros(1)  # for device and type
autocast = self.amp and (p.device.type != 'cpu')  # Automatic Mixed Precision (AMP) inference
if isinstance(imgs, torch.Tensor):  # torch
    with amp.autocast(autocast):
        return self.model(imgs.to(p.device).type_as(p), augment, profile)  # inference

device = select_device('cuda:0') if torch.cuda.is_available() else 'cpu'
p = p.to(device)

or

p = next(self.model.parameters()) if self.pt else torch.zeros(1)  # for device and type
       
if self.dmb:
    device = self.model.device
    p = p.to(device)

autocast = self.amp and (p.device.type != 'cpu')  # Automatic Mixed Precision (AMP) inference
if isinstance(imgs, torch.Tensor):  # torch
    with amp.autocast(autocast):
        return self.model(imgs.to(p.device).type_as(p), augment, profile)  # inference

It's not up to me to decide if I should change this part
Please check.

@glenn-jocher
Copy link
Member

@youngjae-avikus ohhh yes I see. We have a device issue. Thanks for the fixes, I'll test them!

@glenn-jocher glenn-jocher added bug Something isn't working TODO labels Apr 24, 2022
glenn-jocher added a commit that referenced this issue Apr 24, 2022
Solution proposed in #7128 to TRT PyTorch Hub CUDA illegal memory errors.
glenn-jocher added a commit that referenced this issue Apr 24, 2022
Solution proposed in #7128 to TRT PyTorch Hub CUDA illegal memory errors.
@glenn-jocher glenn-jocher linked a pull request Apr 24, 2022 that will close this issue
@glenn-jocher
Copy link
Member

glenn-jocher commented Apr 24, 2022

@sezer-muhammed @youngjae-avikus good news 😃! Your original issue may now be fixed ✅ in PR #7560 implementing a solution by @youngjae-avikus.

Screen Shot 2022-04-24 at 12 44 49 PM

To receive this update:

  • Gitgit pull from within your yolov5/ directory or git clone https://github.com/ultralytics/yolov5 again
  • PyTorch Hub – Force-reload model = torch.hub.load('ultralytics/yolov5', 'yolov5s', force_reload=True)
  • Notebooks – View updated notebooks Open In Colab Open In Kaggle
  • Dockersudo docker pull ultralytics/yolov5:latest to update your image Docker Pulls

Thank you for spotting this issue and informing us of the problem. Please let us know if this update resolves the issue for you, and feel free to inform us of any other issues you discover or feature requests that come to mind. Happy trainings with YOLOv5 🚀!

@YoungjaeDev
Copy link

@glenn-jocher

I don't know exactly if it's something that needs to be solved
If engine export is 384x640, an error occurs because it is not pt in the code below

yolov5/models/common.py

Lines 559 to 562 in 950a85d

shape1 = [make_divisible(x, self.stride) if self.pt else size for x in np.array(shape1).max(0)] # inf shape
x = [letterbox(im, shape1, auto=False)[0] for im in imgs] # pad
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32

@glenn-jocher
Copy link
Member

@sezer-muhammed in general letterbox() function will pad images according to minimum stride constraints, but this will vary by image so you should be very careful exporting fixed size rectangular models. To get started the most easily I would export a square model and then maybe run some experiments to see what needs to be done to force rectangular inference at the size the model is expecting.

@YoungjaeDev
Copy link

@glenn-jocher

If it's a fixed image size RTSP video stream input, it's better to force it to run faster, right?

@glenn-jocher
Copy link
Member

@youngjae-avikus you could try to see what shape detect.py turns the stream into and then see if an exported engine model works with detect.py at that shape

BjarneKuehl pushed a commit to fhkiel-mlaip/yolov5 that referenced this issue Aug 26, 2022
Solution proposed in ultralytics#7128 to TRT PyTorch Hub CUDA illegal memory errors.
ctjanuhowski pushed a commit to ctjanuhowski/yolov5 that referenced this issue Sep 8, 2022
Solution proposed in ultralytics#7128 to TRT PyTorch Hub CUDA illegal memory errors.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants