Skip to content

Commit

Permalink
switch between max_width = 200 and max_width = 100
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexErrant committed Nov 29, 2023
1 parent 9c8a40c commit d864317
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 231 deletions.
324 changes: 154 additions & 170 deletions burn-autodiff/src/ops/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,49 +103,45 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
}

match bias {
Some(bias) => {
match Conv2DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
B::conv2d(x.primitive, weight.primitive, Some(bias.primitive), options),
Some(bias) => match Conv2DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv2d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
}
}
None => {
match Conv2DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv2d(x.primitive, weight.primitive, None, options),
B::conv2d(x.primitive, weight.primitive, Some(bias.primitive), options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv2d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
},
None => match Conv2DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => {
prep.finish(B::conv2d(x.primitive, weight.primitive, None, options))
}
B::conv2d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => {
prep.finish(B::conv2d(x.primitive, weight.primitive, None, options))
}
}
},
}
}

Expand Down Expand Up @@ -211,57 +207,53 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
}

match bias {
Some(bias) => {
match ConvTranspose2DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
B::conv_transpose2d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
),
Some(bias) => match ConvTranspose2DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
B::conv_transpose2d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
}
}
None => {
match ConvTranspose2DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv_transpose2d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
x.primitive,
weight.primitive,
None,
options,
)),
}
}
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
},
None => match ConvTranspose2DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv_transpose2d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
x.primitive,
weight.primitive,
None,
options,
)),
},
}
}

Expand Down Expand Up @@ -322,49 +314,45 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
}
}
match bias {
Some(bias) => {
match Conv1DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
B::conv1d(x.primitive, weight.primitive, Some(bias.primitive), options),
Some(bias) => match Conv1DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv1d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
}
}
None => {
match Conv1DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv1d(x.primitive, weight.primitive, None, options),
B::conv1d(x.primitive, weight.primitive, Some(bias.primitive), options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv1d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
},
None => match Conv1DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => {
prep.finish(B::conv1d(x.primitive, weight.primitive, None, options))
}
B::conv1d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => {
prep.finish(B::conv1d(x.primitive, weight.primitive, None, options))
}
}
},
}
}

Expand Down Expand Up @@ -430,57 +418,53 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
}

match bias {
Some(bias) => {
match ConvTranspose1DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
B::conv_transpose1d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
),
Some(bias) => match ConvTranspose1DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
B::conv_transpose1d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
}
}
None => {
match ConvTranspose1DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv_transpose1d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
x.primitive,
weight.primitive,
None,
options,
)),
}
}
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
},
None => match ConvTranspose1DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv_transpose1d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
x.primitive,
weight.primitive,
None,
options,
)),
},
}
}

Expand Down
7 changes: 6 additions & 1 deletion burn-core/src/nn/norm/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,12 @@ impl<const D: usize, B: Backend> BatchNorm<B, D> {
// Should be move to a compilation error when const generic support that kind of
// validation. https://github.com/rust-lang/rust/issues/76560
if D + 2 != DI {
panic!("BatchNorm{}D can only be applied on tensors of size {} with the following shape [batch_size, channels, ...], received {}D tensor", D, D+2, DI);
panic!(
"BatchNorm{}D can only be applied on tensors of size {} with the following shape [batch_size, channels, ...], received {}D tensor",
D,
D + 2,
DI
);
}

match B::ad_enabled() {
Expand Down
Loading

0 comments on commit d864317

Please sign in to comment.