Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stochastic Rounding Optimizers #17

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions aten/src/ATen/native/cuda/StochasticRounding.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#include <ATen/ATen.h>
#include <ATen/native/cuda/stochastic_rounding.cuh>


namespace at {
namespace native {

template <typename input_t, typename output_t>
__global__ void stochastic_rounding_kernel(
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved
const input_t* input,
output_t* output,
const int64_t numel,
std::pair<uint64_t, uint64_t> seed_and_offset) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seed_and_offset.first, tid, seed_and_offset.second, &state);

round_stochastically<output_t, input_t, at::Half> rounder;

for (int64_t i = tid; i < numel; i += blockDim.x * gridDim.x) {
output[i] = rounder(input[i], curand_uniform(&state));
}
}

Tensor stochastic_rounding_cuda(const Tensor& input, c10::optional<Generator> gen_) {

TORCH_CHECK(input.is_contiguous());

if (input.scalar_type() == kHalf) {
return input;
}

Tensor output = at::empty_like(input, input.options().dtype(kHalf), input.suggest_memory_format());
const int64_t numel = input.numel();
if (numel == 0) {
return output;
}

const int block = 256;
const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block;
Copy link

@mcarilli mcarilli May 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only correct if the kernel's number of registers per thread is <= 32, otherwise register pressure limits your occupancy. You can recompile kernels with -ptxas-options=-v as an nvcc option and nvcc will print how many registers they use (this is easiest to do with the kernels in an extension, I'm not sure how you would pass that option to nvcc in a pytorch build).

unsigned int grid = (numel + block - 1) / block;
mcarilli marked this conversation as resolved.
Show resolved Hide resolved
grid = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid);

auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_engine_inputs((numel + block * grid - 1) / (block * grid));
}

AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "stochastic_rounding_cuda", [&] {
stochastic_rounding_kernel<scalar_t, at::Half><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My biggest concern is, upstream will probably ask you to rewrite this with TensorIterator in some form, as @zasdfgbnm hinted.

input.data_ptr<scalar_t>(),
output.data_ptr<at::Half>(),
numel, rng_engine_inputs);
});

return output;
}

} // namespace native
} // namespace at
125 changes: 125 additions & 0 deletions aten/src/ATen/native/cuda/StochasticRoundingAdam.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#include <ATen/ATen.h>
#include <ATen/native/cuda/stochastic_rounding.cuh>


namespace at {
namespace native {

template <typename scalar_t>
__global__ void stochastic_rounding_adam_step_kernel(
scalar_t *weights, scalar_t *gradients,
scalar_t *exp_avg, scalar_t *exp_avg_sq, scalar_t *max_exp_avg_sq,
float *inv_scale, float *found_inf,
float lr, float beta1, float beta2,
float weight_decay, float eps, int step,
bool is_decoupled, bool is_amsgrad,
int numel, std::pair<uint64_t, uint64_t> seeds) {

if (*found_inf) return;

int tid = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seeds.first, tid, seeds.second, &state);

round_stochastically<scalar_t, float, at::Half> rounder;

float m_correction = 1.0 - powf(beta1, step);
float v_correction = 1.0 - powf(beta2, step);

for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
float weight = static_cast<float>(weights[i]);
float gradient = static_cast<float>(gradients[i]) * (*inv_scale);
float m = static_cast<float>(exp_avg[i]);
// Stochastic Rounding Adam tracks square root of the exponential average of squared gradient.
float v = static_cast<float>(exp_avg_sq[i]);
v = v * v;
float4 random_values = curand_uniform4(&state);

if (weight_decay != 0.0f) {
if (is_decoupled)
weight *= (1 - lr * weight_decay);
else
gradient += weight_decay * weight;
}

// Update m and v.
m = beta1 * m + (1.0 - beta1) * gradient;
v = beta2 * v + (1.0 - beta2) * (gradient * gradient);

// Unbias v
float max_v = v;
if (is_amsgrad) {
float prev_max_v = static_cast<float>(max_exp_avg_sq[i]);
prev_max_v = prev_max_v * prev_max_v;
max_v = fmaxf(prev_max_v, v);
}

weight -= (lr / m_correction) * m / (sqrtf(max_v / v_correction) + eps);

weights[i] = rounder(weight, random_values.x);
exp_avg[i] = rounder(m, random_values.y);
exp_avg_sq[i] = rounder(sqrtf(v), random_values.z);
if (is_amsgrad) {
max_exp_avg_sq[i] = rounder(sqrtf(max_v), random_values.w);
}
}
}


