From c6dfe0c722e25d130529a8bf939388080e7a6ca4 Mon Sep 17 00:00:00 2001
From: Pritham Marupaka <pmarupaka@palantir.com>
Date: Mon, 6 Jan 2025 11:55:42 -0500
Subject: [PATCH] wip

---
 .../annotations/ConjureErrorDecoder.java      |   1 +
 .../java/dialogue/serde/ConjureBodySerDe.java | 151 ++++--------------
 .../dialogue/serde/EndpointErrorDecoder.java  |  42 +++--
 .../dialogue/serde/ConjureBodySerDeTest.java  |  16 +-
 .../EndpointErrorsConjureBodySerDeTest.java   |   1 +
 5 files changed, 70 insertions(+), 141 deletions(-)

diff --git a/dialogue-annotations/src/main/java/com/palantir/dialogue/annotations/ConjureErrorDecoder.java b/dialogue-annotations/src/main/java/com/palantir/dialogue/annotations/ConjureErrorDecoder.java
index 382c97560..1bf3a01eb 100644
--- a/dialogue-annotations/src/main/java/com/palantir/dialogue/annotations/ConjureErrorDecoder.java
+++ b/dialogue-annotations/src/main/java/com/palantir/dialogue/annotations/ConjureErrorDecoder.java
@@ -18,6 +18,7 @@
 
 import com.palantir.dialogue.Response;
 
