Skip to content

Commit

Permalink
fix some compilation error
Browse files Browse the repository at this point in the history
  • Loading branch information
cksac committed Jun 7, 2024
1 parent b5f77d0 commit 5111c5a
Show file tree
Hide file tree
Showing 22 changed files with 155 additions and 126 deletions.
10 changes: 6 additions & 4 deletions rai-core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ impl<T> From<StdResult<T>> for RaiResult<T> {
}

impl<T> RaiResult<T> {
pub fn from_val(val: T) -> Self {
RaiResult(StdResult::Ok(val))
}

pub fn as_ref(&self) -> RaiResult<&T> {
RaiResult(self.0.as_ref().map_err(|e| e.clone()))
}
Expand Down Expand Up @@ -91,3 +87,9 @@ impl<A, V: FromIterator<A>> FromIterator<RaiResult<A>> for RaiResult<V> {
RaiResult(Ok(vec.into_iter().collect()))
}
}

impl<V: FromIterator<Tensor>> FromIterator<Tensor> for RaiResult<V> {
fn from_iter<I: IntoIterator<Item = Tensor>>(iter: I) -> Self {
RaiResult(Ok(iter.into_iter().collect()))
}
}
2 changes: 1 addition & 1 deletion rai-core/src/hlops/chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub fn chunk(x: impl TryAsTensor, chunks: usize, dim: impl Dim) -> RaiResult<Vec
tensors.push(tensor);
sum_chunk_size += chunk_size
}
RaiResult::from_val(tensors)
Ok(tensors).into()
}
}

Expand Down
9 changes: 5 additions & 4 deletions rai-core/src/hlops/dropout.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::Tensor;
use crate::{RaiResult, Tensor, TryAsTensor};

