diff --git a/docs/modules/ROOT/assets/images/servlet/architecture/filterchain.gif b/docs/modules/ROOT/assets/images/servlet/architecture/filterchain.gif index 7c82a4fe1ed..5ac6af1c853 100644 Binary files a/docs/modules/ROOT/assets/images/servlet/architecture/filterchain.gif and b/docs/modules/ROOT/assets/images/servlet/architecture/filterchain.gif differ diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java new file mode 100644 index 00000000000..86013f84467 --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java @@ -0,0 +1,324 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.authentication; + +import java.nio.charset.StandardCharsets; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.time.Clock; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import net.shibboleth.utilities.java.support.resolver.CriteriaSet; +import net.shibboleth.utilities.java.support.xml.SerializeSupport; +import org.joda.time.DateTime; +import org.opensaml.core.config.ConfigurationService; +import org.opensaml.core.xml.config.XMLObjectProviderRegistry; +import org.opensaml.core.xml.io.MarshallingException; +import org.opensaml.saml.common.SAMLObjectBuilder; +import org.opensaml.saml.common.xml.SAMLConstants; +import org.opensaml.saml.saml2.core.AuthnRequest; +import org.opensaml.saml.saml2.core.Issuer; +import org.opensaml.saml.saml2.core.NameIDPolicy; +import org.opensaml.saml.saml2.core.impl.AuthnRequestBuilder; +import org.opensaml.saml.saml2.core.impl.AuthnRequestMarshaller; +import org.opensaml.saml.saml2.core.impl.IssuerBuilder; +import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver; +import org.opensaml.security.SecurityException; +import org.opensaml.security.credential.BasicCredential; +import org.opensaml.security.credential.Credential; +import org.opensaml.security.credential.CredentialSupport; +import org.opensaml.security.credential.UsageType; +import org.opensaml.xmlsec.SignatureSigningParameters; +import org.opensaml.xmlsec.SignatureSigningParametersResolver; +import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion; +import org.opensaml.xmlsec.crypto.XMLSigningUtil; +import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration; +import org.opensaml.xmlsec.signature.support.SignatureConstants; +import org.opensaml.xmlsec.signature.support.SignatureSupport; +import org.w3c.dom.Element; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.OpenSamlInitializationService; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest.Builder; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.UriUtils; + +/** + * @since 5.2 + */ +public class OpenSamlAuthenticationRequestFactory implements Saml2AuthenticationRequestFactory { + + static { + OpenSamlInitializationService.initialize(); + } + + private Clock clock = Clock.systemUTC(); + + private AuthnRequestMarshaller marshaller; + + private AuthnRequestBuilder authnRequestBuilder; + + private IssuerBuilder issuerBuilder; + + private SAMLObjectBuilder nameIDBuilder; + + private Converter protocolBindingResolver = (context) -> { + if (context == null) { + return SAMLConstants.SAML2_POST_BINDING_URI; + } + return context.getRelyingPartyRegistration().getAssertionConsumerServiceBinding().getUrn(); + }; + + private Converter authenticationRequestContextConverter = this::createAuthnRequest; + + /** + * Creates an {@link OpenSamlAuthenticationRequestFactory} + */ + public OpenSamlAuthenticationRequestFactory() { + XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class); + this.marshaller = (AuthnRequestMarshaller) registry.getMarshallerFactory() + .getMarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME); + this.authnRequestBuilder = (AuthnRequestBuilder) registry.getBuilderFactory() + .getBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME); + this.issuerBuilder = (IssuerBuilder) registry.getBuilderFactory().getBuilder(Issuer.DEFAULT_ELEMENT_NAME); + this.nameIDBuilder = (SAMLObjectBuilder) registry.getBuilderFactory() + .getBuilder(NameIDPolicy.DEFAULT_ELEMENT_NAME); + } + + @Override + @Deprecated + public String createAuthenticationRequest(Saml2AuthenticationRequest request) { + AuthnRequest authnRequest = createAuthnRequest(request.getIssuer(), request.getDestination(), + request.getAssertionConsumerServiceUrl(), this.protocolBindingResolver.convert(null), null); + for (org.springframework.security.saml2.credentials.Saml2X509Credential credential : request.getCredentials()) { + if (credential.isSigningCredential()) { + X509Certificate certificate = credential.getCertificate(); + PrivateKey privateKey = credential.getPrivateKey(); + BasicCredential cred = CredentialSupport.getSimpleCredential(certificate, privateKey); + cred.setEntityId(request.getIssuer()); + cred.setUsageType(UsageType.SIGNING); + SignatureSigningParameters parameters = new SignatureSigningParameters(); + parameters.setSigningCredential(cred); + parameters.setSignatureAlgorithm(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); + parameters.setSignatureReferenceDigestMethod(SignatureConstants.ALGO_ID_DIGEST_SHA256); + parameters.setSignatureCanonicalizationAlgorithm(SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS); + return serialize(sign(authnRequest, parameters)); + } + } + throw new IllegalArgumentException("No signing credential provided"); + } + + @Override + public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2AuthenticationRequestContext context) { + AuthnRequest authnRequest = this.authenticationRequestContextConverter.convert(context); + String xml = context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned() + ? serialize(sign(authnRequest, context.getRelyingPartyRegistration())) : serialize(authnRequest); + + return Saml2PostAuthenticationRequest.withAuthenticationRequestContext(context) + .samlRequest(Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8))).build(); + } + + @Override + public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest( + Saml2AuthenticationRequestContext context) { + AuthnRequest authnRequest = this.authenticationRequestContextConverter.convert(context); + String xml = serialize(authnRequest); + Builder result = Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context); + String deflatedAndEncoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml)); + result.samlRequest(deflatedAndEncoded).relayState(context.getRelayState()); + if (context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned()) { + Map parameters = new LinkedHashMap<>(); + parameters.put("SAMLRequest", deflatedAndEncoded); + if (StringUtils.hasText(context.getRelayState())) { + parameters.put("RelayState", context.getRelayState()); + } + sign(parameters, context.getRelyingPartyRegistration()); + return result.sigAlg(parameters.get("SigAlg")).signature(parameters.get("Signature")).build(); + } + return result.build(); + } + + private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) { + return createAuthnRequest(context.getIssuer(), context.getDestination(), + context.getAssertionConsumerServiceUrl(), this.protocolBindingResolver.convert(context), + context.getRelyingPartyRegistration().getNameIdFormat()); + } + + private AuthnRequest createAuthnRequest(String issuer, String destination, String assertionConsumerServiceUrl, + String protocolBinding, String nameIDFormat) { + AuthnRequest auth = this.authnRequestBuilder.buildObject(); + auth.setID("ARQ" + UUID.randomUUID().toString().substring(1)); + auth.setIssueInstant(new DateTime(this.clock.millis())); + auth.setForceAuthn(Boolean.FALSE); + auth.setIsPassive(Boolean.FALSE); + auth.setProtocolBinding(protocolBinding); + Issuer iss = this.issuerBuilder.buildObject(); + iss.setValue(issuer); + auth.setIssuer(iss); + auth.setDestination(destination); + auth.setAssertionConsumerServiceURL(assertionConsumerServiceUrl); + + if (nameIDFormat != null) { + NameIDPolicy nameId = this.nameIDBuilder.buildObject(); + nameId.setFormat(nameIDFormat); + auth.setNameIDPolicy(nameId); + } + return auth; + } + + /** + * Set the {@link AuthnRequest} post-processor resolver + * @param authenticationRequestContextConverter + * @since 5.4 + */ + public void setAuthenticationRequestContextConverter( + Converter authenticationRequestContextConverter) { + Assert.notNull(authenticationRequestContextConverter, "authenticationRequestContextConverter cannot be null"); + this.authenticationRequestContextConverter = authenticationRequestContextConverter; + } + + /** + * ' Use this {@link Clock} with {@link Instant#now()} for generating timestamps + * @param clock + */ + public void setClock(Clock clock) { + Assert.notNull(clock, "clock cannot be null"); + this.clock = clock; + } + + /** + * Sets the {@code protocolBinding} to use when generating authentication requests. + * Acceptable values are {@link SAMLConstants#SAML2_POST_BINDING_URI} and + * {@link SAMLConstants#SAML2_REDIRECT_BINDING_URI} The IDP will be reading this value + * in the {@code AuthNRequest} to determine how to send the Response/Assertion to the + * ACS URL, assertion consumer service URL. + * @param protocolBinding either {@link SAMLConstants#SAML2_POST_BINDING_URI} or + * {@link SAMLConstants#SAML2_REDIRECT_BINDING_URI} + * @throws IllegalArgumentException if the protocolBinding is not valid + * @deprecated Use + * {@link org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.Builder#assertionConsumerServiceBinding(Saml2MessageBinding)} + * instead + */ + @Deprecated + public void setProtocolBinding(String protocolBinding) { + boolean isAllowedBinding = SAMLConstants.SAML2_POST_BINDING_URI.equals(protocolBinding) + || SAMLConstants.SAML2_REDIRECT_BINDING_URI.equals(protocolBinding); + if (!isAllowedBinding) { + throw new IllegalArgumentException("Invalid protocol binding: " + protocolBinding); + } + this.protocolBindingResolver = (context) -> protocolBinding; + } + + private AuthnRequest sign(AuthnRequest authnRequest, RelyingPartyRegistration relyingPartyRegistration) { + SignatureSigningParameters parameters = resolveSigningParameters(relyingPartyRegistration); + return sign(authnRequest, parameters); + } + + private AuthnRequest sign(AuthnRequest authnRequest, SignatureSigningParameters parameters) { + try { + SignatureSupport.signObject(authnRequest, parameters); + return authnRequest; + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + + private void sign(Map components, RelyingPartyRegistration relyingPartyRegistration) { + SignatureSigningParameters parameters = resolveSigningParameters(relyingPartyRegistration); + sign(components, parameters); + } + + private void sign(Map components, SignatureSigningParameters parameters) { + Credential credential = parameters.getSigningCredential(); + String algorithmUri = parameters.getSignatureAlgorithm(); + components.put("SigAlg", algorithmUri); + UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); + for (Map.Entry component : components.entrySet()) { + builder.queryParam(component.getKey(), UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1)); + } + String queryString = builder.build(true).toString().substring(1); + try { + byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri, + queryString.getBytes(StandardCharsets.UTF_8)); + String b64Signature = Saml2Utils.samlEncode(rawSignature); + components.put("Signature", b64Signature); + } + catch (SecurityException ex) { + throw new Saml2Exception(ex); + } + } + + private String serialize(AuthnRequest authnRequest) { + try { + Element element = this.marshaller.marshall(authnRequest); + return SerializeSupport.nodeToString(element); + } + catch (MarshallingException ex) { + throw new Saml2Exception(ex); + } + } + + private SignatureSigningParameters resolveSigningParameters(RelyingPartyRegistration relyingPartyRegistration) { + List credentials = resolveSigningCredentials(relyingPartyRegistration); + List algorithms = relyingPartyRegistration.getAssertingPartyDetails().getSigningAlgorithms(); + List digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256); + String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS; + SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver(); + CriteriaSet criteria = new CriteriaSet(); + BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration(); + signingConfiguration.setSigningCredentials(credentials); + signingConfiguration.setSignatureAlgorithms(algorithms); + signingConfiguration.setSignatureReferenceDigestMethods(digests); + signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization); + criteria.add(new SignatureSigningConfigurationCriterion(signingConfiguration)); + try { + SignatureSigningParameters parameters = resolver.resolveSingle(criteria); + Assert.notNull(parameters, "Failed to resolve any signing credential"); + return parameters; + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + + private List resolveSigningCredentials(RelyingPartyRegistration relyingPartyRegistration) { + List credentials = new ArrayList<>(); + for (Saml2X509Credential x509Credential : relyingPartyRegistration.getSigningX509Credentials()) { + X509Certificate certificate = x509Credential.getCertificate(); + PrivateKey privateKey = x509Credential.getPrivateKey(); + BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey); + credential.setEntityId(relyingPartyRegistration.getEntityId()); + credential.setUsageType(UsageType.SIGNING); + credentials.add(credential); + } + return credentials; + } + +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java index 1f0d5c19af1..4a18e8cdd6f 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java @@ -31,6 +31,7 @@ import org.opensaml.saml.saml2.metadata.AssertionConsumerService; import org.opensaml.saml.saml2.metadata.EntityDescriptor; import org.opensaml.saml.saml2.metadata.KeyDescriptor; +import org.opensaml.saml.saml2.metadata.NameIDFormat; import org.opensaml.saml.saml2.metadata.SPSSODescriptor; import org.opensaml.saml.saml2.metadata.SingleLogoutService; import org.opensaml.saml.saml2.metadata.impl.EntityDescriptorMarshaller; @@ -87,6 +88,9 @@ private SPSSODescriptor buildSpSsoDescriptor(RelyingPartyRegistration registrati .addAll(buildKeys(registration.getDecryptionX509Credentials(), UsageType.ENCRYPTION)); spSsoDescriptor.getAssertionConsumerServices().add(buildAssertionConsumerService(registration)); spSsoDescriptor.getSingleLogoutServices().add(buildSingleLogoutService(registration)); + if (registration.getNameIdFormat() != null) { + spSsoDescriptor.getNameIDFormats().add(buildNameIDFormat(registration)); + } return spSsoDescriptor; } @@ -133,6 +137,12 @@ private SingleLogoutService buildSingleLogoutService(RelyingPartyRegistration re return singleLogoutService; } + private NameIDFormat buildNameIDFormat(RelyingPartyRegistration registration) { + NameIDFormat nameIdFormat = build(NameIDFormat.DEFAULT_ELEMENT_NAME); + nameIdFormat.setFormat(registration.getNameIdFormat()); + return nameIdFormat; + } + @SuppressWarnings("unchecked") private T build(QName elementName) { XMLObjectBuilder builder = XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(elementName); diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java index d07a3664f8a..43e61b11e18 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java @@ -87,6 +87,8 @@ public final class RelyingPartyRegistration { private final Saml2MessageBinding singleLogoutServiceBinding; + private final String nameIdFormat; + private final ProviderDetails providerDetails; private final List credentials; @@ -98,7 +100,7 @@ public final class RelyingPartyRegistration { private RelyingPartyRegistration(String registrationId, String entityId, String assertionConsumerServiceLocation, Saml2MessageBinding assertionConsumerServiceBinding, String singleLogoutServiceLocation, String singleLogoutServiceResponseLocation, Saml2MessageBinding singleLogoutServiceBinding, - ProviderDetails providerDetails, + ProviderDetails providerDetails, String nameIdFormat, Collection credentials, Collection decryptionX509Credentials, Collection signingX509Credentials) { @@ -129,6 +131,7 @@ private RelyingPartyRegistration(String registrationId, String entityId, String this.singleLogoutServiceLocation = singleLogoutServiceLocation; this.singleLogoutServiceResponseLocation = singleLogoutServiceResponseLocation; this.singleLogoutServiceBinding = singleLogoutServiceBinding; + this.nameIdFormat = nameIdFormat; this.providerDetails = providerDetails; this.credentials = Collections.unmodifiableList(new LinkedList<>(credentials)); this.decryptionX509Credentials = Collections.unmodifiableList(new LinkedList<>(decryptionX509Credentials)); @@ -234,6 +237,15 @@ public String getSingleLogoutServiceResponseLocation() { return this.singleLogoutServiceResponseLocation; } + /** + * Get the NameID format. + * @return the NameID format + * @since 5.7 + */ + public String getNameIdFormat() { + return this.nameIdFormat; + } + /** * Get the {@link Collection} of decryption {@link Saml2X509Credential}s associated * with this relying party @@ -424,6 +436,7 @@ public static Builder withRelyingPartyRegistration(RelyingPartyRegistration regi .singleLogoutServiceLocation(registration.getSingleLogoutServiceLocation()) .singleLogoutServiceResponseLocation(registration.getSingleLogoutServiceResponseLocation()) .singleLogoutServiceBinding(registration.getSingleLogoutServiceBinding()) + .nameIdFormat(registration.getNameIdFormat()) .assertingPartyDetails((assertingParty) -> assertingParty .entityId(registration.getAssertingPartyDetails().getEntityId()) .wantAuthnRequestsSigned(registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) @@ -1018,6 +1031,8 @@ public static final class Builder { private Saml2MessageBinding singleLogoutServiceBinding = Saml2MessageBinding.POST; + private String nameIdFormat = null; + private ProviderDetails.Builder providerDetails = new ProviderDetails.Builder(); private Collection credentials = new HashSet<>(); @@ -1173,6 +1188,17 @@ public Builder singleLogoutServiceResponseLocation(String singleLogoutServiceRes return this; } + /** + * Set the NameID format + * @param nameIdFormat + * @return the {@link Builder} for further configuration + * @since 5.7 + */ + public Builder nameIdFormat(String nameIdFormat) { + this.nameIdFormat = nameIdFormat; + return this; + } + /** * Apply this {@link Consumer} to further configure the Asserting Party details * @param assertingPartyDetails The {@link Consumer} to apply @@ -1321,7 +1347,7 @@ public RelyingPartyRegistration build() { return new RelyingPartyRegistration(this.registrationId, this.entityId, this.assertionConsumerServiceLocation, this.assertionConsumerServiceBinding, this.singleLogoutServiceLocation, this.singleLogoutServiceResponseLocation, - this.singleLogoutServiceBinding, this.providerDetails.build(), this.credentials, + this.singleLogoutServiceBinding, this.providerDetails.build(), this.nameIdFormat, this.credentials, this.decryptionX509Credentials, this.signingX509Credentials); } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java new file mode 100644 index 00000000000..9a66653b67d --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java @@ -0,0 +1,294 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.authentication; + +import java.io.ByteArrayInputStream; +import java.nio.charset.StandardCharsets; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.saml.common.xml.SAMLConstants; +import org.opensaml.saml.saml2.core.AuthnRequest; +import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller; +import org.opensaml.xmlsec.signature.support.SignatureConstants; +import org.w3c.dom.Document; +import org.w3c.dom.Element; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.security.saml2.credentials.TestSaml2X509Credentials; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link OpenSamlAuthenticationRequestFactory} + */ +public class OpenSamlAuthenticationRequestFactoryTests { + + private OpenSamlAuthenticationRequestFactory factory; + + private Saml2AuthenticationRequestContext.Builder contextBuilder; + + private Saml2AuthenticationRequestContext context; + + private RelyingPartyRegistration.Builder relyingPartyRegistrationBuilder; + + private RelyingPartyRegistration relyingPartyRegistration; + + private AuthnRequestUnmarshaller unmarshaller; + + @Before + public void setUp() { + this.relyingPartyRegistrationBuilder = RelyingPartyRegistration.withRegistrationId("id") + .assertionConsumerServiceLocation("template") + .providerDetails((c) -> c.webSsoUrl("https://destination/sso")) + .providerDetails((c) -> c.entityId("remote-entity-id")).localEntityIdTemplate("local-entity-id") + .credentials((c) -> c.add(TestSaml2X509Credentials.relyingPartySigningCredential())); + this.relyingPartyRegistration = this.relyingPartyRegistrationBuilder.build(); + this.contextBuilder = Saml2AuthenticationRequestContext.builder().issuer("https://issuer") + .relyingPartyRegistration(this.relyingPartyRegistration) + .assertionConsumerServiceUrl("https://issuer/sso"); + this.context = this.contextBuilder.build(); + this.factory = new OpenSamlAuthenticationRequestFactory(); + this.unmarshaller = (AuthnRequestUnmarshaller) XMLObjectProviderRegistrySupport.getUnmarshallerFactory() + .getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME); + } + + @Test + public void createAuthenticationRequestWhenInvokingDeprecatedMethodThenReturnsXML() { + Saml2AuthenticationRequest request = Saml2AuthenticationRequest.withAuthenticationRequestContext(this.context) + .build(); + String result = this.factory.createAuthenticationRequest(request); + assertThat(result.replace("\n", "")) + .startsWith(" c.signAuthNRequest(false)).build()) + .build(); + Saml2RedirectAuthenticationRequest result = this.factory.createRedirectAuthenticationRequest(this.context); + assertThat(result.getSamlRequest()).isNotEmpty(); + assertThat(result.getRelayState()).isEqualTo("Relay State Value"); + assertThat(result.getSigAlg()).isNull(); + assertThat(result.getSignature()).isNull(); + assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); + } + + @Test + public void createRedirectAuthenticationRequestWhenSignRequestThenSignatureIsPresent() { + this.context = this.contextBuilder.relayState("Relay State Value") + .relyingPartyRegistration(this.relyingPartyRegistration).build(); + Saml2RedirectAuthenticationRequest request = this.factory.createRedirectAuthenticationRequest(this.context); + assertThat(request.getRelayState()).isEqualTo("Relay State Value"); + assertThat(request.getSigAlg()).isEqualTo(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); + assertThat(request.getSignature()).isNotNull(); + } + + @Test + public void createRedirectAuthenticationRequestWhenSignRequestThenCredentialIsRequired() { + Saml2X509Credential credential = org.springframework.security.saml2.core.TestSaml2X509Credentials + .relyingPartyVerifyingCredential(); + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.noCredentials() + .assertingPartyDetails((party) -> party.verificationX509Credentials((c) -> c.add(credential))).build(); + this.context = this.contextBuilder.relayState("Relay State Value").relyingPartyRegistration(registration) + .build(); + assertThatExceptionOfType(Saml2Exception.class) + .isThrownBy(() -> this.factory.createPostAuthenticationRequest(this.context)); + } + + @Test + public void createPostAuthenticationRequestWhenNotSignRequestThenNoSignatureIsPresent() { + this.context = this.contextBuilder.relayState("Relay State Value") + .relyingPartyRegistration( + RelyingPartyRegistration.withRelyingPartyRegistration(this.relyingPartyRegistration) + .providerDetails((c) -> c.signAuthNRequest(false)).build()) + .build(); + Saml2PostAuthenticationRequest result = this.factory.createPostAuthenticationRequest(this.context); + assertThat(result.getSamlRequest()).isNotEmpty(); + assertThat(result.getRelayState()).isEqualTo("Relay State Value"); + assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.POST); + assertThat(new String(Saml2Utils.samlDecode(result.getSamlRequest()), StandardCharsets.UTF_8)) + .doesNotContain("ds:Signature"); + } + + @Test + public void createPostAuthenticationRequestWhenSignRequestThenSignatureIsPresent() { + this.context = this.contextBuilder.relayState("Relay State Value") + .relyingPartyRegistration( + RelyingPartyRegistration.withRelyingPartyRegistration(this.relyingPartyRegistration).build()) + .build(); + Saml2PostAuthenticationRequest result = this.factory.createPostAuthenticationRequest(this.context); + assertThat(result.getSamlRequest()).isNotEmpty(); + assertThat(result.getRelayState()).isEqualTo("Relay State Value"); + assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.POST); + assertThat(new String(Saml2Utils.samlDecode(result.getSamlRequest()), StandardCharsets.UTF_8)) + .contains("ds:Signature"); + } + + @Test + public void createPostAuthenticationRequestWhenSignRequestThenCredentialIsRequired() { + Saml2X509Credential credential = org.springframework.security.saml2.core.TestSaml2X509Credentials + .relyingPartyVerifyingCredential(); + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.noCredentials() + .assertingPartyDetails((party) -> party.verificationX509Credentials((c) -> c.add(credential))).build(); + this.context = this.contextBuilder.relayState("Relay State Value").relyingPartyRegistration(registration) + .build(); + assertThatExceptionOfType(Saml2Exception.class) + .isThrownBy(() -> this.factory.createPostAuthenticationRequest(this.context)); + } + + @Test + public void createAuthenticationRequestWhenDefaultThenReturnsPostBinding() { + AuthnRequest authn = getAuthNRequest(Saml2MessageBinding.POST); + Assert.assertEquals(SAMLConstants.SAML2_POST_BINDING_URI, authn.getProtocolBinding()); + } + + @Test + public void createAuthenticationRequestWhenSetUriThenReturnsCorrectBinding() { + this.factory.setProtocolBinding(SAMLConstants.SAML2_REDIRECT_BINDING_URI); + AuthnRequest authn = getAuthNRequest(Saml2MessageBinding.POST); + Assert.assertEquals(SAMLConstants.SAML2_REDIRECT_BINDING_URI, authn.getProtocolBinding()); + } + + @Test + public void createAuthenticationRequestWhenSetNameIDPolicyThenReturnsCorrectNameIDPolicy() { + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().nameIdFormat("format").build(); + this.context = this.contextBuilder.relayState("Relay State Value").relyingPartyRegistration(registration) + .build(); + AuthnRequest authn = getAuthNRequest(Saml2MessageBinding.POST); + assertThat(authn.getNameIDPolicy()).isNotNull(); + assertThat(authn.getNameIDPolicy().getAllowCreate()).isFalse(); + assertThat(authn.getNameIDPolicy().getFormat()).isEqualTo("format"); + assertThat(authn.getNameIDPolicy().getSPNameQualifier()).isNull(); + } + + @Test + public void createAuthenticationRequestWhenSetUnsupportredUriThenThrowsIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.factory.setProtocolBinding("my-invalid-binding")) + .withMessageContaining("my-invalid-binding"); + } + + @Test + public void createPostAuthenticationRequestWhenAuthnRequestConsumerThenUses() { + Converter authenticationRequestContextConverter = mock( + Converter.class); + given(authenticationRequestContextConverter.convert(this.context)) + .willReturn(TestOpenSamlObjects.authnRequest()); + this.factory.setAuthenticationRequestContextConverter(authenticationRequestContextConverter); + + this.factory.createPostAuthenticationRequest(this.context); + verify(authenticationRequestContextConverter).convert(this.context); + } + + @Test + public void createRedirectAuthenticationRequestWhenAuthnRequestConsumerThenUses() { + Converter authenticationRequestContextConverter = mock( + Converter.class); + given(authenticationRequestContextConverter.convert(this.context)) + .willReturn(TestOpenSamlObjects.authnRequest()); + this.factory.setAuthenticationRequestContextConverter(authenticationRequestContextConverter); + + this.factory.createRedirectAuthenticationRequest(this.context); + verify(authenticationRequestContextConverter).convert(this.context); + } + + @Test + public void setAuthenticationRequestContextConverterWhenNullThenException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.factory.setAuthenticationRequestContextConverter(null)); + // @formatter:on + } + + @Test + public void createPostAuthenticationRequestWhenAssertionConsumerServiceBindingThenUses() { + RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationBuilder + .assertionConsumerServiceBinding(Saml2MessageBinding.REDIRECT).build(); + Saml2AuthenticationRequestContext context = this.contextBuilder + .relyingPartyRegistration(relyingPartyRegistration).build(); + Saml2PostAuthenticationRequest request = this.factory.createPostAuthenticationRequest(context); + String samlRequest = request.getSamlRequest(); + String inflated = new String(Saml2Utils.samlDecode(samlRequest)); + assertThat(inflated).contains("ProtocolBinding=\"" + SAMLConstants.SAML2_REDIRECT_BINDING_URI + "\""); + } + + @Test + public void createRedirectAuthenticationRequestWhenSHA1SignRequestThenSignatureIsPresent() { + RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationBuilder + .assertingPartyDetails( + (a) -> a.signingAlgorithms((algs) -> algs.add(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA1))) + .build(); + Saml2AuthenticationRequestContext context = this.contextBuilder.relayState("Relay State Value") + .relyingPartyRegistration(relyingPartyRegistration).build(); + Saml2RedirectAuthenticationRequest result = this.factory.createRedirectAuthenticationRequest(context); + assertThat(result.getSamlRequest()).isNotEmpty(); + assertThat(result.getRelayState()).isEqualTo("Relay State Value"); + assertThat(result.getSigAlg()).isEqualTo(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA1); + assertThat(result.getSignature()).isNotNull(); + assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); + } + + private AuthnRequest getAuthNRequest(Saml2MessageBinding binding) { + AbstractSaml2AuthenticationRequest result = (binding == Saml2MessageBinding.REDIRECT) + ? this.factory.createRedirectAuthenticationRequest(this.context) + : this.factory.createPostAuthenticationRequest(this.context); + String samlRequest = result.getSamlRequest(); + assertThat(samlRequest).isNotEmpty(); + if (result.getBinding() == Saml2MessageBinding.REDIRECT) { + samlRequest = Saml2Utils.samlInflate(Saml2Utils.samlDecode(samlRequest)); + } + else { + samlRequest = new String(Saml2Utils.samlDecode(samlRequest), StandardCharsets.UTF_8); + } + try { + Document document = XMLObjectProviderRegistrySupport.getParserPool() + .parse(new ByteArrayInputStream(samlRequest.getBytes(StandardCharsets.UTF_8))); + Element element = document.getDocumentElement(); + return (AuthnRequest) this.unmarshaller.unmarshall(element); + } + catch (Exception ex) { + throw new Saml2Exception(ex); + } + } + +} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java index d42fc875be6..1dfaa3b854d 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java @@ -61,4 +61,13 @@ public void resolveWhenRelyingPartyNoCredentialsThenMetadataMatches() { .contains("ResponseLocation=\"https://rp.example.org/logout/saml2/response\""); } + @Test + public void resolveWhenRelyingPartyNameIDFormatThenMetadataMatches() { + RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.full().nameIdFormat("format") + .build(); + OpenSamlMetadataResolver openSamlMetadataResolver = new OpenSamlMetadataResolver(); + String metadata = openSamlMetadataResolver.resolve(relyingPartyRegistration); + assertThat(metadata).contains("format"); + } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java index d25d4b981c0..63e9d58505a 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java @@ -28,6 +28,7 @@ public class RelyingPartyRegistrationTests { @Test public void withRelyingPartyRegistrationWorks() { RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration() + .nameIdFormat("format") .assertingPartyDetails((a) -> a.singleSignOnServiceBinding(Saml2MessageBinding.POST)) .assertingPartyDetails((a) -> a.wantAuthnRequestsSigned(false)) .assertingPartyDetails((a) -> a.signingAlgorithms((algs) -> algs.add("alg"))) @@ -74,6 +75,7 @@ private void compareRegistrations(RelyingPartyRegistration registration, Relying .isEqualTo(registration.getAssertingPartyDetails().getVerificationX509Credentials()); assertThat(copy.getAssertingPartyDetails().getSigningAlgorithms()) .isEqualTo(registration.getAssertingPartyDetails().getSigningAlgorithms()); + assertThat(copy.getNameIdFormat()).isEqualTo(registration.getNameIdFormat()); } @Test