Skip to content

Commit

Permalink
ENH: use _constructor properties to get factory function instead of t…
Browse files Browse the repository at this point in the history
…ype(self), for easier subclassing
  • Loading branch information
wesm committed Jul 20, 2011
1 parent 5ecc475 commit 8021575
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 33 deletions.
60 changes: 31 additions & 29 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,19 +217,19 @@ def astype(self, dtype):
-------
casted : DataFrame
"""
return type(self)(self._data, dtype=dtype)
return self._constructor(self._data, dtype=dtype)

def _wrap_array(self, arr, axes, copy=False):
index, columns = axes
return type(self)(arr, index=index, columns=columns, copy=copy)
return self._constructor(arr, index=index, columns=columns, copy=copy)

@property
def axes(self):
return [self.index, self.columns]

@property
def _constructor(self):
return type(self)
return DataFrame

#----------------------------------------------------------------------
# Class behavior
Expand Down Expand Up @@ -277,7 +277,7 @@ def copy(self):
"""
Make a copy of this DataFrame
"""
return type(self)(self._data.copy())
return self._constructor(self._data.copy())

#----------------------------------------------------------------------
# Arithmetic methods
Expand Down Expand Up @@ -574,8 +574,8 @@ def transpose(self):
Returns a DataFrame with the rows/columns switched. Copy of data is not
made by default
"""
return type(self)(data=self.values.T, index=self.columns,
columns=self.index, copy=False)
return self._constructor(data=self.values.T, index=self.columns,
columns=self.index, copy=False)
T = property(transpose)

#----------------------------------------------------------------------
Expand Down Expand Up @@ -651,8 +651,8 @@ def __array__(self):
return self.values

def __array_wrap__(self, result):
return type(self)(result, index=self.index, columns=self.columns,
copy=False)
return self._constructor(result, index=self.index, columns=self.columns,
copy=False)

#----------------------------------------------------------------------
# getitem/setitem related
Expand Down Expand Up @@ -682,7 +682,7 @@ def __getitem__(self, item):
"""
if isinstance(item, slice):
new_data = self._data.get_slice(item, axis=1)
return type(self)(new_data)
return self._constructor(new_data)
elif isinstance(item, np.ndarray):
if len(item) != len(self.index):
raise ValueError('Item wrong length %d instead of %d!' %
Expand Down Expand Up @@ -846,11 +846,11 @@ def _reindex_index(self, new_index, method):
if new_index is self.index:
return self.copy()
new_data = self._data.reindex_axis(new_index, method, axis=1)
return type(self)(new_data)
return self._constructor(new_data)

def _reindex_columns(self, new_columns):
new_data = self._data.reindex_items(new_columns)
return type(self)(new_data)
return self._constructor(new_data)

def reindex_like(self, other, method=None):
"""
Expand Down Expand Up @@ -1048,14 +1048,16 @@ def fillna(self, value=None, method='pad'):
series = self._series
for col, s in series.iteritems():
result[col] = s.fillna(method=method, value=value)
return type(self)(result, index=self.index, columns=self.columns)
return self._constructor(result, index=self.index,
columns=self.columns)
else:
# Float type values
if len(self.columns) == 0:
return self

new_data = self._data.fillna(value)
return type(self)(new_data, index=self.index, columns=self.columns)
return self._constructor(new_data, index=self.index,
columns=self.columns)

#----------------------------------------------------------------------
# Rename
Expand Down Expand Up @@ -1131,7 +1133,7 @@ def _combine_frame(self, other, func, fill_value=None):
# some shortcuts
if fill_value is None:
if not self and not other:
return type(self)(index=new_index)
return self._constructor(index=new_index)
elif not self:
return other * nan
elif not other:
Expand Down Expand Up @@ -1164,8 +1166,8 @@ def _combine_frame(self, other, func, fill_value=None):
other_vals[other_mask & mask] = fill_value

result = func(this_vals, other_vals)
return type(self)(result, index=new_index, columns=new_columns,
copy=False)
return self._constructor(result, index=new_index, columns=new_columns,
copy=False)

def _indexed_same(self, other):
same_index = self.index.equals(other.index)
Expand Down Expand Up @@ -1202,8 +1204,8 @@ def _combine_match_index(self, other, func, fill_value=None):
if fill_value is not None:
raise NotImplementedError

return type(self)(func(values.T, other_vals).T, index=new_index,
columns=self.columns, copy=False)
return self._constructor(func(values.T, other_vals).T, index=new_index,
columns=self.columns, copy=False)

