diff --git a/tests/unit/schema/test_schema.py b/tests/unit/schema/test_schema.py index 1a648c257..239f89441 100644 --- a/tests/unit/schema/test_schema.py +++ b/tests/unit/schema/test_schema.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import dataclasses + import pytest from merlin.dag import ColumnSelector @@ -157,8 +159,11 @@ def test_schema_to_pandas(): schema_set = Schema(["a", "b", "c"]) df = schema_set.to_pandas() + expected_columns = [field.name for field in dataclasses.fields(ColumnSchema)] + expected_columns.remove("properties") + assert isinstance(df, pd.DataFrame) - assert list(df.columns) == ["name", "tags", "dtype", "is_list", "is_ragged"] + assert list(df.columns) == expected_columns def test_construct_schema_with_column_names():