Skip to content

Commit

Permalink
[SPIR-V] Fix pushconstants offset calculation for 32 bit values (apac…
Browse files Browse the repository at this point in the history
…he#7620)

* Fix push constant offset for 32 bit value

* add test

* remove unused function from test

* add dynamic cumsum test

* skip if vulkan is not enabled

* replace dynamic cumsum test with dynamic argsort for now

Co-authored-by: Masahiro Masuda <masahi@129@gmail.com>
  • Loading branch information
2 people authored and trevor-m committed May 11, 2021
1 parent 8b89ca7 commit a1f4a6f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/target/spirv/ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,14 @@ Value IRBuilder::DeclarePushConstant(const std::vector<SType>& value_types) {
DataType t = value_types[i].type;
uint32_t nbits = t.bits() * t.lanes();
ICHECK_EQ(nbits % 8, 0);
offset += nbits / 8;
uint32_t bytes = (nbits / 8);
if (t.bits() == 32) {
// In our Vulkan runtime, each push constant always occupies 64 bit.
offset += bytes * 2;
} else {
ICHECK_EQ(t.bits(), 64);
offset += bytes;
}
}
// Decorate push constants as UBO
this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock);
Expand Down
34 changes: 34 additions & 0 deletions tests/python/unittest/test_target_codegen_spirv.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tvm
import tvm.testing
from tvm import te
from tvm import relay
from tvm.topi.math import cast
import numpy as np

Expand Down Expand Up @@ -71,5 +72,38 @@ def do_copy(A, B, n):
tvm.testing.assert_allclose(b.asnumpy(), ref)


def test_pushconstants():
if not tvm.testing.device_enabled("vulkan"):
return

def check_mod(mod, x_np, res_np):
target = "vulkan"
ctx = tvm.context(target, 0)
ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target)
res = ex.evaluate()(x_np).asnumpy()
tvm.testing.assert_allclose(res, res_np, atol=1e-5)

# Three 32 bit pushconstants: any_dim, stride, stride
dtype = "float32"
x = relay.var("x", shape=(relay.Any(),), dtype=dtype)
mod = tvm.IRModule()
mod["main"] = relay.Function([x], relay.sqrt(x))
x_np = np.random.uniform(size=(10,)).astype(dtype)
res_np = np.sqrt(x_np)

check_mod(mod, x_np, res_np)

# One 64 bit and one 32 bit constants
dtype = "int32"
x = relay.var("x", shape=(relay.Any(),), dtype=dtype)
mod = tvm.IRModule()
mod["main"] = relay.Function([x], relay.argsort(x))
x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype)
res_np = np.argsort(x_np)

check_mod(mod, x_np, res_np)


if __name__ == "__main__":
test_bool_load()
test_pushconstants()

0 comments on commit a1f4a6f

Please sign in to comment.