Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NVTX handlers #2765

Merged
merged 24 commits into from
Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ Decollate batch
.. autoclass:: DecollateBatch
:members:

NVTX Handlers
---------------
drbeh marked this conversation as resolved.
Show resolved Hide resolved
.. automodule:: monai.handlers.nvtx_handlers
:members:

Utilities
---------
.. automodule:: monai.handlers.utils
Expand Down
1 change: 1 addition & 0 deletions monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .mean_dice import MeanDice
from .metric_logger import MetricLogger, MetricLoggerKeys
from .metrics_saver import MetricsSaver
from .nvtx_handlers import MarkHandler, RangeHandler, RangePopHandler, RangePushHandler
from .parameter_scheduler import ParamSchedulerHandler
from .postprocessing import PostProcessing
from .regression_metrics import MeanAbsoluteError, MeanSquaredError, PeakSignalToNoiseRatio, RootMeanSquaredError
Expand Down
176 changes: 176 additions & 0 deletions monai/handlers/nvtx_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Wrapper around NVIDIA Tools Extension for profiling MONAI ignite workflow
"""

from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union

from monai.config import IgniteInfo
from monai.utils import min_version, optional_import

_nvtx, _ = optional_import("torch._C._nvtx", descriptor="NVTX is not installed. Are you sure you have a CUDA build?")
if TYPE_CHECKING:
from ignite.engine import Engine, Events
else:
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")


__all__ = ["RangeHandler", "RangePushHandler", "RangePopHandler", "MarkHandler"]


class RangeHandler:
"""
Attach a NVTX range to a pair of Ignite events.
It pushes an NVTX range at the first event and pops it at the second event.
Stores zero-based depth of the range that is started.

Args:
event: a string or pair of Ignite events to attach a range to
If a single string is provided, it should describe the base name of a pair of default Ignite events
with _STARTED and _COMPLETED postfix (like "EPOCH" for EPOCH_STARTED and EPOCH_COMPLETED).
The common accepted events are: BATCH, ITERATION, EPOCH, and ENGINE.
For the complete list of Events,
check https://pytorch.org/ignite/generated/ignite.engine.events.Events.html.

msg: ASCII message to associate with range.
If not provided, the name of first event will be assigned to the NVTX range.
"""

def __init__(self, events: Union[str, Tuple[Events, Events]], msg: Optional[str] = None) -> None:
if msg is None:
if isinstance(events, str):
msg = events
else:
msg = events[0].strip("_")[0]
drbeh marked this conversation as resolved.
Show resolved Hide resolved
self.msg = msg
if isinstance(events, str):
if events.upper() in ["ENGINE", ""]:
event_name = ""
elif events.upper() == "BATCH":
event_name = "GET_BATCH_"
drbeh marked this conversation as resolved.
Show resolved Hide resolved
else:
event_name = events.upper() + "_"
self.events = (
getattr(Events, event_name + "STARTED"),
getattr(Events, event_name + "COMPLETED"),
)
elif isinstance(events, Sequence):
if len(events) != 2:
raise ValueError(f"Exactly two Ignite events should be provided [received {len(events)}].")
if not isinstance(events[0], Events):
raise ValueError("The provided first event is not an Ignite event!")
if not isinstance(events[1], Events):
raise ValueError("The provided second event is not an Ignite event!")
self.events = events
else:
raise ValueError("The start/end events should either be a string or tuple/list of two Ignite events.")
self.depth = None

def attach(self, engine: Engine) -> None:
"""
Attach an NVTX Range to specific Ignite events
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
engine.add_event_handler(self.events[0], self.range_push)
engine.add_event_handler(self.events[1], self.range_pop)

def range_push(self):
self.depth = _nvtx.rangePushA(self.msg)
drbeh marked this conversation as resolved.
Show resolved Hide resolved

def range_pop(self):
_nvtx.rangePop()


class RangePushHandler:
"""
At a specific event, pushes a range onto a stack of nested range span.
Stores zero-based depth of the range that is started.

Args:
msg: ASCII message to associate with range
"""

def __init__(self, event: Events, msg: Optional[str] = None) -> None:
if isinstance(event, str):
event = getattr(Events, event)
drbeh marked this conversation as resolved.
Show resolved Hide resolved
if msg is None:
msg = event.name
self.msg = msg
self.event = event
self.depth = None

