diff --git a/config.yaml b/config.yaml index b595df0..0fc41c4 100644 --- a/config.yaml +++ b/config.yaml @@ -4,7 +4,7 @@ DATASET: SAM: CHECKPOINT: "./sam_vit_b_01ec64.pth" - RANK: 6 + RANK: 512 TRAIN: BATCH_SIZE: 1 NUM_EPOCHS: 50 \ No newline at end of file diff --git a/inference_eval.py b/inference_eval.py index 3a0d159..e177fb7 100644 --- a/inference_eval.py +++ b/inference_eval.py @@ -17,6 +17,10 @@ import torch.nn.functional as F import monai import numpy as np +""" +This file compute the evaluation metric (Dice cross entropy loss) for all trained LoRA SAM with different ranks. This gives the plot that is in ./plots/rank_comparison.jpg +which compares the performances on test the test set. +""" device = "cuda" if torch.cuda.is_available() else "cpu" seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean') diff --git a/inference_plots.py b/inference_plots.py index 40b89c3..73dc294 100644 --- a/inference_plots.py +++ b/inference_plots.py @@ -11,6 +11,11 @@ import json from torchvision.transforms import ToTensor +""" +This file is used to plots the predictions of a model (either baseline or LoRA) on the train or test set. Most of it is hard coded so I would like to explain some parameters to change +referencing by lines : +line 22: change the rank of lora; line 98: Do inference on train (inference_train=True) else on test; line 101 and 111 is_baseline arguments in fuction: True to use baseline False to use LoRA model. +""" sam_checkpoint = "sam_vit_b_01ec64.pth" device = "cuda" if torch.cuda.is_available() else "cpu" sam = build_sam_vit_b(checkpoint=sam_checkpoint) diff --git a/src/lora.py b/src/lora.py index 73d478f..33f2114 100644 --- a/src/lora.py +++ b/src/lora.py @@ -58,6 +58,14 @@ class LoRA_sam(nn.Module): """ Class that takes the image encoder of SAM and add the lora weights to the attentions blocks + Arguments: + sam_model: Sam class of the segment anything model + rank: Rank of the matrix for LoRA + lora_layer: List of weights exisitng for LoRA + + Return: + None + """ def __init__(self, sam_model: Sam, rank: int, lora_layer=None): @@ -112,6 +120,9 @@ def __init__(self, sam_model: Sam, rank: int, lora_layer=None): def reset_parameters(self): + """ + Initialize the LoRA A and B matrices like in the paper + """ # Initalisation like in the paper for w_A in self.A_weights: nn.init.kaiming_uniform_(w_A.weight, a=np.sqrt(5)) @@ -120,7 +131,15 @@ def reset_parameters(self): def save_lora_parameters(self, filename: str): - "save lora and fc parameters" + """ + Save the LoRA wieghts applied to the attention model as safetensors. + + Arguments: + filenmame: Name of the file that will be saved + + Return: + None: Saves a safetensors file + """ num_layer = len(self.A_weights) # sufix 03:d -> allows to have a name 1 instead of 001 a_tensors = {f"w_a_{i:03d}": self.A_weights[i].weight for i in range(num_layer)} @@ -130,7 +149,15 @@ def save_lora_parameters(self, filename: str): def load_lora_parameters(self, filename: str): - "load lora and fc parameters" + """ + Load a safetensor file of LoRA weights for the attention modules + + Arguments: + filename: Name of the file containing the saved weights + + Return: + None: Loads the weights to the LoRA_sam class + """ with safe_open(filename, framework="pt") as f: for i, w_A_linear in enumerate(self.A_weights): saved_key = f"w_a_{i:03d}" diff --git a/transform_to_mask.py b/transform_to_mask.py index b167cce..4567140 100644 --- a/transform_to_mask.py +++ b/transform_to_mask.py @@ -2,6 +2,10 @@ from PIL import Image import numpy as np +""" +This file takes the images on "./dataset/image_before_mask" folder. I have prepared and outlined the images to get their masks. I set the pixel treshold to 10. +If the intensity < 10 the pixel is set to black. +""" filename = "ring_test_2.jpg" mask_path = f"./dataset/image_before_mask/{filename}" mask = Image.open(mask_path)