Skip to content

Commit

Permalink
GPKG: make GetArrowStream() honour SetNextByIndex()
Browse files Browse the repository at this point in the history
  • Loading branch information
rouault committed Sep 2, 2023
1 parent 248cf60 commit ecf2f4d
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 27 deletions.
13 changes: 13 additions & 0 deletions autotest/ogr/ogr_gpkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7687,6 +7687,19 @@ def test_ogr_gpkg_arrow_stream_numpy():
assert batch["int16"][0] == 123
assert len(batch["fid"]) == 1

assert lyr.SetNextByIndex(1) == ogr.OGRERR_NONE
stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"])
batches = [batch for batch in stream]
assert len(batches) == 1
assert list(batches[0]["fid"]) == [2, 3]

with ds.ExecuteSQL("SELECT * FROM test") as sql_lyr:
assert sql_lyr.SetNextByIndex(1) == ogr.OGRERR_NONE
stream = sql_lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"])
batches = [batch for batch in stream]
assert len(batches) == 1
assert list(batches[0]["fid"]) == [2, 3]

with lyr.GetArrowStreamAsNumPy(options=["MAX_FEATURES_IN_BATCH=1"]) as stream:
batches = [batch for batch in stream]
assert len(batches) == 3
Expand Down
1 change: 1 addition & 0 deletions ogr/ogrsf_frmts/gpkg/ogr_geopackage.h
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,7 @@ class OGRGeoPackageTableLayer final : public OGRGeoPackageLayer
std::set<OGRwkbGeometryType> m_eSetBadGeomTypeWarned{};

int m_nIsCompatOfOptimizedGetNextArrowArray = -1;
bool m_bGetNextArrowArrayCalledSinceResetReading = false;

int m_nCountInsertInTransactionThreshold = -1;
GIntBig m_nCountInsertInTransaction = 0;
Expand Down
81 changes: 54 additions & 27 deletions ogr/ogrsf_frmts/gpkg/ogrgeopackagetablelayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3174,6 +3174,8 @@ void OGRGeoPackageTableLayer::ResetReading()

CancelAsyncNextArrowArray();

m_bGetNextArrowArrayCalledSinceResetReading = false;

BuildColumns();
}

Expand Down Expand Up @@ -3276,6 +3278,7 @@ OGRErr OGRGeoPackageTableLayer::ResetStatementInternal(GIntBig nStartIndex)
}

m_iNextShapeId = nStartIndex;
m_bGetNextArrowArrayCalledSinceResetReading = false;

return OGRERR_NONE;
}
Expand Down Expand Up @@ -7616,6 +7619,8 @@ int OGRGeoPackageTableLayer::GetNextArrowArrayAsynchronous(
{
memset(out_array, 0, sizeof(*out_array));

m_bGetNextArrowArrayCalledSinceResetReading = true;

if (m_poFillArrowArray && m_poFillArrowArray->bIsFinished)
{
return 0;
Expand Down Expand Up @@ -7713,39 +7718,53 @@ void OGRGeoPackageTableLayer::GetNextArrowArrayAsynchronousWorker()
std::string osSQL;
osSQL = "SELECT OGR_GPKG_FillArrowArray_INTERNAL(";

if (m_pszFidColumn)
const auto AddFields = [this, &osSQL]()
{
osSQL += "m.\"";
osSQL += SQLEscapeName(m_pszFidColumn);
osSQL += '"';
}
else
{
osSQL += "NULL";
}
if (m_pszFidColumn)
{
osSQL += "m.\"";
osSQL += SQLEscapeName(m_pszFidColumn);
osSQL += '"';
}
else
{
osSQL += "NULL";
}

if (!m_poFillArrowArray->psHelper->mapOGRGeomFieldToArrowField.empty() &&
m_poFillArrowArray->psHelper->mapOGRGeomFieldToArrowField[0] >= 0)
{
osSQL += ",m.\"";
osSQL += SQLEscapeName(GetGeometryColumn());
osSQL += '"';
}
for (int iField = 0; iField < m_poFillArrowArray->psHelper->nFieldCount;
iField++)
{
const int iArrowField =
m_poFillArrowArray->psHelper->mapOGRFieldToArrowField[iField];
if (iArrowField >= 0)
if (!m_poFillArrowArray->psHelper->mapOGRGeomFieldToArrowField
.empty() &&
m_poFillArrowArray->psHelper->mapOGRGeomFieldToArrowField[0] >= 0)
{
const OGRFieldDefn *poFieldDefn =
m_poFeatureDefn->GetFieldDefnUnsafe(iField);
osSQL += ",m.\"";
osSQL += SQLEscapeName(poFieldDefn->GetNameRef());
osSQL += SQLEscapeName(GetGeometryColumn());
osSQL += '"';
}
for (int iField = 0; iField < m_poFillArrowArray->psHelper->nFieldCount;
iField++)
{
const int iArrowField =
m_poFillArrowArray->psHelper->mapOGRFieldToArrowField[iField];
if (iArrowField >= 0)
{
const OGRFieldDefn *poFieldDefn =
m_poFeatureDefn->GetFieldDefnUnsafe(iField);
osSQL += ",m.\"";
osSQL += SQLEscapeName(poFieldDefn->GetNameRef());
osSQL += '"';
}
}
};

AddFields();

osSQL += ") FROM ";
if (m_iNextShapeId > 0)
{
osSQL += "(SELECT ";
AddFields();
osSQL += " FROM ";
}
osSQL += ") FROM \"";
osSQL += '\"';
osSQL += SQLEscapeName(m_pszTableName);
osSQL += "\" m";
if (!m_soFilter.empty())
Expand Down Expand Up @@ -7792,6 +7811,10 @@ void OGRGeoPackageTableLayer::GetNextArrowArrayAsynchronousWorker()
}
}

if (m_iNextShapeId > 0)
osSQL +=
CPLSPrintf(" LIMIT -1 OFFSET " CPL_FRMT_GIB ") m", m_iNextShapeId);

// CPLDebug("GPKG", "%s", osSQL.c_str());

char *pszErrMsg = nullptr;
Expand Down Expand Up @@ -7831,7 +7854,9 @@ int OGRGeoPackageTableLayer::GetNextArrowArray(struct ArrowArrayStream *stream,
}

if (m_nIsCompatOfOptimizedGetNextArrowArray == FALSE ||
m_pszFidColumn == nullptr || !m_soFilter.empty())
m_pszFidColumn == nullptr || !m_soFilter.empty() ||
m_poFillArrowArray ||
(!m_bGetNextArrowArrayCalledSinceResetReading && m_iNextShapeId > 0))
{
return GetNextArrowArrayAsynchronous(out_array);
}
Expand Down Expand Up @@ -7865,6 +7890,8 @@ int OGRGeoPackageTableLayer::GetNextArrowArray(struct ArrowArrayStream *stream,
m_nIsCompatOfOptimizedGetNextArrowArray = TRUE;
}

m_bGetNextArrowArrayCalledSinceResetReading = true;

// CPLDebug("GPKG", "m_iNextShapeId = " CPL_FRMT_GIB, m_iNextShapeId);

const int nMaxBatchSize = OGRArrowArrayHelper::GetMaxFeaturesInBatch(
Expand Down

0 comments on commit ecf2f4d

Please sign in to comment.