55
66from distutils .version import LooseVersion
77import numpy as np
8+ import pytz
89import pandas as pd
910
1011from xray import Variable , Dataset , DataArray
@@ -153,7 +154,7 @@ def test_0d_time_data(self):
153154 def test_datetime64_conversion (self ):
154155 times = pd .date_range ('2000-01-01' , periods = 3 )
155156 for values , preserve_source in [
156- (times , False ),
157+ (times , True ),
157158 (times .values , True ),
158159 (times .values .astype ('datetime64[s]' ), False ),
159160 (times .to_pydatetime (), False ),
@@ -163,15 +164,12 @@ def test_datetime64_conversion(self):
163164 self .assertArrayEqual (v .values , times .values )
164165 self .assertEqual (v .values .dtype , np .dtype ('datetime64[ns]' ))
165166 same_source = source_ndarray (v .values ) is source_ndarray (values )
166- if preserve_source and self .cls is Variable :
167- self .assertTrue (same_source )
168- else :
169- self .assertFalse (same_source )
167+ assert preserve_source == same_source
170168
171169 def test_timedelta64_conversion (self ):
172170 times = pd .timedelta_range (start = 0 , periods = 3 )
173171 for values , preserve_source in [
174- (times , False ),
172+ (times , True ),
175173 (times .values , True ),
176174 (times .values .astype ('timedelta64[s]' ), False ),
177175 (times .to_pytimedelta (), False ),
@@ -181,10 +179,7 @@ def test_timedelta64_conversion(self):
181179 self .assertArrayEqual (v .values , times .values )
182180 self .assertEqual (v .values .dtype , np .dtype ('timedelta64[ns]' ))
183181 same_source = source_ndarray (v .values ) is source_ndarray (values )
184- if preserve_source and self .cls is Variable :
185- self .assertTrue (same_source )
186- else :
187- self .assertFalse (same_source )
182+ assert preserve_source == same_source
188183
189184 def test_object_conversion (self ):
190185 data = np .arange (5 ).astype (str ).astype (object )
@@ -405,6 +400,22 @@ def test_aggregate_complex(self):
405400 expected = Variable ((), 0.5 + 1j )
406401 self .assertVariableAllClose (v .mean (), expected )
407402
403+ def test_pandas_cateogrical_dtype (self ):
404+ data = pd .Categorical (np .arange (10 , dtype = 'int64' ))
405+ v = self .cls ('x' , data )
406+ print (v ) # should not error
407+ assert v .dtype == 'int64'
408+
409+ def test_pandas_datetime64_with_tz (self ):
410+ data = pd .date_range (start = '2000-01-01' ,
411+ tz = pytz .timezone ('America/New_York' ),
412+ periods = 10 , freq = '1h' )
413+ v = self .cls ('x' , data )
414+ print (v ) # should not error
415+ if 'America/New_York' in data .dtype :
416+ # pandas is new enough that it has datetime64 with timezone dtype
417+ assert v .dtype == 'object'
418+
408419
409420class TestVariable (TestCase , VariableSubclassTestCases ):
410421 cls = staticmethod (Variable )
0 commit comments