Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Pritham Marupaka committed Dec 6, 2024
1 parent 6d14c96 commit 6081541
Show file tree
Hide file tree
Showing 8 changed files with 749 additions and 3 deletions.
6 changes: 6 additions & 0 deletions .palantir/revapi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -311,3 +311,9 @@ acceptedBreaks:
new: "method com.palantir.dialogue.clients.DialogueClients.StickyChannelSession\
\ com.palantir.dialogue.clients.DialogueClients.StickyChannelFactory2::session()"
justification: "interface for consumption, not extension"
"4.6.0":
com.palantir.dialogue:dialogue-target:
- code: "java.method.addedToInterface"
new: "method <T> com.palantir.dialogue.Deserializer<T> com.palantir.dialogue.BodySerDe::deserializer(com.palantir.dialogue.DeserializerArgs<T>)"
justification: "Adding a new method to create deserializers in support of endpoint\
\ associated error deserialization"
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.palantir.dialogue.BinaryRequestBody;
import com.palantir.dialogue.BodySerDe;
import com.palantir.dialogue.Deserializer;
import com.palantir.dialogue.DeserializerArgs;
import com.palantir.dialogue.RequestBody;
import com.palantir.dialogue.Response;
import com.palantir.dialogue.Serializer;
Expand All @@ -48,6 +49,12 @@
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
* items:
* - we don't want to use `String` for the error identifier. Let's create an `ErrorName` class.
* - re-consider using a map for the deserializersForEndpointBaseType field. is there a more direct way to get this info
*/

