Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bi-di subscription to support dapr-api-token #1142

Merged
merged 5 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions sdk-actors/src/main/java/io/dapr/actors/client/ActorClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public ActorClient(ResiliencyOptions resiliencyOptions) {
* @param overrideProperties Override properties.
*/
public ActorClient(Properties overrideProperties) {
this(buildManagedChannel(overrideProperties), null);
this(buildManagedChannel(overrideProperties), null, overrideProperties.getValue(Properties.API_TOKEN));
Copy link
Member Author

@artursouza artursouza Oct 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also needed to fix how properties override work in some cases because it was not handling overrides correctly for Dapr API Token. Now, it works too and all ITs require sidecar calls to have dapr-api-token header.

}

/**
Expand All @@ -69,7 +69,7 @@ public ActorClient(Properties overrideProperties) {
* @param resiliencyOptions Client resiliency options.
*/
public ActorClient(Properties overrideProperties, ResiliencyOptions resiliencyOptions) {
this(buildManagedChannel(overrideProperties), resiliencyOptions);
this(buildManagedChannel(overrideProperties), resiliencyOptions, overrideProperties.getValue(Properties.API_TOKEN));
}

/**
Expand All @@ -80,9 +80,10 @@ public ActorClient(Properties overrideProperties, ResiliencyOptions resiliencyOp
*/
private ActorClient(
ManagedChannel grpcManagedChannel,
ResiliencyOptions resiliencyOptions) {
ResiliencyOptions resiliencyOptions,
String daprApiToken) {
this.grpcManagedChannel = grpcManagedChannel;
this.daprClient = buildDaprClient(grpcManagedChannel, resiliencyOptions);
this.daprClient = buildDaprClient(grpcManagedChannel, resiliencyOptions, daprApiToken);
}

/**
Expand Down Expand Up @@ -136,7 +137,11 @@ private static ManagedChannel buildManagedChannel(Properties overrideProperties)
*/
private static DaprClient buildDaprClient(
Channel grpcManagedChannel,
ResiliencyOptions resiliencyOptions) {
return new DaprClientImpl(DaprGrpc.newStub(grpcManagedChannel), resiliencyOptions);
ResiliencyOptions resiliencyOptions,
String daprApiToken) {
return new DaprClientImpl(
DaprGrpc.newStub(grpcManagedChannel),
resiliencyOptions,
daprApiToken);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,6 @@
*/
class DaprClientImpl implements DaprClient {

/**
* Timeout policy for SDK calls to Dapr API.
*/
private final TimeoutPolicy timeoutPolicy;

/**
* Retry policy for SDK calls to Dapr API.
*/
Expand All @@ -57,16 +52,22 @@ class DaprClientImpl implements DaprClient {
*/
private final DaprGrpc.DaprStub client;

/**
* gRPC client interceptors.
*/
private final DaprClientGrpcInterceptors grpcInterceptors;

/**
* Internal constructor.
*
* @param grpcClient Dapr's GRPC client.
* @param resiliencyOptions Client resiliency options (optional)
* @param resiliencyOptions Client resiliency options (optional).
* @param daprApiToken Dapr API token (optional).
*/
DaprClientImpl(DaprGrpc.DaprStub grpcClient, ResiliencyOptions resiliencyOptions) {
this.client = intercept(grpcClient);
this.timeoutPolicy = new TimeoutPolicy(
resiliencyOptions == null ? null : resiliencyOptions.getTimeout());
DaprClientImpl(DaprGrpc.DaprStub grpcClient, ResiliencyOptions resiliencyOptions, String daprApiToken) {
this.client = grpcClient;
this.grpcInterceptors = new DaprClientGrpcInterceptors(daprApiToken,
new TimeoutPolicy(resiliencyOptions == null ? null : resiliencyOptions.getTimeout()));
this.retryPolicy = new RetryPolicy(
resiliencyOptions == null ? null : resiliencyOptions.getMaxRetries());
}
Expand All @@ -85,54 +86,11 @@ public Mono<byte[]> invoke(String actorType, String actorId, String methodName,
.build();
return Mono.deferContextual(
context -> this.<DaprProtos.InvokeActorResponse>createMono(
it -> intercept(context, this.timeoutPolicy, client).invokeActor(req, it)
it -> this.grpcInterceptors.intercept(client, context).invokeActor(req, it)
)
).map(r -> r.getData().toByteArray());
}

/**
* Populates GRPC client with interceptors.
*
* @param client GRPC client for Dapr.
* @return Client after adding interceptors.
*/
private DaprGrpc.DaprStub intercept(DaprGrpc.DaprStub client) {
ClientInterceptor interceptor = new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> methodDescriptor,
CallOptions options,
Channel channel) {
ClientCall<ReqT, RespT> clientCall = channel.newCall(methodDescriptor, timeoutPolicy.apply(options));
return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(clientCall) {
@Override
public void start(final Listener<RespT> responseListener, final Metadata metadata) {
String daprApiToken = Properties.API_TOKEN.get();
if (daprApiToken != null) {
metadata.put(Metadata.Key.of("dapr-api-token", Metadata.ASCII_STRING_MARSHALLER), daprApiToken);
}

super.start(responseListener, metadata);
}
};
}
};
return client.withInterceptors(interceptor);
}

/**
* Populates GRPC client with interceptors for telemetry.
*
* @param context Reactor's context.
* @param timeoutPolicy Timeout policy for gRPC call.
* @param client GRPC client for Dapr.
* @return Client after adding interceptors.
*/
private static DaprGrpc.DaprStub intercept(
ContextView context, TimeoutPolicy timeoutPolicy, DaprGrpc.DaprStub client) {
return DaprClientGrpcInterceptors.intercept(client, timeoutPolicy, context);
}

private <T> Mono<T> createMono(Consumer<StreamObserver<T>> consumer) {
return retryPolicy.apply(
Mono.create(sink -> DaprException.wrap(() -> consumer.accept(createStreamObserver(sink))).run()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public void setup() throws IOException {
InProcessChannelBuilder.forName(serverName).directExecutor().build());

// Create a HelloWorldClient using the in-process channel;
client = new DaprClientImpl(DaprGrpc.newStub(channel), null);
client = new DaprClientImpl(DaprGrpc.newStub(channel), null, null);
}

@Test
Expand Down
64 changes: 55 additions & 9 deletions sdk-tests/src/test/java/io/dapr/it/DaprRun.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
import org.apache.commons.lang3.tuple.ImmutablePair;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
Expand All @@ -40,6 +43,7 @@

public class DaprRun implements Stoppable {

private static final String DEFAULT_DAPR_API_TOKEN = UUID.randomUUID().toString();
private static final String DAPR_SUCCESS_MESSAGE = "You're up and running!";

private static final String DAPR_RUN = "dapr run --app-id %s --app-protocol %s " +
Expand Down Expand Up @@ -68,19 +72,41 @@ public class DaprRun implements Stoppable {

private final boolean hasAppHealthCheck;

private final Map<Property<?>, String> propertyOverrides;

private DaprRun(String testName,
DaprPorts ports,
String successMessage,
Class serviceClass,
int maxWaitMilliseconds,
AppRun.AppProtocol appProtocol) {
this(
testName,
ports,
successMessage,
serviceClass,
maxWaitMilliseconds,
appProtocol,
resolveDaprApiToken(serviceClass));
}

private DaprRun(String testName,
DaprPorts ports,
String successMessage,
Class serviceClass,
int maxWaitMilliseconds,
AppRun.AppProtocol appProtocol,
String daprApiToken) {
// The app name needs to be deterministic since we depend on it to kill previous runs.
this.appName = serviceClass == null ?
testName.toLowerCase() :
String.format("%s-%s", testName, serviceClass.getSimpleName()).toLowerCase();
this.appProtocol = appProtocol;
this.startCommand =
new Command(successMessage, buildDaprCommand(this.appName, serviceClass, ports, appProtocol));
new Command(
successMessage,
buildDaprCommand(this.appName, serviceClass, ports, appProtocol),
daprApiToken == null ? null : Map.of("DAPR_API_TOKEN", daprApiToken));
this.listCommand = new Command(
this.appName,
"dapr list");
Expand All @@ -91,6 +117,10 @@ private DaprRun(String testName,
this.maxWaitMilliseconds = maxWaitMilliseconds;
this.started = new AtomicBoolean(false);
this.hasAppHealthCheck = isAppHealthCheckEnabled(serviceClass);
this.propertyOverrides = daprApiToken == null ? ports.getPropertyOverrides() :
Collections.unmodifiableMap(new HashMap<>(ports.getPropertyOverrides()) {{
put(Properties.API_TOKEN, daprApiToken);
}});
}

public void start() throws InterruptedException, IOException {
Expand Down Expand Up @@ -149,7 +179,7 @@ public void stop() throws InterruptedException, IOException {
}

public Map<Property<?>, String> getPropertyOverrides() {
return this.ports.getPropertyOverrides();
return this.propertyOverrides;
}

public DaprClientBuilder newDaprClientBuilder() {
Expand Down Expand Up @@ -239,17 +269,13 @@ public String getAppName() {

public DaprClient newDaprClient() {
return new DaprClientBuilder()
.withPropertyOverride(Properties.GRPC_PORT, ports.getGrpcPort().toString())
.withPropertyOverride(Properties.HTTP_PORT, ports.getHttpPort().toString())
.withPropertyOverride(Properties.SIDECAR_IP, "127.0.0.1")
.withPropertyOverrides(this.getPropertyOverrides())
.build();
}

public DaprPreviewClient newDaprPreviewClient() {
return new DaprClientBuilder()
.withPropertyOverride(Properties.GRPC_PORT, ports.getGrpcPort().toString())
.withPropertyOverride(Properties.HTTP_PORT, ports.getHttpPort().toString())
.withPropertyOverride(Properties.SIDECAR_IP, "127.0.0.1")
.withPropertyOverrides(this.getPropertyOverrides())
.buildPreviewClient();
}

Expand Down Expand Up @@ -298,6 +324,22 @@ private static boolean isAppHealthCheckEnabled(Class serviceClass) {
return false;
}

private static String resolveDaprApiToken(Class serviceClass) {
if (serviceClass != null) {
DaprRunConfig daprRunConfig = (DaprRunConfig) serviceClass.getAnnotation(DaprRunConfig.class);
if (daprRunConfig != null) {
if (!daprRunConfig.enableDaprApiToken()) {
return null;
}
// We use the clas name itself as the token. Just needs to be deterministic.
return serviceClass.getCanonicalName();
}
}

// By default, we use a token.
return DEFAULT_DAPR_API_TOKEN;
}

private static void assertListeningOnPort(int port) {
System.out.printf("Checking port %d ...\n", port);

Expand Down Expand Up @@ -325,6 +367,8 @@ static class Builder {

private AppRun.AppProtocol appProtocol;

private String daprApiToken;

Builder(
String testName,
Supplier<DaprPorts> portsSupplier,
Expand All @@ -336,6 +380,7 @@ static class Builder {
this.successMessage = successMessage;
this.maxWaitMilliseconds = maxWaitMilliseconds;
this.appProtocol = appProtocol;
this.daprApiToken = UUID.randomUUID().toString();
}

public Builder withServiceClass(Class serviceClass) {
Expand Down Expand Up @@ -371,7 +416,8 @@ ImmutablePair<AppRun, DaprRun> splitBuild() {
DAPR_SUCCESS_MESSAGE,
null,
this.maxWaitMilliseconds,
this.appProtocol);
this.appProtocol,
resolveDaprApiToken(serviceClass));

return new ImmutablePair<>(appRun, daprRun);
}
Expand Down
2 changes: 2 additions & 0 deletions sdk-tests/src/test/java/io/dapr/it/DaprRunConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,6 @@
public @interface DaprRunConfig {

boolean enableAppHealthCheck() default false;

boolean enableDaprApiToken() default true;
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
package io.dapr.it.actors.app;

import io.dapr.actors.runtime.ActorRuntime;
import io.dapr.it.DaprRunConfig;

// Enable dapr-api-token once runtime supports it in standalone mode.
@DaprRunConfig(enableDaprApiToken = false)
public class MyActorService {
public static final String SUCCESS_MESSAGE = "dapr initialized. Status: Running";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
package io.dapr.it.actors.services.springboot;

import io.dapr.actors.runtime.ActorRuntime;
import io.dapr.it.DaprRunConfig;
import io.dapr.serializer.DefaultObjectSerializer;

import java.time.Duration;

@DaprRunConfig(enableDaprApiToken = false)
public class StatefulActorService {

public static final String SUCCESS_MESSAGE = "dapr initialized. Status: Running";
Expand Down
Loading
Loading