-
Notifications
You must be signed in to change notification settings - Fork 793
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat lazy tensor indexing #9334
Merged
Merged
Changes from 25 commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
6d40284
feat(boxing): collective_boxing slice_boxing support 0size tensor
wyg1997 e1b02ae
test(Indexing): add lazy tensor basic indexing
wyg1997 8272026
add MaskTensor judgement
wyg1997 e636732
Merge remote-tracking branch 'origin/master' into feat-lazy_tensor_in…
wyg1997 3abb999
Merge remote-tracking branch 'origin/master' into feat-lazy_tensor_in…
wyg1997 e8e441b
Merge remote-tracking branch 'origin/master' into feat-lazy_tensor_in…
wyg1997 dcb2065
Merge branch 'master' into feat-lazy_tensor_indexing
wyg1997 a52eddb
format code
wyg1997 e03a98c
feat(TensorIndexing): support lazy advance getitem indexing
wyg1997 a4fab01
feat(Indexing): support lazy indexing for lazy_tensor and free_tensor
wyg1997 d15ed76
fix(Indexing): fix indexing test bug
wyg1997 4890857
test(Indexing): test all advance indexing
wyg1997 17fd5ff
test(GlobalIndexing): fix eager global indexing bug
wyg1997 01bbd25
test(Indexing): support combined indexing
wyg1997 8f95248
add last test cases
wyg1997 27c6961
Merge remote-tracking branch 'origin/master' into feat-lazy_tensor_in…
wyg1997 1cb928e
Merge branch 'master' into feat-lazy_tensor_indexing
wyg1997 3b0be93
fix merge bug
wyg1997 dd4b56a
Merge remote-tracking branch 'origin/master' into feat-lazy_tensor_in…
wyg1997 acf65d0
Merge remote-tracking branch 'origin/master' into feat-lazy_tensor_in…
wyg1997 a71de68
Merge remote-tracking branch 'origin/master' into feat-lazy_tensor_in…
wyg1997 d2634c8
fix lazy mode guard
wyg1997 6ec3ec4
Merge remote-tracking branch 'origin/master' into feat-lazy_tensor_in…
wyg1997 7eff650
test(Indexing): refine set scalar value test
wyg1997 37c5adc
test(Indexing): enable all bool tensor index setitem
wyg1997 4affec4
decrease test time
wyg1997 fdf4f1f
Merge branch 'master' into feat-lazy_tensor_indexing
wyg1997 baccecd
Merge branch 'master' into feat-lazy_tensor_indexing
wyg1997 ed18408
Merge branch 'master' into feat-lazy_tensor_indexing
wyg1997 60f2549
Merge branch 'master' into feat-lazy_tensor_indexing
wyg1997 40f30d5
refine 0size shape judgement
wyg1997 8c1785d
add comment
wyg1997 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,7 +42,7 @@ int64_t CountSpecifiedDims(const TensorIndex& index) { | |
specified_ndims++; | ||
} else if (index_item.IsTensor()) { | ||
const auto& tensor = index_item.tensor(); | ||
if (tensor->dtype() == DType::Int8() || tensor->dtype() == DType::UInt8()) { | ||
if (IsMaskTensor(tensor)) { | ||
specified_ndims += tensor->ndim(); | ||
} else { | ||
specified_ndims++; | ||
|
@@ -247,6 +247,11 @@ Maybe<Tensor> PermuteBackForGlobalTensor(const std::shared_ptr<Tensor>& result, | |
|
||
} // namespace | ||
|
||
bool IsMaskTensor(const std::shared_ptr<Tensor>& tensor) { | ||
return tensor->dtype() == DType::Int8() || tensor->dtype() == DType::UInt8() | ||
|| tensor->dtype() == DType::Bool(); | ||
} | ||
|
||
Maybe<void> PrepareSliceIndices(const TensorIndex& index, const Shape& shape, | ||
std::vector<detail::Slice>* slice_indices, | ||
TensorTuple* tensor_indices, std::vector<int64_t>* expand_dims, | ||
|
@@ -558,6 +563,8 @@ Maybe<void> UnifyInputAndIndicesOnDevice(const std::shared_ptr<Tensor>& x, | |
const auto tensor_index = tensor_indices[i]; | ||
if (tensor_index == nullptr) { continue; } | ||
if (tensor_index->is_local()) { | ||
// NOTE: LocalToGlobal should be called in eager mode | ||
LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false); | ||
Comment on lines
+566
to
+567
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LocalToGlobal 只能在 eager 模式下调用 |
||
tensor_indices[i] = JUST(ToGlobal(tensor_index, placement, | ||
std::vector<Symbol<SbpParallel>>(n, broadcast_sbp), | ||
grad_sbp_tuple, /*check_meta=*/false, /*copy=*/false)); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -104,7 +104,11 @@ bool IsFullSlice(int64_t start, int64_t stop, int64_t step, int64_t size) { | |
return Maybe<void>::Ok(); | ||
} | ||
/*static*/ Maybe<void> SliceUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { | ||
return InferLogicalTensorDesc(ctx); | ||
const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc("ref", 0); | ||
auto* y_desc = ctx->MutOutputTensorDesc("y", 0); | ||
y_desc->set_shape(ref_desc.shape()); | ||
y_desc->set_is_dynamic(ref_desc.is_dynamic()); | ||
return Maybe<void>::Ok(); | ||
} | ||
Comment on lines
106
to
112
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 之前 SliceUpdate 的物理 Tensor 推导是错误的,它支持 S + B -> S,是不能和逻辑 shape 推导共用推导函数(逻辑推导函数中有一些 shape 的检察,物理 tensor shape 推导不需要) |
||
/*static*/ Maybe<void> SliceUpdateOp::InferDataType(user_op::InferContext* ctx) { | ||
const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc("ref", 0); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里代码太长,只是加了一个作用域