Skip to content

Commit

Permalink
add unique_consecutive_op
Browse files Browse the repository at this point in the history
  • Loading branch information
firestonelib committed Jul 26, 2021
1 parent e911118 commit 28b2fa0
Showing 1 changed file with 11 additions and 36 deletions.
47 changes: 11 additions & 36 deletions paddle/fluid/operators/unique_consecutive_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,39 +24,15 @@ limitations under the License. */
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/operators/unique_op.h"

namespace paddle {
namespace operators {

static std::vector<framework::Tensor> Unbind(const framework::Tensor& in) {
int64_t size = in.dims()[0];
std::vector<framework::Tensor> tensors(size);
for (int64_t i = 0; i < size; ++i) {
tensors[i] = in.Slice(i, i + 1);
}
return tensors;
}

template <typename T>
static bool Equal(const framework::Tensor& a, const framework::Tensor& b) {
if (a.numel() != b.numel()) {
return false;
}
for (int64_t i = 0; i < a.numel(); ++i) {
if (a.data<T>()[i] != b.data<T>()[i]) {
return false;
}
}
return true;
}

template <typename InT, typename IndexT>
static void UniqueFlattendTensor(const framework::ExecutionContext& context,
const framework::Tensor& in,
framework::Tensor* out, bool return_inverse,
bool return_counts) {
static void UniqueConsecutiveFlattendTensor(
const framework::ExecutionContext& context, const framework::Tensor& in,
framework::Tensor* out, bool return_inverse, bool return_counts) {
const InT* in_data = in.data<InT>();

std::vector<InT> out_vec(in.numel());
std::vector<IndexT> inverse_vec(in.numel());
std::vector<IndexT> counts_vec(in.numel());
Expand Down Expand Up @@ -209,17 +185,16 @@ static void UniqueConsecutiveDim(const framework::ExecutionContext& context,
}

template <typename DeviceContext, typename InT>
struct UniqueFlattendTensorFunctor {
struct UniqueConsecutiveFlattendTensorFunctor {
const framework::ExecutionContext& ctx_;
const framework::Tensor& in_;
framework::Tensor* out_;
const bool return_inverse_;
const bool return_counts_;

UniqueFlattendTensorFunctor(const framework::ExecutionContext& context,
const framework::Tensor& in,
framework::Tensor* out, bool return_inverse,
bool return_counts)
UniqueConsecutiveFlattendTensorFunctor(
const framework::ExecutionContext& context, const framework::Tensor& in,
framework::Tensor* out, bool return_inverse, bool return_counts)
: ctx_(context),
in_(in),
out_(out),
Expand All @@ -228,8 +203,8 @@ struct UniqueFlattendTensorFunctor {

template <typename IndexT>
void apply() const {
UniqueFlattendTensor<InT, IndexT>(ctx_, in_, out_, return_inverse_,
return_counts_);
UniqueConsecutiveFlattendTensor<InT, IndexT>(
ctx_, in_, out_, return_inverse_, return_counts_);
}
};

Expand Down Expand Up @@ -283,7 +258,7 @@ class UniqueConsecutiveKernel : public framework::OpKernel<T> {

if (axis_vec.empty()) {
framework::VisitDataTypeTiny(
data_type, UniqueFlattendTensorFunctor<DeviceContext, T>(
data_type, UniqueConsecutiveFlattendTensorFunctor<DeviceContext, T>(
context, *x, out, return_inverse, return_counts));
} else {
int axis = axis_vec[0];
Expand Down

0 comments on commit 28b2fa0

Please sign in to comment.