Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Nms_ir data_race solved #2600

Merged
merged 3 commits into from
Feb 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions topi/python/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,6 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):

max_threads = int(math.sqrt(
tvm.target.current_target(allow_none=False).max_num_threads))
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib = tvm.ir_builder.create()
p_data = ib.buffer_ptr(data)
p_sort_result = ib.buffer_ptr(sort_result)
Expand All @@ -126,6 +124,8 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
num_anchors = out.shape[1]
nthread_tx = max_threads
nthread_bx = num_anchors // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
i = bx * max_threads + tx
Expand All @@ -151,8 +151,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
with ib.if_scope(tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[b])):
with ib.for_range(0, p_valid_count[b] - nkeep) as l:
with ib.if_scope(i < 6):
p_out[(base_idx + (l + nkeep) * 6 + i)] = \
p_data[(base_idx + (l + nkeep) * 6 + i)]
p_out[(base_idx + (l + nkeep) * 6 + i)] = -1.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the meaning of -1? Are the last (valid_count - nkeep) bboxes dropped by this line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's based on the implementation of the GluonCV SSD, plz reference PR #2353 .

# Apply nms
with ib.for_range(0, p_valid_count[b]) as l:
offset_l = l * 6
Expand All @@ -169,6 +168,9 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
base_idx + offset_i + 2)
with ib.if_scope(iou >= nms_threshold):
p_out[base_idx + offset_i] = -1.0
ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this makes sense to me, could you also add this to proposal op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I’ll do that.

tvm.convert(['shared']),
tvm.expr.Call.Intrinsic, None, 0))
with ib.else_scope():
with ib.for_range(0, p_valid_count[b]) as c:
with ib.if_scope(i < 6):
Expand Down
3 changes: 3 additions & 0 deletions topi/python/topi/cuda/rcnn/proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
iou = calculate_overlap(p_data, (base_idx + l) * 5, (base_idx + i) * 5)
with ib.if_scope(iou > nms_threshold):
p_out[base_idx + i] = True
ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
tvm.convert(['shared']),
tvm.expr.Call.Intrinsic, None, 0))
return ib.get()


Expand Down
2 changes: 1 addition & 1 deletion topi/tests/python/test_topi_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def check_device(device):
f(tvm_data, tvm_valid_count, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4)

for device in ['llvm', 'opencl', 'cuda']:
masahi marked this conversation as resolved.
Show resolved Hide resolved
for device in ['llvm']:
check_device(device)


Expand Down