-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inference] Optimized some scattered optimization points in the framework #5544
base: feature/colossal-infer
Are you sure you want to change the base?
[Inference] Optimized some scattered optimization points in the framework #5544
Conversation
93db9fc
to
57a9574
Compare
f63a248
to
f7d4f6f
Compare
if end_indexes.numel() > 0: | ||
# contiguous cache exists | ||
end_idx = end_indexes[0].item() + 1 # open interval | ||
start_idx = end_idx - num_blocks_required # closed interval | ||
alloc_block_ids = torch.arange(start_idx, end_idx) | ||
alloc_block_ids = torch.arange(start_idx, end_idx, device=block_tables.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assigning alloc_block_ids.device
to that of block_tables
might trigger error in L259
self._block_states[alloc_block_ids] = 0
Notice that self._block_states
is on the host. If the passed-in block tables tensor was on a device, you will get runtime error Expected all tensors to be on the same device, but found ...
.
At this moment, there exist no difference of adding device=block_tables.device
here, since in batch bucket class the block tables tensor is on host, which cause no error and no functionality here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I will fix it.
@@ -34,18 +34,8 @@ __global__ void act_and_mul_kernel( | |||
|
|||
// Note(LiuYang):This func is designed for calculation mode like | |||
// silu(x[:half_1stdim]) * (x[half_1stdim:]) | |||
torch::Tensor silu_and_mul(const torch::Tensor& ins) | |||
void silu_and_mul(const torch::Tensor& ins, torch::Tensor& outs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't handle the condition of outs is None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a None value is passed in, it will be an illegal operation and C++ will report an error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean the case should be considered, whether you dispatch to a different kernel or not. The modifications here make it lose the capabilities of handling the regular way of calling the kernel (only inputs).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, let me think about how to fix it
.
@@ -20,7 +20,8 @@ def test_silu_and_mul(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype): | |||
act_out = torch.nn.functional.silu(ref_input[0], inplace=True) | |||
ref_out = act_out * ref_input[1] | |||
|
|||
origin_out = inference_ops.silu_and_mul(origin_input) | |||
origin_out = torch.empty_like(ref_out) | |||
inference_ops.silu_and_mul(origin_input, origin_out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above. No test for None
as output tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above reply
@@ -167,6 +171,7 @@ def llama_decoder_layer_forward( | |||
kv_seq_len: int = 0, | |||
output_tensor: torch.Tensor = None, | |||
norm_output: torch.Tensor = None, | |||
silu_and_mul_output: torch.Tensor = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if it's a good idea to just put the silu_and_mul
output tensor as an arg and pass it module by module to MLP layer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I also feel that there are too many parameters to pass like this, I feel that we can put all these temporary outputs into a struct for unified management in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then we only need to pass this struct each time."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some advice :) . Firstly, It's not a good idea to design a such ACT API that you should add a output_tensor as an arg, if you really want to do such things, you'd better make it a inplace API. Secondly, I don't think it's a good idea to help torch to do such memory management by you own before you really understand it or you've already designed a great memory management system, meanwhile, the profit of performance seems little and maybe it's just normal value fluctuation, so that this opt point may not work well. finally, maybe it's not a good idea to write trick code for just little performance profit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
During testing, it is possible to obtain a stable performance benefit, moreover, compared to other optimizations, such performance benefits already seem quite considerable. Also, this does not involve helping torch manage memory; instead, it should be attributed to our unreasonable use of memory. Of course, I also agree that this operator should be implemented as an inplace operator, which will avoid redundant memory allocation operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel that this is only a temporary optimization solution, and the optimal solution would be to implement this operator as an inplace one. And we can put a TODO here.
📌 Checklist before creating the PR
[doc/gemini/tensor/...]: A concise description
🚨 Issue number
📝 What does this PR do?
💥 Checklist before requesting a review
⭐️ Do you enjoy contributing to Colossal-AI?
Tell us more if you don't enjoy contributing to Colossal-AI.