diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt
index aab0a9b2d49..5fd68bfb26c 100644
--- a/cpp/CMakeLists.txt
+++ b/cpp/CMakeLists.txt
@@ -502,6 +502,7 @@ add_library(
src/reductions/product.cu
src/reductions/reductions.cpp
src/reductions/scan/rank_scan.cu
+ src/reductions/scan/ewm.cu
src/reductions/scan/scan.cpp
src/reductions/scan/scan_exclusive.cu
src/reductions/scan/scan_inclusive.cu
diff --git a/cpp/benchmarks/io/text/multibyte_split.cpp b/cpp/benchmarks/io/text/multibyte_split.cpp
index 67705863d41..4bfef9767ca 100644
--- a/cpp/benchmarks/io/text/multibyte_split.cpp
+++ b/cpp/benchmarks/io/text/multibyte_split.cpp
@@ -85,8 +85,7 @@ static cudf::string_scalar create_random_input(int32_t num_chars,
// extract the chars from the returned strings column.
auto input_column_contents = input_column->release();
- auto chars_column_contents = input_column_contents.children[1]->release();
- auto chars_buffer = chars_column_contents.data.release();
+ auto chars_buffer = input_column_contents.data.release();
// turn the chars in to a string scalar.
return cudf::string_scalar(std::move(*chars_buffer));
@@ -218,7 +217,7 @@ NVBENCH_BENCH_TYPES(bench_multibyte_split,
NVBENCH_BENCH_TYPES(bench_multibyte_split, NVBENCH_TYPE_AXES(source_type_list))
.set_name("multibyte_split_source")
.set_min_samples(4)
- .add_int64_axis("strip_delimiters", {1})
+ .add_int64_axis("strip_delimiters", {0, 1})
.add_int64_axis("delim_size", {1})
.add_int64_axis("delim_percent", {1})
.add_int64_power_of_two_axis("size_approx", {15, 30})
diff --git a/cpp/include/cudf/aggregation.hpp b/cpp/include/cudf/aggregation.hpp
index d458c831f19..3c1023017be 100644
--- a/cpp/include/cudf/aggregation.hpp
+++ b/cpp/include/cudf/aggregation.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2023, NVIDIA CORPORATION.
+ * Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -103,6 +103,7 @@ class aggregation {
NUNIQUE, ///< count number of unique elements
NTH_ELEMENT, ///< get the nth element
ROW_NUMBER, ///< get row-number of current index (relative to rolling window)
+ EWMA, ///< get exponential weighted moving average at current index
RANK, ///< get rank of current index
COLLECT_LIST, ///< collect values into a list
COLLECT_SET, ///< collect values into a list without duplicate entries
@@ -250,6 +251,8 @@ class segmented_reduce_aggregation : public virtual aggregation {
enum class udf_type : bool { CUDA, PTX };
/// Type of correlation method.
enum class correlation_type : int32_t { PEARSON, KENDALL, SPEARMAN };
+/// Type of treatment of EWM input values' first value
+enum class ewm_history : int32_t { INFINITE, FINITE };
/// Factory to create a SUM aggregation
/// @return A SUM aggregation object
@@ -411,6 +414,42 @@ std::unique_ptr make_nth_element_aggregation(
template
std::unique_ptr make_row_number_aggregation();
+/**
+ * @brief Factory to create an EWMA aggregation
+ *
+ * `EWMA` returns a non-nullable column with the same type as the input,
+ * whose values are the exponentially weighted moving average of the input
+ * sequence. Let these values be known as the y_i.
+ *
+ * EWMA aggregations are parameterized by a center of mass (`com`) which
+ * affects the contribution of the previous values (y_0 ... y_{i-1}) in
+ * computing the y_i.
+ *
+ * EWMA aggregations are also parameterized by a history `cudf::ewm_history`.
+ * Special considerations have to be given to the mathematical treatment of
+ * the first value of the input sequence. There are two approaches to this,
+ * one which considers the first value of the sequence to be the exponential
+ * weighted moving average of some infinite history of data, and one which
+ * takes the first value to be the only datapoint known. These assumptions
+ * lead to two different formulas for the y_i. `ewm_history` selects which.
+ *
+ * EWMA aggregations have special null handling. Nulls have two effects. The
+ * first is to propagate forward the last valid value as far as it has been
+ * computed. This could be thought of as the nulls not affecting the average
+ * in any way. The second effect changes the way the y_i are computed. Since
+ * a moving average is conceptually designed to weight contributing values by
+ * their recency, nulls ought to count as valid periods even though they do
+ * not change the average. For example, if the input sequence is {1, NULL, 3}
+ * then when computing y_2 one should weigh y_0 as if it occurs two periods
+ * before y_2 rather than just one.
+ *
+ * @param center_of_mass the center of mass.
+ * @param history which assumption to make about the first value
+ * @return A EWM aggregation object
+ */
+template
+std::unique_ptr make_ewma_aggregation(double const center_of_mass, ewm_history history);
+
/**
* @brief Factory to create a RANK aggregation
*
diff --git a/cpp/include/cudf/detail/aggregation/aggregation.hpp b/cpp/include/cudf/detail/aggregation/aggregation.hpp
index edee83783b8..843414817e3 100644
--- a/cpp/include/cudf/detail/aggregation/aggregation.hpp
+++ b/cpp/include/cudf/detail/aggregation/aggregation.hpp
@@ -76,6 +76,8 @@ class simple_aggregations_collector { // Declares the interface for the simple
class nth_element_aggregation const& agg);
virtual std::vector> visit(data_type col_type,
class row_number_aggregation const& agg);
+ virtual std::vector> visit(data_type col_type,
+ class ewma_aggregation const& agg);
virtual std::vector> visit(data_type col_type,
class rank_aggregation const& agg);
virtual std::vector> visit(
@@ -141,6 +143,7 @@ class aggregation_finalizer { // Declares the interface for the finalizer
virtual void visit(class correlation_aggregation const& agg);
virtual void visit(class tdigest_aggregation const& agg);
virtual void visit(class merge_tdigest_aggregation const& agg);
+ virtual void visit(class ewma_aggregation const& agg);
};
/**
@@ -667,6 +670,40 @@ class row_number_aggregation final : public rolling_aggregation {
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
};
+/**
+ * @brief Derived class for specifying an ewma aggregation
+ */
+class ewma_aggregation final : public scan_aggregation {
+ public:
+ double const center_of_mass;
+ cudf::ewm_history history;
+
+ ewma_aggregation(double const center_of_mass, cudf::ewm_history history)
+ : aggregation{EWMA}, center_of_mass{center_of_mass}, history{history}
+ {
+ }
+
+ std::unique_ptr clone() const override
+ {
+ return std::make_unique(*this);
+ }
+
+ std::vector> get_simple_aggregations(
+ data_type col_type, simple_aggregations_collector& collector) const override
+ {
+ return collector.visit(col_type, *this);
+ }
+
+ bool is_equal(aggregation const& _other) const override
+ {
+ if (!this->aggregation::is_equal(_other)) { return false; }
+ auto const& other = dynamic_cast(_other);
+ return this->center_of_mass == other.center_of_mass and this->history == other.history;
+ }
+
+ void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
+};
+
/**
* @brief Derived class for specifying a rank aggregation
*/
@@ -1336,6 +1373,11 @@ struct target_type_impl {
using type = size_type;
};
+template
+struct target_type_impl {
+ using type = double;
+};
+
// Always use size_type accumulator for RANK
template
struct target_type_impl {
@@ -1536,6 +1578,8 @@ CUDF_HOST_DEVICE inline decltype(auto) aggregation_dispatcher(aggregation::Kind
return f.template operator()(std::forward(args)...);
case aggregation::MERGE_TDIGEST:
return f.template operator()(std::forward(args)...);
+ case aggregation::EWMA:
+ return f.template operator()(std::forward(args)...);
default: {
#ifndef __CUDA_ARCH__
CUDF_FAIL("Unsupported aggregation.");
diff --git a/cpp/src/aggregation/aggregation.cpp b/cpp/src/aggregation/aggregation.cpp
index adee9147740..5422304c5cb 100644
--- a/cpp/src/aggregation/aggregation.cpp
+++ b/cpp/src/aggregation/aggregation.cpp
@@ -154,6 +154,12 @@ std::vector> simple_aggregations_collector::visit(
return visit(col_type, static_cast(agg));
}
+std::vector> simple_aggregations_collector::visit(
+ data_type col_type, ewma_aggregation const& agg)
+{
+ return visit(col_type, static_cast(agg));
+}
+
std::vector> simple_aggregations_collector::visit(
data_type col_type, rank_aggregation const& agg)
{
@@ -333,6 +339,11 @@ void aggregation_finalizer::visit(row_number_aggregation const& agg)
visit(static_cast(agg));
}
+void aggregation_finalizer::visit(ewma_aggregation const& agg)
+{
+ visit(static_cast(agg));
+}
+
void aggregation_finalizer::visit(rank_aggregation const& agg)
{
visit(static_cast(agg));
@@ -665,6 +676,17 @@ std::unique_ptr make_row_number_aggregation()
template std::unique_ptr make_row_number_aggregation();
template std::unique_ptr make_row_number_aggregation();
+/// Factory to create an EWMA aggregation
+template
+std::unique_ptr make_ewma_aggregation(double const com, cudf::ewm_history history)
+{
+ return std::make_unique(com, history);
+}
+template std::unique_ptr make_ewma_aggregation(double const com,
+ cudf::ewm_history history);
+template std::unique_ptr make_ewma_aggregation(
+ double const com, cudf::ewm_history history);
+
/// Factory to create a RANK aggregation
template
std::unique_ptr make_rank_aggregation(rank_method method,
diff --git a/cpp/src/io/parquet/writer_impl.cu b/cpp/src/io/parquet/writer_impl.cu
index ca15b532d07..bed4dbc5a66 100644
--- a/cpp/src/io/parquet/writer_impl.cu
+++ b/cpp/src/io/parquet/writer_impl.cu
@@ -296,19 +296,6 @@ size_t column_size(column_view const& column, rmm::cuda_stream_view stream)
CUDF_FAIL("Unexpected compound type");
}
-// checks to see if the given column has a fixed size. This doesn't
-// check every row, so assumes string and list columns are not fixed, even
-// if each row is the same width.
-// TODO: update this if FIXED_LEN_BYTE_ARRAY is ever supported for writes.
-bool is_col_fixed_width(column_view const& column)
-{
- if (column.type().id() == type_id::STRUCT) {
- return std::all_of(column.child_begin(), column.child_end(), is_col_fixed_width);
- }
-
- return is_fixed_width(column.type());
-}
-
/**
* @brief Extends SchemaElement to add members required in constructing parquet_column_view
*
@@ -946,6 +933,15 @@ struct parquet_column_view {
return schema_node.converted_type.value_or(UNKNOWN);
}
+ // Checks to see if the given column has a fixed-width data type. This doesn't
+ // check every value, so it assumes string and list columns are not fixed-width, even
+ // if each value has the same size.
+ [[nodiscard]] bool is_fixed_width() const
+ {
+ // lists and strings are not fixed width
+ return max_rep_level() == 0 and physical_type() != Type::BYTE_ARRAY;
+ }
+
std::vector const& get_path_in_schema() { return path_in_schema; }
// LIST related member functions
@@ -1764,7 +1760,7 @@ auto convert_table_to_parquet_data(table_input_metadata& table_meta,
// unbalanced in final page sizes, so using 4 which seems to be a good
// compromise at smoothing things out without getting fragment sizes too small.
auto frag_size_fn = [&](auto const& col, size_t col_size) {
- int const target_frags_per_page = is_col_fixed_width(col) ? 1 : 4;
+ int const target_frags_per_page = col.is_fixed_width() ? 1 : 4;
auto const avg_len =
target_frags_per_page * util::div_rounding_up_safe(col_size, input.num_rows());
if (avg_len > 0) {
@@ -1775,8 +1771,8 @@ auto convert_table_to_parquet_data(table_input_metadata& table_meta,
}
};
- std::transform(single_streams_table.begin(),
- single_streams_table.end(),
+ std::transform(parquet_columns.begin(),
+ parquet_columns.end(),
column_sizes.begin(),
column_frag_size.begin(),
frag_size_fn);
diff --git a/cpp/src/io/text/byte_range_info.cpp b/cpp/src/io/text/byte_range_info.cpp
index 290e0451839..6a7836ed4e1 100644
--- a/cpp/src/io/text/byte_range_info.cpp
+++ b/cpp/src/io/text/byte_range_info.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022, NVIDIA CORPORATION.
+ * Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -31,7 +31,7 @@ std::vector create_byte_range_infos_consecutive(int64_t total_b
auto range_size = util::div_rounding_up_safe(total_bytes, range_count);
auto ranges = std::vector();
- ranges.reserve(range_size);
+ ranges.reserve(range_count);
for (int64_t i = 0; i < range_count; i++) {
auto offset = i * range_size;
diff --git a/cpp/src/io/text/data_chunk_source_factories.cpp b/cpp/src/io/text/data_chunk_source_factories.cpp
index 596ca3458c8..58faa0ebfe4 100644
--- a/cpp/src/io/text/data_chunk_source_factories.cpp
+++ b/cpp/src/io/text/data_chunk_source_factories.cpp
@@ -120,7 +120,11 @@ class istream_data_chunk_reader : public data_chunk_reader {
{
}
- void skip_bytes(std::size_t size) override { _datastream->ignore(size); };
+ void skip_bytes(std::size_t size) override
+ {
+ // 20% faster than _datastream->ignore(size) for large files
+ _datastream->seekg(_datastream->tellg() + static_cast(size));
+ };
std::unique_ptr get_next_chunk(std::size_t read_size,
rmm::cuda_stream_view stream) override
@@ -265,7 +269,7 @@ class file_data_chunk_source : public data_chunk_source {
[[nodiscard]] std::unique_ptr create_reader() const override
{
return std::make_unique(
- std::make_unique(_filename, std::ifstream::in));
+ std::make_unique(_filename, std::ifstream::in | std::ifstream::binary));
}
private:
diff --git a/cpp/src/reductions/scan/ewm.cu b/cpp/src/reductions/scan/ewm.cu
new file mode 100644
index 00000000000..3fa2de450ad
--- /dev/null
+++ b/cpp/src/reductions/scan/ewm.cu
@@ -0,0 +1,330 @@
+/*
+ * Copyright (c) 2022-2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "scan.cuh"
+
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+
+#include
+#include
+#include
+
+namespace cudf {
+namespace detail {
+
+template
+using pair_type = thrust::pair;
+
+/**
+ * @brief functor to be summed over in a prefix sum such that
+ * the recurrence in question is solved. See
+ * G. E. Blelloch. Prefix sums and their applications. Technical Report
+ * CMU-CS-90-190, Nov. 1990. S. 1.4
+ * for details
+ */
+template
+class recurrence_functor {
+ public:
+ __device__ pair_type operator()(pair_type ci, pair_type cj)
+ {
+ return {ci.first * cj.first, ci.second * cj.first + cj.second};
+ }
+};
+
+template
+struct ewma_functor_base {
+ T beta;
+ const pair_type IDENTITY{1.0, 0.0};
+};
+
+template
+struct ewma_adjust_nulls_functor : public ewma_functor_base {
+ __device__ pair_type operator()(thrust::tuple const data)
+ {
+ // Not const to allow for updating the input value
+ auto [valid, exp, input] = data;
+ if (!valid) { return this->IDENTITY; }
+ if constexpr (not is_numerator) { input = 1; }
+
+ // The value is non-null, but nulls preceding it
+ // must adjust the second element of the pair
+ T const beta = this->beta;
+ return {beta * ((exp != 0) ? pow(beta, exp) : 1), input};
+ }
+};
+
+template
+struct ewma_adjust_no_nulls_functor : public ewma_functor_base {
+ __device__ pair_type operator()(T const data)
+ {
+ T const beta = this->beta;
+ if constexpr (is_numerator) {
+ return {beta, data};
+ } else {
+ return {beta, 1.0};
+ }
+ }
+};
+
+template
+struct ewma_noadjust_nulls_functor : public ewma_functor_base {
+ /*
+ In the null case, a denominator actually has to be computed. The formula is
+ y_{i+1} = (1 - alpha)x_{i-1} + alpha x_i, but really there is a "denominator"
+ which is the sum of the weights: alpha + (1 - alpha) == 1. If a null is
+ encountered, that means that the "previous" value is downweighted by a
+ factor (for each missing value). For example with a single null:
+ data = {x_0, NULL, x_1},
+ y_2 = (1 - alpha)**2 x_0 + alpha * x_2 / (alpha + (1-alpha)**2)
+
+ As such, the pairs must be updated before summing like the adjusted case to
+ properly downweight the previous values. But now but we also need to compute
+ the normalization factors and divide the results into them at the end.
+ */
+ __device__ pair_type operator()(thrust::tuple const data)
+ {
+ T const beta = this->beta;
+ auto const [input, index, valid, nullcnt] = data;
+ if (index == 0) {
+ return {beta, input};
+ } else {
+ if (!valid) { return this->IDENTITY; }
+ // preceding value is valid, return normal pair
+ if (nullcnt == 0) { return {beta, (1.0 - beta) * input}; }
+ // one or more preceding values is null, adjust by how many
+ T const factor = (1.0 - beta) + pow(beta, nullcnt + 1);
+ return {(beta * (pow(beta, nullcnt)) / factor), ((1.0 - beta) * input) / factor};
+ }
+ }
+};
+
+template
+struct ewma_noadjust_no_nulls_functor : public ewma_functor_base {
+ __device__ pair_type operator()(thrust::tuple const data)
+ {
+ T const beta = this->beta;
+ auto const [input, index] = data;
+ if (index == 0) {
+ return {beta, input};
+ } else {
+ return {beta, (1.0 - beta) * input};
+ }
+ }
+};
+
+/**
+* @brief Return an array whose values y_i are the number of null entries
+* in between the last valid entry of the input and the current index.
+* Example: {1, NULL, 3, 4, NULL, NULL, 7}
+ -> {0, 0 1, 0, 0, 1, 2}
+*/
+rmm::device_uvector null_roll_up(column_view const& input,
+ rmm::cuda_stream_view stream)
+{
+ rmm::device_uvector output(input.size(), stream);
+
+ auto device_view = column_device_view::create(input);
+ auto invalid_it = thrust::make_transform_iterator(
+ cudf::detail::make_validity_iterator(*device_view),
+ cuda::proclaim_return_type([] __device__(int valid) -> int { return 1 - valid; }));
+
+ // valid mask {1, 0, 1, 0, 0, 1} leads to output array {0, 0, 1, 0, 1, 2}
+ thrust::inclusive_scan_by_key(rmm::exec_policy(stream),
+ invalid_it,
+ invalid_it + input.size() - 1,
+ invalid_it,
+ std::next(output.begin()));
+ return output;
+}
+
+template
+rmm::device_uvector compute_ewma_adjust(column_view const& input,
+ T const beta,
+ rmm::cuda_stream_view stream,
+ rmm::device_async_resource_ref mr)
+{
+ rmm::device_uvector output(input.size(), stream);
+ rmm::device_uvector> pairs(input.size(), stream);
+
+ if (input.has_nulls()) {
+ rmm::device_uvector nullcnt = null_roll_up(input, stream);
+ auto device_view = column_device_view::create(input);
+ auto valid_it = cudf::detail::make_validity_iterator(*device_view);
+ auto data =
+ thrust::make_zip_iterator(thrust::make_tuple(valid_it, nullcnt.begin(), input.begin()));
+
+ thrust::transform_inclusive_scan(rmm::exec_policy(stream),
+ data,
+ data + input.size(),
+ pairs.begin(),
+ ewma_adjust_nulls_functor{beta},
+ recurrence_functor{});
+ thrust::transform(rmm::exec_policy(stream),
+ pairs.begin(),
+ pairs.end(),
+ output.begin(),
+ [] __device__(pair_type pair) -> T { return pair.second; });
+
+ thrust::transform_inclusive_scan(rmm::exec_policy(stream),
+ data,
+ data + input.size(),
+ pairs.begin(),
+ ewma_adjust_nulls_functor{beta},
+ recurrence_functor{});
+
+ } else {
+ thrust::transform_inclusive_scan(rmm::exec_policy(stream),
+ input.begin(),
+ input.end(),
+ pairs.begin(),
+ ewma_adjust_no_nulls_functor{beta},
+ recurrence_functor{});
+ thrust::transform(rmm::exec_policy(stream),
+ pairs.begin(),
+ pairs.end(),
+ output.begin(),
+ [] __device__(pair_type pair) -> T { return pair.second; });
+ auto itr = thrust::make_counting_iterator(0);
+
+ thrust::transform_inclusive_scan(rmm::exec_policy(stream),
+ itr,
+ itr + input.size(),
+ pairs.begin(),
+ ewma_adjust_no_nulls_functor{beta},
+ recurrence_functor{});
+ }
+
+ thrust::transform(
+ rmm::exec_policy(stream),
+ pairs.begin(),
+ pairs.end(),
+ output.begin(),
+ output.begin(),
+ [] __device__(pair_type pair, T numerator) -> T { return numerator / pair.second; });
+
+ return output;
+}
+
+template
+rmm::device_uvector compute_ewma_noadjust(column_view const& input,
+ T const beta,
+ rmm::cuda_stream_view stream,
+ rmm::device_async_resource_ref mr)
+{
+ rmm::device_uvector output(input.size(), stream);
+ rmm::device_uvector> pairs(input.size(), stream);
+ rmm::device_uvector nullcnt =
+ [&input, stream]() -> rmm::device_uvector {
+ if (input.has_nulls()) {
+ return null_roll_up(input, stream);
+ } else {
+ return rmm::device_uvector(input.size(), stream);
+ }
+ }();
+ // denominators are all 1 and do not need to be computed
+ // pairs are all (beta, 1-beta x_i) except for the first one
+
+ if (!input.has_nulls()) {
+ auto data = thrust::make_zip_iterator(
+ thrust::make_tuple(input.begin(), thrust::make_counting_iterator(0)));
+ thrust::transform_inclusive_scan(rmm::exec_policy(stream),
+ data,
+ data + input.size(),
+ pairs.begin(),
+ ewma_noadjust_no_nulls_functor{beta},
+ recurrence_functor{});
+
+ } else {
+ auto device_view = column_device_view::create(input);
+ auto valid_it = detail::make_validity_iterator(*device_view);
+
+ auto data = thrust::make_zip_iterator(thrust::make_tuple(
+ input.begin(), thrust::make_counting_iterator(0), valid_it, nullcnt.begin()));
+
+ thrust::transform_inclusive_scan(rmm::exec_policy(stream),
+ data,
+ data + input.size(),
+ pairs.begin(),
+ ewma_noadjust_nulls_functor{beta},
+ recurrence_functor());
+ }
+
+ // copy the second elements to the output for now
+ thrust::transform(rmm::exec_policy(stream),
+ pairs.begin(),
+ pairs.end(),
+ output.begin(),
+ [] __device__(pair_type pair) -> T { return pair.second; });
+ return output;
+}
+
+struct ewma_functor {
+ template ::value)>
+ std::unique_ptr operator()(scan_aggregation const& agg,
+ column_view const& input,
+ rmm::cuda_stream_view stream,
+ rmm::device_async_resource_ref mr)
+ {
+ CUDF_FAIL("Unsupported type for EWMA.");
+ }
+
+ template ::value)>
+ std::unique_ptr operator()(scan_aggregation const& agg,
+ column_view const& input,
+ rmm::cuda_stream_view stream,
+ rmm::device_async_resource_ref mr)
+ {
+ auto const ewma_agg = dynamic_cast(&agg);
+ auto const history = ewma_agg->history;
+ auto const center_of_mass = ewma_agg->center_of_mass;
+
+ // center of mass is easier for the user, but the recurrences are
+ // better expressed in terms of the derived parameter `beta`
+ T const beta = center_of_mass / (center_of_mass + 1.0);
+
+ auto result = [&]() {
+ if (history == cudf::ewm_history::INFINITE) {
+ return compute_ewma_adjust(input, beta, stream, mr);
+ } else {
+ return compute_ewma_noadjust(input, beta, stream, mr);
+ }
+ }();
+ return std::make_unique(cudf::data_type(cudf::type_to_id()),
+ input.size(),
+ result.release(),
+ rmm::device_buffer{},
+ 0);
+ }
+};
+
+std::unique_ptr exponentially_weighted_moving_average(column_view const& input,
+ scan_aggregation const& agg,
+ rmm::cuda_stream_view stream,
+ rmm::device_async_resource_ref mr)
+{
+ return type_dispatcher(input.type(), ewma_functor{}, agg, input, stream, mr);
+}
+
+} // namespace detail
+} // namespace cudf
diff --git a/cpp/src/reductions/scan/scan.cuh b/cpp/src/reductions/scan/scan.cuh
index aeb9e516cd4..6c237741ac3 100644
--- a/cpp/src/reductions/scan/scan.cuh
+++ b/cpp/src/reductions/scan/scan.cuh
@@ -36,6 +36,12 @@ std::pair mask_scan(column_view const& input_view
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr);
+// exponentially weighted moving average of the input
+std::unique_ptr exponentially_weighted_moving_average(column_view const& input,
+ scan_aggregation const& agg,
+ rmm::cuda_stream_view stream,
+ rmm::device_async_resource_ref mr);
+
template typename DispatchFn>
std::unique_ptr scan_agg_dispatch(column_view const& input,
scan_aggregation const& agg,
@@ -59,6 +65,7 @@ std::unique_ptr scan_agg_dispatch(column_view const& input,
if (is_fixed_point(input.type())) CUDF_FAIL("decimal32/64/128 cannot support product scan");
return type_dispatcher(
input.type(), DispatchFn(), input, output_mask, stream, mr);
+ case aggregation::EWMA: return exponentially_weighted_moving_average(input, agg, stream, mr);
default: CUDF_FAIL("Unsupported aggregation operator for scan");
}
}
diff --git a/cpp/src/reductions/scan/scan_inclusive.cu b/cpp/src/reductions/scan/scan_inclusive.cu
index ad2eaa6a471..7c02a8d1b99 100644
--- a/cpp/src/reductions/scan/scan_inclusive.cu
+++ b/cpp/src/reductions/scan/scan_inclusive.cu
@@ -182,7 +182,8 @@ std::unique_ptr scan_inclusive(column_view const& input,
auto output = scan_agg_dispatch(
input, agg, static_cast(mask.data()), stream, mr);
- output->set_null_mask(std::move(mask), null_count);
+ // Use the null mask produced by the op for EWM
+ if (agg.kind != aggregation::EWMA) { output->set_null_mask(std::move(mask), null_count); }
// If the input is a structs column, we also need to push down nulls from the parent output column
// into the children columns.
diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt
index eda470d2309..9f14455f42d 100644
--- a/cpp/tests/CMakeLists.txt
+++ b/cpp/tests/CMakeLists.txt
@@ -205,6 +205,7 @@ ConfigureTest(
ConfigureTest(
REDUCTIONS_TEST
reductions/collect_ops_tests.cpp
+ reductions/ewm_tests.cpp
reductions/rank_tests.cpp
reductions/reduction_tests.cpp
reductions/scan_tests.cpp
diff --git a/cpp/tests/ast/transform_tests.cpp b/cpp/tests/ast/transform_tests.cpp
index ef1d09e5652..6b350c137d0 100644
--- a/cpp/tests/ast/transform_tests.cpp
+++ b/cpp/tests/ast/transform_tests.cpp
@@ -65,6 +65,22 @@ TEST_F(TransformTest, ColumnReference)
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->view(), verbosity);
}
+TEST_F(TransformTest, BasicAdditionDoubleCast)
+{
+ auto c_0 = column_wrapper{3, 20, 1, 50};
+ std::vector<__int128_t> data1{10, 7, 20, 0};
+ auto c_1 = cudf::test::fixed_point_column_wrapper<__int128_t>(
+ data1.begin(), data1.end(), numeric::scale_type{0});
+ auto table = cudf::table_view{{c_0, c_1}};
+ auto col_ref_0 = cudf::ast::column_reference(0);
+ auto col_ref_1 = cudf::ast::column_reference(1);
+ auto cast = cudf::ast::operation(cudf::ast::ast_operator::CAST_TO_FLOAT64, col_ref_1);
+ auto expression = cudf::ast::operation(cudf::ast::ast_operator::ADD, col_ref_0, cast);
+ auto expected = column_wrapper{13, 27, 21, 50};
+ auto result = cudf::compute_column(table, expression);
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->view(), verbosity);
+}
+
TEST_F(TransformTest, Literal)
{
auto c_0 = column_wrapper{3, 20, 1, 50};
diff --git a/cpp/tests/reductions/ewm_tests.cpp b/cpp/tests/reductions/ewm_tests.cpp
new file mode 100644
index 00000000000..09cec688509
--- /dev/null
+++ b/cpp/tests/reductions/ewm_tests.cpp
@@ -0,0 +1,101 @@
+/*
+ * Copyright (c) 2021-2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "scan_tests.hpp"
+
+#include
+#include
+#include
+
+#include
+#include
+
+template
+struct TypedEwmScanTest : BaseScanTest {
+ inline void test_ungrouped_ewma_scan(cudf::column_view const& input,
+ cudf::column_view const& expect_vals,
+ cudf::scan_aggregation const& agg,
+ cudf::null_policy null_handling)
+ {
+ auto col_out = cudf::scan(input, agg, cudf::scan_type::INCLUSIVE, null_handling);
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expect_vals, col_out->view());
+ }
+};
+
+TYPED_TEST_SUITE(TypedEwmScanTest, cudf::test::FloatingPointTypes);
+
+TYPED_TEST(TypedEwmScanTest, Ewm)
+{
+ auto const v = make_vector({1.0, 2.0, 3.0, 4.0, 5.0});
+ auto col = this->make_column(v);
+
+ auto const expected_ewma_vals_adjust = cudf::test::fixed_width_column_wrapper{
+ {1.0, 1.75, 2.61538461538461497469, 3.54999999999999982236, 4.52066115702479365268}};
+
+ auto const expected_ewma_vals_noadjust =
+ cudf::test::fixed_width_column_wrapper{{1.0,
+ 1.66666666666666651864,
+ 2.55555555555555535818,
+ 3.51851851851851815667,
+ 4.50617283950617242283}};
+
+ this->test_ungrouped_ewma_scan(
+ *col,
+ expected_ewma_vals_adjust,
+ *cudf::make_ewma_aggregation(0.5, cudf::ewm_history::INFINITE),
+ cudf::null_policy::INCLUDE);
+ this->test_ungrouped_ewma_scan(
+ *col,
+ expected_ewma_vals_noadjust,
+ *cudf::make_ewma_aggregation(0.5, cudf::ewm_history::FINITE),
+ cudf::null_policy::INCLUDE);
+}
+
+TYPED_TEST(TypedEwmScanTest, EwmWithNulls)
+{
+ auto const v = make_vector({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0});
+ auto const b = thrust::host_vector(std::vector{1, 0, 1, 0, 0, 1, 1});
+ auto col = this->make_column(v, b);
+
+ auto const expected_ewma_vals_adjust =
+ cudf::test::fixed_width_column_wrapper{{1.0,
+ 1.0,
+ 2.79999999999999982236,
+ 2.79999999999999982236,
+ 2.79999999999999982236,
+ 5.87351778656126466416,
+ 6.70977596741344139986}};
+
+ auto const expected_ewma_vals_noadjust =
+ cudf::test::fixed_width_column_wrapper{{1.0,
+ 1.0,
+ 2.71428571428571441260,
+ 2.71428571428571441260,
+ 2.71428571428571441260,
+ 5.82706766917293172980,
+ 6.60902255639097724327}};
+
+ this->test_ungrouped_ewma_scan(
+ *col,
+ expected_ewma_vals_adjust,
+ *cudf::make_ewma_aggregation(0.5, cudf::ewm_history::INFINITE),
+ cudf::null_policy::INCLUDE);
+ this->test_ungrouped_ewma_scan(
+ *col,
+ expected_ewma_vals_noadjust,
+ *cudf::make_ewma_aggregation(0.5, cudf::ewm_history::FINITE),
+ cudf::null_policy::INCLUDE);
+}
diff --git a/docs/cudf/source/user_guide/api_docs/dataframe.rst b/docs/cudf/source/user_guide/api_docs/dataframe.rst
index 70e4bd060ca..02fd9f7b396 100644
--- a/docs/cudf/source/user_guide/api_docs/dataframe.rst
+++ b/docs/cudf/source/user_guide/api_docs/dataframe.rst
@@ -137,6 +137,7 @@ Computations / descriptive stats
DataFrame.describe
DataFrame.diff
DataFrame.eval
+ DataFrame.ewm
DataFrame.kurt
DataFrame.kurtosis
DataFrame.max
diff --git a/docs/cudf/source/user_guide/api_docs/general_utilities.rst b/docs/cudf/source/user_guide/api_docs/general_utilities.rst
index d9c53c3fbbd..8d0edc0b100 100644
--- a/docs/cudf/source/user_guide/api_docs/general_utilities.rst
+++ b/docs/cudf/source/user_guide/api_docs/general_utilities.rst
@@ -8,6 +8,8 @@ Testing functions
:toctree: api/
cudf.testing.testing.assert_column_equal
+ cudf.testing.testing.assert_eq
cudf.testing.testing.assert_frame_equal
cudf.testing.testing.assert_index_equal
+ cudf.testing.testing.assert_neq
cudf.testing.testing.assert_series_equal
diff --git a/docs/cudf/source/user_guide/api_docs/series.rst b/docs/cudf/source/user_guide/api_docs/series.rst
index 5dc87a97337..48a7dc8ff87 100644
--- a/docs/cudf/source/user_guide/api_docs/series.rst
+++ b/docs/cudf/source/user_guide/api_docs/series.rst
@@ -138,6 +138,7 @@ Computations / descriptive stats
Series.describe
Series.diff
Series.digitize
+ Series.ewm
Series.factorize
Series.kurt
Series.max
diff --git a/python/cudf/cudf/_fuzz_testing/tests/fuzz_test_csv.py b/python/cudf/cudf/_fuzz_testing/tests/fuzz_test_csv.py
index f8f674fecec..d90f3ea1aca 100644
--- a/python/cudf/cudf/_fuzz_testing/tests/fuzz_test_csv.py
+++ b/python/cudf/cudf/_fuzz_testing/tests/fuzz_test_csv.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2023, NVIDIA CORPORATION.
+# Copyright (c) 2020-2024, NVIDIA CORPORATION.
import sys
from io import StringIO
@@ -13,7 +13,7 @@
compare_content,
run_test,
)
-from cudf.testing._utils import assert_eq
+from cudf.testing import assert_eq
@pythonfuzz(data_handle=CSVReader)
diff --git a/python/cudf/cudf/_fuzz_testing/tests/fuzz_test_json.py b/python/cudf/cudf/_fuzz_testing/tests/fuzz_test_json.py
index 2f5e6204f7c..69e9437be93 100644
--- a/python/cudf/cudf/_fuzz_testing/tests/fuzz_test_json.py
+++ b/python/cudf/cudf/_fuzz_testing/tests/fuzz_test_json.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020, NVIDIA CORPORATION.
+# Copyright (c) 2020-2024, NVIDIA CORPORATION.
import io
import sys
@@ -9,7 +9,7 @@
from cudf._fuzz_testing.json import JSONReader, JSONWriter
from cudf._fuzz_testing.main import pythonfuzz
from cudf._fuzz_testing.utils import ALL_POSSIBLE_VALUES, run_test
-from cudf.testing._utils import assert_eq
+from cudf.testing import assert_eq
@pythonfuzz(data_handle=JSONReader)
diff --git a/python/cudf/cudf/_fuzz_testing/utils.py b/python/cudf/cudf/_fuzz_testing/utils.py
index d685174f3c2..e6dfe2eae62 100644
--- a/python/cudf/cudf/_fuzz_testing/utils.py
+++ b/python/cudf/cudf/_fuzz_testing/utils.py
@@ -8,7 +8,7 @@
import pyarrow as pa
import cudf
-from cudf.testing._utils import assert_eq
+from cudf.testing import assert_eq
from cudf.utils.dtypes import (
pandas_dtypes_to_np_dtypes,
pyarrow_dtypes_to_pandas_dtypes,
diff --git a/python/cudf/cudf/_lib/aggregation.pyx b/python/cudf/cudf/_lib/aggregation.pyx
index 11f801ba772..1616c24eec2 100644
--- a/python/cudf/cudf/_lib/aggregation.pyx
+++ b/python/cudf/cudf/_lib/aggregation.pyx
@@ -58,6 +58,14 @@ class Aggregation:
if dropna else pylibcudf.types.NullPolicy.INCLUDE
))
+ @classmethod
+ def ewma(cls, com=1.0, adjust=True):
+ return cls(pylibcudf.aggregation.ewma(
+ com,
+ pylibcudf.aggregation.EWMHistory.INFINITE
+ if adjust else pylibcudf.aggregation.EWMHistory.FINITE
+ ))
+
@classmethod
def size(cls):
return cls(pylibcudf.aggregation.count(pylibcudf.types.NullPolicy.INCLUDE))
diff --git a/python/cudf/cudf/_lib/csv.pyx b/python/cudf/cudf/_lib/csv.pyx
index 0b0bbdb2589..c706351a683 100644
--- a/python/cudf/cudf/_lib/csv.pyx
+++ b/python/cudf/cudf/_lib/csv.pyx
@@ -8,7 +8,7 @@ from libcpp.utility cimport move
from libcpp.vector cimport vector
cimport cudf._lib.pylibcudf.libcudf.types as libcudf_types
-from cudf._lib.io.datasource cimport Datasource, NativeFileDatasource
+from cudf._lib.pylibcudf.io.datasource cimport Datasource, NativeFileDatasource
from cudf._lib.pylibcudf.libcudf.types cimport data_type
from cudf._lib.types cimport dtype_to_data_type
diff --git a/python/cudf/cudf/_lib/io/CMakeLists.txt b/python/cudf/cudf/_lib/io/CMakeLists.txt
index 2408fa1c12f..620229a1275 100644
--- a/python/cudf/cudf/_lib/io/CMakeLists.txt
+++ b/python/cudf/cudf/_lib/io/CMakeLists.txt
@@ -1,5 +1,5 @@
# =============================================================================
-# Copyright (c) 2022-2023, NVIDIA CORPORATION.
+# Copyright (c) 2022-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
@@ -12,7 +12,7 @@
# the License.
# =============================================================================
-set(cython_sources datasource.pyx utils.pyx)
+set(cython_sources utils.pyx)
set(linked_libraries cudf::cudf)
rapids_cython_create_modules(
CXX
diff --git a/python/cudf/cudf/_lib/io/utils.pyx b/python/cudf/cudf/_lib/io/utils.pyx
index 3c14ec46122..1d7c56888d9 100644
--- a/python/cudf/cudf/_lib/io/utils.pyx
+++ b/python/cudf/cudf/_lib/io/utils.pyx
@@ -8,7 +8,7 @@ from libcpp.utility cimport move
from libcpp.vector cimport vector
from cudf._lib.column cimport Column
-from cudf._lib.io.datasource cimport Datasource
+from cudf._lib.pylibcudf.io.datasource cimport Datasource
from cudf._lib.pylibcudf.libcudf.io.data_sink cimport data_sink
from cudf._lib.pylibcudf.libcudf.io.datasource cimport datasource
from cudf._lib.pylibcudf.libcudf.io.types cimport (
diff --git a/python/cudf/cudf/_lib/orc.pyx b/python/cudf/cudf/_lib/orc.pyx
index d3e6053ef4b..9609e3131b4 100644
--- a/python/cudf/cudf/_lib/orc.pyx
+++ b/python/cudf/cudf/_lib/orc.pyx
@@ -23,12 +23,12 @@ except ImportError:
cimport cudf._lib.pylibcudf.libcudf.io.types as cudf_io_types
from cudf._lib.column cimport Column
-from cudf._lib.io.datasource cimport NativeFileDatasource
from cudf._lib.io.utils cimport (
make_sink_info,
make_source_info,
update_column_struct_field_names,
)
+from cudf._lib.pylibcudf.io.datasource cimport NativeFileDatasource
from cudf._lib.pylibcudf.libcudf.io.data_sink cimport data_sink
from cudf._lib.pylibcudf.libcudf.io.orc cimport (
chunked_orc_writer_options,
diff --git a/python/cudf/cudf/_lib/parquet.pyx b/python/cudf/cudf/_lib/parquet.pyx
index f6f9cfa9a7c..7914ed7e9d9 100644
--- a/python/cudf/cudf/_lib/parquet.pyx
+++ b/python/cudf/cudf/_lib/parquet.pyx
@@ -37,12 +37,12 @@ cimport cudf._lib.pylibcudf.libcudf.io.types as cudf_io_types
cimport cudf._lib.pylibcudf.libcudf.types as cudf_types
from cudf._lib.column cimport Column
from cudf._lib.expressions cimport Expression
-from cudf._lib.io.datasource cimport NativeFileDatasource
from cudf._lib.io.utils cimport (
make_sinks_info,
make_source_info,
update_struct_field_names,
)
+from cudf._lib.pylibcudf.io.datasource cimport NativeFileDatasource
from cudf._lib.pylibcudf.libcudf.expressions cimport expression
from cudf._lib.pylibcudf.libcudf.io.parquet cimport (
chunked_parquet_reader as cpp_chunked_parquet_reader,
diff --git a/python/cudf/cudf/_lib/pylibcudf/aggregation.pxd b/python/cudf/cudf/_lib/pylibcudf/aggregation.pxd
index 8526728656b..0981d0e855a 100644
--- a/python/cudf/cudf/_lib/pylibcudf/aggregation.pxd
+++ b/python/cudf/cudf/_lib/pylibcudf/aggregation.pxd
@@ -6,6 +6,7 @@ from cudf._lib.pylibcudf.libcudf.aggregation cimport (
Kind as kind_t,
aggregation,
correlation_type,
+ ewm_history,
groupby_aggregation,
groupby_scan_aggregation,
rank_method,
@@ -80,6 +81,8 @@ cpdef Aggregation argmax()
cpdef Aggregation argmin()
+cpdef Aggregation ewma(float center_of_mass, ewm_history history)
+
cpdef Aggregation nunique(null_policy null_handling = *)
cpdef Aggregation nth_element(size_type n, null_policy null_handling = *)
diff --git a/python/cudf/cudf/_lib/pylibcudf/aggregation.pyx b/python/cudf/cudf/_lib/pylibcudf/aggregation.pyx
index 7bb64e32a1b..eed2f6de585 100644
--- a/python/cudf/cudf/_lib/pylibcudf/aggregation.pyx
+++ b/python/cudf/cudf/_lib/pylibcudf/aggregation.pyx
@@ -8,6 +8,7 @@ from libcpp.utility cimport move
from cudf._lib.pylibcudf.libcudf.aggregation cimport (
aggregation,
correlation_type,
+ ewm_history,
groupby_aggregation,
groupby_scan_aggregation,
make_all_aggregation,
@@ -19,6 +20,7 @@ from cudf._lib.pylibcudf.libcudf.aggregation cimport (
make_correlation_aggregation,
make_count_aggregation,
make_covariance_aggregation,
+ make_ewma_aggregation,
make_max_aggregation,
make_mean_aggregation,
make_median_aggregation,
@@ -52,6 +54,8 @@ from cudf._lib.pylibcudf.libcudf.types cimport (
from cudf._lib.pylibcudf.libcudf.aggregation import Kind # no-cython-lint
from cudf._lib.pylibcudf.libcudf.aggregation import \
correlation_type as CorrelationType # no-cython-lint
+from cudf._lib.pylibcudf.libcudf.aggregation import \
+ ewm_history as EWMHistory # no-cython-lint
from cudf._lib.pylibcudf.libcudf.aggregation import \
rank_method as RankMethod # no-cython-lint
from cudf._lib.pylibcudf.libcudf.aggregation import \
@@ -202,6 +206,28 @@ cpdef Aggregation max():
return Aggregation.from_libcudf(move(make_max_aggregation[aggregation]()))
+cpdef Aggregation ewma(float center_of_mass, ewm_history history):
+ """Create a EWMA aggregation.
+
+ For details, see :cpp:func:`make_ewma_aggregation`.
+
+ Parameters
+ ----------
+ center_of_mass : float
+ The decay in terms of the center of mass
+ history : ewm_history
+ Whether or not to treat the history as infinite.
+
+ Returns
+ -------
+ Aggregation
+ The EWMA aggregation.
+ """
+ return Aggregation.from_libcudf(
+ move(make_ewma_aggregation[aggregation](center_of_mass, history))
+ )
+
+
cpdef Aggregation count(null_policy null_handling = null_policy.EXCLUDE):
"""Create a count aggregation.
diff --git a/python/cudf/cudf/_lib/pylibcudf/io/CMakeLists.txt b/python/cudf/cudf/_lib/pylibcudf/io/CMakeLists.txt
index 2cfec101bab..32f0f5543e4 100644
--- a/python/cudf/cudf/_lib/pylibcudf/io/CMakeLists.txt
+++ b/python/cudf/cudf/_lib/pylibcudf/io/CMakeLists.txt
@@ -12,7 +12,7 @@
# the License.
# =============================================================================
-set(cython_sources avro.pyx types.pyx)
+set(cython_sources avro.pyx datasource.pyx types.pyx)
set(linked_libraries cudf::cudf)
rapids_cython_create_modules(
@@ -21,5 +21,5 @@ rapids_cython_create_modules(
LINKED_LIBRARIES "${linked_libraries}" MODULE_PREFIX pylibcudf_io_ ASSOCIATED_TARGETS cudf
)
-set(targets_using_arrow_headers pylibcudf_io_avro pylibcudf_io_types)
+set(targets_using_arrow_headers pylibcudf_io_avro pylibcudf_io_datasource pylibcudf_io_types)
link_to_pyarrow_headers("${targets_using_arrow_headers}")
diff --git a/python/cudf/cudf/_lib/pylibcudf/io/__init__.pxd b/python/cudf/cudf/_lib/pylibcudf/io/__init__.pxd
index 250292746c1..cfd6d2cd281 100644
--- a/python/cudf/cudf/_lib/pylibcudf/io/__init__.pxd
+++ b/python/cudf/cudf/_lib/pylibcudf/io/__init__.pxd
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
-from . cimport avro, types
+from . cimport avro, datasource, types
from .types cimport SourceInfo, TableWithMetadata
diff --git a/python/cudf/cudf/_lib/pylibcudf/io/__init__.py b/python/cudf/cudf/_lib/pylibcudf/io/__init__.py
index 5242c741911..a54ba1834dc 100644
--- a/python/cudf/cudf/_lib/pylibcudf/io/__init__.py
+++ b/python/cudf/cudf/_lib/pylibcudf/io/__init__.py
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
-from . import avro, types
+from . import avro, datasource, types
from .types import SourceInfo, TableWithMetadata
diff --git a/python/cudf/cudf/_lib/io/datasource.pxd b/python/cudf/cudf/_lib/pylibcudf/io/datasource.pxd
similarity index 100%
rename from python/cudf/cudf/_lib/io/datasource.pxd
rename to python/cudf/cudf/_lib/pylibcudf/io/datasource.pxd
diff --git a/python/cudf/cudf/_lib/io/datasource.pyx b/python/cudf/cudf/_lib/pylibcudf/io/datasource.pyx
similarity index 100%
rename from python/cudf/cudf/_lib/io/datasource.pyx
rename to python/cudf/cudf/_lib/pylibcudf/io/datasource.pyx
diff --git a/python/cudf/cudf/_lib/pylibcudf/io/types.pyx b/python/cudf/cudf/_lib/pylibcudf/io/types.pyx
index cd777232b33..ab3375da662 100644
--- a/python/cudf/cudf/_lib/pylibcudf/io/types.pyx
+++ b/python/cudf/cudf/_lib/pylibcudf/io/types.pyx
@@ -4,6 +4,8 @@ from libcpp.string cimport string
from libcpp.utility cimport move
from libcpp.vector cimport vector
+from cudf._lib.pylibcudf.io.datasource cimport Datasource
+from cudf._lib.pylibcudf.libcudf.io.datasource cimport datasource
from cudf._lib.pylibcudf.libcudf.io.types cimport (
host_buffer,
source_info,
@@ -56,9 +58,8 @@ cdef class SourceInfo:
Parameters
----------
- sources : List[Union[str, os.PathLike, bytes, io.BytesIO]]
- A homogeneous list of sources (this can be a string filename,
- an os.PathLike, bytes, or an io.BytesIO) to read from.
+ sources : List[Union[str, os.PathLike, bytes, io.BytesIO, DataSource]]
+ A homogeneous list of sources to read from.
Mixing different types of sources will raise a `ValueError`.
"""
@@ -68,6 +69,7 @@ cdef class SourceInfo:
raise ValueError("Need to pass at least one source")
cdef vector[string] c_files
+ cdef vector[datasource*] c_datasources
if isinstance(sources[0], (os.PathLike, str)):
c_files.reserve(len(sources))
@@ -84,6 +86,13 @@ cdef class SourceInfo:
self.c_obj = move(source_info(c_files))
return
+ elif isinstance(sources[0], Datasource):
+ for csrc in sources:
+ if not isinstance(csrc, Datasource):
+ raise ValueError("All sources must be of the same type!")
+ c_datasources.push_back((csrc).get_datasource())
+ self.c_obj = move(source_info(c_datasources))
+ return
# TODO: host_buffer is deprecated API, use host_span instead
cdef vector[host_buffer] c_host_buffers
@@ -106,5 +115,11 @@ cdef class SourceInfo:
c_buffer = bio.getbuffer() # check if empty?
c_host_buffers.push_back(host_buffer(&c_buffer[0],
c_buffer.shape[0]))
+ else:
+ raise ValueError("Sources must be a list of str/paths, "
+ "bytes, io.BytesIO, or a Datasource")
+
+ if empty_buffer is True:
+ c_host_buffers.push_back(host_buffer(NULL, 0))
- self.c_obj = source_info(c_host_buffers)
+ self.c_obj = move(source_info(c_host_buffers))
diff --git a/python/cudf/cudf/_lib/pylibcudf/libcudf/aggregation.pxd b/python/cudf/cudf/_lib/pylibcudf/libcudf/aggregation.pxd
index 8c14bc45723..fe04db52094 100644
--- a/python/cudf/cudf/_lib/pylibcudf/libcudf/aggregation.pxd
+++ b/python/cudf/cudf/_lib/pylibcudf/libcudf/aggregation.pxd
@@ -79,6 +79,10 @@ cdef extern from "cudf/aggregation.hpp" namespace "cudf" nogil:
KENDALL
SPEARMAN
+ cpdef enum class ewm_history(int32_t):
+ INFINITE
+ FINITE
+
cpdef enum class rank_method(int32_t):
FIRST
AVERAGE
@@ -143,6 +147,10 @@ cdef extern from "cudf/aggregation.hpp" namespace "cudf" nogil:
string user_defined_aggregator,
data_type output_type) except +
+ cdef unique_ptr[T] make_ewma_aggregation[T](
+ double com, ewm_history adjust
+ ) except +
+
cdef unique_ptr[T] make_correlation_aggregation[T](
correlation_type type, size_type min_periods) except +
diff --git a/python/cudf/cudf/core/_internals/timezones.py b/python/cudf/cudf/core/_internals/timezones.py
index 269fcf3e37f..29cb9d7bd12 100644
--- a/python/cudf/cudf/core/_internals/timezones.py
+++ b/python/cudf/cudf/core/_internals/timezones.py
@@ -1,21 +1,50 @@
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
from __future__ import annotations
+import datetime
import os
import zoneinfo
from functools import lru_cache
from typing import TYPE_CHECKING, Literal
import numpy as np
+import pandas as pd
+import cudf
from cudf._lib.timezone import make_timezone_transition_table
-from cudf.core.column.column import as_column
if TYPE_CHECKING:
from cudf.core.column.datetime import DatetimeColumn
from cudf.core.column.timedelta import TimeDeltaColumn
+def get_compatible_timezone(dtype: pd.DatetimeTZDtype) -> pd.DatetimeTZDtype:
+ """Convert dtype.tz object to zoneinfo object if possible."""
+ tz = dtype.tz
+ if isinstance(tz, zoneinfo.ZoneInfo):
+ return dtype
+ if cudf.get_option("mode.pandas_compatible"):
+ raise NotImplementedError(
+ f"{tz} must be a zoneinfo.ZoneInfo object in pandas_compatible mode."
+ )
+ elif (tzname := getattr(tz, "zone", None)) is not None:
+ # pytz-like
+ key = tzname
+ elif (tz_file := getattr(tz, "_filename", None)) is not None:
+ # dateutil-like
+ key = tz_file.split("zoneinfo/")[-1]
+ elif isinstance(tz, datetime.tzinfo):
+ # Try to get UTC-like tzinfos
+ reference = datetime.datetime.now()
+ key = tz.tzname(reference)
+ if not (isinstance(key, str) and key.lower() == "utc"):
+ raise NotImplementedError(f"cudf does not support {tz}")
+ else:
+ raise NotImplementedError(f"cudf does not support {tz}")
+ new_tz = zoneinfo.ZoneInfo(key)
+ return pd.DatetimeTZDtype(dtype.unit, new_tz)
+
+
@lru_cache(maxsize=20)
def get_tz_data(zone_name: str) -> tuple[DatetimeColumn, TimeDeltaColumn]:
"""
@@ -87,6 +116,8 @@ def _read_tzfile_as_columns(
)
if not transition_times_and_offsets:
+ from cudf.core.column.column import as_column
+
# this happens for UTC-like zones
min_date = np.int64(np.iinfo("int64").min + 1).astype("M8[s]")
return (as_column([min_date]), as_column([np.timedelta64(0, "s")]))
diff --git a/python/cudf/cudf/core/column/categorical.py b/python/cudf/cudf/core/column/categorical.py
index f538180805b..231af30c06d 100644
--- a/python/cudf/cudf/core/column/categorical.py
+++ b/python/cudf/cudf/core/column/categorical.py
@@ -1068,51 +1068,34 @@ def notnull(self) -> ColumnBase:
return result
- def fillna(
- self,
- fill_value: Any = None,
- method: str | None = None,
- ) -> Self:
- """
- Fill null values with *fill_value*
- """
- if fill_value is not None:
- fill_is_scalar = np.isscalar(fill_value)
-
- if fill_is_scalar:
- if fill_value == _DEFAULT_CATEGORICAL_VALUE:
- fill_value = self.codes.dtype.type(fill_value)
- else:
- try:
- fill_value = self._encode(fill_value)
- fill_value = self.codes.dtype.type(fill_value)
- except ValueError as err:
- err_msg = "fill value must be in categories"
- raise ValueError(err_msg) from err
+ def _validate_fillna_value(
+ self, fill_value: ScalarLike | ColumnLike
+ ) -> cudf.Scalar | ColumnBase:
+ """Align fill_value for .fillna based on column type."""
+ if cudf.api.types.is_scalar(fill_value):
+ if fill_value != _DEFAULT_CATEGORICAL_VALUE:
+ try:
+ fill_value = self._encode(fill_value)
+ except ValueError as err:
+ raise ValueError(
+ f"{fill_value=} must be in categories"
+ ) from err
+ return cudf.Scalar(fill_value, dtype=self.codes.dtype)
+ else:
+ fill_value = column.as_column(fill_value, nan_as_null=False)
+ if isinstance(fill_value.dtype, CategoricalDtype):
+ if self.dtype != fill_value.dtype:
+ raise TypeError(
+ "Cannot set a categorical with another without identical categories"
+ )
else:
- fill_value = column.as_column(fill_value, nan_as_null=False)
- if isinstance(fill_value, CategoricalColumn):
- if self.dtype != fill_value.dtype:
- raise TypeError(
- "Cannot set a Categorical with another, "
- "without identical categories"
- )
- # TODO: only required if fill_value has a subset of the
- # categories:
- fill_value = fill_value._set_categories(
- self.categories,
- is_unique=True,
- )
- fill_value = column.as_column(fill_value.codes).astype(
- self.codes.dtype
+ raise TypeError(
+ "Cannot set a categorical with non-categorical data"
)
-
- # Validation of `fill_value` will have to be performed
- # before returning self.
- if not self.nullable:
- return self
-
- return super().fillna(fill_value, method=method)
+ fill_value = fill_value._set_categories(
+ self.categories,
+ )
+ return fill_value.codes.astype(self.codes.dtype)
def indices_of(
self, value: ScalarLike
@@ -1372,11 +1355,13 @@ def _set_categories(
if not (is_unique or new_cats.is_unique):
new_cats = cudf.Series(new_cats)._column.unique()
+ if cur_cats.equals(new_cats, check_dtypes=True):
+ # TODO: Internal usages don't always need a copy; add a copy keyword
+ # as_ordered shallow copies
+ return self.copy().as_ordered(ordered=ordered)
+
cur_codes = self.codes
- max_cat_size = (
- len(cur_cats) if len(cur_cats) > len(new_cats) else len(new_cats)
- )
- out_code_dtype = min_unsigned_type(max_cat_size)
+ out_code_dtype = min_unsigned_type(max(len(cur_cats), len(new_cats)))
cur_order = column.as_column(range(len(cur_codes)))
old_codes = column.as_column(
diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py
index c4e715aeb45..dfcdfbb9d91 100644
--- a/python/cudf/cudf/core/column/column.py
+++ b/python/cudf/cudf/core/column/column.py
@@ -47,6 +47,7 @@
is_string_dtype,
)
from cudf.core._compat import PANDAS_GE_210
+from cudf.core._internals.timezones import get_compatible_timezone
from cudf.core.abc import Serializable
from cudf.core.buffer import (
Buffer,
@@ -665,15 +666,32 @@ def _check_scatter_key_length(
f"{num_keys}"
)
+ def _validate_fillna_value(
+ self, fill_value: ScalarLike | ColumnLike
+ ) -> cudf.Scalar | ColumnBase:
+ """Align fill_value for .fillna based on column type."""
+ if is_scalar(fill_value):
+ return cudf.Scalar(fill_value, dtype=self.dtype)
+ return as_column(fill_value)
+
def fillna(
self,
- fill_value: Any = None,
- method: str | None = None,
+ fill_value: ScalarLike | ColumnLike,
+ method: Literal["ffill", "bfill", None] = None,
) -> Self:
"""Fill null values with ``value``.
Returns a copy with null filled.
"""
+ if not self.has_nulls(include_nan=True):
+ return self.copy()
+ elif method is None:
+ if is_scalar(fill_value) and libcudf.scalar._is_null_host_scalar(
+ fill_value
+ ):
+ return self.copy()
+ else:
+ fill_value = self._validate_fillna_value(fill_value)
return libcudf.replace.replace_nulls(
input_col=self.nans_to_nulls(),
replacement=fill_value,
@@ -1854,6 +1872,21 @@ def as_column(
arbitrary.dtype,
(pd.CategoricalDtype, pd.IntervalDtype, pd.DatetimeTZDtype),
):
+ if isinstance(arbitrary.dtype, pd.DatetimeTZDtype):
+ new_tz = get_compatible_timezone(arbitrary.dtype)
+ arbitrary = arbitrary.astype(new_tz)
+ if isinstance(arbitrary.dtype, pd.CategoricalDtype) and isinstance(
+ arbitrary.dtype.categories.dtype, pd.DatetimeTZDtype
+ ):
+ new_tz = get_compatible_timezone(
+ arbitrary.dtype.categories.dtype
+ )
+ new_cats = arbitrary.dtype.categories.astype(new_tz)
+ new_dtype = pd.CategoricalDtype(
+ categories=new_cats, ordered=arbitrary.dtype.ordered
+ )
+ arbitrary = arbitrary.astype(new_dtype)
+
return as_column(
pa.array(arbitrary, from_pandas=True),
nan_as_null=nan_as_null,
diff --git a/python/cudf/cudf/core/column/datetime.py b/python/cudf/cudf/core/column/datetime.py
index 9ac761b6be1..121076b69ce 100644
--- a/python/cudf/cudf/core/column/datetime.py
+++ b/python/cudf/cudf/core/column/datetime.py
@@ -8,19 +8,23 @@
import locale
import re
from locale import nl_langinfo
-from typing import TYPE_CHECKING, Any, Literal, Sequence, cast
+from typing import TYPE_CHECKING, Literal, Sequence, cast
import numpy as np
import pandas as pd
import pyarrow as pa
-from typing_extensions import Self
import cudf
from cudf import _lib as libcudf
from cudf._lib.labeling import label_bins
from cudf._lib.search import search_sorted
-from cudf.api.types import is_datetime64_dtype, is_scalar, is_timedelta64_dtype
+from cudf.api.types import is_datetime64_dtype, is_timedelta64_dtype
from cudf.core._compat import PANDAS_GE_220
+from cudf.core._internals.timezones import (
+ check_ambiguous_and_nonexistent,
+ get_compatible_timezone,
+ get_tz_data,
+)
from cudf.core.column import ColumnBase, as_column, column, string
from cudf.core.column.timedelta import _unit_to_nanoseconds_conversion
from cudf.utils.dtypes import _get_base_dtype
@@ -282,8 +286,6 @@ def __contains__(self, item: ScalarLike) -> bool:
@functools.cached_property
def time_unit(self) -> str:
- if isinstance(self.dtype, pd.DatetimeTZDtype):
- return self.dtype.unit
return np.datetime_data(self.dtype)[0]
@property
@@ -638,22 +640,6 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase:
else:
return result_col
- def fillna(
- self,
- fill_value: Any = None,
- method: str | None = None,
- ) -> Self:
- if fill_value is not None:
- if cudf.utils.utils._isnat(fill_value):
- return self.copy(deep=True)
- if is_scalar(fill_value):
- if not isinstance(fill_value, cudf.Scalar):
- fill_value = cudf.Scalar(fill_value, dtype=self.dtype)
- else:
- fill_value = column.as_column(fill_value, nan_as_null=False)
-
- return super().fillna(fill_value, method)
-
def indices_of(
self, value: ScalarLike
) -> cudf.core.column.NumericalColumn:
@@ -725,8 +711,6 @@ def _find_ambiguous_and_nonexistent(
transitions occur in the time zone database for the given timezone.
If no transitions occur, the tuple `(False, False)` is returned.
"""
- from cudf.core._internals.timezones import get_tz_data
-
transition_times, offsets = get_tz_data(zone_name)
offsets = offsets.astype(f"timedelta64[{self.time_unit}]") # type: ignore[assignment]
@@ -785,26 +769,22 @@ def tz_localize(
ambiguous: Literal["NaT"] = "NaT",
nonexistent: Literal["NaT"] = "NaT",
):
- from cudf.core._internals.timezones import (
- check_ambiguous_and_nonexistent,
- get_tz_data,
- )
-
if tz is None:
return self.copy()
ambiguous, nonexistent = check_ambiguous_and_nonexistent(
ambiguous, nonexistent
)
- dtype = pd.DatetimeTZDtype(self.time_unit, tz)
+ dtype = get_compatible_timezone(pd.DatetimeTZDtype(self.time_unit, tz))
+ tzname = dtype.tz.key
ambiguous_col, nonexistent_col = self._find_ambiguous_and_nonexistent(
- tz
+ tzname
)
localized = self._scatter_by_column(
self.isnull() | (ambiguous_col | nonexistent_col),
cudf.Scalar(cudf.NaT, dtype=self.dtype),
)
- transition_times, offsets = get_tz_data(tz)
+ transition_times, offsets = get_tz_data(tzname)
transition_times_local = (transition_times + offsets).astype(
localized.dtype
)
@@ -845,7 +825,7 @@ def __init__(
offset=offset,
null_count=null_count,
)
- self._dtype = dtype
+ self._dtype = get_compatible_timezone(dtype)
def to_pandas(
self,
@@ -865,6 +845,10 @@ def to_arrow(self):
self._local_time.to_arrow(), str(self.dtype.tz)
)
+ @functools.cached_property
+ def time_unit(self) -> str:
+ return self.dtype.unit
+
@property
def _utc_time(self):
"""Return UTC time as naive timestamps."""
@@ -880,8 +864,6 @@ def _utc_time(self):
@property
def _local_time(self):
"""Return the local time as naive timestamps."""
- from cudf.core._internals.timezones import get_tz_data
-
transition_times, offsets = get_tz_data(str(self.dtype.tz))
transition_times = transition_times.astype(_get_base_dtype(self.dtype))
indices = search_sorted([transition_times], [self], "right") - 1
@@ -911,10 +893,6 @@ def __repr__(self):
)
def tz_localize(self, tz: str | None, ambiguous="NaT", nonexistent="NaT"):
- from cudf.core._internals.timezones import (
- check_ambiguous_and_nonexistent,
- )
-
if tz is None:
return self._local_time
ambiguous, nonexistent = check_ambiguous_and_nonexistent(
diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py
index e9d9b4933e5..d66908b5f94 100644
--- a/python/cudf/cudf/core/column/decimal.py
+++ b/python/cudf/cudf/core/column/decimal.py
@@ -4,12 +4,11 @@
import warnings
from decimal import Decimal
-from typing import TYPE_CHECKING, Any, Sequence, cast
+from typing import TYPE_CHECKING, Sequence, cast
import cupy as cp
import numpy as np
import pyarrow as pa
-from typing_extensions import Self
import cudf
from cudf import _lib as libcudf
@@ -31,7 +30,7 @@
from .numerical_base import NumericalBaseColumn
if TYPE_CHECKING:
- from cudf._typing import ColumnBinaryOperand, Dtype
+ from cudf._typing import ColumnBinaryOperand, ColumnLike, Dtype, ScalarLike
class DecimalBaseColumn(NumericalBaseColumn):
@@ -135,30 +134,20 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str):
return result
- def fillna(
- self,
- fill_value: Any = None,
- method: str | None = None,
- ) -> Self:
- """Fill null values with ``value``.
-
- Returns a copy with null filled.
- """
+ def _validate_fillna_value(
+ self, fill_value: ScalarLike | ColumnLike
+ ) -> cudf.Scalar | ColumnBase:
+ """Align fill_value for .fillna based on column type."""
if isinstance(fill_value, (int, Decimal)):
- fill_value = cudf.Scalar(fill_value, dtype=self.dtype)
- elif (
- isinstance(fill_value, DecimalBaseColumn)
- or isinstance(fill_value, cudf.core.column.NumericalColumn)
- and is_integer_dtype(fill_value.dtype)
+ return cudf.Scalar(fill_value, dtype=self.dtype)
+ elif isinstance(fill_value, ColumnBase) and (
+ isinstance(self.dtype, DecimalDtype) or self.dtype.kind in "iu"
):
- fill_value = fill_value.astype(self.dtype)
- else:
- raise TypeError(
- "Decimal columns only support using fillna with decimal and "
- "integer values"
- )
-
- return super().fillna(fill_value, method=method)
+ return fill_value.astype(self.dtype)
+ raise TypeError(
+ "Decimal columns only support using fillna with decimal and "
+ "integer values"
+ )
def normalize_binop_value(self, other):
if isinstance(other, ColumnBase):
diff --git a/python/cudf/cudf/core/column/numerical.py b/python/cudf/cudf/core/column/numerical.py
index 098cf43421b..76c64e1aea0 100644
--- a/python/cudf/cudf/core/column/numerical.py
+++ b/python/cudf/cudf/core/column/numerical.py
@@ -532,57 +532,26 @@ def find_and_replace(
replaced, df._data["old"], df._data["new"]
)
- def fillna(
- self,
- fill_value: Any = None,
- method: str | None = None,
- ) -> Self:
- """
- Fill null values with *fill_value*
- """
- col = self.nans_to_nulls()
-
- if col.null_count == 0:
- return col
-
- if method is not None:
- return super().fillna(fill_value, method)
-
- if fill_value is None:
- raise ValueError("Must specify either 'fill_value' or 'method'")
-
- if (
- isinstance(fill_value, cudf.Scalar)
- and fill_value.dtype == col.dtype
- ):
- return super().fillna(fill_value, method)
-
- if np.isscalar(fill_value):
- # cast safely to the same dtype as self
- fill_value_casted = col.dtype.type(fill_value)
- if not np.isnan(fill_value) and (fill_value_casted != fill_value):
+ def _validate_fillna_value(
+ self, fill_value: ScalarLike | ColumnLike
+ ) -> cudf.Scalar | ColumnBase:
+ """Align fill_value for .fillna based on column type."""
+ if is_scalar(fill_value):
+ cudf_obj = cudf.Scalar(fill_value)
+ if not as_column(cudf_obj).can_cast_safely(self.dtype):
raise TypeError(
f"Cannot safely cast non-equivalent "
- f"{type(fill_value).__name__} to {col.dtype.name}"
+ f"{type(fill_value).__name__} to {self.dtype.name}"
)
- fill_value = cudf.Scalar(fill_value_casted)
else:
- fill_value = column.as_column(fill_value, nan_as_null=False)
- if is_integer_dtype(col.dtype):
- # cast safely to the same dtype as self
- if fill_value.dtype != col.dtype:
- new_fill_value = fill_value.astype(col.dtype)
- if not (new_fill_value == fill_value).all():
- raise TypeError(
- f"Cannot safely cast non-equivalent "
- f"{fill_value.dtype.type.__name__} to "
- f"{col.dtype.type.__name__}"
- )
- fill_value = new_fill_value
- else:
- fill_value = fill_value.astype(col.dtype)
-
- return super().fillna(fill_value, method)
+ cudf_obj = as_column(fill_value, nan_as_null=False)
+ if not cudf_obj.can_cast_safely(self.dtype): # type: ignore[attr-defined]
+ raise TypeError(
+ f"Cannot safely cast non-equivalent "
+ f"{cudf_obj.dtype.type.__name__} to "
+ f"{self.dtype.type.__name__}"
+ )
+ return cudf_obj.astype(self.dtype)
def can_cast_safely(self, to_dtype: DtypeObj) -> bool:
"""
diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py
index 2451a9cc0af..936cd1eccb0 100644
--- a/python/cudf/cudf/core/column/string.py
+++ b/python/cudf/cudf/core/column/string.py
@@ -5,12 +5,11 @@
import re
import warnings
from functools import cached_property
-from typing import TYPE_CHECKING, Any, Sequence, cast, overload
+from typing import TYPE_CHECKING, Sequence, cast, overload
import numpy as np
import pandas as pd
import pyarrow as pa
-from typing_extensions import Self
import cudf
import cudf.api.types
@@ -5838,21 +5837,6 @@ def find_and_replace(
res = self
return libcudf.replace.replace(res, df._data["old"], df._data["new"])
- def fillna(
- self,
- fill_value: Any = None,
- method: str | None = None,
- ) -> Self:
- if fill_value is not None:
- if not is_scalar(fill_value):
- fill_value = column.as_column(fill_value, dtype=self.dtype)
- elif cudf._lib.scalar._is_null_host_scalar(fill_value):
- # Trying to fill with value? Return copy.
- return self.copy(deep=True)
- else:
- fill_value = cudf.Scalar(fill_value, dtype=self.dtype)
- return super().fillna(fill_value, method=method)
-
def normalize_binop_value(self, other) -> column.ColumnBase | cudf.Scalar:
if (
isinstance(other, (column.ColumnBase, cudf.Scalar))
diff --git a/python/cudf/cudf/core/column/timedelta.py b/python/cudf/cudf/core/column/timedelta.py
index 26b449f1863..8f41bcb6422 100644
--- a/python/cudf/cudf/core/column/timedelta.py
+++ b/python/cudf/cudf/core/column/timedelta.py
@@ -4,12 +4,11 @@
import datetime
import functools
-from typing import TYPE_CHECKING, Any, Sequence, cast
+from typing import TYPE_CHECKING, Sequence, cast
import numpy as np
import pandas as pd
import pyarrow as pa
-from typing_extensions import Self
import cudf
from cudf import _lib as libcudf
@@ -252,22 +251,6 @@ def normalize_binop_value(self, other) -> ColumnBinaryOperand:
def time_unit(self) -> str:
return np.datetime_data(self.dtype)[0]
- def fillna(
- self,
- fill_value: Any = None,
- method: str | None = None,
- ) -> Self:
- if fill_value is not None:
- if cudf.utils.utils._isnat(fill_value):
- return self.copy(deep=True)
- if is_scalar(fill_value):
- fill_value = cudf.Scalar(fill_value)
- dtype = self.dtype
- fill_value = fill_value.astype(dtype)
- else:
- fill_value = column.as_column(fill_value, nan_as_null=False)
- return super().fillna(fill_value, method)
-
def as_numerical_column(
self, dtype: Dtype
) -> "cudf.core.column.NumericalColumn":
diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py
index 76bb9d2a8ed..f0d8157011d 100644
--- a/python/cudf/cudf/core/dataframe.py
+++ b/python/cudf/cudf/core/dataframe.py
@@ -2980,6 +2980,32 @@ def set_index(
df.index = idx
return df if not inplace else None
+ @_cudf_nvtx_annotate
+ def fillna(
+ self, value=None, method=None, axis=None, inplace=False, limit=None
+ ): # noqa: D102
+ if isinstance(value, (pd.Series, pd.DataFrame)):
+ value = cudf.from_pandas(value)
+ if isinstance(value, cudf.Series):
+ # Align value.index to self.columns
+ value = value.reindex(self._column_names)
+ elif isinstance(value, cudf.DataFrame):
+ if not self.index.equals(value.index):
+ # Align value.index to self.index
+ value = value.reindex(self.index)
+ value = dict(value.items())
+ elif isinstance(value, abc.Mapping):
+ # Align value.indexes to self.index
+ value = {
+ key: value.reindex(self.index)
+ if isinstance(value, cudf.Series)
+ else value
+ for key, value in value.items()
+ }
+ return super().fillna(
+ value=value, method=method, axis=axis, inplace=inplace, limit=limit
+ )
+
@_cudf_nvtx_annotate
def where(self, cond, other=None, inplace=False):
from cudf.core._internals.where import (
diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py
index 38bff3946d6..8ca71180c00 100644
--- a/python/cudf/cudf/core/frame.py
+++ b/python/cudf/cudf/core/frame.py
@@ -2,7 +2,6 @@
from __future__ import annotations
-import copy
import operator
import pickle
import warnings
@@ -20,6 +19,7 @@
import cudf
from cudf import _lib as libcudf
from cudf.api.types import is_dtype_equal, is_scalar
+from cudf.core._compat import PANDAS_LT_300
from cudf.core.buffer import acquire_spill_lock
from cudf.core.column import (
ColumnBase,
@@ -38,7 +38,7 @@
if TYPE_CHECKING:
from types import ModuleType
- from cudf._typing import Dtype
+ from cudf._typing import Dtype, ScalarLike
# TODO: It looks like Frame is missing a declaration of `copy`, need to add
@@ -613,8 +613,8 @@ def where(self, cond, other=None, inplace: bool = False) -> Self | None:
@_cudf_nvtx_annotate
def fillna(
self,
- value=None,
- method: Literal["ffill", "bfill", "pad", "backfill"] | None = None,
+ value: None | ScalarLike | cudf.Series = None,
+ method: Literal["ffill", "bfill", "pad", "backfill", None] = None,
axis=None,
inplace: bool = False,
limit=None,
@@ -725,6 +725,16 @@ def fillna(
raise ValueError("Cannot specify both 'value' and 'method'.")
if method:
+ # Do not remove until pandas 3.0 support is added.
+ assert (
+ PANDAS_LT_300
+ ), "Need to drop after pandas-3.0 support is added."
+ warnings.warn(
+ f"{type(self).__name__}.fillna with 'method' is "
+ "deprecated and will raise in a future version. "
+ "Use obj.ffill() or obj.bfill() instead.",
+ FutureWarning,
+ )
if method not in {"ffill", "bfill", "pad", "backfill"}:
raise NotImplementedError(
f"Fill method {method} is not supported"
@@ -734,57 +744,24 @@ def fillna(
elif method == "backfill":
method = "bfill"
- # TODO: This logic should be handled in different subclasses since
- # different Frames support different types of values.
- if isinstance(value, cudf.Series):
- value = value.reindex(self._data.names)
- elif isinstance(value, cudf.DataFrame):
- if not self.index.equals(value.index): # type: ignore[attr-defined]
- value = value.reindex(self.index) # type: ignore[attr-defined]
- else:
- value = value
- elif not isinstance(value, abc.Mapping):
- value = {name: copy.deepcopy(value) for name in self._data.names}
- else:
- value = {
- key: value.reindex(self.index) # type: ignore[attr-defined]
- if isinstance(value, cudf.Series)
- else value
- for key, value in value.items()
- }
-
- filled_data = {}
- for col_name, col in self._data.items():
- if col_name in value and method is None:
- replace_val = value[col_name]
- else:
- replace_val = None
- should_fill = (
- (
- col_name in value
- and col.has_nulls(include_nan=True)
- and not libcudf.scalar._is_null_host_scalar(replace_val)
- )
- or method is not None
- or (
- isinstance(col, cudf.core.column.CategoricalColumn)
- and not libcudf.scalar._is_null_host_scalar(replace_val)
- )
+ if is_scalar(value):
+ value = {name: value for name in self._column_names}
+ elif not isinstance(value, (abc.Mapping, cudf.Series)):
+ raise TypeError(
+ f'"value" parameter must be a scalar, dict '
+ f"or Series, but you passed a "
+ f'"{type(value).__name__}"'
)
- if should_fill:
- filled_data[col_name] = col.fillna(replace_val, method)
- else:
- filled_data[col_name] = col.copy(deep=True)
+
+ filled_columns = [
+ col.fillna(value[name], method) if name in value else col.copy()
+ for name, col in self._data.items()
+ ]
return self._mimic_inplace(
- self._from_data(
- data=ColumnAccessor(
- data=filled_data,
- multiindex=self._data.multiindex,
- level_names=self._data.level_names,
- rangeindex=self._data.rangeindex,
- label_dtype=self._data.label_dtype,
- verify=False,
+ self._from_data_like_self(
+ self._data._from_columns_like_self(
+ filled_columns, verify=False
)
),
inplace=inplace,
diff --git a/python/cudf/cudf/core/indexed_frame.py b/python/cudf/cudf/core/indexed_frame.py
index f1b74adefed..280a6e92eab 100644
--- a/python/cudf/cudf/core/indexed_frame.py
+++ b/python/cudf/cudf/core/indexed_frame.py
@@ -52,7 +52,7 @@
_post_process_output_col,
_return_arr_from_dtype,
)
-from cudf.core.window import Rolling
+from cudf.core.window import ExponentialMovingWindow, Rolling
from cudf.utils import docutils, ioutils
from cudf.utils._numba import _CUDFNumbaConfig
from cudf.utils.docutils import copy_docstring
@@ -1853,6 +1853,32 @@ def rolling(
win_type=win_type,
)
+ @copy_docstring(ExponentialMovingWindow)
+ def ewm(
+ self,
+ com: float | None = None,
+ span: float | None = None,
+ halflife: float | None = None,
+ alpha: float | None = None,
+ min_periods: int | None = 0,
+ adjust: bool = True,
+ ignore_na: bool = False,
+ axis: int = 0,
+ times: str | np.ndarray | None = None,
+ ):
+ return ExponentialMovingWindow(
+ self,
+ com=com,
+ span=span,
+ halflife=halflife,
+ alpha=alpha,
+ min_periods=min_periods,
+ adjust=adjust,
+ ignore_na=ignore_na,
+ axis=axis,
+ times=times,
+ )
+
@_cudf_nvtx_annotate
def nans_to_nulls(self):
"""
@@ -2701,11 +2727,24 @@ def sort_index(
if ignore_index:
out = out.reset_index(drop=True)
else:
- labels = sorted(self._data.names, reverse=not ascending)
- out = self[labels]
+ labels = sorted(self._column_names, reverse=not ascending)
+ result_columns = (self._data[label] for label in labels)
if ignore_index:
- out._data.rangeindex = True
- out._data.names = list(range(self._num_columns))
+ ca = ColumnAccessor(
+ dict(enumerate(result_columns)),
+ rangeindex=True,
+ verify=False,
+ )
+ else:
+ ca = ColumnAccessor(
+ dict(zip(labels, result_columns)),
+ rangeindex=self._data.rangeindex,
+ multiindex=self._data.multiindex,
+ level_names=self._data.level_names,
+ label_dtype=self._data.label_dtype,
+ verify=False,
+ )
+ out = self._from_data_like_self(ca)
return self._mimic_inplace(out, inplace=inplace)
@@ -3178,29 +3217,6 @@ def _split(self, splits, keep_index=True):
for i in range(len(splits) + 1)
]
- @_cudf_nvtx_annotate
- def fillna(
- self, value=None, method=None, axis=None, inplace=False, limit=None
- ): # noqa: D102
- if method is not None:
- # Do not remove until pandas 3.0 support is added.
- assert (
- PANDAS_LT_300
- ), "Need to drop after pandas-3.0 support is added."
- warnings.warn(
- f"{type(self).__name__}.fillna with 'method' is "
- "deprecated and will raise in a future version. "
- "Use obj.ffill() or obj.bfill() instead.",
- FutureWarning,
- )
- old_index = self.index
- ret = super().fillna(value, method, axis, inplace, limit)
- if inplace:
- self.index = old_index
- else:
- ret.index = old_index
- return ret
-
@_cudf_nvtx_annotate
def bfill(self, value=None, axis=None, inplace=None, limit=None):
"""
diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py
index c0716d7709a..15ad0813601 100644
--- a/python/cudf/cudf/core/series.py
+++ b/python/cudf/cudf/core/series.py
@@ -1797,20 +1797,12 @@ def fillna(
):
if isinstance(value, pd.Series):
value = Series.from_pandas(value)
-
- if not (is_scalar(value) or isinstance(value, (abc.Mapping, Series))):
- raise TypeError(
- f'"value" parameter must be a scalar, dict '
- f"or Series, but you passed a "
- f'"{type(value).__name__}"'
- )
-
- if isinstance(value, (abc.Mapping, Series)):
+ elif isinstance(value, abc.Mapping):
value = Series(value)
+ if isinstance(value, cudf.Series):
if not self.index.equals(value.index):
value = value.reindex(self.index)
- value = value._column
-
+ value = {self.name: value._column}
return super().fillna(
value=value, method=method, axis=axis, inplace=inplace, limit=limit
)
diff --git a/python/cudf/cudf/core/window/__init__.py b/python/cudf/cudf/core/window/__init__.py
index 8ea3eb0179b..23522588d33 100644
--- a/python/cudf/cudf/core/window/__init__.py
+++ b/python/cudf/cudf/core/window/__init__.py
@@ -1,3 +1,3 @@
-# Copyright (c) 2019-2022, NVIDIA CORPORATION
-
+# Copyright (c) 2019-2024, NVIDIA CORPORATION
+from cudf.core.window.ewm import ExponentialMovingWindow
from cudf.core.window.rolling import Rolling
diff --git a/python/cudf/cudf/core/window/ewm.py b/python/cudf/cudf/core/window/ewm.py
new file mode 100644
index 00000000000..21693e106bd
--- /dev/null
+++ b/python/cudf/cudf/core/window/ewm.py
@@ -0,0 +1,200 @@
+# Copyright (c) 2022-2024, NVIDIA CORPORATION.
+
+from __future__ import annotations
+
+import numpy as np
+
+from cudf._lib.reduce import scan
+from cudf.api.types import is_numeric_dtype
+from cudf.core.window.rolling import _RollingBase
+
+
+class ExponentialMovingWindow(_RollingBase):
+ r"""
+ Provide exponential weighted (EW) functions.
+ Available EW functions: ``mean()``
+ Exactly one parameter: ``com``, ``span``, ``halflife``, or ``alpha``
+ must be provided.
+
+ Parameters
+ ----------
+ com : float, optional
+ Specify decay in terms of center of mass,
+ :math:`\alpha = 1 / (1 + com)`, for :math:`com \geq 0`.
+ span : float, optional
+ Specify decay in terms of span,
+ :math:`\alpha = 2 / (span + 1)`, for :math:`span \geq 1`.
+ halflife : float, str, timedelta, optional
+ Specify decay in terms of half-life,
+ :math:`\alpha = 1 - \exp\left(-\ln(2) / halflife\right)`, for
+ :math:`halflife > 0`.
+ alpha : float, optional
+ Specify smoothing factor :math:`\alpha` directly,
+ :math:`0 < \alpha \leq 1`.
+ min_periods : int, default 0
+ Not Supported
+ adjust : bool, default True
+ Controls assumptions about the first value in the sequence.
+ https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.ewm.html
+ for details.
+ ignore_na : bool, default False
+ Not Supported
+ axis : {0, 1}, default 0
+ Not Supported
+ times : str, np.ndarray, Series, default None
+ Not Supported
+
+ Returns
+ -------
+ ``ExponentialMovingWindow`` object
+
+ Notes
+ -----
+ cuDF input data may contain both nulls and nan values. For the purposes
+ of this method, they are taken to have the same meaning, meaning nulls
+ in cuDF will affect the result the same way that nan values would using
+ the equivalent pandas method.
+
+ .. pandas-compat::
+ **cudf.core.window.ExponentialMovingWindow**
+
+ The parameters ``min_periods``, ``ignore_na``, ``axis``, and ``times``
+ are not yet supported. Behavior is defined only for data that begins
+ with a valid (non-null) element.
+
+ Currently, only ``mean`` is a supported method.
+
+ Examples
+ --------
+ >>> df = cudf.DataFrame({'B': [0, 1, 2, cudf.NA, 4]})
+ >>> df
+ B
+ 0 0
+ 1 1
+ 2 2
+ 3
+ 4 4
+ >>> df.ewm(com=0.5).mean()
+ B
+ 0 0.000000
+ 1 0.750000
+ 2 1.615385
+ 3 1.615385
+ 4 3.670213
+
+ >>> df.ewm(com=0.5, adjust=False).mean()
+ B
+ 0 0.000000
+ 1 0.666667
+ 2 1.555556
+ 3 1.555556
+ 4 3.650794
+ """
+
+ def __init__(
+ self,
+ obj,
+ com: float | None = None,
+ span: float | None = None,
+ halflife: float | None = None,
+ alpha: float | None = None,
+ min_periods: int | None = 0,
+ adjust: bool = True,
+ ignore_na: bool = False,
+ axis: int = 0,
+ times: str | np.ndarray | None = None,
+ ):
+ if (min_periods, ignore_na, axis, times) != (0, False, 0, None):
+ raise NotImplementedError(
+ "The parameters `min_periods`, `ignore_na`, "
+ "`axis`, and `times` are not yet supported."
+ )
+
+ self.obj = obj
+ self.adjust = adjust
+ self.com = get_center_of_mass(com, span, halflife, alpha)
+
+ def mean(self):
+ """
+ Calculate the ewm (exponential weighted moment) mean.
+ """
+ return self._apply_agg("ewma")
+
+ def var(self, bias):
+ raise NotImplementedError("ewmvar not yet supported.")
+
+ def std(self, bias):
+ raise NotImplementedError("ewmstd not yet supported.")
+
+ def corr(self, other):
+ raise NotImplementedError("ewmcorr not yet supported.")
+
+ def cov(self, other):
+ raise NotImplementedError("ewmcov not yet supported.")
+
+ def _apply_agg_series(self, sr, agg_name):
+ if not is_numeric_dtype(sr.dtype):
+ raise TypeError("No numeric types to aggregate")
+
+ # libcudf ewm has special casing for nulls only
+ # and come what may with nans. It treats those nulls like
+ # pandas does nans in the same positions mathematically.
+ # as such we need to convert the nans to nulls before
+ # passing them in.
+ to_libcudf_column = sr._column.astype("float64").nans_to_nulls()
+
+ return self.obj._from_data_like_self(
+ self.obj._data._from_columns_like_self(
+ [
+ scan(
+ agg_name,
+ to_libcudf_column,
+ True,
+ com=self.com,
+ adjust=self.adjust,
+ )
+ ]
+ )
+ )
+
+
+def get_center_of_mass(
+ comass: float | None,
+ span: float | None,
+ halflife: float | None,
+ alpha: float | None,
+) -> float:
+ valid_count = count_not_none(comass, span, halflife, alpha)
+ if valid_count > 1:
+ raise ValueError(
+ "comass, span, halflife, and alpha are mutually exclusive"
+ )
+
+ # Convert to center of mass; domain checks ensure 0 < alpha <= 1
+ if comass is not None:
+ if comass < 0:
+ raise ValueError("comass must satisfy: comass >= 0")
+ elif span is not None:
+ if span < 1:
+ raise ValueError("span must satisfy: span >= 1")
+ comass = (span - 1) / 2
+ elif halflife is not None:
+ if halflife <= 0:
+ raise ValueError("halflife must satisfy: halflife > 0")
+ decay = 1 - np.exp(np.log(0.5) / halflife)
+ comass = 1 / decay - 1
+ elif alpha is not None:
+ if alpha <= 0 or alpha > 1:
+ raise ValueError("alpha must satisfy: 0 < alpha <= 1")
+ comass = (1 - alpha) / alpha
+ else:
+ raise ValueError("Must pass one of comass, span, halflife, or alpha")
+
+ return float(comass)
+
+
+def count_not_none(*args) -> int:
+ """
+ Returns the count of arguments that are not None.
+ """
+ return sum(x is not None for x in args)
diff --git a/python/cudf/cudf/core/window/rolling.py b/python/cudf/cudf/core/window/rolling.py
index 7d140a1ffa5..29391c68471 100644
--- a/python/cudf/cudf/core/window/rolling.py
+++ b/python/cudf/cudf/core/window/rolling.py
@@ -14,7 +14,27 @@
from cudf.utils.utils import GetAttrGetItemMixin
-class Rolling(GetAttrGetItemMixin, Reducible):
+class _RollingBase:
+ """
+ Contains methods common to all kinds of rolling
+ """
+
+ def _apply_agg_dataframe(self, df, agg_name):
+ result_df = cudf.DataFrame({})
+ for i, col_name in enumerate(df.columns):
+ result_col = self._apply_agg_series(df[col_name], agg_name)
+ result_df.insert(i, col_name, result_col)
+ result_df.index = df.index
+ return result_df
+
+ def _apply_agg(self, agg_name):
+ if isinstance(self.obj, cudf.Series):
+ return self._apply_agg_series(self.obj, agg_name)
+ else:
+ return self._apply_agg_dataframe(self.obj, agg_name)
+
+
+class Rolling(GetAttrGetItemMixin, _RollingBase, Reducible):
"""
Rolling window calculations.
diff --git a/python/cudf/cudf/pandas/_wrappers/pandas.py b/python/cudf/cudf/pandas/_wrappers/pandas.py
index 698dd946022..0ba432d6d0e 100644
--- a/python/cudf/cudf/pandas/_wrappers/pandas.py
+++ b/python/cudf/cudf/pandas/_wrappers/pandas.py
@@ -789,7 +789,7 @@ def Index__new__(cls, *args, **kwargs):
ExponentialMovingWindow = make_intermediate_proxy_type(
"ExponentialMovingWindow",
- _Unusable,
+ cudf.core.window.ewm.ExponentialMovingWindow,
pd.core.window.ewm.ExponentialMovingWindow,
)
diff --git a/python/cudf/cudf/pandas/fast_slow_proxy.py b/python/cudf/cudf/pandas/fast_slow_proxy.py
index 1540c6850e7..dfb729cae6b 100644
--- a/python/cudf/cudf/pandas/fast_slow_proxy.py
+++ b/python/cudf/cudf/pandas/fast_slow_proxy.py
@@ -17,7 +17,7 @@
import numpy as np
from ..options import _env_get_bool
-from ..testing._utils import assert_eq
+from ..testing import assert_eq
from .annotation import nvtx
diff --git a/python/cudf/cudf/pylibcudf_tests/test_source_info.py b/python/cudf/cudf/pylibcudf_tests/test_source_info.py
index 71a3ecbcc30..019321b7259 100644
--- a/python/cudf/cudf/pylibcudf_tests/test_source_info.py
+++ b/python/cudf/cudf/pylibcudf_tests/test_source_info.py
@@ -2,13 +2,21 @@
import io
+import pyarrow as pa
import pytest
import cudf._lib.pylibcudf as plc
+from cudf._lib.pylibcudf.io.datasource import NativeFileDatasource
@pytest.mark.parametrize(
- "source", ["a.txt", b"hello world", io.BytesIO(b"hello world")]
+ "source",
+ [
+ "a.txt",
+ b"hello world",
+ io.BytesIO(b"hello world"),
+ NativeFileDatasource(pa.PythonFile(io.BytesIO(), mode="r")),
+ ],
)
def test_source_info_ctor(source, tmp_path):
if isinstance(source, str):
@@ -28,6 +36,10 @@ def test_source_info_ctor(source, tmp_path):
["a.txt", "a.txt"],
[b"hello world", b"hello there"],
[io.BytesIO(b"hello world"), io.BytesIO(b"hello there")],
+ [
+ NativeFileDatasource(pa.PythonFile(io.BytesIO(), mode="r")),
+ NativeFileDatasource(pa.PythonFile(io.BytesIO(), mode="r")),
+ ],
],
)
def test_source_info_ctor_multiple(sources, tmp_path):
@@ -54,6 +66,11 @@ def test_source_info_ctor_multiple(sources, tmp_path):
io.BytesIO(b"hello there"),
b"hello world",
],
+ [
+ NativeFileDatasource(pa.PythonFile(io.BytesIO(), mode="r")),
+ "awef.txt",
+ b"hello world",
+ ],
],
)
def test_source_info_ctor_mixing_invalid(sources, tmp_path):
@@ -67,3 +84,8 @@ def test_source_info_ctor_mixing_invalid(sources, tmp_path):
sources[i] = str(file)
with pytest.raises(ValueError):
plc.io.SourceInfo(sources)
+
+
+def test_source_info_invalid():
+ with pytest.raises(ValueError):
+ plc.io.SourceInfo([123])
diff --git a/python/cudf/cudf/testing/__init__.py b/python/cudf/cudf/testing/__init__.py
index 1843344bc81..4e92b43b9f9 100644
--- a/python/cudf/cudf/testing/__init__.py
+++ b/python/cudf/cudf/testing/__init__.py
@@ -1,7 +1,9 @@
-# Copyright (c) 2020, NVIDIA CORPORATION.
+# Copyright (c) 2020-2024, NVIDIA CORPORATION.
from cudf.testing.testing import (
+ assert_eq,
assert_frame_equal,
assert_index_equal,
+ assert_neq,
assert_series_equal,
)
diff --git a/python/cudf/cudf/testing/_utils.py b/python/cudf/cudf/testing/_utils.py
index e067d15af4c..a6a2d4eea00 100644
--- a/python/cudf/cudf/testing/_utils.py
+++ b/python/cudf/cudf/testing/_utils.py
@@ -2,12 +2,10 @@
import itertools
import string
-import warnings
from collections import abc
from contextlib import contextmanager
from decimal import Decimal
-import cupy
import numpy as np
import pandas as pd
import pytest
@@ -15,7 +13,6 @@
from numba.core.typing.templates import AbstractTemplate
from numba.cuda.cudadecl import registry as cuda_decl_registry
from numba.cuda.cudaimpl import lower as cuda_lower
-from pandas import testing as tm
import cudf
from cudf._lib.null_mask import bitmask_allocation_size_bytes
@@ -113,81 +110,6 @@ def count_zero(arr):
return np.count_nonzero(arr == 0)
-def assert_eq(left, right, **kwargs):
- """Assert that two cudf-like things are equivalent
-
- This equality test works for pandas/cudf dataframes/series/indexes/scalars
- in the same way, and so makes it easier to perform parametrized testing
- without switching between assert_frame_equal/assert_series_equal/...
- functions.
- """
- # dtypes that we support but Pandas doesn't will convert to
- # `object`. Check equality before that happens:
- if kwargs.get("check_dtype", True):
- if hasattr(left, "dtype") and hasattr(right, "dtype"):
- if isinstance(
- left.dtype, cudf.core.dtypes._BaseDtype
- ) and not isinstance(
- left.dtype, cudf.CategoricalDtype
- ): # leave categorical comparison to Pandas
- assert_eq(left.dtype, right.dtype)
-
- if hasattr(left, "to_pandas"):
- left = left.to_pandas()
- if hasattr(right, "to_pandas"):
- right = right.to_pandas()
- if isinstance(left, cupy.ndarray):
- left = cupy.asnumpy(left)
- if isinstance(right, cupy.ndarray):
- right = cupy.asnumpy(right)
-
- if isinstance(left, (pd.DataFrame, pd.Series, pd.Index)):
- # TODO: A warning is emitted from the function
- # pandas.testing.assert_[series, frame, index]_equal for some inputs:
- # "DeprecationWarning: elementwise comparison failed; this will raise
- # an error in the future."
- # or "FutureWarning: elementwise ..."
- # This warning comes from a call from pandas to numpy. It is ignored
- # here because it cannot be fixed within cudf.
- with warnings.catch_warnings():
- warnings.simplefilter(
- "ignore", (DeprecationWarning, FutureWarning)
- )
- if isinstance(left, pd.DataFrame):
- tm.assert_frame_equal(left, right, **kwargs)
- elif isinstance(left, pd.Series):
- tm.assert_series_equal(left, right, **kwargs)
- else:
- tm.assert_index_equal(left, right, **kwargs)
-
- elif isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
- if np.issubdtype(left.dtype, np.floating) and np.issubdtype(
- right.dtype, np.floating
- ):
- assert np.allclose(left, right, equal_nan=True)
- else:
- assert np.array_equal(left, right)
- else:
- # Use the overloaded __eq__ of the operands
- if left == right:
- return True
- elif any(np.issubdtype(type(x), np.floating) for x in (left, right)):
- np.testing.assert_almost_equal(left, right)
- else:
- np.testing.assert_equal(left, right)
- return True
-
-
-def assert_neq(left, right, **kwargs):
- __tracebackhide__ = True
- try:
- assert_eq(left, right, **kwargs)
- except AssertionError:
- pass
- else:
- raise AssertionError
-
-
def assert_exceptions_equal(
lfunc,
rfunc,
diff --git a/python/cudf/cudf/testing/testing.py b/python/cudf/cudf/testing/testing.py
index dffbbe92fc1..e56c8d867cb 100644
--- a/python/cudf/cudf/testing/testing.py
+++ b/python/cudf/cudf/testing/testing.py
@@ -2,9 +2,12 @@
from __future__ import annotations
+import warnings
+
import cupy as cp
import numpy as np
import pandas as pd
+from pandas import testing as tm
import cudf
from cudf._lib.unary import is_nan
@@ -708,3 +711,100 @@ def assert_frame_equal(
atol=atol,
obj=f'Column name="{col}"',
)
+
+
+def assert_eq(left, right, **kwargs):
+ """Assert that two cudf-like things are equivalent
+
+ Parameters
+ ----------
+ left
+ Object to compare
+ right
+ Object to compare
+ kwargs
+ Keyword arguments to control behaviour of comparisons. See
+ :func:`assert_frame_equal`, :func:`assert_series_equal`, and
+ :func:`assert_index_equal`.
+
+ Notes
+ -----
+ This equality test works for pandas/cudf dataframes/series/indexes/scalars
+ in the same way, and so makes it easier to perform parametrized testing
+ without switching between assert_frame_equal/assert_series_equal/...
+ functions.
+
+ Raises
+ ------
+ AssertionError
+ If the two objects do not compare equal.
+ """
+ # dtypes that we support but Pandas doesn't will convert to
+ # `object`. Check equality before that happens:
+ if kwargs.get("check_dtype", True):
+ if hasattr(left, "dtype") and hasattr(right, "dtype"):
+ if isinstance(
+ left.dtype, cudf.core.dtypes._BaseDtype
+ ) and not isinstance(
+ left.dtype, cudf.CategoricalDtype
+ ): # leave categorical comparison to Pandas
+ assert_eq(left.dtype, right.dtype)
+
+ if hasattr(left, "to_pandas"):
+ left = left.to_pandas()
+ if hasattr(right, "to_pandas"):
+ right = right.to_pandas()
+ if isinstance(left, cp.ndarray):
+ left = cp.asnumpy(left)
+ if isinstance(right, cp.ndarray):
+ right = cp.asnumpy(right)
+
+ if isinstance(left, (pd.DataFrame, pd.Series, pd.Index)):
+ # TODO: A warning is emitted from the function
+ # pandas.testing.assert_[series, frame, index]_equal for some inputs:
+ # "DeprecationWarning: elementwise comparison failed; this will raise
+ # an error in the future."
+ # or "FutureWarning: elementwise ..."
+ # This warning comes from a call from pandas to numpy. It is ignored
+ # here because it cannot be fixed within cudf.
+ with warnings.catch_warnings():
+ warnings.simplefilter(
+ "ignore", (DeprecationWarning, FutureWarning)
+ )
+ if isinstance(left, pd.DataFrame):
+ tm.assert_frame_equal(left, right, **kwargs)
+ elif isinstance(left, pd.Series):
+ tm.assert_series_equal(left, right, **kwargs)
+ else:
+ tm.assert_index_equal(left, right, **kwargs)
+
+ elif isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
+ if np.issubdtype(left.dtype, np.floating) and np.issubdtype(
+ right.dtype, np.floating
+ ):
+ assert np.allclose(left, right, equal_nan=True)
+ else:
+ assert np.array_equal(left, right)
+ else:
+ # Use the overloaded __eq__ of the operands
+ if left == right:
+ return True
+ elif any(np.issubdtype(type(x), np.floating) for x in (left, right)):
+ np.testing.assert_almost_equal(left, right)
+ else:
+ np.testing.assert_equal(left, right)
+ return True
+
+
+def assert_neq(left, right, **kwargs):
+ """Assert that two cudf-like things are not equal.
+
+ Provides the negation of the meaning of :func:`assert_eq`.
+ """
+ __tracebackhide__ = True
+ try:
+ assert_eq(left, right, **kwargs)
+ except AssertionError:
+ pass
+ else:
+ raise AssertionError
diff --git a/python/cudf/cudf/tests/conftest.py b/python/cudf/cudf/tests/conftest.py
index 30d8f1c8422..437bc4cba67 100644
--- a/python/cudf/cudf/tests/conftest.py
+++ b/python/cudf/cudf/tests/conftest.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2019-2022, NVIDIA CORPORATION.
+# Copyright (c) 2019-2024, NVIDIA CORPORATION.
import itertools
import os
@@ -11,7 +11,7 @@
import rmm # noqa: F401
import cudf
-from cudf.testing._utils import assert_eq
+from cudf.testing import assert_eq
_CURRENT_DIRECTORY = str(pathlib.Path(__file__).resolve().parent)
diff --git a/python/cudf/cudf/tests/dataframe/test_conversion.py b/python/cudf/cudf/tests/dataframe/test_conversion.py
index fa7e5ec1d4c..d1de7245634 100644
--- a/python/cudf/cudf/tests/dataframe/test_conversion.py
+++ b/python/cudf/cudf/tests/dataframe/test_conversion.py
@@ -1,9 +1,9 @@
-# Copyright (c) 2023, NVIDIA CORPORATION.
+# Copyright (c) 2023-2024, NVIDIA CORPORATION.
import pandas as pd
import pytest
import cudf
-from cudf.testing._utils import assert_eq
+from cudf.testing import assert_eq
def test_convert_dtypes():
diff --git a/python/cudf/cudf/tests/dataframe/test_io_serialization.py b/python/cudf/cudf/tests/dataframe/test_io_serialization.py
index ad81609470c..57948afe1d8 100644
--- a/python/cudf/cudf/tests/dataframe/test_io_serialization.py
+++ b/python/cudf/cudf/tests/dataframe/test_io_serialization.py
@@ -8,7 +8,7 @@
import pytest
import cudf
-from cudf.testing._utils import assert_eq
+from cudf.testing import assert_eq
@pytest.mark.parametrize(
diff --git a/python/cudf/cudf/tests/groupby/test_computation.py b/python/cudf/cudf/tests/groupby/test_computation.py
index 04c56ef7462..630fcdc4dce 100644
--- a/python/cudf/cudf/tests/groupby/test_computation.py
+++ b/python/cudf/cudf/tests/groupby/test_computation.py
@@ -1,9 +1,9 @@
-# Copyright (c) 2023, NVIDIA CORPORATION.
+# Copyright (c) 2023-2024, NVIDIA CORPORATION.
import pandas as pd
import pytest
import cudf
-from cudf.testing._utils import assert_eq
+from cudf.testing import assert_eq
@pytest.mark.parametrize("method", ["average", "min", "max", "first", "dense"])
diff --git a/python/cudf/cudf/tests/groupby/test_groupby_obj.py b/python/cudf/cudf/tests/groupby/test_groupby_obj.py
index 04b483e08dc..ab2b16d263c 100644
--- a/python/cudf/cudf/tests/groupby/test_groupby_obj.py
+++ b/python/cudf/cudf/tests/groupby/test_groupby_obj.py
@@ -2,7 +2,7 @@
from numpy.testing import assert_array_equal
import cudf
-from cudf.testing._utils import assert_eq
+from cudf.testing import assert_eq
def test_groupby_14955():
diff --git a/python/cudf/cudf/tests/groupby/test_indexing.py b/python/cudf/cudf/tests/groupby/test_indexing.py
index 57e8bc1c2d8..43b6183fca5 100644
--- a/python/cudf/cudf/tests/groupby/test_indexing.py
+++ b/python/cudf/cudf/tests/groupby/test_indexing.py
@@ -1,6 +1,6 @@
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
import cudf
-from cudf.testing._utils import assert_eq
+from cudf.testing import assert_eq
def test_rank_return_type_compatible_mode():
diff --git a/python/cudf/cudf/tests/groupby/test_transform.py b/python/cudf/cudf/tests/groupby/test_transform.py
index 78d7fbfd879..f7138036ddf 100644
--- a/python/cudf/cudf/tests/groupby/test_transform.py
+++ b/python/cudf/cudf/tests/groupby/test_transform.py
@@ -4,7 +4,7 @@
import pytest
import cudf
-from cudf.testing._utils import assert_eq
+from cudf.testing import assert_eq
@pytest.fixture(params=[False, True], ids=["no-null-keys", "null-keys"])
diff --git a/python/cudf/cudf/tests/indexes/datetime/test_indexing.py b/python/cudf/cudf/tests/indexes/datetime/test_indexing.py
index f2c2d9a263b..4c0ce2ed191 100644
--- a/python/cudf/cudf/tests/indexes/datetime/test_indexing.py
+++ b/python/cudf/cudf/tests/indexes/datetime/test_indexing.py
@@ -1,19 +1,17 @@
-# Copyright (c) 2023, NVIDIA CORPORATION.
+# Copyright (c) 2023-2024, NVIDIA CORPORATION.
+import zoneinfo
import pandas as pd
import cudf
-from cudf.testing._utils import assert_eq
+from cudf.testing import assert_eq
def test_slice_datetimetz_index():
+ tz = zoneinfo.ZoneInfo("US/Eastern")
data = ["2001-01-01", "2001-01-02", None, None, "2001-01-03"]
- pidx = pd.DatetimeIndex(data, dtype="datetime64[ns]").tz_localize(
- "US/Eastern"
- )
- idx = cudf.DatetimeIndex(data, dtype="datetime64[ns]").tz_localize(
- "US/Eastern"
- )
+ pidx = pd.DatetimeIndex(data, dtype="datetime64[ns]").tz_localize(tz)
+ idx = cudf.DatetimeIndex(data, dtype="datetime64[ns]").tz_localize(tz)
expected = pidx[1:4]
got = idx[1:4]
assert_eq(expected, got)
diff --git a/python/cudf/cudf/tests/indexes/datetime/test_time_specific.py b/python/cudf/cudf/tests/indexes/datetime/test_time_specific.py
index b28ef131025..7cc629270b1 100644
--- a/python/cudf/cudf/tests/indexes/datetime/test_time_specific.py
+++ b/python/cudf/cudf/tests/indexes/datetime/test_time_specific.py
@@ -1,29 +1,28 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
+import zoneinfo
+
import pandas as pd
import cudf
-from cudf.testing._utils import assert_eq
+from cudf.testing import assert_eq
def test_tz_localize():
+ tz = zoneinfo.ZoneInfo("America/New_York")
pidx = pd.date_range("2001-01-01", "2001-01-02", freq="1s")
pidx = pidx.astype(" None:
super().__init__(dtype)
self.options = options
self.children = (agg,)
+ raise NotImplementedError("Rolling window not implemented")
class GroupedRollingWindow(Expr):
@@ -909,6 +910,7 @@ def __init__(self, dtype: plc.DataType, options: Any, agg: Expr, *by: Expr) -> N
super().__init__(dtype)
self.options = options
self.children = (agg, *by)
+ raise NotImplementedError("Grouped rolling window not implemented")
class Cast(Expr):
@@ -952,7 +954,9 @@ def __init__(
self.options = options
self.children = (value,)
if name not in Agg._SUPPORTED:
- raise NotImplementedError(f"Unsupported aggregation {name=}")
+ raise NotImplementedError(
+ f"Unsupported aggregation {name=}"
+ ) # pragma: no cover; all valid aggs are supported
# TODO: nan handling in groupby case
if name == "min":
req = plc.aggregation.min()
@@ -978,7 +982,9 @@ def __init__(
elif name == "count":
req = plc.aggregation.count(null_handling=plc.types.NullPolicy.EXCLUDE)
else:
- raise NotImplementedError
+ raise NotImplementedError(
+ f"Unreachable, {name=} is incorrectly listed in _SUPPORTED"
+ ) # pragma: no cover
self.request = req
op = getattr(self, f"_{name}", None)
if op is None:
@@ -988,7 +994,9 @@ def __init__(
elif name in {"count", "first", "last"}:
pass
else:
- raise AssertionError
+ raise NotImplementedError(
+ f"Unreachable, supported agg {name=} has no implementation"
+ ) # pragma: no cover
self.op = op
_SUPPORTED: ClassVar[frozenset[str]] = frozenset(
@@ -1010,11 +1018,15 @@ def __init__(
def collect_agg(self, *, depth: int) -> AggInfo:
"""Collect information about aggregations in groupbys."""
if depth >= 1:
- raise NotImplementedError("Nested aggregations in groupby")
+ raise NotImplementedError(
+ "Nested aggregations in groupby"
+ ) # pragma: no cover; check_agg trips first
(child,) = self.children
((expr, _, _),) = child.collect_agg(depth=depth + 1).requests
if self.request is None:
- raise NotImplementedError(f"Aggregation {self.name} in groupby")
+ raise NotImplementedError(
+ f"Aggregation {self.name} in groupby"
+ ) # pragma: no cover; __init__ trips first
return AggInfo([(expr, self.request, self)])
def _reduce(
@@ -1024,10 +1036,7 @@ def _reduce(
plc.Column.from_scalar(
plc.reduce.reduce(column.obj, request, self.dtype),
1,
- ),
- is_sorted=plc.types.Sorted.YES,
- order=plc.types.Order.ASCENDING,
- null_order=plc.types.NullOrder.BEFORE,
+ )
)
def _count(self, column: Column) -> Column:
@@ -1040,10 +1049,7 @@ def _count(self, column: Column) -> Column:
),
),
1,
- ),
- is_sorted=plc.types.Sorted.YES,
- order=plc.types.Order.ASCENDING,
- null_order=plc.types.NullOrder.BEFORE,
+ )
)
def _min(self, column: Column, *, propagate_nans: bool) -> Column:
@@ -1054,10 +1060,7 @@ def _min(self, column: Column, *, propagate_nans: bool) -> Column:
pa.scalar(float("nan"), type=plc.interop.to_arrow(self.dtype))
),
1,
- ),
- is_sorted=plc.types.Sorted.YES,
- order=plc.types.Order.ASCENDING,
- null_order=plc.types.NullOrder.BEFORE,
+ )
)
if column.nan_count > 0:
column = column.mask_nans()
@@ -1071,31 +1074,18 @@ def _max(self, column: Column, *, propagate_nans: bool) -> Column:
pa.scalar(float("nan"), type=plc.interop.to_arrow(self.dtype))
),
1,
- ),
- is_sorted=plc.types.Sorted.YES,
- order=plc.types.Order.ASCENDING,
- null_order=plc.types.NullOrder.BEFORE,
+ )
)
if column.nan_count > 0:
column = column.mask_nans()
return self._reduce(column, request=plc.aggregation.max())
def _first(self, column: Column) -> Column:
- return Column(
- plc.copying.slice(column.obj, [0, 1])[0],
- is_sorted=plc.types.Sorted.YES,
- order=plc.types.Order.ASCENDING,
- null_order=plc.types.NullOrder.BEFORE,
- )
+ return Column(plc.copying.slice(column.obj, [0, 1])[0])
def _last(self, column: Column) -> Column:
n = column.obj.size()
- return Column(
- plc.copying.slice(column.obj, [n - 1, n])[0],
- is_sorted=plc.types.Sorted.YES,
- order=plc.types.Order.ASCENDING,
- null_order=plc.types.NullOrder.BEFORE,
- )
+ return Column(plc.copying.slice(column.obj, [n - 1, n])[0])
def do_evaluate(
self,
@@ -1106,7 +1096,9 @@ def do_evaluate(
) -> Column:
"""Evaluate this expression given a dataframe for context."""
if context is not ExecutionContext.FRAME:
- raise NotImplementedError(f"Agg in context {context}")
+ raise NotImplementedError(
+ f"Agg in context {context}"
+ ) # pragma: no cover; unreachable
(child,) = self.children
return self.op(child.evaluate(df, context=context, mapping=mapping))
diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py
index 3ccefac6b0a..b3dd6ae7cc3 100644
--- a/python/cudf_polars/cudf_polars/dsl/ir.py
+++ b/python/cudf_polars/cudf_polars/dsl/ir.py
@@ -427,8 +427,6 @@ def check_agg(agg: expr.Expr) -> int:
if isinstance(agg, (expr.BinOp, expr.Cast)):
return max(GroupBy.check_agg(child) for child in agg.children)
elif isinstance(agg, expr.Agg):
- if agg.name == "implode":
- raise NotImplementedError("implode in groupby")
return 1 + max(GroupBy.check_agg(child) for child in agg.children)
elif isinstance(agg, (expr.Len, expr.Col, expr.Literal)):
return 0
@@ -440,7 +438,9 @@ def __post_init__(self) -> None:
if self.options.rolling is None and self.maintain_order:
raise NotImplementedError("Maintaining order in groupby")
if self.options.rolling:
- raise NotImplementedError("rolling window/groupby")
+ raise NotImplementedError(
+ "rolling window/groupby"
+ ) # pragma: no cover; rollingwindow constructor has already raised
if any(GroupBy.check_agg(a.value) > 1 for a in self.agg_requests):
raise NotImplementedError("Nested aggregations in groupby")
self.agg_infos = [req.collect_agg(depth=0) for req in self.agg_requests]
diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py
index 41bc3032bc5..5d289885f47 100644
--- a/python/cudf_polars/cudf_polars/dsl/translate.py
+++ b/python/cudf_polars/cudf_polars/dsl/translate.py
@@ -10,6 +10,7 @@
from typing import Any
import pyarrow as pa
+from typing_extensions import assert_never
from polars.polars import _expr_nodes as pl_expr, _ir_nodes as pl_ir
@@ -354,17 +355,20 @@ def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> ex
@_translate_expr.register
def _(node: pl_expr.Window, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr:
# TODO: raise in groupby?
- if node.partition_by is None:
+ if isinstance(node.options, pl_expr.RollingGroupOptions):
+ # pl.col("a").rolling(...)
return expr.RollingWindow(
dtype, node.options, translate_expr(visitor, n=node.function)
)
- else:
+ elif isinstance(node.options, pl_expr.WindowMapping):
+ # pl.col("a").over(...)
return expr.GroupedRollingWindow(
dtype,
node.options,
translate_expr(visitor, n=node.function),
*(translate_expr(visitor, n=n) for n in node.partition_by),
)
+ assert_never(node.options)
@_translate_expr.register
diff --git a/python/cudf_polars/cudf_polars/utils/dtypes.py b/python/cudf_polars/cudf_polars/utils/dtypes.py
index 7b0049daf11..3d4a643e1fc 100644
--- a/python/cudf_polars/cudf_polars/utils/dtypes.py
+++ b/python/cudf_polars/cudf_polars/utils/dtypes.py
@@ -70,7 +70,7 @@ def from_polars(dtype: pl.DataType) -> plc.DataType:
return plc.DataType(plc.TypeId.TIMESTAMP_MICROSECONDS)
elif dtype.time_unit == "ns":
return plc.DataType(plc.TypeId.TIMESTAMP_NANOSECONDS)
- assert dtype.time_unit is not None
+ assert dtype.time_unit is not None # pragma: no cover
assert_never(dtype.time_unit)
elif isinstance(dtype, pl.Duration):
if dtype.time_unit == "ms":
@@ -79,7 +79,7 @@ def from_polars(dtype: pl.DataType) -> plc.DataType:
return plc.DataType(plc.TypeId.DURATION_MICROSECONDS)
elif dtype.time_unit == "ns":
return plc.DataType(plc.TypeId.DURATION_NANOSECONDS)
- assert dtype.time_unit is not None
+ assert dtype.time_unit is not None # pragma: no cover
assert_never(dtype.time_unit)
elif isinstance(dtype, pl.String):
return plc.DataType(plc.TypeId.STRING)
diff --git a/python/cudf_polars/cudf_polars/utils/sorting.py b/python/cudf_polars/cudf_polars/utils/sorting.py
index 24fd449dd88..57f94c4ec4c 100644
--- a/python/cudf_polars/cudf_polars/utils/sorting.py
+++ b/python/cudf_polars/cudf_polars/utils/sorting.py
@@ -43,8 +43,8 @@ def sort_order(
for d in descending
]
null_precedence = []
- # TODO: use strict=True when we drop py39
- assert len(descending) == len(nulls_last)
+ if len(descending) != len(nulls_last) or len(descending) != num_keys:
+ raise ValueError("Mismatching length of arguments in sort_order")
for asc, null_last in zip(column_order, nulls_last):
if (asc == plc.types.Order.ASCENDING) ^ (not null_last):
null_precedence.append(plc.types.NullOrder.AFTER)
diff --git a/python/cudf_polars/pyproject.toml b/python/cudf_polars/pyproject.toml
index face04b9bd8..effa4861e0c 100644
--- a/python/cudf_polars/pyproject.toml
+++ b/python/cudf_polars/pyproject.toml
@@ -52,6 +52,13 @@ version = {file = "cudf_polars/VERSION"}
[tool.pytest.ini_options]
xfail_strict = true
+[tool.coverage.report]
+exclude_also = [
+ "if TYPE_CHECKING:",
+ "class .*\\bProtocol\\):",
+ "assert_never\\("
+]
+
[tool.ruff]
line-length = 88
indent-width = 4
diff --git a/python/cudf_polars/tests/expressions/test_agg.py b/python/cudf_polars/tests/expressions/test_agg.py
index b044bbb2885..2ffa1c4af6d 100644
--- a/python/cudf_polars/tests/expressions/test_agg.py
+++ b/python/cudf_polars/tests/expressions/test_agg.py
@@ -56,3 +56,17 @@ def test_agg(df, agg):
with pytest.raises(AssertionError):
assert_gpu_result_equal(q)
assert_gpu_result_equal(q, check_dtypes=check_dtypes, check_exact=False)
+
+
+@pytest.mark.parametrize(
+ "propagate_nans",
+ [pytest.param(False, marks=pytest.mark.xfail(reason="Need to mask nans")), True],
+ ids=["mask_nans", "propagate_nans"],
+)
+@pytest.mark.parametrize("op", ["min", "max"])
+def test_agg_float_with_nans(propagate_nans, op):
+ df = pl.LazyFrame({"a": [1, 2, float("nan")]})
+ op = getattr(pl.Expr, f"nan_{op}" if propagate_nans else op)
+ q = df.select(op(pl.col("a")))
+
+ assert_gpu_result_equal(q)
diff --git a/python/cudf_polars/tests/expressions/test_datetime_basic.py b/python/cudf_polars/tests/expressions/test_datetime_basic.py
new file mode 100644
index 00000000000..6ba2a1dce1e
--- /dev/null
+++ b/python/cudf_polars/tests/expressions/test_datetime_basic.py
@@ -0,0 +1,34 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-License-Identifier: Apache-2.0
+from __future__ import annotations
+
+import pytest
+
+import polars as pl
+
+from cudf_polars.testing.asserts import assert_gpu_result_equal
+
+
+@pytest.mark.parametrize(
+ "dtype",
+ [
+ pl.Date(),
+ pl.Datetime("ms"),
+ pl.Datetime("us"),
+ pl.Datetime("ns"),
+ pl.Duration("ms"),
+ pl.Duration("us"),
+ pl.Duration("ns"),
+ ],
+ ids=repr,
+)
+def test_datetime_dataframe_scan(dtype):
+ ldf = pl.DataFrame(
+ {
+ "a": pl.Series([1, 2, 3, 4, 5, 6, 7], dtype=dtype),
+ "b": pl.Series([3, 4, 5, 6, 7, 8, 9], dtype=pl.UInt16),
+ }
+ ).lazy()
+
+ query = ldf.select(pl.col("b"), pl.col("a"))
+ assert_gpu_result_equal(query)
diff --git a/python/cudf_polars/tests/expressions/test_filter.py b/python/cudf_polars/tests/expressions/test_filter.py
index 783403d764c..1a8e994e3aa 100644
--- a/python/cudf_polars/tests/expressions/test_filter.py
+++ b/python/cudf_polars/tests/expressions/test_filter.py
@@ -2,19 +2,35 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
+import pytest
+
import polars as pl
from cudf_polars.testing.asserts import assert_gpu_result_equal
-def test_filter():
- ldf = pl.DataFrame(
+@pytest.mark.parametrize(
+ "expr",
+ [
+ pytest.param(
+ pl.lit(value=False),
+ marks=pytest.mark.xfail(reason="Expression filter does not handle scalars"),
+ ),
+ pl.col("c"),
+ pl.col("b") > 2,
+ ],
+)
+@pytest.mark.parametrize("predicate_pushdown", [False, True])
+def test_filter_expression(expr, predicate_pushdown):
+ ldf = pl.LazyFrame(
{
"a": [1, 2, 3, 4, 5, 6, 7],
- "b": [1, 1, 1, 1, 1, 1, 1],
+ "b": [0, 3, 1, 5, 6, 1, 0],
+ "c": [None, True, False, False, True, True, False],
}
- ).lazy()
+ )
- # group-by is just to avoid the filter being pushed into the scan.
- query = ldf.group_by(pl.col("a")).agg(pl.col("b").sum()).filter(pl.col("b") < 1)
- assert_gpu_result_equal(query)
+ query = ldf.select(pl.col("a").filter(expr))
+ assert_gpu_result_equal(
+ query, collect_kwargs={"predicate_pushdown": predicate_pushdown}
+ )
diff --git a/python/cudf_polars/tests/expressions/test_rolling.py b/python/cudf_polars/tests/expressions/test_rolling.py
new file mode 100644
index 00000000000..d4920d35f14
--- /dev/null
+++ b/python/cudf_polars/tests/expressions/test_rolling.py
@@ -0,0 +1,41 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+import pytest
+
+import polars as pl
+
+from cudf_polars import translate_ir
+
+
+def test_rolling():
+ dates = [
+ "2020-01-01 13:45:48",
+ "2020-01-01 16:42:13",
+ "2020-01-01 16:45:09",
+ "2020-01-02 18:12:48",
+ "2020-01-03 19:45:32",
+ "2020-01-08 23:16:43",
+ ]
+ df = (
+ pl.DataFrame({"dt": dates, "a": [3, 7, 5, 9, 2, 1]})
+ .with_columns(pl.col("dt").str.strptime(pl.Datetime))
+ .lazy()
+ )
+ q = df.with_columns(
+ sum_a=pl.sum("a").rolling(index_column="dt", period="2d"),
+ min_a=pl.min("a").rolling(index_column="dt", period="2d"),
+ max_a=pl.max("a").rolling(index_column="dt", period="2d"),
+ )
+ with pytest.raises(NotImplementedError):
+ _ = translate_ir(q._ldf.visit())
+
+
+def test_grouped_rolling():
+ df = pl.LazyFrame({"a": [1, 2, 3, 4, 5, 6], "b": [1, 2, 1, 3, 1, 2]})
+
+ q = df.select(pl.col("a").min().over("b"))
+ with pytest.raises(NotImplementedError):
+ _ = translate_ir(q._ldf.visit())
diff --git a/python/cudf_polars/tests/expressions/test_sort.py b/python/cudf_polars/tests/expressions/test_sort.py
new file mode 100644
index 00000000000..0195266f5c6
--- /dev/null
+++ b/python/cudf_polars/tests/expressions/test_sort.py
@@ -0,0 +1,53 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-License-Identifier: Apache-2.0
+from __future__ import annotations
+
+import itertools
+
+import pytest
+
+import polars as pl
+
+from cudf_polars.testing.asserts import assert_gpu_result_equal
+
+
+@pytest.mark.parametrize("descending", [False, True])
+@pytest.mark.parametrize("nulls_last", [False, True])
+def test_sort_expression(descending, nulls_last):
+ ldf = pl.LazyFrame(
+ {
+ "a": [5, -1, 3, 4, None, 8, 6, 7, None],
+ }
+ )
+
+ query = ldf.select(pl.col("a").sort(descending=descending, nulls_last=nulls_last))
+ assert_gpu_result_equal(query)
+
+
+@pytest.mark.parametrize(
+ "descending", itertools.combinations_with_replacement([False, True], 3)
+)
+@pytest.mark.parametrize(
+ "nulls_last", itertools.combinations_with_replacement([False, True], 3)
+)
+@pytest.mark.parametrize("maintain_order", [False, True], ids=["unstable", "stable"])
+def test_sort_by_expression(descending, nulls_last, maintain_order):
+ ldf = pl.LazyFrame(
+ {
+ "a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ "b": [1, 2, 2, 3, 9, 5, -1, 2, -2, 16],
+ "c": ["a", "A", "b", "b", "c", "d", "A", "Z", "ä", "̈Ä"],
+ }
+ )
+
+ query = ldf.select(
+ pl.col("a").sort_by(
+ pl.col("b"),
+ pl.col("c"),
+ pl.col("b") + pl.col("a"),
+ descending=descending,
+ nulls_last=nulls_last,
+ maintain_order=maintain_order,
+ )
+ )
+ assert_gpu_result_equal(query, check_row_order=maintain_order)
diff --git a/python/cudf_polars/tests/test_filter.py b/python/cudf_polars/tests/test_filter.py
new file mode 100644
index 00000000000..f39b348144b
--- /dev/null
+++ b/python/cudf_polars/tests/test_filter.py
@@ -0,0 +1,26 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-License-Identifier: Apache-2.0
+from __future__ import annotations
+
+import pytest
+
+import polars as pl
+
+from cudf_polars.testing.asserts import assert_gpu_result_equal
+
+
+@pytest.mark.parametrize("expr", [pl.col("c"), pl.col("b") < 1, pl.lit(value=True)])
+@pytest.mark.parametrize("predicate_pushdown", [False, True])
+def test_filter(expr, predicate_pushdown):
+ ldf = pl.DataFrame(
+ {
+ "a": [1, 2, 3, 4, 5, 6, 7],
+ "b": [1, 1, 1, 1, 1, 1, 1],
+ "c": [True, False, False, True, True, True, None],
+ }
+ ).lazy()
+
+ query = ldf.filter(expr)
+ assert_gpu_result_equal(
+ query, collect_kwargs={"predicate_pushdown": predicate_pushdown}
+ )
diff --git a/python/cudf_polars/tests/test_groupby.py b/python/cudf_polars/tests/test_groupby.py
index d06a7ecf105..e70f923b097 100644
--- a/python/cudf_polars/tests/test_groupby.py
+++ b/python/cudf_polars/tests/test_groupby.py
@@ -6,6 +6,7 @@
import polars as pl
+from cudf_polars import translate_ir
from cudf_polars.testing.asserts import assert_gpu_result_equal
@@ -43,6 +44,7 @@ def keys(request):
[pl.col("float") + pl.col("int")],
[pl.col("float").max() - pl.col("int").min()],
[pl.col("float").mean(), pl.col("int").std()],
+ [(pl.col("float") - pl.lit(2)).max()],
],
ids=lambda aggs: "-".join(map(str, aggs)),
)
@@ -72,7 +74,28 @@ def test_groupby(df: pl.LazyFrame, maintain_order, keys, exprs):
if not maintain_order:
sort_keys = list(q.schema.keys())[: len(keys)]
q = q.sort(*sort_keys)
- # from cudf_polars.dsl.translate import translate_ir
- # ir = translate_ir(q._ldf.visit())
- # from IPython import embed; embed()
+
assert_gpu_result_equal(q, check_exact=False)
+
+
+def test_groupby_len(df, keys):
+ q = df.group_by(*keys).agg(pl.len())
+
+ # TODO: polars returns UInt32, libcudf returns Int32
+ with pytest.raises(AssertionError):
+ assert_gpu_result_equal(q, check_row_order=False)
+ assert_gpu_result_equal(q, check_dtypes=False, check_row_order=False)
+
+
+@pytest.mark.parametrize(
+ "expr",
+ [
+ pl.col("float").is_not_null(),
+ (pl.col("int").max() + pl.col("float").min()).max(),
+ ],
+)
+def test_groupby_unsupported(df, expr):
+ q = df.group_by("key1").agg(expr)
+
+ with pytest.raises(NotImplementedError):
+ _ = translate_ir(q._ldf.visit())
diff --git a/python/cudf_polars/tests/utils/test_dtypes.py b/python/cudf_polars/tests/utils/test_dtypes.py
new file mode 100644
index 00000000000..535fdd846a0
--- /dev/null
+++ b/python/cudf_polars/tests/utils/test_dtypes.py
@@ -0,0 +1,31 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+import pytest
+
+import polars as pl
+
+from cudf_polars.utils.dtypes import from_polars
+
+
+@pytest.mark.parametrize(
+ "pltype",
+ [
+ pl.Time(),
+ pl.Struct({"a": pl.Int8, "b": pl.Float32}),
+ pl.Datetime("ms", time_zone="US/Pacific"),
+ pl.Array(pl.Int8, 2),
+ pl.Binary(),
+ pl.Categorical(),
+ pl.Enum(["a", "b"]),
+ pl.Field("a", pl.Int8),
+ pl.Object(),
+ pl.Unknown(),
+ ],
+ ids=repr,
+)
+def test_unhandled_dtype_conversion_raises(pltype):
+ with pytest.raises(NotImplementedError):
+ _ = from_polars(pltype)
diff --git a/python/cudf_polars/tests/utils/test_sorting.py b/python/cudf_polars/tests/utils/test_sorting.py
new file mode 100644
index 00000000000..4e98a3a7ce7
--- /dev/null
+++ b/python/cudf_polars/tests/utils/test_sorting.py
@@ -0,0 +1,21 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+import pytest
+
+from cudf_polars.utils.sorting import sort_order
+
+
+@pytest.mark.parametrize(
+ "descending,nulls_last,num_keys",
+ [
+ ([True], [False, True], 3),
+ ([True, True], [False, True, False], 3),
+ ([False, True], [True], 3),
+ ],
+)
+def test_sort_order_raises_mismatch(descending, nulls_last, num_keys):
+ with pytest.raises(ValueError):
+ _ = sort_order(descending, nulls_last=nulls_last, num_keys=num_keys)
diff --git a/python/custreamz/custreamz/tests/test_kafka.py b/python/custreamz/custreamz/tests/test_kafka.py
index ad3b829544b..3a3c4e994d0 100644
--- a/python/custreamz/custreamz/tests/test_kafka.py
+++ b/python/custreamz/custreamz/tests/test_kafka.py
@@ -1,8 +1,8 @@
-# Copyright (c) 2020, NVIDIA CORPORATION.
+# Copyright (c) 2020-2024, NVIDIA CORPORATION.
import confluent_kafka as ck
import pytest
-from cudf.testing._utils import assert_eq
+from cudf.testing import assert_eq
@pytest.mark.parametrize("commit_offset", [1, 45, 100, 22, 1000, 10])
diff --git a/python/dask_cudf/dask_cudf/tests/test_accessor.py b/python/dask_cudf/dask_cudf/tests/test_accessor.py
index 58d28f0597e..6f04b5737da 100644
--- a/python/dask_cudf/dask_cudf/tests/test_accessor.py
+++ b/python/dask_cudf/dask_cudf/tests/test_accessor.py
@@ -9,7 +9,8 @@
from dask import dataframe as dd
from cudf import DataFrame, Series, date_range
-from cudf.testing._utils import assert_eq, does_not_raise
+from cudf.testing import assert_eq
+from cudf.testing._utils import does_not_raise
import dask_cudf
from dask_cudf.tests.utils import xfail_dask_expr
diff --git a/python/dask_cudf/dask_cudf/tests/test_core.py b/python/dask_cudf/dask_cudf/tests/test_core.py
index 7f8a619ae22..174923c2c7e 100644
--- a/python/dask_cudf/dask_cudf/tests/test_core.py
+++ b/python/dask_cudf/dask_cudf/tests/test_core.py
@@ -795,7 +795,7 @@ def test_dataframe_set_index():
pddf = dd.from_pandas(pdf, npartitions=4)
pddf = pddf.set_index("str")
- from cudf.testing._utils import assert_eq
+ from cudf.testing import assert_eq
assert_eq(ddf.compute(), pddf.compute())
diff --git a/python/dask_cudf/dask_cudf/tests/test_distributed.py b/python/dask_cudf/dask_cudf/tests/test_distributed.py
index 07fdb25dff9..be10b0d4843 100644
--- a/python/dask_cudf/dask_cudf/tests/test_distributed.py
+++ b/python/dask_cudf/dask_cudf/tests/test_distributed.py
@@ -9,7 +9,7 @@
from distributed.utils_test import cleanup, loop, loop_in_thread # noqa: F401
import cudf
-from cudf.testing._utils import assert_eq
+from cudf.testing import assert_eq
import dask_cudf