Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Implement R.ensure_zero_offset and update memory planning for R.view #17145

Merged
merged 12 commits into from
Aug 6, 2024

Conversation

vinx13
Copy link
Member

@vinx13 vinx13 commented Jul 9, 2024

Previously, R.view was legalized to extern call to runtime.TVMArrayCreateView during LegalizeOps. This call to extern func can't be properly handled by StaticBlockPlanMemory because it assumes the extern func does not retain the input buffer. Extern func returning a view of the input would break the ref count of the buffer. This PR defers the legalization of R.view so that it can be explicitly handled by memory planning.

A new op R.ensure_aligned is added as discussed in #16955

cc @tqchen @yongwww @Lunderberg

@github-actions github-actions bot requested review from Lunderberg, tqchen and yongwww July 9, 2024 23:04
@vinx13 vinx13 force-pushed the feat/view-align branch 2 times, most recently from e2098bc to 094428d Compare July 10, 2024 04:22
Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for making this follow-up. I have a couple of straight

It looks like the removal of the legalization for R.memory.view is to avoid a phase-order issue, where StaticPlanBlockMemory must be able to identify operators that may alias. Is that understanding correct?

Rather than moving some of the legalization steps into LowerVMBuiltin, I propose we instead add a legalization_level for each operator, and to LegalizeOps. That way, we can distinguish between higher-abstraction operators (legalize before StaticPlanBlockMemory) and lower-abstraction operators (legalize after StaticPlanBlockMemory).

  • If not specified, an operator would have legalization level of 10. The R.memory.view and R.memory.ensure_aligned operators would have legalization level of 0.
  • LegalizeOps would default to a legalization level of 10. Any operator whose legalization level is less than the LegalizeOps level would be skipped.
  • An additional pass of LegalizeOps would occur at the end of the Relax lowering pipeline, with legalization level of zero.

python/tvm/relax/op/memory/view.py Outdated Show resolved Hide resolved
src/relax/backend/vm/vm_builtin_lower.cc Outdated Show resolved Hide resolved

return Call(runtime_view_func, {data, shape, dtype, relative_byte_offset});
return Call(call->op, {data, shape, dtype, relative_byte_offset});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change means that R.memory.view is still present in the output of LegalizeOps, but a legalization function should replace the operator with a lowered form.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only does the inference of void type args and leave the lowering to the later pass.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want to do the inference of the shape/dtype prior to lowering, because it could result in unexpected StructInfo inference later on.

Suppose we have R.memory.view(arg, shape=[16]). This returns a view into the first 16 elements of arg, without changing the dtype. If an IRModule pass updates the datatype of arg, then that new datatype should also propagate to the view. However, legalizing it to R.memory.view(arg, shape=[16], dtype="float16") would return a view into arg, interpreting the first 32 bytes as if they were "float16". Now, if an IRModule pass updates the datatype of arg, the view would still be "float16". To avoid this issue, the unknown arguments shouldn't be filled in until the lowering is about to occur.

What if we were to remove .set_attr<FLegalize> altogether, and only have .set_attr<FLowerBuiltin>? That way, we preserve the R.memory.view as-is until it is ready to be lowered. The LegalizeOps pass would then be a no-op for R.memory.view, and only the LowerRuntimeBuiltin pass would change it at all.

src/relax/op/memory/view.cc Outdated Show resolved Hide resolved
src/relax/transform/static_plan_block_memory.cc Outdated Show resolved Hide resolved
src/runtime/relax_vm/builtin.cc Outdated Show resolved Hide resolved
src/runtime/relax_vm/builtin.cc Outdated Show resolved Hide resolved
tests/python/relax/test_op_view.py Outdated Show resolved Hide resolved
@Lunderberg
Copy link
Contributor

Also, if you're interested, I have a partial implementation in this dev branch that includes the device-type validation and a TIR legalization. If you'd like to pull any of it over, you're welcome to it, as I've had it on the back-burner for far too long.

@tqchen
Copy link
Member

tqchen commented Jul 10, 2024

Based on the current grouping, seems quite a bit of the runtime function dispatchings happens in LowerBuiltin, while legalizeOps primarily focused on lowering to TIR related functions.

I think such distinction is still helpful, so that can be a factor considering moving the view legalization into the VMBuiltin.

@Lunderberg
Copy link
Contributor

Lunderberg commented Jul 10, 2024

Based on the current grouping, seems quite a bit of the runtime function dispatchings happens in LowerBuiltin, while legalizeOps primarily focused on lowering to TIR related functions.

I don't think distinguishing between the style of implementation is a useful distinction to make. The important distinction is what functionality must still be observable outside of the operator, not the functional form of the legalized expression.

My understanding is that LegalizeOps is for anything that can be lowered independent of the context in which it appears, and VMBuiltinLower is for operators that require some non-local context (e.g. the VM context pointer) in order to be lowered.

@tqchen
Copy link
Member

tqchen commented Jul 10, 2024

There are different ways to look at this particular case. For this particular case, given the view was lowered to runtime function, it was primarily focused for the VM itself. One can also envision in future we have a codegen approach to get a view function that get inlined which is not needed in the VM approach.

Introducing legalize ops with different levels can also have extra issues, as we need to run default scheduling of some of the ops. But for the certain legalization level we do not have to. In some sense, we are creating different grouping here.

Perhaps one way to make it more clear is to rename LowerVMBuiltin to LowerRuntimeBuiltin, which have clear indication that that is a pass which takes charge of lowering all implementaitons of runtime builtin functions.

@Lunderberg
Copy link
Contributor

