diff --git a/core/src/ops/matmul/lir_unary.rs b/core/src/ops/matmul/lir_unary.rs index 7f6925437d..e3e3721b86 100644 --- a/core/src/ops/matmul/lir_unary.rs +++ b/core/src/ops/matmul/lir_unary.rs @@ -45,21 +45,19 @@ impl ProtoFusedSpec { output: &Tensor, ) -> FusedSpec<'t> { let fs = match self { - ProtoFusedSpec::AddMatMul { geo, a, b, packing } => { - let mut a = inputs[*a].view(); - unsafe { - geo.c_to_a_axis_mapping.translate_view(output_coords, &mut a); - } - let a = - a.as_slice::().unwrap()[0].downcast_ref::>().unwrap(); - let mut b = inputs[*b].view(); - unsafe { - geo.c_to_b_axis_mapping.translate_view(output_coords, &mut b); - } - let b = - b.as_slice::().unwrap()[0].downcast_ref::>().unwrap(); + ProtoFusedSpec::AddMatMul { geo, a, b, packing } => unsafe { + let mut a = inputs.get_unchecked(*a).view(); + geo.c_to_a_axis_mapping.translate_view(output_coords, &mut a); + let a = a.as_slice_unchecked::()[0] + .downcast_ref::>() + .unwrap_unchecked(); + let mut b = inputs.get_unchecked(*b).view(); + geo.c_to_b_axis_mapping.translate_view(output_coords, &mut b); + let b = b.as_slice_unchecked::()[0] + .downcast_ref::>() + .unwrap_unchecked(); FusedSpec::AddMatMul { a: &**a, b: &**b, packing: *packing } - } + }, ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op), ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]), ProtoFusedSpec::BinPerRow(v, op, map) => { @@ -96,21 +94,26 @@ impl ProtoFusedSpec { } } + #[inline] pub fn resolve_trivial<'t>( &'t self, inputs: &'t [TValue], output: &mut Tensor, ) -> FusedSpec<'t> { let fs = match self { - ProtoFusedSpec::AddMatMul { a, b, packing, .. } => { - let a = &inputs[*a]; - let b = &inputs[*b]; - let a = - a.to_scalar::().unwrap().downcast_ref::>().unwrap(); - let b = - b.to_scalar::().unwrap().downcast_ref::>().unwrap(); + ProtoFusedSpec::AddMatMul { a, b, packing, .. } => unsafe { + let a = &inputs.get_unchecked(*a); + let b = &inputs.get_unchecked(*b); + let a = a + .to_scalar_unchecked::() + .downcast_ref::>() + .unwrap_unchecked(); + let b = b + .to_scalar_unchecked::() + .downcast_ref::>() + .unwrap_unchecked(); FusedSpec::AddMatMul { a: &**a, b: &**b, packing: *packing } - } + }, ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op), ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]), ProtoFusedSpec::BinPerRow(v, op, _) => { @@ -282,59 +285,91 @@ impl Op for LirMatMulUnary { impl EvalOp for LirMatMulUnary { fn is_stateless(&self) -> bool { - true + false } - fn eval_with_session( + fn state( &self, - session: &SessionState, + _session: &mut SessionState, + _node_id: usize, + ) -> TractResult>> { + Ok(Some(Box::new(LirMatMulUnaryState::default()))) + } +} + +#[derive(Clone, Debug, Default)] +struct LirMatMulUnaryState(Vec>); + +impl OpState for LirMatMulUnaryState { + fn eval( + &mut self, + session: &mut SessionState, + op: &dyn Op, inputs: TVec, ) -> TractResult> { + let op = op.downcast_ref::().unwrap(); unsafe { let mut cell = session.cached_mmm_scratch_space.borrow_mut(); if !cell .as_ref() - .map(|scratch| self.mmm.can_use_scratch_space(&**scratch)) + .map(|scratch| op.mmm.can_use_scratch_space(&**scratch)) .unwrap_or(false) { *cell = None } - let scratch = cell.get_or_insert_with(|| self.mmm.allocate_scratch_space()); - - if self.trivial_path { - let c_shape = self.c_fact.shape.as_concrete().unwrap_unchecked(); - let geometry = self.geometry.as_concrete().unwrap_unchecked(); - let mut c = Tensor::uninitialized_dt(self.c_fact.datum_type, c_shape)?; - let uops: Vec = - self.micro_ops.iter().map(|o| o.resolve_trivial(&inputs, &mut c)).collect(); - self.mmm.run_with_scratch_space(geometry.m, geometry.n, scratch.as_mut(), &uops)?; - Ok(tvec!(c.into_tvalue())) + self.0.reserve(op.micro_ops.len().saturating_sub(self.0.capacity())); + #[allow(clippy::uninit_vec)] + self.0.set_len(op.micro_ops.len()); + // kill static lifefime! + let fused_spec: &mut Vec = std::mem::transmute(&mut self.0); + let scratch = cell.get_or_insert_with(|| op.mmm.allocate_scratch_space()); + + let c = if op.trivial_path { + let c_shape = op.c_fact.shape.as_concrete().unwrap_unchecked(); + let geometry = op.geometry.as_concrete().unwrap_unchecked(); + let mut c = Tensor::uninitialized_dt(op.c_fact.datum_type, c_shape)?; + for i in 0..op.micro_ops.len() { + *fused_spec.get_unchecked_mut(i) = + op.micro_ops.get_unchecked(i).resolve_trivial(&inputs, &mut c); + } + op.mmm.run_with_scratch_space( + geometry.m, + geometry.n, + scratch.as_mut(), + fused_spec, + )?; + c } else { - let geometry = self.geometry.to_concrete(&session.resolved_symbols)?; - let c_shape = self.c_fact.shape.eval_to_usize(&session.resolved_symbols)?; - let c = Tensor::uninitialized_dt(self.c_fact.datum_type, &c_shape)?; - let mut uops = vec![FusedSpec::ShiftLeft(0); self.micro_ops.len()]; + let geometry = op.geometry.to_concrete(&session.resolved_symbols)?; + let c_shape = op.c_fact.shape.eval_to_usize(&session.resolved_symbols)?; + let c = Tensor::uninitialized_dt(op.c_fact.datum_type, &c_shape)?; let mut looping_shape: TVec = c_shape.to_smallvec(); - looping_shape[self.c_m_axis] = 1; - looping_shape[self.c_n_axis] = 1; + looping_shape[op.c_m_axis] = 1; + looping_shape[op.c_n_axis] = 1; for c_coords in indices(&*looping_shape) { - for ix in 0..self.micro_ops.len() { - *uops.get_unchecked_mut(ix) = - self.micro_ops.get_unchecked(ix).resolve(&inputs, c_coords.slice(), &c); + for i in 0..op.micro_ops.len() { + *fused_spec.get_unchecked_mut(i) = + op.micro_ops.get_unchecked(i).resolve(&inputs, c_coords.slice(), &c) } - self.mmm.run_with_scratch_space( - geometry.m, - geometry.n, - scratch.as_mut(), - &uops, - ).context("In mmm.run_with_scratch_space")?; + op.mmm + .run_with_scratch_space( + geometry.m, + geometry.n, + scratch.as_mut(), + fused_spec, + ) + .context("In mmm.run_with_scratch_space")?; } - Ok(tvec!(c.into_tvalue())) - } + c + }; + fused_spec.clear(); + Ok(tvec!(c.into_tvalue())) } } } +trivial_op_state_freeeze!(LirMatMulUnaryState); + impl TypedOp for LirMatMulUnary { fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { ensure!(self.c_m_axis < self.c_fact.rank());