diff --git a/pandas/core/base.py b/pandas/core/base.py index f55d9f905945d..40ca70ec7973a 100644 --- a/pandas/core/base.py +++ b/pandas/core/base.py @@ -93,7 +93,7 @@ def _ensure_type(self: T, obj) -> T: Used by type checkers. """ - assert isinstance(obj, type(self)), type(obj) + assert issubclass(type(obj), type(self)), type(obj) return obj diff --git a/pandas/tests/base/test_base.py b/pandas/tests/base/test_base.py new file mode 100644 index 0000000000000..22eb4f03653f6 --- /dev/null +++ b/pandas/tests/base/test_base.py @@ -0,0 +1,38 @@ +import pytest + +from pandas.core.base import PandasObject + +pandas_object = PandasObject() + + +class SubclassPandasObject(PandasObject): + pass + + +subclass_pandas_object = SubclassPandasObject() + + +@pytest.mark.parametrize("other_object", [pandas_object, subclass_pandas_object]) +def test_pandas_object_ensure_type(other_object): + pandas_object = PandasObject() + assert pandas_object._ensure_type(other_object) + + +def test_pandas_object_ensure_type_for_same_object(): + pandas_object_a = PandasObject() + pandas_object_b = pandas_object_a + assert pandas_object_a._ensure_type(pandas_object_b) + + +class OtherClass: + pass + + +other_class = OtherClass() + + +@pytest.mark.parametrize("other_object", [other_class]) +def test_pandas_object_ensure_type_for_false(other_object): + pandas_object = PandasObject() + with pytest.raises(AssertionError): + assert pandas_object._ensure_type(other_object)