Skip to content

Commit

Permalink
fix/solvers: remove CUDA build flag
Browse files Browse the repository at this point in the history
Removes the cuda build flag requirement for the solvers.
  • Loading branch information
hobofan committed Mar 9, 2016
1 parent 3e78189 commit 1f5f6b8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 29 deletions.
2 changes: 0 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ extern crate collenchyma_blas as coblas;
extern crate collenchyma_nn as conn;
pub mod layer;
pub mod layers;
#[cfg(feature="cuda")]
pub mod solver;
#[cfg(feature="cuda")]
pub mod solvers;
pub mod weight;

Expand Down
39 changes: 12 additions & 27 deletions src/solvers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use co::{IBackend, MemoryType, SharedTensor};
use conn::NN;
use solver::*;
use layer::*;
use util::{ArcLock, native_backend, LayerOps, SolverOps};
use util::*;

trait SGDSolver<SolverB: IBackend + SolverOps<f32>, NetB: IBackend + LayerOps<f32>> : ISolver<SolverB, NetB> {
fn compute_update_value(&mut self,
Expand Down Expand Up @@ -74,11 +74,13 @@ trait SGDSolver<SolverB: IBackend + SolverOps<f32>, NetB: IBackend + LayerOps<f3

let mut result = SharedTensor::<f32>::new(IBackend::device(backend), &1).unwrap();
match result.add_device(native.device()) { _ => result.sync(native.device()).unwrap() }
if let &MemoryType::Native(ref sumsq_result) = result.get(native.device()).unwrap() {
let sumsq_diff_slice = sumsq_result.as_slice::<f32>();
sumsq_diff += sumsq_diff_slice[0];
} else {
panic!();
match result.get(native.device()).unwrap() {
&MemoryType::Native(ref sumsq_result) => {
let sumsq_diff_slice = sumsq_result.as_slice::<f32>();
sumsq_diff += sumsq_diff_slice[0];
},
#[cfg(any(feature = "opencl", feature = "cuda"))]
_ => {}
}
}
let l2norm_diff = sumsq_diff.sqrt();
Expand All @@ -90,13 +92,7 @@ trait SGDSolver<SolverB: IBackend + SolverOps<f32>, NetB: IBackend + LayerOps<f3
clip_threshold,
scale_factor);

let mut scale_shared = SharedTensor::<f32>::new(native.device(), &1).unwrap();
if let &mut MemoryType::Native(ref mut scale) = scale_shared.get_mut(native.device()).unwrap() {
let scale_slice = scale.as_mut_slice::<f32>();
scale_slice[0] = scale_factor;
} else {
panic!();
}
let mut scale_shared = native_scalar(scale_factor);

for weight_gradient in net_gradients {
let mut gradient = weight_gradient.write().unwrap();
Expand All @@ -117,13 +113,8 @@ trait SGDSolver<SolverB: IBackend + SolverOps<f32>, NetB: IBackend + LayerOps<f3
let scale_factor = 1f32 / config.minibatch_size as f32;
let mut gradient = weight_blob.write().unwrap();
let native = native_backend();
let mut scale_factor_shared = SharedTensor::<f32>::new(native.device(), &1).unwrap();
if let &mut MemoryType::Native(ref mut scale) = scale_factor_shared.get_mut(native.device()).unwrap() {
let scale_slice = scale.as_mut_slice::<f32>();
scale_slice[0] = scale_factor;
} else {
panic!();
}

let mut scale_factor_shared = native_scalar(scale_factor);
// self.backend().scal_plain(&scale_factor_shared, &mut gradient).unwrap();
self.backend().scal(&mut scale_factor_shared, &mut gradient).unwrap();
}
Expand All @@ -141,13 +132,7 @@ trait SGDSolver<SolverB: IBackend + SolverOps<f32>, NetB: IBackend + LayerOps<f3
match regularization_method {
RegularizationMethod::L2 => {
let native = native_backend();
let mut decay_shared = SharedTensor::<f32>::new(native.device(), &1).unwrap();
if let &mut MemoryType::Native(ref mut decay) = decay_shared.get_mut(native.device()).unwrap() {
let decay_slice = decay.as_mut_slice::<f32>();
decay_slice[0] = local_decay;
} else {
panic!();
}
let decay_shared = native_scalar(local_decay);
let gradient = &mut weight_gradient.write().unwrap();
// gradient.regularize_l2(self.backend(), &decay_shared);
// backend.axpy_plain(&decay_shared, &self.data, &mut self.diff).unwrap();
Expand Down

0 comments on commit 1f5f6b8

Please sign in to comment.