def attach(self, engine: Engine) -> None:
"""
Push an NVTX range at a specific Ignite event
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
engine.add_event_handler(self.event, self.range_push)

def range_push(self):
self.depth = _nvtx.rangePushA(self.msg)


class RangePopHandler:
"""
At a specific event, pop a previously pushed range.
Stores zero-based depth of the range that is started.

Args:
msg: ASCII message to associate with range
"""

def __init__(self, event: Events) -> None:
if isinstance(event, str):
event = getattr(Events, event)
self.event = event

def attach(self, engine: Engine) -> None:
"""
Pop an NVTX range at a specific Ignite event
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
engine.add_event_handler(self.event, self.range_pop)

def range_pop(self):
_nvtx.rangePop()


class MarkHandler:
"""
Mark an instantaneous event that occurred at some point.

Args:
msg: ASCII message to associate with range
"""

def __init__(self, event: Events, msg: Optional[str] = None) -> None:
if isinstance(event, str):
drbeh marked this conversation as resolved.
Show resolved Hide resolved
event = getattr(Events, event)
if msg is None:
msg = event.name
self.msg = msg
self.event = event

def attach(self, engine: Engine) -> None:
"""
Add an NVTX mark to a specific Ignite event
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
engine.add_event_handler(self.event, self.mark)

def mark(self):
_nvtx.markA(self.msg)
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def run_testsuit():
"test_handler_mean_dice",
"test_handler_metrics_saver",
"test_handler_metrics_saver_dist",
"test_handler_nvtx",
"test_handler_parameter_scheduler",
"test_handler_post_processing",
"test_handler_prob_map_producer",
Expand Down
100 changes: 100 additions & 0 deletions tests/test_handler_nvtx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from ignite.engine import Events
from parameterized import parameterized

from monai.engines import SupervisedEvaluator
from monai.handlers import StatsHandler
from monai.handlers.nvtx_handlers import MarkHandler, RangeHandler, RangePopHandler, RangePushHandler
from monai.utils import optional_import

_, has_nvtx = optional_import("torch._C._nvtx", descriptor="NVTX is not installed. Are you sure you have a CUDA build?")

TENSOR_0 = torch.tensor(
[
[
[[1.0], [2.0]],
[[3.0], [4.0]],
]
]
)

TENSOR_1 = torch.tensor(
[
[
[[0.0], [-2.0]],
[[-3.0], [4.0]],
]
]
)

TENSOR_1_EXPECTED = torch.tensor(
[
[[1.0], [0.5]],
[[0.25], [5.0]],
]
)

TEST_CASE_0 = [[{"image": TENSOR_0}], TENSOR_0[0] + 1.0]
TEST_CASE_1 = [[{"image": TENSOR_1}], TENSOR_1_EXPECTED]


class TestHandlerDecollateBatch(unittest.TestCase):
@parameterized.expand(
[
TEST_CASE_0,
TEST_CASE_1,
]
)
@unittest.skipUnless(has_nvtx, "CUDA is required for NVTX!")
def test_compute(self, data, expected):

# Set up handlers
handlers = [
# Mark with Ignite Event
MarkHandler(Events.STARTED),
# Mark with literal
MarkHandler("EPOCH_STARTED"),
# Define a range between BATCH_STARTED and BATCH_COMPLETED
RangeHandler("Batch"),
# Define the start of range using literal
RangePushHandler("ITERATION_STARTED"),
# Define the end of range using Ignite Event
RangePopHandler(Events.ITERATION_COMPLETED),
# Other handlers
StatsHandler(tag_name="train"),
]

# Set up an engine
engine = SupervisedEvaluator(
device=torch.device("cpu:0"),
val_data_loader=data,
epoch_length=1,
network=torch.nn.PReLU(),
postprocessing=lambda x: dict(pred=x["pred"] + 1.0),
decollate=True,
val_handlers=handlers,
)
# Run the engine
engine.run()

# Get the output from the engine
output = engine.state.output[0]

torch.testing.assert_allclose(output["pred"], expected)


if __name__ == "__main__":
unittest.main()