Skip to content

Commit

Permalink
feat(revoke-tokens): revoke access and refresh tokens when consent is…
Browse files Browse the repository at this point in the history
… revoked

Closes gravitee-io/issues#4039
  • Loading branch information
jhaeyaert committed Jul 6, 2020
1 parent 632fe04 commit d020b51
Show file tree
Hide file tree
Showing 10 changed files with 4,361 additions and 997 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,11 @@ public interface AccessTokenRepository {
* @return acknowledge of the operation
*/
Completable deleteByUserId(String userId);

/**
* Delete access token by domainId, clientId and userId.
*/
Completable deleteByDomainIdClientIdAndUserId(String domainId, String clientId, String userId);

Completable deleteByDomainIdAndUserId(String domainId, String userId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,8 @@ public interface RefreshTokenRepository {
Completable delete(String token);

Completable deleteByUserId(String userId);

Completable deleteByDomainIdClientIdAndUserId(String domainId, String clientId, String userId);

Completable deleteByDomainIdAndUserId(String domainId, String userId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public class MongoAccessTokenRepository extends AbstractOAuth2MongoRepository im
private static final String FIELD_ID = "_id";
private static final String FIELD_TOKEN = "token";
private static final String FIELD_RESET_TIME = "expire_at";
private static final String FIELD_DOMAIN_ID = "domain";
private static final String FIELD_CLIENT_ID = "client";
private static final String FIELD_SUBJECT = "subject";
private static final String FIELD_AUTHORIZATION_CODE = "authorization_code";
Expand All @@ -60,8 +61,8 @@ public void init() {
super.createIndex(accessTokenCollection, new Document(FIELD_AUTHORIZATION_CODE, 1));
super.createIndex(accessTokenCollection, new Document(FIELD_SUBJECT, 1));

// two fields index
super.createIndex(accessTokenCollection, new Document(FIELD_CLIENT_ID, 1).append(FIELD_SUBJECT, 1));
// three fields index
super.createIndex(accessTokenCollection, new Document(FIELD_DOMAIN_ID, 1).append(FIELD_CLIENT_ID, 1).append(FIELD_SUBJECT, 1));

// expire after index
super.createIndex(accessTokenCollection, new Document(FIELD_RESET_TIME, 1), new IndexOptions().expireAfter(0L, TimeUnit.SECONDS));
Expand Down Expand Up @@ -130,6 +131,16 @@ public Completable deleteByUserId(String userId) {
return Completable.fromPublisher(accessTokenCollection.deleteMany(eq(FIELD_SUBJECT, userId)));
}

@Override
public Completable deleteByDomainIdClientIdAndUserId(String domainId, String clientId, String userId) {
return Completable.fromPublisher(accessTokenCollection.deleteMany(and(eq(FIELD_DOMAIN_ID, domainId), eq(FIELD_CLIENT_ID, clientId), eq(FIELD_SUBJECT, userId))));
}

@Override
public Completable deleteByDomainIdAndUserId(String domainId, String userId) {
return Completable.fromPublisher(accessTokenCollection.deleteMany(and(eq(FIELD_DOMAIN_ID, domainId), eq(FIELD_SUBJECT, userId))));
}

private List<WriteModel<AccessTokenMongo>> convert(List<AccessToken> accessTokens) {
return accessTokens.stream().map(accessToken -> new InsertOneModel<>(convert(accessToken))).collect(Collectors.toList());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import static com.mongodb.client.model.Filters.and;
import static com.mongodb.client.model.Filters.eq;

/**
Expand All @@ -49,13 +50,18 @@ public class MongoRefreshTokenRepository extends AbstractOAuth2MongoRepository i
private static final String FIELD_RESET_TIME = "expire_at";
private static final String FIELD_TOKEN = "token";
private static final String FIELD_SUBJECT = "subject";
private static final String FIELD_DOMAIN_ID = "domain";
private static final String FIELD_CLIENT_ID = "client";

@PostConstruct
public void init() {
refreshTokenCollection = mongoOperations.getCollection("refresh_tokens", RefreshTokenMongo.class);
super.createIndex(refreshTokenCollection, new Document(FIELD_TOKEN, 1));
super.createIndex(refreshTokenCollection, new Document(FIELD_SUBJECT, 1));
super.createIndex(refreshTokenCollection, new Document(FIELD_RESET_TIME, 1), new IndexOptions().expireAfter(0L, TimeUnit.SECONDS));

// three fields index
super.createIndex(refreshTokenCollection, new Document(FIELD_DOMAIN_ID, 1).append(FIELD_CLIENT_ID, 1).append(FIELD_SUBJECT, 1));
}

private Maybe<RefreshToken> findById(String id) {
Expand All @@ -65,7 +71,6 @@ private Maybe<RefreshToken> findById(String id) {
.map(this::convert);
}


@Override
public Maybe<RefreshToken> findByToken(String token) {
return Observable
Expand Down Expand Up @@ -100,6 +105,16 @@ public Completable deleteByUserId(String userId) {
return Completable.fromPublisher(refreshTokenCollection.deleteMany(eq(FIELD_SUBJECT, userId)));
}

@Override
public Completable deleteByDomainIdClientIdAndUserId(String domainId, String clientId, String userId) {
return Completable.fromPublisher(refreshTokenCollection.deleteMany(and(eq(FIELD_DOMAIN_ID, domainId), eq(FIELD_CLIENT_ID, clientId), eq(FIELD_SUBJECT, userId))));
}

@Override
public Completable deleteByDomainIdAndUserId(String domainId, String userId) {
return Completable.fromPublisher(refreshTokenCollection.deleteMany(and(eq(FIELD_DOMAIN_ID, domainId), eq(FIELD_SUBJECT, userId))));
}

private List<WriteModel<RefreshTokenMongo>> convert(List<RefreshToken> refreshTokens) {
return refreshTokens.stream().map(refreshToken -> new InsertOneModel<>(convert(refreshToken))).collect(Collectors.toList());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import org.junit.Test;
import org.springframework.beans.factory.annotation.Autowired;

import java.util.UUID;
import java.util.Arrays;

/**
* @author David BRASSELY (david.brassely at graviteesource.com)
Expand Down Expand Up @@ -124,4 +124,54 @@ public void shouldCountByClientId() {
observer.assertNoErrors();
observer.assertValue(new Long(1));
}

@Test
public void shouldDeleteByDomainIdClientIdAndUserId() {
AccessToken token1 = new AccessToken();
token1.setId("my-token");
token1.setToken("my-token");
token1.setClient("client-id");
token1.setDomain("domain-id");
token1.setSubject("user-id");

AccessToken token2 = new AccessToken();
token2.setId("my-token2");
token2.setToken("my-token2");
token2.setClient("client-id2");
token2.setDomain("domain-id2");
token2.setSubject("user-id2");

assertEquals(0, accessTokenRepository
.bulkWrite(Arrays.asList(token1, token2))
.andThen(accessTokenRepository.deleteByDomainIdClientIdAndUserId("domain-id", "client-id", "user-id"))
.andThen(accessTokenRepository.findByToken("my-token"))
.test().valueCount());

assertNotNull(accessTokenRepository.findByToken("my-token2").blockingGet());
}

@Test
public void shouldDeleteByDomainIdAndUserId() {
AccessToken token1 = new AccessToken();
token1.setId("my-token");
token1.setToken("my-token");
token1.setClient("client-id");
token1.setDomain("domain-id");
token1.setSubject("user-id");

AccessToken token2 = new AccessToken();
token2.setId("my-token2");
token2.setToken("my-token2");
token2.setClient("client-id2");
token2.setDomain("domain-id2");
token2.setSubject("user-id2");

assertEquals(0, accessTokenRepository
.bulkWrite(Arrays.asList(token1, token2))
.andThen(accessTokenRepository.deleteByDomainIdAndUserId("domain-id", "user-id"))
.andThen(accessTokenRepository.findByToken("my-token"))
.test().valueCount());

assertNotNull(accessTokenRepository.findByToken("my-token2").blockingGet());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.junit.Test;
import org.springframework.beans.factory.annotation.Autowired;

import java.util.Arrays;
import java.util.UUID;

/**
Expand Down Expand Up @@ -81,4 +82,54 @@ public void shouldDelete() {
.andThen(refreshTokenRepository.findByToken("my-token"))
.test().assertEmpty();
}

@Test
public void shouldDeleteByDomainIdClientIdAndUserId() {
RefreshToken token1 = new RefreshToken();
token1.setId("my-token");
token1.setToken("my-token");
token1.setClient("client-id");
token1.setDomain("domain-id");
token1.setSubject("user-id");

RefreshToken token2 = new RefreshToken();
token2.setId("my-token2");
token2.setToken("my-token2");
token2.setClient("client-id2");
token2.setDomain("domain-id2");
token2.setSubject("user-id2");

assertEquals(0, refreshTokenRepository
.bulkWrite(Arrays.asList(token1, token2))
.andThen(refreshTokenRepository.deleteByDomainIdClientIdAndUserId("domain-id", "client-id", "user-id"))
.andThen(refreshTokenRepository.findByToken("my-token"))
.test().valueCount());

assertNotNull(refreshTokenRepository.findByToken("my-token2").blockingGet());
}

@Test
public void shouldDeleteByDomainIdAndUserId() {
RefreshToken token1 = new RefreshToken();
token1.setId("my-token");
token1.setToken("my-token");
token1.setClient("client-id");
token1.setDomain("domain-id");
token1.setSubject("user-id");

RefreshToken token2 = new RefreshToken();
token2.setId("my-token2");
token2.setToken("my-token2");
token2.setClient("client-id2");
token2.setDomain("domain-id2");
token2.setSubject("user-id2");

assertEquals(0, refreshTokenRepository
.bulkWrite(Arrays.asList(token1, token2))
.andThen(refreshTokenRepository.deleteByDomainIdAndUserId("domain-id", "user-id"))
.andThen(refreshTokenRepository.findByToken("my-token"))
.test().valueCount());

assertNotNull(refreshTokenRepository.findByToken("my-token2").blockingGet());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import io.gravitee.am.identityprovider.api.User;
import io.gravitee.am.model.oauth2.ScopeApproval;
import io.gravitee.am.model.oidc.Client;
import io.gravitee.am.repository.oauth2.api.AccessTokenRepository;
import io.gravitee.am.repository.oauth2.api.RefreshTokenRepository;
import io.gravitee.am.repository.oauth2.api.ScopeApprovalRepository;
import io.gravitee.am.service.AuditService;
import io.gravitee.am.service.ScopeApprovalService;
Expand Down Expand Up @@ -56,6 +58,14 @@ public class ScopeApprovalServiceImpl implements ScopeApprovalService {
@Autowired
private ScopeApprovalRepository scopeApprovalRepository;

@Lazy
@Autowired
private AccessTokenRepository accessTokenRepository;

@Lazy
@Autowired
private RefreshTokenRepository refreshTokenRepository;

@Autowired
private UserService userService;

Expand Down Expand Up @@ -120,8 +130,9 @@ public Completable revokeByConsent(String domain, String userId, String consentI
.switchIfEmpty(Maybe.error(new ScopeApprovalNotFoundException(consentId)))
.flatMapCompletable(scopeApproval -> scopeApprovalRepository.delete(consentId)
.doOnComplete(() -> auditService.report(AuditBuilder.builder(UserConsentAuditBuilder.class).type(EventType.USER_CONSENT_REVOKED).domain(domain).principal(principal).user(user).approvals(Collections.singleton(scopeApproval))))
.doOnError(throwable -> auditService.report(AuditBuilder.builder(UserConsentAuditBuilder.class).type(EventType.USER_CONSENT_REVOKED).domain(domain).principal(principal).user(user).throwable(throwable))))
)
.doOnError(throwable -> auditService.report(AuditBuilder.builder(UserConsentAuditBuilder.class).type(EventType.USER_CONSENT_REVOKED).domain(domain).principal(principal).user(user).throwable(throwable)))
.andThen(Completable.mergeArrayDelayError(accessTokenRepository.deleteByDomainIdClientIdAndUserId(scopeApproval.getDomain(), scopeApproval.getClientId(), scopeApproval.getUserId()),
refreshTokenRepository.deleteByDomainIdClientIdAndUserId(scopeApproval.getDomain(), scopeApproval.getClientId(), scopeApproval.getUserId())))))
.onErrorResumeNext(ex -> {
if (ex instanceof AbstractManagementException) {
return Completable.error(ex);
Expand All @@ -142,7 +153,8 @@ public Completable revokeByUser(String domain, String user, User principal) {
.flatMapCompletable(scopeApprovals -> scopeApprovalRepository.deleteByDomainAndUser(domain, user)
.doOnComplete(() -> auditService.report(AuditBuilder.builder(UserConsentAuditBuilder.class).type(EventType.USER_CONSENT_REVOKED).domain(domain).principal(principal).user(user1).approvals(scopeApprovals)))
.doOnError(throwable -> auditService.report(AuditBuilder.builder(UserConsentAuditBuilder.class).type(EventType.USER_CONSENT_REVOKED).domain(domain).principal(principal).user(user1).throwable(throwable))))
)
.andThen(Completable.mergeArrayDelayError(accessTokenRepository.deleteByDomainIdAndUserId(domain, user),
refreshTokenRepository.deleteByDomainIdAndUserId(domain, user))))
.onErrorResumeNext(ex -> {
if (ex instanceof AbstractManagementException) {
return Completable.error(ex);
Expand All @@ -163,7 +175,8 @@ public Completable revokeByUserAndClient(String domain, String user, String clie
.flatMapCompletable(scopeApprovals -> scopeApprovalRepository.deleteByDomainAndUserAndClient(domain, user, clientId)
.doOnComplete(() -> auditService.report(AuditBuilder.builder(UserConsentAuditBuilder.class).type(EventType.USER_CONSENT_REVOKED).domain(domain).principal(principal).user(user1).approvals(scopeApprovals)))
.doOnError(throwable -> auditService.report(AuditBuilder.builder(UserConsentAuditBuilder.class).type(EventType.USER_CONSENT_REVOKED).domain(domain).principal(principal).user(user1).throwable(throwable))))
)
.andThen(Completable.mergeArrayDelayError(accessTokenRepository.deleteByDomainIdClientIdAndUserId(domain, clientId, user),
refreshTokenRepository.deleteByDomainIdClientIdAndUserId(domain, clientId, user))))
.onErrorResumeNext(ex -> {
if (ex instanceof AbstractManagementException) {
return Completable.error(ex);
Expand Down
Loading

0 comments on commit d020b51

Please sign in to comment.