Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
package com.amazonaws.serverless.proxy.internal.servlet;

import com.amazonaws.serverless.proxy.internal.SecurityUtils;
import jakarta.servlet.http.Cookie;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.time.Instant;
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.util.*;

/**
* Implementation of the CookieProcessor interface that provides cookie parsing and generation functionality.
*/
public class AwsCookieProcessor implements CookieProcessor {

// Cookie attribute constants
static final String COOKIE_COMMENT_ATTR = "Comment";
static final String COOKIE_DOMAIN_ATTR = "Domain";
static final String COOKIE_MAX_AGE_ATTR = "Max-Age";
static final String COOKIE_PATH_ATTR = "Path";
static final String COOKIE_SECURE_ATTR = "Secure";
static final String COOKIE_HTTP_ONLY_ATTR = "HttpOnly";
static final String COOKIE_SAME_SITE_ATTR = "SameSite";
static final String COOKIE_PARTITIONED_ATTR = "Partitioned";
static final String EMPTY_STRING = "";

// BitSet to store valid token characters as defined in RFC 2616
static final BitSet tokenValid = createTokenValidSet();

// BitSet to validate domain characters
static final BitSet domainValid = createDomainValidSet();

static final DateTimeFormatter COOKIE_DATE_FORMATTER = DateTimeFormatter.RFC_1123_DATE_TIME.withZone(ZoneId.of("GMT"));

static final String ANCIENT_DATE = COOKIE_DATE_FORMATTER.format(Instant.ofEpochMilli(10000));

static BitSet createTokenValidSet() {
BitSet tokenSet = new BitSet(128);
for (char c = '0'; c <= '9'; c++) tokenSet.set(c);
for (char c = 'a'; c <= 'z'; c++) tokenSet.set(c);
for (char c = 'A'; c <= 'Z'; c++) tokenSet.set(c);
for (char c : "!#$%&'*+-.^_`|~".toCharArray()) tokenSet.set(c);
return tokenSet;
}

static BitSet createDomainValidSet() {
BitSet domainValid = new BitSet(128);
for (char c = '0'; c <= '9'; c++) domainValid.set(c);
for (char c = 'a'; c <= 'z'; c++) domainValid.set(c);
for (char c = 'A'; c <= 'Z'; c++) domainValid.set(c);
domainValid.set('.');
domainValid.set('-');
return domainValid;
}

private final Logger log = LoggerFactory.getLogger(AwsCookieProcessor.class);

@Override
public Cookie[] parseCookieHeader(String cookieHeader) {
// Return an empty array if the input is null or empty after trimming
if (cookieHeader == null || cookieHeader.trim().isEmpty()) {
return new Cookie[0];
}

// Parse cookie header and convert to Cookie array
return Arrays.stream(cookieHeader.split("\\s*;\\s*"))
.map(this::parseCookiePair)
.filter(Objects::nonNull) // Filter out invalid pairs
.toArray(Cookie[]::new);
}

/**
* Parse a single cookie pair (name=value).
*
* @param cookiePair The cookie pair string.
* @return A valid Cookie object or null if the pair is invalid.
*/
private Cookie parseCookiePair(String cookiePair) {
String[] kv = cookiePair.split("=", 2);

if (kv.length != 2) {
log.warn("Ignoring invalid cookie: {}", cookiePair);
return null; // Skip malformed cookie pairs
}

String cookieName = kv[0];
String cookieValue = kv[1];

// Validate name and value
if (!isToken(cookieName)){
log.warn("Ignoring cookie with invalid name: {}={}", cookieName, cookieValue);
return null; // Skip invalid cookie names
}

if (!isValidCookieValue(cookieValue)) {
log.warn("Ignoring cookie with invalid value: {}={}", cookieName, cookieValue);
return null; // Skip invalid cookie values
}

// Return a new Cookie object after security processing
return new Cookie(SecurityUtils.crlf(cookieName), SecurityUtils.crlf(cookieValue));
}

@Override
public String generateHeader(Cookie cookie) {
StringBuffer header = new StringBuffer();
header.append(cookie.getName()).append('=');

String value = cookie.getValue();
if (value != null && value.length() > 0) {
validateCookieValue(value);
header.append(value);
}

int maxAge = cookie.getMaxAge();
if (maxAge > -1) {
header.append("; Expires=");
if (maxAge == 0) {
header.append(ANCIENT_DATE);
} else {
Instant expiresAt = Instant.now().plusSeconds(maxAge);
header.append(COOKIE_DATE_FORMATTER.format(expiresAt));
header.append("; Max-Age=").append(maxAge);
}
}

String domain = cookie.getDomain();
if (domain != null && !domain.isEmpty()) {
validateDomain(domain);
header.append("; Domain=").append(domain);
}

String path = cookie.getPath();
if (path != null && !path.isEmpty()) {
validatePath(path);
header.append("; Path=").append(path);
}

if (cookie.getSecure()) {
header.append("; Secure");
}

if (cookie.isHttpOnly()) {
header.append("; HttpOnly");
}

String sameSite = cookie.getAttribute(COOKIE_SAME_SITE_ATTR);
if (sameSite != null) {
header.append("; SameSite=").append(sameSite);
}

String partitioned = cookie.getAttribute(COOKIE_PARTITIONED_ATTR);
if (EMPTY_STRING.equals(partitioned)) {
header.append("; Partitioned");
}

addAdditionalAttributes(cookie, header);

return header.toString();
}

private void addAdditionalAttributes(Cookie cookie, StringBuffer header) {
for (Map.Entry<String, String> entry : cookie.getAttributes().entrySet()) {
switch (entry.getKey()) {
case COOKIE_COMMENT_ATTR:
case COOKIE_DOMAIN_ATTR:
case COOKIE_MAX_AGE_ATTR:
case COOKIE_PATH_ATTR:
case COOKIE_SECURE_ATTR:
case COOKIE_HTTP_ONLY_ATTR:
case COOKIE_SAME_SITE_ATTR:
case COOKIE_PARTITIONED_ATTR:
// Already handled attributes are ignored
break;
default:
validateAttribute(entry.getKey(), entry.getValue());
header.append("; ").append(entry.getKey());
if (!EMPTY_STRING.equals(entry.getValue())) {
header.append('=').append(entry.getValue());
}
break;
}
}
}

private void validateCookieValue(String value) {
if (!isValidCookieValue(value)) {
throw new IllegalArgumentException("Invalid cookie value: " + value);
}
}

private void validateDomain(String domain) {
if (!isValidDomain(domain)) {
throw new IllegalArgumentException("Invalid cookie domain: " + domain);
}
}

private void validatePath(String path) {
for (char ch : path.toCharArray()) {
if (ch < 0x20 || ch > 0x7E || ch == ';') {
throw new IllegalArgumentException("Invalid cookie path: " + path);
}
}
}

private void validateAttribute(String name, String value) {
if (!isToken(name)) {
throw new IllegalArgumentException("Invalid cookie attribute name: " + name);
}

for (char ch : value.toCharArray()) {
if (ch < 0x20 || ch > 0x7E || ch == ';') {
throw new IllegalArgumentException("Invalid cookie attribute value: " + ch);
}
}
}

private boolean isValidCookieValue(String value) {
int start = 0;
int end = value.length();
boolean quoted = end > 1 && value.charAt(0) == '"' && value.charAt(end - 1) == '"';

char[] chars = value.toCharArray();
for (int i = start; i < end; i++) {
if (quoted && (i == start || i == end - 1)) {
continue;
}
char c = chars[i];
if (!isValidCookieChar(c)) return false;
}
return true;
}

private boolean isValidDomain(String domain) {
if (domain.isEmpty()) {
return false;
}
int prev = -1;
for (char c : domain.toCharArray()) {
if (!domainValid.get(c) || isInvalidLabelStartOrEnd(prev, c)) {
return false;
}
prev = c;
}
return prev != '.' && prev != '-';
}

private boolean isInvalidLabelStartOrEnd(int prev, char current) {
return (prev == '.' || prev == -1) && (current == '.' || current == '-') ||
(prev == '-' && current == '.');
}

private boolean isToken(String s) {
if (s.isEmpty()) return false;
for (char c : s.toCharArray()) {
if (!tokenValid.get(c)) {
return false;
}
}
return true;
}

private boolean isValidCookieChar(char c) {
return !(c < 0x21 || c > 0x7E || c == 0x22 || c == 0x2c || c == 0x3b || c == 0x5c);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import java.time.ZonedDateTime;
import java.time.format.DateTimeParseException;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class AwsHttpApiV2ProxyHttpServletRequest extends AwsHttpServletRequest {
Expand Down Expand Up @@ -81,26 +80,14 @@ public Cookie[] getCookies() {
if (headers == null || !headers.containsKey(HttpHeaders.COOKIE)) {
rhc = new Cookie[0];
} else {
rhc = parseCookieHeaderValue(headers.getFirst(HttpHeaders.COOKIE));
rhc = getCookieProcessor().parseCookieHeader(headers.getFirst(HttpHeaders.COOKIE));
}

Cookie[] rc;
if (request.getCookies() == null) {
rc = new Cookie[0];
} else {
rc = request.getCookies().stream()
.map(c -> {
int i = c.indexOf('=');
if (i == -1) {
return null;
} else {
String k = SecurityUtils.crlf(c.substring(0, i)).trim();
String v = SecurityUtils.crlf(c.substring(i+1));
return new Cookie(k, v);
}
})
.filter(c -> c != null)
.toArray(Cookie[]::new);
rc = getCookieProcessor().parseCookieHeader(String.join("; ", request.getCookies()));
}

return Stream.concat(Arrays.stream(rhc), Arrays.stream(rc)).toArray(Cookie[]::new);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ public abstract class AwsHttpServletRequest implements HttpServletRequest {
private String queryString;
private Map<String, List<Part>> multipartFormParameters;
private Map<String, List<String>> urlEncodedFormParameters;
private CookieProcessor cookieProcessor;

protected AwsHttpServletResponse response;
protected AwsLambdaServletContainerHandler containerHandler;
Expand Down Expand Up @@ -295,12 +296,7 @@ public void setServletContext(ServletContext context) {
* @return An array of Cookie objects from the header
*/
protected Cookie[] parseCookieHeaderValue(String headerValue) {
List<HeaderValue> parsedHeaders = this.parseHeaderValue(headerValue, ";", ",");

return parsedHeaders.stream()
.filter(e -> e.getKey() != null)
.map(e -> new Cookie(SecurityUtils.crlf(e.getKey()), SecurityUtils.crlf(e.getValue())))
.toArray(Cookie[]::new);
return getCookieProcessor().parseCookieHeader(headerValue);
}


Expand Down Expand Up @@ -512,6 +508,13 @@ protected Map<String, List<String>> getFormUrlEncodedParametersMap() {
return urlEncodedFormParameters;
}

protected CookieProcessor getCookieProcessor(){
if (cookieProcessor == null) {
cookieProcessor = new AwsCookieProcessor();
}
return cookieProcessor;
}

@Override
public Collection<Part> getParts()
throws IOException, ServletException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ public class AwsHttpServletResponse
private CountDownLatch writersCountDownLatch;
private HttpServletRequest request;
private boolean isCommitted = false;
private CookieProcessor cookieProcessor;

private Logger log = LoggerFactory.getLogger(AwsHttpServletResponse.class);

Expand Down Expand Up @@ -102,33 +103,7 @@ public void addCookie(Cookie cookie) {
if (request != null && request.getDispatcherType() == DispatcherType.INCLUDE && isCommitted()) {
throw new IllegalStateException("Cannot add Cookies for include request when response is committed");
}
String cookieData = cookie.getName() + "=" + cookie.getValue();
if (cookie.getPath() != null) {
cookieData += "; Path=" + cookie.getPath();
}
if (cookie.getSecure()) {
cookieData += "; Secure";
}
if (cookie.isHttpOnly()) {
cookieData += "; HttpOnly";
}
if (cookie.getDomain() != null && !"".equals(cookie.getDomain().trim())) {
cookieData += "; Domain=" + cookie.getDomain();
}

if (cookie.getMaxAge() > 0) {
cookieData += "; Max-Age=" + cookie.getMaxAge();

// we always set the timezone to GMT
TimeZone gmtTimeZone = TimeZone.getTimeZone(COOKIE_DEFAULT_TIME_ZONE);
Calendar currentTimestamp = Calendar.getInstance(gmtTimeZone);
currentTimestamp.add(Calendar.SECOND, cookie.getMaxAge());
SimpleDateFormat cookieDateFormatter = new SimpleDateFormat(HEADER_DATE_PATTERN);
cookieDateFormatter.setTimeZone(gmtTimeZone);
cookieData += "; Expires=" + cookieDateFormatter.format(currentTimestamp.getTime());
}

setHeader(HttpHeaders.SET_COOKIE, cookieData, false);
setHeader(HttpHeaders.SET_COOKIE, getCookieProcessor().generateHeader(cookie), false);
}


Expand Down Expand Up @@ -500,6 +475,12 @@ AwsProxyRequest getAwsProxyRequest() {
return (AwsProxyRequest)request.getAttribute(API_GATEWAY_EVENT_PROPERTY);
}

CookieProcessor getCookieProcessor(){
if (cookieProcessor == null) {
cookieProcessor = new AwsCookieProcessor();
}
return cookieProcessor;
}

//-------------------------------------------------------------
// Methods - Private
Expand Down
Loading