diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientService.java index 7d88ba72361..36e8822a888 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientService.java @@ -17,6 +17,7 @@ package org.springframework.security.oauth2.client; import java.nio.charset.StandardCharsets; +import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Timestamp; @@ -35,6 +36,9 @@ import org.springframework.jdbc.core.PreparedStatementSetter; import org.springframework.jdbc.core.RowMapper; import org.springframework.jdbc.core.SqlParameterValue; +import org.springframework.jdbc.support.lob.DefaultLobHandler; +import org.springframework.jdbc.support.lob.LobCreator; +import org.springframework.jdbc.support.lob.LobHandler; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; @@ -56,6 +60,7 @@ * * @author Joe Grandja * @author Stav Shamir + * @author Craig Andrews * @since 5.3 * @see OAuth2AuthorizedClientService * @see OAuth2AuthorizedClient @@ -107,18 +112,38 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient protected Function> authorizedClientParametersMapper; + protected final LobHandler lobHandler; + /** * Constructs a {@code JdbcOAuth2AuthorizedClientService} using the provided * parameters. * @param jdbcOperations the JDBC operations * @param clientRegistrationRepository the repository of client registrations + * @since 5.5 */ public JdbcOAuth2AuthorizedClientService(JdbcOperations jdbcOperations, ClientRegistrationRepository clientRegistrationRepository) { + this(jdbcOperations, clientRegistrationRepository, new DefaultLobHandler()); + } + + /** + * Constructs a {@code JdbcOAuth2AuthorizedClientService} using the provided + * parameters. + * @param jdbcOperations the JDBC operations + * @param clientRegistrationRepository the repository of client registrations + * @param lobHandler the handler for large binary fields and large text fields + */ + public JdbcOAuth2AuthorizedClientService(JdbcOperations jdbcOperations, + ClientRegistrationRepository clientRegistrationRepository, LobHandler lobHandler) { Assert.notNull(jdbcOperations, "jdbcOperations cannot be null"); Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + Assert.notNull(lobHandler, "LobHandler must not be null"); this.jdbcOperations = jdbcOperations; - this.authorizedClientRowMapper = new OAuth2AuthorizedClientRowMapper(clientRegistrationRepository); + this.lobHandler = lobHandler; + OAuth2AuthorizedClientRowMapper authorizedClientRowMapper = new OAuth2AuthorizedClientRowMapper( + clientRegistrationRepository); + authorizedClientRowMapper.setLobHandler(lobHandler); + this.authorizedClientRowMapper = authorizedClientRowMapper; this.authorizedClientParametersMapper = new OAuth2AuthorizedClientParametersMapper(); } @@ -131,6 +156,7 @@ public T loadAuthorizedClient(String clientRe SqlParameterValue[] parameters = new SqlParameterValue[] { new SqlParameterValue(Types.VARCHAR, clientRegistrationId), new SqlParameterValue(Types.VARCHAR, principalName) }; + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); List result = this.jdbcOperations.query(LOAD_AUTHORIZED_CLIENT_SQL, pss, this.authorizedClientRowMapper); @@ -163,15 +189,22 @@ private void updateAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Aut SqlParameterValue principalNameParameter = parameters.remove(0); parameters.add(clientRegistrationIdParameter); parameters.add(principalNameParameter); - PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); - this.jdbcOperations.update(UPDATE_AUTHORIZED_CLIENT_SQL, pss); + + try (LobCreator lobCreator = this.lobHandler.getLobCreator()) { + PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter(lobCreator, + parameters.toArray()); + this.jdbcOperations.update(UPDATE_AUTHORIZED_CLIENT_SQL, pss); + } } private void insertAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) { List parameters = this.authorizedClientParametersMapper .apply(new OAuth2AuthorizedClientHolder(authorizedClient, principal)); - PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); - this.jdbcOperations.update(SAVE_AUTHORIZED_CLIENT_SQL, pss); + try (LobCreator lobCreator = this.lobHandler.getLobCreator()) { + PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter(lobCreator, + parameters.toArray()); + this.jdbcOperations.update(SAVE_AUTHORIZED_CLIENT_SQL, pss); + } } @Override @@ -218,11 +251,18 @@ public static class OAuth2AuthorizedClientRowMapper implements RowMapper scopes = Collections.emptySet(); @@ -247,7 +288,7 @@ public OAuth2AuthorizedClient mapRow(ResultSet rs, int rowNum) throws SQLExcepti } OAuth2AccessToken accessToken = new OAuth2AccessToken(tokenType, tokenValue, issuedAt, expiresAt, scopes); OAuth2RefreshToken refreshToken = null; - byte[] refreshTokenValue = rs.getBytes("refresh_token_value"); + byte[] refreshTokenValue = this.lobHandler.getBlobAsBytes(rs, "refresh_token_value"); if (refreshTokenValue != null) { tokenValue = new String(refreshTokenValue, StandardCharsets.UTF_8); issuedAt = null; @@ -346,4 +387,32 @@ public Authentication getPrincipal() { } + private static final class LobCreatorArgumentPreparedStatementSetter extends ArgumentPreparedStatementSetter { + + protected final LobCreator lobCreator; + + private LobCreatorArgumentPreparedStatementSetter(LobCreator lobCreator, Object[] args) { + super(args); + this.lobCreator = lobCreator; + } + + @Override + protected void doSetValue(PreparedStatement ps, int parameterPosition, Object argValue) throws SQLException { + if (argValue instanceof SqlParameterValue) { + SqlParameterValue paramValue = (SqlParameterValue) argValue; + if (paramValue.getSqlType() == Types.BLOB) { + if (paramValue.getValue() != null) { + Assert.isInstanceOf(byte[].class, paramValue.getValue(), + "Value of blob parameter must be byte[]"); + } + byte[] valueBytes = (byte[]) paramValue.getValue(); + this.lobCreator.setBlobAsBytes(ps, parameterPosition, valueBytes); + return; + } + } + super.doSetValue(ps, parameterPosition, argValue); + } + + } + } diff --git a/oauth2/oauth2-client/src/main/resources/org/springframework/security/oauth2/client/oauth2-client-schema-postgres.sql b/oauth2/oauth2-client/src/main/resources/org/springframework/security/oauth2/client/oauth2-client-schema-postgres.sql new file mode 100644 index 00000000000..ff238822475 --- /dev/null +++ b/oauth2/oauth2-client/src/main/resources/org/springframework/security/oauth2/client/oauth2-client-schema-postgres.sql @@ -0,0 +1,13 @@ +CREATE TABLE oauth2_authorized_client ( + client_registration_id varchar(100) NOT NULL, + principal_name varchar(200) NOT NULL, + access_token_type varchar(100) NOT NULL, + access_token_value bytea NOT NULL, + access_token_issued_at timestamp NOT NULL, + access_token_expires_at timestamp NOT NULL, + access_token_scopes varchar(1000) DEFAULT NULL, + refresh_token_value bytea DEFAULT NULL, + refresh_token_issued_at timestamp DEFAULT NULL, + created_at timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL, + PRIMARY KEY (client_registration_id, principal_name) +);