-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-68] Random shuffle implementation #10048
Changes from all commits
61126c1
2f2f434
844b655
2208b29
926d5e4
5349acd
3957148
a2813f3
a219316
cfa96e6
f240714
1260287
8bc85f0
cc3f650
dd2d6a4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,7 @@ | |
|
||
|
||
__all__ = ['uniform', 'normal', 'poisson', 'exponential', 'gamma', 'multinomial', | ||
'negative_binomial', 'generalized_negative_binomial'] | ||
'negative_binomial', 'generalized_negative_binomial', 'shuffle'] | ||
|
||
|
||
def _random_helper(random, sampler, params, shape, dtype, kwargs): | ||
|
@@ -247,3 +247,34 @@ def multinomial(data, shape=_Null, get_prob=True, **kwargs): | |
reward as head gradient w.r.t. this array to estimate gradient. | ||
""" | ||
return _internal._sample_multinomial(data, shape, get_prob, **kwargs) | ||
|
||
|
||
def shuffle(data, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Python interface seems unnecessary. You can register the operator with name |
||
"""Shuffle the elements randomly. | ||
|
||
This shuffles the array along the first axis. | ||
The order of the elements in each subarray does not change. | ||
For example, if a 2D array is given, the order of the rows randomly changes, | ||
but the order of the elements in each row does not change. | ||
|
||
Parameters | ||
---------- | ||
data : NDArray | ||
Input data array. | ||
Examples | ||
-------- | ||
>>> data = mx.nd.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) | ||
>>> a = mx.sym.Variable('a') | ||
>>> b = mx.sym.random.shuffle(a) | ||
>>> b.eval(a=data) | ||
[[ 0. 1. 2.] | ||
[ 6. 7. 8.] | ||
[ 3. 4. 5.]] | ||
<NDArray 2x3 @cpu(0)> | ||
>>> b.eval(a=data) | ||
[[ 3. 4. 5.] | ||
[ 0. 1. 2.] | ||
[ 6. 7. 8.]] | ||
<NDArray 2x3 @cpu(0)> | ||
""" | ||
return _internal._shuffle(data, **kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* Copyright (c) 2018 by Contributors | ||
* \file shuffle_op.cc | ||
* \brief Operator to shuffle elements of an NDArray | ||
*/ | ||
#if (__GNUC__ > 4 && !defined(__clang__major__)) || (__clang_major__ > 4 && __linux__) | ||
#define USE_GNU_PARALLEL_SHUFFLE | ||
#endif | ||
|
||
#include <mxnet/operator_util.h> | ||
#include <algorithm> | ||
#include <random> | ||
#include <vector> | ||
#ifdef USE_GNU_PARALLEL_SHUFFLE | ||
#include <parallel/algorithm> | ||
#endif | ||
#include "../elemwise_op_common.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
namespace { | ||
|
||
template<typename DType, typename Rand> | ||
void Shuffle1D(DType* const out, const index_t size, Rand* const prnd) { | ||
#ifdef USE_GNU_PARALLEL_SHUFFLE | ||
auto rand_n = [prnd](index_t n) { | ||
std::uniform_int_distribution<index_t> dist(0, n - 1); | ||
return dist(*prnd); | ||
}; | ||
__gnu_parallel::random_shuffle(out, out + size, rand_n); | ||
#else | ||
std::shuffle(out, out + size, *prnd); | ||
#endif | ||
} | ||
|
||
template<typename DType, typename Rand> | ||
void ShuffleND(DType* const out, const index_t size, const index_t first_axis_len, | ||
Rand* const prnd) { | ||
// Fisher-Yates shuffling | ||
const index_t stride = size / first_axis_len; | ||
auto rand_n = [prnd](index_t n) { | ||
std::uniform_int_distribution<index_t> dist(0, n - 1); | ||
return dist(*prnd); | ||
}; | ||
CHECK_GT(first_axis_len, 0U); | ||
for (index_t i = first_axis_len - 1; i > 0; --i) { | ||
const index_t j = rand_n(i + 1); | ||
if (i != j) { | ||
std::swap_ranges(out + stride * i, out + stride * (i + 1), out + stride * j); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's okay to use single thread for swapping two ranges for now. We may need to consider using multiple threads to saturate the memory bandwidth if the number of elements per row is big if it's identified as a bottleneck. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be good if we could have a benchmark for this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess that the optimization may be not trivial. Anyway here are some tests with a very naive parallelization with openmp. It simply splits the ranges to swap into multiple ranges and gives each piece to an openmp thread. Multiple threads benefit arrays with large number of elements per row when they run on two Xeon E5-2680 CPUs, but there is no gain when run on single i7-7700. For small arrays, multiple threads very poorly perform in either CPUs. There could be more sophisticated optimizations for this kind of memory copy, but I have no idea. Test with Xeon E5-2680 two CPUs.
Test with i7-7700.
Here is the code. #include <iostream>
#include <algorithm>
#include <random>
#include <chrono>
using index_t = unsigned int;
// The current implementation
template<typename DType, typename Rand>
void ShuffleND(DType* const out, const index_t size,
const index_t first_axis_len, Rand* const prnd) {
const index_t stride = size / first_axis_len;
auto rand_n = [prnd](index_t n) {
std::uniform_int_distribution<index_t> dist(0, n - 1);
return dist(*prnd);
};
for (index_t i = first_axis_len - 1; i > 0; --i) {
const index_t j = rand_n(i + 1);
if (i != j) {
std::swap_ranges(out + stride * i, out + stride * (i + 1), out + stride * j);
}
}
}
// Naive parallelization with openmp
template<typename DType, typename Rand>
void ShuffleND_M(const unsigned int n_threads, DType* const out, const index_t size,
const index_t first_axis_len, Rand* const prnd) {
const index_t stride = size / first_axis_len;
auto rand_n = [prnd](index_t n) {
std::uniform_int_distribution<index_t> dist(0, n - 1);
return dist(*prnd);
};
for (index_t i = first_axis_len - 1; i > 0; --i) {
const index_t j = rand_n(i + 1);
if (i != j) {
// This loop is different from the current implementation.
#pragma omp parallel for num_threads(n_threads)
for(unsigned int k = 0; k < n_threads; ++k) {
std::swap_ranges(out + stride * i + k * stride / n_threads,
out + stride * i + (k + 1) * stride / n_threads,
out + stride * j + k * stride / n_threads);
}
}
}
}
int main(int argc, char* argv[]) {
using namespace std;
using namespace std::chrono;
const size_t n_rows = stol(argv[1]);
const size_t n_cols = stol(argv[2]);
const size_t n_repeats = stol(argv[3]);
const unsigned int n_threads = stol(argv[4]);
vector<float> vec(n_rows * n_cols);
iota(vec.begin(), vec.end(), 0);
mt19937 rnd((random_device())());
high_resolution_clock::time_point t1;
high_resolution_clock::time_point t2;
t1 = high_resolution_clock::now();
for(unsigned int i = 0; i < n_repeats; ++i) {
ShuffleND_M(n_threads, vec.data(), vec.size(), n_rows, &rnd);
}
t2 = high_resolution_clock::now();
cout << "multi : " << duration_cast<microseconds>(t2 - t1).count() << " us" << endl;
t1 = high_resolution_clock::now();
for(unsigned int i = 0; i < n_repeats; ++i) {
ShuffleND(vec.data(), vec.size(), n_rows, &rnd);
}
t2 = high_resolution_clock::now();
cout << "single : " << duration_cast<microseconds>(t2 - t1).count() << " us" << endl;
return 0;
} |
||
} | ||
} | ||
} | ||
|
||
} // namespace | ||
|
||
void ShuffleForwardCPU(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const std::vector<TBlob>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<TBlob>& outputs) { | ||
using namespace mxnet_op; | ||
if (req[0] == kNullOp) { | ||
return; | ||
} | ||
CHECK_NE(req[0], kAddTo) << "Shuffle does not support AddTo"; | ||
const TShape& input_shape = inputs[0].shape_; | ||
const index_t size = inputs[0].Size(); | ||
const index_t first_axis_len = input_shape[0]; | ||
Stream<cpu> *s = ctx.get_stream<cpu>(); | ||
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { | ||
Tensor<cpu, 1, DType> in = inputs[0].get_with_shape<cpu, 1, DType>(Shape1(size), s); | ||
Tensor<cpu, 1, DType> out = outputs[0].get_with_shape<cpu, 1, DType>(Shape1(size), s); | ||
auto& prnd = ctx.requested[0].get_random<cpu, index_t>(ctx.get_stream<cpu>())->GetRndEngine(); | ||
if (req[0] != kWriteInplace) { | ||
std::copy(in.dptr_, in.dptr_ + size, out.dptr_); | ||
} | ||
if (input_shape.ndim() == 1) { | ||
Shuffle1D(out.dptr_, size, &prnd); | ||
} else { | ||
ShuffleND(out.dptr_, size, first_axis_len, &prnd); | ||
} | ||
}); | ||
} | ||
|
||
|
||
// No parameter is declared. | ||
// No backward computation is registered. Shuffling is not differentiable. | ||
|
||
NNVM_REGISTER_OP(_shuffle) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why register as internal? You can register with name There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does not work. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I forgot @piiswrong refactored it. It makes sense to keep the python interface. |
||
.add_alias("shuffle") | ||
.describe(R"code(Randomly shuffle the elements. | ||
|
||
This shuffles the array along the first axis. | ||
The order of the elements in each subarray does not change. | ||
For example, if a 2D array is given, the order of the rows randomly changes, | ||
but the order of the elements in each row does not change. | ||
)code") | ||
.set_num_inputs(1) | ||
.set_num_outputs(1) | ||
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>) | ||
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) | ||
.set_attr<FResourceRequest>("FResourceRequest", | ||
[](const nnvm::NodeAttrs& attrs) { | ||
return std::vector<ResourceRequest>{ResourceRequest::kRandom, ResourceRequest::kTempSpace}; | ||
}) | ||
.set_attr<nnvm::FInplaceOption>("FInplaceOption", | ||
[](const NodeAttrs& attrs) { | ||
return std::vector<std::pair<int, int>>{{0, 0}}; | ||
}) | ||
.set_attr<FCompute>("FCompute<cpu>", ShuffleForwardCPU) | ||
.add_argument("data", "NDArray-or-Symbol", "Data to be shuffled."); | ||
|
||
} // namespace op | ||
} // namespace mxnet |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* Copyright (c) 2018 by Contributors | ||
* \file shuffle_op.cc | ||
* \brief Operator to shuffle elements of an NDArray | ||
*/ | ||
#include <mxnet/operator_util.h> | ||
#include <algorithm> | ||
#include <random> | ||
#include <vector> | ||
#include "../elemwise_op_common.h" | ||
#include "../tensor/init_op.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
namespace { | ||
|
||
struct CopyForShuffle { | ||
template<typename DType> | ||
MSHADOW_XINLINE static void Map(int i, const DType* const in, DType* out, | ||
const index_t* indices, const index_t stride) { | ||
out[i] = in[indices[i / stride] * stride + i % stride]; | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void ShuffleForwardGPU(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const std::vector<TBlob>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<TBlob>& outputs) { | ||
using namespace mxnet_op; | ||
if (req[0] == kNullOp) { | ||
return; | ||
} | ||
CHECK_NE(req[0], kAddTo) << "Shuffle does not support AddTo"; | ||
const TShape& input_shape = inputs[0].shape_; | ||
const index_t size = inputs[0].Size(); | ||
const index_t first_axis_len = input_shape[0]; | ||
const index_t stride = size / first_axis_len; | ||
Stream<gpu> *s = ctx.get_stream<gpu>(); | ||
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { | ||
using KeyType = index_t; | ||
Tensor<gpu, 1, DType> in = inputs[0].get_with_shape<gpu, 1, DType>(Shape1(size), s); | ||
Tensor<gpu, 1, DType> out = outputs[0].get_with_shape<gpu, 1, DType>(Shape1(size), s); | ||
Random<gpu, KeyType> *prnd = ctx.requested[0].get_random<gpu, KeyType>(s); | ||
if (input_shape.ndim() == 1) { | ||
if (req[0] != kWriteInplace) { | ||
Copy(out, in, s); | ||
} | ||
Tensor<gpu, 1, KeyType> keys = | ||
ctx.requested[1].get_space_typed<gpu, 1, KeyType>(Shape1(size), s); | ||
prnd->GetRandInt(keys); | ||
SortByKey(keys, out, true); | ||
} else { | ||
const size_t tmp_space_size = req[0] == kWriteInplace ? | ||
2 * first_axis_len * sizeof(index_t) + size * sizeof(DType) : | ||
2 * first_axis_len * sizeof(index_t); | ||
Tensor<gpu, 1, char> tmp_space = | ||
ctx.requested[1].get_space_typed<gpu, 1, char>(Shape1(tmp_space_size), s); | ||
char* tmp_space_ptr = tmp_space.dptr_; | ||
Tensor<gpu, 1, index_t> indices(reinterpret_cast<index_t*>(tmp_space_ptr), | ||
Shape1(first_axis_len), s); | ||
tmp_space_ptr += sizeof(index_t) * first_axis_len; | ||
Kernel<range_fwd, gpu>::Launch(s, first_axis_len, 1, 0U, 1U, kWriteTo, indices.dptr_); | ||
Tensor<gpu, 1, KeyType> keys(reinterpret_cast<KeyType*>(tmp_space_ptr), | ||
Shape1(first_axis_len), s); | ||
tmp_space_ptr += sizeof(KeyType) * first_axis_len; | ||
prnd->GetRandInt(keys); | ||
SortByKey(keys, indices, true); | ||
if (req[0] == kWriteInplace) { | ||
Tensor<gpu, 1, DType> buf(reinterpret_cast<DType*>(tmp_space_ptr), Shape1(size), s); | ||
Copy(buf, in, s); | ||
Kernel<CopyForShuffle, gpu>::Launch(s, size, buf.dptr_, out.dptr_, indices.dptr_, stride); | ||
} else { | ||
Kernel<CopyForShuffle, gpu>::Launch(s, size, in.dptr_, out.dptr_, indices.dptr_, stride); | ||
} | ||
} | ||
}); | ||
} | ||
|
||
NNVM_REGISTER_OP(_shuffle) | ||
.set_attr<FCompute>("FCompute<gpu>", ShuffleForwardGPU); | ||
|
||
} // namespace op | ||
} // namespace mxnet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Python interface seems unnecessary. You can register the operator with name _random_shuffle in C++.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still valid @reminisce ?