diff --git a/engine/src/main/java/io/seldon/engine/service/InternalPredictionService.java b/engine/src/main/java/io/seldon/engine/service/InternalPredictionService.java index b7fd1130ad..4747d9f9a0 100644 --- a/engine/src/main/java/io/seldon/engine/service/InternalPredictionService.java +++ b/engine/src/main/java/io/seldon/engine/service/InternalPredictionService.java @@ -101,7 +101,7 @@ public class InternalPredictionService { private final GrpcChannelHandler grpcChannelHandler; private final Map headersCache = new ConcurrentHashMap<>(); - private final Map uriCache = new ConcurrentHashMap<>(); + private final Map uriCache = new ConcurrentHashMap<>(); @Autowired public InternalPredictionService(RestTemplateBuilder restTemplateBuilder,AnnotationsConfig annotations,GrpcChannelHandler grpcChannelHandler,TracingProvider tracingProvider){ @@ -344,6 +344,12 @@ private boolean isDefaultData(SeldonMessage message){ return true; return false; } + + public static String getUriKey(Endpoint endpoint,String path) + { + StringBuilder sb = new StringBuilder(); + return sb.append(endpoint.getServiceHost()).append(":").append(endpoint.getServicePort()).append(path).toString(); + } private SeldonMessage queryREST(String path, String dataString, PredictiveUnitState state, Endpoint endpoint, boolean isDefault) { @@ -351,8 +357,9 @@ private SeldonMessage queryREST(String path, String dataString, PredictiveUnitSt URI uri; try { - if (uriCache.containsKey(endpoint)) - uri = uriCache.get(endpoint); + final String uriKey = getUriKey(endpoint, path); + if (uriCache.containsKey(uriKey)) + uri = uriCache.get(uriKey); else { URIBuilder builder = new URIBuilder().setScheme("http") @@ -360,7 +367,7 @@ private SeldonMessage queryREST(String path, String dataString, PredictiveUnitSt .setPort(endpoint.getServicePort()) .setPath("/"+path); uri = builder.build(); - uriCache.put(endpoint, uri); + uriCache.put(uriKey, uri); } } catch (URISyntaxException e) { diff --git a/engine/src/test/java/io/seldon/engine/service/UriCacheTest.java b/engine/src/test/java/io/seldon/engine/service/UriCacheTest.java new file mode 100644 index 0000000000..a491c59841 --- /dev/null +++ b/engine/src/test/java/io/seldon/engine/service/UriCacheTest.java @@ -0,0 +1,36 @@ +package io.seldon.engine.service; + +import org.junit.Test; +import org.junit.Assert; +import io.seldon.protos.DeploymentProtos.Endpoint; + + +public class UriCacheTest { + + @Test + public void testUri() + { + Endpoint endpointA = Endpoint.newBuilder().setServiceHost("hostA").setServicePort(1000).build(); + Endpoint endpointA2 = Endpoint.newBuilder().setServiceHost("hostA").setServicePort(1000).build(); + Endpoint endpointB = Endpoint.newBuilder().setServiceHost("hostB").setServicePort(1000).build(); + final String predictPath = "/predict"; + final String predictPath2 = "/predict"; + final String feedbackPath = "/feedback"; + + final String key1 = InternalPredictionService.getUriKey(endpointA, predictPath); + final String key2 = InternalPredictionService.getUriKey(endpointB, predictPath); + final String key3 = InternalPredictionService.getUriKey(endpointA, feedbackPath); + + Assert.assertNotEquals(key1, key2); + Assert.assertNotEquals(key1, key3); + + final String key4 = InternalPredictionService.getUriKey(endpointA2, predictPath); + + Assert.assertEquals(key1, key4); + + final String key5 = InternalPredictionService.getUriKey(endpointA, predictPath2); + + Assert.assertEquals(key1, key5); + } + +}