Skip to content

Commit

Permalink
perf(backend): use mutex for nsfw model loading
Browse files Browse the repository at this point in the history
  • Loading branch information
popkirby authored and riku6460 committed Jul 5, 2023
1 parent 3e93450 commit d4a9eba
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 78 deletions.
1 change: 1 addition & 0 deletions packages/backend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
"accepts": "1.3.8",
"ajv": "8.12.0",
"archiver": "5.3.1",
"async-mutex": "^0.4.0",
"autwh": "0.1.0",
"bcryptjs": "2.4.3",
"blurhash": "2.0.5",
Expand Down
18 changes: 13 additions & 5 deletions packages/backend/src/core/AiService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { dirname } from 'node:path';
import { Inject, Injectable } from '@nestjs/common';
import * as nsfw from 'nsfwjs';
import si from 'systeminformation';
import { Mutex } from 'async-mutex';
import type { Config } from '@/config.js';
import { DI } from '@/di-symbols.js';
import { bindThis } from '@/decorators.js';
Expand All @@ -17,6 +18,7 @@ let isSupportedCpu: undefined | boolean = undefined;
@Injectable()
export class AiService {
private model: nsfw.NSFWJS;
private modelLoadMutex: Mutex = new Mutex();

constructor(
@Inject(DI.config)
Expand All @@ -31,16 +33,22 @@ export class AiService {
const cpuFlags = await this.getCpuFlags();
isSupportedCpu = REQUIRED_CPU_FLAGS.every(required => cpuFlags.includes(required));
}

if (!isSupportedCpu) {
console.error('This CPU cannot use TensorFlow.');
return null;
}

const tf = await import('@tensorflow/tfjs-node');

if (this.model == null) this.model = await nsfw.load(`file://${_dirname}/../../nsfw-model/`, { size: 299 });


if (this.model == null) {
await this.modelLoadMutex.runExclusive(async () => {
if (this.model == null) {
this.model = await nsfw.load(`file://${_dirname}/../../nsfw-model/`, { size: 299 });
}
});
}

const buffer = await fs.promises.readFile(path);
const image = await tf.node.decodeImage(buffer, 3) as any;
try {
Expand Down
Loading

0 comments on commit d4a9eba

Please sign in to comment.