diff --git a/backends/cadence/fusion_g3/operators/op_mean.cpp b/backends/cadence/fusion_g3/operators/op_mean.cpp index ae0cfd1e27b..48f691a145a 100644 --- a/backends/cadence/fusion_g3/operators/op_mean.cpp +++ b/backends/cadence/fusion_g3/operators/op_mean.cpp @@ -118,7 +118,7 @@ Tensor& mean_out( for (int i = 0; i < kNnlibMaxDim; i++) { out_shape[i] = 1; inp_shape[i] = 1; - p_axis[i] = 1; + p_axis[i] = -1; } int num_axis_dims = prepare_data( @@ -135,20 +135,10 @@ Tensor& mean_out( num_out_dims = 1; } - int inp_shape_max = inp_shape[p_axis[0]]; - for (int i = 1; i < num_axis_dims; i++) { - if (inp_shape[p_axis[i]] > inp_shape_max) { - inp_shape_max = inp_shape[p_axis[i]]; - } + if ((out.dim() == 0) && (out.numel())) { + num_out_dims = 1; } - int scratch_size = in.numel() / inp_shape_max; - - executorch::runtime::Result temp_mem = - ctx.allocate_temp(scratch_size * sizeof(float)); - - void* __restrict__ p_scratch_in = (void* __restrict__)(temp_mem.get()); - XT_KERNEL_CHECK( ctx, out, @@ -160,8 +150,7 @@ Tensor& mean_out( inp_shape, num_inp_dims, p_axis, - num_axis_dims, - p_scratch_in); + num_axis_dims); } else { ET_KERNEL_CHECK( ctx,