Skip to content

Commit

Permalink
1
Browse files Browse the repository at this point in the history
  • Loading branch information
xinyiZzz committed Dec 25, 2024
1 parent 8e09e1b commit 55f2bcd
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.WriteChannel;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
Expand Down Expand Up @@ -131,21 +133,19 @@ private void getStreamStatementResult(String handle, ServerStreamListener listen
String[] handleParts = handle.split(":");
String executedPeerIdentity = handleParts[0];
String queryId = handleParts[1];
// The tokens used for authentication between getStreamStatement and getFlightInfoStatement are different.
ConnectContext connectContext = flightSessionsManager.getConnectContext(executedPeerIdentity);
try {
// The tokens used for authentication between getStreamStatement and getFlightInfoStatement are different.
final FlightSqlResultCacheEntry flightSqlResultCacheEntry = Objects.requireNonNull(
connectContext.getFlightSqlChannel().getResult(queryId));
final VectorSchemaRoot vectorSchemaRoot = flightSqlResultCacheEntry.getVectorSchemaRoot();
listener.start(vectorSchemaRoot);
listener.putNext();
} catch (Exception e) {
listener.error(e);
String errMsg = "get stream statement failed, " + e.getMessage() + ", " + Util.getRootCauseMessage(e)
+ ", error code: " + connectContext.getState().getErrorCode() + ", error msg: "
+ connectContext.getState().getErrorMessage();
LOG.warn(errMsg, e);
throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException();
handleStreamException(e, errMsg, listener);
} finally {
listener.completed();
// The result has been sent or sent failed, delete it.
Expand Down Expand Up @@ -280,7 +280,7 @@ private FlightInfo executeQueryStatement(String peerIdentity, ConnectContext con
String errMsg = "get flight info statement failed, " + e.getMessage() + ", " + Util.getRootCauseMessage(e)
+ ", error code: " + connectContext.getState().getErrorCode() + ", error msg: "
+ connectContext.getState().getErrorMessage();
LOG.warn(errMsg, e);
LOG.error(errMsg, e);
throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException();
} finally {
connectContext.setCommand(MysqlCommand.COM_SLEEP);
Expand Down Expand Up @@ -361,7 +361,7 @@ public void createPreparedStatement(final ActionCreatePreparedStatementRequest r
String errMsg = "create prepared statement failed, " + e.getMessage() + ", " + Util.getRootCauseMessage(
e) + ", error code: " + connectContext.getState().getErrorCode() + ", error msg: "
+ connectContext.getState().getErrorMessage();
LOG.warn(errMsg, e);
LOG.error(errMsg, e);
listener.onError(CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException());
return;
} catch (final Throwable t) {
Expand Down Expand Up @@ -407,7 +407,7 @@ public Runnable acceptPutPreparedStatementUpdate(CommandPreparedStatementUpdate
} catch (Exception e) {
String errMsg = "acceptPutPreparedStatementUpdate failed, " + e.getMessage() + ", "
+ Util.getRootCauseMessage(e);
LOG.warn(errMsg, e);
LOG.error(errMsg, e);
throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException();
}
};
Expand Down Expand Up @@ -479,7 +479,82 @@ public FlightInfo getFlightInfoTables(final CommandGetTables request, final Call
@Override
public void getStreamTables(final CommandGetTables command, final CallContext context,
final ServerStreamListener listener) {
throw CallStatus.UNIMPLEMENTED.withDescription("getStreamTables unimplemented").toRuntimeException();
// throw CallStatus.UNIMPLEMENTED.withDescription("getStreamTables unimplemented").toRuntimeException();
try {
// if (commandGetTables.hasDbSchemaFilterPattern()) {
// builder.setSchemaNameFilter(
// UserProtos.LikeFilter.newBuilder()
// .setPattern(command.getDbSchemaFilterPattern())
// .build());
// }
//
// if (commandGetTables.hasTableNameFilterPattern()) {
// builder.setTableNameFilter(
// UserProtos.LikeFilter.newBuilder()
// .setPattern(command.getTableNameFilterPattern())
// .build());
// }
//
// if (!commandGetTables.getTableTypesList().isEmpty()) {
// builder.addAllTableTypeFilter(command.getTableTypesList());
// }

final boolean includeSchema = command.getIncludeSchema();

// final Map<UserProtos.TableMetadata, List<Field>> tableToFields;
// if (includeSchema) {
// tableToFields = runGetColumns(isRequestCancelled, userSession, runExternalId, getTablesReq);
// } else {
// tableToFields = null;
// }

final Schema schema =
includeSchema
? FlightSqlProducer.Schemas.GET_TABLES_SCHEMA
: FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA;

try (VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(schema, rootAllocator)) {
listener.start(vectorSchemaRoot);

vectorSchemaRoot.allocateNew();
VarCharVector catalogNameVector = (VarCharVector) vectorSchemaRoot.getVector("catalog_name");
VarCharVector schemaNameVector = (VarCharVector) vectorSchemaRoot.getVector("db_schema_name");
VarCharVector tableNameVector = (VarCharVector) vectorSchemaRoot.getVector("table_name");
VarCharVector tableTypeVector = (VarCharVector) vectorSchemaRoot.getVector("table_type");
VarBinaryVector schemaVector = (VarBinaryVector) vectorSchemaRoot.getVector("table_schema");

// final int tablesCount = getTablesResp.getTablesCount();
// final IntStream range = range(0, tablesCount);
//
// range.forEach(
// i -> {
// final UserProtos.TableMetadata table = getTablesResp.getTables(i);
// catalogNameVector.setNull(i);
// schemaNameVector.setSafe(i, new Text(table.getSchemaName()));
// tableTypeVector.setSafe(i, new Text(table.getType()));
//
// final String tableName = table.getTableName();
// tableNameVector.setSafe(i, new Text(tableName));
//
// if (includeSchema) {
// List<Field> fields =
// tableToFields.get(
// UserProtos.TableMetadata.newBuilder()
// .setSchemaName(table.getSchemaName())
// .setTableName(table.getTableName())
// .build());
// schemaVector.setSafe(i, getSerializedSchema(fields));
// }
// });

// vectorSchemaRoot.setRowCount(tablesCount);
vectorSchemaRoot.setRowCount(0);
listener.putNext();
listener.completed();
}
} catch (final Exception e) {
handleStreamException(e, "", listener);
}
}

@Override
Expand Down Expand Up @@ -550,4 +625,10 @@ private <T extends Message> FlightInfo getFlightInfoForSchema(final T request, f

return new FlightInfo(schema, descriptor, endpoints, -1, -1);
}

private static void handleStreamException(Exception e, String errMsg, ServerStreamListener listener) {
LOG.error(errMsg, e);
listener.error(CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException());
throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ public static FlightAuthResult authenticateCredentials(String username, String p
Logger logger) {
try {
List<UserIdentity> currentUserIdentity = Lists.newArrayList();

// If the password is empty, DBeaver will pass "null" string for authentication.
// This behavior of DBeaver is strange, but we have to be compatible with it, of course,
// it may be a problem with Arrow Flight Jdbc driver.
// Here, "null" is converted to null, if user's password is really the string "null",
// authentication will fail. Usually, the user's password will not be "null", let's hope so.
password = (password.equals("null")) ? null : password;
Env.getCurrentEnv().getAuth().checkPlainPassword(username, remoteIp, password, currentUserIdentity);
Preconditions.checkState(currentUserIdentity.size() == 1);
return FlightAuthResult.of(username, currentUserIdentity.get(0), remoteIp);
Expand Down

0 comments on commit 55f2bcd

Please sign in to comment.