Skip to content

Commit

Permalink
Update support for unsigned integers in aesara.tensor.subtensor
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Aug 28, 2022
1 parent 5b935bc commit 3500fec
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 4 deletions.
21 changes: 19 additions & 2 deletions aesara/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
iscalar,
lscalar,
tensor,
ubscalar,
uiscalar,
ulscalar,
uwscalar,
wscalar,
zscalar,
)
Expand All @@ -50,12 +54,25 @@
_logger = logging.getLogger("aesara.tensor.subtensor")

invalid_scal_types = (aes.float64, aes.float32, aes.float16)
scal_types = (aes.int64, aes.int32, aes.int16, aes.int8)
scal_types = (
aes.int64,
aes.int32,
aes.int16,
aes.int8,
aes.uint64,
aes.uint32,
aes.uint16,
aes.uint8,
)
tensor_types = (
lscalar,
iscalar,
wscalar,
bscalar,
ulscalar,
uiscalar,
uwscalar,
ubscalar,
)
invalid_tensor_types = (
fscalar,
Expand Down Expand Up @@ -376,7 +393,7 @@ def slice_len(slc, n):
def is_basic_idx(idx):
"""Determine if an index is of the NumPy basic type.
XXX: This only checks a single index, so an integers is *not* considered a
XXX: This only checks a single index, so an integer is *not* considered a
basic index, because--depending on the other indices its used with--an
integer can indicate advanced indexing.
Expand Down
4 changes: 4 additions & 0 deletions aesara/tensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,10 @@ def tensor(*args, **kwargs):
wscalar = TensorType("int16", ())
iscalar = TensorType("int32", ())
lscalar = TensorType("int64", ())
ubscalar = TensorType("uint8", ())
uwscalar = TensorType("uint16", ())
uiscalar = TensorType("uint32", ())
ulscalar = TensorType("uint64", ())


def scalar(name=None, dtype=None):
Expand Down
4 changes: 2 additions & 2 deletions aesara/tensor/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,13 +515,13 @@ def is_empty_array(val):
isinstance(val, np.ndarray) and val.size == 0
)

# Force input to be int64 datatype if input is an empty list or tuple
# Force input to be an int datatype if input is an empty list or tuple
# Else leave it as is if it is a real number
# Convert python literals to aesara constants
args = tuple(
[
at.subtensor.as_index_constant(
np.array(inp, dtype=np.int64) if is_empty_array(inp) else inp
np.array(inp, dtype=np.uint8) if is_empty_array(inp) else inp
)
for inp in args
]
Expand Down
5 changes: 5 additions & 0 deletions tests/tensor/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2615,3 +2615,8 @@ def test_index_vars_to_types():

res = index_vars_to_types(iscalar)
assert isinstance(res, scal.ScalarType)

x = scal.constant(1, dtype=np.uint8)
assert isinstance(x.type, scal.ScalarType)
res = index_vars_to_types(x)
assert res == x.type

0 comments on commit 3500fec

Please sign in to comment.