diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index c6623603c7..d4542e1ab4 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -49,3 +49,58 @@ def test_key_query_all_ones(device, q_len, kv_len, batch_size, k_len): ref = value.mean(1, keepdim=True).expand_as(query) assert torch.allclose(out, ref, atol=1e-5) + + +@pytest.mark.parametrize("k_len", [5, 6, 32]) +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("kv_len", [3, 15, 32, 33]) +@pytest.mark.parametrize("q_len", [2, 3, 5]) +@pytest.mark.parametrize("device", _devices) +def test_logsumexp(device, q_len, kv_len, batch_size, k_len): + scale = 3 + query = torch.randn((batch_size, q_len, k_len), device=device) * scale + key = torch.randn((batch_size, kv_len, k_len), device=device) * scale + value = torch.randn((batch_size, kv_len, k_len), device=device) * scale + + _, lse = torch.ops.xformers.efficient_attention(query, key, value, True) + ref_lse = ((query / k_len ** 0.5) @ key.transpose(-2, -1)).logsumexp(-1) + + assert torch.allclose(lse, ref_lse, atol=2e-4) + + +@pytest.mark.parametrize("k_len", [5, 6, 32]) +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("kv_len", [3, 15, 32, 33]) +@pytest.mark.parametrize("q_len", [2, 3, 5]) +@pytest.mark.parametrize("device", _devices) +def test_memory_efficient_attention_backward(device, q_len, kv_len, batch_size, k_len): + scale = 3 + query = torch.randn((batch_size, q_len, k_len), device=device) * scale + key = torch.randn((batch_size, kv_len, k_len), device=device) * scale + value = torch.randn((batch_size, kv_len, k_len), device=device) * scale + + query.requires_grad_(True) + key.requires_grad_(True) + value.requires_grad_(True) + + out = xformers.ops.memory_efficient_attention(query, key, value) + out.backward(torch.ones_like(query)) + + grad_q = query.grad + grad_k = key.grad + grad_v = value.grad + + query.grad = None + key.grad = None + value.grad = None + + ref = ref_attention(query, key, value) + ref.backward(torch.ones_like(query)) + + # there is some extra precision loss in the CPU implementation due to an + # extra accumulation step in grad_q, which is not present in the CUDA + # implementation + atol = 3e-4 if device == "cuda" else 4e-4 + assert torch.allclose(grad_q, query.grad, atol=atol), "grad_q doesn't match" + assert torch.allclose(grad_k, key.grad, atol=atol), "grad_k doesn't match" + assert torch.allclose(grad_v, value.grad, atol=atol), "grad_v doesn't match" diff --git a/xformers/benchmarks/benchmark_mem_eff_attention.py b/xformers/benchmarks/benchmark_mem_eff_attention.py index 88968c907d..9f0662c820 100644 --- a/xformers/benchmarks/benchmark_mem_eff_attention.py +++ b/xformers/benchmarks/benchmark_mem_eff_attention.py @@ -30,66 +30,147 @@ def ref_attention(q, k, v): results = [] mem_use: Dict[str, Dict[str, float]] = dict(optimized={}, vanilla={}) -print(f"Processing {len(SHAPES)} cases") -for num_threads in NUM_THREADS: - for shape in SHAPES: - print(f"===== {shape} =====") - B, M, K = shape - q = torch.rand(shape, device=device) - sub_label = f"B={B}, M={M}, K={K}" - - if True: - r = xformers.ops.memory_efficient_attention(q, q, q) - - rr = ref_attention(q, q, q) - assert (r - rr).abs().max() < 1e-5 - - torch.cuda.reset_peak_memory_stats() - torch.cuda.synchronize() - results.append( - benchmark.Timer( - stmt="fn(q, q, q)", - globals={ - "q": q, - "fn": torch.ops.xformers.efficient_attention, - }, - label="attention", - description="optimized", - sub_label=sub_label, - num_threads=num_threads, - ).blocked_autorange(min_run_time=min_run_time) - ) - torch.cuda.synchronize() - memory = torch.cuda.max_memory_allocated() / 2 ** 20 - mem_use["optimized"][sub_label] = memory - memory_str = f"Memory used: {memory} MB" - - print("Optimized", memory_str) - - torch.cuda.reset_peak_memory_stats() - torch.cuda.synchronize() - results.append( - benchmark.Timer( - stmt="fn(q, q, q)", - globals={ - "q": q, - "fn": ref_attention, - }, - label="attention", - description="vanilla", - sub_label=sub_label, - num_threads=num_threads, - ).blocked_autorange(min_run_time=min_run_time) - ) - - torch.cuda.synchronize() - memory = torch.cuda.max_memory_allocated() / 2 ** 20 - mem_use["vanilla"][sub_label] = memory - memory_str = f"Memory used: {memory} MB" - print("Vanilla", memory_str) - - -compare = benchmark.Compare(results) -compare.print() - -pprint.pprint(mem_use) + +def benchmark_forward(): + print(f"Processing {len(SHAPES)} cases") + print("Forward") + for num_threads in NUM_THREADS: + for shape in SHAPES: + print(f"===== {shape} =====") + B, M, K = shape + q = torch.rand(shape, device=device) + sub_label = f"B={B}, M={M}, K={K}" + + if True: + r = xformers.ops.memory_efficient_attention(q, q, q) + + rr = ref_attention(q, q, q) + assert (r - rr).abs().max() < 1e-5 + + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + results.append( + benchmark.Timer( + stmt="fn(q, q, q)", + globals={ + "q": q, + "fn": xformers.ops.memory_efficient_attention, + }, + label="attention", + description="optimized", + sub_label=sub_label, + num_threads=num_threads, + ).blocked_autorange(min_run_time=min_run_time) + ) + torch.cuda.synchronize() + memory = torch.cuda.max_memory_allocated() / 2 ** 20 + mem_use["optimized"][sub_label] = memory + memory_str = f"Memory used: {memory} MB" + + print("Optimized", memory_str) + + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + results.append( + benchmark.Timer( + stmt="fn(q, q, q)", + globals={ + "q": q, + "fn": ref_attention, + }, + label="attention", + description="vanilla", + sub_label=sub_label, + num_threads=num_threads, + ).blocked_autorange(min_run_time=min_run_time) + ) + + torch.cuda.synchronize() + memory = torch.cuda.max_memory_allocated() / 2 ** 20 + mem_use["vanilla"][sub_label] = memory + memory_str = f"Memory used: {memory} MB" + print("Vanilla", memory_str) + + compare = benchmark.Compare(results) + compare.print() + + pprint.pprint(mem_use) + + +def benchmark_backward(): + print(f"Processing {len(SHAPES)} cases") + print("Backward") + for num_threads in NUM_THREADS: + for shape in SHAPES: + print(f"===== {shape} =====") + B, M, K = shape + q = torch.rand(shape, device=device, requires_grad=True) + sub_label = f"B={B}, M={M}, K={K}" + + if True: + r = xformers.ops.memory_efficient_attention(q, q, q) + r.backward(torch.ones_like(q)) + + grad = q.grad + q.grad = None + + rr = ref_attention(q, q, q) + rr.backward(torch.ones_like(q)) + assert (grad - q.grad).abs().max() < 1e-5 + + out = xformers.ops.memory_efficient_attention(q, q, q) + grad = torch.ones_like(q) + + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + results.append( + benchmark.Timer( + stmt="out.backward(grad, retain_graph=True)", + globals={ + "out": out, + "grad": grad, + }, + label="attention", + description="optimized", + sub_label=sub_label, + num_threads=num_threads, + ).blocked_autorange(min_run_time=min_run_time) + ) + torch.cuda.synchronize() + memory = torch.cuda.max_memory_allocated() / 2 ** 20 + mem_use["optimized"][sub_label] = memory + memory_str = f"Memory used: {memory} MB" + + print("Optimized", memory_str) + + out = ref_attention(q, q, q) + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + results.append( + benchmark.Timer( + stmt="out.backward(grad, retain_graph=True)", + globals={ + "out": out, + "grad": grad, + }, + label="attention", + description="vanilla", + sub_label=sub_label, + num_threads=num_threads, + ).blocked_autorange(min_run_time=min_run_time) + ) + + torch.cuda.synchronize() + memory = torch.cuda.max_memory_allocated() / 2 ** 20 + mem_use["vanilla"][sub_label] = memory + memory_str = f"Memory used: {memory} MB" + print("Vanilla", memory_str) + + compare = benchmark.Compare(results) + compare.print() + + pprint.pprint(mem_use) + + +benchmark_forward() +benchmark_backward() diff --git a/xformers/components/attention/csrc/attention.cpp b/xformers/components/attention/csrc/attention.cpp index 931b8f9689..eb35b95c04 100644 --- a/xformers/components/attention/csrc/attention.cpp +++ b/xformers/components/attention/csrc/attention.cpp @@ -2,5 +2,7 @@ TORCH_LIBRARY_FRAGMENT(xformers, m) { m.def(TORCH_SELECTIVE_SCHEMA( - "xformers::efficient_attention(Tensor query, Tensor key, Tensor value) -> Tensor")); + "xformers::efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_logsumexp) -> (Tensor, Tensor)")); + m.def(TORCH_SELECTIVE_SCHEMA( + "xformers::efficient_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor logsumexp) -> (Tensor, Tensor, Tensor)")); } diff --git a/xformers/components/attention/csrc/cpu/attention.cpp b/xformers/components/attention/csrc/cpu/attention.cpp index d79eb36c4c..ee207cdf75 100644 --- a/xformers/components/attention/csrc/cpu/attention.cpp +++ b/xformers/components/attention/csrc/cpu/attention.cpp @@ -28,10 +28,12 @@ scalar_t max(scalar_t* buf) { template void attention_kernel( at::TensorAccessor output, + at::TensorAccessor logsumexp, at::TensorAccessor query, at::TensorAccessor key, at::TensorAccessor value, - at::TensorAccessor buffer //, + at::TensorAccessor buffer, + bool compute_logsumexp // at::TensorAccessor mask ) { // TODO: optimize the code by adding blocking @@ -90,15 +92,18 @@ void attention_kernel( for (int64_t k = 0; k < K; k++) { oo[k] = buf[k] / s_prime; } + if (compute_logsumexp) + logsumexp[i][j] = m_prime + std::log(s_prime); } } }); } -at::Tensor attention( +std::tuple attention( const at::Tensor& query, const at::Tensor& key, - const at::Tensor& value + const at::Tensor& value, + bool compute_logsumexp // const at::Tensor& mask ) { TORCH_CHECK(query.dim() == key.dim()); @@ -132,19 +137,174 @@ at::Tensor attention( int64_t K = query.size(2); at::Tensor res = at::empty({B, M, K}, query.options()); + at::Tensor logsumexp = at::empty({B, M}, query.options()); at::Tensor buffer = at::empty({at::get_num_threads(), 1, K}, query.options()); AT_DISPATCH_FLOATING_TYPES(query.scalar_type(), "attention_kernel", [&] { attention_kernel( res.accessor(), + logsumexp.accessor(), query.accessor(), key.accessor(), value.accessor(), - buffer.accessor()); + buffer.accessor(), + compute_logsumexp); }); - return res; + return std::make_tuple(res, logsumexp); +} + +template +void attention_backward_kernel( + at::TensorAccessor grad_q, + at::TensorAccessor grad_k, + at::TensorAccessor grad_v, + at::TensorAccessor grad_out, + at::TensorAccessor q, + at::TensorAccessor k, + at::TensorAccessor v, + at::TensorAccessor logsumexp_normalizer, + at::TensorAccessor buffer, + at::TensorAccessor buffer2 //, + // at::TensorAccessor mask +) { + int64_t K = q.size(2); + int64_t B = q.size(0); + int64_t M = q.size(1); + int64_t N = k.size(1); + int64_t grain_size = 1; // buffer.size(1); + scalar_t scale = 1.0 / std::sqrt(scalar_t(K)); + at::parallel_for(0, B, grain_size, [&](int64_t start, int64_t end) { + auto buf = buffer[at::get_thread_num()][0]; + auto buf2 = buffer2[at::get_thread_num()][0]; + for (int64_t i = start; i < end; i++) { + for (int64_t j = 0; j < M; j++) { + for (int64_t k = 0; k < K; k++) { + buf[k] = 0; + } + auto query_i = q[i][j]; + auto normalizer = logsumexp_normalizer[i][j]; + scalar_t tmp_sum = 0; + for (int64_t l = 0; l < N; l++) { + auto key_j = k[i][l]; + scalar_t si = 0; + for (int64_t k = 0; k < K; k++) { + si += query_i[k] * key_j[k]; + } + scalar_t attn_v = std::exp(si * scale - normalizer); + + for (int64_t k = 0; k < K; k++) { + grad_v[i][l][k] += attn_v * grad_out[i][j][k]; + } + + // now compute grad_q and grad_k + // first compute the gradient for the self-attention + // after softmax + scalar_t grad_attn_v = 0; + for (int64_t k = 0; k < K; k++) { + grad_attn_v += grad_out[i][j][k] * v[i][l][k]; + // grad_attn_v[i][j][l] += grad_out[i][j][k] * v[i][l][k]; + } + + // those are temporaries for the gradient of the softmax + scalar_t tmp = attn_v * grad_attn_v * scale; + tmp_sum += tmp; + + // grad_q is easy + for (int64_t k = 0; k < K; k++) { + grad_q[i][j][k] += tmp * key_j[k]; + buf[k] += attn_v * key_j[k]; + } + + // but grad_k is a bit trickier + buf2[l] = attn_v; + for (int64_t k = 0; k < K; k++) { + grad_k[i][l][k] += tmp * query_i[k]; + } + } + for (int64_t l = 0; l < N; l++) { + for (int64_t k = 0; k < K; k++) { + grad_k[i][l][k] -= buf2[l] * query_i[k] * tmp_sum; + } + } + for (int64_t k = 0; k < K; k++) { + grad_q[i][j][k] -= buf[k] * tmp_sum; + } + } + } + }); +} + +std::tuple attention_backward( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& logsumexp + // const at::Tensor& mask +) { + TORCH_CHECK(query.dim() == grad_out.dim()); + TORCH_CHECK(query.dim() == key.dim()); + TORCH_CHECK(query.dim() == value.dim()); + // TORCH_CHECK(query.dim() == mask.dim()); + TORCH_CHECK(query.dim() == 3); + + TORCH_CHECK(query.size(0) == grad_out.size(0)); + TORCH_CHECK(query.size(1) == grad_out.size(1)); + TORCH_CHECK(query.size(2) == grad_out.size(2)); + + TORCH_CHECK(query.size(2) == key.size(2)); + TORCH_CHECK(query.size(0) == key.size(0)); + + TORCH_CHECK(query.size(0) == value.size(0)); + TORCH_CHECK(key.size(1) == value.size(1)); + TORCH_CHECK( + query.size(2) == + value.size(2)); // TODO: drop this limitation in the future + + TORCH_CHECK(!query.is_cuda(), "query must be a CPU tensor"); + TORCH_CHECK(!key.is_cuda(), "key must be a CPU tensor"); + TORCH_CHECK(!value.is_cuda(), "value must be a CPU tensor"); + TORCH_CHECK(!grad_out.is_cuda(), "grad_out must be a CPU tensor"); + + TORCH_CHECK(!query.is_sparse(), "query must be a dense tensor"); + TORCH_CHECK(!key.is_sparse(), "key must be a dense tensor"); + TORCH_CHECK(!value.is_sparse(), "value must be a dense tensor"); + TORCH_CHECK(!grad_out.is_sparse(), "grad_out must be a dense tensor"); + + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + int64_t K = query.size(2); + + at::Tensor res = at::empty({B, M, K}, query.options()); + at::Tensor grad_q = at::zeros_like(query); + at::Tensor grad_k = at::zeros_like(key); + at::Tensor grad_v = at::zeros_like(value); + + at::Tensor buffer = at::empty({at::get_num_threads(), 1, K}, query.options()); + at::Tensor buffer2 = + at::zeros({at::get_num_threads(), 1, N}, query.options()); + + AT_DISPATCH_FLOATING_TYPES( + query.scalar_type(), "attention_backward_kernel", [&] { + attention_backward_kernel( + grad_q.accessor(), + grad_k.accessor(), + grad_v.accessor(), + grad_out.accessor(), + query.accessor(), + key.accessor(), + value.accessor(), + logsumexp.accessor(), + buffer.accessor(), + buffer2.accessor() + // idxs.accessor() + ); + }); + + return std::make_tuple(grad_q, grad_k, grad_v); } } // namespace @@ -153,4 +313,7 @@ TORCH_LIBRARY_IMPL(xformers, CPU, m) { m.impl( TORCH_SELECTIVE_NAME("xformers::efficient_attention"), TORCH_FN(attention)); + m.impl( + TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward"), + TORCH_FN(attention_backward)); } diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index d9880548ca..b1fff91619 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -5,6 +5,7 @@ #include #include +#include #include "sputnik/vector_utils.h" @@ -58,6 +59,25 @@ __device__ __forceinline__ void iDiv(scalar_t x1, float* out) { out[0] /= x1; } +template +__device__ __forceinline__ void myGpuAtomicAdd(scalar_t* address, float4 val) { + gpuAtomicAdd(address + 0, val.x); + gpuAtomicAdd(address + 1, val.y); + gpuAtomicAdd(address + 2, val.z); + gpuAtomicAdd(address + 3, val.w); +} + +template +__device__ __forceinline__ void myGpuAtomicAdd(scalar_t* address, float2 val) { + gpuAtomicAdd(address + 0, val.x); + gpuAtomicAdd(address + 1, val.y); +} + +template +__device__ __forceinline__ void myGpuAtomicAdd(scalar_t* address, float val) { + gpuAtomicAdd(address, val); +} + template __device__ __forceinline__ scalar_t warpSum(scalar_t val) { for (int stride = WARP_SIZE / 2; stride > 0; stride >>= 1) { @@ -355,9 +375,11 @@ template < int kBlockSizeK, int kBlockSizeQ, int WARP_SIZE, - int BUFFER_SIZE> + int BUFFER_SIZE, + bool compute_logsumexp> __global__ void attention_kernel( at::PackedTensorAccessor output, + at::PackedTensorAccessor logsumexp, at::PackedTensorAccessor query, at::PackedTensorAccessor key, at::PackedTensorAccessor value) { @@ -379,6 +401,7 @@ __global__ void attention_kernel( vec_t* query_block[kBlockSizeQ]; vec_t* output_block[kBlockSizeQ]; + scalar_t* logsumexp_block[kBlockSizeQ]; // TODO [BUFFER_SIZE limitation]: the current strategy assumes a // statically-known size for K. Ideally we would like to remove this // limitation in the future, so that any K is supported @@ -393,6 +416,7 @@ __global__ void attention_kernel( output_block[q_item_idx] = reinterpret_cast(output[batch_idx][index].data()); m_prime[q_item_idx] = -std::numeric_limits::infinity(); + logsumexp_block[q_item_idx] = &logsumexp[batch_idx][index]; } #if 0 // this for now makes things slower @@ -473,56 +497,30 @@ __global__ void attention_kernel( output_block[q_item_idx][k] = tmp; } } + + if (compute_logsumexp) { +#pragma unroll + for (int64_t q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + *logsumexp_block[q_item_idx] = + m_prime[q_item_idx] + std::log(s_prime[q_item_idx]); + } + } } -at::Tensor attention( +template +void launch_attention( + at::Tensor& res, + at::Tensor& logsumexp, const at::Tensor& query, const at::Tensor& key, - const at::Tensor& value - // const at::Tensor& mask -) { - TORCH_CHECK(query.dim() == key.dim()); - TORCH_CHECK(query.dim() == value.dim()); - // TORCH_CHECK(query.dim() == mask.dim()); - TORCH_CHECK(query.dim() == 3); - TORCH_CHECK(query.size(2) == key.size(2)); - TORCH_CHECK(query.size(0) == key.size(0)); - - TORCH_CHECK(query.size(0) == value.size(0)); - TORCH_CHECK(key.size(1) == value.size(1)); - TORCH_CHECK( - query.size(2) == - value.size(2)); // TODO: drop this limitation in the future - - TORCH_CHECK(query.is_cuda(), "query must be a CUDA tensor"); - TORCH_CHECK(key.is_cuda(), "key must be a CUDA tensor"); - TORCH_CHECK(value.is_cuda(), "value must be a CUDA tensor"); - - TORCH_CHECK(!query.is_sparse(), "query must be a dense tensor"); - TORCH_CHECK(!key.is_sparse(), "key must be a dense tensor"); - TORCH_CHECK(!value.is_sparse(), "value must be a dense tensor"); - - // TODO drop this limitation in the future - TORCH_CHECK(query.is_contiguous()); - TORCH_CHECK(key.is_contiguous()); - TORCH_CHECK(value.is_contiguous()); - - // TODO: support other dtypes in the future - TORCH_CHECK( - query.scalar_type() == at::ScalarType::Float, - "Only float32 type is supported for now"); - - at::cuda::CUDAGuard device_guard(query.device()); + const at::Tensor& value) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int64_t B = query.size(0); int64_t M = query.size(1); int64_t N = key.size(1); int64_t K = query.size(2); - at::Tensor res = at::zeros({B, M, K}, query.options()); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - constexpr int WARP_SIZE = 4; constexpr int kBlockSizeK = 32; @@ -546,8 +544,10 @@ at::Tensor attention( kBlockSizeK, kBlockSizeQ, WARP_SIZE, - BUFFER_SIZE><<>>( + BUFFER_SIZE, + compute_logsumexp><<>>( res.packed_accessor(), + logsumexp.packed_accessor(), query.packed_accessor(), key.packed_accessor(), value.packed_accessor()); @@ -561,8 +561,10 @@ at::Tensor attention( kBlockSizeK, kBlockSizeQ, WARP_SIZE, - BUFFER_SIZE><<>>( + BUFFER_SIZE, + compute_logsumexp><<>>( res.packed_accessor(), + logsumexp.packed_accessor(), query.packed_accessor(), key.packed_accessor(), value.packed_accessor()); @@ -577,14 +579,611 @@ at::Tensor attention( kBlockSizeK, kBlockSizeQ, WARP_SIZE, - BUFFER_SIZE><<>>( + BUFFER_SIZE, + compute_logsumexp><<>>( res.packed_accessor(), + logsumexp.packed_accessor(), query.packed_accessor(), key.packed_accessor(), value.packed_accessor()); } +} + +std::tuple attention( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + bool compute_logsumexp + // const at::Tensor& mask +) { + TORCH_CHECK(query.dim() == key.dim()); + TORCH_CHECK(query.dim() == value.dim()); + // TORCH_CHECK(query.dim() == mask.dim()); + TORCH_CHECK(query.dim() == 3); + TORCH_CHECK(query.size(2) == key.size(2)); + TORCH_CHECK(query.size(0) == key.size(0)); + + TORCH_CHECK(query.size(0) == value.size(0)); + TORCH_CHECK(key.size(1) == value.size(1)); + TORCH_CHECK( + query.size(2) == + value.size(2)); // TODO: drop this limitation in the future + + TORCH_CHECK(query.is_cuda(), "query must be a CUDA tensor"); + TORCH_CHECK(key.is_cuda(), "key must be a CUDA tensor"); + TORCH_CHECK(value.is_cuda(), "value must be a CUDA tensor"); + + TORCH_CHECK(!query.is_sparse(), "query must be a dense tensor"); + TORCH_CHECK(!key.is_sparse(), "key must be a dense tensor"); + TORCH_CHECK(!value.is_sparse(), "value must be a dense tensor"); + + // TODO drop this limitation in the future + TORCH_CHECK(query.is_contiguous()); + TORCH_CHECK(key.is_contiguous()); + TORCH_CHECK(value.is_contiguous()); + + // TODO: support other dtypes in the future + TORCH_CHECK( + query.scalar_type() == at::ScalarType::Float, + "Only float32 type is supported for now"); + + at::cuda::CUDAGuard device_guard(query.device()); + + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + int64_t K = query.size(2); + + at::Tensor res = at::zeros({B, M, K}, query.options()); + at::Tensor logsumexp = at::empty({B, M}, query.options()); + + // have to pass compute_logsumexp as a template parameter + // otherwise there is a slowdown in the kernel... + if (compute_logsumexp) { + launch_attention(res, logsumexp, query, key, value); + } else { + launch_attention(res, logsumexp, query, key, value); + } + + AT_CUDA_CHECK(cudaGetLastError()); + + return std::make_tuple(res, logsumexp); +} + +template < + typename scalar_t, + typename vec_t, + int kBlockSizeQ, + int kBlockSizeK, + int TILE_SIZEQ, + int TILE_SIZEK, + bool check_bounds> +__global__ void attention_backward_grad_v_kernel( + at::PackedTensorAccessor grad_v, + at::PackedTensorAccessor grad_out, + at::PackedTensorAccessor query, + at::PackedTensorAccessor key, + at::PackedTensorAccessor value, + at::PackedTensorAccessor tmp_sum_i, + at::PackedTensorAccessor logsumexp_normalizer) { + int64_t K = query.size(2); + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + + constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); + + int64_t batch_idx = blockIdx.z; + int64_t query_idx = + blockIdx.x * blockDim.x * kBlockSizeQ + threadIdx.x * kBlockSizeQ; + int64_t l = blockIdx.y * blockDim.y * kBlockSizeK + threadIdx.y * kBlockSizeK; + + __shared__ scalar_t fact[TILE_SIZEQ][TILE_SIZEK + 1]; + +#pragma unroll + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { +#pragma unroll + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + fact[kBlockSizeQ * threadIdx.x + q_item_idx] + [kBlockSizeK * threadIdx.y + k_item_idx] = 0; + } + } + + scalar_t normalizer[kBlockSizeQ]; + scalar_t tmp_sum[kBlockSizeQ] = {0}; + + vec_t *qb[kBlockSizeQ], *kb[kBlockSizeK], *vb[kBlockSizeK], *gb[kBlockSizeQ], + *gbb[TILE_SIZEQ]; + scalar_t maskQ[kBlockSizeQ], maskK[kBlockSizeK]; + + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + int64_t index = l + k_item_idx; + maskK[k_item_idx] = index >= N ? scalar_t(0) : scalar_t(1); + if (check_bounds) + index = min(index, N - 1); + kb[k_item_idx] = reinterpret_cast(key[batch_idx][index].data()); + vb[k_item_idx] = reinterpret_cast(value[batch_idx][index].data()); + } + + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + int64_t index = query_idx + q_item_idx; + maskQ[q_item_idx] = index >= M ? scalar_t(0) : scalar_t(1); + if (check_bounds) + index = min(index, M - 1); + qb[q_item_idx] = reinterpret_cast(query[batch_idx][index].data()); + gb[q_item_idx] = + reinterpret_cast(grad_out[batch_idx][index].data()); + } + + for (int64_t i = 0; i < TILE_SIZEQ; i++) { + int64_t index = query_idx + i - kBlockSizeQ * threadIdx.x; + if (check_bounds) + index = min(index, M - 1); + gbb[i] = reinterpret_cast(grad_out[batch_idx][index].data()); + } + + for (int i = 0; i < kBlockSizeQ; i++) { + int64_t index = query_idx + i; + if (check_bounds) + index = min(index, M - 1); + normalizer[i] = logsumexp_normalizer[batch_idx][index]; + } + + scalar_t attn_v[kBlockSizeQ][kBlockSizeK] = {0}; + scalar_t grad_attn_v[kBlockSizeQ][kBlockSizeK] = {0}; + scalar_t scale = 1.0 / std::sqrt(scalar_t(K)); + + for (int64_t k = 0; k < K / kVecSize; k += 1) { +#pragma unroll + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + vec_t kk = __ldg(kb[k_item_idx] + k); + iMul(scale, &kk); + vec_t tt = __ldg(vb[k_item_idx] + k); +#pragma unroll + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + sputnik::VectorCompute::Dot( + __ldg(qb[q_item_idx] + k), kk, &attn_v[q_item_idx][k_item_idx]); + sputnik::VectorCompute::Dot( + __ldg(gb[q_item_idx] + k), + tt, + &grad_attn_v[q_item_idx][k_item_idx]); + } + } + } +#pragma unroll + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { +#pragma unroll + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + attn_v[q_item_idx][k_item_idx] = + std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]) * + maskQ[q_item_idx] * maskK[k_item_idx]; + } + } + +#pragma unroll + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { +#pragma unroll + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + fact[kBlockSizeQ * threadIdx.x + q_item_idx] + [kBlockSizeK * threadIdx.y + k_item_idx] = + attn_v[q_item_idx][k_item_idx]; + tmp_sum[q_item_idx] += + attn_v[q_item_idx][k_item_idx] * grad_attn_v[q_item_idx][k_item_idx]; + } + } + __syncthreads(); + + for (int64_t k = threadIdx.x; k < K / kVecSize; k += blockDim.x) { + vec_t res[kBlockSizeK] = {0}; +#pragma unroll + for (int64_t i = 0; i < TILE_SIZEQ; i++) { + vec_t kk = __ldg(gbb[i] + k); +#pragma unroll + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + sputnik::VectorCompute::FMA( + fact[i][kBlockSizeK * threadIdx.y + k_item_idx], + kk, + &res[k_item_idx]); + } + } +#pragma unroll + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + int64_t index = l + k_item_idx; + if (check_bounds) + index = min(index, N - 1); + myGpuAtomicAdd(&grad_v[batch_idx][index][k * kVecSize], res[k_item_idx]); + } + } + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + int64_t index = query_idx + q_item_idx; + if (check_bounds) + index = min(index, M - 1); + myGpuAtomicAdd(&tmp_sum_i[batch_idx][index], tmp_sum[q_item_idx]); + } +} + +template < + typename scalar_t, + typename vec_t, + int kBlockSizeQ, + int kBlockSizeK, + int TILE_SIZEQ, + int TILE_SIZEK, + bool check_bounds> +__global__ void attention_backward_grad_qk_kernel( + at::PackedTensorAccessor grad_q, + at::PackedTensorAccessor grad_k, + at::PackedTensorAccessor grad_out, + at::PackedTensorAccessor query, + at::PackedTensorAccessor key, + at::PackedTensorAccessor value, + at::PackedTensorAccessor tmp_sum_i, + at::PackedTensorAccessor logsumexp_normalizer) { + int64_t K = query.size(2); + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + + constexpr int kVecSize = sizeof(vec_t) / sizeof(scalar_t); + + int64_t batch_idx = blockIdx.z; + int64_t query_idx = + blockIdx.x * blockDim.x * kBlockSizeQ + threadIdx.x * kBlockSizeQ; + int64_t l = blockIdx.y * blockDim.y * kBlockSizeK + threadIdx.y * kBlockSizeK; + + __shared__ scalar_t fact[TILE_SIZEQ][TILE_SIZEK + 1]; + +#pragma unroll + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { +#pragma unroll + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + fact[kBlockSizeQ * threadIdx.x + q_item_idx] + [kBlockSizeK * threadIdx.y + k_item_idx] = 0; + } + } + + scalar_t normalizer[kBlockSizeQ]; + scalar_t tmp_sum[kBlockSizeQ]; + + vec_t *qb[kBlockSizeQ], *kb[kBlockSizeK], *vb[kBlockSizeK], *gb[kBlockSizeQ], + *qbb[TILE_SIZEQ], *kbb[TILE_SIZEK]; + scalar_t maskQ[kBlockSizeQ], maskK[kBlockSizeK]; + + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + int64_t index = l + k_item_idx; + maskK[k_item_idx] = index >= N ? scalar_t(0) : scalar_t(1); + if (check_bounds) + index = min(index, N - 1); + kb[k_item_idx] = reinterpret_cast(key[batch_idx][index].data()); + vb[k_item_idx] = reinterpret_cast(value[batch_idx][index].data()); + } + + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + int64_t index = query_idx + q_item_idx; + maskQ[q_item_idx] = index >= M ? scalar_t(0) : scalar_t(1); + if (check_bounds) + index = min(index, M - 1); + qb[q_item_idx] = reinterpret_cast(query[batch_idx][index].data()); + gb[q_item_idx] = + reinterpret_cast(grad_out[batch_idx][index].data()); + } + for (int64_t i = 0; i < TILE_SIZEQ; i++) { + int64_t index = query_idx + i - kBlockSizeQ * threadIdx.x; + if (check_bounds) + index = min(index, M - 1); + qbb[i] = reinterpret_cast(query[batch_idx][index].data()); + } + + for (int64_t i = 0; i < TILE_SIZEK; i++) { + int64_t index = l + i - kBlockSizeK * threadIdx.y; + if (check_bounds) + index = min(index, N - 1); + kbb[i] = reinterpret_cast(key[batch_idx][index].data()); + } + + for (int i = 0; i < kBlockSizeQ; i++) { + int64_t index = query_idx + i; + if (check_bounds) + index = min(index, M - 1); + normalizer[i] = logsumexp_normalizer[batch_idx][index]; + tmp_sum[i] = tmp_sum_i[batch_idx][index]; + } + + scalar_t attn_v[kBlockSizeQ][kBlockSizeK] = {0}; + scalar_t grad_attn_v[kBlockSizeQ][kBlockSizeK] = {0}; + scalar_t scale = 1.0 / std::sqrt(scalar_t(K)); + + for (int64_t k = 0; k < K / kVecSize; k += 1) { +#pragma unroll + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + vec_t kk = __ldg(kb[k_item_idx] + k); + iMul(scale, &kk); + vec_t tt = __ldg(vb[k_item_idx] + k); +#pragma unroll + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + sputnik::VectorCompute::Dot( + __ldg(qb[q_item_idx] + k), kk, &attn_v[q_item_idx][k_item_idx]); + sputnik::VectorCompute::Dot( + __ldg(gb[q_item_idx] + k), + tt, + &grad_attn_v[q_item_idx][k_item_idx]); + } + } + } +#pragma unroll + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { +#pragma unroll + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + attn_v[q_item_idx][k_item_idx] = + std::exp(attn_v[q_item_idx][k_item_idx] - normalizer[q_item_idx]) * + maskQ[q_item_idx] * maskK[k_item_idx]; + } + } + +#pragma unroll + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { +#pragma unroll + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + fact[kBlockSizeQ * threadIdx.x + q_item_idx] + [kBlockSizeK * threadIdx.y + k_item_idx] = + attn_v[q_item_idx][k_item_idx] * scale * + (grad_attn_v[q_item_idx][k_item_idx] - tmp_sum[q_item_idx]); + } + } + __syncthreads(); + + for (int64_t k = threadIdx.y; k < K / kVecSize; k += blockDim.y) { + vec_t res[kBlockSizeQ] = {0}; +#pragma unroll + for (int64_t i = 0; i < TILE_SIZEK; i++) { + vec_t kk = __ldg(kbb[i] + k); +#pragma unroll + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + sputnik::VectorCompute::FMA( + fact[kBlockSizeQ * threadIdx.x + q_item_idx][i], + kk, + &res[q_item_idx]); + } + } +#pragma unroll + for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { + int64_t index = query_idx + q_item_idx; + if (check_bounds) + index = min(index, M - 1); + myGpuAtomicAdd(&grad_q[batch_idx][index][k * kVecSize], res[q_item_idx]); + } + } + + for (int64_t k = threadIdx.x; k < K / kVecSize; k += blockDim.x) { + vec_t res[kBlockSizeK] = {0}; +#pragma unroll + for (int64_t i = 0; i < TILE_SIZEQ; i++) { + vec_t kk = __ldg(qbb[i] + k); +#pragma unroll + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + sputnik::VectorCompute::FMA( + fact[i][kBlockSizeK * threadIdx.y + k_item_idx], + kk, + &res[k_item_idx]); + } + } +#pragma unroll + for (int k_item_idx = 0; k_item_idx < kBlockSizeK; k_item_idx++) { + int64_t index = l + k_item_idx; + if (check_bounds) + index = min(index, N - 1); + myGpuAtomicAdd(&grad_k[batch_idx][index][k * kVecSize], res[k_item_idx]); + } + } +} + +template +void launch_attention_backward( + at::Tensor& grad_q, + at::Tensor& grad_k, + at::Tensor& grad_v, + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& logsumexp, + at::Tensor& tmp_sum_i) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + + constexpr int TILE_SIZEQ = 32; + constexpr int TILE_SIZEK = 32; + + constexpr int64_t kBlockSizeQ = 4; + constexpr int64_t kBlockSizeK = 8; + + dim3 grid( + ceil_div(M, int64_t(TILE_SIZEQ)), ceil_div(N, int64_t(TILE_SIZEK)), B); + dim3 block(TILE_SIZEQ / kBlockSizeQ, TILE_SIZEK / kBlockSizeK); + + constexpr int TILE_SIZEQ2 = 32; + constexpr int TILE_SIZEK2 = 32; + + constexpr int64_t kBlockSizeQ2 = 4; + constexpr int64_t kBlockSizeK2 = 4; + + dim3 grid2( + ceil_div(M, int64_t(TILE_SIZEQ2)), ceil_div(N, int64_t(TILE_SIZEK2)), B); + dim3 block2(TILE_SIZEQ2 / kBlockSizeQ2, TILE_SIZEK2 / kBlockSizeK2); + + // the bounds checking in device code is very expensive, making the code + // around 25% slower. So let's skip those checks if possible. + if ((M % TILE_SIZEQ == 0) && (N % TILE_SIZEK == 0)) { + attention_backward_grad_v_kernel< + scalar_t, + vec_t, + kBlockSizeQ, + kBlockSizeK, + TILE_SIZEQ, + TILE_SIZEK, + false><<>>( + grad_v.packed_accessor(), + grad_out.packed_accessor(), + query.packed_accessor(), + key.packed_accessor(), + value.packed_accessor(), + tmp_sum_i.packed_accessor(), + logsumexp.packed_accessor()); + } else { + attention_backward_grad_v_kernel< + scalar_t, + vec_t, + kBlockSizeQ, + kBlockSizeK, + TILE_SIZEQ, + TILE_SIZEK, + true><<>>( + grad_v.packed_accessor(), + grad_out.packed_accessor(), + query.packed_accessor(), + key.packed_accessor(), + value.packed_accessor(), + tmp_sum_i.packed_accessor(), + logsumexp.packed_accessor()); + } - return res; + if ((M % TILE_SIZEQ2 == 0) && (N % TILE_SIZEK2 == 0)) { + attention_backward_grad_qk_kernel< + scalar_t, + vec_t, + kBlockSizeQ2, + kBlockSizeK2, + TILE_SIZEQ2, + TILE_SIZEK2, + false><<>>( + grad_q.packed_accessor(), + grad_k.packed_accessor(), + grad_out.packed_accessor(), + query.packed_accessor(), + key.packed_accessor(), + value.packed_accessor(), + tmp_sum_i.packed_accessor(), + logsumexp.packed_accessor()); + } else { + attention_backward_grad_qk_kernel< + scalar_t, + vec_t, + kBlockSizeQ2, + kBlockSizeK2, + TILE_SIZEQ2, + TILE_SIZEK2, + true><<>>( + grad_q.packed_accessor(), + grad_k.packed_accessor(), + grad_out.packed_accessor(), + query.packed_accessor(), + key.packed_accessor(), + value.packed_accessor(), + tmp_sum_i.packed_accessor(), + logsumexp.packed_accessor()); + } +} + +std::tuple attention_backward( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& logsumexp + // const at::Tensor& mask +) { + TORCH_CHECK(query.dim() == grad_out.dim()); + TORCH_CHECK(query.dim() == key.dim()); + TORCH_CHECK(query.dim() == value.dim()); + // TORCH_CHECK(query.dim() == mask.dim()); + TORCH_CHECK(query.dim() == 3); + + TORCH_CHECK(query.size(0) == grad_out.size(0)); + TORCH_CHECK(query.size(1) == grad_out.size(1)); + TORCH_CHECK(query.size(2) == grad_out.size(2)); + + TORCH_CHECK(query.size(2) == key.size(2)); + TORCH_CHECK(query.size(0) == key.size(0)); + + TORCH_CHECK(query.size(0) == value.size(0)); + TORCH_CHECK(key.size(1) == value.size(1)); + TORCH_CHECK( + query.size(2) == + value.size(2)); // TODO: drop this limitation in the future + + TORCH_CHECK(query.is_cuda(), "query must be a CUDA tensor"); + TORCH_CHECK(key.is_cuda(), "key must be a CUDA tensor"); + TORCH_CHECK(value.is_cuda(), "value must be a CUDA tensor"); + TORCH_CHECK(grad_out.is_cuda(), "grad_out must be a CUDA tensor"); + + TORCH_CHECK(!query.is_sparse(), "query must be a dense tensor"); + TORCH_CHECK(!key.is_sparse(), "key must be a dense tensor"); + TORCH_CHECK(!value.is_sparse(), "value must be a dense tensor"); + TORCH_CHECK(!grad_out.is_sparse(), "grad_out must be a dense tensor"); + + // TODO: support other dtypes in the future + TORCH_CHECK( + query.scalar_type() == at::ScalarType::Float, + "Only float32 type is supported for now"); + + at::cuda::CUDAGuard device_guard(query.device()); + + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + int64_t K = query.size(2); + + at::Tensor grad_q = at::zeros_like(query); + at::Tensor grad_k = at::zeros_like(key); + at::Tensor grad_v = at::zeros_like(value); + + at::Tensor tmp_sum_i = at::zeros({B, M}, query.options()); + + // using scalar_t = float; + // using vec_t = float4; + // using vec_t = float; + + if ((K % 4) == 0) { + launch_attention_backward( + grad_q, + grad_k, + grad_v, + grad_out, + query, + key, + value, + logsumexp, + tmp_sum_i); + } else if ((K % 2) == 0) { + launch_attention_backward( + grad_q, + grad_k, + grad_v, + grad_out, + query, + key, + value, + logsumexp, + tmp_sum_i); + } else { + launch_attention_backward( + grad_q, + grad_k, + grad_v, + grad_out, + query, + key, + value, + logsumexp, + tmp_sum_i); + } + + AT_CUDA_CHECK(cudaGetLastError()); + + return std::make_tuple(grad_q, grad_k, grad_v); } } // namespace @@ -593,4 +1192,7 @@ TORCH_LIBRARY_IMPL(xformers, CUDA, m) { m.impl( TORCH_SELECTIVE_NAME("xformers::efficient_attention"), TORCH_FN(attention)); + m.impl( + TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward"), + TORCH_FN(attention_backward)); } diff --git a/xformers/ops.py b/xformers/ops.py index 2eb4412051..355c170ca2 100644 --- a/xformers/ops.py +++ b/xformers/ops.py @@ -29,6 +29,22 @@ def masked_matmul(a, b, mask=None): return att +class _MemoryEfficientAttentionOp(torch.autograd.Function): + @staticmethod + def forward(ctx, query, key, value): + out, lse = torch.ops.xformers.efficient_attention(query, key, value, True) + ctx.save_for_backward(query, key, value, lse) + return out + + @staticmethod + def backward(ctx, grad): + query, key, value, lse = ctx.saved_tensors + grad_q, grad_k, grad_v = torch.ops.xformers.efficient_attention_backward( + grad, query, key, value, lse + ) + return grad_q, grad_k, grad_v + + def memory_efficient_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ): @@ -36,11 +52,8 @@ def memory_efficient_attention( Implements the memory-efficient attention mechanism following `"Self-Attention Does Not Need O(n^2) Memory" `_. - For now, only forward in inference-mode is supported. """ - # don't support backwards for now - assert query.requires_grad is False - assert key.requires_grad is False - assert value.requires_grad is False - - return torch.ops.xformers.efficient_attention(query, key, value) + # fast-path that doesn't require computing the logsumexp for backward computation + if all(x.requires_grad is False for x in [query, key, value]): + return torch.ops.xformers.efficient_attention(query, key, value, False)[0] + return _MemoryEfficientAttentionOp.apply(query, key, value)