Skip to content

Commit

Permalink
Merge branch 'layer-ops' into this-main
Browse files Browse the repository at this point in the history
  • Loading branch information
swfsql committed Mar 3, 2024
2 parents c4a2995 + 973a2f1 commit 8fc72c5
Show file tree
Hide file tree
Showing 64 changed files with 914 additions and 118 deletions.
4 changes: 2 additions & 2 deletions dfdx/examples/09-module-sequential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ use dfdx::prelude::*;
struct MlpConfig {
// Linear with compile time input size & runtime known output size
linear1: LinearConfig<Const<784>, usize>,
act1: ReLU,
act1: ops::ReLU,
// Linear with runtime input & output size
linear2: LinearConfig<usize, usize>,
act2: Tanh,
act2: ops::Tanh,
// Linear with runtime input & compile time output size.
linear3: LinearConfig<usize, Const<10>>,
}
Expand Down
4 changes: 2 additions & 2 deletions dfdx/examples/10-module-gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use dfdx::prelude::*;
#[derive(Clone, Default, Debug, Sequential)]
struct MlpConfig<const I: usize, const O: usize> {
linear1: LinearConstConfig<I, 64>,
act1: ReLU,
act1: ops::ReLU,
linear2: LinearConstConfig<64, 64>,
act2: ReLU,
act2: ops::ReLU,
linear3: LinearConstConfig<64, O>,
}

Expand Down
6 changes: 3 additions & 3 deletions dfdx/examples/11-module-optim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ use dfdx::prelude::*;
#[built(Mlp)]
struct MlpConfig {
l1: LinearConstConfig<5, 32>,
act1: ReLU,
act1: ops::ReLU,
l2: LinearConstConfig<32, 32>,
act2: ReLU,
act2: ops::ReLU,
l3: LinearConstConfig<32, 2>,
act3: Tanh,
act3: ops::Tanh,
}

