Skip to content

Commit

Permalink
✨ Test using the TestClient
Browse files Browse the repository at this point in the history
  • Loading branch information
tiangolo committed Apr 20, 2020
1 parent 41a2f15 commit 307fffa
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 103 deletions.
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from typing import Dict

import requests
from fastapi.testclient import TestClient

from app.core.config import settings
from app.tests.utils.utils import get_server_api


def test_celery_worker_test(superuser_token_headers: Dict[str, str]) -> None:
server_api = get_server_api()
def test_celery_worker_test(
client: TestClient, superuser_token_headers: Dict[str, str]
) -> None:
data = {"msg": "test"}
r = requests.post(
f"{server_api}{settings.API_V1_STR}/utils/test-celery/",
r = client.post(
f"{settings.API_V1_STR}/utils/test-celery/",
json=data,
headers=superuser_token_headers,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import requests
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session

from app.core.config import settings
from app.tests.utils.item import create_random_item
from app.tests.utils.utils import get_server_api


def test_create_item(superuser_token_headers: dict, db: Session) -> None:
server_api = get_server_api()
def test_create_item(
client: TestClient, superuser_token_headers: dict, db: Session
) -> None:
data = {"title": "Foo", "description": "Fighters"}
response = requests.post(
f"{server_api}{settings.API_V1_STR}/items/",
headers=superuser_token_headers,
json=data,
response = client.post(
f"{settings.API_V1_STR}/items/", headers=superuser_token_headers, json=data,
)
assert response.status_code == 200
content = response.json()
Expand All @@ -22,12 +20,12 @@ def test_create_item(superuser_token_headers: dict, db: Session) -> None:
assert "owner_id" in content


def test_read_item(superuser_token_headers: dict, db: Session) -> None:
def test_read_item(
client: TestClient, superuser_token_headers: dict, db: Session
) -> None:
item = create_random_item(db)
server_api = get_server_api()
response = requests.get(
f"{server_api}{settings.API_V1_STR}/items/{item.id}",
headers=superuser_token_headers,
response = client.get(
f"{settings.API_V1_STR}/items/{item.id}", headers=superuser_token_headers,
)
assert response.status_code == 200
content = response.json()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
from typing import Dict

import requests
from fastapi.testclient import TestClient

from app.core.config import settings
from app.tests.utils.utils import get_server_api


def test_get_access_token() -> None:
server_api = get_server_api()
def test_get_access_token(client: TestClient) -> None:
login_data = {
"username": settings.FIRST_SUPERUSER,
"password": settings.FIRST_SUPERUSER_PASSWORD,
}
r = requests.post(
f"{server_api}{settings.API_V1_STR}/login/access-token", data=login_data
)
r = client.post(f"{settings.API_V1_STR}/login/access-token", data=login_data)
tokens = r.json()
assert r.status_code == 200
assert "access_token" in tokens
assert tokens["access_token"]


def test_use_access_token(superuser_token_headers: Dict[str, str]) -> None:
server_api = get_server_api()
r = requests.post(
f"{server_api}{settings.API_V1_STR}/login/test-token",
headers=superuser_token_headers,
def test_use_access_token(
client: TestClient, superuser_token_headers: Dict[str, str]
) -> None:
r = client.post(
f"{settings.API_V1_STR}/login/test-token", headers=superuser_token_headers,
)
result = r.json()
assert r.status_code == 200
Expand Down
Original file line number Diff line number Diff line change
@@ -1,47 +1,44 @@
from typing import Dict

import requests
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session

from app import crud
from app.core.config import settings
from app.schemas.user import UserCreate
from app.tests.utils.utils import get_server_api, random_email, random_lower_string
from app.tests.utils.utils import random_email, random_lower_string


def test_get_users_superuser_me(superuser_token_headers: Dict[str, str]) -> None:
server_api = get_server_api()
r = requests.get(
f"{server_api}{settings.API_V1_STR}/users/me", headers=superuser_token_headers
)
def test_get_users_superuser_me(
client: TestClient, superuser_token_headers: Dict[str, str]
) -> None:
r = client.get(f"{settings.API_V1_STR}/users/me", headers=superuser_token_headers)
current_user = r.json()
assert current_user
assert current_user["is_active"] is True
assert current_user["is_superuser"]
assert current_user["email"] == settings.FIRST_SUPERUSER


def test_get_users_normal_user_me(normal_user_token_headers: Dict[str, str]) -> None:
server_api = get_server_api()
r = requests.get(
f"{server_api}{settings.API_V1_STR}/users/me", headers=normal_user_token_headers
)
def test_get_users_normal_user_me(
client: TestClient, normal_user_token_headers: Dict[str, str]
) -> None:
r = client.get(f"{settings.API_V1_STR}/users/me", headers=normal_user_token_headers)
current_user = r.json()
assert current_user
assert current_user["is_active"] is True
assert current_user["is_superuser"] is False
assert current_user["email"] == settings.EMAIL_TEST_USER


def test_create_user_new_email(superuser_token_headers: dict, db: Session) -> None:
server_api = get_server_api()
def test_create_user_new_email(
client: TestClient, superuser_token_headers: dict, db: Session
) -> None:
username = random_email()
password = random_lower_string()
data = {"email": username, "password": password}
r = requests.post(
f"{server_api}{settings.API_V1_STR}/users/",
headers=superuser_token_headers,
json=data,
r = client.post(
f"{settings.API_V1_STR}/users/", headers=superuser_token_headers, json=data,
)
assert 200 <= r.status_code < 300
created_user = r.json()
Expand All @@ -50,16 +47,16 @@ def test_create_user_new_email(superuser_token_headers: dict, db: Session) -> No
assert user.email == created_user["email"]


def test_get_existing_user(superuser_token_headers: dict, db: Session) -> None:
server_api = get_server_api()
def test_get_existing_user(
client: TestClient, superuser_token_headers: dict, db: Session
) -> None:
username = random_email()
password = random_lower_string()
user_in = UserCreate(email=username, password=password)
user = crud.user.create(db, obj_in=user_in)
user_id = user.id
r = requests.get(
f"{server_api}{settings.API_V1_STR}/users/{user_id}",
headers=superuser_token_headers,
r = client.get(
f"{settings.API_V1_STR}/users/{user_id}", headers=superuser_token_headers,
)
assert 200 <= r.status_code < 300
api_user = r.json()
Expand All @@ -69,40 +66,37 @@ def test_get_existing_user(superuser_token_headers: dict, db: Session) -> None:


def test_create_user_existing_username(
superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict, db: Session
) -> None:
server_api = get_server_api()
username = random_email()
# username = email
password = random_lower_string()
user_in = UserCreate(email=username, password=password)
crud.user.create(db, obj_in=user_in)
data = {"email": username, "password": password}
r = requests.post(
f"{server_api}{settings.API_V1_STR}/users/",
headers=superuser_token_headers,
json=data,
r = client.post(
f"{settings.API_V1_STR}/users/", headers=superuser_token_headers, json=data,
)
created_user = r.json()
assert r.status_code == 400
assert "_id" not in created_user


def test_create_user_by_normal_user(normal_user_token_headers: Dict[str, str]) -> None:
server_api = get_server_api()
def test_create_user_by_normal_user(
client: TestClient, normal_user_token_headers: Dict[str, str]
) -> None:
username = random_email()
password = random_lower_string()
data = {"email": username, "password": password}
r = requests.post(
f"{server_api}{settings.API_V1_STR}/users/",
headers=normal_user_token_headers,
json=data,
r = client.post(
f"{settings.API_V1_STR}/users/", headers=normal_user_token_headers, json=data,
)
assert r.status_code == 400


def test_retrieve_users(superuser_token_headers: dict, db: Session) -> None:
server_api = get_server_api()
def test_retrieve_users(
client: TestClient, superuser_token_headers: dict, db: Session
) -> None:
username = random_email()
password = random_lower_string()
user_in = UserCreate(email=username, password=password)
Expand All @@ -113,9 +107,7 @@ def test_retrieve_users(superuser_token_headers: dict, db: Session) -> None:
user_in2 = UserCreate(email=username2, password=password2)
crud.user.create(db, obj_in=user_in2)

r = requests.get(
f"{server_api}{settings.API_V1_STR}/users/", headers=superuser_token_headers
)
r = client.get(f"{settings.API_V1_STR}/users/", headers=superuser_token_headers)
all_users = r.json()

assert len(all_users) > 1
Expand Down
23 changes: 14 additions & 9 deletions {{cookiecutter.project_slug}}/backend/app/app/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,34 @@
from typing import Dict, Iterator
from typing import Dict, Generator

import pytest
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session

from app.core.config import settings
from app.db.session import SessionLocal
from app.main import app
from app.tests.utils.user import authentication_token_from_email
from app.tests.utils.utils import get_server_api, get_superuser_token_headers
from app.tests.utils.utils import get_superuser_token_headers


@pytest.fixture(scope="session")
def db() -> Iterator[Session]:
def db() -> Generator:
yield SessionLocal()


@pytest.fixture(scope="module")
def server_api() -> str:
return get_server_api()
def client() -> Generator:
with TestClient(app) as c:
yield c


@pytest.fixture(scope="module")
def superuser_token_headers() -> Dict[str, str]:
return get_superuser_token_headers()
def superuser_token_headers(client: TestClient) -> Dict[str, str]:
return get_superuser_token_headers(client)


@pytest.fixture(scope="module")
def normal_user_token_headers(db: Session) -> Dict[str, str]:
return authentication_token_from_email(email=settings.EMAIL_TEST_USER, db=db)
def normal_user_token_headers(client: TestClient, db: Session) -> Dict[str, str]:
return authentication_token_from_email(
client=client, email=settings.EMAIL_TEST_USER, db=db
)
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
from typing import Dict

import requests
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session

from app import crud
from app.core.config import settings
from app.models.user import User
from app.schemas.user import UserCreate, UserUpdate
from app.tests.utils.utils import get_server_api, random_email, random_lower_string
from app.tests.utils.utils import random_email, random_lower_string


def user_authentication_headers(
server_api: str, email: str, password: str
*, client: TestClient, email: str, password: str
) -> Dict[str, str]:
data = {"username": email, "password": password}

r = requests.post(
f"{server_api}{settings.API_V1_STR}/login/access-token", data=data
)
r = client.post(f"{settings.API_V1_STR}/login/access-token", data=data)
response = r.json()
auth_token = response["access_token"]
headers = {"Authorization": f"Bearer {auth_token}"}
Expand All @@ -32,7 +30,9 @@ def create_random_user(db: Session) -> User:
return user


def authentication_token_from_email(*, email: str, db: Session) -> Dict[str, str]:
def authentication_token_from_email(
*, client: TestClient, email: str, db: Session
) -> Dict[str, str]:
"""
Return a valid token for the user with given email.
Expand All @@ -47,4 +47,4 @@ def authentication_token_from_email(*, email: str, db: Session) -> Dict[str, str
user_in_update = UserUpdate(password=password)
user = crud.user.update(db, db_obj=user, obj_in=user_in_update)

return user_authentication_headers(get_server_api(), email, password)
return user_authentication_headers(client=client, email=email, password=password)
15 changes: 3 additions & 12 deletions {{cookiecutter.project_slug}}/backend/app/app/tests/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import string
from typing import Dict

import requests
from fastapi.testclient import TestClient

from app.core.config import settings

Expand All @@ -15,22 +15,13 @@ def random_email() -> str:
return f"{random_lower_string()}@{random_lower_string()}.com"


def get_server_api() -> str:
server_name = f"http://{settings.SERVER_NAME}"
return server_name


def get_superuser_token_headers() -> Dict[str, str]:
server_api = get_server_api()
def get_superuser_token_headers(client: TestClient) -> Dict[str, str]:
login_data = {
"username": settings.FIRST_SUPERUSER,
"password": settings.FIRST_SUPERUSER_PASSWORD,
}
r = requests.post(
f"{server_api}{settings.API_V1_STR}/login/access-token", data=login_data
)
r = client.post(f"{settings.API_V1_STR}/login/access-token", data=login_data)
tokens = r.json()
a_token = tokens["access_token"]
headers = {"Authorization": f"Bearer {a_token}"}
# superuser_token_headers = headers
return headers
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed

from app.db.session import SessionLocal
from app.tests.api.api_v1.test_login import test_get_access_token

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand All @@ -23,8 +22,6 @@ def init() -> None:
# Try to create session to check if DB is awake
db = SessionLocal()
db.execute("SELECT 1")
# Wait for API to be awake, run one simple tests to authenticate
test_get_access_token()
except Exception as e:
logger.error(e)
raise e
Expand Down

0 comments on commit 307fffa

Please sign in to comment.