From a24725e4eb3bd421e4b3fbb91892c8af8a087622 Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Tue, 23 May 2023 15:41:06 +0000 Subject: [PATCH 1/2] pull in bulk sampler fix --- python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py b/python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py index fd7366cbe40..d7f1c136484 100644 --- a/python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py +++ b/python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py @@ -43,7 +43,7 @@ def _write_samples_to_parquet( """ # Required by dask; need to skip dummy partitions. - if partition_info is None: + if partition_info is None or len(results) == 0: return if partition_info != "sg" and (not isinstance(partition_info, dict)): raise ValueError("Invalid value of partition_info") @@ -69,7 +69,7 @@ def _write_samples_to_parquet( results_p["batch_id"] = offsets_p.batch_id.repeat( cupy.diff(offsets_p.offsets.values, append=end_ix) ).values - results_p.to_parquet(full_output_path) + results_p.to_parquet(full_output_path, compression=None, index=False) def write_samples( From 0433cbd26b1fa2601e661b8d7871ee79f3ca2a3a Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Tue, 23 May 2023 15:45:08 +0000 Subject: [PATCH 2/2] revert compression change --- python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py b/python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py index d7f1c136484..673b53838c5 100644 --- a/python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py +++ b/python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py @@ -69,7 +69,7 @@ def _write_samples_to_parquet( results_p["batch_id"] = offsets_p.batch_id.repeat( cupy.diff(offsets_p.offsets.values, append=end_ix) ).values - results_p.to_parquet(full_output_path, compression=None, index=False) + results_p.to_parquet(full_output_path) def write_samples(