Skip to content

Commit

Permalink
override all the stuff (apache#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame authored and jroesch committed Aug 16, 2018
1 parent 3fe4cfa commit 2c79399
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 51 deletions.
32 changes: 16 additions & 16 deletions relay/include/relay/evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,22 @@ class Evaluator : public ExprFunctor<Value(const Expr& n)> {
Evaluator();
Evaluator(Environment env) : env(env) {}
Value Eval(const Expr& expr);
Value VisitExpr_(const LocalIdNode* op);
Value VisitExpr_(const GlobalIdNode* op);
Value VisitExpr_(const IntrinsicIdNode* op);
Value VisitExpr_(const FloatLitNode* op);
Value VisitExpr_(const BoolLitNode* op);
Value VisitExpr_(const IntLitNode* op);
Value VisitExpr_(const TensorLitNode* op);
Value VisitExpr_(const ProductLitNode* op);
Value VisitExpr_(const CastNode* op);
Value VisitExpr_(const ParamNode* op);
Value VisitExpr_(const FunctionNode* op);
Value VisitExpr_(const CallNode* op);
Value VisitExpr_(const DebugNode* op);
Value VisitExpr_(const UnaryOpNode* op);
Value VisitExpr_(const BinaryOpNode* op);
Value VisitExpr_(const AssignmentNode* op);
Value VisitExpr_(const LocalIdNode* op) override;
Value VisitExpr_(const GlobalIdNode* op) override;
Value VisitExpr_(const IntrinsicIdNode* op) override;
Value VisitExpr_(const FloatLitNode* op) override;
Value VisitExpr_(const BoolLitNode* op) override;
Value VisitExpr_(const IntLitNode* op) override;
Value VisitExpr_(const TensorLitNode* op) override;
Value VisitExpr_(const ProductLitNode* op) override;
Value VisitExpr_(const CastNode* op) override;
Value VisitExpr_(const ParamNode* op) override;
Value VisitExpr_(const FunctionNode* op) override;
Value VisitExpr_(const CallNode* op) override;
Value VisitExpr_(const DebugNode* op) override;
Value VisitExpr_(const UnaryOpNode* op) override;
Value VisitExpr_(const BinaryOpNode* op) override;
Value VisitExpr_(const AssignmentNode* op) override;
};

} // namespace relay
Expand Down
32 changes: 16 additions & 16 deletions relay/include/relay/typechecker.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,22 @@ class Typechecker : public ExprFunctor<Type(const Expr & n)> {
Typechecker();
Typechecker(Environment env) : env(env) {}
Type Check(const Expr & expr);
Type VisitExpr_(const LocalIdNode* op);
Type VisitExpr_(const GlobalIdNode* op);
Type VisitExpr_(const IntrinsicIdNode* op);
Type VisitExpr_(const FloatLitNode* op);
Type VisitExpr_(const BoolLitNode* op);
Type VisitExpr_(const IntLitNode* op);
Type VisitExpr_(const TensorLitNode* op);
Type VisitExpr_(const ProductLitNode* op);
Type VisitExpr_(const CastNode* op);
Type VisitExpr_(const ParamNode* op);
Type VisitExpr_(const FunctionNode* op);
Type VisitExpr_(const CallNode* op);
Type VisitExpr_(const DebugNode* op);
Type VisitExpr_(const UnaryOpNode* op);
Type VisitExpr_(const BinaryOpNode* op);
Type VisitExpr_(const AssignmentNode* op);
Type VisitExpr_(const LocalIdNode* op) override;
Type VisitExpr_(const GlobalIdNode* op) override;
Type VisitExpr_(const IntrinsicIdNode* op) override;
Type VisitExpr_(const FloatLitNode* op) override;
Type VisitExpr_(const BoolLitNode* op) override;
Type VisitExpr_(const IntLitNode* op) override;
Type VisitExpr_(const TensorLitNode* op) override;
Type VisitExpr_(const ProductLitNode* op) override;
Type VisitExpr_(const CastNode* op) override;
Type VisitExpr_(const ParamNode* op) override;
Type VisitExpr_(const FunctionNode* op) override;
Type VisitExpr_(const CallNode* op) override;
Type VisitExpr_(const DebugNode* op) override;
Type VisitExpr_(const UnaryOpNode* op) override;
Type VisitExpr_(const BinaryOpNode* op) override;
Type VisitExpr_(const AssignmentNode* op) override;
};

} // namespace relay
Expand Down
38 changes: 19 additions & 19 deletions relay/include/relay/visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ class FunctorNode

