Skip to content

Commit

Permalink
Nvtx transform (#2713)
Browse files Browse the repository at this point in the history
  • Loading branch information
drbeh authored Aug 12, 2021
1 parent f9bc713 commit 8726dd5
Show file tree
Hide file tree
Showing 4 changed files with 334 additions and 0 deletions.
29 changes: 29 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,35 @@ IO
:members:
:special-members: __call__


NVIDIA Tool Extension (NVTX)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

`RangePush`
"""""""""""
.. autoclass:: RangePush

`RandRangePush`
"""""""""""""""
.. autoclass:: RandRangePush

`RangePop`
""""""""""
.. autoclass:: RangePop

`RandRangePop`
""""""""""""""
.. autoclass:: RandRangePop

`Mark`
""""""
.. autoclass:: Mark

`RandMark`
""""""""""
.. autoclass:: RandMark


Post-processing
^^^^^^^^^^^^^^^

Expand Down
26 changes: 26 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,32 @@
from .inverse_batch_transform import BatchInverseTransform, Decollated
from .io.array import LoadImage, SaveImage
from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict
from .nvtx import (
Mark,
Markd,
MarkD,
MarkDict,
RandMark,
RandMarkd,
RandMarkD,
RandMarkDict,
RandRangePop,
RandRangePopd,
RandRangePopD,
RandRangePopDict,
RandRangePush,
RandRangePushd,
RandRangePushD,
RandRangePushDict,
RangePop,
RangePopd,
RangePopD,
RangePopDict,
RangePush,
RangePushd,
RangePushD,
RangePushDict,
)
from .post.array import (
Activations,
AsDiscrete,
Expand Down
125 changes: 125 additions & 0 deletions monai/transforms/nvtx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Wrapper around NVIDIA Tools Extension for profiling MONAI transformations
"""

from monai.transforms.transform import RandomizableTransform, Transform
from monai.utils import optional_import

_nvtx, _ = optional_import("torch._C._nvtx", descriptor="NVTX is not installed. Are you sure you have a CUDA build?")

__all__ = [
"Mark",
"Markd",
"MarkD",
"MarkDict",
"RandMark",
"RandMarkd",
"RandMarkD",
"RandMarkDict",
"RandRangePop",
"RandRangePopd",
"RandRangePopD",
"RandRangePopDict",
"RandRangePush",
"RandRangePushd",
"RandRangePushD",
"RandRangePushDict",
"RangePop",
"RangePopd",
"RangePopD",
"RangePopDict",
"RangePush",
"RangePushd",
"RangePushD",
"RangePushDict",
]


class RangePush(Transform):
"""
Pushes a range onto a stack of nested range span.
Stores zero-based depth of the range that is started.
Args:
msg: ASCII message to associate with range
"""

def __init__(self, msg: str) -> None:
self.msg = msg
self.depth = None

def __call__(self, data):
self.depth = _nvtx.rangePushA(self.msg)
return data


class RandRangePush(RangePush, RandomizableTransform):
"""
Pushes a range onto a stack of nested range span (RandomizableTransform).
Stores zero-based depth of the range that is started.
Args:
msg: ASCII message to associate with range
"""


class RangePop(Transform):
"""
Pops a range off of a stack of nested range spans.
Stores zero-based depth of the range that is ended.
"""

def __call__(self, data):
_nvtx.rangePop()
return data


class RandRangePop(RangePop, RandomizableTransform):
"""
Pops a range off of a stack of nested range spans (RandomizableTransform).
Stores zero-based depth of the range that is ended.
"""


class Mark(Transform):
"""
Mark an instantaneous event that occurred at some point.
Args:
msg: ASCII message to associate with the event.
"""

def __init__(self, msg: str) -> None:
self.msg = msg

def __call__(self, data):
_nvtx.markA(self.msg)
return data


class RandMark(Mark, RandomizableTransform):
"""
Mark an instantaneous event that occurred at some point.
(RandomizableTransform)
Args:
msg: ASCII message to associate with the event.
"""


MarkDict = MarkD = Markd = Mark
RandMarkDict = RandMarkD = RandMarkd = RandMark
RandRangePopDict = RandRangePopD = RandRangePopd = RandRangePop
RandRangePushDict = RandRangePushD = RandRangePushd = RandRangePush
RangePopDict = RangePopD = RangePopd = RangePop
RangePushDict = RangePushD = RangePushd = RangePush
154 changes: 154 additions & 0 deletions tests/test_nvtx_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.transforms import Compose, Flip, RandFlip, RandFlipD, Randomizable, ToTensor, ToTensorD
from monai.transforms.nvtx import (
Mark,
MarkD,
RandMark,
RandMarkD,
RandRangePop,
RandRangePopD,
RandRangePush,
RandRangePushD,
RangePop,
RangePopD,
RangePush,
RangePushD,
)
from monai.utils import optional_import

_, has_nvtx = optional_import("torch._C._nvtx", descriptor="NVTX is not installed. Are you sure you have a CUDA build?")


TEST_CASE_ARRAY_0 = [
np.random.randn(3, 3),
]
TEST_CASE_ARRAY_1 = [
np.random.randn(3, 10, 10),
]
TEST_CASE_DICT_0 = [
{"image": np.random.randn(3, 3)},
]
TEST_CASE_DICT_1 = [
{"image": np.random.randn(3, 10, 10)},
]


class TestNVTXTransforms(unittest.TestCase):
@parameterized.expand(
[
TEST_CASE_ARRAY_0,
TEST_CASE_ARRAY_1,
TEST_CASE_DICT_0,
TEST_CASE_DICT_1,
]
)
@unittest.skipUnless(has_nvtx, "CUDA is required for NVTX!")
def test_nvtx_transfroms_alone(self, input):
transforms = Compose(
[
Mark("Mark: Transform Starts!"),
RangePush("Range: RandFlipD"),
RangePop(),
RandRangePush("Range: ToTensorD"),
RandRangePop(),
RandMark("Mark: Transform Ends!"),
]
)
output = transforms(input)
self.assertEqual(id(input), id(output))

# Check if chain of randomizable/non-randomizable transforms is not broken
for tran in transforms.transforms:
if isinstance(tran, Randomizable):
self.assertIsInstance(tran, RangePush)
break

@parameterized.expand([TEST_CASE_ARRAY_0, TEST_CASE_ARRAY_1])
@unittest.skipUnless(has_nvtx, "CUDA is required for NVTX!")
def test_nvtx_transfroms_array(self, input):
transforms = Compose(
[
RandMark("Mark: Transform Starts!"),
RandRangePush("Range: RandFlip"),
RandFlip(prob=0.0),
RandRangePop(),
RangePush("Range: ToTensor"),
ToTensor(),
RangePop(),
Mark("Mark: Transform Ends!"),
]
)
output = transforms(input)
self.assertIsInstance(output, torch.Tensor)
np.testing.assert_array_equal(input, output)

transforms = Compose(
[
RandMark("Mark: Transform Starts!"),
RandRangePush("Range: RandFlip"),
RandFlip(prob=1.0),
RandRangePop(),
RangePush("Range: ToTensor"),
ToTensor(),
RangePop(),
Mark("Mark: Transform Ends!"),
]
)
output = transforms(input)
self.assertIsInstance(output, torch.Tensor)
np.testing.assert_array_equal(input, Flip()(output.numpy()))

@parameterized.expand([TEST_CASE_DICT_0, TEST_CASE_DICT_1])
@unittest.skipUnless(has_nvtx, "CUDA is required for NVTX!")
def test_nvtx_transfromsd(self, input):
transforms = Compose(
[
RandMarkD("Mark: Transform Starts!"),
RandRangePushD("Range: RandFlipD"),
RandFlipD(keys="image", prob=0.0),
RandRangePopD(),
RangePushD("Range: ToTensorD"),
ToTensorD(keys=("image")),
RangePopD(),
MarkD("Mark: Transform Ends!"),
]
)
output = transforms(input)
self.assertIsInstance(output["image"], torch.Tensor)
np.testing.assert_array_equal(input["image"], output["image"])

transforms = Compose(
[
RandMarkD("Mark: Transform Starts!"),
RandRangePushD("Range: RandFlipD"),
RandFlipD(keys="image", prob=1.0),
RandRangePopD(),
RangePushD("Range: ToTensorD"),
ToTensorD(keys=("image")),
RangePopD(),
MarkD("Mark: Transform Ends!"),
]
)
output = transforms(input)
self.assertIsInstance(output["image"], torch.Tensor)
np.testing.assert_array_equal(input["image"], Flip()(output["image"].numpy()))


if __name__ == "__main__":
unittest.main()

0 comments on commit 8726dd5

Please sign in to comment.