#[track_caller]
pub fn dropout(input: &Tensor, p: f32) -> Tensor {
pub fn dropout(input: impl TryAsTensor, p: f32) -> RaiResult<Tensor> {
let input = crate::try_get! { input.try_as_tensor() };
assert!((0.0..1.0).contains(&p));
let r = input.rand_like();
let r = crate::try_get! { input.rand_like() };
let scale = 1.0 / (1.0 - p);
let mask = r.ge(r.full_like(p)).to_dtype(r) * scale;
input * mask
Expand All @@ -12,7 +13,7 @@ pub fn dropout(input: &Tensor, p: f32) -> Tensor {
crate::impl_op! {
#[inline]
#[track_caller]
pub fn dropout(&self, p: f32) -> Tensor {
pub fn dropout(&self, p: f32) -> RaiResult<Tensor> {
dropout(self, p)
}
}
9 changes: 5 additions & 4 deletions rai-core/src/hlops/flatten.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{Dim, Shape, Tensor};
use crate::{Dim, RaiResult, Shape, Tensor, TryAsTensor};
use std::{
fmt::Debug,
ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive},
Expand Down Expand Up @@ -162,7 +162,8 @@ impl FlattenArgs for RangeInclusive<i32> {
}

#[track_caller]
pub fn flatten<T: FlattenArgs>(x: &Tensor, args: T) -> Tensor {
pub fn flatten<T: FlattenArgs>(x: impl TryAsTensor, args: T) -> RaiResult<Tensor> {
let x = crate::try_get! { x.try_as_tensor() };
if x.rank() == 0 {
return x.reshape([1]);
}
Expand All @@ -176,14 +177,14 @@ pub fn flatten<T: FlattenArgs>(x: &Tensor, args: T) -> Tensor {
}
x.reshape(dst_dim)
} else {
x.clone()
x.clone().into()
}
}

impl Tensor {
#[inline]
#[track_caller]
pub fn flatten<T: FlattenArgs>(&self, args: T) -> Tensor {
pub fn flatten<T: FlattenArgs>(&self, args: T) -> RaiResult<Tensor> {
flatten(self, args)
}
}
6 changes: 3 additions & 3 deletions rai-core/src/hlops/from_safetensor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{AsDevice, Tensor};
use crate::{AsDevice, RaiResult, Tensor};
use half::{bf16, f16};
use safetensors::tensor::TensorView;
use std::slice::from_raw_parts;
Expand Down Expand Up @@ -48,7 +48,7 @@ fn convert_slice<T: Clone>(data: &[u8]) -> Vec<T> {
///
/// A `Tensor` created from the `safetensors::TensorView`.
#[track_caller]
pub fn from_safetensor(view: &TensorView, device: impl AsDevice) -> Tensor {
pub fn from_safetensor(view: &TensorView, device: impl AsDevice) -> RaiResult<Tensor> {
let shape = view.shape();
let data = view.data();
match view.dtype() {
Expand Down Expand Up @@ -94,7 +94,7 @@ impl Tensor {
/// see [`hlops::from_safetensor`](hlops::from_safetensor)
#[inline]
#[track_caller]
pub fn from_safetensor(view: &TensorView, device: impl AsDevice) -> Tensor {
pub fn from_safetensor(view: &TensorView, device: impl AsDevice) -> RaiResult<Tensor> {
from_safetensor(view, device)
}
}
9 changes: 5 additions & 4 deletions rai-core/src/hlops/mean.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use crate::{ops::ReduceArgs, Shape, Tensor};
use crate::{ops::ReduceArgs, RaiResult, Shape, Tensor, TryAsTensor};

#[track_caller]
pub fn mean<T: ReduceArgs>(x: &Tensor, args: T) -> Tensor {
pub fn mean<T: ReduceArgs>(x: impl TryAsTensor, args: T) -> RaiResult<Tensor> {
let x = crate::try_get! { x.try_as_tensor() };
let elem_count = x.dims_elem_count(args.dims()) as f64;
x.sum(args) / elem_count
}

impl Tensor {
crate::impl_op! {
#[inline]
#[track_caller]
pub fn mean<T: ReduceArgs>(&self, args: T) -> Tensor {
pub fn mean<T: ReduceArgs>(&self, args: T) -> RaiResult<Tensor> {
mean(self, args)
}
}
9 changes: 5 additions & 4 deletions rai-core/src/hlops/squeeze.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::{Dims, Shape, Tensor};
use crate::{Dims, RaiResult, Shape, Tensor, TryAsTensor};

#[track_caller]
pub fn squeeze(x: &Tensor, dims: impl Dims<Vec<usize>>) -> Tensor {
pub fn squeeze(x: impl TryAsTensor, dims: impl Dims<Vec<usize>>) -> RaiResult<Tensor> {
let x = crate::try_get! { x.try_as_tensor() };
let dims = x.dims(dims);
let mut out_shape = Vec::new();
for (i, s) in x.shape().iter().enumerate() {
Expand All @@ -12,10 +13,10 @@ pub fn squeeze(x: &Tensor, dims: impl Dims<Vec<usize>>) -> Tensor {
x.reshape(out_shape)
}

impl Tensor {
crate::impl_op! {
#[inline]
#[track_caller]
pub fn squeeze(&self, dims: impl Dims<Vec<usize>>) -> Tensor {
pub fn squeeze(&self, dims: impl Dims<Vec<usize>>) -> RaiResult<Tensor> {
squeeze(self, dims)
}
}
9 changes: 5 additions & 4 deletions rai-core/src/hlops/unsqueeze.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::{Dim, Shape, Tensor};
use crate::{Dim, RaiResult, Shape, Tensor, TryAsTensor};

#[track_caller]
pub fn unsqueeze(x: &Tensor, d: impl Dim) -> Tensor {
pub fn unsqueeze(x: impl TryAsTensor, d: impl Dim) -> RaiResult<Tensor> {
let x = crate::try_get! { x.try_as_tensor() };
let is_negative = d.is_negative();
let dim = x.dim(d);
let mut shape = x.shape().to_vec();
Expand All @@ -13,10 +14,10 @@ pub fn unsqueeze(x: &Tensor, d: impl Dim) -> Tensor {
x.reshape(shape)
}

impl Tensor {
crate::impl_op! {
#[inline]
#[track_caller]
pub fn unsqueeze(&self, d: impl Dim) -> Tensor {
pub fn unsqueeze(&self, d: impl Dim) -> RaiResult<Tensor> {
unsqueeze(self, d)
}
}
9 changes: 5 additions & 4 deletions rai-core/src/hlops/var.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::ops::ReduceArgs;
use crate::{Dims, Shape, Tensor};
use crate::{Dims, RaiResult, Shape, Tensor, TryAsTensor};

pub trait VarArgs: ReduceArgs {
fn ddof(&self) -> usize {
Expand Down Expand Up @@ -30,17 +30,18 @@ where
}

#[track_caller]
pub fn var<T: VarArgs>(x: &Tensor, args: T) -> Tensor {
pub fn var<T: VarArgs>(x: impl TryAsTensor, args: T) -> RaiResult<Tensor> {
let x = crate::try_get! { x.try_as_tensor() };
let elem_count = x.dims_elem_count(args.dims());
let m = x.mean((args.dims(), args.keep_dim()));
let s = (x - m).square().sum((args.dims(), args.keep_dim()));
s / (elem_count - args.ddof()) as f32
}

impl Tensor {
crate::impl_op! {
#[inline]
#[track_caller]
pub fn var<T: VarArgs>(&self, args: T) -> Tensor {
pub fn var<T: VarArgs>(&self, args: T) -> RaiResult<Tensor> {
var(self, args)
}
}
12 changes: 8 additions & 4 deletions rai-core/src/ops/abs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,22 @@ impl Op for Abs {
}

#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
fn jvp(&self, _output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> RaiResult<Tensor> {
let x = &primals[0];
let tangent_x = &tangents[0];
let t = x.sign();
tangent_x * x.sign()
}

#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
fn vjp(
&self,
_output: &Tensor,
primals: &[Tensor],
cotangent: &Tensor,
) -> RaiResult<Vec<Tensor>> {
let x = &primals[0];
let cotangent_x = cotangent * x.sign();
vec![cotangent_x]
vec![cotangent_x].into_iter().collect()
}
}

Expand Down
11 changes: 8 additions & 3 deletions rai-core/src/ops/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,22 @@ impl Op for Add {
}

#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, _primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
fn jvp(&self, _output: &Tensor, _primals: &[Tensor], tangents: &[Tensor]) -> RaiResult<Tensor> {
let tangent_lhs = &tangents[0];
let tangent_rhs = &tangents[1];
tangent_lhs + tangent_rhs
}

#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, _primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
fn vjp(
&self,
_output: &Tensor,
_primals: &[Tensor],
cotangent: &Tensor,
) -> RaiResult<Vec<Tensor>> {
let cotangent_lhs = cotangent.clone();
let cotangent_rhs = cotangent.clone();
vec![cotangent_lhs, cotangent_rhs]
Ok(vec![cotangent_lhs, cotangent_rhs]).into()
}
}

Expand Down
11 changes: 8 additions & 3 deletions rai-core/src/ops/arange.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,19 @@ impl<D: Type> Op for Arange<D> {

#[tracing::instrument(ret(level = Level::TRACE))]
#[inline]
fn jvp(&self, output: &Tensor, _primals: &[Tensor], _tangents: &[Tensor]) -> Tensor {
fn jvp(&self, output: &Tensor, _primals: &[Tensor], _tangents: &[Tensor]) -> RaiResult<Tensor> {
output.ones_like()
}

#[tracing::instrument(ret(level = Level::TRACE))]
#[inline]
fn vjp(&self, _output: &Tensor, _primals: &[Tensor], _cotangent: &Tensor) -> Vec<Tensor> {
vec![]
fn vjp(
&self,
_output: &Tensor,
_primals: &[Tensor],
_cotangent: &Tensor,
) -> RaiResult<Vec<Tensor>> {
Ok(vec![]).into()
}
}

Expand Down
11 changes: 8 additions & 3 deletions rai-core/src/ops/argmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,18 @@ impl Op for ArgMax {
}

#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
fn jvp(&self, output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> RaiResult<Tensor> {
output.zeros_like()
}

#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
vec![]
fn vjp(
&self,
output: &Tensor,
primals: &[Tensor],
cotangent: &Tensor,
) -> RaiResult<Vec<Tensor>> {
Ok(vec![]).into()
}
}

Expand Down
11 changes: 8 additions & 3 deletions rai-core/src/ops/argmin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,18 @@ impl Op for ArgMin {
}

#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
fn jvp(&self, output: &Tensor, primals: &[Tensor], tangents: &[Tensor]) -> RaiResult<Tensor> {
output.zeros_like()
}

#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
vec![]
fn vjp(
&self,
output: &Tensor,
primals: &[Tensor],
cotangent: &Tensor,
) -> RaiResult<Vec<Tensor>> {
Ok(vec![]).into()
}
}

Expand Down
11 changes: 8 additions & 3 deletions rai-core/src/ops/avg_pool1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,17 @@ impl Op for AvgPool1d {
}

#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, _primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
fn jvp(&self, _output: &Tensor, _primals: &[Tensor], tangents: &[Tensor]) -> RaiResult<Tensor> {
todo!("jvp for AvgPool1d")
}

#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
fn vjp(
&self,
_output: &Tensor,
primals: &[Tensor],
cotangent: &Tensor,
) -> RaiResult<Vec<Tensor>> {
assert_eq!(
self.kernel_size, self.stride,
"vjp not supported for avgPool1d if kernel_size != stride"
Expand All @@ -50,7 +55,7 @@ impl Op for AvgPool1d {
let [_n, _c, l] = x.sizes(Before::<3>);
let cotan_upsampled = cotangent.upsample_nearest1d(l);
let cotan_x = cotan_upsampled * (1.0f32 / self.kernel_size as f32);
vec![cotan_x]
vec![cotan_x].into_iter().collect()
}
}

Expand Down
11 changes: 8 additions & 3 deletions rai-core/src/ops/avg_pool2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,17 @@ impl Op for AvgPool2d {
}

#[tracing::instrument(ret(level = Level::TRACE))]
fn jvp(&self, _output: &Tensor, _primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
fn jvp(&self, _output: &Tensor, _primals: &[Tensor], tangents: &[Tensor]) -> RaiResult<Tensor> {
todo!("jvp for AvgPool2d")
}

#[tracing::instrument(ret(level = Level::TRACE))]
fn vjp(&self, output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
fn vjp(
&self,
output: &Tensor,
primals: &[Tensor],
cotangent: &Tensor,
) -> RaiResult<Vec<Tensor>> {
assert_eq!(
self.kernel_size, self.stride,
"vjp not supported for avgPool2d if kernel_size != stride"
Expand All @@ -57,7 +62,7 @@ impl Op for AvgPool2d {
let [_n, _c, h, w] = x.sizes(Before::<4>);
let cotan_upsampled = cotangent.upsample_nearest2d([h, w]);
let cotan_x = cotan_upsampled * (1.0f32 / (self.kernel_size.0 * self.kernel_size.1) as f32);
vec![cotan_x]
vec![cotan_x].into_iter().collect()
}
}

Expand Down
Loading

0 comments on commit 5111c5a

Please sign in to comment.