Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Fix segfaults from ordering of Let/Assert in MakePackedAPI #16543

Merged
merged 3 commits into from
Apr 4, 2024

Conversation

Lunderberg
Copy link
Contributor

Prior to this commit, the MakePackedAPI pass would output steps in the following order:

  1. Check the number of arguments.
  2. All LetStmt produced by the ArgBinder
  3. AssertStmt for the Type code checks for each argument.
  4. Additional AssertStmt produced by the ArgBinder.

This order can cause segfaults if a function was provided incorrect arguments. For example, an integer argument passed to a function expecting a DLTensor* would be dereferenced to find the tensor's data pointer (step (2)) before checking if it is valid to perform that dereference (step (3)). The same would occur when reading the size of a tensor's axes (step (2)) before checking whether the tensor is the correct dimensionality (step (4)).

This commit updates the steps to the following order.

  1. Check the number of arguments.
  2. Check the type code of each argument.
  3. All LetStmt and AssertStmt produced by the ArgBinder, in the order in which they are generated.

@Lunderberg
Copy link
Contributor Author

This came about while debugging the implementation of #16542, but is otherwise unrelated.

@Lunderberg
Copy link
Contributor Author

Merged from main to PR branch, to resolve CI breakage that was fixed in #16546.

@Lunderberg Lunderberg force-pushed the tir_reorder_dlpack_asserts branch 2 times, most recently from a7f822a to 4a064de Compare February 29, 2024 23:56
Prior to this commit, the `MakePackedAPI` pass would output steps in
the following order:

1. Check the number of arguments.
2. All `LetStmt` produced by the `ArgBinder`
3. `AssertStmt` for the Type code checks for each argument.
4. Additional `AssertStmt` produced by the `ArgBinder`.

This order can cause segfaults if a function was provided incorrect
arguments.  For example, an integer argument passed to a function
expecting a `DLTensor*` would be dereferenced to find the tensor's
data pointer (step (2)) before checking if it is valid to perform that
dereference (step (3)).  The same would occur when reading the size of
a tensor's axes (step (2)) before checking whether the tensor is the
correct dimensionality (step (4)).

This commit updates the steps to the following order.

1. Check the number of arguments.
2. Check the type code of each argument.
3. All `LetStmt` and `AssertStmt` produced by the `ArgBinder`, in the
   order in which they are generated.
@Lunderberg Lunderberg force-pushed the tir_reorder_dlpack_asserts branch from 4a064de to e2b7871 Compare March 13, 2024 22:02
Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely good to have clean errors instead of segfaults and more validation. I'm also glad to see tests for these error cases.

Comment on lines 51 to 54
.trim()
.split("\n")
.last()
.unwrap_or("")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonder how this came up. Just for readability?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that's really weird. I'm guessing it was from bouncing over to the PR branch of #16183, which touched a number of the FFI bindings. I've removed this delta from the PR.

@@ -186,18 +191,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
if (!(buffer->dtype == DataType::Int(1) || buffer->dtype == DataType::Int(4) ||
buffer->dtype == DataType::UInt(4))) {
auto type_msg = tvm::tir::StringImm(type_err_msg.str());
asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this just a duplicate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. The buffer's dimensionality is checked earlier, so this is entirely a duplicate check on the dimensionality.

@Lunderberg
Copy link
Contributor Author

Last CI failure is due to a flaky hexagon test. I've added a @pytest.mark.skip to this PR, now just waiting on CI.

@Lunderberg Lunderberg merged commit cd08356 into apache:main Apr 4, 2024
18 checks passed
@Lunderberg Lunderberg deleted the tir_reorder_dlpack_asserts branch April 4, 2024 23:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants