Skip to content

Commit

Permalink
feat: add support for 'array' datatype (#224)
Browse files Browse the repository at this point in the history
example usage

```ts
const s = pl.Series(
  "a",
  [
    [1, 2],
    [3, 4],
  ],
  pl.FixedSizeList(pl.Int32, 2),
);

shape: (2,)
Series: 'a' [array[i32, 2]]
[
	[1, 2]
	[3, 4]
]
  • Loading branch information
universalmind303 authored Jun 12, 2024
1 parent e107597 commit 0707fb9
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 147 deletions.
8 changes: 1 addition & 7 deletions polars/dataframe.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1292,13 +1292,7 @@ export interface DataFrame
* ```
*/
shiftAndFill(n: number, fillValue: number): DataFrame;
shiftAndFill({
n,
fillValue,
}: {
n: number;
fillValue: number;
}): DataFrame;
shiftAndFill({ n, fillValue }: { n: number; fillValue: number }): DataFrame;
/**
* Shrink memory usage of this DataFrame to fit the exact capacity needed to hold the data.
*/
Expand Down
213 changes: 126 additions & 87 deletions polars/datatypes/datatype.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { Field } from "./field";

export abstract class DataType {
get variant() {
return this.constructor.name.slice(1);
return this.constructor.name;
}
protected identity = "DataType";
protected get inner(): null | any[] {
Expand All @@ -18,67 +18,67 @@ export abstract class DataType {

/** Null type */
public static get Null(): DataType {
return new _Null();
return new Null();
}
/** `true` and `false`. */
public static get Bool(): DataType {
return new _Bool();
return new Bool();
}
/** An `i8` */
public static get Int8(): DataType {
return new _Int8();
return new Int8();
}
/** An `i16` */
public static get Int16(): DataType {
return new _Int16();
return new Int16();
}
/** An `i32` */
public static get Int32(): DataType {
return new _Int32();
return new Int32();
}
/** An `i64` */
public static get Int64(): DataType {
return new _Int64();
return new Int64();
}
/** An `u8` */
public static get UInt8(): DataType {
return new _UInt8();
return new UInt8();
}
/** An `u16` */
public static get UInt16(): DataType {
return new _UInt16();
return new UInt16();
}
/** An `u32` */
public static get UInt32(): DataType {
return new _UInt32();
return new UInt32();
}
/** An `u64` */
public static get UInt64(): DataType {
return new _UInt64();
return new UInt64();
}

/** A `f32` */
public static get Float32(): DataType {
return new _Float32();
return new Float32();
}
/** A `f64` */
public static get Float64(): DataType {
return new _Float64();
return new Float64();
}
public static get Date(): DataType {
return new _Date();
return new Date();
}
/** Time of day type */
public static get Time(): DataType {
return new _Time();
return new Time();
}
/** Type for wrapping arbitrary JS objects */
public static get Object(): DataType {
return new _Object();
return new Object_();
}
/** A categorical encoding of a set of strings */
public static get Categorical(): DataType {
return new _Categorical();
return new Categorical();
}

/**
Expand All @@ -93,7 +93,7 @@ export abstract class DataType {
timeUnit,
timeZone: string | null | undefined = null,
): DataType {
return new _Datetime(timeUnit, timeZone as any);
return new Datetime(timeUnit, timeZone as any);
}
/**
* Nested list/array type
Expand All @@ -102,7 +102,15 @@ export abstract class DataType {
*
*/
public static List(inner: DataType): DataType {
return new _List(inner);
return new List(inner);
}
/**
* List of fixed length
* This is called `Array` in other polars implementations, but `Array` is widely used in JS, so we use `FixedSizeList` instead.
*
*/
public static FixedSizeList(inner: DataType, listSize: number): DataType {
return new FixedSizeList(inner, listSize);
}
/**
* Struct type
Expand All @@ -112,15 +120,15 @@ export abstract class DataType {
public static Struct(
fields: Field[] | { [key: string]: DataType },
): DataType {
return new _Struct(fields);
return new Struct(fields);
}
/** A variable-length UTF-8 encoded string whose offsets are represented as `i64`. */
public static get Utf8(): DataType {
return new _Utf8();
return new Utf8();
}

public static get String(): DataType {
return new _String();
return new String();
}

toString() {
Expand All @@ -131,7 +139,6 @@ export abstract class DataType {
}
toJSON() {
const inner = (this as any).inner;

if (inner) {
return {
[this.identity]: {
Expand All @@ -149,32 +156,40 @@ export abstract class DataType {
static from(obj): DataType {
return null as any;
}
asFixedSizeList() {
if (this instanceof FixedSizeList) {
return this;
}
return null;
}
}

class _Null extends DataType {}
class _Bool extends DataType {}
class _Int8 extends DataType {}
class _Int16 extends DataType {}
class _Int32 extends DataType {}
class _Int64 extends DataType {}
class _UInt8 extends DataType {}
class _UInt16 extends DataType {}
class _UInt32 extends DataType {}
class _UInt64 extends DataType {}
class _Float32 extends DataType {}
class _Float64 extends DataType {}
class _Date extends DataType {}
class _Time extends DataType {}
class _Object extends DataType {}
class _Utf8 extends DataType {}
class _String extends DataType {}
export class Null extends DataType {}
export class Bool extends DataType {}
export class Int8 extends DataType {}
export class Int16 extends DataType {}
export class Int32 extends DataType {}
export class Int64 extends DataType {}
export class UInt8 extends DataType {}
export class UInt16 extends DataType {}
export class UInt32 extends DataType {}
export class UInt64 extends DataType {}
export class Float32 extends DataType {}
export class Float64 extends DataType {}
// biome-ignore lint/suspicious/noShadowRestrictedNames: <explanation>
export class Date extends DataType {}
export class Time extends DataType {}
export class Object_ extends DataType {}
export class Utf8 extends DataType {}
// biome-ignore lint/suspicious/noShadowRestrictedNames: <explanation>
export class String extends DataType {}

class _Categorical extends DataType {}
export class Categorical extends DataType {}

/**
* Datetime type
*/
class _Datetime extends DataType {
export class Datetime extends DataType {
constructor(
private timeUnit: TimeUnit,
private timeZone?: string,
Expand All @@ -188,15 +203,15 @@ class _Datetime extends DataType {
override equals(other: DataType): boolean {
if (other.variant === this.variant) {
return (
this.timeUnit === (other as _Datetime).timeUnit &&
this.timeZone === (other as _Datetime).timeZone
this.timeUnit === (other as Datetime).timeUnit &&
this.timeZone === (other as Datetime).timeZone
);
}
return false;
}
}

class _List extends DataType {
export class List extends DataType {
constructor(protected __inner: DataType) {
super();
}
Expand All @@ -205,13 +220,50 @@ class _List extends DataType {
}
override equals(other: DataType): boolean {
if (other.variant === this.variant) {
return this.inner[0].equals((other as _List).inner[0]);
return this.inner[0].equals((other as List).inner[0]);
}
return false;
}
}

export class FixedSizeList extends DataType {
constructor(
protected __inner: DataType,
protected listSize: number,
) {
super();
}

override get variant() {
return "FixedSizeList";
}

override get inner(): [DataType, number] {
return [this.__inner, this.listSize];
}

override equals(other: DataType): boolean {
if (other.variant === this.variant) {
return (
this.inner[0].equals((other as FixedSizeList).inner[0]) &&
this.inner[1] === (other as FixedSizeList).inner[1]
);
}
return false;
}
override toJSON() {
return {
[this.identity]: {
[this.variant]: {
type: this.inner[0].toJSON(),
size: this.inner[1],
},
},
};
}
}

class _Struct extends DataType {
export class Struct extends DataType {
private fields: Field[];

constructor(
Expand All @@ -235,7 +287,7 @@ class _Struct extends DataType {
if (other.variant === this.variant) {
return this.inner
.map((fld, idx) => {
const otherfld = (other as _Struct).fields[idx];
const otherfld = (other as Struct).fields[idx];

return otherfld.name === fld.name && otherfld.dtype.equals(fld.dtype);
})
Expand Down Expand Up @@ -275,45 +327,28 @@ export namespace TimeUnit {
* Datatype namespace
*/
export namespace DataType {
/** Null */
export type Null = _Null;
/** Boolean */
export type Bool = _Bool;
/** Int8 */
export type Int8 = _Int8;
/** Int16 */
export type Int16 = _Int16;
/** Int32 */
export type Int32 = _Int32;
/** Int64 */
export type Int64 = _Int64;
/** UInt8 */
export type UInt8 = _UInt8;
/** UInt16 */
export type UInt16 = _UInt16;
/** UInt32 */
export type UInt32 = _UInt32;
/** UInt64 */
export type UInt64 = _UInt64;
/** Float32 */
export type Float32 = _Float32;
/** Float64 */
export type Float64 = _Float64;
/** Date dtype */
export type Date = _Date;
/** Datetime */
export type Datetime = _Datetime;
/** Utf8 */
export type Utf8 = _Utf8;
/** Utf8 */
export type String = _String;
/** Categorical */
export type Categorical = _Categorical;
/** List */
export type List = _List;
/** Struct */
export type Struct = _Struct;

export type Categorical = import(".").Categorical;
export type Int8 = import(".").Int8;
export type Int16 = import(".").Int16;
export type Int32 = import(".").Int32;
export type Int64 = import(".").Int64;
export type UInt8 = import(".").UInt8;
export type UInt16 = import(".").UInt16;
export type UInt32 = import(".").UInt32;
export type UInt64 = import(".").UInt64;
export type Float32 = import(".").Float32;
export type Float64 = import(".").Float64;
export type Bool = import(".").Bool;
export type Utf8 = import(".").Utf8;
export type String = import(".").String;
export type List = import(".").List;
export type FixedSizeList = import(".").FixedSizeList;
export type Date = import(".").Date;
export type Datetime = import(".").Datetime;
export type Time = import(".").Time;
export type Object = import(".").Object_;
export type Null = import(".").Null;
export type Struct = import(".").Struct;
/**
* deserializes a datatype from the serde output of rust polars `DataType`
* @param dtype dtype object
Expand All @@ -333,6 +368,10 @@ export namespace DataType {
inner = [deserialize(inner[0])];
}

if (variant === "FixedSizeList") {
inner = [deserialize(inner[0]), inner[1]];
}

return DataType[variant](...inner);
}
}
5 changes: 3 additions & 2 deletions polars/datatypes/index.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { DataType, TimeUnit } from "./datatype";
export { DataType, TimeUnit };
export * from "./datatype";
export { Field } from "./field";

import pli from "../internals/polars_internal";
// biome-ignore lint/style/useImportType: <explanation>
import { type DataType } from "./datatype";

/** @ignore */
export type TypedArray =
Expand Down
Loading

0 comments on commit 0707fb9

Please sign in to comment.