Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-68] Random shuffle implementation #10048

Merged
merged 15 commits into from
Mar 20, 2018

Conversation

asitstands
Copy link
Contributor

@asitstands asitstands commented Mar 9, 2018

Description

This operator randomly shuffles an NDArray along the first axis. The order of the elements in each subarray does not change. For exmaple, if an NDArray x is shuffled, the order of the subarrays x[i] randomly changes but the order of the elements in each x[i] does not change. This PR is a revise of #9991.

Implementation

  1. In cpu, the shuffling of an 1D array is delegated to __gnu_parallel::random_shuffle, which utilizes openmp, for clang on linux and gcc on any OS and delegated to std::shuffle for other platforms. For an multidimensional array, the usual Fisher-Yates shuffling is implemented.
  2. In gpu, for a multidimensional array, it shuffles the array of indices representing the subarrays and then rearrange the elements of the data array according to the shuffled index array. To shuffle the index array, a random key is generated for each index and then the indices are sorted by the keys. If the given array is 1D, it sorts the array directly with random keys instead of sorting the indices. The sorting is delegated to mshadow's SortByKey which again delegates the call to thrust's sort_by_key. This implementation is essentially equivalent to __gnu_parallel::random_shuffle while the random key generation and sorting are fused in gnu's implementation.

Note

This operator is modeled on numpy.random.shuffle.

Checklist

Essentials

  • Passed code style checking (make lint)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

This operator randomly shuffles an NDArray along the first axis. The order of the elements in each subarray does not change. For exmaple, if an NDArray `x` is shuffled, the order of the subarrays `x[i]` randomly changes but the order of the elements in each `x[i]` does not change. It is modeled on `numpy.random.shuffle`.

In cpu, the shuffling of an 1D array is delegated to `__gnu_parallel::random_shuffle`, which utilizes openmp, for clang on linux and gcc on any OS and delegated to `std::shuffle` for other platforms. For an multidimensional array, the usual Fisher-Yates shuffling is implemented.

In gpu, it shuffles the array of indices representing the subarrays and then rearrange the elements of the data array according to the shuffled index array. To shuffle the index array, a random key is generated for each index and then the indices are sorted by the keys. The sorting is delegated to mshadow's `SortByKey` which again delegates the call to thrust's `sort_by_key`.
@asitstands asitstands changed the title Random shuffle implementation [MXNET-68] Random shuffle implementation Mar 9, 2018

namespace {

template<typename DType, typename Rand>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to indent in namespace.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

std::uniform_int_distribution<index_t> dist(0, n - 1);
return dist(*prnd);
};
for (index_t i = first_axis_len - 1; i > 0; --i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add CHECK_GT(first_axis_len, 0U); above this line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}

template<typename DType, typename Rand>
void ShuffleND(DType* out, index_t size, index_t first_axis_len, Rand* prnd) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add const qualifier to the arguments if possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return;
}
CHECK_NE(req[0], kAddTo) << "Shuffle does not support AddTo";
const TShape input_shape = inputs[0].shape_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const TShape&

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

const index_t size = inputs[0].Size();
const index_t first_axis_len = input_shape[0];
Stream<cpu> *s = ctx.get_stream<cpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason of not supporting integer types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed with MSHADOW_TYPE_SWITCH

@@ -247,3 +247,34 @@ def multinomial(data, shape=_Null, get_prob=True, **kwargs):
reward as head gradient w.r.t. this array to estimate gradient.
"""
return _internal._sample_multinomial(data, shape, get_prob, **kwargs)


def shuffle(data, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python interface seems unnecessary. You can register the operator with name _random_shuffle in C++.

@@ -431,3 +431,35 @@ def multinomial(data, shape=_Null, get_prob=False, out=None, **kwargs):
<NDArray 2 @cpu(0)>
"""
return _internal._sample_multinomial(data, shape, get_prob, out=out, **kwargs)


