Skip to content

Commit

Permalink
Merge pull request #100 from code-sam/more-generic-iterators
Browse files Browse the repository at this point in the history
More generic iterators
  • Loading branch information
code-sam authored Dec 26, 2024
2 parents ba016e1 + f324321 commit a58597f
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 93 deletions.
2 changes: 1 addition & 1 deletion graphblas_sparse_linear_algebra/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "graphblas_sparse_linear_algebra"
version = "0.54.1"
version = "0.54.2"
authors = ["code_sam <mail@samdekker.nl>"]
description = "Wrapper for SuiteSparse:GraphBLAS"
edition = "2021"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,74 +1,79 @@
use std::mem::MaybeUninit;
use std::sync::Arc;

use once_cell::sync::Lazy;
use suitesparse_graphblas_sys::{
GxB_Iterator, GxB_Iterator_free, GxB_Matrix_Iterator_attach, GxB_Matrix_Iterator_getIndex,
GxB_Matrix_Iterator_next, GxB_Matrix_Iterator_seek,
GrB_Matrix, GxB_Iterator, GxB_Iterator_free, GxB_Matrix_Iterator_attach,
GxB_Matrix_Iterator_getIndex, GxB_Matrix_Iterator_next, GxB_Matrix_Iterator_seek,
};

use crate::collections::new_graphblas_iterator;
use crate::collections::sparse_matrix::{Coordinate, GetGraphblasSparseMatrix, SparseMatrix};
use crate::context::CallGraphBlasContext;
use crate::collections::sparse_matrix::{Coordinate, GetGraphblasSparseMatrix};
use crate::context::GetContext;
use crate::context::{CallGraphBlasContext, Context};
use crate::error::SparseLinearAlgebraError;
use crate::error::{GraphblasErrorType, LogicErrorType, SparseLinearAlgebraErrorType};
use crate::index::ElementIndex;
use crate::index::IndexConversion;
use crate::operators::options::GetGraphblasDescriptor;
use crate::operators::options::OperatorOptions;
use crate::value_type::ValueType;

static DEFAULT_GRAPHBLAS_OPERATOR_OPTIONS: Lazy<OperatorOptions> =
Lazy::new(|| OperatorOptions::new_default());

pub struct MatrixElementCoordinateIterator<'a, T: ValueType> {
matrix: &'a SparseMatrix<T>,
pub struct MatrixElementCoordinateIterator<'a> {
graphblas_context: Arc<Context>,
graphblas_matrix: &'a GrB_Matrix,
graphblas_iterator: GxB_Iterator,
next_element: fn(&SparseMatrix<T>, GxB_Iterator) -> Option<Coordinate>,
next_element: fn(&Arc<Context>, &GrB_Matrix, GxB_Iterator) -> Option<Coordinate>,
}

