Skip to content

Commit

Permalink
feat: Add BigInt64Array and BigUint64Array (#49)
Browse files Browse the repository at this point in the history
* feat: Add `BigInt64Array` and `BigUint64Array`

* Update Podfile.lock
  • Loading branch information
mrousavy authored Feb 5, 2024
1 parent 7bb0b1e commit a533fd4
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 1 deletion.
18 changes: 18 additions & 0 deletions cpp/TensorHelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ TfLiteType getTFLDataTypeForTypedArrayKind(TypedArrayKind kind) {
return kTfLiteFloat32;
case TypedArrayKind::Float64Array:
return kTfLiteFloat64;
case TypedArrayKind::BigInt64Array:
return kTfLiteInt64;
case TypedArrayKind::BigUint64Array:
return kTfLiteUInt64;
}
}

Expand Down Expand Up @@ -151,6 +155,10 @@ TypedArrayBase TensorHelpers::createJSBufferForTensor(jsi::Runtime& runtime,
return TypedArray<TypedArrayKind::Uint16Array>(runtime, size);
case kTfLiteUInt32:
return TypedArray<TypedArrayKind::Uint32Array>(runtime, size);
case kTfLiteInt64:
return TypedArray<TypedArrayKind::BigInt64Array>(runtime, size);
case kTfLiteUInt64:
return TypedArray<TypedArrayKind::BigUint64Array>(runtime, size);
default:
[[unlikely]];
throw std::runtime_error("TFLite: Unsupported tensor data type! " +
Expand Down Expand Up @@ -213,6 +221,16 @@ void TensorHelpers::updateJSBufferFromTensor(jsi::Runtime& runtime, TypedArrayBa
.as<TypedArrayKind::Uint32Array>(runtime)
.updateUnsafe(runtime, (uint32_t*)data, size);
break;
case kTfLiteInt64:
getTypedArray(runtime, jsBuffer)
.as<TypedArrayKind::BigInt64Array>(runtime)
.updateUnsafe(runtime, (int64_t*)data, size);
break;
case kTfLiteUInt64:
getTypedArray(runtime, jsBuffer)
.as<TypedArrayKind::BigUint64Array>(runtime)
.updateUnsafe(runtime, (uint64_t*)data, size);
break;
default:
[[unlikely]];
throw jsi::JSError(runtime,
Expand Down
14 changes: 14 additions & 0 deletions cpp/jsi/TypedArray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ enum class Prop {
Uint32Array, // "Uint32Array"
Float32Array, // "Float32Array"
Float64Array, // "Float64Array"
BigInt64Array, // "BigInt64Array"
BigUint64Array, // "BigUint64Array"
};

class PropNameIDCache {
Expand Down Expand Up @@ -260,6 +262,10 @@ const jsi::PropNameID& PropNameIDCache::getConstructorNameProp(jsi::Runtime& run
return get(runtime, Prop::Float32Array);
case TypedArrayKind::Float64Array:
return get(runtime, Prop::Float64Array);
case TypedArrayKind::BigInt64Array:
return get(runtime, Prop::BigInt64Array);
case TypedArrayKind::BigUint64Array:
return get(runtime, Prop::BigUint64Array);
}
}

Expand Down Expand Up @@ -304,6 +310,10 @@ jsi::PropNameID PropNameIDCache::createProp(jsi::Runtime& runtime, Prop prop) {
return create("Float32Array");
case Prop::Float64Array:
return create("Float64Array");
case Prop::BigInt64Array:
return create("BigInt64Array");
case Prop::BigUint64Array:
return create("BigUint64Array");
}
}

Expand All @@ -317,6 +327,8 @@ std::unordered_map<std::string, TypedArrayKind> nameToKindMap = {
{"Uint32Array", TypedArrayKind::Uint32Array},
{"Float32Array", TypedArrayKind::Float32Array},
{"Float64Array", TypedArrayKind::Float64Array},
{"BigInt64Array", TypedArrayKind::BigInt64Array},
{"BigUint64Array", TypedArrayKind::BigUint64Array},
};

TypedArrayKind getTypedArrayKindForName(const std::string& name) {
Expand All @@ -332,5 +344,7 @@ template class TypedArray<TypedArrayKind::Uint16Array>;
template class TypedArray<TypedArrayKind::Uint32Array>;
template class TypedArray<TypedArrayKind::Float32Array>;
template class TypedArray<TypedArrayKind::Float64Array>;
template class TypedArray<TypedArrayKind::BigInt64Array>;
template class TypedArray<TypedArrayKind::BigUint64Array>;

} // namespace mrousavy
8 changes: 8 additions & 0 deletions cpp/jsi/TypedArray.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ enum class TypedArrayKind {
Uint32Array,
Float32Array,
Float64Array,
BigInt64Array,
BigUint64Array,
};

template <TypedArrayKind T> class TypedArray;
Expand Down Expand Up @@ -61,6 +63,12 @@ template <> struct typedArrayTypeMap<TypedArrayKind::Float32Array> {
template <> struct typedArrayTypeMap<TypedArrayKind::Float64Array> {
typedef double type;
};
template <> struct typedArrayTypeMap<TypedArrayKind::BigInt64Array> {
typedef int64_t type;
};
template <> struct typedArrayTypeMap<TypedArrayKind::BigUint64Array> {
typedef uint64_t type;
};

// Instance of this class will invalidate PropNameIDCache when destructor is called.
// Attach this object to global in specific jsi::Runtime to make sure lifecycle of
Expand Down
2 changes: 1 addition & 1 deletion example/ios/Podfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -1299,7 +1299,7 @@ SPEC CHECKSUMS:
SocketRocket: f32cd54efbe0f095c4d7594881e52619cfe80b17
vision-camera-resize-plugin: 536345c29f42e04438d34d0cada7b4992a8fc104
VisionCamera: edbcd00e27a438b2228f67823e2b8d15a189065f
Yoga: 13c8ef87792450193e117976337b8527b49e8c03
Yoga: e64aa65de36c0832d04e8c7bd614396c77a80047

PODFILE CHECKSUM: 33c2f84b68c18fe6787cbe3e723675d15fcc7e66

Expand Down
2 changes: 2 additions & 0 deletions src/TensorflowLite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ type TypedArray =
| Uint8Array
| Uint16Array
| Uint32Array
| BigInt64Array
| BigUint64Array

declare global {
/**
Expand Down

0 comments on commit a533fd4

Please sign in to comment.