Skip to content

Commit 9e1474e

Browse files
JanuszLcyyever
authored andcommitted
Turn off OpticalFlow test on aarch64 platform for driver r495.x and newer (NVIDIA#3566)
- there is an OpticalFlow driver crash on the aarch64 platform for driver r495.x and newer. This PR blacklist this OpticalFlow tests for aarch64 for this set of drivers Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
1 parent 75a00d3 commit 9e1474e

File tree

2 files changed

+28
-26
lines changed

2 files changed

+28
-26
lines changed

dali/test/python/test_dali_variable_batch_size.py

+1-21
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import random
3232
import nose
3333
from nose.plugins.attrib import attr
34+
from test_optical_flow import is_of_supported
3435

3536
"""
3637
How to test variable (iter-to-iter) batch size for a given op?
@@ -61,27 +62,6 @@
6162
"""
6263

6364

64-
is_of_supported_var = None
65-
def is_of_supported(device_id=0):
66-
global is_of_supported_var
67-
if is_of_supported_var is not None:
68-
return is_of_supported_var
69-
70-
compute_cap = 0
71-
try:
72-
import pynvml
73-
pynvml.nvmlInit()
74-
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
75-
compute_cap = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
76-
compute_cap = compute_cap[0] + compute_cap[1] / 10.
77-
except ModuleNotFoundError:
78-
print("NVML not found")
79-
pass
80-
81-
is_of_supported_var = compute_cap >= 7.5
82-
return is_of_supported_var
83-
84-
8565
def generate_data(max_batch_size, n_iter, sample_shape, lo=0., hi=1., dtype=np.float32):
8666
"""
8767
Generates an epoch of data, that will be used for variable batch size verification.

dali/test/python/test_optical_flow.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,38 @@
1414

1515
import os
1616
import numpy as np
17-
import shutil
18-
import sys
17+
import platform
1918
import cv2
2019
from nvidia.dali.pipeline import Pipeline
2120
import nvidia.dali.ops as ops
22-
import nvidia.dali.types as types
23-
from random import shuffle
2421
from test_utils import get_dali_extra_path
2522

2623
test_data_root = get_dali_extra_path()
2724
images_dir = os.path.join(test_data_root, 'db', 'imgproc')
2825

26+
is_of_supported_var = None
27+
def is_of_supported(device_id=0):
28+
global is_of_supported_var
29+
if is_of_supported_var is not None:
30+
return is_of_supported_var
31+
32+
compute_cap = 0
33+
driver_version_major = 0
34+
try:
35+
import pynvml
36+
pynvml.nvmlInit()
37+
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
38+
compute_cap = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
39+
compute_cap = compute_cap[0] + compute_cap[1] / 10.
40+
driver_version = pynvml.nvmlSystemGetDriverVersion().decode('utf-8')
41+
driver_version_major = int(driver_version.split('.')[0])
42+
except ModuleNotFoundError:
43+
print("NVML not found")
44+
45+
# there is an issue with OpticalFlow driver in R495 and newer on aarch64 platform
46+
is_of_supported_var = compute_cap >= 7.5 and (platform.machine() == "x86_64" or driver_version_major < 495)
47+
return is_of_supported_var
48+
2949
def get_mapping(shape):
3050
h, w = shape
3151
x = np.arange(w, dtype=np.float32) + 0.5
@@ -192,7 +212,9 @@ def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False):
192212
interactive = False
193213

194214
def test_optflow():
195-
pipe = OFPipeline(3, 0);
215+
if not is_of_supported():
216+
raise nose.SkipTest('Optical Flow is not supported on this platform')
217+
pipe = OFPipeline(3, 0)
196218
pipe.build()
197219
out = pipe.run()
198220
seq = out[0].at(0)

0 commit comments

Comments
 (0)