Skip to content

Commit

Permalink
1
Browse files Browse the repository at this point in the history
  • Loading branch information
xinyiZzz committed Dec 31, 2024
1 parent 8e09e1b commit de57672
Show file tree
Hide file tree
Showing 3 changed files with 456 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,12 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.WriteChannel;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.Text;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

Expand Down Expand Up @@ -131,21 +133,19 @@ private void getStreamStatementResult(String handle, ServerStreamListener listen
String[] handleParts = handle.split(":");
String executedPeerIdentity = handleParts[0];
String queryId = handleParts[1];
// The tokens used for authentication between getStreamStatement and getFlightInfoStatement are different.
ConnectContext connectContext = flightSessionsManager.getConnectContext(executedPeerIdentity);
try {
// The tokens used for authentication between getStreamStatement and getFlightInfoStatement are different.
final FlightSqlResultCacheEntry flightSqlResultCacheEntry = Objects.requireNonNull(
connectContext.getFlightSqlChannel().getResult(queryId));
final VectorSchemaRoot vectorSchemaRoot = flightSqlResultCacheEntry.getVectorSchemaRoot();
listener.start(vectorSchemaRoot);
listener.putNext();
} catch (Exception e) {
listener.error(e);
String errMsg = "get stream statement failed, " + e.getMessage() + ", " + Util.getRootCauseMessage(e)
+ ", error code: " + connectContext.getState().getErrorCode() + ", error msg: "
+ connectContext.getState().getErrorMessage();
LOG.warn(errMsg, e);
throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException();
handleStreamException(e, errMsg, listener);
} finally {
listener.completed();
// The result has been sent or sent failed, delete it.
Expand Down Expand Up @@ -280,7 +280,7 @@ private FlightInfo executeQueryStatement(String peerIdentity, ConnectContext con
String errMsg = "get flight info statement failed, " + e.getMessage() + ", " + Util.getRootCauseMessage(e)
+ ", error code: " + connectContext.getState().getErrorCode() + ", error msg: "
+ connectContext.getState().getErrorMessage();
LOG.warn(errMsg, e);
LOG.error(errMsg, e);
throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException();
} finally {
connectContext.setCommand(MysqlCommand.COM_SLEEP);
Expand Down Expand Up @@ -361,7 +361,7 @@ public void createPreparedStatement(final ActionCreatePreparedStatementRequest r
String errMsg = "create prepared statement failed, " + e.getMessage() + ", " + Util.getRootCauseMessage(
e) + ", error code: " + connectContext.getState().getErrorCode() + ", error msg: "
+ connectContext.getState().getErrorMessage();
LOG.warn(errMsg, e);
LOG.error(errMsg, e);
listener.onError(CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException());
return;
} catch (final Throwable t) {
Expand Down Expand Up @@ -407,7 +407,7 @@ public Runnable acceptPutPreparedStatementUpdate(CommandPreparedStatementUpdate
} catch (Exception e) {
String errMsg = "acceptPutPreparedStatementUpdate failed, " + e.getMessage() + ", "
+ Util.getRootCauseMessage(e);
LOG.warn(errMsg, e);
LOG.error(errMsg, e);
throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException();
}
};
Expand Down Expand Up @@ -451,7 +451,21 @@ public FlightInfo getFlightInfoCatalogs(final CommandGetCatalogs request, final

@Override
public void getStreamCatalogs(final CallContext context, final ServerStreamListener listener) {
throw CallStatus.UNIMPLEMENTED.withDescription("getStreamCatalogs unimplemented").toRuntimeException();
try (final VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(Schemas.GET_CATALOGS_SCHEMA,
rootAllocator)) {
listener.start(vectorSchemaRoot);
vectorSchemaRoot.allocateNew();
VarCharVector catalogNameVector = (VarCharVector) vectorSchemaRoot.getVector("catalog_name");
// Only show Internal Catalog, which is consistent with `jdbc:mysql`.
// Otherwise, if the configured ExternalCatalog cannot be connected,
// `catalog.getAllDbs()` will be stuck and wait until the timeout period ends.
catalogNameVector.setSafe(0, new Text("internal"));
vectorSchemaRoot.setRowCount(1);
listener.putNext();
listener.completed();
} catch (final Exception ex) {
handleStreamException(ex, "", listener);
}
}

@Override
Expand All @@ -463,7 +477,22 @@ public FlightInfo getFlightInfoSchemas(final CommandGetDbSchemas request, final
@Override
public void getStreamSchemas(final CommandGetDbSchemas command, final CallContext context,
final ServerStreamListener listener) {
throw CallStatus.UNIMPLEMENTED.withDescription("getStreamSchemas unimplemented").toRuntimeException();
try {
ConnectContext connectContext = flightSessionsManager.getConnectContext(context.peerIdentity());
FlightSqlSchemaHelper flightSqlSchemaHelper = new FlightSqlSchemaHelper(connectContext);
flightSqlSchemaHelper.setParameterForGetDbSchemas(command);
final Schema schema = Schemas.GET_SCHEMAS_SCHEMA;

try (VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(schema, rootAllocator)) {
listener.start(vectorSchemaRoot);
vectorSchemaRoot.allocateNew();
flightSqlSchemaHelper.getSchemas(vectorSchemaRoot);
listener.putNext();
listener.completed();
}
} catch (final Exception e) {
handleStreamException(e, "", listener);
}
}

@Override
Expand All @@ -479,7 +508,23 @@ public FlightInfo getFlightInfoTables(final CommandGetTables request, final Call
@Override
public void getStreamTables(final CommandGetTables command, final CallContext context,
final ServerStreamListener listener) {
throw CallStatus.UNIMPLEMENTED.withDescription("getStreamTables unimplemented").toRuntimeException();
try {
ConnectContext connectContext = flightSessionsManager.getConnectContext(context.peerIdentity());
FlightSqlSchemaHelper flightSqlSchemaHelper = new FlightSqlSchemaHelper(connectContext);
flightSqlSchemaHelper.setParameterForGetTables(command);
final Schema schema = command.getIncludeSchema() ? Schemas.GET_TABLES_SCHEMA
: Schemas.GET_TABLES_SCHEMA_NO_SCHEMA;

try (VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(schema, rootAllocator)) {
listener.start(vectorSchemaRoot);
vectorSchemaRoot.allocateNew();
flightSqlSchemaHelper.getTables(vectorSchemaRoot);
listener.putNext();
listener.completed();
}
} catch (final Exception e) {
handleStreamException(e, "", listener);
}
}

@Override
Expand Down Expand Up @@ -545,9 +590,14 @@ public void getStreamCrossReference(CommandGetCrossReference command, CallContex
private <T extends Message> FlightInfo getFlightInfoForSchema(final T request, final FlightDescriptor descriptor,
final Schema schema) {
final Ticket ticket = new Ticket(Any.pack(request).toByteArray());
// TODO Support multiple endpoints.
final List<FlightEndpoint> endpoints = Collections.singletonList(new FlightEndpoint(ticket, location));

return new FlightInfo(schema, descriptor, endpoints, -1, -1);
}

private static void handleStreamException(Exception e, String errMsg, ServerStreamListener listener) {
LOG.error(errMsg, e);
listener.error(CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException());
throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException();
}
}
Loading

0 comments on commit de57672

Please sign in to comment.