Skip to content

Commit

Permalink
Refactor asyncpg support (#26)
Browse files Browse the repository at this point in the history
* fix: Add support for the relations table, add getKeys() to get columns used for SELECT

* fix: Add support for a multipolygon boundary, and converting a Record to a proper Feature

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
rsavoye and pre-commit-ci[bot] authored Jul 25, 2024
1 parent 8b4ea6e commit ad8de2d
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 40 deletions.
21 changes: 21 additions & 0 deletions osm_rawdata/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@ def __init__(self, boundary: Polygon = None):
"nodes": [],
"ways_poly": [],
"ways_line": [],
"relations": [],
},
"tables": [],
"where": {
"nodes": [],
"ways_poly": [],
"ways_line": [],
"relations": [],
},
"keep": [],
}
Expand Down Expand Up @@ -280,6 +282,25 @@ def convert_geometry(geom_type):

return self.config

def getKeys(self):
""" """
keys = list()
# The first column returned is always the geometry
keys.append("geometry")
for key, value in self.config["select"].items():
if isinstance(value, list):
for v in value:
if isinstance(v, str):
# print(f"\tSelecting table '{key}' has value '{v}'")
keys.append(v)
continue
for k1, v1 in v.items():
keys.append(k1)
# print(f"\tSelecting table '{key}' tag '{k1}'")
# else:
# print(f"\tSelecting tag '{key}'")
return keys

def dump(self):
"""Dump the contents of the internal data strucute for debugging purposes."""
print("Dumping QueryConfig class")
Expand Down
163 changes: 123 additions & 40 deletions osm_rawdata/pgasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self):
self.pg = None
self.dburi = None
self.qc = None
self.clipped = False