def shuffle(data, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python interface seems unnecessary. You can register the operator with name _random_shuffle in C++.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still valid @reminisce ?

@@ -552,6 +554,56 @@ def compute_expected_prob():
mx.test_utils.assert_almost_equal(exp_cnt_sampled.asnumpy(), exp_cnt[sampled_classes].asnumpy(), rtol=1e-1, atol=1e-2)
mx.test_utils.assert_almost_equal(exp_cnt_true.asnumpy(), exp_cnt[true_classes].asnumpy(), rtol=1e-1, atol=1e-2)

@with_seed()
def test_shuffle():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate your reasoning of this unit test? How did you prove it's a uniformly random shuffling.

Copy link
Contributor Author

@asitstands asitstands Mar 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It just counts the probability of each possible outcome. I added some comments explaining the test.

@@ -552,6 +554,56 @@ def compute_expected_prob():
mx.test_utils.assert_almost_equal(exp_cnt_sampled.asnumpy(), exp_cnt[sampled_classes].asnumpy(), rtol=1e-1, atol=1e-2)
mx.test_utils.assert_almost_equal(exp_cnt_true.asnumpy(), exp_cnt[true_classes].asnumpy(), rtol=1e-1, atol=1e-2)

@with_seed()
def test_shuffle():
def hash(arr):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Concerned with the speed of hash function.

Copy link
Contributor Author

@asitstands asitstands Mar 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to distinguish the order of the elements in the array and so need to take care of all elements. This is a simple one-to-one hash but does its job, and it seems not a bottleneck.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your logic here. I think there are too many blocking calls involved (asscalar). We want to do the best to squeeze the time of unit tests. Since your input array is small, you can use hash(str(arr)) which should be more efficient (in terms of speed) and reliable (in terms of hash value collision) than the current implementation.

d = mx.sym.sort(c, axis=0)
assert (d.eval(a=data, ctx=mx.current_context())[0] == data).prod() == 1

test(mx.nd.arange(0, 3), 10, 20000)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The array sizes in the following three test cases seem not big enough.

Copy link
Contributor Author

@asitstands asitstands Mar 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verifying the uniformity of the distribution of the outcomes needs very large number of samplings for larger arrays. The needed number of samples grows factorially with the length of the first axis. I added a weaker test for larger arrays.

// No parameter is declared.
// No backward computation is registered. Shuffling is not differentiable.

NNVM_REGISTER_OP(_shuffle)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I forgot @piiswrong refactored it. It makes sense to keep the python interface.

for (index_t i = first_axis_len - 1; i > 0; --i) {
const index_t j = rand_n(i + 1);
if (i != j) {
std::swap_ranges(out + stride * i, out + stride * (i + 1), out + stride * j);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's okay to use single thread for swapping two ranges for now. We may need to consider using multiple threads to saturate the memory bandwidth if the number of elements per row is big if it's identified as a bottleneck.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good if we could have a benchmark for this

Copy link
Contributor Author

@asitstands asitstands Mar 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that the optimization may be not trivial. Anyway here are some tests with a very naive parallelization with openmp. It simply splits the ranges to swap into multiple ranges and gives each piece to an openmp thread. Multiple threads benefit arrays with large number of elements per row when they run on two Xeon E5-2680 CPUs, but there is no gain when run on single i7-7700. For small arrays, multiple threads very poorly perform in either CPUs. There could be more sophisticated optimizations for this kind of memory copy, but I have no idea.

Test with Xeon E5-2680 two CPUs.

# ./a.out num_rows num_cols num_repeats num_threads
# measures the running time of the two implementations in microseconds.

> ./a.out 100 10000000 10 4
multi  : 3861601 us
single : 9080845 us

> ./a.out 100 1000000 10 4
multi  : 338396 us
single : 861971 us

> ./a.out 100 100000 10 4
multi  : 21387 us
single : 57533 us

> ./a.out 100 10000 10 4
multi  : 6956 us
single : 4073 us

> ./a.out 100 1000 10 4
multi  : 5886 us
single : 597 us

> ./a.out 100 100 10 4
multi  : 4606 us
single : 139 us

Test with i7-7700.


> ./a.out 100 10000000 10 4
multi  : 10015002 us
single : 9327057 us

> ./a.out 100 1000000 10 4
multi  : 969582 us
single : 918764 us

> ./a.out 100 100000 10 4
multi  : 77717 us
single : 75001 us

> ./a.out 100 10000 10 4
multi  : 1850 us
single : 2016 us

> ./a.out 100 1000 10 4
multi  : 1911 us
single : 209 us

> ./a.out 100 10000000 10 2
multi  : 9478994 us
single : 9451969 us

> ./a.out 100 1000000 10 2
multi  : 936728 us
single : 918129 us

> ./a.out 100 100000 10 2
multi  : 75222 us
single : 75331 us

> ./a.out 100 10000 10 2
multi  : 1885 us
single : 1953 us

> ./a.out 100 1000 10 2
multi  : 1425 us
single : 204 us

Here is the code.

#include <iostream>
#include <algorithm>
#include <random>
#include <chrono>

using index_t = unsigned int;

// The current implementation
template<typename DType, typename Rand>
void ShuffleND(DType* const out, const index_t size,
               const index_t first_axis_len, Rand* const prnd) {
  const index_t stride = size / first_axis_len;
  auto rand_n = [prnd](index_t n) {
    std::uniform_int_distribution<index_t> dist(0, n - 1);
    return dist(*prnd);
  };
  for (index_t i = first_axis_len - 1; i > 0; --i) {
    const index_t j = rand_n(i + 1);
    if (i != j) {
      std::swap_ranges(out + stride * i, out + stride * (i + 1), out + stride * j);
    }
  }
}

// Naive parallelization with openmp
template<typename DType, typename Rand>
void ShuffleND_M(const unsigned int n_threads, DType* const out, const index_t size,
                 const index_t first_axis_len, Rand* const prnd) {
  const index_t stride = size / first_axis_len;
  auto rand_n = [prnd](index_t n) {
    std::uniform_int_distribution<index_t> dist(0, n - 1);
    return dist(*prnd);
  };
  for (index_t i = first_axis_len - 1; i > 0; --i) {
    const index_t j = rand_n(i + 1);
    if (i != j) {
      // This loop is different from the current implementation.
      #pragma omp parallel for num_threads(n_threads)
      for(unsigned int k = 0; k < n_threads; ++k) {
        std::swap_ranges(out + stride * i + k * stride / n_threads,
                         out + stride * i + (k + 1) * stride / n_threads,
                         out + stride * j + k * stride / n_threads);
      }
    }
  }
}

int main(int argc, char* argv[]) {
  using namespace std;
  using namespace std::chrono;

  const size_t n_rows = stol(argv[1]);
  const size_t n_cols = stol(argv[2]);
  const size_t n_repeats = stol(argv[3]);
  const unsigned int n_threads = stol(argv[4]);

  vector<float> vec(n_rows * n_cols);
  iota(vec.begin(), vec.end(), 0);
  mt19937 rnd((random_device())());

  high_resolution_clock::time_point t1;
  high_resolution_clock::time_point t2;

  t1 = high_resolution_clock::now();
  for(unsigned int i = 0; i < n_repeats; ++i) {
    ShuffleND_M(n_threads, vec.data(), vec.size(), n_rows, &rnd);
  }
  t2 = high_resolution_clock::now();
  cout << "multi  : " << duration_cast<microseconds>(t2 - t1).count() << " us" << endl;

  t1 = high_resolution_clock::now();
  for(unsigned int i = 0; i < n_repeats; ++i) {
    ShuffleND(vec.data(), vec.size(), n_rows, &rnd);
  }
  t2 = high_resolution_clock::now();
  cout << "single : " << duration_cast<microseconds>(t2 - t1).count() << " us" << endl;

  return 0;
}

c = count.get(h, 0)
count[h] = c + 1
# Check the total number of possible outcomes.
assert len(count) == math.factorial(data.shape[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory this is true. In practice, it could fail if the number of rows is big enough. Add proper comment here in case other people add more test cases and know what happen when it fails.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I revised the comments in the test.

# Test small arrays with different shapes
testSmall(mx.nd.arange(0, 3), 100, 20000)
testSmall(mx.nd.arange(0, 9).reshape((3, 3)), 100, 20000)
testSmall(mx.nd.arange(0, 12).reshape((2, 2, 3)), 100, 20000)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest using bigger number of rows instead of 2, which is too trivial.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually any number >= 2 is equivallent. Anyway I changed it to 3 :)

assert len(count) == math.factorial(data.shape[0])
# The outcomes must be uniformly distributed.
for p in itertools.permutations(range(0, data.size - stride + 1, stride)):
assert abs(count[hash(mx.nd.array(p))] / repeat2 - 1 / math.factorial(data.shape[0])) < 0.01
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you forget to convert integers to floats before performing divisions?

Copy link
Contributor Author

@asitstands asitstands Mar 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot the difference between python 2 and 3. Fixed.
The hash is also changed. Actually the test is faster! Thanks.
(There seems no way to put an answer to the comment about the hash ?)

@@ -552,6 +554,56 @@ def compute_expected_prob():
mx.test_utils.assert_almost_equal(exp_cnt_sampled.asnumpy(), exp_cnt[sampled_classes].asnumpy(), rtol=1e-1, atol=1e-2)
mx.test_utils.assert_almost_equal(exp_cnt_true.asnumpy(), exp_cnt[true_classes].asnumpy(), rtol=1e-1, atol=1e-2)

@with_seed()
def test_shuffle():
def hash(arr):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your logic here. I think there are too many blocking calls involved (asscalar). We want to do the best to squeeze the time of unit tests. Since your input array is small, you can use hash(str(arr)) which should be more efficient (in terms of speed) and reliable (in terms of hash value collision) than the current implementation.

@reminisce
Copy link
Contributor

Agree multi-threaded copy is not trivial to optimize due to hardware/data_size difference. This PR has provided a functional shuffle op and unit tests with sufficient coverage.
@piiswrong Do you mind to take a look and decide whether to merge it?

@piiswrong piiswrong merged commit 9bbdc16 into apache:master Mar 20, 2018
ashokei pushed a commit to ashokei/incubator-mxnet that referenced this pull request Mar 27, 2018
* Random shuffle implementation

This operator randomly shuffles an NDArray along the first axis. The order of the elements in each subarray does not change. For exmaple, if an NDArray `x` is shuffled, the order of the subarrays `x[i]` randomly changes but the order of the elements in each `x[i]` does not change. It is modeled on `numpy.random.shuffle`.

In cpu, the shuffling of an 1D array is delegated to `__gnu_parallel::random_shuffle`, which utilizes openmp, for clang on linux and gcc on any OS and delegated to `std::shuffle` for other platforms. For an multidimensional array, the usual Fisher-Yates shuffling is implemented.

In gpu, it shuffles the array of indices representing the subarrays and then rearrange the elements of the data array according to the shuffled index array. To shuffle the index array, a random key is generated for each index and then the indices are sorted by the keys. The sorting is delegated to mshadow's `SortByKey` which again delegates the call to thrust's `sort_by_key`.

* Refactoring to avoid a preprocessing problem in Windows build

* Cosmetic changes

* Typo

* Adding const keyword at several places

* Fix the bug that integer arrays are not allowed

* Revise the comments to explain the unit test

* Add a check for correct array shape

* Revised unit test with larger arrays

* Replace the custom hash with 'str'

* Fix a bug due to the integer arithmetic in python2

* Revise comments for the unit test

* Fix the invalid fix in the commit f240714

* Update random.md

* Update random.md
@marcoabreu
Copy link
Contributor

#10277

jinhuang415 pushed a commit to jinhuang415/incubator-mxnet that referenced this pull request Mar 30, 2018
* Random shuffle implementation

This operator randomly shuffles an NDArray along the first axis. The order of the elements in each subarray does not change. For exmaple, if an NDArray `x` is shuffled, the order of the subarrays `x[i]` randomly changes but the order of the elements in each `x[i]` does not change. It is modeled on `numpy.random.shuffle`.

In cpu, the shuffling of an 1D array is delegated to `__gnu_parallel::random_shuffle`, which utilizes openmp, for clang on linux and gcc on any OS and delegated to `std::shuffle` for other platforms. For an multidimensional array, the usual Fisher-Yates shuffling is implemented.

In gpu, it shuffles the array of indices representing the subarrays and then rearrange the elements of the data array according to the shuffled index array. To shuffle the index array, a random key is generated for each index and then the indices are sorted by the keys. The sorting is delegated to mshadow's `SortByKey` which again delegates the call to thrust's `sort_by_key`.

* Refactoring to avoid a preprocessing problem in Windows build

* Cosmetic changes

* Typo

* Adding const keyword at several places

* Fix the bug that integer arrays are not allowed

* Revise the comments to explain the unit test

* Add a check for correct array shape

* Revised unit test with larger arrays

* Replace the custom hash with 'str'

* Fix a bug due to the integer arithmetic in python2

* Revise comments for the unit test

* Fix the invalid fix in the commit f240714

* Update random.md

* Update random.md
@asitstands asitstands deleted the shuffle_op_np branch May 22, 2018 07:16
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
* Random shuffle implementation

This operator randomly shuffles an NDArray along the first axis. The order of the elements in each subarray does not change. For exmaple, if an NDArray `x` is shuffled, the order of the subarrays `x[i]` randomly changes but the order of the elements in each `x[i]` does not change. It is modeled on `numpy.random.shuffle`.

In cpu, the shuffling of an 1D array is delegated to `__gnu_parallel::random_shuffle`, which utilizes openmp, for clang on linux and gcc on any OS and delegated to `std::shuffle` for other platforms. For an multidimensional array, the usual Fisher-Yates shuffling is implemented.

In gpu, it shuffles the array of indices representing the subarrays and then rearrange the elements of the data array according to the shuffled index array. To shuffle the index array, a random key is generated for each index and then the indices are sorted by the keys. The sorting is delegated to mshadow's `SortByKey` which again delegates the call to thrust's `sort_by_key`.

* Refactoring to avoid a preprocessing problem in Windows build

* Cosmetic changes

* Typo

* Adding const keyword at several places

* Fix the bug that integer arrays are not allowed

* Revise the comments to explain the unit test

* Add a check for correct array shape

* Revised unit test with larger arrays

* Replace the custom hash with 'str'

* Fix a bug due to the integer arithmetic in python2

* Revise comments for the unit test

* Fix the invalid fix in the commit f240714

* Update random.md

* Update random.md
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
* Random shuffle implementation

This operator randomly shuffles an NDArray along the first axis. The order of the elements in each subarray does not change. For exmaple, if an NDArray `x` is shuffled, the order of the subarrays `x[i]` randomly changes but the order of the elements in each `x[i]` does not change. It is modeled on `numpy.random.shuffle`.

In cpu, the shuffling of an 1D array is delegated to `__gnu_parallel::random_shuffle`, which utilizes openmp, for clang on linux and gcc on any OS and delegated to `std::shuffle` for other platforms. For an multidimensional array, the usual Fisher-Yates shuffling is implemented.

In gpu, it shuffles the array of indices representing the subarrays and then rearrange the elements of the data array according to the shuffled index array. To shuffle the index array, a random key is generated for each index and then the indices are sorted by the keys. The sorting is delegated to mshadow's `SortByKey` which again delegates the call to thrust's `sort_by_key`.

* Refactoring to avoid a preprocessing problem in Windows build

* Cosmetic changes

* Typo

* Adding const keyword at several places

* Fix the bug that integer arrays are not allowed

* Revise the comments to explain the unit test

* Add a check for correct array shape

* Revised unit test with larger arrays

* Replace the custom hash with 'str'

* Fix a bug due to the integer arithmetic in python2

* Revise comments for the unit test

* Fix the invalid fix in the commit f240714

* Update random.md

* Update random.md
@eric-haibin-lin
Copy link
Member

@eric-haibin-lin
Copy link
Member

@asitstands this creates different result on different platforms. It would be ideal if the behavior is consistent (like numpy.shuffle).

@asitstands
Copy link
Contributor Author

@eric-haibin-lin I agree that the consistency across different platforms is highly desirable. On GPU, I think that it should be consistent already, as the implementation does not delegate the shuffle to other library. I'll try to implement the shuffle on CPU independent of the compiler or OS soon.

@leezu
Copy link
Contributor

leezu commented Mar 15, 2020

@asitstands why is this feature enabled on __clang_major__ > 4 && __linux__?
I don't think clang implements the GNU Parallel extensions.

The condition __clang_major__ > 4 && __linux__ leads to errors if we compile with clang and OpenMP disabled. Let's track the OS-specific behaviour issue in #17836

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants