From 969ddce9f8da429071257a15b91ce3d9fe5f5439 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Tue, 12 Jul 2022 21:41:16 +0200 Subject: [PATCH] Cast gutils to correct subclass for applyChainRule (#733) --- enzyme/Enzyme/AdjointGenerator.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 7bb735125d87b..fa5c712c812aa 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -1732,7 +1732,7 @@ class AdjointGenerator template Value *applyChainRule(Type *diffType, IRBuilder<> &Builder, Func rule, Args... args) { - return ((DiffeGradientUtils *)gutils) + return ((GradientUtils *)gutils) ->applyChainRule(diffType, Builder, rule, args...); } @@ -1740,7 +1740,7 @@ class AdjointGenerator /// function f to each element. template void applyChainRule(IRBuilder<> &Builder, Func rule, Args... args) { - ((DiffeGradientUtils *)gutils)->applyChainRule(Builder, rule, args...); + ((GradientUtils *)gutils)->applyChainRule(Builder, rule, args...); } /// Unwraps an collection of constant vector derivatives from their internal @@ -1748,7 +1748,7 @@ class AdjointGenerator template void applyChainRule(ArrayRef diffs, IRBuilder<> &Builder, Func rule) { - ((DiffeGradientUtils *)gutils)->applyChainRule(diffs, Builder, rule); + ((GradientUtils *)gutils)->applyChainRule(diffs, Builder, rule); } bool shouldFree() {