+// TODO(pm): use the new EndpointErrorDecoder
 public final class ConjureErrorDecoder implements ErrorDecoder {
 
     @Override
diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java
index 8e67cddd2..ffa8bfd51 100644
--- a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java
+++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java
@@ -45,6 +45,7 @@
 import java.util.ArrayList;
 import java.util.Comparator;
 import java.util.List;
+import java.util.Map;
 import java.util.Optional;
 import java.util.function.Supplier;
 import java.util.stream.Collectors;
@@ -52,7 +53,6 @@
 /**
  * items:
  * - we don't want to use `String` for the error identifier. Let's create an `ErrorName` class.
- * - re-consider using a map for the deserializersForEndpointBaseType field. is there a more direct way to get this info
  */
 
 /** Package private internal API. */
@@ -65,7 +65,7 @@ final class ConjureBodySerDe implements BodySerDe {
     private final Deserializer<Optional<InputStream>> optionalBinaryInputStreamDeserializer;
     private final Deserializer<Void> emptyBodyDeserializer;
     private final LoadingCache<Type, Serializer<?>> serializers;
-    private final LoadingCache<Type, EncodingDeserializerRegistry<?>> deserializers;
+    private final LoadingCache<Type, EncodingDeserializerForEndpointRegistry<?>> deserializers;
     private final EmptyContainerDeserializer emptyContainerDeserializer;
 
     /**
@@ -75,32 +75,49 @@ final class ConjureBodySerDe implements BodySerDe {
      */
     ConjureBodySerDe(
             List<WeightedEncoding> rawEncodings,
-            ErrorDecoder errorDecoder,
+            ErrorDecoder _errorDecoder,
             EmptyContainerDeserializer emptyContainerDeserializer,
             CaffeineSpec cacheSpec) {
         List<WeightedEncoding> encodings = decorateEncodings(rawEncodings);
         this.encodingsSortedByWeight = sortByWeight(encodings);
         Preconditions.checkArgument(encodings.size() > 0, "At least one Encoding is required");
+        // note(pm): why do the weighted encoding thing? can we just pass in the default encoding?
         this.defaultEncoding = encodings.get(0).encoding();
         this.emptyContainerDeserializer = emptyContainerDeserializer;
-        this.binaryInputStreamDeserializer = new EncodingDeserializerRegistry<>(
+        this.binaryInputStreamDeserializer = new EncodingDeserializerForEndpointRegistry<>(
                 ImmutableList.of(BinaryEncoding.INSTANCE),
-                errorDecoder,
                 emptyContainerDeserializer,
-                BinaryEncoding.MARKER);
-        this.optionalBinaryInputStreamDeserializer = new EncodingDeserializerRegistry<>(
+                BinaryEncoding.MARKER,
+                DeserializerArgs.<InputStream>builder()
+                        .withBaseType(BinaryEncoding.MARKER)
+                        .withExpectedResult(BinaryEncoding.MARKER)
+                        .build());
+        this.optionalBinaryInputStreamDeserializer = new EncodingDeserializerForEndpointRegistry<>(
                 ImmutableList.of(BinaryEncoding.INSTANCE),
-                errorDecoder,
                 emptyContainerDeserializer,
-                BinaryEncoding.OPTIONAL_MARKER);
-        this.emptyBodyDeserializer = new EmptyBodyDeserializer(errorDecoder);
+                BinaryEncoding.OPTIONAL_MARKER,
+                DeserializerArgs.<Optional<InputStream>>builder()
+                        .withBaseType(BinaryEncoding.OPTIONAL_MARKER)
+                        .withExpectedResult(BinaryEncoding.OPTIONAL_MARKER)
+                        .build());
+        this.emptyBodyDeserializer =
+                new EmptyBodyDeserializer(new EndpointErrorDecoder<>(Map.of(), encodingsSortedByWeight));
         // Class unloading: Not supported, Jackson keeps strong references to the types
         // it sees: https://github.com/FasterXML/jackson-databind/issues/489
         this.serializers = Caffeine.from(cacheSpec)
                 .build(type -> new EncodingSerializerRegistry<>(defaultEncoding, TypeMarker.of(type)));
-        this.deserializers = Caffeine.from(cacheSpec)
-                .build(type -> new EncodingDeserializerRegistry<>(
-                        encodingsSortedByWeight, errorDecoder, emptyContainerDeserializer, TypeMarker.of(type)));
+        this.deserializers = Caffeine.from(cacheSpec).build(type -> buildCacheEntry(TypeMarker.of(type)));
+    }
+
+    private <T> EncodingDeserializerForEndpointRegistry<?> buildCacheEntry(TypeMarker<T> typeMarker) {
+        return new EncodingDeserializerForEndpointRegistry<>(
+                encodingsSortedByWeight,
+                emptyContainerDeserializer,
+                typeMarker,
+                DeserializerArgs.<T>builder()
+                        .withBaseType(typeMarker)
+                        .withExpectedResult(typeMarker)
+                        .build());
     }
 
     private static List<WeightedEncoding> decorateEncodings(List<WeightedEncoding> input) {
@@ -235,108 +252,7 @@ private static final class EncodingSerializerContainer<T> {
         }
     }
 
-    private static final class EncodingDeserializerRegistry<T> implements Deserializer<T> {
-
-        private static final SafeLogger log = SafeLoggerFactory.get(EncodingDeserializerRegistry.class);
-        private final ImmutableList<EncodingDeserializerContainer<T>> encodings;
-        private final ErrorDecoder errorDecoder;
-        private final Optional<String> acceptValue;
-        private final Supplier<Optional<T>> emptyInstance;
-        private final TypeMarker<T> token;
-
-        EncodingDeserializerRegistry(
-                List<Encoding> encodings,
-                ErrorDecoder errorDecoder,
-                EmptyContainerDeserializer empty,
-                TypeMarker<T> token) {
-            this.encodings = encodings.stream()
-                    .map(encoding -> new EncodingDeserializerContainer<>(encoding, token))
-                    .collect(ImmutableList.toImmutableList());
-            this.errorDecoder = errorDecoder;
-            this.token = token;
-            this.emptyInstance = Suppliers.memoize(() -> empty.tryGetEmptyInstance(token));
-            // Encodings are applied to the accept header in the order of preference based on the provided list.
-            this.acceptValue =
-                    Optional.of(encodings.stream().map(Encoding::getContentType).collect(Collectors.joining(", ")));
-        }
-
-        @Override
-        public T deserialize(Response response) {
-            boolean closeResponse = true;
-            try {
-                if (errorDecoder.isError(response)) {
-                    throw errorDecoder.decode(response);
-                } else if (response.code() == 204) {
-                    // TODO(dfox): what if we get a 204 for a non-optional type???
-                    // TODO(dfox): support http200 & body=null
-                    // TODO(dfox): what if we were expecting an empty list but got {}?
-                    Optional<T> maybeEmptyInstance = emptyInstance.get();
-                    if (maybeEmptyInstance.isPresent()) {
-                        return maybeEmptyInstance.get();
-                    }
-                    throw new SafeRuntimeException(
-                            "Unable to deserialize non-optional response type from 204", SafeArg.of("type", token));
-                }
-
-                Optional<String> contentType = response.getFirstHeader(HttpHeaders.CONTENT_TYPE);
-                if (!contentType.isPresent()) {
-                    throw new SafeIllegalArgumentException(
-                            "Response is missing Content-Type header",
-                            SafeArg.of("received", response.headers().keySet()));
-                }
-                Encoding.Deserializer<T> deserializer = getResponseDeserializer(contentType.get());
-                T deserialized = deserializer.deserialize(response.body());
-                // deserializer has taken on responsibility for closing the response body
-                closeResponse = false;
-                return deserialized;
-            } catch (IOException e) {
-                throw new SafeRuntimeException(
-                        "Failed to deserialize response stream",
-                        e,
-                        SafeArg.of("contentType", response.getFirstHeader(HttpHeaders.CONTENT_TYPE)),
-                        SafeArg.of("type", token));
-            } finally {
-                if (closeResponse) {
-                    response.close();
-                }
-            }
-        }
-
-        @Override
-        public Optional<String> accepts() {
-            return acceptValue;
-        }
-
-        /** Returns the {@link EncodingDeserializerContainer} to use to deserialize the request body. */
-        @SuppressWarnings("ForLoopReplaceableByForEach")
-        // performance sensitive code avoids iterator allocation
-        Encoding.Deserializer<T> getResponseDeserializer(String contentType) {
-            for (int i = 0; i < encodings.size(); i++) {
-                EncodingDeserializerContainer<T> container = encodings.get(i);
-                if (container.encoding.supportsContentType(contentType)) {
-                    return container.deserializer;
-                }
-            }
-            return throwingDeserializer(contentType);
-        }
-
-        private Encoding.Deserializer<T> throwingDeserializer(String contentType) {
-            return input -> {
-                try {
-                    input.close();
-                } catch (RuntimeException | IOException e) {
-                    log.warn("Failed to close InputStream", e);
-                }
-                throw new SafeRuntimeException(
-                        "Unsupported Content-Type",
-                        SafeArg.of("received", contentType),
-                        SafeArg.of("supportedEncodings", encodings));
-            };
-        }
-    }
-
     private static final class EncodingDeserializerForEndpointRegistry<T> implements Deserializer<T> {
-
         private static final SafeLogger log = SafeLoggerFactory.get(EncodingDeserializerForEndpointRegistry.class);
         private final ImmutableList<EncodingDeserializerContainer<? extends T>> encodings;
         private final EndpointErrorDecoder<T> endpointErrorDecoder;
@@ -367,7 +283,6 @@ public T deserialize(Response response) {
             boolean closeResponse = true;
             try {
                 if (endpointErrorDecoder.isError(response)) {
-                    // TODO(pm): This needs to return T for the new deserializer API, but throw an exception for the old
                     return endpointErrorDecoder.decode(response);
                 } else if (response.code() == 204) {
                     Optional<T> maybeEmptyInstance = emptyInstance.get();
@@ -457,9 +372,9 @@ public String toString() {
     }
 
     private static final class EmptyBodyDeserializer implements Deserializer<Void> {
-        private final ErrorDecoder errorDecoder;
+        private final EndpointErrorDecoder<?> errorDecoder;
 
-        EmptyBodyDeserializer(ErrorDecoder errorDecoder) {
+        EmptyBodyDeserializer(EndpointErrorDecoder<?> errorDecoder) {
             this.errorDecoder = errorDecoder;
         }
 
@@ -469,7 +384,7 @@ public Void deserialize(Response response) {
             // We should not fail if a server that previously returned nothing starts returning a response
             try (Response unused = response) {
                 if (errorDecoder.isError(response)) {
-                    throw errorDecoder.decode(response);
+                    errorDecoder.decode(response);
                 }
                 return null;
             }
diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorDecoder.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorDecoder.java
index f2ecfdfcc..dff70a07b 100644
--- a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorDecoder.java
+++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorDecoder.java
@@ -127,27 +127,32 @@ private T decodeInternal(Response response) {
         }
 
         Optional<String> contentType = response.getFirstHeader(HttpHeaders.CONTENT_TYPE);
-        // Use a factory: given contentType, create the deserailizer.
+        // Use a factory: given contentType, create the deserializer.
         // We need Encoding.Deserializer here. That depends on the encoding.
-        if (contentType.isPresent() && Encodings.matchesContentType("application/json", contentType.get())) {
+        String jsonContentType = "application/json";
+        if (contentType.isPresent() && Encodings.matchesContentType(jsonContentType, contentType.get())) {
             try {
                 JsonNode node = MAPPER.readTree(body);
-                if (node.get("errorName") != null) {
-                    // TODO(pm): Update this to use some struct instead of errorName.
-                    TypeMarker<? extends T> container = Optional.ofNullable(
-                                    errorNameToTypeMap.get(node.get("errorName").asText()))
-                            .orElseThrow();
-                    for (int i = 0; i < encodings.size(); i++) {
-                        Encoding encoding = encodings.get(i);
-                        if (encoding.supportsContentType(contentType.get())) {
-                            return encoding.deserializer(container)
-                                    .deserialize(new ByteArrayInputStream(body.getBytes(StandardCharsets.UTF_8)));
-                        }
+                if (node.get("errorName") == null) {
+                    throwSerializableError(body, code);
+                }
+                // TODO(pm): Update this to use some struct instead of errorName.
+                Optional<TypeMarker<? extends T>> maybeContainer = Optional.ofNullable(
+                        errorNameToTypeMap.get(node.get("errorName").asText()));
+                if (maybeContainer.isEmpty()) {
+                    // This thrown exception will be caught below. Refactor.
+                    throwSerializableError(body, code);
+                }
+                for (int i = 0; i < encodings.size(); i++) {
+                    Encoding encoding = encodings.get(i);
+                    if (encoding.supportsContentType(jsonContentType)) {
+                        return encoding.deserializer(maybeContainer.get())
+                                .deserialize(new ByteArrayInputStream(body.getBytes(StandardCharsets.UTF_8)));
                     }
-                } else {
-                    SerializableError serializableError = MAPPER.readValue(body, SerializableError.class);
-                    throw new RemoteException(serializableError, code);
                 }
+            } catch (RemoteException remoteException) {
+                // rethrow the created remote exception
+                throw remoteException;
             } catch (Exception e) {
                 throw new UnknownRemoteException(code, body);
             }
@@ -156,6 +161,11 @@ private T decodeInternal(Response response) {
         throw new UnknownRemoteException(code, body);
     }
 
+    private static void throwSerializableError(String body, int code) throws IOException {
+        SerializableError serializableError = MAPPER.readValue(body, SerializableError.class);
+        throw new RemoteException(serializableError, code);
+    }
+
     private static String toString(InputStream body) throws IOException {
         try (Reader reader = new InputStreamReader(body, StandardCharsets.UTF_8)) {
             return CharStreams.toString(reader);
diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDeTest.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDeTest.java
index da7ea260c..38ef424de 100644
--- a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDeTest.java
+++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDeTest.java
@@ -22,11 +22,14 @@
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
 import com.google.common.collect.ImmutableList;
 import com.palantir.conjure.java.api.errors.ErrorType;
 import com.palantir.conjure.java.api.errors.RemoteException;
 import com.palantir.conjure.java.api.errors.SerializableError;
 import com.palantir.conjure.java.api.errors.ServiceException;
+import com.palantir.conjure.java.serialization.ObjectMappers;
 import com.palantir.dialogue.BinaryRequestBody;
 import com.palantir.dialogue.BodySerDe;
 import com.palantir.dialogue.RequestBody;
@@ -47,6 +50,7 @@
 @ExtendWith(MockitoExtension.class)
 public class ConjureBodySerDeTest {
 
+    private static final ObjectMapper SERVER_MAPPER = ObjectMappers.newServerObjectMapper();
     private static final TypeMarker<String> TYPE = new TypeMarker<String>() {};
     private static final TypeMarker<Optional<String>> OPTIONAL_TYPE = new TypeMarker<Optional<String>>() {};
 
@@ -137,14 +141,12 @@ public void testRequestUnknownContentType() throws IOException {
     }
 
     @Test
-    public void testErrorsDecoded() {
-        TestResponse response = new TestResponse().code(400);
-
+    public void testErrorsDecoded() throws JsonProcessingException {
         ServiceException serviceException = new ServiceException(ErrorType.INVALID_ARGUMENT);
-        SerializableError serialized = SerializableError.forException(serviceException);
-        errorDecoder = mock(ErrorDecoder.class);
-        when(errorDecoder.isError(response)).thenReturn(true);
-        when(errorDecoder.decode(response)).thenReturn(new RemoteException(serialized, 400));
+        TestResponse response = TestResponse.withBody(
+                        SERVER_MAPPER.writeValueAsString(SerializableError.forException(serviceException)))
+                .code(400)
+                .contentType("application/json");
 
         BodySerDe serializers = conjureBodySerDe("text/plain");
 
diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorsConjureBodySerDeTest.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorsConjureBodySerDeTest.java
index 296bb193d..eb06ec457 100644
--- a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorsConjureBodySerDeTest.java
+++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorsConjureBodySerDeTest.java
@@ -144,6 +144,7 @@ public void testDeserializeCustomErrors() throws IOException {
         EndpointErrorsConjureBodySerDeTest.EndpointReturnBaseType value =
                 serializers.deserializer(deserializerArgs).deserialize(response);
 
+        assertThat(value).isInstanceOf(ErrorForEndpoint.class);
         assertThat(value)
                 .extracting("errorCode", "errorName", "errorInstanceId", "args")
                 .containsExactly(