-
Notifications
You must be signed in to change notification settings - Fork 5
/
app.py
188 lines (156 loc) · 6.41 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
"""TiTiler+PgSTAC FastAPI application."""
import logging
from aws_lambda_powertools.metrics import MetricUnit
from rio_cogeo.cogeo import cog_info as rio_cogeo_info
from rio_cogeo.models import Info
from src.config import ApiSettings
from src.datasetparams import DatasetParams
from src.factory import MultiBaseTilerFactory
from src.version import __version__ as veda_raster_version
from fastapi import APIRouter, Depends, FastAPI, Query
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse
from starlette.templating import Jinja2Templates
from starlette_cramjam.middleware import CompressionMiddleware
from titiler.core.dependencies import DatasetPathParams
from titiler.core.errors import DEFAULT_STATUS_CODES, add_exception_handlers
from titiler.core.factory import TilerFactory
from titiler.core.middleware import CacheControlMiddleware
from titiler.core.resources.enums import OptionalHeader
from titiler.mosaic.errors import MOSAIC_STATUS_CODES
from titiler.pgstac.db import close_db_connection, connect_to_db
from titiler.pgstac.dependencies import ItemPathParams
from titiler.pgstac.factory import MosaicTilerFactory
from titiler.pgstac.reader import PgSTACReader
try:
from importlib.resources import files as resources_files # type: ignore
except ImportError:
# Try backported to PY<39 `importlib_resources`.
from importlib_resources import files as resources_files # type: ignore
from .monitoring import LoggerRouteHandler, logger, metrics, tracer
logging.getLogger("botocore.credentials").disabled = True
logging.getLogger("botocore.utils").disabled = True
logging.getLogger("rio-tiler").setLevel(logging.ERROR)
settings = ApiSettings()
templates = Jinja2Templates(directory=str(resources_files(__package__) / "templates")) # type: ignore
if settings.debug:
optional_headers = [OptionalHeader.server_timing, OptionalHeader.x_assets]
else:
optional_headers = []
app = FastAPI(title=settings.name, version=veda_raster_version)
# router to be applied to all titiler route factories (improves logs with FastAPI context)
router = APIRouter(route_class=LoggerRouteHandler)
add_exception_handlers(app, DEFAULT_STATUS_CODES)
add_exception_handlers(app, MOSAIC_STATUS_CODES)
# Custom PgSTAC mosaic tiler
mosaic = MosaicTilerFactory(
router_prefix="/mosaic",
add_mosaic_list=settings.enable_mosaic_search,
optional_headers=optional_headers,
environment_dependency=settings.get_gdal_config,
dataset_dependency=DatasetParams,
router=APIRouter(route_class=LoggerRouteHandler),
)
app.include_router(mosaic.router, prefix="/mosaic", tags=["Mosaic"])
# Custom STAC titiler endpoint (not added to the openapi docs)
stac = MultiBaseTilerFactory(
reader=PgSTACReader,
path_dependency=ItemPathParams,
optional_headers=optional_headers,
router_prefix="/stac",
environment_dependency=settings.get_gdal_config,
router=APIRouter(route_class=LoggerRouteHandler),
)
app.include_router(stac.router, tags=["Items"], prefix="/stac")
cog = TilerFactory(
router_prefix="/cog",
optional_headers=optional_headers,
environment_dependency=settings.get_gdal_config,
router=APIRouter(route_class=LoggerRouteHandler),
)
@cog.router.get("/validate", response_model=Info)
def cog_validate(
src_path: str = Depends(DatasetPathParams),
strict: bool = Query(False, description="Treat warnings as errors"),
):
"""Validate a COG"""
return rio_cogeo_info(src_path, strict=strict, config=settings.get_gdal_config())
@cog.router.get("/viewer", response_class=HTMLResponse)
def cog_demo(request: Request):
"""COG Viewer."""
return templates.TemplateResponse(
name="viewer.html",
context={
"request": request,
"tilejson_endpoint": cog.url_for(request, "tilejson"),
"info_endpoint": cog.url_for(request, "info"),
"statistics_endpoint": cog.url_for(request, "statistics"),
},
media_type="text/html",
)
app.include_router(cog.router, tags=["Cloud Optimized GeoTIFF"], prefix="/cog")
@app.get("/healthz", description="Health Check", tags=["Health Check"])
def ping():
"""Health check."""
return {"ping": "pong!!"}
# Set all CORS enabled origins
if settings.cors_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"],
)
app.add_middleware(
CacheControlMiddleware,
cachecontrol=settings.cachecontrol,
exclude_path={r"/healthz"},
)
app.add_middleware(
CompressionMiddleware,
exclude_mediatype={
"image/jpeg",
"image/jpg",
"image/png",
"image/jp2",
"image/webp",
},
)
# If the correlation header is used in the UI, we can analyze traces that originate from a given user or client
@app.middleware("http")
async def add_correlation_id(request: Request, call_next):
"""Add correlation ids to all requests and subsequent logs/traces"""
# Get correlation id from X-Correlation-Id header if provided
corr_id = request.headers.get("x-correlation-id")
if not corr_id:
try:
# If empty, use request id from aws context
corr_id = request.scope["aws.context"].aws_request_id
except KeyError:
# If empty, use uuid
corr_id = "local"
# Add correlation id to logs
logger.set_correlation_id(corr_id)
# Add correlation id to traces
tracer.put_annotation(key="correlation_id", value=corr_id)
response = await tracer.capture_method(call_next)(request)
# Return correlation header in response
response.headers["X-Correlation-Id"] = corr_id
logger.info("Request completed")
return response
@app.exception_handler(Exception)
async def validation_exception_handler(request, err):
"""Handle exceptions that aren't caught elsewhere"""
metrics.add_metric(name="UnhandledExceptions", unit=MetricUnit.Count, value=1)
logger.exception("Unhandled exception")
return JSONResponse(status_code=500, content={"detail": "Internal Server Error"})
@app.on_event("startup")
async def startup_event() -> None:
"""Connect to database on startup."""
await connect_to_db(app, settings=settings.load_postgres_settings())
@app.on_event("shutdown")
async def shutdown_event() -> None:
"""Close database connection."""
await close_db_connection(app)