Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy Lookahead Relabeling #230

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions rustfst-ffi/src/algorithms/compose.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ pub extern "C" fn fst_matcher_config_destroy(ptr: *mut CMatcherConfig) -> RUSTFS
}

unsafe {
Box::from_raw(ptr);
drop(Box::from_raw(ptr));
}
Ok(())
})
Expand All @@ -283,7 +283,7 @@ pub extern "C" fn fst_compose_config_destroy(ptr: *mut CComposeConfig) -> RUSTFS
}

unsafe {
Box::from_raw(ptr);
drop(Box::from_raw(ptr));
}
Ok(())
})
Expand Down
2 changes: 1 addition & 1 deletion rustfst-ffi/src/fst/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ pub fn fst_destroy(fst_ptr: *mut CFst) -> RUSTFST_FFI_RESULT {
}

unsafe {
Box::from_raw(fst_ptr);
drop(Box::from_raw(fst_ptr));
}
Ok(())
})
Expand Down
6 changes: 3 additions & 3 deletions rustfst-ffi/src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ pub extern "C" fn trs_iterator_destroy(iter_ptr: *mut CTrsIterator) -> RUSTFST_F
}

unsafe {
Box::from_raw(iter_ptr);
drop(Box::from_raw(iter_ptr));
}
Ok(())
})
Expand Down Expand Up @@ -274,7 +274,7 @@ pub extern "C" fn mut_trs_iterator_destroy(iter_ptr: *mut CMutTrsIterator) -> RU
}

unsafe {
Box::from_raw(iter_ptr);
drop(Box::from_raw(iter_ptr));
}
Ok(())
})
Expand Down Expand Up @@ -333,7 +333,7 @@ pub extern "C" fn state_iterator_destroy(iter_ptr: *mut CStateIterator) -> RUSTF
}

unsafe {
Box::from_raw(iter_ptr);
drop(Box::from_raw(iter_ptr));
}
Ok(())
})
Expand Down
2 changes: 1 addition & 1 deletion rustfst-ffi/src/string_path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub extern "C" fn string_path_destroy(iter_ptr: *mut CStringPath) -> RUSTFST_FFI
}

unsafe {
Box::from_raw(iter_ptr);
drop(Box::from_raw(iter_ptr));
}
Ok(())
})
Expand Down
2 changes: 1 addition & 1 deletion rustfst-ffi/src/string_paths_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ pub extern "C" fn string_paths_iterator_destroy(
}

unsafe {
Box::from_raw(iter_ptr);
drop(Box::from_raw(iter_ptr));
}
Ok(())
})
Expand Down
2 changes: 1 addition & 1 deletion rustfst-ffi/src/symbol_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ pub extern "C" fn symt_destroy(symt_ptr: *mut CSymbolTable) -> RUSTFST_FFI_RESUL
}

unsafe {
Box::from_raw(symt_ptr);
drop(Box::from_raw(symt_ptr));
}
Ok(())
})
Expand Down
2 changes: 1 addition & 1 deletion rustfst-ffi/src/tr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ pub extern "C" fn tr_delete(tr_ptr: *mut CTr) -> RUSTFST_FFI_RESULT {
}

unsafe {
Box::from_raw(tr_ptr);
drop(Box::from_raw(tr_ptr));
}
Ok(())
})
Expand Down
2 changes: 1 addition & 1 deletion rustfst-ffi/src/trs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ pub extern "C" fn trs_vec_delete(trs_ptr: *mut CTrs) -> RUSTFST_FFI_RESULT {
}

unsafe {
Box::from_raw(trs_ptr);
drop(Box::from_raw(trs_ptr));
}
Ok(())
})
Expand Down
11 changes: 8 additions & 3 deletions rustfst/src/algorithms/compose/compose_fst.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use anyhow::Result;
use std::borrow::Borrow;
use std::fmt::Debug;
use std::path::Path;
use std::sync::Arc;

use anyhow::Result;

