Skip to content
This repository has been archived by the owner on Jul 25, 2022. It is now read-only.

Commit

Permalink
use __getitem__ for df column selection (#41)
Browse files Browse the repository at this point in the history
* use __getitem__ for df column selection

* add python test
  • Loading branch information
jimexist authored Mar 15, 2022
1 parent e24d59c commit cbf6840
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ rust-version = "1.57"
[dependencies]
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] }
rand = "0.7"
pyo3 = { version = "0.15", features = ["extension-module", "abi3", "abi3-py36"] }
pyo3 = { version = "~0.15", features = ["extension-module", "abi3", "abi3-py36"] }
datafusion = { version = "^7.0.0", features = ["pyarrow"] }
datafusion-expr = { version = "^7.0.0" }
datafusion-common = { version = "^7.0.0", features = ["pyarrow"] }
Expand Down
55 changes: 55 additions & 0 deletions datafusion/tests/test_indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import pyarrow as pa
import pytest

from datafusion import ExecutionContext


@pytest.fixture
def df():
ctx = ExecutionContext()

# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 4, 6])],
names=["a", "b"],
)
return ctx.create_dataframe([[batch]])


def test_indexing(df):
assert df["a"] is not None
assert df["a", "b"] is not None
assert df[("a", "b")] is not None
assert df[["a"]] is not None


def test_err(df):
with pytest.raises(Exception) as e_info:
df["c"]

assert "No field with unqualified name" in e_info.value.args[0]

with pytest.raises(Exception) as e_info:
df[1]

assert (
"DataFrame can only be indexed by string index or indices"
in e_info.value.args[0]
)
36 changes: 29 additions & 7 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use pyo3::prelude::*;

use crate::utils::wait_for_future;
use crate::{errors::DataFusionError, expression::PyExpr};
use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::pyarrow::PyArrowConvert;
use datafusion::arrow::util::pretty;
use datafusion::dataframe::DataFrame;
use datafusion::logical_plan::JoinType;

use crate::utils::wait_for_future;
use crate::{errors::DataFusionError, expression::PyExpr};
use pyo3::exceptions::PyTypeError;
use pyo3::mapping::PyMappingProtocol;
use pyo3::prelude::*;
use pyo3::types::PyTuple;
use std::sync::Arc;

/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
Expand Down Expand Up @@ -142,3 +142,25 @@ impl PyDataFrame {
Ok(pretty::print_batches(&batches)?)
}
}

#[pyproto]
impl PyMappingProtocol<'_> for PyDataFrame {
fn __getitem__(&self, key: PyObject) -> PyResult<Self> {
Python::with_gil(|py| {
if let Ok(key) = key.extract::<&str>(py) {
self.select_columns(vec![key])
} else if let Ok(tuple) = key.extract::<&PyTuple>(py) {
let keys = tuple
.iter()
.map(|item| item.extract::<&str>())
.collect::<PyResult<Vec<&str>>>()?;
self.select_columns(keys)
} else if let Ok(keys) = key.extract::<Vec<&str>>(py) {
self.select_columns(keys)
} else {
let message = "DataFrame can only be indexed by string index or indices";
Err(PyTypeError::new_err(message))
}
})
}
}

0 comments on commit cbf6840

Please sign in to comment.