@@ -748,6 +748,138 @@ kernel void ps_roi_align<DTYPE>( \
748748REGISTER_PS_ROI_ALIGN_OP(float);
749749REGISTER_PS_ROI_ALIGN_OP(half);
750750
751+ template<typename T>
752+ kernel void ps_roi_align_backward(
753+ constant T * grad_output [[buffer(0)]],
754+ constant T * rois [[buffer(1)]],
755+ constant int64_t * channel_mapping [[buffer(2)]],
756+ device T * grad_input [[buffer(3)]],
757+ constant int64_t & output_size [[buffer(4)]],
758+ constant int64_t & channels [[buffer(5)]],
759+ constant int64_t & height [[buffer(6)]],
760+ constant int64_t & width [[buffer(7)]],
761+ constant int64_t & pooled_height [[buffer(8)]],
762+ constant int64_t & pooled_width [[buffer(9)]],
763+ constant int64_t & sampling_ratio [[buffer(10)]],
764+ constant int64_t & channels_out [[buffer(11)]],
765+ constant float & spatial_scale [[buffer(12)]],
766+ uint2 tgid [[threadgroup_position_in_grid]],
767+ uint2 tptg [[threads_per_threadgroup]],
768+ uint2 tid2 [[thread_position_in_threadgroup]]){
769+
770+ MPS_1D_KERNEL_LOOP(index, output_size, 1) {
771+ // (n, *, ph, pw) is an element in the pooled output
772+ int pw = index % pooled_width;
773+ int ph = (index / pooled_width) % pooled_height;
774+ int n = index / pooled_width / pooled_height / channels_out;
775+
776+ constant T* offset_rois = rois + n * 5;
777+ int roi_batch_ind = offset_rois[0];
778+
779+ // Do not using rounding; this implementation detail is critical
780+ T roi_start_w = offset_rois[1] * spatial_scale - static_cast<T>(0.5);
781+ T roi_start_h = offset_rois[2] * spatial_scale - static_cast<T>(0.5);
782+ T roi_end_w = offset_rois[3] * spatial_scale - static_cast<T>(0.5);
783+ T roi_end_h = offset_rois[4] * spatial_scale - static_cast<T>(0.5);
784+
785+ // Force too small ROIs to be 1x1
786+ T roi_width = roi_end_w - roi_start_w;
787+ T roi_height = roi_end_h - roi_start_h;
788+ T bin_size_h = roi_height / static_cast<T>(pooled_height);
789+ T bin_size_w = roi_width / static_cast<T>(pooled_width);
790+
791+ int c_in = channel_mapping[index];
792+
793+ // Do not using floor/ceil; this implementation detail is critical
794+ T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
795+ T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
796+
797+ const T grad_output_this_bin = grad_output[index];
798+
799+ // We use roi_bin_grid to sample the grid and mimic integral
800+ int roi_bin_grid_h = (sampling_ratio > 0)
801+ ? sampling_ratio
802+ : ceil(roi_height / pooled_height); // e.g., = 2
803+ int roi_bin_grid_w =
804+ (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
805+ const T count = roi_bin_grid_h * roi_bin_grid_w;
806+
807+ const int offset = (roi_batch_ind * channels + c_in) * height * width;
808+
809+ for (int iy = 0; iy < roi_bin_grid_h; iy++) {
810+ const T y = hstart +
811+ static_cast<T>(iy + .5f) * bin_size_h /
812+ static_cast<T>(roi_bin_grid_h);
813+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
814+ const T x = wstart +
815+ static_cast<T>(ix + .5f) * bin_size_w /
816+ static_cast<T>(roi_bin_grid_w);
817+
818+ T w1, w2, w3, w4;
819+ int x_low, x_high, y_low, y_high;
820+
821+ bilinear_interpolate_gradient(
822+ height,
823+ width,
824+ y,
825+ x,
826+ w1,
827+ w2,
828+ w3,
829+ w4,
830+ x_low,
831+ x_high,
832+ y_low,
833+ y_high,
834+ index);
835+
836+ T g1 = grad_output_this_bin * w1 / count;
837+ T g2 = grad_output_this_bin * w2 / count;
838+ T g3 = grad_output_this_bin * w3 / count;
839+ T g4 = grad_output_this_bin * w4 / count;
840+
841+ if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
842+ device atomic_uint* xAtomic = (device atomic_uint*)(grad_input + offset + y_low * width + x_low);
843+ device atomic_uint* yAtomic = (device atomic_uint*)(grad_input + offset + y_low * width + x_high);
844+ device atomic_uint* zAtomic = (device atomic_uint*)(grad_input + offset + y_high * width + x_low);
845+ device atomic_uint* wAtomic = (device atomic_uint*)(grad_input + offset + y_high * width + x_high);
846+
847+ // atomic_float data type is supported on Metal 3 onward.
848+ // TODO: Use native atomic_fetch_add_explicit for Metal 3.
849+ atomic_add_float(xAtomic, static_cast<T>(g1));
850+ atomic_add_float(yAtomic, static_cast<T>(g2));
851+ atomic_add_float(zAtomic, static_cast<T>(g3));
852+ atomic_add_float(wAtomic, static_cast<T>(g4));
853+ } // if
854+ } // ix
855+ } // iy
856+ }
857+ }
858+
859+ #define REGISTER_PS_ROI_ALIGN_BACKWARD_OP(DTYPE) \
860+ template \
861+ [[host_name("ps_roi_align_backward_" #DTYPE)]] \
862+ kernel void ps_roi_align_backward<DTYPE>( \
863+ constant DTYPE * grad_output [[buffer(0)]], \
864+ constant DTYPE * rois [[buffer(1)]], \
865+ constant int64_t * channel_mapping [[buffer(2)]], \
866+ device DTYPE * grad_input [[buffer(3)]], \
867+ constant int64_t & output_size [[buffer(4)]], \
868+ constant int64_t & channels [[buffer(5)]], \
869+ constant int64_t & height [[buffer(6)]], \
870+ constant int64_t & width [[buffer(7)]], \
871+ constant int64_t & pooled_height [[buffer(8)]], \
872+ constant int64_t & pooled_width [[buffer(9)]], \
873+ constant int64_t & sampling_ratio [[buffer(10)]], \
874+ constant int64_t & channels_out [[buffer(11)]], \
875+ constant float & spatial_scale [[buffer(12)]], \
876+ uint2 tgid [[threadgroup_position_in_grid]], \
877+ uint2 tptg [[threads_per_threadgroup]], \
878+ uint2 tid2 [[thread_position_in_threadgroup]]);
879+
880+ REGISTER_PS_ROI_ALIGN_BACKWARD_OP(float);
881+ REGISTER_PS_ROI_ALIGN_BACKWARD_OP(half);
882+
751883)VISION_METAL" ;
752884
753885static id<MTLLibrary> compileBinaryOpsLibrary (id<MTLDevice> device) {
0 commit comments