This repository has been archived by the owner on Oct 3, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathApp.tsx
153 lines (139 loc) · 4.26 KB
/
App.tsx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import React, { useState, useEffect } from "react";
import { StyleSheet, Text, View, Image, TextInput, Button } from "react-native";
import * as tf from "@tensorflow/tfjs";
import { fetch as tfjsFetch, decodeJpeg } from "@tensorflow/tfjs-react-native";
import * as mobilenet from "@tensorflow-models/mobilenet";
import * as FileSystem from "expo-file-system";
import * as ImagePicker from 'expo-image-picker';
import * as jpeg from "jpeg-js";
const App: React.FC = () => {
const [isTfReady, setIsTfReady] = useState(false);
const [mobilenetModel, setMobilenetModel] = useState(null);
const [image, setImage] = useState({
uri: null,
});
const [predictions, setPredictions] = useState(null);
// === IMAGE TO TENSOR HELPER ==
const imageToTensor = rawImageData => {
const TO_UINT8ARRAY = true;
const { width, height, data } = jpeg.decode(rawImageData, TO_UINT8ARRAY);
// Drop the alpha channel info for mobilenet
const buffer = new Uint8Array(width * height * 3);
let offset = 0; // offset into original data
for (let i = 0; i < buffer.length; i += 3) {
buffer[i] = data[offset];
buffer[i + 1] = data[offset + 1];
buffer[i + 2] = data[offset + 2];
offset += 4;
}
return tf.tensor3d(buffer, [height, width, 3]);
};
// === START CLASSIFICATION ==
const classifyImage = async (imgUri) => {
try {
const fileUri = imgUri;
const imgB64 = await FileSystem.readAsStringAsync(fileUri, {
encoding: FileSystem.EncodingType.Base64,
});
const imgBuffer = tf.util.encodeString(imgB64, "base64").buffer;
const newData = new Uint8Array(imgBuffer);
const imageTensor = decodeJpeg(newData); // transforms byte array into 3d tensor
const prediction = await mobilenetModel.classify(imageTensor);
setPredictions(prediction);
console.info(prediction);
} catch (error) {
console.log(error);
}
};
const selectImage = async () => {
try {
let response = await ImagePicker.launchImageLibraryAsync({
mediaTypes: ImagePicker.MediaTypeOptions.Images,
allowsEditing: false,
aspect: [4, 3]
})
if (!response.cancelled) {
const source = { uri: response.uri }
await setImage(source);
classifyImage(response.uri);
}
} catch (error) {
console.log(error)
}
}
// === CLEAR PREDICTIONS ==
const clearPredictions = () => {
setPredictions(null);
};
useEffect(() => {
(async function mango() {
await tf.ready();
setIsTfReady(true);
try {
let myModel = await mobilenet.load();
setMobilenetModel(myModel);
} catch (error) {
console.log(error);
}
})();
}, []);
return (
<View style={styles.container}>
<Text>Hello!</Text>
<Text>TF Status: {isTfReady ? "👌" : "⏳"}</Text>
<Text>Mobilenet Model Status: {mobilenetModel ? "👌" : "⏳"}</Text>
<Image
source={{ uri: image.uri }}
style={{ width: 200, height: 200, margin: 20 }}
/>
<TextInput
style={{
marginBottom: 20,
width: 200,
height: 40,
borderColor: "gray",
borderWidth: 1
}}
onChangeText={text => setImage({ uri: text })}
value={image.uri}
/>
<Button
title="Predict"
onPress={mobilenetModel ? selectImage : undefined}
disabled={mobilenetModel ? false : true}
/>
{predictions ? (
<View style={styles.predictions}>
<Text
style={{
marginBottom: 20
}}
>
I'm {predictions[0].probability.toFixed(2)}% sure it's a{" "}
{predictions[0].className.toLowerCase()}
{", "}
it might also be a {predictions[1].className.toLowerCase()} or{" "}
{predictions[2].className.toLowerCase()}
</Text>
<Button title="Clear" onPress={clearPredictions} />
</View>
) : null}
</View>
);
};
const styles = StyleSheet.create({
container: {
flex: 1,
backgroundColor: "#fff",
alignItems: "center",
justifyContent: "center"
},
predictions: {
borderColor: "grey",
borderWidth: 1,
width: 300,
padding: 20,
margin: 20
}
});
export default App;