Perhaps one way to make it more clear is to rename LowerVMBuiltin to LowerRuntimeBuiltin, which have clear indication that that is a pass which takes charge of lowering all implementaitons of runtime builtin functions.

I like this idea, but I don't think we should move the definition of the legalized form into the LowerRuntimeBuiltin. What if we were to instead add a new attribute, which has the same signature as FLegalize, but would be applied at the later point. This would allow LowerRuntimeBuiltin to replace anything that has the FLowerBuiltin attribute, and wouldn't require a distinction between different levels of FLegalize.

That would also allow FLowerBuiltin to only run after ToNonDataflow, and to be implemented in terms of impure functions. By constrast, since FLegalize may replace a call within a dataflow block, the implementation cannot be in terms of an impure call.

One can also envision in future we have a codegen approach to get a view function that get inlined which is not needed in the VM approach.

I like this idea, and have been toying around with some TIR implementations. The key limitation at the moment is the inability to construct and return a new NDArray if required. (Similar to the difficulties in returning a string that are blocking #16836 and #17103.)

@tqchen
Copy link
Member

tqchen commented Jul 11, 2024

I like this idea, but I don't think we should move the definition of the legalized form into the LowerRuntimeBuiltin. What if we were to instead add a new attribute, which has the same signature as FLegalize, but would be applied at the later point. This would allow LowerRuntimeBuiltin to replace anything that has the FLowerBuiltin attribute, and wouldn't require a distinction between different levels of FLegalize.

I think having FLowerBuiltin builtin attribute is great. lets go with that

python/tvm/relax/transform/transform.py Show resolved Hide resolved
src/relax/op/memory/view.cc Outdated Show resolved Hide resolved

return Call(runtime_view_func, {data, shape, dtype, relative_byte_offset});
return Call(call->op, {data, shape, dtype, relative_byte_offset});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want to do the inference of the shape/dtype prior to lowering, because it could result in unexpected StructInfo inference later on.

Suppose we have R.memory.view(arg, shape=[16]). This returns a view into the first 16 elements of arg, without changing the dtype. If an IRModule pass updates the datatype of arg, then that new datatype should also propagate to the view. However, legalizing it to R.memory.view(arg, shape=[16], dtype="float16") would return a view into arg, interpreting the first 32 bytes as if they were "float16". Now, if an IRModule pass updates the datatype of arg, the view would still be "float16". To avoid this issue, the unknown arguments shouldn't be filled in until the lowering is about to occur.

What if we were to remove .set_attr<FLegalize> altogether, and only have .set_attr<FLowerBuiltin>? That way, we preserve the R.memory.view as-is until it is ready to be lowered. The LegalizeOps pass would then be a no-op for R.memory.view, and only the LowerRuntimeBuiltin pass would change it at all.

@@ -286,8 +286,13 @@ class TokenAllocator1D {
std::vector<StorageToken> full_pool_;
};

/*! \brief Check if the input op is "relax.reshape". */
bool IsReshape(const Expr& op) { return op.same_as(Op::Get("relax.reshape")); }
/*! \brief Check if the input op is a memory op that may return the same buffer. */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you on the updated docstring. As I'm looking at it, we may want to add this as another operator attribute (e.g. .set_attr<Bool>("ReturnMayAliasArgument", Bool(true))), but that could be a follow-up PR instead.

@@ -240,6 +240,10 @@ class TVM_DLL DeviceAPI {
return device_type != kDLCPU && device_type != kDLMicroDev;
}

static bool SupportsPointerArithmetics(int device_type) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we already have a vtable for DeviceAPI, this should be a virtual function instead of a static boolean. That would allow individual DeviceAPI implementations to independently mark that they support the pointer-arithmetic. (It would also allow checking for driver-dependent support, such as vulkan support for the optional VK_KHR_buffer_device_address feature.)

Since host-side pointer arithmetic is not the default behavior for DLTensor::data, the default implementation in DeviceAPI would return false, and it could be overridden in CPUDeviceAPI and CUDADeviceAPI to return true.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, a nitpick: This isn't whether the device supports pointer arithmetic, but whether pointer arithmetic of a device-owned void* DLTensor::data may be performed on the host. The TVM backends for both Vulkan and OpenCL support pointer arithmetic, but only within the generated kernels. Neither support pointer arithmetic being performed on the host.

src/runtime/relax_vm/builtin.cc Outdated Show resolved Hide resolved
R.dtype("float32"),
R.prim_value(0),
)
B = R.memory.view(A, shape=R.shape([64, 64]), dtype="float32", relative_byte_offset=0)
return B

After = tvm.relax.transform.LegalizeOps()(Before)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After replacing the .set_attr<FLegalize> for R.memory.view with .set_attr<FLowerBuiltin>, the changes to these unit tests can be reverted. Instead, any use of LegalizeOps in the unit tests would instead call LowerRuntimeBuiltin.

@vinx13 vinx13 changed the title [Relax] Implement R.ensure_aligned and update memory planning for R.view [Relax] Implement R.ensure_zero_offset and update memory planning for R.view Jul 17, 2024
@Lunderberg
Copy link
Contributor

@vinx13 I took a look at the current CI failures, and it looks like it pretty close to passing. If you'd like, applying the diff below should resolve the last 4 failing tests in CI.

pr_17145_diff.txt

@tqchen
Copy link
Member

tqchen commented Aug 6, 2024

cc @Lunderberg let us merge this in

Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for making the changes (and for pinging me on it). Looks good!

@Lunderberg Lunderberg merged commit 05e2bc3 into apache:main Aug 6, 2024
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants