Skip to content

Commit

Permalink
[WEB] Reduce memleak in web runtime (apache#14086)
Browse files Browse the repository at this point in the history
This PR robustifies the web runtime to reduce memory leak
and enhances the runtime with object support.

Specifically we introduce scoping and auto-release
mechanism when we exit the scope. The improvements
are helpful to deal with memory leak in wasm and webgpu settings
  • Loading branch information
tqchen authored and yongwww committed Feb 27, 2023
1 parent b956917 commit cb66e6c
Show file tree
Hide file tree
Showing 13 changed files with 670 additions and 161 deletions.
1 change: 1 addition & 0 deletions web/.eslintignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
dist
debug
2 changes: 1 addition & 1 deletion web/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ The following is an example to reproduce this.
- Start the WebSocket RPC
- Browswer version: open https://localhost:8888, click connect to proxy
- NodeJS version: `npm run rpc`
- run `python tests/node/websock_rpc_test.py` to run the rpc client.
- run `python tests/python/websock_rpc_test.py` to run the rpc test.


## WebGPU Experiments
Expand Down
2 changes: 2 additions & 0 deletions web/apps/node/example.js
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm"));
// the async version of the API.
tvmjs.instantiate(wasmSource, new EmccWASI())
.then((tvm) => {
tvm.beginScope();
const log_info = tvm.getGlobalFunc("testing.log_info_str");
log_info("hello world");
// List all the global functions from the runtime.
console.log("Runtime functions using EmccWASI\n", tvm.listGlobalFuncNames());
tvm.endScope();
});
2 changes: 1 addition & 1 deletion web/emcc/wasm_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
#include <tvm/runtime/logging.h>

#include "src/runtime/c_runtime_api.cc"
#include "src/runtime/container.cc"
#include "src/runtime/contrib/sort/sort.cc"
#include "src/runtime/cpu_device_api.cc"
#include "src/runtime/file_utils.cc"
#include "src/runtime/graph_executor/graph_executor.cc"
#include "src/runtime/library_module.cc"
#include "src/runtime/logging.cc"
#include "src/runtime/module.cc"
Expand Down
22 changes: 22 additions & 0 deletions web/src/ctypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ export type FTVMModGetFunction = (
* TVMModuleHandle dep);
*/
export type FTVMModImport = (mod: Pointer, dep: Pointer) => number;

/**
* int TVMModFree(TVMModuleHandle mod);
*/
Expand Down Expand Up @@ -161,6 +162,27 @@ export type FTVMBackendPackedCFunc = (
argValues: Pointer, argCodes: Pointer, nargs: number,
outValue: Pointer, outCode: Pointer) => number;


/**
* int TVMObjectFree(TVMObjectHandle obj);
*/
export type FTVMObjectFree = (obj: Pointer) => number;

/**
* int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex);
*/
export type FTVMObjectGetTypeIndex = (obj: Pointer, out_tindex: Pointer) => number;

/**
* int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key);
*/
export type FTVMObjectTypeIndex2Key = (type_index: number, out_type_key: Pointer) => number;

/**
* int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex);
*/
export type FTVMObjectTypeKey2Index = (type_key: Pointer, out_tindex: Pointer) => number;

// -- TVM Wasm Auxiliary C API --

/** void* TVMWasmAllocSpace(int size); */
Expand Down
5 changes: 3 additions & 2 deletions web/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@

export {
Scalar, DLDevice, DLDataType,
PackedFunc, Module, NDArray, Instance,
instantiate
PackedFunc, Module, NDArray,
TVMArray,
Instance, instantiate
} from "./runtime";
export { Disposable, LibraryProvider } from "./types";
export { RPCServer } from "./rpc_server";
Expand Down
19 changes: 15 additions & 4 deletions web/src/rpc_server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import { assert, StringToUint8Array, Uint8ArrayToString } from "./support";
import { detectGPUDevice } from "./webgpu";
import * as compact from "./compact";
import * as runtime from "./runtime";
import { timeStamp } from "console";
import { Disposable } from "./types";

enum RPCServerState {
InitHeader,
Expand Down Expand Up @@ -83,6 +85,7 @@ export class RPCServer {
private pendingSend: Promise<void> = Promise.resolve();
private name: string;
private inst?: runtime.Instance = undefined;
private globalObjects: Array<Disposable> = [];
private serverRecvData?: (header: Uint8Array, body: Uint8Array) => void;
private currPacketHeader?: Uint8Array;
private currPacketLength = 0;
Expand Down Expand Up @@ -121,6 +124,9 @@ export class RPCServer {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
private onClose(_event: CloseEvent): void {
if (this.inst !== undefined) {
this.globalObjects.forEach(obj => {
obj.dispose();
});
this.inst.dispose();
}
if (this.state == RPCServerState.ReceivePacketHeader) {
Expand Down Expand Up @@ -263,6 +269,9 @@ export class RPCServer {
}

this.inst = inst;
// begin scope to allow handling of objects
// the object should stay alive during all sessions.
this.inst.beginScope();
const fcreate = this.inst.getGlobalFunc("rpc.CreateEventDrivenServer");

const messageHandler = fcreate(
Expand Down Expand Up @@ -301,8 +310,10 @@ export class RPCServer {
this.name,
this.key
);

fcreate.dispose();
// message handler should persist across RPC runs
this.globalObjects.push(
this.inst.detachFromCurrentScope(messageHandler)
);
const writeFlag = this.inst.scalar(3, "int32");

this.serverRecvData = (header: Uint8Array, body: Uint8Array): void => {
Expand All @@ -320,7 +331,6 @@ export class RPCServer {
// register the callback to redirect the session to local.
const flocal = this.inst.getGlobalFunc("wasm.LocalSession");
const localSession = flocal();
flocal.dispose();
assert(localSession instanceof runtime.Module);

// eslint-disable-next-line @typescript-eslint/no-unused-vars
Expand All @@ -333,13 +343,14 @@ export class RPCServer {
);
messageHandler(header, writeFlag);
messageHandler(body, writeFlag);
localSession.dispose();

this.log("Finish initializing the Wasm Server..");
this.requestBytes(SizeOf.I64);
this.state = RPCServerState.ReceivePacketHeader;
// call process events in case there are bufferred data.
this.processEvents();
// recycle all values.
this.inst.endScope();
};

this.state = RPCServerState.WaitForCallback;
Expand Down
Loading

0 comments on commit cb66e6c

Please sign in to comment.