Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix struct multiarg #2048

Merged
merged 1 commit into from
Aug 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,12 +905,30 @@ class EnzymeBase {
constants.push_back(DIFFE_TYPE::DUP_ARG);
}

ssize_t interleaved = -1;

size_t maxsize;
#if LLVM_VERSION_MAJOR >= 14
for (unsigned i = 1 + sret; i < CI->arg_size(); ++i)
maxsize = CI->arg_size();
#else
for (unsigned i = 1 + sret; i < CI->getNumArgOperands(); ++i)
maxsize = CI->getNumArgOperands();
#endif
{
size_t num_args = maxsize;
for (unsigned i = 1 + sret; i < maxsize; ++i) {
Value *res = CI->getArgOperand(i);
auto metaString = getMetadataName(res);
if (metaString && startsWith(*metaString, "enzyme_")) {
if (*metaString == "enzyme_interleave") {
maxsize = i;
interleaved = i + 1;
break;
}
}
}

DIFFE_TYPE last_ty = DIFFE_TYPE::DUP_ARG;

for (ssize_t i = 1 + sret; (size_t)i < maxsize; ++i) {
Value *res = CI->getArgOperand(i);
auto metaString = getMetadataName(res);
#if LLVM_VERSION_MAJOR > 16
Expand Down Expand Up @@ -1186,7 +1204,10 @@ class EnzymeBase {
overwritten_args[truei] = overwritten;

auto PTy = FT->getParamType(truei);
DIFFE_TYPE ty = opt_ty ? *opt_ty : whatType(PTy, mode);
DIFFE_TYPE ty =
opt_ty ? *opt_ty
: ((interleaved == -1) ? whatType(PTy, mode) : last_ty);
last_ty = ty;

constants.push_back(ty);

Expand Down Expand Up @@ -1250,7 +1271,8 @@ class EnzymeBase {

args.push_back(res);
if (ty == DIFFE_TYPE::DUP_ARG || ty == DIFFE_TYPE::DUP_NONEED) {
++i;
if (interleaved == -1)
++i;

Value *res = nullptr;
#if LLVM_VERSION_MAJOR >= 16
Expand All @@ -1260,22 +1282,19 @@ class EnzymeBase {
#endif

for (unsigned v = 0; v < width; ++v) {
#if LLVM_VERSION_MAJOR >= 14
if (i >= CI->arg_size())
#else
if (i >= CI->getNumArgOperands())
#endif
{
if ((size_t)((interleaved == -1) ? i : interleaved) >= num_args) {
EmitFailure("MissingArgShadow", CI->getDebugLoc(), CI,
"__enzyme_autodiff missing argument shadow at index ",
i, ", need shadow of type ", *PTy,
*((interleaved == -1) ? &i : &interleaved),
", need shadow of type ", *PTy,
" to shadow primal argument ", *args.back(),
" at call ", *CI);
return {};
}

// cast diffe
Value *element = CI->getArgOperand(i);
Value *element =
CI->getArgOperand((interleaved == -1) ? i : interleaved);
if (batch) {
if (auto elementPtrTy = dyn_cast<PointerType>(element->getType())) {
element = Builder.CreateBitCast(
Expand All @@ -1290,14 +1309,16 @@ class EnzymeBase {
} else {
EmitFailure(
"NonPointerBatch", CI->getDebugLoc(), CI,
"Batched argument at index ", i,
"Batched argument at index ",
*((interleaved == -1) ? &i : &interleaved),
" must be of pointer type, found: ", *element->getType());
return {};
}
}
if (PTy != element->getType()) {
element = castToDiffeFunctionArgType(Builder, CI, FT, PTy, i, mode,
element, truei);
element = castToDiffeFunctionArgType(
Builder, CI, FT, PTy, (interleaved == -1) ? i : interleaved,
mode, element, truei);
if (!element) {
return {};
}
Expand All @@ -1310,13 +1331,16 @@ class EnzymeBase {
element->getType(), width)),
element, {v});

if (v < width - 1 && !batch) {
if (v < width - 1 && !batch && (interleaved == -1)) {
++i;
}

} else {
res = element;
}

if (interleaved != -1)
interleaved++;
}

args.push_back(res);
Expand Down
41 changes: 26 additions & 15 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2718,22 +2718,33 @@ getAllLoadedValuesFrom(AllocaInst *ptr0, size_t offset, size_t valSz,
// all sub uses
if (auto MTI = dyn_cast<MemTransferInst>(U)) {
if (auto CI = dyn_cast<ConstantInt>(MTI->getLength())) {
if (MTI->getOperand(0) == ptr && suboff == 0 &&
CI->getValue().uge(offset + valSz)) {
size_t midoffset = 0;
auto AI2 = getBaseAndOffset(MTI->getOperand(1), midoffset);
if (!AI2) {
legal = false;
return options;
}
if (midoffset != 0) {
legal = false;
return options;
}
for (const auto &pair3 : findAllUsersOf(AI2)) {
todo.emplace_back(std::move(pair3));
if (MTI->getOperand(0) == ptr) {
auto storeSz = CI->getValue();

// If store is before the load would start
if ((storeSz + suboff).ule(offset))
continue;

// if store starts after load would start
if (offset + valSz <= suboff)
continue;

if (suboff == 0 && CI->getValue().uge(offset + valSz)) {
size_t midoffset = 0;
auto AI2 = getBaseAndOffset(MTI->getOperand(1), midoffset);
if (!AI2) {
legal = false;
return options;
}
if (midoffset != 0) {
legal = false;
return options;
}
for (const auto &pair3 : findAllUsersOf(AI2)) {
todo.emplace_back(std::move(pair3));
}
continue;
}
continue;
}
}
}
Expand Down
60 changes: 55 additions & 5 deletions enzyme/include/enzyme/utils
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

extern int enzyme_interleave;

extern int enzyme_dup;
extern int enzyme_dupnoneed;
extern int enzyme_out;
Expand Down Expand Up @@ -238,6 +240,54 @@ namespace enzyme {
return enzyme::tuple{enzyme_const, arg.value};
}

template < typename T >
__attribute__((always_inline))
auto expand_primals(const enzyme::Duplicated<T> & arg) {
return enzyme::tuple{enzyme_dup, arg.value};
}

template < typename T >
__attribute__((always_inline))
auto expand_primals(const enzyme::DuplicatedNoNeed<T> & arg) {
return enzyme::tuple{enzyme_dupnoneed, arg.value};
}

template < typename T >
__attribute__((always_inline))
auto expand_primals(const enzyme::Active<T> & arg) {
return enzyme::tuple<int, T>{enzyme_out, arg.value};
}

template < typename T >
__attribute__((always_inline))
auto expand_primals(const enzyme::Const<T> & arg) {
return enzyme::tuple{enzyme_const, arg.value};
}

template < typename T >
__attribute__((always_inline))
auto expand_shadows(const enzyme::Duplicated<T> & arg) {
return enzyme::tuple{arg.shadow};
}

template < typename T >
__attribute__((always_inline))
auto expand_shadows(const enzyme::DuplicatedNoNeed<T> & arg) {
return enzyme::tuple{arg.shadow};
}

template < typename T >
__attribute__((always_inline))
auto expand_shadows(const enzyme::Active<T> & arg) {
return enzyme::tuple{};
}

template < typename T >
__attribute__((always_inline))
auto expand_shadows(const enzyme::Const<T> & arg) {
return enzyme::tuple{};
}

template < typename T >
__attribute__((always_inline))
auto primal_args(const enzyme::Duplicated<T> & arg) {
Expand Down Expand Up @@ -279,7 +329,7 @@ namespace enzyme {

template<typename function, typename RT, typename ...T>
struct templated_call<function, RT(T...)> {
static RT wrap(T... args, function* __restrict__ f) {
static RT wrap(function* __restrict__ f, T... args) {
return (*f)(args...);
}
};
Expand Down Expand Up @@ -311,7 +361,7 @@ namespace enzyme {
template <class return_type, class Tuple, std::size_t... I, typename... ExtraArgs>
__attribute__((always_inline))
static constexpr decltype(auto) impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>, ExtraArgs... args) {
return push_return_last(__enzyme_autodiff<return_type>(f, ret_attr, enzyme::get<I>(impl::forward<Tuple>(t))..., args...));
return push_return_last(__enzyme_autodiff<return_type>(f, ret_attr, args..., enzyme::get<I>(impl::forward<Tuple>(t))...));
}
};

Expand All @@ -320,7 +370,7 @@ namespace enzyme {
template <class return_type, class Tuple, std::size_t... I, typename... ExtraArgs>
__attribute__((always_inline))
static constexpr return_type impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>, ExtraArgs... args) {
return __enzyme_fwddiff<return_type>(f, ret_attr, enzyme::get<I>(impl::forward<Tuple>(t))..., args...);
return __enzyme_fwddiff<return_type>(f, ret_attr, args..., enzyme::get<I>(impl::forward<Tuple>(t))...);
}
};

Expand Down Expand Up @@ -466,7 +516,7 @@ namespace enzyme {
using primal_return_type = decltype(f(arg_type(args)...));
using functy = typename detail::function_type<primal_return_type, decltype(arg_type(args))...>::type;
using return_type = typename autodiff_return<DiffMode, RetActivity, arg_types...>::type;
return autodiff_impl<return_type, DiffMode, function, functy, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
return autodiff_impl<return_type, DiffMode, function, functy, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_primals(args)..., enzyme::tuple{enzyme_interleave}, expand_shadows(args)...));
}

template < typename DiffMode, typename function, typename ... arg_types>
Expand All @@ -481,7 +531,7 @@ namespace enzyme {
using functy = typename detail::function_type<primal_return_type, decltype(arg_type(args))...>::type;
using RetActivity = typename detail::default_ret_activity<DiffMode, primal_return_type>::type;
using return_type = typename autodiff_return<DiffMode, RetActivity, arg_types...>::type;
return autodiff_impl<return_type, DiffMode, function, functy, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
return autodiff_impl<return_type, DiffMode, function, functy, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_primals(args)..., enzyme::tuple{enzyme_interleave}, expand_shadows(args)...));
}
#pragma clang diagnostic pop

Expand Down
6 changes: 3 additions & 3 deletions enzyme/test/Integration/CppSugar/ptrref_args_fails.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ int test_failures() {
*/
const float dfdx = enzyme::get<0>(
enzyme::get<0>(
enzyme::autodiff<enzyme::Reverse>( // expected-error@/enzymeroot/enzyme/utils:233 {{no member named 'value' in 'enzyme::Active<const float &>'}} expected-note {{}}
g, enzyme::Active<const float&>{ x } // expected-error@/enzymeroot/enzyme/utils:46 {{static assertion failed due to requirement '!std::is_reference_v<const float &>': Reference/pointer active arguments don't make sense for AD!}} expected-note {{}} expected-note@/enzymeroot/enzyme/utils:485 {{}}
enzyme::autodiff<enzyme::Reverse>( // expected-error@/enzymeroot/enzyme/utils:259 {{no member named 'value' in 'enzyme::Active<const float &>'}} expected-note {{}}
g, enzyme::Active<const float&>{ x } // expected-error@/enzymeroot/enzyme/utils:48 {{static assertion failed due to requirement '!std::is_reference_v<const float &>': Reference/pointer active arguments don't make sense for AD!}} expected-note {{}} expected-note@/enzymeroot/enzyme/utils:535 {{}}
)
)
);
Expand All @@ -38,7 +38,7 @@ int test_failures() {
* mode
*/
float dfdx = 0;
enzyme::autodiff<enzyme::Reverse>( // expected-error@/enzymeroot/enzyme/utils:477 {{static assertion failed due to requirement 'detail::verify_dup_args<enzyme::ReverseMode<false>, enzyme::Duplicated<float>>::value': Non-reference/pointer Duplicated/DuplicatedNoNeed args don't make sense for Reverse mode AD}} expected-note {{}}
enzyme::autodiff<enzyme::Reverse>( // expected-error@/enzymeroot/enzyme/utils:527 {{static assertion failed due to requirement 'detail::verify_dup_args<enzyme::ReverseMode<false>, enzyme::Duplicated<float>>::value': Non-reference/pointer Duplicated/DuplicatedNoNeed args don't make sense for Reverse mode AD}} expected-note {{}}
f, enzyme::Duplicated<float>{ x, dfdx }
);
}
Expand Down
29 changes: 29 additions & 0 deletions enzyme/test/Integration/CppSugar/structmulitarg.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O0 %s -S -emit-llvm -o
// - %loadClangEnzyme | %lli - ; fi RUN: if [ %llvmver -ge 11 ]; then %clang++
// -std=c++17 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi RUN: if [
// %llvmver -ge 11 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o -
// %loadClangEnzyme | %lli - ; fi RUN: if [ %llvmver -ge 11 ]; then %clang++
// -std=c++17 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi RUN: if [
// %llvmver -ge 12 ]; then %clang++ -std=c++17 -O0 %s -S -emit-llvm -o -
// %newLoadClangEnzyme | %lli - ; fi RUN: if [ %llvmver -ge 12 ]; then %clang++
// -std=c++17 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi RUN:
// if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o -
// %newLoadClangEnzyme | %lli - ; fi RUN: if [ %llvmver -ge 12 ]; then %clang++
// -std=c++17 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi

#include "../test_utils.h"

#include <enzyme/enzyme>

struct pair {
double x;
double y;
};

double square(struct pair p) { return p.x * p.y; }

int main() {
double res = enzyme::get<0>(enzyme::autodiff<enzyme::Forward>(
square, enzyme::Duplicated{pair{2, 3}, pair{70, 110}}));
APPROX_EQ(res, 2 * 110 + 3 * 70, 1e-10);
}
Loading