Skip to content

Commit

Permalink
feat: improve audience handling (#3786)
Browse files Browse the repository at this point in the history
* feat: improve audience handling

* chore: fix tests

* pr suggestions
  • Loading branch information
wolf4ood authored Jan 22, 2024
1 parent 0732049 commit 185b633
Show file tree
Hide file tree
Showing 42 changed files with 402 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ protected boolean processTerminating(ContractNegotiation negotiation) {
protected AsyncStatusResultRetryProcess<ContractNegotiation, Object, ?> dispatch(ProcessRemoteMessage.Builder<?, ?> messageBuilder,
ContractNegotiation negotiation) {
messageBuilder.counterPartyAddress(negotiation.getCounterPartyAddress())
.counterPartyId(negotiation.getCounterPartyId())
.protocol(negotiation.getProtocol())
.processId(Optional.ofNullable(negotiation.getCorrelationId()).orElse(negotiation.getId()));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,27 @@ public PolicyArchiveImpl(ContractNegotiationStore contractNegotiationStore) {
public Policy findPolicyForContract(String contractId) {
return Optional.ofNullable(contractId)
.map(contractNegotiationStore::findContractAgreement)
.map(ContractAgreement::getPolicy)
.map(this::mapAgreementPolicy)
.orElse(null);
}

// TODO assignee and assigner should end up stored in the Agreement's policy as outlined here
// https://github.com/International-Data-Spaces-Association/ids-specification/issues/195
// As fallback we fill the assignee and the assigner from the consumer and provider id in
// the contract agreement

private Policy mapAgreementPolicy(ContractAgreement contractAgreement) {
var policy = contractAgreement.getPolicy();
var assignee = Optional.ofNullable(policy.getAssignee())
.orElseGet(contractAgreement::getConsumerId);

var assigner = Optional.ofNullable(policy.getAssigner())
.orElseGet(contractAgreement::getProviderId);

return policy.toBuilder()
.assignee(assignee)
.assigner(assigner)
.build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,33 @@ class PolicyArchiveImplTest {

@Test
void shouldGetPolicyFromAgreement() {
var policy = Policy.Builder.newInstance().assigner("assigner").assignee("assignee").build();
var contractAgreement = createContractAgreement(policy);
when(contractNegotiationStore.findContractAgreement("contractId")).thenReturn(contractAgreement);

var result = policyArchive.findPolicyForContract("contractId");

assertThat(result).usingRecursiveComparison().ignoringFields().isEqualTo(policy);

assertThat(result.getAssigner()).isNotEqualTo(contractAgreement.getProviderId());
assertThat(result.getAssignee()).isNotEqualTo(contractAgreement.getConsumerId());
}

@Test
void shouldGetPolicyFromAgreement_WithAssigneeAndAssignedInferred() {
var policy = Policy.Builder.newInstance().build();
var contractAgreement = createContractAgreement(policy);
when(contractNegotiationStore.findContractAgreement("contractId")).thenReturn(contractAgreement);

var result = policyArchive.findPolicyForContract("contractId");

assertThat(result).usingRecursiveComparison().isEqualTo(policy);
assertThat(result).usingRecursiveComparison().ignoringFields("assignee", "assigner").isEqualTo(policy);

assertThat(result.getAssigner()).isEqualTo(contractAgreement.getProviderId());
assertThat(result.getAssignee()).isEqualTo(contractAgreement.getConsumerId());
}


@Test
void shouldReturnNullIfContractDoesNotExist() {
when(contractNegotiationStore.findContractAgreement("contractId")).thenReturn(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,23 @@ public CatalogServiceImpl(RemoteMessageDispatcherRegistry dispatcher) {
}

@Override
public CompletableFuture<StatusResult<byte[]>> requestCatalog(String providerUrl, String protocol, QuerySpec querySpec) {
public CompletableFuture<StatusResult<byte[]>> requestCatalog(String counterPartyId, String counterPartyAddress, String protocol, QuerySpec querySpec) {
var request = CatalogRequestMessage.Builder.newInstance()
.protocol(protocol)
.counterPartyAddress(providerUrl)
.counterPartyId(counterPartyId)
.counterPartyAddress(counterPartyAddress)
.querySpec(querySpec)
.build();

return dispatcher.dispatch(byte[].class, request);
}

@Override
public CompletableFuture<StatusResult<byte[]>> requestDataset(String id, String counterPartyAddress, String protocol) {
public CompletableFuture<StatusResult<byte[]>> requestDataset(String id, String counterPartyId, String counterPartyAddress, String protocol) {
var request = DatasetRequestMessage.Builder.newInstance()
.datasetId(id)
.protocol(protocol)
.counterPartyId(counterPartyId)
.counterPartyAddress(counterPartyAddress)
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class CatalogServiceImplTest {
void requestCatalog_shouldDispatchRequestAndReturnResult() {
when(dispatcher.dispatch(eq(byte[].class), any())).thenReturn(completedFuture(StatusResult.success("content".getBytes())));

var result = service.requestCatalog("http://provider/url", "protocol", QuerySpec.none());
var result = service.requestCatalog("counterPartyId", "http://provider/url", "protocol", QuerySpec.none());

assertThat(result).succeedsWithin(5, SECONDS).satisfies(statusResult -> {
assertThat(statusResult).isSucceeded().isEqualTo("content".getBytes());
Expand All @@ -54,7 +54,7 @@ void requestCatalog_shouldDispatchRequestAndReturnResult() {
void requestDataset_shouldDispatchRequestAndReturnResult() {
when(dispatcher.dispatch(eq(byte[].class), any())).thenReturn(completedFuture(StatusResult.success("content".getBytes())));

var result = service.requestDataset("datasetId", "http://provider/url", "protocol");
var result = service.requestDataset("datasetId", "counterPartyId", "http://provider/url", "protocol");

assertThat(result).succeedsWithin(5, SECONDS).satisfies(statusResult -> {
assertThat(statusResult).isSucceeded().isEqualTo("content".getBytes());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,9 @@ private boolean processRequesting(TransferProcess process) {
.callbackAddress(protocolWebhook.url())
.dataDestination(dataDestination)
.transferType(process.getTransferType())
.contractId(process.getContractId())
.policy(policyArchive.findPolicyForContract(process.getContractId()));
.contractId(process.getContractId());

return dispatch(messageBuilder, process)
return dispatch(messageBuilder, process, policyArchive.findPolicyForContract(process.getContractId()))
.onSuccess((t, content) -> transitionToRequested(t))
.onRetryExhausted(this::transitionToTerminated)
.onFailure((t, throwable) -> transitionToRequesting(t))
Expand Down Expand Up @@ -323,9 +322,9 @@ private boolean processStarting(TransferProcess process) {
*/
@WithSpan
private boolean processCompleting(TransferProcess process) {
var builder = TransferCompletionMessage.Builder.newInstance().policy(policyArchive.findPolicyForContract(process.getContractId()));
var builder = TransferCompletionMessage.Builder.newInstance();

return dispatch(builder, process)
return dispatch(builder, process, policyArchive.findPolicyForContract(process.getContractId()))
.onSuccess((t, content) -> {
transitionToCompleted(t);
if (t.getType() == PROVIDER) {
Expand Down Expand Up @@ -385,10 +384,9 @@ private boolean processDeprovisioning(TransferProcess process) {
@WithSpan
private void sendTransferStartMessage(TransferProcess process, DataFlowResponse dataFlowResponse, Policy policy) {
var messageBuilder = TransferStartMessage.Builder.newInstance()
.dataAddress(dataFlowResponse.getDataAddress())
.policy(policy);
.dataAddress(dataFlowResponse.getDataAddress());

dispatch(messageBuilder, process)
dispatch(messageBuilder, process, policy)
.onSuccess((t, content) -> transitionToStarted(t))
.onFailure((t, throwable) -> transitionToStarting(t))
.onFatalError((n, failure) -> transitionToTerminated(n, failure.getFailureDetail()))
Expand All @@ -407,10 +405,9 @@ private StatusResult<Void> terminateDataFlow(TransferProcess process) {

private boolean sendTransferTerminationMessage(TransferProcess process) {
var builder = TransferTerminationMessage.Builder.newInstance()
.policy(policyArchive.findPolicyForContract(process.getContractId()))
.reason(process.getErrorDetail());

return dispatch(builder, process)
return dispatch(builder, process, policyArchive.findPolicyForContract(process.getContractId()))
.onSuccess((t, content) -> {
transitionToTerminated(t);
if (t.getType() == PROVIDER) {
Expand All @@ -423,19 +420,25 @@ private boolean sendTransferTerminationMessage(TransferProcess process) {
.execute("send transfer termination to " + process.getConnectorAddress());
}

private <M extends TransferRemoteMessage, B extends TransferRemoteMessage.Builder<M, B>> AsyncStatusResultRetryProcess<TransferProcess, Object, ?> dispatch(B messageBuilder, TransferProcess process) {
private <M extends TransferRemoteMessage, B extends TransferRemoteMessage.Builder<M, B>> AsyncStatusResultRetryProcess<TransferProcess, Object, ?> dispatch(B messageBuilder, TransferProcess process, Policy policy) {

messageBuilder.protocol(process.getProtocol())
.counterPartyAddress(process.getConnectorAddress())
.processId(process.getCorrelationId());
.processId(process.getCorrelationId())
.policy(policy);

if (process.lastSentProtocolMessage() != null) {
messageBuilder.id(process.lastSentProtocolMessage());
}

if (process.getType() == PROVIDER) {
messageBuilder.consumerPid(process.getCorrelationId()).providerPid(process.getId());
messageBuilder.consumerPid(process.getCorrelationId())
.providerPid(process.getId())
.counterPartyId(policy.getAssignee());
} else {
messageBuilder.consumerPid(process.getId()).providerPid(process.getCorrelationId());
messageBuilder.consumerPid(process.getId())
.providerPid(process.getCorrelationId())
.counterPartyId(policy.getAssigner());
}

var message = messageBuilder.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class TransferProcessManagerImplTest {
void setup() {
when(protocolWebhook.url()).thenReturn(protocolWebhookUrl);
when(dataFlowManager.initiate(any(), any())).thenReturn(StatusResult.success(createDataFlowResponse()));
when(policyArchive.findPolicyForContract(any())).thenReturn(Policy.Builder.newInstance().build());
var observable = new TransferProcessObservableImpl();
observable.registerListener(listener);
var entityRetryProcessConfiguration = new EntityRetryProcessConfiguration(RETRY_LIMIT, () -> new ExponentialWaitStrategy(0L));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.eclipse.edc.runtime.metamodel.annotation.Inject;
import org.eclipse.edc.runtime.metamodel.annotation.Provider;
import org.eclipse.edc.spi.http.EdcHttpClient;
import org.eclipse.edc.spi.iam.AudienceResolver;
import org.eclipse.edc.spi.iam.IdentityService;
import org.eclipse.edc.spi.message.RemoteMessageDispatcherRegistry;
import org.eclipse.edc.spi.monitor.Monitor;
Expand Down Expand Up @@ -98,6 +99,8 @@ public class DspHttpCoreExtension implements ServiceExtension {
@Inject
private PolicyEngine policyEngine;

@Inject
private AudienceResolver audienceResolver;
@Inject
private Monitor monitor;

Expand All @@ -119,7 +122,7 @@ public DspHttpRemoteMessageDispatcher dspHttpRemoteMessageDispatcher(ServiceExte
td = bldr -> bldr;
}

var dispatcher = new DspHttpRemoteMessageDispatcherImpl(httpClient, identityService, td, policyEngine);
var dispatcher = new DspHttpRemoteMessageDispatcherImpl(httpClient, identityService, td, policyEngine, audienceResolver);
registerNegotiationPolicyScopes(dispatcher);
registerTransferProcessPolicyScopes(dispatcher);
registerCatalogPolicyScopes(dispatcher);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.eclipse.edc.protocol.dsp.spi.types.HttpMessageProtocol;
import org.eclipse.edc.spi.EdcException;
import org.eclipse.edc.spi.http.EdcHttpClient;
import org.eclipse.edc.spi.iam.AudienceResolver;
import org.eclipse.edc.spi.iam.IdentityService;
import org.eclipse.edc.spi.iam.TokenParameters;
import org.eclipse.edc.spi.response.StatusResult;
Expand Down Expand Up @@ -58,14 +59,20 @@ public class DspHttpRemoteMessageDispatcherImpl implements DspHttpRemoteMessageD
private final PolicyEngine policyEngine;
private final TokenDecorator tokenDecorator;

private final AudienceResolver audienceResolver;

private static final String AUDIENCE_CLAIM = "aud";

public DspHttpRemoteMessageDispatcherImpl(EdcHttpClient httpClient,
IdentityService identityService,
TokenDecorator decorator,
PolicyEngine policyEngine) {
PolicyEngine policyEngine,
AudienceResolver audienceResolver) {
this.httpClient = httpClient;
this.identityService = identityService;
this.policyEngine = policyEngine;
this.tokenDecorator = decorator;
this.audienceResolver = audienceResolver;
}

@Override
Expand Down Expand Up @@ -94,7 +101,7 @@ public <T, M extends RemoteMessage> CompletableFuture<StatusResult<T>> dispatch(
}

var tokenParameters = tokenParametersBuilder
.claims("aud", message.getCounterPartyAddress()) // enforce the audience, ignore anything a decorator might have set
.claims(AUDIENCE_CLAIM, audienceResolver.resolve(message)) // enforce the audience, ignore anything a decorator might have set
.build();

return identityService.obtainClientCredentials(tokenParameters)
Expand Down Expand Up @@ -150,11 +157,13 @@ private String asString(ResponseBody it) {
private record MessageHandler<M extends RemoteMessage, R>(
DspHttpRequestFactory<M> requestFactory,
DspHttpResponseBodyExtractor<R> bodyExtractor
) { }
) {
}

private record PolicyScope<M extends RemoteMessage>(
Class<M> messageClass, String scope,
Function<M, Policy> policyProvider
) { }
) {
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void createDispatcher_noTokenDecorator_shouldUseNoop(ServiceExtensionContext con
extension = factory.constructInstance(DspHttpCoreExtension.class);
var dispatcher = extension.dspHttpRemoteMessageDispatcher(context);
dispatcher.registerMessage(TestMessage.class, mock(), mock());
dispatcher.dispatch(String.class, new TestMessage("protocol", "address"));
dispatcher.dispatch(String.class, new TestMessage("protocol", "address", "counterPartyId"));

verify(identityService).obtainClientCredentials(argThat(tokenParams -> tokenParams.getStringClaim(SCOPE_CLAIM) == null));
}
Expand All @@ -68,7 +68,7 @@ void createDispatcher_withTokenDecorator_shouldUse(ServiceExtensionContext conte
extension = factory.constructInstance(DspHttpCoreExtension.class);
var dispatcher = extension.dspHttpRemoteMessageDispatcher(context);
dispatcher.registerMessage(TestMessage.class, mock(), mock());
dispatcher.dispatch(String.class, new TestMessage("protocol", "address"));
dispatcher.dispatch(String.class, new TestMessage("protocol", "address", "counterPartyId"));

verify(identityService).obtainClientCredentials(argThat(tokenParams -> tokenParams.getStringClaim(SCOPE_CLAIM).equals("test-scope")));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import org.eclipse.edc.spi.types.domain.message.RemoteMessage;

public record TestMessage(String protocol, String counterPartyAddress) implements RemoteMessage {
public record TestMessage(String protocol, String counterPartyAddress, String counterPartyId) implements RemoteMessage {
@Override
public String getProtocol() {
return protocol;
Expand All @@ -26,4 +26,9 @@ public String getProtocol() {
public String getCounterPartyAddress() {
return counterPartyAddress;
}

@Override
public String getCounterPartyId() {
return counterPartyId;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.eclipse.edc.protocol.dsp.spi.dispatcher.response.DspHttpResponseBodyExtractor;
import org.eclipse.edc.spi.EdcException;
import org.eclipse.edc.spi.http.EdcHttpClient;
import org.eclipse.edc.spi.iam.AudienceResolver;
import org.eclipse.edc.spi.iam.IdentityService;
import org.eclipse.edc.spi.iam.TokenParameters;
import org.eclipse.edc.spi.iam.TokenRepresentation;
Expand Down Expand Up @@ -64,15 +65,17 @@ class DspHttpRemoteMessageDispatcherImplTest {

private static final String SCOPE_CLAIM = "scope";
private static final String AUDIENCE_CLAIM = "aud";
private static final String AUDIENCE_VALUE = "audValue";
private final EdcHttpClient httpClient = mock();
private final IdentityService identityService = mock();
private final PolicyEngine policyEngine = mock();
private final TokenDecorator tokenDecorator = mock();
private final DspHttpRequestFactory<TestMessage> requestFactory = mock();
private final AudienceResolver audienceResolver = mock();
private final Duration timeout = Duration.of(5, SECONDS);

private final DspHttpRemoteMessageDispatcher dispatcher =
new DspHttpRemoteMessageDispatcherImpl(httpClient, identityService, tokenDecorator, policyEngine);
new DspHttpRemoteMessageDispatcherImpl(httpClient, identityService, tokenDecorator, policyEngine, audienceResolver);

private static okhttp3.Response dummyResponse(int code) {
return dummyResponseBuilder(code)
Expand All @@ -91,6 +94,7 @@ private static okhttp3.Response.Builder dummyResponseBuilder(int code) {

@BeforeEach
void setUp() {
when(audienceResolver.resolve(any())).thenReturn(AUDIENCE_VALUE);
when(tokenDecorator.decorate(any())).thenAnswer(a -> a.getArgument(0));
}

Expand Down Expand Up @@ -122,7 +126,7 @@ void dispatch_ensureTokenDecoratorScope() {
verify(requestFactory).createRequest(message);
assertThat(captor.getValue()).satisfies(tr -> {
assertThat(tr.getStringClaim(SCOPE_CLAIM)).isEqualTo("test-scope");
assertThat(tr.getStringClaim(AUDIENCE_CLAIM)).isEqualTo(message.getCounterPartyAddress());
assertThat(tr.getStringClaim(AUDIENCE_CLAIM)).isEqualTo(AUDIENCE_VALUE);
assertThat(tr.getClaims()).containsAllEntriesOf(additional);
});

Expand Down Expand Up @@ -188,6 +192,11 @@ public String getProtocol() {
public String getCounterPartyAddress() {
return "http://connector";
}

@Override
public String getCounterPartyId() {
return null;
}
}

@Nested
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class GetDspHttpRequestFactoryTest {
void shouldCreateProperHttpRequest() {
when(pathProvider.providePath(any())).thenReturn("/message/request/path");

var message = new TestMessage("protocol", "http://counter-party");
var message = new TestMessage("protocol", "http://counter-party", "counterPartyId");
var request = factory.createRequest(message);

assertThat(request.url().url().toString()).isEqualTo("http://counter-party/message/request/path");
Expand Down
Loading

0 comments on commit 185b633

Please sign in to comment.