diff --git a/docs/source/traits_api_reference/trait_numeric.rst b/docs/source/traits_api_reference/trait_numeric.rst index a2a4b2818..f7ecb41b3 100644 --- a/docs/source/traits_api_reference/trait_numeric.rst +++ b/docs/source/traits_api_reference/trait_numeric.rst @@ -11,6 +11,8 @@ Classes .. autoclass:: Array +.. autoclass:: ArrayOrNone + .. autoclass:: CArray Function diff --git a/docs/source/traits_user_manual/defining.rst b/docs/source/traits_user_manual/defining.rst index 793fcd51d..a9534184f 100644 --- a/docs/source/traits_user_manual/defining.rst +++ b/docs/source/traits_user_manual/defining.rst @@ -258,6 +258,9 @@ the table. | Array | Array( [*dtype* = None, *shape* = None, *value* = None, | | | *typecode* = None, \*\*\ *metadata*] ) | +------------------+----------------------------------------------------------+ +| ArrayOrNone | ArrayOrNone( [*dtype* = None, *shape* = None, | +| | *value* = None, *typecode* = None, \*\*\ *metadata*] ) | ++------------------+----------------------------------------------------------+ | Button | Button( [*label* = '', *image* = None, *style* = | | | 'button', *orientation* = 'vertical', *width_padding* = | | | 7, *height_padding* = 5, \*\*\ *metadata*] ) | diff --git a/traits/api.py b/traits/api.py index 211127a13..bc87a06f7 100644 --- a/traits/api.py +++ b/traits/api.py @@ -91,7 +91,7 @@ from .adaptation.adaptation_manager import adapt, register_factory, \ register_provides -from .trait_numeric import Array, CArray +from .trait_numeric import Array, ArrayOrNone, CArray try: from . import has_traits as has_traits diff --git a/traits/tests/test_array_or_none.py b/traits/tests/test_array_or_none.py new file mode 100644 index 000000000..b1cfeb92f --- /dev/null +++ b/traits/tests/test_array_or_none.py @@ -0,0 +1,176 @@ +#------------------------------------------------------------------------------ +# +# Copyright (c) 2014, Enthought, Inc. +# All rights reserved. +# +# This software is provided without warranty under the terms of the BSD +# license included in /LICENSE.txt and may be redistributed only +# under the conditions described in the aforementioned license. The license +# is also available online at http://www.enthought.com/licenses/BSD.txt +# +# Thanks for using Enthought open source! +# +#------------------------------------------------------------------------------ +""" +Tests for the ArrayOrNone TraitType. + +""" + +from __future__ import absolute_import + +from traits.testing.unittest_tools import unittest + +try: + import numpy +except ImportError: + numpy_available = False +else: + numpy_available = True + +from traits.testing.unittest_tools import UnittestTools +from ..api import ArrayOrNone, HasTraits, NO_COMPARE, TraitError + + +if numpy_available: + # Use of `ArrayOrNone` requires NumPy to be installed. + + class Foo(HasTraits): + maybe_array = ArrayOrNone + + maybe_float_array = ArrayOrNone(dtype=float) + + maybe_two_d_array = ArrayOrNone(shape=(None, None)) + + maybe_array_with_default = ArrayOrNone(value=[1, 2, 3]) + + maybe_array_no_compare = ArrayOrNone(comparison_mode=NO_COMPARE) + + +@unittest.skipUnless(numpy_available, "numpy not available") +class TestArrayOrNone(unittest.TestCase, UnittestTools): + """ + Tests for the ArrayOrNone TraitType. + + """ + def test_default(self): + foo = Foo() + self.assertIsNone(foo.maybe_array) + + def test_explicit_default(self): + foo = Foo() + self.assertIsInstance(foo.maybe_array_with_default, numpy.ndarray) + + def test_default_validation(self): + # CArray and Array validate the default at class creation time; + # we do the same for ArrayOrNone. + with self.assertRaises(TraitError): + class Bar(HasTraits): + bad_array = ArrayOrNone(shape=(None, None), value=[1, 2, 3]) + + def test_setting_array_from_array(self): + foo = Foo() + test_array = numpy.arange(5) + foo.maybe_array = test_array + output_array = foo.maybe_array + self.assertIsInstance(output_array, numpy.ndarray) + self.assertEqual(output_array.dtype, test_array.dtype) + self.assertEqual(output_array.shape, test_array.shape) + self.assertTrue((output_array == test_array).all()) + + def test_setting_array_from_list(self): + foo = Foo() + test_list = [5, 6, 7, 8, 9] + foo.maybe_array = test_list + output_array = foo.maybe_array + self.assertIsInstance(output_array, numpy.ndarray) + self.assertEqual(output_array.dtype, numpy.dtype(int)) + self.assertEqual(output_array.shape, (5,)) + self.assertTrue((output_array == test_list).all()) + + def test_setting_array_from_none(self): + foo = Foo() + test_array = numpy.arange(5) + + self.assertIsNone(foo.maybe_array) + foo.maybe_array = test_array + self.assertIsInstance(foo.maybe_array, numpy.ndarray) + foo.maybe_array = None + self.assertIsNone(foo.maybe_array) + + def test_dtype(self): + foo = Foo() + foo.maybe_float_array = [1, 2, 3] + + array_value = foo.maybe_float_array + self.assertIsInstance(array_value, numpy.ndarray) + self.assertEqual(array_value.dtype, numpy.dtype(float)) + + def test_shape(self): + foo = Foo() + with self.assertRaises(TraitError): + foo.maybe_two_d_array = [1, 2, 3] + + def test_change_notifications(self): + foo = Foo() + test_array = numpy.arange(-7, -2) + different_test_array = numpy.arange(10) + + # Assigning None to something that's already None shouldn't fire. + with self.assertTraitDoesNotChange(foo, 'maybe_array'): + foo.maybe_array = None + + # Changing from None to an array: expect an event. + with self.assertTraitChanges(foo, 'maybe_array'): + foo.maybe_array = test_array + + # No event from assigning the same array again. + with self.assertTraitDoesNotChange(foo, 'maybe_array'): + foo.maybe_array = test_array + + # But assigning a new array fires an event. + with self.assertTraitChanges(foo, 'maybe_array'): + foo.maybe_array = different_test_array + + # No event even if the array is modified in place. + different_test_array += 2 + with self.assertTraitDoesNotChange(foo, 'maybe_array'): + foo.maybe_array = different_test_array + + # Set back to None; we should get an event. + with self.assertTraitChanges(foo, 'maybe_array'): + foo.maybe_array = None + + def test_comparison_mode_override(self): + foo = Foo() + test_array = numpy.arange(-7, 2) + + with self.assertTraitChanges(foo, 'maybe_array_no_compare'): + foo.maybe_array_no_compare = None + + with self.assertTraitChanges(foo, 'maybe_array_no_compare'): + foo.maybe_array_no_compare = test_array + + with self.assertTraitChanges(foo, 'maybe_array_no_compare'): + foo.maybe_array_no_compare = test_array + + def test_default_value_copied(self): + # Check that we don't share defaults. + test_default = numpy.arange(100.0, 110.0) + + class FooBar(HasTraits): + foo = ArrayOrNone(value=test_default) + + bar = ArrayOrNone(value=test_default) + + foo_bar = FooBar() + + self.assertTrue((foo_bar.foo == test_default).all()) + self.assertTrue((foo_bar.bar == test_default).all()) + + test_default += 2.0 + self.assertFalse((foo_bar.foo == test_default).all()) + self.assertFalse((foo_bar.bar == test_default).all()) + + foo = foo_bar.foo + foo += 1729.0 + self.assertFalse((foo_bar.foo == foo_bar.bar).all()) diff --git a/traits/trait_numeric.py b/traits/trait_numeric.py index 9bbde32c3..577697e4b 100644 --- a/traits/trait_numeric.py +++ b/traits/trait_numeric.py @@ -80,7 +80,7 @@ def __init__ ( self, dtype = None, shape = None, value = None, raise TraitError( "Using Array or CArray trait types requires the " "numpy package to be installed." ) - from numpy import array, asarray, ndarray, zeros + from numpy import asarray, ndarray # Mark this as being an 'array' trait: metadata[ 'array' ] = True @@ -122,23 +122,7 @@ def __init__ ( self, dtype = None, shape = None, value = None, raise TraitError, "shape should be a list or tuple" if value is None: - if dtype is None: - # Compatibility with the default of Traits 2.0 - dt = int - else: - dt = dtype - if shape is None: - value = zeros( ( 0, ), dt ) - else: - size = [] - for item in shape: - if item is None: - item = 1 - elif type( item ) in SequenceTypes: - # XXX: what is this supposed to do? - item = item[0] - size.append( item ) - value = zeros( size, dt ) + value = self._default_for_dtype_and_shape( dtype, shape ) self.dtype = dtype self.shape = shape @@ -259,6 +243,31 @@ def copy_default_value ( self, value ): """ return value.copy() + def _default_for_dtype_and_shape ( self, dtype, shape ): + """ Invent a suitable default value for a given dtype and shape. """ + from numpy import zeros + + if dtype is None: + # Compatibility with the default of Traits 2.0 + dt = int + else: + dt = dtype + if shape is None: + value = zeros( ( 0, ), dt ) + else: + size = [] + for item in shape: + if item is None: + item = 1 + elif type( item ) in SequenceTypes: + # Given a (minimum-allowed-length, maximum-allowed_length) + # pair for a particular axis, use the minimum. + item = item[0] + size.append( item ) + value = zeros( size, dt ) + return value + + #------------------------------------------------------------------------------- # 'Array' trait: #------------------------------------------------------------------------------- @@ -354,3 +363,34 @@ def __init__ ( self, dtype = None, shape = None, value = None, super( CArray, self ).__init__( dtype, shape, value, True, typecode = typecode, **metadata ) + +#------------------------------------------------------------------------------- +# 'ArrayOrNone' trait +#------------------------------------------------------------------------------- + +class ArrayOrNone ( CArray ): + """ A trait whose value may be either a NumPy array or None, with + casting allowed. The default is None. + """ + def __init__ ( self, *args, **metadata ): + # Normally use object identity to detect array values changing: + metadata.setdefault( 'comparison_mode', OBJECT_IDENTITY_COMPARE ) + super( ArrayOrNone, self ).__init__( *args, **metadata ) + + def validate (self, object, name, value ): + if value is None: + return value + return super( ArrayOrNone, self ).validate( object, name, value ) + + def get_default_value ( self ): + dv = self.default_value + if dv is None: + return ( 0, dv ) + else: + return ( 7, ( self.copy_default_value, + ( self.validate( None, None, dv ), ), None ) ) + + def _default_for_dtype_and_shape ( self, dtype, shape ): + # For ArrayOrNone, if no default is explicitly specified, we + # always default to `None`. + return None