Skip to content

Commit

Permalink
Refresh command via Spice gRPC API
Browse files Browse the repository at this point in the history
  • Loading branch information
sgrebnov committed Aug 25, 2024
1 parent 65bb360 commit 8a2e715
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 85 deletions.
50 changes: 50 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
<properties>
<maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target>
<protobuf.version>3.25.4</protobuf.version>
<protoc.version>3.25.4</protoc.version>
<grpc.version>1.66.0</grpc.version>
</properties>
<dependencies>
<dependency>
Expand Down Expand Up @@ -59,8 +62,36 @@
<version>4.13.2</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
<version>${protobuf.version}</version>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-netty-shaded</artifactId>
<version>${grpc.version}</version>
</dependency>

<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-protobuf</artifactId>
<version>${grpc.version}</version>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-stub</artifactId>
<version>${grpc.version}</version>
</dependency>
</dependencies>
<build>
<extensions>
<extension>
<groupId>kr.motd.maven</groupId>
<artifactId>os-maven-plugin</artifactId>
<version>1.7.1</version>
</extension>
</extensions>
<plugins>
<!-- https://arrow.apache.org/docs/java/install.html#java-compatibility -->
<plugin>
Expand Down Expand Up @@ -127,6 +158,25 @@
<waitUntil>published</waitUntil> -->
</configuration>
</plugin>
<!-- compile proto file into java files. -->
<plugin>
<groupId>org.xolstice.maven.plugins</groupId>
<artifactId>protobuf-maven-plugin</artifactId>
<version>0.6.1</version>
<configuration>
<protocArtifact>com.google.protobuf:protoc:${protoc.version}:exe:${os.detected.classifier}</protocArtifact>
<pluginId>grpc-java</pluginId>
<pluginArtifact>io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier}</pluginArtifact>
</configuration>
<executions>
<execution>
<goals>
<goal>compile</goal>
<goal>compile-custom</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
30 changes: 0 additions & 30 deletions src/main/java/ai/spice/Config.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,13 @@ public class Config {
public static final String CLOUD_FLIGHT_ADDRESS;
/** Local flight address */
public static final String LOCAL_FLIGHT_ADDRESS;
/** Cloud HTTP address */
public static final String CLOUD_HTTP_ADDRESS;
/** Local HTTP address */
public static final String LOCAL_HTTP_ADDRESS;

static {
CLOUD_FLIGHT_ADDRESS = System.getenv("SPICE_FLIGHT_URL") != null ? System.getenv("SPICE_FLIGHT_URL")
: "https://flight.spiceai.io:443";

LOCAL_FLIGHT_ADDRESS = System.getenv("SPICE_FLIGHT_URL") != null ? System.getenv("SPICE_FLIGHT_URL")
: "http://localhost:50051";

CLOUD_HTTP_ADDRESS = System.getenv("SPICE_HTTP_URL") != null ? System.getenv("SPICE_HTTP_URL")
: "https://data.spiceai.io";

LOCAL_HTTP_ADDRESS = System.getenv("SPICE_HTTP_URL") != null ? System.getenv("SPICE_HTTP_URL")
: "http://localhost:8090";
}

/**
Expand All @@ -72,24 +62,4 @@ public static URI getLocalFlightAddressUri() throws URISyntaxException {
public static URI getCloudFlightAddressUri() throws URISyntaxException {
return new URI(CLOUD_FLIGHT_ADDRESS);
}

/**
* Returns the local HTTP address
*
* @return URI of the local HTTP address.
* @throws URISyntaxException if the string could not be parsed as a URI.
*/
public static URI getLocalHttpAddressUri() throws URISyntaxException {
return new URI(LOCAL_HTTP_ADDRESS);
}

/**
* Returns the cloud HTTP address
*
* @return URI of the cloud HTTP address.
* @throws URISyntaxException if the string could not be parsed as a URI.
*/
public static URI getCloudHttpAddressUri() throws URISyntaxException {
return new URI(CLOUD_HTTP_ADDRESS);
}
}
69 changes: 36 additions & 33 deletions src/main/java/ai/spice/SpiceClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,11 @@ of this software and associated documentation files (the "Software"), to deal

package ai.spice;

import java.net.ConnectException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.util.concurrent.ExecutionException;

