Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ public final class OpenSamlRelyingPartyRegistration extends RelyingPartyRegistra
registration.getAssertionConsumerServiceLocation(), registration.getAssertionConsumerServiceBinding(),
registration.getSingleLogoutServiceLocation(), registration.getSingleLogoutServiceResponseLocation(),
registration.getSingleLogoutServiceBindings(), registration.getAssertingPartyDetails(),
registration.getNameIdFormat(), registration.getDecryptionX509Credentials(),
registration.getSigningX509Credentials());
registration.getNameIdFormat(), registration.isAuthnRequestsSigned(),
registration.getDecryptionX509Credentials(), registration.getSigningX509Credentials());
}

/**
Expand All @@ -55,7 +55,7 @@ public OpenSamlRelyingPartyRegistration.Builder mutate() {
.singleLogoutServiceLocation(getSingleLogoutServiceLocation())
.singleLogoutServiceResponseLocation(getSingleLogoutServiceResponseLocation())
.singleLogoutServiceBindings((c) -> c.addAll(getSingleLogoutServiceBindings()))
.nameIdFormat(getNameIdFormat())
.nameIdFormat(getNameIdFormat()).authnRequestsSigned(isAuthnRequestsSigned())
.assertingPartyDetails((assertingParty) -> ((OpenSamlAssertingPartyDetails.Builder) assertingParty)
.entityId(party.getEntityId()).wantAuthnRequestsSigned(party.getWantAuthnRequestsSigned())
.signingAlgorithms((algorithms) -> algorithms.addAll(party.getSigningAlgorithms()))
Expand Down Expand Up @@ -152,6 +152,11 @@ public Builder nameIdFormat(String nameIdFormat) {
return (Builder) super.nameIdFormat(nameIdFormat);
}

@Override
public Builder authnRequestsSigned(Boolean authnRequestsSigned) {
return (Builder) super.authnRequestsSigned(authnRequestsSigned);
}

@Override
public Builder assertingPartyDetails(Consumer<AssertingPartyDetails.Builder> assertingPartyDetails) {
return (Builder) super.assertingPartyDetails(assertingPartyDetails);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ public class RelyingPartyRegistration {

private final String nameIdFormat;

private final boolean authnRequestsSigned;

private final AssertingPartyDetails assertingPartyDetails;

private final Collection<Saml2X509Credential> decryptionX509Credentials;
Expand All @@ -95,7 +97,7 @@ public class RelyingPartyRegistration {
protected RelyingPartyRegistration(String registrationId, String entityId, String assertionConsumerServiceLocation,
Saml2MessageBinding assertionConsumerServiceBinding, String singleLogoutServiceLocation,
String singleLogoutServiceResponseLocation, Collection<Saml2MessageBinding> singleLogoutServiceBindings,
AssertingPartyDetails assertingPartyDetails, String nameIdFormat,
AssertingPartyDetails assertingPartyDetails, String nameIdFormat, boolean authnRequestsSigned,
Collection<Saml2X509Credential> decryptionX509Credentials,
Collection<Saml2X509Credential> signingX509Credentials) {
Assert.hasText(registrationId, "registrationId cannot be empty");
Expand Down Expand Up @@ -124,6 +126,7 @@ protected RelyingPartyRegistration(String registrationId, String entityId, Strin
this.singleLogoutServiceResponseLocation = singleLogoutServiceResponseLocation;
this.singleLogoutServiceBindings = Collections.unmodifiableList(new LinkedList<>(singleLogoutServiceBindings));
this.nameIdFormat = nameIdFormat;
this.authnRequestsSigned = authnRequestsSigned;
this.assertingPartyDetails = assertingPartyDetails;
this.decryptionX509Credentials = Collections.unmodifiableList(new LinkedList<>(decryptionX509Credentials));
this.signingX509Credentials = Collections.unmodifiableList(new LinkedList<>(signingX509Credentials));
Expand All @@ -145,7 +148,7 @@ public Builder mutate() {
.singleLogoutServiceLocation(this.singleLogoutServiceLocation)
.singleLogoutServiceResponseLocation(this.singleLogoutServiceResponseLocation)
.singleLogoutServiceBindings((c) -> c.addAll(this.singleLogoutServiceBindings))
.nameIdFormat(this.nameIdFormat)
.nameIdFormat(this.nameIdFormat).authnRequestsSigned(this.authnRequestsSigned)
.assertingPartyDetails((assertingParty) -> assertingParty.entityId(party.getEntityId())
.wantAuthnRequestsSigned(party.getWantAuthnRequestsSigned())
.signingAlgorithms((algorithms) -> algorithms.addAll(party.getSigningAlgorithms()))
Expand Down Expand Up @@ -281,6 +284,23 @@ public String getNameIdFormat() {
return this.nameIdFormat;
}

/**
* Get the <a href=
* "https://docs.oasis-open.org/security/saml/v2.0/saml-metadata-2.0-os.pdf#page=18">
* AuthnRequestsSigned</a> setting. If {@code true}, the relying party will sign all
* AuthnRequests, regardless of asserting party preference.
*
* <p>
* Note that Spring Security will sign the request if either
* {@link #isAuthnRequestsSigned()} is {@code true} or
* {@link AssertingPartyDetails#getWantAuthnRequestsSigned()} is {@code true}.
* @return the relying-party preference
* @since 6.1
*/
public boolean isAuthnRequestsSigned() {
return this.authnRequestsSigned;
}

