Skip to content

Commit 4f971ac

Browse files
committed
Added endpoint to generate csrf tokens
1 parent 609ab39 commit 4f971ac

File tree

7 files changed

+69
-20
lines changed

7 files changed

+69
-20
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from django.http import JsonResponse
2+
from django.middleware.csrf import get_token
3+
4+
def get_csrf_token(request):
5+
"""Returns a response with the CSRF token to set it in cookies."""
6+
return JsonResponse({"csrftoken": get_token(request)})
7+

backend/settings/urls.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
from django.conf.urls.static import static
99
from django.contrib import admin
1010
from django.urls import include, path
11-
from django.views.decorators.csrf import csrf_exempt
11+
from django.views.decorators.csrf import csrf_protect
1212
from graphene_django.views import GraphQLView
1313
from rest_framework import routers
1414

1515
from apps.core.api.algolia import algolia_search
16+
from apps.core.api.csrf_token import get_csrf_token
1617
from apps.github.api.urls import router as github_router
1718
from apps.owasp.api.urls import router as owasp_router
1819
from apps.slack.apps import SlackConfig
@@ -22,8 +23,9 @@
2223
router.registry.extend(owasp_router.registry)
2324

2425
urlpatterns = [
25-
path("idx/", csrf_exempt(algolia_search)),
26-
path("graphql/", csrf_exempt(GraphQLView.as_view(graphiql=settings.DEBUG))),
26+
path("idx/", csrf_protect(algolia_search)),
27+
path("graphql/", csrf_protect(GraphQLView.as_view(graphiql=settings.DEBUG))),
28+
path("csrf", get_csrf_token),
2729
path("api/v1/", include(router.urls)),
2830
path("a/", admin.site.urls),
2931
]
Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import { getCsrfToken } from 'utils/utility'
22

3+
jest.mock('api/getCsrfToken', () => ({
4+
getInitialCsrfToken: jest.fn(() => Promise.resolve('abc123')),
5+
}))
6+
37
describe('utility tests', () => {
48
beforeEach(() => {
59
jest.clearAllMocks()
@@ -9,33 +13,39 @@ describe('utility tests', () => {
913
})
1014
})
1115

12-
test('returns CSRF token when it exists in cookies', () => {
16+
test('returns CSRF token when it exists in cookies', async () => {
1317
document.cookie = 'csrftoken=abc123; otherkey=xyz789'
14-
expect(getCsrfToken()).toBe('abc123')
18+
const result = await getCsrfToken()
19+
expect(result).toBe('abc123')
1520
})
1621

17-
test('returns undefined when no cookies are present', () => {
22+
test('returns new token when no cookies are present', async () => {
1823
document.cookie = ''
19-
expect(getCsrfToken()).toBeUndefined()
24+
const result = await getCsrfToken()
25+
expect(result).toBe('abc123')
2026
})
2127

22-
test('returns undefined when csrftoken cookie is not present', () => {
28+
test('returns new csrftoken when csrftoken cookie is not present', async () => {
2329
document.cookie = 'someid=xyz789; othercookie=123'
24-
expect(getCsrfToken()).toBeUndefined()
30+
const result = await getCsrfToken()
31+
expect(result).toBe('abc123')
2532
})
2633

27-
test('returns first csrftoken value when multiple cookies exist', () => {
34+
test('returns first csrftoken value when multiple cookies exist', async () => {
2835
document.cookie = 'csrftoken=first; csrftoken=second; otherid=xyz789'
29-
expect(getCsrfToken()).toBe('first')
36+
const result = await getCsrfToken()
37+
expect(result).toBe('first')
3038
})
3139

32-
test('handles cookie with no value', () => {
40+
test('handles cookie with no value', async () => {
3341
document.cookie = 'csrftoken=; otherid=xyz789'
34-
expect(getCsrfToken()).toBe('')
42+
const result = await getCsrfToken()
43+
expect(result).toBe('abc123')
3544
})
3645

37-
test('handles malformed cookie string', () => {
46+
test('handles malformed cookie string', async () => {
3847
document.cookie = 'csrftoken; otherid=xyz789'
39-
expect(getCsrfToken()).toBeUndefined()
48+
const result = await getCsrfToken()
49+
expect(result).toBe('abc123')
4050
})
4151
})

frontend/src/api/fetchAlgoliaData.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ export const fetchAlgoliaData = async <T>(
2020
method: 'POST',
2121
headers: {
2222
'Content-Type': 'application/json',
23-
'X-CSRFToken': getCsrfToken() || '',
23+
'X-CSRFToken': (await getCsrfToken()) || '',
2424
},
2525
credentials: 'include',
2626
body: JSON.stringify({

frontend/src/api/getCsrfToken.ts

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import { AppError } from 'wrappers/ErrorWrapper'
2+
3+
export const getInitialCsrfToken = async () => {
4+
try {
5+
const response = await fetch('http://localhost:8000/csrf', {
6+
method: 'GET',
7+
})
8+
9+
if (!response.ok) {
10+
throw new AppError(response.status, 'Failed to fetch CSRF token')
11+
}
12+
13+
const data = await response.json()
14+
document.cookie = `csrftoken=${data.csrftoken}; path=/; SameSite=Lax`
15+
return data.csrftoken
16+
} catch (error) {
17+
if (error instanceof AppError) {
18+
throw error
19+
}
20+
throw new AppError(500, 'Internal server error')
21+
}
22+
}

frontend/src/utils/helpers/apolloClient.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ const createApolloClient = () => {
1616
uri: GRAPHQL_URL,
1717
})
1818

19-
const authLink = setContext((_, { headers }) => {
19+
const authLink = setContext(async (_, { headers }) => {
2020
return {
2121
headers: {
2222
...headers,
23-
'X-CSRFToken': getCsrfToken() || '',
23+
'X-CSRFToken': (await getCsrfToken()) || '',
2424
},
2525
}
2626
})

frontend/src/utils/utility.ts

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { getInitialCsrfToken } from 'api/getCsrfToken'
12
import { type ClassValue, clsx } from 'clsx'
23
import dayjs from 'dayjs'
34
import relativeTime from 'dayjs/plugin/relativeTime'
@@ -65,10 +66,17 @@ export type IndexedObject = {
6566
[key: string]: unknown
6667
}
6768

68-
export const getCsrfToken = (): string | undefined => {
69-
return document.cookie
69+
export const getCsrfToken = async () => {
70+
const csrfToken = document.cookie
7071
.split(';')
7172
.map((cookie) => cookie.split('='))
7273
.find(([key]) => key.trim() === 'csrftoken')?.[1]
7374
?.trim()
75+
76+
if (csrfToken) {
77+
return csrfToken
78+
}
79+
80+
const res = await getInitialCsrfToken()
81+
return res
7482
}

0 commit comments

Comments
 (0)