Skip to content

Commit

Permalink
Change WDS index version representation to integer + refactor version…
Browse files Browse the repository at this point in the history
… utilties. (#4708)

* Change WDS index version representation to integer.
* Use a common utility for generating an integer version number.
* Move version utilities to a standalone header
* Make version utilities more generic

Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
  • Loading branch information
mzient authored Mar 14, 2023
1 parent bbcb667 commit ebd85f0
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 39 deletions.
4 changes: 2 additions & 2 deletions dali/imgcodec/decoders/nvjpeg/nvjpeg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include <string>
#include <utility>
#include "dali/core/device_guard.h"
#include "dali/core/util.h"
#include "dali/core/version_util.h"
#include "dali/imgcodec/decoders/nvjpeg/nvjpeg.h"
#include "dali/imgcodec/decoders/nvjpeg/nvjpeg_helper.h"
#include "dali/imgcodec/decoders/nvjpeg/nvjpeg_memory.h"
Expand All @@ -38,7 +38,7 @@ int nvjpegGetVersion() {
GetVersionProperty(nvjpegGetProperty, &major, MAJOR_VERSION, NVJPEG_STATUS_SUCCESS);
GetVersionProperty(nvjpegGetProperty, &minor, MINOR_VERSION, NVJPEG_STATUS_SUCCESS);
GetVersionProperty(nvjpegGetProperty, &patch, PATCH_LEVEL, NVJPEG_STATUS_SUCCESS);
return GetVersionNumber(major, minor, patch);
return MakeVersionNumber(major, minor, patch);
}

} // namespace
Expand Down
6 changes: 3 additions & 3 deletions dali/kernels/signal/fft/cufft_helper.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -13,7 +13,7 @@
// limitations under the License.

#include "dali/kernels/signal/fft/cufft_helper.h"
#include "dali/core/util.h"
#include "dali/core/version_util.h"

namespace dali {

Expand All @@ -24,7 +24,7 @@ DLL_PUBLIC int cufftGetVersion() {
GetVersionProperty(cufftGetProperty, &major, MAJOR_VERSION, CUFFT_SUCCESS);
GetVersionProperty(cufftGetProperty, &minor, MINOR_VERSION, CUFFT_SUCCESS);
GetVersionProperty(cufftGetProperty, &patch, PATCH_LEVEL, CUFFT_SUCCESS);
return GetVersionNumber(major, minor, patch);
return MakeVersionNumber(major, minor, patch);
}

} // namespace dali
6 changes: 3 additions & 3 deletions dali/operators/decoder/nvjpeg/nvjpeg_helper.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -13,7 +13,7 @@
// limitations under the License.

#include "dali/operators/decoder/nvjpeg/nvjpeg_helper.h"
#include "dali/core/util.h"
#include "dali/core/version_util.h"