Tensor stochastic_rounding_adam_step_cuda(
Tensor& param,
const Tensor& grad,
Tensor& exp_avg,
Tensor& exp_avg_sq,
Tensor& max_exp_avg_sq,
const Tensor& inv_scale,
const Tensor& found_inf,
double lr, double beta1, double beta2,
double weight_decay, double eps, int64_t step,
bool is_decoupled, bool is_amsgrad, c10::optional<Generator> gen_) {

if (param.numel() == 0) return param;

TORCH_CHECK(param.is_contiguous());
TORCH_CHECK(grad.is_contiguous());
TORCH_CHECK(exp_avg.is_contiguous());
TORCH_CHECK(exp_avg_sq.is_contiguous());
TORCH_CHECK(max_exp_avg_sq.is_contiguous());

const int64_t numel = param.numel();
const int block_size = 256;
const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size;
dim3 dim_block(block_size);
dim3 grid((numel + block_size - 1) / block_size);
mcarilli marked this conversation as resolved.
Show resolved Hide resolved
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);

auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());

uint64_t counter_offset = ((numel + dim_block.x * grid.x - 1) / (block_size * grid.x)) * 4;
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
}
crcrpar marked this conversation as resolved.
Show resolved Hide resolved

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
param.scalar_type(), "stochastic_rounding_adam_step_cuda", [&] {
stochastic_rounding_adam_step_kernel<scalar_t><<<grid, dim_block, 0, c10::cuda::getCurrentCUDAStream()>>>(
param.data_ptr<scalar_t>(),
grad.data_ptr<scalar_t>(),
exp_avg.data_ptr<scalar_t>(),
exp_avg_sq.data_ptr<scalar_t>(),
max_exp_avg_sq.data_ptr<scalar_t>(),
inv_scale.data_ptr<float>(),
found_inf.data_ptr<float>(),
lr, beta1, beta2, weight_decay, eps, step,
is_decoupled, is_amsgrad,
numel, rng_engine_inputs);
}
);
AT_CUDA_CHECK(cudaGetLastError());
return param;
}

} // namespace native
} // namespace at
95 changes: 95 additions & 0 deletions aten/src/ATen/native/cuda/StochasticRoundingSGD.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#include <ATen/ATen.h>
#include <ATen/native/cuda/stochastic_rounding.cuh>


namespace at {
namespace native {

// SGD update math with Stochastic Rounding
template <typename scalar_t>
__global__ void stochastic_rounding_sgd_step_kernel(
scalar_t *weights, scalar_t *gradients, scalar_t *momentum_buffer,
float* inv_scale, float* found_inf,
float weight_decay, float momentum, float dampening, float lr,
bool nesterov, bool first_run, int numel, std::pair<uint64_t, uint64_t> seeds) {

if (*found_inf) return;

int tid = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seeds.first, tid, seeds.second, &state);

round_stochastically<scalar_t, float, at::Half> rounder;

for (int i = tid; i < numel; i += blockDim.x * gridDim.x) {
float weight = static_cast<float>(weights[i]);
float gradient = static_cast<float>(gradients[i]) * (*inv_scale);
float velocity = static_cast<float>(momentum_buffer[i]);
float4 random_values = curand_uniform4(&state);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you generate 4 rng and only use 2. I don't think that's a big problem though.


if (weight_decay != 0.0f)
gradient += weight_decay * weight;

if (momentum != 0.0f) {
if (!first_run)
velocity = velocity * momentum + (1.0f - dampening) * gradient;
else
velocity = gradient;

if (nesterov)
gradient += momentum * velocity;
else
gradient = velocity;
}

weight -= lr * gradient;

weights[i] = rounder(weight, random_values.x);
if (momentum != 0.0f)
momentum_buffer[i] = rounder(velocity, random_values.y);
}
}

Tensor stochastic_rounding_sgd_step_cuda(
Tensor& param, const Tensor& grad, Tensor& momentum_buffer,
const Tensor& inv_scale, const Tensor& found_inf,
double lr, double momentum, double weight_decay, double dampening,
bool nesterov, bool first_run, c10::optional<Generator> gen_) {

if (param.numel() == 0) return param;

TORCH_CHECK(param.is_contiguous());
TORCH_CHECK(grad.is_contiguous());
TORCH_CHECK(momentum_buffer.is_contiguous());

const int64_t numel = param.numel();
const int block_size = 256;
const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size;
dim3 dim_block(block_size);
dim3 grid((numel + block_size - 1) / block_size);
mcarilli marked this conversation as resolved.
Show resolved Hide resolved
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);

auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
uint64_t counter_offset = ((numel + dim_block.x * grid.x - 1) / (dim_block.x * grid.x)) * 4;
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
}

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
param.scalar_type(), "stochastic_rounding_sgd_step_cuda", [&] {
stochastic_rounding_sgd_step_kernel<scalar_t><<<grid, dim_block, 0, c10::cuda::getCurrentCUDAStream()>>>(
param.data_ptr<scalar_t>(),
grad.data_ptr<scalar_t>(),
momentum_buffer.data_ptr<scalar_t>(),
inv_scale.data_ptr<float>(), found_inf.data_ptr<float>(),
static_cast<float>(weight_decay), static_cast<float>(momentum), static_cast<float>(dampening), static_cast<float>(lr),
nesterov, first_run, numel, rng_engine_inputs);
});
AT_CUDA_CHECK(cudaGetLastError());
return param;
}

} // namespace native
} // namespace at
67 changes: 67 additions & 0 deletions aten/src/ATen/native/cuda/stochastic_rounding.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#pragma once

