Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
hokamsingh committed Aug 27, 2024
2 parents 859a0dd + d1c9ff0 commit f8383f6
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 16 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,7 @@ node_modules/

# Windows specific files
Thumbs.db

# Docker files
Dockerfile
docker-compose.yml
29 changes: 23 additions & 6 deletions internal/core/middleware/CSRF.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crypto/rand"
"encoding/base64"
"io"
"log"
"net/http"
)

Expand All @@ -16,13 +17,17 @@ func NewCSRFProtection() *CSRFProtection {
func (csrf *CSRFProtection) Handle(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet {
// Generate and set CSRF token for GET requests
token, err := GenerateCSRFToken()
// Retrieve or set CSRF token for GET requests
_, err := getCSRFCookie(r)
if err != nil {
http.Error(w, "Failed to generate CSRF token", http.StatusInternalServerError)
return
// Generate and set a new CSRF token if not present
token, err := GenerateCSRFToken()
if err != nil {
http.Error(w, "Failed to generate CSRF token", http.StatusInternalServerError)
return
}
SetCSRFCookie(w, token)
}
SetCSRFCookie(w, token)
} else if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodDelete {
// Validate CSRF token for state-changing requests
if !ValidateCSRFToken(r) {
Expand All @@ -34,6 +39,7 @@ func (csrf *CSRFProtection) Handle(next http.Handler) http.Handler {
})
}

// GenerateCSRFToken generates a new CSRF token.
func GenerateCSRFToken() (string, error) {
token := make([]byte, 32) // 32 bytes = 256 bits
if _, err := io.ReadFull(rand.Reader, token); err != nil {
Expand All @@ -53,11 +59,22 @@ func SetCSRFCookie(w http.ResponseWriter, token string) {
})
}

// getCSRFCookie retrieves the CSRF token from the cookie, if present.
func getCSRFCookie(r *http.Request) (string, error) {
cookie, err := r.Cookie("csrf_token")
if err != nil {
return "", err
}
return cookie.Value, nil
}

// ValidateCSRFToken validates the CSRF token from the request header or form data.
func ValidateCSRFToken(r *http.Request) bool {
cookie, err := r.Cookie("csrf_token")
if err != nil {
log.Printf("Error retrieving CSRF cookie: %v", err)
return false
}
csrfToken := r.Header.Get("X-CSRF-Token") // Or retrieve from form data
csrfToken := r.Header.Get("X-CSRF-Token") // Retrieve from request header
return csrfToken == cookie.Value
}
42 changes: 33 additions & 9 deletions internal/core/middleware/cacher.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package middleware

import (
"context"
"log"
"net/http"
"time"

Expand All @@ -14,9 +15,14 @@ type Caching struct {
}

func NewCaching(redisAddr string, ttl time.Duration) *Caching {
ctx := context.Background()
client := redis.NewClient(&redis.Options{
Addr: redisAddr, // e.g., "localhost:6379"
})
_, err := client.Ping(ctx).Result()
if err != nil {
log.Fatalf("Could not connect to Redis: %v", err)
}
return &Caching{
client: client,
ttl: ttl,
Expand All @@ -27,20 +33,31 @@ func (c *Caching) Handle(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.Background()

// Try to get the cached response from Redis
data, err := c.client.Get(ctx, r.RequestURI).Result()
if err == nil {
// If found in cache, write it directly to the response
w.Write([]byte(data))
return
if r.Method == http.MethodGet {
// Try to get the cached response from Redis
data, err := c.client.Get(ctx, r.RequestURI).Result()
if err == nil {
// If found in cache, write it directly to the response
w.Header().Set("X-Cache-Hit", "true")
w.Write([]byte(data))
return
} else if err != redis.Nil {
// Log any errors retrieving from Redis
log.Printf("Error retrieving from cache: %v", err)
}
}

// If not cached, create a response writer to capture the response
// Create a response writer to capture the response
rec := &responseRecorder{ResponseWriter: w, statusCode: http.StatusOK}
next.ServeHTTP(rec, r)

// Cache the response in Redis
c.client.Set(ctx, r.RequestURI, rec.body, c.ttl)
if r.Method == http.MethodGet {
// Cache the response in Redis
err := c.client.Set(ctx, r.RequestURI, rec.body, c.ttl).Err()
if err != nil {
log.Printf("Error setting cache: %v", err)
}
}
})
}

Expand All @@ -54,3 +71,10 @@ func (rec *responseRecorder) Write(p []byte) (int, error) {
rec.body = append(rec.body, p...)
return rec.ResponseWriter.Write(p)
}

// Implement the Flush method
func (rec *responseRecorder) Flush() {
if flusher, ok := rec.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
4 changes: 3 additions & 1 deletion internal/core/middleware/json_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ func NewJsonParser(options ParserOptions) *JSONParser {
}
}

type JsonKey string

func (jp *JSONParser) Handle(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Content-Type") == "application/json" {
Expand All @@ -70,7 +72,7 @@ func (jp *JSONParser) Handle(next http.Handler) http.Handler {
}

// Store the parsed JSON in the context
key := "jsonBody"
key := JsonKey("jsonBody")
r = r.WithContext(context.WithValue(r.Context(), key, body))
}
next.ServeHTTP(w, r)
Expand Down

0 comments on commit f8383f6

Please sign in to comment.