Skip to content

Commit

Permalink
Update TVM registry API calls and node->object
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky committed Jan 9, 2020
1 parent 06068e7 commit 984f9c8
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions aot/to_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def mk_register_api(self, name: str, func) -> str:
args += ", "

source += f"""
TVM_REGISTER_API("{name}")
TVM_REGISTER_GLOBAL("{name}")
.set_body([](TVMArgs args, TVMRetValue* ret) {{
{init}
std::initializer_list<Value> ilist = {{{args}}};
Expand All @@ -385,7 +385,9 @@ def inter(strs, sep=", "):

def mk_file(body, ctx):
return f"""
#include <tvm/api_registry.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/node/env_func.h>
#include <tvm/relay/interpreter.h>
#include <iostream>
Expand Down Expand Up @@ -416,8 +418,8 @@ def mk_file(body, ctx):
}}
static ConstructorValue TagToCV(size_t tag, const tvm::Array<Value>& fields) {{
NodePtr<ConstructorValueNode> n = make_node<ConstructorValueNode>();
NodePtr<ConstructorNode> con = make_node<ConstructorNode>();
ObjectPtr<ConstructorValueNode> n = make_object<ConstructorValueNode>();
ObjectPtr<ConstructorNode> con = make_object<ConstructorNode>();
con->tag = tag;
n->tag = tag;
n->constructor = Constructor(con);
Expand All @@ -439,13 +441,16 @@ class FunctionValue;
TVM_DLL static FunctionValue make(const function_value_t& f);
static constexpr const char* _type_key = "relay.FunctionValue";
TVM_DECLARE_NODE_TYPE_INFO(FunctionValueNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionValueNode, ValueNode);
}};
RELAY_DEFINE_NODE_REF(FunctionValue, FunctionValueNode, Value);
class FunctionValue : public Value {{
public:
TVM_DEFINE_OBJECT_REF_METHODS(FunctionValue, Value, FunctionValueNode);
}};
FunctionValue FunctionValueNode::make(const function_value_t& f) {{
NodePtr<FunctionValueNode> n = make_node<FunctionValueNode>();
ObjectPtr<FunctionValueNode> n = make_object<FunctionValueNode>();
n->f = f;
return FunctionValue(n);
}}
Expand Down

0 comments on commit 984f9c8

Please sign in to comment.