Skip to content

Commit

Permalink
Update DefaultWebInvocationPrivilegeEvaluator to use current ServletC…
Browse files Browse the repository at this point in the history
…ontext

Closes gh-10208
  • Loading branch information
marcusdacoregio committed Oct 14, 2021
1 parent 6db58cb commit faec20b
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,6 +29,7 @@
import java.util.Map;

import javax.servlet.FilterChain;
import javax.servlet.ServletContext;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
Expand Down Expand Up @@ -78,10 +79,19 @@ public FilterInvocation(String servletPath, String method) {
}

public FilterInvocation(String contextPath, String servletPath, String method) {
this(contextPath, servletPath, null, null, method);
this(contextPath, servletPath, method, null);
}

public FilterInvocation(String contextPath, String servletPath, String method, ServletContext servletContext) {
this(contextPath, servletPath, null, null, method, servletContext);
}

public FilterInvocation(String contextPath, String servletPath, String pathInfo, String query, String method) {
this(contextPath, servletPath, pathInfo, query, method, null);
}

public FilterInvocation(String contextPath, String servletPath, String pathInfo, String query, String method,
ServletContext servletContext) {
DummyRequest request = new DummyRequest();
contextPath = (contextPath != null) ? contextPath : "/cp";
request.setContextPath(contextPath);
Expand All @@ -90,6 +100,7 @@ public FilterInvocation(String contextPath, String servletPath, String pathInfo,
request.setPathInfo(pathInfo);
request.setQueryString(query);
request.setMethod(method);
request.setServletContext(servletContext);
this.request = request;
}

Expand Down Expand Up @@ -160,6 +171,8 @@ static class DummyRequest extends HttpServletRequestWrapper {

private String method;

private ServletContext servletContext;

private final HttpHeaders headers = new HttpHeaders();

private final Map<String, String[]> parameters = new LinkedHashMap<>();
Expand Down Expand Up @@ -290,6 +303,15 @@ void setParameter(String name, String... values) {
this.parameters.put(name, values);
}

@Override
public ServletContext getServletContext() {
return this.servletContext;
}

void setServletContext(ServletContext servletContext) {
this.servletContext = servletContext;
}

}

static final class UnsupportedOperationExceptionInvocationHandler implements InvocationHandler {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,6 +18,8 @@

import java.util.Collection;

import javax.servlet.ServletContext;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

Expand All @@ -28,6 +30,7 @@
import org.springframework.security.core.Authentication;
import org.springframework.security.web.FilterInvocation;
import org.springframework.util.Assert;
import org.springframework.web.context.ServletContextAware;

/**
* Allows users to determine whether they have privileges for a given web URI.
Expand All @@ -36,12 +39,14 @@
* @author Luke Taylor
* @since 3.0
*/
public class DefaultWebInvocationPrivilegeEvaluator implements WebInvocationPrivilegeEvaluator {
public class DefaultWebInvocationPrivilegeEvaluator implements WebInvocationPrivilegeEvaluator, ServletContextAware {

protected static final Log logger = LogFactory.getLog(DefaultWebInvocationPrivilegeEvaluator.class);

private final AbstractSecurityInterceptor securityInterceptor;

private ServletContext servletContext;

public DefaultWebInvocationPrivilegeEvaluator(AbstractSecurityInterceptor securityInterceptor) {
Assert.notNull(securityInterceptor, "SecurityInterceptor cannot be null");
Assert.isTrue(FilterInvocation.class.equals(securityInterceptor.getSecureObjectClass()),
Expand Down Expand Up @@ -82,7 +87,7 @@ public boolean isAllowed(String uri, Authentication authentication) {
@Override
public boolean isAllowed(String contextPath, String uri, String method, Authentication authentication) {
Assert.notNull(uri, "uri parameter is required");
FilterInvocation filterInvocation = new FilterInvocation(contextPath, uri, method);
FilterInvocation filterInvocation = new FilterInvocation(contextPath, uri, method, this.servletContext);
Collection<ConfigAttribute> attributes = this.securityInterceptor.obtainSecurityMetadataSource()
.getAttributes(filterInvocation);
if (attributes == null) {
Expand All @@ -101,4 +106,9 @@ public boolean isAllowed(String contextPath, String uri, String method, Authenti
}
}

@Override
public void setServletContext(ServletContext servletContext) {
this.servletContext = servletContext;
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,6 +24,7 @@

import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockServletContext;
import org.springframework.security.web.FilterInvocation.DummyRequest;
import org.springframework.security.web.util.UrlUtils;

Expand Down Expand Up @@ -131,4 +132,14 @@ public void dummyRequestIsSupportedByUrlUtils() {
UrlUtils.buildRequestUrl(request);
}

@Test
public void constructorWhenServletContextProvidedThenSetServletContextInRequest() {
String contextPath = "";
String servletPath = "/path";
String method = "";
MockServletContext mockServletContext = new MockServletContext();
FilterInvocation filterInvocation = new FilterInvocation(contextPath, servletPath, method, mockServletContext);
assertThat(filterInvocation.getRequest().getServletContext()).isSameAs(mockServletContext);
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,25 +18,30 @@

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;

import org.springframework.context.ApplicationEventPublisher;
import org.springframework.mock.web.MockServletContext;
import org.springframework.security.access.AccessDecisionManager;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.access.intercept.RunAsManager;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.FilterInvocation;
import org.springframework.security.web.access.intercept.FilterInvocationSecurityMetadataSource;
import org.springframework.security.web.access.intercept.FilterSecurityInterceptor;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.anyObject;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.willThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;

/**
* Tests
Expand Down Expand Up @@ -106,4 +111,17 @@ public void deniesAccessIfAccessDecisionManagerDoes() {
assertThat(wipe.isAllowed("/foo/index.jsp", token)).isFalse();
}

@Test
public void isAllowedWhenServletContextIsSetThenPassedFilterInvocationHasServletContext() {
Authentication token = new TestingAuthenticationToken("test", "Password", "MOCK_INDEX");
MockServletContext servletContext = new MockServletContext();
ArgumentCaptor<FilterInvocation> filterInvocationArgumentCaptor = ArgumentCaptor
.forClass(FilterInvocation.class);
DefaultWebInvocationPrivilegeEvaluator wipe = new DefaultWebInvocationPrivilegeEvaluator(this.interceptor);
wipe.setServletContext(servletContext);
wipe.isAllowed("/foo/index.jsp", token);
verify(this.adm).decide(eq(token), filterInvocationArgumentCaptor.capture(), any());
assertThat(filterInvocationArgumentCaptor.getValue().getRequest().getServletContext()).isNotNull();
}

}

0 comments on commit faec20b

Please sign in to comment.