From 6e155bf76d113043118b3ac037ad00da0aac3df8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 29 Nov 2024 20:56:54 +0100 Subject: [PATCH 1/6] try fix fft rev diff rule --- .../jax/Implementations/HLODerivatives.td | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/enzyme_ad/jax/Implementations/HLODerivatives.td b/src/enzyme_ad/jax/Implementations/HLODerivatives.td index 6d816fcec..1058716b8 100644 --- a/src/enzyme_ad/jax/Implementations/HLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/HLODerivatives.td @@ -901,6 +901,11 @@ def FftMultiplier : GlobalExpr(op.getLoc(), op->getResult(0).getType(), ret_broadcast); }]>; +def FftIsIRFFT : GlobalExpr(op.getLoc(), builder.getDenseBoolArrayAttr(ArrayRef({cond}))); +}]>; + // Derivative rules def : HLODerivative<"AddOp", (Op $x, $y), [ @@ -997,12 +1002,17 @@ def : HLODerivative<"ExpOp", (Op $x), [(CheckedMul (DiffeRet), (Exp $x))]>; def : HLODerivative<"Expm1Op", (Op $x), [(CheckedMul (DiffeRet), (Exp $x))]>; -// TODO fix `rfft` and `irfft` derivatives: -// - `rfft` => divide `DiffeRet` elems by 2 except 1st elem, and last elem if `FftLength` is even -// - `irfft` => def : HLODerivative<"FftOp", (Op $x), [ - (Fft (DiffeRet), (FftType), (FftLength)) // TODO maybe we need to conjugate? or inverse fft + multiply by N? + (Multiply + (FftMultiplier), // TODO fix this + (Fft + (Select + (FftIsIRFFT), // if IRFFT + (Real (DiffeRet)), // call real(diff) + (DiffeRet), + (RevFftType), + (FftLength)))) ], (Fft (Shadow $x), (FftType), (FftLength)) >; From 0c6b832104687d468f55ed5964e6d7c1f84ec945 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 29 Nov 2024 21:18:32 +0100 Subject: [PATCH 2/6] fix typo --- src/enzyme_ad/jax/Implementations/HLODerivatives.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/Implementations/HLODerivatives.td b/src/enzyme_ad/jax/Implementations/HLODerivatives.td index 1058716b8..8da5d660d 100644 --- a/src/enzyme_ad/jax/Implementations/HLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/HLODerivatives.td @@ -1004,7 +1004,7 @@ def : HLODerivative<"Expm1Op", (Op $x), [(CheckedMul (DiffeRet), (Exp $x))]>; def : HLODerivative<"FftOp", (Op $x), [ - (Multiply + (Mul (FftMultiplier), // TODO fix this (Fft (Select From 61ea4414f1ad17bedb491db211421b45fa34b191 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 29 Nov 2024 21:52:19 +0100 Subject: [PATCH 3/6] fix typo --- src/enzyme_ad/jax/Implementations/HLODerivatives.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/Implementations/HLODerivatives.td b/src/enzyme_ad/jax/Implementations/HLODerivatives.td index 8da5d660d..d48ab47cf 100644 --- a/src/enzyme_ad/jax/Implementations/HLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/HLODerivatives.td @@ -1011,7 +1011,7 @@ def : HLODerivative<"FftOp", (Op $x), (FftIsIRFFT), // if IRFFT (Real (DiffeRet)), // call real(diff) (DiffeRet), - (RevFftType), + (FftTypeInverse), (FftLength)))) ], (Fft (Shadow $x), (FftType), (FftLength)) From 08a05ee92fd466e8051825caafc0245c63c5557a Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Mon, 16 Dec 2024 14:50:55 -0600 Subject: [PATCH 4/6] correct impl --- .../jax/Implementations/HLODerivatives.td | 40 +++++++++++++++++-- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/src/enzyme_ad/jax/Implementations/HLODerivatives.td b/src/enzyme_ad/jax/Implementations/HLODerivatives.td index d48ab47cf..da8ac71ed 100644 --- a/src/enzyme_ad/jax/Implementations/HLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/HLODerivatives.td @@ -896,9 +896,43 @@ def FftMultiplier : GlobalExpr(op.getLoc(), builder.getDenseI64ArrayAttr(ArrayRef({N}))); - auto ret_broadcast = builder.create(op.getLoc(), op_type.clone(op_type.getShape(), builder.getI64Type()), ret_constant, builder.getI64VectorAttr(op_type.getShape())); - builder.create(op.getLoc(), op->getResult(0).getType(), ret_broadcast); + double value = N; + switch (op.getFftType()) { + case FftType::FFT: + break; + case FftType::IFFT: + value = 1 / value; + break; + case FftType::RFFT: + value /= 2; + break; + case FftType::IRFFT: + value = 2 / value; + break; + } + auto resTy = op->getResult(0).getType().cast(); + mlir::Value ret_constant = builder.create(op.getLoc(), SplatElementsAttr::get( + resTy, FloatAttr::get(resTy.getElementType(), value))); + + if (op.getFftType() == FftType::RFFT || op.getFftType() == FftType::IRFFT) { + auto RT = RankedTensorType::get({1}, resTy.getElementType()); + auto zero_constant = builder.create(op.getLoc(), SplatElementsAttr::get( + RT, FloatAttr::get(resTy.getElementType(), 0))); + auto end_constant = builder.create(op.getLoc(), SplatElementsAttr::get( + RT, FloatAttr::get(resTy.getElementType(), lengths.back()-1))); + + auto RT64 = RankedTensorType::get({1}, builder.getIntegerType(64)); + + Value start[] = { + builder.create(op.getLoc(), SplatElementsAttr::get(RT64, op.getI64IntegerAttr(0))) + }; + Value end[] = { + builder.create(op.getLoc(), SplatElementsAttr::get(RT64, op.getI64IntegerAttr(lengths.size()-1))) + }; + ret_constant = builder.create(op.getLoc(), resTy, ret_constant, zero_constant, start); + ret_constant = builder.create(op.getLoc(), resTy, ret_constant, end_constant, end); + } + ret_constant; }]>; def FftIsIRFFT : GlobalExpr Date: Mon, 16 Dec 2024 18:19:00 -0600 Subject: [PATCH 5/6] fixup --- src/enzyme_ad/jax/Implementations/HLODerivatives.td | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/enzyme_ad/jax/Implementations/HLODerivatives.td b/src/enzyme_ad/jax/Implementations/HLODerivatives.td index da8ac71ed..9040add94 100644 --- a/src/enzyme_ad/jax/Implementations/HLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/HLODerivatives.td @@ -893,7 +893,7 @@ def FftLength : GlobalExprgetResult(0).getType().cast(); - auto lengths = op.getFftLength(); + auto lengths = op.getFftLengthAttr().getValues(); auto N = std::accumulate(lengths.begin(), lengths.end(), llvm::APInt(64, 1, true), std::multiplies{}).getSExtValue(); double value = N; @@ -919,15 +919,15 @@ def FftMultiplier : GlobalExpr(op.getLoc(), SplatElementsAttr::get( RT, FloatAttr::get(resTy.getElementType(), 0))); auto end_constant = builder.create(op.getLoc(), SplatElementsAttr::get( - RT, FloatAttr::get(resTy.getElementType(), lengths.back()-1))); + RT, FloatAttr::get(resTy.getElementType(), lengths[lengths.size()-1]-1))); auto RT64 = RankedTensorType::get({1}, builder.getIntegerType(64)); Value start[] = { - builder.create(op.getLoc(), SplatElementsAttr::get(RT64, op.getI64IntegerAttr(0))) + builder.create(op.getLoc(), SplatElementsAttr::get(RT64, rewriter.getI64IntegerAttr(0))) }; Value end[] = { - builder.create(op.getLoc(), SplatElementsAttr::get(RT64, op.getI64IntegerAttr(lengths.size()-1))) + builder.create(op.getLoc(), SplatElementsAttr::get(RT64, rewriter.getI64IntegerAttr(lengths.size()-1))) }; ret_constant = builder.create(op.getLoc(), resTy, ret_constant, zero_constant, start); ret_constant = builder.create(op.getLoc(), resTy, ret_constant, end_constant, end); From 03a79ae4c7a9f0513ff4983432b43f4d5def73a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 10 Jan 2025 11:28:39 +0100 Subject: [PATCH 6/6] refactor deriv of FFT currently multiplier set to 1 to start testing stuff --- .../jax/Implementations/HLODerivatives.td | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/enzyme_ad/jax/Implementations/HLODerivatives.td b/src/enzyme_ad/jax/Implementations/HLODerivatives.td index 9040add94..ea1c17e5f 100644 --- a/src/enzyme_ad/jax/Implementations/HLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/HLODerivatives.td @@ -935,10 +935,7 @@ def FftMultiplier : GlobalExpr; -def FftIsIRFFT : GlobalExpr(op.getLoc(), builder.getDenseBoolArrayAttr(ArrayRef({cond}))); -}]>; +def SelectIfIRFFT : StaticSelect<"op.getFftType() == FftType::IRFFT">; // Derivative rules def : HLODerivative<"AddOp", (Op $x, $y), @@ -1039,14 +1036,19 @@ def : HLODerivative<"Expm1Op", (Op $x), [(CheckedMul (DiffeRet), (Exp $x))]>; def : HLODerivative<"FftOp", (Op $x), [ (Mul - (FftMultiplier), // TODO fix this + // multiplier + // (FftMultiplier), // TODO fix this + (HLOConstantFP<"1">), + // inverse fft (Fft - (Select - (FftIsIRFFT), // if IRFFT - (Real (DiffeRet)), // call real(diff) - (DiffeRet), + (SelectIfIRFFT + (Real (DiffeRet)), // IRFFT is complex to real, so reverse-mode needs to pass a real diff + (DiffeRet) + ), (FftTypeInverse), - (FftLength)))) + (FftLength) // TODO revise + ) + ) ], (Fft (Shadow $x), (FftType), (FftLength)) >;