Commit 3e4adc4 1 parent 3264b10 commit 3e4adc4 Copy full SHA for 3e4adc4
File tree 4 files changed +541
-2
lines changed
4 files changed +541
-2
lines changed Original file line number Diff line number Diff line change @@ -176,8 +176,8 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
176
176
177
177
/// Map each tensor in the module with a [mapper](ModuleMapper).
178
178
fn map < M : ModuleMapper < B > > ( self , mapper : & mut M ) -> Self ;
179
- /// Load the module state from a record.
180
179
180
+ /// Load the module state from a record.
181
181
fn load_record ( self , record : Self :: Record ) -> Self ;
182
182
183
183
/// Convert the module into a record containing the state.
Original file line number Diff line number Diff line change @@ -16,7 +16,7 @@ pub struct WeightDecayConfig {
16
16
/// State of [WeightDecay](WeightDecay).
17
17
#[ derive( Record , Clone , new) ]
18
18
pub struct WeightDecayState < B : Backend , const D : usize > {
19
- grad_last_step : Tensor < B , D > ,
19
+ pub ( crate ) grad_last_step : Tensor < B , D > ,
20
20
}
21
21
22
22
/// Weight decay implementation that transforms gradients.
@@ -57,6 +57,15 @@ impl<B: Backend> WeightDecay<B> {
57
57
58
58
( grad, WeightDecayState :: new ( grad_last_step) )
59
59
}
60
+
61
+ /// temp fix for Transform.
62
+ pub fn transform_temp_fix < const D : usize > (
63
+ & self ,
64
+ grad : Tensor < B , D > ,
65
+ tensor : Tensor < B , D > ,
66
+ ) -> Tensor < B , D > {
67
+ tensor. mul_scalar ( self . penalty ) . add ( grad)
68
+ }
60
69
}
61
70
62
71
impl < B : Backend , const D : usize > WeightDecayState < B , D > {
Original file line number Diff line number Diff line change @@ -10,6 +10,7 @@ mod adamw;
10
10
mod base;
11
11
mod grad_accum;
12
12
mod grads;
13
+ mod rmsprop;
13
14
mod sgd;
14
15
mod simple;
15
16
mod visitor;
@@ -20,5 +21,6 @@ pub use adamw::*;
20
21
pub use base:: * ;
21
22
pub use grad_accum:: * ;
22
23
pub use grads:: * ;
24
+ pub use rmsprop:: * ;
23
25
pub use sgd:: * ;
24
26
pub use simple:: * ;
You can’t perform that action at this time.
0 commit comments