Skip to content

Commit

Permalink
ak.concatenate should preserve regular-type for axis>0, too. (#1609)
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski authored Aug 19, 2022
1 parent 8af6e2c commit b2ee4ab
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 7 deletions.
4 changes: 4 additions & 0 deletions src/awkward/_v2/_typetracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,10 @@ def repeat(self, *args, **kwargs):
# array1, array2
raise ak._v2._util.error(NotImplementedError)

def tile(self, *args, **kwargs):
# array, int
raise ak._v2._util.error(NotImplementedError)

def stack(self, *args, **kwargs):
# arrays
raise ak._v2._util.error(NotImplementedError)
Expand Down
65 changes: 59 additions & 6 deletions src/awkward/_v2/operations/ak_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def _impl(arrays, axis, merge, mergebool, highlevel, behavior):
else:

def action(inputs, depth, **kwargs):

if depth == posaxis and any(
isinstance(x, ak._v2.contents.Content) and x.is_OptionType
for x in inputs
Expand All @@ -124,19 +123,72 @@ def action(inputs, depth, **kwargs):
nextinputs.append(x)
inputs = nextinputs

if depth == posaxis:
nplike = ak.nplike.of(*inputs)

length = ak._v2._typetracer.UnknownLength
for x in inputs:
if isinstance(x, ak._v2.contents.Content):
if not ak._v2._util.isint(length):
length = x.length
elif length != x.length and ak._v2._util.isint(x.length):
raise ak._v2._util.error(
ValueError(
"all arrays must have the same length for "
"axis={}".format(axis)
)
)

if depth == posaxis and all(
isinstance(x, ak._v2.contents.Content)
and x.is_ListType
and x.is_RegularType
or (isinstance(x, ak._v2.contents.NumpyArray) and x.data.ndim > 1)
or not isinstance(x, ak._v2.contents.Content)
for x in inputs
):
regulararrays = []
sizes = []
for x in inputs:
if isinstance(x, ak._v2.contents.RegularArray):
regulararrays.append(x)
elif isinstance(x, ak._v2.contents.NumpyArray):
regulararrays.append(x.toRegularArray())
else:
regulararrays.append(
ak._v2.contents.RegularArray(
ak._v2.contents.NumpyArray(
nplike.broadcast_to(nplike.array([x]), (length,))
),
1,
)
)
sizes.append(regulararrays[-1].size)

nplike = ak.nplike.of(*inputs)
prototype = nplike.empty(sum(sizes), np.int8)
start = 0
for tag, size in enumerate(sizes):
prototype[start : start + size] = tag
start += size

length = max(
len(x) for x in inputs if isinstance(x, ak._v2.contents.Content)
tags = ak._v2.index.Index8(nplike.tile(prototype, length))
index = ak._v2.contents.UnionArray.regular_index(tags)
inner = ak._v2.contents.UnionArray(
tags, index, [x._content for x in regulararrays]
)

out = ak._v2.contents.RegularArray(
inner.simplify_uniontype(merge=merge, mergebool=mergebool),
len(prototype),
)
return (out,)

elif depth == posaxis and all(
isinstance(x, ak._v2.contents.Content)
and x.is_ListType
or (isinstance(x, ak._v2.contents.NumpyArray) and x.data.ndim > 1)
or not isinstance(x, ak._v2.contents.Content)
for x in inputs
):
nextinputs = []
for x in inputs:
if isinstance(x, ak._v2.contents.Content):
Expand Down Expand Up @@ -190,7 +242,8 @@ def action(inputs, depth, **kwargs):
inner = ak._v2.contents.UnionArray(tags, index, all_flatten)

out = ak._v2.contents.ListOffsetArray(
offsets, inner.simplify_uniontype(merge=merge, mergebool=mergebool)
offsets,
inner.simplify_uniontype(merge=merge, mergebool=mergebool),
)

return (out,)
Expand Down
4 changes: 4 additions & 0 deletions src/awkward/nplike.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,10 @@ def repeat(self, *args, **kwargs):
# array1, array2
return self._module.repeat(*args, **kwargs)

def tile(self, *args, **kwargs):
# array, int
return self._module.tile(*args, **kwargs)

def stack(self, *args, **kwargs):
# arrays
return self._module.stack(*args, **kwargs)
Expand Down
86 changes: 85 additions & 1 deletion tests/v2/test_1586-concatenate-should-preserve-regulararray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np # noqa: F401
import awkward as ak # noqa: F401

from awkward._v2.types import ArrayType, RegularType, OptionType, NumpyType
from awkward._v2.types import ArrayType, ListType, RegularType, OptionType, NumpyType


def test_simple():
Expand Down Expand Up @@ -76,3 +76,87 @@ def test_option_option():
[8.8, 9.9],
]
assert c.type == ArrayType(OptionType(RegularType(NumpyType("float64"), 2)), 7)


def test_regular_numpy():
a1 = ak._v2.from_json("[[0.0, 1.1], [2.2, 3.3]]")
a2 = ak._v2.Array(np.array([[4.4, 5.5], [6.6, 7.7], [8.8, 9.9]]))
a1 = ak._v2.to_regular(a1, axis=1)
assert isinstance(a2.layout, ak._v2.contents.NumpyArray)
c = ak._v2.concatenate([a1, a2])
assert c.tolist() == [[0.0, 1.1], [2.2, 3.3], [4.4, 5.5], [6.6, 7.7], [8.8, 9.9]]
assert c.type == ArrayType(RegularType(NumpyType("float64"), 2), 5)


def test_numpy_regular():
a1 = ak._v2.Array(np.array([[0.0, 1.1], [2.2, 3.3]]))
a2 = ak._v2.from_json("[[4.4, 5.5], [6.6, 7.7], [8.8, 9.9]]")
assert isinstance(a1.layout, ak._v2.contents.NumpyArray)
a2 = ak._v2.to_regular(a2, axis=1)
c = ak._v2.concatenate([a1, a2])
assert c.tolist() == [[0.0, 1.1], [2.2, 3.3], [4.4, 5.5], [6.6, 7.7], [8.8, 9.9]]
assert c.type == ArrayType(RegularType(NumpyType("float64"), 2), 5)


def test_regular_regular_axis1():
a1 = ak._v2.from_json("[[0.0, 1.1], [2.2, 3.3]]")
a2 = ak._v2.from_json("[[4.4, 5.5, 6.6], [7.7, 8.8, 9.9]]")
a1 = ak._v2.to_regular(a1, axis=1)
a2 = ak._v2.to_regular(a2, axis=1)
c = ak._v2.concatenate([a1, a2], axis=1)
assert c.tolist() == [[0.0, 1.1, 4.4, 5.5, 6.6], [2.2, 3.3, 7.7, 8.8, 9.9]]
assert c.type == ArrayType(RegularType(NumpyType("float64"), 5), 2)


def test_option_regular_axis1():
a1 = ak._v2.from_json("[[0.0, 1.1], null, [2.2, 3.3]]")
a2 = ak._v2.from_json("[[4.4, 5.5, 6.6], [7, 8, 9], [7.7, 8.8, 9.9]]")
a1 = ak._v2.to_regular(a1, axis=1)
a2 = ak._v2.to_regular(a2, axis=1)
c = ak._v2.concatenate([a1, a2], axis=1)
assert c.tolist() == [
[0.0, 1.1, 4.4, 5.5, 6.6],
[7, 8, 9],
[2.2, 3.3, 7.7, 8.8, 9.9],
]
assert c.type == ArrayType(ListType(NumpyType("float64")), 3)


def test_regular_option_axis1():
a1 = ak._v2.from_json("[[0.0, 1.1], [7, 8], [2.2, 3.3]]")
a2 = ak._v2.from_json("[[4.4, 5.5, 6.6], null, [7.7, 8.8, 9.9]]")
a1 = ak._v2.to_regular(a1, axis=1)
a2 = ak._v2.to_regular(a2, axis=1)
c = ak._v2.concatenate([a1, a2], axis=1)
assert c.tolist() == [[0.0, 1.1, 4.4, 5.5, 6.6], [7, 8], [2.2, 3.3, 7.7, 8.8, 9.9]]
assert c.type == ArrayType(ListType(NumpyType("float64")), 3)


def test_option_option_axis1():
a1 = ak._v2.from_json("[[0.0, 1.1], null, [2.2, 3.3]]")
a2 = ak._v2.from_json("[[4.4, 5.5, 6.6], null, [7.7, 8.8, 9.9]]")
a1 = ak._v2.to_regular(a1, axis=1)
a2 = ak._v2.to_regular(a2, axis=1)
c = ak._v2.concatenate([a1, a2], axis=1)
assert c.tolist() == [[0.0, 1.1, 4.4, 5.5, 6.6], [], [2.2, 3.3, 7.7, 8.8, 9.9]]
assert c.type == ArrayType(ListType(NumpyType("float64")), 3)


def test_regular_numpy_axis1():
a1 = ak._v2.from_json("[[0.0, 1.1], [2.2, 3.3]]")
a2 = ak._v2.Array(np.array([[4.4, 5.5, 6.6], [7.7, 8.8, 9.9]]))
a1 = ak._v2.to_regular(a1, axis=1)
assert isinstance(a2.layout, ak._v2.contents.NumpyArray)
c = ak._v2.concatenate([a1, a2], axis=1)
assert c.tolist() == [[0.0, 1.1, 4.4, 5.5, 6.6], [2.2, 3.3, 7.7, 8.8, 9.9]]
assert c.type == ArrayType(RegularType(NumpyType("float64"), 5), 2)


def test_numpy_regular_axis1():
a1 = ak._v2.Array(np.array([[0.0, 1.1], [2.2, 3.3]]))
a2 = ak._v2.from_json("[[4.4, 5.5, 6.6], [7.7, 8.8, 9.9]]")
assert isinstance(a1.layout, ak._v2.contents.NumpyArray)
a2 = ak._v2.to_regular(a2, axis=1)
c = ak._v2.concatenate([a1, a2], axis=1)
assert c.tolist() == [[0.0, 1.1, 4.4, 5.5, 6.6], [2.2, 3.3, 7.7, 8.8, 9.9]]
assert c.type == ArrayType(RegularType(NumpyType("float64"), 5), 2)

0 comments on commit b2ee4ab

Please sign in to comment.