Skip to content

Commit

Permalink
feature/solver: implement solver and sgd
Browse files Browse the repository at this point in the history
Set up structure for solvers; Documentation is still lacking
  • Loading branch information
hobofan committed Nov 10, 2015
1 parent 5dc879e commit 83db20d
Show file tree
Hide file tree
Showing 10 changed files with 408 additions and 100 deletions.
17 changes: 9 additions & 8 deletions src/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ pub type WriteBlob<'_> = RwLockWriteGuard<'_, HeapBlob>;

#[derive(Debug)]
/// The generic Layer
pub struct Layer<'a> {
pub struct Layer {
/// The configuration of the Layer
pub config: Box<&'a LayerConfig>,
pub config: Box<LayerConfig>,
/// The [implementation][1] of the Layer.
/// [1]: ../layers/index.html
///
Expand All @@ -97,16 +97,16 @@ pub struct Layer<'a> {
weight_propagate_down: Vec<bool>,
}

impl<'a> Layer<'a> {
impl Layer {
/// Creates a new Layer from a [LayerConfig][1].
/// [1]: ./struct.LayerConfig.html
///
/// Used during [Network][2] initalization.
///
/// [2]: ../network/struct.Network.html
pub fn from_config(config: &'a LayerConfig) -> Layer {
pub fn from_config(config: &LayerConfig) -> Layer {
let cl = config.clone();
let cfg = Box::<&'a LayerConfig>::new(cl);
let cfg = Box::<LayerConfig>::new(cl);
Layer {
loss: Vec::new(),
blobs: Vec::new(),
Expand Down Expand Up @@ -171,14 +171,15 @@ pub trait ILayer {
/// [2]: ./type.ReadBlob.html
/// [3]: ./type.WriteBlob.html
/// [3]: #method.forward_cpu
#[allow(map_clone)]
fn forward(&self, bottom: &[ArcLock<HeapBlob>], top: &mut Vec<ArcLock<HeapBlob>>) -> f32 {
// Lock();
// Reshape(bottom, top); // Reshape the layer to fit top & bottom blob
let mut loss = 0f32;

let btm: Vec<_> = bottom.iter().map(|b| b.read().unwrap()).collect();
// let tp: Vec<_> = top.iter().map(|b| b.write().unwrap()).collect();
let tp_ref = top.iter().map(|t| t.clone()).collect::<Vec<_>>();
let tp_ref = top.iter().cloned().collect::<Vec<_>>();
let mut tp = &mut tp_ref.iter().map(|b| b.write().unwrap()).collect::<Vec<_>>();
let mut tpo = &mut tp.iter_mut().map(|a| a).collect::<Vec<_>>();
self.forward_cpu(&btm, tpo);
Expand Down Expand Up @@ -249,7 +250,7 @@ impl fmt::Debug for ILayer {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
/// Layer Configuration Struct
pub struct LayerConfig {
/// The name of the Layer
Expand Down Expand Up @@ -331,7 +332,7 @@ impl LayerConfig {
}


#[derive(Debug)]
#[derive(Debug, Clone)]
/// Specifies training configuration for a weight blob.
pub struct WeightConfig {
/// The name of the weight blob -- useful for sharing weights among
Expand Down
4 changes: 4 additions & 0 deletions src/layers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@
//!
//! [2]: https://en.wikipedia.org/wiki/Activation_function
//! [3]: ../layer/index.html

/// Implement [ILayer][1] for [activation layers][2].
/// [1]: ./layer/trait.ILayer.html
/// [2]: ./layers/activation/index.html
macro_rules! impl_neuron_layer {
() => (
fn exact_num_top_blobs(&self) -> usize { 1 }
Expand Down
16 changes: 12 additions & 4 deletions src/math.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
use rblas::Axpy;
use rblas::Dot;
use rblas::*;

pub fn leaf_cpu_axpy(alpha: &f32, x: &[f32], y: &mut Vec<f32>) {
Axpy::axpy(alpha, x, y);
}

pub fn leaf_cpu_axpby(alpha: &f32, x: &[f32], beta: &f32, y: &mut Vec<f32>) {
leaf_cpu_scal(beta, y);
leaf_cpu_axpy(alpha, x, y);
}

pub fn leaf_cpu_dot(x: &[f32], y: &[f32]) -> f32 {
Dot::dot(x, y)
}

pub fn leaf_cpu_axpy(alpha: &f32, x: &[f32], y: &mut Vec<f32>) {
Axpy::axpy(alpha, x, y);
pub fn leaf_cpu_scal(alpha: &f32, x: &mut Vec<f32>) {
Scal::scal(alpha, x)
}
32 changes: 21 additions & 11 deletions src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ use phloem::Blob;
/// A Network is usually used together with a [Solver][6] to optimize the networks' weights.
///
/// [6]: ../solver/struct.Solver.html
pub struct Network<'a> {
pub struct Network {
/// Identifies the Network
///
/// The name is mainly used for logging purposes.
pub name: String,
layers: Vec<Layer<'a>>,
layers: Vec<Layer>,
layer_names: Vec<String>,
layer_names_index: HashMap<String, usize>,
layer_need_backwards: Vec<bool>,
Expand Down Expand Up @@ -114,8 +114,8 @@ pub struct Network<'a> {
weights_weight_decay: Vec<Option<f32>>,
}

impl<'a> Default for Network<'a> {
fn default() -> Network<'a> {
impl Default for Network {
fn default() -> Network {
Network {
name: "".to_owned(),
layers: vec![],
Expand Down Expand Up @@ -159,7 +159,7 @@ impl<'a> Default for Network<'a> {
}
}

impl<'a> Network<'a> {
impl Network {
/// Creates a Network from a [NetworkConfig][1].
/// [1]: ./struct.NetworkConfig.html
///
Expand All @@ -183,12 +183,12 @@ impl<'a> Network<'a> {
/// to be executed for each blob and layer.
///
/// [1]: ./struct.NetworkConfig.html
fn init(&mut self, in_config: &'a NetworkConfig) {
fn init(&mut self, in_config: &NetworkConfig) {
let config = in_config.clone();
let available_blobs = &mut HashSet::new();
let blob_name_to_idx = &mut HashMap::<String, usize>::new();
for (input_id, _) in config.inputs.iter().enumerate() {
self.append_top(config,
self.append_top(&config,
None,
input_id,
Some(available_blobs),
Expand All @@ -198,7 +198,7 @@ impl<'a> Network<'a> {
self.resize_vecs(config.layers.len());

for (layer_id, _) in config.inputs.iter().enumerate() {
self.init_layer(layer_id, config, available_blobs, blob_name_to_idx);
self.init_layer(layer_id, &config, available_blobs, blob_name_to_idx);
}

// Go through the net backwards to determine which blobs contribute to the
Expand Down Expand Up @@ -259,7 +259,7 @@ impl<'a> Network<'a> {
/// [4]: ../layers/index.html
fn init_layer(&mut self,
layer_id: usize,
config: &'a NetworkConfig,
config: &NetworkConfig,
available_blobs: &mut HashSet<String>,
blob_name_to_idx: &mut HashMap<String, usize>) {
// Caffe
Expand Down Expand Up @@ -868,9 +868,19 @@ impl<'a> Network<'a> {
pub fn learnable_weights(&self) -> &Vec<ArcLock<HeapBlob>> {
&self.learnable_weights
}

#[allow(missing_docs)]
pub fn weights_weight_decay(&self) -> &Vec<Option<f32>> {
&self.weights_weight_decay
}

#[allow(missing_docs)]
pub fn weights_lr(&self) -> &Vec<Option<f32>> {
&self.weights_lr
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
/// Defines the configuration of a network.
///
/// TODO: [DOC] When and why would you use this?
Expand Down Expand Up @@ -959,7 +969,7 @@ impl NetworkConfig {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
/// Defines the state of a network.
pub struct NetworkState {
/// Defines the current mode of the network.
Expand Down
Loading

0 comments on commit 83db20d

Please sign in to comment.