#include <math.h>
#include <utility>

#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <curand.h>
#include <curand_kernel.h>

#include <ATen/Utils.h>
#include <ATen/Generator.h>
#include <ATen/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAFunctions.h>

// 2^-10 is the step for normal FP16 numbers.
// 2^-24 is the unit in the last place (ULP)/precision limitation.
// 24 is **NOT** related to the number of mantissa bits of single precision format.
__device__ const float TWO_10 = 0.0009765625;
__device__ const float TWO_24 = 0.000000059604644775390625;


template<typename T>
__device__ __forceinline__ T maybe_upcast(__half x){
return T(__half2float(x));
}

template<>
__device__ __forceinline__ __half maybe_upcast<__half>(__half x){
return x;
}

__device__ __forceinline__ float get_delta_fp16(float x) {
int exponent;
frexpf(x, &exponent);
exponent -= 1;
if (exponent >= -14)
return TWO_10 * std::pow(2, exponent);
else
return TWO_24;
}

// Natalia magic

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep this comment.

template <typename out_type, typename in_type, typename round_to_prec=at::Half>
struct round_stochastically {
static_assert(std::is_same<round_to_prec, at::Half>::value, "round_stochastically only supports round_to_prec=at::Half");
};

template <typename out_type, typename in_type>
struct round_stochastically<out_type, in_type, at::Half> {
__device__ __forceinline__ out_type operator()(in_type x, float random_value) {
if (x == 0.0) {
return out_type(0.0);
}
float delta = get_delta_fp16(static_cast<float>(x));
Copy link

@mcarilli mcarilli May 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoding float here is probably fine IMO, but natalia may ask you to change this to in_type (which might require making get_delta_fp16 a template), and replace the __float2half_rz call with a wrapper function that has several overloads and the float overload calls __float2half_rz.

float val;
if (x < 0.0) {
val = x - random_value * delta;
} else {
val = x + random_value * delta;
}
return maybe_upcast<out_type>(__float2half_rz(val));
}
};
12 changes: 12 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6719,3 +6719,15 @@
# It is undocumented and should not be used outside of tests.
- func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor
use_c10_dispatcher: full

- func: stochastic_rounding(Tensor input, Generator? gen_=None) -> Tensor
dispatch:
CUDA: stochastic_rounding_cuda

- func: stochastic_rounding_adam_step(Tensor(a!) param, Tensor grad, Tensor(b!) exp_avg, Tensor(c!) exp_avg_sq, Tensor(d!) max_exp_avg_sq, Tensor inv_scale, Tensor found_inf, float lr, float beta1, float beta2, float weight_decay, float eps, int step, bool is_decoupled, bool is_amsgrad, Generator? gen_=None) -> Tensor(a!)
dispatch:
CUDA: stochastic_rounding_adam_step_cuda

- func: stochastic_rounding_sgd_step(Tensor(a!) param, Tensor grad, Tensor(b!) momentum_buffer, Tensor inv_scale, Tensor found_inf, float lr, float momentum, float weight_decay, float dampening, bool nesterov, bool first_run, Generator? gen_=None) -> Tensor(a!)
dispatch:
CUDA: stochastic_rounding_sgd_step_cuda
6 changes: 6 additions & 0 deletions docs/source/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ Algorithms
:members:
.. autoclass:: SGD
:members:
.. autoclass:: SRAdam
:members:
.. autoclass:: SRAdamW
:members:
.. autoclass:: SRSGD
:members:

How to adjust learning rate
---------------------------
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ Pointwise Ops
.. autofunction:: sinh
.. autofunction:: sqrt
.. autofunction:: square
.. autofunction:: stochastic_rounding
.. autofunction:: tan
.. autofunction:: tanh
.. autofunction:: true_divide
Expand Down
1 change: 1 addition & 0 deletions test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
'test_overrides',
'test_jit_fuser_te',
'test_tensorexpr',
'test_stochastic_rounding',
]

# skip < 3.3 because mock is added in 3.3 and is used in rpc_spawn
Expand Down
Loading