Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deconv support auto squeeze and fix autotest bug #9740

Merged
merged 12 commits into from
Jan 16, 2023
27 changes: 23 additions & 4 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class ConvBaseFunctor {
std::shared_ptr<one::Tensor> squeezed_conv_output = conv_out;
if (!is_batched) {
squeezed_conv_output = JUST(functional::Squeeze(conv_out, std::vector<int32_t>{0}));
channel_idx -= 1;
}
if (bias) {
return functional::BiasAdd(squeezed_conv_output, JUST(bias), channel_idx);
Expand Down Expand Up @@ -171,6 +172,18 @@ class DeConvBaseFunctor {
const std::vector<int32_t>& output_padding, const int32_t& groups,
const std::vector<int32_t>& dilation,
const std::string& data_format) const {
std::shared_ptr<one::Tensor> unsqueezed_input;
bool is_batched = true;
std::string func_name;
if (num_spatial_dims_ == 1) {
func_name = "deconv1d";
} else if (num_spatial_dims_ == 2) {
func_name = "deconv2d";
} else {
func_name = "deconv3d";
}
std::tie(unsqueezed_input, is_batched) = *JUST(batchify(input, num_spatial_dims_, func_name));
int32_t channel_idx = 1;
std::vector<int32_t> kernel_size_vec(num_spatial_dims_);
int32_t kernel_idx_offset = 2;
if (data_format == "channels_last") { kernel_idx_offset = 1; }
Expand All @@ -184,13 +197,19 @@ class DeConvBaseFunctor {
deconv_attrs.SetAllAttrs(static_cast<int32_t>(weight->shape()->At(1) * groups), kernel_size_vec,
padding, output_padding, stride, dilation, groups, data_format);
std::shared_ptr<one::Tensor> deconv_out =
JUST(OpInterpUtil::Dispatch<Tensor>(*deconv_op_, {input, weight}, deconv_attrs));
JUST(OpInterpUtil::Dispatch<Tensor>(*deconv_op_, {unsqueezed_input, weight}, deconv_attrs));
std::shared_ptr<one::Tensor> squeezed_deconv_output = deconv_out;
if (!is_batched) {
squeezed_deconv_output = JUST(functional::Squeeze(deconv_out, std::vector<int32_t>{0}));
channel_idx -= 1;
}
if (bias) {
auto& bias_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("axis");
bias_attrs.SetAllAttrs(static_cast<int32_t>(1));
return OpInterpUtil::Dispatch<Tensor>(*bias_op_, {deconv_out, JUST(bias)}, bias_attrs);
bias_attrs.SetAllAttrs(static_cast<int32_t>(channel_idx));
return OpInterpUtil::Dispatch<Tensor>(*bias_op_, {squeezed_deconv_output, JUST(bias)},
bias_attrs);
} else {
return deconv_out;
return squeezed_deconv_output;
}
}

Expand Down
10 changes: 0 additions & 10 deletions python/oneflow/nn/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,14 +433,6 @@ def reset_parameters(self) -> None:
init.uniform_(self.bias, -bound, bound)

def _conv_forward(self, x, weight, bias):
if self.channel_pos == "channels_first":
in_channel_axis = 1
else:
in_channel_axis = 3
if x.shape[in_channel_axis] != self.in_channels:
raise ValueError(
f"The input channels {x.shape[in_channel_axis]} should be equal to self.in_channels {self.in_channels}."
)
return flow._C.conv2d(
x,
weight,
Expand Down Expand Up @@ -612,8 +604,6 @@ def reset_parameters(self) -> None:
init.uniform_(self.bias, -bound, bound)

def _conv_forward(self, x, weight, bias):
if x.shape[1] != self.in_channels:
raise ValueError("The input channels should be equal to self.in_channels")
return flow._C.conv3d(
x,
weight,
Expand Down
25 changes: 25 additions & 0 deletions python/oneflow/test/modules/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1865,6 +1865,31 @@ def test_conv2d_with_random_data(test_case):
y = m(x)
return y

@unittest.skipIf(
version.parse(torch_original.__version__) <= version.parse("1.13.0"),
"conv module don't support unbatched input in PyTorch before '1.13.0'",
)
@autotest(n=5)
def test_conv2d_auto_squeeze_with_random_data(test_case):
channels = random(1, 6)
m = torch.nn.Conv2d(
in_channels=channels,
out_channels=random(1, 20),
kernel_size=random(1, 4),
stride=random() | nothing(),
padding=random(1, 3).to(int) | nothing(),
dilation=random(1, 5) | nothing(),
groups=random(1, 5) | nothing(),
padding_mode=constant("zeros") | nothing(),
bias=random_bool(),
)
m.train(random())
device = random_device()
m.to(device)
x = random_tensor(ndim=3, dim0=channels).to(device)
y = m(x)
return y

@autotest(n=5, check_graph=False)
def test_conv2d_0size_with_random_data(test_case):
channels = random(1, 6)
Expand Down
30 changes: 29 additions & 1 deletion python/oneflow/test/modules/test_deconv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import oneflow as flow
import oneflow.nn as nn
import oneflow.unittest
import torch as torch_original
from packaging import version


def _test_deconv_bias_false(test_case, device):
Expand Down Expand Up @@ -871,7 +873,7 @@ def test_deconv2d(test_case):
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])

@autotest()
@autotest(n=5)
def test_deconv2d_with_random_data(test_case):
channels = random(1, 6)
m = torch.nn.ConvTranspose2d(
Expand All @@ -883,6 +885,7 @@ def test_deconv2d_with_random_data(test_case):
dilation=random(1, 5) | nothing(),
groups=random(1, 5) | nothing(),
padding_mode=constant("zeros") | nothing(),
bias=random_bool(),
)
m.train(random())
device = random_device()
Expand All @@ -891,6 +894,31 @@ def test_deconv2d_with_random_data(test_case):
y = m(x)
return y

@unittest.skipIf(
version.parse(torch_original.__version__) <= version.parse("1.13.0"),
"deconv module don't support unbatched input in PyTorch before '1.13.0'",
)
@autotest(n=5)
def test_deconv2d_auto_squeeze_with_random_data(test_case):
channels = random(1, 6)
m = torch.nn.ConvTranspose2d(
in_channels=channels,
out_channels=random(1, 20),
kernel_size=random(1, 4),
stride=random() | nothing(),
padding=random(1, 3).to(int) | nothing(),
dilation=random(1, 5) | nothing(),
groups=random(1, 5) | nothing(),
padding_mode=constant("zeros") | nothing(),
bias=random_bool(),
)
m.train(random())
device = random_device()
m.to(device)
x = random_tensor(ndim=3, dim0=channels).to(device)
y = m(x)
return y

@autotest(check_graph=False)
def test_deconv2d_0size_with_random_data(test_case):
channels = random(1, 6)
Expand Down