From df605156811945e5d5b586ee86596dfa64955cde Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 13 Apr 2022 02:27:46 -0700 Subject: [PATCH 01/45] Add naive CPU implementation for memory-efficient attention backward Needs cleanup --- .../components/attention/csrc/attention.cpp | 2 + .../attention/csrc/cpu/attention.cpp | 179 ++++++++++++++++++ 2 files changed, 181 insertions(+) diff --git a/xformers/components/attention/csrc/attention.cpp b/xformers/components/attention/csrc/attention.cpp index 931b8f9689..90c905f8b8 100644 --- a/xformers/components/attention/csrc/attention.cpp +++ b/xformers/components/attention/csrc/attention.cpp @@ -3,4 +3,6 @@ TORCH_LIBRARY_FRAGMENT(xformers, m) { m.def(TORCH_SELECTIVE_SCHEMA( "xformers::efficient_attention(Tensor query, Tensor key, Tensor value) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "xformers::efficient_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value) -> (Tensor, Tensor, Tensor)")); } diff --git a/xformers/components/attention/csrc/cpu/attention.cpp b/xformers/components/attention/csrc/cpu/attention.cpp index d79eb36c4c..214097fa55 100644 --- a/xformers/components/attention/csrc/cpu/attention.cpp +++ b/xformers/components/attention/csrc/cpu/attention.cpp @@ -147,10 +147,189 @@ at::Tensor attention( return res; } +template +void attention_backward_kernel( + at::TensorAccessor grad_q, + at::TensorAccessor grad_k, + at::TensorAccessor grad_v, + at::TensorAccessor grad_out, + at::TensorAccessor q, + at::TensorAccessor k, + at::TensorAccessor v, + at::TensorAccessor logsumexp_normalizer, + at::TensorAccessor buffer, + at::TensorAccessor buffer2, + at::TensorAccessor buffer3 //, + // at::TensorAccessor mask +) { + int64_t K = q.size(2); + int64_t B = q.size(0); + int64_t M = q.size(1); + int64_t N = k.size(1); + int64_t grain_size = 1; // buffer.size(1); + at::parallel_for(0, B, grain_size, [&](int64_t start, int64_t end) { + auto buf = buffer[at::get_thread_num()][0]; + auto buf2 = buffer2[at::get_thread_num()]; + auto buf3 = buffer3[at::get_thread_num()][0]; + for (int64_t i = start; i < end; i++) { + for (int64_t l = 0; l < N; l++) { + for (int64_t k = 0; k < K; k++) { + buf2[l][k] = 0; + } + } + + for (int64_t j = 0; j < M; j++) { + for (int64_t k = 0; k < K; k++) { + buf[k] = 0; + } + auto aar = q[i][j]; + auto normalizer = logsumexp_normalizer[i][j]; + scalar_t tmp_sum = 0; + for (int64_t l = 0; l < N; l++) { + auto bar = k[i][l]; + scalar_t si = 0; + for (int64_t k = 0; k < K; k++) { + si += aar[k] * bar[k]; + } + scalar_t attn_v = std::exp(si - normalizer); + + for (int64_t k = 0; k < K; k++) { + grad_v[i][l][k] += attn_v * grad_out[i][j][k]; + } + + // now compute grad_q and grad_k + scalar_t g_attn = 0; + for (int64_t k = 0; k < K; k++) { + g_attn += grad_out[i][j][k] * v[i][l][k]; + // g_attn[i][j][l] += grad_out[i][j][k] * v[i][l][k]; + } + scalar_t tmp = attn_v * g_attn; + tmp_sum += tmp; + + // grad_q is easy + for (int64_t k = 0; k < K; k++) { + grad_q[i][j][k] += tmp * bar[k]; // bar == key + buf[k] += attn_v * bar[k]; + } + + // scalar_t factor = tmp_sum / (tmp_sum - tmp); + // if (tmp_sum == tmp) + // factor = 1; + + // but grad_k is trickier + for (int64_t k = 0; k < K; k++) { + grad_k[i][l][k] += tmp * aar[k]; + // buf2[l][k] = buf2[l][k] * factor + attn_v * aar[k] * tmp_sum; + } + } + buf3[j] = tmp_sum; + for (int64_t k = 0; k < K; k++) { + grad_q[i][j][k] -= buf[k] * tmp_sum; + } + } + + // TODO: try to make this folded in the previous loop + for (int64_t j = 0; j < M; j++) { + auto aar = q[i][j]; + auto normalizer = logsumexp_normalizer[i][j]; + scalar_t tmp = buf3[j]; + for (int64_t l = 0; l < N; l++) { + auto bar = k[i][l]; + scalar_t si = 0; + for (int64_t k = 0; k < K; k++) { + si += aar[k] * bar[k]; + } + scalar_t attn_v = std::exp(si - normalizer); + + for (int64_t k = 0; k < K; k++) { + buf2[l][k] += attn_v * aar[k] * tmp; + } + } + } + + for (int64_t l = 0; l < N; l++) { + for (int64_t k = 0; k < K; k++) { + grad_k[i][l][k] -= buf2[l][k]; + } + } + } + }); +} + +std::tuple attention_backward( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value + // const at::Tensor& mask +) { + TORCH_CHECK(query.dim() == key.dim()); + // TORCH_CHECK(query.dim() == mask.dim()); + TORCH_CHECK(query.dim() == 3); + TORCH_CHECK(query.size(2) == key.size(2)); + TORCH_CHECK(query.size(0) == key.size(0)); + // TORCH_CHECK(query.size(1) == mask.size(1)); + // TORCH_CHECK(query.size(2) == mask.size(2)); + // TORCH_CHECK(query.size(0) == mask.size(0)); + + /* + TORCH_CHECK(!a.is_cuda(), "a must be a CPU tensor"); + TORCH_CHECK(!b.is_cuda(), "b must be a CPU tensor"); + TORCH_CHECK(!mask.is_cuda(), "mask must be a CPU tensor"); + TORCH_CHECK(!a.is_sparse(), "a must be a dense tensor"); + TORCH_CHECK(!b.is_sparse(), "b must be a dense tensor"); + //TORCH_CHECK(mask.is_sparse(), "mask must be a sparse tensor"); + */ + + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + int64_t K = query.size(2); + + at::Tensor res = at::empty({B, M, K}, query.options()); + at::Tensor grad_q = at::zeros_like(query); + at::Tensor grad_k = at::zeros_like(key); + at::Tensor grad_v = at::zeros_like(value); + + int64_t grain_size = 32; // TODO: tune this + // at::Tensor buffer = at::empty({B, grain_size, K}, query.options()); + at::Tensor buffer = at::empty({at::get_num_threads(), 1, K}, query.options()); + at::Tensor buffer2 = + at::zeros({at::get_num_threads(), N, K + 1}, query.options()); + at::Tensor buffer3 = + at::zeros({at::get_num_threads(), 1, M}, query.options()); + + // TODO this should be an argument from the function + at::Tensor logsumexp = query.bmm(key.transpose(-2, -1)).logsumexp(-1); + + AT_DISPATCH_FLOATING_TYPES( + query.scalar_type(), "attention_backward_kernel", [&] { + attention_backward_kernel( + grad_q.accessor(), + grad_k.accessor(), + grad_v.accessor(), + grad_out.accessor(), + query.accessor(), + key.accessor(), + value.accessor(), + logsumexp.accessor(), + buffer.accessor(), + buffer2.accessor(), + buffer3.accessor() + // idxs.accessor() + ); + }); + + return std::make_tuple(grad_q, grad_k, grad_v); +} + } // namespace TORCH_LIBRARY_IMPL(xformers, CPU, m) { m.impl( TORCH_SELECTIVE_NAME("xformers::efficient_attention"), TORCH_FN(attention)); + m.impl( + TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward"), + TORCH_FN(attention_backward)); } From 77535fe04b1e04112b10ee257a033489756a85a3 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 13 Apr 2022 04:22:37 -0700 Subject: [PATCH 02/45] Optimize (at least) by a factor 2 --- .../attention/csrc/cpu/attention.cpp | 55 +++++-------------- 1 file changed, 14 insertions(+), 41 deletions(-) diff --git a/xformers/components/attention/csrc/cpu/attention.cpp b/xformers/components/attention/csrc/cpu/attention.cpp index 214097fa55..1f18856a15 100644 --- a/xformers/components/attention/csrc/cpu/attention.cpp +++ b/xformers/components/attention/csrc/cpu/attention.cpp @@ -158,8 +158,7 @@ void attention_backward_kernel( at::TensorAccessor v, at::TensorAccessor logsumexp_normalizer, at::TensorAccessor buffer, - at::TensorAccessor buffer2, - at::TensorAccessor buffer3 //, + at::TensorAccessor buffer2 //, // at::TensorAccessor mask ) { int64_t K = q.size(2); @@ -169,14 +168,8 @@ void attention_backward_kernel( int64_t grain_size = 1; // buffer.size(1); at::parallel_for(0, B, grain_size, [&](int64_t start, int64_t end) { auto buf = buffer[at::get_thread_num()][0]; - auto buf2 = buffer2[at::get_thread_num()]; - auto buf3 = buffer3[at::get_thread_num()][0]; + auto buf2 = buffer2[at::get_thread_num()][0]; for (int64_t i = start; i < end; i++) { - for (int64_t l = 0; l < N; l++) { - for (int64_t k = 0; k < K; k++) { - buf2[l][k] = 0; - } - } for (int64_t j = 0; j < M; j++) { for (int64_t k = 0; k < K; k++) { @@ -212,44 +205,27 @@ void attention_backward_kernel( buf[k] += attn_v * bar[k]; } - // scalar_t factor = tmp_sum / (tmp_sum - tmp); - // if (tmp_sum == tmp) - // factor = 1; + scalar_t factor = tmp_sum / (tmp_sum - tmp); + if (tmp_sum == tmp) + factor = 1; // but grad_k is trickier + buf2[l] = attn_v; for (int64_t k = 0; k < K; k++) { grad_k[i][l][k] += tmp * aar[k]; - // buf2[l][k] = buf2[l][k] * factor + attn_v * aar[k] * tmp_sum; + //buf3[l][k] = attn_v * aar[k]; + //for (int64_t ll = 0; ll < N; ll++) { + // buf2[l][k] = buf2[l][k] * factor + attn_v * aar[k] * tmp_sum; + //} } } - buf3[j] = tmp_sum; - for (int64_t k = 0; k < K; k++) { - grad_q[i][j][k] -= buf[k] * tmp_sum; - } - } - - // TODO: try to make this folded in the previous loop - for (int64_t j = 0; j < M; j++) { - auto aar = q[i][j]; - auto normalizer = logsumexp_normalizer[i][j]; - scalar_t tmp = buf3[j]; for (int64_t l = 0; l < N; l++) { - auto bar = k[i][l]; - scalar_t si = 0; - for (int64_t k = 0; k < K; k++) { - si += aar[k] * bar[k]; - } - scalar_t attn_v = std::exp(si - normalizer); - for (int64_t k = 0; k < K; k++) { - buf2[l][k] += attn_v * aar[k] * tmp; + grad_k[i][l][k] -= buf2[l] * aar[k] * tmp_sum; } } - } - - for (int64_t l = 0; l < N; l++) { for (int64_t k = 0; k < K; k++) { - grad_k[i][l][k] -= buf2[l][k]; + grad_q[i][j][k] -= buf[k] * tmp_sum; } } } @@ -295,9 +271,7 @@ std::tuple attention_backward( // at::Tensor buffer = at::empty({B, grain_size, K}, query.options()); at::Tensor buffer = at::empty({at::get_num_threads(), 1, K}, query.options()); at::Tensor buffer2 = - at::zeros({at::get_num_threads(), N, K + 1}, query.options()); - at::Tensor buffer3 = - at::zeros({at::get_num_threads(), 1, M}, query.options()); + at::zeros({at::get_num_threads(), 1, N}, query.options()); // TODO this should be an argument from the function at::Tensor logsumexp = query.bmm(key.transpose(-2, -1)).logsumexp(-1); @@ -314,8 +288,7 @@ std::tuple attention_backward( value.accessor(), logsumexp.accessor(), buffer.accessor(), - buffer2.accessor(), - buffer3.accessor() + buffer2.accessor() // idxs.accessor() ); }); From bd6d1f2ae371ef6b68c1bfa4d21c931203fd279b Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 13 Apr 2022 04:35:17 -0700 Subject: [PATCH 03/45] More cleanups --- .../attention/csrc/cpu/attention.cpp | 61 ++++++++++--------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/xformers/components/attention/csrc/cpu/attention.cpp b/xformers/components/attention/csrc/cpu/attention.cpp index 1f18856a15..f4e7c02c55 100644 --- a/xformers/components/attention/csrc/cpu/attention.cpp +++ b/xformers/components/attention/csrc/cpu/attention.cpp @@ -175,14 +175,14 @@ void attention_backward_kernel( for (int64_t k = 0; k < K; k++) { buf[k] = 0; } - auto aar = q[i][j]; + auto query_i = q[i][j]; auto normalizer = logsumexp_normalizer[i][j]; scalar_t tmp_sum = 0; for (int64_t l = 0; l < N; l++) { - auto bar = k[i][l]; + auto key_j = k[i][l]; scalar_t si = 0; for (int64_t k = 0; k < K; k++) { - si += aar[k] * bar[k]; + si += query_i[k] * key_j[k]; } scalar_t attn_v = std::exp(si - normalizer); @@ -201,27 +201,19 @@ void attention_backward_kernel( // grad_q is easy for (int64_t k = 0; k < K; k++) { - grad_q[i][j][k] += tmp * bar[k]; // bar == key - buf[k] += attn_v * bar[k]; + grad_q[i][j][k] += tmp * key_j[k]; // key_j == key + buf[k] += attn_v * key_j[k]; } - scalar_t factor = tmp_sum / (tmp_sum - tmp); - if (tmp_sum == tmp) - factor = 1; - // but grad_k is trickier buf2[l] = attn_v; for (int64_t k = 0; k < K; k++) { - grad_k[i][l][k] += tmp * aar[k]; - //buf3[l][k] = attn_v * aar[k]; - //for (int64_t ll = 0; ll < N; ll++) { - // buf2[l][k] = buf2[l][k] * factor + attn_v * aar[k] * tmp_sum; - //} + grad_k[i][l][k] += tmp * query_i[k]; } } for (int64_t l = 0; l < N; l++) { for (int64_t k = 0; k < K; k++) { - grad_k[i][l][k] -= buf2[l] * aar[k] * tmp_sum; + grad_k[i][l][k] -= buf2[l] * query_i[k] * tmp_sum; } } for (int64_t k = 0; k < K; k++) { @@ -239,23 +231,36 @@ std::tuple attention_backward( const at::Tensor& value // const at::Tensor& mask ) { + + + TORCH_CHECK(query.dim() == grad_out.dim()); TORCH_CHECK(query.dim() == key.dim()); + TORCH_CHECK(query.dim() == value.dim()); // TORCH_CHECK(query.dim() == mask.dim()); TORCH_CHECK(query.dim() == 3); + + TORCH_CHECK(query.size(0) == grad_out.size(0)); + TORCH_CHECK(query.size(1) == grad_out.size(1)); + TORCH_CHECK(query.size(2) == grad_out.size(2)); + TORCH_CHECK(query.size(2) == key.size(2)); TORCH_CHECK(query.size(0) == key.size(0)); - // TORCH_CHECK(query.size(1) == mask.size(1)); - // TORCH_CHECK(query.size(2) == mask.size(2)); - // TORCH_CHECK(query.size(0) == mask.size(0)); - - /* - TORCH_CHECK(!a.is_cuda(), "a must be a CPU tensor"); - TORCH_CHECK(!b.is_cuda(), "b must be a CPU tensor"); - TORCH_CHECK(!mask.is_cuda(), "mask must be a CPU tensor"); - TORCH_CHECK(!a.is_sparse(), "a must be a dense tensor"); - TORCH_CHECK(!b.is_sparse(), "b must be a dense tensor"); - //TORCH_CHECK(mask.is_sparse(), "mask must be a sparse tensor"); - */ + + TORCH_CHECK(query.size(0) == value.size(0)); + TORCH_CHECK(key.size(1) == value.size(1)); + TORCH_CHECK( + query.size(2) == + value.size(2)); // TODO: drop this limitation in the future + + TORCH_CHECK(!query.is_cuda(), "query must be a CPU tensor"); + TORCH_CHECK(!key.is_cuda(), "key must be a CPU tensor"); + TORCH_CHECK(!value.is_cuda(), "value must be a CPU tensor"); + TORCH_CHECK(!grad_out.is_cuda(), "grad_out must be a CPU tensor"); + + TORCH_CHECK(!query.is_sparse(), "query must be a dense tensor"); + TORCH_CHECK(!key.is_sparse(), "key must be a dense tensor"); + TORCH_CHECK(!value.is_sparse(), "value must be a dense tensor"); + TORCH_CHECK(!grad_out.is_sparse(), "grad_out must be a dense tensor"); int64_t B = query.size(0); int64_t M = query.size(1); @@ -267,8 +272,6 @@ std::tuple attention_backward( at::Tensor grad_k = at::zeros_like(key); at::Tensor grad_v = at::zeros_like(value); - int64_t grain_size = 32; // TODO: tune this - // at::Tensor buffer = at::empty({B, grain_size, K}, query.options()); at::Tensor buffer = at::empty({at::get_num_threads(), 1, K}, query.options()); at::Tensor buffer2 = at::zeros({at::get_num_threads(), 1, N}, query.options()); From 54988fe1160edc70aa3f415b083ed4a4f086c059 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 13 Apr 2022 04:43:01 -0700 Subject: [PATCH 04/45] A few more comments --- .../components/attention/csrc/cpu/attention.cpp | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/xformers/components/attention/csrc/cpu/attention.cpp b/xformers/components/attention/csrc/cpu/attention.cpp index f4e7c02c55..962ee72a38 100644 --- a/xformers/components/attention/csrc/cpu/attention.cpp +++ b/xformers/components/attention/csrc/cpu/attention.cpp @@ -191,21 +191,25 @@ void attention_backward_kernel( } // now compute grad_q and grad_k - scalar_t g_attn = 0; + // first compute the gradient for the self-attention + // after softmax + scalar_t grad_attn_v = 0; for (int64_t k = 0; k < K; k++) { - g_attn += grad_out[i][j][k] * v[i][l][k]; - // g_attn[i][j][l] += grad_out[i][j][k] * v[i][l][k]; + grad_attn_v += grad_out[i][j][k] * v[i][l][k]; + // grad_attn_v[i][j][l] += grad_out[i][j][k] * v[i][l][k]; } - scalar_t tmp = attn_v * g_attn; + + // those are temporaries for the gradient of the softmax + scalar_t tmp = attn_v * grad_attn_v; tmp_sum += tmp; // grad_q is easy for (int64_t k = 0; k < K; k++) { - grad_q[i][j][k] += tmp * key_j[k]; // key_j == key + grad_q[i][j][k] += tmp * key_j[k]; buf[k] += attn_v * key_j[k]; } - // but grad_k is trickier + // but grad_k is a bit trickier buf2[l] = attn_v; for (int64_t k = 0; k < K; k++) { grad_k[i][l][k] += tmp * query_i[k]; From 9ab94a2ca58adc531c58c9df4175319ec59d0805 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 13 Apr 2022 05:35:41 -0700 Subject: [PATCH 05/45] Add very naive CUDA implementation It's super slow! --- .../attention/csrc/cuda/attention.cu | 164 ++++++++++++++++++ 1 file changed, 164 insertions(+) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index d9880548ca..40d26396e0 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -5,6 +5,7 @@ #include #include +#include #include "sputnik/vector_utils.h" @@ -587,10 +588,173 @@ at::Tensor attention( return res; } +template +__global__ void attention_backward_kernel( + at::PackedTensorAccessor grad_q, + at::PackedTensorAccessor grad_k, + at::PackedTensorAccessor grad_v, + at::PackedTensorAccessor grad_out, + at::PackedTensorAccessor query, + at::PackedTensorAccessor key, + at::PackedTensorAccessor value, + at::PackedTensorAccessor logsumexp_normalizer, + at::PackedTensorAccessor buffer, + at::PackedTensorAccessor buffer2 +) { + int64_t K = query.size(2); + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + int64_t batch_idx = blockIdx.y; + int64_t query_idx = blockIdx.x; + + auto buf = buffer[batch_idx][query_idx]; + auto buf2 = buffer2[batch_idx][query_idx]; + + for (int64_t k = 0; k < K; k++) { + buf[k] = 0; + } + auto query_i = query[batch_idx][query_idx]; + auto normalizer = logsumexp_normalizer[batch_idx][query_idx]; + scalar_t tmp_sum = 0; + for (int64_t l = 0; l < N; l++) { + auto key_j = key[batch_idx][l]; + scalar_t si = 0; + for (int64_t k = 0; k < K; k++) { + si += query_i[k] * key_j[k]; + } + scalar_t attn_v = std::exp(si - normalizer); + + for (int64_t k = 0; k < K; k++) { + // grad_v[batch_idx][l][k] += attn_v * grad_out[batch_idx][query_idx][k]; + gpuAtomicAdd( + &grad_v[batch_idx][l][k], attn_v * grad_out[batch_idx][query_idx][k]); + } + + // now compute grad_q and grad_k + // first compute the gradient for the self-attention + // after softmax + scalar_t grad_attn_v = 0; + for (int64_t k = 0; k < K; k++) { + grad_attn_v += grad_out[batch_idx][query_idx][k] * value[batch_idx][l][k]; + // grad_attn_v[i][j][l] += grad_out[i][j][k] * v[i][l][k]; + } + + // those are temporaries for the gradient of the softmax + scalar_t tmp = attn_v * grad_attn_v; + tmp_sum += tmp; + + // grad_q is easy + for (int64_t k = 0; k < K; k++) { + // grad_q[batch_idx][query_idx][k] += tmp * key_j[k]; + gpuAtomicAdd(&grad_q[batch_idx][query_idx][k], tmp * key_j[k]); + buf[k] += attn_v * key_j[k]; + } + + // but grad_k is a bit trickier + buf2[l] = attn_v; + for (int64_t k = 0; k < K; k++) { + // grad_k[batch_idx][l][k] += tmp * query_i[k]; + gpuAtomicAdd(&grad_k[batch_idx][l][k], tmp * query_i[k]); + } + } + for (int64_t l = 0; l < N; l++) { + for (int64_t k = 0; k < K; k++) { + // grad_k[batch_idx][l][k] -= buf2[l] * query_i[k] * tmp_sum; + gpuAtomicAdd(&grad_k[batch_idx][l][k], -buf2[l] * query_i[k] * tmp_sum); + } + } + for (int64_t k = 0; k < K; k++) { + // grad_q[batch_idx][query_idx][k] -= buf[k] * tmp_sum; + gpuAtomicAdd(&grad_q[batch_idx][query_idx][k], -buf[k] * tmp_sum); + } +} + +std::tuple attention_backward( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value + // const at::Tensor& mask +) { + TORCH_CHECK(query.dim() == grad_out.dim()); + TORCH_CHECK(query.dim() == key.dim()); + TORCH_CHECK(query.dim() == value.dim()); + // TORCH_CHECK(query.dim() == mask.dim()); + TORCH_CHECK(query.dim() == 3); + + TORCH_CHECK(query.size(0) == grad_out.size(0)); + TORCH_CHECK(query.size(1) == grad_out.size(1)); + TORCH_CHECK(query.size(2) == grad_out.size(2)); + + TORCH_CHECK(query.size(2) == key.size(2)); + TORCH_CHECK(query.size(0) == key.size(0)); + + TORCH_CHECK(query.size(0) == value.size(0)); + TORCH_CHECK(key.size(1) == value.size(1)); + TORCH_CHECK( + query.size(2) == + value.size(2)); // TODO: drop this limitation in the future + + TORCH_CHECK(query.is_cuda(), "query must be a CUDA tensor"); + TORCH_CHECK(key.is_cuda(), "key must be a CUDA tensor"); + TORCH_CHECK(value.is_cuda(), "value must be a CUDA tensor"); + TORCH_CHECK(grad_out.is_cuda(), "grad_out must be a CUDA tensor"); + + TORCH_CHECK(!query.is_sparse(), "query must be a dense tensor"); + TORCH_CHECK(!key.is_sparse(), "key must be a dense tensor"); + TORCH_CHECK(!value.is_sparse(), "value must be a dense tensor"); + TORCH_CHECK(!grad_out.is_sparse(), "grad_out must be a dense tensor"); + + at::cuda::CUDAGuard device_guard(query.device()); + + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + int64_t K = query.size(2); + + at::Tensor res = at::empty({B, M, K}, query.options()); + at::Tensor grad_q = at::zeros_like(query); + at::Tensor grad_k = at::zeros_like(key); + at::Tensor grad_v = at::zeros_like(value); + + at::Tensor buffer = at::empty({B, M, K}, query.options()); + at::Tensor buffer2 = at::zeros({B, M, N}, query.options()); + + // TODO this should be an argument from the function + at::Tensor logsumexp = query.bmm(key.transpose(-2, -1)).logsumexp(-1); + + dim3 grid(M, B); + dim3 block(1, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES( + query.scalar_type(), "attention_backward_kernel", [&] { + attention_backward_kernel<<>>( + grad_q.packed_accessor(), + grad_k.packed_accessor(), + grad_v.packed_accessor(), + grad_out.packed_accessor(), + query.packed_accessor(), + key.packed_accessor(), + value.packed_accessor(), + logsumexp.packed_accessor(), + buffer.packed_accessor(), + buffer2.packed_accessor() + ); + }); + + return std::make_tuple(grad_q, grad_k, grad_v); +} + } // namespace TORCH_LIBRARY_IMPL(xformers, CUDA, m) { m.impl( TORCH_SELECTIVE_NAME("xformers::efficient_attention"), TORCH_FN(attention)); + m.impl( + TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward"), + TORCH_FN(attention_backward)); } From aeb49e81c96a1f123518a1a7ed9e7998fb6c96f2 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 13 Apr 2022 06:25:35 -0700 Subject: [PATCH 06/45] Speedup CUDA kernel by 5x But we still have a long way to go --- .../attention/csrc/cuda/attention.cu | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 40d26396e0..4217b88649 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -617,7 +617,7 @@ __global__ void attention_backward_kernel( auto query_i = query[batch_idx][query_idx]; auto normalizer = logsumexp_normalizer[batch_idx][query_idx]; scalar_t tmp_sum = 0; - for (int64_t l = 0; l < N; l++) { + for (int64_t l = threadIdx.x; l < N; l += blockDim.x) { auto key_j = key[batch_idx][l]; scalar_t si = 0; for (int64_t k = 0; k < K; k++) { @@ -648,7 +648,8 @@ __global__ void attention_backward_kernel( for (int64_t k = 0; k < K; k++) { // grad_q[batch_idx][query_idx][k] += tmp * key_j[k]; gpuAtomicAdd(&grad_q[batch_idx][query_idx][k], tmp * key_j[k]); - buf[k] += attn_v * key_j[k]; + //buf[k] += attn_v * key_j[k]; + gpuAtomicAdd(&buf[k], attn_v * key_j[k]); } // but grad_k is a bit trickier @@ -658,16 +659,19 @@ __global__ void attention_backward_kernel( gpuAtomicAdd(&grad_k[batch_idx][l][k], tmp * query_i[k]); } } - for (int64_t l = 0; l < N; l++) { + tmp_sum = warpSum(tmp_sum); + for (int64_t l = threadIdx.x; l < N; l += blockDim.x) { for (int64_t k = 0; k < K; k++) { // grad_k[batch_idx][l][k] -= buf2[l] * query_i[k] * tmp_sum; gpuAtomicAdd(&grad_k[batch_idx][l][k], -buf2[l] * query_i[k] * tmp_sum); } } + if (threadIdx.x == 0) { for (int64_t k = 0; k < K; k++) { // grad_q[batch_idx][query_idx][k] -= buf[k] * tmp_sum; gpuAtomicAdd(&grad_q[batch_idx][query_idx][k], -buf[k] * tmp_sum); } + } } std::tuple attention_backward( @@ -725,12 +729,13 @@ std::tuple attention_backward( at::Tensor logsumexp = query.bmm(key.transpose(-2, -1)).logsumexp(-1); dim3 grid(M, B); - dim3 block(1, 1); + dim3 block(32, 1); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES( - query.scalar_type(), "attention_backward_kernel", [&] { + using scalar_t = float; + //AT_DISPATCH_FLOATING_TYPES( + // query.scalar_type(), "attention_backward_kernel", [&] { attention_backward_kernel<<>>( grad_q.packed_accessor(), grad_k.packed_accessor(), @@ -743,7 +748,7 @@ std::tuple attention_backward( buffer.packed_accessor(), buffer2.packed_accessor() ); - }); + // }); return std::make_tuple(grad_q, grad_k, grad_v); } From 643baaa1797cbee31a326b5653561897934901af Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 13 Apr 2022 08:16:16 -0700 Subject: [PATCH 07/45] Make logsumexp an argument --- xformers/components/attention/csrc/attention.cpp | 2 +- .../components/attention/csrc/cpu/attention.cpp | 5 +++-- .../components/attention/csrc/cuda/attention.cu | 13 ++++++++----- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/xformers/components/attention/csrc/attention.cpp b/xformers/components/attention/csrc/attention.cpp index 90c905f8b8..3b0d6e71b0 100644 --- a/xformers/components/attention/csrc/attention.cpp +++ b/xformers/components/attention/csrc/attention.cpp @@ -4,5 +4,5 @@ TORCH_LIBRARY_FRAGMENT(xformers, m) { m.def(TORCH_SELECTIVE_SCHEMA( "xformers::efficient_attention(Tensor query, Tensor key, Tensor value) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA( - "xformers::efficient_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value) -> (Tensor, Tensor, Tensor)")); + "xformers::efficient_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor logsumexp) -> (Tensor, Tensor, Tensor)")); } diff --git a/xformers/components/attention/csrc/cpu/attention.cpp b/xformers/components/attention/csrc/cpu/attention.cpp index 962ee72a38..dd5d4280e5 100644 --- a/xformers/components/attention/csrc/cpu/attention.cpp +++ b/xformers/components/attention/csrc/cpu/attention.cpp @@ -232,7 +232,8 @@ std::tuple attention_backward( const at::Tensor& grad_out, const at::Tensor& query, const at::Tensor& key, - const at::Tensor& value + const at::Tensor& value, + const at::Tensor& logsumexp // const at::Tensor& mask ) { @@ -281,7 +282,7 @@ std::tuple attention_backward( at::zeros({at::get_num_threads(), 1, N}, query.options()); // TODO this should be an argument from the function - at::Tensor logsumexp = query.bmm(key.transpose(-2, -1)).logsumexp(-1); + //at::Tensor logsumexp = query.bmm(key.transpose(-2, -1)).logsumexp(-1); AT_DISPATCH_FLOATING_TYPES( query.scalar_type(), "attention_backward_kernel", [&] { diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 4217b88649..d4e3dc1964 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -606,7 +606,9 @@ __global__ void attention_backward_kernel( int64_t M = query.size(1); int64_t N = key.size(1); int64_t batch_idx = blockIdx.y; - int64_t query_idx = blockIdx.x; + int64_t query_idx = blockIdx.x * blockDim.y + threadIdx.y; + + if (query_idx >= M) return; auto buf = buffer[batch_idx][query_idx]; auto buf2 = buffer2[batch_idx][query_idx]; @@ -678,7 +680,8 @@ std::tuple attention_backward( const at::Tensor& grad_out, const at::Tensor& query, const at::Tensor& key, - const at::Tensor& value + const at::Tensor& value, + const at::Tensor& logsumexp // const at::Tensor& mask ) { TORCH_CHECK(query.dim() == grad_out.dim()); @@ -726,10 +729,10 @@ std::tuple attention_backward( at::Tensor buffer2 = at::zeros({B, M, N}, query.options()); // TODO this should be an argument from the function - at::Tensor logsumexp = query.bmm(key.transpose(-2, -1)).logsumexp(-1); + //at::Tensor logsumexp = query.bmm(key.transpose(-2, -1)).logsumexp(-1); - dim3 grid(M, B); - dim3 block(32, 1); + dim3 grid(ceil_div(M, int64_t(16)), B); + dim3 block(32, 16); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); From 788e756762736a485751e18697ce69db8cdadb24 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 14 Apr 2022 01:42:13 -0700 Subject: [PATCH 08/45] Make it 30% faster Merge two loops together and use local buffers for accumulation and grad_q. The use of local buffers as is currently introduces limitations on the sizes of dimension K --- .../attention/csrc/cuda/attention.cu | 45 ++++++++++++++----- 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index d4e3dc1964..0731912e86 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -613,11 +613,22 @@ __global__ void attention_backward_kernel( auto buf = buffer[batch_idx][query_idx]; auto buf2 = buffer2[batch_idx][query_idx]; + constexpr int64_t BUFFER_SIZE = 32; + constexpr int64_t kBlockSizeK = 8; + + //scalar_t temp_grad_q[kBlockSizeK][BUFFER_SIZE] = {0}; + //scalar_t temp_grad_k[kBlockSizeK][BUFFER_SIZE] = {0}; + //scalar_t temp_grad_v[kBlockSizeK][BUFFER_SIZE] = {0}; + scalar_t temp_buffer[BUFFER_SIZE] = {0}; + scalar_t temp_grad_q[BUFFER_SIZE] = {0}; + + /* for (int64_t k = 0; k < K; k++) { buf[k] = 0; - } + }*/ auto query_i = query[batch_idx][query_idx]; auto normalizer = logsumexp_normalizer[batch_idx][query_idx]; + auto grad_q_i = grad_q[batch_idx][query_idx]; scalar_t tmp_sum = 0; for (int64_t l = threadIdx.x; l < N; l += blockDim.x) { auto key_j = key[batch_idx][l]; @@ -627,19 +638,19 @@ __global__ void attention_backward_kernel( } scalar_t attn_v = std::exp(si - normalizer); - for (int64_t k = 0; k < K; k++) { - // grad_v[batch_idx][l][k] += attn_v * grad_out[batch_idx][query_idx][k]; - gpuAtomicAdd( - &grad_v[batch_idx][l][k], attn_v * grad_out[batch_idx][query_idx][k]); - } - // now compute grad_q and grad_k // first compute the gradient for the self-attention // after softmax scalar_t grad_attn_v = 0; for (int64_t k = 0; k < K; k++) { - grad_attn_v += grad_out[batch_idx][query_idx][k] * value[batch_idx][l][k]; - // grad_attn_v[i][j][l] += grad_out[i][j][k] * v[i][l][k]; + scalar_t temp = grad_out[batch_idx][query_idx][k]; + // grad_v[batch_idx][l][k] += attn_v * grad_out[batch_idx][query_idx][k]; + gpuAtomicAdd( + &grad_v[batch_idx][l][k], attn_v * temp); + + //temp_grad_v[l][k] += attn_v * grad_out[batch_idx][query_idx][k]; + + grad_attn_v += temp * value[batch_idx][l][k]; } // those are temporaries for the gradient of the softmax @@ -649,9 +660,11 @@ __global__ void attention_backward_kernel( // grad_q is easy for (int64_t k = 0; k < K; k++) { // grad_q[batch_idx][query_idx][k] += tmp * key_j[k]; - gpuAtomicAdd(&grad_q[batch_idx][query_idx][k], tmp * key_j[k]); + //gpuAtomicAdd(&grad_q_i[k], tmp * key_j[k]); + temp_grad_q[k] += tmp * key_j[k]; //buf[k] += attn_v * key_j[k]; - gpuAtomicAdd(&buf[k], attn_v * key_j[k]); + //gpuAtomicAdd(&buf[k], attn_v * key_j[k]); + temp_buffer[k] += attn_v * key_j[k]; } // but grad_k is a bit trickier @@ -662,16 +675,24 @@ __global__ void attention_backward_kernel( } } tmp_sum = warpSum(tmp_sum); + for (int64_t l = threadIdx.x; l < N; l += blockDim.x) { for (int64_t k = 0; k < K; k++) { // grad_k[batch_idx][l][k] -= buf2[l] * query_i[k] * tmp_sum; gpuAtomicAdd(&grad_k[batch_idx][l][k], -buf2[l] * query_i[k] * tmp_sum); } } + for (int64_t k = 0; k < K; k++) { + temp_grad_q[k] = warpSum(temp_grad_q[k]); + temp_buffer[k] = warpSum(temp_buffer[k]); + } if (threadIdx.x == 0) { for (int64_t k = 0; k < K; k++) { // grad_q[batch_idx][query_idx][k] -= buf[k] * tmp_sum; - gpuAtomicAdd(&grad_q[batch_idx][query_idx][k], -buf[k] * tmp_sum); + + //gpuAtomicAdd(&grad_q_i[k], -buf[k] * tmp_sum); + //gpuAtomicAdd(&grad_q_i[k], -temp_buffer[k] * tmp_sum); + grad_q_i[k] = temp_grad_q[k] - temp_buffer[k] * tmp_sum; } } } From ade14e77e79a39965c173c39574529d72e51cf71 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 14 Apr 2022 10:48:45 -0700 Subject: [PATCH 09/45] 3.5x speedup by blocking strategy --- .../attention/csrc/cuda/attention.cu | 214 +++++++++++++++--- 1 file changed, 183 insertions(+), 31 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 0731912e86..8cd2078d18 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -103,7 +103,7 @@ __device__ void compute_dot( scalar_t out[kBlockSizeQ][kBlockSizeK], int64_t K) { constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); - scalar_t scale = 1.0 / std::sqrt(scalar_t(K)); + scalar_t scale = 1.0;// / std::sqrt(scalar_t(K)); vec_t q_i[kBlockSizeQ]; for (int64_t k = 0; k < K / kVecSize; k += 1) { #pragma unroll @@ -605,44 +605,135 @@ __global__ void attention_backward_kernel( int64_t B = query.size(0); int64_t M = query.size(1); int64_t N = key.size(1); + + + constexpr int64_t BUFFER_SIZE = 32; + constexpr int64_t kBlockSizeQ = 4; + constexpr int64_t kBlockSizeK = 4; + int64_t batch_idx = blockIdx.y; - int64_t query_idx = blockIdx.x * blockDim.y + threadIdx.y; + //int64_t query_idx = blockIdx.x * blockDim.y + threadIdx.y; + int64_t query_idx = blockIdx.x * blockDim.y * kBlockSizeQ + threadIdx.y * kBlockSizeQ; if (query_idx >= M) return; auto buf = buffer[batch_idx][query_idx]; auto buf2 = buffer2[batch_idx][query_idx]; - constexpr int64_t BUFFER_SIZE = 32; - constexpr int64_t kBlockSizeK = 8; - //scalar_t temp_grad_q[kBlockSizeK][BUFFER_SIZE] = {0}; //scalar_t temp_grad_k[kBlockSizeK][BUFFER_SIZE] = {0}; //scalar_t temp_grad_v[kBlockSizeK][BUFFER_SIZE] = {0}; - scalar_t temp_buffer[BUFFER_SIZE] = {0}; - scalar_t temp_grad_q[BUFFER_SIZE] = {0}; + scalar_t temp_buffer[kBlockSizeQ][BUFFER_SIZE] = {0}; + //scalar_t temp_grad_q[BUFFER_SIZE] = {0}; + scalar_t temp_grad_q[kBlockSizeQ][BUFFER_SIZE] = {0}; + //using vec_t = float4; + using vec_t = float; + + constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); + + vec_t* query_block[kBlockSizeQ]; + vec_t* grad_out_block[kBlockSizeQ]; + scalar_t normalizer[kBlockSizeQ]; + + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + int64_t index = query_idx + q_item_idx; + index = index >= M ? M - 1 : index; + query_block[q_item_idx] = + reinterpret_cast(query[batch_idx][index].data()); + grad_out_block[q_item_idx] = + reinterpret_cast(grad_out[batch_idx][index].data()); + normalizer[q_item_idx] = logsumexp_normalizer[batch_idx][index]; + } /* for (int64_t k = 0; k < K; k++) { buf[k] = 0; }*/ - auto query_i = query[batch_idx][query_idx]; - auto normalizer = logsumexp_normalizer[batch_idx][query_idx]; + //auto query_i = query[batch_idx][query_idx + //auto normalizer = logsumexp_normalizer[batch_idx][query_idx]; auto grad_q_i = grad_q[batch_idx][query_idx]; - scalar_t tmp_sum = 0; - for (int64_t l = threadIdx.x; l < N; l += blockDim.x) { + //scalar_t tmp_sum = 0; + scalar_t tmp_sum[kBlockSizeQ] = {0}; + for (int64_t l = threadIdx.x * kBlockSizeK; l < N; l += blockDim.x * kBlockSizeK) { + /* auto key_j = key[batch_idx][l]; scalar_t si = 0; for (int64_t k = 0; k < K; k++) { si += query_i[k] * key_j[k]; + }*/ + + auto key_j = reinterpret_cast(key[batch_idx][l].data()); + + scalar_t attn_v[kBlockSizeQ][kBlockSizeK] = {0}; + compute_dot( + query_block, key_j, attn_v, K); + + //for (int64_t k = 0; k < K; k++) { + // attn_v[0][0] += query_block[0][k] * key_j[k]; + // } + +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { +#pragma unroll + for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + attn_v[q_item_idx][k_item_idx] = std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]); + } } - scalar_t attn_v = std::exp(si - normalizer); + + //scalar_t attn_v = std::exp(si - normalizer); // now compute grad_q and grad_k // first compute the gradient for the self-attention // after softmax - scalar_t grad_attn_v = 0; - for (int64_t k = 0; k < K; k++) { + //scalar_t grad_attn_v = 0; + scalar_t grad_attn_v[kBlockSizeQ][kBlockSizeK] = {0}; + auto value_j = reinterpret_cast(value[batch_idx][l].data()); + + for (int64_t k = 0; k < K / kVecSize; k++) { +#if 1 + vec_t temp_i[kBlockSizeQ]; +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + temp_i[q_item_idx] = __ldg(grad_out_block[q_item_idx] + k); + } + +#pragma unroll + for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + vec_t v = value_j[k + K / kVecSize * k_item_idx]; + //vec_t tt; tt.x = 0; tt.y = 0; tt.z = 0; tt.w = 0; + vec_t tt = {0}; +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + sputnik::VectorCompute::Dot(temp_i[q_item_idx], v, &grad_attn_v[q_item_idx][k_item_idx]); + //iMul(v[q_item_idx][k_item_idx], temp_i[q_item_idx]); + tt += attn_v[q_item_idx][k_item_idx] * temp_i[q_item_idx]; + /* + tt.x += attn_v[q_item_idx][k_item_idx] * temp_i[q_item_idx].x; + tt.y += attn_v[q_item_idx][k_item_idx] * temp_i[q_item_idx].y; + tt.z += attn_v[q_item_idx][k_item_idx] * temp_i[q_item_idx].z; + tt.w += attn_v[q_item_idx][k_item_idx] * temp_i[q_item_idx].w;*/ + + } + gpuAtomicAdd(&grad_v[batch_idx][l + k_item_idx][k], tt); + //gpuAtomicAdd(&grad_v[batch_idx][l + k_item_idx][k * 4 + 0], tt.x); + //gpuAtomicAdd(&grad_v[batch_idx][l + k_item_idx][k * 4 + 1], tt.y); + //gpuAtomicAdd(&grad_v[batch_idx][l + k_item_idx][k * 4 + 2], tt.z); + //gpuAtomicAdd(&grad_v[batch_idx][l + k_item_idx][k * 4 + 3], tt.w); + } + +#else + + scalar_t temp = grad_out[batch_idx][query_idx][k]; + // grad_v[batch_idx][l][k] += attn_v * grad_out[batch_idx][query_idx][k]; + gpuAtomicAdd( + &grad_v[batch_idx][l][k], attn_v[0][0] * temp); + + //temp_grad_v[l][k] += attn_v * grad_out[batch_idx][query_idx][k]; + + grad_attn_v[0][0] += temp * value[batch_idx][l][k]; + +#endif + /* scalar_t temp = grad_out[batch_idx][query_idx][k]; // grad_v[batch_idx][l][k] += attn_v * grad_out[batch_idx][query_idx][k]; gpuAtomicAdd( @@ -650,41 +741,93 @@ __global__ void attention_backward_kernel( //temp_grad_v[l][k] += attn_v * grad_out[batch_idx][query_idx][k]; - grad_attn_v += temp * value[batch_idx][l][k]; + grad_attn_v += temp * value[batch_idx][l][k];*/ } // those are temporaries for the gradient of the softmax - scalar_t tmp = attn_v * grad_attn_v; - tmp_sum += tmp; + scalar_t tmp[kBlockSizeQ][kBlockSizeK]; +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { +#pragma unroll + for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + tmp[q_item_idx][k_item_idx] = attn_v[q_item_idx][k_item_idx] * grad_attn_v[q_item_idx][k_item_idx]; + tmp_sum[q_item_idx] += tmp[q_item_idx][k_item_idx]; + } + } + //scalar_t tmp = attn_v * grad_attn_v; + //tmp_sum += tmp; // grad_q is easy for (int64_t k = 0; k < K; k++) { // grad_q[batch_idx][query_idx][k] += tmp * key_j[k]; //gpuAtomicAdd(&grad_q_i[k], tmp * key_j[k]); - temp_grad_q[k] += tmp * key_j[k]; - //buf[k] += attn_v * key_j[k]; - //gpuAtomicAdd(&buf[k], attn_v * key_j[k]); - temp_buffer[k] += attn_v * key_j[k]; +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { +#pragma unroll + for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + temp_grad_q[q_item_idx][k] += tmp[q_item_idx][k_item_idx] * key_j[k + K * k_item_idx]; + //buf[k] += attn_v * key_j[k]; + //gpuAtomicAdd(&buf[k], attn_v * key_j[k]); + temp_buffer[q_item_idx][k] += attn_v[q_item_idx][k_item_idx] * key_j[k + K * k_item_idx]; + } + } } // but grad_k is a bit trickier - buf2[l] = attn_v; + //buf2[l] = attn_v; + +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { +#pragma unroll + for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + buffer2[batch_idx][query_idx + q_item_idx][l + k_item_idx] = attn_v[q_item_idx][k_item_idx]; + } + } for (int64_t k = 0; k < K; k++) { // grad_k[batch_idx][l][k] += tmp * query_i[k]; - gpuAtomicAdd(&grad_k[batch_idx][l][k], tmp * query_i[k]); + //gpuAtomicAdd(&grad_k[batch_idx][l][k], tmp * query_i[k]); +#pragma unroll + for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + scalar_t res = 0; +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + res += tmp[q_item_idx][k_item_idx] * query_block[q_item_idx][k]; + } + gpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k], res); + + } } } - tmp_sum = warpSum(tmp_sum); + //tmp_sum = warpSum(tmp_sum); +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + tmp_sum[q_item_idx] = warpSum(tmp_sum[q_item_idx]); + } - for (int64_t l = threadIdx.x; l < N; l += blockDim.x) { + for (int64_t l = threadIdx.x * kBlockSizeK; l < N; l += blockDim.x * kBlockSizeK) { for (int64_t k = 0; k < K; k++) { // grad_k[batch_idx][l][k] -= buf2[l] * query_i[k] * tmp_sum; - gpuAtomicAdd(&grad_k[batch_idx][l][k], -buf2[l] * query_i[k] * tmp_sum); + //gpuAtomicAdd(&grad_k[batch_idx][l][k], -buf2[l] * query_i[k] * tmp_sum); +#pragma unroll + for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + scalar_t res = 0; +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + res += buffer2[batch_idx][query_idx + q_item_idx][l + k_item_idx] * query_block[q_item_idx][k] * tmp_sum[q_item_idx]; + } + gpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k], -res); + } } } for (int64_t k = 0; k < K; k++) { - temp_grad_q[k] = warpSum(temp_grad_q[k]); - temp_buffer[k] = warpSum(temp_buffer[k]); + //temp_grad_q[k] = warpSum(temp_grad_q[k]); + //temp_buffer[k] = warpSum(temp_buffer[k]); + +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + temp_grad_q[q_item_idx][k] = warpSum(temp_grad_q[q_item_idx][k]); + temp_buffer[q_item_idx][k] = warpSum(temp_buffer[q_item_idx][k]); + } } if (threadIdx.x == 0) { for (int64_t k = 0; k < K; k++) { @@ -692,7 +835,11 @@ __global__ void attention_backward_kernel( //gpuAtomicAdd(&grad_q_i[k], -buf[k] * tmp_sum); //gpuAtomicAdd(&grad_q_i[k], -temp_buffer[k] * tmp_sum); - grad_q_i[k] = temp_grad_q[k] - temp_buffer[k] * tmp_sum; + //grad_q_i[k] = temp_grad_q[k] - temp_buffer[k] * tmp_sum; +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + grad_q[batch_idx][query_idx + q_item_idx][k] = temp_grad_q[q_item_idx][k] - temp_buffer[q_item_idx][k] * tmp_sum[q_item_idx]; + } } } } @@ -752,8 +899,13 @@ std::tuple attention_backward( // TODO this should be an argument from the function //at::Tensor logsumexp = query.bmm(key.transpose(-2, -1)).logsumexp(-1); - dim3 grid(ceil_div(M, int64_t(16)), B); - dim3 block(32, 16); + //dim3 grid(ceil_div(M, int64_t(16)), B); + //dim3 block(32, 16); + constexpr int TILE_SIZE = 16 * 4; + constexpr int kBlockSizeQ = 4; + + dim3 grid(ceil_div(M, int64_t(TILE_SIZE)), B); + dim3 block(32, TILE_SIZE / kBlockSizeQ); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); From f6427d7c51b5c34ae47f0f9176991e2035c7832d Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 14 Apr 2022 12:17:05 -0700 Subject: [PATCH 10/45] Use vector loads and improve tile selection Brings an extra 2x improvement --- .../attention/csrc/cuda/attention.cu | 139 ++++++++++++------ 1 file changed, 91 insertions(+), 48 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 8cd2078d18..ff9ee08491 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -59,6 +59,45 @@ __device__ __forceinline__ void iDiv(scalar_t x1, float* out) { out[0] /= x1; } + +template +__device__ __forceinline__ void axpy(scalar_t a, float4 in, float4* out) { + out[0].x += a * in.x; + out[0].y += a * in.y; + out[0].z += a * in.z; + out[0].w += a * in.w; +} + +template +__device__ __forceinline__ void axpy(scalar_t a, float2 in, float2* out) { + out[0].x += a * in.x; + out[0].y += a * in.y; +} + +template +__device__ __forceinline__ void axpy(scalar_t a, float in, float* out) { + out[0] += a * in; +} + +template +__device__ __forceinline__ void myGpuAtomicAdd(scalar_t* address, float4 val) { + gpuAtomicAdd(address + 0, val.x); + gpuAtomicAdd(address + 1, val.y); + gpuAtomicAdd(address + 2, val.z); + gpuAtomicAdd(address + 3, val.w); +} + +template +__device__ __forceinline__ void myGpuAtomicAdd(scalar_t* address, float2 val) { + gpuAtomicAdd(address + 0, val.x); + gpuAtomicAdd(address + 1, val.y); +} + +template +__device__ __forceinline__ void myGpuAtomicAdd(scalar_t* address, float val) { + gpuAtomicAdd(address, val); +} + template __device__ __forceinline__ scalar_t warpSum(scalar_t val) { for (int stride = WARP_SIZE / 2; stride > 0; stride >>= 1) { @@ -584,6 +623,7 @@ at::Tensor attention( key.packed_accessor(), value.packed_accessor()); } + AT_CUDA_CHECK(cudaGetLastError()); return res; } @@ -606,10 +646,14 @@ __global__ void attention_backward_kernel( int64_t M = query.size(1); int64_t N = key.size(1); + using vec_t = float4; + //using vec_t = float; + constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); + + constexpr int64_t BUFFER_SIZE = 32 / kVecSize; + constexpr int64_t kBlockSizeQ = 8; + constexpr int64_t kBlockSizeK = 8; - constexpr int64_t BUFFER_SIZE = 32; - constexpr int64_t kBlockSizeQ = 4; - constexpr int64_t kBlockSizeK = 4; int64_t batch_idx = blockIdx.y; //int64_t query_idx = blockIdx.x * blockDim.y + threadIdx.y; @@ -623,16 +667,16 @@ __global__ void attention_backward_kernel( //scalar_t temp_grad_q[kBlockSizeK][BUFFER_SIZE] = {0}; //scalar_t temp_grad_k[kBlockSizeK][BUFFER_SIZE] = {0}; //scalar_t temp_grad_v[kBlockSizeK][BUFFER_SIZE] = {0}; - scalar_t temp_buffer[kBlockSizeQ][BUFFER_SIZE] = {0}; + vec_t temp_buffer[kBlockSizeQ][BUFFER_SIZE] = {0}; //scalar_t temp_grad_q[BUFFER_SIZE] = {0}; - scalar_t temp_grad_q[kBlockSizeQ][BUFFER_SIZE] = {0}; - //using vec_t = float4; - using vec_t = float; + vec_t temp_grad_q[kBlockSizeQ][BUFFER_SIZE] = {0}; + + - constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); vec_t* query_block[kBlockSizeQ]; vec_t* grad_out_block[kBlockSizeQ]; + vec_t* grad_q_block[kBlockSizeQ]; scalar_t normalizer[kBlockSizeQ]; for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { @@ -642,6 +686,8 @@ __global__ void attention_backward_kernel( reinterpret_cast(query[batch_idx][index].data()); grad_out_block[q_item_idx] = reinterpret_cast(grad_out[batch_idx][index].data()); + grad_q_block[q_item_idx] = + reinterpret_cast(grad_q[batch_idx][index].data()); normalizer[q_item_idx] = logsumexp_normalizer[batch_idx][index]; } @@ -651,16 +697,10 @@ __global__ void attention_backward_kernel( }*/ //auto query_i = query[batch_idx][query_idx //auto normalizer = logsumexp_normalizer[batch_idx][query_idx]; - auto grad_q_i = grad_q[batch_idx][query_idx]; + //auto grad_q_i = grad_q[batch_idx][query_idx]; //scalar_t tmp_sum = 0; scalar_t tmp_sum[kBlockSizeQ] = {0}; for (int64_t l = threadIdx.x * kBlockSizeK; l < N; l += blockDim.x * kBlockSizeK) { - /* - auto key_j = key[batch_idx][l]; - scalar_t si = 0; - for (int64_t k = 0; k < K; k++) { - si += query_i[k] * key_j[k]; - }*/ auto key_j = reinterpret_cast(key[batch_idx][l].data()); @@ -668,9 +708,6 @@ __global__ void attention_backward_kernel( compute_dot( query_block, key_j, attn_v, K); - //for (int64_t k = 0; k < K; k++) { - // attn_v[0][0] += query_block[0][k] * key_j[k]; - // } #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { @@ -706,19 +743,10 @@ __global__ void attention_backward_kernel( for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { sputnik::VectorCompute::Dot(temp_i[q_item_idx], v, &grad_attn_v[q_item_idx][k_item_idx]); //iMul(v[q_item_idx][k_item_idx], temp_i[q_item_idx]); - tt += attn_v[q_item_idx][k_item_idx] * temp_i[q_item_idx]; - /* - tt.x += attn_v[q_item_idx][k_item_idx] * temp_i[q_item_idx].x; - tt.y += attn_v[q_item_idx][k_item_idx] * temp_i[q_item_idx].y; - tt.z += attn_v[q_item_idx][k_item_idx] * temp_i[q_item_idx].z; - tt.w += attn_v[q_item_idx][k_item_idx] * temp_i[q_item_idx].w;*/ - + axpy(attn_v[q_item_idx][k_item_idx], temp_i[q_item_idx], &tt); } - gpuAtomicAdd(&grad_v[batch_idx][l + k_item_idx][k], tt); - //gpuAtomicAdd(&grad_v[batch_idx][l + k_item_idx][k * 4 + 0], tt.x); - //gpuAtomicAdd(&grad_v[batch_idx][l + k_item_idx][k * 4 + 1], tt.y); - //gpuAtomicAdd(&grad_v[batch_idx][l + k_item_idx][k * 4 + 2], tt.z); - //gpuAtomicAdd(&grad_v[batch_idx][l + k_item_idx][k * 4 + 3], tt.w); + myGpuAtomicAdd(&grad_v[batch_idx][l + k_item_idx][k * kVecSize], tt); + //gpuAtomicAdd(&grad_v[batch_idx][l + k_item_idx][k], tt); } #else @@ -758,17 +786,18 @@ __global__ void attention_backward_kernel( //tmp_sum += tmp; // grad_q is easy - for (int64_t k = 0; k < K; k++) { + for (int64_t k = 0; k < K / kVecSize; k++) { // grad_q[batch_idx][query_idx][k] += tmp * key_j[k]; //gpuAtomicAdd(&grad_q_i[k], tmp * key_j[k]); #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { #pragma unroll for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - temp_grad_q[q_item_idx][k] += tmp[q_item_idx][k_item_idx] * key_j[k + K * k_item_idx]; + vec_t ttt = key_j[k + K / kVecSize * k_item_idx]; + axpy(tmp[q_item_idx][k_item_idx], ttt, &temp_grad_q[q_item_idx][k]); //buf[k] += attn_v * key_j[k]; //gpuAtomicAdd(&buf[k], attn_v * key_j[k]); - temp_buffer[q_item_idx][k] += attn_v[q_item_idx][k_item_idx] * key_j[k + K * k_item_idx]; + axpy(attn_v[q_item_idx][k_item_idx], ttt, &temp_buffer[q_item_idx][k]); } } } @@ -783,18 +812,21 @@ __global__ void attention_backward_kernel( buffer2[batch_idx][query_idx + q_item_idx][l + k_item_idx] = attn_v[q_item_idx][k_item_idx]; } } - for (int64_t k = 0; k < K; k++) { + for (int64_t k = 0; k < K / kVecSize; k++) { // grad_k[batch_idx][l][k] += tmp * query_i[k]; //gpuAtomicAdd(&grad_k[batch_idx][l][k], tmp * query_i[k]); #pragma unroll for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - scalar_t res = 0; + //scalar_t res = 0; + vec_t res = {0}; #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - res += tmp[q_item_idx][k_item_idx] * query_block[q_item_idx][k]; + //res += tmp[q_item_idx][k_item_idx] * query_block[q_item_idx][k]; + vec_t qqq = query_block[q_item_idx][k]; + axpy(tmp[q_item_idx][k_item_idx], qqq, &res); } - gpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k], res); - + //gpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k], res); + myGpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k * kVecSize], res); } } } @@ -805,21 +837,26 @@ __global__ void attention_backward_kernel( } for (int64_t l = threadIdx.x * kBlockSizeK; l < N; l += blockDim.x * kBlockSizeK) { - for (int64_t k = 0; k < K; k++) { + for (int64_t k = 0; k < K / kVecSize; k++) { // grad_k[batch_idx][l][k] -= buf2[l] * query_i[k] * tmp_sum; //gpuAtomicAdd(&grad_k[batch_idx][l][k], -buf2[l] * query_i[k] * tmp_sum); #pragma unroll for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - scalar_t res = 0; + //scalar_t res = 0; + vec_t res = {0}; #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - res += buffer2[batch_idx][query_idx + q_item_idx][l + k_item_idx] * query_block[q_item_idx][k] * tmp_sum[q_item_idx]; + //res += buffer2[batch_idx][query_idx + q_item_idx][l + k_item_idx] * query_block[q_item_idx][k] * tmp_sum[q_item_idx]; + scalar_t ttt = - buffer2[batch_idx][query_idx + q_item_idx][l + k_item_idx] * tmp_sum[q_item_idx]; + vec_t qqq = query_block[q_item_idx][k]; + axpy(ttt, qqq, &res); + } + //gpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k], -res); + myGpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k * kVecSize], res); } - gpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k], -res); - } } } - for (int64_t k = 0; k < K; k++) { + for (int64_t k = 0; k < K / kVecSize; k++) { //temp_grad_q[k] = warpSum(temp_grad_q[k]); //temp_buffer[k] = warpSum(temp_buffer[k]); @@ -830,7 +867,7 @@ __global__ void attention_backward_kernel( } } if (threadIdx.x == 0) { - for (int64_t k = 0; k < K; k++) { + for (int64_t k = 0; k < K / kVecSize; k++) { // grad_q[batch_idx][query_idx][k] -= buf[k] * tmp_sum; //gpuAtomicAdd(&grad_q_i[k], -buf[k] * tmp_sum); @@ -838,7 +875,11 @@ __global__ void attention_backward_kernel( //grad_q_i[k] = temp_grad_q[k] - temp_buffer[k] * tmp_sum; #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - grad_q[batch_idx][query_idx + q_item_idx][k] = temp_grad_q[q_item_idx][k] - temp_buffer[q_item_idx][k] * tmp_sum[q_item_idx]; + //grad_q[batch_idx][query_idx + q_item_idx][k] = temp_grad_q[q_item_idx][k] - temp_buffer[q_item_idx][k] * tmp_sum[q_item_idx]; + grad_q_block[q_item_idx][k].x = temp_grad_q[q_item_idx][k].x - temp_buffer[q_item_idx][k].x * tmp_sum[q_item_idx]; + grad_q_block[q_item_idx][k].y = temp_grad_q[q_item_idx][k].y - temp_buffer[q_item_idx][k].y * tmp_sum[q_item_idx]; + grad_q_block[q_item_idx][k].z = temp_grad_q[q_item_idx][k].z - temp_buffer[q_item_idx][k].z * tmp_sum[q_item_idx]; + grad_q_block[q_item_idx][k].w = temp_grad_q[q_item_idx][k].w - temp_buffer[q_item_idx][k].w * tmp_sum[q_item_idx]; } } } @@ -901,8 +942,8 @@ std::tuple attention_backward( //dim3 grid(ceil_div(M, int64_t(16)), B); //dim3 block(32, 16); - constexpr int TILE_SIZE = 16 * 4; - constexpr int kBlockSizeQ = 4; + constexpr int TILE_SIZE = 16 * 2; + constexpr int kBlockSizeQ = 8; dim3 grid(ceil_div(M, int64_t(TILE_SIZE)), B); dim3 block(32, TILE_SIZE / kBlockSizeQ); @@ -926,6 +967,8 @@ std::tuple attention_backward( ); // }); + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(grad_q, grad_k, grad_v); } From 70bfedaa694db5553939fce681eea920bb711465 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 15 Apr 2022 04:16:35 -0700 Subject: [PATCH 11/45] Recompute attention for grad_q computation Makes it another 20% faster, and doesn't use extra memory --- .../attention/csrc/cuda/attention.cu | 56 ++++++++----------- 1 file changed, 22 insertions(+), 34 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index ff9ee08491..1305d25e1e 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -637,9 +637,7 @@ __global__ void attention_backward_kernel( at::PackedTensorAccessor query, at::PackedTensorAccessor key, at::PackedTensorAccessor value, - at::PackedTensorAccessor logsumexp_normalizer, - at::PackedTensorAccessor buffer, - at::PackedTensorAccessor buffer2 + at::PackedTensorAccessor logsumexp_normalizer ) { int64_t K = query.size(2); int64_t B = query.size(0); @@ -661,9 +659,6 @@ __global__ void attention_backward_kernel( if (query_idx >= M) return; - auto buf = buffer[batch_idx][query_idx]; - auto buf2 = buffer2[batch_idx][query_idx]; - //scalar_t temp_grad_q[kBlockSizeK][BUFFER_SIZE] = {0}; //scalar_t temp_grad_k[kBlockSizeK][BUFFER_SIZE] = {0}; //scalar_t temp_grad_v[kBlockSizeK][BUFFER_SIZE] = {0}; @@ -691,10 +686,6 @@ __global__ void attention_backward_kernel( normalizer[q_item_idx] = logsumexp_normalizer[batch_idx][index]; } - /* - for (int64_t k = 0; k < K; k++) { - buf[k] = 0; - }*/ //auto query_i = query[batch_idx][query_idx //auto normalizer = logsumexp_normalizer[batch_idx][query_idx]; //auto grad_q_i = grad_q[batch_idx][query_idx]; @@ -795,23 +786,13 @@ __global__ void attention_backward_kernel( for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { vec_t ttt = key_j[k + K / kVecSize * k_item_idx]; axpy(tmp[q_item_idx][k_item_idx], ttt, &temp_grad_q[q_item_idx][k]); - //buf[k] += attn_v * key_j[k]; - //gpuAtomicAdd(&buf[k], attn_v * key_j[k]); axpy(attn_v[q_item_idx][k_item_idx], ttt, &temp_buffer[q_item_idx][k]); } } } // but grad_k is a bit trickier - //buf2[l] = attn_v; -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { -#pragma unroll - for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - buffer2[batch_idx][query_idx + q_item_idx][l + k_item_idx] = attn_v[q_item_idx][k_item_idx]; - } - } for (int64_t k = 0; k < K / kVecSize; k++) { // grad_k[batch_idx][l][k] += tmp * query_i[k]; //gpuAtomicAdd(&grad_k[batch_idx][l][k], tmp * query_i[k]); @@ -837,17 +818,30 @@ __global__ void attention_backward_kernel( } for (int64_t l = threadIdx.x * kBlockSizeK; l < N; l += blockDim.x * kBlockSizeK) { + + + auto key_j = reinterpret_cast(key[batch_idx][l].data()); + scalar_t attn_v[kBlockSizeQ][kBlockSizeK] = {0}; + compute_dot( + query_block, key_j, attn_v, K); + + +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { +#pragma unroll + for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + attn_v[q_item_idx][k_item_idx] = std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]); + } + } + + for (int64_t k = 0; k < K / kVecSize; k++) { - // grad_k[batch_idx][l][k] -= buf2[l] * query_i[k] * tmp_sum; - //gpuAtomicAdd(&grad_k[batch_idx][l][k], -buf2[l] * query_i[k] * tmp_sum); #pragma unroll for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - //scalar_t res = 0; vec_t res = {0}; #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - //res += buffer2[batch_idx][query_idx + q_item_idx][l + k_item_idx] * query_block[q_item_idx][k] * tmp_sum[q_item_idx]; - scalar_t ttt = - buffer2[batch_idx][query_idx + q_item_idx][l + k_item_idx] * tmp_sum[q_item_idx]; + scalar_t ttt = - attn_v[q_item_idx][k_item_idx] * tmp_sum[q_item_idx]; vec_t qqq = query_block[q_item_idx][k]; axpy(ttt, qqq, &res); } @@ -868,14 +862,13 @@ __global__ void attention_backward_kernel( } if (threadIdx.x == 0) { for (int64_t k = 0; k < K / kVecSize; k++) { - // grad_q[batch_idx][query_idx][k] -= buf[k] * tmp_sum; - - //gpuAtomicAdd(&grad_q_i[k], -buf[k] * tmp_sum); //gpuAtomicAdd(&grad_q_i[k], -temp_buffer[k] * tmp_sum); //grad_q_i[k] = temp_grad_q[k] - temp_buffer[k] * tmp_sum; #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { //grad_q[batch_idx][query_idx + q_item_idx][k] = temp_grad_q[q_item_idx][k] - temp_buffer[q_item_idx][k] * tmp_sum[q_item_idx]; + //axpy(-tmp_sum[q_item_idx], temp_buffer[q_item_idx][k], &temp_grad_q[q_item_idx][k]); + //grad_q_block[q_item_idx][k] = temp_grad_q[q_item_idx][k]; grad_q_block[q_item_idx][k].x = temp_grad_q[q_item_idx][k].x - temp_buffer[q_item_idx][k].x * tmp_sum[q_item_idx]; grad_q_block[q_item_idx][k].y = temp_grad_q[q_item_idx][k].y - temp_buffer[q_item_idx][k].y * tmp_sum[q_item_idx]; grad_q_block[q_item_idx][k].z = temp_grad_q[q_item_idx][k].z - temp_buffer[q_item_idx][k].z * tmp_sum[q_item_idx]; @@ -934,9 +927,6 @@ std::tuple attention_backward( at::Tensor grad_k = at::zeros_like(key); at::Tensor grad_v = at::zeros_like(value); - at::Tensor buffer = at::empty({B, M, K}, query.options()); - at::Tensor buffer2 = at::zeros({B, M, N}, query.options()); - // TODO this should be an argument from the function //at::Tensor logsumexp = query.bmm(key.transpose(-2, -1)).logsumexp(-1); @@ -961,9 +951,7 @@ std::tuple attention_backward( query.packed_accessor(), key.packed_accessor(), value.packed_accessor(), - logsumexp.packed_accessor(), - buffer.packed_accessor(), - buffer2.packed_accessor() + logsumexp.packed_accessor() ); // }); From 7628805016af92970f7cb04a80b972159e064d82 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 15 Apr 2022 04:25:32 -0700 Subject: [PATCH 12/45] Smal cleanups --- .../attention/csrc/cuda/attention.cu | 74 +++---------------- 1 file changed, 11 insertions(+), 63 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 1305d25e1e..5b10c26c8c 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -628,7 +628,7 @@ at::Tensor attention( return res; } -template +template __global__ void attention_backward_kernel( at::PackedTensorAccessor grad_q, at::PackedTensorAccessor grad_k, @@ -644,31 +644,17 @@ __global__ void attention_backward_kernel( int64_t M = query.size(1); int64_t N = key.size(1); - using vec_t = float4; - //using vec_t = float; constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); - constexpr int64_t BUFFER_SIZE = 32 / kVecSize; - constexpr int64_t kBlockSizeQ = 8; - constexpr int64_t kBlockSizeK = 8; - int64_t batch_idx = blockIdx.y; - //int64_t query_idx = blockIdx.x * blockDim.y + threadIdx.y; int64_t query_idx = blockIdx.x * blockDim.y * kBlockSizeQ + threadIdx.y * kBlockSizeQ; if (query_idx >= M) return; - //scalar_t temp_grad_q[kBlockSizeK][BUFFER_SIZE] = {0}; - //scalar_t temp_grad_k[kBlockSizeK][BUFFER_SIZE] = {0}; - //scalar_t temp_grad_v[kBlockSizeK][BUFFER_SIZE] = {0}; vec_t temp_buffer[kBlockSizeQ][BUFFER_SIZE] = {0}; - //scalar_t temp_grad_q[BUFFER_SIZE] = {0}; vec_t temp_grad_q[kBlockSizeQ][BUFFER_SIZE] = {0}; - - - vec_t* query_block[kBlockSizeQ]; vec_t* grad_out_block[kBlockSizeQ]; vec_t* grad_q_block[kBlockSizeQ]; @@ -686,10 +672,6 @@ __global__ void attention_backward_kernel( normalizer[q_item_idx] = logsumexp_normalizer[batch_idx][index]; } - //auto query_i = query[batch_idx][query_idx - //auto normalizer = logsumexp_normalizer[batch_idx][query_idx]; - //auto grad_q_i = grad_q[batch_idx][query_idx]; - //scalar_t tmp_sum = 0; scalar_t tmp_sum[kBlockSizeQ] = {0}; for (int64_t l = threadIdx.x * kBlockSizeK; l < N; l += blockDim.x * kBlockSizeK) { @@ -708,8 +690,6 @@ __global__ void attention_backward_kernel( } } - //scalar_t attn_v = std::exp(si - normalizer); - // now compute grad_q and grad_k // first compute the gradient for the self-attention // after softmax @@ -718,7 +698,6 @@ __global__ void attention_backward_kernel( auto value_j = reinterpret_cast(value[batch_idx][l].data()); for (int64_t k = 0; k < K / kVecSize; k++) { -#if 1 vec_t temp_i[kBlockSizeQ]; #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { @@ -728,39 +707,14 @@ __global__ void attention_backward_kernel( #pragma unroll for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { vec_t v = value_j[k + K / kVecSize * k_item_idx]; - //vec_t tt; tt.x = 0; tt.y = 0; tt.z = 0; tt.w = 0; vec_t tt = {0}; #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { sputnik::VectorCompute::Dot(temp_i[q_item_idx], v, &grad_attn_v[q_item_idx][k_item_idx]); - //iMul(v[q_item_idx][k_item_idx], temp_i[q_item_idx]); axpy(attn_v[q_item_idx][k_item_idx], temp_i[q_item_idx], &tt); } myGpuAtomicAdd(&grad_v[batch_idx][l + k_item_idx][k * kVecSize], tt); - //gpuAtomicAdd(&grad_v[batch_idx][l + k_item_idx][k], tt); } - -#else - - scalar_t temp = grad_out[batch_idx][query_idx][k]; - // grad_v[batch_idx][l][k] += attn_v * grad_out[batch_idx][query_idx][k]; - gpuAtomicAdd( - &grad_v[batch_idx][l][k], attn_v[0][0] * temp); - - //temp_grad_v[l][k] += attn_v * grad_out[batch_idx][query_idx][k]; - - grad_attn_v[0][0] += temp * value[batch_idx][l][k]; - -#endif - /* - scalar_t temp = grad_out[batch_idx][query_idx][k]; - // grad_v[batch_idx][l][k] += attn_v * grad_out[batch_idx][query_idx][k]; - gpuAtomicAdd( - &grad_v[batch_idx][l][k], attn_v * temp); - - //temp_grad_v[l][k] += attn_v * grad_out[batch_idx][query_idx][k]; - - grad_attn_v += temp * value[batch_idx][l][k];*/ } // those are temporaries for the gradient of the softmax @@ -773,13 +727,9 @@ __global__ void attention_backward_kernel( tmp_sum[q_item_idx] += tmp[q_item_idx][k_item_idx]; } } - //scalar_t tmp = attn_v * grad_attn_v; - //tmp_sum += tmp; // grad_q is easy for (int64_t k = 0; k < K / kVecSize; k++) { - // grad_q[batch_idx][query_idx][k] += tmp * key_j[k]; - //gpuAtomicAdd(&grad_q_i[k], tmp * key_j[k]); #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { #pragma unroll @@ -794,11 +744,8 @@ __global__ void attention_backward_kernel( // but grad_k is a bit trickier for (int64_t k = 0; k < K / kVecSize; k++) { - // grad_k[batch_idx][l][k] += tmp * query_i[k]; - //gpuAtomicAdd(&grad_k[batch_idx][l][k], tmp * query_i[k]); #pragma unroll for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - //scalar_t res = 0; vec_t res = {0}; #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { @@ -806,12 +753,10 @@ __global__ void attention_backward_kernel( vec_t qqq = query_block[q_item_idx][k]; axpy(tmp[q_item_idx][k_item_idx], qqq, &res); } - //gpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k], res); myGpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k * kVecSize], res); } } } - //tmp_sum = warpSum(tmp_sum); #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { tmp_sum[q_item_idx] = warpSum(tmp_sum[q_item_idx]); @@ -845,15 +790,11 @@ __global__ void attention_backward_kernel( vec_t qqq = query_block[q_item_idx][k]; axpy(ttt, qqq, &res); } - //gpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k], -res); myGpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k * kVecSize], res); } } } for (int64_t k = 0; k < K / kVecSize; k++) { - //temp_grad_q[k] = warpSum(temp_grad_q[k]); - //temp_buffer[k] = warpSum(temp_buffer[k]); - #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { temp_grad_q[q_item_idx][k] = warpSum(temp_grad_q[q_item_idx][k]); @@ -932,18 +873,25 @@ std::tuple attention_backward( //dim3 grid(ceil_div(M, int64_t(16)), B); //dim3 block(32, 16); + using scalar_t = float; + using vec_t = float4; + //using vec_t = float; constexpr int TILE_SIZE = 16 * 2; - constexpr int kBlockSizeQ = 8; + constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); + + constexpr int64_t BUFFER_SIZE = 32 / kVecSize; + constexpr int64_t kBlockSizeQ = 8; + constexpr int64_t kBlockSizeK = 8; dim3 grid(ceil_div(M, int64_t(TILE_SIZE)), B); dim3 block(32, TILE_SIZE / kBlockSizeQ); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - using scalar_t = float; + //AT_DISPATCH_FLOATING_TYPES( // query.scalar_type(), "attention_backward_kernel", [&] { - attention_backward_kernel<<>>( + attention_backward_kernel<<>>( grad_q.packed_accessor(), grad_k.packed_accessor(), grad_v.packed_accessor(), From bd749c99a36bbe6fdf2befbee18cee0fb303d54d Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 15 Apr 2022 04:28:06 -0700 Subject: [PATCH 13/45] clang-format --- .../attention/csrc/cuda/attention.cu | 135 ++++++++++-------- 1 file changed, 73 insertions(+), 62 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 5b10c26c8c..9f33915bf2 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -59,7 +59,6 @@ __device__ __forceinline__ void iDiv(scalar_t x1, float* out) { out[0] /= x1; } - template __device__ __forceinline__ void axpy(scalar_t a, float4 in, float4* out) { out[0].x += a * in.x; @@ -142,7 +141,7 @@ __device__ void compute_dot( scalar_t out[kBlockSizeQ][kBlockSizeK], int64_t K) { constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); - scalar_t scale = 1.0;// / std::sqrt(scalar_t(K)); + scalar_t scale = 1.0; // / std::sqrt(scalar_t(K)); vec_t q_i[kBlockSizeQ]; for (int64_t k = 0; k < K / kVecSize; k += 1) { #pragma unroll @@ -628,7 +627,12 @@ at::Tensor attention( return res; } -template +template < + typename scalar_t, + typename vec_t, + int kBlockSizeQ, + int kBlockSizeK, + int BUFFER_SIZE> __global__ void attention_backward_kernel( at::PackedTensorAccessor grad_q, at::PackedTensorAccessor grad_k, @@ -637,8 +641,7 @@ __global__ void attention_backward_kernel( at::PackedTensorAccessor query, at::PackedTensorAccessor key, at::PackedTensorAccessor value, - at::PackedTensorAccessor logsumexp_normalizer -) { + at::PackedTensorAccessor logsumexp_normalizer) { int64_t K = query.size(2); int64_t B = query.size(0); int64_t M = query.size(1); @@ -646,11 +649,12 @@ __global__ void attention_backward_kernel( constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); - int64_t batch_idx = blockIdx.y; - int64_t query_idx = blockIdx.x * blockDim.y * kBlockSizeQ + threadIdx.y * kBlockSizeQ; + int64_t query_idx = + blockIdx.x * blockDim.y * kBlockSizeQ + threadIdx.y * kBlockSizeQ; - if (query_idx >= M) return; + if (query_idx >= M) + return; vec_t temp_buffer[kBlockSizeQ][BUFFER_SIZE] = {0}; vec_t temp_grad_q[kBlockSizeQ][BUFFER_SIZE] = {0}; @@ -673,27 +677,27 @@ __global__ void attention_backward_kernel( } scalar_t tmp_sum[kBlockSizeQ] = {0}; - for (int64_t l = threadIdx.x * kBlockSizeK; l < N; l += blockDim.x * kBlockSizeK) { - + for (int64_t l = threadIdx.x * kBlockSizeK; l < N; + l += blockDim.x * kBlockSizeK) { auto key_j = reinterpret_cast(key[batch_idx][l].data()); scalar_t attn_v[kBlockSizeQ][kBlockSizeK] = {0}; compute_dot( query_block, key_j, attn_v, K); - #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { #pragma unroll for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - attn_v[q_item_idx][k_item_idx] = std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]); + attn_v[q_item_idx][k_item_idx] = + std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]); } } // now compute grad_q and grad_k // first compute the gradient for the self-attention // after softmax - //scalar_t grad_attn_v = 0; + // scalar_t grad_attn_v = 0; scalar_t grad_attn_v[kBlockSizeQ][kBlockSizeK] = {0}; auto value_j = reinterpret_cast(value[batch_idx][l].data()); @@ -710,7 +714,8 @@ __global__ void attention_backward_kernel( vec_t tt = {0}; #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - sputnik::VectorCompute::Dot(temp_i[q_item_idx], v, &grad_attn_v[q_item_idx][k_item_idx]); + sputnik::VectorCompute::Dot( + temp_i[q_item_idx], v, &grad_attn_v[q_item_idx][k_item_idx]); axpy(attn_v[q_item_idx][k_item_idx], temp_i[q_item_idx], &tt); } myGpuAtomicAdd(&grad_v[batch_idx][l + k_item_idx][k * kVecSize], tt); @@ -723,7 +728,8 @@ __global__ void attention_backward_kernel( for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { #pragma unroll for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - tmp[q_item_idx][k_item_idx] = attn_v[q_item_idx][k_item_idx] * grad_attn_v[q_item_idx][k_item_idx]; + tmp[q_item_idx][k_item_idx] = attn_v[q_item_idx][k_item_idx] * + grad_attn_v[q_item_idx][k_item_idx]; tmp_sum[q_item_idx] += tmp[q_item_idx][k_item_idx]; } } @@ -736,7 +742,8 @@ __global__ void attention_backward_kernel( for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { vec_t ttt = key_j[k + K / kVecSize * k_item_idx]; axpy(tmp[q_item_idx][k_item_idx], ttt, &temp_grad_q[q_item_idx][k]); - axpy(attn_v[q_item_idx][k_item_idx], ttt, &temp_buffer[q_item_idx][k]); + axpy( + attn_v[q_item_idx][k_item_idx], ttt, &temp_buffer[q_item_idx][k]); } } } @@ -749,7 +756,7 @@ __global__ void attention_backward_kernel( vec_t res = {0}; #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - //res += tmp[q_item_idx][k_item_idx] * query_block[q_item_idx][k]; + // res += tmp[q_item_idx][k_item_idx] * query_block[q_item_idx][k]; vec_t qqq = query_block[q_item_idx][k]; axpy(tmp[q_item_idx][k_item_idx], qqq, &res); } @@ -762,61 +769,62 @@ __global__ void attention_backward_kernel( tmp_sum[q_item_idx] = warpSum(tmp_sum[q_item_idx]); } - for (int64_t l = threadIdx.x * kBlockSizeK; l < N; l += blockDim.x * kBlockSizeK) { - - + for (int64_t l = threadIdx.x * kBlockSizeK; l < N; + l += blockDim.x * kBlockSizeK) { auto key_j = reinterpret_cast(key[batch_idx][l].data()); scalar_t attn_v[kBlockSizeQ][kBlockSizeK] = {0}; compute_dot( query_block, key_j, attn_v, K); - #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { #pragma unroll for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - attn_v[q_item_idx][k_item_idx] = std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]); + attn_v[q_item_idx][k_item_idx] = + std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]); } } - for (int64_t k = 0; k < K / kVecSize; k++) { #pragma unroll - for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - vec_t res = {0}; + for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + vec_t res = {0}; #pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - scalar_t ttt = - attn_v[q_item_idx][k_item_idx] * tmp_sum[q_item_idx]; - vec_t qqq = query_block[q_item_idx][k]; - axpy(ttt, qqq, &res); - } - myGpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k * kVecSize], res); + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + scalar_t ttt = -attn_v[q_item_idx][k_item_idx] * tmp_sum[q_item_idx]; + vec_t qqq = query_block[q_item_idx][k]; + axpy(ttt, qqq, &res); + } + myGpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k * kVecSize], res); } } } for (int64_t k = 0; k < K / kVecSize; k++) { #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - temp_grad_q[q_item_idx][k] = warpSum(temp_grad_q[q_item_idx][k]); - temp_buffer[q_item_idx][k] = warpSum(temp_buffer[q_item_idx][k]); + temp_grad_q[q_item_idx][k] = + warpSum(temp_grad_q[q_item_idx][k]); + temp_buffer[q_item_idx][k] = + warpSum(temp_buffer[q_item_idx][k]); } } if (threadIdx.x == 0) { - for (int64_t k = 0; k < K / kVecSize; k++) { - //gpuAtomicAdd(&grad_q_i[k], -temp_buffer[k] * tmp_sum); - //grad_q_i[k] = temp_grad_q[k] - temp_buffer[k] * tmp_sum; + for (int64_t k = 0; k < K / kVecSize; k++) { #pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - //grad_q[batch_idx][query_idx + q_item_idx][k] = temp_grad_q[q_item_idx][k] - temp_buffer[q_item_idx][k] * tmp_sum[q_item_idx]; - //axpy(-tmp_sum[q_item_idx], temp_buffer[q_item_idx][k], &temp_grad_q[q_item_idx][k]); - //grad_q_block[q_item_idx][k] = temp_grad_q[q_item_idx][k]; - grad_q_block[q_item_idx][k].x = temp_grad_q[q_item_idx][k].x - temp_buffer[q_item_idx][k].x * tmp_sum[q_item_idx]; - grad_q_block[q_item_idx][k].y = temp_grad_q[q_item_idx][k].y - temp_buffer[q_item_idx][k].y * tmp_sum[q_item_idx]; - grad_q_block[q_item_idx][k].z = temp_grad_q[q_item_idx][k].z - temp_buffer[q_item_idx][k].z * tmp_sum[q_item_idx]; - grad_q_block[q_item_idx][k].w = temp_grad_q[q_item_idx][k].w - temp_buffer[q_item_idx][k].w * tmp_sum[q_item_idx]; + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + // axpy(-tmp_sum[q_item_idx], temp_buffer[q_item_idx][k], &temp_grad_q[q_item_idx][k]); + // grad_q_block[q_item_idx][k] = temp_grad_q[q_item_idx][k]; + grad_q_block[q_item_idx][k].x = temp_grad_q[q_item_idx][k].x - + temp_buffer[q_item_idx][k].x * tmp_sum[q_item_idx]; + grad_q_block[q_item_idx][k].y = temp_grad_q[q_item_idx][k].y - + temp_buffer[q_item_idx][k].y * tmp_sum[q_item_idx]; + grad_q_block[q_item_idx][k].z = temp_grad_q[q_item_idx][k].z - + temp_buffer[q_item_idx][k].z * tmp_sum[q_item_idx]; + grad_q_block[q_item_idx][k].w = temp_grad_q[q_item_idx][k].w - + temp_buffer[q_item_idx][k].w * tmp_sum[q_item_idx]; + } } } - } } std::tuple attention_backward( @@ -869,13 +877,13 @@ std::tuple attention_backward( at::Tensor grad_v = at::zeros_like(value); // TODO this should be an argument from the function - //at::Tensor logsumexp = query.bmm(key.transpose(-2, -1)).logsumexp(-1); + // at::Tensor logsumexp = query.bmm(key.transpose(-2, -1)).logsumexp(-1); - //dim3 grid(ceil_div(M, int64_t(16)), B); - //dim3 block(32, 16); + // dim3 grid(ceil_div(M, int64_t(16)), B); + // dim3 block(32, 16); using scalar_t = float; using vec_t = float4; - //using vec_t = float; + // using vec_t = float; constexpr int TILE_SIZE = 16 * 2; constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); @@ -888,20 +896,23 @@ std::tuple attention_backward( cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - //AT_DISPATCH_FLOATING_TYPES( + // AT_DISPATCH_FLOATING_TYPES( // query.scalar_type(), "attention_backward_kernel", [&] { - attention_backward_kernel<<>>( - grad_q.packed_accessor(), - grad_k.packed_accessor(), - grad_v.packed_accessor(), - grad_out.packed_accessor(), - query.packed_accessor(), - key.packed_accessor(), - value.packed_accessor(), - logsumexp.packed_accessor() - ); - // }); + attention_backward_kernel< + scalar_t, + vec_t, + kBlockSizeQ, + kBlockSizeK, + BUFFER_SIZE><<>>( + grad_q.packed_accessor(), + grad_k.packed_accessor(), + grad_v.packed_accessor(), + grad_out.packed_accessor(), + query.packed_accessor(), + key.packed_accessor(), + value.packed_accessor(), + logsumexp.packed_accessor()); + // }); AT_CUDA_CHECK(cudaGetLastError()); From 86e87f947ecd548ed6b3611094bfa770b683aba4 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 15 Apr 2022 05:06:57 -0700 Subject: [PATCH 14/45] Make it 0.5% faster Use all threads to compute grad_q --- xformers/components/attention/csrc/cuda/attention.cu | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 9f33915bf2..94cf10f380 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -808,8 +808,7 @@ __global__ void attention_backward_kernel( warpSum(temp_buffer[q_item_idx][k]); } } - if (threadIdx.x == 0) { - for (int64_t k = 0; k < K / kVecSize; k++) { + for (int64_t k = threadIdx.x; k < K / kVecSize; k += blockDim.x) { #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { // axpy(-tmp_sum[q_item_idx], temp_buffer[q_item_idx][k], &temp_grad_q[q_item_idx][k]); @@ -824,7 +823,6 @@ __global__ void attention_backward_kernel( temp_buffer[q_item_idx][k].w * tmp_sum[q_item_idx]; } } - } } std::tuple attention_backward( @@ -876,14 +874,10 @@ std::tuple attention_backward( at::Tensor grad_k = at::zeros_like(key); at::Tensor grad_v = at::zeros_like(value); - // TODO this should be an argument from the function - // at::Tensor logsumexp = query.bmm(key.transpose(-2, -1)).logsumexp(-1); - - // dim3 grid(ceil_div(M, int64_t(16)), B); - // dim3 block(32, 16); using scalar_t = float; using vec_t = float4; // using vec_t = float; + constexpr int TILE_SIZE = 16 * 2; constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); From d3e2140eeefa620483453206c6e65b2534c3f272 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 15 Apr 2022 06:26:31 -0700 Subject: [PATCH 15/45] Make it 1% faster by caching the loads --- .../components/attention/csrc/cuda/attention.cu | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 94cf10f380..c6cf62e4d0 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -664,6 +664,8 @@ __global__ void attention_backward_kernel( vec_t* grad_q_block[kBlockSizeQ]; scalar_t normalizer[kBlockSizeQ]; + //__shared__ vec_t query_cache[kBlockSizeQ][BUFFER_SIZE]; + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { int64_t index = query_idx + q_item_idx; index = index >= M ? M - 1 : index; @@ -674,7 +676,11 @@ __global__ void attention_backward_kernel( grad_q_block[q_item_idx] = reinterpret_cast(grad_q[batch_idx][index].data()); normalizer[q_item_idx] = logsumexp_normalizer[batch_idx][index]; + //for (int64_t k = threadIdx.x; k < K / kVecSize; k += blockDim.x) { + // query_cache[q_item_idx][k] = query_block[q_item_idx][k]; + // } } + //__syncthreads(); scalar_t tmp_sum[kBlockSizeQ] = {0}; for (int64_t l = threadIdx.x * kBlockSizeK; l < N; @@ -683,6 +689,7 @@ __global__ void attention_backward_kernel( scalar_t attn_v[kBlockSizeQ][kBlockSizeK] = {0}; compute_dot( + //query_cache, key_j, attn_v, K); query_block, key_j, attn_v, K); #pragma unroll @@ -757,7 +764,8 @@ __global__ void attention_backward_kernel( #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { // res += tmp[q_item_idx][k_item_idx] * query_block[q_item_idx][k]; - vec_t qqq = query_block[q_item_idx][k]; + vec_t qqq = __ldg(query_block[q_item_idx] + k); + //vec_t qqq = query_cache[q_item_idx][k]; axpy(tmp[q_item_idx][k_item_idx], qqq, &res); } myGpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k * kVecSize], res); @@ -792,7 +800,8 @@ __global__ void attention_backward_kernel( #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { scalar_t ttt = -attn_v[q_item_idx][k_item_idx] * tmp_sum[q_item_idx]; - vec_t qqq = query_block[q_item_idx][k]; + vec_t qqq = __ldg(query_block[q_item_idx] + k); + //vec_t qqq = query_cache[q_item_idx][k]; axpy(ttt, qqq, &res); } myGpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k * kVecSize], res); From 6cc57688a58b3b0738a9f5c629710809af5e1f56 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 15 Apr 2022 06:41:31 -0700 Subject: [PATCH 16/45] Make it 6% faster with better hyperparameters --- xformers/components/attention/csrc/cuda/attention.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index c6cf62e4d0..dc5d2ea60f 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -887,12 +887,12 @@ std::tuple attention_backward( using vec_t = float4; // using vec_t = float; - constexpr int TILE_SIZE = 16 * 2; + constexpr int TILE_SIZE = 16 * 4; constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); constexpr int64_t BUFFER_SIZE = 32 / kVecSize; - constexpr int64_t kBlockSizeQ = 8; - constexpr int64_t kBlockSizeK = 8; + constexpr int64_t kBlockSizeQ = 16; + constexpr int64_t kBlockSizeK = 4; dim3 grid(ceil_div(M, int64_t(TILE_SIZE)), B); dim3 block(32, TILE_SIZE / kBlockSizeQ); From 8d493a7ea0cbe7ce7311b82f346d428455cf6055 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 15 Apr 2022 07:07:25 -0700 Subject: [PATCH 17/45] Slightly better hyperparameter --- xformers/components/attention/csrc/cuda/attention.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index dc5d2ea60f..82472bd824 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -887,7 +887,7 @@ std::tuple attention_backward( using vec_t = float4; // using vec_t = float; - constexpr int TILE_SIZE = 16 * 4; + constexpr int TILE_SIZE = 16 * 8; constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); constexpr int64_t BUFFER_SIZE = 32 / kVecSize; From 40b9f4308d065447afc2c1931e26601222342bc8 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sun, 17 Apr 2022 03:04:49 -0700 Subject: [PATCH 18/45] axpy == FMA --- .../attention/csrc/cuda/attention.cu | 31 ++++--------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 82472bd824..9c3a7974f8 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -59,25 +59,6 @@ __device__ __forceinline__ void iDiv(scalar_t x1, float* out) { out[0] /= x1; } -template -__device__ __forceinline__ void axpy(scalar_t a, float4 in, float4* out) { - out[0].x += a * in.x; - out[0].y += a * in.y; - out[0].z += a * in.z; - out[0].w += a * in.w; -} - -template -__device__ __forceinline__ void axpy(scalar_t a, float2 in, float2* out) { - out[0].x += a * in.x; - out[0].y += a * in.y; -} - -template -__device__ __forceinline__ void axpy(scalar_t a, float in, float* out) { - out[0] += a * in; -} - template __device__ __forceinline__ void myGpuAtomicAdd(scalar_t* address, float4 val) { gpuAtomicAdd(address + 0, val.x); @@ -723,7 +704,7 @@ __global__ void attention_backward_kernel( for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { sputnik::VectorCompute::Dot( temp_i[q_item_idx], v, &grad_attn_v[q_item_idx][k_item_idx]); - axpy(attn_v[q_item_idx][k_item_idx], temp_i[q_item_idx], &tt); + sputnik::VectorCompute::FMA(attn_v[q_item_idx][k_item_idx], temp_i[q_item_idx], &tt); } myGpuAtomicAdd(&grad_v[batch_idx][l + k_item_idx][k * kVecSize], tt); } @@ -748,8 +729,8 @@ __global__ void attention_backward_kernel( #pragma unroll for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { vec_t ttt = key_j[k + K / kVecSize * k_item_idx]; - axpy(tmp[q_item_idx][k_item_idx], ttt, &temp_grad_q[q_item_idx][k]); - axpy( + sputnik::VectorCompute::FMA(tmp[q_item_idx][k_item_idx], ttt, &temp_grad_q[q_item_idx][k]); + sputnik::VectorCompute::FMA( attn_v[q_item_idx][k_item_idx], ttt, &temp_buffer[q_item_idx][k]); } } @@ -766,7 +747,7 @@ __global__ void attention_backward_kernel( // res += tmp[q_item_idx][k_item_idx] * query_block[q_item_idx][k]; vec_t qqq = __ldg(query_block[q_item_idx] + k); //vec_t qqq = query_cache[q_item_idx][k]; - axpy(tmp[q_item_idx][k_item_idx], qqq, &res); + sputnik::VectorCompute::FMA(tmp[q_item_idx][k_item_idx], qqq, &res); } myGpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k * kVecSize], res); } @@ -802,7 +783,7 @@ __global__ void attention_backward_kernel( scalar_t ttt = -attn_v[q_item_idx][k_item_idx] * tmp_sum[q_item_idx]; vec_t qqq = __ldg(query_block[q_item_idx] + k); //vec_t qqq = query_cache[q_item_idx][k]; - axpy(ttt, qqq, &res); + sputnik::VectorCompute::FMA(ttt, qqq, &res); } myGpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k * kVecSize], res); } @@ -820,7 +801,7 @@ __global__ void attention_backward_kernel( for (int64_t k = threadIdx.x; k < K / kVecSize; k += blockDim.x) { #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - // axpy(-tmp_sum[q_item_idx], temp_buffer[q_item_idx][k], &temp_grad_q[q_item_idx][k]); + // sputnik::VectorCompute::FMA(-tmp_sum[q_item_idx], temp_buffer[q_item_idx][k], &temp_grad_q[q_item_idx][k]); // grad_q_block[q_item_idx][k] = temp_grad_q[q_item_idx][k]; grad_q_block[q_item_idx][k].x = temp_grad_q[q_item_idx][k].x - temp_buffer[q_item_idx][k].x * tmp_sum[q_item_idx]; From ca2bb7247b8f3228d04ad6553a7bf118dda8fbb6 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 19 Apr 2022 04:21:14 -0700 Subject: [PATCH 19/45] Separate grad_q into its own kernel This brings 50% speedup compared to the previous approach, despite redundant computation. The benefit comes from the fact that we are using better block sizes for the matmul computation of grad_q, which doesnt involve the transpose of the attention matrix --- .../attention/csrc/cuda/attention.cu | 163 +++++++++++++++++- 1 file changed, 155 insertions(+), 8 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 9c3a7974f8..f0bac81d04 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -622,6 +622,7 @@ __global__ void attention_backward_kernel( at::PackedTensorAccessor query, at::PackedTensorAccessor key, at::PackedTensorAccessor value, + at::PackedTensorAccessor tmp_sum_i, at::PackedTensorAccessor logsumexp_normalizer) { int64_t K = query.size(2); int64_t B = query.size(0); @@ -637,12 +638,12 @@ __global__ void attention_backward_kernel( if (query_idx >= M) return; - vec_t temp_buffer[kBlockSizeQ][BUFFER_SIZE] = {0}; - vec_t temp_grad_q[kBlockSizeQ][BUFFER_SIZE] = {0}; + //vec_t temp_buffer[kBlockSizeQ][BUFFER_SIZE] = {0}; + //vec_t temp_grad_q[kBlockSizeQ][BUFFER_SIZE] = {0}; vec_t* query_block[kBlockSizeQ]; vec_t* grad_out_block[kBlockSizeQ]; - vec_t* grad_q_block[kBlockSizeQ]; + //vec_t* grad_q_block[kBlockSizeQ]; scalar_t normalizer[kBlockSizeQ]; //__shared__ vec_t query_cache[kBlockSizeQ][BUFFER_SIZE]; @@ -654,8 +655,8 @@ __global__ void attention_backward_kernel( reinterpret_cast(query[batch_idx][index].data()); grad_out_block[q_item_idx] = reinterpret_cast(grad_out[batch_idx][index].data()); - grad_q_block[q_item_idx] = - reinterpret_cast(grad_q[batch_idx][index].data()); + //grad_q_block[q_item_idx] = + // reinterpret_cast(grad_q[batch_idx][index].data()); normalizer[q_item_idx] = logsumexp_normalizer[batch_idx][index]; //for (int64_t k = threadIdx.x; k < K / kVecSize; k += blockDim.x) { // query_cache[q_item_idx][k] = query_block[q_item_idx][k]; @@ -721,7 +722,7 @@ __global__ void attention_backward_kernel( tmp_sum[q_item_idx] += tmp[q_item_idx][k_item_idx]; } } - +/* // grad_q is easy for (int64_t k = 0; k < K / kVecSize; k++) { #pragma unroll @@ -735,7 +736,7 @@ __global__ void attention_backward_kernel( } } } - +*/ // but grad_k is a bit trickier for (int64_t k = 0; k < K / kVecSize; k++) { @@ -756,6 +757,7 @@ __global__ void attention_backward_kernel( #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { tmp_sum[q_item_idx] = warpSum(tmp_sum[q_item_idx]); + tmp_sum_i[batch_idx][query_idx + q_item_idx] = tmp_sum[q_item_idx]; } for (int64_t l = threadIdx.x * kBlockSizeK; l < N; @@ -789,6 +791,7 @@ __global__ void attention_backward_kernel( } } } + /* for (int64_t k = 0; k < K / kVecSize; k++) { #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { @@ -812,9 +815,123 @@ __global__ void attention_backward_kernel( grad_q_block[q_item_idx][k].w = temp_grad_q[q_item_idx][k].w - temp_buffer[q_item_idx][k].w * tmp_sum[q_item_idx]; } + }*/ +} + + +template < + typename scalar_t, + typename vec_t, + int kBlockSizeQ, + int kBlockSizeK, + int BUFFER_SIZE> +__global__ void attention_backward_kernel2( + at::PackedTensorAccessor grad_q, + at::PackedTensorAccessor grad_k, + at::PackedTensorAccessor grad_v, + at::PackedTensorAccessor grad_out, + at::PackedTensorAccessor query, + at::PackedTensorAccessor key, + at::PackedTensorAccessor value, + at::PackedTensorAccessor tmp_sum_i, + at::PackedTensorAccessor logsumexp_normalizer) { + int64_t K = query.size(2); + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + + constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); + + int64_t batch_idx = blockIdx.y; + int64_t query_idx = + blockIdx.x * blockDim.y * kBlockSizeQ + threadIdx.y * kBlockSizeQ; + + if (query_idx >= M) + return; + + vec_t* query_block[kBlockSizeQ]; + vec_t* grad_out_block[kBlockSizeQ]; + scalar_t normalizer[kBlockSizeQ]; + scalar_t tmp_sum[kBlockSizeQ]; + + //__shared__ vec_t query_cache[kBlockSizeQ][BUFFER_SIZE]; + + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + int64_t index = query_idx + q_item_idx; + index = index >= M ? M - 1 : index; + query_block[q_item_idx] = + reinterpret_cast(query[batch_idx][index].data()); + grad_out_block[q_item_idx] = + reinterpret_cast(grad_out[batch_idx][index].data()); + normalizer[q_item_idx] = logsumexp_normalizer[batch_idx][index]; + tmp_sum[q_item_idx] = tmp_sum_i[batch_idx][index]; + //for (int64_t k = threadIdx.x; k < K / kVecSize; k += blockDim.x) { + // query_cache[q_item_idx][k] = query_block[q_item_idx][k]; + // } + } + //__syncthreads(); + + //scalar_t tmp_sum[kBlockSizeQ] = {0}; + + for (int64_t l = threadIdx.x * kBlockSizeK; l < N; + l += blockDim.x * kBlockSizeK) { + auto key_j = reinterpret_cast(key[batch_idx][l].data()); + scalar_t attn_v[kBlockSizeQ][kBlockSizeK] = {0}; + compute_dot( + query_block, key_j, attn_v, K); + +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { +#pragma unroll + for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + attn_v[q_item_idx][k_item_idx] = + std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]); + } + } + + scalar_t grad_attn_v[kBlockSizeQ][kBlockSizeK] = {0}; + auto value_j = reinterpret_cast(value[batch_idx][l].data()); + + for (int64_t k = 0; k < K / kVecSize; k++) { + vec_t temp_i[kBlockSizeQ]; +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + temp_i[q_item_idx] = __ldg(grad_out_block[q_item_idx] + k); + } + +#pragma unroll + for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + vec_t v = value_j[k + K / kVecSize * k_item_idx]; +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + sputnik::VectorCompute::Dot( + temp_i[q_item_idx], v, &grad_attn_v[q_item_idx][k_item_idx]); + } + } + } + + for (int64_t k = 0; k < K / kVecSize; k++) { + vec_t res[kBlockSizeQ] = {0}; +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { +#pragma unroll + for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + vec_t ttt = key_j[k + K / kVecSize * k_item_idx]; + scalar_t ttmp = attn_v[q_item_idx][k_item_idx] * grad_attn_v[q_item_idx][k_item_idx] - attn_v[q_item_idx][k_item_idx] * tmp_sum[q_item_idx]; + sputnik::VectorCompute::FMA(ttmp, ttt, &res[q_item_idx]); + } + } +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + myGpuAtomicAdd(&grad_q[batch_idx][query_idx + q_item_idx][k * kVecSize], res[q_item_idx]); + + } } + + } } + std::tuple attention_backward( const at::Tensor& grad_out, const at::Tensor& query, @@ -859,11 +976,13 @@ std::tuple attention_backward( int64_t N = key.size(1); int64_t K = query.size(2); - at::Tensor res = at::empty({B, M, K}, query.options()); + //at::Tensor res = at::empty({B, M, K}, query.options()); at::Tensor grad_q = at::zeros_like(query); at::Tensor grad_k = at::zeros_like(key); at::Tensor grad_v = at::zeros_like(value); + at::Tensor tmp_sum_i = at::empty({B, M}, query.options()); + using scalar_t = float; using vec_t = float4; // using vec_t = float; @@ -895,9 +1014,37 @@ std::tuple attention_backward( query.packed_accessor(), key.packed_accessor(), value.packed_accessor(), + tmp_sum_i.packed_accessor(), logsumexp.packed_accessor()); // }); + + constexpr int TILE_SIZE2 = 32; + + constexpr int64_t kBlockSizeQ2 = 2; + constexpr int64_t kBlockSizeK2 = 32; + + dim3 grid2(ceil_div(M, int64_t(TILE_SIZE2)), B); + dim3 block2(4, TILE_SIZE2 / kBlockSizeQ2); + + attention_backward_kernel2< + scalar_t, + vec_t, + kBlockSizeQ2, + kBlockSizeK2, + BUFFER_SIZE><<>>( + grad_q.packed_accessor(), + grad_k.packed_accessor(), + grad_v.packed_accessor(), + grad_out.packed_accessor(), + query.packed_accessor(), + key.packed_accessor(), + value.packed_accessor(), + tmp_sum_i.packed_accessor(), + logsumexp.packed_accessor()); + // }); + + AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(grad_q, grad_k, grad_v); From 63bd28696c99fd4d717d5209cb83c99afb57dec5 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 19 Apr 2022 07:33:29 -0700 Subject: [PATCH 20/45] Avoid additional global writes by recomputing grad_aatn_v in grad_k Brings an additional 12% speedup despite duplicate computation --- .../attention/csrc/cuda/attention.cu | 50 ++++++++++++++++--- 1 file changed, 43 insertions(+), 7 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index f0bac81d04..6c898bfacf 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -712,14 +712,15 @@ __global__ void attention_backward_kernel( } // those are temporaries for the gradient of the softmax - scalar_t tmp[kBlockSizeQ][kBlockSizeK]; + //scalar_t tmp[kBlockSizeQ][kBlockSizeK]; #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { #pragma unroll for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - tmp[q_item_idx][k_item_idx] = attn_v[q_item_idx][k_item_idx] * - grad_attn_v[q_item_idx][k_item_idx]; - tmp_sum[q_item_idx] += tmp[q_item_idx][k_item_idx]; + //tmp[q_item_idx][k_item_idx] = attn_v[q_item_idx][k_item_idx] * + // grad_attn_v[q_item_idx][k_item_idx]; + //tmp_sum[q_item_idx] += tmp[q_item_idx][k_item_idx]; + tmp_sum[q_item_idx] += attn_v[q_item_idx][k_item_idx] * grad_attn_v[q_item_idx][k_item_idx]; } } /* @@ -738,7 +739,7 @@ __global__ void attention_backward_kernel( } */ // but grad_k is a bit trickier - +/* for (int64_t k = 0; k < K / kVecSize; k++) { #pragma unroll for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { @@ -752,7 +753,7 @@ __global__ void attention_backward_kernel( } myGpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k * kVecSize], res); } - } + }*/ } #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { @@ -776,13 +777,48 @@ __global__ void attention_backward_kernel( } } + scalar_t grad_attn_v[kBlockSizeQ][kBlockSizeK] = {0}; + auto value_j = reinterpret_cast(value[batch_idx][l].data()); + + for (int64_t k = 0; k < K / kVecSize; k++) { + vec_t temp_i[kBlockSizeQ]; +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + temp_i[q_item_idx] = __ldg(grad_out_block[q_item_idx] + k); + } + +#pragma unroll + for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + vec_t v = value_j[k + K / kVecSize * k_item_idx]; + vec_t tt = {0}; +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + sputnik::VectorCompute::Dot( + temp_i[q_item_idx], v, &grad_attn_v[q_item_idx][k_item_idx]); + } + } + } + + // those are temporaries for the gradient of the softmax + scalar_t tmp[kBlockSizeQ][kBlockSizeK]; +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { +#pragma unroll + for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + tmp[q_item_idx][k_item_idx] = attn_v[q_item_idx][k_item_idx] * + grad_attn_v[q_item_idx][k_item_idx]; + } + } + + for (int64_t k = 0; k < K / kVecSize; k++) { #pragma unroll for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { vec_t res = {0}; #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - scalar_t ttt = -attn_v[q_item_idx][k_item_idx] * tmp_sum[q_item_idx]; + //scalar_t ttt = -attn_v[q_item_idx][k_item_idx] * tmp_sum[q_item_idx]; + scalar_t ttt = tmp[q_item_idx][k_item_idx] - attn_v[q_item_idx][k_item_idx] * tmp_sum[q_item_idx]; vec_t qqq = __ldg(query_block[q_item_idx] + k); //vec_t qqq = query_cache[q_item_idx][k]; sputnik::VectorCompute::FMA(ttt, qqq, &res); From f1e7c7cfb227bb014e066d62da53bcc3c21191f0 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 20 Apr 2022 00:58:15 -0700 Subject: [PATCH 21/45] Trying out new idea --- .../attention/csrc/cuda/attention.cu | 239 +++++++++++++++++- 1 file changed, 225 insertions(+), 14 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 6c898bfacf..55daf1c8cf 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -760,7 +760,7 @@ __global__ void attention_backward_kernel( tmp_sum[q_item_idx] = warpSum(tmp_sum[q_item_idx]); tmp_sum_i[batch_idx][query_idx + q_item_idx] = tmp_sum[q_item_idx]; } - +/* for (int64_t l = threadIdx.x * kBlockSizeK; l < N; l += blockDim.x * kBlockSizeK) { auto key_j = reinterpret_cast(key[batch_idx][l].data()); @@ -826,7 +826,7 @@ __global__ void attention_backward_kernel( myGpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k * kVecSize], res); } } - } + }*/ /* for (int64_t k = 0; k < K / kVecSize; k++) { #pragma unroll @@ -861,7 +861,7 @@ template < int kBlockSizeQ, int kBlockSizeK, int BUFFER_SIZE> -__global__ void attention_backward_kernel2( +__global__ void attention_backward_kernel3( at::PackedTensorAccessor grad_q, at::PackedTensorAccessor grad_k, at::PackedTensorAccessor grad_v, @@ -878,13 +878,19 @@ __global__ void attention_backward_kernel2( constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); - int64_t batch_idx = blockIdx.y; + int64_t batch_idx = blockIdx.z; int64_t query_idx = - blockIdx.x * blockDim.y * kBlockSizeQ + threadIdx.y * kBlockSizeQ; + blockIdx.x * blockDim.x * kBlockSizeQ + threadIdx.x * kBlockSizeQ; + + int64_t l = + blockIdx.y * blockDim.y * kBlockSizeK + threadIdx.y * kBlockSizeK; if (query_idx >= M) return; + if (l >= N) + return; + vec_t* query_block[kBlockSizeQ]; vec_t* grad_out_block[kBlockSizeQ]; scalar_t normalizer[kBlockSizeQ]; @@ -909,12 +915,31 @@ __global__ void attention_backward_kernel2( //scalar_t tmp_sum[kBlockSizeQ] = {0}; - for (int64_t l = threadIdx.x * kBlockSizeK; l < N; - l += blockDim.x * kBlockSizeK) { + //for (int64_t l = threadIdx.x * kBlockSizeK; l < N; + // l += blockDim.x * kBlockSizeK) + { auto key_j = reinterpret_cast(key[batch_idx][l].data()); scalar_t attn_v[kBlockSizeQ][kBlockSizeK] = {0}; - compute_dot( - query_block, key_j, attn_v, K); + //compute_dot( + // query_block, key_j, attn_v, K); + + vec_t q_i[kBlockSizeQ]; + for (int64_t k = 0; k < K / kVecSize; k += 1) { +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + q_i[q_item_idx] = __ldg(query_block[q_item_idx] + k); + } +#pragma unroll + for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + vec_t k_i = key_j[k + K / kVecSize * k_item_idx]; +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + sputnik::VectorCompute::Dot( + q_i[q_item_idx], k_i, &attn_v[q_item_idx][k_item_idx]); + } + } + } + #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { @@ -964,10 +989,191 @@ __global__ void attention_backward_kernel2( } } + + for (int64_t k = 0; k < K / kVecSize; k++) { +#pragma unroll + for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + vec_t res = {0}; +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + //scalar_t ttt = -attn_v[q_item_idx][k_item_idx] * tmp_sum[q_item_idx]; + scalar_t ttt = attn_v[q_item_idx][k_item_idx] * grad_attn_v[q_item_idx][k_item_idx] - attn_v[q_item_idx][k_item_idx] * tmp_sum[q_item_idx]; + vec_t qqq = __ldg(query_block[q_item_idx] + k); + //vec_t qqq = query_cache[q_item_idx][k]; + sputnik::VectorCompute::FMA(ttt, qqq, &res); + } + myGpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k * kVecSize], res); + } + } + } + +} + + +template < + typename scalar_t, + typename vec_t, + int kBlockSizeQ, + int kBlockSizeK, + int BUFFER_SIZE> +__global__ void attention_backward_kernel2( + at::PackedTensorAccessor grad_q, + at::PackedTensorAccessor grad_k, + at::PackedTensorAccessor grad_v, + at::PackedTensorAccessor grad_out, + at::PackedTensorAccessor query, + at::PackedTensorAccessor key, + at::PackedTensorAccessor value, + at::PackedTensorAccessor tmp_sum_i, + at::PackedTensorAccessor logsumexp_normalizer) { + int64_t K = query.size(2); + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + + constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); + + int64_t batch_idx = blockIdx.z; + int64_t query_idx = + blockIdx.x * blockDim.x + threadIdx.x; + + int64_t l = + blockIdx.y * blockDim.y + threadIdx.y; + + if (query_idx >= M) + return; + + if (l >= N) + return; + + scalar_t normalizer; + scalar_t tmp_sum; + + constexpr int KS1 = 16; + constexpr int KS2 = 16; + + __shared__ vec_t query_cache[KS1][BUFFER_SIZE]; + __shared__ vec_t key_cache[KS2][BUFFER_SIZE]; + //__shared__ vec_t value_cache[KS2][BUFFER_SIZE]; + //__shared__ vec_t grad_out_cache[KS1][BUFFER_SIZE]; + __shared__ scalar_t fact[KS1][KS2]; + + auto qb = reinterpret_cast(query[batch_idx][query_idx].data()); + auto kb = reinterpret_cast(key[batch_idx][l].data()); + auto vb = reinterpret_cast(value[batch_idx][l].data()); + auto gb = reinterpret_cast(grad_out[batch_idx][query_idx].data()); + + //__shared__ vec_t tmp_grad1[KS1][BUFFER_SIZE]; + //__shared__ vec_t tmp_grad2[KS2][BUFFER_SIZE]; + //vec_t query_cache[KS1][BUFFER_SIZE]; + //vec_t key_cache[KS2][BUFFER_SIZE]; + + //vec_t tmp_grad1[KS1][BUFFER_SIZE] = {0}; + //vec_t tmp_grad2[KS2][BUFFER_SIZE] = {0}; + + vec_t zero = {0}; + for (int64_t k = 0; k < K / kVecSize; k++) { + query_cache[threadIdx.x][k] = qb[k]; + key_cache[threadIdx.y][k] = kb[k]; + //value_cache[threadIdx.y][k] = vb[k]; + //grad_out_cache[threadIdx.x][k] = gb[k]; + //tmp_grad1[threadIdx.x][k] = zero; + //tmp_grad2[threadIdx.y][k] = zero; + } + //__syncwarp(); + //__syncthreads(); + + + normalizer = logsumexp_normalizer[batch_idx][query_idx]; + tmp_sum = tmp_sum_i[batch_idx][query_idx]; + + auto key_j = reinterpret_cast(key[batch_idx][l].data()); + scalar_t attn_v = 0; + scalar_t grad_attn_v = 0; + + for (int64_t k = 0; k < K / kVecSize; k += 1) { + //sputnik::VectorCompute::Dot(query_block[k], key_j[k], &attn_v); + //sputnik::VectorCompute::Dot(__ldg(qb + k), __ldg(kb + k), &attn_v); + sputnik::VectorCompute::Dot(query_cache[threadIdx.x][k], key_cache[threadIdx.y][k], &attn_v); + sputnik::VectorCompute::Dot(__ldg(gb + k), __ldg(vb + k), &grad_attn_v); + } + attn_v = std::exp(attn_v - normalizer); + + + /* + for (int64_t k = 0; k < K / kVecSize; k++) { + sputnik::VectorCompute::Dot(__ldg(gb + k), __ldg(vb + k), &grad_attn_v); + //sputnik::VectorCompute::Dot(grad_out_cache[threadIdx.x][k], value_cache[threadIdx.y][k], &grad_attn_v); + }*/ + + fact[threadIdx.x][threadIdx.y] = attn_v * grad_attn_v - attn_v * tmp_sum; + __syncthreads(); + + + + for (int64_t k = threadIdx.y; k < K / kVecSize; k+= blockDim.y) { + vec_t res = {0}; + for (int64_t i = 0; i < KS2; i++) { + sputnik::VectorCompute::FMA(fact[threadIdx.x][i], key_cache[i][k], &res); + //sputnik::VectorCompute::FMA(fact[threadIdx.x][i], __ldg(kb + i * K / kVecSize + k), &res); + } + //if (threadIdx.y == 0) + myGpuAtomicAdd(&grad_q[batch_idx][query_idx][k * kVecSize], res); + } + + for (int64_t k = threadIdx.x; k < K / kVecSize; k+= blockDim.x) { + vec_t res = {0}; + for (int64_t i = 0; i < KS1; i++) { + sputnik::VectorCompute::FMA(fact[i][threadIdx.y], query_cache[i][k], &res); + //sputnik::VectorCompute::FMA(fact[i][threadIdx.y], __ldg(qb + i * K / kVecSize + k), &res); + } + //if (threadIdx.x == 0) + myGpuAtomicAdd(&grad_k[batch_idx][l][k * kVecSize], res); + } +/* + for (int64_t k = 0; k < K / kVecSize; k++) { + vec_t res = {0}; + scalar_t ttmp = attn_v * grad_attn_v - attn_v * tmp_sum; + //sputnik::VectorCompute::FMA(ttmp, key_j[k], &res); + sputnik::VectorCompute::FMA(ttmp, key_cache[threadIdx.y][k], &res); + //sputnik::VectorCompute::FMA(ttmp, key_cache[threadIdx.y][k], &tmp_grad1[threadIdx.x][k]); + //myGpuAtomicAdd(&grad_q[batch_idx][query_idx][k * kVecSize], res); } + + + for (int64_t k = 0; k < K / kVecSize; k++) { + //vec_t res = {0}; + scalar_t ttmp = attn_v * grad_attn_v - attn_v * tmp_sum; + //sputnik::VectorCompute::FMA(ttmp, query_block[k], &res); + sputnik::VectorCompute::FMA(ttmp, query_cache[threadIdx.x][k], &res); + //sputnik::VectorCompute::FMA(ttmp, query_cache[threadIdx.x][k], &tmp_grad2[threadIdx.y][k]); + //myGpuAtomicAdd(&grad_k[batch_idx][l][k * kVecSize], res); + } +*/ +/* + if ((threadIdx.x == 0) && (threadIdx.y == 0)) { + for (int64_t k = 0; k < K / kVecSize; k++) { + vec_t res0 = {0}; + for (int i = 0; i < KS1; i++) + { + sputnik::VectorCompute::FMA(scalar_t(1), tmp_grad1[i][k], &res0); + } + myGpuAtomicAdd(&grad_q[batch_idx][query_idx][k * kVecSize], res0); + + vec_t res1 = {0}; + for (int i = 0; i < KS2; i++) + { + sputnik::VectorCompute::FMA(scalar_t(1), tmp_grad2[i][k], &res1); + } + myGpuAtomicAdd(&grad_k[batch_idx][l][k * kVecSize], res1); + } + }*/ + + } + std::tuple attention_backward( const at::Tensor& grad_out, const at::Tensor& query, @@ -1055,13 +1261,18 @@ std::tuple attention_backward( // }); - constexpr int TILE_SIZE2 = 32; + constexpr int TILE_SIZE2Q = 32 * 4; + constexpr int TILE_SIZE2K = 32 * 8; + + constexpr int64_t kBlockSizeQ2 = 8;//2; + constexpr int64_t kBlockSizeK2 = 16;//32; - constexpr int64_t kBlockSizeQ2 = 2; - constexpr int64_t kBlockSizeK2 = 32; + //dim3 grid2(ceil_div(M, int64_t(TILE_SIZE2Q)), ceil_div(N, int64_t(TILE_SIZE2K)), B); + //dim3 block2(TILE_SIZE2Q / kBlockSizeQ2, TILE_SIZE2K / kBlockSizeK2); - dim3 grid2(ceil_div(M, int64_t(TILE_SIZE2)), B); - dim3 block2(4, TILE_SIZE2 / kBlockSizeQ2); + dim3 grid2(ceil_div(M, int64_t(16)), ceil_div(N, int64_t(16)), B); + dim3 block2(16, 16); + // TODO: try adding a blockDim.x to iterate over k attention_backward_kernel2< scalar_t, From b6b0cfc462c9ee8f05b007162b119374ad80ea21 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 20 Apr 2022 02:08:38 -0700 Subject: [PATCH 22/45] Almost on par with my previous best implementation --- .../attention/csrc/cuda/attention.cu | 143 ++++++++---------- 1 file changed, 64 insertions(+), 79 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 55daf1c8cf..0c51c88b35 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -1032,13 +1032,15 @@ __global__ void attention_backward_kernel2( int64_t N = key.size(1); constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); + constexpr int BLOCK = 4; + constexpr int BLOCK2 = 4; int64_t batch_idx = blockIdx.z; int64_t query_idx = - blockIdx.x * blockDim.x + threadIdx.x; + blockIdx.x * blockDim.x * BLOCK + threadIdx.x * BLOCK; int64_t l = - blockIdx.y * blockDim.y + threadIdx.y; + blockIdx.y * blockDim.y * BLOCK2 + threadIdx.y * BLOCK2; if (query_idx >= M) return; @@ -1046,8 +1048,8 @@ __global__ void attention_backward_kernel2( if (l >= N) return; - scalar_t normalizer; - scalar_t tmp_sum; + scalar_t normalizer[BLOCK]; + scalar_t tmp_sum[BLOCK]; constexpr int KS1 = 16; constexpr int KS2 = 16; @@ -1068,107 +1070,90 @@ __global__ void attention_backward_kernel2( //vec_t query_cache[KS1][BUFFER_SIZE]; //vec_t key_cache[KS2][BUFFER_SIZE]; - //vec_t tmp_grad1[KS1][BUFFER_SIZE] = {0}; - //vec_t tmp_grad2[KS2][BUFFER_SIZE] = {0}; - vec_t zero = {0}; for (int64_t k = 0; k < K / kVecSize; k++) { - query_cache[threadIdx.x][k] = qb[k]; - key_cache[threadIdx.y][k] = kb[k]; + for (int i = 0; i < BLOCK; i++) + query_cache[BLOCK * threadIdx.x + i][k] = qb[k + K / kVecSize * i]; + for (int i = 0; i < BLOCK; i++) + key_cache[BLOCK2 * threadIdx.y + i][k] = kb[k + K / kVecSize * i]; //value_cache[threadIdx.y][k] = vb[k]; //grad_out_cache[threadIdx.x][k] = gb[k]; - //tmp_grad1[threadIdx.x][k] = zero; - //tmp_grad2[threadIdx.y][k] = zero; } //__syncwarp(); - //__syncthreads(); + __syncthreads(); - normalizer = logsumexp_normalizer[batch_idx][query_idx]; - tmp_sum = tmp_sum_i[batch_idx][query_idx]; + for (int i = 0; i < BLOCK; i++) { + normalizer[i] = logsumexp_normalizer[batch_idx][query_idx + i]; + tmp_sum[i] = tmp_sum_i[batch_idx][query_idx + i]; + } auto key_j = reinterpret_cast(key[batch_idx][l].data()); - scalar_t attn_v = 0; - scalar_t grad_attn_v = 0; + scalar_t attn_v[BLOCK][BLOCK2] = {0}; + scalar_t grad_attn_v[BLOCK][BLOCK2] = {0}; for (int64_t k = 0; k < K / kVecSize; k += 1) { - //sputnik::VectorCompute::Dot(query_block[k], key_j[k], &attn_v); - //sputnik::VectorCompute::Dot(__ldg(qb + k), __ldg(kb + k), &attn_v); - sputnik::VectorCompute::Dot(query_cache[threadIdx.x][k], key_cache[threadIdx.y][k], &attn_v); - sputnik::VectorCompute::Dot(__ldg(gb + k), __ldg(vb + k), &grad_attn_v); +#pragma unroll + for (int ii = 0; ii < BLOCK2; ii++) { + vec_t kk = key_cache[BLOCK2 * threadIdx.y + ii][k]; + vec_t tt = __ldg(vb + k + K / kVecSize * ii); +#pragma unroll + for (int i = 0; i < BLOCK; i++) { + sputnik::VectorCompute::Dot(query_cache[BLOCK * threadIdx.x + i][k], kk, &attn_v[i][ii]); + sputnik::VectorCompute::Dot(__ldg(gb + k + K / kVecSize * i), tt, &grad_attn_v[i][ii]); + } + } + } +#pragma unroll + for (int ii = 0; ii < BLOCK2; ii++) { +#pragma unroll + for (int i = 0; i < BLOCK; i++) { + attn_v[i][ii] = std::exp(attn_v[i][ii] - normalizer[i]); + } } - attn_v = std::exp(attn_v - normalizer); - - - /* - for (int64_t k = 0; k < K / kVecSize; k++) { - sputnik::VectorCompute::Dot(__ldg(gb + k), __ldg(vb + k), &grad_attn_v); - //sputnik::VectorCompute::Dot(grad_out_cache[threadIdx.x][k], value_cache[threadIdx.y][k], &grad_attn_v); - }*/ - fact[threadIdx.x][threadIdx.y] = attn_v * grad_attn_v - attn_v * tmp_sum; +#pragma unroll + for (int ii = 0; ii < BLOCK2; ii++) { +#pragma unroll + for (int i = 0; i < BLOCK; i++) { + fact[BLOCK * threadIdx.x + i][BLOCK2 * threadIdx.y + ii] = attn_v[i][ii] * grad_attn_v[i][ii] - attn_v[i][ii] * tmp_sum[i]; + } + } __syncthreads(); for (int64_t k = threadIdx.y; k < K / kVecSize; k+= blockDim.y) { - vec_t res = {0}; + vec_t res[BLOCK] = {0}; +#pragma unroll for (int64_t i = 0; i < KS2; i++) { - sputnik::VectorCompute::FMA(fact[threadIdx.x][i], key_cache[i][k], &res); - //sputnik::VectorCompute::FMA(fact[threadIdx.x][i], __ldg(kb + i * K / kVecSize + k), &res); + vec_t kk = key_cache[i][k]; +#pragma unroll + for (int ii = 0; ii < BLOCK; ii++) { + sputnik::VectorCompute::FMA(fact[BLOCK * threadIdx.x + ii][i], kk, &res[ii]); + } + } +#pragma unroll + for (int ii = 0; ii < BLOCK; ii++) { + myGpuAtomicAdd(&grad_q[batch_idx][query_idx + ii][k * kVecSize], res[ii]); } - //if (threadIdx.y == 0) - myGpuAtomicAdd(&grad_q[batch_idx][query_idx][k * kVecSize], res); } for (int64_t k = threadIdx.x; k < K / kVecSize; k+= blockDim.x) { - vec_t res = {0}; + vec_t res[BLOCK2] = {0}; +#pragma unroll for (int64_t i = 0; i < KS1; i++) { - sputnik::VectorCompute::FMA(fact[i][threadIdx.y], query_cache[i][k], &res); - //sputnik::VectorCompute::FMA(fact[i][threadIdx.y], __ldg(qb + i * K / kVecSize + k), &res); - } - //if (threadIdx.x == 0) - myGpuAtomicAdd(&grad_k[batch_idx][l][k * kVecSize], res); - } -/* - for (int64_t k = 0; k < K / kVecSize; k++) { - vec_t res = {0}; - scalar_t ttmp = attn_v * grad_attn_v - attn_v * tmp_sum; - //sputnik::VectorCompute::FMA(ttmp, key_j[k], &res); - sputnik::VectorCompute::FMA(ttmp, key_cache[threadIdx.y][k], &res); - //sputnik::VectorCompute::FMA(ttmp, key_cache[threadIdx.y][k], &tmp_grad1[threadIdx.x][k]); - //myGpuAtomicAdd(&grad_q[batch_idx][query_idx][k * kVecSize], res); - } - - - for (int64_t k = 0; k < K / kVecSize; k++) { - //vec_t res = {0}; - scalar_t ttmp = attn_v * grad_attn_v - attn_v * tmp_sum; - //sputnik::VectorCompute::FMA(ttmp, query_block[k], &res); - sputnik::VectorCompute::FMA(ttmp, query_cache[threadIdx.x][k], &res); - //sputnik::VectorCompute::FMA(ttmp, query_cache[threadIdx.x][k], &tmp_grad2[threadIdx.y][k]); - //myGpuAtomicAdd(&grad_k[batch_idx][l][k * kVecSize], res); - } -*/ -/* - if ((threadIdx.x == 0) && (threadIdx.y == 0)) { - for (int64_t k = 0; k < K / kVecSize; k++) { - vec_t res0 = {0}; - for (int i = 0; i < KS1; i++) - { - sputnik::VectorCompute::FMA(scalar_t(1), tmp_grad1[i][k], &res0); + vec_t kk = query_cache[i][k]; +#pragma unroll + for (int ii = 0; ii < BLOCK2; ii++) { + sputnik::VectorCompute::FMA(fact[i][BLOCK2 * threadIdx.y + ii], kk, &res[ii]); + } } - myGpuAtomicAdd(&grad_q[batch_idx][query_idx][k * kVecSize], res0); - - vec_t res1 = {0}; - for (int i = 0; i < KS2; i++) - { - sputnik::VectorCompute::FMA(scalar_t(1), tmp_grad2[i][k], &res1); +#pragma unroll + for (int ii = 0; ii < BLOCK2; ii++) { + myGpuAtomicAdd(&grad_k[batch_idx][l + ii][k * kVecSize], res[ii]); } - myGpuAtomicAdd(&grad_k[batch_idx][l][k * kVecSize], res1); - } - }*/ - + } } @@ -1271,7 +1256,7 @@ std::tuple attention_backward( //dim3 block2(TILE_SIZE2Q / kBlockSizeQ2, TILE_SIZE2K / kBlockSizeK2); dim3 grid2(ceil_div(M, int64_t(16)), ceil_div(N, int64_t(16)), B); - dim3 block2(16, 16); + dim3 block2(4, 4); // TODO: try adding a blockDim.x to iterate over k attention_backward_kernel2< From c83ebb37c1d2b1b3b8f47e0af4b22cfe353d389b Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 20 Apr 2022 02:33:39 -0700 Subject: [PATCH 23/45] Improve perf by 5% Potentially due to avoiding bank conflicts? --- .../components/attention/csrc/cuda/attention.cu | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 0c51c88b35..f10dfd9728 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -1054,30 +1054,21 @@ __global__ void attention_backward_kernel2( constexpr int KS1 = 16; constexpr int KS2 = 16; - __shared__ vec_t query_cache[KS1][BUFFER_SIZE]; - __shared__ vec_t key_cache[KS2][BUFFER_SIZE]; - //__shared__ vec_t value_cache[KS2][BUFFER_SIZE]; - //__shared__ vec_t grad_out_cache[KS1][BUFFER_SIZE]; - __shared__ scalar_t fact[KS1][KS2]; + __shared__ vec_t query_cache[KS1][BUFFER_SIZE+1]; + __shared__ vec_t key_cache[KS2][BUFFER_SIZE+1]; + __shared__ scalar_t fact[KS1][KS2+1]; auto qb = reinterpret_cast(query[batch_idx][query_idx].data()); auto kb = reinterpret_cast(key[batch_idx][l].data()); auto vb = reinterpret_cast(value[batch_idx][l].data()); auto gb = reinterpret_cast(grad_out[batch_idx][query_idx].data()); - //__shared__ vec_t tmp_grad1[KS1][BUFFER_SIZE]; - //__shared__ vec_t tmp_grad2[KS2][BUFFER_SIZE]; - //vec_t query_cache[KS1][BUFFER_SIZE]; - //vec_t key_cache[KS2][BUFFER_SIZE]; - vec_t zero = {0}; for (int64_t k = 0; k < K / kVecSize; k++) { for (int i = 0; i < BLOCK; i++) query_cache[BLOCK * threadIdx.x + i][k] = qb[k + K / kVecSize * i]; for (int i = 0; i < BLOCK; i++) key_cache[BLOCK2 * threadIdx.y + i][k] = kb[k + K / kVecSize * i]; - //value_cache[threadIdx.y][k] = vb[k]; - //grad_out_cache[threadIdx.x][k] = gb[k]; } //__syncwarp(); __syncthreads(); From 2497f967b4b7483559733d46d1e6471945166474 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 20 Apr 2022 06:18:06 -0700 Subject: [PATCH 24/45] Remove query-key from shared memory and increase tile size Brings 10% improvement, being better than my previous best version --- .../attention/csrc/cuda/attention.cu | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index f10dfd9728..ef3d4ae7f8 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -1032,7 +1032,7 @@ __global__ void attention_backward_kernel2( int64_t N = key.size(1); constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); - constexpr int BLOCK = 4; + constexpr int BLOCK = 8; // KS1 / blockDim.x constexpr int BLOCK2 = 4; int64_t batch_idx = blockIdx.z; @@ -1051,11 +1051,11 @@ __global__ void attention_backward_kernel2( scalar_t normalizer[BLOCK]; scalar_t tmp_sum[BLOCK]; - constexpr int KS1 = 16; + constexpr int KS1 = 32; constexpr int KS2 = 16; - __shared__ vec_t query_cache[KS1][BUFFER_SIZE+1]; - __shared__ vec_t key_cache[KS2][BUFFER_SIZE+1]; + //__shared__ vec_t query_cache[KS1][BUFFER_SIZE+1]; + //__shared__ vec_t key_cache[KS2][BUFFER_SIZE+1]; __shared__ scalar_t fact[KS1][KS2+1]; auto qb = reinterpret_cast(query[batch_idx][query_idx].data()); @@ -1063,16 +1063,17 @@ __global__ void attention_backward_kernel2( auto vb = reinterpret_cast(value[batch_idx][l].data()); auto gb = reinterpret_cast(grad_out[batch_idx][query_idx].data()); + /* vec_t zero = {0}; for (int64_t k = 0; k < K / kVecSize; k++) { for (int i = 0; i < BLOCK; i++) query_cache[BLOCK * threadIdx.x + i][k] = qb[k + K / kVecSize * i]; - for (int i = 0; i < BLOCK; i++) + for (int i = 0; i < BLOCK2; i++) key_cache[BLOCK2 * threadIdx.y + i][k] = kb[k + K / kVecSize * i]; } //__syncwarp(); __syncthreads(); - + */ for (int i = 0; i < BLOCK; i++) { normalizer[i] = logsumexp_normalizer[batch_idx][query_idx + i]; @@ -1086,11 +1087,13 @@ __global__ void attention_backward_kernel2( for (int64_t k = 0; k < K / kVecSize; k += 1) { #pragma unroll for (int ii = 0; ii < BLOCK2; ii++) { - vec_t kk = key_cache[BLOCK2 * threadIdx.y + ii][k]; + //vec_t kk = key_cache[BLOCK2 * threadIdx.y + ii][k]; + vec_t kk = __ldg(kb + k + K / kVecSize * ii); vec_t tt = __ldg(vb + k + K / kVecSize * ii); #pragma unroll for (int i = 0; i < BLOCK; i++) { - sputnik::VectorCompute::Dot(query_cache[BLOCK * threadIdx.x + i][k], kk, &attn_v[i][ii]); + //sputnik::VectorCompute::Dot(query_cache[BLOCK * threadIdx.x + i][k], kk, &attn_v[i][ii]); + sputnik::VectorCompute::Dot(__ldg(qb + k + K / kVecSize * i), kk, &attn_v[i][ii]); sputnik::VectorCompute::Dot(__ldg(gb + k + K / kVecSize * i), tt, &grad_attn_v[i][ii]); } } @@ -1118,7 +1121,8 @@ __global__ void attention_backward_kernel2( vec_t res[BLOCK] = {0}; #pragma unroll for (int64_t i = 0; i < KS2; i++) { - vec_t kk = key_cache[i][k]; + //vec_t kk = key_cache[i][k]; + vec_t kk = __ldg(kb + k + K / kVecSize * (i - BLOCK2 * threadIdx.y)); #pragma unroll for (int ii = 0; ii < BLOCK; ii++) { sputnik::VectorCompute::FMA(fact[BLOCK * threadIdx.x + ii][i], kk, &res[ii]); @@ -1134,7 +1138,8 @@ __global__ void attention_backward_kernel2( vec_t res[BLOCK2] = {0}; #pragma unroll for (int64_t i = 0; i < KS1; i++) { - vec_t kk = query_cache[i][k]; + //vec_t kk = query_cache[i][k]; + vec_t kk = __ldg(qb + k + K / kVecSize * (i - BLOCK * threadIdx.x)); #pragma unroll for (int ii = 0; ii < BLOCK2; ii++) { sputnik::VectorCompute::FMA(fact[i][BLOCK2 * threadIdx.y + ii], kk, &res[ii]); @@ -1246,7 +1251,7 @@ std::tuple attention_backward( //dim3 grid2(ceil_div(M, int64_t(TILE_SIZE2Q)), ceil_div(N, int64_t(TILE_SIZE2K)), B); //dim3 block2(TILE_SIZE2Q / kBlockSizeQ2, TILE_SIZE2K / kBlockSizeK2); - dim3 grid2(ceil_div(M, int64_t(16)), ceil_div(N, int64_t(16)), B); + dim3 grid2(ceil_div(M, int64_t(32)), ceil_div(N, int64_t(16)), B); dim3 block2(4, 4); // TODO: try adding a blockDim.x to iterate over k From 24ed9bb10fbbb372076999681f2b8737be12f984 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 20 Apr 2022 06:26:39 -0700 Subject: [PATCH 25/45] Make it 20% faster with better hyperparameters This is now significantly faster than what we had before, and is even faster than the vanilla implementation --- xformers/components/attention/csrc/cuda/attention.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index ef3d4ae7f8..24218d82ea 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -1032,7 +1032,7 @@ __global__ void attention_backward_kernel2( int64_t N = key.size(1); constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); - constexpr int BLOCK = 8; // KS1 / blockDim.x + constexpr int BLOCK = 4; // KS1 / blockDim.x constexpr int BLOCK2 = 4; int64_t batch_idx = blockIdx.z; @@ -1252,7 +1252,7 @@ std::tuple attention_backward( //dim3 block2(TILE_SIZE2Q / kBlockSizeQ2, TILE_SIZE2K / kBlockSizeK2); dim3 grid2(ceil_div(M, int64_t(32)), ceil_div(N, int64_t(16)), B); - dim3 block2(4, 4); + dim3 block2(8, 4); // TODO: try adding a blockDim.x to iterate over k attention_backward_kernel2< From 33f0c71454186e241ab0ddc965684508b8cc42cd Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 20 Apr 2022 06:30:29 -0700 Subject: [PATCH 26/45] Make it another 12% faster This is now 18% faster than the vanilla implementation --- xformers/components/attention/csrc/cuda/attention.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 24218d82ea..2a35431260 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -1052,7 +1052,7 @@ __global__ void attention_backward_kernel2( scalar_t tmp_sum[BLOCK]; constexpr int KS1 = 32; - constexpr int KS2 = 16; + constexpr int KS2 = 32; //__shared__ vec_t query_cache[KS1][BUFFER_SIZE+1]; //__shared__ vec_t key_cache[KS2][BUFFER_SIZE+1]; @@ -1251,8 +1251,8 @@ std::tuple attention_backward( //dim3 grid2(ceil_div(M, int64_t(TILE_SIZE2Q)), ceil_div(N, int64_t(TILE_SIZE2K)), B); //dim3 block2(TILE_SIZE2Q / kBlockSizeQ2, TILE_SIZE2K / kBlockSizeK2); - dim3 grid2(ceil_div(M, int64_t(32)), ceil_div(N, int64_t(16)), B); - dim3 block2(8, 4); + dim3 grid2(ceil_div(M, int64_t(32)), ceil_div(N, int64_t(32)), B); + dim3 block2(8, 8); // TODO: try adding a blockDim.x to iterate over k attention_backward_kernel2< From 253b3ebd7862016befcc0ead66e39ca283fc4b67 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 20 Apr 2022 08:06:48 -0700 Subject: [PATCH 27/45] Code cleanup --- .../attention/csrc/cuda/attention.cu | 102 +++++++----------- 1 file changed, 37 insertions(+), 65 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 2a35431260..604b6cabb3 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -1015,7 +1015,8 @@ template < typename vec_t, int kBlockSizeQ, int kBlockSizeK, - int BUFFER_SIZE> + int TILE_SIZEQ, + int TILE_SIZEK> __global__ void attention_backward_kernel2( at::PackedTensorAccessor grad_q, at::PackedTensorAccessor grad_k, @@ -1032,15 +1033,12 @@ __global__ void attention_backward_kernel2( int64_t N = key.size(1); constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); - constexpr int BLOCK = 4; // KS1 / blockDim.x - constexpr int BLOCK2 = 4; int64_t batch_idx = blockIdx.z; int64_t query_idx = - blockIdx.x * blockDim.x * BLOCK + threadIdx.x * BLOCK; - + blockIdx.x * blockDim.x * kBlockSizeQ + threadIdx.x * kBlockSizeQ; int64_t l = - blockIdx.y * blockDim.y * BLOCK2 + threadIdx.y * BLOCK2; + blockIdx.y * blockDim.y * kBlockSizeK + threadIdx.y * kBlockSizeK; if (query_idx >= M) return; @@ -1048,109 +1046,85 @@ __global__ void attention_backward_kernel2( if (l >= N) return; - scalar_t normalizer[BLOCK]; - scalar_t tmp_sum[BLOCK]; - - constexpr int KS1 = 32; - constexpr int KS2 = 32; + scalar_t normalizer[kBlockSizeQ]; + scalar_t tmp_sum[kBlockSizeQ]; - //__shared__ vec_t query_cache[KS1][BUFFER_SIZE+1]; - //__shared__ vec_t key_cache[KS2][BUFFER_SIZE+1]; - __shared__ scalar_t fact[KS1][KS2+1]; + __shared__ scalar_t fact[TILE_SIZEQ][TILE_SIZEK + 1]; auto qb = reinterpret_cast(query[batch_idx][query_idx].data()); auto kb = reinterpret_cast(key[batch_idx][l].data()); auto vb = reinterpret_cast(value[batch_idx][l].data()); auto gb = reinterpret_cast(grad_out[batch_idx][query_idx].data()); - /* - vec_t zero = {0}; - for (int64_t k = 0; k < K / kVecSize; k++) { - for (int i = 0; i < BLOCK; i++) - query_cache[BLOCK * threadIdx.x + i][k] = qb[k + K / kVecSize * i]; - for (int i = 0; i < BLOCK2; i++) - key_cache[BLOCK2 * threadIdx.y + i][k] = kb[k + K / kVecSize * i]; - } - //__syncwarp(); - __syncthreads(); - */ - - for (int i = 0; i < BLOCK; i++) { + for (int i = 0; i < kBlockSizeQ; i++) { normalizer[i] = logsumexp_normalizer[batch_idx][query_idx + i]; tmp_sum[i] = tmp_sum_i[batch_idx][query_idx + i]; } auto key_j = reinterpret_cast(key[batch_idx][l].data()); - scalar_t attn_v[BLOCK][BLOCK2] = {0}; - scalar_t grad_attn_v[BLOCK][BLOCK2] = {0}; + scalar_t attn_v[kBlockSizeQ][kBlockSizeK] = {0}; + scalar_t grad_attn_v[kBlockSizeQ][kBlockSizeK] = {0}; for (int64_t k = 0; k < K / kVecSize; k += 1) { #pragma unroll - for (int ii = 0; ii < BLOCK2; ii++) { - //vec_t kk = key_cache[BLOCK2 * threadIdx.y + ii][k]; + for (int ii = 0; ii < kBlockSizeK; ii++) { vec_t kk = __ldg(kb + k + K / kVecSize * ii); vec_t tt = __ldg(vb + k + K / kVecSize * ii); #pragma unroll - for (int i = 0; i < BLOCK; i++) { - //sputnik::VectorCompute::Dot(query_cache[BLOCK * threadIdx.x + i][k], kk, &attn_v[i][ii]); + for (int i = 0; i < kBlockSizeQ; i++) { sputnik::VectorCompute::Dot(__ldg(qb + k + K / kVecSize * i), kk, &attn_v[i][ii]); sputnik::VectorCompute::Dot(__ldg(gb + k + K / kVecSize * i), tt, &grad_attn_v[i][ii]); } } } #pragma unroll - for (int ii = 0; ii < BLOCK2; ii++) { + for (int ii = 0; ii < kBlockSizeK; ii++) { #pragma unroll - for (int i = 0; i < BLOCK; i++) { + for (int i = 0; i < kBlockSizeQ; i++) { attn_v[i][ii] = std::exp(attn_v[i][ii] - normalizer[i]); } } #pragma unroll - for (int ii = 0; ii < BLOCK2; ii++) { + for (int ii = 0; ii < kBlockSizeK; ii++) { #pragma unroll - for (int i = 0; i < BLOCK; i++) { - fact[BLOCK * threadIdx.x + i][BLOCK2 * threadIdx.y + ii] = attn_v[i][ii] * grad_attn_v[i][ii] - attn_v[i][ii] * tmp_sum[i]; + for (int i = 0; i < kBlockSizeQ; i++) { + fact[kBlockSizeQ * threadIdx.x + i][kBlockSizeK * threadIdx.y + ii] = attn_v[i][ii] * grad_attn_v[i][ii] - attn_v[i][ii] * tmp_sum[i]; } } __syncthreads(); - - for (int64_t k = threadIdx.y; k < K / kVecSize; k+= blockDim.y) { - vec_t res[BLOCK] = {0}; + vec_t res[kBlockSizeQ] = {0}; #pragma unroll - for (int64_t i = 0; i < KS2; i++) { - //vec_t kk = key_cache[i][k]; - vec_t kk = __ldg(kb + k + K / kVecSize * (i - BLOCK2 * threadIdx.y)); + for (int64_t i = 0; i < TILE_SIZEK; i++) { + vec_t kk = __ldg(kb + k + K / kVecSize * (i - kBlockSizeK * threadIdx.y)); #pragma unroll - for (int ii = 0; ii < BLOCK; ii++) { - sputnik::VectorCompute::FMA(fact[BLOCK * threadIdx.x + ii][i], kk, &res[ii]); + for (int ii = 0; ii < kBlockSizeQ; ii++) { + sputnik::VectorCompute::FMA(fact[kBlockSizeQ * threadIdx.x + ii][i], kk, &res[ii]); } } #pragma unroll - for (int ii = 0; ii < BLOCK; ii++) { + for (int ii = 0; ii < kBlockSizeQ; ii++) { myGpuAtomicAdd(&grad_q[batch_idx][query_idx + ii][k * kVecSize], res[ii]); } } for (int64_t k = threadIdx.x; k < K / kVecSize; k+= blockDim.x) { - vec_t res[BLOCK2] = {0}; + vec_t res[kBlockSizeK] = {0}; #pragma unroll - for (int64_t i = 0; i < KS1; i++) { - //vec_t kk = query_cache[i][k]; - vec_t kk = __ldg(qb + k + K / kVecSize * (i - BLOCK * threadIdx.x)); + for (int64_t i = 0; i < TILE_SIZEQ; i++) { + vec_t kk = __ldg(qb + k + K / kVecSize * (i - kBlockSizeQ * threadIdx.x)); #pragma unroll - for (int ii = 0; ii < BLOCK2; ii++) { - sputnik::VectorCompute::FMA(fact[i][BLOCK2 * threadIdx.y + ii], kk, &res[ii]); + for (int ii = 0; ii < kBlockSizeK; ii++) { + sputnik::VectorCompute::FMA(fact[i][kBlockSizeK * threadIdx.y + ii], kk, &res[ii]); } } #pragma unroll - for (int ii = 0; ii < BLOCK2; ii++) { + for (int ii = 0; ii < kBlockSizeK; ii++) { myGpuAtomicAdd(&grad_k[batch_idx][l + ii][k * kVecSize], res[ii]); } } - } @@ -1242,17 +1216,14 @@ std::tuple attention_backward( // }); - constexpr int TILE_SIZE2Q = 32 * 4; - constexpr int TILE_SIZE2K = 32 * 8; - - constexpr int64_t kBlockSizeQ2 = 8;//2; - constexpr int64_t kBlockSizeK2 = 16;//32; + constexpr int TILE_SIZEQ2 = 32; + constexpr int TILE_SIZEK2 = 32; - //dim3 grid2(ceil_div(M, int64_t(TILE_SIZE2Q)), ceil_div(N, int64_t(TILE_SIZE2K)), B); - //dim3 block2(TILE_SIZE2Q / kBlockSizeQ2, TILE_SIZE2K / kBlockSizeK2); + constexpr int64_t kBlockSizeQ2 = 4; + constexpr int64_t kBlockSizeK2 = 4; - dim3 grid2(ceil_div(M, int64_t(32)), ceil_div(N, int64_t(32)), B); - dim3 block2(8, 8); + dim3 grid2(ceil_div(M, int64_t(TILE_SIZEQ2)), ceil_div(N, int64_t(TILE_SIZEK2)), B); + dim3 block2(TILE_SIZEQ2 / kBlockSizeQ2, TILE_SIZEK2 / kBlockSizeK2); // TODO: try adding a blockDim.x to iterate over k attention_backward_kernel2< @@ -1260,7 +1231,8 @@ std::tuple attention_backward( vec_t, kBlockSizeQ2, kBlockSizeK2, - BUFFER_SIZE><<>>( + TILE_SIZEQ2, + TILE_SIZEK2><<>>( grad_q.packed_accessor(), grad_k.packed_accessor(), grad_v.packed_accessor(), From e94d0cdb80b656521c2a1e8ed2336c9ae46476a0 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 20 Apr 2022 08:19:33 -0700 Subject: [PATCH 28/45] Further cleanups Remove previous implementation --- .../attention/csrc/cuda/attention.cu | 324 +----------------- 1 file changed, 6 insertions(+), 318 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 604b6cabb3..8db2ed675e 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -612,11 +612,8 @@ template < typename scalar_t, typename vec_t, int kBlockSizeQ, - int kBlockSizeK, - int BUFFER_SIZE> -__global__ void attention_backward_kernel( - at::PackedTensorAccessor grad_q, - at::PackedTensorAccessor grad_k, + int kBlockSizeK> +__global__ void attention_backward_grad_v_kernel( at::PackedTensorAccessor grad_v, at::PackedTensorAccessor grad_out, at::PackedTensorAccessor query, @@ -638,16 +635,10 @@ __global__ void attention_backward_kernel( if (query_idx >= M) return; - //vec_t temp_buffer[kBlockSizeQ][BUFFER_SIZE] = {0}; - //vec_t temp_grad_q[kBlockSizeQ][BUFFER_SIZE] = {0}; - vec_t* query_block[kBlockSizeQ]; vec_t* grad_out_block[kBlockSizeQ]; - //vec_t* grad_q_block[kBlockSizeQ]; scalar_t normalizer[kBlockSizeQ]; - //__shared__ vec_t query_cache[kBlockSizeQ][BUFFER_SIZE]; - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { int64_t index = query_idx + q_item_idx; index = index >= M ? M - 1 : index; @@ -655,14 +646,8 @@ __global__ void attention_backward_kernel( reinterpret_cast(query[batch_idx][index].data()); grad_out_block[q_item_idx] = reinterpret_cast(grad_out[batch_idx][index].data()); - //grad_q_block[q_item_idx] = - // reinterpret_cast(grad_q[batch_idx][index].data()); normalizer[q_item_idx] = logsumexp_normalizer[batch_idx][index]; - //for (int64_t k = threadIdx.x; k < K / kVecSize; k += blockDim.x) { - // query_cache[q_item_idx][k] = query_block[q_item_idx][k]; - // } } - //__syncthreads(); scalar_t tmp_sum[kBlockSizeQ] = {0}; for (int64_t l = threadIdx.x * kBlockSizeK; l < N; @@ -671,7 +656,6 @@ __global__ void attention_backward_kernel( scalar_t attn_v[kBlockSizeQ][kBlockSizeK] = {0}; compute_dot( - //query_cache, key_j, attn_v, K); query_block, key_j, attn_v, K); #pragma unroll @@ -683,10 +667,8 @@ __global__ void attention_backward_kernel( } } - // now compute grad_q and grad_k // first compute the gradient for the self-attention // after softmax - // scalar_t grad_attn_v = 0; scalar_t grad_attn_v[kBlockSizeQ][kBlockSizeK] = {0}; auto value_j = reinterpret_cast(value[batch_idx][l].data()); @@ -712,304 +694,21 @@ __global__ void attention_backward_kernel( } // those are temporaries for the gradient of the softmax - //scalar_t tmp[kBlockSizeQ][kBlockSizeK]; #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { #pragma unroll for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - //tmp[q_item_idx][k_item_idx] = attn_v[q_item_idx][k_item_idx] * - // grad_attn_v[q_item_idx][k_item_idx]; - //tmp_sum[q_item_idx] += tmp[q_item_idx][k_item_idx]; tmp_sum[q_item_idx] += attn_v[q_item_idx][k_item_idx] * grad_attn_v[q_item_idx][k_item_idx]; } } -/* - // grad_q is easy - for (int64_t k = 0; k < K / kVecSize; k++) { -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { -#pragma unroll - for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - vec_t ttt = key_j[k + K / kVecSize * k_item_idx]; - sputnik::VectorCompute::FMA(tmp[q_item_idx][k_item_idx], ttt, &temp_grad_q[q_item_idx][k]); - sputnik::VectorCompute::FMA( - attn_v[q_item_idx][k_item_idx], ttt, &temp_buffer[q_item_idx][k]); - } - } - } -*/ - // but grad_k is a bit trickier -/* - for (int64_t k = 0; k < K / kVecSize; k++) { -#pragma unroll - for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - vec_t res = {0}; -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - // res += tmp[q_item_idx][k_item_idx] * query_block[q_item_idx][k]; - vec_t qqq = __ldg(query_block[q_item_idx] + k); - //vec_t qqq = query_cache[q_item_idx][k]; - sputnik::VectorCompute::FMA(tmp[q_item_idx][k_item_idx], qqq, &res); - } - myGpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k * kVecSize], res); - } - }*/ } #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { tmp_sum[q_item_idx] = warpSum(tmp_sum[q_item_idx]); tmp_sum_i[batch_idx][query_idx + q_item_idx] = tmp_sum[q_item_idx]; } -/* - for (int64_t l = threadIdx.x * kBlockSizeK; l < N; - l += blockDim.x * kBlockSizeK) { - auto key_j = reinterpret_cast(key[batch_idx][l].data()); - scalar_t attn_v[kBlockSizeQ][kBlockSizeK] = {0}; - compute_dot( - query_block, key_j, attn_v, K); - -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { -#pragma unroll - for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - attn_v[q_item_idx][k_item_idx] = - std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]); - } - } - - scalar_t grad_attn_v[kBlockSizeQ][kBlockSizeK] = {0}; - auto value_j = reinterpret_cast(value[batch_idx][l].data()); - - for (int64_t k = 0; k < K / kVecSize; k++) { - vec_t temp_i[kBlockSizeQ]; -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - temp_i[q_item_idx] = __ldg(grad_out_block[q_item_idx] + k); - } - -#pragma unroll - for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - vec_t v = value_j[k + K / kVecSize * k_item_idx]; - vec_t tt = {0}; -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - sputnik::VectorCompute::Dot( - temp_i[q_item_idx], v, &grad_attn_v[q_item_idx][k_item_idx]); - } - } - } - - // those are temporaries for the gradient of the softmax - scalar_t tmp[kBlockSizeQ][kBlockSizeK]; -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { -#pragma unroll - for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - tmp[q_item_idx][k_item_idx] = attn_v[q_item_idx][k_item_idx] * - grad_attn_v[q_item_idx][k_item_idx]; - } - } - - - for (int64_t k = 0; k < K / kVecSize; k++) { -#pragma unroll - for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - vec_t res = {0}; -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - //scalar_t ttt = -attn_v[q_item_idx][k_item_idx] * tmp_sum[q_item_idx]; - scalar_t ttt = tmp[q_item_idx][k_item_idx] - attn_v[q_item_idx][k_item_idx] * tmp_sum[q_item_idx]; - vec_t qqq = __ldg(query_block[q_item_idx] + k); - //vec_t qqq = query_cache[q_item_idx][k]; - sputnik::VectorCompute::FMA(ttt, qqq, &res); - } - myGpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k * kVecSize], res); - } - } - }*/ - /* - for (int64_t k = 0; k < K / kVecSize; k++) { -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - temp_grad_q[q_item_idx][k] = - warpSum(temp_grad_q[q_item_idx][k]); - temp_buffer[q_item_idx][k] = - warpSum(temp_buffer[q_item_idx][k]); - } - } - for (int64_t k = threadIdx.x; k < K / kVecSize; k += blockDim.x) { -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - // sputnik::VectorCompute::FMA(-tmp_sum[q_item_idx], temp_buffer[q_item_idx][k], &temp_grad_q[q_item_idx][k]); - // grad_q_block[q_item_idx][k] = temp_grad_q[q_item_idx][k]; - grad_q_block[q_item_idx][k].x = temp_grad_q[q_item_idx][k].x - - temp_buffer[q_item_idx][k].x * tmp_sum[q_item_idx]; - grad_q_block[q_item_idx][k].y = temp_grad_q[q_item_idx][k].y - - temp_buffer[q_item_idx][k].y * tmp_sum[q_item_idx]; - grad_q_block[q_item_idx][k].z = temp_grad_q[q_item_idx][k].z - - temp_buffer[q_item_idx][k].z * tmp_sum[q_item_idx]; - grad_q_block[q_item_idx][k].w = temp_grad_q[q_item_idx][k].w - - temp_buffer[q_item_idx][k].w * tmp_sum[q_item_idx]; - } - }*/ } - -template < - typename scalar_t, - typename vec_t, - int kBlockSizeQ, - int kBlockSizeK, - int BUFFER_SIZE> -__global__ void attention_backward_kernel3( - at::PackedTensorAccessor grad_q, - at::PackedTensorAccessor grad_k, - at::PackedTensorAccessor grad_v, - at::PackedTensorAccessor grad_out, - at::PackedTensorAccessor query, - at::PackedTensorAccessor key, - at::PackedTensorAccessor value, - at::PackedTensorAccessor tmp_sum_i, - at::PackedTensorAccessor logsumexp_normalizer) { - int64_t K = query.size(2); - int64_t B = query.size(0); - int64_t M = query.size(1); - int64_t N = key.size(1); - - constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); - - int64_t batch_idx = blockIdx.z; - int64_t query_idx = - blockIdx.x * blockDim.x * kBlockSizeQ + threadIdx.x * kBlockSizeQ; - - int64_t l = - blockIdx.y * blockDim.y * kBlockSizeK + threadIdx.y * kBlockSizeK; - - if (query_idx >= M) - return; - - if (l >= N) - return; - - vec_t* query_block[kBlockSizeQ]; - vec_t* grad_out_block[kBlockSizeQ]; - scalar_t normalizer[kBlockSizeQ]; - scalar_t tmp_sum[kBlockSizeQ]; - - //__shared__ vec_t query_cache[kBlockSizeQ][BUFFER_SIZE]; - - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - int64_t index = query_idx + q_item_idx; - index = index >= M ? M - 1 : index; - query_block[q_item_idx] = - reinterpret_cast(query[batch_idx][index].data()); - grad_out_block[q_item_idx] = - reinterpret_cast(grad_out[batch_idx][index].data()); - normalizer[q_item_idx] = logsumexp_normalizer[batch_idx][index]; - tmp_sum[q_item_idx] = tmp_sum_i[batch_idx][index]; - //for (int64_t k = threadIdx.x; k < K / kVecSize; k += blockDim.x) { - // query_cache[q_item_idx][k] = query_block[q_item_idx][k]; - // } - } - //__syncthreads(); - - //scalar_t tmp_sum[kBlockSizeQ] = {0}; - - //for (int64_t l = threadIdx.x * kBlockSizeK; l < N; - // l += blockDim.x * kBlockSizeK) - { - auto key_j = reinterpret_cast(key[batch_idx][l].data()); - scalar_t attn_v[kBlockSizeQ][kBlockSizeK] = {0}; - //compute_dot( - // query_block, key_j, attn_v, K); - - vec_t q_i[kBlockSizeQ]; - for (int64_t k = 0; k < K / kVecSize; k += 1) { -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - q_i[q_item_idx] = __ldg(query_block[q_item_idx] + k); - } -#pragma unroll - for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - vec_t k_i = key_j[k + K / kVecSize * k_item_idx]; -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - sputnik::VectorCompute::Dot( - q_i[q_item_idx], k_i, &attn_v[q_item_idx][k_item_idx]); - } - } - } - - -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { -#pragma unroll - for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - attn_v[q_item_idx][k_item_idx] = - std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]); - } - } - - scalar_t grad_attn_v[kBlockSizeQ][kBlockSizeK] = {0}; - auto value_j = reinterpret_cast(value[batch_idx][l].data()); - - for (int64_t k = 0; k < K / kVecSize; k++) { - vec_t temp_i[kBlockSizeQ]; -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - temp_i[q_item_idx] = __ldg(grad_out_block[q_item_idx] + k); - } - -#pragma unroll - for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - vec_t v = value_j[k + K / kVecSize * k_item_idx]; -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - sputnik::VectorCompute::Dot( - temp_i[q_item_idx], v, &grad_attn_v[q_item_idx][k_item_idx]); - } - } - } - - for (int64_t k = 0; k < K / kVecSize; k++) { - vec_t res[kBlockSizeQ] = {0}; -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { -#pragma unroll - for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - vec_t ttt = key_j[k + K / kVecSize * k_item_idx]; - scalar_t ttmp = attn_v[q_item_idx][k_item_idx] * grad_attn_v[q_item_idx][k_item_idx] - attn_v[q_item_idx][k_item_idx] * tmp_sum[q_item_idx]; - sputnik::VectorCompute::FMA(ttmp, ttt, &res[q_item_idx]); - } - } -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - myGpuAtomicAdd(&grad_q[batch_idx][query_idx + q_item_idx][k * kVecSize], res[q_item_idx]); - - } - } - - - for (int64_t k = 0; k < K / kVecSize; k++) { -#pragma unroll - for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - vec_t res = {0}; -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - //scalar_t ttt = -attn_v[q_item_idx][k_item_idx] * tmp_sum[q_item_idx]; - scalar_t ttt = attn_v[q_item_idx][k_item_idx] * grad_attn_v[q_item_idx][k_item_idx] - attn_v[q_item_idx][k_item_idx] * tmp_sum[q_item_idx]; - vec_t qqq = __ldg(query_block[q_item_idx] + k); - //vec_t qqq = query_cache[q_item_idx][k]; - sputnik::VectorCompute::FMA(ttt, qqq, &res); - } - myGpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k * kVecSize], res); - } - } - } - -} - - template < typename scalar_t, typename vec_t, @@ -1017,10 +716,9 @@ template < int kBlockSizeK, int TILE_SIZEQ, int TILE_SIZEK> -__global__ void attention_backward_kernel2( +__global__ void attention_backward_grad_qk_kernel( at::PackedTensorAccessor grad_q, at::PackedTensorAccessor grad_k, - at::PackedTensorAccessor grad_v, at::PackedTensorAccessor grad_out, at::PackedTensorAccessor query, at::PackedTensorAccessor key, @@ -1173,7 +871,6 @@ std::tuple attention_backward( int64_t N = key.size(1); int64_t K = query.size(2); - //at::Tensor res = at::empty({B, M, K}, query.options()); at::Tensor grad_q = at::zeros_like(query); at::Tensor grad_k = at::zeros_like(key); at::Tensor grad_v = at::zeros_like(value); @@ -1187,7 +884,6 @@ std::tuple attention_backward( constexpr int TILE_SIZE = 16 * 8; constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); - constexpr int64_t BUFFER_SIZE = 32 / kVecSize; constexpr int64_t kBlockSizeQ = 16; constexpr int64_t kBlockSizeK = 4; @@ -1196,16 +892,11 @@ std::tuple attention_backward( cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - // AT_DISPATCH_FLOATING_TYPES( - // query.scalar_type(), "attention_backward_kernel", [&] { - attention_backward_kernel< + attention_backward_grad_v_kernel< scalar_t, vec_t, kBlockSizeQ, - kBlockSizeK, - BUFFER_SIZE><<>>( - grad_q.packed_accessor(), - grad_k.packed_accessor(), + kBlockSizeK><<>>( grad_v.packed_accessor(), grad_out.packed_accessor(), query.packed_accessor(), @@ -1213,7 +904,6 @@ std::tuple attention_backward( value.packed_accessor(), tmp_sum_i.packed_accessor(), logsumexp.packed_accessor()); - // }); constexpr int TILE_SIZEQ2 = 32; @@ -1226,7 +916,7 @@ std::tuple attention_backward( dim3 block2(TILE_SIZEQ2 / kBlockSizeQ2, TILE_SIZEK2 / kBlockSizeK2); // TODO: try adding a blockDim.x to iterate over k - attention_backward_kernel2< + attention_backward_grad_qk_kernel< scalar_t, vec_t, kBlockSizeQ2, @@ -1235,14 +925,12 @@ std::tuple attention_backward( TILE_SIZEK2><<>>( grad_q.packed_accessor(), grad_k.packed_accessor(), - grad_v.packed_accessor(), grad_out.packed_accessor(), query.packed_accessor(), key.packed_accessor(), value.packed_accessor(), tmp_sum_i.packed_accessor(), logsumexp.packed_accessor()); - // }); AT_CUDA_CHECK(cudaGetLastError()); From 570677727bfbc9a6b3b4c5b595d68c67ab92a11b Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 20 Apr 2022 08:31:54 -0700 Subject: [PATCH 29/45] Variable rename --- .../attention/csrc/cuda/attention.cu | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 8db2ed675e..476f5dceb0 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -765,29 +765,29 @@ __global__ void attention_backward_grad_qk_kernel( for (int64_t k = 0; k < K / kVecSize; k += 1) { #pragma unroll - for (int ii = 0; ii < kBlockSizeK; ii++) { - vec_t kk = __ldg(kb + k + K / kVecSize * ii); - vec_t tt = __ldg(vb + k + K / kVecSize * ii); + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + vec_t kk = __ldg(kb + k + K / kVecSize * k_item_idx); + vec_t tt = __ldg(vb + k + K / kVecSize * k_item_idx); #pragma unroll - for (int i = 0; i < kBlockSizeQ; i++) { - sputnik::VectorCompute::Dot(__ldg(qb + k + K / kVecSize * i), kk, &attn_v[i][ii]); - sputnik::VectorCompute::Dot(__ldg(gb + k + K / kVecSize * i), tt, &grad_attn_v[i][ii]); + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + sputnik::VectorCompute::Dot(__ldg(qb + k + K / kVecSize * q_item_idx), kk, &attn_v[q_item_idx][k_item_idx]); + sputnik::VectorCompute::Dot(__ldg(gb + k + K / kVecSize * q_item_idx), tt, &grad_attn_v[q_item_idx][k_item_idx]); } } } #pragma unroll - for (int ii = 0; ii < kBlockSizeK; ii++) { + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { #pragma unroll - for (int i = 0; i < kBlockSizeQ; i++) { - attn_v[i][ii] = std::exp(attn_v[i][ii] - normalizer[i]); + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + attn_v[q_item_idx][k_item_idx] = std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]); } } #pragma unroll - for (int ii = 0; ii < kBlockSizeK; ii++) { + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { #pragma unroll - for (int i = 0; i < kBlockSizeQ; i++) { - fact[kBlockSizeQ * threadIdx.x + i][kBlockSizeK * threadIdx.y + ii] = attn_v[i][ii] * grad_attn_v[i][ii] - attn_v[i][ii] * tmp_sum[i]; + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + fact[kBlockSizeQ * threadIdx.x + q_item_idx][kBlockSizeK * threadIdx.y + k_item_idx] = attn_v[q_item_idx][k_item_idx] * grad_attn_v[q_item_idx][k_item_idx] - attn_v[q_item_idx][k_item_idx] * tmp_sum[q_item_idx]; } } __syncthreads(); @@ -798,13 +798,13 @@ __global__ void attention_backward_grad_qk_kernel( for (int64_t i = 0; i < TILE_SIZEK; i++) { vec_t kk = __ldg(kb + k + K / kVecSize * (i - kBlockSizeK * threadIdx.y)); #pragma unroll - for (int ii = 0; ii < kBlockSizeQ; ii++) { - sputnik::VectorCompute::FMA(fact[kBlockSizeQ * threadIdx.x + ii][i], kk, &res[ii]); + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + sputnik::VectorCompute::FMA(fact[kBlockSizeQ * threadIdx.x + q_item_idx][i], kk, &res[q_item_idx]); } } #pragma unroll - for (int ii = 0; ii < kBlockSizeQ; ii++) { - myGpuAtomicAdd(&grad_q[batch_idx][query_idx + ii][k * kVecSize], res[ii]); + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + myGpuAtomicAdd(&grad_q[batch_idx][query_idx + q_item_idx][k * kVecSize], res[q_item_idx]); } } @@ -814,13 +814,13 @@ __global__ void attention_backward_grad_qk_kernel( for (int64_t i = 0; i < TILE_SIZEQ; i++) { vec_t kk = __ldg(qb + k + K / kVecSize * (i - kBlockSizeQ * threadIdx.x)); #pragma unroll - for (int ii = 0; ii < kBlockSizeK; ii++) { - sputnik::VectorCompute::FMA(fact[i][kBlockSizeK * threadIdx.y + ii], kk, &res[ii]); + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + sputnik::VectorCompute::FMA(fact[i][kBlockSizeK * threadIdx.y + k_item_idx], kk, &res[k_item_idx]); } } #pragma unroll - for (int ii = 0; ii < kBlockSizeK; ii++) { - myGpuAtomicAdd(&grad_k[batch_idx][l + ii][k * kVecSize], res[ii]); + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + myGpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k * kVecSize], res[k_item_idx]); } } } From 69d1aa8da3aacdea3d4bf32f914a1c6024af9abf Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 20 Apr 2022 08:32:31 -0700 Subject: [PATCH 30/45] clang-format --- .../attention/csrc/cuda/attention.cu | 111 ++++++++++-------- 1 file changed, 61 insertions(+), 50 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 476f5dceb0..04dca78436 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -608,11 +608,7 @@ at::Tensor attention( return res; } -template < - typename scalar_t, - typename vec_t, - int kBlockSizeQ, - int kBlockSizeK> +template __global__ void attention_backward_grad_v_kernel( at::PackedTensorAccessor grad_v, at::PackedTensorAccessor grad_out, @@ -687,7 +683,8 @@ __global__ void attention_backward_grad_v_kernel( for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { sputnik::VectorCompute::Dot( temp_i[q_item_idx], v, &grad_attn_v[q_item_idx][k_item_idx]); - sputnik::VectorCompute::FMA(attn_v[q_item_idx][k_item_idx], temp_i[q_item_idx], &tt); + sputnik::VectorCompute::FMA( + attn_v[q_item_idx][k_item_idx], temp_i[q_item_idx], &tt); } myGpuAtomicAdd(&grad_v[batch_idx][l + k_item_idx][k * kVecSize], tt); } @@ -698,7 +695,8 @@ __global__ void attention_backward_grad_v_kernel( for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { #pragma unroll for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - tmp_sum[q_item_idx] += attn_v[q_item_idx][k_item_idx] * grad_attn_v[q_item_idx][k_item_idx]; + tmp_sum[q_item_idx] += attn_v[q_item_idx][k_item_idx] * + grad_attn_v[q_item_idx][k_item_idx]; } } } @@ -735,8 +733,7 @@ __global__ void attention_backward_grad_qk_kernel( int64_t batch_idx = blockIdx.z; int64_t query_idx = blockIdx.x * blockDim.x * kBlockSizeQ + threadIdx.x * kBlockSizeQ; - int64_t l = - blockIdx.y * blockDim.y * kBlockSizeK + threadIdx.y * kBlockSizeK; + int64_t l = blockIdx.y * blockDim.y * kBlockSizeK + threadIdx.y * kBlockSizeK; if (query_idx >= M) return; @@ -770,8 +767,14 @@ __global__ void attention_backward_grad_qk_kernel( vec_t tt = __ldg(vb + k + K / kVecSize * k_item_idx); #pragma unroll for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - sputnik::VectorCompute::Dot(__ldg(qb + k + K / kVecSize * q_item_idx), kk, &attn_v[q_item_idx][k_item_idx]); - sputnik::VectorCompute::Dot(__ldg(gb + k + K / kVecSize * q_item_idx), tt, &grad_attn_v[q_item_idx][k_item_idx]); + sputnik::VectorCompute::Dot( + __ldg(qb + k + K / kVecSize * q_item_idx), + kk, + &attn_v[q_item_idx][k_item_idx]); + sputnik::VectorCompute::Dot( + __ldg(gb + k + K / kVecSize * q_item_idx), + tt, + &grad_attn_v[q_item_idx][k_item_idx]); } } } @@ -779,7 +782,8 @@ __global__ void attention_backward_grad_qk_kernel( for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { #pragma unroll for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - attn_v[q_item_idx][k_item_idx] = std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]); + attn_v[q_item_idx][k_item_idx] = + std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]); } } @@ -787,46 +791,57 @@ __global__ void attention_backward_grad_qk_kernel( for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { #pragma unroll for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - fact[kBlockSizeQ * threadIdx.x + q_item_idx][kBlockSizeK * threadIdx.y + k_item_idx] = attn_v[q_item_idx][k_item_idx] * grad_attn_v[q_item_idx][k_item_idx] - attn_v[q_item_idx][k_item_idx] * tmp_sum[q_item_idx]; + fact[kBlockSizeQ * threadIdx.x + q_item_idx] + [kBlockSizeK * threadIdx.y + k_item_idx] = + attn_v[q_item_idx][k_item_idx] * + grad_attn_v[q_item_idx][k_item_idx] - + attn_v[q_item_idx][k_item_idx] * tmp_sum[q_item_idx]; } } __syncthreads(); - for (int64_t k = threadIdx.y; k < K / kVecSize; k+= blockDim.y) { - vec_t res[kBlockSizeQ] = {0}; -#pragma unroll - for (int64_t i = 0; i < TILE_SIZEK; i++) { - vec_t kk = __ldg(kb + k + K / kVecSize * (i - kBlockSizeK * threadIdx.y)); + for (int64_t k = threadIdx.y; k < K / kVecSize; k += blockDim.y) { + vec_t res[kBlockSizeQ] = {0}; #pragma unroll - for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - sputnik::VectorCompute::FMA(fact[kBlockSizeQ * threadIdx.x + q_item_idx][i], kk, &res[q_item_idx]); - } - } + for (int64_t i = 0; i < TILE_SIZEK; i++) { + vec_t kk = __ldg(kb + k + K / kVecSize * (i - kBlockSizeK * threadIdx.y)); #pragma unroll for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - myGpuAtomicAdd(&grad_q[batch_idx][query_idx + q_item_idx][k * kVecSize], res[q_item_idx]); + sputnik::VectorCompute::FMA( + fact[kBlockSizeQ * threadIdx.x + q_item_idx][i], + kk, + &res[q_item_idx]); } + } +#pragma unroll + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + myGpuAtomicAdd( + &grad_q[batch_idx][query_idx + q_item_idx][k * kVecSize], + res[q_item_idx]); + } } - for (int64_t k = threadIdx.x; k < K / kVecSize; k+= blockDim.x) { - vec_t res[kBlockSizeK] = {0}; -#pragma unroll - for (int64_t i = 0; i < TILE_SIZEQ; i++) { - vec_t kk = __ldg(qb + k + K / kVecSize * (i - kBlockSizeQ * threadIdx.x)); + for (int64_t k = threadIdx.x; k < K / kVecSize; k += blockDim.x) { + vec_t res[kBlockSizeK] = {0}; #pragma unroll - for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - sputnik::VectorCompute::FMA(fact[i][kBlockSizeK * threadIdx.y + k_item_idx], kk, &res[k_item_idx]); - } - } + for (int64_t i = 0; i < TILE_SIZEQ; i++) { + vec_t kk = __ldg(qb + k + K / kVecSize * (i - kBlockSizeQ * threadIdx.x)); #pragma unroll for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - myGpuAtomicAdd(&grad_k[batch_idx][l + k_item_idx][k * kVecSize], res[k_item_idx]); + sputnik::VectorCompute::FMA( + fact[i][kBlockSizeK * threadIdx.y + k_item_idx], + kk, + &res[k_item_idx]); } + } +#pragma unroll + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + myGpuAtomicAdd( + &grad_k[batch_idx][l + k_item_idx][k * kVecSize], res[k_item_idx]); + } } } - - std::tuple attention_backward( const at::Tensor& grad_out, const at::Tensor& query, @@ -892,19 +907,15 @@ std::tuple attention_backward( cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - attention_backward_grad_v_kernel< - scalar_t, - vec_t, - kBlockSizeQ, - kBlockSizeK><<>>( - grad_v.packed_accessor(), - grad_out.packed_accessor(), - query.packed_accessor(), - key.packed_accessor(), - value.packed_accessor(), - tmp_sum_i.packed_accessor(), - logsumexp.packed_accessor()); - + attention_backward_grad_v_kernel + <<>>( + grad_v.packed_accessor(), + grad_out.packed_accessor(), + query.packed_accessor(), + key.packed_accessor(), + value.packed_accessor(), + tmp_sum_i.packed_accessor(), + logsumexp.packed_accessor()); constexpr int TILE_SIZEQ2 = 32; constexpr int TILE_SIZEK2 = 32; @@ -912,7 +923,8 @@ std::tuple attention_backward( constexpr int64_t kBlockSizeQ2 = 4; constexpr int64_t kBlockSizeK2 = 4; - dim3 grid2(ceil_div(M, int64_t(TILE_SIZEQ2)), ceil_div(N, int64_t(TILE_SIZEK2)), B); + dim3 grid2( + ceil_div(M, int64_t(TILE_SIZEQ2)), ceil_div(N, int64_t(TILE_SIZEK2)), B); dim3 block2(TILE_SIZEQ2 / kBlockSizeQ2, TILE_SIZEK2 / kBlockSizeK2); // TODO: try adding a blockDim.x to iterate over k @@ -932,7 +944,6 @@ std::tuple attention_backward( tmp_sum_i.packed_accessor(), logsumexp.packed_accessor()); - AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(grad_q, grad_k, grad_v); From 220e046f8c68fbb8a78593fa0e8b3759cc24fae4 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 20 Apr 2022 09:41:51 -0700 Subject: [PATCH 31/45] Add alternative implementation for grad_v So far it has exactly the same speed as the previous kernel, but is much more similar to the grad_q and grad_k kernels --- .../attention/csrc/cuda/attention.cu | 149 +++++++++++++++++- 1 file changed, 148 insertions(+), 1 deletion(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 04dca78436..4d3cd44413 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -608,6 +608,137 @@ at::Tensor attention( return res; } +template < + typename scalar_t, + typename vec_t, + int kBlockSizeQ, + int kBlockSizeK, + int TILE_SIZEQ, + int TILE_SIZEK> +__global__ void attention_backward_grad_v2_kernel( + at::PackedTensorAccessor grad_v, + at::PackedTensorAccessor grad_out, + at::PackedTensorAccessor query, + at::PackedTensorAccessor key, + at::PackedTensorAccessor value, + at::PackedTensorAccessor tmp_sum_i, + at::PackedTensorAccessor logsumexp_normalizer) { + int64_t K = query.size(2); + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + + constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); + + int64_t batch_idx = blockIdx.z; + int64_t query_idx = + blockIdx.x * blockDim.x * kBlockSizeQ + threadIdx.x * kBlockSizeQ; + int64_t l = blockIdx.y * blockDim.y * kBlockSizeK + threadIdx.y * kBlockSizeK; + + if (query_idx >= M) + return; + + if (l >= N) + return; + + scalar_t normalizer[kBlockSizeQ]; + scalar_t tmp_sum[kBlockSizeQ] = {0}; + + __shared__ scalar_t fact[TILE_SIZEQ][TILE_SIZEK + 1]; + + //__shared__ scalar_t tmp_sum_shared[kBlockSizeQ][TILE_SIZEK]; + + auto qb = reinterpret_cast(query[batch_idx][query_idx].data()); + auto kb = reinterpret_cast(key[batch_idx][l].data()); + auto vb = reinterpret_cast(value[batch_idx][l].data()); + auto gb = reinterpret_cast(grad_out[batch_idx][query_idx].data()); + + for (int i = 0; i < kBlockSizeQ; i++) { + normalizer[i] = logsumexp_normalizer[batch_idx][query_idx + i]; + } + + auto key_j = reinterpret_cast(key[batch_idx][l].data()); + scalar_t attn_v[kBlockSizeQ][kBlockSizeK] = {0}; + scalar_t grad_attn_v[kBlockSizeQ][kBlockSizeK] = {0}; + + for (int64_t k = 0; k < K / kVecSize; k += 1) { +#pragma unroll + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + vec_t kk = __ldg(kb + k + K / kVecSize * k_item_idx); + vec_t tt = __ldg(vb + k + K / kVecSize * k_item_idx); +#pragma unroll + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + sputnik::VectorCompute::Dot( + __ldg(qb + k + K / kVecSize * q_item_idx), + kk, + &attn_v[q_item_idx][k_item_idx]); + sputnik::VectorCompute::Dot( + __ldg(gb + k + K / kVecSize * q_item_idx), + tt, + &grad_attn_v[q_item_idx][k_item_idx]); + } + } + } +#pragma unroll + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { +#pragma unroll + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + attn_v[q_item_idx][k_item_idx] = + std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]); + } + } + +#pragma unroll + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { +#pragma unroll + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + fact[kBlockSizeQ * threadIdx.x + q_item_idx] + [kBlockSizeK * threadIdx.y + k_item_idx] = + attn_v[q_item_idx][k_item_idx]; + tmp_sum[q_item_idx] += attn_v[q_item_idx][k_item_idx] * grad_attn_v[q_item_idx][k_item_idx]; + } + } + //for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + // tmp_sum_shared[kBlockSizeQ * threadIdx.x + q_item_idx][threadIdx.y] = tmp_sum[q_item_idx]; + //} + __syncthreads(); + + for (int64_t k = threadIdx.x; k < K / kVecSize; k += blockDim.x) { + vec_t res[kBlockSizeK] = {0}; +#pragma unroll + for (int64_t i = 0; i < TILE_SIZEQ; i++) { + vec_t kk = __ldg(gb + k + K / kVecSize * (i - kBlockSizeQ * threadIdx.x)); +#pragma unroll + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + sputnik::VectorCompute::FMA( + fact[i][kBlockSizeK * threadIdx.y + k_item_idx], + kk, + &res[k_item_idx]); + } + } +#pragma unroll + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + myGpuAtomicAdd( + &grad_v[batch_idx][l + k_item_idx][k * kVecSize], res[k_item_idx]); + } + } + /* + if (threadIdx.y == 0) { + for (int i = threadIdx.x; i < TILE_SIZEQ; i += blockDim.x) { + scalar_t tmp = 0; + for (int k = 0; k < blockDim.y; k++) + { + tmp += tmp_sum_shared[i][k]; + } + myGpuAtomicAdd(&tmp_sum_i[batch_idx][query_idx + i - kBlockSizeQ * threadIdx.x], tmp); + } + } + */ + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + myGpuAtomicAdd(&tmp_sum_i[batch_idx][query_idx + q_item_idx], tmp_sum[q_item_idx]); + } +} + template __global__ void attention_backward_grad_v_kernel( at::PackedTensorAccessor grad_v, @@ -907,6 +1038,7 @@ std::tuple attention_backward( cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + /* attention_backward_grad_v_kernel <<>>( grad_v.packed_accessor(), @@ -915,7 +1047,7 @@ std::tuple attention_backward( key.packed_accessor(), value.packed_accessor(), tmp_sum_i.packed_accessor(), - logsumexp.packed_accessor()); + logsumexp.packed_accessor());*/ constexpr int TILE_SIZEQ2 = 32; constexpr int TILE_SIZEK2 = 32; @@ -928,6 +1060,21 @@ std::tuple attention_backward( dim3 block2(TILE_SIZEQ2 / kBlockSizeQ2, TILE_SIZEK2 / kBlockSizeK2); // TODO: try adding a blockDim.x to iterate over k + attention_backward_grad_v2_kernel< + scalar_t, + vec_t, + kBlockSizeQ2, + kBlockSizeK2, + TILE_SIZEQ2, + TILE_SIZEK2><<>>( + grad_v.packed_accessor(), + grad_out.packed_accessor(), + query.packed_accessor(), + key.packed_accessor(), + value.packed_accessor(), + tmp_sum_i.packed_accessor(), + logsumexp.packed_accessor()); + attention_backward_grad_qk_kernel< scalar_t, vec_t, From 0b38bf172de5fd983ad1c2055a1e1c295fbed416 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 20 Apr 2022 10:21:23 -0700 Subject: [PATCH 32/45] Speed it up by 10% with better hyperparameters --- .../attention/csrc/cuda/attention.cu | 45 ++++++++++--------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 4d3cd44413..795d130c37 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -1027,14 +1027,19 @@ std::tuple attention_backward( using vec_t = float4; // using vec_t = float; - constexpr int TILE_SIZE = 16 * 8; + constexpr int TILE_SIZEQ = 32; + constexpr int TILE_SIZEK = 32; constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); - constexpr int64_t kBlockSizeQ = 16; - constexpr int64_t kBlockSizeK = 4; + constexpr int64_t kBlockSizeQ = 4; + constexpr int64_t kBlockSizeK = 8; - dim3 grid(ceil_div(M, int64_t(TILE_SIZE)), B); - dim3 block(32, TILE_SIZE / kBlockSizeQ); + //dim3 grid(ceil_div(M, int64_t(TILE_SIZE)), B); + //dim3 block(32, TILE_SIZE / kBlockSizeQ); + + dim3 grid( + ceil_div(M, int64_t(TILE_SIZEQ)), ceil_div(N, int64_t(TILE_SIZEK)), B); + dim3 block(TILE_SIZEQ / kBlockSizeQ, TILE_SIZEK / kBlockSizeK); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -1049,24 +1054,13 @@ std::tuple attention_backward( tmp_sum_i.packed_accessor(), logsumexp.packed_accessor());*/ - constexpr int TILE_SIZEQ2 = 32; - constexpr int TILE_SIZEK2 = 32; - - constexpr int64_t kBlockSizeQ2 = 4; - constexpr int64_t kBlockSizeK2 = 4; - - dim3 grid2( - ceil_div(M, int64_t(TILE_SIZEQ2)), ceil_div(N, int64_t(TILE_SIZEK2)), B); - dim3 block2(TILE_SIZEQ2 / kBlockSizeQ2, TILE_SIZEK2 / kBlockSizeK2); - // TODO: try adding a blockDim.x to iterate over k - attention_backward_grad_v2_kernel< scalar_t, vec_t, - kBlockSizeQ2, - kBlockSizeK2, - TILE_SIZEQ2, - TILE_SIZEK2><<>>( + kBlockSizeQ, + kBlockSizeK, + TILE_SIZEQ, + TILE_SIZEK><<>>( grad_v.packed_accessor(), grad_out.packed_accessor(), query.packed_accessor(), @@ -1075,6 +1069,17 @@ std::tuple attention_backward( tmp_sum_i.packed_accessor(), logsumexp.packed_accessor()); + constexpr int TILE_SIZEQ2 = 32; + constexpr int TILE_SIZEK2 = 32; + + constexpr int64_t kBlockSizeQ2 = 4; + constexpr int64_t kBlockSizeK2 = 4; + + dim3 grid2( + ceil_div(M, int64_t(TILE_SIZEQ2)), ceil_div(N, int64_t(TILE_SIZEK2)), B); + dim3 block2(TILE_SIZEQ2 / kBlockSizeQ2, TILE_SIZEK2 / kBlockSizeK2); + // TODO: try adding a blockDim.x to iterate over k + attention_backward_grad_qk_kernel< scalar_t, vec_t, From a7d1eac21fe5a815c482bc60debcad23c2af4a98 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 20 Apr 2022 11:34:27 -0700 Subject: [PATCH 33/45] Delete old implementation --- .../attention/csrc/cuda/attention.cu | 134 +----------------- 1 file changed, 2 insertions(+), 132 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 795d130c37..5ba02baa83 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -615,7 +615,7 @@ template < int kBlockSizeK, int TILE_SIZEQ, int TILE_SIZEK> -__global__ void attention_backward_grad_v2_kernel( +__global__ void attention_backward_grad_v_kernel( at::PackedTensorAccessor grad_v, at::PackedTensorAccessor grad_out, at::PackedTensorAccessor query, @@ -646,8 +646,6 @@ __global__ void attention_backward_grad_v2_kernel( __shared__ scalar_t fact[TILE_SIZEQ][TILE_SIZEK + 1]; - //__shared__ scalar_t tmp_sum_shared[kBlockSizeQ][TILE_SIZEK]; - auto qb = reinterpret_cast(query[batch_idx][query_idx].data()); auto kb = reinterpret_cast(key[batch_idx][l].data()); auto vb = reinterpret_cast(value[batch_idx][l].data()); @@ -698,9 +696,6 @@ __global__ void attention_backward_grad_v2_kernel( tmp_sum[q_item_idx] += attn_v[q_item_idx][k_item_idx] * grad_attn_v[q_item_idx][k_item_idx]; } } - //for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - // tmp_sum_shared[kBlockSizeQ * threadIdx.x + q_item_idx][threadIdx.y] = tmp_sum[q_item_idx]; - //} __syncthreads(); for (int64_t k = threadIdx.x; k < K / kVecSize; k += blockDim.x) { @@ -722,122 +717,11 @@ __global__ void attention_backward_grad_v2_kernel( &grad_v[batch_idx][l + k_item_idx][k * kVecSize], res[k_item_idx]); } } - /* - if (threadIdx.y == 0) { - for (int i = threadIdx.x; i < TILE_SIZEQ; i += blockDim.x) { - scalar_t tmp = 0; - for (int k = 0; k < blockDim.y; k++) - { - tmp += tmp_sum_shared[i][k]; - } - myGpuAtomicAdd(&tmp_sum_i[batch_idx][query_idx + i - kBlockSizeQ * threadIdx.x], tmp); - } - } - */ for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { myGpuAtomicAdd(&tmp_sum_i[batch_idx][query_idx + q_item_idx], tmp_sum[q_item_idx]); } } -template -__global__ void attention_backward_grad_v_kernel( - at::PackedTensorAccessor grad_v, - at::PackedTensorAccessor grad_out, - at::PackedTensorAccessor query, - at::PackedTensorAccessor key, - at::PackedTensorAccessor value, - at::PackedTensorAccessor tmp_sum_i, - at::PackedTensorAccessor logsumexp_normalizer) { - int64_t K = query.size(2); - int64_t B = query.size(0); - int64_t M = query.size(1); - int64_t N = key.size(1); - - constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); - - int64_t batch_idx = blockIdx.y; - int64_t query_idx = - blockIdx.x * blockDim.y * kBlockSizeQ + threadIdx.y * kBlockSizeQ; - - if (query_idx >= M) - return; - - vec_t* query_block[kBlockSizeQ]; - vec_t* grad_out_block[kBlockSizeQ]; - scalar_t normalizer[kBlockSizeQ]; - - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - int64_t index = query_idx + q_item_idx; - index = index >= M ? M - 1 : index; - query_block[q_item_idx] = - reinterpret_cast(query[batch_idx][index].data()); - grad_out_block[q_item_idx] = - reinterpret_cast(grad_out[batch_idx][index].data()); - normalizer[q_item_idx] = logsumexp_normalizer[batch_idx][index]; - } - - scalar_t tmp_sum[kBlockSizeQ] = {0}; - for (int64_t l = threadIdx.x * kBlockSizeK; l < N; - l += blockDim.x * kBlockSizeK) { - auto key_j = reinterpret_cast(key[batch_idx][l].data()); - - scalar_t attn_v[kBlockSizeQ][kBlockSizeK] = {0}; - compute_dot( - query_block, key_j, attn_v, K); - -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { -#pragma unroll - for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - attn_v[q_item_idx][k_item_idx] = - std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]); - } - } - - // first compute the gradient for the self-attention - // after softmax - scalar_t grad_attn_v[kBlockSizeQ][kBlockSizeK] = {0}; - auto value_j = reinterpret_cast(value[batch_idx][l].data()); - - for (int64_t k = 0; k < K / kVecSize; k++) { - vec_t temp_i[kBlockSizeQ]; -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - temp_i[q_item_idx] = __ldg(grad_out_block[q_item_idx] + k); - } - -#pragma unroll - for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - vec_t v = value_j[k + K / kVecSize * k_item_idx]; - vec_t tt = {0}; -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - sputnik::VectorCompute::Dot( - temp_i[q_item_idx], v, &grad_attn_v[q_item_idx][k_item_idx]); - sputnik::VectorCompute::FMA( - attn_v[q_item_idx][k_item_idx], temp_i[q_item_idx], &tt); - } - myGpuAtomicAdd(&grad_v[batch_idx][l + k_item_idx][k * kVecSize], tt); - } - } - - // those are temporaries for the gradient of the softmax -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { -#pragma unroll - for (int64_t k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - tmp_sum[q_item_idx] += attn_v[q_item_idx][k_item_idx] * - grad_attn_v[q_item_idx][k_item_idx]; - } - } - } -#pragma unroll - for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - tmp_sum[q_item_idx] = warpSum(tmp_sum[q_item_idx]); - tmp_sum_i[batch_idx][query_idx + q_item_idx] = tmp_sum[q_item_idx]; - } -} - template < typename scalar_t, typename vec_t, @@ -1034,27 +918,13 @@ std::tuple attention_backward( constexpr int64_t kBlockSizeQ = 4; constexpr int64_t kBlockSizeK = 8; - //dim3 grid(ceil_div(M, int64_t(TILE_SIZE)), B); - //dim3 block(32, TILE_SIZE / kBlockSizeQ); - dim3 grid( ceil_div(M, int64_t(TILE_SIZEQ)), ceil_div(N, int64_t(TILE_SIZEK)), B); dim3 block(TILE_SIZEQ / kBlockSizeQ, TILE_SIZEK / kBlockSizeK); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - /* - attention_backward_grad_v_kernel - <<>>( - grad_v.packed_accessor(), - grad_out.packed_accessor(), - query.packed_accessor(), - key.packed_accessor(), - value.packed_accessor(), - tmp_sum_i.packed_accessor(), - logsumexp.packed_accessor());*/ - - attention_backward_grad_v2_kernel< + attention_backward_grad_v_kernel< scalar_t, vec_t, kBlockSizeQ, From 99a6418e221ddb92925edac0ccda45e78061bf12 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 20 Apr 2022 23:40:53 -0700 Subject: [PATCH 34/45] Centralize all input accesses in the beginning This will make it easier to support inputs which are not multiple of 32. Plus, this seems to give a small performance improvement, in the order of 1% --- .../attention/csrc/cuda/attention.cu | 70 +++++++++++++------ 1 file changed, 49 insertions(+), 21 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 5ba02baa83..da6acf215c 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -646,32 +646,43 @@ __global__ void attention_backward_grad_v_kernel( __shared__ scalar_t fact[TILE_SIZEQ][TILE_SIZEK + 1]; - auto qb = reinterpret_cast(query[batch_idx][query_idx].data()); - auto kb = reinterpret_cast(key[batch_idx][l].data()); - auto vb = reinterpret_cast(value[batch_idx][l].data()); - auto gb = reinterpret_cast(grad_out[batch_idx][query_idx].data()); + vec_t* qb[kBlockSizeQ], *kb[kBlockSizeK], *vb[kBlockSizeK], *gb[kBlockSizeQ], *gbb[TILE_SIZEQ]; + + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + kb[k_item_idx] = reinterpret_cast(key[batch_idx][l + k_item_idx].data()); + vb[k_item_idx] = reinterpret_cast(value[batch_idx][l + k_item_idx].data()); + } + + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + qb[q_item_idx] = reinterpret_cast(query[batch_idx][query_idx + q_item_idx].data()); + gb[q_item_idx] = reinterpret_cast(grad_out[batch_idx][query_idx + q_item_idx].data()); + + } + for (int64_t i = 0; i < TILE_SIZEQ; i++) { + int64_t idx = i - kBlockSizeQ * threadIdx.x; + gbb[i] = reinterpret_cast(grad_out[batch_idx][query_idx + idx].data()); + } for (int i = 0; i < kBlockSizeQ; i++) { normalizer[i] = logsumexp_normalizer[batch_idx][query_idx + i]; } - auto key_j = reinterpret_cast(key[batch_idx][l].data()); scalar_t attn_v[kBlockSizeQ][kBlockSizeK] = {0}; scalar_t grad_attn_v[kBlockSizeQ][kBlockSizeK] = {0}; for (int64_t k = 0; k < K / kVecSize; k += 1) { #pragma unroll for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - vec_t kk = __ldg(kb + k + K / kVecSize * k_item_idx); - vec_t tt = __ldg(vb + k + K / kVecSize * k_item_idx); + vec_t kk = __ldg(kb[k_item_idx] + k); + vec_t tt = __ldg(vb[k_item_idx] + k); #pragma unroll for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { sputnik::VectorCompute::Dot( - __ldg(qb + k + K / kVecSize * q_item_idx), + __ldg(qb[q_item_idx] + k), kk, &attn_v[q_item_idx][k_item_idx]); sputnik::VectorCompute::Dot( - __ldg(gb + k + K / kVecSize * q_item_idx), + __ldg(gb[q_item_idx] + k), tt, &grad_attn_v[q_item_idx][k_item_idx]); } @@ -702,7 +713,7 @@ __global__ void attention_backward_grad_v_kernel( vec_t res[kBlockSizeK] = {0}; #pragma unroll for (int64_t i = 0; i < TILE_SIZEQ; i++) { - vec_t kk = __ldg(gb + k + K / kVecSize * (i - kBlockSizeQ * threadIdx.x)); + vec_t kk = __ldg(gbb[i] + k); #pragma unroll for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { sputnik::VectorCompute::FMA( @@ -761,33 +772,50 @@ __global__ void attention_backward_grad_qk_kernel( __shared__ scalar_t fact[TILE_SIZEQ][TILE_SIZEK + 1]; - auto qb = reinterpret_cast(query[batch_idx][query_idx].data()); - auto kb = reinterpret_cast(key[batch_idx][l].data()); - auto vb = reinterpret_cast(value[batch_idx][l].data()); - auto gb = reinterpret_cast(grad_out[batch_idx][query_idx].data()); + + vec_t* qb[kBlockSizeQ], *kb[kBlockSizeK], *vb[kBlockSizeK], *gb[kBlockSizeQ], *qbb[TILE_SIZEQ], *kbb[TILE_SIZEK]; + + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + kb[k_item_idx] = reinterpret_cast(key[batch_idx][l + k_item_idx].data()); + vb[k_item_idx] = reinterpret_cast(value[batch_idx][l + k_item_idx].data()); + } + + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + qb[q_item_idx] = reinterpret_cast(query[batch_idx][query_idx + q_item_idx].data()); + gb[q_item_idx] = reinterpret_cast(grad_out[batch_idx][query_idx + q_item_idx].data()); + + } + for (int64_t i = 0; i < TILE_SIZEQ; i++) { + int64_t idx = i - kBlockSizeQ * threadIdx.x; + qbb[i] = reinterpret_cast(query[batch_idx][query_idx + idx].data()); + } + + for (int64_t i = 0; i < TILE_SIZEK; i++) { + int64_t idx = i - kBlockSizeK * threadIdx.y; + kbb[i] = reinterpret_cast(key[batch_idx][l + idx].data()); + } for (int i = 0; i < kBlockSizeQ; i++) { normalizer[i] = logsumexp_normalizer[batch_idx][query_idx + i]; tmp_sum[i] = tmp_sum_i[batch_idx][query_idx + i]; } - auto key_j = reinterpret_cast(key[batch_idx][l].data()); scalar_t attn_v[kBlockSizeQ][kBlockSizeK] = {0}; scalar_t grad_attn_v[kBlockSizeQ][kBlockSizeK] = {0}; for (int64_t k = 0; k < K / kVecSize; k += 1) { #pragma unroll for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - vec_t kk = __ldg(kb + k + K / kVecSize * k_item_idx); - vec_t tt = __ldg(vb + k + K / kVecSize * k_item_idx); + vec_t kk = __ldg(kb[k_item_idx] + k); + vec_t tt = __ldg(vb[k_item_idx] + k); #pragma unroll for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { sputnik::VectorCompute::Dot( - __ldg(qb + k + K / kVecSize * q_item_idx), + __ldg(qb[q_item_idx] + k), kk, &attn_v[q_item_idx][k_item_idx]); sputnik::VectorCompute::Dot( - __ldg(gb + k + K / kVecSize * q_item_idx), + __ldg(gb[q_item_idx] + k), tt, &grad_attn_v[q_item_idx][k_item_idx]); } @@ -819,7 +847,7 @@ __global__ void attention_backward_grad_qk_kernel( vec_t res[kBlockSizeQ] = {0}; #pragma unroll for (int64_t i = 0; i < TILE_SIZEK; i++) { - vec_t kk = __ldg(kb + k + K / kVecSize * (i - kBlockSizeK * threadIdx.y)); + vec_t kk = __ldg(kbb[i] + k); #pragma unroll for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { sputnik::VectorCompute::FMA( @@ -840,7 +868,7 @@ __global__ void attention_backward_grad_qk_kernel( vec_t res[kBlockSizeK] = {0}; #pragma unroll for (int64_t i = 0; i < TILE_SIZEQ; i++) { - vec_t kk = __ldg(qb + k + K / kVecSize * (i - kBlockSizeQ * threadIdx.x)); + vec_t kk = __ldg(qbb[i] + k); #pragma unroll for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { sputnik::VectorCompute::FMA( From 5bf4431d59d8813e08a27592e0cfb33d874e067b Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 21 Apr 2022 08:58:29 -0700 Subject: [PATCH 35/45] Bugfix Only shows up for certain sizes for some reason --- xformers/components/attention/csrc/cuda/attention.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index da6acf215c..dee5a42699 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -933,7 +933,7 @@ std::tuple attention_backward( at::Tensor grad_k = at::zeros_like(key); at::Tensor grad_v = at::zeros_like(value); - at::Tensor tmp_sum_i = at::empty({B, M}, query.options()); + at::Tensor tmp_sum_i = at::zeros({B, M}, query.options()); using scalar_t = float; using vec_t = float4; From 222c13664d02cab9acccdc964ad56edb6a9d8006 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 22 Apr 2022 02:09:53 -0700 Subject: [PATCH 36/45] Make kernels generic wrt sequence length This introduces a slowdown of 25%, mostly due to the index computation in the preamble of each kernel. In a next commit I'll try to optimize this out --- .../attention/csrc/cuda/attention.cu | 108 ++++++++++++------ 1 file changed, 71 insertions(+), 37 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index dee5a42699..2bcdd52c94 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -635,36 +635,49 @@ __global__ void attention_backward_grad_v_kernel( blockIdx.x * blockDim.x * kBlockSizeQ + threadIdx.x * kBlockSizeQ; int64_t l = blockIdx.y * blockDim.y * kBlockSizeK + threadIdx.y * kBlockSizeK; - if (query_idx >= M) - return; + __shared__ scalar_t fact[TILE_SIZEQ][TILE_SIZEK + 1]; - if (l >= N) - return; +#pragma unroll + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { +#pragma unroll + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + fact[kBlockSizeQ * threadIdx.x + q_item_idx] + [kBlockSizeK * threadIdx.y + k_item_idx] = 0; + } + } scalar_t normalizer[kBlockSizeQ]; scalar_t tmp_sum[kBlockSizeQ] = {0}; - __shared__ scalar_t fact[TILE_SIZEQ][TILE_SIZEK + 1]; - vec_t* qb[kBlockSizeQ], *kb[kBlockSizeK], *vb[kBlockSizeK], *gb[kBlockSizeQ], *gbb[TILE_SIZEQ]; + scalar_t maskQ[kBlockSizeQ], maskK[kBlockSizeK]; for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - kb[k_item_idx] = reinterpret_cast(key[batch_idx][l + k_item_idx].data()); - vb[k_item_idx] = reinterpret_cast(value[batch_idx][l + k_item_idx].data()); + int64_t index = l + k_item_idx; + maskK[k_item_idx] = index >= N ? scalar_t(0) : scalar_t(1); + index = index >= N ? N - 1 : index; + kb[k_item_idx] = reinterpret_cast(key[batch_idx][index].data()); + vb[k_item_idx] = reinterpret_cast(value[batch_idx][index].data()); } for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - qb[q_item_idx] = reinterpret_cast(query[batch_idx][query_idx + q_item_idx].data()); - gb[q_item_idx] = reinterpret_cast(grad_out[batch_idx][query_idx + q_item_idx].data()); - + int64_t index = query_idx + q_item_idx; + maskQ[q_item_idx] = index >= M ? scalar_t(0) : scalar_t(1); + index = index >= M ? M - 1 : index; + qb[q_item_idx] = reinterpret_cast(query[batch_idx][index].data()); + gb[q_item_idx] = reinterpret_cast(grad_out[batch_idx][index].data()); } + for (int64_t i = 0; i < TILE_SIZEQ; i++) { - int64_t idx = i - kBlockSizeQ * threadIdx.x; - gbb[i] = reinterpret_cast(grad_out[batch_idx][query_idx + idx].data()); + int64_t index = query_idx + i - kBlockSizeQ * threadIdx.x; + index = index >= M ? M - 1 : index; + gbb[i] = reinterpret_cast(grad_out[batch_idx][index].data()); } for (int i = 0; i < kBlockSizeQ; i++) { - normalizer[i] = logsumexp_normalizer[batch_idx][query_idx + i]; + int64_t index = query_idx + i; + index = index >= M ? M - 1 : index; + normalizer[i] = logsumexp_normalizer[batch_idx][index]; } scalar_t attn_v[kBlockSizeQ][kBlockSizeK] = {0}; @@ -693,7 +706,7 @@ __global__ void attention_backward_grad_v_kernel( #pragma unroll for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { attn_v[q_item_idx][k_item_idx] = - std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]); + std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]) * maskQ[q_item_idx] * maskK[k_item_idx]; } } @@ -724,12 +737,16 @@ __global__ void attention_backward_grad_v_kernel( } #pragma unroll for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + int64_t index = l + k_item_idx; + index = index >= N ? N - 1 : index; myGpuAtomicAdd( - &grad_v[batch_idx][l + k_item_idx][k * kVecSize], res[k_item_idx]); + &grad_v[batch_idx][index][k * kVecSize], res[k_item_idx]); } } for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - myGpuAtomicAdd(&tmp_sum_i[batch_idx][query_idx + q_item_idx], tmp_sum[q_item_idx]); + int64_t index = query_idx + q_item_idx; + index = index >= M ? M - 1 : index; + myGpuAtomicAdd(&tmp_sum_i[batch_idx][index], tmp_sum[q_item_idx]); } } @@ -761,43 +778,56 @@ __global__ void attention_backward_grad_qk_kernel( blockIdx.x * blockDim.x * kBlockSizeQ + threadIdx.x * kBlockSizeQ; int64_t l = blockIdx.y * blockDim.y * kBlockSizeK + threadIdx.y * kBlockSizeK; - if (query_idx >= M) - return; + __shared__ scalar_t fact[TILE_SIZEQ][TILE_SIZEK + 1]; - if (l >= N) - return; +#pragma unroll + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { +#pragma unroll + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + fact[kBlockSizeQ * threadIdx.x + q_item_idx] + [kBlockSizeK * threadIdx.y + k_item_idx] = 0; + } + } scalar_t normalizer[kBlockSizeQ]; scalar_t tmp_sum[kBlockSizeQ]; - __shared__ scalar_t fact[TILE_SIZEQ][TILE_SIZEK + 1]; - - vec_t* qb[kBlockSizeQ], *kb[kBlockSizeK], *vb[kBlockSizeK], *gb[kBlockSizeQ], *qbb[TILE_SIZEQ], *kbb[TILE_SIZEK]; + scalar_t maskQ[kBlockSizeQ], maskK[kBlockSizeK]; for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { - kb[k_item_idx] = reinterpret_cast(key[batch_idx][l + k_item_idx].data()); - vb[k_item_idx] = reinterpret_cast(value[batch_idx][l + k_item_idx].data()); + int64_t index = l + k_item_idx; + maskK[k_item_idx] = index >= N ? scalar_t(0) : scalar_t(1); + index = index >= N ? N - 1 : index; + kb[k_item_idx] = reinterpret_cast(key[batch_idx][index].data()); + vb[k_item_idx] = reinterpret_cast(value[batch_idx][index].data()); } for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - qb[q_item_idx] = reinterpret_cast(query[batch_idx][query_idx + q_item_idx].data()); - gb[q_item_idx] = reinterpret_cast(grad_out[batch_idx][query_idx + q_item_idx].data()); + int64_t index = query_idx + q_item_idx; + maskQ[q_item_idx] = index >= M ? scalar_t(0) : scalar_t(1); + index = index >= M ? M - 1 : index; + qb[q_item_idx] = reinterpret_cast(query[batch_idx][index].data()); + gb[q_item_idx] = reinterpret_cast(grad_out[batch_idx][index].data()); } for (int64_t i = 0; i < TILE_SIZEQ; i++) { - int64_t idx = i - kBlockSizeQ * threadIdx.x; - qbb[i] = reinterpret_cast(query[batch_idx][query_idx + idx].data()); + int64_t index = query_idx + i - kBlockSizeQ * threadIdx.x; + index = index >= M ? M - 1 : index; + qbb[i] = reinterpret_cast(query[batch_idx][index].data()); } for (int64_t i = 0; i < TILE_SIZEK; i++) { - int64_t idx = i - kBlockSizeK * threadIdx.y; - kbb[i] = reinterpret_cast(key[batch_idx][l + idx].data()); + int64_t index = l + i - kBlockSizeK * threadIdx.y; + index = index >= N ? N - 1 : index; + kbb[i] = reinterpret_cast(key[batch_idx][index].data()); } for (int i = 0; i < kBlockSizeQ; i++) { - normalizer[i] = logsumexp_normalizer[batch_idx][query_idx + i]; - tmp_sum[i] = tmp_sum_i[batch_idx][query_idx + i]; + int64_t index = query_idx + i; + index = index >= M ? M - 1 : index; + normalizer[i] = logsumexp_normalizer[batch_idx][index]; + tmp_sum[i] = tmp_sum_i[batch_idx][index]; } scalar_t attn_v[kBlockSizeQ][kBlockSizeK] = {0}; @@ -826,7 +856,7 @@ __global__ void attention_backward_grad_qk_kernel( #pragma unroll for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { attn_v[q_item_idx][k_item_idx] = - std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]); + std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]) * maskQ[q_item_idx] * maskK[k_item_idx]; } } @@ -858,8 +888,10 @@ __global__ void attention_backward_grad_qk_kernel( } #pragma unroll for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + int64_t index = query_idx + q_item_idx; + index = index >= M ? M - 1 : index; myGpuAtomicAdd( - &grad_q[batch_idx][query_idx + q_item_idx][k * kVecSize], + &grad_q[batch_idx][index][k * kVecSize], res[q_item_idx]); } } @@ -879,8 +911,10 @@ __global__ void attention_backward_grad_qk_kernel( } #pragma unroll for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + int64_t index = l + k_item_idx; + index = index >= N ? N - 1 : index; myGpuAtomicAdd( - &grad_k[batch_idx][l + k_item_idx][k * kVecSize], res[k_item_idx]); + &grad_k[batch_idx][index][k * kVecSize], res[k_item_idx]); } } } From 011b2cd55c9d243cc4d312ba918e533f2692720b Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 22 Apr 2022 05:20:48 -0700 Subject: [PATCH 37/45] Add template argument to skip bound checking Brings back speed to where it was, for the cases where we can safely skip this --- .../attention/csrc/cuda/attention.cu | 49 ++++++++++++------- 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 2bcdd52c94..6f18c4381b 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -614,7 +614,8 @@ template < int kBlockSizeQ, int kBlockSizeK, int TILE_SIZEQ, - int TILE_SIZEK> + int TILE_SIZEK, + bool check_bounds> __global__ void attention_backward_grad_v_kernel( at::PackedTensorAccessor grad_v, at::PackedTensorAccessor grad_out, @@ -655,7 +656,8 @@ __global__ void attention_backward_grad_v_kernel( for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { int64_t index = l + k_item_idx; maskK[k_item_idx] = index >= N ? scalar_t(0) : scalar_t(1); - index = index >= N ? N - 1 : index; + if (check_bounds) + index = min(index, N - 1); kb[k_item_idx] = reinterpret_cast(key[batch_idx][index].data()); vb[k_item_idx] = reinterpret_cast(value[batch_idx][index].data()); } @@ -663,20 +665,23 @@ __global__ void attention_backward_grad_v_kernel( for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { int64_t index = query_idx + q_item_idx; maskQ[q_item_idx] = index >= M ? scalar_t(0) : scalar_t(1); - index = index >= M ? M - 1 : index; + if (check_bounds) + index = min(index, M - 1); qb[q_item_idx] = reinterpret_cast(query[batch_idx][index].data()); gb[q_item_idx] = reinterpret_cast(grad_out[batch_idx][index].data()); } for (int64_t i = 0; i < TILE_SIZEQ; i++) { int64_t index = query_idx + i - kBlockSizeQ * threadIdx.x; - index = index >= M ? M - 1 : index; + if (check_bounds) + index = min(index, M - 1); gbb[i] = reinterpret_cast(grad_out[batch_idx][index].data()); } for (int i = 0; i < kBlockSizeQ; i++) { int64_t index = query_idx + i; - index = index >= M ? M - 1 : index; + if (check_bounds) + index = min(index, M - 1); normalizer[i] = logsumexp_normalizer[batch_idx][index]; } @@ -738,14 +743,16 @@ __global__ void attention_backward_grad_v_kernel( #pragma unroll for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { int64_t index = l + k_item_idx; - index = index >= N ? N - 1 : index; + if (check_bounds) + index = min(index, N - 1); myGpuAtomicAdd( &grad_v[batch_idx][index][k * kVecSize], res[k_item_idx]); } } for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { int64_t index = query_idx + q_item_idx; - index = index >= M ? M - 1 : index; + if (check_bounds) + index = min(index, M - 1); myGpuAtomicAdd(&tmp_sum_i[batch_idx][index], tmp_sum[q_item_idx]); } } @@ -756,7 +763,8 @@ template < int kBlockSizeQ, int kBlockSizeK, int TILE_SIZEQ, - int TILE_SIZEK> + int TILE_SIZEK, + bool check_bounds> __global__ void attention_backward_grad_qk_kernel( at::PackedTensorAccessor grad_q, at::PackedTensorAccessor grad_k, @@ -798,7 +806,8 @@ __global__ void attention_backward_grad_qk_kernel( for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { int64_t index = l + k_item_idx; maskK[k_item_idx] = index >= N ? scalar_t(0) : scalar_t(1); - index = index >= N ? N - 1 : index; + if (check_bounds) + index = min(index, N - 1); kb[k_item_idx] = reinterpret_cast(key[batch_idx][index].data()); vb[k_item_idx] = reinterpret_cast(value[batch_idx][index].data()); } @@ -806,26 +815,30 @@ __global__ void attention_backward_grad_qk_kernel( for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { int64_t index = query_idx + q_item_idx; maskQ[q_item_idx] = index >= M ? scalar_t(0) : scalar_t(1); - index = index >= M ? M - 1 : index; + if (check_bounds) + index = min(index, M - 1); qb[q_item_idx] = reinterpret_cast(query[batch_idx][index].data()); gb[q_item_idx] = reinterpret_cast(grad_out[batch_idx][index].data()); } for (int64_t i = 0; i < TILE_SIZEQ; i++) { int64_t index = query_idx + i - kBlockSizeQ * threadIdx.x; - index = index >= M ? M - 1 : index; + if (check_bounds) + index = min(index, M - 1); qbb[i] = reinterpret_cast(query[batch_idx][index].data()); } for (int64_t i = 0; i < TILE_SIZEK; i++) { int64_t index = l + i - kBlockSizeK * threadIdx.y; - index = index >= N ? N - 1 : index; + if (check_bounds) + index = min(index, N - 1); kbb[i] = reinterpret_cast(key[batch_idx][index].data()); } for (int i = 0; i < kBlockSizeQ; i++) { int64_t index = query_idx + i; - index = index >= M ? M - 1 : index; + if (check_bounds) + index = min(index, M - 1); normalizer[i] = logsumexp_normalizer[batch_idx][index]; tmp_sum[i] = tmp_sum_i[batch_idx][index]; } @@ -889,7 +902,8 @@ __global__ void attention_backward_grad_qk_kernel( #pragma unroll for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { int64_t index = query_idx + q_item_idx; - index = index >= M ? M - 1 : index; + if (check_bounds) + index = min(index, M - 1); myGpuAtomicAdd( &grad_q[batch_idx][index][k * kVecSize], res[q_item_idx]); @@ -912,7 +926,8 @@ __global__ void attention_backward_grad_qk_kernel( #pragma unroll for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { int64_t index = l + k_item_idx; - index = index >= N ? N - 1 : index; + if (check_bounds) + index = min(index, N - 1); myGpuAtomicAdd( &grad_k[batch_idx][index][k * kVecSize], res[k_item_idx]); } @@ -992,7 +1007,7 @@ std::tuple attention_backward( kBlockSizeQ, kBlockSizeK, TILE_SIZEQ, - TILE_SIZEK><<>>( + TILE_SIZEK, true><<>>( grad_v.packed_accessor(), grad_out.packed_accessor(), query.packed_accessor(), @@ -1018,7 +1033,7 @@ std::tuple attention_backward( kBlockSizeQ2, kBlockSizeK2, TILE_SIZEQ2, - TILE_SIZEK2><<>>( + TILE_SIZEK2, true><<>>( grad_q.packed_accessor(), grad_k.packed_accessor(), grad_out.packed_accessor(), From 2552da098c19255991d371488b9ea785080e7917 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 22 Apr 2022 05:46:14 -0700 Subject: [PATCH 38/45] Make it support all use-cases --- .../attention/csrc/cuda/attention.cu | 193 +++++++++++++----- 1 file changed, 137 insertions(+), 56 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 6f18c4381b..bed355628c 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -934,6 +934,114 @@ __global__ void attention_backward_grad_qk_kernel( } } +template +void launch_attention_backward( + at::Tensor& grad_q, + at::Tensor& grad_k, + at::Tensor& grad_v, + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& logsumexp, + at::Tensor& tmp_sum_i +) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + + constexpr int TILE_SIZEQ = 32; + constexpr int TILE_SIZEK = 32; + + constexpr int64_t kBlockSizeQ = 4; + constexpr int64_t kBlockSizeK = 8; + + dim3 grid( + ceil_div(M, int64_t(TILE_SIZEQ)), ceil_div(N, int64_t(TILE_SIZEK)), B); + dim3 block(TILE_SIZEQ / kBlockSizeQ, TILE_SIZEK / kBlockSizeK); + + constexpr int TILE_SIZEQ2 = 32; + constexpr int TILE_SIZEK2 = 32; + + constexpr int64_t kBlockSizeQ2 = 4; + constexpr int64_t kBlockSizeK2 = 4; + + dim3 grid2( + ceil_div(M, int64_t(TILE_SIZEQ2)), ceil_div(N, int64_t(TILE_SIZEK2)), B); + dim3 block2(TILE_SIZEQ2 / kBlockSizeQ2, TILE_SIZEK2 / kBlockSizeK2); + + // the bounds checking in device code is very expensive, making the code + // around 25% slower. So let's skip those checks if possible. + if ((M % TILE_SIZEQ == 0) && (N % TILE_SIZEK == 0)) { + attention_backward_grad_v_kernel< + scalar_t, + vec_t, + kBlockSizeQ, + kBlockSizeK, + TILE_SIZEQ, + TILE_SIZEK, false><<>>( + grad_v.packed_accessor(), + grad_out.packed_accessor(), + query.packed_accessor(), + key.packed_accessor(), + value.packed_accessor(), + tmp_sum_i.packed_accessor(), + logsumexp.packed_accessor()); + } else { + attention_backward_grad_v_kernel< + scalar_t, + vec_t, + kBlockSizeQ, + kBlockSizeK, + TILE_SIZEQ, + TILE_SIZEK, true><<>>( + grad_v.packed_accessor(), + grad_out.packed_accessor(), + query.packed_accessor(), + key.packed_accessor(), + value.packed_accessor(), + tmp_sum_i.packed_accessor(), + logsumexp.packed_accessor()); + } + + if ((M % TILE_SIZEQ2 == 0) && (N % TILE_SIZEK2 == 0)) { + attention_backward_grad_qk_kernel< + scalar_t, + vec_t, + kBlockSizeQ2, + kBlockSizeK2, + TILE_SIZEQ2, + TILE_SIZEK2, false><<>>( + grad_q.packed_accessor(), + grad_k.packed_accessor(), + grad_out.packed_accessor(), + query.packed_accessor(), + key.packed_accessor(), + value.packed_accessor(), + tmp_sum_i.packed_accessor(), + logsumexp.packed_accessor()); + } else { + attention_backward_grad_qk_kernel< + scalar_t, + vec_t, + kBlockSizeQ2, + kBlockSizeK2, + TILE_SIZEQ2, + TILE_SIZEK2, true><<>>( + grad_q.packed_accessor(), + grad_k.packed_accessor(), + grad_out.packed_accessor(), + query.packed_accessor(), + key.packed_accessor(), + value.packed_accessor(), + tmp_sum_i.packed_accessor(), + logsumexp.packed_accessor()); + } +} + + std::tuple attention_backward( const at::Tensor& grad_out, const at::Tensor& query, @@ -971,6 +1079,11 @@ std::tuple attention_backward( TORCH_CHECK(!value.is_sparse(), "value must be a dense tensor"); TORCH_CHECK(!grad_out.is_sparse(), "grad_out must be a dense tensor"); + // TODO: support other dtypes in the future + TORCH_CHECK( + query.scalar_type() == at::ScalarType::Float, + "Only float32 type is supported for now"); + at::cuda::CUDAGuard device_guard(query.device()); int64_t B = query.size(0); @@ -984,64 +1097,32 @@ std::tuple attention_backward( at::Tensor tmp_sum_i = at::zeros({B, M}, query.options()); - using scalar_t = float; - using vec_t = float4; + //using scalar_t = float; + //using vec_t = float4; // using vec_t = float; - constexpr int TILE_SIZEQ = 32; - constexpr int TILE_SIZEK = 32; - constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); - - constexpr int64_t kBlockSizeQ = 4; - constexpr int64_t kBlockSizeK = 8; - - dim3 grid( - ceil_div(M, int64_t(TILE_SIZEQ)), ceil_div(N, int64_t(TILE_SIZEK)), B); - dim3 block(TILE_SIZEQ / kBlockSizeQ, TILE_SIZEK / kBlockSizeK); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - attention_backward_grad_v_kernel< - scalar_t, - vec_t, - kBlockSizeQ, - kBlockSizeK, - TILE_SIZEQ, - TILE_SIZEK, true><<>>( - grad_v.packed_accessor(), - grad_out.packed_accessor(), - query.packed_accessor(), - key.packed_accessor(), - value.packed_accessor(), - tmp_sum_i.packed_accessor(), - logsumexp.packed_accessor()); - - constexpr int TILE_SIZEQ2 = 32; - constexpr int TILE_SIZEK2 = 32; - - constexpr int64_t kBlockSizeQ2 = 4; - constexpr int64_t kBlockSizeK2 = 4; - - dim3 grid2( - ceil_div(M, int64_t(TILE_SIZEQ2)), ceil_div(N, int64_t(TILE_SIZEK2)), B); - dim3 block2(TILE_SIZEQ2 / kBlockSizeQ2, TILE_SIZEK2 / kBlockSizeK2); - // TODO: try adding a blockDim.x to iterate over k - - attention_backward_grad_qk_kernel< - scalar_t, - vec_t, - kBlockSizeQ2, - kBlockSizeK2, - TILE_SIZEQ2, - TILE_SIZEK2, true><<>>( - grad_q.packed_accessor(), - grad_k.packed_accessor(), - grad_out.packed_accessor(), - query.packed_accessor(), - key.packed_accessor(), - value.packed_accessor(), - tmp_sum_i.packed_accessor(), - logsumexp.packed_accessor()); + if ((K % 4) == 0) { + launch_attention_backward(grad_q, grad_k, grad_v, + grad_out, + query, + key, + value, + logsumexp, tmp_sum_i); + } else if ((K % 2) == 0) { + launch_attention_backward(grad_q, grad_k, grad_v, + grad_out, + query, + key, + value, + logsumexp, tmp_sum_i); + } else { + launch_attention_backward(grad_q, grad_k, grad_v, + grad_out, + query, + key, + value, + logsumexp, tmp_sum_i); + } AT_CUDA_CHECK(cudaGetLastError()); From a8c1f4b5e9ac43e7557f8dba8b5a2f3b36ee027c Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 22 Apr 2022 09:08:36 -0700 Subject: [PATCH 39/45] Let logsumexp be returned by forward Also add an autograd Function for backward --- .../benchmarks/benchmark_mem_eff_attention.py | 5 +- .../components/attention/csrc/attention.cpp | 2 +- .../attention/csrc/cpu/attention.cpp | 21 +-- .../attention/csrc/cuda/attention.cu | 133 ++++++++++++------ xformers/ops.py | 26 +++- 5 files changed, 124 insertions(+), 63 deletions(-) diff --git a/xformers/benchmarks/benchmark_mem_eff_attention.py b/xformers/benchmarks/benchmark_mem_eff_attention.py index 88968c907d..72b4dbe0f3 100644 --- a/xformers/benchmarks/benchmark_mem_eff_attention.py +++ b/xformers/benchmarks/benchmark_mem_eff_attention.py @@ -26,6 +26,7 @@ def ref_attention(q, k, v): SHAPES = list( itertools.product([1, 8, 32, 256], [127, 128, 512, 513, 1023, 1024], [16, 32]) ) +SHAPES = [(256, 1024, 32)] results = [] mem_use: Dict[str, Dict[str, float]] = dict(optimized={}, vanilla={}) @@ -38,7 +39,7 @@ def ref_attention(q, k, v): q = torch.rand(shape, device=device) sub_label = f"B={B}, M={M}, K={K}" - if True: + if False: r = xformers.ops.memory_efficient_attention(q, q, q) rr = ref_attention(q, q, q) @@ -51,7 +52,7 @@ def ref_attention(q, k, v): stmt="fn(q, q, q)", globals={ "q": q, - "fn": torch.ops.xformers.efficient_attention, + "fn": xformers.ops.memory_efficient_attention, }, label="attention", description="optimized", diff --git a/xformers/components/attention/csrc/attention.cpp b/xformers/components/attention/csrc/attention.cpp index 3b0d6e71b0..eb35b95c04 100644 --- a/xformers/components/attention/csrc/attention.cpp +++ b/xformers/components/attention/csrc/attention.cpp @@ -2,7 +2,7 @@ TORCH_LIBRARY_FRAGMENT(xformers, m) { m.def(TORCH_SELECTIVE_SCHEMA( - "xformers::efficient_attention(Tensor query, Tensor key, Tensor value) -> Tensor")); + "xformers::efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_logsumexp) -> (Tensor, Tensor)")); m.def(TORCH_SELECTIVE_SCHEMA( "xformers::efficient_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor logsumexp) -> (Tensor, Tensor, Tensor)")); } diff --git a/xformers/components/attention/csrc/cpu/attention.cpp b/xformers/components/attention/csrc/cpu/attention.cpp index dd5d4280e5..04b534a2d1 100644 --- a/xformers/components/attention/csrc/cpu/attention.cpp +++ b/xformers/components/attention/csrc/cpu/attention.cpp @@ -28,10 +28,12 @@ scalar_t max(scalar_t* buf) { template void attention_kernel( at::TensorAccessor output, + at::TensorAccessor logsumexp, at::TensorAccessor query, at::TensorAccessor key, at::TensorAccessor value, - at::TensorAccessor buffer //, + at::TensorAccessor buffer, + bool compute_logsumexp // at::TensorAccessor mask ) { // TODO: optimize the code by adding blocking @@ -90,15 +92,18 @@ void attention_kernel( for (int64_t k = 0; k < K; k++) { oo[k] = buf[k] / s_prime; } + if (compute_logsumexp) + logsumexp[i][j] = m_prime + std::log(s_prime); } } }); } -at::Tensor attention( +std::tuple attention( const at::Tensor& query, const at::Tensor& key, - const at::Tensor& value + const at::Tensor& value, + bool compute_logsumexp // const at::Tensor& mask ) { TORCH_CHECK(query.dim() == key.dim()); @@ -132,19 +137,22 @@ at::Tensor attention( int64_t K = query.size(2); at::Tensor res = at::empty({B, M, K}, query.options()); + at::Tensor logsumexp = at::empty({B, M}, query.options()); at::Tensor buffer = at::empty({at::get_num_threads(), 1, K}, query.options()); AT_DISPATCH_FLOATING_TYPES(query.scalar_type(), "attention_kernel", [&] { attention_kernel( res.accessor(), + logsumexp.accessor(), query.accessor(), key.accessor(), value.accessor(), - buffer.accessor()); + buffer.accessor(), + compute_logsumexp); }); - return res; + return std::make_tuple(res, logsumexp); } template @@ -281,9 +289,6 @@ std::tuple attention_backward( at::Tensor buffer2 = at::zeros({at::get_num_threads(), 1, N}, query.options()); - // TODO this should be an argument from the function - //at::Tensor logsumexp = query.bmm(key.transpose(-2, -1)).logsumexp(-1); - AT_DISPATCH_FLOATING_TYPES( query.scalar_type(), "attention_backward_kernel", [&] { attention_backward_kernel( diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index bed355628c..25f18a5c0e 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -375,12 +375,15 @@ template < int kBlockSizeK, int kBlockSizeQ, int WARP_SIZE, - int BUFFER_SIZE> + int BUFFER_SIZE, + bool compute_logsumexp> __global__ void attention_kernel( at::PackedTensorAccessor output, + at::PackedTensorAccessor logsumexp, at::PackedTensorAccessor query, at::PackedTensorAccessor key, - at::PackedTensorAccessor value) { + at::PackedTensorAccessor value + ) { constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); static_assert( integerIsPowerOf2(kBlockSizeK * WARP_SIZE), @@ -399,6 +402,7 @@ __global__ void attention_kernel( vec_t* query_block[kBlockSizeQ]; vec_t* output_block[kBlockSizeQ]; + scalar_t* logsumexp_block[kBlockSizeQ]; // TODO [BUFFER_SIZE limitation]: the current strategy assumes a // statically-known size for K. Ideally we would like to remove this // limitation in the future, so that any K is supported @@ -413,6 +417,7 @@ __global__ void attention_kernel( output_block[q_item_idx] = reinterpret_cast(output[batch_idx][index].data()); m_prime[q_item_idx] = -std::numeric_limits::infinity(); + logsumexp_block[q_item_idx] = &logsumexp[batch_idx][index]; } #if 0 // this for now makes things slower @@ -493,56 +498,30 @@ __global__ void attention_kernel( output_block[q_item_idx][k] = tmp; } } + + if (compute_logsumexp) { +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + *logsumexp_block[q_item_idx] = m_prime[q_item_idx] + std::log(s_prime[q_item_idx]); + } + } } -at::Tensor attention( +template +void launch_attention( + at::Tensor& res, + at::Tensor& logsumexp, const at::Tensor& query, const at::Tensor& key, const at::Tensor& value - // const at::Tensor& mask -) { - TORCH_CHECK(query.dim() == key.dim()); - TORCH_CHECK(query.dim() == value.dim()); - // TORCH_CHECK(query.dim() == mask.dim()); - TORCH_CHECK(query.dim() == 3); - TORCH_CHECK(query.size(2) == key.size(2)); - TORCH_CHECK(query.size(0) == key.size(0)); - - TORCH_CHECK(query.size(0) == value.size(0)); - TORCH_CHECK(key.size(1) == value.size(1)); - TORCH_CHECK( - query.size(2) == - value.size(2)); // TODO: drop this limitation in the future - - TORCH_CHECK(query.is_cuda(), "query must be a CUDA tensor"); - TORCH_CHECK(key.is_cuda(), "key must be a CUDA tensor"); - TORCH_CHECK(value.is_cuda(), "value must be a CUDA tensor"); - - TORCH_CHECK(!query.is_sparse(), "query must be a dense tensor"); - TORCH_CHECK(!key.is_sparse(), "key must be a dense tensor"); - TORCH_CHECK(!value.is_sparse(), "value must be a dense tensor"); - - // TODO drop this limitation in the future - TORCH_CHECK(query.is_contiguous()); - TORCH_CHECK(key.is_contiguous()); - TORCH_CHECK(value.is_contiguous()); - - // TODO: support other dtypes in the future - TORCH_CHECK( - query.scalar_type() == at::ScalarType::Float, - "Only float32 type is supported for now"); - - at::cuda::CUDAGuard device_guard(query.device()); + ) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int64_t B = query.size(0); int64_t M = query.size(1); int64_t N = key.size(1); int64_t K = query.size(2); - at::Tensor res = at::zeros({B, M, K}, query.options()); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - constexpr int WARP_SIZE = 4; constexpr int kBlockSizeK = 32; @@ -566,8 +545,10 @@ at::Tensor attention( kBlockSizeK, kBlockSizeQ, WARP_SIZE, - BUFFER_SIZE><<>>( + BUFFER_SIZE, + compute_logsumexp><<>>( res.packed_accessor(), + logsumexp.packed_accessor(), query.packed_accessor(), key.packed_accessor(), value.packed_accessor()); @@ -581,8 +562,10 @@ at::Tensor attention( kBlockSizeK, kBlockSizeQ, WARP_SIZE, - BUFFER_SIZE><<>>( + BUFFER_SIZE, + compute_logsumexp><<>>( res.packed_accessor(), + logsumexp.packed_accessor(), query.packed_accessor(), key.packed_accessor(), value.packed_accessor()); @@ -597,15 +580,75 @@ at::Tensor attention( kBlockSizeK, kBlockSizeQ, WARP_SIZE, - BUFFER_SIZE><<>>( + BUFFER_SIZE, + compute_logsumexp><<>>( res.packed_accessor(), + logsumexp.packed_accessor(), query.packed_accessor(), key.packed_accessor(), value.packed_accessor()); } +} + +std::tuple attention( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + bool compute_logsumexp + // const at::Tensor& mask +) { + TORCH_CHECK(query.dim() == key.dim()); + TORCH_CHECK(query.dim() == value.dim()); + // TORCH_CHECK(query.dim() == mask.dim()); + TORCH_CHECK(query.dim() == 3); + TORCH_CHECK(query.size(2) == key.size(2)); + TORCH_CHECK(query.size(0) == key.size(0)); + + TORCH_CHECK(query.size(0) == value.size(0)); + TORCH_CHECK(key.size(1) == value.size(1)); + TORCH_CHECK( + query.size(2) == + value.size(2)); // TODO: drop this limitation in the future + + TORCH_CHECK(query.is_cuda(), "query must be a CUDA tensor"); + TORCH_CHECK(key.is_cuda(), "key must be a CUDA tensor"); + TORCH_CHECK(value.is_cuda(), "value must be a CUDA tensor"); + + TORCH_CHECK(!query.is_sparse(), "query must be a dense tensor"); + TORCH_CHECK(!key.is_sparse(), "key must be a dense tensor"); + TORCH_CHECK(!value.is_sparse(), "value must be a dense tensor"); + + // TODO drop this limitation in the future + TORCH_CHECK(query.is_contiguous()); + TORCH_CHECK(key.is_contiguous()); + TORCH_CHECK(value.is_contiguous()); + + // TODO: support other dtypes in the future + TORCH_CHECK( + query.scalar_type() == at::ScalarType::Float, + "Only float32 type is supported for now"); + + at::cuda::CUDAGuard device_guard(query.device()); + + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + int64_t K = query.size(2); + + at::Tensor res = at::zeros({B, M, K}, query.options()); + at::Tensor logsumexp = at::empty({B, M}, query.options()); + + // have to pass compute_logsumexp as a template parameter + // otherwise there is a slowdown in the kernel... + if (compute_logsumexp) { + launch_attention(res, logsumexp, query, key, value); + } else { + launch_attention(res, logsumexp, query, key, value); + } + AT_CUDA_CHECK(cudaGetLastError()); - return res; + return std::make_tuple(res, logsumexp); } template < diff --git a/xformers/ops.py b/xformers/ops.py index 2eb4412051..b2f18afeba 100644 --- a/xformers/ops.py +++ b/xformers/ops.py @@ -29,6 +29,22 @@ def masked_matmul(a, b, mask=None): return att +class _MemoryEfficientAttentionOp(torch.autograd.Function): + @staticmethod + def forward(ctx, query, key, value): + out, lse = torch.ops.xformers.efficient_attention(query, key, value, True) + ctx.save_for_backward(query, key, value, lse) + return out + + @staticmethod + def backward(ctx, grad): + query, key, value, lse = ctx.saved_tensors + grad_q, grad_k, grad_v = torch.ops.xformers.efficient_attention( + grad, query, key, value, lse + ) + return grad_q, grad_k, grad_v + + def memory_efficient_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ): @@ -36,11 +52,7 @@ def memory_efficient_attention( Implements the memory-efficient attention mechanism following `"Self-Attention Does Not Need O(n^2) Memory" `_. - For now, only forward in inference-mode is supported. """ - # don't support backwards for now - assert query.requires_grad is False - assert key.requires_grad is False - assert value.requires_grad is False - - return torch.ops.xformers.efficient_attention(query, key, value) + if all(x.requires_grad is False for x in [query, key, value]): + return torch.ops.xformers.efficient_attention(query, key, value, False)[0] + return _MemoryEfficientAttentionOp.apply(query, key, value) From 726d3c59ee759df44785be9cecb5f07cc9842aee Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 22 Apr 2022 09:09:26 -0700 Subject: [PATCH 40/45] clang-format --- .../attention/csrc/cpu/attention.cpp | 3 - .../attention/csrc/cuda/attention.cu | 117 ++++++++++-------- 2 files changed, 64 insertions(+), 56 deletions(-) diff --git a/xformers/components/attention/csrc/cpu/attention.cpp b/xformers/components/attention/csrc/cpu/attention.cpp index 04b534a2d1..9fbeae1bb2 100644 --- a/xformers/components/attention/csrc/cpu/attention.cpp +++ b/xformers/components/attention/csrc/cpu/attention.cpp @@ -178,7 +178,6 @@ void attention_backward_kernel( auto buf = buffer[at::get_thread_num()][0]; auto buf2 = buffer2[at::get_thread_num()][0]; for (int64_t i = start; i < end; i++) { - for (int64_t j = 0; j < M; j++) { for (int64_t k = 0; k < K; k++) { buf[k] = 0; @@ -244,8 +243,6 @@ std::tuple attention_backward( const at::Tensor& logsumexp // const at::Tensor& mask ) { - - TORCH_CHECK(query.dim() == grad_out.dim()); TORCH_CHECK(query.dim() == key.dim()); TORCH_CHECK(query.dim() == value.dim()); diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index 25f18a5c0e..f211da0d44 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -382,8 +382,7 @@ __global__ void attention_kernel( at::PackedTensorAccessor logsumexp, at::PackedTensorAccessor query, at::PackedTensorAccessor key, - at::PackedTensorAccessor value - ) { + at::PackedTensorAccessor value) { constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); static_assert( integerIsPowerOf2(kBlockSizeK * WARP_SIZE), @@ -502,7 +501,8 @@ __global__ void attention_kernel( if (compute_logsumexp) { #pragma unroll for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { - *logsumexp_block[q_item_idx] = m_prime[q_item_idx] + std::log(s_prime[q_item_idx]); + *logsumexp_block[q_item_idx] = + m_prime[q_item_idx] + std::log(s_prime[q_item_idx]); } } } @@ -513,8 +513,7 @@ void launch_attention( at::Tensor& logsumexp, const at::Tensor& query, const at::Tensor& key, - const at::Tensor& value - ) { + const at::Tensor& value) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int64_t B = query.size(0); @@ -693,7 +692,8 @@ __global__ void attention_backward_grad_v_kernel( scalar_t normalizer[kBlockSizeQ]; scalar_t tmp_sum[kBlockSizeQ] = {0}; - vec_t* qb[kBlockSizeQ], *kb[kBlockSizeK], *vb[kBlockSizeK], *gb[kBlockSizeQ], *gbb[TILE_SIZEQ]; + vec_t *qb[kBlockSizeQ], *kb[kBlockSizeK], *vb[kBlockSizeK], *gb[kBlockSizeQ], + *gbb[TILE_SIZEQ]; scalar_t maskQ[kBlockSizeQ], maskK[kBlockSizeK]; for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { @@ -711,7 +711,8 @@ __global__ void attention_backward_grad_v_kernel( if (check_bounds) index = min(index, M - 1); qb[q_item_idx] = reinterpret_cast(query[batch_idx][index].data()); - gb[q_item_idx] = reinterpret_cast(grad_out[batch_idx][index].data()); + gb[q_item_idx] = + reinterpret_cast(grad_out[batch_idx][index].data()); } for (int64_t i = 0; i < TILE_SIZEQ; i++) { @@ -739,9 +740,7 @@ __global__ void attention_backward_grad_v_kernel( #pragma unroll for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { sputnik::VectorCompute::Dot( - __ldg(qb[q_item_idx] + k), - kk, - &attn_v[q_item_idx][k_item_idx]); + __ldg(qb[q_item_idx] + k), kk, &attn_v[q_item_idx][k_item_idx]); sputnik::VectorCompute::Dot( __ldg(gb[q_item_idx] + k), tt, @@ -754,7 +753,8 @@ __global__ void attention_backward_grad_v_kernel( #pragma unroll for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { attn_v[q_item_idx][k_item_idx] = - std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]) * maskQ[q_item_idx] * maskK[k_item_idx]; + std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]) * + maskQ[q_item_idx] * maskK[k_item_idx]; } } @@ -765,7 +765,8 @@ __global__ void attention_backward_grad_v_kernel( fact[kBlockSizeQ * threadIdx.x + q_item_idx] [kBlockSizeK * threadIdx.y + k_item_idx] = attn_v[q_item_idx][k_item_idx]; - tmp_sum[q_item_idx] += attn_v[q_item_idx][k_item_idx] * grad_attn_v[q_item_idx][k_item_idx]; + tmp_sum[q_item_idx] += + attn_v[q_item_idx][k_item_idx] * grad_attn_v[q_item_idx][k_item_idx]; } } __syncthreads(); @@ -788,8 +789,7 @@ __global__ void attention_backward_grad_v_kernel( int64_t index = l + k_item_idx; if (check_bounds) index = min(index, N - 1); - myGpuAtomicAdd( - &grad_v[batch_idx][index][k * kVecSize], res[k_item_idx]); + myGpuAtomicAdd(&grad_v[batch_idx][index][k * kVecSize], res[k_item_idx]); } } for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { @@ -843,7 +843,8 @@ __global__ void attention_backward_grad_qk_kernel( scalar_t normalizer[kBlockSizeQ]; scalar_t tmp_sum[kBlockSizeQ]; - vec_t* qb[kBlockSizeQ], *kb[kBlockSizeK], *vb[kBlockSizeK], *gb[kBlockSizeQ], *qbb[TILE_SIZEQ], *kbb[TILE_SIZEK]; + vec_t *qb[kBlockSizeQ], *kb[kBlockSizeK], *vb[kBlockSizeK], *gb[kBlockSizeQ], + *qbb[TILE_SIZEQ], *kbb[TILE_SIZEK]; scalar_t maskQ[kBlockSizeQ], maskK[kBlockSizeK]; for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { @@ -861,8 +862,8 @@ __global__ void attention_backward_grad_qk_kernel( if (check_bounds) index = min(index, M - 1); qb[q_item_idx] = reinterpret_cast(query[batch_idx][index].data()); - gb[q_item_idx] = reinterpret_cast(grad_out[batch_idx][index].data()); - + gb[q_item_idx] = + reinterpret_cast(grad_out[batch_idx][index].data()); } for (int64_t i = 0; i < TILE_SIZEQ; i++) { int64_t index = query_idx + i - kBlockSizeQ * threadIdx.x; @@ -897,9 +898,7 @@ __global__ void attention_backward_grad_qk_kernel( #pragma unroll for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { sputnik::VectorCompute::Dot( - __ldg(qb[q_item_idx] + k), - kk, - &attn_v[q_item_idx][k_item_idx]); + __ldg(qb[q_item_idx] + k), kk, &attn_v[q_item_idx][k_item_idx]); sputnik::VectorCompute::Dot( __ldg(gb[q_item_idx] + k), tt, @@ -912,7 +911,8 @@ __global__ void attention_backward_grad_qk_kernel( #pragma unroll for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { attn_v[q_item_idx][k_item_idx] = - std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]) * maskQ[q_item_idx] * maskK[k_item_idx]; + std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]) * + maskQ[q_item_idx] * maskK[k_item_idx]; } } @@ -947,9 +947,7 @@ __global__ void attention_backward_grad_qk_kernel( int64_t index = query_idx + q_item_idx; if (check_bounds) index = min(index, M - 1); - myGpuAtomicAdd( - &grad_q[batch_idx][index][k * kVecSize], - res[q_item_idx]); + myGpuAtomicAdd(&grad_q[batch_idx][index][k * kVecSize], res[q_item_idx]); } } @@ -971,8 +969,7 @@ __global__ void attention_backward_grad_qk_kernel( int64_t index = l + k_item_idx; if (check_bounds) index = min(index, N - 1); - myGpuAtomicAdd( - &grad_k[batch_idx][index][k * kVecSize], res[k_item_idx]); + myGpuAtomicAdd(&grad_k[batch_idx][index][k * kVecSize], res[k_item_idx]); } } } @@ -987,8 +984,7 @@ void launch_attention_backward( const at::Tensor& key, const at::Tensor& value, const at::Tensor& logsumexp, - at::Tensor& tmp_sum_i -) { + at::Tensor& tmp_sum_i) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int64_t B = query.size(0); @@ -1024,7 +1020,8 @@ void launch_attention_backward( kBlockSizeQ, kBlockSizeK, TILE_SIZEQ, - TILE_SIZEK, false><<>>( + TILE_SIZEK, + false><<>>( grad_v.packed_accessor(), grad_out.packed_accessor(), query.packed_accessor(), @@ -1039,7 +1036,8 @@ void launch_attention_backward( kBlockSizeQ, kBlockSizeK, TILE_SIZEQ, - TILE_SIZEK, true><<>>( + TILE_SIZEK, + true><<>>( grad_v.packed_accessor(), grad_out.packed_accessor(), query.packed_accessor(), @@ -1056,7 +1054,8 @@ void launch_attention_backward( kBlockSizeQ2, kBlockSizeK2, TILE_SIZEQ2, - TILE_SIZEK2, false><<>>( + TILE_SIZEK2, + false><<>>( grad_q.packed_accessor(), grad_k.packed_accessor(), grad_out.packed_accessor(), @@ -1072,7 +1071,8 @@ void launch_attention_backward( kBlockSizeQ2, kBlockSizeK2, TILE_SIZEQ2, - TILE_SIZEK2, true><<>>( + TILE_SIZEK2, + true><<>>( grad_q.packed_accessor(), grad_k.packed_accessor(), grad_out.packed_accessor(), @@ -1084,7 +1084,6 @@ void launch_attention_backward( } } - std::tuple attention_backward( const at::Tensor& grad_out, const at::Tensor& query, @@ -1140,31 +1139,43 @@ std::tuple attention_backward( at::Tensor tmp_sum_i = at::zeros({B, M}, query.options()); - //using scalar_t = float; - //using vec_t = float4; + // using scalar_t = float; + // using vec_t = float4; // using vec_t = float; if ((K % 4) == 0) { - launch_attention_backward(grad_q, grad_k, grad_v, - grad_out, - query, - key, - value, - logsumexp, tmp_sum_i); + launch_attention_backward( + grad_q, + grad_k, + grad_v, + grad_out, + query, + key, + value, + logsumexp, + tmp_sum_i); } else if ((K % 2) == 0) { - launch_attention_backward(grad_q, grad_k, grad_v, - grad_out, - query, - key, - value, - logsumexp, tmp_sum_i); + launch_attention_backward( + grad_q, + grad_k, + grad_v, + grad_out, + query, + key, + value, + logsumexp, + tmp_sum_i); } else { - launch_attention_backward(grad_q, grad_k, grad_v, - grad_out, - query, - key, - value, - logsumexp, tmp_sum_i); + launch_attention_backward( + grad_q, + grad_k, + grad_v, + grad_out, + query, + key, + value, + logsumexp, + tmp_sum_i); } AT_CUDA_CHECK(cudaGetLastError()); From 3f8f9547808eb58769808467bc159cb450cc3ee4 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 22 Apr 2022 09:32:20 -0700 Subject: [PATCH 41/45] Add scaling factor --- xformers/components/attention/csrc/cpu/attention.cpp | 5 +++-- xformers/components/attention/csrc/cuda/attention.cu | 10 +++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/xformers/components/attention/csrc/cpu/attention.cpp b/xformers/components/attention/csrc/cpu/attention.cpp index 9fbeae1bb2..ee207cdf75 100644 --- a/xformers/components/attention/csrc/cpu/attention.cpp +++ b/xformers/components/attention/csrc/cpu/attention.cpp @@ -174,6 +174,7 @@ void attention_backward_kernel( int64_t M = q.size(1); int64_t N = k.size(1); int64_t grain_size = 1; // buffer.size(1); + scalar_t scale = 1.0 / std::sqrt(scalar_t(K)); at::parallel_for(0, B, grain_size, [&](int64_t start, int64_t end) { auto buf = buffer[at::get_thread_num()][0]; auto buf2 = buffer2[at::get_thread_num()][0]; @@ -191,7 +192,7 @@ void attention_backward_kernel( for (int64_t k = 0; k < K; k++) { si += query_i[k] * key_j[k]; } - scalar_t attn_v = std::exp(si - normalizer); + scalar_t attn_v = std::exp(si * scale - normalizer); for (int64_t k = 0; k < K; k++) { grad_v[i][l][k] += attn_v * grad_out[i][j][k]; @@ -207,7 +208,7 @@ void attention_backward_kernel( } // those are temporaries for the gradient of the softmax - scalar_t tmp = attn_v * grad_attn_v; + scalar_t tmp = attn_v * grad_attn_v * scale; tmp_sum += tmp; // grad_q is easy diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index f211da0d44..b5efefd337 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -122,7 +122,7 @@ __device__ void compute_dot( scalar_t out[kBlockSizeQ][kBlockSizeK], int64_t K) { constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); - scalar_t scale = 1.0; // / std::sqrt(scalar_t(K)); + scalar_t scale = 1.0 / std::sqrt(scalar_t(K)); vec_t q_i[kBlockSizeQ]; for (int64_t k = 0; k < K / kVecSize; k += 1) { #pragma unroll @@ -731,11 +731,13 @@ __global__ void attention_backward_grad_v_kernel( scalar_t attn_v[kBlockSizeQ][kBlockSizeK] = {0}; scalar_t grad_attn_v[kBlockSizeQ][kBlockSizeK] = {0}; + scalar_t scale = 1.0 / std::sqrt(scalar_t(K)); for (int64_t k = 0; k < K / kVecSize; k += 1) { #pragma unroll for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { vec_t kk = __ldg(kb[k_item_idx] + k); + iMul(scale, &kk); vec_t tt = __ldg(vb[k_item_idx] + k); #pragma unroll for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { @@ -889,11 +891,13 @@ __global__ void attention_backward_grad_qk_kernel( scalar_t attn_v[kBlockSizeQ][kBlockSizeK] = {0}; scalar_t grad_attn_v[kBlockSizeQ][kBlockSizeK] = {0}; + scalar_t scale = 1.0 / std::sqrt(scalar_t(K)); for (int64_t k = 0; k < K / kVecSize; k += 1) { #pragma unroll for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { vec_t kk = __ldg(kb[k_item_idx] + k); + iMul(scale, &kk); vec_t tt = __ldg(vb[k_item_idx] + k); #pragma unroll for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { @@ -922,9 +926,9 @@ __global__ void attention_backward_grad_qk_kernel( for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { fact[kBlockSizeQ * threadIdx.x + q_item_idx] [kBlockSizeK * threadIdx.y + k_item_idx] = - attn_v[q_item_idx][k_item_idx] * + attn_v[q_item_idx][k_item_idx] * scale * ( grad_attn_v[q_item_idx][k_item_idx] - - attn_v[q_item_idx][k_item_idx] * tmp_sum[q_item_idx]; + tmp_sum[q_item_idx]); } } __syncthreads(); From a43d72b97da6e883dbd73cdf2c2a2a360666002a Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 22 Apr 2022 10:10:08 -0700 Subject: [PATCH 42/45] Add tests + silly bugfix --- tests/test_mem_eff_attention.py | 55 +++++++++++++++++++++++++++++++++ xformers/ops.py | 2 +- 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index c6623603c7..d4542e1ab4 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -49,3 +49,58 @@ def test_key_query_all_ones(device, q_len, kv_len, batch_size, k_len): ref = value.mean(1, keepdim=True).expand_as(query) assert torch.allclose(out, ref, atol=1e-5) + + +@pytest.mark.parametrize("k_len", [5, 6, 32]) +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("kv_len", [3, 15, 32, 33]) +@pytest.mark.parametrize("q_len", [2, 3, 5]) +@pytest.mark.parametrize("device", _devices) +def test_logsumexp(device, q_len, kv_len, batch_size, k_len): + scale = 3 + query = torch.randn((batch_size, q_len, k_len), device=device) * scale + key = torch.randn((batch_size, kv_len, k_len), device=device) * scale + value = torch.randn((batch_size, kv_len, k_len), device=device) * scale + + _, lse = torch.ops.xformers.efficient_attention(query, key, value, True) + ref_lse = ((query / k_len ** 0.5) @ key.transpose(-2, -1)).logsumexp(-1) + + assert torch.allclose(lse, ref_lse, atol=2e-4) + + +@pytest.mark.parametrize("k_len", [5, 6, 32]) +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("kv_len", [3, 15, 32, 33]) +@pytest.mark.parametrize("q_len", [2, 3, 5]) +@pytest.mark.parametrize("device", _devices) +def test_memory_efficient_attention_backward(device, q_len, kv_len, batch_size, k_len): + scale = 3 + query = torch.randn((batch_size, q_len, k_len), device=device) * scale + key = torch.randn((batch_size, kv_len, k_len), device=device) * scale + value = torch.randn((batch_size, kv_len, k_len), device=device) * scale + + query.requires_grad_(True) + key.requires_grad_(True) + value.requires_grad_(True) + + out = xformers.ops.memory_efficient_attention(query, key, value) + out.backward(torch.ones_like(query)) + + grad_q = query.grad + grad_k = key.grad + grad_v = value.grad + + query.grad = None + key.grad = None + value.grad = None + + ref = ref_attention(query, key, value) + ref.backward(torch.ones_like(query)) + + # there is some extra precision loss in the CPU implementation due to an + # extra accumulation step in grad_q, which is not present in the CUDA + # implementation + atol = 3e-4 if device == "cuda" else 4e-4 + assert torch.allclose(grad_q, query.grad, atol=atol), "grad_q doesn't match" + assert torch.allclose(grad_k, key.grad, atol=atol), "grad_k doesn't match" + assert torch.allclose(grad_v, value.grad, atol=atol), "grad_v doesn't match" diff --git a/xformers/ops.py b/xformers/ops.py index b2f18afeba..8812746176 100644 --- a/xformers/ops.py +++ b/xformers/ops.py @@ -39,7 +39,7 @@ def forward(ctx, query, key, value): @staticmethod def backward(ctx, grad): query, key, value, lse = ctx.saved_tensors - grad_q, grad_k, grad_v = torch.ops.xformers.efficient_attention( + grad_q, grad_k, grad_v = torch.ops.xformers.efficient_attention_backward( grad, query, key, value, lse ) return grad_q, grad_k, grad_v From 45ed14ccaee2eff1625f2430a750854c2e7ab693 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 22 Apr 2022 10:22:09 -0700 Subject: [PATCH 43/45] Add benchmark function for backward --- .../benchmarks/benchmark_mem_eff_attention.py | 208 ++++++++++++------ 1 file changed, 144 insertions(+), 64 deletions(-) diff --git a/xformers/benchmarks/benchmark_mem_eff_attention.py b/xformers/benchmarks/benchmark_mem_eff_attention.py index 72b4dbe0f3..9f0662c820 100644 --- a/xformers/benchmarks/benchmark_mem_eff_attention.py +++ b/xformers/benchmarks/benchmark_mem_eff_attention.py @@ -26,71 +26,151 @@ def ref_attention(q, k, v): SHAPES = list( itertools.product([1, 8, 32, 256], [127, 128, 512, 513, 1023, 1024], [16, 32]) ) -SHAPES = [(256, 1024, 32)] results = [] mem_use: Dict[str, Dict[str, float]] = dict(optimized={}, vanilla={}) -print(f"Processing {len(SHAPES)} cases") -for num_threads in NUM_THREADS: - for shape in SHAPES: - print(f"===== {shape} =====") - B, M, K = shape - q = torch.rand(shape, device=device) - sub_label = f"B={B}, M={M}, K={K}" - - if False: - r = xformers.ops.memory_efficient_attention(q, q, q) - - rr = ref_attention(q, q, q) - assert (r - rr).abs().max() < 1e-5 - - torch.cuda.reset_peak_memory_stats() - torch.cuda.synchronize() - results.append( - benchmark.Timer( - stmt="fn(q, q, q)", - globals={ - "q": q, - "fn": xformers.ops.memory_efficient_attention, - }, - label="attention", - description="optimized", - sub_label=sub_label, - num_threads=num_threads, - ).blocked_autorange(min_run_time=min_run_time) - ) - torch.cuda.synchronize() - memory = torch.cuda.max_memory_allocated() / 2 ** 20 - mem_use["optimized"][sub_label] = memory - memory_str = f"Memory used: {memory} MB" - - print("Optimized", memory_str) - - torch.cuda.reset_peak_memory_stats() - torch.cuda.synchronize() - results.append( - benchmark.Timer( - stmt="fn(q, q, q)", - globals={ - "q": q, - "fn": ref_attention, - }, - label="attention", - description="vanilla", - sub_label=sub_label, - num_threads=num_threads, - ).blocked_autorange(min_run_time=min_run_time) - ) - - torch.cuda.synchronize() - memory = torch.cuda.max_memory_allocated() / 2 ** 20 - mem_use["vanilla"][sub_label] = memory - memory_str = f"Memory used: {memory} MB" - print("Vanilla", memory_str) - - -compare = benchmark.Compare(results) -compare.print() - -pprint.pprint(mem_use) + +def benchmark_forward(): + print(f"Processing {len(SHAPES)} cases") + print("Forward") + for num_threads in NUM_THREADS: + for shape in SHAPES: + print(f"===== {shape} =====") + B, M, K = shape + q = torch.rand(shape, device=device) + sub_label = f"B={B}, M={M}, K={K}" + + if True: + r = xformers.ops.memory_efficient_attention(q, q, q) + + rr = ref_attention(q, q, q) + assert (r - rr).abs().max() < 1e-5 + + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + results.append( + benchmark.Timer( + stmt="fn(q, q, q)", + globals={ + "q": q, + "fn": xformers.ops.memory_efficient_attention, + }, + label="attention", + description="optimized", + sub_label=sub_label, + num_threads=num_threads, + ).blocked_autorange(min_run_time=min_run_time) + ) + torch.cuda.synchronize() + memory = torch.cuda.max_memory_allocated() / 2 ** 20 + mem_use["optimized"][sub_label] = memory + memory_str = f"Memory used: {memory} MB" + + print("Optimized", memory_str) + + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + results.append( + benchmark.Timer( + stmt="fn(q, q, q)", + globals={ + "q": q, + "fn": ref_attention, + }, + label="attention", + description="vanilla", + sub_label=sub_label, + num_threads=num_threads, + ).blocked_autorange(min_run_time=min_run_time) + ) + + torch.cuda.synchronize() + memory = torch.cuda.max_memory_allocated() / 2 ** 20 + mem_use["vanilla"][sub_label] = memory + memory_str = f"Memory used: {memory} MB" + print("Vanilla", memory_str) + + compare = benchmark.Compare(results) + compare.print() + + pprint.pprint(mem_use) + + +def benchmark_backward(): + print(f"Processing {len(SHAPES)} cases") + print("Backward") + for num_threads in NUM_THREADS: + for shape in SHAPES: + print(f"===== {shape} =====") + B, M, K = shape + q = torch.rand(shape, device=device, requires_grad=True) + sub_label = f"B={B}, M={M}, K={K}" + + if True: + r = xformers.ops.memory_efficient_attention(q, q, q) + r.backward(torch.ones_like(q)) + + grad = q.grad + q.grad = None + + rr = ref_attention(q, q, q) + rr.backward(torch.ones_like(q)) + assert (grad - q.grad).abs().max() < 1e-5 + + out = xformers.ops.memory_efficient_attention(q, q, q) + grad = torch.ones_like(q) + + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + results.append( + benchmark.Timer( + stmt="out.backward(grad, retain_graph=True)", + globals={ + "out": out, + "grad": grad, + }, + label="attention", + description="optimized", + sub_label=sub_label, + num_threads=num_threads, + ).blocked_autorange(min_run_time=min_run_time) + ) + torch.cuda.synchronize() + memory = torch.cuda.max_memory_allocated() / 2 ** 20 + mem_use["optimized"][sub_label] = memory + memory_str = f"Memory used: {memory} MB" + + print("Optimized", memory_str) + + out = ref_attention(q, q, q) + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + results.append( + benchmark.Timer( + stmt="out.backward(grad, retain_graph=True)", + globals={ + "out": out, + "grad": grad, + }, + label="attention", + description="vanilla", + sub_label=sub_label, + num_threads=num_threads, + ).blocked_autorange(min_run_time=min_run_time) + ) + + torch.cuda.synchronize() + memory = torch.cuda.max_memory_allocated() / 2 ** 20 + mem_use["vanilla"][sub_label] = memory + memory_str = f"Memory used: {memory} MB" + print("Vanilla", memory_str) + + compare = benchmark.Compare(results) + compare.print() + + pprint.pprint(mem_use) + + +benchmark_forward() +benchmark_backward() From 05ea687770da822ce5d5d98fb19f4f2f85ad5648 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 22 Apr 2022 10:24:12 -0700 Subject: [PATCH 44/45] Add comment --- xformers/ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xformers/ops.py b/xformers/ops.py index 8812746176..355c170ca2 100644 --- a/xformers/ops.py +++ b/xformers/ops.py @@ -53,6 +53,7 @@ def memory_efficient_attention( `"Self-Attention Does Not Need O(n^2) Memory" `_. """ + # fast-path that doesn't require computing the logsumexp for backward computation if all(x.requires_grad is False for x in [query, key, value]): return torch.ops.xformers.efficient_attention(query, key, value, False)[0] return _MemoryEfficientAttentionOp.apply(query, key, value) From 8e7bbb9a028ed612f1c2a7ec0d923c9359c78906 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 22 Apr 2022 10:25:12 -0700 Subject: [PATCH 45/45] clang-format --- xformers/components/attention/csrc/cuda/attention.cu | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index b5efefd337..b1fff91619 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -926,9 +926,8 @@ __global__ void attention_backward_grad_qk_kernel( for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { fact[kBlockSizeQ * threadIdx.x + q_item_idx] [kBlockSizeK * threadIdx.y + k_item_idx] = - attn_v[q_item_idx][k_item_idx] * scale * ( - grad_attn_v[q_item_idx][k_item_idx] - - tmp_sum[q_item_idx]); + attn_v[q_item_idx][k_item_idx] * scale * + (grad_attn_v[q_item_idx][k_item_idx] - tmp_sum[q_item_idx]); } } __syncthreads();