Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: replace default spring x509 authentication in zaas #3971

Merged
merged 13 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ private boolean checkPermission(String userId, String resourceType, String resou
@Override
public boolean hasSafResourceAccess(Authentication authentication, String resourceClass, String resourceName, String accessLevel) {
String userid = authentication.getName();
if (StringUtils.isEmpty(userid)) {
if (StringUtils.isEmpty(userid) || userid.length() > 8) {
log.debug("UserId {} is not valid for SAF permissions check", userid);
return false;
}
AccessLevel level = AccessLevel.valueOf(accessLevel);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ public class CategorizeCertsFilter extends OncePerRequestFilter {
*/
private void categorizeCerts(ServletRequest request) {
X509Certificate[] certs = (X509Certificate[]) request.getAttribute(ATTRNAME_JAKARTA_SERVLET_REQUEST_X509_CERTIFICATE);
if (certs != null) {
if (certs != null && certs.length > 0 && certs[0] != null) {
Optional<Certificate> clientCert = getClientCertFromHeader((HttpServletRequest) request);
if (certificateValidator.isForwardingEnabled() && certificateValidator.isTrusted(certs) && clientCert.isPresent()) {
if (certificateValidator.isForwardingEnabled() && certificateValidator.hasGatewayChain(certs) && clientCert.isPresent()) {
certificateValidator.updateAPIMLPublicKeyCertificates(certs);
// add the client certificate to the certs array
String subjectDN = ((X509Certificate) clientCert.get()).getSubjectX500Principal().getName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@

package org.zowe.apiml.security.common.login;

import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;

import jakarta.servlet.*;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;

/**
Expand All @@ -41,9 +44,7 @@ public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)

try {
authResult = attemptAuthentication(request, response);
}

catch (AuthenticationException failed) {
} catch (AuthenticationException failed) {
// Authentication failed
unsuccessfulAuthentication(request, response, failed);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,21 @@

package org.zowe.apiml.security.common.login;

import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;

import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationToken;

import java.io.IOException;

public class X509AuthAwareFilter extends X509AuthenticationFilter {
public class X509AuthAwareFilter extends X509ForwardingAwareAuthenticationFilter {
private final AuthenticationFailureHandler failureHandler;

public X509AuthAwareFilter(String endpoint, AuthenticationFailureHandler failureHandler, AuthenticationProvider authenticationProvider) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,28 @@

package org.zowe.apiml.security.common.login;

import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.zowe.apiml.security.common.token.X509AuthenticationToken;

import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.security.cert.X509Certificate;

@Slf4j
public class X509AuthenticationFilter extends NonCompulsoryAuthenticationProcessingFilter {
public class X509ForwardingAwareAuthenticationFilter extends NonCompulsoryAuthenticationProcessingFilter {

private final AuthenticationProvider authenticationProvider;
private final AuthenticationSuccessHandler successHandler;

public X509AuthenticationFilter(String endpoint,
AuthenticationSuccessHandler successHandler,
AuthenticationProvider authenticationProvider) {
public X509ForwardingAwareAuthenticationFilter(String endpoint,
AuthenticationSuccessHandler successHandler,
AuthenticationProvider authenticationProvider) {
super(endpoint);
this.authenticationProvider = authenticationProvider;
this.successHandler = successHandler;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public CertificateValidator(TrustedCertificatesProvider trustedCertificatesProvi
* @param certs Certificates to compare with known trusted ones
* @return true if all given certificates are known false otherwise
*/
public boolean isTrusted(X509Certificate[] certs) {
public boolean hasGatewayChain(X509Certificate[] certs) {
if ((proxyCertificatesEndpoints == null) || (proxyCertificatesEndpoints.length == 0)) {
log.debug("No endpoint configured to retrieve trusted certificates. Provide URL via apiml.security.x509.certificatesUrls");
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.NullSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.*;

@ExtendWith(MockitoExtension.class)
class SafResourceAccessSafTest {
Expand Down Expand Up @@ -92,6 +93,15 @@ void testHasSafResourceAccess_whenNoResponse_thenTrue() {
void testHasSafResourceAccess_whenUseridEmpty_thenFalse() {
assertFalse(safResourceAccessVerifying.hasSafResourceAccess(new UsernamePasswordAuthenticationToken("", "token"), CLASS, RESOURCE, LEVEL.name()));
}

@ParameterizedTest
@NullSource
@ValueSource(strings = {"", "tooLongUserId"})
void testInvalidUserIds_thenSkipped(String userId) {
var auth = new UsernamePasswordAuthenticationToken(userId, "");
assertFalse(assertDoesNotThrow(() -> safResourceAccessVerifying.hasSafResourceAccess(auth, CLASS, RESOURCE, LEVEL.name())));
verify(checkPermissionMock, never()).checkPermission(any(), any(), any(), anyInt());
}

@Builder
public static class TestPlatformReturned {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

package org.zowe.apiml.security.common.filter;

import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
Expand All @@ -20,8 +22,6 @@
import org.zowe.apiml.security.common.utils.X509Utils;
import org.zowe.apiml.security.common.verify.CertificateValidator;

import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
Expand Down Expand Up @@ -86,7 +86,7 @@ public void setUp() {
chain = new MockFilterChain();
certificateValidator = mock(CertificateValidator.class);
when(certificateValidator.isForwardingEnabled()).thenReturn(false);
when(certificateValidator.isTrusted(any())).thenReturn(false);
when(certificateValidator.hasGatewayChain(any())).thenReturn(false);
}

@Nested
Expand Down Expand Up @@ -118,7 +118,7 @@ class WhenForwardingEnabled {
@BeforeEach
void setUp() {
when(certificateValidator.isForwardingEnabled()).thenReturn(true);
when(certificateValidator.isTrusted(any())).thenReturn(true);
when(certificateValidator.hasGatewayChain(any())).thenReturn(true);
}

@Test
Expand Down Expand Up @@ -200,7 +200,7 @@ public void setUp() {

@Test
void givenTrustedCerts_thenClientCertHeaderAccepted() throws ServletException, IOException {
when(certificateValidator.isTrusted(certificates)).thenReturn(true);
when(certificateValidator.hasGatewayChain(certificates)).thenReturn(true);
// when incoming certs are all trusted means that all their public keys are added to the filter
filter.getPublicKeyCertificatesBase64().add(X509Utils.correctBase64("apimlCert1"));
filter.getPublicKeyCertificatesBase64().add(X509Utils.correctBase64("apimlCertCA"));
Expand All @@ -225,7 +225,7 @@ void givenTrustedCerts_thenClientCertHeaderAccepted() throws ServletException, I

@Test
void givenNotTrustedCerts_thenClientCertHeaderIgnored() throws ServletException, IOException {
when(certificateValidator.isTrusted(certificates)).thenReturn(false);
when(certificateValidator.hasGatewayChain(certificates)).thenReturn(false);
filter.doFilter(request, response, chain);
HttpServletRequest nextRequest = (HttpServletRequest) chain.getRequest();
assertNotNull(nextRequest);
Expand Down Expand Up @@ -280,7 +280,7 @@ class WhenInvalidCertificateInHeaderAndForwardingEnabled {
public void setUp() {
request.addHeader(CLIENT_CERT_HEADER, "invalid_cert");
when(certificateValidator.isForwardingEnabled()).thenReturn(true);
when(certificateValidator.isTrusted(certificates)).thenReturn(true);
when(certificateValidator.hasGatewayChain(certificates)).thenReturn(true);
}

@Test
Expand Down Expand Up @@ -410,7 +410,7 @@ public void setUp() {

@Test
void givenTrustedCerts_thenClientCertHeaderAccepted() throws ServletException, IOException {
when(certificateValidator.isTrusted(certificates)).thenReturn(true);
when(certificateValidator.hasGatewayChain(certificates)).thenReturn(true);
// when incoming certs are all trusted means that all their public keys are added to the filter
filter.getPublicKeyCertificatesBase64().add(X509Utils.correctBase64("apimlCert1"));
filter.getPublicKeyCertificatesBase64().add(X509Utils.correctBase64("apimlCertCA"));
Expand All @@ -435,7 +435,7 @@ void givenTrustedCerts_thenClientCertHeaderAccepted() throws ServletException, I

@Test
void givenNotTrustedCerts_thenClientCertHeaderIgnored() throws ServletException, IOException {
when(certificateValidator.isTrusted(certificates)).thenReturn(false);
when(certificateValidator.hasGatewayChain(certificates)).thenReturn(false);
filter.doFilter(request, response, chain);
HttpServletRequest nextRequest = (HttpServletRequest) chain.getRequest();
assertNotNull(nextRequest);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,20 @@ class WhenTrustedCertsProvided {

@BeforeEach
void setUp() {
ReflectionTestUtils.setField(certificateValidator, "proxyCertificatesEndpoints", new String[] {URL_PROVIDE_TWO_TRUSTED_CERTS});
ReflectionTestUtils.setField(certificateValidator, "proxyCertificatesEndpoints", new String[]{URL_PROVIDE_TWO_TRUSTED_CERTS});
}
@Test
void whenAllCertificatesFoundThenTheyAreTrusted() {
assertTrue(certificateValidator.isTrusted(new X509Certificate[]{cert1}));
assertTrue(certificateValidator.isTrusted(new X509Certificate[]{cert2}));
assertTrue(certificateValidator.isTrusted(new X509Certificate[]{cert1, cert2}));
assertTrue(certificateValidator.hasGatewayChain(new X509Certificate[]{cert1}));
assertTrue(certificateValidator.hasGatewayChain(new X509Certificate[]{cert2}));
assertTrue(certificateValidator.hasGatewayChain(new X509Certificate[]{cert1, cert2}));
}

@Test
void whenSomeCertificateNotFoundThenAllUntrusted() {
assertFalse(certificateValidator.isTrusted(new X509Certificate[]{cert3}));
assertFalse(certificateValidator.isTrusted(new X509Certificate[]{cert1, cert3}));
assertFalse(certificateValidator.isTrusted(new X509Certificate[]{cert2, cert3}));
assertFalse(certificateValidator.hasGatewayChain(new X509Certificate[]{cert3}));
assertFalse(certificateValidator.hasGatewayChain(new X509Certificate[]{cert1, cert3}));
assertFalse(certificateValidator.hasGatewayChain(new X509Certificate[]{cert2, cert3}));
}
}

Expand All @@ -78,13 +78,13 @@ class WhenNoTrustedCertsProvided {

@BeforeEach
void setUp() {
ReflectionTestUtils.setField(certificateValidator, "proxyCertificatesEndpoints", new String[] {URL_WITH_NO_TRUSTED_CERTS});
ReflectionTestUtils.setField(certificateValidator, "proxyCertificatesEndpoints", new String[]{URL_WITH_NO_TRUSTED_CERTS});
}
@Test
void thenAnyCertificateIsNotTrusted() {
assertFalse(certificateValidator.isTrusted(new X509Certificate[]{cert1}));
assertFalse(certificateValidator.isTrusted(new X509Certificate[]{cert2}));
assertFalse(certificateValidator.isTrusted(new X509Certificate[]{cert3}));
assertFalse(certificateValidator.hasGatewayChain(new X509Certificate[]{cert1}));
assertFalse(certificateValidator.hasGatewayChain(new X509Certificate[]{cert2}));
assertFalse(certificateValidator.hasGatewayChain(new X509Certificate[]{cert3}));
}
}

Expand Down Expand Up @@ -114,15 +114,15 @@ class WhenMultipleSources {

@BeforeEach
void setUp() {
ReflectionTestUtils.setField(certificateValidator, "proxyCertificatesEndpoints", new String[] {URL_PROVIDE_TWO_TRUSTED_CERTS, URL_PROVIDE_THIRD_TRUSTED_CERT});
ReflectionTestUtils.setField(certificateValidator, "proxyCertificatesEndpoints", new String[]{URL_PROVIDE_TWO_TRUSTED_CERTS, URL_PROVIDE_THIRD_TRUSTED_CERT});
}

@Test
void whenAllCertificatesFoundThenTheyAreTrusted() {
assertTrue(certificateValidator.isTrusted(new X509Certificate[]{cert1}));
assertTrue(certificateValidator.isTrusted(new X509Certificate[]{cert2}));
assertTrue(certificateValidator.isTrusted(new X509Certificate[]{cert3}));
assertTrue(certificateValidator.isTrusted(new X509Certificate[]{cert1, cert3}));
assertTrue(certificateValidator.hasGatewayChain(new X509Certificate[]{cert1}));
assertTrue(certificateValidator.hasGatewayChain(new X509Certificate[]{cert2}));
assertTrue(certificateValidator.hasGatewayChain(new X509Certificate[]{cert3}));
assertTrue(certificateValidator.hasGatewayChain(new X509Certificate[]{cert1, cert3}));
}

}
Expand All @@ -146,7 +146,7 @@ class OldPropertyValue {

@Test
void thenUrlIsSetAsListCorrectly() {
assertArrayEquals(new String[] {"url1"}, (String[]) ReflectionTestUtils.getField(certificateValidator, "proxyCertificatesEndpoints"));
assertArrayEquals(new String[]{"url1"}, (String[]) ReflectionTestUtils.getField(certificateValidator, "proxyCertificatesEndpoints"));
}

}
Expand All @@ -163,7 +163,7 @@ class NewPropertyValue {

@Test
void thenUrlsAreSetCorrectly() {
assertArrayEquals(new String[] {"url1", "url2"}, (String[]) ReflectionTestUtils.getField(certificateValidator, "proxyCertificatesEndpoints"));
assertArrayEquals(new String[]{"url1", "url2"}, (String[]) ReflectionTestUtils.getField(certificateValidator, "proxyCertificatesEndpoints"));
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public GatewayFilter apply(AcceptForwardedClientCertFilterFactory.Config config)
return (exchange, chain) -> {
SslInfo sslInfo = exchange.getRequest().getSslInfo();
X509Certificate[] x509Certificates = sslInfo == null ? null : sslInfo.getPeerCertificates();
if ((x509Certificates != null) && (x509Certificates.length > 0) && certificateValidator.isTrusted(x509Certificates)) {
if ((x509Certificates != null) && (x509Certificates.length > 0) && certificateValidator.hasGatewayChain(x509Certificates)) {
X509Certificate[] forwardedClientCertificate = getClientCertificateFromHeader(exchange.getRequest());
if (forwardedClientCertificate.length > 0) {
log.debug("Accepting forwarded client certificate {}", forwardedClientCertificate[0].getSubjectX500Principal().getName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class AcceptForwardedClientCertFilterFactoryTest {
@Test
void givenClientCertificateInHeader_andTrustedClientCertificateInHandshake_thenUpdateWebExchange() {
var validator = mock(CertificateValidator.class);
when(validator.isTrusted(any())).thenReturn(true);
when(validator.hasGatewayChain(any())).thenReturn(true);
var factory = new AcceptForwardedClientCertFilterFactory(validator);
var exchange = mock(ServerWebExchange.class);
var req = mock(ServerHttpRequest.class);
Expand All @@ -89,7 +89,7 @@ void givenClientCertificateInHeader_andTrustedClientCertificateInHandshake_thenU
@Test
void givenClientCertificateInHeader_andInvalidClientCertificateInHandshake_thenDoNothing() {
var validator = mock(CertificateValidator.class);
when(validator.isTrusted(any())).thenReturn(false);
when(validator.hasGatewayChain(any())).thenReturn(false);
var factory = new AcceptForwardedClientCertFilterFactory(validator);
var exchange = mock(ServerWebExchange.class);
var req = mock(ServerHttpRequest.class);
Expand Down
Loading
Loading