Skip to content

Commit

Permalink
Merge pull request #507 from drdavella/ndarray-broadcast
Browse files Browse the repository at this point in the history
Allow serialization of broadcasted ndarrays
  • Loading branch information
drdavella authored May 24, 2018
2 parents e435a56 + 7bf5a1f commit a5a4670
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 7 deletions.
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@

- Add API function for retrieving history entries. [#501]

2.0.2 (unreleased)
------------------

- Allow serialization of broadcasted ``numpy`` arrays. [#507]

2.0.1 (2018-05-08)
------------------

Expand Down
4 changes: 2 additions & 2 deletions asdf/generic_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def write(self, content):
"""

def write_array(self, array):
_array_tofile(None, self.write, array)
_array_tofile(None, self.write, np.ascontiguousarray(array))

def seek(self, offset, whence=0):
"""
Expand Down Expand Up @@ -751,7 +751,7 @@ def write_array(self, arr):
arr.flush()
self.fast_forward(len(arr.data))
else:
_array_tofile(self._fd, self._fd.write, arr)
_array_tofile(self._fd, self._fd.write, np.ascontiguousarray(arr))

def can_memmap(self):
return True
Expand Down
15 changes: 10 additions & 5 deletions asdf/tags/core/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,14 +386,19 @@ def reserve_blocks(cls, data, ctx):
@classmethod
def to_tree(cls, data, ctx):
base = util.get_array_base(data)
block = ctx.blocks.find_or_create_block_for_array(data, ctx)
shape = data.shape
dtype = data.dtype
offset = data.ctypes.data - base.ctypes.data
if data.flags[b'C_CONTIGUOUS']:
strides = None
else:
strides = data.strides
strides = None

if not data.flags.c_contiguous:
# We do not want to encode strides for broadcasted arrays
if not all(data.strides):
data = np.ascontiguousarray(data)
else:
strides = data.strides

block = ctx.blocks.find_or_create_block_for_array(data, ctx)

result = {}

Expand Down
6 changes: 6 additions & 0 deletions asdf/tags/core/tests/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,3 +803,9 @@ def test_tagged_object_array(tmpdir):
objdata.flat[i] = Quantity(i, 'angstrom')

helpers.assert_roundtrip_tree({'bizbaz': objdata}, tmpdir)


def test_broadcasted_array(tmpdir):
attrs = np.broadcast_arrays(np.array([10,20]), np.array(10), np.array(10))
tree = {'one': attrs[1] }#, 'two': attrs[1], 'three': attrs[2]}
helpers.assert_roundtrip_tree(tree, tmpdir)

0 comments on commit a5a4670

Please sign in to comment.