|
14 | 14 |
|
15 | 15 | import os
|
16 | 16 | import numpy as np
|
17 |
| -import shutil |
18 |
| -import sys |
| 17 | +import platform |
19 | 18 | import cv2
|
20 | 19 | from nvidia.dali.pipeline import Pipeline
|
21 | 20 | import nvidia.dali.ops as ops
|
22 |
| -import nvidia.dali.types as types |
23 |
| -from random import shuffle |
24 | 21 | from test_utils import get_dali_extra_path
|
25 | 22 |
|
26 | 23 | test_data_root = get_dali_extra_path()
|
27 | 24 | images_dir = os.path.join(test_data_root, 'db', 'imgproc')
|
28 | 25 |
|
| 26 | +is_of_supported_var = None |
| 27 | +def is_of_supported(device_id=0): |
| 28 | + global is_of_supported_var |
| 29 | + if is_of_supported_var is not None: |
| 30 | + return is_of_supported_var |
| 31 | + |
| 32 | + compute_cap = 0 |
| 33 | + driver_version_major = 0 |
| 34 | + try: |
| 35 | + import pynvml |
| 36 | + pynvml.nvmlInit() |
| 37 | + handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) |
| 38 | + compute_cap = pynvml.nvmlDeviceGetCudaComputeCapability(handle) |
| 39 | + compute_cap = compute_cap[0] + compute_cap[1] / 10. |
| 40 | + driver_version = pynvml.nvmlSystemGetDriverVersion().decode('utf-8') |
| 41 | + driver_version_major = int(driver_version.split('.')[0]) |
| 42 | + except ModuleNotFoundError: |
| 43 | + print("NVML not found") |
| 44 | + |
| 45 | + # there is an issue with OpticalFlow driver in R495 and newer on aarch64 platform |
| 46 | + is_of_supported_var = compute_cap >= 7.5 and (platform.machine() == "x86_64" or driver_version_major < 495) |
| 47 | + return is_of_supported_var |
| 48 | + |
29 | 49 | def get_mapping(shape):
|
30 | 50 | h, w = shape
|
31 | 51 | x = np.arange(w, dtype=np.float32) + 0.5
|
@@ -192,7 +212,9 @@ def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False):
|
192 | 212 | interactive = False
|
193 | 213 |
|
194 | 214 | def test_optflow():
|
195 |
| - pipe = OFPipeline(3, 0); |
| 215 | + if not is_of_supported(): |
| 216 | + raise nose.SkipTest('Optical Flow is not supported on this platform') |
| 217 | + pipe = OFPipeline(3, 0) |
196 | 218 | pipe.build()
|
197 | 219 | out = pipe.run()
|
198 | 220 | seq = out[0].at(0)
|
|
0 commit comments