Skip to content

Commit

Permalink
IA Debug
Browse files Browse the repository at this point in the history
  • Loading branch information
anrouxel committed Feb 28, 2024
1 parent e288308 commit 7007806
Show file tree
Hide file tree
Showing 16 changed files with 181 additions and 135 deletions.
10 changes: 9 additions & 1 deletion .idea/sonarlint/issuestore/index.pb

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion .idea/sonarlint/securityhotspotstore/index.pb

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

53 changes: 32 additions & 21 deletions app/src/main/java/fr/medicapp/medicapp/ai/PrescriptionAI.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ import androidx.annotation.WorkerThread
import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.text.TextRecognition
import com.google.mlkit.vision.text.latin.TextRecognizerOptions
import fr.medicapp.medicapp.ai.tokenization.Feature
import fr.medicapp.medicapp.ai.tokenization.FeatureConverter
import fr.medicapp.medicapp.database.repositories.medication.MedicationRepository
import fr.medicapp.medicapp.model.prescription.relationship.Prescription
import fr.medicapp.medicapp.tokenization.Feature
import fr.medicapp.medicapp.tokenization.FeatureConverter
import fr.medicapp.medicapp.utils.JaroWinkler
import org.pytorch.IValue
import org.pytorch.Module
Expand Down Expand Up @@ -128,7 +128,7 @@ class PrescriptionAI(
*/
init {
// Crée un nouveau thread en arrière-plan.
mBackgroundThread = HandlerThread("BackgroundPyTorchThread")
mBackgroundThread = HandlerThread("BackgroundThread")

// Démarre le thread en arrière-plan.
mBackgroundThread.start()
Expand Down Expand Up @@ -157,24 +157,24 @@ class PrescriptionAI(
// Attend que le module PyTorch soit chargé.
while (mModule == null) {
Thread.sleep(100)
Log.v(TAG, "Waiting for model to load.")
}
Log.d(TAG, "Model loaded.")

// Reconnaît le texte dans l'image spécifiée par l'URI.
val visionText = recognizeText(imageUri)
Log.d(TAG, "Text recognized: $visionText")

Log.d("visionText", visionText.toString())

if (visionText != null) {
// Exécute le modèle PyTorch sur le texte reconnu et génère des prédictions.
val sentenceTokenized = runModel(visionText)
Log.d(TAG, "Predictions: $sentenceTokenized")

// Appelle le callback avec les prédictions générées.
Log.d("sentenceTokenized", sentenceTokenized.toString())

val prescriptions = onPrediction(sentenceTokenized)
Log.d(TAG, "Prescriptions: $prescriptions")

// Appelle le callback lorsque l'analyse est terminée.
Log.d("prescriptions", prescriptions.toString())

// Appelle le callback avec les prédictions générées.
onDismiss(prescriptions)
}

Expand Down Expand Up @@ -287,6 +287,7 @@ class PrescriptionAI(
val feature: Feature = featureConverter.convert(visionText)
val inputIds = feature.inputIds
val inputMask = feature.inputMask
val segmentIds = feature.segmentIds
val startLogits = FloatArray(MAX_SEQ_LEN)
val endLogits = FloatArray(MAX_SEQ_LEN)

Expand Down Expand Up @@ -314,7 +315,7 @@ class PrescriptionAI(
)

// Aligne les IDs de mots avec les labels.
val labelIds = FeatureConverter.alignWordIDS(feature)
val labelIds = FeatureConverter.align_word_ids(feature)

// Exécute le modèle PyTorch avec les prédictions d'entrée et de masque.
val outputTensor = mModule!!.forward(
Expand Down Expand Up @@ -347,7 +348,7 @@ class PrescriptionAI(

// Convertit la liste de prédictions en liste de labels.
var predictionsLabelList: List<String> = startPredictionsList.map { index ->
labels.getValue(index)
labels[index]!!
}

// Parcourt la liste des labels prédits.
Expand All @@ -372,13 +373,18 @@ class PrescriptionAI(
sentenceTokenized.forEach { (word, label) ->
when {
label.startsWith("B-") -> {
if (label.removePrefix("B-") == "Drug") {
val medication = MedicationRepository(context).getAll().sortedBy {
JaroWinkler.jaroWinklerDistance(
it.medicationInformation.name,
if (label.removePrefix("B-") == "Drug" && query.isNotEmpty()) {
val medication = MedicationRepository(context).getAll().map { medication ->
val distance = JaroWinkler.jaroWinklerDistance(
medication.medicationInformation.name,
query.trim()
)
}.first()
Pair(medication, distance)
}.filter { (_, distance) ->
distance < 0.20
}.minByOrNull { (_, distance) ->
distance
}?.first
prescription.medication = medication
prescriptions.add(prescription)
prescription = Prescription()
Expand All @@ -404,12 +410,17 @@ class PrescriptionAI(
}
}
if (query.isNotEmpty()) {
val medication = MedicationRepository(context).getAll().sortedBy {
JaroWinkler.jaroWinklerDistance(
it.medicationInformation.name,
val medication = MedicationRepository(context).getAll().map { medication ->
val distance = JaroWinkler.jaroWinklerDistance(
medication.medicationInformation.name,
query.trim()
)
}.first()
Pair(medication, distance)
}.filter { (_, distance) ->
distance < 0.20
}.minByOrNull { (_, distance) ->
distance
}?.first
prescription.medication = medication
prescriptions.add(prescription)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package fr.medicapp.medicapp.ai.tokenization
package fr.medicapp.medicapp.tokenization

/**
* Classe BasicTokenizer pour la tokenization de base du texte.
Expand Down Expand Up @@ -107,13 +107,13 @@ class BasicTokenizer(
* @return Une liste de tokens.
*/
fun whitespaceTokenize(text: String): MutableList<String> {

// Divise le texte en tokens en utilisant les espaces blancs comme séparateurs.
return mutableListOf(
// Supprime les tokens vides à la fin de la liste.
*text.split(" ".toRegex()).dropLastWhile { it.isEmpty() }
// Convertit la liste de tokens en MutableList.
.toTypedArray()
)
.toTypedArray())
}

/**
Expand Down Expand Up @@ -155,4 +155,4 @@ class BasicTokenizer(
return tokens
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package fr.medicapp.medicapp.ai.tokenization
package fr.medicapp.medicapp.tokenization

/**
* Objet CamemBERT qui contient les tokens spéciaux utilisés dans le modèle de tokenization CamemBERT.
Expand Down Expand Up @@ -38,4 +38,4 @@ object CamemBERT {
* Token de masquage.
*/
val MASK_TOKEN: String = "<mask>"
}
}
100 changes: 47 additions & 53 deletions app/src/main/java/fr/medicapp/medicapp/ai/tokenization/CharChecker.kt
Original file line number Diff line number Diff line change
@@ -1,63 +1,57 @@
package fr.medicapp.medicapp.ai.tokenization
package fr.medicapp.medicapp.tokenization

/**
* Classe CharChecker pour vérifier les caractères spécifiques.
*/
object CharChecker {
/**
* Pour juger si c'est un caractère vide ou inconnu.
*
* @param ch Le caractère à vérifier.
* @return Vrai si le caractère est invalide, faux sinon.
*/
fun isInvalid(ch: Char): Boolean {
return ch.code == 0 || ch.code == 0xfffd
}
class CharChecker {
companion object {
/**
* Pour juger si c'est un caractère vide ou inconnu.
*
* @param ch Le caractère à vérifier.
* @return Vrai si le caractère est invalide, faux sinon.
*/
fun isInvalid(ch: Char): Boolean {
return ch.code == 0 || ch.code == 0xfffd
}

/**
* Pour juger si c'est un caractère de contrôle (exclut l'espace blanc).
*
* @param ch Le caractère à vérifier.
* @return Vrai si le caractère est un caractère de contrôle, faux sinon.
*/
fun isControl(ch: Char): Boolean {
if (Character.isWhitespace(ch)) {
return false
/**
* Pour juger si c'est un caractère de contrôle (exclut l'espace blanc).
*
* @param ch Le caractère à vérifier.
* @return Vrai si le caractère est un caractère de contrôle, faux sinon.
*/
fun isControl(ch: Char): Boolean {
if (Character.isWhitespace(ch)) {
return false
}
val type = Character.getType(ch)
return type == Character.CONTROL.toInt() || type == Character.FORMAT.toInt()
}
val type = Character.getType(ch)
return type == Character.CONTROL.toInt() || type == Character.FORMAT.toInt()
}

/**
* Pour juger si cela peut être considéré comme un espace blanc.
*
* @param ch Le caractère à vérifier.
* @return Vrai si le caractère est un espace blanc, faux sinon.
*/
fun isWhitespace(ch: Char): Boolean {
if (Character.isWhitespace(ch)) {
return true
/**
* Pour juger si cela peut être considéré comme un espace blanc.
*
* @param ch Le caractère à vérifier.
* @return Vrai si le caractère est un espace blanc, faux sinon.
*/
fun isWhitespace(ch: Char): Boolean {
if (Character.isWhitespace(ch)) {
return true
}
val type = Character.getType(ch)
return type == Character.SPACE_SEPARATOR.toInt() || type == Character.LINE_SEPARATOR.toInt() || type == Character.PARAGRAPH_SEPARATOR.toInt()
}
val type = Character.getType(ch)
return type == Character.SPACE_SEPARATOR.toInt() ||
type == Character.LINE_SEPARATOR.toInt() ||
type == Character.PARAGRAPH_SEPARATOR.toInt()
}

/**
* Pour juger si c'est une ponctuation.
*
* @param ch Le caractère à vérifier.
* @return Vrai si le caractère est une ponctuation, faux sinon.
*/
fun isPunctuation(ch: Char): Boolean {
val type = Character.getType(ch)
return type == Character.CONNECTOR_PUNCTUATION.toInt() ||
type == Character.DASH_PUNCTUATION.toInt() ||
type == Character.START_PUNCTUATION.toInt() ||
type == Character.END_PUNCTUATION.toInt() ||
type == Character.INITIAL_QUOTE_PUNCTUATION.toInt() ||
type == Character.FINAL_QUOTE_PUNCTUATION.toInt() ||
type == Character.OTHER_PUNCTUATION.toInt()
/**
* Pour juger si c'est une ponctuation.
*
* @param ch Le caractère à vérifier.
* @return Vrai si le caractère est une ponctuation, faux sinon.
*/
fun isPunctuation(ch: Char): Boolean {
val type = Character.getType(ch)
return type == Character.CONNECTOR_PUNCTUATION.toInt() || type == Character.DASH_PUNCTUATION.toInt() || type == Character.START_PUNCTUATION.toInt() || type == Character.END_PUNCTUATION.toInt() || type == Character.INITIAL_QUOTE_PUNCTUATION.toInt() || type == Character.FINAL_QUOTE_PUNCTUATION.toInt() || type == Character.OTHER_PUNCTUATION.toInt()
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package fr.medicapp.medicapp.ai.tokenization
package fr.medicapp.medicapp.tokenization

/**
* Classe Feature pour représenter une caractéristique d'un texte.
Expand Down Expand Up @@ -51,4 +51,4 @@ class Feature(
this.origTokens = origTokens
this.tokenToOrigMap = tokenToOrigMap
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package fr.medicapp.medicapp.ai.tokenization
package fr.medicapp.medicapp.tokenization

import java.util.Collections

Expand Down Expand Up @@ -159,7 +159,7 @@ class FeatureConverter(
* @param labelAllTokens Indique si tous les tokens doivent être étiquetés.
* @return Une liste d'identifiants de mots alignés.
*/
fun alignWordIDS(feature: Feature, labelAllTokens: Boolean = false): MutableList<Int> {
fun align_word_ids(feature: Feature, labelAllTokens: Boolean = false): MutableList<Int> {
// Récupère les identifiants d'entrée de la caractéristique.
val inputIds = feature.inputIds

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package fr.medicapp.medicapp.ai.tokenization
package fr.medicapp.medicapp.tokenization

/**
* Classe FullTokenizer pour la tokenization complète du texte.
Expand Down Expand Up @@ -61,17 +61,17 @@ class FullTokenizer(
* @param tokens La liste de tokens à convertir.
* @return Une liste d'identifiants.
*/
fun convertTokensToIds(tokens: List<String>): MutableList<Int> {
fun convertTokensToIds(tokens: MutableList<String>): MutableList<Int> {
// Crée une liste mutable pour stocker les identifiants de sortie.
val outputIds: MutableList<Int> = ArrayList()

// Parcourt chaque token dans la liste des tokens.
for (token in tokens) {
// Ajoute l'identifiant correspondant au token dans le dictionnaire à la liste des identifiants de sortie.
outputIds.add(dic.getValue(token))
outputIds.add(dic[token]!!)
}

// Retourne la liste des identifiants de sortie.
return outputIds
}
}
}
Loading

0 comments on commit 7007806

Please sign in to comment.