diff --git a/pom.xml b/pom.xml index a039be2..a534cc9 100644 --- a/pom.xml +++ b/pom.xml @@ -31,6 +31,9 @@ 17 17 + 3.25.4 + 3.25.4 + 1.66.0 @@ -59,8 +62,36 @@ 4.13.2 test + + com.google.protobuf + protobuf-java + ${protobuf.version} + + + io.grpc + grpc-netty-shaded + ${grpc.version} + + + + io.grpc + grpc-protobuf + ${grpc.version} + + + io.grpc + grpc-stub + ${grpc.version} + + + + kr.motd.maven + os-maven-plugin + 1.7.1 + + @@ -127,6 +158,25 @@ published --> + + + org.xolstice.maven.plugins + protobuf-maven-plugin + 0.6.1 + + com.google.protobuf:protoc:${protoc.version}:exe:${os.detected.classifier} + grpc-java + io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier} + + + + + compile + compile-custom + + + + diff --git a/src/main/java/ai/spice/Config.java b/src/main/java/ai/spice/Config.java index fc24d85..db50ada 100644 --- a/src/main/java/ai/spice/Config.java +++ b/src/main/java/ai/spice/Config.java @@ -34,10 +34,6 @@ public class Config { public static final String CLOUD_FLIGHT_ADDRESS; /** Local flight address */ public static final String LOCAL_FLIGHT_ADDRESS; - /** Cloud HTTP address */ - public static final String CLOUD_HTTP_ADDRESS; - /** Local HTTP address */ - public static final String LOCAL_HTTP_ADDRESS; static { CLOUD_FLIGHT_ADDRESS = System.getenv("SPICE_FLIGHT_URL") != null ? System.getenv("SPICE_FLIGHT_URL") @@ -45,12 +41,6 @@ public class Config { LOCAL_FLIGHT_ADDRESS = System.getenv("SPICE_FLIGHT_URL") != null ? System.getenv("SPICE_FLIGHT_URL") : "http://localhost:50051"; - - CLOUD_HTTP_ADDRESS = System.getenv("SPICE_HTTP_URL") != null ? System.getenv("SPICE_HTTP_URL") - : "https://data.spiceai.io"; - - LOCAL_HTTP_ADDRESS = System.getenv("SPICE_HTTP_URL") != null ? System.getenv("SPICE_HTTP_URL") - : "http://localhost:8090"; } /** @@ -72,24 +62,4 @@ public static URI getLocalFlightAddressUri() throws URISyntaxException { public static URI getCloudFlightAddressUri() throws URISyntaxException { return new URI(CLOUD_FLIGHT_ADDRESS); } - - /** - * Returns the local HTTP address - * - * @return URI of the local HTTP address. - * @throws URISyntaxException if the string could not be parsed as a URI. - */ - public static URI getLocalHttpAddressUri() throws URISyntaxException { - return new URI(LOCAL_HTTP_ADDRESS); - } - - /** - * Returns the cloud HTTP address - * - * @return URI of the cloud HTTP address. - * @throws URISyntaxException if the string could not be parsed as a URI. - */ - public static URI getCloudHttpAddressUri() throws URISyntaxException { - return new URI(CLOUD_HTTP_ADDRESS); - } } \ No newline at end of file diff --git a/src/main/java/ai/spice/SpiceClient.java b/src/main/java/ai/spice/SpiceClient.java index 07842d0..f2532e0 100644 --- a/src/main/java/ai/spice/SpiceClient.java +++ b/src/main/java/ai/spice/SpiceClient.java @@ -22,14 +22,11 @@ of this software and associated documentation files (the "Software"), to deal package ai.spice; -import java.net.ConnectException; import java.net.URI; import java.net.URISyntaxException; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; import java.util.concurrent.ExecutionException; +import org.apache.arrow.flight.Action; import org.apache.arrow.flight.CallStatus; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightClient.Builder; @@ -42,6 +39,7 @@ of this software and associated documentation files (the "Software"), to deal import org.apache.arrow.flight.grpc.CredentialCallOption; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.FlightStatusCode; import org.apache.arrow.memory.RootAllocator; import com.github.rholder.retry.RetryException; @@ -52,6 +50,7 @@ of this software and associated documentation files (the "Software"), to deal import com.google.common.base.Strings; import org.apache.arrow.flight.sql.FlightSqlClient; +import ai.spice.proto.Spice.AcceleratedDatasetRefreshRequest; /** * Client to execute SQL queries against Spice.ai Cloud and Spice.ai OSS @@ -61,9 +60,9 @@ public class SpiceClient implements AutoCloseable { private String appId; private String apiKey; private URI flightAddress; - private URI httpAddress; private int maxRetries; - private FlightSqlClient flightClient; + private FlightSqlClient flightSqlClient; + private FlightClient flightClient; private CredentialCallOption authCallOptions = null; /** @@ -85,15 +84,12 @@ public static SpiceClientBuilder builder() throws URISyntaxException { * services * @param flightAddress the URI of the flight address for connecting to Spice.ai * services - * @param httpAddress the URI of the Spice.ai runtime HTTP address - * * @param maxRetries the maximum number of connection retries for the client */ - public SpiceClient(String appId, String apiKey, URI flightAddress, URI httpAddress, int maxRetries) { + public SpiceClient(String appId, String apiKey, URI flightAddress, int maxRetries) { this.appId = appId; this.apiKey = apiKey; this.maxRetries = maxRetries; - this.httpAddress = httpAddress; // Arrow Flight requires URI to be grpc protocol, convert http/https for // convinience @@ -108,7 +104,8 @@ public SpiceClient(String appId, String apiKey, URI flightAddress, URI httpAddre Builder builder = FlightClient.builder(new RootAllocator(Long.MAX_VALUE), new Location(this.flightAddress)); if (Strings.isNullOrEmpty(apiKey)) { - this.flightClient = new FlightSqlClient(builder.build()); + this.flightClient = builder.build(); + this.flightSqlClient = new FlightSqlClient(this.flightClient); return; } @@ -118,7 +115,8 @@ public SpiceClient(String appId, String apiKey, URI flightAddress, URI httpAddre final FlightClient client = builder.intercept(factory).build(); client.handshake(new CredentialCallOption(new BasicAuthCredentialWriter(this.appId, this.apiKey))); this.authCallOptions = factory.getCredentialCallOption(); - this.flightClient = new FlightSqlClient(client); + this.flightClient = client; + this.flightSqlClient = new FlightSqlClient(client); } /** @@ -137,6 +135,10 @@ public FlightStream query(String sql) throws ExecutionException { return this.queryInternalWithRetry(sql); } catch (RetryException e) { Throwable err = e.getLastFailedAttempt().getExceptionCause(); + if (err instanceof FlightRuntimeException) { + maybeSpiceConnectionError((FlightRuntimeException) err); + } + throw new ExecutionException("Failed to execute query due to error: " + err.toString(), err); } } @@ -147,37 +149,31 @@ public void refresh(String dataset) throws ExecutionException { } try { - HttpClient client = HttpClient.newHttpClient(); - HttpRequest request = HttpRequest.newBuilder() - .uri(new URI(String.format("%s/v1/datasets/%s/acceleration/refresh", this.httpAddress, dataset))) - .header("Content-Type", "application/json") - .POST(HttpRequest.BodyPublishers.noBody()) - .build(); - - HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); - - if (response.statusCode() != 201) { - throw new ExecutionException( - String.format("Failed to trigger dataset refresh. Status Code: %d, Response: %s", - response.statusCode(), - response.body()), - null); + + AcceleratedDatasetRefreshRequest request = AcceleratedDatasetRefreshRequest.newBuilder() + .setDatasetName(dataset).build(); + + Action action = new Action("AcceleratedDatasetRefresh", request.toByteArray()); + + if (!this.flightClient.doAction(action).hasNext()) { + throw new ExecutionException("Failed to trigger dataset refresh: No response from the server.", null); } + } catch (ExecutionException e) { // no need to wrap ExecutionException throw e; - } catch (ConnectException err) { - throw new ExecutionException( - String.format("The Spice runtime is unavailable at %s. Is it running?", this.httpAddress), err); + } catch (FlightRuntimeException err) { + maybeSpiceConnectionError(err); + throw new ExecutionException("Failed to trigger dataset refresh due to error: " + err.toString(), err); } catch (Exception err) { throw new ExecutionException("Failed to trigger dataset refresh due to error: " + err.toString(), err); } } private FlightStream queryInternal(String sql) { - FlightInfo flightInfo = this.flightClient.execute(sql, authCallOptions); + FlightInfo flightInfo = this.flightSqlClient.execute(sql, authCallOptions); Ticket ticket = flightInfo.getEndpoints().get(0).getTicket(); - return this.flightClient.getStream(ticket, authCallOptions); + return this.flightSqlClient.getStream(ticket, authCallOptions); } private FlightStream queryInternalWithRetry(String sql) throws ExecutionException, RetryException { @@ -209,8 +205,15 @@ private boolean shouldRetry(CallStatus status) { } } + private void maybeSpiceConnectionError(FlightRuntimeException err) throws ExecutionException { + if (err.status().code() == FlightStatusCode.UNAVAILABLE) { + throw new ExecutionException( + String.format("The Spice runtime is unavailable at %s. Is it running?", this.flightAddress), err); + } + } + @Override public void close() throws Exception { - this.flightClient.close(); + this.flightSqlClient.close(); } } diff --git a/src/main/java/ai/spice/SpiceClientBuilder.java b/src/main/java/ai/spice/SpiceClientBuilder.java index 08209ac..25b7ecc 100644 --- a/src/main/java/ai/spice/SpiceClientBuilder.java +++ b/src/main/java/ai/spice/SpiceClientBuilder.java @@ -35,7 +35,6 @@ public class SpiceClientBuilder { private String appId; private String apiKey; private URI flightAddress; - private URI httpAddress; private int maxRetries = 3; /** @@ -45,7 +44,6 @@ public class SpiceClientBuilder { */ SpiceClientBuilder() throws URISyntaxException { this.flightAddress = Config.getLocalFlightAddressUri(); - this.httpAddress = Config.getLocalHttpAddressUri(); } /** @@ -62,20 +60,6 @@ public SpiceClientBuilder withFlightAddress(URI flightAddress) { return this; } - /** - * Sets the client's HTTP address - * - * @param httpAddress The URI of the HTTP address - * @return The current instance of SpiceClientBuilder for method chaining. - */ - public SpiceClientBuilder withHttpAddress(URI httpAddress) { - if (httpAddress == null) { - throw new IllegalArgumentException("httpAddress can't be null"); - } - this.httpAddress = httpAddress; - return this; - } - /** * Sets the client's Api Key. * @@ -106,7 +90,6 @@ public SpiceClientBuilder withApiKey(String apiKey) { */ public SpiceClientBuilder withSpiceCloud() throws URISyntaxException { this.flightAddress = Config.getCloudFlightAddressUri(); - this.httpAddress = Config.getCloudHttpAddressUri(); return this; } @@ -130,6 +113,6 @@ public SpiceClientBuilder withMaxRetries(int maxRetries) { * @return The SpiceClient instance */ public SpiceClient build() { - return new SpiceClient(appId, apiKey, flightAddress, httpAddress, maxRetries); + return new SpiceClient(appId, apiKey, flightAddress, maxRetries); } } \ No newline at end of file diff --git a/src/main/java/ai/spice/example/ExampleDatasetRefreshSpiceOSS.java b/src/main/java/ai/spice/example/ExampleDatasetRefreshSpiceOSS.java index ae40dfc..e10808d 100644 --- a/src/main/java/ai/spice/example/ExampleDatasetRefreshSpiceOSS.java +++ b/src/main/java/ai/spice/example/ExampleDatasetRefreshSpiceOSS.java @@ -22,8 +22,6 @@ of this software and associated documentation files (the "Software"), to deal package ai.spice.example; -import java.net.URI; - import org.apache.arrow.flight.FlightStream; import org.apache.arrow.vector.VectorSchemaRoot; @@ -40,8 +38,6 @@ public class ExampleDatasetRefreshSpiceOSS { public static void main(String[] args) { try (SpiceClient client = SpiceClient.builder() - .withFlightAddress(URI.create("http://localhost:50051")) - .withHttpAddress(URI.create("http://localhost:8090")) .build()) { client.refresh("taxi_trips"); diff --git a/src/main/proto/spice.proto b/src/main/proto/spice.proto new file mode 100644 index 0000000..81e0fa6 --- /dev/null +++ b/src/main/proto/spice.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +option java_package = "ai.spice.proto"; + +message AcceleratedDatasetRefreshRequest { + string dataset_name = 1; +} \ No newline at end of file