Skip to content

Restrict type of Arrays that can be used for indexing #259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions src/core/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::array::Array;
use super::defines::AfError;
use super::error::HANDLE_ERROR;
use super::seq::Seq;
use super::util::{af_array, af_index_t, dim_t, HasAfEnum};
use super::util::{af_array, af_index_t, dim_t, HasAfEnum, IndexableType};

use libc::{c_double, c_int, c_uint};
use std::default::Default;
Expand Down Expand Up @@ -142,7 +142,10 @@ pub trait Indexable {
///
/// This is used in functions [index_gen](./fn.index_gen.html) and
/// [assign_gen](./fn.assign_gen.html)
impl<T: HasAfEnum> Indexable for Array<T> {
impl<T> Indexable for Array<T>
where
T: HasAfEnum + IndexableType,
{
fn set(&self, idxr: &mut Indexer, dim: u32, _is_batch: Option<bool>) {
unsafe {
let err_val = af_set_array_indexer(idxr.get(), self.get(), dim as dim_t);
Expand All @@ -155,9 +158,10 @@ impl<T: HasAfEnum> Indexable for Array<T> {
///
/// This is used in functions [index_gen](./fn.index_gen.html) and
/// [assign_gen](./fn.assign_gen.html)
impl<T: Copy> Indexable for Seq<T>
impl<T> Indexable for Seq<T>
where
c_double: From<T>,
T: Copy + IndexableType,
{
fn set(&self, idxr: &mut Indexer, dim: u32, is_batch: Option<bool>) {
unsafe {
Expand Down Expand Up @@ -256,10 +260,11 @@ impl<'object> Drop for Indexer<'object> {
/// println!("a(seq(1, 3, 1), span)");
/// print(&sub);
/// ```
pub fn index<IO, T: Copy>(input: &Array<IO>, seqs: &[Seq<T>]) -> Array<IO>
pub fn index<IO, T>(input: &Array<IO>, seqs: &[Seq<T>]) -> Array<IO>
where
c_double: From<T>,
IO: HasAfEnum,
T: Copy + HasAfEnum + IndexableType,
{
let seqs: Vec<SeqInternal> = seqs.iter().map(|s| SeqInternal::from_seq(s)).collect();
unsafe {
Expand Down Expand Up @@ -462,7 +467,7 @@ where
pub fn lookup<T, I>(input: &Array<T>, indices: &Array<I>, seq_dim: i32) -> Array<T>
where
T: HasAfEnum,
I: HasAfEnum,
I: HasAfEnum + IndexableType,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
Expand Down Expand Up @@ -504,10 +509,11 @@ where
/// // 1.0 1.0 1.0
/// // 2.0 2.0 2.0
/// ```
pub fn assign_seq<T: Copy, I>(lhs: &mut Array<I>, seqs: &[Seq<T>], rhs: &Array<I>)
pub fn assign_seq<T, I>(lhs: &mut Array<I>, seqs: &[Seq<T>], rhs: &Array<I>)
where
c_double: From<T>,
I: HasAfEnum,
T: Copy + IndexableType,
{
let seqs: Vec<SeqInternal> = seqs.iter().map(|s| SeqInternal::from_seq(s)).collect();
unsafe {
Expand Down Expand Up @@ -632,9 +638,10 @@ struct SeqInternal {
}

impl SeqInternal {
fn from_seq<T: Copy>(s: &Seq<T>) -> Self
fn from_seq<T>(s: &Seq<T>) -> Self
where
c_double: From<T>,
T: Copy + IndexableType,
{
Self {
begin: From::from(s.begin()),
Expand Down
19 changes: 15 additions & 4 deletions src/core/seq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,23 @@ use serde::{Deserialize, Serialize};
use std::default::Default;
use std::fmt;

use super::util::IndexableType;

/// Sequences are used for indexing Arrays
#[derive(Copy, Clone, Debug, PartialEq)]
#[cfg_attr(feature = "afserde", derive(Serialize, Deserialize))]
#[repr(C)]
pub struct Seq<T> {
pub struct Seq<T: IndexableType> {
begin: T,
end: T,
step: T,
}

/// Default `Seq` spans all the elements along a dimension
impl<T: One + Zero> Default for Seq<T> {
impl<T> Default for Seq<T>
where
T: One + Zero + IndexableType,
{
fn default() -> Self {
Self {
begin: One::one(),
Expand All @@ -27,7 +32,10 @@ impl<T: One + Zero> Default for Seq<T> {
}

/// Enables use of `Seq` with `{}` format in print statements
impl<T: fmt::Display> fmt::Display for Seq<T> {
impl<T> fmt::Display for Seq<T>
where
T: fmt::Display + IndexableType,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
Expand All @@ -37,7 +45,10 @@ impl<T: fmt::Display> fmt::Display for Seq<T> {
}
}

impl<T: Copy> Seq<T> {
impl<T> Seq<T>
where
T: Copy + IndexableType,
{
/// Create a `Seq` that goes from `begin` to `end` at a step size of `step`
pub fn new(begin: T, end: T, step: T) -> Self {
Self { begin, end, step }
Expand Down
13 changes: 13 additions & 0 deletions src/core/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -827,3 +827,16 @@ impl Fromf64 for i16 { fn fromf64(value: f64) -> Self { value as Self }}
impl Fromf64 for u8 { fn fromf64(value: f64) -> Self { value as Self }}
#[rustfmt::skip]
impl Fromf64 for bool { fn fromf64(value: f64) -> Self { value > 0.0 }}

///Trait qualifier for the type of Arrays accepted by scan operations
pub trait IndexableType {}

impl IndexableType for f64 {}
impl IndexableType for i64 {}
impl IndexableType for u64 {}
impl IndexableType for f32 {}
impl IndexableType for i32 {}
impl IndexableType for u32 {}
impl IndexableType for i16 {}
impl IndexableType for u16 {}
impl IndexableType for u8 {}