impl<'a, T: ValueType> MatrixElementCoordinateIterator<'a, T> {
pub fn new(matrix: &'a SparseMatrix<T>) -> Result<Self, SparseLinearAlgebraError> {
impl<'a> MatrixElementCoordinateIterator<'a> {
pub fn new(
matrix: &'a (impl GetGraphblasSparseMatrix + GetContext),
) -> Result<Self, SparseLinearAlgebraError> {
let graphblas_iterator = unsafe { new_graphblas_iterator(matrix.context_ref()) }?;

Ok(Self {
matrix,
graphblas_context: matrix.context(),
graphblas_matrix: unsafe { matrix.graphblas_matrix_ref() },
graphblas_iterator,
next_element: initial_matrix_element_coordinate,
})
}
}

fn initial_matrix_element_coordinate<T: ValueType>(
matrix: &SparseMatrix<T>,
fn initial_matrix_element_coordinate(
graphblas_context: &Arc<Context>,
matrix: &GrB_Matrix,
graphblas_iterator: GxB_Iterator,
) -> Option<Coordinate> {
match matrix.context_ref().call(
match graphblas_context.call(
|| unsafe {
GxB_Matrix_Iterator_attach(
graphblas_iterator,
matrix.graphblas_matrix(),
matrix.to_owned(),
DEFAULT_GRAPHBLAS_OPERATOR_OPTIONS.graphblas_descriptor(),
)
},
unsafe { &matrix.graphblas_matrix() }, // TODO: check that error indeed link to the matrix the iterator was attached to
matrix, // TODO: check that error indeed link to the matrix the iterator was attached to
) {
Ok(_) => {}
Err(error) => return match_iterator_error(error),
}

match matrix.context_ref().call(
match graphblas_context.call(
|| unsafe { GxB_Matrix_Iterator_seek(graphblas_iterator, 0) },
unsafe { &matrix.graphblas_matrix() }, // TODO: check that error indeed link to the matrix the iterator was attached to
matrix, // TODO: check that error indeed link to the matrix the iterator was attached to
) {
Ok(_) => {}
// TODO: attaching may actually fail, this will cause a panic, which is not desired
Err(error) => return match_iterator_error(error),
}

let next_index = match matrix.context_ref().call(
let next_index = match graphblas_context.call(
|| unsafe { GxB_Matrix_Iterator_seek(graphblas_iterator, 0) },
unsafe { &matrix.graphblas_matrix() }, // TODO: check that error indeed link to the matrix the iterator was attached to
matrix, // TODO: check that error indeed link to the matrix the iterator was attached to
) {
Ok(_) => matrix_element_coordinate_at_iterator_position(graphblas_iterator),
Err(error) => match_iterator_error(error),
Expand All @@ -77,35 +82,40 @@ fn initial_matrix_element_coordinate<T: ValueType>(
return next_index;
}

impl<'a, T: ValueType> Drop for MatrixElementCoordinateIterator<'a, T> {
impl<'a> Drop for MatrixElementCoordinateIterator<'a> {
fn drop(&mut self) {
let context = self.matrix.context_ref();
let _ = context.call_without_detailed_error_information(|| unsafe {
GxB_Iterator_free(&mut self.graphblas_iterator)
});
let _ = self
.graphblas_context
.call_without_detailed_error_information(|| unsafe {
GxB_Iterator_free(&mut self.graphblas_iterator)
});
}
}

impl<'a, T: ValueType> Iterator for MatrixElementCoordinateIterator<'a, T> {
impl<'a> Iterator for MatrixElementCoordinateIterator<'a> {
type Item = Coordinate;

fn next(&mut self) -> Option<Coordinate> {
let next_matrix_element_coordinate =
(self.next_element)(self.matrix, self.graphblas_iterator);
let next_matrix_element_coordinate = (self.next_element)(
&self.graphblas_context,
self.graphblas_matrix,
self.graphblas_iterator,
);

self.next_element = next_element_coordinate;

return next_matrix_element_coordinate;
}
}

fn next_element_coordinate<T: ValueType>(
matrix: &SparseMatrix<T>,
fn next_element_coordinate(
context: &Arc<Context>,
graphblas_matrix: &GrB_Matrix,
graphblas_iterator: GxB_Iterator,
) -> Option<Coordinate> {
match matrix.context_ref().call(
match context.call(
|| unsafe { GxB_Matrix_Iterator_next(graphblas_iterator) },
unsafe { &matrix.graphblas_matrix() }, // TODO: check that error indeed link to the matrix the iterator was attached to
graphblas_matrix, // TODO: check that error indeed link to the matrix the iterator was attached to
) {
Ok(_) => matrix_element_coordinate_at_iterator_position(graphblas_iterator),
Err(error) => match_iterator_error(error),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use std::sync::Arc;

use once_cell::sync::Lazy;
use suitesparse_graphblas_sys::{
GxB_Iterator, GxB_Iterator_free, GxB_Matrix_Iterator_attach, GxB_Matrix_Iterator_next,
GxB_Matrix_Iterator_seek,
GrB_Matrix, GxB_Iterator, GxB_Iterator_free, GxB_Matrix_Iterator_attach,
GxB_Matrix_Iterator_next, GxB_Matrix_Iterator_seek,
};

use crate::collections::sparse_matrix::GetGraphblasSparseMatrix;
use crate::collections::sparse_matrix::MatrixElement;
use crate::collections::sparse_matrix::{GetGraphblasSparseMatrix, SparseMatrix};
use crate::collections::{new_graphblas_iterator, GetElementValueAtIteratorPosition};
use crate::context::CallGraphBlasContext;
use crate::context::GetContext;
use crate::context::{CallGraphBlasContext, Context};
use crate::error::SparseLinearAlgebraError;
use crate::error::{GraphblasErrorType, LogicErrorType, SparseLinearAlgebraErrorType};
use crate::operators::options::GetGraphblasDescriptor;
Expand All @@ -21,53 +23,58 @@ static DEFAULT_GRAPHBLAS_OPERATOR_OPTIONS: Lazy<OperatorOptions> =
Lazy::new(|| OperatorOptions::new_default());

pub struct MatrixElementIterator<'a, T: ValueType + GetElementValueAtIteratorPosition<T>> {
matrix: &'a SparseMatrix<T>,
graphblas_context: Arc<Context>,
graphblas_matrix: &'a GrB_Matrix,
graphblas_iterator: GxB_Iterator,
next_element: fn(&SparseMatrix<T>, GxB_Iterator) -> Option<MatrixElement<T>>,
next_element: fn(&Arc<Context>, &GrB_Matrix, GxB_Iterator) -> Option<MatrixElement<T>>,
}

impl<'a, T: ValueType + GetElementValueAtIteratorPosition<T>> MatrixElementIterator<'a, T> {
pub fn new(matrix: &'a SparseMatrix<T>) -> Result<Self, SparseLinearAlgebraError> {
let graphblas_iterator = unsafe { new_graphblas_iterator(matrix.context_ref()) }?;
pub fn new(
graphblas_matrix: &'a (impl GetGraphblasSparseMatrix + GetContext),
) -> Result<Self, SparseLinearAlgebraError> {
let graphblas_iterator = unsafe { new_graphblas_iterator(graphblas_matrix.context_ref()) }?;

Ok(Self {
matrix,
graphblas_context: graphblas_matrix.context(),
graphblas_matrix: unsafe { graphblas_matrix.graphblas_matrix_ref() },
graphblas_iterator,
next_element: initial_matrix_element,
})
}
}

fn initial_matrix_element<T: ValueType + GetElementValueAtIteratorPosition<T>>(
matrix: &SparseMatrix<T>,
graphblas_context: &Arc<Context>,
graphblas_matrix: &GrB_Matrix,
graphblas_iterator: GxB_Iterator,
) -> Option<MatrixElement<T>> {
match matrix.context_ref().call(
match graphblas_context.call(
|| unsafe {
GxB_Matrix_Iterator_attach(
graphblas_iterator,
matrix.graphblas_matrix(),
graphblas_matrix.to_owned(),
DEFAULT_GRAPHBLAS_OPERATOR_OPTIONS.graphblas_descriptor(),
)
},
unsafe { &matrix.graphblas_matrix() }, // TODO: check that error indeed link to the matrix the iterator was attached to
graphblas_matrix, // TODO: check that error indeed link to the matrix the iterator was attached to
) {
Ok(_) => {}
Err(error) => return match_iterator_error(error),
}

match matrix.context_ref().call(
match graphblas_context.call(
|| unsafe { GxB_Matrix_Iterator_seek(graphblas_iterator, 0) },
unsafe { &matrix.graphblas_matrix() }, // TODO: check that error indeed link to the matrix the iterator was attached to
graphblas_matrix, // TODO: check that error indeed link to the matrix the iterator was attached to
) {
Ok(_) => {}
// TODO: attaching may actually fail, this will cause a panic, which is not desired
Err(error) => return match_iterator_error(error),
}

let next_value = match matrix.context_ref().call(
let next_value = match graphblas_context.call(
|| unsafe { GxB_Matrix_Iterator_seek(graphblas_iterator, 0) },
unsafe { &matrix.graphblas_matrix() }, // TODO: check that error indeed link to the matrix the iterator was attached to
graphblas_matrix, // TODO: check that error indeed link to the matrix the iterator was attached to
) {
Ok(_) => matrix_element_at_iterator_position(graphblas_iterator),
Err(error) => match_iterator_error(error),
Expand All @@ -80,10 +87,11 @@ impl<'a, T: ValueType + GetElementValueAtIteratorPosition<T>> Drop
for MatrixElementIterator<'a, T>
{
fn drop(&mut self) {
let context = self.matrix.context_ref();
let _ = context.call_without_detailed_error_information(|| unsafe {
GxB_Iterator_free(&mut self.graphblas_iterator)
});
let _ = self
.graphblas_context
.call_without_detailed_error_information(|| unsafe {
GxB_Iterator_free(&mut self.graphblas_iterator)
});
}
}

Expand All @@ -93,7 +101,11 @@ impl<'a, T: ValueType + GetElementValueAtIteratorPosition<T>> Iterator
type Item = MatrixElement<T>;

fn next(&mut self) -> Option<MatrixElement<T>> {
let next_matrix_element = (self.next_element)(self.matrix, self.graphblas_iterator);
let next_matrix_element = (self.next_element)(
&self.graphblas_context,
self.graphblas_matrix,
self.graphblas_iterator,
);

self.next_element = next_element;

Expand All @@ -102,12 +114,13 @@ impl<'a, T: ValueType + GetElementValueAtIteratorPosition<T>> Iterator
}

fn next_element<T: ValueType + GetElementValueAtIteratorPosition<T>>(
matrix: &SparseMatrix<T>,
graphblas_context: &Arc<Context>,
graphblas_matrix: &GrB_Matrix,
graphblas_iterator: GxB_Iterator,
) -> Option<MatrixElement<T>> {
match matrix.context_ref().call(
match graphblas_context.call(
|| unsafe { GxB_Matrix_Iterator_next(graphblas_iterator) },
unsafe { &matrix.graphblas_matrix() }, // TODO: check that error indeed link to the matrix the iterator was attached to
graphblas_matrix, // TODO: check that error indeed link to the matrix the iterator was attached to
) {
Ok(_) => matrix_element_at_iterator_position::<T>(graphblas_iterator),
Err(error) => match_iterator_error(error),
Expand Down Expand Up @@ -181,7 +194,7 @@ mod tests {
)
.unwrap();

let matrix_element_iterator = MatrixElementIterator::new(&matrix).unwrap();
let matrix_element_iterator = MatrixElementIterator::<u8>::new(&matrix).unwrap();

for (element, expected_element) in matrix_element_iterator
.into_iter()
Expand All @@ -207,7 +220,7 @@ mod tests {
)
.unwrap();

let matrix_element_iterator = MatrixElementIterator::new(&matrix).unwrap();
let matrix_element_iterator = MatrixElementIterator::<u8>::new(&matrix).unwrap();

for (element, expected_element) in matrix_element_iterator
.into_iter()
Expand Down
Loading

0 comments on commit a58597f

Please sign in to comment.