diff --git a/python/lsst/dax/obscore/obscore_exporter.py b/python/lsst/dax/obscore/obscore_exporter.py index dceb9dc..79e8aa7 100644 --- a/python/lsst/dax/obscore/obscore_exporter.py +++ b/python/lsst/dax/obscore/obscore_exporter.py @@ -296,7 +296,6 @@ def __init__(self, butler: Butler, config: ExporterConfig): self.record_factory = RecordFactory( config, schema, universe, spatial_plugins, exposure_region_factory ) - self.overflow = False def to_parquet(self, output: str) -> None: """Export Butler datasets as ObsCore Data Model in parquet format. @@ -308,7 +307,7 @@ def to_parquet(self, output: str) -> None: """ compression = self.config.parquet_compression with ParquetWriter(output, self.schema, compression=compression) as writer: - for record_batch in self._make_record_batches(self.config.batch_size): + for record_batch, _ in self._make_record_batches(self.config.batch_size): writer.write_batch(record_batch) def to_csv(self, output: str) -> None: @@ -325,7 +324,7 @@ def to_csv(self, output: str) -> None: null_string = self.config.csv_null_string.encode() with contextlib.closing(_CSVFile(output, null_string, sep_in=b"\x1f", sep_out=b",")) as file: with CSVWriter(file, self.schema, write_options=options) as writer: - for record_batch in self._make_record_batches(self.config.batch_size): + for record_batch, _ in self._make_record_batches(self.config.batch_size): writer.write_batch(record_batch) def to_votable(self, limit: int | None = None) -> astropy.io.votable.tree.VOTableFile: @@ -408,14 +407,15 @@ def to_votable(self, limit: int | None = None) -> astropy.io.votable.tree.VOTabl chunks = [] n_rows = 0 - for record_batch in self._make_record_batches(self.config.batch_size, limit=limit): + overflow = False + for record_batch, overflow in self._make_record_batches(self.config.batch_size, limit=limit): table = ArrowTable.from_batches([record_batch]) chunk = arrow_to_numpy(table) n_rows += len(chunk) chunks.append(chunk) # Report any overflow. - query_status = "OVERFLOW" if self.overflow else "OK" + query_status = "OVERFLOW" if overflow else "OK" info = astropy.io.votable.tree.Info(name="QUERY_STATUS", value=query_status) resource.infos.append(info) @@ -424,9 +424,7 @@ def to_votable(self, limit: int | None = None) -> astropy.io.votable.tree.VOTabl table0.array = ma.hstack(chunks) # Write the output file. - _LOG.info( - "Got %d result%s%s", n_rows, "" if n_rows == 1 else "s", " (overflow)" if self.overflow else "" - ) + _LOG.info("Got %d result%s%s", n_rows, "" if n_rows == 1 else "s", " (overflow)" if overflow else "") return votable def to_votable_file(self, output: str, limit: int | None = None) -> None: @@ -468,12 +466,16 @@ def _make_schema(self, table_spec: ddl.TableSpec) -> Schema: def _make_record_batches( self, batch_size: int = 10_000, limit: int | None = None - ) -> Iterator[RecordBatch]: - """Generate batches of records to save to a file.""" + ) -> Iterator[tuple[RecordBatch, bool]]: + """Generate batches of records to save to a file. + + Yields the batches and a flag indicating whether an overflow condition + was hit. + """ batch = _BatchCollector(self.schema) - # Reset overflow flag. - self.overflow = False + # Set overflow flag. + overflow = False collections: Any = self.config.collections if not collections: @@ -525,33 +527,33 @@ def _make_record_batches( if not unlimited and count == query_limit: # Hit the +1 so should not add this to the batch. _LOG.debug("Got one more than requested limit so dropping final record.") - self.overflow = True + overflow = True break batch.add_to_batch(record) if batch.size >= batch_size: _LOG.debug("Saving next record batch, size=%s", batch.size) - yield batch.make_record_batch() + yield (batch.make_record_batch(), overflow) if not unlimited: query_limit -= count - if self.overflow: + if overflow: # We counted one too many so adjust for the log # message. count -= 1 _LOG.info("Copied %d records from dataset type %s", count, dataset_type_name) - if self.overflow: + if overflow: # No more queries need to run. # This breaks out one level of nesting. break - if self.overflow: + if overflow: # Stop further dataset type queries. break # Final batch if anything is there if batch.size > 0: _LOG.debug("Saving final record batch, size=%s", batch.size) - yield batch.make_record_batch() + yield (batch.make_record_batch(), overflow)