Skip to content

Commit c86ab58

Browse files
jnthntatumcopybara-github
authored andcommitted
Add arbitrary count version for function adapter helpers.
PiperOrigin-RevId: 822644997
1 parent 38f6e94 commit c86ab58

File tree

2 files changed

+235
-73
lines changed

2 files changed

+235
-73
lines changed

runtime/function_adapter.h

Lines changed: 149 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,21 @@
1818
#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_ADAPTER_H_
1919
#define THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_ADAPTER_H_
2020

21+
#include <cstddef>
2122
#include <functional>
2223
#include <memory>
24+
#include <tuple>
25+
#include <utility>
2326
#include <vector>
2427

2528
#include "absl/base/nullability.h"
2629
#include "absl/functional/any_invocable.h"
27-
#include "absl/functional/bind_front.h"
2830
#include "absl/status/status.h"
2931
#include "absl/status/statusor.h"
3032
#include "absl/strings/str_cat.h"
3133
#include "absl/strings/string_view.h"
3234
#include "absl/types/span.h"
3335
#include "common/function_descriptor.h"
34-
#include "common/kind.h"
3536
#include "common/value.h"
3637
#include "internal/status_macros.h"
3738
#include "runtime/function.h"
@@ -94,79 +95,73 @@ struct AdaptedTypeTraits<const T&> {
9495
static T ToArg(AssignableType v) { return v; }
9596
};
9697

