-
Notifications
You must be signed in to change notification settings - Fork 631
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GPU filter kernel #4298
Add GPU filter kernel #4298
Conversation
!build |
CI MESSAGE: [6043632]: BUILD STARTED |
CI MESSAGE: [6043632]: BUILD FAILED |
!build |
CI MESSAGE: [6043975]: BUILD STARTED |
CI MESSAGE: [6043975]: BUILD PASSED |
return "wrap"; | ||
case BoundaryType::TRANSPARENT: | ||
return "transparent"; | ||
case BoundaryType::ISOLATED: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just a flag that can be combined with other border modes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, it is no longer used by the kernel, the "valid" mode will be part of the operator, at the kernel level it can be just achieved with the implicit roi handling (output smaller than the input + anchor).
case BoundaryType::ISOLATED: | ||
RunKernelBorderRemap<filter::ROIOnlyValid>(ctx, std::move(launch_kernel)); | ||
break; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ROI handling (including this one) can be achieved by specifying the output as smaller than input and appropriate anchors. To just use the pixels that have full halos present, you'd go with output_size[d] = input_size[d] - kernel_size[d] + 1
and anchored at (0, 0). Note that this mechanism allows you to use totally arbitrary ROI.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
* @{ | ||
*/ | ||
|
||
template <bool degenerated_extents> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see where degenerated_extents
is used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, good catch. I first implemented the idx reflect by hand, then realized we already have that. But without this degenerated_extents
optimization.
Reflecting input with 101 goes into infinite loop if the width is 1: you need to explicitly check that before the loop. It is visible cost for smaller filters. I added this as a hint and changed that condition to be if (degenerated_extents && width <= 1). Now I wonder if I should get rid of the optimization or add it to the existing utility. I am in favour of adding it to the utility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After all the changes it seemed to no longer make a visible difference, so I dropped the idea.
|
||
template <typename Out, typename In, typename W, bool has_channel_dim, bool has_sequence_dim> | ||
struct Filter2dGpu { | ||
/* It computes a corellation of the input and the filter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/* It computes a corellation of the input and the filter. | |
/* It computes a correlation of the input and the filter. |
Isn't it a convolution?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that strictly speaking the filter is "visited" in the reversed order to the signal in the convolution. Anyway, I saw similar remarks in some of the python libs, so thought that it makes sense to put such a clarifying remark. On the other hand, maybe it only adds noise?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
*/ | ||
|
||
/** | ||
* @brief First loads the input roi neceesary to compute the output of shape ``block_width x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* @brief First loads the input roi neceesary to compute the output of shape ``block_width x | |
* @brief First loads the input roi necessary to compute the output of shape ``block_width x |
// over an image is costly, so try to avoid it. | ||
BOOL_SWITCH( | ||
has_degenerated_extents, HasDegeneratedExtents, | ||
(using Loader = filter::InLoaderBorderRemap<filter::Reflect101<HasDegeneratedExtents>, In>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems the template parameter is ignored in the implementation
int inner_stride, | ||
int total_stride) const { | ||
if (idx < 0) { | ||
int reflect_dim_idx = (idx + 1) / inner_stride - 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd add some simple explanation here for future readers, perhaps with some example:
// First, shift by 1 towards 0, so that indices belonging to the same pixel translate to the right pixel index when dividing by the stride - (-6, -5, -4, -3, -2, -1) + 1 -> (-5, -4, -3. -2, -1, 0)
// (-5, -4, -3, -2, -1, 0) / 3 -> (-1, -1, -1, 0, 0, 0)
// Now, shift away from 0 to revert (-1, -1, -1, 0, 0, 0) + 1 -> (-2, -2, -2, -1, -1, -1)
// We don't want to reflect the order of the channels, so we revert the order
// (-2, -1, 0, -2, -1, 0) + 3 - 1 -> (0, 1, 2, 0, 1, 2)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly complaining about naming and comments, I guess nothing major, I need to take a second look tomorrow, for now looks good.
struct ShapeDesc { | ||
int64_t hwc; | ||
int wc, f, h, w, c; | ||
int filter_vol, r, s; | ||
int filter_top_anchor, filter_left_anchor; | ||
int in_workspace_width; | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you provide some documentation for this?
int h = h_end - h_begin; | ||
int wc = wc_end - wc_begin; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is interesting that the ROI description doesn't carry strides of original image. I assume they are later looked up in the shape_desc?
int inner_dim_idx = (idx + 1) % inner_stride + inner_stride - 1; | ||
return this->border_remap(reflect_dim_idx, reflect_dim_size) * inner_stride + inner_dim_idx; | ||
} | ||
if (idx >= total_stride) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just putting a comment to check how it works with the ROI - why we only reflect when we cross the stride, or maybe it's just a bit of weird naming.
return border_remap_strided(idx, sample_shape.w, sample_shape.c, sample_shape.wc); | ||
} | ||
|
||
DALI_HOST_DEV DALI_FORCEINLINE In load(const In* __restrict__ in, int y, int x, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could imagine load working in "absolute" coordinates and remapping them internally in this function, unless you want to make the access explicit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is slower unfortunately.
return idx; | ||
} | ||
|
||
DALI_HOST_DEV DALI_FORCEINLINE In load(const In* __restrict__ in, int y, int x, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like handling the borders inside the load would make sense here as well, as now we rely on the pass through of the functions above and detect the case in load either way. We could do it in both cases.
static_cast<int>(s), | ||
static_cast<int>(filter_top_anchor), | ||
static_cast<int>(filter_left_anchor), | ||
static_cast<int>(in_workspace_width)}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would love designated initializers, but it's in C++20.
int w = out_shape[num_sequence_dim + 1]; | ||
int c = has_channel_dim ? out_shape[num_sequence_dim + 2] : 1; | ||
max_height = std::max(max_height, h); | ||
max_width = std::max(max_width, w * c); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be highly nitpicky, but it might be clearer if we won't mix the width and w*c, and for example use max_row
for the former?
for (int s = 0; s < sample_desc.shape.s; s++) { | ||
int x = threadIdx.x + s * sample_desc.shape.c; | ||
for (int r = 0; r < sample_desc.shape.r; r++) { | ||
auto filter_coef = __ldg(filter + r * sample_desc.shape.s + s); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a random thought, did you entertain loading the filter into shared memory?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did try it. The differences where not big and with no clear winner as the winner depended on the filter size.
for (int y = 0; y < SampleDescT::lanes; y++) { | ||
load_row(y); | ||
} | ||
for (int y = SampleDescT::lanes; y < SampleDescT::lanes + sample_desc.shape.r - 1; y++) { | ||
load_row(y); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How big do we expect r to be compared to r? would it make sense to do (and be feasible), something similar to::
for (int y = 0; y < SampleDescT::lanes; y++) { | |
load_row(y); | |
} | |
for (int y = SampleDescT::lanes; y < SampleDescT::lanes + sample_desc.shape.r - 1; y++) { | |
load_row(y); | |
} | |
for (int i = 0; i < r / lanes; i++) { | |
#pragma unroll | |
for (int y = i * lanes; y < (i + 1) * lanes; y++) { | |
load_row(y); | |
} | |
} | |
for (int y = r / lanes * lanes; y < sample_desc.shape.r - 1; y++) { | |
load_row(y); | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think pragma unroll lanes
should do something along those lines but I saw no speed ups here.
c8f3f96
to
5d8a2bb
Compare
!build |
CI MESSAGE: [6583354]: BUILD STARTED |
CI MESSAGE: [6583354]: BUILD PASSED |
!build |
CI MESSAGE: [6584316]: BUILD STARTED |
CI MESSAGE: [6584316]: BUILD FAILED |
CI MESSAGE: [6584316]: BUILD PASSED |
template <int N, typename T> | ||
vec<N, T> rev(const vec<N, T>& v) { | ||
vec<N, T> out; | ||
for (int d = 0; d < N; d++) { | ||
out[N - d - 1] = v[d]; | ||
} | ||
return out; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could go to vec.h. Alternatively you can simply use std::reverse(v.begin(), v.end());
(Checked it, it works)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As to the std::reverse, I'd like to get a copy - I use it only "for presentation purposes", i.e. in error messages.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't even used. Please remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be in the operator. So let me move it to the utils. Maybe as reversed
rather than rev.
template <int N, typename T, typename U> | ||
void strides(vec<N, U>& out, U& total_stride, const vec<N, T>& v) { | ||
total_stride = 1; | ||
for (int d = 0; d < N; d++) { | ||
out[d] = total_stride; | ||
total_stride *= v[d]; | ||
} | ||
} | ||
|
||
template <int N, typename T, typename U> | ||
void strides(vec<N, U>& out, const vec<N, T>& v) { | ||
U total_strides; | ||
strides(out, total_strides, v); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have CalcStrides
in dali/kernels/common/utils.h
. I those should work with vec.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It works with shape and reverses the order of the extents which is already reversed by shape2vec.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then perhaps move those some common utils? I suggest to the header I mentioned above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved to geom_utils
// (-6, -5, -4, -3, -2, -1) + 1 -> (-5, -4, -3, -2, -1, 0) | ||
// (-5, -4, -3, -2, -1, 0) / 3 -> (-1, -1, -1, 0, 0, 0) | ||
// Then shift back away from 0 (-1, -1, -1, 0, 0, 0) - 1 -> (-2, -2, -2, -1, -1, -1) | ||
// Finally, with (num_channels - 1) we get the positive channels indecies |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Finally, with (num_channels - 1) we get the positive channels indecies | |
// Finally, with (num_channels - 1) we get the positive channel indices |
const auto& block_dim = block_setup.block_dim(); | ||
const auto& thread_idx = block_setup.thread_idx(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const auto& block_dim = block_setup.block_dim(); | |
const auto& thread_idx = block_setup.thread_idx(); | |
const auto& block_dim = block_setup.block_dim(); | |
const auto& thread_idx = block_setup.thread_idx(); |
What's the reason behind using references here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No particularly strong reson for that. It just that I expected some of the methods to return by value some to return reference to the memebr, so const auto &
seemed in place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const auto& block_dim = block_setup.block_dim(); | |
const auto& thread_idx = block_setup.thread_idx(); | |
auto block_dim = block_setup.block_dim(); | |
auto thread_idx = block_setup.thread_idx(); |
In device code, you should rather avoid references for such small objects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
} | ||
|
||
/* | ||
* The ``lanes`` paramter impacts perf in number of ways: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* The ``lanes`` paramter impacts perf in number of ways: | |
* The ``lanes`` parameter impacts perf in number of ways: |
/* | ||
* The ``lanes`` paramter impacts perf in number of ways: | ||
* 1. It reduces overhead of the for loops arithmetic when iterating over the filter extents: | ||
* We do not know the filter extents in compile time so those loops cannot be unrolled. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* We do not know the filter extents in compile time so those loops cannot be unrolled. | |
* We do not know the filter extents at compile time so those loops cannot be unrolled. |
baseline_out_ = baseline_output_.cpu(); | ||
out_view_ = output_.gpu(); | ||
if (!T::valid_only_mode) { | ||
kernel_gpu.Run(ctx_gpu, out_view_, in_view_, filters_view_, make_cspan(anchors_), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like you can put the kernel run outside of the if, as it's the same on both branches
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, thanks
// Then shift back away from 0 (-1, -1, -1, 0, 0, 0) - 1 -> (-2, -2, -2, -1, -1, -1) | ||
// Finally, with (num_channels - 1) we get the positive channels indecies | ||
// (-2, -1, 0, -2, -1, 0) + 3 - 1 -> (0, 1, 2, 0, 1, 2) | ||
int reflect_dim_idx = (idx + 1) / num_channels - 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A more general question. Have we considered/measured using fast_div
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was planning on checking that, but it seems there is already big registers pressure in this kernel so my first thought was that it probably won't help. But I didn't try it, I'll try to give it a go.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fast_div
is only for unsigned numbers. It should still be possible to use it, though not through div_mod
.
DALI_HOST_DEV DALI_FORCEINLINE In load(const In* __restrict__ in, const ivec<axes>& coords, | ||
const ShapeDesc<axes>& sample_shape) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A general thought. I see you are passing arguments by reference in those functions. While optimizing other kernels, we found that it generally works generally better to pass those by value. It'd worth trying to see if you can squeeze a better performance by doing that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's interesting, let me check that out. My intuition was that constness + reference should perform no worse than passing by value.
template <int axes_> | ||
struct ShapeDesc { | ||
static constexpr int axes = axes_; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, you could hold Surface<axes, T>
instances, that would abstract away the multiplication by strides, etc.
|
||
protected: | ||
template <typename T> | ||
vec<axes, T> ShapeAsVec(const TensorShape<ndim>& shape) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
static
vec<axes, T> ShapeAsVec(const TensorShape<ndim>& shape) { | |
static vec<axes, T> ShapeAsVec(const TensorShape<ndim>& shape) { |
} | ||
|
||
DALI_DEVICE DALI_FORCEINLINE void load_input_to_shm(const In* __restrict__ in, | ||
const ivec2& anchored_start) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const ivec2& anchored_start) const { | |
ivec2 anchored_start) const { |
const SampleDescT& sample_desc; | ||
const Inloader& in_loader; | ||
const BlockSetupT& block_setup; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With all those references it's likely to end up reaching to memory more than you'd wish.
with_shm_conv( | ||
sample_desc, in_loader, block_setup, idx_workspace, in_workspace, [&](auto&& conv) { | ||
stride_grid(block_start, grid_size, out_extents, out_strides, out_frame_stride, conv); | ||
}); | ||
|
||
} else { | ||
with_direct_conv(sample_desc, in_loader, block_setup, [&](auto&& conv) { | ||
stride_grid(block_start, grid_size, out_extents, out_strides, out_frame_stride, conv); | ||
}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on our experience with optimizing stuff for H100, I would recommend splitting this kernel into two variants (and possibly launching both, if the decision is made per-sample). The use of SHM will prevent the shared memory from being used as L1 cache, which will cripple the direct variant. There should be two launches - the samples eligible for shm variant should go to shm kernel and the others should go to the direct kernel, with reduced SHM consumption.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
DALI_DEVICE DALI_FORCEINLINE void for_each_output_point_in_log_block(const ivec2& block_start, | ||
const ivec2& out_extents, | ||
const Conv& conv, | ||
const Cb&& cb) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const Cb&& cb) { | |
Cb&& cb) { |
or
const Cb&& cb) { | |
const Cb& cb) { |
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
!build |
CI MESSAGE: [6789928]: BUILD STARTED |
CI MESSAGE: [6789928]: BUILD PASSED |
// the strides for the in_extents | ||
ivec<axes> in_strides; | ||
// The offset in shared memory that should be left | ||
// for precomuted indices. At that offset, the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// for precomuted indices. At that offset, the | |
// for precomputed indices. At that offset, the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
* @brief Computes the 2D or 3D (``axes_``) convolution of the input (``In``) and the | ||
* filter (``W``). The input must have the same number of spatial dims as the filter, but | ||
* can have extra sequence dim at the beginning (``has_sequence_dim``) and channels | ||
* dim at the end (``has_channel_dim``). If the ``enable_roi_`` is true, the outputs and inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it the opposite?
* dim at the end (``has_channel_dim``). If the ``enable_roi_`` is true, the outputs and inputs | |
* dim at the end (``has_channel_dim``). If the ``enable_roi_`` is false, the outputs and inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
/** | ||
* @brief The innermost extent consists of the width and channel extents flattened. | ||
* Thus, handling border condition for innermost extent requires extra step of computing back | ||
* the scurrent channel and spatial position. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* the scurrent channel and spatial position. | |
* the current channel and spatial position. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
ivec<SampleDescT::axes> start) const { | ||
auto anchored_start = start - sample_desc_.in_shape.anchor_shift; | ||
__syncthreads(); | ||
precompute_indices(in, anchored_start); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in
seems unused?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spot-on, I guess it is an artifact of splitting the precomuting and loading into two functions.
done
Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>
!build |
CI MESSAGE: [6800727]: BUILD STARTED |
CI MESSAGE: [6800727]: BUILD PASSED |
}; | ||
|
||
/** | ||
* @brief Computes the convolution of logical block size size (logical_block_extents) accessing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick - this is not a "brief" description.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I split those comments to have brief intro, blank line and the rest to follow.
return in[dot(coords, sample_shape.in_strides)]; | ||
} | ||
|
||
In fill_value; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A scalar fill value is a bit underwhelming.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree it may be useful. I just thought it is not a big priority as we already have operators that have scalar-only padding. As this PR grew considerbaly, I'd prefer making it a follow-up.
It should not be a big change on its own - if we enounter OOB access, we can retreive index channel (as we do for remapping in border reflect etc.) in the InLoader and take an input value accordingly.
int global_x = in_loader_.border_remap_innermost(anchored_start.x + x, sample_desc_.in_shape); | ||
for (int y = thread_idx.y; y < sample_desc_.workspace_desc.in_extents.y; y += block_dim.y) { | ||
int global_y = precomputed_idx_[y]; | ||
in_workspace_[dot(ivec2{x, y}, sample_desc_.workspace_desc.in_strides)] = | ||
in_loader_.load(in, ivec2{global_x, global_y}, sample_desc_.in_shape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fact that border_remap
is separated from load
, but load
still does some border handling (padding) seems a bit off...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For smaller kernels, it seems that the way border handling is done makes a visible difference in the perf. Even without the precompute_idx thingy it is faster when decoupled from loading than when it is done in the load method in the innermost loop (sadly, as I hoped for compiler to hoist what can be hoisted). And to actually handle each extent once (with the precompute helper), I don't see other way.
include/dali/core/geom/geom_utils.h
Outdated
template <int N, typename T, typename U> | ||
void strides(vec<N, U>& out, U& total_stride, const vec<N, T>& v) { | ||
total_stride = 1; | ||
for (int d = 0; d < N; d++) { | ||
out[d] = total_stride; | ||
total_stride *= v[d]; | ||
} | ||
} | ||
|
||
template <int N, typename T, typename U> | ||
void strides(vec<N, U>& out, const vec<N, T>& v) { | ||
U total_strides; | ||
strides(out, total_strides, v); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about:
template <int N, typename T, typename U> | |
void strides(vec<N, U>& out, U& total_stride, const vec<N, T>& v) { | |
total_stride = 1; | |
for (int d = 0; d < N; d++) { | |
out[d] = total_stride; | |
total_stride *= v[d]; | |
} | |
} | |
template <int N, typename T, typename U> | |
void strides(vec<N, U>& out, const vec<N, T>& v) { | |
U total_strides; | |
strides(out, total_strides, v); | |
} | |
template <int N, typename T, typename U> | |
U strides(vec<N, U>& out, const vec<N, T>& v) { | |
U total_stride = 1; | |
for (int d = 0; d < N; d++) { | |
out[d] = total_stride; | |
total_stride *= v[d]; | |
} | |
return total_stride; | |
} |
Also: I'm not convinced this is the correct file to put this utility in. This could land in the same place that we have CalcStrides - and, indeed, we could extend CalcStrides to support different traversal directions and to produce the total volume, e.g.:
template <bool outer_first = true, typename Strides, typename Shape>
auto CalcStrides(Strides &strides, const Shape& shape) {
using stride_t = std::remove_reference_t<decltype(strides[0])>;
int ndim = dali::size(shape);
resize_if_possible(strides, ndim); // no-op if strides is a plain array or std::array
stride_t ret = 1;
if (outer_first) {
for (int d = ndim - 1; d > 0; d--) {
strides[d] = ret;
ret *= shape[d];
}
} else {
for (int d = 0; d < ndim; d++) {
strides[d] = ret;
ret *= shape[d];
}
}
return ret;
}
Then you could use it like:
total_stride = CalcStrides<false>(strides, shape);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I'll move it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CalcStrides uses int64_t for strides accumulation and it likely needs to stay that way, as the strides[0], for example in transpoge_gpu_impl, can be fast_div<uint64_t>, that is not happy about *=
operator. I left it that way:
template <bool outer_first = true, typename Strides, typename Shape>
DALI_HOST_DEV std::remove_reference_t<decltype(std::declval<Strides>()[0])> CalcStrides(
Strides &strides, const Shape &shape) {
int ndim = dali::size(shape);
resize_if_possible(strides, ndim); // no-op if strides is a plain array or std::array
int64_t ret = 1;
if (outer_first) {
for (int d = ndim - 1; d >= 0; d--) {
strides[d] = ret;
ret *= shape[d];
}
} else {
for (int d = 0; d < ndim; d++) {
strides[d] = ret;
ret *= shape[d];
}
}
return ret;
}
// limitations under the License. | ||
|
||
#ifndef DALI_KERNELS_IMGPROC_CONVOLUTION_FILTER_GPU_CUH_ | ||
#define DALI_KERNELS_IMGPROC_CONVOLUTION_FILTER_GPU_CUH_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd recommend splitting this file into:
filter_gpu.h
, containing the kernel class, parameters, etc
and
filter_gpu_impl.cuh
with GPU kernels.
Depending on build times, we might even decide to instantiate just some input/output combinations in dali_kernels
to avoid compiling this file twice - in kernel tests and operators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>
!build |
CI MESSAGE: [6826816]: BUILD STARTED |
CI MESSAGE: [6826816]: BUILD FAILED |
Oh, the recent numpy issue. |
filter::filter<<<grid_setup.kernel_setup(), StaticConfigT::threadblock_size, | ||
required_shm_size_, ctx.gpu.stream>>>( | ||
samples_desc_dev, block_setup_provider, in_loader_provider, grid_setup, | ||
out_shape_provider, conv_factory); | ||
CUDA_CALL(cudaGetLastError()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, I guess it would be better to split the launch instead of redirecting all sample to the inefficient global load use case. This can be left for a follow-up, though.
void Run(KernelContext& ctx, const TensorListView<StorageGPU, Out, ndim>& out, | ||
const TensorListView<StorageGPU, const In, ndim>& in, | ||
const TensorListView<StorageGPU, const W, axes>& filters, | ||
const span<const ivec<axes>> anchors, boundary::BoundaryType border_type, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const span<const ivec<axes>> anchors, boundary::BoundaryType border_type, | |
const span<const ivec<axes>> anchors, | |
boundary::BoundaryType border_type, |
I honestly thought that boundary_type could be per-sample, seeing border_type
at the end and span
at the beginning of the line.
Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>
!build |
CI MESSAGE: [6831529]: BUILD STARTED |
CI MESSAGE: [6831529]: BUILD PASSED |
* Add GPU filter (convolution) kernel * Support 2D and 3D * Specialize to use shm for smaller kernels Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>
Adds GPU filter kernel (for 2d and 3d convolutions).
Signed-off-by: Kamil Tokarski ktokarski@nvidia.com
Category:
New feature (non-breaking change which adds functionality)
Description:
Adds GPU filter kernel with tests against straightoward CPU baseline.
Supports following border modes: reflect 101, reflect 1001, wrap, clamp, constant.
Passing smaller outputs than inputs and anchor manipulation can be used to get ROI handling. It will be used by the operator to implement "valid" mode, when filter lies fully within the input for all products comuted.
For reasonalby sized filters and number of channels provides implementation that first transfers input data to gpu shm, otherwise uses direct implementation.
Additional information:
(20 batches of random imgs from imgnet, decoded once and repeated "synthetically" in the loop, batch size = 128)
For smaller kernels it outperforms our sparable-conv ops.
The 2D separable convolution run simply as two launches of the new convolution outperforms the current impl for all window sizes supported by the current impl.
It supports 2D and 3D convolutions.
Affected modules and functionalities:
Key points relevant for the review:
Tests:
Checklist
Documentation
DALI team only
Requirements
REQ IDs: FILTER.01-FILTER.04, FILTER.06-FILTER.12
JIRA TASK: DALI-2949