-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat/native: Add support for softmax w/ test and benches.
- Loading branch information
Ewan Higgs
committed
Jan 21, 2016
1 parent
892ce8f
commit 14d6d1b
Showing
4 changed files
with
301 additions
and
197 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
#![feature(test)] | ||
#![feature(clone_from_slice)] | ||
|
||
extern crate test; | ||
extern crate collenchyma as co; | ||
extern crate collenchyma_nn as co_nn; | ||
extern crate rand; | ||
|
||
use test::Bencher; | ||
use co::backend::{Backend, BackendConfig}; | ||
use co::frameworks::Native; | ||
use co::framework::IFramework; | ||
use co::tensor::SharedTensor; | ||
use co_nn::*; | ||
|
||
use rand::{thread_rng, Rng}; | ||
|
||
fn backend() -> Backend<Native> { | ||
let framework = Native::new(); | ||
let hardwares = framework.hardwares(); | ||
let backend_config = BackendConfig::new(framework, hardwares); | ||
Backend::new(backend_config).unwrap() | ||
} | ||
|
||
fn arguments<T: IFramework + Clone>(backend: &Backend<T>, size: usize) -> (SharedTensor<f32>, SharedTensor<f32>) { | ||
let mut rng = thread_rng(); | ||
let slice_x = rng.gen_iter::<f32>().take(size).collect::<Vec<f32>>(); | ||
|
||
let mut x = SharedTensor::<f32>::new(backend.device(), &size).unwrap(); | ||
let out = SharedTensor::<f32>::new(backend.device(), &size).unwrap(); | ||
x.get_mut(backend.device()).unwrap().as_mut_native().unwrap().as_mut_slice().clone_from_slice(&slice_x); | ||
(x, out) | ||
} | ||
|
||
fn arguments_grad<T: IFramework + Clone>(backend: &Backend<T>, size: usize) -> (SharedTensor<f32>, SharedTensor<f32>, SharedTensor<f32>) { | ||
let mut rng = thread_rng(); | ||
let slice_x = rng.gen_iter::<f32>().take(size).collect::<Vec<f32>>(); | ||
|
||
let mut x = SharedTensor::<f32>::new(backend.device(), &size).unwrap(); | ||
let mut dx = SharedTensor::<f32>::new(backend.device(), &size).unwrap(); | ||
let dout = SharedTensor::<f32>::new(backend.device(), &size).unwrap(); | ||
x.get_mut(backend.device()).unwrap().as_mut_native().unwrap().as_mut_slice().clone_from_slice(&slice_x); | ||
dx.get_mut(backend.device()).unwrap().as_mut_native().unwrap().as_mut_slice().clone_from_slice(&slice_x); | ||
(x, dx, dout) | ||
} | ||
|
||
#[inline(never)] | ||
fn bench_profile<F: FnMut() -> ()>( | ||
b: &mut Bencher, | ||
mut bench_func: F, | ||
times: usize | ||
) { | ||
b.iter(|| { for _ in 0..times { bench_func(); } }); | ||
} | ||
|
||
#[bench] | ||
fn bench_1000_softmax_100_native(b: &mut Bencher) { | ||
let backend = backend(); | ||
let (mut x, mut out) = arguments(&backend, 100); | ||
let mut func = || { let _ = backend.softmax_plain(&mut x, &mut out); }; | ||
{ func(); bench_profile(b, func, 1000); } | ||
} | ||
|
||
#[bench] | ||
fn bench_10_softmax_10000_native(b: &mut Bencher) { | ||
let backend = backend(); | ||
let (mut x, mut out) = arguments(&backend, 10000); | ||
let mut func = || { let _ = backend.softmax_plain(&mut x, &mut out); }; | ||
{ func(); bench_profile(b, func, 10); } | ||
} | ||
|
||
#[bench] | ||
fn bench_1000_softmax_grad_100_native(b: &mut Bencher) { | ||
let backend = backend(); | ||
let (mut x, mut dx, mut dout) = arguments_grad(&backend, 100); | ||
let mut func = || { let _ = backend.softmax_grad_plain(&mut x, &mut dx, &mut dout); }; | ||
{ func(); bench_profile(b, func, 1000); } | ||
} | ||
|
||
#[bench] | ||
fn bench_10_softmax_grad_10000_native(b: &mut Bencher) { | ||
let backend = backend(); | ||
let (mut x, mut dx, mut dout) = arguments_grad(&backend, 10000); | ||
let mut func = || { let _ = backend.softmax_grad_plain(&mut x, &mut dx, &mut dout); }; | ||
{ func(); bench_profile(b, func, 10); } | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.