2323namespace phi {
2424namespace funcs {
2525
26+ template <typename T>
27+ static bool NaNSafeEqual (const T& a, const T& b) {
28+ if constexpr (std::is_floating_point_v<T>) {
29+ if (std::isnan (a) && std::isnan (b)) {
30+ return true ;
31+ }
32+ if (std::isnan (a) || std::isnan (b)) {
33+ return false ;
34+ }
35+ }
36+ return a == b;
37+ }
38+
39+ template <typename T>
40+ static bool NaNSafeLess (const T& a, const T& b) {
41+ if constexpr (std::is_floating_point_v<T>) {
42+ if (std::isnan (a) && std::isnan (b)) {
43+ return false ;
44+ }
45+ if (std::isnan (a)) {
46+ return false ;
47+ }
48+ if (std::isnan (b)) {
49+ return true ;
50+ }
51+ }
52+ return a < b;
53+ }
54+
2655template <typename Context, typename InT>
2756struct UniqueOpFunctor {
2857 const Context& dev_ctx_;
@@ -122,7 +151,7 @@ static bool Equal(const DenseTensor& a, const DenseTensor& b) {
122151 return false ;
123152 }
124153 for (int64_t i = 0 ; i < a.numel (); ++i) {
125- if (a.data <T>()[i] != b.data <T>()[i]) {
154+ if (! NaNSafeEqual ( a.data <T>()[i], b.data <T>()[i]) ) {
126155 return false ;
127156 }
128157 }
@@ -140,7 +169,15 @@ static void UniqueFlattenedTensor(const Context& dev_ctx,
140169 bool return_inverse,
141170 bool return_counts) {
142171 const InT* in_data = in.data <InT>();
143- std::set<InT> unique (in_data, in_data + in.numel ());
172+
173+ auto nan_safe_comp = [](const InT& a, const InT& b) {
174+ return NaNSafeLess (a, b);
175+ };
176+ std::set<InT, decltype (nan_safe_comp)> unique (nan_safe_comp);
177+ for (int64_t i = 0 ; i < in.numel (); ++i) {
178+ unique.insert (in_data[i]);
179+ }
180+
144181 out->Resize (common::make_ddim ({static_cast <int64_t >(unique.size ())}));
145182 auto * out_data = dev_ctx.template Alloc <InT>(out);
146183 std::copy (unique.begin (), unique.end (), out_data);
@@ -162,29 +199,27 @@ static void UniqueFlattenedTensor(const Context& dev_ctx,
162199 if (return_inverse) {
163200 index->Resize (common::make_ddim ({in.numel ()}));
164201 auto inverse_data = dev_ctx.template Alloc <IndexT>(index);
165- std::unordered_map<InT, IndexT> inverse_map;
166- inverse_map.reserve (out->numel ());
167- for (int64_t i = 0 ; i < out->numel (); ++i) {
168- inverse_map[out_data[i]] = i;
169- }
170202 for (int64_t i = 0 ; i < in.numel (); ++i) {
171- inverse_data[i] = inverse_map[in_data[i]];
203+ for (int64_t j = 0 ; j < out->numel (); ++j) {
204+ if (NaNSafeEqual (in_data[i], out_data[j])) {
205+ inverse_data[i] = j;
206+ break ;
207+ }
208+ }
172209 }
173210 }
174211
175212 if (return_counts) {
176213 count->Resize (common::make_ddim ({out->numel ()}));
177214 auto count_data = dev_ctx.template Alloc <IndexT>(count);
178- std::unordered_map<InT, IndexT> counts_map;
179- counts_map.reserve (out->numel ());
180215 for (int64_t i = 0 ; i < out->numel (); ++i) {
181- counts_map[out_data[i]] = 0 ;
182- }
183- for ( int64_t i = 0 ; i < in. numel (); i++ ) {
184- counts_map[in_data[i]] += 1 ;
185- }
186- for ( int64_t i = 0 ; i < out-> numel (); i++) {
187- count_data[i] = counts_map[out_data[i]] ;
216+ IndexT cnt = 0 ;
217+ for ( int64_t j = 0 ; j < in. numel (); ++j) {
218+ if ( NaNSafeEqual (out_data[i], in_data[j]) ) {
219+ cnt++ ;
220+ }
221+ }
222+ count_data[i] = cnt ;
188223 }
189224 }
190225}
0 commit comments