97-
template <typename... Args>
98-
struct KindAdderImpl;
99-
100-
template <typename Arg, typename... Args>
101-
struct KindAdderImpl<Arg, Args...> {
102-
static void AddTo(std::vector<cel::Kind>& args) {
103-
args.push_back(AdaptedKind<Arg>());
104-
KindAdderImpl<Args...>::AddTo(args);
98+
template <size_t I, typename... Args>
99+
struct AdaptHelperImpl {
100+
template <typename T>
101+
static absl::Status Apply(absl::Span<const Value> input, T& output) {
102+
static_assert(sizeof...(Args) > 0);
103+
static_assert(std::tuple_size_v<T> == sizeof...(Args));
104+
CEL_RETURN_IF_ERROR(HandleToAdaptedVisitor{input[I]}(&std::get<I>(output)));
105+
if constexpr (I == sizeof...(Args) - 1) {
106+
return absl::OkStatus();
107+
} else {
108+
CEL_RETURN_IF_ERROR(
109+
(AdaptHelperImpl<I + 1, Args...>::template Apply<T>(input, output)));
110+
}
111+
return absl::OkStatus();
105112
}
106113
};
107114

108-
template <>
109-
struct KindAdderImpl<> {
110-
static void AddTo(std::vector<cel::Kind>& args) {}
111-
};
112-
113115
template <typename... Args>
114-
struct KindAdder {
115-
static std::vector<cel::Kind> Kinds() {
116-
std::vector<cel::Kind> args;
117-
KindAdderImpl<Args...>::AddTo(args);
118-
return args;
116+
struct AdaptHelper {
117+
template <typename T>
118+
static absl::Status Apply(absl::Span<const Value> input, T& output) {
119+
return AdaptHelperImpl<0, Args...>::template Apply<T>(input, output);
119120
}
120121
};
121122

122-
template <typename T>
123-
struct ApplyReturnType {
124-
using type = absl::StatusOr<T>;
125-
};
126-
127-
template <typename T>
128-
struct ApplyReturnType<absl::StatusOr<T>> {
129-
using type = absl::StatusOr<T>;
130-
};
131-
132-
template <int N, typename Arg, typename... Args>
133-
struct IndexerImpl {
134-
using type = typename IndexerImpl<N - 1, Args...>::type;
135-
};
136-
137-
template <typename Arg, typename... Args>
138-
struct IndexerImpl<0, Arg, Args...> {
139-
using type = Arg;
140-
};
123+
template <typename... Args>
124+
struct ToArgsImpl {
125+
template <int I, typename T>
126+
struct El {
127+
using type = T;
128+
constexpr static size_t index = I;
129+
};
141130

142-
template <int N, typename... Args>
143-
struct Indexer {
144-
static_assert(N < sizeof...(Args) && N >= 0);
145-
using type = typename IndexerImpl<N, Args...>::type;
146-
};
131+
template <typename... Es>
132+
struct ZipHolder {
133+
template <typename ResultType, typename TupleType, typename Op>
134+
static ResultType ToArgs(
135+
Op&& op, const TupleType& argbuffer,
136+
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
137+
google::protobuf::MessageFactory* absl_nonnull message_factory,
138+
google::protobuf::Arena* absl_nonnull arena) {
139+
return std::forward<Op>(op)(
140+
runtime_internal::AdaptedTypeTraits<typename Es::type>::ToArg(
141+
std::get<Es::index>(argbuffer))...,
142+
descriptor_pool, message_factory, arena);
143+
}
144+
};
147145

148-
template <int N, typename... Args>
149-
struct ApplyHelper {
150-
template <typename T, typename Op>
151-
static typename ApplyReturnType<T>::type Apply(
152-
Op&& op, absl::Span<const Value> input) {
153-
constexpr int idx = sizeof...(Args) - N;
154-
using Arg = typename Indexer<idx, Args...>::type;
155-
using ArgTraits = AdaptedTypeTraits<Arg>;
156-
typename ArgTraits::AssignableType arg_i;
157-
CEL_RETURN_IF_ERROR(HandleToAdaptedVisitor{input[idx]}(&arg_i));
158-
159-
return ApplyHelper<N - 1, Args...>::template Apply<T>(
160-
absl::bind_front(std::forward<Op>(op), ArgTraits::ToArg(arg_i)), input);
146+
template <size_t... Is>
147+
static ZipHolder<El<Is, Args>...> MakeZip(const std::index_sequence<Is...>&) {
148+
return ZipHolder<El<Is, Args>...>{};
161149
}
162150
};
163151

164152
template <typename... Args>
165-
struct ApplyHelper<0, Args...> {
166-
template <typename T, typename Op>
167-
static typename ApplyReturnType<T>::type Apply(
168-
Op&& op, absl::Span<const Value> input) {
169-
return op();
153+
struct ToArgsHelper {
154+
template <typename ResultType, typename TupleType, typename Op>
155+
static ResultType Apply(
156+
Op&& op, const TupleType& argbuffer,
157+
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
158+
google::protobuf::MessageFactory* absl_nonnull message_factory,
159+
google::protobuf::Arena* absl_nonnull arena) {
160+
using Impl = ToArgsImpl<Args...>;
161+
using Zip = decltype(Impl::MakeZip(std::index_sequence_for<Args...>{}));
162+
return Zip::template ToArgs<ResultType>(std::forward<Op>(op), argbuffer,
163+
descriptor_pool, message_factory,
164+
arena);
170165
}
171166
};
172167

@@ -629,6 +624,98 @@ class QuaternaryFunctionAdapter
629624
};
630625
};
631626

627+
// Primary template for n-ary adapter.
628+
template <typename T, typename... Args>
629+
class NaryFunctionAdapter;
630+
631+
template <typename T>
632+
class NaryFunctionAdapter<T> : public NullaryFunctionAdapter<T> {};
633+
634+
template <typename T, typename U>
635+
class NaryFunctionAdapter<T, U> : public UnaryFunctionAdapter<T, U> {};
636+
637+
template <typename T, typename U, typename V>
638+
class NaryFunctionAdapter<T, U, V> : public BinaryFunctionAdapter<T, U, V> {};
639+
640+
template <typename T, typename U, typename V, typename W>
641+
class NaryFunctionAdapter<T, U, V, W>
642+
: public TernaryFunctionAdapter<T, U, V, W> {};
643+
644+
template <typename T, typename U, typename V, typename W, typename X>
645+
class NaryFunctionAdapter<T, U, V, W, X>
646+
: public QuaternaryFunctionAdapter<T, U, V, W, X> {};
647+
648+
// N-ary function adapter.
649+
//
650+
// Prefer using one of the specific count adapters above for readability and
651+
// better error messages.
652+
template <typename T, typename... Args>
653+
class NaryFunctionAdapter
654+
: public RegisterHelper<NaryFunctionAdapter<T, Args...>> {
655+
public:
656+
using FunctionType = absl::AnyInvocable<T(
657+
Args..., const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
658+
google::protobuf::MessageFactory* absl_nonnull message_factory,
659+
google::protobuf::Arena* absl_nonnull arena) const>;
660+
661+
static FunctionDescriptor CreateDescriptor(absl::string_view name,
662+
bool receiver_style,
663+
bool is_strict = true) {
664+
return FunctionDescriptor(name, receiver_style,
665+
{runtime_internal::AdaptedKind<Args>()...},
666+
is_strict);
667+
}
668+
669+
static std::unique_ptr<cel::Function> WrapFunction(FunctionType fn) {
670+
return std::make_unique<NaryFunctionImpl>(std::move(fn));
671+
}
672+
673+
static std::unique_ptr<cel::Function> WrapFunction(
674+
absl::AnyInvocable<T(Args...) const> function) {
675+
return WrapFunction(
676+
[function = std::move(function)](
677+
Args... args, const google::protobuf::DescriptorPool* absl_nonnull,
678+
google::protobuf::MessageFactory* absl_nonnull,
679+
google::protobuf::Arena* absl_nonnull) -> T { return function(args...); });
680+
}
681+
682+
private:
683+
class NaryFunctionImpl : public cel::Function {
684+
private:
685+
using ArgBuffer = std::tuple<
686+
typename runtime_internal::AdaptedTypeTraits<Args>::AssignableType...>;
687+
688+
public:
689+
explicit NaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {}
690+
absl::StatusOr<Value> Invoke(
691+
absl::Span<const Value> args,
692+
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
693+
google::protobuf::MessageFactory* absl_nonnull message_factory,
694+
google::protobuf::Arena* absl_nonnull arena) const override {
695+
if (args.size() != sizeof...(Args)) {
696+
return absl::InvalidArgumentError(
697+
absl::StrCat("unexpected number of arguments for ", sizeof...(Args),
698+
"-ary function"));
699+
}
700+
ArgBuffer arg_buffer;
701+
CEL_RETURN_IF_ERROR(
702+
runtime_internal::AdaptHelper<Args...>::Apply(args, arg_buffer));
703+
if constexpr (std::is_same_v<T, Value> ||
704+
std::is_same_v<T, absl::StatusOr<Value>>) {
705+
return runtime_internal::ToArgsHelper<Args...>::template Apply<T>(
706+
fn_, arg_buffer, descriptor_pool, message_factory, arena);
707+
} else {
708+
T result = runtime_internal::ToArgsHelper<Args...>::template Apply<T>(
709+
fn_, arg_buffer, descriptor_pool, message_factory, arena);
710+
return runtime_internal::AdaptedToHandleVisitor{}(std::move(result));
711+
}
712+
}
713+
714+
private:
715+
FunctionType fn_;
716+
};
717+
};
718+
632719
} // namespace cel
633720

634721
#endif // THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_ADAPTER_H_

0 commit comments

Comments
 (0)