Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 52 additions & 17 deletions paddle/phi/kernels/funcs/unique_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,35 @@
namespace phi {
namespace funcs {

template <typename T>
static bool NaNSafeEqual(const T& a, const T& b) {
if constexpr (std::is_floating_point_v<T>) {
if (std::isnan(a) && std::isnan(b)) {
return &a == &b;
}
if (std::isnan(a) || std::isnan(b)) {
return false;
}
}
return a == b;
}

template <typename T>
static bool NaNSafeLess(const T& a, const T& b) {
if constexpr (std::is_floating_point_v<T>) {
if (std::isnan(a) && !std::isnan(b)) {
return false;
}
if (!std::isnan(a) && std::isnan(b)) {
return true;
}
if (std::isnan(a) && std::isnan(b)) {
return &a < &b;
}
}
return a < b;
}

template <typename Context, typename InT>
struct UniqueOpFunctor {
const Context& dev_ctx_;
Expand Down Expand Up @@ -122,7 +151,7 @@ static bool Equal(const DenseTensor& a, const DenseTensor& b) {
return false;
}
for (int64_t i = 0; i < a.numel(); ++i) {
if (a.data<T>()[i] != b.data<T>()[i]) {
if (!NaNSafeEqual(a.data<T>()[i], b.data<T>()[i])) {
return false;
}
}
Expand All @@ -140,7 +169,15 @@ static void UniqueFlattenedTensor(const Context& dev_ctx,
bool return_inverse,
bool return_counts) {
const InT* in_data = in.data<InT>();
std::set<InT> unique(in_data, in_data + in.numel());

auto nan_safe_comp = [](const InT& a, const InT& b) {
return NaNSafeLess(a, b);
};
std::set<InT, decltype(nan_safe_comp)> unique(nan_safe_comp);
for (int64_t i = 0; i < in.numel(); ++i) {
unique.insert(in_data[i]);
}

out->Resize(common::make_ddim({static_cast<int64_t>(unique.size())}));
auto* out_data = dev_ctx.template Alloc<InT>(out);
std::copy(unique.begin(), unique.end(), out_data);
Expand All @@ -162,29 +199,27 @@ static void UniqueFlattenedTensor(const Context& dev_ctx,
if (return_inverse) {
index->Resize(common::make_ddim({in.numel()}));
auto inverse_data = dev_ctx.template Alloc<IndexT>(index);
std::unordered_map<InT, IndexT> inverse_map;
inverse_map.reserve(out->numel());
for (int64_t i = 0; i < out->numel(); ++i) {
inverse_map[out_data[i]] = i;
}
for (int64_t i = 0; i < in.numel(); ++i) {
inverse_data[i] = inverse_map[in_data[i]];
for (int64_t j = 0; j < out->numel(); ++j) {
if (NaNSafeEqual(in_data[i], out_data[j])) {
inverse_data[i] = j;
break;
}
}
}
}

if (return_counts) {
count->Resize(common::make_ddim({out->numel()}));
auto count_data = dev_ctx.template Alloc<IndexT>(count);
std::unordered_map<InT, IndexT> counts_map;
counts_map.reserve(out->numel());
for (int64_t i = 0; i < out->numel(); ++i) {
counts_map[out_data[i]] = 0;
}
for (int64_t i = 0; i < in.numel(); i++) {
counts_map[in_data[i]] += 1;
}
for (int64_t i = 0; i < out->numel(); i++) {
count_data[i] = counts_map[out_data[i]];
IndexT cnt = 0;
for (int64_t j = 0; j < in.numel(); ++j) {
if (NaNSafeEqual(out_data[i], in_data[j])) {
cnt++;
}
}
count_data[i] = cnt;
}
}
}
Expand Down