namespace dali {

Expand All @@ -24,7 +24,7 @@ DLL_PUBLIC int nvjpegGetVersion() {
GetVersionProperty(nvjpegGetProperty, &major, MAJOR_VERSION, NVJPEG_STATUS_SUCCESS);
GetVersionProperty(nvjpegGetProperty, &minor, MINOR_VERSION, NVJPEG_STATUS_SUCCESS);
GetVersionProperty(nvjpegGetProperty, &patch, PATCH_LEVEL, NVJPEG_STATUS_SUCCESS);
return GetVersionNumber(major, minor, patch);
return MakeVersionNumber(major, minor, patch);
}

} // namespace dali
29 changes: 22 additions & 7 deletions dali/operators/reader/loader/webdataset_loader.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,7 @@
#include <tuple>
#include <utility>
#include "dali/core/common.h"
#include "dali/core/version_util.h"
#include "dali/core/error_handling.h"
#include "dali/operators/reader/loader/webdataset/tar_utils.h"
#include "dali/pipeline/data/types.h"
Expand Down Expand Up @@ -55,7 +56,7 @@ inline MissingExtBehavior ParseMissingExtBehavior(std::string missing_component_
inline void ParseSampleDesc(std::vector<SampleDesc>& samples_container,
std::vector<ComponentDesc>& components_container,
std::ifstream& index_file, const std::string& index_path, int64_t line,
const std::string& index_version) {
int index_version) {
// Preparing the SampleDesc
samples_container.emplace_back();
samples_container.back().components =
Expand All @@ -70,7 +71,7 @@ inline void ParseSampleDesc(std::vector<SampleDesc>& samples_container,
// Reading consecutive components
ComponentDesc component;
while (components_stream >> component.ext) {
if (index_version == "v1.2") {
if (index_version == MakeVersionNumber(1, 2)) {
DALI_ENFORCE(
components_stream >> component.offset >> component.size >> component.filename,
IndexFileErrMsg(
Expand All @@ -97,6 +98,18 @@ inline void ParseSampleDesc(std::vector<SampleDesc>& samples_container,
IndexFileErrMsg(index_path, line, "no extensions provided for the sample"));
}

inline int ParseIndexVersion(const string& version_str) {
const char *s = version_str.c_str();
assert(*s == 'v');
s++;
int major = atoi(s);
s = strchr(s, '.');
assert(s);
s++;
int minor = atoi(s);
return MakeVersionNumber(major, minor);
}

inline void ParseIndexFile(std::vector<SampleDesc>& samples_container,
std::vector<ComponentDesc>& components_container,
const std::string& index_path) {
Expand All @@ -106,13 +119,15 @@ inline void ParseIndexFile(std::vector<SampleDesc>& samples_container,
std::string global_meta;
getline(index_file, global_meta);
std::stringstream global_meta_stream(global_meta);
std::string index_version;
DALI_ENFORCE(global_meta_stream >> index_version,
std::string index_version_str;
DALI_ENFORCE(global_meta_stream >> index_version_str,
IndexFileErrMsg(index_path, 0, "no version signature found"));
DALI_ENFORCE(
kSupportedIndexVersions.count(index_version) > 0,
kSupportedIndexVersions.count(index_version_str) > 0,
IndexFileErrMsg(index_path, 0,
make_string("Unsupported version of the index file (", index_version, ").")));
make_string("Unsupported version of the index file (",
index_version_str, ").")));
int index_version = ParseIndexVersion(index_version_str);

// Getting the number of samples in the index file
int64_t sample_desc_num_signed;
Expand Down
6 changes: 3 additions & 3 deletions dali/operators/util/npp.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -15,15 +15,15 @@
#include "dali/npp/npp.h"
#include "dali/core/error_handling.h"
#include "dali/core/cuda_error.h"
#include "dali/core/util.h"
#include "dali/core/version_util.h"

namespace dali {

DLL_PUBLIC int NPPGetVersion() {
auto version_s = nppGetLibVersion();
int version = -1;
if (version_s) {
version = GetVersionNumber(version_s->major, version_s->minor, version_s->build);
version = MakeVersionNumber(version_s->major, version_s->minor, version_s->build);
}
return version;
}
Expand Down
22 changes: 1 addition & 21 deletions include/dali/core/util.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2018-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -442,24 +442,4 @@ inline auto flatten(span<const std::array<T, N>, extent> in) {

} // namespace dali

/// @brief Returns the product of all elements in shape
/// @param getter - function obtaining property
/// @param value - value where to put the property, in case of failure it gets -1
/// @param property - enum for the select right property
/// @param success_status - enum value for success status of getter
template <typename F, typename E, typename S>
void GetVersionProperty(F getter, int *value, E property, S success_status) {
if (getter(property, value) != success_status) {
*value = -1;
}
}

// gets single int that can be represented as int value
static int GetVersionNumber(int major, int minor, int patch) {
if (major < 0 || minor < 0 || patch < 0) {
return -1;
}
return major*1000 + minor*10 + patch;
}

#endif // DALI_CORE_UTIL_H_
43 changes: 43 additions & 0 deletions include/dali/core/version_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_CORE_VERSION_UTIL_H_
#define DALI_CORE_VERSION_UTIL_H_

namespace dali {

/// @brief Returns the product of all elements in shape
/// @param getter - function obtaining property
/// @param value - value where to put the property, in case of failure it gets `invalid_value`
/// @param property - enum for the select right property
/// @param success_status - enum value for success status of getter
/// @param invalid_value - the value returned when the getter fails
template <typename F, typename V, typename E, typename S>
void GetVersionProperty(F getter, V *value, E property, S success_status, V invalid_value = -1) {
if (getter(property, value) != success_status) {
*value = invalid_value;
}
}

// gets single int that can be represented as int value
constexpr int MakeVersionNumber(int major, int minor, int patch = 0) {
if (major < 0 || minor < 0 || patch < 0) {
return -1;
}
return major*1000 + minor*10 + patch;
}

} // namespace dali

#endif // DALI_CORE_VERSION_UTIL_H_

0 comments on commit ebd85f0

Please sign in to comment.