Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.security.oauth2.client;

import java.nio.charset.StandardCharsets;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Timestamp;
Expand All @@ -35,6 +36,9 @@
import org.springframework.jdbc.core.PreparedStatementSetter;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.SqlParameterValue;
import org.springframework.jdbc.support.lob.DefaultLobHandler;
import org.springframework.jdbc.support.lob.LobCreator;
import org.springframework.jdbc.support.lob.LobHandler;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
Expand All @@ -56,6 +60,7 @@
*
* @author Joe Grandja
* @author Stav Shamir
* @author Craig Andrews
* @since 5.3
* @see OAuth2AuthorizedClientService
* @see OAuth2AuthorizedClient
Expand Down Expand Up @@ -107,18 +112,38 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient

protected Function<OAuth2AuthorizedClientHolder, List<SqlParameterValue>> authorizedClientParametersMapper;

protected final LobHandler lobHandler;

/**
* Constructs a {@code JdbcOAuth2AuthorizedClientService} using the provided
* parameters.
* @param jdbcOperations the JDBC operations
* @param clientRegistrationRepository the repository of client registrations
* @since 5.5
*/
public JdbcOAuth2AuthorizedClientService(JdbcOperations jdbcOperations,
ClientRegistrationRepository clientRegistrationRepository) {
this(jdbcOperations, clientRegistrationRepository, new DefaultLobHandler());
}

/**
* Constructs a {@code JdbcOAuth2AuthorizedClientService} using the provided
* parameters.
* @param jdbcOperations the JDBC operations
* @param clientRegistrationRepository the repository of client registrations
* @param lobHandler the handler for large binary fields and large text fields
*/
public JdbcOAuth2AuthorizedClientService(JdbcOperations jdbcOperations,
ClientRegistrationRepository clientRegistrationRepository, LobHandler lobHandler) {
Assert.notNull(jdbcOperations, "jdbcOperations cannot be null");
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
Assert.notNull(lobHandler, "LobHandler must not be null");
this.jdbcOperations = jdbcOperations;
this.authorizedClientRowMapper = new OAuth2AuthorizedClientRowMapper(clientRegistrationRepository);
this.lobHandler = lobHandler;
OAuth2AuthorizedClientRowMapper authorizedClientRowMapper = new OAuth2AuthorizedClientRowMapper(
clientRegistrationRepository);
authorizedClientRowMapper.setLobHandler(lobHandler);
this.authorizedClientRowMapper = authorizedClientRowMapper;
this.authorizedClientParametersMapper = new OAuth2AuthorizedClientParametersMapper();
}

Expand All @@ -131,6 +156,7 @@ public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRe
SqlParameterValue[] parameters = new SqlParameterValue[] {
new SqlParameterValue(Types.VARCHAR, clientRegistrationId),
new SqlParameterValue(Types.VARCHAR, principalName) };

PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
List<OAuth2AuthorizedClient> result = this.jdbcOperations.query(LOAD_AUTHORIZED_CLIENT_SQL, pss,
this.authorizedClientRowMapper);
Expand Down Expand Up @@ -163,15 +189,22 @@ private void updateAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Aut
SqlParameterValue principalNameParameter = parameters.remove(0);
parameters.add(clientRegistrationIdParameter);
parameters.add(principalNameParameter);
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
this.jdbcOperations.update(UPDATE_AUTHORIZED_CLIENT_SQL, pss);

try (LobCreator lobCreator = this.lobHandler.getLobCreator()) {
PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter(lobCreator,
parameters.toArray());
this.jdbcOperations.update(UPDATE_AUTHORIZED_CLIENT_SQL, pss);
}
}

private void insertAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) {
List<SqlParameterValue> parameters = this.authorizedClientParametersMapper
.apply(new OAuth2AuthorizedClientHolder(authorizedClient, principal));
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
this.jdbcOperations.update(SAVE_AUTHORIZED_CLIENT_SQL, pss);
try (LobCreator lobCreator = this.lobHandler.getLobCreator()) {
PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter(lobCreator,
parameters.toArray());
this.jdbcOperations.update(SAVE_AUTHORIZED_CLIENT_SQL, pss);
}
}

