Skip to content

Commit

Permalink
Merge pull request #12 from olincollege/SAN-63-bronze-db
Browse files Browse the repository at this point in the history
SAN-63-bronze-db
  • Loading branch information
crane919 authored Nov 20, 2024
2 parents 83dc3cc + 1693f7a commit c661c0e
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ env/
.env
.idea/
*.pt
*.DS_Store
illuminance_results/
.vscode/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
129 changes: 129 additions & 0 deletions src/night_light/db/bronze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import os
import duckdb
import geopandas as gpd
from geopandas import GeoDataFrame
from typing import List, Tuple, Union


def connect_to_duckdb(db_path: str) -> duckdb.DuckDBPyConnection:
"""
Establish a connection to a DuckDB database and load the spatial extension.
Args:
db_path (str): Path to the DuckDB database file.
Returns:
duckdb.DuckDBPyConnection: Connection to the DuckDB database.
"""
con = duckdb.connect(db_path)
is_spatial_installed = con.execute(
"""SELECT EXISTS (
SELECT 1
FROM duckdb_extensions()
WHERE extension_name = 'spatial' AND installed
)
"""
).fetchone()[0]
if not is_spatial_installed:
con.install_extension("spatial")
con.load_extension("spatial")
return con


def load_data_to_table(
con: duckdb.DuckDBPyConnection,
data_source: Union[str, GeoDataFrame],
table_name: str,
) -> None:
"""
Load GeoJSON or GeoDataFrame into DuckDB.
Args:
con (duckdb.DuckDBPyConnection): Connection to the DuckDB database.
data_source (Union[str, GeoDataFrame]): Path to a GeoJSON file or a
GeoDataFrame.
table_name (str): Name of the target table.
"""
if con.execute(
f"SELECT 1 FROM information_schema.tables WHERE table_name = '{table_name}'"
).fetchone():
print(f"Table '{table_name}' already exists. Skipping data load.")
return

gdf = (
gpd.read_file(data_source)
if isinstance(data_source, str) and os.path.isfile(data_source)
else data_source
)
if not isinstance(gdf, gpd.GeoDataFrame):
raise ValueError(
"data_source must be a valid GeoJSON file path or GeoDataFrame."
)

# Convert geometry to WKT and ensure compatibility with DuckDB
if "geometry" in gdf:
gdf["geometry"] = gdf["geometry"].to_wkt()
gdf = gdf.astype({col: "string" for col in gdf.select_dtypes("object").columns})

con.register("temp_gdf", gdf)
con.execute(f"CREATE TABLE {table_name} AS SELECT * FROM temp_gdf")


def load_multiple_datasets(
con: duckdb.DuckDBPyConnection, datasets: List[Tuple[Union[str, GeoDataFrame], str]]
) -> None:
"""
Load multiple GeoJSON files or GeoDataFrames into DuckDB.
Args:
con (duckdb.DuckDBPyConnection): Connection to the DuckDB database.
datasets (List[Tuple[Union[str, GeoDataFrame], str]]): List of tuples where
each tuple contains:
- data_source (Union[str, GeoDataFrame]): Path to a GeoJSON file or a
GeoDataFrame.
- table_name (str): Name of the target table.
"""
for data_source, table_name in datasets:
load_data_to_table(con, data_source, table_name)


def query_table_to_gdf(
con: duckdb.DuckDBPyConnection,
table_name: str,
query: str = None,
) -> GeoDataFrame:
"""
Query a DuckDB table and return the results as a GeoPandas DataFrame.
Args:
con (duckdb.DuckDBPyConnection): Connection to the DuckDB database.
table_name (str): Name of the table to query.
query (Optional[str]): SQL query to execute. Default is to fetch the first 10 rows.
Returns:
GeoDataFrame: Results of the query.
"""
if query is None:
query = "SELECT * FROM {table_name} LIMIT 10".format(table_name=table_name)
df = con.execute(query).fetchdf()
df["geometry"] = gpd.GeoSeries.from_wkt(df["geometry"])
gdf = gpd.GeoDataFrame(df)
return gdf


if __name__ == "__main__":
db_path = "bronze.db"
# Run the scripts in tests to generate the GeoJSON files
datasets = [
("../../../tests/test_boston_crosswalk.geojson", "crosswalks"),
("../../../tests/test_boston_streetlights.geojson", "streetlights"),
("../../../tests/test_all_population_density.geojson", "population_density"),
("../../../tests/test_boston_traffic.geojson", "traffic"),
("../../../tests/test_boston_vision_zero.geojson", "accidents"),
("../../../tests/test_ma_median_household_income.geojson", "median_income"),
]

conn = connect_to_duckdb(db_path)
load_multiple_datasets(conn, datasets)
gdf = query_table_to_gdf(conn, "crosswalks")
print(gdf)

0 comments on commit c661c0e

Please sign in to comment.