Skip to content

Commit

Permalink
add group_norm_GB
Browse files Browse the repository at this point in the history
  • Loading branch information
Yin Hongyun committed Nov 27, 2024
1 parent f09e78e commit a6dbbb6
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 4 deletions.
29 changes: 29 additions & 0 deletions diopi_test/python/configs/diopi_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7202,6 +7202,35 @@
]
),
),

'group_norm_GB': dict(
name=['group_norm_GB'],
interface=['CustomizedTest'],
atol=1e-4,
rtol=1e-5,
para=dict(
num_groups=[32, 4, 5, 1],
eps=[1e-05, 1e-05, 1e-05, 1e-05],
reduced_axes = [[2, 3], [1, 3], [0, 3], [2, 3]],
channel_axis = [1, 2, 1, 0]
),
tensor_para=dict(
args=[
{
"ins": ["input"],
"shape": ((2, 256, 7, 10), (2, 256, 12, 12),
(12, 15, 8, 9),(3, 6, 9, 0)),
"dtype": [np.float32, np.float64, np.float16],
},
{
"ins": ["weight", "bias"],
"shape": ((256,), (12,),
(15,), (3,)),
"dtype": [np.float32, np.float64, np.float16],
},
]
),
),

'unique': dict(
name=['unique'],
Expand Down
40 changes: 39 additions & 1 deletion diopi_test/python/conformance/customized_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,4 +908,42 @@ def batch_norm_GB(input, running_mean, running_var, weight, bias, training=False
eps=eps,
)
out = out.permute(dims)
return out
return out

def group_norm_GB(input, num_groups, weight=None, bias=None, eps=1e-05, reduced_axes=[2, 3], channel_axis=1):

input_dims = list(input.size())
reduced_axes_set = set(reduced_axes)
dims = []
non_reduced_dims = []

for i, size in enumerate(input_dims):
if i == channel_axis:
continue
elif i in reduced_axes_set:
continue
else:
non_reduced_dims.append(i)
N = 1
for i in non_reduced_dims:
N = N * input.size(i)
HxW = 1
for i in reduced_axes:
HxW = HxW * input.size(i)
C = input.size(channel_axis)
dims = non_reduced_dims + [channel_axis] + reduced_axes
permuted_input = input.permute(dims)
reshaped_input = permuted_input.reshape([N, C, HxW, 1]).contiguous()
out = torch.nn.functional.group_norm(
reshaped_input,
num_groups,
weight=weight,
bias=bias,
eps=eps
)

reversed_order = [0]*len(dims)
for i in range(1, len(dims)):
reversed_order[dims[i]] = i
return out.reshape(permuted_input.shape).permute(reversed_order)

36 changes: 34 additions & 2 deletions diopi_test/python/conformance/diopi_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2864,8 +2864,8 @@ def batch_norm_GB(
)

check_returncode(ret)
GLOBAL_STATE["batch_norm_save_mean"] = save_mean
GLOBAL_STATE["batch_norm_save_invstd"] = save_invstd
GLOBAL_STATE["batch_norm_GB_save_mean"] = save_mean
GLOBAL_STATE["batch_norm_GB_save_invstd"] = save_invstd
return out


Expand Down Expand Up @@ -5242,6 +5242,38 @@ def norm_backward(grad_outputs, input, p, dim, keepdim=False, dtype=None):

return {k: v for k, v in out.items() if v.requires_grad}

def group_norm_GB(input, num_groups, weight=None, bias=None, eps=1e-05, reduced_axes=[2, 3], channel_axis=1):
dim = list(input.size().data)
N = 1
for i in range(len(dim)):
if i not in reduced_axes and i != channel_axis:
N = N * dim[i]
save_mean = Tensor((N, num_groups), input.get_dtype())
save_invstd = raw_like(save_mean)

weight = None if weight is None else weight
bias = None if bias is None else bias

reduced_axes = Sizes(reduced_axes)
out = raw_like(input)
func = check_function("diopiGroupNormGB")
ret = func(
input.context(),
out,
save_mean,
save_invstd,
input,
weight,
bias,
num_groups,
eps,
reduced_axes,
channel_axis
)
check_returncode(ret)
GLOBAL_STATE["group_norm_GB_save_mean"] = save_mean
GLOBAL_STATE["group_norm_GB_save_invstd"] = save_invstd
return out

def group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
dim = list(input.size().data)
Expand Down
60 changes: 59 additions & 1 deletion impl/torch/functions/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4183,8 +4183,66 @@ diopiError_t diopiForeachnormScalar(diopiContextHandle_t ctx, diopiTensorHandle_
return diopiSuccess;
}

diopiError_t diopiGroupNormGB(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd,
diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups,
double eps, diopiSize_t reduced_axes, const int64_t channel_axis) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto axisSize = atInput.size(channel_axis);
auto k = axisSize / num_groups;
at::IntArrayRef atReducedAxes = impl::aten::buildAtIntArray(reduced_axes);
std::vector<int64_t> dims;
int64_t N = 1;
for (int i = 0; i < atInput.dim(); i++) {
if (i == channel_axis) {
continue;
} else {
bool is_reduced_axis = false;
for (int m = 0; m < reduced_axes.len; m++) {
if (i == reduced_axes.data[m]) {
is_reduced_axis = true;
break;
}
}
if (is_reduced_axis) {
continue;
} else {
dims.push_back(i);
N *= atInput.size(i);
}
}
}
dims.push_back(channel_axis);
int64_t HxW = 1;
for(auto i = 0; i < reduced_axes.len; i++) {
dims.push_back(reduced_axes.data[i]);
HxW *= atInput.size(reduced_axes.data[i]);
}
auto C = atInput.size(channel_axis);
auto permutedInput = atInput.permute(dims);
auto permutedShape = permutedInput.sizes();
auto reshapedInput = permutedInput.reshape({N, C, HxW, 1}).contiguous();

auto atWeight = impl::aten::buildATen(weight);
auto atBias = impl::aten::buildATen(bias);
auto atOut = impl::aten::buildATen(out);
auto atSaveMean = impl::aten::buildATen(save_mean);
auto atSaveInvstd = impl::aten::buildATen(save_invstd);

std::vector<int64_t> reverse_order(dims.size());
for (auto i = 0; i < atInput.dim(); i++) {
reverse_order[dims[i]] = i;
}
auto tempOut = CALL_ATEN_CUDA_FUNC(native_group_norm, reshapedInput, atWeight, atBias, N, C, HxW, num_groups, eps);
at::native::copy_(atOut, std::get<0>(tempOut).reshape(permutedShape).permute(reverse_order), true);
at::native::copy_(atSaveMean, std::get<1>(tempOut), true);
at::native::copy_(atSaveInvstd, std::get<2>(tempOut), true);
return diopiSuccess;
}

diopiError_t diopiGroupNorm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd,
diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups, double eps) {
diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups,
double eps) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atWeight = impl::aten::buildATen(weight);
Expand Down
7 changes: 7 additions & 0 deletions proto/include/diopi/functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -3600,6 +3600,13 @@ DIOPI_API diopiError_t diopiGroupNorm(diopiContextHandle_t ctx, diopiTensorHandl
diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups,
double eps);

/**
* @brief Applies Group Normalization over a mini-batch of inputs.
*/
DIOPI_API diopiError_t diopiGroupNormGB(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd,
diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups,
double eps, diopiSize_t reduced_axes, const int64_t channel_axis);

/**
* @brief Compute the backward pass of diopiGroupNorm().
* @param[in] ctx Context environment.
Expand Down

0 comments on commit a6dbbb6

Please sign in to comment.