Skip to content

Commit

Permalink
Require EAs to implement _from_sequence_of_strings to be used in parsers
Browse files Browse the repository at this point in the history
  • Loading branch information
kprestel committed Dec 21, 2018
1 parent cfd55fe commit 4937b22
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 7 deletions.
8 changes: 6 additions & 2 deletions pandas/_libs/parsers.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1223,8 +1223,12 @@ cdef class TextReader:
result = dtype.construct_array_type() \
._from_sequence_of_strings(result, dtype=dtype)
except NotImplementedError:
result = dtype.construct_array_type() \
._from_sequence(result, dtype=dtype)
raise NotImplementedError(
"Extension Array: {ea} must implement "
"_from_sequence_of_strings in order "
"to be used in parser methods".format(
ea=dtype.construct_array_type()))

return result, na_count

elif is_integer_dtype(dtype):
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _from_sequence(cls, scalars, dtype=None, copy=False):

@classmethod
def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
"""Construct a new ExtensionArray from a sequence of scalars.
"""Construct a new ExtensionArray from a sequence of strings.
.. versionadded:: 0.24.0
Expand All @@ -145,7 +145,7 @@ def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
ExtensionArray
"""
raise AbstractMethodError(cls)
raise NotImplementedError(cls)

@classmethod
def _from_factorized(cls, values, original):
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def _from_sequence(cls, scalars, dtype=None, copy=False):

@classmethod
def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
scalars = to_numeric(strings, errors='raise')
scalars = to_numeric(strings, errors="raise")
return cls._from_sequence(scalars, dtype, copy)

@classmethod
Expand Down
1 change: 0 additions & 1 deletion pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import numpy as np

from pandas.errors import AbstractMethodError
from pandas._libs import lib, tslib, tslibs
from pandas._libs.tslibs import OutOfBoundsDatetime, Period, iNaT
from pandas.compat import PY3, string_types, text_type, to_str
Expand Down
1 change: 0 additions & 1 deletion pandas/tests/extension/base/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,3 @@ def test_EA_types(self, engine, data):
result = pd.read_csv(StringIO(data), dtype={'Int': str(data.dtype)},
engine=engine)
assert result is not None

40 changes: 40 additions & 0 deletions pandas/tests/extension/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,27 @@
import pandas as pd
from pandas.core.arrays import ExtensionArray
import pandas.util.testing as tm
from pandas.compat import StringIO
from pandas.core.arrays.integer import (
Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype, UInt8Dtype, UInt16Dtype,
UInt32Dtype, UInt64Dtype, integer_array,
)


def make_data():
return (list(range(1, 9)) + [np.nan] + list(range(10, 98))
+ [np.nan] + [99, 100])


@pytest.fixture(params=[Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype,
UInt8Dtype, UInt16Dtype, UInt32Dtype, UInt64Dtype])
def dtype(request):
return request.param()


@pytest.fixture
def data(dtype):
return integer_array(make_data(), dtype=dtype)


class DummyDtype(dtypes.ExtensionDtype):
Expand Down Expand Up @@ -92,3 +113,22 @@ def test_is_not_extension_array_dtype(dtype):
def test_is_extension_array_dtype(dtype):
assert isinstance(dtype, dtypes.ExtensionDtype)
assert is_extension_array_dtype(dtype)


@pytest.mark.parametrize('engine', ['c', 'python'])
def test_EA_types(engine):
df = pd.DataFrame({'Int': pd.Series([1, 2, 3], dtype='Int64'),
'A': [1, 2, 1]})
data = df.to_csv(index=False)
result = pd.read_csv(StringIO(data), dtype={'Int': Int64Dtype},
engine=engine)
assert result is not None
tm.assert_frame_equal(result, df)

df = pd.DataFrame({'Int': pd.Series([1, 2, 3], dtype='Int8'),
'A': [1, 2, 1]})
data = df.to_csv(index=False)
result = pd.read_csv(StringIO(data), dtype={'Int': 'Int8'},
engine=engine)
assert result is not None
tm.assert_frame_equal(result, df)

0 comments on commit 4937b22

Please sign in to comment.