Skip to content

Commit

Permalink
Merge 55f34eb into 1cc9b8a
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Jul 1, 2024
2 parents 1cc9b8a + 55f34eb commit 2e334d4
Showing 1 changed file with 88 additions and 53 deletions.
141 changes: 88 additions & 53 deletions core/src/ops/matmul/lir_unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,19 @@ impl ProtoFusedSpec {
output: &Tensor,
) -> FusedSpec<'t> {
let fs = match self {
ProtoFusedSpec::AddMatMul { geo, a, b, packing } => {
let mut a = inputs[*a].view();
unsafe {
geo.c_to_a_axis_mapping.translate_view(output_coords, &mut a);
}
let a =
a.as_slice::<Opaque>().unwrap()[0].downcast_ref::<Box<dyn MMMInput>>().unwrap();
let mut b = inputs[*b].view();
unsafe {
geo.c_to_b_axis_mapping.translate_view(output_coords, &mut b);
}
let b =
b.as_slice::<Opaque>().unwrap()[0].downcast_ref::<Box<dyn MMMInput>>().unwrap();
ProtoFusedSpec::AddMatMul { geo, a, b, packing } => unsafe {
let mut a = inputs.get_unchecked(*a).view();
geo.c_to_a_axis_mapping.translate_view(output_coords, &mut a);
let a = a.as_slice_unchecked::<Opaque>()[0]
.downcast_ref::<Box<dyn MMMInput>>()
.unwrap_unchecked();
let mut b = inputs.get_unchecked(*b).view();
geo.c_to_b_axis_mapping.translate_view(output_coords, &mut b);
let b = b.as_slice_unchecked::<Opaque>()[0]
.downcast_ref::<Box<dyn MMMInput>>()
.unwrap_unchecked();
FusedSpec::AddMatMul { a: &**a, b: &**b, packing: *packing }
}
},
ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
ProtoFusedSpec::BinPerRow(v, op, map) => {
Expand Down Expand Up @@ -96,21 +94,26 @@ impl ProtoFusedSpec {
}
}

#[inline]
pub fn resolve_trivial<'t>(
&'t self,
inputs: &'t [TValue],
output: &mut Tensor,
) -> FusedSpec<'t> {
let fs = match self {
ProtoFusedSpec::AddMatMul { a, b, packing, .. } => {
let a = &inputs[*a];
let b = &inputs[*b];
let a =
a.to_scalar::<Opaque>().unwrap().downcast_ref::<Box<dyn MMMInput>>().unwrap();
let b =
b.to_scalar::<Opaque>().unwrap().downcast_ref::<Box<dyn MMMInput>>().unwrap();
ProtoFusedSpec::AddMatMul { a, b, packing, .. } => unsafe {
let a = &inputs.get_unchecked(*a);
let b = &inputs.get_unchecked(*b);
let a = a
.to_scalar_unchecked::<Opaque>()
.downcast_ref::<Box<dyn MMMInput>>()
.unwrap_unchecked();
let b = b
.to_scalar_unchecked::<Opaque>()
.downcast_ref::<Box<dyn MMMInput>>()
.unwrap_unchecked();
FusedSpec::AddMatMul { a: &**a, b: &**b, packing: *packing }
}
},
ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
ProtoFusedSpec::BinPerRow(v, op, _) => {
Expand Down Expand Up @@ -282,59 +285,91 @@ impl Op for LirMatMulUnary {

impl EvalOp for LirMatMulUnary {
fn is_stateless(&self) -> bool {
true
false
}

fn eval_with_session(
fn state(
&self,
session: &SessionState,
_session: &mut SessionState,
_node_id: usize,
) -> TractResult<Option<Box<dyn OpState>>> {
Ok(Some(Box::new(LirMatMulUnaryState::default())))
}
}

#[derive(Clone, Debug, Default)]
struct LirMatMulUnaryState(Vec<FusedSpec<'static>>);

impl OpState for LirMatMulUnaryState {
fn eval(
&mut self,
session: &mut SessionState,
op: &dyn Op,
inputs: TVec<TValue>,
) -> TractResult<TVec<TValue>> {
let op = op.downcast_ref::<LirMatMulUnary>().unwrap();
unsafe {
let mut cell = session.cached_mmm_scratch_space.borrow_mut();
if !cell
.as_ref()
.map(|scratch| self.mmm.can_use_scratch_space(&**scratch))
.map(|scratch| op.mmm.can_use_scratch_space(&**scratch))
.unwrap_or(false)
{
*cell = None
}
let scratch = cell.get_or_insert_with(|| self.mmm.allocate_scratch_space());

if self.trivial_path {
let c_shape = self.c_fact.shape.as_concrete().unwrap_unchecked();
let geometry = self.geometry.as_concrete().unwrap_unchecked();
let mut c = Tensor::uninitialized_dt(self.c_fact.datum_type, c_shape)?;
let uops: Vec<FusedSpec> =
self.micro_ops.iter().map(|o| o.resolve_trivial(&inputs, &mut c)).collect();
self.mmm.run_with_scratch_space(geometry.m, geometry.n, scratch.as_mut(), &uops)?;
Ok(tvec!(c.into_tvalue()))
self.0.reserve(op.micro_ops.len().saturating_sub(self.0.capacity()));
#[allow(clippy::uninit_vec)]
self.0.set_len(op.micro_ops.len());
// kill static lifefime!
let fused_spec: &mut Vec<FusedSpec> = std::mem::transmute(&mut self.0);
let scratch = cell.get_or_insert_with(|| op.mmm.allocate_scratch_space());

let c = if op.trivial_path {
let c_shape = op.c_fact.shape.as_concrete().unwrap_unchecked();
let geometry = op.geometry.as_concrete().unwrap_unchecked();
let mut c = Tensor::uninitialized_dt(op.c_fact.datum_type, c_shape)?;
for i in 0..op.micro_ops.len() {
*fused_spec.get_unchecked_mut(i) =
op.micro_ops.get_unchecked(i).resolve_trivial(&inputs, &mut c);
}
op.mmm.run_with_scratch_space(
geometry.m,
geometry.n,
scratch.as_mut(),
fused_spec,
)?;
c
} else {
let geometry = self.geometry.to_concrete(&session.resolved_symbols)?;
let c_shape = self.c_fact.shape.eval_to_usize(&session.resolved_symbols)?;
let c = Tensor::uninitialized_dt(self.c_fact.datum_type, &c_shape)?;
let mut uops = vec![FusedSpec::ShiftLeft(0); self.micro_ops.len()];
let geometry = op.geometry.to_concrete(&session.resolved_symbols)?;
let c_shape = op.c_fact.shape.eval_to_usize(&session.resolved_symbols)?;
let c = Tensor::uninitialized_dt(op.c_fact.datum_type, &c_shape)?;
let mut looping_shape: TVec<usize> = c_shape.to_smallvec();
looping_shape[self.c_m_axis] = 1;
looping_shape[self.c_n_axis] = 1;
looping_shape[op.c_m_axis] = 1;
looping_shape[op.c_n_axis] = 1;
for c_coords in indices(&*looping_shape) {
for ix in 0..self.micro_ops.len() {
*uops.get_unchecked_mut(ix) =
self.micro_ops.get_unchecked(ix).resolve(&inputs, c_coords.slice(), &c);
for i in 0..op.micro_ops.len() {
*fused_spec.get_unchecked_mut(i) =
op.micro_ops.get_unchecked(i).resolve(&inputs, c_coords.slice(), &c)
}
self.mmm.run_with_scratch_space(
geometry.m,
geometry.n,
scratch.as_mut(),
&uops,
).context("In mmm.run_with_scratch_space")?;
op.mmm
.run_with_scratch_space(
geometry.m,
geometry.n,
scratch.as_mut(),
fused_spec,
)
.context("In mmm.run_with_scratch_space")?;
}
Ok(tvec!(c.into_tvalue()))
}
c
};
fused_spec.clear();
Ok(tvec!(c.into_tvalue()))
}
}
}

trivial_op_state_freeeze!(LirMatMulUnaryState);

impl TypedOp for LirMatMulUnary {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
ensure!(self.c_m_axis < self.c_fact.rank());
Expand Down

0 comments on commit 2e334d4

Please sign in to comment.