/**
* Get the {@link Collection} of decryption {@link Saml2X509Credential}s associated
* with this relying party
Expand Down Expand Up @@ -356,7 +376,7 @@ public static Builder withRelyingPartyRegistration(RelyingPartyRegistration regi
.singleLogoutServiceLocation(registration.getSingleLogoutServiceLocation())
.singleLogoutServiceResponseLocation(registration.getSingleLogoutServiceResponseLocation())
.singleLogoutServiceBindings((c) -> c.addAll(registration.getSingleLogoutServiceBindings()))
.nameIdFormat(registration.getNameIdFormat())
.nameIdFormat(registration.getNameIdFormat()).authnRequestsSigned(registration.isAuthnRequestsSigned())
.assertingPartyDetails((assertingParty) -> assertingParty
.entityId(registration.getAssertingPartyDetails().getEntityId())
.wantAuthnRequestsSigned(registration.getAssertingPartyDetails().getWantAuthnRequestsSigned())
Expand Down Expand Up @@ -788,6 +808,8 @@ public static class Builder {

private String nameIdFormat = null;

private boolean authnRequestsSigned = false;

private AssertingPartyDetails.Builder assertingPartyDetailsBuilder;

protected Builder(String registrationId, AssertingPartyDetails.Builder assertingPartyDetailsBuilder) {
Expand Down Expand Up @@ -974,6 +996,24 @@ public Builder nameIdFormat(String nameIdFormat) {
return this;
}

/**
* Set the <a href=
* "https://docs.oasis-open.org/security/saml/v2.0/saml-metadata-2.0-os.pdf#page=18">
* AuthnRequestsSigned</a> setting. If {@code true}, the relying party will sign
* all AuthnRequests, 301 asserting party preference.
*
* <p>
* Note that Spring Security will sign the request if either
* {@link #isAuthnRequestsSigned()} is {@code true} or
* {@link AssertingPartyDetails#getWantAuthnRequestsSigned()} is {@code true}.
* @return the {@link Builder} for further configuration
* @since 6.1
*/
public Builder authnRequestsSigned(Boolean authnRequestsSigned) {
this.authnRequestsSigned = authnRequestsSigned;
return this;
}