async def connect(
self,
Expand Down Expand Up @@ -199,6 +200,35 @@ async def createJson(
feature["centroid"] = true
return json.dumps(feature)

async def recordsToFeatures(
self,
records: list,
) -> list:
"""Convert an asyncpg.Record to a GeoJson FeatureCollection.
Args:
records (list): The records from an SQL query
Returns:
(list): The converted data
"""
data = list()
keys = self.qc.getKeys()
for entry in records:
i = 0
geom = None
last = len(entry) - 1
props = dict()
while i <= last:
if keys[i] == "geometry":
geom = wkt.loads(entry[i])
elif entry[i] is not None:
props[keys[i]] = entry[i]
i += 1
data.append(Feature(geometry=geom, properties=props))

return data

async def createSQL(
self,
config: QueryConfig,
Expand All @@ -218,12 +248,18 @@ async def createSQL(
for table in config.config["tables"]:
select = "SELECT "
if allgeom:
select += "ST_AsText(geom)"
select += "ST_AsText(geom) AS geometry"
else:
select += "ST_AsText(ST_Centroid(geom))"
select += "ST_AsText(ST_Centroid(geom)) AS geometry"
# FIXME: This part is OSM specific, and should be made more
# general. these two columns are OSM attributes, so each
# have their own column in the database. All the other
# values are in a single JSON column.
select += ", osm_id, version, "
for entry in config.config["select"][table]:
for k1, v1 in entry.items():
if k1 == "osm_id" or k1 == "version":
continue
select += f"tags->>'{k1}', "
select = select[:-2]

Expand Down Expand Up @@ -368,7 +404,7 @@ async def getPage(
async def execute(
self,
sql: str,
):
) -> list:
"""Execute a raw SQL query and return the results.
Args:
Expand All @@ -378,13 +414,31 @@ async def execute(
(list): The results of the query
"""
# print(sql)
data = list()
if sql.find(";") <= 0:
queries = [sql]

async with self.pg.transaction():
try:
result = await self.pg.fetch(sql)
return result
except Exception as e:
log.error(f"Couldn't execute query! {e}\n{sql}")
return list()
queries = list()
# If using an SRID, we have to hide the sem-colon so the string
# doesn't split in the wrong place.
cmds = sql.replace("SRID=4326;P", "SRID=4326@P").split(";")
for sub in cmds:
queries.append(sub.replace("@", ";"))
continue

for query in queries:
try:
# print(query)
result = await self.pg.fetch(query)
if len(result) > 0:
data += result

except Exception as e:
log.error(f"Couldn't execute query! {e}\n{query}")
return list()

return data

async def queryLocal(
self,
Expand All @@ -404,15 +458,16 @@ async def queryLocal(
"""
features = list()
# if no boundary, it's already been setup
# if boundary and not self.clipped:
if boundary:
sql = f"DROP VIEW IF EXISTS ways_view;CREATE VIEW ways_view AS SELECT * FROM ways_poly WHERE ST_CONTAINS(ST_GeomFromEWKT('SRID=4326;{boundary.wkt}'), geom)"
await self.execute(sql)
await self.pg.execute(sql)
sql = f"DROP VIEW IF EXISTS nodes_view;CREATE VIEW nodes_view AS SELECT * FROM nodes WHERE ST_CONTAINS(ST_GeomFromEWKT('SRID=4326;{boundary.wkt}'), geom)"
await self.execute(sql)
await self.pg.execute(sql)
sql = f"DROP VIEW IF EXISTS lines_view;CREATE VIEW lines_view AS SELECT * FROM ways_line WHERE ST_CONTAINS(ST_GeomFromEWKT('SRID=4326;{boundary.wkt}'), geom)"
await self.execute(sql)
sql = f"DROP VIEW IF EXISTS relations_view;CREATE TEMP VIEW relations_view AS SELECT * FROM nodes WHERE ST_CONTAINS(ST_GeomFromEWKT('SRID=4326;{boundary.wkt}'), geom)"
await self.execute(sql)
await self.pg.execute(sql)
sql = f"DROP VIEW IF EXISTS relations_view;CREATE VIEW relations_view AS SELECT * FROM nodes WHERE ST_CONTAINS(ST_GeomFromEWKT('SRID=4326;{boundary.wkt}'), geom)"
await self.pg.execute(sql)

if query.find(" ways_poly ") > 0:
query = query.replace("ways_poly", "ways_view")
Expand Down Expand Up @@ -467,6 +522,7 @@ async def queryLocal(
# This should be the version
tags[res[3][:-1]] = item[2]
features.append(Feature(geometry=geom, properties=tags))

return FeatureCollection(features)
# return features

Expand Down Expand Up @@ -588,28 +644,47 @@ async def execQuery(
"""
log.info("Extracting features from Postgres...")

if "features" in boundary:
# FIXME: ideally this should support multipolygons
poly = boundary["features"][0]["geometry"]
else:
poly = boundary["geometry"]
wkt = shape(poly)

if not self.pg.is_closed():
if not customsql:
sql = await self.createSQL(self.qc, allgeom)
polygons = list()
if "geometry" in boundary:
polygons.append(boundary["geometry"])

if boundary["type"] == "MultiPolygon":
# poly = boundary["features"][0]["geometry"]
points = list()
for coords in boundary["features"]["coordinates"]:
for pt in coords:
for xy in pt:
points.append([float(xy[0]), float(xy[1])])
poly = Polygon(points)
polygons.append(poly)

for poly in polygons:
wkt = shape(poly)

if not self.pg.is_closed():
if not customsql:
sql = await self.createSQL(self.qc, allgeom)
else:
sql = [customsql]
alldata = list()
queries = list()
if type(sql) != list:
queries = sql.split(";")
else:
queries = sql

for query in queries:
# print(query)
result = await self.queryLocal(query, allgeom, wkt)
if len(result) > 0:
# Some queries don't return any data, for example
# when creating a VIEW.
alldata += await self.recordsToFeatures(result)
collection = FeatureCollection(alldata)
else:
sql = [customsql]
alldata = list()
for query in sql:
# print(query)
result = await self.queryLocal(query, allgeom, wkt)
if len(result) > 0:
alldata += result["features"]
collection = FeatureCollection(alldata)
else:
request = await self.createJson(self.qc, poly, allgeom)
collection = await self.queryRemote(request)
request = await self.createJson(self.qc, poly, allgeom)
collection = await self.queryRemote(request)

return collection


Expand Down Expand Up @@ -664,15 +739,21 @@ async def main():
stream=sys.stdout,
)

infile = open(args.boundary, "r")
poly = geojson.load(infile)
if args.boundary:
infile = open(args.boundary, "r")
inpoly = geojson.load(infile)
if inpoly["type"] == "MultiPolygon":
poly = FeatureCollection(inpoly)
else:
log.error("A boundary file is needed!")

if args.uri is not None:
log.info("Using a Postgres database for the data source")
db = DatabaseAccess()
await db.connect(args.uri)
result = await db.execute("SELECT * FROM nodes LIMIT 10;")
print(result)
quit()
# result = await db.execute("SELECT * FROM nodes LIMIT 10;")
# print(result)
# quit()
# await db.connect(args.uri)
# data = await db.pg.fetch("SELECT * FROM schemas LIMIT 10;")
# print(data)
Expand Down Expand Up @@ -700,3 +781,5 @@ async def main():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(main())

# ./pgasync.py -u localhost/colorado -b /play/MapData/States/Colorado/Boundaries/NationalForest/MedicineBowNationalForest.geojson -c /usr/local/lib/python3.12/site-packages/osm_fieldwork/data_models/highways.yaml

0 comments on commit ad8de2d

Please sign in to comment.