Skip to content

Commit 3264b10

Browse files
authored
Changes and fixes required for squeezenet1.1-opset7 onnx (#639)
1 parent d659f11 commit 3264b10

26 files changed

+729
-93
lines changed

burn-core/src/nn/padding.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ impl PaddingConfig1d {
3434
}
3535

3636
/// Padding configuration for 2D operators.
37-
#[derive(Module, Config, Debug)]
37+
#[derive(Module, Config, Debug, PartialEq)]
3838
pub enum PaddingConfig2d {
3939
/// Dynamically calculate the amount of padding necessary to ensure that the output size will be
4040
/// the same as the input.

burn-core/src/nn/pool/avg_pool1d.rs

+14
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,20 @@ pub struct AvgPool1dConfig {
2121
}
2222

2323
/// Applies a 1D avg pooling over input tensors.
24+
///
25+
/// See [AvgPool1dConfig](AvgPool1dConfig) for details.
26+
///
27+
/// # Remarks
28+
///
29+
/// The zero-padding values will be included in the calculation
30+
/// of the average. This means that the zeros are counted as
31+
/// legitimate values, and they contribute to the denominator
32+
/// when calculating the average. This is equivalent to
33+
/// `torch.nn.AvgPool2d` with `count_include_pad=True`.
34+
///
35+
/// TODO: Add support for `count_include_pad=False`, see
36+
/// [Issue 636](https://github.com/burn-rs/burn/issues/636)
37+
2438
#[derive(Module, Debug, Clone)]
2539
pub struct AvgPool1d {
2640
stride: usize,

burn-core/src/nn/pool/avg_pool2d.rs

+14-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::tensor::Tensor;
88
use burn_tensor::module::avg_pool2d;
99

1010
/// Configuration to create a [2D avg pooling](AvgPool2d) layer.
11-
#[derive(Config)]
11+
#[derive(Config, Debug)]
1212
pub struct AvgPool2dConfig {
1313
/// The size of the kernel.
1414
pub kernel_size: [usize; 2],
@@ -21,6 +21,19 @@ pub struct AvgPool2dConfig {
2121
}
2222

2323
/// Applies a 2D avg pooling over input tensors.
24+
///
25+
/// See [AvgPool2dConfig](AvgPool2dConfig) for details.
26+
///
27+
/// # Remarks
28+
///
29+
/// The zero-padding values will be included in the calculation
30+
/// of the average. This means that the zeros are counted as
31+
/// legitimate values, and they contribute to the denominator
32+
/// when calculating the average. This is equivalent to
33+
/// `torch.nn.AvgPool2d` with `count_include_pad=True`.
34+
///
35+
/// TODO: Add support for `count_include_pad=False`, see
36+
/// [Issue 636](https://github.com/burn-rs/burn/issues/636)
2437
#[derive(Module, Debug, Clone)]
2538
pub struct AvgPool2d {
2639
stride: [usize; 2],

burn-import/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ List taken from [here](https://github.com/onnx/onnx/blob/main/docs/Operators.md)
2323
- [ ] Asinh
2424
- [ ] Atan
2525
- [ ] Atanh
26-
- [ ] AveragePool
26+
- [ ] AveragePool1d
27+
- [x] AveragePool2d
2728
- [x] BatchNormalization
2829
- [ ] Bernoulli
2930
- [ ] BitShift
@@ -107,7 +108,6 @@ List taken from [here](https://github.com/onnx/onnx/blob/main/docs/Operators.md)
107108
- [ ] MatMul
108109
- [ ] MatMulInteger
109110
- [ ] Max
110-
- [ ] MaxPool
111111
- [ ] MaxPool1d
112112
- [x] MaxPool2d
113113
- [ ] MaxRoiPool

burn-import/onnx-tests/build.rs

+8-5
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,20 @@ fn main() {
77
// Add onnx models.
88
ModelGen::new()
99
.input("tests/add/add.onnx")
10-
.input("tests/sub/sub.onnx")
11-
.input("tests/mul/mul.onnx")
12-
.input("tests/div/div.onnx")
10+
.input("tests/avg_pool2d/avg_pool2d.onnx")
1311
.input("tests/concat/concat.onnx")
1412
.input("tests/conv1d/conv1d.onnx")
1513
.input("tests/conv2d/conv2d.onnx")
16-
.input("tests/dropout/dropout.onnx")
14+
.input("tests/div/div.onnx")
15+
.input("tests/dropout/dropout_opset16.onnx")
16+
.input("tests/dropout/dropout_opset7.onnx")
1717
.input("tests/global_avr_pool/global_avr_pool.onnx")
18-
.input("tests/softmax/softmax.onnx")
1918
.input("tests/log_softmax/log_softmax.onnx")
2019
.input("tests/maxpool2d/maxpool2d.onnx")
20+
.input("tests/mul/mul.onnx")
21+
.input("tests/reshape/reshape.onnx")
22+
.input("tests/softmax/softmax.onnx")
23+
.input("tests/sub/sub.onnx")
2124
.out_dir("model/")
2225
.run_from_script();
2326

Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/usr/bin/env python3
2+
3+
# used to generate model: avg_pool2d.onnx
4+
5+
import torch
6+
import torch.nn as nn
7+
8+
9+
class Model(nn.Module):
10+
def __init__(self):
11+
super(Model, self).__init__()
12+
13+
# TODO when https://github.com/burn-rs/burn/issues/636 is resolved, test this with a model
14+
# that uses `count_include_pad=False` and padding=(2, 1)
15+
self.pool2d = nn.AvgPool2d((4, 2), stride=(
16+
2, 1), padding=(0, 0), count_include_pad=False)
17+
18+
def forward(self, x):
19+
x = self.pool2d(x)
20+
return x
21+
22+
23+
def main():
24+
# Set seed for reproducibility
25+
torch.manual_seed(3)
26+
27+
# Print options
28+
torch.set_printoptions(precision=3)
29+
30+
# Export to onnx
31+
model = Model()
32+
model.eval()
33+
device = torch.device("cpu")
34+
35+
file_name = "avg_pool2d.onnx"
36+
test_input = torch.randn(1, 1, 5, 5, device=device)
37+
torch.onnx.export(model, test_input, file_name,
38+
verbose=False, opset_version=16)
39+
40+
print("Finished exporting model to {}".format(file_name))
41+
42+
# Output some test data for use in the test
43+
print("Test input data shape of ones: {}".format(test_input.shape))
44+
print("Test input data of ones: {}".format(test_input))
45+
output = model.forward(test_input)
46+
print("Test output data shape: {}".format(output.shape))
47+
print("Test output: {}".format(output))
48+
49+
50+
if __name__ == '__main__':
51+
main()

burn-import/onnx-tests/tests/dropout/dropout.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def main():
2323
model.eval()
2424
device = torch.device("cpu")
2525

26-
file_name = "dropout.onnx"
26+
file_name = "dropout_opset16.onnx"
2727
test_input = torch.ones(2, 4, 10, 15, device=device)
2828
torch.onnx.export(model, test_input, file_name,
2929
training=torch.onnx.TrainingMode.TRAINING,
Binary file not shown.

burn-import/onnx-tests/tests/onnx_tests.rs

+62-8
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,20 @@ macro_rules! include_models {
1212
// ATTENTION: Modify this macro to include all models in the `model` directory.
1313
include_models!(
1414
add,
15-
sub,
16-
mul,
17-
div,
15+
avg_pool2d,
1816
concat,
1917
conv1d,
2018
conv2d,
21-
dropout,
19+
div,
20+
dropout_opset16,
21+
dropout_opset7,
2222
global_avr_pool,
23-
softmax,
2423
log_softmax,
25-
maxpool2d
24+
maxpool2d,
25+
mul,
26+
reshape,
27+
softmax,
28+
sub
2629
);
2730

2831
#[cfg(test)]
@@ -151,8 +154,27 @@ mod tests {
151154
}
152155

153156
#[test]
154-
fn dropout() {
155-
let model: dropout::Model<Backend> = dropout::Model::default();
157+
fn dropout_opset16() {
158+
let model: dropout_opset16::Model<Backend> = dropout_opset16::Model::default();
159+
160+
// Run the model with ones as input for easier testing
161+
let input = Tensor::<Backend, 4>::ones([2, 4, 10, 15]);
162+
163+
let output = model.forward(input);
164+
165+
let expected_shape = Shape::from([2, 4, 10, 15]);
166+
assert_eq!(output.shape(), expected_shape);
167+
168+
let output_sum = output.sum().into_scalar();
169+
170+
let expected_sum = 1200.0; // from pytorch
171+
172+
assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2)));
173+
}
174+
175+
#[test]
176+
fn dropout_opset7() {
177+
let model: dropout_opset7::Model<Backend> = dropout_opset7::Model::default();
156178

157179
// Run the model with ones as input for easier testing
158180
let input = Tensor::<Backend, 4>::ones([2, 4, 10, 15]);
@@ -255,4 +277,36 @@ mod tests {
255277

256278
assert_eq!(output.to_data(), expected);
257279
}
280+
281+
#[test]
282+
fn avg_pool2d() {
283+
// Initialize the model without weights (because the exported file does not contain them)
284+
let model: avg_pool2d::Model<Backend> = avg_pool2d::Model::new();
285+
286+
// Run the model
287+
let input = Tensor::<Backend, 4>::from_floats([[[
288+
[-0.077, 0.360, -0.782, 0.072, 0.665],
289+
[-0.287, 1.621, -1.597, -0.052, 0.611],
290+
[0.760, -0.034, -0.345, 0.494, -0.078],
291+
[-1.805, -0.476, 0.205, 0.338, 1.353],
292+
[0.374, 0.013, 0.774, -0.109, -0.271],
293+
]]]);
294+
let output = model.forward(input);
295+
let expected = Data::from([[[[0.008, -0.131, -0.208, 0.425]]]]);
296+
297+
output.to_data().assert_approx_eq(&expected, 3);
298+
}
299+
300+
#[test]
301+
fn reshape() {
302+
// Initialize the model without weights (because the exported file does not contain them)
303+
let model: reshape::Model<Backend> = reshape::Model::new();
304+
305+
// Run the model
306+
let input = Tensor::<Backend, 1>::from_floats([0., 1., 2., 3.]);
307+
let output = model.forward(input);
308+
let expected = Data::from([[0., 1., 2., 3.]]);
309+
310+
assert_eq!(output.to_data(), expected);
311+
}
258312
}
430 Bytes
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#!/usr/bin/env python3
2+
3+
# used to generate model: reshape.onnx
4+
5+
import torch
6+
import torch.nn as nn
7+
8+
9+
class Model(nn.Module):
10+
def __init__(self):
11+
super(Model, self).__init__()
12+
13+
def forward(self, x):
14+
x = x.reshape(2, 2)
15+
x = x.reshape(1, -1) # -1 means infer from other dimensions
16+
return x
17+
18+
19+
def main():
20+
21+
# Set seed for reproducibility
22+
torch.manual_seed(42)
23+
24+
torch.set_printoptions(precision=8)
25+
26+
# Export to onnx
27+
model = Model()
28+
model.eval()
29+
device = torch.device("cpu")
30+
31+
file_name = "reshape.onnx"
32+
test_input = torch.arange(4., device=device)
33+
torch.onnx.export(model, test_input, file_name,
34+
verbose=False, opset_version=16)
35+
36+
print("Finished exporting model to {}".format(file_name))
37+
38+
# Output some test data for use in the test
39+
print("Test input data of ones: {}".format(test_input))
40+
print("Test input data shape of ones: {}".format(test_input.shape))
41+
output = model.forward(test_input)
42+
print("Test output data shape: {}".format(output.shape))
43+
44+
print("Test output: {}".format(output))
45+
46+
47+
if __name__ == '__main__':
48+
main()

0 commit comments

Comments
 (0)