Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add experimental.tensor_resize operator #4492

Merged
merged 24 commits into from
Jan 25, 2023

Conversation

jantonguirao
Copy link
Contributor

@jantonguirao jantonguirao commented Dec 1, 2022

Signed-off-by: Joaquin Anton janton@nvidia.com

Category:

New feature

Description:

  • Introduces a layout-agnostic frontend for Resize
  • New operator experimental.tensor_resize
  • All dimensions can be resized (including channels), the implementation is able to detect unchanged dimensions at the beginning and the end, and treat those as non-spatial dimensions.
  • Up to 3 spatial dimensions supported

Additional information:

Affected modules and functionalities:

  • New op (related to resize)

Key points relevant for the review:

Resize parameter logic

Tests:

  • Existing tests apply
  • New tests added
    • Python tests
    • GTests
    • Benchmark
    • Other
  • N/A
    test_tensor_resize.test_resize_upsample_scales_nearest
    test_tensor_resize.test_resize_downsample_scales_nearest
    test_tensor_resize.test_resize_upsample_sizes_nearest
    test_tensor_resize.test_resize_upsample_scales_linear
    test_tensor_resize.test_resize_downsample_scales_linear
    test_tensor_resize.test_resize_alignment
    test_tensor_resize.test_resize_upsample_scales_cubic
    test_tensor_resize.test_resize_downsample_scales_cubic
    test_tensor_resize.test_resize_upsample_sizes_cubic
    test_tensor_resize.test_resize_downsample_sizes_cubic

Checklist

Documentation

  • Existing documentation applies
  • Documentation updated
    • Docstring
    • Doxygen
    • RST
    • Jupyter
    • Other
  • N/A

DALI team only

Requirements

  • Implements new requirements
  • Affects existing requirements
  • N/A

REQ IDs: N/A

JIRA TASK: DALI-3148

@jantonguirao jantonguirao force-pushed the tensor_resize branch 3 times, most recently from 909a166 to 282f150 Compare December 1, 2022 14:32
@jantonguirao jantonguirao mentioned this pull request Dec 2, 2022
18 tasks
@jantonguirao jantonguirao changed the title [WIP] tensor resize Add experimental.tensor_resize operator Dec 21, 2022
@jantonguirao jantonguirao marked this pull request as ready for review December 21, 2022 15:19
Comment on lines 110 to 119
} else if (alignment[d] == -1) {
// keep start of the ROI
center = params.src_lo[d];
} else if (alignment[d] == 1) {
// keep end of the ROI
center = params.src_hi[d];
} else {
DALI_FAIL(make_string("Unsupported alignment value ", alignment[d],
". Supported values are 0, -1, 1"));
}
Copy link
Contributor

@mzient mzient Jan 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To consider:

Suggested change
} else if (alignment[d] == -1) {
// keep start of the ROI
center = params.src_lo[d];
} else if (alignment[d] == 1) {
// keep end of the ROI
center = params.src_hi[d];
} else {
DALI_FAIL(make_string("Unsupported alignment value ", alignment[d],
". Supported values are 0, -1, 1"));
}
} else if (alignment[d] < 0) {
// keep start of the ROI
center = params.src_lo[d];
} else if (alignment[d] > 0) {
// keep end of the ROI
center = params.src_hi[d];
}

Having a sign (as opposed to exact value) define the behavior may simplify the client code. The usage of sign is not uncommon (see strcmp or the "spaceship" operator <=>).

Copy link
Contributor

@mzient mzient Jan 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A random thought: how about having an alignment as a float 0..1? 0.5 would denote the center.

center = (1 - alignment[d]) * params.src_lo[d] + alignment[d] * params.src_hi[d]

We could even skip the range check and have the invariant point anywhere - possibly outside the ROI.

Comment on lines 43 to 45
Accepted values are -1 (align with top-left corner), 0 (centered), 1 (align with bottom-right corner).
By default, 0 (centered) is assumed. Contains as many elements as dimensions provided for sizes/scales.
If only one value is provided, it is apply to all dimensions.)code",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you go with the sign, adjust the comment accordingly.

