Skip to content

Commit

Permalink
add unique_consecutive op
Browse files Browse the repository at this point in the history
  • Loading branch information
firestonelib committed Jul 27, 2021
1 parent 62ccea3 commit 5f0fd5f
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 151 deletions.
98 changes: 97 additions & 1 deletion paddle/fluid/operators/unique_consecutive_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,109 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/tensor_util.h" // TensorToVector()
#include "paddle/fluid/operators/unique_consecutive_op.h" // TransComute()
#include "paddle/fluid/operators/unique_utils.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

// Binary function 'equal_to'
template <typename InT>
struct BinaryEqual {
int64_t col;
const InT* in_trans_data;
BinaryEqual(int64_t _col, const InT* _in_trans_data)
: col(_col), in_trans_data(_in_trans_data) {}
__device__ bool operator()(int64_t a, int64_t b) const {
for (int64_t i = 0; i < col; ++i) {
InT lhs = in_trans_data[i + a * col];
InT rhs = in_trans_data[i + b * col];
if (lhs != rhs) {
return false;
}
}
return true;
}
};

// Binary function 'not_equal_to'
template <typename InT>
struct BinaryNotEqual {
int64_t col;
const InT* in_trans_data;
BinaryNotEqual(int64_t _col, const InT* _in_trans_data)
: col(_col), in_trans_data(_in_trans_data) {}
__device__ bool operator()(int64_t a, int64_t b) const {
for (int64_t i = 0; i < col; ++i) {
InT lhs = in_trans_data[i + a * col];
InT rhs = in_trans_data[i + b * col];
if (lhs != rhs) {
return true;
}
}
return false;
}
};

// index_select() function for Tensor
template <typename InT, typename IndexT>
void IndexSelect(const framework::ExecutionContext& context,
const Tensor& input, const Tensor& index, Tensor* output,
int dim) {
auto input_dim = input.dims();
auto input_dim_size = input_dim.size();
auto output_dim = output->dims();

auto slice_size = 1;
for (auto i = dim + 1; i < input_dim_size; i++) {
slice_size *= input_dim[i];
}
auto input_width = slice_size * input_dim[dim];
auto output_width = slice_size * output_dim[dim];

auto outer_nums = 1;
for (auto i = 0; i < dim; i++) {
outer_nums *= input_dim[i];
}
auto index_size = index.dims()[0];
std::vector<InT> input_vec;
std::vector<IndexT> index_vec;
TensorToVector(input, context.device_context(), &input_vec);
TensorToVector(index, context.device_context(), &index_vec);
std::vector<InT> out_vec(output->numel());
for (int i = 0; i < index_size; i++) {
PADDLE_ENFORCE_GE(
index_vec[i], 0,
platform::errors::InvalidArgument(
"Variable value (index) of OP(index_select) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
input_dim[dim], index_vec[i]));
PADDLE_ENFORCE_LT(
index_vec[i], input_dim[dim],
platform::errors::InvalidArgument(
"Variable value (index) of OP(index_select) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
input_dim[dim], index_vec[i]));
}
for (auto i = 0; i < outer_nums; i++) {
auto input_start_offset = i * input_width;
auto output_start_offset = i * output_width;

for (auto j = 0; j < index_size; j++) {
IndexT index_value = index_vec[j];
for (auto k = 0; k < slice_size; k++) {
out_vec[output_start_offset + j * slice_size + k] =
input_vec[input_start_offset + index_value * slice_size + k];
}
}
}
output->mutable_data<InT>(context.GetPlace());
framework::TensorFromVector(out_vec, context.device_context(), output);
output->Resize(output_dim);
}

// The core logic of computing Unique Consecutive for a flattend Tensor
template <typename InT, typename IndexT, typename equal_T, typename not_equal_T>
static void UniqueConsecutiveFlattendCUDATensor(
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/operators/unique_consecutive_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,25 @@ static void UniqueConsecutiveFlattendTensor(
inverse_vec[i] = p - out_vec.data();
}
}

int64_t output_size = p - out_vec.data() + 1;
if (return_counts) {
*q = in.numel() - last;
counts_vec.resize(output_size);
}
out_vec.resize(output_size);

out->Resize(framework::make_ddim({output_size}));
auto* out_data = out->mutable_data<InT>(context.GetPlace());
std::copy(out_vec.begin(), out_vec.end(), out_data);

if (return_inverse) {
auto* inverse = context.Output<framework::Tensor>("Index");
inverse->Resize(framework::make_ddim({in.numel()}));
auto* inverse_data = inverse->mutable_data<IndexT>(context.GetPlace());
std::copy(inverse_vec.begin(), inverse_vec.end(), inverse_data);
}

if (return_counts) {
auto* count = context.Output<framework::Tensor>("Counts");
count->Resize(framework::make_ddim({out->numel()}));
Expand All @@ -83,10 +87,13 @@ static ForwardIt UniqueConsecutiveDimImpl(
if (first == last) {
return last;
}

(*inverse_vec)[sorted_indices_vec[0]] = 0;
(*counts_vec)[0] = 1;

ForwardIt begin = first;
ForwardIt result = first;

while (++first != last) {
int64_t idx_first = std::distance(begin, first);
int64_t idx_result = std::distance(begin, result);
Expand Down Expand Up @@ -126,10 +133,12 @@ static void UniqueConsecutiveDim(const framework::ExecutionContext& context,
framework::DDim in_trans_flat_dims =
framework::flatten_to_2d(in_trans_dims, 1);
in_trans.Resize(in_trans_flat_dims);

std::vector<IndexT> sorted_indices_vec(in_trans.dims()[0]);
std::iota(sorted_indices_vec.begin(), sorted_indices_vec.end(), 0);
int64_t col = in_trans.dims()[1];
const InT* in_trans_data = in_trans.data<InT>();

// sort tensor according to indices
framework::Tensor input_sorted;
input_sorted.Resize(in_trans_dims);
Expand Down Expand Up @@ -241,6 +250,7 @@ class UniqueConsecutiveKernel : public framework::OpKernel<T> {
std::vector<int> axis_vec = context.Attr<std::vector<int>>("axis");
bool return_inverse = context.Attr<bool>("return_inverse");
bool return_counts = context.Attr<bool>("return_counts");

if (axis_vec.empty()) {
framework::VisitDataTypeTiny(
data_type, UniqueConsecutiveFlattendTensorFunctor<DeviceContext, T>(
Expand Down
150 changes: 0 additions & 150 deletions paddle/fluid/operators/unique_utils.h

This file was deleted.

0 comments on commit 5f0fd5f

Please sign in to comment.