-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relax] Implement R.ensure_zero_offset and update memory planning for R.view #17145
Conversation
e2098bc
to
094428d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for making this follow-up. I have a couple of straight
It looks like the removal of the legalization for R.memory.view
is to avoid a phase-order issue, where StaticPlanBlockMemory
must be able to identify operators that may alias. Is that understanding correct?
Rather than moving some of the legalization steps into LowerVMBuiltin
, I propose we instead add a legalization_level
for each operator, and to LegalizeOps
. That way, we can distinguish between higher-abstraction operators (legalize before StaticPlanBlockMemory
) and lower-abstraction operators (legalize after StaticPlanBlockMemory
).
- If not specified, an operator would have legalization level of 10. The
R.memory.view
andR.memory.ensure_aligned
operators would have legalization level of 0. LegalizeOps
would default to a legalization level of 10. Any operator whose legalization level is less than theLegalizeOps
level would be skipped.- An additional pass of
LegalizeOps
would occur at the end of the Relax lowering pipeline, with legalization level of zero.
src/relax/op/memory/view.cc
Outdated
|
||
return Call(runtime_view_func, {data, shape, dtype, relative_byte_offset}); | ||
return Call(call->op, {data, shape, dtype, relative_byte_offset}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change means that R.memory.view
is still present in the output of LegalizeOps
, but a legalization function should replace the operator with a lowered form.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only does the inference of void
type args and leave the lowering to the later pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we want to do the inference of the shape/dtype prior to lowering, because it could result in unexpected StructInfo inference later on.
Suppose we have R.memory.view(arg, shape=[16])
. This returns a view into the first 16 elements of arg
, without changing the dtype. If an IRModule
pass updates the datatype of arg
, then that new datatype should also propagate to the view. However, legalizing it to R.memory.view(arg, shape=[16], dtype="float16")
would return a view into arg
, interpreting the first 32 bytes as if they were "float16"
. Now, if an IRModule
pass updates the datatype of arg
, the view would still be "float16"
. To avoid this issue, the unknown arguments shouldn't be filled in until the lowering is about to occur.
What if we were to remove .set_attr<FLegalize>
altogether, and only have .set_attr<FLowerBuiltin>
? That way, we preserve the R.memory.view
as-is until it is ready to be lowered. The LegalizeOps
pass would then be a no-op for R.memory.view
, and only the LowerRuntimeBuiltin
pass would change it at all.
Also, if you're interested, I have a partial implementation in this dev branch that includes the device-type validation and a TIR legalization. If you'd like to pull any of it over, you're welcome to it, as I've had it on the back-burner for far too long. |
Based on the current grouping, seems quite a bit of the runtime function dispatchings happens in LowerBuiltin, while legalizeOps primarily focused on lowering to TIR related functions. I think such distinction is still helpful, so that can be a factor considering moving the view legalization into the VMBuiltin. |
I don't think distinguishing between the style of implementation is a useful distinction to make. The important distinction is what functionality must still be observable outside of the operator, not the functional form of the legalized expression. My understanding is that |
There are different ways to look at this particular case. For this particular case, given the view was lowered to runtime function, it was primarily focused for the VM itself. One can also envision in future we have a codegen approach to get a view function that get inlined which is not needed in the VM approach. Introducing legalize ops with different levels can also have extra issues, as we need to run default scheduling of some of the ops. But for the certain legalization level we do not have to. In some sense, we are creating different grouping here. Perhaps one way to make it more clear is to rename LowerVMBuiltin to LowerRuntimeBuiltin, which have clear indication that that is a pass which takes charge of lowering all implementaitons of runtime builtin functions. |
I like this idea, but I don't think we should move the definition of the legalized form into the That would also allow
I like this idea, and have been toying around with some TIR implementations. The key limitation at the moment is the inability to construct and return a new |
I think having |
src/relax/op/memory/view.cc
Outdated
|
||
return Call(runtime_view_func, {data, shape, dtype, relative_byte_offset}); | ||
return Call(call->op, {data, shape, dtype, relative_byte_offset}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we want to do the inference of the shape/dtype prior to lowering, because it could result in unexpected StructInfo inference later on.
Suppose we have R.memory.view(arg, shape=[16])
. This returns a view into the first 16 elements of arg
, without changing the dtype. If an IRModule
pass updates the datatype of arg
, then that new datatype should also propagate to the view. However, legalizing it to R.memory.view(arg, shape=[16], dtype="float16")
would return a view into arg
, interpreting the first 32 bytes as if they were "float16"
. Now, if an IRModule
pass updates the datatype of arg
, the view would still be "float16"
. To avoid this issue, the unknown arguments shouldn't be filled in until the lowering is about to occur.
What if we were to remove .set_attr<FLegalize>
altogether, and only have .set_attr<FLowerBuiltin>
? That way, we preserve the R.memory.view
as-is until it is ready to be lowered. The LegalizeOps
pass would then be a no-op for R.memory.view
, and only the LowerRuntimeBuiltin
pass would change it at all.
@@ -286,8 +286,13 @@ class TokenAllocator1D { | |||
std::vector<StorageToken> full_pool_; | |||
}; | |||
|
|||
/*! \brief Check if the input op is "relax.reshape". */ | |||
bool IsReshape(const Expr& op) { return op.same_as(Op::Get("relax.reshape")); } | |||
/*! \brief Check if the input op is a memory op that may return the same buffer. */ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you on the updated docstring. As I'm looking at it, we may want to add this as another operator attribute (e.g. .set_attr<Bool>("ReturnMayAliasArgument", Bool(true))
), but that could be a follow-up PR instead.
include/tvm/runtime/device_api.h
Outdated
@@ -240,6 +240,10 @@ class TVM_DLL DeviceAPI { | |||
return device_type != kDLCPU && device_type != kDLMicroDev; | |||
} | |||
|
|||
static bool SupportsPointerArithmetics(int device_type) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we already have a vtable for DeviceAPI
, this should be a virtual function instead of a static boolean. That would allow individual DeviceAPI
implementations to independently mark that they support the pointer-arithmetic. (It would also allow checking for driver-dependent support, such as vulkan support for the optional VK_KHR_buffer_device_address
feature.)
Since host-side pointer arithmetic is not the default behavior for DLTensor::data
, the default implementation in DeviceAPI
would return false, and it could be overridden in CPUDeviceAPI
and CUDADeviceAPI
to return true.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, a nitpick: This isn't whether the device supports pointer arithmetic, but whether pointer arithmetic of a device-owned void* DLTensor::data
may be performed on the host. The TVM backends for both Vulkan and OpenCL support pointer arithmetic, but only within the generated kernels. Neither support pointer arithmetic being performed on the host.
tests/python/relax/test_op_view.py
Outdated
R.dtype("float32"), | ||
R.prim_value(0), | ||
) | ||
B = R.memory.view(A, shape=R.shape([64, 64]), dtype="float32", relative_byte_offset=0) | ||
return B | ||
|
||
After = tvm.relax.transform.LegalizeOps()(Before) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After replacing the .set_attr<FLegalize>
for R.memory.view
with .set_attr<FLowerBuiltin>
, the changes to these unit tests can be reverted. Instead, any use of LegalizeOps
in the unit tests would instead call LowerRuntimeBuiltin
.
@vinx13 I took a look at the current CI failures, and it looks like it pretty close to passing. If you'd like, applying the diff below should resolve the last 4 failing tests in CI. |
cc @Lunderberg let us merge this in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for making the changes (and for pinging me on it). Looks good!
Previously,
R.view
was legalized to extern call toruntime.TVMArrayCreateView
duringLegalizeOps
. This call to extern func can't be properly handled byStaticBlockPlanMemory
because it assumes the extern func does not retain the input buffer. Extern func returning a view of the input would break the ref count of the buffer. This PR defers the legalization ofR.view
so that it can be explicitly handled by memory planning.A new op
R.ensure_aligned
is added as discussed in #16955cc @tqchen @yongwww @Lunderberg