Skip to content

Commit

Permalink
Cast gutils to correct subclass for applyChainRule (rust-lang#733)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich authored Jul 12, 2022
1 parent 95a6666 commit 969ddce
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -1732,23 +1732,23 @@ class AdjointGenerator
template <typename Func, typename... Args>
Value *applyChainRule(Type *diffType, IRBuilder<> &Builder, Func rule,
Args... args) {
return ((DiffeGradientUtils *)gutils)
return ((GradientUtils *)gutils)
->applyChainRule(diffType, Builder, rule, args...);
}

/// Unwraps a vector derivative from its internal representation and applies a
/// function f to each element.
template <typename Func, typename... Args>
void applyChainRule(IRBuilder<> &Builder, Func rule, Args... args) {
((DiffeGradientUtils *)gutils)->applyChainRule(Builder, rule, args...);
((GradientUtils *)gutils)->applyChainRule(Builder, rule, args...);
}

/// Unwraps an collection of constant vector derivatives from their internal
/// representations and applies a function f to each element.
template <typename Func>
void applyChainRule(ArrayRef<Value *> diffs, IRBuilder<> &Builder,
Func rule) {
((DiffeGradientUtils *)gutils)->applyChainRule(diffs, Builder, rule);
((GradientUtils *)gutils)->applyChainRule(diffs, Builder, rule);
}

bool shouldFree() {
Expand Down

0 comments on commit 969ddce

Please sign in to comment.