Skip to content

Commit

Permalink
Use forward-declaration of OpCheckpoint in operator.h - most ops are …
Browse files Browse the repository at this point in the history
…stateless.

Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
  • Loading branch information
mzient committed Jun 11, 2024
1 parent 32087da commit 5cbffaf
Show file tree
Hide file tree
Showing 12 changed files with 17 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "dali/operators/decoder/host/fused/host_decoder_random_crop.h"
#include "dali/pipeline/operator/common.h"
#include "dali/pipeline/operator/checkpointing/snapshot_serializer.h"
#include "dali/pipeline/operator/checkpointing/op_checkpoint.h"

namespace dali {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "dali/operators/decoder/nvjpeg/fused/nvjpeg_decoder_random_crop.h"
#include "dali/pipeline/operator/checkpointing/snapshot_serializer.h"
#include "dali/util/random_crop_generator.h"
#include "dali/pipeline/operator/checkpointing/op_checkpoint.h"

namespace dali {

Expand Down
1 change: 1 addition & 0 deletions dali/operators/image/crop/random_crop_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "dali/pipeline/operator/common.h"
#include "dali/operators/image/crop/random_crop_attr.h"
#include "dali/pipeline/operator/checkpointing/snapshot_serializer.h"
#include "dali/pipeline/operator/checkpointing/op_checkpoint.h"

namespace dali {

Expand Down
3 changes: 2 additions & 1 deletion dali/operators/image/resize/random_resized_crop.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,6 +29,7 @@
#include "dali/operators/image/crop/random_crop_attr.h"
#include "dali/kernels/imgproc/resample/params.h"
#include "dali/pipeline/operator/checkpointing/snapshot_serializer.h"
#include "dali/pipeline/operator/checkpointing/op_checkpoint.h"

namespace dali {

Expand Down
1 change: 1 addition & 0 deletions dali/operators/imgcodec/roi_image_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "dali/operators/image/crop/random_crop_attr.h"
#include "dali/operators/imgcodec/imgcodec.h"
#include "dali/pipeline/operator/checkpointing/snapshot_serializer.h"
#include "dali/pipeline/operator/checkpointing/op_checkpoint.h"
#include "dali/pipeline/operator/common.h"
#include "dali/pipeline/operator/operator.h"
#include "dali/util/crop_window.h"
Expand Down
4 changes: 2 additions & 2 deletions dali/operators/random/rng_base_cpu_test.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -61,7 +61,7 @@ class RNGCheckpointingTest : public ::testing::Test {

// save and restore the pipeline
auto node = original_pipe.GetOperatorNode("rng_op");
OpCheckpoint cpt(node->spec);
OpCheckpoint cpt("rng_op");
node->op->SaveState(cpt, {});
restored_pipe.GetOperatorNode("rng_op")->op->RestoreState(cpt);

Expand Down
1 change: 1 addition & 0 deletions dali/operators/random/rng_checkpointing_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "dali/pipeline/operator/checkpointing/snapshot_serializer.h"
#include "dali/pipeline/util/batch_rng.h"
#include "dali/operators/util/randomizer.cuh"
#include "dali/pipeline/operator/checkpointing/op_checkpoint.h"

namespace dali {
namespace rng {
Expand Down
1 change: 1 addition & 0 deletions dali/operators/reader/reader_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "dali/operators/reader/loader/loader.h"
#include "dali/operators/reader/parser/parser.h"
#include "dali/pipeline/operator/checkpointing/snapshot_serializer.h"
#include "dali/pipeline/operator/checkpointing/op_checkpoint.h"
#include "dali/pipeline/operator/name_utils.h"
#include "dali/pipeline/operator/operator.h"

Expand Down
1 change: 1 addition & 0 deletions dali/pipeline/executor/lowered_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <string>
#include <memory>
#include <fstream>
#include <optional>
#include <set>

#include "dali/core/common.h"
Expand Down
3 changes: 2 additions & 1 deletion dali/pipeline/operator/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@
#include "dali/pipeline/operator/op_schema.h"
#include "dali/pipeline/operator/op_spec.h"
#include "dali/pipeline/operator/operator_factory.h"
#include "dali/pipeline/operator/checkpointing/op_checkpoint.h"
#include "dali/pipeline/util/batch_utils.h"
#include "dali/pipeline/workspace/workspace.h"
#include "dali/pipeline/workspace/sample_workspace.h"
#include "dali/pipeline/util/thread_pool.h"

namespace dali {

class OpCheckpoint;

struct DLL_PUBLIC ReaderMeta {
Index epoch_size = -1; // raw epoch size
Index epoch_size_padded = -1; // epoch size with the padding at the end
Expand Down
4 changes: 2 additions & 2 deletions dali/test/dali_test_checkpointing.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -68,7 +68,7 @@ class PipelineWrapper {
clone.Build();

for (const auto &[spec, id] : ops_) {
OpCheckpoint cpt(spec);
OpCheckpoint cpt(id);
// TODO(mstaniewski): provide a stream, so operators with state kept
// in device memory can be tested.
GetOperator(id)->SaveState(cpt, {});
Expand Down
3 changes: 2 additions & 1 deletion dali/test/operators/dummy_op.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,7 @@
#include <vector>

#include "dali/pipeline/operator/operator.h"
#include "dali/pipeline/operator/checkpointing/op_checkpoint.h"

namespace dali {

Expand Down

0 comments on commit 5cbffaf

Please sign in to comment.