Skip to content

Commit

Permalink
✨ Use Pydantic BaseSettings for config settings (#87)
Browse files Browse the repository at this point in the history
* Use Pydantic BaseSettings for config settings

* Update fastapi dep to >=0.47.0 and email_validator to email-validator

* Fix deprecation warning for Pydantic >=1.0

* Properly support old-format comma separated strings for BACKEND_CORS_ORIGINS

Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
  • Loading branch information
StephenBrown2 and tiangolo authored Apr 17, 2020
1 parent cd875e5 commit 79631c7
Show file tree
Hide file tree
Showing 24 changed files with 173 additions and 140 deletions.
2 changes: 1 addition & 1 deletion cookiecutter.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"secret_key": "changethis",
"first_superuser": "admin@{{cookiecutter.domain_main}}",
"first_superuser_password": "changethis",
"backend_cors_origins": "http://localhost, http://localhost:4200, http://localhost:3000, http://localhost:8080, https://localhost, https://localhost:4200, https://localhost:3000, https://localhost:8080, http://dev.{{cookiecutter.domain_main}}, https://{{cookiecutter.domain_staging}}, https://{{cookiecutter.domain_main}}, http://local.dockertoolbox.tiangolo.com, http://localhost.tiangolo.com",
"backend_cors_origins": "[\"http://localhost\", \"http://localhost:4200\", \"http://localhost:3000\", \"http://localhost:8080\", \"https://localhost\", \"https://localhost:4200\", \"https://localhost:3000\", \"https://localhost:8080\", \"http://dev.{{cookiecutter.domain_main}}\", \"https://{{cookiecutter.domain_staging}}\", \"https://{{cookiecutter.domain_main}}\", \"http://local.dockertoolbox.tiangolo.com\", \"http://localhost.tiangolo.com\"]",
"smtp_port": "587",
"smtp_host": "",
"smtp_user": "",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from app import crud
from app.api.utils.db import get_db
from app.api.utils.security import get_current_user
from app.core import config
from app.core.config import settings
from app.core.jwt import create_access_token
from app.core.security import get_password_hash
from app.models.user import User as DBUser
Expand Down Expand Up @@ -37,7 +37,7 @@ def login_access_token(
raise HTTPException(status_code=400, detail="Incorrect email or password")
elif not crud.user.is_active(user):
raise HTTPException(status_code=400, detail="Inactive user")
access_token_expires = timedelta(minutes=config.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
return {
"access_token": create_access_token(
data={"user_id": user.id}, expires_delta=access_token_expires
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from app import crud
from app.api.utils.db import get_db
from app.api.utils.security import get_current_active_superuser, get_current_active_user
from app.core import config
from app.core.config import settings
from app.models.user import User as DBUser
from app.schemas.user import User, UserCreate, UserUpdate
from app.utils import send_new_account_email
Expand Down Expand Up @@ -47,7 +47,7 @@ def create_user(
detail="The user with this username already exists in the system.",
)
user = crud.user.create(db, obj_in=user_in)
if config.EMAILS_ENABLED and user_in.email:
if settings.EMAILS_ENABLED and user_in.email:
send_new_account_email(
email_to=user_in.email, username=user_in.email, password=user_in.password
)
Expand Down Expand Up @@ -100,7 +100,7 @@ def create_user_open(
"""
Create new user without the need to be logged in.
"""
if not config.USERS_OPEN_REGISTRATION:
if not settings.USERS_OPEN_REGISTRATION:
raise HTTPException(
status_code=403,
detail="Open user registration is forbidden on this server",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@

from app import crud
from app.api.utils.db import get_db
from app.core import config
from app.core.config import settings
from app.core.jwt import ALGORITHM
from app.models.user import User
from app.schemas.token import TokenPayload

reusable_oauth2 = OAuth2PasswordBearer(tokenUrl="/api/v1/login/access-token")
reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/login/access-token")


def get_current_user(
db: Session = Depends(get_db), token: str = Security(reusable_oauth2)
):
try:
payload = jwt.decode(token, config.SECRET_KEY, algorithms=[ALGORITHM])
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
token_data = TokenPayload(**payload)
except PyJWTError:
raise HTTPException(
Expand Down
141 changes: 89 additions & 52 deletions {{cookiecutter.project_slug}}/backend/app/app/core/config.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,92 @@
import os
import secrets
from typing import List

from pydantic import AnyHttpUrl, BaseSettings, EmailStr, HttpUrl, PostgresDsn, validator

def getenv_boolean(var_name, default_value=False):
result = default_value
env_value = os.getenv(var_name)
if env_value is not None:
result = env_value.upper() in ("TRUE", "1")
return result


API_V1_STR = "/api/v1"

SECRET_KEY = os.getenvb(b"SECRET_KEY")
if not SECRET_KEY:
SECRET_KEY = os.urandom(32)

ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 8 # 60 minutes * 24 hours * 8 days = 8 days

SERVER_NAME = os.getenv("SERVER_NAME")
SERVER_HOST = os.getenv("SERVER_HOST")
BACKEND_CORS_ORIGINS = os.getenv(
"BACKEND_CORS_ORIGINS"
) # a string of origins separated by commas, e.g: "http://localhost, http://localhost:4200, http://localhost:3000, http://localhost:8080, http://local.dockertoolbox.tiangolo.com"
PROJECT_NAME = os.getenv("PROJECT_NAME")
SENTRY_DSN = os.getenv("SENTRY_DSN")

POSTGRES_SERVER = os.getenv("POSTGRES_SERVER")
POSTGRES_USER = os.getenv("POSTGRES_USER")
POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD")
POSTGRES_DB = os.getenv("POSTGRES_DB")
SQLALCHEMY_DATABASE_URI = (
f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_SERVER}/{POSTGRES_DB}"
)

SMTP_TLS = getenv_boolean("SMTP_TLS", True)
SMTP_PORT = None
_SMTP_PORT = os.getenv("SMTP_PORT")
if _SMTP_PORT is not None:
SMTP_PORT = int(_SMTP_PORT)
SMTP_HOST = os.getenv("SMTP_HOST")
SMTP_USER = os.getenv("SMTP_USER")
SMTP_PASSWORD = os.getenv("SMTP_PASSWORD")
EMAILS_FROM_EMAIL = os.getenv("EMAILS_FROM_EMAIL")
EMAILS_FROM_NAME = PROJECT_NAME
EMAIL_RESET_TOKEN_EXPIRE_HOURS = 48
EMAIL_TEMPLATES_DIR = "/app/app/email-templates/build"
EMAILS_ENABLED = SMTP_HOST and SMTP_PORT and EMAILS_FROM_EMAIL

FIRST_SUPERUSER = os.getenv("FIRST_SUPERUSER")
FIRST_SUPERUSER_PASSWORD = os.getenv("FIRST_SUPERUSER_PASSWORD")

USERS_OPEN_REGISTRATION = getenv_boolean("USERS_OPEN_REGISTRATION")

EMAIL_TEST_USER = "test@example.com"

class Settings(BaseSettings):

API_V1_STR: str = "/api/v1"

SECRET_KEY: str = secrets.token_urlsafe(32)

ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 60 minutes * 24 hours * 8 days = 8 days

SERVER_NAME: str
SERVER_HOST: AnyHttpUrl
# BACKEND_CORS_ORIGINS is a JSON-formatted list of origins
# e.g: '["http://localhost", "http://localhost:4200", "http://localhost:3000", \
# "http://localhost:8080", "http://local.dockertoolbox.tiangolo.com"]'
BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = []

@validator("BACKEND_CORS_ORIGINS", pre=True)
def assemble_cors_origins(cls, v):
if isinstance(v, str) and not v.startswith("["):
return [i.strip() for i in v.split(",")]
return v

PROJECT_NAME: str
SENTRY_DSN: HttpUrl = None

@validator("SENTRY_DSN", pre=True)
def sentry_dsn_can_be_blank(cls, v):
if len(v) == 0:
return None
return v

POSTGRES_SERVER: str
POSTGRES_USER: str
POSTGRES_PASSWORD: str
POSTGRES_DB: str
SQLALCHEMY_DATABASE_URI: PostgresDsn = None

@validator("SQLALCHEMY_DATABASE_URI", pre=True)
def assemble_db_connection(cls, v, values):
if isinstance(v, str):
return v
return PostgresDsn.build(
scheme="postgresql",
user=values.get("POSTGRES_USER"),
password=values.get("POSTGRES_PASSWORD"),
host=values.get("POSTGRES_SERVER"),
path=f"/{values.get('POSTGRES_DB') or ''}",
)

SMTP_TLS: bool = True
SMTP_PORT: int = None
SMTP_HOST: str = None
SMTP_USER: str = None
SMTP_PASSWORD: str = None
EMAILS_FROM_EMAIL: EmailStr = None
EMAILS_FROM_NAME: str = None

@validator("EMAILS_FROM_NAME")
def get_project_name(cls, v, values):
if not v:
return values["PROJECT_NAME"]
return v

EMAIL_RESET_TOKEN_EXPIRE_HOURS: int = 48
EMAIL_TEMPLATES_DIR: str = "/app/app/email-templates/build"
EMAILS_ENABLED: bool = False

@validator("EMAILS_ENABLED", pre=True)
def get_emails_enabled(cls, v, values):
return bool(
values.get("SMTP_HOST")
and values.get("SMTP_PORT")
and values.get("EMAILS_FROM_EMAIL")
)

EMAIL_TEST_USER: EmailStr = "test@example.com"

FIRST_SUPERUSER: EmailStr
FIRST_SUPERUSER_PASSWORD: str

USERS_OPEN_REGISTRATION: bool = False

class Config:
case_sensitive = True

settings = Settings()
4 changes: 2 additions & 2 deletions {{cookiecutter.project_slug}}/backend/app/app/core/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import jwt

from app.core import config
from app.core.config import settings

ALGORITHM = "HS256"
access_token_jwt_subject = "access"
Expand All @@ -15,5 +15,5 @@ def create_access_token(*, data: dict, expires_delta: timedelta = None):
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire, "sub": access_token_jwt_subject})
encoded_jwt = jwt.encode(to_encode, config.SECRET_KEY, algorithm=ALGORITHM)
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
2 changes: 1 addition & 1 deletion {{cookiecutter.project_slug}}/backend/app/app/crud/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def update(
self, db_session: Session, *, db_obj: ModelType, obj_in: UpdateSchemaType
) -> ModelType:
obj_data = jsonable_encoder(db_obj)
update_data = obj_in.dict(skip_defaults=True)
update_data = obj_in.dict(exclude_unset=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
Expand Down
8 changes: 4 additions & 4 deletions {{cookiecutter.project_slug}}/backend/app/app/db/init_db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from app import crud
from app.core import config
from app.core.config import settings
from app.schemas.user import UserCreate

# make sure all SQL Alchemy models are imported before initializing DB
Expand All @@ -14,11 +14,11 @@ def init_db(db_session):
# the tables un-commenting the next line
# Base.metadata.create_all(bind=engine)

user = crud.user.get_by_email(db_session, email=config.FIRST_SUPERUSER)
user = crud.user.get_by_email(db_session, email=settings.FIRST_SUPERUSER)
if not user:
user_in = UserCreate(
email=config.FIRST_SUPERUSER,
password=config.FIRST_SUPERUSER_PASSWORD,
email=settings.FIRST_SUPERUSER,
password=settings.FIRST_SUPERUSER_PASSWORD,
is_superuser=True,
)
user = crud.user.create(db_session, obj_in=user_in)
4 changes: 2 additions & 2 deletions {{cookiecutter.project_slug}}/backend/app/app/db/session.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker

from app.core import config
from app.core.config import settings

engine = create_engine(config.SQLALCHEMY_DATABASE_URI, pool_pre_ping=True)
engine = create_engine(settings.SQLALCHEMY_DATABASE_URI, pool_pre_ping=True)
db_session = scoped_session(
sessionmaker(autocommit=False, autoflush=False, bind=engine)
)
Expand Down
17 changes: 5 additions & 12 deletions {{cookiecutter.project_slug}}/backend/app/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,22 @@
from starlette.requests import Request

from app.api.api_v1.api import api_router
from app.core import config
from app.core.config import settings
from app.db.session import Session

app = FastAPI(title=config.PROJECT_NAME, openapi_url="/api/v1/openapi.json")

# CORS
origins = []
app = FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json")

# Set all CORS enabled origins
if config.BACKEND_CORS_ORIGINS:
origins_raw = config.BACKEND_CORS_ORIGINS.split(",")
for origin in origins_raw:
use_origin = origin.strip()
origins.append(use_origin)
if settings.BACKEND_CORS_ORIGINS:
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_origins=[str(origin) for origin in settings.BACKEND_CORS_ORIGINS],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
),

app.include_router(api_router, prefix=config.API_V1_STR)
app.include_router(api_router, prefix=settings.API_V1_STR)


@app.middleware("http")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import requests

from app.core import config
from app.core.config import settings
from app.tests.utils.utils import get_server_api


def test_celery_worker_test(superuser_token_headers):
server_api = get_server_api()
data = {"msg": "test"}
r = requests.post(
f"{server_api}{config.API_V1_STR}/utils/test-celery/",
f"{server_api}{settings.API_V1_STR}/utils/test-celery/",
json=data,
headers=superuser_token_headers,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import requests

from app.core import config
from app.core.config import settings
from app.tests.utils.item import create_random_item
from app.tests.utils.utils import get_server_api
from app.tests.utils.user import create_random_user
Expand All @@ -10,7 +10,7 @@ def test_create_item(superuser_token_headers):
server_api = get_server_api()
data = {"title": "Foo", "description": "Fighters"}
response = requests.post(
f"{server_api}{config.API_V1_STR}/items/",
f"{server_api}{settings.API_V1_STR}/items/",
headers=superuser_token_headers,
json=data,
)
Expand All @@ -26,7 +26,7 @@ def test_read_item(superuser_token_headers):
item = create_random_item()
server_api = get_server_api()
response = requests.get(
f"{server_api}{config.API_V1_STR}/items/{item.id}",
f"{server_api}{settings.API_V1_STR}/items/{item.id}",
headers=superuser_token_headers,
)
assert response.status_code == 200
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import requests

from app.core import config
from app.core.config import settings
from app.tests.utils.utils import get_server_api


def test_get_access_token():
server_api = get_server_api()
login_data = {
"username": config.FIRST_SUPERUSER,
"password": config.FIRST_SUPERUSER_PASSWORD,
"username": settings.FIRST_SUPERUSER,
"password": settings.FIRST_SUPERUSER_PASSWORD,
}
r = requests.post(
f"{server_api}{config.API_V1_STR}/login/access-token", data=login_data
f"{server_api}{settings.API_V1_STR}/login/access-token", data=login_data
)
tokens = r.json()
assert r.status_code == 200
Expand All @@ -22,7 +22,7 @@ def test_get_access_token():
def test_use_access_token(superuser_token_headers):
server_api = get_server_api()
r = requests.post(
f"{server_api}{config.API_V1_STR}/login/test-token",
f"{server_api}{settings.API_V1_STR}/login/test-token",
headers=superuser_token_headers,
)
result = r.json()
Expand Down
Loading

0 comments on commit 79631c7

Please sign in to comment.