Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix affine_grid and add global unittest #7578

Merged
merged 22 commits into from
Feb 28, 2022
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
46c1b49
fix affine_grid and add global unittest
hjchen2 Feb 23, 2022
6be297c
Merge branch 'master' into dev_fix_affine_grid_and_add_global_unittest
hjchen2 Feb 23, 2022
02c619c
Merge branch 'master' into dev_fix_affine_grid_and_add_global_unittest
hjchen2 Feb 24, 2022
882ac34
Merge branch 'master' into dev_fix_affine_grid_and_add_global_unittest
oneflow-ci-bot Feb 24, 2022
efc2617
Merge branch 'master' into dev_fix_affine_grid_and_add_global_unittest
oneflow-ci-bot Feb 24, 2022
97f2363
Merge branch 'master' into dev_fix_affine_grid_and_add_global_unittest
oneflow-ci-bot Feb 24, 2022
ae3b91f
Merge branch 'master' into dev_fix_affine_grid_and_add_global_unittest
oneflow-ci-bot Feb 24, 2022
ad355cb
Merge branch 'master' into dev_fix_affine_grid_and_add_global_unittest
oneflow-ci-bot Feb 24, 2022
2b5be77
Merge branch 'master' into dev_fix_affine_grid_and_add_global_unittest
oneflow-ci-bot Feb 24, 2022
8499491
Merge branch 'master' into dev_fix_affine_grid_and_add_global_unittest
oneflow-ci-bot Feb 24, 2022
77a6bd9
Merge branch 'master' into dev_fix_affine_grid_and_add_global_unittest
oneflow-ci-bot Feb 24, 2022
4435fa7
Merge branch 'master' into dev_fix_affine_grid_and_add_global_unittest
oneflow-ci-bot Feb 24, 2022
8542c8a
Merge branch 'master' into dev_fix_affine_grid_and_add_global_unittest
oneflow-ci-bot Feb 25, 2022
3110781
Merge branch 'master' into dev_fix_affine_grid_and_add_global_unittest
oneflow-ci-bot Feb 25, 2022
3e40fd4
Merge branch 'master' into dev_fix_affine_grid_and_add_global_unittest
oneflow-ci-bot Feb 25, 2022
833c3a2
Merge branch 'master' into dev_fix_affine_grid_and_add_global_unittest
oneflow-ci-bot Feb 25, 2022
1e20663
Merge branch 'master' into dev_fix_affine_grid_and_add_global_unittest
oneflow-ci-bot Feb 25, 2022
6d4ed33
Merge branch 'master' into dev_fix_affine_grid_and_add_global_unittest
oneflow-ci-bot Feb 25, 2022
36b465d
Merge branch 'master' into dev_fix_affine_grid_and_add_global_unittest
oneflow-ci-bot Feb 25, 2022
b3d921d
Merge branch 'master' into dev_fix_affine_grid_and_add_global_unittest
oneflow-ci-bot Feb 25, 2022
f159ee8
Merge branch 'master' into dev_fix_affine_grid_and_add_global_unittest
hjchen2 Feb 28, 2022
7e84b85
Merge branch 'master' into dev_fix_affine_grid_and_add_global_unittest
oneflow-ci-bot Feb 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions oneflow/user/kernels/affine_grid_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ class AffineGridKernel final : public user_op::OpKernel {
bool is_2d_grid = true;
if (size.NumAxes() == 5) { is_2d_grid = false; }

int64_t N = theta->shape().At(0);
int64_t theta_h = theta->shape().At(1);
int64_t theta_w = theta->shape().At(2);

if (is_2d_grid) {
int64_t N = size.At(0);
int64_t H = size.At(2);
int64_t W = size.At(3);
// generate base grid
Expand All @@ -56,7 +56,6 @@ class AffineGridKernel final : public user_op::OpKernel {
grid->mut_dptr<data_type>() + n * theta_h * H * W);
}
} else {
int64_t N = size.At(0);
int64_t D = size.At(2);
int64_t H = size.At(3);
int64_t W = size.At(4);
Expand Down Expand Up @@ -109,11 +108,11 @@ class AffineGridGradKernel final : public user_op::OpKernel {
bool is_2d_grid = true;
if (size.NumAxes() == 5) { is_2d_grid = false; }

int64_t N = dtheta->shape().At(0);
int64_t dtheta_h = dtheta->shape().At(1);
int64_t dtheta_w = dtheta->shape().At(2);

if (is_2d_grid) {
int64_t N = size.At(0);
int64_t H = size.At(2);
int64_t W = size.At(3);
// generate base grid
Expand All @@ -127,7 +126,6 @@ class AffineGridGradKernel final : public user_op::OpKernel {
dtheta->mut_dptr<data_type>() + n * dtheta_h * dtheta_w);
}
} else {
int64_t N = size.At(0);
int64_t D = size.At(2);
int64_t H = size.At(3);
int64_t W = size.At(4);
Expand Down
53 changes: 49 additions & 4 deletions oneflow/user/ops/affine_grid_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/framework/op_generated.h"
#include "oneflow/core/operator/operator.h"

namespace oneflow {

Expand Down Expand Up @@ -83,7 +84,51 @@ Maybe<void> CheckAttr_(const user_op::UserOpDefWrapper& def,
}

/*static*/ Maybe<void> AffineGridOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {
return InferLogicalTensorDesc(ctx);
const user_op::TensorDesc& theta = ctx->InputTensorDesc("theta", 0);
user_op::TensorDesc* grid = ctx->OutputTensorDesc("grid", 0);
const Shape& size = ctx->Attr<Shape>("size");
// Only support 2D or 3D affine grid with NCHW layout
// For 2D grid: theta = { N, 2, 3 },
// size = { N, C, H, W }
// grid = { N, H, W, 2 }
// For 3D grid: theta = { N, 3, 4 },
// size = { N, C, D, H, W }
// grid = { N, D, H, W, 3 }
const Shape& theta_shape = theta.shape();
bool is_2d_grid = true;
if (theta_shape.At(1) == 2) {
CHECK_EQ_OR_RETURN(theta_shape.At(2), 3) << "Theta shape MUST be (N, 2, 3) or (N, 3, 4)";
CHECK_EQ_OR_RETURN(size.NumAxes(), 4) << "Dimension of size MUST be 4, when 2d affine grid";
is_2d_grid = true;
} else if (theta_shape.At(1) == 3) {
CHECK_EQ_OR_RETURN(theta_shape.At(2), 4) << "Theta shape MUST be (N, 2, 3) or (N, 3, 4)";
CHECK_EQ_OR_RETURN(size.NumAxes(), 5) "Dimension of size MUST be 4, when 3d affine grid";
is_2d_grid = false;
} else {
CHECK_OR_RETURN(false) << "Theta MUST be 2D or 3D grid";
}

int64_t N = size.At(0);
const int64_t& parallel_num = ctx->parallel_ctx().parallel_num();
if (parallel_num > 1) {
const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("theta", 0);
Shape logical_shape = theta_shape;
logical_shape.Set(0, size.At(0));
const auto& physical_shape =
JUST(GetPhysicalShape(logical_shape, nd_sbp, ctx->parallel_desc(), ctx->parallel_ctx()));
N = physical_shape->At(0);
}
CHECK_EQ_OR_RETURN(theta_shape.At(0), size.At(0))
<< "The dimension 0 size of theta shape should be " << N << ", but got " << theta_shape.At(0);

*grid->mut_is_dynamic() = theta.is_dynamic();
Shape& grid_shape = *grid->mut_shape();
if (is_2d_grid) {
grid_shape = {N, size.At(2), size.At(3), 2};
} else {
grid_shape = {N, size.At(2), size.At(3), size.At(4), 3};
}
return Maybe<void>::Ok();
}

/* static */ Maybe<void> AffineGridOp::GetSbp(user_op::SbpContext* ctx) {
Expand All @@ -105,12 +150,12 @@ Maybe<void> CheckAttr_(const user_op::UserOpDefWrapper& def,
}

/* static */ Maybe<void> AffineGridGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
const user_op::TensorDesc& dgrid = ctx->InputTensorDesc("dgrid", 0);
const Shape& size = ctx->Attr<Shape>("size");

if (size.NumAxes() == 4) {
*(ctx->OutputTensorDesc("dtheta", 0)->mut_shape()) = {size.At(0), 2, 3};
*(ctx->OutputTensorDesc("dtheta", 0)->mut_shape()) = {dgrid.shape().At(0), 2, 3};
} else if (size.NumAxes() == 5) {
*(ctx->OutputTensorDesc("dtheta", 0)->mut_shape()) = {size.At(0), 3, 4};
*(ctx->OutputTensorDesc("dtheta", 0)->mut_shape()) = {dgrid.shape().At(0), 3, 4};
} else {
CHECK_OR_RETURN(false) << "size MUST be 4D or 5D";
}
Expand Down
68 changes: 68 additions & 0 deletions python/oneflow/test/modules/test_consistent_affine_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest
import numpy as np

import oneflow as flow
import oneflow.unittest
from oneflow.test_utils.automated_test_util import *


@autotest(n=1, rtol=1e-03, atol=1e-04, check_graph=False)
def _test_affine_grid_2d_with_random_data(test_case, placement, sbp):
N = random(1, 3).to(int).value() * 8
C = random(1, 8).to(int).value()
H = random(1, 8).to(int).value()
W = random(1, 8).to(int).value()
align_corners = oneof(True, False).value()
dims = [N, 2, 3]

theta = random_tensor(3, *dims).to_global(placement=placement, sbp=sbp)
output = torch.nn.functional.affine_grid(
theta, (N, C, H, W), align_corners=align_corners
)
return output


@autotest(n=1, rtol=1e-03, atol=1e-04, check_graph=False)
def _test_affine_grid_3d_with_random_data(test_case, placement, sbp):
N = random(1, 3).to(int) * 8
C = random(1, 8).to(int)
D = random(1, 8).to(int)
H = random(1, 8).to(int)
W = random(1, 8).to(int)
align_corners = oneof(True, False)
dims = [N, 3, 4]

theta = random_tensor(3, *dims).to_global(placement=placement, sbp=sbp)
output = torch.nn.functional.affine_grid(
theta, (N, C, D, H, W), align_corners=align_corners
)
return output


class TestAffineGrid(flow.unittest.TestCase):
@globaltest
def test_affine_grid(test_case):
for placement in all_placement():
for sbp in all_sbp(placement, max_dim=1):
_test_affine_grid_2d_with_random_data(test_case, placement, sbp)
_test_affine_grid_3d_with_random_data(test_case, placement, sbp)


if __name__ == "__main__":
unittest.main()