Skip to content

Commit

Permalink
feat: better type generality for matrix element iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
code-sam committed Dec 25, 2024
1 parent ed01827 commit 681a4e3
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ static DEFAULT_GRAPHBLAS_OPERATOR_OPTIONS: Lazy<OperatorOptions> =

pub struct MatrixElementCoordinateIterator<'a> {
graphblas_context: Arc<Context>,
matrix: &'a GrB_Matrix,
graphblas_matrix: &'a GrB_Matrix,
graphblas_iterator: GxB_Iterator,
next_element: fn(&Arc<Context>, &GrB_Matrix, GxB_Iterator) -> Option<Coordinate>,
}
Expand All @@ -33,19 +33,19 @@ impl<'a> MatrixElementCoordinateIterator<'a> {

Ok(Self {
graphblas_context: matrix.context(),
matrix: unsafe { matrix.graphblas_matrix_ref() },
graphblas_matrix: unsafe { matrix.graphblas_matrix_ref() },
graphblas_iterator,
next_element: initial_matrix_element_coordinate,
})
}
}

fn initial_matrix_element_coordinate(
context: &Arc<Context>,
graphblas_context: &Arc<Context>,
matrix: &GrB_Matrix,
graphblas_iterator: GxB_Iterator,
) -> Option<Coordinate> {
match context.call(
match graphblas_context.call(
|| unsafe {
GxB_Matrix_Iterator_attach(
graphblas_iterator,
Expand All @@ -59,7 +59,7 @@ fn initial_matrix_element_coordinate(
Err(error) => return match_iterator_error(error),
}

match context.call(
match graphblas_context.call(
|| unsafe { GxB_Matrix_Iterator_seek(graphblas_iterator, 0) },
matrix, // TODO: check that error indeed link to the matrix the iterator was attached to
) {
Expand All @@ -68,7 +68,7 @@ fn initial_matrix_element_coordinate(
Err(error) => return match_iterator_error(error),
}

let next_index = match context.call(
let next_index = match graphblas_context.call(
|| unsafe { GxB_Matrix_Iterator_seek(graphblas_iterator, 0) },
matrix, // TODO: check that error indeed link to the matrix the iterator was attached to
) {
Expand All @@ -92,7 +92,7 @@ impl<'a> Iterator for MatrixElementCoordinateIterator<'a> {

fn next(&mut self) -> Option<Coordinate> {
let next_matrix_element_coordinate =
(self.next_element)(&self.graphblas_context, self.matrix, self.graphblas_iterator);
(self.next_element)(&self.graphblas_context, self.graphblas_matrix, self.graphblas_iterator);

self.next_element = next_element_coordinate;

Expand All @@ -102,12 +102,12 @@ impl<'a> Iterator for MatrixElementCoordinateIterator<'a> {

fn next_element_coordinate(
context: &Arc<Context>,
matrix: &GrB_Matrix,
graphblas_matrix: &GrB_Matrix,
graphblas_iterator: GxB_Iterator,
) -> Option<Coordinate> {
match context.call(
|| unsafe { GxB_Matrix_Iterator_next(graphblas_iterator) },
matrix, // TODO: check that error indeed link to the matrix the iterator was attached to
graphblas_matrix, // TODO: check that error indeed link to the matrix the iterator was attached to
) {
Ok(_) => matrix_element_coordinate_at_iterator_position(graphblas_iterator),
Err(error) => match_iterator_error(error),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use std::sync::Arc;

use once_cell::sync::Lazy;
use suitesparse_graphblas_sys::{
GxB_Iterator, GxB_Iterator_free, GxB_Matrix_Iterator_attach, GxB_Matrix_Iterator_next,
GxB_Matrix_Iterator_seek,
GrB_Matrix, GxB_Iterator, GxB_Iterator_free, GxB_Matrix_Iterator_attach, GxB_Matrix_Iterator_next, GxB_Matrix_Iterator_seek
};

use crate::collections::sparse_matrix::MatrixElement;
use crate::collections::sparse_matrix::{GetGraphblasSparseMatrix, SparseMatrix};
use crate::collections::sparse_matrix::GetGraphblasSparseMatrix;
use crate::collections::{new_graphblas_iterator, GetElementValueAtIteratorPosition};
use crate::context::CallGraphBlasContext;
use crate::context::{CallGraphBlasContext, Context};
use crate::context::GetContext;
use crate::error::SparseLinearAlgebraError;
use crate::error::{GraphblasErrorType, LogicErrorType, SparseLinearAlgebraErrorType};
Expand All @@ -21,53 +22,56 @@ static DEFAULT_GRAPHBLAS_OPERATOR_OPTIONS: Lazy<OperatorOptions> =
Lazy::new(|| OperatorOptions::new_default());

pub struct MatrixElementIterator<'a, T: ValueType + GetElementValueAtIteratorPosition<T>> {
matrix: &'a SparseMatrix<T>,
graphblas_context: Arc<Context>,
graphblas_matrix: &'a GrB_Matrix,
graphblas_iterator: GxB_Iterator,
next_element: fn(&SparseMatrix<T>, GxB_Iterator) -> Option<MatrixElement<T>>,
next_element: fn(&Arc<Context>, &GrB_Matrix, GxB_Iterator) -> Option<MatrixElement<T>>,
}

impl<'a, T: ValueType + GetElementValueAtIteratorPosition<T>> MatrixElementIterator<'a, T> {
pub fn new(matrix: &'a SparseMatrix<T>) -> Result<Self, SparseLinearAlgebraError> {
let graphblas_iterator = unsafe { new_graphblas_iterator(matrix.context_ref()) }?;
pub fn new(graphblas_matrix: &'a (impl GetGraphblasSparseMatrix + GetContext)) -> Result<Self, SparseLinearAlgebraError> {
let graphblas_iterator = unsafe { new_graphblas_iterator(graphblas_matrix.context_ref()) }?;

Ok(Self {
matrix,
graphblas_context: graphblas_matrix.context(),
graphblas_matrix: unsafe { graphblas_matrix.graphblas_matrix_ref() },
graphblas_iterator,
next_element: initial_matrix_element,
})
}
}

fn initial_matrix_element<T: ValueType + GetElementValueAtIteratorPosition<T>>(
matrix: &SparseMatrix<T>,
graphblas_context: &Arc<Context>,
graphblas_matrix: &GrB_Matrix,
graphblas_iterator: GxB_Iterator,
) -> Option<MatrixElement<T>> {
match matrix.context_ref().call(
match graphblas_context.call(
|| unsafe {
GxB_Matrix_Iterator_attach(
graphblas_iterator,
matrix.graphblas_matrix(),
graphblas_matrix.to_owned(),
DEFAULT_GRAPHBLAS_OPERATOR_OPTIONS.graphblas_descriptor(),
)
},
unsafe { &matrix.graphblas_matrix() }, // TODO: check that error indeed link to the matrix the iterator was attached to
graphblas_matrix, // TODO: check that error indeed link to the matrix the iterator was attached to
) {
Ok(_) => {}
Err(error) => return match_iterator_error(error),
}

match matrix.context_ref().call(
match graphblas_context.call(
|| unsafe { GxB_Matrix_Iterator_seek(graphblas_iterator, 0) },
unsafe { &matrix.graphblas_matrix() }, // TODO: check that error indeed link to the matrix the iterator was attached to
graphblas_matrix, // TODO: check that error indeed link to the matrix the iterator was attached to
) {
Ok(_) => {}
// TODO: attaching may actually fail, this will cause a panic, which is not desired
Err(error) => return match_iterator_error(error),
}

let next_value = match matrix.context_ref().call(
let next_value = match graphblas_context.call(
|| unsafe { GxB_Matrix_Iterator_seek(graphblas_iterator, 0) },
unsafe { &matrix.graphblas_matrix() }, // TODO: check that error indeed link to the matrix the iterator was attached to
graphblas_matrix, // TODO: check that error indeed link to the matrix the iterator was attached to
) {
Ok(_) => matrix_element_at_iterator_position(graphblas_iterator),
Err(error) => match_iterator_error(error),
Expand All @@ -80,8 +84,7 @@ impl<'a, T: ValueType + GetElementValueAtIteratorPosition<T>> Drop
for MatrixElementIterator<'a, T>
{
fn drop(&mut self) {
let context = self.matrix.context_ref();
let _ = context.call_without_detailed_error_information(|| unsafe {
let _ = self.graphblas_context.call_without_detailed_error_information(|| unsafe {
GxB_Iterator_free(&mut self.graphblas_iterator)
});
}
Expand All @@ -93,7 +96,7 @@ impl<'a, T: ValueType + GetElementValueAtIteratorPosition<T>> Iterator
type Item = MatrixElement<T>;

fn next(&mut self) -> Option<MatrixElement<T>> {
let next_matrix_element = (self.next_element)(self.matrix, self.graphblas_iterator);
let next_matrix_element = (self.next_element)(&self.graphblas_context, self.graphblas_matrix, self.graphblas_iterator);

self.next_element = next_element;

Expand All @@ -102,12 +105,13 @@ impl<'a, T: ValueType + GetElementValueAtIteratorPosition<T>> Iterator
}

fn next_element<T: ValueType + GetElementValueAtIteratorPosition<T>>(
matrix: &SparseMatrix<T>,
graphblas_context: &Arc<Context>,
graphblas_matrix: &GrB_Matrix,
graphblas_iterator: GxB_Iterator,
) -> Option<MatrixElement<T>> {
match matrix.context_ref().call(
match graphblas_context.call(
|| unsafe { GxB_Matrix_Iterator_next(graphblas_iterator) },
unsafe { &matrix.graphblas_matrix() }, // TODO: check that error indeed link to the matrix the iterator was attached to
graphblas_matrix, // TODO: check that error indeed link to the matrix the iterator was attached to
) {
Ok(_) => matrix_element_at_iterator_position::<T>(graphblas_iterator),
Err(error) => match_iterator_error(error),
Expand Down Expand Up @@ -181,7 +185,7 @@ mod tests {
)
.unwrap();

let matrix_element_iterator = MatrixElementIterator::new(&matrix).unwrap();
let matrix_element_iterator = MatrixElementIterator::<u8>::new(&matrix).unwrap();

for (element, expected_element) in matrix_element_iterator
.into_iter()
Expand All @@ -207,7 +211,7 @@ mod tests {
)
.unwrap();

let matrix_element_iterator = MatrixElementIterator::new(&matrix).unwrap();
let matrix_element_iterator = MatrixElementIterator::<u8>::new(&matrix).unwrap();

for (element, expected_element) in matrix_element_iterator
.into_iter()
Expand Down

0 comments on commit 681a4e3

Please sign in to comment.