From ade4e1578ac5b7fa47255dfd790c682b3fcb2997 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Sun, 5 Nov 2023 23:14:00 -0300 Subject: [PATCH] fix test for f64 --- dfdx/src/nn/layers/on.rs | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/dfdx/src/nn/layers/on.rs b/dfdx/src/nn/layers/on.rs index 522d1e49..d899f6fc 100644 --- a/dfdx/src/nn/layers/on.rs +++ b/dfdx/src/nn/layers/on.rs @@ -38,12 +38,6 @@ mod tests { use super::*; use crate::tests::*; - #[input_wrapper] - pub struct MyWrapper { - pub a: A, - pub b: B, - } - #[input_wrapper] pub struct Split1 { pub forward: Forward, @@ -73,8 +67,20 @@ mod tests { fn test_residual_add_backward() { let dev: TestDevice = Default::default(); - let model = - dev.build_module::(>>::default()); + let model = dev.build_module::(>>::default()); + let model = DeviceResidualAdd1::, TestDtype, TestDevice> { + t: On { + t: Linear { + weight: model.t.t.weight.to_dtype::(), + bias: model.t.t.bias.to_dtype::(), + }, + _n: Default::default(), + }, + add: Default::default(), + input_to_tuple: Default::default(), + input_to_wrapper: Default::default(), + split: Default::default(), + }; let x: Tensor, f32, _> = dev.sample_normal(); let x = x.to_dtype::(); @@ -115,8 +121,20 @@ mod tests { fn test_residual_add_backward2() { let dev: TestDevice = Default::default(); - let model = - dev.build_module::(>>::default()); + let model = dev.build_module::(>>::default()); + let model = DeviceResidualAdd2::, TestDtype, TestDevice> { + t: On { + t: Linear { + weight: model.t.t.weight.to_dtype::(), + bias: model.t.t.bias.to_dtype::(), + }, + _n: Default::default(), + }, + add: Default::default(), + input_to_tuple: Default::default(), + input_to_wrapper: Default::default(), + split: Default::default(), + }; let x: Tensor, f32, _> = dev.sample_normal(); let x = x.to_dtype::();