diff --git a/aten/src/ATen/native/mps/operations/UpSample.mm b/aten/src/ATen/native/mps/operations/UpSample.mm index f346536318764..14bf161e7b62e 100644 --- a/aten/src/ATen/native/mps/operations/UpSample.mm +++ b/aten/src/ATen/native/mps/operations/UpSample.mm @@ -75,7 +75,7 @@ void upsample_out_template(const Tensor& input, const bool is_macOS_13_0_or_newer = is_macos_13_or_newer(); const int64_t output_width = output_size.size() > 1 ? output_size[1] : output_size[0]; - const int64_t output_height = output_size.size() > 1 ? output_size[0] : 1; + const int64_t output_height = output_size.size() > 1 ? output_size[0] : (output.dim() > 2 ? output.size(-2) : 1); const float scale_w = (scale_w_opt.value_or(0.) > 0.) ? static_cast(scale_w_opt.value()) : 0.; const float scale_h = (scale_h_opt.value_or(0.) > 0.) ? static_cast(scale_h_opt.value()) : 1.; const float offset_y = centerResults ? (scale_h - 1.0f) / 2.0f : 0.0f;