From cb66e6c3ae4488036fcd027da7ee55c45900bfd1 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 22 Feb 2023 15:20:21 -0500 Subject: [PATCH] [WEB] Reduce memleak in web runtime (#14086) This PR robustifies the web runtime to reduce memory leak and enhances the runtime with object support. Specifically we introduce scoping and auto-release mechanism when we exit the scope. The improvements are helpful to deal with memory leak in wasm and webgpu settings --- web/.eslintignore | 1 + web/README.md | 2 +- web/apps/node/example.js | 2 + web/emcc/wasm_runtime.cc | 2 +- web/src/ctypes.ts | 22 + web/src/index.ts | 5 +- web/src/rpc_server.ts | 19 +- web/src/runtime.ts | 642 +++++++++++++++++++++------ web/tests/node/test_module_load.js | 15 +- web/tests/node/test_ndarray.js | 16 +- web/tests/node/test_object.js | 45 ++ web/tests/node/test_packed_func.js | 59 ++- web/tests/python/websock_rpc_test.py | 1 - 13 files changed, 670 insertions(+), 161 deletions(-) create mode 100644 web/tests/node/test_object.js diff --git a/web/.eslintignore b/web/.eslintignore index 1521c8b7652b..f71ee79871c4 100644 --- a/web/.eslintignore +++ b/web/.eslintignore @@ -1 +1,2 @@ dist +debug diff --git a/web/README.md b/web/README.md index 4154300e62e4..64f507579e94 100644 --- a/web/README.md +++ b/web/README.md @@ -81,7 +81,7 @@ The following is an example to reproduce this. - Start the WebSocket RPC - Browswer version: open https://localhost:8888, click connect to proxy - NodeJS version: `npm run rpc` -- run `python tests/node/websock_rpc_test.py` to run the rpc client. +- run `python tests/python/websock_rpc_test.py` to run the rpc test. ## WebGPU Experiments diff --git a/web/apps/node/example.js b/web/apps/node/example.js index cff76d8a067e..0cd6b532011b 100644 --- a/web/apps/node/example.js +++ b/web/apps/node/example.js @@ -31,8 +31,10 @@ const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); // the async version of the API. tvmjs.instantiate(wasmSource, new EmccWASI()) .then((tvm) => { + tvm.beginScope(); const log_info = tvm.getGlobalFunc("testing.log_info_str"); log_info("hello world"); // List all the global functions from the runtime. console.log("Runtime functions using EmccWASI\n", tvm.listGlobalFuncNames()); + tvm.endScope(); }); diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index 2b0ee49d7edd..00d2a8c579f1 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -32,10 +32,10 @@ #include #include "src/runtime/c_runtime_api.cc" +#include "src/runtime/container.cc" #include "src/runtime/contrib/sort/sort.cc" #include "src/runtime/cpu_device_api.cc" #include "src/runtime/file_utils.cc" -#include "src/runtime/graph_executor/graph_executor.cc" #include "src/runtime/library_module.cc" #include "src/runtime/logging.cc" #include "src/runtime/module.cc" diff --git a/web/src/ctypes.ts b/web/src/ctypes.ts index 4a6d25ae6270..282679fc02e5 100644 --- a/web/src/ctypes.ts +++ b/web/src/ctypes.ts @@ -46,6 +46,7 @@ export type FTVMModGetFunction = ( * TVMModuleHandle dep); */ export type FTVMModImport = (mod: Pointer, dep: Pointer) => number; + /** * int TVMModFree(TVMModuleHandle mod); */ @@ -161,6 +162,27 @@ export type FTVMBackendPackedCFunc = ( argValues: Pointer, argCodes: Pointer, nargs: number, outValue: Pointer, outCode: Pointer) => number; + +/** + * int TVMObjectFree(TVMObjectHandle obj); + */ + export type FTVMObjectFree = (obj: Pointer) => number; + +/** + * int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex); + */ +export type FTVMObjectGetTypeIndex = (obj: Pointer, out_tindex: Pointer) => number; + +/** + * int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key); + */ +export type FTVMObjectTypeIndex2Key = (type_index: number, out_type_key: Pointer) => number; + +/** + * int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); + */ +export type FTVMObjectTypeKey2Index = (type_key: Pointer, out_tindex: Pointer) => number; + // -- TVM Wasm Auxiliary C API -- /** void* TVMWasmAllocSpace(int size); */ diff --git a/web/src/index.ts b/web/src/index.ts index ac82e5967f48..bf2d982e21f3 100644 --- a/web/src/index.ts +++ b/web/src/index.ts @@ -19,8 +19,9 @@ export { Scalar, DLDevice, DLDataType, - PackedFunc, Module, NDArray, Instance, - instantiate + PackedFunc, Module, NDArray, + TVMArray, + Instance, instantiate } from "./runtime"; export { Disposable, LibraryProvider } from "./types"; export { RPCServer } from "./rpc_server"; diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts index c63dcf3a9ae3..e37d1838d604 100644 --- a/web/src/rpc_server.ts +++ b/web/src/rpc_server.ts @@ -22,6 +22,8 @@ import { assert, StringToUint8Array, Uint8ArrayToString } from "./support"; import { detectGPUDevice } from "./webgpu"; import * as compact from "./compact"; import * as runtime from "./runtime"; +import { timeStamp } from "console"; +import { Disposable } from "./types"; enum RPCServerState { InitHeader, @@ -83,6 +85,7 @@ export class RPCServer { private pendingSend: Promise = Promise.resolve(); private name: string; private inst?: runtime.Instance = undefined; + private globalObjects: Array = []; private serverRecvData?: (header: Uint8Array, body: Uint8Array) => void; private currPacketHeader?: Uint8Array; private currPacketLength = 0; @@ -121,6 +124,9 @@ export class RPCServer { // eslint-disable-next-line @typescript-eslint/no-unused-vars private onClose(_event: CloseEvent): void { if (this.inst !== undefined) { + this.globalObjects.forEach(obj => { + obj.dispose(); + }); this.inst.dispose(); } if (this.state == RPCServerState.ReceivePacketHeader) { @@ -263,6 +269,9 @@ export class RPCServer { } this.inst = inst; + // begin scope to allow handling of objects + // the object should stay alive during all sessions. + this.inst.beginScope(); const fcreate = this.inst.getGlobalFunc("rpc.CreateEventDrivenServer"); const messageHandler = fcreate( @@ -301,8 +310,10 @@ export class RPCServer { this.name, this.key ); - - fcreate.dispose(); + // message handler should persist across RPC runs + this.globalObjects.push( + this.inst.detachFromCurrentScope(messageHandler) + ); const writeFlag = this.inst.scalar(3, "int32"); this.serverRecvData = (header: Uint8Array, body: Uint8Array): void => { @@ -320,7 +331,6 @@ export class RPCServer { // register the callback to redirect the session to local. const flocal = this.inst.getGlobalFunc("wasm.LocalSession"); const localSession = flocal(); - flocal.dispose(); assert(localSession instanceof runtime.Module); // eslint-disable-next-line @typescript-eslint/no-unused-vars @@ -333,13 +343,14 @@ export class RPCServer { ); messageHandler(header, writeFlag); messageHandler(body, writeFlag); - localSession.dispose(); this.log("Finish initializing the Wasm Server.."); this.requestBytes(SizeOf.I64); this.state = RPCServerState.ReceivePacketHeader; // call process events in case there are bufferred data. this.processEvents(); + // recycle all values. + this.inst.endScope(); }; this.state = RPCServerState.WaitForCallback; diff --git a/web/src/runtime.ts b/web/src/runtime.ts index b341a7d4b1a4..a24459ca29a0 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -29,6 +29,7 @@ import { WebGPUContext } from "./webgpu"; import * as compact from "./compact"; import * as ctypes from "./ctypes"; +import { tsImportEqualsDeclaration } from "@babel/types"; /** * Type for PackedFunc inthe TVMRuntime. @@ -134,6 +135,95 @@ class FFILibrary implements Disposable { } } +/** + * @internal + * Manages extra runtime context for the runtime. + */ +class RuntimeContext implements Disposable { + arrayGetItem : PackedFunc; + arrayGetSize : PackedFunc; + arrayMake : PackedFunc; + getSysLib: PackedFunc; + + private autoDisposeScope: Array> = []; + + constructor(getGlobalFunc: (name: string) => PackedFunc) { + this.arrayGetItem = getGlobalFunc("runtime.ArrayGetItem"); + this.arrayGetSize = getGlobalFunc("runtime.ArraySize"); + this.arrayMake = getGlobalFunc("runtime.Array"); + this.getSysLib = getGlobalFunc("runtime.SystemLib"); + } + + dispose(): void { + this.arrayGetItem.dispose(); + this.arrayGetSize.dispose(); + this.arrayMake.dispose(); + } + + beginScope() : void { + this.autoDisposeScope.push([]); + } + + endScope() : void { + if (this.autoDisposeScope.length == 0) { + throw Error("tvm.endScope called when the stack is empty."); + } + // automatically dispose all the tracked values in the current scope. + const currScope = this.autoDisposeScope.pop() as Array; + for (let i = 0; i < currScope.length; ++i) { + const val = currScope[i]; + if (val !== undefined) { + val.dispose(); + } + } + } + + /** + * Track object for dispose in current scope. + * + * @param obj The object to be tracked. + * @returns the same object. + * @note This function only needs to be called for raw system C API values. + * The return value of PackedFunc will be automatically tracked. + */ + attachToCurrentScope(obj: T): T { + if (this.autoDisposeScope.length == 0) { + throw Error("Must call beginScope to use functions that returns TVM objects"); + } + const currScope = this.autoDisposeScope[this.autoDisposeScope.length - 1]; + currScope.push(obj); + return obj; + } + + moveToParentScope(obj: T): T { + this.detachFromCurrentScope(obj); + if (this.autoDisposeScope.length < 2) { + throw Error("moveToParentScope: Parent scope do not exist"); + } + const parentScope = this.autoDisposeScope[this.autoDisposeScope.length - 2]; + parentScope.push(obj); + return obj; + } + + detachFromCurrentScope(obj: T): T { + const currScope = this.autoDisposeScope[this.autoDisposeScope.length - 1]; + let occurance = 0; + for (let i = 0; i < currScope.length; ++i) { + if (currScope[i] === obj) { + occurance += 1; + currScope[i] = undefined; + } + } + if (occurance == 0) { + throw Error("Cannot find obj in the current auto conversion pool"); + } + if (occurance > 1) { + throw Error("Value attached to scope multiple times"); + } + return obj; + } +} + /** * A typed scalar constant used to represent a typed number * argument to PackedFunc calls. @@ -154,7 +244,7 @@ export class Scalar { * Cell holds the PackedFunc object. */ class PackedFuncCell implements Disposable { - handle: Pointer; + private handle: Pointer; private lib: FFILibrary; constructor(handle: Pointer, lib: FFILibrary) { @@ -170,6 +260,13 @@ class PackedFuncCell implements Disposable { this.handle = 0; } } + + getHandle(requireNotNull : boolean = true): Pointer { + if (requireNotNull && this.handle == 0) { + throw Error("PackedFunc has already been disposed"); + } + return this.handle; + } } const DeviceEnumToStr: Record = { @@ -286,7 +383,7 @@ export class DLDataType { */ export class NDArray implements Disposable { /** Internal array handle. */ - handle: Pointer; + private handle: Pointer; /** Number of dimensions. */ ndim: number; /** Data type of the array. */ @@ -352,6 +449,19 @@ export class NDArray implements Disposable { this.byteOffset = lib.memory.loadI64(this.dltensor + arrayOffsetByteOffset); } + /** + * Get handle of ndarray, check it is not null. + * + * @param requireNotNull require handle is not null. + * @returns The handle. + */ + getHandle(requireNotNull : boolean = true): Pointer { + if (requireNotNull && this.handle == 0) { + throw Error("NDArray has already been disposed"); + } + return this.handle; + } + dispose(): void { if (this.handle != 0 && !this.isView) { this.lib.checkCall( @@ -371,8 +481,8 @@ export class NDArray implements Disposable { if (data instanceof NDArray) { this.lib.checkCall( (this.lib.exports.TVMArrayCopyFromTo as ctypes.FTVMArrayCopyFromTo)( - data.handle, - this.handle, + data.getHandle(), + this.getHandle(), 0 ) ); @@ -427,7 +537,7 @@ export class NDArray implements Disposable { this.lib.memory.storeRawBytes(tempPtr, data); this.lib.checkCall( (this.lib.exports.TVMArrayCopyFromBytes as ctypes.FTVMArrayCopyFromBytes)( - this.handle, + this.getHandle(), tempPtr, nbytes ) @@ -455,7 +565,7 @@ export class NDArray implements Disposable { const tempPtr = stack.ptrFromOffset(tempOffset); this.lib.checkCall( (this.lib.exports.TVMArrayCopyToBytes as ctypes.FTVMArrayCopyToBytes)( - this.handle, + this.getHandle(), tempPtr, nbytes ) @@ -499,7 +609,7 @@ export class NDArray implements Disposable { * Runtime Module. */ export class Module implements Disposable { - handle: Pointer; + private handle: Pointer; private lib: FFILibrary; private makePackedFunc: (ptr: Pointer) => PackedFunc; @@ -522,12 +632,28 @@ export class Module implements Disposable { } } + /** + * Get handle of module, check it is not null. + * + * @param requireNotNull require handle is not null. + * @returns The handle. + */ + getHandle(requireNotNull : boolean = true): Pointer { + if (requireNotNull && this.handle == 0) { + throw Error("Module has already been disposed"); + } + return this.handle; + } + /** * Get a function in the module. * @param name The name of the function. * @returns The result function. */ getFunction(name: string): PackedFunc { + if (this.handle == 0) { + throw Error("Module has already been disposed"); + } const stack = this.lib.getOrAllocCallStack(); const nameOffset = stack.allocRawBytes(name.length + 1); stack.storeRawBytes(nameOffset, StringToUint8Array(name)); @@ -539,7 +665,7 @@ export class Module implements Disposable { this.lib.checkCall( (this.lib.exports.TVMModGetFunction as ctypes.FTVMModGetFunction)( - this.handle, + this.getHandle(), stack.ptrFromOffset(nameOffset), 1, outPtr @@ -561,112 +687,122 @@ export class Module implements Disposable { importModule(mod: Module): void { this.lib.checkCall( (this.lib.exports.TVMModImport as ctypes.FTVMModImport)( - this.handle, - mod.handle + this.getHandle(), + mod.getHandle() ) ); } } /** - * Graph executor. - * - * This is a thin wrapper of the underlying TVM module. - * you can also directly call set_input, run, and get_output - * of underlying module functions + * Generic object base */ -class GraphExecutor implements Disposable { - module: Module; - private packedSetInput: PackedFunc; - private packedRun: PackedFunc; - private packedGetOutput: PackedFunc; - private packedLoadParams: PackedFunc; + export class TVMObject implements Disposable { + private handle: Pointer; + private lib: FFILibrary; + protected ctx: RuntimeContext; - /** - * COnstructor - * @param module The underlying module. - */ - constructor(module: Module) { - this.module = module; - this.packedSetInput = module.getFunction("set_input"); - this.packedRun = module.getFunction("run"); - this.packedGetOutput = module.getFunction("get_output"); - this.packedLoadParams = module.getFunction("load_params"); + constructor( + handle: Pointer, + lib: FFILibrary, + ctx: RuntimeContext + ) { + this.handle = handle; + this.lib = lib; + this.ctx = ctx; } dispose(): void { - this.packedSetInput.dispose(); - this.packedRun.dispose(); - this.packedGetOutput.dispose(); + if (this.handle != 0) { + this.lib.checkCall( + (this.lib.exports.TVMObjectFree as ctypes.FTVMObjectFree)(this.handle) + ); + this.handle = 0; + } } /** - * Set input to the executor. + * Get handle of module, check it is not null. * - * @param key The input key. - * @param value The value to get set. + * @param requireNotNull require handle is not null. + * @returns The handle. */ - setInput(key: number | string, value: NDArray): void { - if (typeof key == "number") { - this.packedSetInput(new Scalar(key, "int32"), value); - } else { - this.packedSetInput(key, value); + getHandle(requireNotNull : boolean = true): Pointer { + if (requireNotNull && this.handle == 0) { + throw Error("Module has already been disposed"); + } + return this.handle; + } + /** get the type index of the object */ + typeIndex(): number { + if (this.handle == 0) { + throw Error("The current Object has already been disposed"); } + const stack = this.lib.getOrAllocCallStack(); + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + + this.lib.checkCall( + (this.lib.exports.TVMObjectGetTypeIndex as ctypes.FTVMObjectGetTypeIndex)( + this.getHandle(), + outPtr + ) + ); + const result = this.lib.memory.loadU32(outPtr); + this.lib.recycleCallStack(stack); + return result; } - /** - * Execute the underlying graph. - */ - run(): void { - this.packedRun(); + /** get the type key of the object */ + typeKey(): string { + const type_index = this.typeIndex(); + const stack = this.lib.getOrAllocCallStack(); + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + this.lib.checkCall( + (this.lib.exports.TVMObjectTypeIndex2Key as ctypes.FTVMObjectTypeIndex2Key)( + type_index, + outPtr + ) + ); + const result =this.lib.memory.loadCString( + this.lib.memory.loadPointer(outPtr) + ); + this.lib.recycleCallStack(stack); + return result; } +} - /** - * Get index-th output. - * @param index The index number. - * @param out The optional output storage parameters. - * @returns The output array. - */ - getOutput(index: number, out: NDArray | undefined = undefined): NDArray { - if (out !== undefined) { - this.packedGetOutput(new Scalar(index, "int32"), out) - return out; - } else { - return this.packedGetOutput(new Scalar(index, "int32")); - } +/** Objectconstructor */ +type FObjectConstructor = (handle: Pointer, lib: FFILibrary, ctx: RuntimeContext) => TVMObject; + +/** All possible object types. */ +type TVMObjectBase = TVMObject | NDArray | Module | PackedFunc; + +/** Runtime array object. */ +export class TVMArray extends TVMObject { + constructor( + handle: Pointer, + lib: FFILibrary, + ctx: RuntimeContext + ) { + super(handle, lib, ctx); } /** - * Load parameters from parameter binary. - * @param paramBinary The parameter binary. + * @returns the size of the array. */ - loadParams(paramBinary: Uint8Array): void { - this.packedLoadParams(paramBinary); + size() : number { + return this.ctx.arrayGetSize(this) as number; } - /** - * Benchmark stable execution of the graph(without data copy). - * @params dev The device to sync during each run. - * @number The number of times to compute the average. - * @repeat The number of times to repeat the run. + * Get index-th element of the array + * @param index the array index. + * @returns The element. */ - async benchmarkRuns(dev: DLDevice, number=10, repeat=4): Promise { - // Skip first run as it can involve GPU warmup and module loading time. - const perf = compact.getPerformance(); - const results = []; - this.run(); - await dev.sync(); - for (let k = 0; k < repeat; ++k) { - const tstart = perf.now(); - for (let i = 0; i < number; ++i) { - this.run(); - } - await dev.sync(); - const tend = perf.now(); - results.push((tend - tstart) / number); - } - return results; + get(index : number) : TVMObjectBase { + return this.ctx.arrayGetItem(this, new Scalar(index, "int32")) as TVMObjectBase; } } @@ -678,12 +814,28 @@ const enum AyncCallbackCode { /** * TVM runtime instance. + * + * All objects(NDArray, Module, PackedFunc) returned by TVM runtim function call + * and PackedFunc instance are tracked through a scope mechanism that will get + * auto-released when we call EndScope. + * + * This is necessarily to be able to release the underlying WASM and WebGPU memory that + * are not tracked through JS native garbage collection mechanism. + * + * This does mean that we have to get familar with the following functions: + * - {@link beginScope} + * - {@link endScope} + * - {@link withNewScope} + * - {@link attachToCurrentScope} + * - {@link detachFromCurrentScope} */ export class Instance implements Disposable { memory: Memory; exports: Record; private lib: FFILibrary; private env: Environment; + private objFactory: Map; + private ctx: RuntimeContext; /** * Internal function(registered by the runtime) @@ -726,22 +878,136 @@ export class Instance implements Disposable { this.lib = new FFILibrary(wasmInstance, env.imports); this.memory = this.lib.memory; this.exports = this.lib.exports; + this.objFactory = new Map(); + this.ctx = new RuntimeContext( + (name: string) => { + const autoAttachToScope = false; + // runtime context function do not auto-release. + return this.getGlobalFuncInternal(name, autoAttachToScope); + } + ); this.registerEnvGlobalPackedFuncs(); + this.registerObjectFactoryFuncs(); } + /** + * Benchmark stable execution of the run function. + * + * @params run The run function + * @params dev The device to sync during each run. + * @number The number of times to compute the average. + * @repeat The number of times to repeat the run. + */ + async benchmark(run: ()=>void, dev: DLDevice, number=10, repeat=4): Promise { + // Skip first run as it can involve GPU warmup and module loading time. + const perf = compact.getPerformance(); + const results = []; + + // run with new scope + this.withNewScope(run); + await dev.sync(); + + for (let k = 0; k < repeat; ++k) { + const tstart = perf.now(); + for (let i = 0; i < number; ++i) { + this.withNewScope(run); + } + await dev.sync(); + const tend = perf.now(); + results.push((tend - tstart) / number); + } + return results; + } + dispose(): void { + // order matters + // ctx release goes back into lib. + this.ctx.dispose(); this.lib.dispose(); } + + /** + * Begin a new scope for tracking object disposal. + */ + beginScope(): void { + this.ctx.beginScope(); + } + + /** + * End a scope and release all created TVM objects + * under the current scope. + * + * Exception: one can call retainToParentScope to move + * a value to parent scope. + */ + endScope(): void { + this.ctx.endScope(); + } + + /** + * Perform action under a new scope. + * + * @param action The action function. + * @returns The result value. + * + * @note For action to return a valid value, + * we will need to call {@link retainToParentScope} + * for the objects that are created in the scope. + */ + withNewScope(action: ()=>T): T { + this.beginScope(); + const val = action(); + this.endScope(); + return val; + } + + /** + * Attach a detached obj to the auto-release pool of the current scope. + * + * @param obj The input obj. + * @note Normally user do not need to call this function explicitly, as + * all library call return values are explicitly attached to + * the current scope. You only need to do so when you call + * {@link detachFromCurrentScope} to create a detached object. + */ + attachToCurrentScope(obj: T) : T { + return this.ctx.attachToCurrentScope(obj); + } + + /** + * Move obj's attachment to the parent scope. + * + * This function is useful to make sure objects are still + * alive when exit the current scope. + * + * @param obj The object to be moved. + * @returns The input obj. + */ + moveToParentScope(obj: T) : T { + return this.ctx.moveToParentScope(obj); + } + + /** + * Detach the object from the current scope + * so it won't be released via auto-release during endscope. + * + * User needs to either explicitly call obj.dispose(), or + * {@link attachToCurrentScope} to re-attach to the current scope. + * + * This function can be used to return values to the parent scope. + * @param obj The object. + */ + detachFromCurrentScope(obj: T): T { + return this.ctx.detachFromCurrentScope(obj); + } + /** * Get system-wide library module in the wasm. * System lib is a global module that contains self register functions in startup. * @returns The system library module. */ systemLib(): Module { - const getSysLib = this.getGlobalFunc("runtime.SystemLib"); - const mod = getSysLib() as Module; - getSysLib.dispose(); - return mod; + return this.ctx.getSysLib() as Module; } /** * List all the global function names registered in the runtime. @@ -791,29 +1057,39 @@ export class Instance implements Disposable { func: PackedFunc | Function, override = false ): void { - const packedFunc = this.toPackedFunc(func); - const ioverride = override ? 1 : 0; + this.withNewScope(() => { + const autoAttachToScope = true; + // packed func can be released once it is registered + const packedFunc = this.toPackedFuncInternal(func, autoAttachToScope); + const ioverride = override ? 1 : 0; - const stack = this.lib.getOrAllocCallStack(); - const nameOffset = stack.allocRawBytes(name.length + 1); - stack.storeRawBytes(nameOffset, StringToUint8Array(name)); - stack.commitToWasmMemory(); + const stack = this.lib.getOrAllocCallStack(); + const nameOffset = stack.allocRawBytes(name.length + 1); + stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + stack.commitToWasmMemory(); - this.lib.checkCall( - (this.lib.exports.TVMFuncRegisterGlobal as ctypes.FTVMFuncRegisterGlobal)( - stack.ptrFromOffset(nameOffset), - packedFunc._tvmPackedCell.handle, - ioverride - ) - ); + this.lib.checkCall( + (this.lib.exports.TVMFuncRegisterGlobal as ctypes.FTVMFuncRegisterGlobal)( + stack.ptrFromOffset(nameOffset), + packedFunc._tvmPackedCell.getHandle(), + ioverride + ) + ); + this.lib.recycleCallStack(stack); + }); } /** * Get global PackedFunc from the runtime. * @param name The name of the function. + * @param autoAttachToScope Whether to track it via autoDispose * @returns The result function. */ getGlobalFunc(name: string): PackedFunc { + return this.getGlobalFuncInternal(name, true); + } + + private getGlobalFuncInternal(name: string, autoAttachToScope: boolean = true): PackedFunc { const stack = this.lib.getOrAllocCallStack(); const nameOffset = stack.allocRawBytes(name.length + 1); stack.storeRawBytes(nameOffset, StringToUint8Array(name)); @@ -834,6 +1110,7 @@ export class Instance implements Disposable { throw Error("Cannot find global function " + name); } const ret = this.makePackedFunc(handle); + if (autoAttachToScope) this.ctx.attachToCurrentScope(ret); return ret; } @@ -854,9 +1131,15 @@ export class Instance implements Disposable { * @param func Input function. * @returns The converted function. */ - toPackedFunc(func: Function): PackedFunc { + toPackedFunc(func: Function): PackedFunc { + return this.toPackedFuncInternal(func, true); + } + + private toPackedFuncInternal(func: Function, autoAttachToScope: boolean): PackedFunc { if (this.isPackedFunc(func)) return func as PackedFunc; - return this.createPackedFuncFromCFunc(this.wrapJSFuncAsPackedCFunc(func)); + const ret = this.createPackedFuncFromCFunc(this.wrapJSFuncAsPackedCFunc(func)); + if (autoAttachToScope) return this.ctx.attachToCurrentScope(ret); + return ret; } /** @@ -979,29 +1262,74 @@ export class Instance implements Disposable { outPtr ) ); - const ret = new NDArray(this.memory.loadPointer(outPtr), false, this.lib); + const ret = this.ctx.attachToCurrentScope( + new NDArray(this.memory.loadPointer(outPtr), false, this.lib) + ); this.lib.recycleCallStack(stack); return ret; } /** - * Create a new graph executor. + * Create an tuple {@link TVMArray} input array. + * + * The input array can be passed to tvm runtime function + * and needs to b explicitly disposed. * - * @param graphJson The graph executor json file. - * @param lib The underlying library. - * @param dev The execution device of the graph. + * @param inputs The input array + * @returns The result array. */ - createGraphExecutor(graphJson: string, lib: Module, dev: DLDevice): GraphExecutor { - const fcreate = this.getGlobalFunc('tvm.graph_executor.create'); - const module = fcreate( - graphJson, - lib, - this.scalar(dev.deviceType, "int32"), - this.scalar(dev.deviceId, "int32")) as Module; - return new GraphExecutor(module); + makeTVMArray( + inputs: Array + ): TVMArray { + return this.ctx.arrayMake(...inputs) as TVMArray; } + /** + * Get type index from type key. + * @param typeKey The type key. + * @returns The corresponding type index. + */ + typeKey2Index( + typeKey: string + ) : number { + const stack = this.lib.getOrAllocCallStack(); + const typeKeyOffset = stack.allocRawBytes(typeKey.length + 1); + stack.storeRawBytes(typeKeyOffset, StringToUint8Array(typeKey)); + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + + stack.commitToWasmMemory(outOffset); + + this.lib.checkCall( + (this.lib.exports.TVMObjectTypeKey2Index as ctypes.FTVMObjectTypeKey2Index)( + stack.ptrFromOffset(typeKeyOffset), + outPtr + ) + ); + const typeIndex = this.memory.loadU32(outPtr); + this.lib.recycleCallStack(stack); + return typeIndex; + } + /** + * Register an object constructor. + * @param typeKey The name of the function. + * @param func function to be registered. + * @param override Whether overwrite function in existing registry. + */ + registerObjectConstructor( + typeKey: string, + func: FObjectConstructor, + override = false + ): void { + const typeIndex = this.typeKey2Index(typeKey); + if (this.objFactory.has(typeIndex)) { + if (!override) { + throw new Error("Type " + typeKey + " already registered"); + } + } + this.objFactory.set(typeIndex, func); + } /** * Register an asyncfunction to be global function in the server. * @param name The name of the function. @@ -1017,10 +1345,12 @@ export class Instance implements Disposable { ): void { const asyncVariant = (...args: Array): void => { const fargs = args.slice(0, args.length - 1); - const callback = args[args.length - 1] as PackedFunc; + // need to keep it alive until callback is fulfilled. + const callback = this.detachFromCurrentScope(args[args.length - 1] as PackedFunc); const promise: Promise = func(...fargs); promise.then((rv: any) => { callback(this.scalar(AyncCallbackCode.kReturn, "int32"), rv); + callback.dispose(); }); }; this.registerFunc("__async." + name, asyncVariant, override); @@ -1046,6 +1376,14 @@ export class Instance implements Disposable { this.lib.webGPUContext = webGPUContext; } + /** Register all object factory */ + private registerObjectFactoryFuncs(): void { + this.registerObjectConstructor("Array", + (handle: number, lib: FFILibrary, ctx: RuntimeContext) => { + return new TVMArray(handle, lib, ctx); + }); + } + /** Register global packed functions needed by the backend to the env. */ private registerEnvGlobalPackedFuncs(): void { // Register the timer function to enable the time_evaluator. @@ -1062,6 +1400,11 @@ export class Instance implements Disposable { cooldownIntervalMs: number, repeatsToCooldown: number ): Promise => { + // detach and explicit dispose when tasks is fullfilled + // the promise will immediately return and we need to makesure + // finvoke do not get recycled. + this.ctx.detachFromCurrentScope(finvoke); + finvoke(this.scalar(1, "int32")); await dev.sync(); const result = []; @@ -1095,6 +1438,9 @@ export class Instance implements Disposable { } const ret = new Float64Array(result.length); ret.set(result); + + // dispose finvoke + finvoke.dispose(); return new Uint8Array(ret.buffer); }; @@ -1154,7 +1500,7 @@ export class Instance implements Disposable { const valueOffset = argsValue + i * SizeOf.TVMValue; const codeOffset = argsCode + i * SizeOf.I32; if (val instanceof NDArray) { - stack.storePtr(valueOffset, val.handle); + stack.storePtr(valueOffset, val.getHandle()); stack.storeI32(codeOffset, ArgTypeCode.TVMNDArrayHandle); } else if (val instanceof Scalar) { if (val.dtype.startsWith("int") || val.dtype.startsWith("uint")) { @@ -1177,7 +1523,7 @@ export class Instance implements Disposable { stack.storeI32(codeOffset, ArgTypeCode.Float); // eslint-disable-next-line no-prototype-builtins } else if (tp == "function" && val.hasOwnProperty("_tvmPackedCell")) { - stack.storePtr(valueOffset, val._tvmPackedCell.handle); + stack.storePtr(valueOffset, val._tvmPackedCell.getHandle()); stack.storeI32(codeOffset, ArgTypeCode.TVMPackedFuncHandle); } else if (val === null || val == undefined) { stack.storePtr(valueOffset, 0); @@ -1189,13 +1535,16 @@ export class Instance implements Disposable { stack.allocThenSetArgBytes(valueOffset, val); stack.storeI32(codeOffset, ArgTypeCode.TVMBytes); } else if (val instanceof Function) { - val = this.toPackedFunc(val); + val = this.toPackedFuncInternal(val, false); stack.tempArgs.push(val); - stack.storePtr(valueOffset, val._tvmPackedCell.handle); + stack.storePtr(valueOffset, val._tvmPackedCell.getHandle()); stack.storeI32(codeOffset, ArgTypeCode.TVMPackedFuncHandle); } else if (val instanceof Module) { - stack.storePtr(valueOffset, val.handle); + stack.storePtr(valueOffset, val.getHandle()); stack.storeI32(codeOffset, ArgTypeCode.TVMModuleHandle); + } else if (val instanceof TVMObject) { + stack.storePtr(valueOffset, val.getHandle()); + stack.storeI32(codeOffset, ArgTypeCode.TVMObjectHandle); } else { throw new Error("Unsupported argument type " + tp); } @@ -1213,6 +1562,8 @@ export class Instance implements Disposable { _handle: Pointer ): number => { const jsArgs = []; + // use scope to track js values. + this.ctx.beginScope(); for (let i = 0; i < nargs; ++i) { const valuePtr = argValues + i * SizeOf.TVMValue; const codePtr = argCodes + i * SizeOf.I32; @@ -1237,6 +1588,8 @@ export class Instance implements Disposable { } const rv = func(...jsArgs); + // recycle all js object value in function unless we want to retain them. + this.ctx.endScope(); if (rv !== undefined && rv !== null) { const stack = lib.getOrAllocCallStack(); @@ -1281,7 +1634,7 @@ export class Instance implements Disposable { this.lib.checkCall( (this.exports.TVMFuncCall as ctypes.FTVMFuncCall)( - handle, + cell.getHandle(), stack.ptrFromOffset(valueOffset), stack.ptrFromOffset(tcodeOffset), args.length, @@ -1304,6 +1657,13 @@ export class Instance implements Disposable { return ret as PackedFunc; } + /** + * Creaye return value of the packed func. The value us auto-tracked for dispose. + * @param rvaluePtr The location of rvalue + * @param tcode The type code. + * @param callbackArg Whether it is being used in callbackArg. + * @returns The JS value. + */ private retValueToJS(rvaluePtr: Pointer, tcode: number, callbackArg: boolean): any { switch (tcode) { case ArgTypeCode.Int: @@ -1315,23 +1675,45 @@ export class Instance implements Disposable { return this.memory.loadPointer(rvaluePtr); } case ArgTypeCode.TVMNDArrayHandle: { - return new NDArray(this.memory.loadPointer(rvaluePtr), false, this.lib); + return this.ctx.attachToCurrentScope( + new NDArray(this.memory.loadPointer(rvaluePtr), false, this.lib) + ); } case ArgTypeCode.TVMDLTensorHandle: { assert(callbackArg); + // no need to attach as we are only looking at view return new NDArray(this.memory.loadPointer(rvaluePtr), true, this.lib); } case ArgTypeCode.TVMPackedFuncHandle: { - return this.makePackedFunc(this.memory.loadPointer(rvaluePtr)); + return this.ctx.attachToCurrentScope( + this.makePackedFunc(this.memory.loadPointer(rvaluePtr)) + ); } case ArgTypeCode.TVMModuleHandle: { - return new Module( + return this.ctx.attachToCurrentScope( + new Module( + this.memory.loadPointer(rvaluePtr), + this.lib, + (ptr: Pointer) => { + return this.ctx.attachToCurrentScope(this.makePackedFunc(ptr)); + } + ) + ); + } + case ArgTypeCode.TVMObjectHandle: { + const obj = new TVMObject( this.memory.loadPointer(rvaluePtr), this.lib, - (ptr: Pointer) => { - return this.makePackedFunc(ptr); - } + this.ctx ); + const func = this.objFactory.get(obj.typeIndex()) + if (func != undefined) { + return this.ctx.attachToCurrentScope( + func(obj.getHandle(), this.lib, this.ctx) + ); + } else { + return this.ctx.attachToCurrentScope(obj); + } } case ArgTypeCode.Null: return undefined; case ArgTypeCode.DLDevice: { diff --git a/web/tests/node/test_module_load.js b/web/tests/node/test_module_load.js index 561de8aa5786..24acc66a1e4b 100644 --- a/web/tests/node/test_module_load.js +++ b/web/tests/node/test_module_load.js @@ -32,8 +32,6 @@ const tvm = new tvmjs.Instance( new EmccWASI() ); -// Load system library -const sysLib = tvm.systemLib(); function randomArray(length, max) { return Array.apply(null, Array(length)).map(function () { @@ -42,8 +40,13 @@ function randomArray(length, max) { } test("add one", () => { + tvm.beginScope(); + // Load system library + const sysLib = tvm.systemLib(); // grab pre-loaded function const faddOne = sysLib.getFunction("add_one"); + tvm.detachFromCurrentScope(faddOne); + assert(tvm.isPackedFunc(faddOne)); const n = 124; const A = tvm.empty(n).copyFrom(randomArray(n, 1)); @@ -56,5 +59,13 @@ test("add one", () => { for (var i = 0; i < BB.length; ++i) { assert(Math.abs(BB[i] - (AA[i] + 1)) < 1e-5); } + tvm.endScope(); + + // assert auto release scope behavior + assert(sysLib.getHandle(false) == 0); + // fadd is not released because it is detached + assert(faddOne._tvmPackedCell.handle != 0); faddOne.dispose(); + assert(A.getHandle(false) == 0); + assert(B.getHandle(false) == 0); }); diff --git a/web/tests/node/test_ndarray.js b/web/tests/node/test_ndarray.js index b7a5abdcb155..8393c668dd0f 100644 --- a/web/tests/node/test_ndarray.js +++ b/web/tests/node/test_ndarray.js @@ -28,6 +28,7 @@ const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); let tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), new EmccWASI()); + // Basic fields. assert(tvm.listGlobalFuncNames() !== undefined); @@ -42,15 +43,14 @@ function testArrayCopy(dtype, arrayType) { let ret = a.toArray(); assert(ret instanceof arrayType); assert(ret.toString() == arrayType.from(data).toString()); - // test multiple dispose. - a.dispose(); - a.dispose(); } test("array copy", () => { - testArrayCopy("float32", Float32Array); - testArrayCopy("int", Int32Array); - testArrayCopy("int8", Int8Array); - testArrayCopy("uint8", Uint8Array); - testArrayCopy("float64", Float64Array); + tvm.withNewScope(() => { + testArrayCopy("float32", Float32Array); + testArrayCopy("int", Int32Array); + testArrayCopy("int8", Int8Array); + testArrayCopy("uint8", Uint8Array); + testArrayCopy("float64", Float64Array); + }); }); diff --git a/web/tests/node/test_object.js b/web/tests/node/test_object.js new file mode 100644 index 000000000000..3b7ee5bd0c46 --- /dev/null +++ b/web/tests/node/test_object.js @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* eslint-disable no-undef */ +const path = require("path"); +const fs = require("fs"); +const assert = require("assert"); +const tvmjs = require("../../dist/tvmjs.bundle") + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); + +let tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), new EmccWASI()); + +test("object", () => { + tvm.withNewScope(() => { + let data = [1, 2, 3, 4, 5, 6]; + let a = tvm.empty([2, 3], "float32").copyFrom(data); + + let t = tvm.makeTVMArray([]); + let b = tvm.makeTVMArray([a, t]); + // assert b instanceof tvmjs.TVMArray + assert(b instanceof tvmjs.TVMArray); + assert(b.size() == 2); + + let t1 = b.get(1); + assert(t1.getHandle() == t.getHandle()); + }); +}); diff --git a/web/tests/node/test_packed_func.js b/web/tests/node/test_packed_func.js index 6e0546f39df1..98956ebf2b7a 100644 --- a/web/tests/node/test_packed_func.js +++ b/web/tests/node/test_packed_func.js @@ -31,7 +31,9 @@ let tvm = new tvmjs.Instance( new EmccWASI() ); + test("GetGlobal", () => { + tvm.beginScope(); let flist = tvm.listGlobalFuncNames(); let faddOne = tvm.getGlobalFunc("testing.add_one"); let fecho = tvm.getGlobalFunc("testing.echo"); @@ -51,25 +53,35 @@ test("GetGlobal", () => { assert(fecho(undefined) == undefined); + tvm.beginScope(); + let arr = tvm.empty([2, 2]).copyFrom([1, 2, 3, 4]); let arr2 = fecho(arr); - assert(arr.handle == arr2.handle); + assert(arr.getHandle() == arr2.getHandle()); assert(arr2.toArray().toString() == arr.toArray().toString()); + tvm.moveToParentScope(arr2); + tvm.endScope(); + // test move to parent scope and tracking + assert(arr.getHandle(false) == 0); + assert(arr2.handle != 0); + let mod = tvm.systemLib(); let ret = fecho(mod); - assert(ret.handle == mod.handle); + assert(ret.getHandle() == mod.getHandle()); assert(flist.length != 0); - - mod.dispose(); - ret.dispose(); - arr.dispose(); - arr2.dispose(); - fecho.dispose(); - faddOne.dispose(); + tvm.endScope(); + + // assert auto release scope behavior + assert(mod.getHandle(false) == 0); + assert(ret.getHandle(false) == 0); + assert(arr2.getHandle(false) == 0); + assert(fecho._tvmPackedCell.getHandle(false) == 0); + assert(faddOne._tvmPackedCell.getHandle(false) == 0); }); test("ReturnFunc", () => { + tvm.beginScope(); function addy(y) { function add(x, z) { return x + y + z; @@ -95,9 +107,11 @@ test("ReturnFunc", () => { // test multiple dispose. f.dispose(); f.dispose(); + tvm.endScope(); }); test("RegisterGlobal", () => { + tvm.beginScope(); tvm.registerFunc("xyz", function (x, y) { return x + y; }); @@ -108,23 +122,44 @@ test("RegisterGlobal", () => { let syslib = tvm.systemLib(); syslib.dispose(); + tvm.endScope(); }); test("NDArrayCbArg", () => { + tvm.beginScope(); let use_count = tvm.getGlobalFunc("testing.object_use_count"); + let record = []; - let fcheck = tvm.toPackedFunc(function (x) { + let fcheck = tvm.toPackedFunc(function (x, retain) { assert(use_count(x) == 2); - x.dispose(); + assert(x.handle != 0); + record.push(x); + if (retain) { + tvm.detachFromCurrentScope(x); + } }); + let x = tvm.empty([2], "float32").copyFrom([1, 2]); assert(use_count(x) == 1); - fcheck(x); + + fcheck(x, 0); + // auto-released when it is out of scope. + assert(record[0].getHandle(false) == 0); + assert(use_count(x) == 1); + + fcheck(x, 1); + assert(use_count(x) == 2); + assert(record[1].handle != 0); + tvm.attachToCurrentScope(record[1]); + tvm.endScope(); + assert(record[1].getHandle(false) == 0); }); test("Logging", () => { + tvm.beginScope(); const log_info = tvm.getGlobalFunc("testing.log_info_str"); log_info("helow world") log_info.dispose(); + tvm.endScope(); }); diff --git a/web/tests/python/websock_rpc_test.py b/web/tests/python/websock_rpc_test.py index 9aab1759f8dd..7de5ee956ec8 100644 --- a/web/tests/python/websock_rpc_test.py +++ b/web/tests/python/websock_rpc_test.py @@ -69,7 +69,6 @@ def check(remote): assert fecho(100, 2, 3) == 100 assert fecho("xyz") == "xyz" assert bytes(fecho(bytearray(b"123"))) == b"123" - # run the generated library. f1 = remote.system_lib() dev = remote.cpu(0)