Skip to content

Commit

Permalink
[TOPI] Support symbolic shape in einsum (apache#14521)
Browse files Browse the repository at this point in the history
* [TOPI] Support symbolic shape in einsum

* Update test_topi_einsum.py
  • Loading branch information
vinx13 authored Apr 7, 2023
1 parent b228037 commit 460374f
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 21 deletions.
31 changes: 20 additions & 11 deletions src/topi/einsum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,24 +98,33 @@ EinsumEquation EinsumEquation::FromString(const std::string& equation) {
}

PrimExpr GetBroadcastedExtent(const PrimExpr& extent1, const PrimExpr& extent2) {
int64_t extent1_value = GetConstInt(extent1);
int64_t extent2_value = GetConstInt(extent2);
if (extent1_value == extent2_value) {
const IntImmNode* extent1_imm = extent1.as<IntImmNode>();
const IntImmNode* extent2_imm = extent2.as<IntImmNode>();
if (extent1_imm != nullptr && extent2_imm != nullptr) {
if (extent1_imm->value == extent2_imm->value) {
return extent1;
} else if (extent1_imm->value == 1 || extent2_imm->value == 1) {
return Integer(std::max(extent1_imm->value, extent2_imm->value));
}
LOG(FATAL) << "Cannot broadcast extents " << extent1 << " and " << extent2;
throw;
} else if (extent1_imm != nullptr) {
return extent2;
} else if (extent2_imm != nullptr) {
return extent1;
} else if (extent1_value == 1 || extent2_value == 1) {
return Integer(std::max(extent1_value, extent2_value));
} else {
return max(extent1, extent2);
}
LOG(FATAL) << "Cannot broadcast extents " << extent1 << " and " << extent2;
throw;
}

PrimExpr GetIndexForBroadcastedDim(const Var& index, const PrimExpr& extent,
const PrimExpr& broadcasted_extent) {
if (GetConstInt(extent) == GetConstInt(broadcasted_extent)) {
return index;
} else {
return Integer(0);
// Check if current dimension is being broadcasted to `broadcasted_extent` (symbolic shape is
// handled)
if (is_one(extent) && !is_one(broadcasted_extent)) {
return make_zero(index.dtype());
}
return index;
}

/*! \brief The compute builder for Einsum */
Expand Down
52 changes: 42 additions & 10 deletions tests/python/topi/python/test_topi_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,59 @@
from tvm.topi.utils import get_const_tuple


def with_tvm(lam, *args):
def with_tvm(lam, shapes, ops, out_shape):
"""Take numpy arrays as args, convert them to TVM tensors and call `lam`.
Result of lambda is converted back to numpy array and returned.
"""
dev = tvm.cpu(0)
pls = [] # placeholders
vals_nd = [] # initial values
for i, arg in enumerate(args):
pls.append(te.placeholder(arg.shape, name="pl" + str(i)))
for i, (shape, arg) in enumerate(zip(shapes, ops)):
pls.append(te.placeholder(shape, name="pl" + str(i)))
vals_nd.append(tvm.nd.array(arg, dev))

out = lam(*pls)
out_nd = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=out.dtype), dev)
out_nd = tvm.nd.array(np.zeros(out_shape).astype(out.dtype), device=dev)
s = te.create_schedule([out.op])
m = tvm.build(s, pls + [out], "llvm")
m(*(vals_nd + [out_nd]))
return out_nd.numpy()


def verify_einsum(subscripts, shapes):
ops = []
def verify_einsum(subscripts, shapes, shape_dict={}):
ops = [] # ndarrays to be used as inputs
symbolic_shapes = [] # shapes to declare the placeholders
name_to_var = {}

def get_concrete_shape(shape):
return [shape_dict[s] if isinstance(s, str) else s for s in shape]

def get_symblic_shape_var(name, dtype="int32"):
if name not in name_to_var:
name_to_var[name] = te.var(name, dtype=dtype)
return name_to_var[name]

def get_symbolic_shape(shape):
return [get_symblic_shape_var(s) if isinstance(s, str) else s for s in shape]

for shape in shapes:
tmp = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(np.float32)
concrete_shape = get_concrete_shape(shape)
tmp = np.random.uniform(low=-1.0, high=1.0, size=concrete_shape).astype(np.float32)
ops.append(tmp)
symbolic_shape = get_symbolic_shape(shape)
symbolic_shapes.append(symbolic_shape)

c1 = np.einsum(subscripts, *ops)
out_shape = c1.shape

if len(ops) == 1:
c2 = with_tvm(lambda A: topi.einsum(subscripts, A), *ops)
c2 = with_tvm(lambda A: topi.einsum(subscripts, A), symbolic_shapes, ops, out_shape)
elif len(ops) == 2:
c2 = with_tvm(lambda A, B: topi.einsum(subscripts, A, B), *ops)
c2 = with_tvm(lambda A, B: topi.einsum(subscripts, A, B), symbolic_shapes, ops, out_shape)
elif len(ops) == 3:
c2 = with_tvm(lambda A, B, C: topi.einsum(subscripts, A, B, C), *ops)
c2 = with_tvm(
lambda A, B, C: topi.einsum(subscripts, A, B, C), symbolic_shapes, ops, out_shape
)

tvm.testing.assert_allclose(c1, c2, rtol=1e-5, atol=1e-5)

Expand All @@ -82,5 +102,17 @@ def test_einsum(equation, inputs):
verify_einsum(equation, inputs)


@pytest.mark.parametrize(
"equation,inputs,shape_dict",
[
("ij,jk->ik", [(2, "K"), (1, "N")], {"K": 3, "N": 4}),
("ij,jk->ik", [(2, "K"), ("K2", "N")], {"K": 3, "N": 4, "K2": 3}),
("ij,jk->ik", [(2, "K"), ("K2", "N")], {"K": 3, "N": 4, "K2": 1}),
],
)
def test_einsum_symblic_shape(equation, inputs, shape_dict):
verify_einsum(equation, inputs, shape_dict)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 460374f

Please sign in to comment.