def _combine_match_columns(self, other, func, fill_value=None):
newCols = self.columns.union(other.index)
Expand All @@ -1215,15 +1217,15 @@ def _combine_match_columns(self, other, func, fill_value=None):
if fill_value is not None:
raise NotImplementedError

return type(self)(func(this.values, other), index=self.index,
columns=newCols, copy=False)
return self._constructor(func(this.values, other), index=self.index,
columns=newCols, copy=False)

def _combine_const(self, other, func):
if not self:
return self

return type(self)(func(self.values, other), index=self.index,
columns=self.columns, copy=False)
return self._constructor(func(self.values, other), index=self.index,
columns=self.columns, copy=False)

def _compare_frame(self, other, func):
if not self._indexed_same(other):
Expand Down Expand Up @@ -1488,7 +1490,7 @@ def _shift_block(blk, indexer):
new_data = self._data.copy()
new_data.axes[1] = self.index.shift(periods, offset)

return type(self)(new_data)
return self._constructor(new_data)

def _shift_indexer(self, periods):
# small reusable utility
Expand Down Expand Up @@ -1535,8 +1537,8 @@ def apply(self, func, axis=0, broadcast=False):

if isinstance(func, np.ufunc):
results = func(self.values)
return type(self)(data=results, index=self.index,
columns=self.columns, copy=False)
return self._constructor(data=results, index=self.index,
columns=self.columns, copy=False)
else:
if not broadcast:
return self._apply_standard(func, axis)
Expand Down Expand Up @@ -1692,7 +1694,7 @@ def _join_on(self, other, on):
raise Exception('%s column not contained in this frame!' % on)

new_data = self._data.join_on(other._data, self[on], axis=1)
return type(self)(new_data)
return self._constructor(new_data)

def _join_index(self, other, how):
join_index = self._get_join_index(other, how)
Expand All @@ -1702,7 +1704,7 @@ def _join_index(self, other, how):
# merge blocks
merged_data = this_data.merge(other_data)
assert(merged_data.axes[1] is join_index) # maybe unnecessary
return type(self)(merged_data)
return self._constructor(merged_data)

def _get_join_index(self, other, how):
if how == 'left':
Expand Down Expand Up @@ -1797,7 +1799,7 @@ def corr(self):
correl[i, j] = c
correl[j, i] = c

return type(self)(correl, index=cols, columns=cols)
return self._constructor(correl, index=cols, columns=cols)

def corrwith(self, other, axis=0, drop=False):
"""
Expand Down Expand Up @@ -1867,7 +1869,7 @@ def describe(self):
tmp.quantile(.1), tmp.median(),
tmp.quantile(.9), tmp.max()]

return type(self)(data, index=cols_destat, columns=cols)
return self._constructor(data, index=cols_destat, columns=cols)

#----------------------------------------------------------------------
# ndarray-like stats methods
Expand Down
8 changes: 6 additions & 2 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ class NDFrame(Picklable):
def __init__(self, data, axes=None, copy=False):
self._data = data

@property
def _constructor(self):
return NDFrame

@property
def axes(self):
return self._data.axes
Expand Down Expand Up @@ -89,7 +93,7 @@ def consolidate(self):
cons_data = self._data.consolidate()
if cons_data is self._data:
cons_data = cons_data.copy()
return type(self)(cons_data)
return self._constructor(cons_data)

@property
def _is_mixed_type(self):
Expand Down Expand Up @@ -174,7 +178,7 @@ def _reindex_axis(self, new_index, fill_method, axis):
else:
new_data = self._data.reindex_axis(new_index, axis=axis,
method=fill_method)
return type(self)(new_data)
return self._constructor(new_data)

def truncate(self, before=None, after=None):
"""Function truncate a sorted DataFrame / Series before and/or after
Expand Down
8 changes: 6 additions & 2 deletions pandas/core/panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,14 @@ def _get_plane_axes(self, axis):

return index, columns

@property
def _constructor(self):
return WidePanel

def _wrap_array(self, arr, axes, copy=False):
items, major, minor = axes
return type(self)(arr, items=items, major_axis=major,
minor_axis=minor, copy=copy)
return self._constructor(arr, items=items, major_axis=major,
minor_axis=minor, copy=copy)

def copy(self):
"""
Expand Down

0 comments on commit 8021575

Please sign in to comment.