Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assert that kernels are called with the right signature #40251

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions aten/src/ATen/core/boxing/impl/test_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,12 @@ inline std::vector<c10::IValue> callOp(const c10::OperatorHandle& op, Args... ar

template<class Result, class... Args>
inline Result callOpUnboxed(const c10::OperatorHandle& op, Args... args) {
return c10::Dispatcher::singleton()
.template call<Result, Args...>(op, std::forward<Args>(args)...);
return op.typed<Result(Args...)>().call(std::forward<Args>(args)...);
}

template<class Result, class... Args>
inline Result callOpUnboxedWithDispatchKey(const c10::OperatorHandle& op, c10::DispatchKey dispatchKey, Args... args) {
return c10::Dispatcher::singleton()
.template callWithDispatchKey<Result, Args...>(op, dispatchKey, std::forward<Args>(args)...);
return op.typed<Result(Args...)>().callWithDispatchKey(dispatchKey, std::forward<Args>(args)...);
}

inline void expectDoesntFindKernel(const char* op_name, c10::DispatchKey dispatch_key) {
Expand Down
60 changes: 60 additions & 0 deletions aten/src/ATen/core/dispatch/CppSignature.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#pragma once

#include <typeindex>
#include <c10/macros/Macros.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/Type.h>

namespace c10 {
namespace impl {

// A CppSignature object holds RTTI information about a C++ function signature at runtime
// and can compare them or get a debug-printable name.
class CAFFE2_API CppSignature final {
public:
CppSignature(const CppSignature&) = default;
CppSignature(CppSignature&&) noexcept = default;
CppSignature& operator=(const CppSignature&) = default;
CppSignature& operator=(CppSignature&&) noexcept = default;

template<class FuncType>
static CppSignature make() {
// Normalize functors, lambdas, function pointers, etc. into the plain function type
using decayed_function_type = typename guts::infer_function_traits_t<std::decay_t<FuncType>>::func_type;

return CppSignature(std::type_index(typeid(decayed_function_type)));
}

std::string name() const {
return c10::demangle(signature_.name());
}

friend bool operator==(const CppSignature& lhs, const CppSignature& rhs) {
if (lhs.signature_ == rhs.signature_) {
return true;
}
// Without RTLD_GLOBAL, the type_index comparison could yield false because
// they point to different instances of the RTTI data, but the types would
// still be the same. Let's check for that case too.
// Note that there still is a case where this might not work, i.e. when
// linking libraries of different compilers together, they might have
// different ways to serialize a type name. That, together with a missing
// RTLD_GLOBAL, would still fail this.
if (lhs.name() == rhs.name()) {
return true;
}

return false;
}

private:
explicit CppSignature(std::type_index signature): signature_(std::move(signature)) {}
std::type_index signature_;
};

inline bool operator!=(const CppSignature& lhs, const CppSignature& rhs) {
return !(lhs == rhs );
}

}
}
38 changes: 38 additions & 0 deletions aten/src/ATen/core/dispatch/CppSignature_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include <ATen/core/dispatch/CppSignature.h>
#include <gtest/gtest.h>
#include <string>

using c10::impl::CppSignature;

namespace {

TEST(CppSignatureTest, given_equalSignature_then_areEqual) {
EXPECT_EQ(CppSignature::make<void()>(), CppSignature::make<void()>());
EXPECT_EQ(CppSignature::make<int64_t(std::string, int64_t)>(), CppSignature::make<int64_t(std::string, int64_t)>());
}

TEST(CppSignatureTest, given_differentSignature_then_areDifferent) {
EXPECT_NE(CppSignature::make<void()>(), CppSignature::make<int64_t()>());
EXPECT_NE(CppSignature::make<int64_t(std::string)>(), CppSignature::make<int64_t(std::string, int64_t)>());
EXPECT_NE(CppSignature::make<std::string(std::string)>(), CppSignature::make<int64_t(std::string)>());
}

TEST(CppSignatureTest, given_equalFunctorAndFunction_then_areEqual) {
struct Functor final {
int64_t operator()(std::string) {return 0;}
};
EXPECT_EQ(CppSignature::make<Functor>(), CppSignature::make<int64_t(std::string)>());
}

TEST(CppSignatureTest, given_differentFunctorAndFunction_then_areDifferent) {
struct Functor final {
int64_t operator()(std::string) {return 0;}
};
EXPECT_NE(CppSignature::make<Functor>(), CppSignature::make<int64_t(std::string, int64_t)>());
}

TEST(CppSignatureTest, given_cppSignature_then_canQueryNameWithoutCrashing) {
CppSignature::make<void(int64_t, const int64_t&)>().name();
}

}
3 changes: 2 additions & 1 deletion aten/src/ATen/core/dispatch/Dispatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,15 @@ RegistrationHandleRAII Dispatcher::registerImpl(
OperatorName op_name,
c10::optional<DispatchKey> dispatch_key,
KernelFunction kernel,
c10::optional<impl::CppSignature> cpp_signature,
std::unique_ptr<FunctionSchema> inferred_function_schema,
std::string debug
) {
std::lock_guard<std::mutex> lock(mutex_);

auto op = findOrRegisterName_(op_name);

auto handle = op.operatorIterator_->op.registerKernel(dispatch_key, std::move(kernel), std::move(inferred_function_schema), std::move(debug));
auto handle = op.operatorIterator_->op.registerKernel(dispatch_key, std::move(kernel), std::move(cpp_signature), std::move(inferred_function_schema), std::move(debug));

++op.operatorIterator_->def_and_impl_count;

Expand Down
68 changes: 50 additions & 18 deletions aten/src/ATen/core/dispatch/Dispatcher.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <ATen/core/dispatch/OperatorEntry.h>
#include <ATen/core/dispatch/CppSignature.h>
#include <ATen/core/dispatch/RegistrationHandleRAII.h>
#include <c10/util/Exception.h>
#include <c10/util/LeftRight.h>
Expand All @@ -10,6 +11,7 @@
namespace c10 {

class CAFFE2_API OperatorHandle;
template<class FuncType> class TypedOperatorHandle;

/**
* Implement this interface and register your instance with the dispatcher
Expand Down Expand Up @@ -59,6 +61,7 @@ class CAFFE2_API Dispatcher final {
size_t def_and_impl_count = 0;
};
friend class OperatorHandle;
template<class> friend class TypedOperatorHandle;

public:
~Dispatcher();
Expand Down Expand Up @@ -107,20 +110,20 @@ class CAFFE2_API Dispatcher final {
// ------------------------------------------------------------------------

template<class Return, class... Args>
Return call(const OperatorHandle& op, Args... args) const;
Return call(const TypedOperatorHandle<Return (Args...)>& op, Args... args) const;

// Like call, but override the default DispatchKey calculation code,
// instead dispatching straight to the provided DispatchKey
template<class Return, class... Args>
Return callWithDispatchKey(const OperatorHandle& op, DispatchKey dispatchKey, Args... args) const;
Return callWithDispatchKey(const TypedOperatorHandle<Return (Args...)>& op, DispatchKey dispatchKey, Args... args) const;

// Like call, but intended for use in a redispatch: you are currently
// in some currentDispatchKey, you have finished processing the key and
// you now want to redispatch to the next dispatch key in the chain.
// This will mask out the current key *and all previous keys* from the
// eligible set, and reinvoke the dispatcher.
template<class Return, class... Args>
Return redispatch(const OperatorHandle& op, DispatchKey currentDispatchKey, Args... args) const;
Return redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKey currentDispatchKey, Args... args) const;

// Invoke an operator via the boxed calling convention using an IValue stack
void callBoxed(const OperatorHandle& op, Stack* stack) const;
Expand Down Expand Up @@ -148,7 +151,7 @@ class CAFFE2_API Dispatcher final {
*/
// NB: steals the inferred function schema, as we may need to hold on to
// it for a bit until the real schema turns up
RegistrationHandleRAII registerImpl(OperatorName op_name, c10::optional<DispatchKey> dispatch_key, KernelFunction kernel, std::unique_ptr<FunctionSchema> inferred_function_schema, std::string debug);
RegistrationHandleRAII registerImpl(OperatorName op_name, c10::optional<DispatchKey> dispatch_key, KernelFunction kernel, c10::optional<impl::CppSignature> cpp_signature, std::unique_ptr<FunctionSchema> inferred_function_schema, std::string debug);

/**
* Register a new operator by name.
Expand Down Expand Up @@ -232,7 +235,7 @@ class CAFFE2_API Dispatcher final {
* This handle can be used to register kernels with the dispatcher or
* to lookup a kernel for a certain set of arguments.
*/
class CAFFE2_API OperatorHandle final {
class CAFFE2_API OperatorHandle {
public:
OperatorHandle(OperatorHandle&&) noexcept = default;
OperatorHandle& operator=(OperatorHandle&&) noexcept = default;
Expand Down Expand Up @@ -263,14 +266,10 @@ class CAFFE2_API OperatorHandle final {
return operatorIterator_->op.checkInvariants();
}

template<class Return, class... Args>
Return call(Args... args) const {
return c10::Dispatcher::singleton().call<Return, Args...>(*this, std::forward<Args>(args)...);
}

template<class Return, class... Args>
Return callWithDispatchKey(DispatchKey dispatchKey, Args... args) const {
return c10::Dispatcher::singleton().callWithDispatchKey<Return, Args...>(*this, dispatchKey, std::forward<Args>(args)...);
template<class FuncType>
TypedOperatorHandle<FuncType> typed() const {
operatorIterator_->op.assertSignatureIsCorrect<FuncType>();
return TypedOperatorHandle<FuncType>(operatorIterator_);
}

void callBoxed(Stack* stack) const {
Expand All @@ -281,35 +280,68 @@ class CAFFE2_API OperatorHandle final {
explicit OperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
: operatorIterator_(std::move(operatorIterator)) {}
friend class Dispatcher;
template<class> friend class TypedOperatorHandle;

std::list<Dispatcher::OperatorDef>::iterator operatorIterator_;
};

/**
* This is a handle to an operator schema registered with the dispatcher.
* It holds the same information as an OperatorHandle, but it is templated
* on the operator arguments and allows calling the operator in an
* unboxed way.
*/
template<class FuncType>
class TypedOperatorHandle final {
static_assert(guts::false_t<FuncType>(), "FuncType in OperatorHandle::typed<FuncType> was not a valid function type");
};
template<class Return, class... Args>
class TypedOperatorHandle<Return (Args...)> final : public OperatorHandle {
public:
TypedOperatorHandle(TypedOperatorHandle&&) noexcept = default;
TypedOperatorHandle& operator=(TypedOperatorHandle&&) noexcept = default;
TypedOperatorHandle(const TypedOperatorHandle&) = default;
TypedOperatorHandle& operator=(const TypedOperatorHandle&) = default;

Return call(Args... args) const {
return c10::Dispatcher::singleton().call<Return, Args...>(*this, std::forward<Args>(args)...);
}

Return callWithDispatchKey(DispatchKey dispatchKey, Args... args) const {
return c10::Dispatcher::singleton().callWithDispatchKey<Return, Args...>(*this, dispatchKey, std::forward<Args>(args)...);
}

private:
explicit TypedOperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
: OperatorHandle(std::move(operatorIterator)) {}
friend class OperatorHandle;
};

namespace detail {
template<class... Args> inline void unused_arg_(const Args&...) {}
}

template<class Return, class... Args>
inline Return Dispatcher::callWithDispatchKey(const OperatorHandle& op, DispatchKey dispatchKey, Args... args) const {
inline Return Dispatcher::callWithDispatchKey(const TypedOperatorHandle<Return(Args...)>& op, DispatchKey dispatchKey, Args... args) const {
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
const auto& dispatchTable = op.operatorIterator_->op.dispatch_table();
const KernelFunction& kernel = dispatch_(dispatchTable, dispatchKey);
return kernel.template call<Return, Args...>(op, std::forward<Args>(args)...);
}

template<class Return, class... Args>
inline Return Dispatcher::call(const OperatorHandle& op, Args... args) const {
inline Return Dispatcher::call(const TypedOperatorHandle<Return(Args...)>& op, Args... args) const {
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
const auto& dispatchTable = op.operatorIterator_->op.dispatch_table();
auto dispatchKey = dispatchTable.dispatchKeyExtractor().getDispatchKeyUnboxed<Args...>(backendsWithoutFallthrough_, DispatchKeySet::FULL, args...);
auto dispatchKey = dispatchTable.dispatchKeyExtractor().template getDispatchKeyUnboxed<Args...>(backendsWithoutFallthrough_, DispatchKeySet::FULL, args...);
return callWithDispatchKey<Return, Args...>(op, dispatchKey, args...);
}

template<class Return, class... Args>
inline Return Dispatcher::redispatch(const OperatorHandle& op, DispatchKey currentDispatchKey, Args... args) const {
inline Return Dispatcher::redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKey currentDispatchKey, Args... args) const {
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
const auto& dispatchTable = op.operatorIterator_->op.dispatch_table();
auto dispatchKey = dispatchTable.dispatchKeyExtractor().getDispatchKeyUnboxed<Args...>(
auto dispatchKey = dispatchTable.dispatchKeyExtractor().template getDispatchKeyUnboxed<Args...>(
backendsWithoutFallthrough_,
DispatchKeySet(DispatchKeySet::FULL_AFTER, currentDispatchKey),
args...);
Expand Down
13 changes: 13 additions & 0 deletions aten/src/ATen/core/dispatch/OperatorEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,24 @@ void OperatorEntry::deregisterSchema() {
std::list<OperatorEntry::KernelEntry>::iterator OperatorEntry::registerKernel(
c10::optional<DispatchKey> dispatch_key,
KernelFunction kernel,
c10::optional<CppSignature> cpp_signature,
std::unique_ptr<FunctionSchema> inferred_function_schema,
std::string debug
) {
std::unique_lock<std::mutex> lock(kernelsMutex_);

if (cpp_signature.has_value()) {
if (cpp_signature_.has_value()) {
TORCH_INTERNAL_ASSERT(*cpp_signature == *cpp_signature_,
"Tried to register a kernel (", debug, ") for operator ", name_," for dispatch key ", toString(dispatch_key),
", but the C++ function signature ", cpp_signature->name(), " mismatched with a previous kernel that had the signature ",
cpp_signature_->name()
);
} else {
cpp_signature_ = *cpp_signature;
}
}

if (schema_ && inferred_function_schema) {
checkSchema(name_, *schema_, *debug_, *inferred_function_schema, debug);
}
Expand Down
24 changes: 23 additions & 1 deletion aten/src/ATen/core/dispatch/OperatorEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <ATen/core/dispatch/DispatchTable.h>
#include <ATen/core/dispatch/OperatorOptions.h>
#include <ATen/core/dispatch/CppSignature.h>
#include <ATen/core/dispatch/RegistrationHandleRAII.h>
#include <list>

Expand Down Expand Up @@ -82,7 +83,7 @@ class CAFFE2_API OperatorEntry final {
void prepareForDeregistration();

// Postcondition: caller is responsible for disposing of the kernel
std::list<KernelEntry>::iterator registerKernel(c10::optional<DispatchKey> dispatch_key, KernelFunction kernel, std::unique_ptr<FunctionSchema> inferred_function_schema, std::string debug);
std::list<KernelEntry>::iterator registerKernel(c10::optional<DispatchKey> dispatch_key, KernelFunction kernel, c10::optional<CppSignature> cpp_signature, std::unique_ptr<FunctionSchema> inferred_function_schema, std::string debug);
void deregisterKernel_(c10::optional<DispatchKey> dispatch_key, std::list<KernelEntry>::iterator kernel);

void updateSchemaAliasAnalysis(AliasAnalysisKind a) {
Expand All @@ -101,6 +102,20 @@ class CAFFE2_API OperatorEntry final {
dispatchTable_.setManuallyBoxedKernel_(func);
}

// Asserts that the given FuncType is correct for calling this operator in an unboxed way.
template<class FuncType>
void assertSignatureIsCorrect() {
TORCH_INTERNAL_ASSERT(!cpp_signature_.has_value() || (CppSignature::make<FuncType>() == *cpp_signature_),
"Tried to access operator ", name_, " with a wrong signature. Accessed with ",
CppSignature::make<FuncType>().name(),
" but the operator was registered with ",
cpp_signature_->name(),
" (",
debug_.value(),
") This likely happened in a call to OperatorHandle::typed<Return (Args...)>(). Please make sure that the function signature matches the signature in the operator registration call."
);
}

private:

OperatorName name_;
Expand Down Expand Up @@ -146,6 +161,13 @@ class CAFFE2_API OperatorEntry final {

std::mutex kernelsMutex_; // protects kernels_

// signature_hash_ is set to the hash of the function signature if any of
// the kernels was created in a way that allowed us to know the function
// signature (i.e. by supplying an unboxed C++ kernel function).
// If this is set, it will be used in unboxed function calls
// to verify their arguments against the known function signature.
c10::optional<CppSignature> cpp_signature_;

// This function re-establishes the invariant that dispatchTable
// contains the front element from the kernels list for a given dispatch key.
void updateDispatchTable_(c10::optional<DispatchKey> dispatch_key);
Expand Down
5 changes: 4 additions & 1 deletion aten/src/ATen/core/library.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ namespace {
}
}

CppFunction::CppFunction(c10::KernelFunction func, std::unique_ptr<c10::FunctionSchema> schema)
CppFunction::CppFunction(c10::KernelFunction func, c10::optional<c10::impl::CppSignature> cpp_signature, std::unique_ptr<c10::FunctionSchema> schema)
: func_(std::move(func))
, cpp_signature_(std::move(cpp_signature))
, schema_(std::move(schema))
, debug_()
{}
Expand Down Expand Up @@ -153,6 +154,7 @@ Library& Library::_def(c10::either<c10::OperatorName, c10::FunctionSchema>&& nam
std::move(name),
dispatch_key,
std::move(f.func_),
std::move(f.cpp_signature_),
std::move(f.schema_),
debugString(std::move(f.debug_), file_, line_)
)
Expand Down Expand Up @@ -197,6 +199,7 @@ Library& Library::_impl(const char* name_str, CppFunction&& f) & {
std::move(name),
dispatch_key,
std::move(f.func_),
std::move(f.cpp_signature_),
std::move(f.schema_),
debugString(std::move(f.debug_), file_, line_)
)
Expand Down
Loading