Skip to content

Commit

Permalink
Merge branch 'main' into agoose77/refactor-typetracer
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 authored Nov 17, 2022
2 parents 2904f59 + 97840b9 commit 56837d9
Show file tree
Hide file tree
Showing 50 changed files with 130 additions and 254 deletions.
129 changes: 5 additions & 124 deletions src/awkward/_connect/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,11 @@ def action(inputs, **ignore):
return _array_ufunc_adjust(custom, inputs, kwargs, behavior)

if ufunc is numpy.matmul:
custom_matmul = action_for_matmul(inputs)
if custom_matmul is not None:
return custom_matmul()
raise ak._errors.wrap_error(
NotImplementedError(
"matrix multiplication (`@` or `np.matmul`) is not yet implemented for Awkward Arrays"
)
)

if all(
isinstance(x, NumpyArray) or not isinstance(x, ak.contents.Content)
Expand Down Expand Up @@ -256,131 +258,10 @@ def unary_action(layout, **ignore):
return ak._util.wrap(out, behavior)


# def matmul_for_numba(lefts, rights, dtype):
# total_outer = 0
# total_inner = 0
# total_content = 0

# for A, B in zip(lefts, rights):
# first = -1
# for Ai in A:
# if first == -1:
# first = len(Ai)
# elif first != len(Ai):
# raise ak._errors.wrap_error(ValueError(
# "one of the left matrices in np.matmul is not rectangular"
# ))
# if first == -1:
# first = 0
# rowsA = len(A)
# colsA = first

# first = -1
# for Bi in B:
# if first == -1:
# first = len(Bi)
# elif first != len(Bi):
# raise ak._errors.wrap_error(ValueError(
# "one of the right matrices in np.matmul is not rectangular"
# ))
# if first == -1:
# first = 0
# rowsB = len(B)
# colsB = first

# if colsA != rowsB:
# raise ak._errors.wrap_error(ValueError(
# u"one of the pairs of matrices in np.matmul do not match shape: "
# u"(n \u00d7 k) @ (k \u00d7 m)"
# ))

# total_outer += 1
# total_inner += rowsA
# total_content += rowsA * colsB

# outer = numpy.empty(total_outer + 1, numpy.int64)
# inner = numpy.empty(total_inner + 1, numpy.int64)
# content = numpy.zeros(total_content, dtype)

# outer[0] = 0
# inner[0] = 0
# outer_i = 1
# inner_i = 1
# content_i = 0
# for A, B in zip(lefts, rights):
# rows = len(A)
# cols = 0
# if len(B) > 0:
# cols = len(B[0])
# mids = 0
# if len(A) > 0:
# mids = len(A[0])

# for i in range(rows):
# for j in range(cols):
# for v in range(mids):
# pos = content_i + i * cols + j
# content[pos] += A[i][v] * B[v][j]

# outer[outer_i] = outer[outer_i - 1] + rows
# outer_i += 1
# for _ in range(rows):
# inner[inner_i] = inner[inner_i - 1] + cols
# inner_i += 1
# content_i += rows * cols

# return outer, inner, content


# matmul_for_numba.numbafied = None


def action_for_matmul(inputs):
raise ak._errors.wrap_error(NotImplementedError)


# def action_for_matmul(inputs):
# inputs = [
# ak._util.recursively_apply(
# x, (lambda _: _), pass_depth=False, numpy_to_regular=True
# )
# if isinstance(x, (ak.contents.Content, ak.record.Record))
# else x
# for x in inputs
# ]

# if len(inputs) == 2 and all(
# isinstance(x, ak._util.listtypes)
# and isinstance(x.content, ak._util.listtypes)
# and isinstance(x.content.content, NumpyArray)
# for x in inputs
# ):
# ak._connect.numba.register_and_check()
# import numba

# if matmul_for_numba.numbafied is None:
# matmul_for_numba.numbafied = numba.njit(matmul_for_numba)

# lefts = ak.highlevel.Array(inputs[0])
# rights = ak.highlevel.Array(inputs[1])
# dtype = numpy.asarray(lefts[0:0, 0:0, 0:0] + rights[0:0, 0:0, 0:0]).dtype

# outer, inner, content = matmul_for_numba.numbafied(lefts, rights, dtype)

# return lambda: (
# ak.contents.ListOffsetArray64(
# ak.index.Index64(outer),
# ak.contents.ListOffsetArray64(
# ak.index.Index64(inner),
# NumpyArray(content),
# ),
# ),
# )

# else:
# return None


try:
NDArrayOperatorsMixin = numpy.lib.mixins.NDArrayOperatorsMixin

Expand Down
4 changes: 2 additions & 2 deletions src/awkward/_connect/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def to_awkwardarrow_type(


def remove_optiontype(akarray):
assert type(akarray).is_OptionType
assert type(akarray).is_option
if isinstance(akarray, ak.contents.IndexedOptionArray):
return ak.contents.IndexedArray(
akarray.index, akarray.content, akarray.parameters
Expand All @@ -883,7 +883,7 @@ def remove_optiontype(akarray):


def form_remove_optiontype(akform):
assert type(akform).is_OptionType
assert type(akform).is_option
if isinstance(akform, ak.forms.IndexedOptionForm):
return ak.forms.IndexedForm(
akform.index,
Expand Down
6 changes: 3 additions & 3 deletions src/awkward/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,11 @@ def union_to_record(unionarray, anonymous):

contents = []
for layout in unionarray.contents:
if layout.is_IndexedType and not layout.is_OptionType:
if layout.is_indexed and not layout.is_option:
contents.append(layout.project())
elif layout.is_UnionType:
elif layout.is_union:
contents.append(union_to_record(layout, anonymous))
elif layout.is_OptionType:
elif layout.is_option:
contents.append(
ak.operations.fill_none(layout, np.nan, axis=0, highlevel=False)
)
Expand Down
6 changes: 3 additions & 3 deletions src/awkward/behaviors/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def _categorical_equal(one, two):

one, two = one.layout, two.layout

assert one.is_IndexedType or (one.is_OptionType and one.is_IndexedType)
assert two.is_IndexedType or (two.is_OptionType and two.is_IndexedType)
assert one.is_indexed or (one.is_option and one.is_indexed)
assert two.is_indexed or (two.is_option and two.is_indexed)
assert one.parameter("__array__") == "categorical"
assert two.parameter("__array__") == "categorical"

Expand Down Expand Up @@ -93,7 +93,7 @@ def _categorical_equal(one, two):
def _apply_ufunc(ufunc, method, inputs, kwargs):
nextinputs = []
for x in inputs:
if isinstance(x, ak.highlevel.Array) and x.layout.is_IndexedType:
if isinstance(x, ak.highlevel.Array) and x.layout.is_indexed:
nextinputs.append(
ak.highlevel.Array(x.layout.project(), behavior=ak._util.behavior_of(x))
)
Expand Down
4 changes: 2 additions & 2 deletions src/awkward/contents/bitmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


class BitMaskedArray(Content):
is_OptionType = True
is_option = True

def copy(
self,
Expand Down Expand Up @@ -656,7 +656,7 @@ def continuation():
raise ak._errors.wrap_error(AssertionError(result))

def packed(self):
if self._content.is_RecordType:
if self._content.is_record:
next = self.toIndexedOptionArray64()

content = next._content.packed()
Expand Down
8 changes: 4 additions & 4 deletions src/awkward/contents/bytemaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


class ByteMaskedArray(Content):
is_OptionType = True
is_option = True

def copy(
self,
Expand Down Expand Up @@ -877,9 +877,9 @@ def _reduce_next(
if not branch and negaxis == depth:
return out
else:
if out.is_ListType:
if out.is_list:
out_content = out.content[out.starts[0] :]
elif out.is_RegularType:
elif out.is_regular:
out_content = out.content
else:
raise ak._errors.wrap_error(
Expand Down Expand Up @@ -1052,7 +1052,7 @@ def continuation():
raise ak._errors.wrap_error(AssertionError(result))

def packed(self):
if self._content.is_RecordType:
if self._content.is_record:
next = self.toIndexedOptionArray64()
content = next._content.packed()
if content.length > self._mask.length:
Expand Down
20 changes: 10 additions & 10 deletions src/awkward/contents/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ def __repr__(self):


class Content:
is_NumpyType = False
is_UnknownType = False
is_ListType = False
is_RegularType = False
is_OptionType = False
is_IndexedType = False
is_RecordType = False
is_UnionType = False
is_numpy = False
is_unknown = False
is_list = False
is_regular = False
is_option = False
is_indexed = False
is_record = False
is_union = False

def _init(self, parameters: dict[str, Any] | None, nplike: NumpyLike | None):
if parameters is not None and not isinstance(parameters, dict):
Expand Down Expand Up @@ -739,7 +739,7 @@ def merge(self, other: Content) -> Content:

def mergeable(self, other: Content, mergebool: bool = True) -> bool:
# Is the other content is an identity, or a union?
if other.is_identity_like or other.is_UnionType:
if other.is_identity_like or other.is_union:
return True
# Otherwise, do the parameters match? If not, we can't merge.
elif not (
Expand Down Expand Up @@ -1663,7 +1663,7 @@ def _to_list(
def _to_list_custom(
self, behavior: dict | None, json_conversions: dict[str, Any] | None
):
if self.is_RecordType:
if self.is_record:
getitem = ak._util.recordclass(self, behavior).__getitem__
overloaded = getitem is not ak.highlevel.Record.__getitem__ and not getattr(
getitem, "ignore_in_to_list", False
Expand Down
4 changes: 2 additions & 2 deletions src/awkward/contents/emptyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@


class EmptyArray(Content):
is_NumpyType = True
is_UnknownType = True
is_numpy = True
is_unknown = True

def copy(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/contents/indexedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


class IndexedArray(Content):
is_IndexedType = True
is_indexed = True

def copy(
self,
Expand Down
8 changes: 4 additions & 4 deletions src/awkward/contents/indexedoptionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@


class IndexedOptionArray(Content):
is_OptionType = True
is_IndexedType = True
is_option = True
is_indexed = True

def copy(
self,
Expand Down Expand Up @@ -1353,9 +1353,9 @@ def _reduce_next(
if not branch and negaxis == depth:
return out
else:
if out.is_ListType:
if out.is_list:
out_content = out.content[out.starts[0] :]
elif out.is_RegularType:
elif out.is_regular:
out_content = out.content
else:
raise ak._errors.wrap_error(
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/contents/listarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


class ListArray(Content):
is_ListType = True
is_list = True

def copy(
self,
Expand Down
8 changes: 4 additions & 4 deletions src/awkward/contents/listoffsetarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


class ListOffsetArray(Content):
is_ListType = True
is_list = True

def copy(
self,
Expand Down Expand Up @@ -1605,9 +1605,9 @@ def _reduce_next(
if keepdims and depth == negaxis + 1:
# Don't convert the `RegularArray()` to a `ListOffsetArray`,
# means this will be broadcastable
assert outcontent.is_RegularType
assert outcontent.is_regular
elif depth >= negaxis + 2:
assert outcontent.is_ListType or outcontent.is_RegularType
assert outcontent.is_list or outcontent.is_regular
outcontent = outcontent.toListOffsetArray64(False)

return ak.contents.ListOffsetArray(
Expand Down Expand Up @@ -1931,7 +1931,7 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
)

content_type = pyarrow.list_(paarray.type).value_field.with_nullable(
akcontent.is_OptionType
akcontent.is_option
)

if issubclass(npoffsets.dtype.type, np.int32):
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/contents/numpyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


class NumpyArray(Content):
is_NumpyType = True
is_numpy = True

def copy(
self,
Expand Down
4 changes: 2 additions & 2 deletions src/awkward/contents/recordarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


class RecordArray(Content):
is_RecordType = True
is_record = True

def copy(
self,
Expand Down Expand Up @@ -860,7 +860,7 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
types = pyarrow.struct(
[
pyarrow.field(self.index_to_field(i), values[i].type).with_nullable(
x.is_OptionType
x.is_option
)
for i, x in enumerate(self._contents)
]
Expand Down
Loading

0 comments on commit 56837d9

Please sign in to comment.