Skip to content

Commit

Permalink
Infer missing arrow schema (#233)
Browse files Browse the repository at this point in the history
* Refactor: pass the whole metadata object to result handlers

Signed-off-by: Levko Kravets <levko.ne@gmail.com>

* Infer Arrow schema when it is not available

Signed-off-by: Levko Kravets <levko.ne@gmail.com>

---------

Signed-off-by: Levko Kravets <levko.ne@gmail.com>
  • Loading branch information
kravets-levko authored Mar 11, 2024
1 parent ff9fc0d commit b92b10c
Show file tree
Hide file tree
Showing 12 changed files with 186 additions and 52 deletions.
10 changes: 5 additions & 5 deletions lib/DBSQLOperation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -372,20 +372,20 @@ export default class DBSQLOperation implements IOperation {

switch (resultFormat) {
case TSparkRowSetType.COLUMN_BASED_SET:
resultSource = new JsonResultHandler(this.context, this._data, metadata.schema);
resultSource = new JsonResultHandler(this.context, this._data, metadata);
break;
case TSparkRowSetType.ARROW_BASED_SET:
resultSource = new ArrowResultConverter(
this.context,
new ArrowResultHandler(this.context, this._data, metadata.arrowSchema, metadata.lz4Compressed),
metadata.schema,
new ArrowResultHandler(this.context, this._data, metadata),
metadata,
);
break;
case TSparkRowSetType.URL_BASED_SET:
resultSource = new ArrowResultConverter(
this.context,
new CloudFetchResultHandler(this.context, this._data, metadata.lz4Compressed),
metadata.schema,
new CloudFetchResultHandler(this.context, this._data, metadata),
metadata,
);
break;
// no default
Expand Down
4 changes: 2 additions & 2 deletions lib/result/ArrowResultConverter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {
RecordBatchReader,
util as arrowUtils,
} from 'apache-arrow';
import { TTableSchema, TColumnDesc } from '../../thrift/TCLIService_types';
import { TGetResultSetMetadataResp, TColumnDesc } from '../../thrift/TCLIService_types';
import IClientContext from '../contracts/IClientContext';
import IResultsProvider, { ResultsProviderFetchNextOptions } from './IResultsProvider';
import { getSchemaColumns, convertThriftValue } from './utils';
Expand All @@ -34,7 +34,7 @@ export default class ArrowResultConverter implements IResultsProvider<Array<any>

private pendingRecordBatch?: RecordBatch<TypeMap>;

constructor(context: IClientContext, source: IResultsProvider<Array<Buffer>>, schema?: TTableSchema) {
constructor(context: IClientContext, source: IResultsProvider<Array<Buffer>>, { schema }: TGetResultSetMetadataResp) {
this.context = context;
this.source = source;
this.schema = getSchemaColumns(schema);
Expand Down
12 changes: 7 additions & 5 deletions lib/result/ArrowResultHandler.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import LZ4 from 'lz4';
import { TRowSet } from '../../thrift/TCLIService_types';
import { TGetResultSetMetadataResp, TRowSet } from '../../thrift/TCLIService_types';
import IClientContext from '../contracts/IClientContext';
import IResultsProvider, { ResultsProviderFetchNextOptions } from './IResultsProvider';
import { hiveSchemaToArrowSchema } from './utils';

export default class ArrowResultHandler implements IResultsProvider<Array<Buffer>> {
protected readonly context: IClientContext;
Expand All @@ -15,13 +16,14 @@ export default class ArrowResultHandler implements IResultsProvider<Array<Buffer
constructor(
context: IClientContext,
source: IResultsProvider<TRowSet | undefined>,
arrowSchema?: Buffer,
isLZ4Compressed?: boolean,
{ schema, arrowSchema, lz4Compressed }: TGetResultSetMetadataResp,
) {
this.context = context;
this.source = source;
this.arrowSchema = arrowSchema;
this.isLZ4Compressed = isLZ4Compressed ?? false;
// Arrow schema is not available in old DBR versions, which also don't support native Arrow types,
// so it's possible to infer Arrow schema from Hive schema ignoring `useArrowNativeTypes` option
this.arrowSchema = arrowSchema ?? hiveSchemaToArrowSchema(schema);
this.isLZ4Compressed = lz4Compressed ?? false;
}

public async hasMore() {
Expand Down
10 changes: 7 additions & 3 deletions lib/result/CloudFetchResultHandler.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import LZ4 from 'lz4';
import fetch, { RequestInfo, RequestInit } from 'node-fetch';
import { TRowSet, TSparkArrowResultLink } from '../../thrift/TCLIService_types';
import { TGetResultSetMetadataResp, TRowSet, TSparkArrowResultLink } from '../../thrift/TCLIService_types';
import IClientContext from '../contracts/IClientContext';
import IResultsProvider, { ResultsProviderFetchNextOptions } from './IResultsProvider';

Expand All @@ -15,10 +15,14 @@ export default class CloudFetchResultHandler implements IResultsProvider<Array<B

private downloadTasks: Array<Promise<Buffer>> = [];

constructor(context: IClientContext, source: IResultsProvider<TRowSet | undefined>, isLZ4Compressed?: boolean) {
constructor(
context: IClientContext,
source: IResultsProvider<TRowSet | undefined>,
{ lz4Compressed }: TGetResultSetMetadataResp,
) {
this.context = context;
this.source = source;
this.isLZ4Compressed = isLZ4Compressed ?? false;
this.isLZ4Compressed = lz4Compressed ?? false;
}

public async hasMore() {
Expand Down
8 changes: 6 additions & 2 deletions lib/result/JsonResultHandler.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { ColumnCode } from '../hive/Types';
import { TRowSet, TTableSchema, TColumn, TColumnDesc } from '../../thrift/TCLIService_types';
import { TGetResultSetMetadataResp, TRowSet, TColumn, TColumnDesc } from '../../thrift/TCLIService_types';
import IClientContext from '../contracts/IClientContext';
import IResultsProvider, { ResultsProviderFetchNextOptions } from './IResultsProvider';
import { getSchemaColumns, convertThriftValue } from './utils';
Expand All @@ -11,7 +11,11 @@ export default class JsonResultHandler implements IResultsProvider<Array<any>> {

private readonly schema: Array<TColumnDesc>;

constructor(context: IClientContext, source: IResultsProvider<TRowSet | undefined>, schema?: TTableSchema) {
constructor(
context: IClientContext,
source: IResultsProvider<TRowSet | undefined>,
{ schema }: TGetResultSetMetadataResp,
) {
this.context = context;
this.source = source;
this.schema = getSchemaColumns(schema);
Expand Down
6 changes: 4 additions & 2 deletions lib/result/ResultSlicer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@ export default class ResultSlicer<T> implements IResultsProvider<Array<T>> {
// Fetch items from source results provider until we reach a requested count
while (resultsCount < options.limit) {
// eslint-disable-next-line no-await-in-loop
const chunk = await this.source.fetchNext(options);
if (chunk.length === 0) {
const hasMore = await this.source.hasMore();
if (!hasMore) {
break;
}

// eslint-disable-next-line no-await-in-loop
const chunk = await this.source.fetchNext(options);
result.push(chunk);
resultsCount += chunk.length;
}
Expand Down
67 changes: 67 additions & 0 deletions lib/result/utils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
import Int64 from 'node-int64';
import {
Schema,
Field,
DataType,
Bool as ArrowBool,
Int8 as ArrowInt8,
Int16 as ArrowInt16,
Int32 as ArrowInt32,
Int64 as ArrowInt64,
Float32 as ArrowFloat32,
Float64 as ArrowFloat64,
Utf8 as ArrowString,
Date_ as ArrowDate,
Binary as ArrowBinary,
DateUnit,
RecordBatchWriter,
} from 'apache-arrow';
import { TTableSchema, TColumnDesc, TPrimitiveTypeEntry, TTypeId } from '../../thrift/TCLIService_types';
import HiveDriverError from '../errors/HiveDriverError';

export function getSchemaColumns(schema?: TTableSchema): Array<TColumnDesc> {
if (!schema) {
Expand Down Expand Up @@ -73,3 +91,52 @@ export function convertThriftValue(typeDescriptor: TPrimitiveTypeEntry | undefin
return value;
}
}

// This type map corresponds to Arrow without native types support (most complex types are serialized as strings)
const hiveTypeToArrowType: Record<TTypeId, DataType | null> = {
[TTypeId.BOOLEAN_TYPE]: new ArrowBool(),
[TTypeId.TINYINT_TYPE]: new ArrowInt8(),
[TTypeId.SMALLINT_TYPE]: new ArrowInt16(),
[TTypeId.INT_TYPE]: new ArrowInt32(),
[TTypeId.BIGINT_TYPE]: new ArrowInt64(),
[TTypeId.FLOAT_TYPE]: new ArrowFloat32(),
[TTypeId.DOUBLE_TYPE]: new ArrowFloat64(),
[TTypeId.STRING_TYPE]: new ArrowString(),
[TTypeId.TIMESTAMP_TYPE]: new ArrowString(),
[TTypeId.BINARY_TYPE]: new ArrowBinary(),
[TTypeId.ARRAY_TYPE]: new ArrowString(),
[TTypeId.MAP_TYPE]: new ArrowString(),
[TTypeId.STRUCT_TYPE]: new ArrowString(),
[TTypeId.UNION_TYPE]: new ArrowString(),
[TTypeId.USER_DEFINED_TYPE]: new ArrowString(),
[TTypeId.DECIMAL_TYPE]: new ArrowString(),
[TTypeId.NULL_TYPE]: null,
[TTypeId.DATE_TYPE]: new ArrowDate(DateUnit.DAY),
[TTypeId.VARCHAR_TYPE]: new ArrowString(),
[TTypeId.CHAR_TYPE]: new ArrowString(),
[TTypeId.INTERVAL_YEAR_MONTH_TYPE]: new ArrowString(),
[TTypeId.INTERVAL_DAY_TIME_TYPE]: new ArrowString(),
};

export function hiveSchemaToArrowSchema(schema?: TTableSchema): Buffer | undefined {
if (!schema) {
return undefined;
}

const columns = getSchemaColumns(schema);

const arrowFields = columns.map((column) => {
const hiveType = column.typeDesc.types[0].primitiveEntry?.type ?? undefined;
const arrowType = hiveType !== undefined ? hiveTypeToArrowType[hiveType] : undefined;
if (!arrowType) {
throw new HiveDriverError(`Unsupported column type: ${hiveType ? TTypeId[hiveType] : 'undefined'}`);
}
return new Field(column.columnName, arrowType, true);
});

const arrowSchema = new Schema(arrowFields);
const writer = new RecordBatchWriter();
writer.reset(undefined, arrowSchema);
writer.finish();
return Buffer.from(writer.toUint8Array(true));
}
8 changes: 4 additions & 4 deletions tests/unit/result/ArrowResultConverter.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -57,30 +57,30 @@ describe('ArrowResultHandler', () => {
it('should convert data', async () => {
const context = {};
const rowSetProvider = new ResultsProviderMock([sampleArrowBatch]);
const result = new ArrowResultConverter(context, rowSetProvider, sampleThriftSchema);
const result = new ArrowResultConverter(context, rowSetProvider, { schema: sampleThriftSchema });
expect(await result.fetchNext({ limit: 10000 })).to.be.deep.eq([{ 1: 1 }]);
});

it('should return empty array if no data to process', async () => {
const context = {};
const rowSetProvider = new ResultsProviderMock([], []);
const result = new ArrowResultConverter(context, rowSetProvider, sampleThriftSchema);
const result = new ArrowResultConverter(context, rowSetProvider, { schema: sampleThriftSchema });
expect(await result.fetchNext({ limit: 10000 })).to.be.deep.eq([]);
expect(await result.hasMore()).to.be.false;
});

it('should return empty array if no schema available', async () => {
const context = {};
const rowSetProvider = new ResultsProviderMock([sampleArrowBatch]);
const result = new ArrowResultConverter(context, rowSetProvider);
const result = new ArrowResultConverter(context, rowSetProvider, {});
expect(await result.hasMore()).to.be.false;
expect(await result.fetchNext({ limit: 10000 })).to.be.deep.eq([]);
});

it('should detect nulls', async () => {
const context = {};
const rowSetProvider = new ResultsProviderMock([arrowBatchAllNulls]);
const result = new ArrowResultConverter(context, rowSetProvider, thriftSchemaAllNulls);
const result = new ArrowResultConverter(context, rowSetProvider, { schema: thriftSchemaAllNulls });
expect(await result.fetchNext({ limit: 10000 })).to.be.deep.eq([
{
boolean_field: null,
Expand Down
46 changes: 38 additions & 8 deletions tests/unit/result/ArrowResultHandler.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ describe('ArrowResultHandler', () => {
it('should return data', async () => {
const context = {};
const rowSetProvider = new ResultsProviderMock([sampleRowSet1]);
const result = new ArrowResultHandler(context, rowSetProvider, sampleArrowSchema);
const result = new ArrowResultHandler(context, rowSetProvider, { arrowSchema: sampleArrowSchema });

const batches = await result.fetchNext({ limit: 10000 });
expect(await rowSetProvider.hasMore()).to.be.false;
Expand All @@ -74,7 +74,10 @@ describe('ArrowResultHandler', () => {
it('should handle LZ4 compressed data', async () => {
const context = {};
const rowSetProvider = new ResultsProviderMock([sampleRowSet1LZ4Compressed]);
const result = new ArrowResultHandler(context, rowSetProvider, sampleArrowSchema, true);
const result = new ArrowResultHandler(context, rowSetProvider, {
arrowSchema: sampleArrowSchema,
lz4Compressed: true,
});

const batches = await result.fetchNext({ limit: 10000 });
expect(await rowSetProvider.hasMore()).to.be.false;
Expand All @@ -87,7 +90,7 @@ describe('ArrowResultHandler', () => {
it('should not buffer any data', async () => {
const context = {};
const rowSetProvider = new ResultsProviderMock([sampleRowSet1]);
const result = new ArrowResultHandler(context, rowSetProvider, sampleArrowSchema);
const result = new ArrowResultHandler(context, rowSetProvider, { arrowSchema: sampleArrowSchema });
expect(await rowSetProvider.hasMore()).to.be.true;
expect(await result.hasMore()).to.be.true;

Expand All @@ -100,34 +103,61 @@ describe('ArrowResultHandler', () => {
const context = {};
case1: {
const rowSetProvider = new ResultsProviderMock();
const result = new ArrowResultHandler(context, rowSetProvider, sampleArrowSchema);
const result = new ArrowResultHandler(context, rowSetProvider, { arrowSchema: sampleArrowSchema });
expect(await result.fetchNext({ limit: 10000 })).to.be.deep.eq([]);
expect(await result.hasMore()).to.be.false;
}
case2: {
const rowSetProvider = new ResultsProviderMock([sampleRowSet2]);
const result = new ArrowResultHandler(context, rowSetProvider, sampleArrowSchema);
const result = new ArrowResultHandler(context, rowSetProvider, { arrowSchema: sampleArrowSchema });
expect(await result.fetchNext({ limit: 10000 })).to.be.deep.eq([]);
expect(await result.hasMore()).to.be.false;
}
case3: {
const rowSetProvider = new ResultsProviderMock([sampleRowSet3]);
const result = new ArrowResultHandler(context, rowSetProvider, sampleArrowSchema);
const result = new ArrowResultHandler(context, rowSetProvider, { arrowSchema: sampleArrowSchema });
expect(await result.fetchNext({ limit: 10000 })).to.be.deep.eq([]);
expect(await result.hasMore()).to.be.false;
}
case4: {
const rowSetProvider = new ResultsProviderMock([sampleRowSet4]);
const result = new ArrowResultHandler(context, rowSetProvider, sampleArrowSchema);
const result = new ArrowResultHandler(context, rowSetProvider, { arrowSchema: sampleArrowSchema });
expect(await result.fetchNext({ limit: 10000 })).to.be.deep.eq([]);
expect(await result.hasMore()).to.be.false;
}
});

it('should infer arrow schema from thrift schema', async () => {
const context = {};
const rowSetProvider = new ResultsProviderMock([sampleRowSet2]);

const sampleThriftSchema = {
columns: [
{
columnName: '1',
typeDesc: {
types: [
{
primitiveEntry: {
type: 3,
typeQualifiers: null,
},
},
],
},
position: 1,
},
],
};

const result = new ArrowResultHandler(context, rowSetProvider, { schema: sampleThriftSchema });
expect(result.arrowSchema).to.not.be.undefined;
});

it('should return empty array if no schema available', async () => {
const context = {};
const rowSetProvider = new ResultsProviderMock([sampleRowSet2]);
const result = new ArrowResultHandler(context, rowSetProvider);
const result = new ArrowResultHandler(context, rowSetProvider, {});
expect(await result.fetchNext({ limit: 10000 })).to.be.deep.eq([]);
expect(await result.hasMore()).to.be.false;
});
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/result/CloudFetchResultHandler.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ describe('CloudFetchResultHandler', () => {
getConfig: () => clientConfig,
};

const result = new CloudFetchResultHandler(context, rowSetProvider);
const result = new CloudFetchResultHandler(context, rowSetProvider, {});

case1: {
result.pendingLinks = [];
Expand Down Expand Up @@ -119,7 +119,7 @@ describe('CloudFetchResultHandler', () => {
getConfig: () => clientConfig,
};

const result = new CloudFetchResultHandler(context, rowSetProvider);
const result = new CloudFetchResultHandler(context, rowSetProvider, {});

sinon.stub(result, 'fetch').returns(
Promise.resolve({
Expand Down Expand Up @@ -153,7 +153,7 @@ describe('CloudFetchResultHandler', () => {
getConfig: () => clientConfig,
};

const result = new CloudFetchResultHandler(context, rowSetProvider);
const result = new CloudFetchResultHandler(context, rowSetProvider, {});

sinon.stub(result, 'fetch').returns(
Promise.resolve({
Expand Down Expand Up @@ -213,7 +213,7 @@ describe('CloudFetchResultHandler', () => {
getConfig: () => clientConfig,
};

const result = new CloudFetchResultHandler(context, rowSetProvider, true);
const result = new CloudFetchResultHandler(context, rowSetProvider, { lz4Compressed: true });

const expectedBatch = Buffer.concat([sampleArrowSchema, sampleArrowBatch]);

Expand Down Expand Up @@ -244,7 +244,7 @@ describe('CloudFetchResultHandler', () => {
getConfig: () => clientConfig,
};

const result = new CloudFetchResultHandler(context, rowSetProvider);
const result = new CloudFetchResultHandler(context, rowSetProvider, {});

sinon.stub(result, 'fetch').returns(
Promise.resolve({
Expand Down Expand Up @@ -275,7 +275,7 @@ describe('CloudFetchResultHandler', () => {
getConfig: () => clientConfig,
};

const result = new CloudFetchResultHandler(context, rowSetProvider);
const result = new CloudFetchResultHandler(context, rowSetProvider, {});

sinon.stub(result, 'fetch').returns(
Promise.resolve({
Expand Down
Loading

0 comments on commit b92b10c

Please sign in to comment.