/** Package private internal API. */
final class ConjureBodySerDe implements BodySerDe {

Expand All @@ -58,7 +65,8 @@ final class ConjureBodySerDe implements BodySerDe {
private final Deserializer<Optional<InputStream>> optionalBinaryInputStreamDeserializer;
private final Deserializer<Void> emptyBodyDeserializer;
private final LoadingCache<Type, Serializer<?>> serializers;
private final LoadingCache<Type, Deserializer<?>> deserializers;
private final LoadingCache<Type, EncodingDeserializerRegistry<?>> deserializers;
private final EmptyContainerDeserializer emptyContainerDeserializer;

/**
* Selects the first (based on input order) of the provided encodings that
Expand All @@ -74,6 +82,7 @@ final class ConjureBodySerDe implements BodySerDe {
this.encodingsSortedByWeight = sortByWeight(encodings);
Preconditions.checkArgument(encodings.size() > 0, "At least one Encoding is required");
this.defaultEncoding = encodings.get(0).encoding();
this.emptyContainerDeserializer = emptyContainerDeserializer;
this.binaryInputStreamDeserializer = new EncodingDeserializerRegistry<>(
ImmutableList.of(BinaryEncoding.INSTANCE),
errorDecoder,
Expand Down Expand Up @@ -122,6 +131,16 @@ public <T> Deserializer<T> deserializer(TypeMarker<T> token) {
return (Deserializer<T>) deserializers.get(token.getType());
}

@Override
@SuppressWarnings("unchecked")
public <T> Deserializer<T> deserializer(DeserializerArgs<T> deserializerArgs) {
return new EncodingDeserializerForEndpointRegistry<>(
encodingsSortedByWeight,
emptyContainerDeserializer,
(TypeMarker<T>) deserializerArgs.baseType(),
deserializerArgs);
}

@Override
public Deserializer<Void> emptyBodyDeserializer() {
return emptyBodyDeserializer;
Expand Down Expand Up @@ -301,6 +320,105 @@ Encoding.Deserializer<T> getResponseDeserializer(String contentType) {
return throwingDeserializer(contentType);
}

private Encoding.Deserializer<T> throwingDeserializer(String contentType) {
return input -> {
try {
input.close();
} catch (RuntimeException | IOException e) {
log.warn("Failed to close InputStream", e);
}
throw new SafeRuntimeException(
"Unsupported Content-Type",
SafeArg.of("received", contentType),
SafeArg.of("supportedEncodings", encodings));
};
}
}

private static final class EncodingDeserializerForEndpointRegistry<T> implements Deserializer<T> {

private static final SafeLogger log = SafeLoggerFactory.get(EncodingDeserializerForEndpointRegistry.class);
private final ImmutableList<EncodingDeserializerContainer<? extends T>> encodings;
private final EndpointErrorDecoder<T> endpointErrorDecoder;
private final Optional<String> acceptValue;
private final Supplier<Optional<T>> emptyInstance;
private final TypeMarker<T> token;

EncodingDeserializerForEndpointRegistry(
List<Encoding> encodings,
EmptyContainerDeserializer empty,
TypeMarker<T> token,
DeserializerArgs<T> deserializersForEndpoint) {
this.encodings = encodings.stream()
.map(encoding -> new EncodingDeserializerContainer<>(
encoding, deserializersForEndpoint.expectedResultType()))
.collect(ImmutableList.toImmutableList());
this.endpointErrorDecoder =
new EndpointErrorDecoder<>(deserializersForEndpoint.errorNameToTypeMarker(), encodings);
this.token = token;
this.emptyInstance = Suppliers.memoize(() -> empty.tryGetEmptyInstance(token));
// Encodings are applied to the accept header in the order of preference based on the provided list.
this.acceptValue =
Optional.of(encodings.stream().map(Encoding::getContentType).collect(Collectors.joining(", ")));
}

@Override
public T deserialize(Response response) {
boolean closeResponse = true;
try {
if (endpointErrorDecoder.isError(response)) {
return endpointErrorDecoder.decode(response);
} else if (response.code() == 204) {
Optional<T> maybeEmptyInstance = emptyInstance.get();
if (maybeEmptyInstance.isPresent()) {
return maybeEmptyInstance.get();
}
throw new SafeRuntimeException(
"Unable to deserialize non-optional response type from 204", SafeArg.of("type", token));
}

Optional<String> contentType = response.getFirstHeader(HttpHeaders.CONTENT_TYPE);
if (!contentType.isPresent()) {
throw new SafeIllegalArgumentException(
"Response is missing Content-Type header",
SafeArg.of("received", response.headers().keySet()));
}
Encoding.Deserializer<? extends T> deserializer = getResponseDeserializer(contentType.get());
T deserialized = deserializer.deserialize(response.body());
// deserializer has taken on responsibility for closing the response body
closeResponse = false;
return deserialized;
} catch (IOException e) {
throw new SafeRuntimeException(
"Failed to deserialize response stream",
e,
SafeArg.of("contentType", response.getFirstHeader(HttpHeaders.CONTENT_TYPE)),
SafeArg.of("type", token));
} finally {
if (closeResponse) {
response.close();
}
}
}

@Override
public Optional<String> accepts() {
return acceptValue;
}

/** Returns the {@link EncodingDeserializerContainer} to use to deserialize the request body. */
@SuppressWarnings("ForLoopReplaceableByForEach")
// performance sensitive code avoids iterator allocation
Encoding.Deserializer<? extends T> getResponseDeserializer(String contentType) {
for (int i = 0; i < encodings.size(); i++) {
EncodingDeserializerContainer<? extends T> container = encodings.get(i);
if (container.encoding.supportsContentType(contentType)) {
return container.deserializer;
}
}
return throwingDeserializer(contentType);
}

private Encoding.Deserializer<T> throwingDeserializer(String contentType) {
return new Encoding.Deserializer<T>() {
@Override
Expand All @@ -320,7 +438,8 @@ public T deserialize(InputStream input) {
}

/** Effectively just a pair. */
private static final class EncodingDeserializerContainer<T> {
// TODO(pm): what does saving the deserializer do for us?
static final class EncodingDeserializerContainer<T> {

private final Encoding encoding;
private final Encoding.Deserializer<T> deserializer;
Expand Down
Loading

0 comments on commit 6081541

Please sign in to comment.