use crate::algorithms::compose::compose_filters::{
ComposeFilter, ComposeFilterBuilder, SequenceComposeFilterBuilder,
Expand All @@ -18,7 +20,6 @@ use crate::fst_traits::{AllocableFst, CoreFst, Fst, FstIterator, MutableFst, Sta
use crate::parsers::SerializeBinary;
use crate::semirings::{Semiring, SerializableSemiring};
use crate::{StateId, SymbolTable, TrsVec};
use std::sync::Arc;

type InnerLazyFst<W, F1, F2, B1, B2, M1, M2, CFB, Cache> =
LazyFst<W, ComposeFstOp<W, F1, F2, B1, B2, M1, M2, CFB>, Cache>;
Expand Down Expand Up @@ -160,6 +161,9 @@ where

/// Turns the Lazy FST into a static one.
pub fn compute<F: MutableFst<W> + AllocableFst<W>>(&self) -> Result<F> {
// Small trick to make sure that both FSTs are fully expanded.
// iterate_lazy(self.0.op.fst1.borrow())?;
// iterate_lazy(self.0.op.fst2.borrow())?;
self.0.compute()
}
}
Expand Down Expand Up @@ -356,11 +360,12 @@ where

#[cfg(test)]
mod test {
use super::*;
use crate::algorithms::compose::matchers::SortedMatcher;
use crate::fst_impls::VectorFst;
use crate::semirings::TropicalWeight;

use super::*;

#[test]
fn test_compose_fst_sync() {
fn is_sync<T: Sync>() {}
Expand Down
7 changes: 4 additions & 3 deletions rustfst/src/algorithms/compose/compose_fst_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ where
>,
match_type: MatchType,
properties: FstProperties,
fst1: B1,
fst2: B2,
pub(crate) fst1: B1,
pub(crate) fst2: B2,
}

impl<W, F1, F2, B1, B2, M1, M2, CFB> Clone for ComposeFstOp<W, F1, F2, B1, B2, M1, M2, CFB>
Expand Down Expand Up @@ -389,11 +389,12 @@ where
fn compute_start(&self) -> Result<Option<StateId>> {
let compose_filter = self.compose_filter_builder.build()?;
let s1 = self.fst1.borrow().start();
// Let's put it here to force the fst2 to have its start state computed in the case of a Lazy Fst.
let s2 = self.fst2.borrow().start();
if s1.is_none() {
return Ok(None);
}
let s1 = s1.unwrap();
let s2 = self.fst2.borrow().start();
if s2.is_none() {
return Ok(None);
}
Expand Down
13 changes: 8 additions & 5 deletions rustfst/src/algorithms/compose/interval_reach_visitor.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use std::cmp::Ordering;

use anyhow::Result;

use crate::algorithms::compose::{IntInterval, IntervalSet};
use crate::algorithms::dfs_visit::Visitor;
use crate::fst_traits::Fst;
use crate::semirings::Semiring;
use crate::{StateId, Tr};
use std::cmp::Ordering;

static UNASSIGNED: usize = std::usize::MAX;

Expand All @@ -30,18 +33,18 @@ impl<'a, W: Semiring, F: Fst<W>> Visitor<'a, W, F> for IntervalReachVisitor<'a,
fn init_visit(&mut self, _fst: &'a F) {}

/// Invoked when state discovered (2nd arg is DFS tree root).
fn init_state(&mut self, s: StateId, _root: StateId) -> bool {
fn init_state(&mut self, s: StateId, _root: StateId) -> Result<bool> {
while self.isets.len() <= (s as usize) {
self.isets.push(IntervalSet::default());
}
while self.state2index.len() <= (s as usize) {
self.state2index.push(UNASSIGNED);
}
if let Some(final_weight) = self.fst.final_weight(s).unwrap() {
if let Some(final_weight) = self.fst.final_weight(s)? {
if !final_weight.is_zero() {
let interval_set = &mut self.isets[s as usize];
if self.index == UNASSIGNED {
if self.fst.num_trs(s).unwrap() > 0 {
if self.fst.num_trs(s)? > 0 {
panic!("IntervalReachVisitor: state2index map must be empty for this FST")
}
let index = self.state2index[s as usize];
Expand All @@ -56,7 +59,7 @@ impl<'a, W: Semiring, F: Fst<W>> Visitor<'a, W, F> for IntervalReachVisitor<'a,
}
}
}
true
Ok(true)
}

/// Invoked when tree transition to white/undiscovered state examined.
Expand Down
42 changes: 42 additions & 0 deletions rustfst/src/algorithms/compose/label_reachable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::algorithms::{fst_convert_from_ref, tr_sort};
use crate::fst_impls::VectorFst;
use crate::fst_properties::FstProperties;
use crate::fst_traits::{CoreFst, ExpandedFst, Fst, MutableFst};
use crate::prelude::compose::LookaheadRelabelFst;
use crate::semirings::Semiring;
use crate::{Label, StateId, Tr, Trs, EPS_LABEL, NO_LABEL, UNASSIGNED};

Expand Down Expand Up @@ -60,6 +61,19 @@ impl LabelReachableData {
.or_insert_with(|| n as Label + 1)
}

pub fn relabel_unmut(&self, label: Label) -> Result<Label> {
if label == EPS_LABEL {
return Ok(EPS_LABEL);
}
Ok(*self.label2index.get(&label).ok_or_else(|| {
format_err!(
"Missing label {:?} from relabelling table {:?}",
label,
self.label2index
)
})?)
}

pub fn relabel_fst<W: Semiring, F: MutableFst<W>>(
&mut self,
fst: &mut F,
Expand Down Expand Up @@ -92,6 +106,34 @@ impl LabelReachableData {
Ok(())
}

pub fn relabel_fst_lazy<W: Semiring, F: Fst<W> + 'static>(
&mut self,
fst: Arc<F>,
relabel_input: bool,
) -> Result<impl Fst<W>> {
// 1. Retrieve the relevant SymbolTable
let symt = if relabel_input {
fst.input_symbols()
.ok_or_else(|| format_err!("Input SymbolTable not attached to the fst which is mandatory for the relabel fst lazy to work"))?
} else {
fst.output_symbols()
.ok_or_else(|| format_err!("Output SymbolTable not attached to the fst which is mandatory for the relabel fst lazy to work"))?
};

// 2. Use the SymbolTable as a proxy to retrieve all the labels that are going to be ralabelled.
// By calling `relabel` on each of those we guarantee we won't see other labels at inference.
for (label, _) in symt.iter() {
self.relabel(label);
}

// 3. Create a custom Lazy Fst that will take as input the current LabelReachableData object
// with the whole relabeling mapping.
// This Lazy fst, will 1. use the mapping to relabel 2. sort the transitions
let la_relabel_fst = LookaheadRelabelFst::<W, F, _>::new(fst, self.clone(), relabel_input)?;

Ok(la_relabel_fst)
}

// Returns relabeling pairs (cf. relabel.h::Relabel()). If avoid_collisions is
// true, extra pairs are added to ensure no collisions when relabeling
// automata that have labels unseen here.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::algorithms::compose::lookahead_filters::LookAheadComposeFilterTrait;
use crate::algorithms::compose::lookahead_matchers::{LookAheadMatcherData, LookaheadMatcher};
use crate::algorithms::compose::matchers::{MatchType, MatcherFlags};
use crate::fst_properties::FstProperties;
use crate::fst_traits::{ExpandedFst, Fst};
use crate::fst_traits::Fst;
use crate::semirings::{
DivideType, Semiring, SerializableSemiring, WeaklyDivisibleSemiring, WeightQuantize,
};
Expand Down Expand Up @@ -80,8 +80,8 @@ impl<W, F1, F2, B1, B2, M1, M2, CFB, SMT> ComposeFilterBuilder<W, F1, F2, B1, B2
for PushWeightsComposeFilterBuilder<W, F1, F2, B1, B2, M1, M2, CFB, SMT>
where
W: SerializableSemiring + WeaklyDivisibleSemiring + WeightQuantize,
F1: ExpandedFst<W>,
F2: ExpandedFst<W>,
F1: Fst<W>,
F2: Fst<W>,
B1: Borrow<F1> + Debug,
B2: Borrow<F2> + Debug,
M1: LookaheadMatcher<W, F1, B1>,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use anyhow::Result;

use crate::algorithms::compose::LabelReachableData;
use crate::fst_traits::MutableFst;
use crate::fst_traits::{Fst, MutableFst};
use crate::semirings::Semiring;
use std::sync::Arc;

pub struct LabelLookAheadRelabeler {}

Expand Down Expand Up @@ -39,4 +40,22 @@ impl LabelLookAheadRelabeler {
}
bail!("Addon contains only None elements")
}

pub fn relabel_lazy<W: Semiring, F: Fst<W> + 'static>(
fst: Arc<F>,
addon: &mut (Option<LabelReachableData>, Option<LabelReachableData>),
relabel_input: bool,
) -> Result<impl Fst<W>> {
if let Some(reachable_data) = &mut addon.0 {
let lazy_fst = reachable_data.relabel_fst_lazy(fst, relabel_input)?;
return Ok(lazy_fst);
}

if let Some(reachable_data) = &mut addon.1 {
let lazy_fst = reachable_data.relabel_fst_lazy(fst, relabel_input)?;
return Ok(lazy_fst);
}

bail!("Addon contains only None elements")
}
}
Loading