-
Notifications
You must be signed in to change notification settings - Fork 844
/
Copy pathmain.js
112 lines (87 loc) · 3.65 KB
/
main.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import './style.css';
import { AutoModel, AutoProcessor, env, RawImage } from '@xenova/transformers';
// Since we will download the model from the Hugging Face Hub, we can skip the local model check
env.allowLocalModels = false;
// Proxy the WASM backend to prevent the UI from freezing
env.backends.onnx.wasm.proxy = true;
// Constants
const EXAMPLE_URL = 'https://images.pexels.com/photos/5965592/pexels-photo-5965592.jpeg?auto=compress&cs=tinysrgb&w=1024';
// Reference the elements that we will need
const status = document.getElementById('status');
const fileUpload = document.getElementById('upload');
const imageContainer = document.getElementById('container');
const example = document.getElementById('example');
// Load model and processor
status.textContent = 'Loading model...';
const model = await AutoModel.from_pretrained('briaai/RMBG-1.4', {
// Do not require config.json to be present in the repository
config: { model_type: 'custom' },
});
const processor = await AutoProcessor.from_pretrained('briaai/RMBG-1.4', {
// Do not require config.json to be present in the repository
config: {
do_normalize: true,
do_pad: false,
do_rescale: true,
do_resize: true,
image_mean: [0.5, 0.5, 0.5],
feature_extractor_type: "ImageFeatureExtractor",
image_std: [1, 1, 1],
resample: 2,
rescale_factor: 0.00392156862745098,
size: { width: 1024, height: 1024 },
}
});
status.textContent = 'Ready';
example.addEventListener('click', (e) => {
e.preventDefault();
predict(EXAMPLE_URL);
});
fileUpload.addEventListener('change', function (e) {
const file = e.target.files[0];
if (!file) {
return;
}
const reader = new FileReader();
// Set up a callback when the file is loaded
reader.onload = e2 => predict(e2.target.result);
reader.readAsDataURL(file);
});
// Predict foreground of the given image
async function predict(url) {
// Read image
const image = await RawImage.fromURL(url);
// Update UI
imageContainer.innerHTML = '';
imageContainer.style.backgroundImage = `url(${url})`;
// Set container width and height depending on the image aspect ratio
const ar = image.width / image.height;
const [cw, ch] = (ar > 720 / 480) ? [720, 720 / ar] : [480 * ar, 480];
imageContainer.style.width = `${cw}px`;
imageContainer.style.height = `${ch}px`;
status.textContent = 'Analysing...';
// Preprocess image
const { pixel_values } = await processor(image);
// Predict alpha matte
const { output } = await model({ input: pixel_values });
// Resize mask back to original size
const mask = await RawImage.fromTensor(output[0].mul(255).to('uint8')).resize(image.width, image.height);
// Create new canvas
const canvas = document.createElement('canvas');
canvas.width = image.width;
canvas.height = image.height;
const ctx = canvas.getContext('2d');
// Draw original image output to canvas
ctx.drawImage(image.toCanvas(), 0, 0);
// Update alpha channel
const pixelData = ctx.getImageData(0, 0, image.width, image.height);
for (let i = 0; i < mask.data.length; ++i) {
pixelData.data[4 * i + 3] = mask.data[i];
}
ctx.putImageData(pixelData, 0, 0);
// Update UI
imageContainer.append(canvas);
imageContainer.style.removeProperty('background-image');
imageContainer.style.background = `url("data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQBAMAAADt3eJSAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAAGUExURb+/v////5nD/3QAAAAJcEhZcwAADsMAAA7DAcdvqGQAAAAUSURBVBjTYwABQSCglEENMxgYGAAynwRB8BEAgQAAAABJRU5ErkJggg==")`;
status.textContent = 'Done!';
}