-
Notifications
You must be signed in to change notification settings - Fork 7
Stochastic Rounding Optimizers #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
| return TWO_24; | ||
| } | ||
|
|
||
| // Natalia magic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep this comment.
- add docstring to `torch.stochastic_rounding` - refer `torch.stochastic_rounding` from stochastic rounding optimizers for details
efbfbf7 to
8ea3245
Compare
| from .sgd import SGD | ||
| from .sradam import SRAdam | ||
| from .sradamw import SRAdamW | ||
| from .srsgd import SRSGD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unclear whether we want to expose stochastic rounding as separate optimizer classes or as constructor options to existing optimizers. I don't think we need to make this choice now, though. I think Facebook people should look at what you've got then we can decide together.
| float weight = static_cast<float>(weights[i]); | ||
| float gradient = static_cast<float>(gradients[i]) * (*inv_scale); | ||
| float velocity = static_cast<float>(momentum_buffer[i]); | ||
| float4 random_values = curand_uniform4(&state); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you generate 4 rng and only use 2. I don't think that's a big problem though.
| } | ||
|
|
||
| const int block = 256; | ||
| const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is only correct if the kernel's number of registers per thread is <= 32, otherwise register pressure limits your occupancy. You can recompile kernels with -ptxas-options=-v as an nvcc option and nvcc will print how many registers they use (this is easiest to do with the kernels in an extension, I'm not sure how you would pass that option to nvcc in a pytorch build).
|
|
||
| AT_DISPATCH_FLOATING_TYPES( | ||
| input.scalar_type(), "stochastic_rounding_cuda", [&] { | ||
| stochastic_rounding_kernel<scalar_t, at::Half><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My biggest concern is, upstream will probably ask you to rewrite this with TensorIterator in some form, as @zasdfgbnm hinted.
|
|
||
| // Natalia magic | ||
| template <typename scalar_t> | ||
| __device__ __forceinline__ scalar_t round_stochastically(float x, float random_value) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now, this function only supports rounding to fp16. I think its syntax is misleading.
To make the usage clearer, and to establish an API that supports stochastic rounding to other types in the future, I think you should define it as follows:
template<typename round_to_prec, typename out_type, typename in_type=float>
struct round_stochastically {
static_assert(false, "round_stochastically only supports round_to_prec=at::Half");
__device__ __forceinline__ out_type operator()(in_type x, float random_value) {}
};
template <typename out_type, typename in_type=float>
struct round_stochastically<at::Half, out_type, in_type> {
__device__ __forceinline__ at::Half operator()(in_type x, float random_value) {
// what we have now
}
}
Then the caller should say
weights[i] = round_stochastically<at::Half, scalar_t>(weight, random_values.x);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
L.59' s maybe_upcast does cast from at::Half to float/double if necessary and stochastic rounding SGD & Adam kernel use this functionality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want to support stochastic rounding to other formats (like bfloat16) later. The API should allow the caller to set a type that determines the rounding precision, even if the actual rounding code for that precision isn't implemented yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understand. Updated.
| curand_init(seed_and_offset.first, tid, seed_and_offset.second, &state); | ||
|
|
||
| for (int64_t i = tid; i < numel; i += blockDim.x * gridDim.x) { | ||
| float inp = static_cast<float>(input[i]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you make these changes https://github.com/csarofeen/pytorch/pull/17/files#r420422970 the cast to float won't be necessary.
9b7caed to
31bd573
Compare
| if (x == 0.0) { | ||
| return out_type(0.0); | ||
| } | ||
| float delta = get_delta_fp16(static_cast<float>(x)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoding float here is probably fine IMO, but natalia may ask you to change this to in_type (which might require making get_delta_fp16 a template), and replace the __float2half_rz call with a wrapper function that has several overloads and the float overload calls __float2half_rz.
No description provided.