import org.apache.arrow.flight.Action;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightClient.Builder;
Expand All @@ -42,6 +39,7 @@ of this software and associated documentation files (the "Software"), to deal
import org.apache.arrow.flight.grpc.CredentialCallOption;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightStatusCode;
import org.apache.arrow.memory.RootAllocator;

import com.github.rholder.retry.RetryException;
Expand All @@ -52,6 +50,7 @@ of this software and associated documentation files (the "Software"), to deal
import com.google.common.base.Strings;

import org.apache.arrow.flight.sql.FlightSqlClient;
import ai.spice.proto.Spice.AcceleratedDatasetRefreshRequest;

/**
* Client to execute SQL queries against Spice.ai Cloud and Spice.ai OSS
Expand All @@ -61,9 +60,9 @@ public class SpiceClient implements AutoCloseable {
private String appId;
private String apiKey;
private URI flightAddress;
private URI httpAddress;
private int maxRetries;
private FlightSqlClient flightClient;
private FlightSqlClient flightSqlClient;
private FlightClient flightClient;
private CredentialCallOption authCallOptions = null;

/**
Expand All @@ -85,15 +84,12 @@ public static SpiceClientBuilder builder() throws URISyntaxException {
* services
* @param flightAddress the URI of the flight address for connecting to Spice.ai
* services
* @param httpAddress the URI of the Spice.ai runtime HTTP address
*
* @param maxRetries the maximum number of connection retries for the client
*/
public SpiceClient(String appId, String apiKey, URI flightAddress, URI httpAddress, int maxRetries) {
public SpiceClient(String appId, String apiKey, URI flightAddress, int maxRetries) {
this.appId = appId;
this.apiKey = apiKey;
this.maxRetries = maxRetries;
this.httpAddress = httpAddress;

// Arrow Flight requires URI to be grpc protocol, convert http/https for
// convinience
Expand All @@ -108,7 +104,8 @@ public SpiceClient(String appId, String apiKey, URI flightAddress, URI httpAddre
Builder builder = FlightClient.builder(new RootAllocator(Long.MAX_VALUE), new Location(this.flightAddress));

if (Strings.isNullOrEmpty(apiKey)) {
this.flightClient = new FlightSqlClient(builder.build());
this.flightClient = builder.build();
this.flightSqlClient = new FlightSqlClient(this.flightClient);
return;
}

Expand All @@ -118,7 +115,8 @@ public SpiceClient(String appId, String apiKey, URI flightAddress, URI httpAddre
final FlightClient client = builder.intercept(factory).build();
client.handshake(new CredentialCallOption(new BasicAuthCredentialWriter(this.appId, this.apiKey)));
this.authCallOptions = factory.getCredentialCallOption();
this.flightClient = new FlightSqlClient(client);
this.flightClient = client;
this.flightSqlClient = new FlightSqlClient(client);
}

/**
Expand All @@ -137,6 +135,10 @@ public FlightStream query(String sql) throws ExecutionException {
return this.queryInternalWithRetry(sql);
} catch (RetryException e) {
Throwable err = e.getLastFailedAttempt().getExceptionCause();
if (err instanceof FlightRuntimeException) {
maybeSpiceConnectionError((FlightRuntimeException) err);
}

throw new ExecutionException("Failed to execute query due to error: " + err.toString(), err);
}
}
Expand All @@ -147,37 +149,31 @@ public void refresh(String dataset) throws ExecutionException {
}

try {
HttpClient client = HttpClient.newHttpClient();
HttpRequest request = HttpRequest.newBuilder()
.uri(new URI(String.format("%s/v1/datasets/%s/acceleration/refresh", this.httpAddress, dataset)))
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.noBody())
.build();

HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());

if (response.statusCode() != 201) {
throw new ExecutionException(
String.format("Failed to trigger dataset refresh. Status Code: %d, Response: %s",
response.statusCode(),
response.body()),
null);

AcceleratedDatasetRefreshRequest request = AcceleratedDatasetRefreshRequest.newBuilder()
.setDatasetName(dataset).build();

Action action = new Action("AcceleratedDatasetRefresh", request.toByteArray());

if (!this.flightClient.doAction(action).hasNext()) {
throw new ExecutionException("Failed to trigger dataset refresh: No response from the server.", null);
}

} catch (ExecutionException e) {
// no need to wrap ExecutionException
throw e;
} catch (ConnectException err) {
throw new ExecutionException(
String.format("The Spice runtime is unavailable at %s. Is it running?", this.httpAddress), err);
} catch (FlightRuntimeException err) {
maybeSpiceConnectionError(err);
throw new ExecutionException("Failed to trigger dataset refresh due to error: " + err.toString(), err);
} catch (Exception err) {
throw new ExecutionException("Failed to trigger dataset refresh due to error: " + err.toString(), err);
}
}

private FlightStream queryInternal(String sql) {
FlightInfo flightInfo = this.flightClient.execute(sql, authCallOptions);
FlightInfo flightInfo = this.flightSqlClient.execute(sql, authCallOptions);
Ticket ticket = flightInfo.getEndpoints().get(0).getTicket();
return this.flightClient.getStream(ticket, authCallOptions);
return this.flightSqlClient.getStream(ticket, authCallOptions);
}

private FlightStream queryInternalWithRetry(String sql) throws ExecutionException, RetryException {
Expand Down Expand Up @@ -209,8 +205,15 @@ private boolean shouldRetry(CallStatus status) {
}
}

private void maybeSpiceConnectionError(FlightRuntimeException err) throws ExecutionException {
if (err.status().code() == FlightStatusCode.UNAVAILABLE) {
throw new ExecutionException(
String.format("The Spice runtime is unavailable at %s. Is it running?", this.flightAddress), err);
}
}

@Override
public void close() throws Exception {
this.flightClient.close();
this.flightSqlClient.close();
}
}
19 changes: 1 addition & 18 deletions src/main/java/ai/spice/SpiceClientBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ public class SpiceClientBuilder {
private String appId;
private String apiKey;
private URI flightAddress;
private URI httpAddress;
private int maxRetries = 3;

/**
Expand All @@ -45,7 +44,6 @@ public class SpiceClientBuilder {
*/
SpiceClientBuilder() throws URISyntaxException {
this.flightAddress = Config.getLocalFlightAddressUri();
this.httpAddress = Config.getLocalHttpAddressUri();
}

/**
Expand All @@ -62,20 +60,6 @@ public SpiceClientBuilder withFlightAddress(URI flightAddress) {
return this;
}

/**
* Sets the client's HTTP address
*
* @param httpAddress The URI of the HTTP address
* @return The current instance of SpiceClientBuilder for method chaining.
*/
public SpiceClientBuilder withHttpAddress(URI httpAddress) {
if (httpAddress == null) {
throw new IllegalArgumentException("httpAddress can't be null");
}
this.httpAddress = httpAddress;
return this;
}

/**
* Sets the client's Api Key.
*
Expand Down Expand Up @@ -106,7 +90,6 @@ public SpiceClientBuilder withApiKey(String apiKey) {
*/
public SpiceClientBuilder withSpiceCloud() throws URISyntaxException {
this.flightAddress = Config.getCloudFlightAddressUri();
this.httpAddress = Config.getCloudHttpAddressUri();
return this;
}

Expand All @@ -130,6 +113,6 @@ public SpiceClientBuilder withMaxRetries(int maxRetries) {
* @return The SpiceClient instance
*/
public SpiceClient build() {
return new SpiceClient(appId, apiKey, flightAddress, httpAddress, maxRetries);
return new SpiceClient(appId, apiKey, flightAddress, maxRetries);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ of this software and associated documentation files (the "Software"), to deal

package ai.spice.example;

import java.net.URI;

import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.vector.VectorSchemaRoot;

Expand All @@ -40,8 +38,6 @@ public class ExampleDatasetRefreshSpiceOSS {

public static void main(String[] args) {
try (SpiceClient client = SpiceClient.builder()
.withFlightAddress(URI.create("http://localhost:50051"))
.withHttpAddress(URI.create("http://localhost:8090"))
.build()) {

client.refresh("taxi_trips");
Expand Down
7 changes: 7 additions & 0 deletions src/main/proto/spice.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
syntax = "proto3";

option java_package = "ai.spice.proto";

message AcceleratedDatasetRefreshRequest {
string dataset_name = 1;
}

0 comments on commit 8a2e715

Please sign in to comment.