Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add arm support #7500

Merged
merged 11 commits into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ FROM ${PYTORCH_IMAGE}

LABEL maintainer="monai.contact@gmail.com"

# TODO: remark for issue [revise the dockerfile](https://github.com/zarr-developers/numcodecs/issues/431)
WORKDIR /opt
RUN git clone --recursive https://github.com/zarr-developers/numcodecs.git
WORKDIR /opt/numcodecs
RUN pip wheel .
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved

WORKDIR /opt/monai

# install full deps
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ mypy>=1.5.0
ninja
torchvision
psutil
cucim>=23.2.0; platform_system == "Linux"
cucim-cu12; platform_system == "Linux" and python_version >= "3.9" and python_version <= "3.10"
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
openslide-python
imagecodecs; platform_system == "Linux" or platform_system == "Darwin"
tifffile; platform_system == "Linux" or platform_system == "Darwin"
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ all =
tqdm>=4.47.0
lmdb
psutil
cucim>=23.2.0
cucim-cu12; python_version >= '3.9' and python_version <= '3.10'
openslide-python
tifffile
imagecodecs
Expand Down Expand Up @@ -111,7 +111,7 @@ lmdb =
psutil =
psutil
cucim =
cucim>=23.2.0
cucim-cu12
openslide =
openslide-python
tifffile =
Expand Down
19 changes: 13 additions & 6 deletions tests/test_convert_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

import itertools
import platform
import unittest

import torch
Expand All @@ -29,6 +30,12 @@
TESTS = list(itertools.product(TORCH_DEVICE_OPTIONS, [True, False], [True, False]))
TESTS_ORT = list(itertools.product(TORCH_DEVICE_OPTIONS, [True]))

ON_AARCH64 = platform.machine() == "aarch64"
if ON_AARCH64:
rtol, atol = 1e-1, 1e-2
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
else:
rtol, atol = 1e-3, 1e-4

onnx, _ = optional_import("onnx")


Expand Down Expand Up @@ -56,8 +63,8 @@ def test_unet(self, device, use_trace, use_ort):
device=device,
use_ort=use_ort,
use_trace=use_trace,
rtol=1e-3,
atol=1e-4,
rtol=rtol,
atol=atol,
)
else:
# https://github.com/pytorch/pytorch/blob/release/1.9/torch/onnx/__init__.py#L182
Expand All @@ -72,8 +79,8 @@ def test_unet(self, device, use_trace, use_ort):
device=device,
use_ort=use_ort,
use_trace=use_trace,
rtol=1e-3,
atol=1e-4,
rtol=rtol,
atol=atol,
)
self.assertTrue(isinstance(onnx_model, onnx.ModelProto))

Expand Down Expand Up @@ -107,8 +114,8 @@ def test_seg_res_net(self, device, use_ort):
device=device,
use_ort=use_ort,
use_trace=True,
rtol=1e-3,
atol=1e-4,
rtol=rtol,
atol=atol,
)
self.assertTrue(isinstance(onnx_model, onnx.ModelProto))

Expand Down
9 changes: 8 additions & 1 deletion tests/test_dynunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

import platform
import unittest
from typing import Any, Sequence

Expand All @@ -24,6 +25,12 @@

InstanceNorm3dNVFuser, _ = optional_import("apex.normalization", name="InstanceNorm3dNVFuser")

ON_AARCH64 = platform.machine() == "aarch64"
if ON_AARCH64:
rtol, atol = 1e-2, 1e-2
else:
rtol, atol = 1e-4, 1e-4

device = "cuda" if torch.cuda.is_available() else "cpu"

strides: Sequence[Sequence[int] | int]
Expand Down Expand Up @@ -159,7 +166,7 @@ def test_consistency(self, input_param, input_shape, _):
with eval_mode(net_fuser):
result_fuser = net_fuser(input_tensor)

assert_allclose(result, result_fuser, rtol=1e-4, atol=1e-4)
assert_allclose(result, result_fuser, rtol=rtol, atol=atol)


class TestDynUNetDeepSupervision(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_rand_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_rand_affine(self, input_param, input_data, expected_val):
g.set_random_state(123)
result = g(**input_data)
g.rand_affine_grid.affine = torch.eye(4, dtype=torch.float64) # reset affine
test_resampler_lazy(g, result, input_param, input_data, seed=123)
test_resampler_lazy(g, result, input_param, input_data, seed=123, rtol=_rtol)
if input_param.get("cache_grid", False):
self.assertTrue(g._cached_grid is not None)
assert_allclose(result, expected_val, rtol=_rtol, atol=1e-4, type_test="tensor")
Expand Down
4 changes: 3 additions & 1 deletion tests/test_rand_affined.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ def test_rand_affined(self, input_param, input_data, expected_val, track_meta):
lazy_init_param["keys"], lazy_init_param["mode"] = key, mode
resampler = RandAffined(**lazy_init_param).set_random_state(123)
expected_output = resampler(**call_param)
test_resampler_lazy(resampler, expected_output, lazy_init_param, call_param, seed=123, output_key=key)
test_resampler_lazy(
resampler, expected_output, lazy_init_param, call_param, seed=123, output_key=key, rtol=_rtol
)
resampler.lazy = False

if input_param.get("cache_grid", False):
Expand Down
9 changes: 8 additions & 1 deletion tests/test_spatial_resampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

import platform
import unittest

import numpy as np
Expand All @@ -23,6 +24,12 @@
from tests.lazy_transforms_utils import test_resampler_lazy
from tests.utils import TEST_DEVICES, assert_allclose

ON_AARCH64 = platform.machine() == "aarch64"
if ON_AARCH64:
rtol, atol = 1e-1, 1e-2
else:
rtol, atol = 1e-3, 1e-4

TESTS = []

destinations_3d = [
Expand Down Expand Up @@ -104,7 +111,7 @@ def test_flips_inverse(self, img, device, dst_affine, kwargs, expected_output):

# check lazy
lazy_xform = SpatialResampled(**init_param)
test_resampler_lazy(lazy_xform, output_data, init_param, call_param, output_key="img")
test_resampler_lazy(lazy_xform, output_data, init_param, call_param, output_key="img", rtol=rtol, atol=atol)

# check inverse
inverted = xform.inverse(output_data)["img"]
Expand Down
Loading