Skip to content

Commit

Permalink
fix cpplint and revert float64 change
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 3, 2021
1 parent df7688c commit 54a534c
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i
# The following algorithm performs parallel exclusive scan
# Up Sweep of exclusive scan
lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float32"))), "int64"
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float64"))), "int64"
)
with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << l2_width
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def compare(a, b):

# Sort the lower levels of the merge using odd-even sort, it's fast for small inputs
lower_lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(block_size, "float32"))), "int64"
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(block_size, "float64"))), "int64"
)

_odd_even_sort(
Expand All @@ -255,7 +255,7 @@ def compare(a, b):
)

upper_lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float32"))), "int64"
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64"
)

def get_merge_begin(source, base_idx, aCount, bCount, aStart, bStart, diag, step_count):
Expand Down
2 changes: 1 addition & 1 deletion src/target/spirv/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ class IRBuilder {
*/
Value GetPushConstant(Value ptr_push_const, const SType& v_type, uint32_t index);

// TODO doc
// TODO(masahi): doc
Value DeclareUniformBuffer(const std::vector<SType>& value_types, uint32_t binding);
Value GetUniform(Value ptr_ubo, const SType& v_type, uint32_t index);
/*!
Expand Down

0 comments on commit 54a534c

Please sign in to comment.