Comment on lines 40 to 42
.AddOptionalArg<int>("alignment", R"code(Determines the position of the ROI
when using scales (provided or calculated).
Copy link
Contributor

@mzient mzient Jan 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add some more explanation.


The real output size must be integer and may differ from "ideal" output size calculated as
input (or ROI) size multiplied by the scale factor. In that case, the output size is rounded
(according to `size_rounding` policy) and the input ROI needs to be adjusted to maintain
the scale factor. This parameter defines which point of the ROI should remain constant.

* | ``"round"`` - Rounds the resulting size to the nearest integer value, with halfway cases rounded away from zero.
* | ``"truncate"`` - Discards the fractional part of the resulting size.
* | ``"ceil"`` - Rounds up the resulting size to the next integer value.)code",
"truncate")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not "round"? Don't we round in normal Resize?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true, by default we round. I'll change it

Comment on lines 55 to 56
Contains as many elements as dimensions provided for sizes/scales. If only one value is provided, it is
applied to all dimensions.)code",
Copy link
Contributor

@mzient mzient Jan 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, with multiple paragraphs above, it's hard to infer the subject.

Suggested change
Contains as many elements as dimensions provided for sizes/scales. If only one value is provided, it is
applied to all dimensions.)code",
The value of this argument contains as many elements as dimensions provided for
sizes/scales. If only one value is provided, it is applied to all dimensions.)code",


This point is calculated as ``center = (1 - alignment) * roi_start + alignment * roi_end``.
Alignment 0.0 denotes alignment with the start of the ROI, 0.5 with the center of the region, and 1.0 with the end.
Note that when ROI is not specified, roi_start=0 and roi_end=size is assumed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Note that when ROI is not specified, roi_start=0 and roi_end=size is assumed.
Note that when ROI is not specified, roi_start=0 and roi_end=input_size is assumed.

@jantonguirao jantonguirao mentioned this pull request Jan 5, 2023
18 tasks
Signed-off-by: Joaquin Anton <janton@nvidia.com>
Signed-off-by: Joaquin Anton <janton@nvidia.com>
@jantonguirao
Copy link
Contributor Author

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [7083752]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [7083752]: BUILD FAILED

first_spatial_dim_ = 0;
spatial_ndim_ = orig_spatial_ndim + add_leading_spatial_ndim_;
expand_dims(expanded_input_shape_, input_shape, first_spatial_dim_, spatial_ndim_);
(void) input_shape; // should use expanded_input_shape_ from now on
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this line for? Does the compiler issue a warning when a variable is used after a void cast or sth? If not, leave just the comment.

Copy link
Contributor Author

@jantonguirao jantonguirao Jan 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was more for "documentation". Removed.

return true;
};

int can_trim_n = std::max(0, spatial_ndim_ - min_ndim); // at least 2 spatial dims should remain
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int can_trim_n = std::max(0, spatial_ndim_ - min_ndim); // at least 2 spatial dims should remain
int can_trim_n = std::max(0, spatial_ndim_ - min_ndim); // at least min_ndim spatial dims should remain

};

// at least min_ndim should remain
int can_trim_n = std::max(0, spatial_ndim_ - min_ndim);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick:
can_trim_n reads like a boolean (can I trim n dimensions?)

Suggested change
int can_trim_n = std::max(0, spatial_ndim_ - min_ndim);
int ndim_to_trim = std::max(0, spatial_ndim_ - min_ndim);

Signed-off-by: Joaquin Anton <janton@nvidia.com>
@jantonguirao
Copy link
Contributor Author

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [7093218]: BUILD STARTED

@jantonguirao
Copy link
Contributor Author

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [7093767]: BUILD STARTED

Signed-off-by: Joaquin Anton <janton@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [7095259]: BUILD FAILED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [7095769]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [7095769]: BUILD FAILED

Signed-off-by: Joaquin Anton <janton@nvidia.com>
@jantonguirao
Copy link
Contributor Author

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [7101791]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [7101791]: BUILD PASSED

@jantonguirao jantonguirao merged commit 8882a79 into NVIDIA:main Jan 25, 2023
aderylo pushed a commit to zpp-dali-2022/DALI that referenced this pull request Mar 17, 2023
    Introduces a layout-agnostic frontend for Resize
    New operator experimental.tensor_resize
    All dimensions can be resized (including channels), the implementation is able to detect unchanged dimensions at the beginning and the end, and treat those as non-spatial dimensions.
    Up to 3 spatial dimensions supported

Signed-off-by: Joaquin Anton <janton@nvidia.com>
@JanuszL JanuszL mentioned this pull request Sep 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants