diff --git a/src/enzyme_ad/jax/Implementations/HLODerivatives.td b/src/enzyme_ad/jax/Implementations/HLODerivatives.td index 6d816fcec..ea1c17e5f 100644 --- a/src/enzyme_ad/jax/Implementations/HLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/HLODerivatives.td @@ -893,14 +893,50 @@ def FftLength : GlobalExprgetResult(0).getType().cast(); - auto lengths = op.getFftLength(); + auto lengths = op.getFftLengthAttr().getValues(); auto N = std::accumulate(lengths.begin(), lengths.end(), llvm::APInt(64, 1, true), std::multiplies{}).getSExtValue(); - auto ret_constant = builder.create(op.getLoc(), builder.getDenseI64ArrayAttr(ArrayRef({N}))); - auto ret_broadcast = builder.create(op.getLoc(), op_type.clone(op_type.getShape(), builder.getI64Type()), ret_constant, builder.getI64VectorAttr(op_type.getShape())); - builder.create(op.getLoc(), op->getResult(0).getType(), ret_broadcast); + double value = N; + switch (op.getFftType()) { + case FftType::FFT: + break; + case FftType::IFFT: + value = 1 / value; + break; + case FftType::RFFT: + value /= 2; + break; + case FftType::IRFFT: + value = 2 / value; + break; + } + auto resTy = op->getResult(0).getType().cast(); + mlir::Value ret_constant = builder.create(op.getLoc(), SplatElementsAttr::get( + resTy, FloatAttr::get(resTy.getElementType(), value))); + + if (op.getFftType() == FftType::RFFT || op.getFftType() == FftType::IRFFT) { + auto RT = RankedTensorType::get({1}, resTy.getElementType()); + auto zero_constant = builder.create(op.getLoc(), SplatElementsAttr::get( + RT, FloatAttr::get(resTy.getElementType(), 0))); + auto end_constant = builder.create(op.getLoc(), SplatElementsAttr::get( + RT, FloatAttr::get(resTy.getElementType(), lengths[lengths.size()-1]-1))); + + auto RT64 = RankedTensorType::get({1}, builder.getIntegerType(64)); + + Value start[] = { + builder.create(op.getLoc(), SplatElementsAttr::get(RT64, rewriter.getI64IntegerAttr(0))) + }; + Value end[] = { + builder.create(op.getLoc(), SplatElementsAttr::get(RT64, rewriter.getI64IntegerAttr(lengths.size()-1))) + }; + ret_constant = builder.create(op.getLoc(), resTy, ret_constant, zero_constant, start); + ret_constant = builder.create(op.getLoc(), resTy, ret_constant, end_constant, end); + } + ret_constant; }]>; +def SelectIfIRFFT : StaticSelect<"op.getFftType() == FftType::IRFFT">; + // Derivative rules def : HLODerivative<"AddOp", (Op $x, $y), [ @@ -997,12 +1033,22 @@ def : HLODerivative<"ExpOp", (Op $x), [(CheckedMul (DiffeRet), (Exp $x))]>; def : HLODerivative<"Expm1Op", (Op $x), [(CheckedMul (DiffeRet), (Exp $x))]>; -// TODO fix `rfft` and `irfft` derivatives: -// - `rfft` => divide `DiffeRet` elems by 2 except 1st elem, and last elem if `FftLength` is even -// - `irfft` => def : HLODerivative<"FftOp", (Op $x), [ - (Fft (DiffeRet), (FftType), (FftLength)) // TODO maybe we need to conjugate? or inverse fft + multiply by N? + (Mul + // multiplier + // (FftMultiplier), // TODO fix this + (HLOConstantFP<"1">), + // inverse fft + (Fft + (SelectIfIRFFT + (Real (DiffeRet)), // IRFFT is complex to real, so reverse-mode needs to pass a real diff + (DiffeRet) + ), + (FftTypeInverse), + (FftLength) // TODO revise + ) + ) ], (Fft (Shadow $x), (FftType), (FftLength)) >;