Skip to content

Commit

Permalink
Fixing example
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Oct 25, 2023
1 parent 45305a3 commit ca9548b
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions dfdx/examples/advanced-train-loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ fn classification_train<
Lbl,
// Our model just needs to implement these two things! ModuleMut for forward
// and TensorCollection for optimizer/alloc_grads/zero_grads
Model: Module<Inp::Traced, Error = crate::tensor::Error> + ZeroGrads<E, D> + UpdateParams<E, D>,
Model: Module<Inp::Traced> + ZeroGrads<E, D> + UpdateParams<E, D>,
// optimizer, pretty straight forward
Opt: Optimizer<Model, E, D>,
// our data will just be any iterator over these items. easy!
Expand All @@ -22,7 +22,7 @@ fn classification_train<
Criterion: FnMut(Model::Output, Lbl) -> Loss,
// the Loss needs to be able to call backward, and we also use
// this generic as an output
Loss: Backward<E, D, Err = crate::tensor::Error> + AsArray<Array = E>,
Loss: Backward<E, D> + AsArray<Array = E>,
// Dtype & Device to tie everything together
E: Dtype,
D: Device<E>,
Expand All @@ -32,7 +32,7 @@ fn classification_train<
mut criterion: Criterion,
data: Data,
batch_accum: usize,
) -> Result<(), crate::tensor::Error> {
) -> Result<(), Error> {
let mut grads = model.try_alloc_grads()?;
for (i, (inp, lbl)) in data.enumerate() {
let y = model.try_forward_mut(inp.traced(grads))?;
Expand Down

0 comments on commit ca9548b

Please sign in to comment.