Skip to content

Commit 3e4adc4

Browse files
authored
Feat/RMSProp-optimizer (#607)
1 parent 3264b10 commit 3e4adc4

File tree

4 files changed

+541
-2
lines changed

4 files changed

+541
-2
lines changed

burn-core/src/module/base.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
176176

177177
/// Map each tensor in the module with a [mapper](ModuleMapper).
178178
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self;
179-
/// Load the module state from a record.
180179

180+
/// Load the module state from a record.
181181
fn load_record(self, record: Self::Record) -> Self;
182182

183183
/// Convert the module into a record containing the state.

burn-core/src/optim/decay.rs

+10-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pub struct WeightDecayConfig {
1616
/// State of [WeightDecay](WeightDecay).
1717
#[derive(Record, Clone, new)]
1818
pub struct WeightDecayState<B: Backend, const D: usize> {
19-
grad_last_step: Tensor<B, D>,
19+
pub(crate) grad_last_step: Tensor<B, D>,
2020
}
2121

2222
/// Weight decay implementation that transforms gradients.
@@ -57,6 +57,15 @@ impl<B: Backend> WeightDecay<B> {
5757

5858
(grad, WeightDecayState::new(grad_last_step))
5959
}
60+
61+
/// temp fix for Transform.
62+
pub fn transform_temp_fix<const D: usize>(
63+
&self,
64+
grad: Tensor<B, D>,
65+
tensor: Tensor<B, D>,
66+
) -> Tensor<B, D> {
67+
tensor.mul_scalar(self.penalty).add(grad)
68+
}
6069
}
6170

6271
impl<B: Backend, const D: usize> WeightDecayState<B, D> {

burn-core/src/optim/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ mod adamw;
1010
mod base;
1111
mod grad_accum;
1212
mod grads;
13+
mod rmsprop;
1314
mod sgd;
1415
mod simple;
1516
mod visitor;
@@ -20,5 +21,6 @@ pub use adamw::*;
2021
pub use base::*;
2122
pub use grad_accum::*;
2223
pub use grads::*;
24+
pub use rmsprop::*;
2325
pub use sgd::*;
2426
pub use simple::*;

0 commit comments

Comments
 (0)