void VisitAttrs(tvm::AttrVisitor* v) final {}

NodeRef VisitExpr_(const LocalIdNode* local, tvm::Array<NodeRef> args) {
NodeRef VisitExpr_(const LocalIdNode* local, tvm::Array<NodeRef> args) override {
if (visit_local_id != nullptr) {
return visit_local_id(local->name, args);
} else {
return LocalIdNode::make(local->name);
}
}

NodeRef VisitExpr_(const GlobalIdNode* global, tvm::Array<NodeRef> args) {
NodeRef VisitExpr_(const GlobalIdNode* global, tvm::Array<NodeRef> args) override {
if (visit_global_id != nullptr) {
return visit_global_id(global->name, args);
} else {
Expand All @@ -62,31 +62,31 @@ class FunctorNode
}

NodeRef VisitExpr_(const IntrinsicIdNode* intrinsic,
tvm::Array<NodeRef> args) {
tvm::Array<NodeRef> args) override {
if (visit_intrinsic_id != nullptr) {
return visit_intrinsic_id(intrinsic->name, args);
} else {
return IntrinsicIdNode::make(intrinsic->name);
}
}

NodeRef VisitExpr_(const FloatLitNode* float_lit, tvm::Array<NodeRef> args) {
NodeRef VisitExpr_(const FloatLitNode* float_lit, tvm::Array<NodeRef> args) override {
if (visit_float_lit != nullptr) {
return visit_float_lit(float_lit->value, args);
} else {
return FloatLitNode::make(float_lit->value);
}
}

NodeRef VisitExpr_(const BoolLitNode* bool_lit, tvm::Array<NodeRef> args) {
NodeRef VisitExpr_(const BoolLitNode* bool_lit, tvm::Array<NodeRef> args) override {
if (visit_bool_lit != nullptr) {
return visit_bool_lit(bool_lit->value, args);
} else {
return BoolLitNode::make(bool_lit->value);
}
}

NodeRef VisitExpr_(const IntLitNode* int_lit, tvm::Array<NodeRef> args) {
NodeRef VisitExpr_(const IntLitNode* int_lit, tvm::Array<NodeRef> args) override {
if (visit_int_lit != nullptr) {
return visit_int_lit(int_lit->value, args);
} else {
Expand All @@ -95,7 +95,7 @@ class FunctorNode
}

NodeRef VisitExpr_(const TensorLitNode* tensor_lit,
tvm::Array<NodeRef> args) {
tvm::Array<NodeRef> args) override {
if (visit_tensor_lit != nullptr) {
return visit_tensor_lit(tensor_lit->data, args);
} else {
Expand All @@ -104,63 +104,63 @@ class FunctorNode
}

NodeRef VisitExpr_(const ProductLitNode* product_lit,
tvm::Array<NodeRef> args) {
tvm::Array<NodeRef> args) override {
if (visit_product_lit != nullptr) {
return visit_product_lit(product_lit->fields, args);
} else {
return ProductLitNode::make(product_lit->fields);
}
}

NodeRef VisitExpr_(const CastNode* cast, tvm::Array<NodeRef> args) {
NodeRef VisitExpr_(const CastNode* cast, tvm::Array<NodeRef> args) override {
if (visit_cast != nullptr) {
return visit_cast(cast->target, cast->node, args);
} else {
return CastNode::make(cast->target, cast->node);
}
}

NodeRef VisitExpr_(const ParamNode* param, tvm::Array<NodeRef> args) {
NodeRef VisitExpr_(const ParamNode* param, tvm::Array<NodeRef> args) override {
if (visit_param != nullptr) {
return visit_param(param->id, param->type, args);
} else {
return ParamNode::make(param->id, param->type);
}
}

NodeRef VisitExpr_(const FunctionNode* fn, tvm::Array<NodeRef> args) {
NodeRef VisitExpr_(const FunctionNode* fn, tvm::Array<NodeRef> args) override {
if (visit_function != nullptr) {
return visit_function(fn->params, fn->body, args);
} else {
return FunctionNode::make(fn->params, fn->body);
}
}

NodeRef VisitExpr_(const CallNode* call, tvm::Array<NodeRef> args) {
NodeRef VisitExpr_(const CallNode* call, tvm::Array<NodeRef> args) override {
if (visit_call != nullptr) {
return visit_call(call->fn, call->args, args);
} else {
return CallNode::make(call->fn, call->args);
}
}

NodeRef VisitExpr_(const DebugNode* debug, tvm::Array<NodeRef> args) {
NodeRef VisitExpr_(const DebugNode* debug, tvm::Array<NodeRef> args) override {
if (visit_debug != nullptr) {
return visit_debug(debug->node, args);
} else {
return DebugNode::make(debug->node);
}
}

NodeRef VisitExpr_(const UnaryOpNode* uop, tvm::Array<NodeRef> args) {
NodeRef VisitExpr_(const UnaryOpNode* uop, tvm::Array<NodeRef> args) override {
if (visit_unary_op != nullptr) {
return visit_unary_op(static_cast<int>(uop->op), uop->node, args);
} else {
return UnaryOpNode::make(uop->op, uop->node);
}
}

NodeRef VisitExpr_(const BinaryOpNode* bop, tvm::Array<NodeRef> args) {
NodeRef VisitExpr_(const BinaryOpNode* bop, tvm::Array<NodeRef> args) override {
if (visit_binary_op != nullptr) {
return visit_binary_op(static_cast<int>(bop->op), bop->left, bop->right,
args);
Expand All @@ -169,7 +169,7 @@ class FunctorNode
}
}

NodeRef VisitExpr_(const AssignmentNode* bop, tvm::Array<NodeRef> args) {
NodeRef VisitExpr_(const AssignmentNode* bop, tvm::Array<NodeRef> args) override {
throw "foo";
// if (visit_assignment != nullptr) {
// return visit_assignment(intrinsic->name, args);
Expand All @@ -178,23 +178,23 @@ class FunctorNode
// }
}

NodeRef VisitExpr_(const ReverseNode* rev, tvm::Array<NodeRef> args) {
NodeRef VisitExpr_(const ReverseNode* rev, tvm::Array<NodeRef> args) override {
if (visit_reverse != nullptr) {
return visit_reverse(rev->node, args);
} else {
return ReverseNode::make(rev->node);
}
}

NodeRef VisitExpr_(const AccumulateNode* acc, tvm::Array<NodeRef> args) {
NodeRef VisitExpr_(const AccumulateNode* acc, tvm::Array<NodeRef> args) override {
if (visit_accumulate != nullptr) {
return visit_accumulate(acc->update_binders, acc->value, args);
} else {
return AccumulateNode::make(acc->update_binders, acc->value);
}
}

NodeRef VisitExpr_(const ZeroNode* zero, tvm::Array<NodeRef> args) {
NodeRef VisitExpr_(const ZeroNode* zero, tvm::Array<NodeRef> args) override {
if (visit_zero != nullptr) {
return visit_zero(zero->type, args);
} else {
Expand Down

0 comments on commit 2c79399

Please sign in to comment.