/**
* Apply this {@link Consumer} to further configure the Asserting Party details
* @param assertingPartyDetails The {@link Consumer} to apply
Expand Down Expand Up @@ -1003,8 +1043,8 @@ public RelyingPartyRegistration build() {
return new RelyingPartyRegistration(this.registrationId, this.entityId,
this.assertionConsumerServiceLocation, this.assertionConsumerServiceBinding,
this.singleLogoutServiceLocation, this.singleLogoutServiceResponseLocation,
this.singleLogoutServiceBindings, party, this.nameIdFormat, this.decryptionX509Credentials,
this.signingX509Credentials);
this.singleLogoutServiceBindings, party, this.nameIdFormat, this.authnRequestsSigned,
this.decryptionX509Credentials, this.signingX509Credentials);
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -142,7 +142,8 @@ <T extends AbstractSaml2AuthenticationRequest> T resolve(HttpServletRequest requ
String relayState = this.relayStateResolver.convert(request);
Saml2MessageBinding binding = registration.getAssertingPartyDetails().getSingleSignOnServiceBinding();
if (binding == Saml2MessageBinding.POST) {
if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) {
if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()
|| registration.isAuthnRequestsSigned()) {
OpenSamlSigningUtils.sign(authnRequest, registration);
}
String xml = serialize(authnRequest);
Expand All @@ -156,7 +157,8 @@ <T extends AbstractSaml2AuthenticationRequest> T resolve(HttpServletRequest requ
Saml2RedirectAuthenticationRequest.Builder builder = Saml2RedirectAuthenticationRequest
.withRelyingPartyRegistration(registration).samlRequest(deflatedAndEncoded).relayState(relayState)
.id(authnRequest.getID());
if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) {
if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()
|| registration.isAuthnRequestsSigned()) {
Map<String, String> parameters = OpenSamlSigningUtils.sign(registration)
.param(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded)
.param(Saml2ParameterNames.RELAY_STATE, relayState).parameters();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,7 +29,7 @@ public class RelyingPartyRegistrationTests {
@Test
public void withRelyingPartyRegistrationWorks() {
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration()
.nameIdFormat("format")
.nameIdFormat("format").authnRequestsSigned(true)
.assertingPartyDetails((a) -> a.singleSignOnServiceBinding(Saml2MessageBinding.POST))
.assertingPartyDetails((a) -> a.wantAuthnRequestsSigned(false))
.assertingPartyDetails((a) -> a.signingAlgorithms((algs) -> algs.add("alg")))
Expand Down Expand Up @@ -82,6 +82,7 @@ private void compareRegistrations(RelyingPartyRegistration registration, Relying
assertThat(copy.getAssertingPartyDetails().getSigningAlgorithms())
.isEqualTo(registration.getAssertingPartyDetails().getSigningAlgorithms());
assertThat(copy.getNameIdFormat()).isEqualTo(registration.getNameIdFormat());
assertThat(copy.isAuthnRequestsSigned()).isEqualTo(registration.isAuthnRequestsSigned());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,8 +16,13 @@

package org.springframework.security.saml2.provider.service.web.authentication;

import java.util.stream.Stream;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.opensaml.xmlsec.signature.support.SignatureConstants;

import org.springframework.mock.web.MockHttpServletRequest;
Expand Down Expand Up @@ -47,11 +52,15 @@ public void setUp() {
this.relyingPartyRegistrationBuilder = TestRelyingPartyRegistrations.relyingPartyRegistration();
}

@Test
public void resolveAuthenticationRequestWhenSignedRedirectThenSignsAndRedirects() {
@ParameterizedTest
@MethodSource("provideSignRequestFlags")
public void resolveAuthenticationRequestWhenSignedRedirectThenSignsAndRedirects(boolean wantAuthRequestsSigned,
boolean authnRequestsSigned) {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setPathInfo("/saml2/authenticate/registration-id");
RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder.build();
RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder
.authnRequestsSigned(authnRequestsSigned)
.assertingPartyDetails((party) -> party.wantAuthnRequestsSigned(wantAuthRequestsSigned)).build();
OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> {
UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
Expand Down Expand Up @@ -115,7 +124,7 @@ public void resolveAuthenticationRequestWhenUnsignedPostThenOnlyPosts() {
request.setPathInfo("/saml2/authenticate/registration-id");
RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder.assertingPartyDetails(
(party) -> party.singleSignOnServiceBinding(Saml2MessageBinding.POST).wantAuthnRequestsSigned(false))
.build();
.authnRequestsSigned(false).build();
OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
Saml2PostAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> {
UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
Expand All @@ -134,12 +143,17 @@ public void resolveAuthenticationRequestWhenUnsignedPostThenOnlyPosts() {
assertThat(result.getId()).isNotEmpty();
}

@Test
public void resolveAuthenticationRequestWhenSignedPostThenSignsAndPosts() {
@ParameterizedTest
@MethodSource("provideSignRequestFlags")
public void resolveAuthenticationRequestWhenSignedPostThenSignsAndPosts(boolean wantAuthRequestsSigned,
boolean authnRequestsSigned) {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setPathInfo("/saml2/authenticate/registration-id");
RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder
.assertingPartyDetails((party) -> party.singleSignOnServiceBinding(Saml2MessageBinding.POST)).build();
.authnRequestsSigned(authnRequestsSigned)
.assertingPartyDetails((party) -> party.singleSignOnServiceBinding(Saml2MessageBinding.POST)
.wantAuthnRequestsSigned(wantAuthRequestsSigned))
.build();
OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
Saml2PostAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> {
UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
Expand Down Expand Up @@ -180,4 +194,8 @@ private OpenSamlAuthenticationRequestResolver authenticationRequestResolver(Rely
return new OpenSamlAuthenticationRequestResolver((request, id) -> registration);
}

private static Stream<Arguments> provideSignRequestFlags() {
return Stream.of(Arguments.of(true, true), Arguments.of(true, false), Arguments.of(false, true));
}

}