fn main() {
Expand Down
6 changes: 3 additions & 3 deletions dfdx/examples/12-mnist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ impl ExactSizeDataset for MnistTrainSet {

// our network structure
type Mlp = (
(LinearConstConfig<784, 512>, ReLU),
(LinearConstConfig<512, 128>, ReLU),
(LinearConstConfig<128, 32>, ReLU),
(LinearConstConfig<784, 512>, ops::ReLU),
(LinearConstConfig<512, 128>, ops::ReLU),
(LinearConstConfig<128, 32>, ops::ReLU),
LinearConstConfig<32, 10>,
);

Expand Down
4 changes: 2 additions & 2 deletions dfdx/examples/advanced-gradient-accum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ fn main() {

type Model = (
LinearConstConfig<2, 5>,
ReLU,
ops::ReLU,
LinearConstConfig<5, 10>,
Tanh,
ops::Tanh,
LinearConstConfig<10, 20>,
);
let model = dev.build_module::<f32>(Model::default());
Expand Down
17 changes: 9 additions & 8 deletions dfdx/examples/advanced-resnet18.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![cfg_attr(feature = "nightly", feature(generic_const_exprs))]
#![allow(incomplete_features)]

#[cfg(not(feature = "nightly"))]
fn main() {
Expand All @@ -15,7 +16,7 @@ fn main() {
pub struct BasicBlockInternal<const C: usize> {
conv1: Conv2DConstConfig<C, C, 3, 1, 1>,
bn1: BatchNorm2DConstConfig<C>,
relu: ReLU,
relu: ops::ReLU,
conv2: Conv2DConstConfig<C, C, 3, 1, 1>,
bn2: BatchNorm2DConstConfig<C>,
}
Expand All @@ -24,7 +25,7 @@ fn main() {
pub struct DownsampleA<const C: usize, const D: usize> {
conv1: Conv2DConstConfig<C, D, 3, 2, 1>,
bn1: BatchNorm2DConstConfig<D>,
relu: ReLU,
relu: ops::ReLU,
conv2: Conv2DConstConfig<D, D, 3, 1, 1>,
bn2: BatchNorm2DConstConfig<D>,
}
Expand All @@ -44,18 +45,18 @@ fn main() {
pub struct Head {
conv: Conv2DConstConfig<3, 64, 7, 2, 3>,
bn: BatchNorm2DConstConfig<64>,
relu: ReLU,
pool: MaxPool2DConst<3, 2, 1>,
relu: ops::ReLU,
pool: ops::MaxPool2DConst<3, 2, 1>,
}

#[derive(Default, Clone, Sequential)]
#[built(Resnet18)]
pub struct Resnet18Config<const NUM_CLASSES: usize> {
head: Head,
l1: (BasicBlock<64>, ReLU, BasicBlock<64>, ReLU),
l2: (Downsample<64, 128>, ReLU, BasicBlock<128>, ReLU),
l3: (Downsample<128, 256>, ReLU, BasicBlock<256>, ReLU),
l4: (Downsample<256, 512>, ReLU, BasicBlock<512>, ReLU),
l1: (BasicBlock<64>, ops::ReLU, BasicBlock<64>, ops::ReLU),
l2: (Downsample<64, 128>, ops::ReLU, BasicBlock<128>, ops::ReLU),
l3: (Downsample<128, 256>, ops::ReLU, BasicBlock<256>, ops::ReLU),
l4: (Downsample<256, 512>, ops::ReLU, BasicBlock<512>, ops::ReLU),
l5: (AvgPoolGlobal, LinearConstConfig<512, NUM_CLASSES>),
}

Expand Down
4 changes: 2 additions & 2 deletions dfdx/examples/advanced-rl-dqn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ const ACTION: usize = 2;

// our simple 2 layer feedforward network with ReLU activations
type QNetwork = (
(LinearConstConfig<STATE, 32>, ReLU),
(LinearConstConfig<32, 32>, ReLU),
(LinearConstConfig<STATE, 32>, ops::ReLU),
(LinearConstConfig<32, 32>, ops::ReLU),
LinearConstConfig<32, ACTION>,
);

Expand Down
4 changes: 2 additions & 2 deletions dfdx/examples/advanced-rl-ppo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ const STATE: usize = 4;
const ACTION: usize = 2;

type PolicyNetwork = (
(LinearConstConfig<STATE, 32>, ReLU),
(LinearConstConfig<32, 32>, ReLU),
(LinearConstConfig<STATE, 32>, ops::ReLU),
(LinearConstConfig<32, 32>, ops::ReLU),
LinearConstConfig<32, ACTION>,
);

Expand Down
8 changes: 4 additions & 4 deletions dfdx/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@
//! struct MlpConfig {
//! // Linear with compile time input size & runtime known output size
//! linear1: LinearConfig<Const<784>, usize>,
//! act1: ReLU,
//! act1: ops::ReLU,
//! // Linear with runtime input & output size
//! linear2: LinearConfig<usize, usize>,
//! act2: Tanh,
//! act2: ops::Tanh,
//! // Linear with runtime input & compile time output size.
//! linear3: LinearConfig<usize, Const<10>>,
//! }
Expand All @@ -208,7 +208,7 @@
//! ```rust
//! # use dfdx::prelude::*;
//! # let dev: Cpu = Default::default();
//! type Arch = (LinearConstConfig<3, 5>, ReLU, LinearConstConfig<5, 10>);
//! type Arch = (LinearConstConfig<3, 5>, ops::ReLU, LinearConstConfig<5, 10>);
//! let mut model = dev.build_module::<f32>(Arch::default());
//! let x: Tensor<(usize, Const<3>), f32, _> = dev.sample_uniform_like(&(100, Const));
//! let y = model.forward_mut(x);
Expand All @@ -233,7 +233,7 @@
//! ```rust
//! # use dfdx::prelude::*;
//! # let dev: Cpu = Default::default();
//! type Arch = (LinearConstConfig<3, 5>, ReLU, LinearConstConfig<5, 10>);
//! type Arch = (LinearConstConfig<3, 5>, ops::ReLU, LinearConstConfig<5, 10>);
//! let arch = Arch::default();
//! let mut model = dev.build_module::<f32>(arch);
//! // 1. allocate gradients for the model
Expand Down
2 changes: 1 addition & 1 deletion dfdx/src/nn/layers/add_into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ mod tests {
// check if it works in a longer neural net
type Model = (
AddInto<(LinearConstConfig<5, 3>, LinearConstConfig<5, 3>)>,
ReLU,
ops::ReLU,
LinearConstConfig<3, 1>,
);
let mut model = dev.build_module::<TestDtype>(Model::default());
Expand Down
2 changes: 1 addition & 1 deletion dfdx/src/nn/layers/generalized_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::prelude::*;
/// # use dfdx::prelude::*;
/// # use dfdx::*;
/// # let dev: Cpu = Default::default();
/// type Model = GeneralizedAdd<ReLU, Square>;
/// type Model = GeneralizedAdd<ops::ReLU, ops::Square>;
/// let model = dev.build_module::<f32>(Model::default());
/// let x = dev.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]);
/// let y = model.forward(x);
Expand Down
2 changes: 1 addition & 1 deletion dfdx/src/nn/layers/generalized_mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::prelude::*;
/// # use dfdx::prelude::*;
/// # use dfdx::*;
/// # let dev: Cpu = Default::default();
/// type Model = GeneralizedMul<ReLU, Square>;
/// type Model = GeneralizedMul<ops::ReLU, ops::Square>;
/// let model = dev.build_module::<f32>(Model::default());
/// let x = dev.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]);
/// let y = model.forward(x);
Expand Down
49 changes: 1 addition & 48 deletions dfdx/src/nn/layers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
mod abs;
mod add_into;
mod batch_norm1d;
mod batch_norm2d;
Expand All @@ -10,49 +9,26 @@ mod conv1d;
mod conv2d;
#[cfg(feature = "nightly")]
mod conv_trans2d;
mod cos;
mod dropout;
mod embedding;
mod exp;
#[cfg(feature = "nightly")]
mod flatten2d;
mod gelu;
mod generalized_add;
mod generalized_mul;
mod layer_norm1d;
mod layer_rms_norm1d;
mod leaky_relu;
mod linear;
mod ln;
mod log_softmax;
mod matmul;
mod multi_head_attention;
#[cfg(feature = "nightly")]
mod pool_2d_avg;
#[cfg(feature = "nightly")]
mod pool_2d_max;
#[cfg(feature = "nightly")]
mod pool_2d_min;
pub mod ops;
mod pool_global_avg;
mod pool_global_max;
mod pool_global_min;
mod prelu;
mod prelu1d;
mod relu;
mod reshape;
mod residual_add;
mod residual_mul;
mod sigmoid;
mod sin;
mod softmax;
mod split_into;
mod sqrt;
mod square;
mod tanh;
mod transformer;
mod upscale2d;

pub use abs::Abs;
pub use add_into::AddInto;
pub use batch_norm1d::{BatchNorm1D, BatchNorm1DConfig, BatchNorm1DConstConfig};
pub use batch_norm2d::{BatchNorm2D, BatchNorm2DConfig, BatchNorm2DConstConfig};
Expand All @@ -64,45 +40,22 @@ pub use conv1d::{Conv1D, Conv1DConfig, Conv1DConstConfig};
pub use conv2d::{Conv2D, Conv2DConfig, Conv2DConstConfig};
#[cfg(feature = "nightly")]
pub use conv_trans2d::{ConvTrans2D, ConvTrans2DConfig, ConvTrans2DConstConfig};
pub use cos::Cos;
pub use dropout::{Dropout, DropoutOneIn};
pub use embedding::{Embedding, EmbeddingConfig, EmbeddingConstConfig};
pub use exp::Exp;
#[cfg(feature = "nightly")]
pub use flatten2d::Flatten2D;
pub use gelu::{AccurateGeLU, FastGeLU};
pub use generalized_add::GeneralizedAdd;
pub use generalized_mul::GeneralizedMul;
pub use layer_norm1d::{LayerNorm1D, LayerNorm1DConfig, LayerNorm1DConstConfig};
pub use layer_rms_norm1d::{LayerRMSNorm1D, LayerRMSNorm1DConfig, LayerRMSNorm1DConstConfig};
pub use leaky_relu::LeakyReLU;
pub use linear::{Linear, LinearConfig, LinearConstConfig};
pub use ln::Ln;
pub use log_softmax::LogSoftmax;
pub use matmul::{MatMul, MatMulConfig, MatMulConstConfig};
pub use multi_head_attention::{MultiHeadAttention, MultiHeadAttentionConfig};
#[cfg(feature = "nightly")]
pub use pool_2d_avg::{AvgPool2D, AvgPool2DConst};
#[cfg(feature = "nightly")]
pub use pool_2d_max::{MaxPool2D, MaxPool2DConst};
#[cfg(feature = "nightly")]
pub use pool_2d_min::{MinPool2D, MinPool2DConst};
pub use pool_global_avg::AvgPoolGlobal;
pub use pool_global_max::MaxPoolGlobal;
pub use pool_global_min::MinPoolGlobal;
pub use prelu::{PReLU, PReLUConfig};
pub use prelu1d::{PReLU1D, PReLU1DConfig};
pub use relu::ReLU;
pub use reshape::Reshape;
pub use residual_add::ResidualAdd;
pub use residual_mul::ResidualMul;
pub use sigmoid::Sigmoid;
pub use sin::Sin;
pub use softmax::Softmax;
pub use split_into::SplitInto;
pub use sqrt::Sqrt;
pub use square::Square;
pub use tanh::Tanh;
pub use transformer::{
DecoderBlock, DecoderBlockConfig, EncoderBlock, EncoderBlockConfig, Transformer,
TransformerConfig,
Expand Down
File renamed without changes.
15 changes: 15 additions & 0 deletions dfdx/src/nn/layers/ops/add.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use crate::prelude::*;

/// Calls on [crate::tensor_ops::TryAdd], which for tensors is [crate::tensor_ops::add()].
#[derive(Default, Debug, Clone, Copy, crate::nn::CustomModule)]
pub struct Add;
impl<Lhs, Rhs> Module<(Lhs, Rhs)> for Add
where
Lhs: TryAdd<Rhs>,
{
type Output = <Lhs as TryAdd<Rhs>>::Output;

fn try_forward(&self, x: (Lhs, Rhs)) -> Result<Self::Output, Error> {
x.0.try_add(x.1)
}
}
22 changes: 22 additions & 0 deletions dfdx/src/nn/layers/ops/bce.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use crate::prelude::*;

/// Calls [crate::tensor_ops::bce_with_logits].
#[derive(Default, Debug, Clone, Copy, crate::nn::CustomModule)]
pub struct Bce;
type Logits<S, E, D, T> = Tensor<S, E, D, T>;
type Probs<S, E, D, T> = Tensor<S, E, D, T>;

impl<S: Shape, E: Dtype, D: Device<E>, LTape: Tape<E, D>, RTape: Tape<E, D>>
Module<(Logits<S, E, D, LTape>, Probs<S, E, D, RTape>)> for Bce
where
LTape: Merge<RTape>,
{
type Output = Logits<S, E, D, LTape>;

fn try_forward(
&self,
x: (Logits<S, E, D, LTape>, Probs<S, E, D, RTape>),
) -> Result<Self::Output, Error> {
x.0.try_bce_with_logits(x.1)
}
}
50 changes: 50 additions & 0 deletions dfdx/src/nn/layers/ops/boolean.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use crate::prelude::*;
use std::ops::{BitAnd, BitOr, BitXor, Not as BitNot};

/// Calls on [std::ops::BitAnd], which for booleans is [crate::tensor_ops::bool_and].
#[derive(Default, Debug, Clone, Copy, crate::nn::CustomModule)]
pub struct And;

/// Calls on [std::ops::Not], which for booleans is [crate::tensor_ops::bool_not].
#[derive(Default, Debug, Clone, Copy, crate::nn::CustomModule)]
pub struct Not;

/// Calls on [std::ops::BitOr], which for booleans is [crate::tensor_ops::bool_or].
#[derive(Default, Debug, Clone, Copy, crate::nn::CustomModule)]
pub struct Or;

/// Calls on [std::ops::BitXor], which for booleans is [crate::tensor_ops::bool_xor].
#[derive(Default, Debug, Clone, Copy, crate::nn::CustomModule)]
pub struct Xor;

impl<Lhs: BitAnd<Rhs>, Rhs> Module<(Lhs, Rhs)> for And {
type Output = <Lhs as BitAnd<Rhs>>::Output;

fn try_forward(&self, x: (Lhs, Rhs)) -> Result<Self::Output, Error> {
Ok(x.0 & x.1)
}
}

impl<Input: BitNot> Module<Input> for Not {
type Output = <Input as BitNot>::Output;

fn try_forward(&self, x: Input) -> Result<Self::Output, Error> {
Ok(!x)
}
}

impl<Lhs: BitOr<Rhs>, Rhs> Module<(Lhs, Rhs)> for Or {
type Output = <Lhs as BitOr<Rhs>>::Output;

fn try_forward(&self, x: (Lhs, Rhs)) -> Result<Self::Output, Error> {
Ok(x.0 | x.1)
}
}

impl<Lhs: BitXor<Rhs>, Rhs> Module<(Lhs, Rhs)> for Xor {
type Output = <Lhs as BitXor<Rhs>>::Output;

fn try_forward(&self, x: (Lhs, Rhs)) -> Result<Self::Output, Error> {
Ok(x.0 ^ x.1)
}
}
Loading

0 comments on commit 8fc72c5

Please sign in to comment.