-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.6】 为 Paddle 增强put_along_axis API -part #59674
Changes from all commits
b6cedc3
36a2405
3818298
18dc8c3
61461f4
013dfb4
1155d4c
efc488c
675b641
2c968e3
d32db6f
564de93
d0c14de
28aadd6
c2fdb6f
d3e33a2
1903b91
5a88f79
8f03013
d215bf8
8c13743
a6b7d16
db3168f
2525f2a
7709066
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -395,6 +395,182 @@ CUDA_ATOMIC_WRAPPER(Add, complex<double>) { | |
CudaAtomicAdd(imag, val.imag)); | ||
} | ||
|
||
// For atomicMul. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这些atomicMul的计算算法,能提供下参考吗 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的atomicMul都是参考的前面的atomicAdd以及后面的atomicMin这些,只是把加改成了乘 |
||
CUDA_ATOMIC_WRAPPER(Mul, int) { | ||
int res = *address, old = res; // NOLINT | ||
do { | ||
old = res; | ||
res = atomicCAS(address, // NOLINT | ||
old, // NOLINT | ||
val * old); // NOLINT | ||
} while (old != res); | ||
return res; | ||
} | ||
|
||
CUDA_ATOMIC_WRAPPER(Mul, unsigned int) { | ||
unsigned int res = *address, old = res; // NOLINT | ||
do { | ||
old = res; | ||
res = atomicCAS(address, // NOLINT | ||
old, // NOLINT | ||
val * old); // NOLINT | ||
} while (old != res); | ||
return res; | ||
} | ||
// CUDA API uses unsigned long long int, we cannot use uint64_t here. | ||
// It because unsigned long long int is not necessarily uint64_t | ||
CUDA_ATOMIC_WRAPPER(Mul, unsigned long long int) { // NOLINT | ||
unsigned long long int old = *address, assumed; // NOLINT | ||
|
||
do { | ||
assumed = old; | ||
old = atomicCAS(address, assumed, val * assumed); | ||
} while (assumed != old); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里是否也应该有一个返回值 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
return old; | ||
} | ||
|
||
CUDA_ATOMIC_WRAPPER(Mul, int64_t) { | ||
// Here, we check long long int must be int64_t. | ||
static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT | ||
"long long should be int64"); | ||
long long int res = *address, old = res; // NOLINT | ||
do { | ||
old = res; | ||
res = (long long int)atomicCAS( // NOLINT | ||
(unsigned long long int *)address, // NOLINT | ||
(unsigned long long int)old, // NOLINT | ||
(unsigned long long int)val * (unsigned long long int)old); // NOLINT | ||
} while (old != res); | ||
return res; | ||
} | ||
|
||
CUDA_ATOMIC_WRAPPER(Mul, float) { | ||
int *const address_as_i = reinterpret_cast<int *>(address); | ||
int old = *address_as_i, assumed; | ||
|
||
do { | ||
assumed = old; | ||
old = atomicCAS( | ||
address_as_i, assumed, __float_as_int(val * __int_as_float(assumed))); | ||
} while (assumed != old); | ||
|
||
return __int_as_float(old); | ||
} | ||
|
||
CUDA_ATOMIC_WRAPPER(Mul, double) { | ||
unsigned long long int *const address_as_ull = // NOLINT | ||
reinterpret_cast<unsigned long long int *>(address); // NOLINT | ||
unsigned long long int old = *address_as_ull, assumed; // NOLINT | ||
|
||
do { | ||
assumed = old; | ||
|
||
old = atomicCAS(address_as_ull, | ||
assumed, | ||
__double_as_longlong(val * __longlong_as_double(assumed))); | ||
} while (assumed != old); | ||
|
||
return __longlong_as_double(old); | ||
} | ||
|
||
#ifdef PADDLE_CUDA_FP16 | ||
inline static __device__ uint32_t mul_to_low_half(uint32_t val, float x) { | ||
phi::dtype::float16 low_half; | ||
// The float16 in lower 16bits | ||
low_half.x = static_cast<uint16_t>(val & 0xFFFFu); | ||
low_half = static_cast<phi::dtype::float16>(static_cast<float>(low_half) * x); | ||
return (val & 0xFFFF0000u) | low_half.x; | ||
} | ||
|
||
inline static __device__ uint32_t mul_to_high_half(uint32_t val, float x) { | ||
phi::dtype::float16 high_half; | ||
// The float16 in higher 16bits | ||
high_half.x = static_cast<uint16_t>(val >> 16); | ||
high_half = | ||
static_cast<phi::dtype::float16>(static_cast<float>(high_half) * x); | ||
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16); | ||
} | ||
|
||
CUDA_ATOMIC_WRAPPER(Mul, phi::dtype::float16) { | ||
if (*address >= val) { | ||
return *address; | ||
} | ||
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>( | ||
reinterpret_cast<char *>(address) - | ||
(reinterpret_cast<uintptr_t>(address) & 0x02)); | ||
float val_f = static_cast<float>(val); | ||
uint32_t old = *address_as_ui; | ||
uint32_t assumed; | ||
if (((uintptr_t)address & 0x02) == 0) { | ||
// The float16 value stay at lower 16 bits of the address. | ||
do { | ||
assumed = old; | ||
old = atomicCAS(address_as_ui, assumed, mul_to_low_half(assumed, val_f)); | ||
} while (old != assumed); | ||
phi::dtype::float16 ret; | ||
ret.x = old & 0xFFFFu; | ||
return ret; | ||
} else { | ||
// The float16 value stay at higher 16 bits of the address. | ||
do { | ||
assumed = old; | ||
old = atomicCAS(address_as_ui, assumed, mul_to_high_half(assumed, val_f)); | ||
} while (old != assumed); | ||
phi::dtype::float16 ret; | ||
ret.x = old >> 16; | ||
return ret; | ||
} | ||
} | ||
#endif | ||
|
||
inline static __device__ uint32_t bf16_mul_to_low_half(uint32_t val, float x) { | ||
phi::dtype::bfloat16 low_half; | ||
// The bfloat16 in lower 16bits | ||
low_half.x = static_cast<uint16_t>(val & 0xFFFFu); | ||
low_half = | ||
static_cast<phi::dtype::bfloat16>(static_cast<float>(low_half) * x); | ||
return (val & 0xFFFF0000u) | low_half.x; | ||
} | ||
|
||
inline static __device__ uint32_t bf16_mul_to_high_half(uint32_t val, float x) { | ||
phi::dtype::bfloat16 high_half; | ||
// The bfloat16 in higher 16bits | ||
high_half.x = static_cast<uint16_t>(val >> 16); | ||
high_half = | ||
static_cast<phi::dtype::bfloat16>(static_cast<float>(high_half) * x); | ||
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16); | ||
} | ||
|
||
CUDA_ATOMIC_WRAPPER(Mul, phi::dtype::bfloat16) { | ||
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>( | ||
reinterpret_cast<char *>(address) - | ||
(reinterpret_cast<uintptr_t>(address) & 0x02)); | ||
float val_f = static_cast<float>(val); | ||
uint32_t old = *address_as_ui; | ||
uint32_t assumed; | ||
if (((uintptr_t)address & 0x02) == 0) { | ||
// The bfloat16 value stay at lower 16 bits of the address. | ||
do { | ||
assumed = old; | ||
old = atomicCAS( | ||
address_as_ui, assumed, bf16_mul_to_low_half(assumed, val_f)); | ||
} while (old != assumed); | ||
phi::dtype::bfloat16 ret; | ||
ret.x = old & 0xFFFFu; | ||
return ret; | ||
} else { | ||
// The bfloat16 value stay at higher 16 bits of the address. | ||
do { | ||
assumed = old; | ||
old = atomicCAS( | ||
address_as_ui, assumed, bf16_mul_to_high_half(assumed, val_f)); | ||
} while (old != assumed); | ||
phi::dtype::bfloat16 ret; | ||
ret.x = old >> 16; | ||
return ret; | ||
} | ||
} | ||
|
||
// For atomicMax | ||
USE_CUDA_ATOMIC(Max, int); | ||
USE_CUDA_ATOMIC(Max, unsigned int); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is a parameter of
broadcast
in Python API, shall we also add it here asinclude_self
or delete it from API?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
broadcast
is processed in the Python interface, so there is no need to pass it into the C interface again