Skip to content

Commit

Permalink
[Fix] Fix i32/i64 issues after rebase (apache#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 authored Nov 20, 2022
1 parent bd21748 commit eee21b8
Show file tree
Hide file tree
Showing 26 changed files with 727 additions and 406 deletions.
2 changes: 1 addition & 1 deletion include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class BufferNode : public Object {

/*! \return preferred index type for this buffer node */
DataType DefaultIndexType() const {
return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
return shape.size() != 0 ? shape[0].dtype() : DataType::Int(64);
}

/*! \brief Determine the offset in the buffer of the given index.
Expand Down
1 change: 1 addition & 0 deletions include/tvm/tir/data_type_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class IndexDataTypeNormalizer : public IndexDataTypeRewriter {
using Parent::VisitStmt_;
PrimExpr VisitExpr_(const IntImmNode* op) final;
PrimExpr VisitExpr_(const VarNode* op) final;
PrimExpr VisitExpr_(const CastNode* op) final;

DataType target_data_type_ = DataType::Int(64);
};
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def call_te(self, func: Callable, *args: Any, **kwargs: Any) -> Expr:
unbound_tir_vars = self._get_unbound_tir_vars(te_args + outs)

inputs = [*te_args] + outs
tir_func = tvm.te.create_prim_func(inputs, unbound_tir_vars)
tir_func = tvm.te.create_prim_func(inputs, unbound_tir_vars, "int64")

if primfunc_name_hint:
gvar = self.add_func(tir_func, primfunc_name_hint)
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ def __init__(
values: Union[List[PrimExpr], typing.Tuple[PrimExpr, ...], tvm.ir.Array],
span: Span = None,
) -> None:
self.__init_handle_by_constructor__(_ffi_api.ShapeExpr, values, span) # type: ignore
values_int64 = []
for value in values:
values_int64.append(tvm.tir.IntImm("int64", value) if isinstance(value, int) else value)
self.__init_handle_by_constructor__(_ffi_api.ShapeExpr, values_int64, span) # type: ignore

def __getitem__(self, index):
if index >= len(self):
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def _create_shape(shape: List[Union[int, PrimExpr]]) -> ShapeExpr:
for x in shape:
if isinstance(x, int):
shape_array.append(tvm.tir.IntImm("int64", x))
elif isinstance(x, tvm.tir.IntImm):
shape_array.append(tvm.tir.IntImm("int64", x.value))
elif isinstance(x, PrimExpr):
# TODO: enforce all shapes are i64
# if x.dtype != "int64":
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/op/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def resize2d(
if isinstance(shape, PrimExpr):
temp_size.append(shape)
elif isinstance(shape, int):
temp_size.append(tvm.tir.const(shape, "int32"))
temp_size.append(tvm.tir.const(shape, "int64"))
else:
raise RuntimeError(
f"The input new shape of reshape operator contains unrecognized dimension {shape}"
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def adaptive_avg_pool2d(
if isinstance(shape, PrimExpr):
temp_size.append(shape)
elif isinstance(shape, int):
temp_size.append(tvm.tir.const(shape, "int32"))
temp_size.append(tvm.tir.const(shape, "int64"))
else:
raise RuntimeError(
f"The input new shape of reshape operator contains unrecognized dimension {shape}"
Expand Down
12 changes: 6 additions & 6 deletions python/tvm/relax/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def reshape(
if isinstance(shape, PrimExpr):
temp_shape.append(shape)
elif isinstance(shape, int):
temp_shape.append(tvm.tir.const(shape, "int32"))
temp_shape.append(tvm.tir.const(shape, "int64"))
else:
raise RuntimeError(
f"The input new shape of reshape operator contains unrecognized dimension {shape}"
Expand Down Expand Up @@ -326,7 +326,7 @@ def full(
if isinstance(shape, PrimExpr):
temp_shape.append(shape)
elif isinstance(shape, int):
temp_shape.append(tvm.tir.const(shape, "int32"))
temp_shape.append(tvm.tir.const(shape, "int64"))
else:
raise RuntimeError(
f"The input new shape of reshape operator contains unrecognized dimension {shape}"
Expand Down Expand Up @@ -375,14 +375,14 @@ def split(
if isinstance(idx, PrimExpr):
indices.append(idx)
elif isinstance(idx, int):
indices.append(tvm.tir.const(idx, "int32"))
indices.append(tvm.tir.const(idx, "int64"))
else:
raise RuntimeError(
f'The input indices of split operator contains unrecognized index "{idx}"'
)
indices_or_sections = indices
elif isinstance(indices_or_sections, int):
indices_or_sections = tvm.tir.IntImm("int32", indices_or_sections)
indices_or_sections = tvm.tir.IntImm("int64", indices_or_sections)
else:
raise RuntimeError(
f"The input `indices_or_sections` has unrecognized type {type(indices_or_sections)}"
Expand Down Expand Up @@ -417,7 +417,7 @@ def broadcast_to(
if isinstance(shape, PrimExpr):
temp_shape.append(shape)
elif isinstance(shape, int):
temp_shape.append(tvm.tir.const(shape, "int32"))
temp_shape.append(tvm.tir.const(shape, "int64"))
else:
raise RuntimeError(
f"The input new shape of reshape operator contains unrecognized dimension {shape}"
Expand Down Expand Up @@ -475,7 +475,7 @@ def convert_int(arr):
if isinstance(x, PrimExpr):
res.append(x)
elif isinstance(x, int):
res.append(tvm.tir.const(x, "int32"))
res.append(tvm.tir.const(x, "int64"))
else:
raise RuntimeError(
f"The input of strided_slice operator contains unrecognized value {x}"
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ def convert_to_expr(value: Union[PrimExpr, Expr, Tuple[PrimExpr, Expr]]) -> Expr
if not isinstance(value, tuple):
return convert_to_object(value)
value = list(value)
if all([isinstance(f, (PrimExpr, int)) for f in value]):
return ShapeExpr(value)
for i, v in enumerate(value):
value[i] = convert_to_expr(v)
if all([isinstance(f, PrimExpr) for f in value]):
return ShapeExpr(value)
elif all([isinstance(f, Expr) for f in value]): # type: ignore
if all([isinstance(f, Expr) for f in value]): # type: ignore
return rx_Tuple(value)
else:
raise TypeError("Return types, with mixed PrimExpr and Relax Expr, is not supported.")
2 changes: 1 addition & 1 deletion src/relax/ir/emit_te.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ te::Tensor TETensor(Expr value, std::string name) {
Array<PrimExpr> shape;
shape.reserve(ndim);
for (int i = 0; i < ndim; ++i) {
shape.push_back(IntImm(DataType::Int(32), shape_tuple[i]));
shape.push_back(IntImm(DataType::Int(64), shape_tuple[i]));
}
n->shape = std::move(shape);
return te::PlaceholderOp(n).output(0);
Expand Down
2 changes: 1 addition & 1 deletion src/relax/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ Expr ExprMutator::VisitExpr_(const SeqExprNode* op) {
}

void ExprMutator::VisitBinding_(const VarBindingNode* binding) {
Expr new_value = this->VisitExpr(binding->value);
Expr new_value = this->builder_->Normalize(this->VisitExpr(binding->value));
Var new_var = this->VisitVarDef(binding->var);

auto emit = [this](VarBinding b) {
Expand Down
4 changes: 2 additions & 2 deletions src/relax/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,12 +473,12 @@ Optional<Expr> InferShapeMatmul(const Call& call, DiagnosticContext diag_ctx) {
bool a_prepended = false;
bool b_appended = false;
if (a_ndim == 1) {
a_shape.insert(a_shape.begin(), tir::make_const(DataType::Int(32), 1));
a_shape.insert(a_shape.begin(), tir::make_const(DataType::Int(64), 1));
a_ndim = 2;
a_prepended = true;
}
if (b_ndim == 1) {
b_shape.insert(b_shape.end(), tir::make_const(DataType::Int(32), 1));
b_shape.insert(b_shape.end(), tir::make_const(DataType::Int(64), 1));
b_ndim = 2;
b_appended = true;
}
Expand Down
2 changes: 1 addition & 1 deletion src/relax/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ Optional<Expr> InferShapeReduction(const Call& call, DiagnosticContext diag_ctx)
if (!appeared_axes[i]) {
output_shape.push_back(shape->values[i]);
} else if (attrs->keepdims) {
output_shape.push_back(tir::make_const(DataType::Int(32), 1));
output_shape.push_back(tir::make_const(DataType::Int(64), 1));
}
}
return ShapeExpr(std::move(output_shape));
Expand Down
38 changes: 18 additions & 20 deletions src/relax/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,26 +140,26 @@ Optional<Expr> InferShapeReshape(const Call& call, DiagnosticContext diag_ctx) {
}

int ndim = shape->values.size();
PrimExpr shape_prod = tir::make_const(tvm::DataType::Int(32), 1);
PrimExpr shape_prod = tir::make_const(tvm::DataType::Int(64), 1);
for (int i = 0; i < ndim; ++i) {
shape_prod = shape_prod * shape->values[i];
}

int dim_to_infer = -1;
int new_ndim = new_shape->values.size();
PrimExpr new_shape_prod = tir::make_const(tvm::DataType::Int(32), 1);
PrimExpr new_shape_prod = tir::make_const(tvm::DataType::Int(64), 1);
arith::Analyzer ana;
for (int i = 0; i < new_ndim; ++i) {
PrimExpr dim_len = new_shape->values[i];
if (ana.CanProveEqual(dim_len, tir::make_const(tvm::DataType::Int(32), -1))) {
if (ana.CanProveEqual(dim_len, tir::make_const(tvm::DataType::Int(64), -1))) {
if (dim_to_infer != -1) {
diag_ctx.EmitFatal(Diagnostic::Error(call->span)
<< "Reshape op accepts at most one \"-1\" in the new shape. However, "
"the new shape on dimension "
<< dim_to_infer << " and " << i << " are both \"-1\"");
}
dim_to_infer = i;
} else if (ana.CanProveEqual(dim_len, tir::make_const(tvm::DataType::Int(32), 0))) {
} else if (ana.CanProveEqual(dim_len, tir::make_const(tvm::DataType::Int(64), 0))) {
diag_ctx.EmitFatal(Diagnostic::Error(call->span)
<< "Reshape op does not accept \"0\" in the new shape. However, the new "
"shape on dimension "
Expand Down Expand Up @@ -264,7 +264,7 @@ Optional<Expr> InferShapeExpandDims(const Call& call, DiagnosticContext diag_ctx
"- at least two indices refers to dim "
<< dim << ". Please make sure the indices do not duplicate.");
}
output_shape.Set(dim, tvm::tir::make_const(tvm::DataType::Int(32), 1));
output_shape.Set(dim, tvm::tir::make_const(tvm::DataType::Int(64), 1));
}

for (int i = 0, p = 0; i < output_ndim; ++i) {
Expand Down Expand Up @@ -327,7 +327,7 @@ Optional<Expr> InferShapeSqueeze(const Call& call, DiagnosticContext diag_ctx) {
const auto* shape = call->args[0]->shape().as<ShapeExprNode>();
const auto* attrs = call->attrs.as<SqueezeAttrs>();
if (shape == nullptr) {
return NullOpt;
return RuntimeDepShape();
}

int ndim = shape->values.size();
Expand All @@ -348,7 +348,7 @@ Optional<Expr> InferShapeSqueeze(const Call& call, DiagnosticContext diag_ctx) {
<< "\" at the position " << i
<< " of the axis indices of operator squeeze is out of range.");
}
if (ana.CanProve(shape->values[dim] != tir::make_const(DataType::Int(32), 1))) {
if (ana.CanProve(shape->values[dim] != tir::make_const(DataType::Int(64), 1))) {
diag_ctx.EmitFatal(Diagnostic::Error(call->span)
<< "Squeeze expects all axis indices to correspond to axes with "
"dimension length 1. However, the input data on given axis index "
Expand All @@ -358,10 +358,10 @@ Optional<Expr> InferShapeSqueeze(const Call& call, DiagnosticContext diag_ctx) {
}
} else {
for (int i = 0; i < ndim; ++i) {
bool is_one = ana.CanProveEqual(shape->values[i], tir::make_const(DataType::Int(32), 1));
bool isnt_one = ana.CanProve(shape->values[i] != tir::make_const(DataType::Int(32), 1));
bool is_one = ana.CanProveEqual(shape->values[i], tir::make_const(DataType::Int(64), 1));
bool isnt_one = ana.CanProve(shape->values[i] != tir::make_const(DataType::Int(64), 1));
if (!is_one && !isnt_one) {
return NullOpt;
return RuntimeDepShape();
} else if (is_one) {
removed_axis.insert(i);
}
Expand Down Expand Up @@ -400,9 +400,7 @@ Type InferTypeSqueeze(const Call& call, DiagnosticContext diag_ctx) {
}
} else {
Optional<Expr> out_shape = InferShapeSqueeze(call, diag_ctx);
if (out_shape.defined()) {
const auto* shape = out_shape.value().as<ShapeExprNode>();
ICHECK_NOTNULL(shape);
if (const auto* shape = out_shape.value().as<ShapeExprNode>()) {
return DynTensorType(shape->values.size(), input_type->dtype);
} else {
return DynTensorType(-1, input_type->dtype);
Expand Down Expand Up @@ -507,7 +505,7 @@ Optional<Expr> InferShapeConcatenate(const Call& call, DiagnosticContext diag_ct

for (int dim = 0; dim < output_ndim; ++dim) {
if (dim == concat_axis) {
PrimExpr concat_dim_len = tir::make_const(DataType::Int(32), 0);
PrimExpr concat_dim_len = tir::make_const(DataType::Int(64), 0);
for (int i = 0; i < n_tensor; ++i) {
PrimExpr dim_len = ana.Simplify(Downcast<ShapeExpr>(tuple_shape->fields[i])->values[dim]);
concat_dim_len = concat_dim_len + dim_len;
Expand Down Expand Up @@ -539,7 +537,7 @@ Optional<Expr> InferShapeConcatenate(const Call& call, DiagnosticContext diag_ct
}
}
if (static_len != -1) {
output_shape.push_back(tir::make_const(DataType::Int(32), static_len));
output_shape.push_back(tir::make_const(DataType::Int(64), static_len));
} else if (!runtime_dep_dim) {
ICHECK(symbolic_len.defined());
output_shape.push_back(symbolic_len);
Expand Down Expand Up @@ -651,7 +649,7 @@ Optional<Expr> InferShapeCumsum(const Call& call, DiagnosticContext diag_ctx) {
return GetRef<ShapeExpr>(shape);
}

PrimExpr prod = tir::make_const(DataType::Int(32), 1);
PrimExpr prod = tir::make_const(DataType::Int(64), 1);
for (const PrimExpr& shape_dim : shape->values) {
prod = prod * shape_dim;
}
Expand Down Expand Up @@ -992,7 +990,7 @@ Optional<Expr> InferShapeSplit(const Call& call, DiagnosticContext diag_ctx) {
PrimExpr len_axis = input_shape->values[axis];
if (const auto* p_indices = attrs->indices_or_sections.as<ArrayNode>()) {
Array<PrimExpr> indices = GetRef<Array<PrimExpr>>(p_indices);
PrimExpr zero = tir::make_const(DataType::Int(32), 0);
PrimExpr zero = tir::make_const(DataType::Int(64), 0);

output_shape.reserve(indices.size() + 1);
indices.insert(indices.begin(), zero);
Expand Down Expand Up @@ -1028,7 +1026,7 @@ Optional<Expr> InferShapeSplit(const Call& call, DiagnosticContext diag_ctx) {
}
// Todo(relax-team): need runtime divisibility check for the cases where `len_axis` is symbolic

PrimExpr n_section_expr = tir::make_const(DataType::Int(32), n_section);
PrimExpr n_section_expr = tir::make_const(DataType::Int(64), n_section);
Array<PrimExpr> shape = input_shape->values;
shape.erase(shape.begin() + axis);
shape.insert(shape.begin() + axis, tvm::floordiv(len_axis, n_section_expr));
Expand Down Expand Up @@ -1217,7 +1215,7 @@ Optional<Expr> InferShapeStridedSlice(const Call& call, DiagnosticContext diag_c
} else {
strides.reserve(n_axis);
for (int i = 0; i < n_axis; ++i) {
strides.push_back(tir::make_const(DataType::Int(32), 1));
strides.push_back(tir::make_const(DataType::Int(64), 1));
}
}

Expand Down Expand Up @@ -1271,7 +1269,7 @@ Optional<Expr> InferShapeStridedSlice(const Call& call, DiagnosticContext diag_c
PrimExpr stride = strides[i];

if (attrs->slice_mode == "size") {
stride = tir::make_const(DataType::Int(32), 1);
stride = tir::make_const(DataType::Int(64), 1);
end = begin + ends[i];
} else {
if (attrs->slice_mode != "end") {
Expand Down
16 changes: 15 additions & 1 deletion src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype, String buffer_name, Opt
buffer_data = data.value();
}
if (!elem_offset.defined() && offset_factor) {
DataType shape_dtype = shape.empty() ? DataType::Int(32) : shape[0]->dtype;
DataType shape_dtype = shape.empty() ? DataType::Int(64) : shape[0]->dtype;
elem_offset = tvm::tir::Var("elem_offset", shape_dtype);
}
return Buffer(buffer_data, dtype, shape, strides.value_or(Array<PrimExpr>()),
Expand Down Expand Up @@ -110,6 +110,20 @@ tvm::Type FuncRet(tvm::Type ret_type) {
Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype, Optional<Var> data,
Array<PrimExpr> strides, PrimExpr elem_offset, String storage_scope, int align,
int offset_factor, String buffer_type_str, Array<IntImm> axis_separators) {
if (!elem_offset.defined()) {
DataType shape_dtype;
if (param->IsInstance<tvm::tir::VarNode>()) {
shape_dtype = shape.empty() ? DataType::Int(64) : shape[0]->dtype;
} else if (const auto* buffer_load = param.as<tvm::tir::BufferLoadNode>()) {
shape_dtype = buffer_load->buffer->elem_offset->dtype;
} else if (const auto* buffer_region = param.as<tvm::tir::BufferRegionNode>()) {
shape_dtype = buffer_region->buffer->elem_offset->dtype;
} else {
LOG(FATAL) << "ValueError: Unexpected type for TIR MatchBuffer.";
}
elem_offset = offset_factor ? tvm::tir::Var("elem_offset", shape_dtype)
: tvm::tir::make_const(shape_dtype, 0);
}
Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align,
offset_factor, buffer_type_str, axis_separators);
if (const auto* var = param.as<tvm::tir::VarNode>()) {
Expand Down
10 changes: 9 additions & 1 deletion src/tir/ir/data_type_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -534,13 +534,21 @@ PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) {
if (auto it = var_remap_.find(GetRef<Var>(op)); it != var_remap_.end()) {
return (*it).second;
}
if (is_enabled_) {
if (is_enabled_ && op->dtype != target_data_type_) {
Var new_var = GetRef<Var>(op).copy_with_dtype(target_data_type_);
var_remap_.Set(GetRef<Var>(op), new_var);
return std::move(new_var);
}
return GetRef<PrimExpr>(op);
}

PrimExpr IndexDataTypeNormalizer::VisitExpr_(const CastNode* op) {
if (is_enabled_) {
PrimExpr value = IndexDataTypeNormalizer::VisitExpr(op->value);
return value->dtype == target_data_type_ ? value : Cast(target_data_type_, value);
}
return IndexDataTypeRewriter::VisitExpr_(op);
}

} // namespace tir
} // namespace tvm
8 changes: 4 additions & 4 deletions tests/python/relax/test_ast_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,17 +218,17 @@ def test_shape_of():
s1_str = dump_ast(s1)
assert s1_str.startswith("ShapeExpr("), s1_str
assert "values=" in s1_str
assert "PrimExpr(value=`96`)" in s1_str
assert "PrimExpr(value=`54`)" in s1_str
assert "PrimExpr(value=`96i64`)" in s1_str, s1_str
assert "PrimExpr(value=`54i64`)" in s1_str


def test_shape_expr():
shape_expr = rx.ShapeExpr([10, 20])
shape_expr_str = dump_ast(shape_expr)
assert shape_expr_str.startswith("ShapeExpr(")
assert "values" in shape_expr_str
assert "PrimExpr(value=`10`)" in shape_expr_str
assert "PrimExpr(value=`20`)" in shape_expr_str
assert "PrimExpr(value=`10i64`)" in shape_expr_str
assert "PrimExpr(value=`20i64`)" in shape_expr_str


def test_call_packed():
Expand Down
Loading

0 comments on commit eee21b8

Please sign in to comment.