@Override
Expand Down Expand Up @@ -218,11 +251,18 @@ public static class OAuth2AuthorizedClientRowMapper implements RowMapper<OAuth2A

protected final ClientRegistrationRepository clientRegistrationRepository;

protected LobHandler lobHandler = new DefaultLobHandler();

public OAuth2AuthorizedClientRowMapper(ClientRegistrationRepository clientRegistrationRepository) {
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
this.clientRegistrationRepository = clientRegistrationRepository;
}

public final void setLobHandler(LobHandler lobHandler) {
Assert.notNull(lobHandler, "LobHandler must not be null");
this.lobHandler = lobHandler;
}

@Override
public OAuth2AuthorizedClient mapRow(ResultSet rs, int rowNum) throws SQLException {
String clientRegistrationId = rs.getString("client_registration_id");
Expand All @@ -237,7 +277,8 @@ public OAuth2AuthorizedClient mapRow(ResultSet rs, int rowNum) throws SQLExcepti
if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(rs.getString("access_token_type"))) {
tokenType = OAuth2AccessToken.TokenType.BEARER;
}
String tokenValue = new String(rs.getBytes("access_token_value"), StandardCharsets.UTF_8);
String tokenValue = new String(this.lobHandler.getBlobAsBytes(rs, "access_token_value"),
StandardCharsets.UTF_8);
Instant issuedAt = rs.getTimestamp("access_token_issued_at").toInstant();
Instant expiresAt = rs.getTimestamp("access_token_expires_at").toInstant();
Set<String> scopes = Collections.emptySet();
Expand All @@ -247,7 +288,7 @@ public OAuth2AuthorizedClient mapRow(ResultSet rs, int rowNum) throws SQLExcepti
}
OAuth2AccessToken accessToken = new OAuth2AccessToken(tokenType, tokenValue, issuedAt, expiresAt, scopes);
OAuth2RefreshToken refreshToken = null;
byte[] refreshTokenValue = rs.getBytes("refresh_token_value");
byte[] refreshTokenValue = this.lobHandler.getBlobAsBytes(rs, "refresh_token_value");
if (refreshTokenValue != null) {
tokenValue = new String(refreshTokenValue, StandardCharsets.UTF_8);
issuedAt = null;
Expand Down Expand Up @@ -346,4 +387,32 @@ public Authentication getPrincipal() {

}

private static final class LobCreatorArgumentPreparedStatementSetter extends ArgumentPreparedStatementSetter {

protected final LobCreator lobCreator;

private LobCreatorArgumentPreparedStatementSetter(LobCreator lobCreator, Object[] args) {
super(args);
this.lobCreator = lobCreator;
}

@Override
protected void doSetValue(PreparedStatement ps, int parameterPosition, Object argValue) throws SQLException {
if (argValue instanceof SqlParameterValue) {
SqlParameterValue paramValue = (SqlParameterValue) argValue;
if (paramValue.getSqlType() == Types.BLOB) {
if (paramValue.getValue() != null) {
Assert.isInstanceOf(byte[].class, paramValue.getValue(),
"Value of blob parameter must be byte[]");
}
byte[] valueBytes = (byte[]) paramValue.getValue();
this.lobCreator.setBlobAsBytes(ps, parameterPosition, valueBytes);
return;
}
}
super.doSetValue(ps, parameterPosition, argValue);
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
CREATE TABLE oauth2_authorized_client (
client_registration_id varchar(100) NOT NULL,
principal_name varchar(200) NOT NULL,
access_token_type varchar(100) NOT NULL,
access_token_value bytea NOT NULL,
access_token_issued_at timestamp NOT NULL,
access_token_expires_at timestamp NOT NULL,
access_token_scopes varchar(1000) DEFAULT NULL,
refresh_token_value bytea DEFAULT NULL,
refresh_token_issued_at timestamp DEFAULT NULL,
created_at timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
PRIMARY KEY (client_registration_id, principal_name)
);