1919import java .util .LinkedHashMap ;
2020import java .util .Map ;
2121
22- import jakarta .servlet .Filter ;
23-
2422import org .opensaml .core .Version ;
2523
2624import org .springframework .beans .factory .NoSuchBeanDefinitionException ;
5048import org .springframework .security .saml2 .provider .service .web .Saml2AuthenticationRequestContextResolver ;
5149import org .springframework .security .saml2 .provider .service .web .Saml2AuthenticationRequestRepository ;
5250import org .springframework .security .saml2 .provider .service .web .Saml2AuthenticationTokenConverter ;
51+ import org .springframework .security .saml2 .provider .service .web .authentication .Saml2AuthenticationRequestResolver ;
5352import org .springframework .security .web .authentication .AuthenticationConverter ;
5453import org .springframework .security .web .authentication .LoginUrlAuthenticationEntryPoint ;
5554import org .springframework .security .web .authentication .ui .DefaultLoginPageGeneratingFilter ;
@@ -115,9 +114,11 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
115114
116115 private String loginPage ;
117116
118- private String loginProcessingUrl = Saml2WebSsoAuthenticationFilter .DEFAULT_FILTER_PROCESSES_URI ;
117+ private String authenticationRequestUri = "/saml2/authenticate/{registrationId}" ;
118+
119+ private Saml2AuthenticationRequestResolver authenticationRequestResolver ;
119120
120- private AuthenticationRequestEndpointConfig authenticationRequestEndpoint = new AuthenticationRequestEndpointConfig () ;
121+ private String loginProcessingUrl = Saml2WebSsoAuthenticationFilter . DEFAULT_FILTER_PROCESSES_URI ;
121122
122123 private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository ;
123124
@@ -176,6 +177,20 @@ public Saml2LoginConfigurer<B> loginPage(String loginPage) {
176177 return this ;
177178 }
178179
180+ /**
181+ * Use this {@link Saml2AuthenticationRequestResolver} for generating SAML 2.0
182+ * Authentication Requests.
183+ * @param authenticationRequestResolver
184+ * @return the {@link Saml2LoginConfigurer} for further configuration
185+ * @since 5.7
186+ */
187+ public Saml2LoginConfigurer <B > authenticationRequestResolver (
188+ Saml2AuthenticationRequestResolver authenticationRequestResolver ) {
189+ Assert .notNull (authenticationRequestResolver , "authenticationRequestResolver cannot be null" );
190+ this .authenticationRequestResolver = authenticationRequestResolver ;
191+ return this ;
192+ }
193+
179194 /**
180195 * Specifies the URL to validate the credentials. If specified a custom URL, consider
181196 * specifying a custom {@link AuthenticationConverter} via
@@ -200,7 +215,7 @@ protected RequestMatcher createLoginProcessingUrlMatcher(String loginProcessingU
200215
201216 /**
202217 * {@inheritDoc}
203- *
218+ * <p>
204219 * Initializes this filter chain for SAML 2 Login. The following actions are taken:
205220 * <ul>
206221 * <li>The WebSSO endpoint has CSRF disabled, typically {@code /login/saml2/sso}</li>
@@ -226,8 +241,8 @@ public void init(B http) throws Exception {
226241 super .init (http );
227242 }
228243 else {
229- Map <String , String > providerUrlMap = getIdentityProviderUrlMap (
230- this .authenticationRequestEndpoint . filterProcessingUrl , this . relyingPartyRegistrationRepository );
244+ Map <String , String > providerUrlMap = getIdentityProviderUrlMap (this . authenticationRequestUri ,
245+ this .relyingPartyRegistrationRepository );
231246 boolean singleProvider = providerUrlMap .size () == 1 ;
232247 if (singleProvider ) {
233248 // Setup auto-redirect to provider login page
@@ -247,14 +262,16 @@ public void init(B http) throws Exception {
247262
248263 /**
249264 * {@inheritDoc}
250- *
265+ * <p>
251266 * During the {@code configure} phase, a
252267 * {@link Saml2WebSsoAuthenticationRequestFilter} is added to handle SAML 2.0
253268 * AuthNRequest redirects
254269 */
255270 @ Override
256271 public void configure (B http ) throws Exception {
257- http .addFilter (this .authenticationRequestEndpoint .build (http ));
272+ Saml2WebSsoAuthenticationRequestFilter filter = getAuthenticationRequestFilter (http );
273+ filter .setAuthenticationRequestRepository (getAuthenticationRequestRepository (http ));
274+ http .addFilter (postProcess (filter ));
258275 super .configure (http );
259276 if (this .authenticationManager == null ) {
260277 registerDefaultAuthenticationProvider (http );
@@ -264,6 +281,11 @@ public void configure(B http) throws Exception {
264281 }
265282 }
266283
284+ private RelyingPartyRegistrationResolver relyingPartyRegistrationResolver (B http ) {
285+ RelyingPartyRegistrationRepository registrations = relyingPartyRegistrationRepository (http );
286+ return new DefaultRelyingPartyRegistrationResolver (registrations );
287+ }
288+
267289 RelyingPartyRegistrationRepository relyingPartyRegistrationRepository (B http ) {
268290 if (this .relyingPartyRegistrationRepository == null ) {
269291 this .relyingPartyRegistrationRepository = getSharedOrBean (http , RelyingPartyRegistrationRepository .class );
@@ -276,6 +298,46 @@ private void setAuthenticationRequestRepository(B http,
276298 saml2WebSsoAuthenticationFilter .setAuthenticationRequestRepository (getAuthenticationRequestRepository (http ));
277299 }
278300
301+ private Saml2WebSsoAuthenticationRequestFilter getAuthenticationRequestFilter (B http ) {
302+ Saml2AuthenticationRequestResolver authenticationRequestResolver = getAuthenticationRequestResolver (http );
303+ if (authenticationRequestResolver != null ) {
304+ return new Saml2WebSsoAuthenticationRequestFilter (authenticationRequestResolver );
305+ }
306+ return new Saml2WebSsoAuthenticationRequestFilter (getAuthenticationRequestContextResolver (http ),
307+ getAuthenticationRequestFactory (http ));
308+ }
309+
310+ private Saml2AuthenticationRequestResolver getAuthenticationRequestResolver (B http ) {
311+ if (this .authenticationRequestResolver != null ) {
312+ return this .authenticationRequestResolver ;
313+ }
314+ return getBeanOrNull (http , Saml2AuthenticationRequestResolver .class );
315+ }
316+
317+ private Saml2AuthenticationRequestFactory getAuthenticationRequestFactory (B http ) {
318+ Saml2AuthenticationRequestFactory resolver = getSharedOrBean (http , Saml2AuthenticationRequestFactory .class );
319+ if (resolver != null ) {
320+ return resolver ;
321+ }
322+ if (version ().startsWith ("4" )) {
323+ return new OpenSaml4AuthenticationRequestFactory ();
324+ }
325+ else {
326+ return new OpenSamlAuthenticationRequestFactory ();
327+ }
328+ }
329+
330+ private Saml2AuthenticationRequestContextResolver getAuthenticationRequestContextResolver (B http ) {
331+ Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull (http ,
332+ Saml2AuthenticationRequestContextResolver .class );
333+ if (resolver != null ) {
334+ return resolver ;
335+ }
336+ RelyingPartyRegistrationResolver registrationResolver = new DefaultRelyingPartyRegistrationResolver (
337+ this .relyingPartyRegistrationRepository );
338+ return new DefaultSaml2AuthenticationRequestContextResolver (registrationResolver );
339+ }
340+
279341 private AuthenticationConverter getAuthenticationConverter (B http ) {
280342 if (this .authenticationConverter != null ) {
281343 return this .authenticationConverter ;
@@ -325,8 +387,8 @@ private void initDefaultLoginFilter(B http) {
325387 return ;
326388 }
327389 loginPageGeneratingFilter .setSaml2LoginEnabled (true );
328- loginPageGeneratingFilter .setSaml2AuthenticationUrlToProviderName (this . getIdentityProviderUrlMap (
329- this .authenticationRequestEndpoint . filterProcessingUrl , this .relyingPartyRegistrationRepository ));
390+ loginPageGeneratingFilter .setSaml2AuthenticationUrlToProviderName (
391+ this .getIdentityProviderUrlMap ( this . authenticationRequestUri , this .relyingPartyRegistrationRepository ));
330392 loginPageGeneratingFilter .setLoginPageUrl (this .getLoginPage ());
331393 loginPageGeneratingFilter .setFailureUrl (this .getFailureUrl ());
332394 }
@@ -380,46 +442,4 @@ private <C> void setSharedObject(B http, Class<C> clazz, C object) {
380442 }
381443 }
382444
383- private final class AuthenticationRequestEndpointConfig {
384-
385- private String filterProcessingUrl = "/saml2/authenticate/{registrationId}" ;
386-
387- private AuthenticationRequestEndpointConfig () {
388- }
389-
390- private Filter build (B http ) {
391- Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver (http );
392- Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver (http );
393- Saml2AuthenticationRequestRepository <AbstractSaml2AuthenticationRequest > repository = getAuthenticationRequestRepository (
394- http );
395- Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter (contextResolver ,
396- authenticationRequestResolver );
397- filter .setAuthenticationRequestRepository (repository );
398- return postProcess (filter );
399- }
400-
401- private Saml2AuthenticationRequestFactory getResolver (B http ) {
402- Saml2AuthenticationRequestFactory resolver = getSharedOrBean (http , Saml2AuthenticationRequestFactory .class );
403- if (resolver == null ) {
404- if (version ().startsWith ("4" )) {
405- return new OpenSaml4AuthenticationRequestFactory ();
406- }
407- return new OpenSamlAuthenticationRequestFactory ();
408- }
409- return resolver ;
410- }
411-
412- private Saml2AuthenticationRequestContextResolver getContextResolver (B http ) {
413- Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull (http ,
414- Saml2AuthenticationRequestContextResolver .class );
415- if (resolver == null ) {
416- RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = new DefaultRelyingPartyRegistrationResolver (
417- Saml2LoginConfigurer .this .relyingPartyRegistrationRepository );
418- return new DefaultSaml2AuthenticationRequestContextResolver (relyingPartyRegistrationResolver );
419- }
420- return resolver ;
421- }
422-
423- }
424-
425445}
0 commit comments