From 057412db890b83001cb715615f2bb70eb57c220f Mon Sep 17 00:00:00 2001 From: "Romain F. Laine" Date: Fri, 7 Aug 2020 14:16:38 +0100 Subject: [PATCH] v1.9 --- .DS_Store | Bin 12292 -> 12292 bytes Colab_notebooks/.DS_Store | Bin 6148 -> 6148 bytes Colab_notebooks/CARE_2D_ZeroCostDL4Mic.ipynb | 2 +- Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb | 2 +- Colab_notebooks/ChangeLog.txt | 2 ++ Colab_notebooks/CycleGAN_ZeroCostDL4Mic.ipynb | 2 +- .../Deep-STORM_2D_ZeroCostDL4Mic.ipynb | 2 +- .../Noise2VOID_2D_ZeroCostDL4Mic.ipynb | 2 +- .../Noise2VOID_3D_ZeroCostDL4Mic.ipynb | 2 +- .../Stardist_2D_ZeroCostDL4Mic.ipynb | 2 +- .../Stardist_3D_ZeroCostDL4Mic.ipynb | 2 +- Colab_notebooks/U-net_2D_ZeroCostDL4Mic.ipynb | 2 +- Colab_notebooks/YOLOv2_ZeroCostDL4Mic.ipynb | 2 +- Colab_notebooks/fnet_ZeroCostDL4Mic.ipynb | 2 +- Colab_notebooks/pix2pix_ZeroCostDL4Mic.ipynb | 2 +- 15 files changed, 14 insertions(+), 12 deletions(-) diff --git a/.DS_Store b/.DS_Store index f888e9ef2378063f5be7df09dc952bf935d21bb4..b76d5913eda9d7bffdf76d7dab0ce284370d1ed8 100644 GIT binary patch delta 2077 zcmeH{U2Icj7{}lLXt&dI_VI3I?P=H3bst^X*0~n8m3_b*;~Nn%Oc4-P)@mAAH`pL< zV@?I5h#JkSU<4v^p;06unM0#|#S8f|F-kAw0*oQxg^3uwK_k)UeJfmep;z7*o1CWa z^R%bW`Tw8ae`oK`-lJ=5dAi|TSiDFVr);I-s4HH2IKR|eUNxg(W~`$tQ&3pMi%Xok zxZ12YRrCG!`_K-MI8tJ>ruCeRvJN6^H#eN3^ zpu+_>OjMu})d->vv(SVV#L$5SSct_~f)!Yehp-mwkU&4u*o-Y0!%hgihCSGegLn&v z@IH><7*2`PW_iu&G)=5!N_cTmTf1T2nNp6NEqC5+n&NV3Ik3TwLVAK~)F6Zzh++=r zqLb!cPIKRfl~{ufh+`xAu!-g###U^@1fIn%JdfRY6$jM(WruJWALBSq;3UrAEY9H~ zGPsQExFOb+my1hzGx8Z@po?CbF>lD{`PmXR0 z!(Y?yahonHU{z{l2iG$4s0S^FgHL4-R?w=`G@g@jxlK>F!P5CO+7yR6MrZ04x6@K+ z>ev#~b#9}r$#&a2sGF8zKeANx!G?GG{>69#~qqG7@f1E zw|``GZ2R~V`^8>ox%~j$t4MI}7r}~v!zeHpFIntbx=h4vU1EcGwm4XAh!ebY`r$&i zuOe92*wi|=J0r<7rW@k7ipB}~+4PuY8qs)%Otq1g-!#@qN|p~KQjesFktvc|?rNpq zG}p>b_m(8RCkP!`t_hijqE>k#EA?T+P}C|PWF-j+t~j1c&O?o zTYSov+P~Q15{IlubBuCvtHPY

Rt+uF{I6%&6O zE$7%*>|2VJUlkxaN(hkvMaf;5t)gTeVX*-BQk<+*SUimN=%q-Jfgz7h+q@Gq@lDsrO1inrB7Af!2TE fW1r*$`zOzb8ttOPT`K~?y#KZSSG5-B-EDsWdcNr7 delta 2047 zcmeH{Yitx{6o${+?v@!lZNKeyZ&SA0ZM$6-mfl;o)V4seEm%a51_A_n*(KRPTWE`1 zDlTe?F#?HxY7mV=h%o{&CZ#9}g~VW>8Z@|uXcUPtnD_@0{oyU(nb}H+2{HaP&X38Q zbG}K=dB5kpqg|t2@2s{=Q#0+YwzhV;%s*cq5Y-x1r>IGhos(sE**%IVR#t88iEr5&AKqb}%b(8YBGypC_^BjdS62+hcg93F z|72uBGgWoP#&~~J45_J1S-EQO+|=C@+a#i@VjillQSB=qjSUXP2l~Z>Q^`oZYG2hm z(AO91A0DzHJe6*2vfI~1yN6>P-9y7mH^=&8QE@?0lJc50m=GX^JQSuTYNgfGOZ~Ks z_R>DuPY3A;y-g=*f^&f6y(u4T1^>oN%EK0R%AvAyi=w z>M<7$Sct{wz!E%!l~|8%^q?1;uoc^}13U2)9M9tgyoi_ZIu7A5j^Q}QF(J>neI@63 zw4|Y`-k>!q#yEFoOXdn&PV+onm)~Z%)Aa&f{+1C+s62_6s%Rl~P!|o-2v2c<4)X*j z>0SDeKA|Lc|2=nplYXL~=~tv76IsYc4%~P3-h^h%!=kCqmtzgqVjXuMMIZXH8C$Ro zyYU2fKZ?C(^DpBdj_BO`2RMTXoW%uP#3YjV0$1?^Zpt&o9rqE)e3Z;Fo8q*Lm1-3R zSbA$-TktNhoVnT*OnHui;u+z}x`ui3y68+(z4Er-o#dXd66 zDikwR1A0&uLEjJbHP)YuvTVBY%X|;?ndbJ5=~z;TCqr>ig!l0KJL-V$?UdO=xVzDIZaihL!CIBigFgj&#-X;5wk zzMwB`GS178UsmqZM7n9GMp#{=`GnOpP9u!eYnrf_)@p=}OY5Pcg{L~K&p2=@mxh_=EJPqc_>#3qe;F@77J$>kK#p&cKQR`9366!PKygsGklRy#f%sbgS@L?h=xd7^a48 zM9xq`OeMzDU{{O~)9FtvE;Vc;#&iT-nRRUCU^kSYtJ9xoI6@lHtux>ZtTS*Fhkfb) zcf0TZ>q$Ox2AqL^#X#t0PuY-LvfjF}Iq9_t^_40jac#s^3Lmr;GgeyhDOH93L>`E# TVH=SiivJN%8r(Pof6BlQeEdmG delta 70 zcmZoMXfc=|#>AjHu~2NHo+1YW5HK<@2y9MdUdFPyfSHSVGdl-A2T%b}CARE is a neural network capable of image restoration from corrupted bio-images, first published in 2018 by [Weigert *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0216-7). The CARE network uses a U-Net network architecture and allows image restoration and resolution improvement in 2D and 3D images, in a supervised manner, using noisy images as input and low-noise images as targets for training. The function of the network is essentially determined by the set of images provided in the training dataset. For instance, if noisy images are provided as input and high signal-to-noise ratio images are provided as targets, the network will perform denoising.\n","\n"," **This particular notebook enables restoration of 2D dataset. If you are interested in restoring 3D dataset, you should use the CARE 3D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n","**Content-aware image restoration: pushing the limits of fluorescence microscopy**, by Weigert *et al.* published in Nature Methods in 2018 (https://www.nature.com/articles/s41592-018-0216-7)\n","\n","And source code found in: https://github.com/csbdeep/csbdeep\n","\n","For a more in-depth description of the features of the network,please refer to [this guide](http://csbdeep.bioimagecomputing.com/doc/) provided by the original authors of the work.\n","\n","We provide a dataset for the training of this notebook as a way to test its functionalities but the training and test data of the restoration experiments is also available from the authors of the original paper [here](https://publications.mpi-cbg.de/publications-sites/7207/).\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," For CARE to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions (for instance, low signal-to-noise ratio and high signal-to-noise ratio) and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Low SNR images\" (Training_source) and \"Training - high SNR images\" (Training_target). Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .tif files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Low SNR images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - High SNR images (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"BCPhV-pe-syw","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"VNZetvLiS1qV","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"UBrnApIUBgxv","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin","colab_type":"text"},"source":["# **2. Install CARE and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install CARE and dependencies\n","\n","#Libraries contains information of certain topics. \n","#For example the tifffile library contains information on how to handle tif-files.\n","\n","#Here, we install libraries which are not already included in Colab.\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install wget\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","#Here, we import and enable Tensorflow 1 instead of Tensorflow 2.\n","import tensorflow \n","import tensorflow as tf\n","\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# ------- Variable specific to CARE -------\n","from csbdeep.utils import download_and_extract_zip_file, plot_some, axes_dict, plot_history, Path, download_and_extract_zip_file\n","from csbdeep.data import RawData, create_patches \n","from csbdeep.io import load_training_data, save_tiff_imagej_compatible\n","from csbdeep.models import Config, CARE\n","from csbdeep import data\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"BLmBseWbRvxL","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd","colab_type":"text"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (Low SNR images) and Training_target (High SNR images or ground truth) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-300 epochs. Evaluate the performance after training (see 5). **Default value: 50**\n","\n","**`patch_size`:** CARE divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 80**\n","\n","**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**`number_of_patches`:** Input the number of the patches per image. Increasing the number of patches allows for larger training datasets. **Default value: 100** \n","\n","**Decreasing the patch size or increasing the number of patches may improve the training but may also increase the training time.**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","InputFile = Training_source+\"/*.tif\"\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","OutputFile = Training_target+\"/*.tif\"\n","\n","#Define where the patch file will be saved\n","base = \"/content\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 50#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels) and number\n","patch_size = 80#@param {type:\"number\"} # in pixels\n","number_of_patches = 100#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","\n","batch_size = 16#@param {type:\"number\"}\n","number_of_steps = 400#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 16\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n","\n","#Here we define the percentage to use for validation\n","percentage = percentage_validation/100\n","\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Folder already exists and has been removed !!\")\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","# Here we disable pre-trained model by default (in case the cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","# The shape of the images.\n","x = imread(InputFile)\n","y = imread(OutputFile)\n","\n","print('Loaded Input images (number, width, length) =', x.shape)\n","print('Loaded Output images (number, width, length) =', y.shape)\n","print(\"Parameters initiated.\")\n","\n","# This will display a randomly chosen dataset input and output\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","\n","# Here we check that the input images contains the expected dimensions\n","if len(x.shape) == 2:\n"," print(\"Image dimensions (y,x)\",x.shape)\n","\n","if not len(x.shape) == 2:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, norm=simple_norm(x, percent = 99), interpolation='nearest')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, norm=simple_norm(y, percent = 99), interpolation='nearest')\n","plt.title('Training target')\n","plt.axis('off');\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_-CEUqlS8o3M","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"qe9zvEJ9qOH2","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by [Augmentor.](https://github.com/mdbloice/Augmentor)\n","\n","[Augmentor](https://github.com/mdbloice/Augmentor) was described in the following article:\n","\n","Marcus D Bloice, Peter M Roth, Andreas Holzinger, Biomedical image augmentation using Augmentor, Bioinformatics, https://doi.org/10.1093/bioinformatics/btz259\n","\n","**Please also cite this original paper when publishing results obtained using this notebook with augmentation enabled.** "]},{"cell_type":"code","metadata":{"id":"zmtlu9YU266X","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," !pip install Augmentor\n"," import Augmentor\n","\n","\n","#@markdown ####Choose a factor by which you want to multiply your original dataset\n","\n","Multiply_dataset_by = 1 #@param {type:\"slider\", min:1, max:30, step:1}\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please choose the probability of the following image manipulations to be used to augment your dataset (1 = always used; 0 = disabled ):\n","\n","#@markdown ####Mirror and rotate images\n","rotate_90_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","rotate_270_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_left_right = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_top_bottom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image Zoom\n","\n","random_zoom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","random_zoom_magnification = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image distortion\n","\n","random_distortion = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","#@markdown ####Image shearing and skewing \n","\n","image_shear = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","max_image_shear = 1 #@param {type:\"slider\", min:1, max:25, step:1}\n","\n","skew_image = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","skew_image_magnitude = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","if Use_Default_Augmentation_Parameters:\n"," rotate_90_degrees = 0.5\n"," rotate_270_degrees = 0.5\n"," flip_left_right = 0.5\n"," flip_top_bottom = 0.5\n","\n"," if not Multiply_dataset_by >5:\n"," random_zoom = 0\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0\n"," image_shear = 0\n"," max_image_shear = 10\n"," skew_image = 0\n"," skew_image_magnitude = 0\n","\n"," if Multiply_dataset_by >5:\n"," random_zoom = 0.1\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0.5\n"," image_shear = 0.2\n"," max_image_shear = 5\n"," skew_image = 0.2\n"," skew_image_magnitude = 0.4\n","\n"," if Multiply_dataset_by >25:\n"," random_zoom = 0.5\n"," random_zoom_magnification = 0.8\n"," random_distortion = 0.5\n"," image_shear = 0.5\n"," max_image_shear = 20\n"," skew_image = 0.5\n"," skew_image_magnitude = 0.6\n","\n","\n","list_files = os.listdir(Training_source)\n","Nb_files = len(list_files)\n","\n","Nb_augmented_files = (Nb_files * Multiply_dataset_by)\n","\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","# Here we set the path for the various folder were the augmented images will be loaded\n","\n","# All images are first saved into the augmented folder\n"," #Augmented_folder = \"/content/Augmented_Folder\"\n"," \n"," if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n"," Augmented_folder = Saving_path+\"/Augmented_Folder\"\n"," if os.path.exists(Augmented_folder):\n"," shutil.rmtree(Augmented_folder)\n"," os.makedirs(Augmented_folder)\n","\n"," #Training_source_augmented = \"/content/Training_source_augmented\"\n"," Training_source_augmented = Saving_path+\"/Training_source_augmented\"\n","\n"," if os.path.exists(Training_source_augmented):\n"," shutil.rmtree(Training_source_augmented)\n"," os.makedirs(Training_source_augmented)\n","\n"," #Training_target_augmented = \"/content/Training_target_augmented\"\n"," Training_target_augmented = Saving_path+\"/Training_target_augmented\"\n","\n"," if os.path.exists(Training_target_augmented):\n"," shutil.rmtree(Training_target_augmented)\n"," os.makedirs(Training_target_augmented)\n","\n","\n","# Here we generate the augmented images\n","#Load the images\n"," p = Augmentor.Pipeline(Training_source, Augmented_folder)\n","\n","#Define the matching images\n"," p.ground_truth(Training_target)\n","#Define the augmentation possibilities\n"," if not rotate_90_degrees == 0:\n"," p.rotate90(probability=rotate_90_degrees)\n"," \n"," if not rotate_270_degrees == 0:\n"," p.rotate270(probability=rotate_270_degrees)\n","\n"," if not flip_left_right == 0:\n"," p.flip_left_right(probability=flip_left_right)\n","\n"," if not flip_top_bottom == 0:\n"," p.flip_top_bottom(probability=flip_top_bottom)\n","\n"," if not random_zoom == 0:\n"," p.zoom_random(probability=random_zoom, percentage_area=random_zoom_magnification)\n"," \n"," if not random_distortion == 0:\n"," p.random_distortion(probability=random_distortion, grid_width=4, grid_height=4, magnitude=8)\n","\n"," if not image_shear == 0:\n"," p.shear(probability=image_shear,max_shear_left=20,max_shear_right=20)\n"," \n"," if not skew_image == 0:\n"," p.skew(probability=skew_image,magnitude=skew_image_magnitude)\n","\n"," p.sample(int(Nb_augmented_files))\n","\n"," print(int(Nb_augmented_files),\"matching images generated\")\n","\n","# Here we sort through the images and move them back to augmented trainning source and targets folders\n","\n"," augmented_files = os.listdir(Augmented_folder)\n","\n"," for f in augmented_files:\n","\n"," if (f.startswith(\"_groundtruth_(1)_\")):\n"," shortname_noprefix = f[17:]\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_target_augmented+\"/\"+shortname_noprefix) \n"," if not (f.startswith(\"_groundtruth_(1)_\")):\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_source_augmented+\"/\"+f)\n"," \n","\n"," for filename in os.listdir(Training_source_augmented):\n"," os.chdir(Training_source_augmented)\n"," os.rename(filename, filename.replace('_original', ''))\n"," \n"," #Here we clean up the extra files\n"," shutil.rmtree(Augmented_folder)\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\") \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"4kb3xSZMRzxU","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CARE 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"mlN-VNOgR-nr","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_'+Weights_choice+'.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead')\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead')\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL","colab_type":"text"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"-A4ipz8gs3Ew","colab_type":"text"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"LKYRNhA5Qnis","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","if Use_Data_augmentation:\n"," Training_source_dir = Training_source_augmented\n"," Training_target_dir = Training_target_augmented\n","\n","if not Use_Data_augmentation:\n"," Training_source_dir = Training_source\n"," Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","# This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images.\n","# This file is saved in .npz format and later called when loading the trainig data.\n","\n","\n","raw_data = data.RawData.from_folder(\n"," basepath=base,\n"," source_dirs=[Training_source_dir], \n"," target_dir=Training_target_dir, \n"," axes='CYX', \n"," pattern='*.tif*')\n","\n","X, Y, XY_axes = data.create_patches(\n"," raw_data, \n"," patch_filter=None, \n"," patch_size=(patch_size,patch_size), \n"," n_patches_per_image=number_of_patches)\n","\n","print ('Creating 2D training dataset')\n","training_path = model_path+\"/rawdata\"\n","rawdata1 = training_path+\".npz\"\n","np.savez(training_path,X=X, Y=Y, axes=XY_axes)\n","\n","# Load Training Data\n","(X,Y), (X_val,Y_val), axes = load_training_data(rawdata1, validation_split=percentage, verbose=True)\n","c = axes_dict(axes)['C']\n","n_channel_in, n_channel_out = X.shape[c], Y.shape[c]\n","\n","%memit \n","\n","#plot of training patches.\n","plt.figure(figsize=(12,5))\n","plot_some(X[:5],Y[:5])\n","plt.suptitle('5 example training patches (top row: source, bottom row: target)');\n","\n","#plot of validation patches\n","plt.figure(figsize=(12,5))\n","plot_some(X_val[:5],Y_val[:5])\n","plt.suptitle('5 example validation patches (top row: source, bottom row: target)');\n","\n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size)+1\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#Here we create the configuration file\n","\n","config = Config(axes, n_channel_in, n_channel_out, probabilistic=True, train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, unet_kern_size=5, unet_n_depth=3, train_batch_size=batch_size, train_learning_rate=initial_learning_rate)\n","\n","print(config)\n","vars(config)\n","\n","# Compile the CARE model for network training\n","model_training= CARE(config, model_name, basedir=model_path)\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model_training.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR","colab_type":"text"},"source":["## **4.2. Train the network**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches."]},{"cell_type":"code","metadata":{"id":"biXiR017C4UU","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","\n","# Start Training\n","history = model_training.train(X,Y, validation_data=(X_val,Y_val))\n","\n","print(\"Training, done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"XQjQb_J_Qyku","colab_type":"text"},"source":["##**4.3. Download your model(s) from Google Drive**\n","\n","\n","---\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"2HbZd7rFqAad","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"EdcnkCr9Nbl8","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","**Note: Plots of the losses will be shown in a linear and in a log scale. This can help visualise changes in the losses at different magnitudes. However, note that if the losses are negative the plot on the log scale will be empty. This is not an error.**"]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"RZOPCVN0qcYb","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"Nh8MlX3sqd_7","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","# Activate the pretrained model. \n","model_training = CARE(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," predicted = model_training.predict(img, axes='YX')\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(filename, predicted)\n","\n","\n","def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = io.imread(os.path.join(Target_QC_folder, i))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n","\n","plt.figure(figsize=(20,20))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99))\n","plt.title('Target',fontsize=15)\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\n","plt.title('Source',fontsize=15)\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", Test_FileList[-1]))\n","plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\n","plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"Esqnbew8uznk"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"9ZmST3JRq-Ho","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = os.path.join(Prediction_model_path, Prediction_model_name)\n","\n","\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","\n","#Activate the pretrained model. \n","model_training = CARE(config=None, name=Prediction_model_name, basedir=Prediction_model_path)\n","\n","\n","# creates a loop, creating filenames and saving them\n","for filename in os.listdir(Data_folder):\n"," img = imread(os.path.join(Data_folder,filename))\n"," restored = model_training.predict(img, axes='YX')\n"," os.chdir(Result_folder)\n"," imsave(filename,restored)\n","\n","print(\"Images saved into folder:\", Result_folder)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"EIe3CRD7XUxa","colab_type":"text"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","metadata":{"id":"LmDP8xiwXTTL","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","plt.figure(figsize=(16,8))\n","\n","plt.subplot(1,2,1)\n","plt.axis('off')\n","plt.imshow(x, norm=simple_norm(x, percent = 99), interpolation='nearest')\n","plt.title('Input')\n","\n","plt.subplot(1,2,2)\n","plt.axis('off')\n","plt.imshow(y, norm=simple_norm(y, percent = 99), interpolation='nearest')\n","plt.title('Predicted output');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB","colab_type":"text"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"Rn9zpWpo0xNw","colab_type":"text"},"source":["\n","#**Thank you for using CARE 2D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"CARE_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1mqcexfPBaIWuvMWWbJZUFtPoZoJJwrEA","timestamp":1589278334507},{"file_id":"159ARwlQE7-zi0EHxunOF_YPFLt-ZVU5x","timestamp":1587562499898},{"file_id":"1W-7NHehG5MRFILvZZzhPWWnOdJMkadb2","timestamp":1586332290412},{"file_id":"1pUetEQICxYWkYVaQIgdRH1EZBTl7oc2A","timestamp":1586292199692},{"file_id":"1MD36ZkM6XR9EuV12zimJmfCjzyeYZFWq","timestamp":1586269469061},{"file_id":"16A2mbaHzlEElntS8qkFBOsBvZG-mUeY6","timestamp":1586253795726},{"file_id":"1gJlcjOiSxr2buDOxmcFbT_d-GqwLjXtK","timestamp":1583343225796},{"file_id":"10yGI51WzHfgWgZAyE-EbkZFEvIOd6CP6","timestamp":1583171396283}],"collapsed_sections":[],"toc_visible":true},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I","colab_type":"text"},"source":["# **CARE: Content-aware image restoration (2D)**\n","\n","---\n","\n","CARE is a neural network capable of image restoration from corrupted bio-images, first published in 2018 by [Weigert *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0216-7). The CARE network uses a U-Net network architecture and allows image restoration and resolution improvement in 2D and 3D images, in a supervised manner, using noisy images as input and low-noise images as targets for training. The function of the network is essentially determined by the set of images provided in the training dataset. For instance, if noisy images are provided as input and high signal-to-noise ratio images are provided as targets, the network will perform denoising.\n","\n"," **This particular notebook enables restoration of 2D dataset. If you are interested in restoring 3D dataset, you should use the CARE 3D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n","**Content-aware image restoration: pushing the limits of fluorescence microscopy**, by Weigert *et al.* published in Nature Methods in 2018 (https://www.nature.com/articles/s41592-018-0216-7)\n","\n","And source code found in: https://github.com/csbdeep/csbdeep\n","\n","For a more in-depth description of the features of the network,please refer to [this guide](http://csbdeep.bioimagecomputing.com/doc/) provided by the original authors of the work.\n","\n","We provide a dataset for the training of this notebook as a way to test its functionalities but the training and test data of the restoration experiments is also available from the authors of the original paper [here](https://publications.mpi-cbg.de/publications-sites/7207/).\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," For CARE to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions (for instance, low signal-to-noise ratio and high signal-to-noise ratio) and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Low SNR images\" (Training_source) and \"Training - high SNR images\" (Training_target). Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .tif files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Low SNR images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - High SNR images (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"BCPhV-pe-syw","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"VNZetvLiS1qV","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"UBrnApIUBgxv","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin","colab_type":"text"},"source":["# **2. Install CARE and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install CARE and dependencies\n","\n","#Libraries contains information of certain topics. \n","#For example the tifffile library contains information on how to handle tif-files.\n","\n","#Here, we install libraries which are not already included in Colab.\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install wget\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","#Here, we import and enable Tensorflow 1 instead of Tensorflow 2.\n","import tensorflow \n","import tensorflow as tf\n","\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# ------- Variable specific to CARE -------\n","from csbdeep.utils import download_and_extract_zip_file, plot_some, axes_dict, plot_history, Path, download_and_extract_zip_file\n","from csbdeep.data import RawData, create_patches \n","from csbdeep.io import load_training_data, save_tiff_imagej_compatible\n","from csbdeep.models import Config, CARE\n","from csbdeep import data\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"BLmBseWbRvxL","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd","colab_type":"text"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (Low SNR images) and Training_target (High SNR images or ground truth) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-300 epochs. Evaluate the performance after training (see 5). **Default value: 50**\n","\n","**`patch_size`:** CARE divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 80**\n","\n","**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**`number_of_patches`:** Input the number of the patches per image. Increasing the number of patches allows for larger training datasets. **Default value: 100** \n","\n","**Decreasing the patch size or increasing the number of patches may improve the training but may also increase the training time.**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","InputFile = Training_source+\"/*.tif\"\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","OutputFile = Training_target+\"/*.tif\"\n","\n","#Define where the patch file will be saved\n","base = \"/content\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 50#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels) and number\n","patch_size = 80#@param {type:\"number\"} # in pixels\n","number_of_patches = 100#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","\n","batch_size = 16#@param {type:\"number\"}\n","number_of_steps = 400#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 16\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n","\n","#Here we define the percentage to use for validation\n","percentage = percentage_validation/100\n","\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Folder already exists and has been removed !!\")\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","# Here we disable pre-trained model by default (in case the cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","# The shape of the images.\n","x = imread(InputFile)\n","y = imread(OutputFile)\n","\n","print('Loaded Input images (number, width, length) =', x.shape)\n","print('Loaded Output images (number, width, length) =', y.shape)\n","print(\"Parameters initiated.\")\n","\n","# This will display a randomly chosen dataset input and output\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","\n","# Here we check that the input images contains the expected dimensions\n","if len(x.shape) == 2:\n"," print(\"Image dimensions (y,x)\",x.shape)\n","\n","if not len(x.shape) == 2:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, norm=simple_norm(x, percent = 99), interpolation='nearest')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, norm=simple_norm(y, percent = 99), interpolation='nearest')\n","plt.title('Training target')\n","plt.axis('off');\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_-CEUqlS8o3M","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"qe9zvEJ9qOH2","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by [Augmentor.](https://github.com/mdbloice/Augmentor)\n","\n","[Augmentor](https://github.com/mdbloice/Augmentor) was described in the following article:\n","\n","Marcus D Bloice, Peter M Roth, Andreas Holzinger, Biomedical image augmentation using Augmentor, Bioinformatics, https://doi.org/10.1093/bioinformatics/btz259\n","\n","**Please also cite this original paper when publishing results obtained using this notebook with augmentation enabled.** "]},{"cell_type":"code","metadata":{"id":"zmtlu9YU266X","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," !pip install Augmentor\n"," import Augmentor\n","\n","\n","#@markdown ####Choose a factor by which you want to multiply your original dataset\n","\n","Multiply_dataset_by = 1 #@param {type:\"slider\", min:1, max:30, step:1}\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please choose the probability of the following image manipulations to be used to augment your dataset (1 = always used; 0 = disabled ):\n","\n","#@markdown ####Mirror and rotate images\n","rotate_90_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","rotate_270_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_left_right = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_top_bottom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image Zoom\n","\n","random_zoom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","random_zoom_magnification = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image distortion\n","\n","random_distortion = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","#@markdown ####Image shearing and skewing \n","\n","image_shear = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","max_image_shear = 1 #@param {type:\"slider\", min:1, max:25, step:1}\n","\n","skew_image = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","skew_image_magnitude = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","if Use_Default_Augmentation_Parameters:\n"," rotate_90_degrees = 0.5\n"," rotate_270_degrees = 0.5\n"," flip_left_right = 0.5\n"," flip_top_bottom = 0.5\n","\n"," if not Multiply_dataset_by >5:\n"," random_zoom = 0\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0\n"," image_shear = 0\n"," max_image_shear = 10\n"," skew_image = 0\n"," skew_image_magnitude = 0\n","\n"," if Multiply_dataset_by >5:\n"," random_zoom = 0.1\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0.5\n"," image_shear = 0.2\n"," max_image_shear = 5\n"," skew_image = 0.2\n"," skew_image_magnitude = 0.4\n","\n"," if Multiply_dataset_by >25:\n"," random_zoom = 0.5\n"," random_zoom_magnification = 0.8\n"," random_distortion = 0.5\n"," image_shear = 0.5\n"," max_image_shear = 20\n"," skew_image = 0.5\n"," skew_image_magnitude = 0.6\n","\n","\n","list_files = os.listdir(Training_source)\n","Nb_files = len(list_files)\n","\n","Nb_augmented_files = (Nb_files * Multiply_dataset_by)\n","\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","# Here we set the path for the various folder were the augmented images will be loaded\n","\n","# All images are first saved into the augmented folder\n"," #Augmented_folder = \"/content/Augmented_Folder\"\n"," \n"," if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n"," Augmented_folder = Saving_path+\"/Augmented_Folder\"\n"," if os.path.exists(Augmented_folder):\n"," shutil.rmtree(Augmented_folder)\n"," os.makedirs(Augmented_folder)\n","\n"," #Training_source_augmented = \"/content/Training_source_augmented\"\n"," Training_source_augmented = Saving_path+\"/Training_source_augmented\"\n","\n"," if os.path.exists(Training_source_augmented):\n"," shutil.rmtree(Training_source_augmented)\n"," os.makedirs(Training_source_augmented)\n","\n"," #Training_target_augmented = \"/content/Training_target_augmented\"\n"," Training_target_augmented = Saving_path+\"/Training_target_augmented\"\n","\n"," if os.path.exists(Training_target_augmented):\n"," shutil.rmtree(Training_target_augmented)\n"," os.makedirs(Training_target_augmented)\n","\n","\n","# Here we generate the augmented images\n","#Load the images\n"," p = Augmentor.Pipeline(Training_source, Augmented_folder)\n","\n","#Define the matching images\n"," p.ground_truth(Training_target)\n","#Define the augmentation possibilities\n"," if not rotate_90_degrees == 0:\n"," p.rotate90(probability=rotate_90_degrees)\n"," \n"," if not rotate_270_degrees == 0:\n"," p.rotate270(probability=rotate_270_degrees)\n","\n"," if not flip_left_right == 0:\n"," p.flip_left_right(probability=flip_left_right)\n","\n"," if not flip_top_bottom == 0:\n"," p.flip_top_bottom(probability=flip_top_bottom)\n","\n"," if not random_zoom == 0:\n"," p.zoom_random(probability=random_zoom, percentage_area=random_zoom_magnification)\n"," \n"," if not random_distortion == 0:\n"," p.random_distortion(probability=random_distortion, grid_width=4, grid_height=4, magnitude=8)\n","\n"," if not image_shear == 0:\n"," p.shear(probability=image_shear,max_shear_left=20,max_shear_right=20)\n"," \n"," if not skew_image == 0:\n"," p.skew(probability=skew_image,magnitude=skew_image_magnitude)\n","\n"," p.sample(int(Nb_augmented_files))\n","\n"," print(int(Nb_augmented_files),\"matching images generated\")\n","\n","# Here we sort through the images and move them back to augmented trainning source and targets folders\n","\n"," augmented_files = os.listdir(Augmented_folder)\n","\n"," for f in augmented_files:\n","\n"," if (f.startswith(\"_groundtruth_(1)_\")):\n"," shortname_noprefix = f[17:]\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_target_augmented+\"/\"+shortname_noprefix) \n"," if not (f.startswith(\"_groundtruth_(1)_\")):\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_source_augmented+\"/\"+f)\n"," \n","\n"," for filename in os.listdir(Training_source_augmented):\n"," os.chdir(Training_source_augmented)\n"," os.rename(filename, filename.replace('_original', ''))\n"," \n"," #Here we clean up the extra files\n"," shutil.rmtree(Augmented_folder)\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\") \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"4kb3xSZMRzxU","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CARE 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"mlN-VNOgR-nr","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_'+Weights_choice+'.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead')\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead')\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL","colab_type":"text"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"-A4ipz8gs3Ew","colab_type":"text"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"LKYRNhA5Qnis","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","if Use_Data_augmentation:\n"," Training_source_dir = Training_source_augmented\n"," Training_target_dir = Training_target_augmented\n","\n","if not Use_Data_augmentation:\n"," Training_source_dir = Training_source\n"," Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","# This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images.\n","# This file is saved in .npz format and later called when loading the trainig data.\n","\n","\n","raw_data = data.RawData.from_folder(\n"," basepath=base,\n"," source_dirs=[Training_source_dir], \n"," target_dir=Training_target_dir, \n"," axes='CYX', \n"," pattern='*.tif*')\n","\n","X, Y, XY_axes = data.create_patches(\n"," raw_data, \n"," patch_filter=None, \n"," patch_size=(patch_size,patch_size), \n"," n_patches_per_image=number_of_patches)\n","\n","print ('Creating 2D training dataset')\n","training_path = model_path+\"/rawdata\"\n","rawdata1 = training_path+\".npz\"\n","np.savez(training_path,X=X, Y=Y, axes=XY_axes)\n","\n","# Load Training Data\n","(X,Y), (X_val,Y_val), axes = load_training_data(rawdata1, validation_split=percentage, verbose=True)\n","c = axes_dict(axes)['C']\n","n_channel_in, n_channel_out = X.shape[c], Y.shape[c]\n","\n","%memit \n","\n","#plot of training patches.\n","plt.figure(figsize=(12,5))\n","plot_some(X[:5],Y[:5])\n","plt.suptitle('5 example training patches (top row: source, bottom row: target)');\n","\n","#plot of validation patches\n","plt.figure(figsize=(12,5))\n","plot_some(X_val[:5],Y_val[:5])\n","plt.suptitle('5 example validation patches (top row: source, bottom row: target)');\n","\n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size)+1\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#Here we create the configuration file\n","\n","config = Config(axes, n_channel_in, n_channel_out, probabilistic=True, train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, unet_kern_size=5, unet_n_depth=3, train_batch_size=batch_size, train_learning_rate=initial_learning_rate)\n","\n","print(config)\n","vars(config)\n","\n","# Compile the CARE model for network training\n","model_training= CARE(config, model_name, basedir=model_path)\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model_training.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR","colab_type":"text"},"source":["## **4.2. Start Trainning**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches."]},{"cell_type":"code","metadata":{"id":"biXiR017C4UU","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","\n","# Start Training\n","history = model_training.train(X,Y, validation_data=(X_val,Y_val))\n","\n","print(\"Training, done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"XQjQb_J_Qyku","colab_type":"text"},"source":["##**4.3. Download your model(s) from Google Drive**\n","\n","\n","---\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"2HbZd7rFqAad","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"EdcnkCr9Nbl8","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","**Note: Plots of the losses will be shown in a linear and in a log scale. This can help visualise changes in the losses at different magnitudes. However, note that if the losses are negative the plot on the log scale will be empty. This is not an error.**"]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"RZOPCVN0qcYb","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"Nh8MlX3sqd_7","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","# Activate the pretrained model. \n","model_training = CARE(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," predicted = model_training.predict(img, axes='YX')\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(filename, predicted)\n","\n","\n","def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = io.imread(os.path.join(Target_QC_folder, i))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n","\n","plt.figure(figsize=(20,20))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99))\n","plt.title('Target',fontsize=15)\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\n","plt.title('Source',fontsize=15)\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", Test_FileList[-1]))\n","plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\n","plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"Esqnbew8uznk"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"9ZmST3JRq-Ho","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = os.path.join(Prediction_model_path, Prediction_model_name)\n","\n","\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","\n","#Activate the pretrained model. \n","model_training = CARE(config=None, name=Prediction_model_name, basedir=Prediction_model_path)\n","\n","\n","# creates a loop, creating filenames and saving them\n","for filename in os.listdir(Data_folder):\n"," img = imread(os.path.join(Data_folder,filename))\n"," restored = model_training.predict(img, axes='YX')\n"," os.chdir(Result_folder)\n"," imsave(filename,restored)\n","\n","print(\"Images saved into folder:\", Result_folder)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"EIe3CRD7XUxa","colab_type":"text"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","metadata":{"id":"LmDP8xiwXTTL","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","plt.figure(figsize=(16,8))\n","\n","plt.subplot(1,2,1)\n","plt.axis('off')\n","plt.imshow(x, norm=simple_norm(x, percent = 99), interpolation='nearest')\n","plt.title('Input')\n","\n","plt.subplot(1,2,2)\n","plt.axis('off')\n","plt.imshow(y, norm=simple_norm(y, percent = 99), interpolation='nearest')\n","plt.title('Predicted output');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB","colab_type":"text"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"Rn9zpWpo0xNw","colab_type":"text"},"source":["\n","#**Thank you for using CARE 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb index 8686f7e5..a7ee7342 100755 --- a/Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"CARE_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1t9a-44km730bI7F4I08-6Xh7wEZuL98p","timestamp":1591013189418},{"file_id":"11TigzvLl4FSSwFHUNwLzZKI2IAix4Nmu","timestamp":1586415689249},{"file_id":"1_dSnxUg_qtNWjrPc7D6RWDWlCanEL4Ve","timestamp":1585153449937},{"file_id":"1bKo8jYVZPPgXPa_-Gdu1KhDnNN4vYfLx","timestamp":1583200150464}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I","colab_type":"text"},"source":["# **CARE: Content-aware image restoration (3D)**\n","\n","---\n","\n","CARE is a neural network capable of image restoration from corrupted bio-images, first published in 2018 by [Weigert *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0216-7). The CARE network uses a U-Net network architecture and allows image restoration and resolution improvement in 2D and 3D images, in a supervised manner, using noisy images as input and low-noise images as targets for training. The function of the network is essentially determined by the set of images provided in the training dataset. For instance, if noisy images are provided as input and high signal-to-noise ratio images are provided as targets, the network will perform denoising.\n","\n"," **This particular notebook enables restoration of 3D dataset. If you are interested in restoring 2D dataset, you should use the CARE 2D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper: \n","\n","**Content-aware image restoration: pushing the limits of fluorescence microscopy**, by Weigert *et al.* published in Nature Methods in 2018 (https://www.nature.com/articles/s41592-018-0216-7)\n","\n","And source code found in: https://github.com/csbdeep/csbdeep\n","\n","For a more in-depth description of the features of the network,please refer to [this guide](http://csbdeep.bioimagecomputing.com/doc/) provided by the original authors of the work.\n","\n","We provide a dataset for the training of this notebook as a way to test its functionalities but the training and test data of the restoration experiments is also available from the authors of the original paper [here](https://publications.mpi-cbg.de/publications-sites/7207/).\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," For CARE to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions (for instance, low signal-to-noise ratio and high signal-to-noise ratio) and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Low SNR images\" (Training_source) and \"Training - high SNR images\" (Training_target). Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .tif files!**\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. \n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Low SNR images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - High SNR images (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"b4-r1gE7Iamv","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"BDhmUgqCStlm","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-oqBTeLaImnU","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin","colab_type":"text"},"source":["# **2. Install CARE and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install CARE and dependencies\n","\n","\n","#Here, we install libraries which are not already included in Colab.\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install wget\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","#Here, we import and enable Tensorflow 1 instead of Tensorflow 2.\n","\n","import tensorflow\n","import tensorflow as tf\n","\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# ------- Variable specific to CARE -------\n","from csbdeep.utils import download_and_extract_zip_file, normalize, plot_some, axes_dict, plot_history, Path, download_and_extract_zip_file\n","from csbdeep.data import RawData, create_patches \n","from csbdeep.io import load_training_data, save_tiff_imagej_compatible\n","from csbdeep.models import Config, CARE\n","from csbdeep import data\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"WzYAA-MuaYrT","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd","colab_type":"text"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (Low SNR images) and Training_target (High SNR images or ground truth) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number of epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-300 epochs. Evaluate the performance after training (see 5.). **Default value: 40**\n","\n","**`patch_size`:** CARE divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 80**\n","\n","**`patch_height`:** The value should be smaller than the Z dimensions of the image and divisible by 4. When analysing isotropic stacks patch_size and patch_height should have similar values.\n","\n","**When choosing the patch_size and patch_height, the values should be i) large enough that they will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**\n","\n","**`number_of_patches`:** Input the number of the patches per image. Increasing the number of patches allows for larger training datasets. **Default value: 200** \n","\n","**Decreasing the patch size or increasing the number of patches may improve the training but may also increase the training time.**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","colab_type":"code","cellView":"form","colab":{}},"source":["\n","#@markdown ###Path to training images:\n","\n","# base folder of GT and low images\n","base = \"/content\"\n","\n","# low SNR images\n","Training_source = \"\" #@param {type:\"string\"}\n","lowfile = Training_source+\"/*.tif\"\n","# Ground truth images\n","Training_target = \"\" #@param {type:\"string\"}\n","GTfile = Training_target+\"/*.tif\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","\n","# create the training data file into model_path folder.\n","training_data = model_path+\"/my_training_data.npz\"\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","\n","number_of_epochs = 40#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels) and number\n","patch_size = 80#@param {type:\"number\"} # pixels in\n","patch_height = 8#@param {type:\"number\"}\n","number_of_patches = 200#@param {type:\"number\"}\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","\n","batch_size = 16#@param {type:\"number\"}\n","number_of_steps = 300#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 16\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n","\n","percentage = percentage_validation/100\n","\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Folder already exists and has been removed !!\")\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","\n","#Load one randomly chosen training source file\n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_height is smaller than the z dimension of the image \n","\n","if patch_height > Image_Z :\n"," patch_height = Image_Z\n"," print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we check that patch_height is divisible by 4\n","if not patch_height % 4 == 0:\n"," patch_height = ((int(patch_height / 4)-1) * 4)\n"," if patch_height == 0:\n"," patch_height = 4\n"," print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n","\n","\n","#Load one randomly chosen training target file\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","\n","\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Low SNR image (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('High SNR image (single Z plane)');\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xGcl7WGP4WHt","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---"]},{"cell_type":"markdown","metadata":{"id":"5Lio8hpZ4PJ1","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by rotating the training images in the XY-Plane and flipping them along X-Axis.\n","\n","**The flip option alone will double the size of your dataset, rotation will quadruple and both together will increase the dataset by a factor of 8.**"]},{"cell_type":"code","metadata":{"id":"htqjkJWt5J_8","colab_type":"code","cellView":"form","colab":{}},"source":["Use_Data_augmentation = False #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n","\n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug(Training_source,Training_target,flip=Flip)\n"," \n"," elif Rotation == False and Flip == True:\n"," flip(Training_source,Training_target)\n"," print(\"Done\")\n","\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bQDuybvyadKU","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CARE 3D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pret-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"8vPkzEBNamE4","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL","colab_type":"text"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"tGW2iaU6X5zi","colab_type":"text"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"WMJnGJpCMa4y","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","\n","# This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images.\n","# This file is saved in .npz format and later called when loading the trainig data.\n","\n","if Use_Data_augmentation == True:\n"," Training_source = Saving_path+'/augmented_source'\n"," Training_target = Saving_path+'/augmented_target'\n","\n","raw_data = RawData.from_folder (\n"," basepath = base,\n"," source_dirs = [Training_source],\n"," target_dir = Training_target,\n"," axes = 'ZYX',\n"," pattern='*.tif*'\n",")\n","X, Y, XY_axes = create_patches (\n"," raw_data = raw_data,\n"," patch_size = (patch_height,patch_size,patch_size),\n"," n_patches_per_image = number_of_patches, \n"," save_file = training_data,\n",")\n","\n","assert X.shape == Y.shape\n","print(\"shape of X,Y =\", X.shape)\n","print(\"axes of X,Y =\", XY_axes)\n","\n","%memit \n","print ('Creating 3D training dataset')\n","\n","# Load Training Data\n","(X,Y), (X_val,Y_val), axes = load_training_data(training_data, validation_split=percentage, verbose=True)\n","c = axes_dict(axes)['C']\n","n_channel_in, n_channel_out = X.shape[c], Y.shape[c]\n","\n","#Plot example patches\n","\n","#plot of training patches.\n","plt.figure(figsize=(12,5))\n","plot_some(X[:5],Y[:5])\n","plt.suptitle('5 example training patches (top row: source, bottom row: target)');\n","\n","#plot of validation patches\n","plt.figure(figsize=(12,5))\n","plot_some(X_val[:5],Y_val[:5])\n","plt.suptitle('5 example validation patches (top row: source, bottom row: target)');\n","\n","%memit \n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size)+1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#Here, we create the default Config object which sets the hyperparameters of the network training.\n","\n","config = Config(axes, n_channel_in, n_channel_out, train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, train_batch_size=batch_size, train_learning_rate=initial_learning_rate)\n","print(config)\n","vars(config)\n","\n","# Compile the CARE model for network training\n","\n","model_training= CARE(config, model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model_training.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR","colab_type":"text"},"source":["## **4.2. Train the network**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point."]},{"cell_type":"code","metadata":{"id":"j_Qm5JBmlvJg","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Start Training\n","\n","start = time.time()\n","\n","# Start Training\n","history = model_training.train(X,Y, validation_data=(X_val,Y_val))\n","\n","print(\"Training, done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"w8Q_uYGgiico","colab_type":"text"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"QYuIOWQ3imuU","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"zazOZ3wDx0zQ","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"biT9FI9Ri77_","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"nAs4Wni7VYbq","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(path_metrics_save+'Prediction'):\n"," shutil.rmtree(path_metrics_save+'Prediction')\n","os.makedirs(path_metrics_save+'Prediction')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","# Activate the pretrained model. \n","model_training = CARE(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," n_slices = img.shape[0]\n"," predicted = model_training.predict(img, axes='ZYX', n_tiles=n_tilesZYX)\n"," os.chdir(path_metrics_save+'Prediction/')\n"," imsave('Predicted_'+filename, predicted)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvS_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvS_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvS_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," mSSIM_GvS_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," NRMSE_GvS_list_mean = []\n"," PSNR_GvP_list_mean = []\n"," PSNR_GvS_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," # Calculating the position of the mid-plane slice\n"," z_mid_plane = int(n_slices / 2)+1\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_SSIM_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," # -------------------------------- Normalising the dataset --------------------------------\n","\n"," test_GT_norm, test_source_norm = norm_minmse(test_GT_stack[z], test_source_stack[z], normalize_gt=True)\n"," test_GT_norm, test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = structural_similarity(test_GT_norm, test_source_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = img_as_float32(img_SSIM_GTvsPrediction, force_copy=False)\n"," img_SSIM_GTvsSource_stack[z] = img_as_float32(img_SSIM_GTvsSource, force_copy=False)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = img_as_float32(img_RSE_GTvsPrediction, force_copy=False)\n"," img_RSE_GTvsSource_stack[z] = img_as_float32(img_RSE_GTvsSource, force_copy=False)\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource), str(PSNR_GTvsPrediction), str(PSNR_GTvsSource)])\n"," \n"," # Collect values to display in dataframe output\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvS_list.append(index_SSIM_GTvsSource)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvS_list.append(NRMSE_GTvsSource)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvS_list.append(PSNR_GTvsSource)\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n"," SSIM_GTvsS_forDisplay = index_SSIM_GTvsSource\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n"," NRMSE_GTvsS_forDisplay = NRMSE_GTvsSource\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n"," mSSIM_GvS_list_mean.append(sum(mSSIM_GvS_list)/len(mSSIM_GvS_list))\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n"," NRMSE_GvS_list_mean.append(sum(NRMSE_GvS_list)/len(NRMSE_GvS_list))\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n"," PSNR_GvS_list_mean.append(sum(PSNR_GvS_list)/len(PSNR_GvS_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n","\n"," img_SSIM_GTvsSource_stack_32 = img_as_float32(img_SSIM_GTvsSource_stack, force_copy=False)\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsSource_stack_32 = img_as_float32(img_RSE_GTvsSource_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsSource_'+thisFile,img_SSIM_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsSource_'+thisFile,img_RSE_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","pdResults[\"Input v. GT mSSIM\"] = mSSIM_GvS_list_mean\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","pdResults[\"Input v. GT NRMSE\"] = NRMSE_GvS_list_mean\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","pdResults[\"Input v. GT PSNR\"] = PSNR_GvS_list_mean\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(20,20))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","\n","# Calculating the position of the mid-plane slice\n","z_mid_plane = int(img_GT.shape[0] / 2)+1\n","\n","plt.imshow(img_GT[z_mid_plane], norm=simple_norm(img_GT[z_mid_plane], percent = 99))\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane], norm=simple_norm(img_Source[z_mid_plane], percent = 99))\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","plt.imshow(img_Prediction[z_mid_plane], norm=simple_norm(img_Prediction[z_mid_plane], percent = 99))\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_SSIM_GTvsSource = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsSource_'+Test_FileList[-1]))\n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsS_forDisplay,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_RSE_GTvsSource = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsSource_'+Test_FileList[-1]))\n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax = 1) \n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsS_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Quality Control folder.')\n","pdResults.head()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"69aJVFfsqXbY","colab_type":"text"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"tcPNRq1TrMPB","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"Am2JSmpC0frj","colab_type":"code","cellView":"form","colab":{}},"source":["\n","#@markdown ##Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n"," \n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","#Activate the pretrained model. \n","model=CARE(config=None, name=Prediction_model_name, basedir=Prediction_model_path)\n","\n","print(\"Restoring images...\")\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","suffix = '.tif'\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='ZYX', n_tiles=n_tilesZYX)\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='ZYX') \n","\n","print(\"Images saved into the result folder:\", Result_folder)\n","\n","#Display an example\n","\n","random_choice=random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","z_mid_plane = int(x.shape[0] / 2)+1\n","\n","@interact\n","def show_results(file=os.listdir(Result_folder), z_plane=widgets.IntSlider(min=0, max=(x.shape[0]-1), step=1, value=z_mid_plane)):\n"," x = imread(Data_folder+\"/\"+file)\n"," y = imread(Result_folder+\"/\"+file)\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(x[z_plane], norm=simple_norm(x[z_plane], percent = 99), interpolation='nearest')\n"," plt.axis('off')\n"," plt.title('Noisy Input (single Z plane)');\n"," plt.subplot(1,2,2)\n"," plt.imshow(y[z_plane], norm=simple_norm(y[z_plane], percent = 99), interpolation='nearest')\n"," plt.axis('off')\n"," plt.title('Prediction (single Z plane)');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB","colab_type":"text"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"u4pcBe8Z3T2J","colab_type":"text"},"source":["#**Thank you for using CARE 3D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"CARE_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1t9a-44km730bI7F4I08-6Xh7wEZuL98p","timestamp":1591013189418},{"file_id":"11TigzvLl4FSSwFHUNwLzZKI2IAix4Nmu","timestamp":1586415689249},{"file_id":"1_dSnxUg_qtNWjrPc7D6RWDWlCanEL4Ve","timestamp":1585153449937},{"file_id":"1bKo8jYVZPPgXPa_-Gdu1KhDnNN4vYfLx","timestamp":1583200150464}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I","colab_type":"text"},"source":["# **CARE: Content-aware image restoration (3D)**\n","\n","---\n","\n","CARE is a neural network capable of image restoration from corrupted bio-images, first published in 2018 by [Weigert *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0216-7). The CARE network uses a U-Net network architecture and allows image restoration and resolution improvement in 2D and 3D images, in a supervised manner, using noisy images as input and low-noise images as targets for training. The function of the network is essentially determined by the set of images provided in the training dataset. For instance, if noisy images are provided as input and high signal-to-noise ratio images are provided as targets, the network will perform denoising.\n","\n"," **This particular notebook enables restoration of 3D dataset. If you are interested in restoring 2D dataset, you should use the CARE 2D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper: \n","\n","**Content-aware image restoration: pushing the limits of fluorescence microscopy**, by Weigert *et al.* published in Nature Methods in 2018 (https://www.nature.com/articles/s41592-018-0216-7)\n","\n","And source code found in: https://github.com/csbdeep/csbdeep\n","\n","For a more in-depth description of the features of the network,please refer to [this guide](http://csbdeep.bioimagecomputing.com/doc/) provided by the original authors of the work.\n","\n","We provide a dataset for the training of this notebook as a way to test its functionalities but the training and test data of the restoration experiments is also available from the authors of the original paper [here](https://publications.mpi-cbg.de/publications-sites/7207/).\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," For CARE to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions (for instance, low signal-to-noise ratio and high signal-to-noise ratio) and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Low SNR images\" (Training_source) and \"Training - high SNR images\" (Training_target). Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .tif files!**\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. \n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Low SNR images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - High SNR images (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"b4-r1gE7Iamv","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"BDhmUgqCStlm","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-oqBTeLaImnU","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin","colab_type":"text"},"source":["# **2. Install CARE and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install CARE and dependencies\n","\n","\n","#Here, we install libraries which are not already included in Colab.\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install wget\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","#Here, we import and enable Tensorflow 1 instead of Tensorflow 2.\n","\n","import tensorflow\n","import tensorflow as tf\n","\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# ------- Variable specific to CARE -------\n","from csbdeep.utils import download_and_extract_zip_file, normalize, plot_some, axes_dict, plot_history, Path, download_and_extract_zip_file\n","from csbdeep.data import RawData, create_patches \n","from csbdeep.io import load_training_data, save_tiff_imagej_compatible\n","from csbdeep.models import Config, CARE\n","from csbdeep import data\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"WzYAA-MuaYrT","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd","colab_type":"text"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (Low SNR images) and Training_target (High SNR images or ground truth) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number of epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-300 epochs. Evaluate the performance after training (see 5.). **Default value: 40**\n","\n","**`patch_size`:** CARE divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 80**\n","\n","**`patch_height`:** The value should be smaller than the Z dimensions of the image and divisible by 4. When analysing isotropic stacks patch_size and patch_height should have similar values.\n","\n","**When choosing the patch_size and patch_height, the values should be i) large enough that they will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**\n","\n","**`number_of_patches`:** Input the number of the patches per image. Increasing the number of patches allows for larger training datasets. **Default value: 200** \n","\n","**Decreasing the patch size or increasing the number of patches may improve the training but may also increase the training time.**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","colab_type":"code","cellView":"form","colab":{}},"source":["\n","#@markdown ###Path to training images:\n","\n","# base folder of GT and low images\n","base = \"/content\"\n","\n","# low SNR images\n","Training_source = \"\" #@param {type:\"string\"}\n","lowfile = Training_source+\"/*.tif\"\n","# Ground truth images\n","Training_target = \"\" #@param {type:\"string\"}\n","GTfile = Training_target+\"/*.tif\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","\n","# create the training data file into model_path folder.\n","training_data = model_path+\"/my_training_data.npz\"\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","\n","number_of_epochs = 40#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels) and number\n","patch_size = 80#@param {type:\"number\"} # pixels in\n","patch_height = 8#@param {type:\"number\"}\n","number_of_patches = 200#@param {type:\"number\"}\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","\n","batch_size = 16#@param {type:\"number\"}\n","number_of_steps = 300#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 16\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n","\n","percentage = percentage_validation/100\n","\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Folder already exists and has been removed !!\")\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","\n","#Load one randomly chosen training source file\n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_height is smaller than the z dimension of the image \n","\n","if patch_height > Image_Z :\n"," patch_height = Image_Z\n"," print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we check that patch_height is divisible by 4\n","if not patch_height % 4 == 0:\n"," patch_height = ((int(patch_height / 4)-1) * 4)\n"," if patch_height == 0:\n"," patch_height = 4\n"," print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n","\n","\n","#Load one randomly chosen training target file\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","\n","\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Low SNR image (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('High SNR image (single Z plane)');\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xGcl7WGP4WHt","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---"]},{"cell_type":"markdown","metadata":{"id":"5Lio8hpZ4PJ1","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by rotating the training images in the XY-Plane and flipping them along X-Axis.\n","\n","**The flip option alone will double the size of your dataset, rotation will quadruple and both together will increase the dataset by a factor of 8.**"]},{"cell_type":"code","metadata":{"id":"htqjkJWt5J_8","colab_type":"code","cellView":"form","colab":{}},"source":["Use_Data_augmentation = False #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n","\n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug(Training_source,Training_target,flip=Flip)\n"," \n"," elif Rotation == False and Flip == True:\n"," flip(Training_source,Training_target)\n"," print(\"Done\")\n","\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bQDuybvyadKU","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CARE 3D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pret-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"8vPkzEBNamE4","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL","colab_type":"text"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"tGW2iaU6X5zi","colab_type":"text"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"WMJnGJpCMa4y","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","\n","# This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images.\n","# This file is saved in .npz format and later called when loading the trainig data.\n","\n","if Use_Data_augmentation == True:\n"," Training_source = Saving_path+'/augmented_source'\n"," Training_target = Saving_path+'/augmented_target'\n","\n","raw_data = RawData.from_folder (\n"," basepath = base,\n"," source_dirs = [Training_source],\n"," target_dir = Training_target,\n"," axes = 'ZYX',\n"," pattern='*.tif*'\n",")\n","X, Y, XY_axes = create_patches (\n"," raw_data = raw_data,\n"," patch_size = (patch_height,patch_size,patch_size),\n"," n_patches_per_image = number_of_patches, \n"," save_file = training_data,\n",")\n","\n","assert X.shape == Y.shape\n","print(\"shape of X,Y =\", X.shape)\n","print(\"axes of X,Y =\", XY_axes)\n","\n","%memit \n","print ('Creating 3D training dataset')\n","\n","# Load Training Data\n","(X,Y), (X_val,Y_val), axes = load_training_data(training_data, validation_split=percentage, verbose=True)\n","c = axes_dict(axes)['C']\n","n_channel_in, n_channel_out = X.shape[c], Y.shape[c]\n","\n","#Plot example patches\n","\n","#plot of training patches.\n","plt.figure(figsize=(12,5))\n","plot_some(X[:5],Y[:5])\n","plt.suptitle('5 example training patches (top row: source, bottom row: target)');\n","\n","#plot of validation patches\n","plt.figure(figsize=(12,5))\n","plot_some(X_val[:5],Y_val[:5])\n","plt.suptitle('5 example validation patches (top row: source, bottom row: target)');\n","\n","%memit \n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size)+1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#Here, we create the default Config object which sets the hyperparameters of the network training.\n","\n","config = Config(axes, n_channel_in, n_channel_out, train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, train_batch_size=batch_size, train_learning_rate=initial_learning_rate)\n","print(config)\n","vars(config)\n","\n","# Compile the CARE model for network training\n","\n","model_training= CARE(config, model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model_training.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR","colab_type":"text"},"source":["## **4.2. Start Trainning**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point."]},{"cell_type":"code","metadata":{"id":"j_Qm5JBmlvJg","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Start Training\n","\n","start = time.time()\n","\n","# Start Training\n","history = model_training.train(X,Y, validation_data=(X_val,Y_val))\n","\n","print(\"Training, done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"w8Q_uYGgiico","colab_type":"text"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"QYuIOWQ3imuU","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"zazOZ3wDx0zQ","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"biT9FI9Ri77_","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"nAs4Wni7VYbq","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(path_metrics_save+'Prediction'):\n"," shutil.rmtree(path_metrics_save+'Prediction')\n","os.makedirs(path_metrics_save+'Prediction')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","# Activate the pretrained model. \n","model_training = CARE(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," n_slices = img.shape[0]\n"," predicted = model_training.predict(img, axes='ZYX', n_tiles=n_tilesZYX)\n"," os.chdir(path_metrics_save+'Prediction/')\n"," imsave('Predicted_'+filename, predicted)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvS_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvS_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvS_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," mSSIM_GvS_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," NRMSE_GvS_list_mean = []\n"," PSNR_GvP_list_mean = []\n"," PSNR_GvS_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," # Calculating the position of the mid-plane slice\n"," z_mid_plane = int(n_slices / 2)+1\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_SSIM_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," # -------------------------------- Normalising the dataset --------------------------------\n","\n"," test_GT_norm, test_source_norm = norm_minmse(test_GT_stack[z], test_source_stack[z], normalize_gt=True)\n"," test_GT_norm, test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = structural_similarity(test_GT_norm, test_source_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = img_as_float32(img_SSIM_GTvsPrediction, force_copy=False)\n"," img_SSIM_GTvsSource_stack[z] = img_as_float32(img_SSIM_GTvsSource, force_copy=False)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = img_as_float32(img_RSE_GTvsPrediction, force_copy=False)\n"," img_RSE_GTvsSource_stack[z] = img_as_float32(img_RSE_GTvsSource, force_copy=False)\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource), str(PSNR_GTvsPrediction), str(PSNR_GTvsSource)])\n"," \n"," # Collect values to display in dataframe output\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvS_list.append(index_SSIM_GTvsSource)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvS_list.append(NRMSE_GTvsSource)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvS_list.append(PSNR_GTvsSource)\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n"," SSIM_GTvsS_forDisplay = index_SSIM_GTvsSource\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n"," NRMSE_GTvsS_forDisplay = NRMSE_GTvsSource\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n"," mSSIM_GvS_list_mean.append(sum(mSSIM_GvS_list)/len(mSSIM_GvS_list))\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n"," NRMSE_GvS_list_mean.append(sum(NRMSE_GvS_list)/len(NRMSE_GvS_list))\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n"," PSNR_GvS_list_mean.append(sum(PSNR_GvS_list)/len(PSNR_GvS_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n","\n"," img_SSIM_GTvsSource_stack_32 = img_as_float32(img_SSIM_GTvsSource_stack, force_copy=False)\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsSource_stack_32 = img_as_float32(img_RSE_GTvsSource_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsSource_'+thisFile,img_SSIM_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsSource_'+thisFile,img_RSE_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","pdResults[\"Input v. GT mSSIM\"] = mSSIM_GvS_list_mean\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","pdResults[\"Input v. GT NRMSE\"] = NRMSE_GvS_list_mean\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","pdResults[\"Input v. GT PSNR\"] = PSNR_GvS_list_mean\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(20,20))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","\n","# Calculating the position of the mid-plane slice\n","z_mid_plane = int(img_GT.shape[0] / 2)+1\n","\n","plt.imshow(img_GT[z_mid_plane], norm=simple_norm(img_GT[z_mid_plane], percent = 99))\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane], norm=simple_norm(img_Source[z_mid_plane], percent = 99))\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","plt.imshow(img_Prediction[z_mid_plane], norm=simple_norm(img_Prediction[z_mid_plane], percent = 99))\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_SSIM_GTvsSource = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsSource_'+Test_FileList[-1]))\n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsS_forDisplay,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_RSE_GTvsSource = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsSource_'+Test_FileList[-1]))\n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax = 1) \n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsS_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Quality Control folder.')\n","pdResults.head()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"69aJVFfsqXbY","colab_type":"text"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"tcPNRq1TrMPB","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"Am2JSmpC0frj","colab_type":"code","cellView":"form","colab":{}},"source":["\n","#@markdown ##Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n"," \n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","#Activate the pretrained model. \n","model=CARE(config=None, name=Prediction_model_name, basedir=Prediction_model_path)\n","\n","print(\"Restoring images...\")\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","suffix = '.tif'\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='ZYX', n_tiles=n_tilesZYX)\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='ZYX') \n","\n","print(\"Images saved into the result folder:\", Result_folder)\n","\n","#Display an example\n","\n","random_choice=random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","z_mid_plane = int(x.shape[0] / 2)+1\n","\n","@interact\n","def show_results(file=os.listdir(Result_folder), z_plane=widgets.IntSlider(min=0, max=(x.shape[0]-1), step=1, value=z_mid_plane)):\n"," x = imread(Data_folder+\"/\"+file)\n"," y = imread(Result_folder+\"/\"+file)\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(x[z_plane], norm=simple_norm(x[z_plane], percent = 99), interpolation='nearest')\n"," plt.axis('off')\n"," plt.title('Noisy Input (single Z plane)');\n"," plt.subplot(1,2,2)\n"," plt.imshow(y[z_plane], norm=simple_norm(y[z_plane], percent = 99), interpolation='nearest')\n"," plt.axis('off')\n"," plt.title('Prediction (single Z plane)');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB","colab_type":"text"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"u4pcBe8Z3T2J","colab_type":"text"},"source":["#**Thank you for using CARE 3D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/ChangeLog.txt b/Colab_notebooks/ChangeLog.txt index 22c66371..47f0b572 100755 --- a/Colab_notebooks/ChangeLog.txt +++ b/Colab_notebooks/ChangeLog.txt @@ -19,12 +19,14 @@ Main notebooks: —————————————— - StarDist 2D, StarDist 3D, CARE 2D and CARE 3D notebooks now uses TensorFlow 2.2 (instead of TF 1.5.15) + - YOLOv2 notebook: QC section now uses the same mAP function as the training, for better consistency of results; ground-truth labels and predicted labels in the QC section are exported to the QC folder as csv files which holde the bounding box coordinates and class labels; Display of prediction results now more consistent with display of GT labels; Updated Augmentation Section, now between 2-8 times augmentation of dataset possible; Additional csv file with predicted bounding box coordinates in a format suitable for use in imageJ as Results Table is now exported to the user's results folder in the Prediction section; Added 'training_times' as hyperparameter for improved tuning of model training; Tracking of mAP during training implemented; After training, model with best validation performance, best mAP score and the last model weights are saved to allow easier performance comparison by the user; Updated explanation of parameters and QC section. - 3D U-Net: 1. Added ability to train network on non-binary targets 2. Added ability to choose loss, metrics and optimizer 3. Fixed data generator bug leading to erroneous generator length when choosing random_crop + 4. Added support for using the imgaug library and creating custom augmentation pipelines + minor modifications and bug fixes diff --git a/Colab_notebooks/CycleGAN_ZeroCostDL4Mic.ipynb b/Colab_notebooks/CycleGAN_ZeroCostDL4Mic.ipynb index 1e90d044..31b1a40a 100755 --- a/Colab_notebooks/CycleGAN_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/CycleGAN_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"CycleGAN_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1mqcexfPBaIWuvMWWbJZUFtPoZoJJwrEA","timestamp":1589278334507},{"file_id":"159ARwlQE7-zi0EHxunOF_YPFLt-ZVU5x","timestamp":1587562499898},{"file_id":"1W-7NHehG5MRFILvZZzhPWWnOdJMkadb2","timestamp":1586332290412},{"file_id":"1pUetEQICxYWkYVaQIgdRH1EZBTl7oc2A","timestamp":1586292199692},{"file_id":"1MD36ZkM6XR9EuV12zimJmfCjzyeYZFWq","timestamp":1586269469061},{"file_id":"16A2mbaHzlEElntS8qkFBOsBvZG-mUeY6","timestamp":1586253795726},{"file_id":"1gJlcjOiSxr2buDOxmcFbT_d-GqwLjXtK","timestamp":1583343225796},{"file_id":"10yGI51WzHfgWgZAyE-EbkZFEvIOd6CP6","timestamp":1583171396283}],"collapsed_sections":[],"toc_visible":true},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I","colab_type":"text"},"source":["# **CycleGAN**\n","\n","---\n","\n","CycleGAN is a method that can capture the characteristics of one image domain and learn how these characteristics can be translated into another image domain, all in the absence of any paired training examples. It was first published by [Zhu *et al.* in 2017](https://arxiv.org/abs/1703.10593). Unlike pix2pix, the image transformation performed does not require paired images for training (unsupervised learning) and is made possible here by using a set of two Generative Adversarial Networks (GANs) that learn to transform images both from the first domain to the second and vice-versa.\n","\n"," **This particular notebook enables unpaired image-to-image translation. If your dataset is paired, you should also consider using the pix2pix notebook.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n"," **Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks** from Zhu *et al.* published in arXiv in 2018 (https://arxiv.org/abs/1703.10593)\n","\n","The source code of the CycleGAN PyTorch implementation can be found in: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"N3azwKB9O0oW","colab_type":"text"},"source":["# **License**\n","\n","---"]},{"cell_type":"code","metadata":{"id":"ByW6Vqdn9sYV","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Double click to see the license information\n","\n","#------------------------- LICENSE FOR ZeroCostDL4Mic------------------------------------\n","#This ZeroCostDL4Mic notebook is distributed under the MIT licence\n","\n","\n","\n","#------------------------- LICENSE FOR CycleGAN ------------------------------------\n","\n","#Copyright (c) 2017, Jun-Yan Zhu and Taesung Park\n","#All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without\n","#modification, are permitted provided that the following conditions are met:\n","\n","#* Redistributions of source code must retain the above copyright notice, this\n","# list of conditions and the following disclaimer.\n","\n","#* Redistributions in binary form must reproduce the above copyright notice,\n","# this list of conditions and the following disclaimer in the documentation\n","# and/or other materials provided with the distribution.\n","\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n","#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n","#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n","#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n","#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n","#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n","#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n","#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n","#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n","#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n","\n","\n","#--------------------------- LICENSE FOR pix2pix --------------------------------\n","#BSD License\n","\n","#For pix2pix software\n","#Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu\n","#All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without\n","#modification, are permitted provided that the following conditions are met:\n","\n","#* Redistributions of source code must retain the above copyright notice, this\n","# list of conditions and the following disclaimer.\n","\n","#* Redistributions in binary form must reproduce the above copyright notice,\n","# this list of conditions and the following disclaimer in the documentation\n","# and/or other materials provided with the distribution.\n","\n","#----------------------------- LICENSE FOR DCGAN --------------------------------\n","#BSD License\n","\n","#For dcgan.torch software\n","\n","#Copyright (c) 2015, Facebook, Inc. All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n","\n","#Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n","\n","#Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n","\n","#Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\n","\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," To train CycleGAN, **you only need two folders containing PNG images**. The images do not need to be paired.\n","\n","While you do not need paired images to train CycleGAN, if possible, **we strongly recommend that you generate a paired dataset. This means that the same image needs to be acquired in the two conditions. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","\n"," Please note that you currently can **only use .png files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset (non-matching images) **\n"," - Training_source\n"," - img_1.png, img_2.png, ...\n"," - Training_target\n"," - img_1.png, img_2.png, ...\n"," - **Quality control dataset (matching images)**\n"," - Training_source\n"," - img_1.png, img_2.png\n"," - Training_target\n"," - img_1.png, img_2.png\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"BCPhV-pe-syw","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"VNZetvLiS1qV","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"UBrnApIUBgxv","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin","colab_type":"text"},"source":["# **2. Install CycleGAN and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install CycleGAN and dependencies\n","\n","\n","#------- Code from the cycleGAN demo notebook starts here -------\n","\n","#Here, we install libraries which are not already included in Colab.\n","\n","\n","\n","!git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","import os\n","os.chdir('pytorch-CycleGAN-and-pix2pix/')\n","!pip install -r requirements.txt\n","\n","\n","import imageio\n","from skimage import data\n","from skimage import exposure\n","from skimage.exposure import match_histograms\n","\n","from skimage.util import img_as_int\n","\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"BLmBseWbRvxL","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd","colab_type":"text"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source and Training_target training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10) epochs, but a full training should run for 200 epochs or more. Evaluate the performance after training (see 5). **Default value: 200**\n","\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`patch_size`:** CycleGAN divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 4. **Default value: 512**\n","\n","**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0002**"]},{"cell_type":"code","metadata":{"id":"pIrTwJjzwV-D","colab_type":"code","cellView":"form","colab":{}},"source":["\n","\n","#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","InputFile = Training_source+\"/*.png\"\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","OutputFile = Training_target+\"/*.png\"\n","\n","\n","#Define where the patch file will be saved\n","base = \"/content\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 200#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","patch_size = 512#@param {type:\"number\"} # in pixels\n","batch_size = 1#@param {type:\"number\"}\n","initial_learning_rate = 0.0002 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 1\n"," patch_size = 512\n"," initial_learning_rate = 0.0002\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\")\n"," \n","\n","\n","#To use Cyclegan we need to organise the data in a way the model can understand\n","\n","Saving_path= \"/content/\"+model_name\n","#Saving_path= model_path+\"/\"+model_name\n","\n","if os.path.exists(Saving_path):\n"," shutil.rmtree(Saving_path)\n","os.makedirs(Saving_path)\n","\n","TrainA_Folder = Saving_path+\"/trainA\"\n","if os.path.exists(TrainA_Folder):\n"," shutil.rmtree(TrainA_Folder)\n","os.makedirs(TrainA_Folder)\n"," \n","TrainB_Folder = Saving_path+\"/trainB\"\n","if os.path.exists(TrainB_Folder):\n"," shutil.rmtree(TrainB_Folder)\n","os.makedirs(TrainB_Folder)\n","\n","# Here we disable pre-trained model by default (in case the cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = True\n","\n","\n","# This will display a randomly chosen dataset input and output\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imageio.imread(Training_source+\"/\"+random_choice)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","\n","#Hyperparameters failsafes\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 4\n","if not patch_size % 4 == 0:\n"," patch_size = ((int(patch_size / 4)-1) * 4)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 4; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","random_choice_2 = random.choice(os.listdir(Training_target))\n","y = imageio.imread(Training_target+\"/\"+random_choice_2)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Training target')\n","plt.axis('off');\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"FX6uxFvI-CsQ","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CwMaFU1T-GtN","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by flipping the patches. \n","\n"," By default data augmentation is enabled."]},{"cell_type":"code","metadata":{"id":"kLtHIATT-0un","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"v-leE8pEWRkn","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CycleGAN model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"CbOcS3wiWV9w","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n"," h5_file_path_A = os.path.join(pretrained_model_path, \"latest_net_G_A.pth\")\n"," h5_file_path_B = os.path.join(pretrained_model_path, \"latest_net_G_B.pth\")\n","\n","# --------------------- Check the model exist ------------------------\n","\n"," if not os.path.exists(h5_file_path_A) and os.path.exists(h5_file_path_B):\n"," print(bcolors.WARNING+'WARNING: Pretrained model does not exist')\n"," Use_pretrained_model = False\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"," if os.path.exists(h5_file_path_A) and os.path.exists(h5_file_path_B):\n"," print(\"Pretrained model \"+os.path.basename(pretrained_model_path)+\" was found and will be loaded prior to training.\")\n"," \n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL","colab_type":"text"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"-A4ipz8gs3Ew","colab_type":"text"},"source":["## **4.1. Prepare the training data for training**\n","---\n","Here, we use the information from 3. to prepare the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"_V2ujGB60gDv","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Prepare the data for training\n","\n","print(\"Data preparation in progress\")\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n","os.makedirs(model_path+'/'+model_name)\n","\n","#--------------- Here we move the files to trainA and train B ---------\n","\n","\n","for f in os.listdir(Training_source):\n"," shutil.copyfile(Training_source+\"/\"+f, TrainA_Folder+\"/\"+f)\n","\n","for files in os.listdir(Training_target):\n"," shutil.copyfile(Training_target+\"/\"+files, TrainB_Folder+\"/\"+files)\n","\n","#---------------------------------------------------------------------\n","\n","# CycleGAN use number of EPOCH withouth lr decay and number of EPOCH with lr decay\n","\n","\n","number_of_epochs_lr_stable = int(number_of_epochs/2)\n","number_of_epochs_lr_decay = int(number_of_epochs/2)\n","\n","if Use_pretrained_model :\n"," for f in os.listdir(pretrained_model_path):\n"," if (f.startswith(\"latest_net_\")): \n"," shutil.copyfile(pretrained_model_path+\"/\"+f, model_path+'/'+model_name+\"/\"+f)\n","\n","print(\"Data ready for training\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR","colab_type":"text"},"source":["## **4.2. Train the network**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches or continue the training in a second Colab session."]},{"cell_type":"code","metadata":{"id":"eBD50tAgv5qf","colab_type":"code","cellView":"form","colab":{}},"source":["\n","#@markdown ##Start training\n","\n","start = time.time()\n","\n","os.chdir(\"/content\")\n","\n","#--------------------------------- Command line inputs to change CycleGAN paramaters------------\n","\n"," # basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n"," \n"," # model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n"," # dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n"," # additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n"," # visdom and HTML visualization parameters\n"," #('--display_freq', type=int, default=400, help='frequency of showing training results on screen')\n"," #('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')\n"," #('--display_id', type=int, default=1, help='window id of the web display')\n"," #('--display_server', type=str, default=\"http://localhost\", help='visdom server of the web display')\n"," #('--display_env', type=str, default='main', help='visdom display environment name (default is \"main\")')\n"," #('--display_port', type=int, default=8097, help='visdom port of the web display')\n"," #('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')\n"," #('--print_freq', type=int, default=100, help='frequency of showing training results on console')\n"," #('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')\n"," \n"," # network saving and loading parameters\n"," #('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')\n"," #('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')\n"," #('--save_by_iter', action='store_true', help='whether saves model by iteration')\n"," #('--continue_train', action='store_true', help='continue training: load the latest model')\n"," #('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')\n"," #('--phase', type=str, default='train', help='train, val, test, etc')\n"," \n"," # training parameters\n"," #('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')\n"," #('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')\n"," #('--beta1', type=float, default=0.5, help='momentum term of adam')\n"," #('--lr', type=float, default=0.0002, help='initial learning rate for adam')\n"," #('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')\n"," #('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')\n"," #('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')\n"," #('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations'\n","\n","#---------------------------------------------------------\n","\n","#----- Start the training ------------------------------------\n","if not Use_pretrained_model:\n"," if Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5\n"," if not Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --no_flip\n","\n","if Use_pretrained_model:\n"," if Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train\n"," \n"," if not Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train --no_flip\n","\n","#---------------------------------------------------------\n","\n","print(\"Training, done.\")\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"XQjQb_J_Qyku","colab_type":"text"},"source":["##**4.3. Download your model(s) from Google Drive**\n","\n","\n","---\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"2HbZd7rFqAad","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","Unfortunately loss functions curve are not very informative for GAN network. Therefore we perform the QC here using a test dataset.\n","\n","\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"PhcOwcgH3JAD","colab_type":"text"},"source":["## **5.1. Choose the model you want to assess**"]},{"cell_type":"code","metadata":{"id":"EdcnkCr9Nbl8","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"E4Yp7ogh3NGD","colab_type":"text"},"source":["## **5.2. Identify the best checkpoint to use to make predictions**"]},{"cell_type":"markdown","metadata":{"id":"1yauWCc78HKD","colab_type":"text"},"source":[" CycleGAN save model checkpoints every five epochs. Due to the stochastic nature of GAN networks, the last checkpoint is not always the best one to use. As a consequence, it can be challenging to choose the most suitable checkpoint to use to make predictions.\n","\n","This section allows you to perform predictions using all the saved checkpoints and to estimate the quality of these predictions by comparing them to the provided ground truths images. Metric used include:\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n"]},{"cell_type":"code","metadata":{"id":"2nBPucJdK3KS","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","Image_type = \"Grayscale\" #@param [\"Grayscale\", \"RGB\"]\n","\n","\n","\n","# average function\n","def Average(lst): \n"," return sum(lst) / len(lst) \n","\n","\n","# Create a quality control folder\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","# List images in Source_QC_folder\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","# Here we need to move the data to be analysed so that cycleGAN can find them\n","\n","Saving_path_QC= \"/content/\"+QC_model_name\n","\n","if os.path.exists(Saving_path_QC):\n"," shutil.rmtree(Saving_path_QC)\n","os.makedirs(Saving_path_QC)\n","\n","Saving_path_QC_folder = Saving_path_QC+\"_images\"\n","\n","if os.path.exists(Saving_path_QC_folder):\n"," shutil.rmtree(Saving_path_QC_folder)\n","os.makedirs(Saving_path_QC_folder)\n","\n","\n","#Here we copy and rename the all the checkpoint to be analysed\n","\n","for f in os.listdir(full_QC_model_path):\n"," shortname = f[:-6]\n"," shortname = shortname + \".pth\"\n"," if f.endswith(\"net_G_A.pth\"):\n"," shutil.copyfile(full_QC_model_path+f, Saving_path_QC+\"/\"+shortname)\n","\n","\n","for files in os.listdir(Source_QC_folder):\n"," shutil.copyfile(Source_QC_folder+\"/\"+files, Saving_path_QC_folder+\"/\"+files)\n"," \n","\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = int(min(Image_Y, Image_X))\n","\n","Nb_Checkpoint = len(os.listdir(Saving_path_QC))\n","\n","print(Nb_Checkpoint)\n","\n","\n","\n","## Initiate list\n","\n","Checkpoint_list = []\n","Average_ssim_score_list = []\n","\n","\n","for j in range(1, len(os.listdir(Saving_path_QC))+1):\n"," checkpoints = j*5\n","\n"," if checkpoints == Nb_Checkpoint*5:\n"," checkpoints = \"latest\"\n","\n","\n"," print(\"The checkpoint currently analysed is =\"+str(checkpoints))\n","\n"," Checkpoint_list.append(checkpoints)\n","\n","\n"," # Create a quality control/Prediction Folder\n","\n"," QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)\n","\n"," if os.path.exists(QC_prediction_results):\n"," shutil.rmtree(QC_prediction_results)\n","\n"," os.makedirs(QC_prediction_results)\n","\n","\n","\n","#---------------------------- Predictions are performed here ----------------------\n","\n"," os.chdir(\"/content\")\n","\n"," !python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$Saving_path_QC_folder\" --name \"$QC_model_name\" --model test --epoch $checkpoints --no_dropout --preprocess scale_width --load_size $Image_min_dim --crop_size $Image_min_dim --results_dir \"$QC_prediction_results\" --checkpoints_dir \"/content/\"\n","\n","#-----------------------------------------------------------------------------------\n","\n","#Here we need to move the data again and remove all the unnecessary folders\n","\n"," Checkpoint_name = \"test_\"+str(checkpoints)\n","\n"," QC_results_images = QC_prediction_results+\"/\"+QC_model_name+\"/\"+Checkpoint_name+\"/images\"\n","\n"," QC_results_images_files = os.listdir(QC_results_images)\n","\n"," for f in QC_results_images_files: \n"," shutil.copyfile(QC_results_images+\"/\"+f, QC_prediction_results+\"/\"+f)\n","\n"," os.chdir(\"/content\") \n","\n"," #Here we clean up the extra files\n"," shutil.rmtree(QC_prediction_results+\"/\"+QC_model_name)\n","\n","\n","#-------------------------------- QC for RGB ------------------------------------\n"," if Image_type == \"RGB\":\n","# List images in Source_QC_folder\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n"," random_choice = random.choice(os.listdir(Source_QC_folder))\n"," x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, multichannel=True)\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\"])\n"," \n"," \n"," # Initiate list\n"," ssim_score_list = [] \n","\n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n","\n"," shortname_no_PNG = i[:-4]\n"," \n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = imageio.imread(os.path.join(Target_QC_folder, i), as_gray=False, pilmode=\"RGB\")\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real.png\"))\n"," \n"," \n"," # -------------------------------- Prediction --------------------------------\n"," \n"," test_prediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake.png\"))\n"," \n"," #--------------------------- Here we normalise using histograms matching--------------------------------\n"," test_prediction_matched = match_histograms(test_prediction, test_GT, multichannel=True)\n"," test_source_matched = match_histograms(test_source, test_GT, multichannel=True)\n"," \n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT, test_prediction_matched)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT, test_source_matched)\n","\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n"," \n"," \n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource)])\n","\n"," #Here we calculate the ssim average for each image in each checkpoints\n","\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\n","\n","\n","\n","\n","#------------------------------------------- QC for Grayscale ----------------------------------------------\n","\n"," if Image_type == \"Grayscale\":\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n"," def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n","\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n"," def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n"," def norm_minmse(gt, x, normalize_gt=True):\n"," \n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," \n"," \n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n","\n"," ssim_score_list = []\n"," shortname_no_PNG = i[:-4]\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT_raw = imageio.imread(os.path.join(Target_QC_folder, i), as_gray=False, pilmode=\"RGB\")\n"," \n"," test_GT = test_GT_raw[:,:,2]\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real.png\"))\n"," \n"," test_source = test_source_raw[:,:,2]\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake.png\"))\n"," \n"," test_prediction = test_prediction_raw[:,:,2]\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\n","\n"," #Save ssim_maps\n"," \n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_8bit = (img_RSE_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_RSE_GTvsPrediction_8bit)\n"," img_RSE_GTvsSource_8bit = (img_RSE_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsSource_\"+shortname_no_PNG+'.tif',img_RSE_GTvsSource_8bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n"," #Here we calculate the ssim average for each image in each checkpoints\n","\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\n","\n","\n","# All data is now processed saved\n"," \n","\n","# -------------------------------- Display --------------------------------\n","\n","# Display the IoV vs Threshold plot\n","plt.figure(figsize=(20,5))\n","plt.plot(Checkpoint_list, Average_ssim_score_list, label=\"SSIM\")\n","plt.title('Checkpoints vs. SSIM')\n","plt.ylabel('SSIM')\n","plt.xlabel('Checkpoints')\n","plt.legend()\n","plt.show()\n","\n","\n","\n","# -------------------------------- Display RGB --------------------------------\n","\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","if Image_type == \"RGB\":\n"," random_choice_shortname_no_PNG = shortname_no_PNG\n","\n"," @interact\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n","\n"," random_choice_shortname_no_PNG = file[:-4]\n","\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n"," df2 = df1.set_index(\"image #\", drop = False)\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n","\n","#Setting up colours\n"," \n"," cmap = None\n","\n"," plt.figure(figsize=(10,10))\n","\n","# Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = imageio.imread(os.path.join(Target_QC_folder, file), as_gray=False, pilmode=\"RGB\")\n"," plt.imshow(img_GT, cmap = cmap)\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = imageio.imread(os.path.join(Source_QC_folder, file), as_gray=False, pilmode=\"RGB\")\n"," plt.imshow(img_Source, cmap = cmap)\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n","\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake.png\"))\n","\n"," plt.imshow(img_Prediction, cmap = cmap)\n"," plt.title('Prediction',fontsize=15)\n","\n","\n","#SSIM between GT and Source\n"," plt.subplot(3,3,5)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","#plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","\n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","#plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","\n","# -------------------------------- Display Grayscale --------------------------------\n","\n","if Image_type == \"Grayscale\":\n"," random_choice_shortname_no_PNG = shortname_no_PNG\n","\n"," @interact\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n","\n"," random_choice_shortname_no_PNG = file[:-4]\n","\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n"," df2 = df1.set_index(\"image #\", drop = False)\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n","\n"," NRMSE_GTvsPrediction = df2.loc[file, \"Prediction v. GT NRMSE\"]\n"," NRMSE_GTvsSource = df2.loc[file, \"Input v. GT NRMSE\"]\n"," PSNR_GTvsSource = df2.loc[file, \"Input v. GT PSNR\"]\n"," PSNR_GTvsPrediction = df2.loc[file, \"Prediction v. GT PSNR\"]\n"," \n","\n"," plt.figure(figsize=(15,15))\n","\n"," cmap = None\n"," \n"," # Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = imageio.imread(os.path.join(Target_QC_folder, file), as_gray=True, pilmode=\"RGB\")\n","\n"," plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99), cmap = 'gray')\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real.png\"))\n"," plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake.png\"))\n"," plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\n"," plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n"," plt.subplot(3,3,5)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_SSIM_GTvsSource = img_SSIM_GTvsSource / 255\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","\n"," \n"," plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," \n"," \n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_SSIM_GTvsPrediction = img_SSIM_GTvsPrediction / 255\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","\n"," \n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n"," plt.subplot(3,3,8)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_RSE_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_RSE_GTvsSource = img_RSE_GTvsSource / 255\n"," \n","\n"," imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n"," plt.subplot(3,3,9)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_RSE_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," img_RSE_GTvsPrediction = img_RSE_GTvsPrediction / 255\n","\n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"Esqnbew8uznk"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as PNG images.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`checkpoint`:** Choose the checkpoint number you would like to use to perform predictions. To use the \"latest\" checkpoint, input \"latest\"."]},{"cell_type":"code","metadata":{"id":"yb3suNkfpNA9","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","import glob\n","import os.path\n","\n","\n","latest = \"latest\"\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###What model checkpoint would you like to use?\n","\n","checkpoint = latest#@param {type:\"raw\"}\n","\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","#here we check if we use the newly trained network or not\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","#here we check if the model exists\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","# Here we check that checkpoint exist, if not the closest one will be chosen \n","\n","Nb_Checkpoint = len(glob.glob(os.path.join(full_Prediction_model_path, '*G_A.pth')))\n","print(Nb_Checkpoint)\n","\n","\n","if not checkpoint == \"latest\":\n","\n"," if checkpoint < 10:\n"," checkpoint = 5\n","\n"," if not checkpoint % 5 == 0:\n"," checkpoint = ((int(checkpoint / 5)-1) * 5)\n"," print (bcolors.WARNING + \" Your chosen checkpoints is not divisible by 5; therefore the checkpoints chosen is now:\",checkpoints)\n"," \n"," if checkpoint > Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n"," if checkpoint == Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n","\n","\n","\n","# Here we need to move the data to be analysed so that cycleGAN can find them\n","\n","Saving_path_prediction= \"/content/\"+Prediction_model_name\n","\n","if os.path.exists(Saving_path_prediction):\n"," shutil.rmtree(Saving_path_prediction)\n","os.makedirs(Saving_path_prediction)\n","\n","Saving_path_Data_folder = Saving_path_prediction+\"/testA\"\n","\n","if os.path.exists(Saving_path_Data_folder):\n"," shutil.rmtree(Saving_path_Data_folder)\n","os.makedirs(Saving_path_Data_folder)\n","\n","for files in os.listdir(Data_folder):\n"," shutil.copyfile(Data_folder+\"/\"+files, Saving_path_Data_folder+\"/\"+files)\n","\n","\n","Nb_files_Data_folder = len(os.listdir(Data_folder)) +10\n","\n","\n","\n","#Here we copy and rename the checkpoint to be used\n","\n","shutil.copyfile(full_Prediction_model_path+\"/\"+str(checkpoint)+\"_net_G_A.pth\", full_Prediction_model_path+\"/\"+str(checkpoint)+\"_net_G.pth\")\n","\n","\n","# This will find the image dimension of a randomly choosen image in Data_folder \n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imageio.imread(Data_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","print(Image_min_dim)\n","\n","\n","\n","#-------------------------------- Perform predictions -----------------------------\n","\n","#-------------------------------- Options that can be used to perform predictions -----------------------------\n","\n","# basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n","\n","# model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n","# dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n","# additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n","\n"," #('--ntest', type=int, default=float(\"inf\"), help='# of test examples.')\n"," #('--results_dir', type=str, default='./results/', help='saves results here.')\n"," #('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')\n"," #('--phase', type=str, default='test', help='train, val, test, etc')\n","\n","# Dropout and Batchnorm has different behavioir during training and test.\n"," #('--eval', action='store_true', help='use eval mode during test time.')\n"," #('--num_test', type=int, default=50, help='how many test images to run')\n"," # rewrite devalue values\n"," \n","# To avoid cropping, the load_size should be the same as crop_size\n"," #parser.set_defaults(load_size=parser.get_default('crop_size'))\n","\n","#------------------------------------------------------------------------\n","\n","\n","#---------------------------- Predictions are performed here ----------------------\n","\n","os.chdir(\"/content\")\n","\n","!python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$Saving_path_Data_folder\" --name \"$Prediction_model_name\" --model test --no_dropout --preprocess scale_width --load_size $Image_min_dim --crop_size $Image_min_dim --results_dir \"$Result_folder\" --checkpoints_dir \"$Prediction_model_path\" --num_test $Nb_files_Data_folder --epoch $checkpoint\n","\n","#-----------------------------------------------------------------------------------\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"EIe3CRD7XUxa","colab_type":"text"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","metadata":{"id":"LmDP8xiwXTTL","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","import os\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","\n","\n","random_choice_no_extension = os.path.splitext(random_choice)\n","\n","\n","x = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_real.png\")\n","\n","\n","y = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_fake.png\")\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Input')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Prediction')\n","plt.axis('off');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB","colab_type":"text"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"Rn9zpWpo0xNw","colab_type":"text"},"source":["\n","#**Thank you for using CycleGAN!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"CycleGAN_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1mqcexfPBaIWuvMWWbJZUFtPoZoJJwrEA","timestamp":1589278334507},{"file_id":"159ARwlQE7-zi0EHxunOF_YPFLt-ZVU5x","timestamp":1587562499898},{"file_id":"1W-7NHehG5MRFILvZZzhPWWnOdJMkadb2","timestamp":1586332290412},{"file_id":"1pUetEQICxYWkYVaQIgdRH1EZBTl7oc2A","timestamp":1586292199692},{"file_id":"1MD36ZkM6XR9EuV12zimJmfCjzyeYZFWq","timestamp":1586269469061},{"file_id":"16A2mbaHzlEElntS8qkFBOsBvZG-mUeY6","timestamp":1586253795726},{"file_id":"1gJlcjOiSxr2buDOxmcFbT_d-GqwLjXtK","timestamp":1583343225796},{"file_id":"10yGI51WzHfgWgZAyE-EbkZFEvIOd6CP6","timestamp":1583171396283}],"collapsed_sections":[],"toc_visible":true},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I","colab_type":"text"},"source":["# **CycleGAN**\n","\n","---\n","\n","CycleGAN is a method that can capture the characteristics of one image domain and learn how these characteristics can be translated into another image domain, all in the absence of any paired training examples. It was first published by [Zhu *et al.* in 2017](https://arxiv.org/abs/1703.10593). Unlike pix2pix, the image transformation performed does not require paired images for training (unsupervised learning) and is made possible here by using a set of two Generative Adversarial Networks (GANs) that learn to transform images both from the first domain to the second and vice-versa.\n","\n"," **This particular notebook enables unpaired image-to-image translation. If your dataset is paired, you should also consider using the pix2pix notebook.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n"," **Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks** from Zhu *et al.* published in arXiv in 2018 (https://arxiv.org/abs/1703.10593)\n","\n","The source code of the CycleGAN PyTorch implementation can be found in: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"N3azwKB9O0oW","colab_type":"text"},"source":["# **License**\n","\n","---"]},{"cell_type":"code","metadata":{"id":"ByW6Vqdn9sYV","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Double click to see the license information\n","\n","#------------------------- LICENSE FOR ZeroCostDL4Mic------------------------------------\n","#This ZeroCostDL4Mic notebook is distributed under the MIT licence\n","\n","\n","\n","#------------------------- LICENSE FOR CycleGAN ------------------------------------\n","\n","#Copyright (c) 2017, Jun-Yan Zhu and Taesung Park\n","#All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without\n","#modification, are permitted provided that the following conditions are met:\n","\n","#* Redistributions of source code must retain the above copyright notice, this\n","# list of conditions and the following disclaimer.\n","\n","#* Redistributions in binary form must reproduce the above copyright notice,\n","# this list of conditions and the following disclaimer in the documentation\n","# and/or other materials provided with the distribution.\n","\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n","#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n","#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n","#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n","#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n","#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n","#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n","#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n","#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n","#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n","\n","\n","#--------------------------- LICENSE FOR pix2pix --------------------------------\n","#BSD License\n","\n","#For pix2pix software\n","#Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu\n","#All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without\n","#modification, are permitted provided that the following conditions are met:\n","\n","#* Redistributions of source code must retain the above copyright notice, this\n","# list of conditions and the following disclaimer.\n","\n","#* Redistributions in binary form must reproduce the above copyright notice,\n","# this list of conditions and the following disclaimer in the documentation\n","# and/or other materials provided with the distribution.\n","\n","#----------------------------- LICENSE FOR DCGAN --------------------------------\n","#BSD License\n","\n","#For dcgan.torch software\n","\n","#Copyright (c) 2015, Facebook, Inc. All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n","\n","#Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n","\n","#Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n","\n","#Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\n","\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," To train CycleGAN, **you only need two folders containing PNG images**. The images do not need to be paired.\n","\n","While you do not need paired images to train CycleGAN, if possible, **we strongly recommend that you generate a paired dataset. This means that the same image needs to be acquired in the two conditions. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","\n"," Please note that you currently can **only use .png files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset (non-matching images) **\n"," - Training_source\n"," - img_1.png, img_2.png, ...\n"," - Training_target\n"," - img_1.png, img_2.png, ...\n"," - **Quality control dataset (matching images)**\n"," - Training_source\n"," - img_1.png, img_2.png\n"," - Training_target\n"," - img_1.png, img_2.png\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"BCPhV-pe-syw","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"VNZetvLiS1qV","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"UBrnApIUBgxv","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin","colab_type":"text"},"source":["# **2. Install CycleGAN and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install CycleGAN and dependencies\n","\n","\n","#------- Code from the cycleGAN demo notebook starts here -------\n","\n","#Here, we install libraries which are not already included in Colab.\n","\n","\n","\n","!git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","import os\n","os.chdir('pytorch-CycleGAN-and-pix2pix/')\n","!pip install -r requirements.txt\n","\n","\n","import imageio\n","from skimage import data\n","from skimage import exposure\n","from skimage.exposure import match_histograms\n","\n","from skimage.util import img_as_int\n","\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"BLmBseWbRvxL","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd","colab_type":"text"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source and Training_target training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10) epochs, but a full training should run for 200 epochs or more. Evaluate the performance after training (see 5). **Default value: 200**\n","\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`patch_size`:** CycleGAN divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 4. **Default value: 512**\n","\n","**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0002**"]},{"cell_type":"code","metadata":{"id":"pIrTwJjzwV-D","colab_type":"code","cellView":"form","colab":{}},"source":["\n","\n","#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","InputFile = Training_source+\"/*.png\"\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","OutputFile = Training_target+\"/*.png\"\n","\n","\n","#Define where the patch file will be saved\n","base = \"/content\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 200#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","patch_size = 512#@param {type:\"number\"} # in pixels\n","batch_size = 1#@param {type:\"number\"}\n","initial_learning_rate = 0.0002 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 1\n"," patch_size = 512\n"," initial_learning_rate = 0.0002\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\")\n"," \n","\n","\n","#To use Cyclegan we need to organise the data in a way the model can understand\n","\n","Saving_path= \"/content/\"+model_name\n","#Saving_path= model_path+\"/\"+model_name\n","\n","if os.path.exists(Saving_path):\n"," shutil.rmtree(Saving_path)\n","os.makedirs(Saving_path)\n","\n","TrainA_Folder = Saving_path+\"/trainA\"\n","if os.path.exists(TrainA_Folder):\n"," shutil.rmtree(TrainA_Folder)\n","os.makedirs(TrainA_Folder)\n"," \n","TrainB_Folder = Saving_path+\"/trainB\"\n","if os.path.exists(TrainB_Folder):\n"," shutil.rmtree(TrainB_Folder)\n","os.makedirs(TrainB_Folder)\n","\n","# Here we disable pre-trained model by default (in case the cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = True\n","\n","\n","# This will display a randomly chosen dataset input and output\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imageio.imread(Training_source+\"/\"+random_choice)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","\n","#Hyperparameters failsafes\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 4\n","if not patch_size % 4 == 0:\n"," patch_size = ((int(patch_size / 4)-1) * 4)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 4; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","random_choice_2 = random.choice(os.listdir(Training_target))\n","y = imageio.imread(Training_target+\"/\"+random_choice_2)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Training target')\n","plt.axis('off');\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"FX6uxFvI-CsQ","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CwMaFU1T-GtN","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by flipping the patches. \n","\n"," By default data augmentation is enabled."]},{"cell_type":"code","metadata":{"id":"kLtHIATT-0un","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"v-leE8pEWRkn","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CycleGAN model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"CbOcS3wiWV9w","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n"," h5_file_path_A = os.path.join(pretrained_model_path, \"latest_net_G_A.pth\")\n"," h5_file_path_B = os.path.join(pretrained_model_path, \"latest_net_G_B.pth\")\n","\n","# --------------------- Check the model exist ------------------------\n","\n"," if not os.path.exists(h5_file_path_A) and os.path.exists(h5_file_path_B):\n"," print(bcolors.WARNING+'WARNING: Pretrained model does not exist')\n"," Use_pretrained_model = False\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"," if os.path.exists(h5_file_path_A) and os.path.exists(h5_file_path_B):\n"," print(\"Pretrained model \"+os.path.basename(pretrained_model_path)+\" was found and will be loaded prior to training.\")\n"," \n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL","colab_type":"text"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"-A4ipz8gs3Ew","colab_type":"text"},"source":["## **4.1. Prepare the training data for training**\n","---\n","Here, we use the information from 3. to prepare the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"_V2ujGB60gDv","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Prepare the data for training\n","\n","print(\"Data preparation in progress\")\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n","os.makedirs(model_path+'/'+model_name)\n","\n","#--------------- Here we move the files to trainA and train B ---------\n","\n","\n","for f in os.listdir(Training_source):\n"," shutil.copyfile(Training_source+\"/\"+f, TrainA_Folder+\"/\"+f)\n","\n","for files in os.listdir(Training_target):\n"," shutil.copyfile(Training_target+\"/\"+files, TrainB_Folder+\"/\"+files)\n","\n","#---------------------------------------------------------------------\n","\n","# CycleGAN use number of EPOCH withouth lr decay and number of EPOCH with lr decay\n","\n","\n","number_of_epochs_lr_stable = int(number_of_epochs/2)\n","number_of_epochs_lr_decay = int(number_of_epochs/2)\n","\n","if Use_pretrained_model :\n"," for f in os.listdir(pretrained_model_path):\n"," if (f.startswith(\"latest_net_\")): \n"," shutil.copyfile(pretrained_model_path+\"/\"+f, model_path+'/'+model_name+\"/\"+f)\n","\n","print(\"Data ready for training\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR","colab_type":"text"},"source":["## **4.2. Start Trainning**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches or continue the training in a second Colab session."]},{"cell_type":"code","metadata":{"id":"eBD50tAgv5qf","colab_type":"code","cellView":"form","colab":{}},"source":["\n","#@markdown ##Start training\n","\n","start = time.time()\n","\n","os.chdir(\"/content\")\n","\n","#--------------------------------- Command line inputs to change CycleGAN paramaters------------\n","\n"," # basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n"," \n"," # model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n"," # dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n"," # additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n"," # visdom and HTML visualization parameters\n"," #('--display_freq', type=int, default=400, help='frequency of showing training results on screen')\n"," #('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')\n"," #('--display_id', type=int, default=1, help='window id of the web display')\n"," #('--display_server', type=str, default=\"http://localhost\", help='visdom server of the web display')\n"," #('--display_env', type=str, default='main', help='visdom display environment name (default is \"main\")')\n"," #('--display_port', type=int, default=8097, help='visdom port of the web display')\n"," #('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')\n"," #('--print_freq', type=int, default=100, help='frequency of showing training results on console')\n"," #('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')\n"," \n"," # network saving and loading parameters\n"," #('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')\n"," #('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')\n"," #('--save_by_iter', action='store_true', help='whether saves model by iteration')\n"," #('--continue_train', action='store_true', help='continue training: load the latest model')\n"," #('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')\n"," #('--phase', type=str, default='train', help='train, val, test, etc')\n"," \n"," # training parameters\n"," #('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')\n"," #('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')\n"," #('--beta1', type=float, default=0.5, help='momentum term of adam')\n"," #('--lr', type=float, default=0.0002, help='initial learning rate for adam')\n"," #('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')\n"," #('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')\n"," #('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')\n"," #('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations'\n","\n","#---------------------------------------------------------\n","\n","#----- Start the training ------------------------------------\n","if not Use_pretrained_model:\n"," if Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5\n"," if not Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --no_flip\n","\n","if Use_pretrained_model:\n"," if Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train\n"," \n"," if not Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train --no_flip\n","\n","#---------------------------------------------------------\n","\n","print(\"Training, done.\")\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"XQjQb_J_Qyku","colab_type":"text"},"source":["##**4.3. Download your model(s) from Google Drive**\n","\n","\n","---\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"2HbZd7rFqAad","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","Unfortunately loss functions curve are not very informative for GAN network. Therefore we perform the QC here using a test dataset.\n","\n","\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"PhcOwcgH3JAD","colab_type":"text"},"source":["## **5.1. Choose the model you want to assess**"]},{"cell_type":"code","metadata":{"id":"EdcnkCr9Nbl8","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"E4Yp7ogh3NGD","colab_type":"text"},"source":["## **5.2. Identify the best checkpoint to use to make predictions**"]},{"cell_type":"markdown","metadata":{"id":"1yauWCc78HKD","colab_type":"text"},"source":[" CycleGAN save model checkpoints every five epochs. Due to the stochastic nature of GAN networks, the last checkpoint is not always the best one to use. As a consequence, it can be challenging to choose the most suitable checkpoint to use to make predictions.\n","\n","This section allows you to perform predictions using all the saved checkpoints and to estimate the quality of these predictions by comparing them to the provided ground truths images. Metric used include:\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n"]},{"cell_type":"code","metadata":{"id":"2nBPucJdK3KS","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","Image_type = \"Grayscale\" #@param [\"Grayscale\", \"RGB\"]\n","\n","\n","\n","# average function\n","def Average(lst): \n"," return sum(lst) / len(lst) \n","\n","\n","# Create a quality control folder\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","# List images in Source_QC_folder\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","# Here we need to move the data to be analysed so that cycleGAN can find them\n","\n","Saving_path_QC= \"/content/\"+QC_model_name\n","\n","if os.path.exists(Saving_path_QC):\n"," shutil.rmtree(Saving_path_QC)\n","os.makedirs(Saving_path_QC)\n","\n","Saving_path_QC_folder = Saving_path_QC+\"_images\"\n","\n","if os.path.exists(Saving_path_QC_folder):\n"," shutil.rmtree(Saving_path_QC_folder)\n","os.makedirs(Saving_path_QC_folder)\n","\n","\n","#Here we copy and rename the all the checkpoint to be analysed\n","\n","for f in os.listdir(full_QC_model_path):\n"," shortname = f[:-6]\n"," shortname = shortname + \".pth\"\n"," if f.endswith(\"net_G_A.pth\"):\n"," shutil.copyfile(full_QC_model_path+f, Saving_path_QC+\"/\"+shortname)\n","\n","\n","for files in os.listdir(Source_QC_folder):\n"," shutil.copyfile(Source_QC_folder+\"/\"+files, Saving_path_QC_folder+\"/\"+files)\n"," \n","\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = int(min(Image_Y, Image_X))\n","\n","Nb_Checkpoint = len(os.listdir(Saving_path_QC))\n","\n","print(Nb_Checkpoint)\n","\n","\n","\n","## Initiate list\n","\n","Checkpoint_list = []\n","Average_ssim_score_list = []\n","\n","\n","for j in range(1, len(os.listdir(Saving_path_QC))+1):\n"," checkpoints = j*5\n","\n"," if checkpoints == Nb_Checkpoint*5:\n"," checkpoints = \"latest\"\n","\n","\n"," print(\"The checkpoint currently analysed is =\"+str(checkpoints))\n","\n"," Checkpoint_list.append(checkpoints)\n","\n","\n"," # Create a quality control/Prediction Folder\n","\n"," QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)\n","\n"," if os.path.exists(QC_prediction_results):\n"," shutil.rmtree(QC_prediction_results)\n","\n"," os.makedirs(QC_prediction_results)\n","\n","\n","\n","#---------------------------- Predictions are performed here ----------------------\n","\n"," os.chdir(\"/content\")\n","\n"," !python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$Saving_path_QC_folder\" --name \"$QC_model_name\" --model test --epoch $checkpoints --no_dropout --preprocess scale_width --load_size $Image_min_dim --crop_size $Image_min_dim --results_dir \"$QC_prediction_results\" --checkpoints_dir \"/content/\"\n","\n","#-----------------------------------------------------------------------------------\n","\n","#Here we need to move the data again and remove all the unnecessary folders\n","\n"," Checkpoint_name = \"test_\"+str(checkpoints)\n","\n"," QC_results_images = QC_prediction_results+\"/\"+QC_model_name+\"/\"+Checkpoint_name+\"/images\"\n","\n"," QC_results_images_files = os.listdir(QC_results_images)\n","\n"," for f in QC_results_images_files: \n"," shutil.copyfile(QC_results_images+\"/\"+f, QC_prediction_results+\"/\"+f)\n","\n"," os.chdir(\"/content\") \n","\n"," #Here we clean up the extra files\n"," shutil.rmtree(QC_prediction_results+\"/\"+QC_model_name)\n","\n","\n","#-------------------------------- QC for RGB ------------------------------------\n"," if Image_type == \"RGB\":\n","# List images in Source_QC_folder\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n"," random_choice = random.choice(os.listdir(Source_QC_folder))\n"," x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, multichannel=True)\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\"])\n"," \n"," \n"," # Initiate list\n"," ssim_score_list = [] \n","\n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n","\n"," shortname_no_PNG = i[:-4]\n"," \n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = imageio.imread(os.path.join(Target_QC_folder, i), as_gray=False, pilmode=\"RGB\")\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real.png\"))\n"," \n"," \n"," # -------------------------------- Prediction --------------------------------\n"," \n"," test_prediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake.png\"))\n"," \n"," #--------------------------- Here we normalise using histograms matching--------------------------------\n"," test_prediction_matched = match_histograms(test_prediction, test_GT, multichannel=True)\n"," test_source_matched = match_histograms(test_source, test_GT, multichannel=True)\n"," \n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT, test_prediction_matched)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT, test_source_matched)\n","\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n"," \n"," \n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource)])\n","\n"," #Here we calculate the ssim average for each image in each checkpoints\n","\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\n","\n","\n","\n","\n","#------------------------------------------- QC for Grayscale ----------------------------------------------\n","\n"," if Image_type == \"Grayscale\":\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n"," def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n","\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n"," def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n"," def norm_minmse(gt, x, normalize_gt=True):\n"," \n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," \n"," \n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n","\n"," ssim_score_list = []\n"," shortname_no_PNG = i[:-4]\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT_raw = imageio.imread(os.path.join(Target_QC_folder, i), as_gray=False, pilmode=\"RGB\")\n"," \n"," test_GT = test_GT_raw[:,:,2]\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real.png\"))\n"," \n"," test_source = test_source_raw[:,:,2]\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake.png\"))\n"," \n"," test_prediction = test_prediction_raw[:,:,2]\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\n","\n"," #Save ssim_maps\n"," \n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_8bit = (img_RSE_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_RSE_GTvsPrediction_8bit)\n"," img_RSE_GTvsSource_8bit = (img_RSE_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsSource_\"+shortname_no_PNG+'.tif',img_RSE_GTvsSource_8bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n"," #Here we calculate the ssim average for each image in each checkpoints\n","\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\n","\n","\n","# All data is now processed saved\n"," \n","\n","# -------------------------------- Display --------------------------------\n","\n","# Display the IoV vs Threshold plot\n","plt.figure(figsize=(20,5))\n","plt.plot(Checkpoint_list, Average_ssim_score_list, label=\"SSIM\")\n","plt.title('Checkpoints vs. SSIM')\n","plt.ylabel('SSIM')\n","plt.xlabel('Checkpoints')\n","plt.legend()\n","plt.show()\n","\n","\n","\n","# -------------------------------- Display RGB --------------------------------\n","\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","if Image_type == \"RGB\":\n"," random_choice_shortname_no_PNG = shortname_no_PNG\n","\n"," @interact\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n","\n"," random_choice_shortname_no_PNG = file[:-4]\n","\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n"," df2 = df1.set_index(\"image #\", drop = False)\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n","\n","#Setting up colours\n"," \n"," cmap = None\n","\n"," plt.figure(figsize=(10,10))\n","\n","# Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = imageio.imread(os.path.join(Target_QC_folder, file), as_gray=False, pilmode=\"RGB\")\n"," plt.imshow(img_GT, cmap = cmap)\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = imageio.imread(os.path.join(Source_QC_folder, file), as_gray=False, pilmode=\"RGB\")\n"," plt.imshow(img_Source, cmap = cmap)\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n","\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake.png\"))\n","\n"," plt.imshow(img_Prediction, cmap = cmap)\n"," plt.title('Prediction',fontsize=15)\n","\n","\n","#SSIM between GT and Source\n"," plt.subplot(3,3,5)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","#plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","\n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","#plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","\n","# -------------------------------- Display Grayscale --------------------------------\n","\n","if Image_type == \"Grayscale\":\n"," random_choice_shortname_no_PNG = shortname_no_PNG\n","\n"," @interact\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n","\n"," random_choice_shortname_no_PNG = file[:-4]\n","\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n"," df2 = df1.set_index(\"image #\", drop = False)\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n","\n"," NRMSE_GTvsPrediction = df2.loc[file, \"Prediction v. GT NRMSE\"]\n"," NRMSE_GTvsSource = df2.loc[file, \"Input v. GT NRMSE\"]\n"," PSNR_GTvsSource = df2.loc[file, \"Input v. GT PSNR\"]\n"," PSNR_GTvsPrediction = df2.loc[file, \"Prediction v. GT PSNR\"]\n"," \n","\n"," plt.figure(figsize=(15,15))\n","\n"," cmap = None\n"," \n"," # Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = imageio.imread(os.path.join(Target_QC_folder, file), as_gray=True, pilmode=\"RGB\")\n","\n"," plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99), cmap = 'gray')\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real.png\"))\n"," plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake.png\"))\n"," plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\n"," plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n"," plt.subplot(3,3,5)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_SSIM_GTvsSource = img_SSIM_GTvsSource / 255\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","\n"," \n"," plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," \n"," \n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_SSIM_GTvsPrediction = img_SSIM_GTvsPrediction / 255\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","\n"," \n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n"," plt.subplot(3,3,8)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_RSE_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_RSE_GTvsSource = img_RSE_GTvsSource / 255\n"," \n","\n"," imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n"," plt.subplot(3,3,9)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_RSE_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," img_RSE_GTvsPrediction = img_RSE_GTvsPrediction / 255\n","\n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"Esqnbew8uznk"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as PNG images.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`checkpoint`:** Choose the checkpoint number you would like to use to perform predictions. To use the \"latest\" checkpoint, input \"latest\"."]},{"cell_type":"code","metadata":{"id":"yb3suNkfpNA9","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","import glob\n","import os.path\n","\n","\n","latest = \"latest\"\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###What model checkpoint would you like to use?\n","\n","checkpoint = latest#@param {type:\"raw\"}\n","\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","#here we check if we use the newly trained network or not\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","#here we check if the model exists\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","# Here we check that checkpoint exist, if not the closest one will be chosen \n","\n","Nb_Checkpoint = len(glob.glob(os.path.join(full_Prediction_model_path, '*G_A.pth')))\n","print(Nb_Checkpoint)\n","\n","\n","if not checkpoint == \"latest\":\n","\n"," if checkpoint < 10:\n"," checkpoint = 5\n","\n"," if not checkpoint % 5 == 0:\n"," checkpoint = ((int(checkpoint / 5)-1) * 5)\n"," print (bcolors.WARNING + \" Your chosen checkpoints is not divisible by 5; therefore the checkpoints chosen is now:\",checkpoints)\n"," \n"," if checkpoint > Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n"," if checkpoint == Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n","\n","\n","\n","# Here we need to move the data to be analysed so that cycleGAN can find them\n","\n","Saving_path_prediction= \"/content/\"+Prediction_model_name\n","\n","if os.path.exists(Saving_path_prediction):\n"," shutil.rmtree(Saving_path_prediction)\n","os.makedirs(Saving_path_prediction)\n","\n","Saving_path_Data_folder = Saving_path_prediction+\"/testA\"\n","\n","if os.path.exists(Saving_path_Data_folder):\n"," shutil.rmtree(Saving_path_Data_folder)\n","os.makedirs(Saving_path_Data_folder)\n","\n","for files in os.listdir(Data_folder):\n"," shutil.copyfile(Data_folder+\"/\"+files, Saving_path_Data_folder+\"/\"+files)\n","\n","\n","Nb_files_Data_folder = len(os.listdir(Data_folder)) +10\n","\n","\n","\n","#Here we copy and rename the checkpoint to be used\n","\n","shutil.copyfile(full_Prediction_model_path+\"/\"+str(checkpoint)+\"_net_G_A.pth\", full_Prediction_model_path+\"/\"+str(checkpoint)+\"_net_G.pth\")\n","\n","\n","# This will find the image dimension of a randomly choosen image in Data_folder \n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imageio.imread(Data_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","print(Image_min_dim)\n","\n","\n","\n","#-------------------------------- Perform predictions -----------------------------\n","\n","#-------------------------------- Options that can be used to perform predictions -----------------------------\n","\n","# basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n","\n","# model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n","# dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n","# additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n","\n"," #('--ntest', type=int, default=float(\"inf\"), help='# of test examples.')\n"," #('--results_dir', type=str, default='./results/', help='saves results here.')\n"," #('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')\n"," #('--phase', type=str, default='test', help='train, val, test, etc')\n","\n","# Dropout and Batchnorm has different behavioir during training and test.\n"," #('--eval', action='store_true', help='use eval mode during test time.')\n"," #('--num_test', type=int, default=50, help='how many test images to run')\n"," # rewrite devalue values\n"," \n","# To avoid cropping, the load_size should be the same as crop_size\n"," #parser.set_defaults(load_size=parser.get_default('crop_size'))\n","\n","#------------------------------------------------------------------------\n","\n","\n","#---------------------------- Predictions are performed here ----------------------\n","\n","os.chdir(\"/content\")\n","\n","!python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$Saving_path_Data_folder\" --name \"$Prediction_model_name\" --model test --no_dropout --preprocess scale_width --load_size $Image_min_dim --crop_size $Image_min_dim --results_dir \"$Result_folder\" --checkpoints_dir \"$Prediction_model_path\" --num_test $Nb_files_Data_folder --epoch $checkpoint\n","\n","#-----------------------------------------------------------------------------------\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"EIe3CRD7XUxa","colab_type":"text"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","metadata":{"id":"LmDP8xiwXTTL","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","import os\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","\n","\n","random_choice_no_extension = os.path.splitext(random_choice)\n","\n","\n","x = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_real.png\")\n","\n","\n","y = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_fake.png\")\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Input')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Prediction')\n","plt.axis('off');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB","colab_type":"text"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"Rn9zpWpo0xNw","colab_type":"text"},"source":["\n","#**Thank you for using CycleGAN!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Deep-STORM_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Deep-STORM_2D_ZeroCostDL4Mic.ipynb index d480b27b..cb0b72d1 100755 --- a/Colab_notebooks/Deep-STORM_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Deep-STORM_2D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Deep-STORM_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"169qcwQo-yw15PwoGatXAdBvjs4wt_foD","timestamp":1592147948265},{"file_id":"1gjRCgDORKi_GNBu4QnVCBkSWrfPtqL-E","timestamp":1588525976305},{"file_id":"1DFy6aCi1XAVdjA5KLRZirB2aMZkMFdv-","timestamp":1587998755430},{"file_id":"1NpzigQoXGy3GFdxh4_jvG1PnBfyrcpBs","timestamp":1587569988032},{"file_id":"1jdI540qAfMSQwjnMhoAFkGJH9EbHwNSf","timestamp":1587486196143}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"FpCtYevLHfl4","colab_type":"text"},"source":["# **Deep-STORM (2D)**\n","\n","---\n","\n","Deep-STORM is a neural network capable of image reconstruction from high-density single-molecule localization microscopy (SMLM), first published in 2018 by [Nehme *et al.* in Optica](https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458). The architecture used here is a U-Net based network without skip connections. This network allows image reconstruction of 2D super-resolution images, in a supervised training manner. The network is trained using simulated high-density SMLM data for which the ground-truth is available. These simulations are obtained from random distribution of single molecules in a field-of-view and therefore do not imprint structural priors during training. The network output a super-resolution image with increased pixel density (typically upsampling factor of 8 in each dimension).\n","\n","Deep-STORM has **two key advantages**:\n","- SMLM reconstruction at high density of emitters\n","- fast prediction (reconstruction) once the model is trained appropriately, compared to more common multi-emitter fitting processes.\n","\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n","**Deep-STORM: super-resolution single-molecule microscopy by deep learning**, Optica (2018) by *Elias Nehme, Lucien E. Weiss, Tomer Michaeli, and Yoav Shechtman* (https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458)\n","\n","And source code found in: https://github.com/EliasNehme/Deep-STORM\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"wyzTn3IcHq6Y","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"bEy4EBXHHyAX","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," Deep-STORM is able to train on simulated dataset of SMLM data (see https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458 for more info). Here, we provide a simulator that will generate training dataset (section 3.1.b). A few parameters will allow you to match the simulation to your experimental data. Similarly to what is described in the paper, simulations obtained from ThunderSTORM can also be loaded here (section 3.1.a).\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"E04mOlG_H5Tz","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"F_tjlGzsH-Dn","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"gn-LaaNNICqL","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","# %tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.__version__ != '2.2.0':\n"," !pip install tensorflow==2.2.0\n","\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime settings are correct then Google did not allocate GPU to your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n","\n","# from tensorflow.python.client import device_lib \n","# device_lib.list_local_devices()\n","\n","# print the tensorflow version\n","print('Tensorflow version is ' + str(tf.__version__))\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"tnP7wM79IKW-","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"1R-7Fo34_gOd","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jRnQZWSZhArJ","colab_type":"text"},"source":["# **2. Install Deep-STORM and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"kSrZMo3X_NhO","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install Deep-STORM and dependencies\n","\n","# %% Model definition + helper functions\n","\n","# Import keras modules and libraries\n","from tensorflow import keras\n","from tensorflow.keras.models import Model\n","from tensorflow.keras.layers import Input, Activation, UpSampling2D, Convolution2D, MaxPooling2D, BatchNormalization, Layer\n","from tensorflow.keras.callbacks import Callback\n","from tensorflow.keras import backend as K\n","from tensorflow.keras import optimizers, losses\n","\n","from tensorflow.keras.preprocessing.image import ImageDataGenerator\n","from tensorflow.keras.callbacks import ModelCheckpoint\n","from tensorflow.keras.callbacks import ReduceLROnPlateau\n","from skimage.transform import warp\n","from skimage.transform import SimilarityTransform\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from scipy.signal import fftconvolve\n","\n","# Import common libraries\n","import tensorflow as tf\n","import numpy as np\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","import h5py\n","import scipy.io as sio\n","from os.path import abspath\n","from sklearn.model_selection import train_test_split\n","from skimage import io\n","import time\n","import os\n","import shutil\n","import csv\n","from PIL import Image \n","from PIL.TiffTags import TAGS\n","from scipy.ndimage import gaussian_filter\n","import math\n","from astropy.visualization import simple_norm\n","from sys import getsizeof\n","\n","# For sliders and dropdown menu, progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","from tqdm import tqdm\n","\n","# For Multi-threading in simulation\n","from numba import njit, prange\n","\n","\n","# define a function that projects and rescales an image to the range [0,1]\n","def project_01(im):\n"," im = np.squeeze(im)\n"," min_val = im.min()\n"," max_val = im.max()\n"," return (im - min_val)/(max_val - min_val)\n","\n","# normalize image given mean and std\n","def normalize_im(im, dmean, dstd):\n"," im = np.squeeze(im)\n"," im_norm = np.zeros(im.shape,dtype=np.float32)\n"," im_norm = (im - dmean)/dstd\n"," return im_norm\n","\n","# Define the loss history recorder\n","class LossHistory(Callback):\n"," def on_train_begin(self, logs={}):\n"," self.losses = []\n","\n"," def on_batch_end(self, batch, logs={}):\n"," self.losses.append(logs.get('loss'))\n"," \n","# Define a matlab like gaussian 2D filter\n","def matlab_style_gauss2D(shape=(7,7),sigma=1):\n"," \"\"\" \n"," 2D gaussian filter - should give the same result as:\n"," MATLAB's fspecial('gaussian',[shape],[sigma]) \n"," \"\"\"\n"," m,n = [(ss-1.)/2. for ss in shape]\n"," y,x = np.ogrid[-m:m+1,-n:n+1]\n"," h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )\n"," h.astype(dtype=K.floatx())\n"," h[ h < np.finfo(h.dtype).eps*h.max() ] = 0\n"," sumh = h.sum()\n"," if sumh != 0:\n"," h /= sumh\n"," h = h*2.0\n"," h = h.astype('float32')\n"," return h\n","\n","# Expand the filter dimensions\n","psf_heatmap = matlab_style_gauss2D(shape = (7,7),sigma=1)\n","gfilter = tf.reshape(psf_heatmap, [7, 7, 1, 1])\n","\n","# Combined MSE + L1 loss\n","def L1L2loss(input_shape):\n"," def bump_mse(heatmap_true, spikes_pred):\n","\n"," # generate the heatmap corresponding to the predicted spikes\n"," heatmap_pred = K.conv2d(spikes_pred, gfilter, strides=(1, 1), padding='same')\n","\n"," # heatmaps MSE\n"," loss_heatmaps = losses.mean_squared_error(heatmap_true,heatmap_pred)\n","\n"," # l1 on the predicted spikes\n"," loss_spikes = losses.mean_absolute_error(spikes_pred,tf.zeros(input_shape))\n"," return loss_heatmaps + loss_spikes\n"," return bump_mse\n","\n","# Define the concatenated conv2, batch normalization, and relu block\n","def conv_bn_relu(nb_filter, rk, ck, name):\n"," def f(input):\n"," conv = Convolution2D(nb_filter, kernel_size=(rk, ck), strides=(1,1),\\\n"," padding=\"same\", use_bias=False,\\\n"," kernel_initializer=\"Orthogonal\",name='conv-'+name)(input)\n"," conv_norm = BatchNormalization(name='BN-'+name)(conv)\n"," conv_norm_relu = Activation(activation = \"relu\",name='Relu-'+name)(conv_norm)\n"," return conv_norm_relu\n"," return f\n","\n","# Define the model architechture\n","def CNN(input,names):\n"," Features1 = conv_bn_relu(32,3,3,names+'F1')(input)\n"," pool1 = MaxPooling2D(pool_size=(2,2),name=names+'Pool1')(Features1)\n"," Features2 = conv_bn_relu(64,3,3,names+'F2')(pool1)\n"," pool2 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool2')(Features2)\n"," Features3 = conv_bn_relu(128,3,3,names+'F3')(pool2)\n"," pool3 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool3')(Features3)\n"," Features4 = conv_bn_relu(512,3,3,names+'F4')(pool3)\n"," up5 = UpSampling2D(size=(2, 2),name=names+'Upsample1')(Features4)\n"," Features5 = conv_bn_relu(128,3,3,names+'F5')(up5)\n"," up6 = UpSampling2D(size=(2, 2),name=names+'Upsample2')(Features5)\n"," Features6 = conv_bn_relu(64,3,3,names+'F6')(up6)\n"," up7 = UpSampling2D(size=(2, 2),name=names+'Upsample3')(Features6)\n"," Features7 = conv_bn_relu(32,3,3,names+'F7')(up7)\n"," return Features7\n","\n","# Define the Model building for an arbitrary input size\n","def buildModel(input_dim, initial_learning_rate = 0.001):\n"," input_ = Input (shape = (input_dim))\n"," act_ = CNN (input_,'CNN')\n"," density_pred = Convolution2D(1, kernel_size=(1, 1), strides=(1, 1), padding=\"same\",\\\n"," activation=\"linear\", use_bias = False,\\\n"," kernel_initializer=\"Orthogonal\",name='Prediction')(act_)\n"," model = Model (inputs= input_, outputs=density_pred)\n"," opt = optimizers.Adam(lr = initial_learning_rate)\n"," model.compile(optimizer=opt, loss = L1L2loss(input_dim))\n"," return model\n","\n","\n","# define a function that trains a model for a given data SNR and density\n","def train_model(patches, heatmaps, modelPath, epochs, steps_per_epoch, batch_size, upsampling_factor=8, validation_split = 0.3, initial_learning_rate = 0.001, pretrained_model_path = '', L2_weighting_factor = 100):\n"," \n"," \"\"\"\n"," This function trains a CNN model on the desired training set, given the \n"," upsampled training images and labels generated in MATLAB.\n"," \n"," # Inputs\n"," # TO UPDATE ----------\n","\n"," # Outputs\n"," function saves the weights of the trained model to a hdf5, and the \n"," normalization factors to a mat file. These will be loaded later for testing \n"," the model in test_model. \n"," \"\"\"\n"," \n"," # for reproducibility\n"," np.random.seed(123)\n","\n"," X_train, X_test, y_train, y_test = train_test_split(patches, heatmaps, test_size = validation_split, random_state=42)\n"," print('Number of training examples: %d' % X_train.shape[0])\n"," print('Number of validation examples: %d' % X_test.shape[0])\n"," \n"," # Setting type\n"," X_train = X_train.astype('float32')\n"," X_test = X_test.astype('float32')\n"," y_train = y_train.astype('float32')\n"," y_test = y_test.astype('float32')\n","\n"," \n"," #===================== Training set normalization ==========================\n"," # normalize training images to be in the range [0,1] and calculate the \n"," # training set mean and std\n"," mean_train = np.zeros(X_train.shape[0],dtype=np.float32)\n"," std_train = np.zeros(X_train.shape[0], dtype=np.float32)\n"," for i in range(X_train.shape[0]):\n"," X_train[i, :, :] = project_01(X_train[i, :, :])\n"," mean_train[i] = X_train[i, :, :].mean()\n"," std_train[i] = X_train[i, :, :].std()\n","\n"," # resulting normalized training images\n"," mean_val_train = mean_train.mean()\n"," std_val_train = std_train.mean()\n"," X_train_norm = np.zeros(X_train.shape, dtype=np.float32)\n"," for i in range(X_train.shape[0]):\n"," X_train_norm[i, :, :] = normalize_im(X_train[i, :, :], mean_val_train, std_val_train)\n"," \n"," # patch size\n"," psize = X_train_norm.shape[1]\n","\n"," # Reshaping\n"," X_train_norm = X_train_norm.reshape(X_train.shape[0], psize, psize, 1)\n","\n"," # ===================== Test set normalization ==========================\n"," # normalize test images to be in the range [0,1] and calculate the test set \n"," # mean and std\n"," mean_test = np.zeros(X_test.shape[0],dtype=np.float32)\n"," std_test = np.zeros(X_test.shape[0], dtype=np.float32)\n"," for i in range(X_test.shape[0]):\n"," X_test[i, :, :] = project_01(X_test[i, :, :])\n"," mean_test[i] = X_test[i, :, :].mean()\n"," std_test[i] = X_test[i, :, :].std()\n","\n"," # resulting normalized test images\n"," mean_val_test = mean_test.mean()\n"," std_val_test = std_test.mean()\n"," X_test_norm = np.zeros(X_test.shape, dtype=np.float32)\n"," for i in range(X_test.shape[0]):\n"," X_test_norm[i, :, :] = normalize_im(X_test[i, :, :], mean_val_test, std_val_test)\n"," \n"," # Reshaping\n"," X_test_norm = X_test_norm.reshape(X_test.shape[0], psize, psize, 1)\n","\n"," # Reshaping labels\n"," Y_train = y_train.reshape(y_train.shape[0], psize, psize, 1)\n"," Y_test = y_test.reshape(y_test.shape[0], psize, psize, 1)\n","\n"," # Save datasets to a matfile to open later in matlab\n"," mdict = {\"mean_test\": mean_val_test, \"std_test\": std_val_test, \"upsampling_factor\": upsampling_factor, \"Normalization factor\": L2_weighting_factor}\n"," sio.savemat(os.path.join(modelPath,\"model_metadata.mat\"), mdict)\n","\n","\n"," # Set the dimensions ordering according to tensorflow consensous\n"," # K.set_image_dim_ordering('tf')\n"," K.set_image_data_format('channels_last')\n","\n"," # Save the model weights after each epoch if the validation loss decreased\n"," checkpointer = ModelCheckpoint(filepath=os.path.join(modelPath,\"weights_best.hdf5\"), verbose=1,\n"," save_best_only=True)\n","\n"," # Change learning when loss reaches a plataeu\n"," change_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=0.00005)\n"," \n"," # Model building and complitation\n"," model = buildModel((psize, psize, 1), initial_learning_rate = initial_learning_rate)\n"," model.summary()\n","\n"," # Load pretrained model\n"," if not pretrained_model_path:\n"," print('Using random initial model weights.')\n"," else:\n"," print('Loading model weights from '+pretrained_model_path)\n"," model.load_weights(pretrained_model_path)\n"," \n"," # Create an image data generator for real time data augmentation\n"," datagen = ImageDataGenerator(\n"," featurewise_center=False, # set input mean to 0 over the dataset\n"," samplewise_center=False, # set each sample mean to 0\n"," featurewise_std_normalization=False, # divide inputs by std of the dataset\n"," samplewise_std_normalization=False, # divide each input by its std\n"," zca_whitening=False, # apply ZCA whitening\n"," rotation_range=0., # randomly rotate images in the range (degrees, 0 to 180)\n"," width_shift_range=0., # randomly shift images horizontally (fraction of total width)\n"," height_shift_range=0., # randomly shift images vertically (fraction of total height)\n"," zoom_range=0.,\n"," shear_range=0.,\n"," horizontal_flip=False, # randomly flip images\n"," vertical_flip=False, # randomly flip images\n"," fill_mode='constant',\n"," data_format=K.image_data_format())\n","\n"," # Fit the image generator on the training data\n"," datagen.fit(X_train_norm)\n"," \n"," # loss history recorder\n"," history = LossHistory()\n","\n"," # Inform user training begun\n"," print('-------------------------------')\n"," print('Training model...')\n","\n"," # Fit model on the batches generated by datagen.flow()\n"," train_history = model.fit_generator(datagen.flow(X_train_norm, Y_train, batch_size=batch_size), \n"," steps_per_epoch=steps_per_epoch, epochs=epochs, verbose=1, \n"," validation_data=(X_test_norm, Y_test), \n"," callbacks=[history, checkpointer, change_lr]) \n","\n"," # Inform user training ended\n"," print('-------------------------------')\n"," print('Training Complete!')\n"," \n"," # Save the last model\n"," model.save(os.path.join(modelPath, 'weights_last.hdf5'))\n","\n"," # convert the history.history dict to a pandas DataFrame: \n"," lossData = pd.DataFrame(train_history.history) \n","\n"," if os.path.exists(os.path.join(modelPath,\"Quality Control\")):\n"," shutil.rmtree(os.path.join(modelPath,\"Quality Control\"))\n","\n"," os.makedirs(os.path.join(modelPath,\"Quality Control\"))\n","\n"," # The training evaluation.csv is saved (overwrites the Files if needed). \n"," lossDataCSVpath = os.path.join(modelPath,\"Quality Control/training_evaluation.csv\")\n"," with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss','learning rate'])\n"," for i in range(len(train_history.history['loss'])):\n"," writer.writerow([train_history.history['loss'][i], train_history.history['val_loss'][i], train_history.history['lr'][i]])\n","\n"," return\n","\n","\n","# Normalization functions from Martin Weigert used in CARE\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","\n","# Multi-threaded Erf-based image construction\n","@njit(parallel=True)\n","def FromLoc2Image_Erf(xc_array, yc_array, photon_array, sigma_array, image_size = (64,64), pixel_size = 100):\n"," w = image_size[0]\n"," h = image_size[1]\n"," erfImage = np.zeros((w, h))\n"," for ij in prange(w*h):\n"," j = int(ij/w)\n"," i = ij - j*w\n"," for (xc, yc, photon, sigma) in zip(xc_array, yc_array, photon_array, sigma_array):\n"," # Don't bother if the emitter has photons <= 0 or if Sigma <= 0\n"," if (sigma > 0) and (photon > 0):\n"," S = sigma*math.sqrt(2)\n"," x = i*pixel_size - xc\n"," y = j*pixel_size - yc\n"," # Don't bother if the emitter is further than 4 sigma from the centre of the pixel\n"," if (x+pixel_size/2)**2 + (y+pixel_size/2)**2 < 16*sigma**2:\n"," ErfX = math.erf((x+pixel_size)/S) - math.erf(x/S)\n"," ErfY = math.erf((y+pixel_size)/S) - math.erf(y/S)\n"," erfImage[j][i] += 0.25*photon*ErfX*ErfY\n"," return erfImage\n","\n","\n","@njit(parallel=True)\n","def FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = (64,64), pixel_size = 100):\n"," w = image_size[0]\n"," h = image_size[1]\n"," locImage = np.zeros((image_size[0],image_size[1]) )\n"," n_locs = len(xc_array)\n","\n"," for e in prange(n_locs):\n"," locImage[int(max(min(round(yc_array[e]/pixel_size),w-1),0))][int(max(min(round(xc_array[e]/pixel_size),h-1),0))] += 1\n","\n"," return locImage\n","\n","\n","\n","def getPixelSizeTIFFmetadata(TIFFpath, display=False):\n"," with Image.open(TIFFpath) as img:\n"," meta_dict = {TAGS[key] : img.tag[key] for key in img.tag.keys()}\n","\n","\n"," # TIFF tags\n"," # https://www.loc.gov/preservation/digital/formats/content/tiff_tags.shtml\n"," # https://www.awaresystems.be/imaging/tiff/tifftags/resolutionunit.html\n"," ResolutionUnit = meta_dict['ResolutionUnit'][0] # unit of resolution\n"," width = meta_dict['ImageWidth'][0]\n"," height = meta_dict['ImageLength'][0]\n","\n"," xResolution = meta_dict['XResolution'][0] # number of pixels / ResolutionUnit\n","\n"," if len(xResolution) == 1:\n"," xResolution = xResolution[0]\n"," elif len(xResolution) == 2:\n"," xResolution = xResolution[0]/xResolution[1]\n"," else:\n"," print('Image resolution not defined.')\n"," xResolution = 1\n","\n"," if ResolutionUnit == 2:\n"," # Units given are in inches\n"," pixel_size = 0.025*1e9/xResolution\n"," elif ResolutionUnit == 3:\n"," # Units given are in cm\n"," pixel_size = 0.01*1e9/xResolution\n"," else: \n"," # ResolutionUnit is therefore 1\n"," print('Resolution unit not defined. Assuming: um')\n"," pixel_size = 1e3/xResolution\n","\n"," if display:\n"," print('Pixel size obtained from metadata: '+str(pixel_size)+' nm')\n"," print('Image size: '+str(width)+'x'+str(height))\n"," \n"," return (pixel_size, width, height)\n","\n","\n","def saveAsTIF(path, filename, array, pixel_size):\n"," \"\"\"\n"," Image saving using PIL to save as .tif format\n"," # Input \n"," path - path where it will be saved\n"," filename - name of the file to save (no extension)\n"," array - numpy array conatining the data at the required format\n"," pixel_size - physical size of pixels in nanometers (identical for x and y)\n"," \"\"\"\n","\n"," # print('Data type: '+str(array.dtype))\n"," if (array.dtype == np.uint16):\n"," mode = 'I;16'\n"," elif (array.dtype == np.uint32):\n"," mode = 'I'\n"," else:\n"," mode = 'F'\n","\n"," # Rounding the pixel size to the nearest number that divides exactly 1cm.\n"," # Resolution needs to be a rational number --> see TIFF format\n"," # pixel_size = 10000/(round(10000/pixel_size))\n","\n"," if len(array.shape) == 2:\n"," im = Image.fromarray(array)\n"," im.save(os.path.join(path, filename+'.tif'),\n"," mode = mode, \n"," resolution_unit = 3,\n"," resolution = 0.01*1e9/pixel_size)\n","\n","\n"," elif len(array.shape) == 3:\n"," imlist = []\n"," for frame in array:\n"," imlist.append(Image.fromarray(frame))\n","\n"," imlist[0].save(os.path.join(path, filename+'.tif'), save_all=True,\n"," append_images=imlist[1:],\n"," mode = mode, \n"," resolution_unit = 3,\n"," resolution = 0.01*1e9/pixel_size)\n","\n"," return\n","\n","\n","\n","\n","class Maximafinder(Layer):\n"," def __init__(self, thresh, neighborhood_size, use_local_avg, **kwargs):\n"," super(Maximafinder, self).__init__(**kwargs)\n"," self.thresh = tf.constant(thresh, dtype=tf.float32)\n"," self.nhood = neighborhood_size\n"," self.use_local_avg = use_local_avg\n","\n"," def build(self, input_shape):\n"," if self.use_local_avg is True:\n"," self.kernel_x = tf.reshape(tf.constant([[-1,0,1],[-1,0,1],[-1,0,1]], dtype=tf.float32), [3, 3, 1, 1])\n"," self.kernel_y = tf.reshape(tf.constant([[-1,-1,-1],[0,0,0],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])\n"," self.kernel_sum = tf.reshape(tf.constant([[1,1,1],[1,1,1],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])\n","\n"," def call(self, inputs):\n","\n"," # local maxima positions\n"," max_pool_image = MaxPooling2D(pool_size=(self.nhood,self.nhood), strides=(1,1), padding='same')(inputs)\n"," cond = tf.math.greater(max_pool_image, self.thresh) & tf.math.equal(max_pool_image, inputs)\n"," indices = tf.where(cond)\n"," bind, xind, yind = indices[:, 0], indices[:, 2], indices[:, 1]\n"," confidence = tf.gather_nd(inputs, indices)\n","\n"," # local CoG estimator\n"," if self.use_local_avg:\n"," x_image = K.conv2d(inputs, self.kernel_x, padding='same')\n"," y_image = K.conv2d(inputs, self.kernel_y, padding='same')\n"," sum_image = K.conv2d(inputs, self.kernel_sum, padding='same')\n"," confidence = tf.cast(tf.gather_nd(sum_image, indices), dtype=tf.float32)\n"," x_local = tf.math.divide(tf.gather_nd(x_image, indices),tf.gather_nd(sum_image, indices))\n"," y_local = tf.math.divide(tf.gather_nd(y_image, indices),tf.gather_nd(sum_image, indices))\n"," xind = tf.cast(xind, dtype=tf.float32) + tf.cast(x_local, dtype=tf.float32)\n"," yind = tf.cast(yind, dtype=tf.float32) + tf.cast(y_local, dtype=tf.float32)\n"," else:\n"," xind = tf.cast(xind, dtype=tf.float32)\n"," yind = tf.cast(yind, dtype=tf.float32)\n"," \n"," return bind, xind, yind, confidence\n","\n"," def get_config(self):\n","\n"," # Implement get_config to enable serialization. This is optional.\n"," base_config = super(Maximafinder, self).get_config()\n"," config = {}\n"," return dict(list(base_config.items()) + list(config.items()))\n","\n","\n","\n","# ------------------------------- Prediction with postprocessing function-------------------------------\n","def batchFramePredictionLocalization(dataPath, filename, modelPath, savePath, batch_size=1, thresh=0.1, neighborhood_size=3, use_local_avg = False, pixel_size = None):\n"," \"\"\"\n"," This function tests a trained model on the desired test set, given the \n"," tiff stack of test images, learned weights, and normalization factors.\n"," \n"," # Inputs\n"," dataPath - the path to the folder containing the tiff stack(s) to run prediction on \n"," filename - the name of the file to process\n"," modelPath - the path to the folder containing the weights file and the mean and standard deviation file generated in train_model\n"," savePath - the path to the folder where to save the prediction\n"," batch_size. - the number of frames to predict on for each iteration\n"," thresh - threshoold percentage from the maximum of the gaussian scaling\n"," neighborhood_size - the size of the neighborhood for local maxima finding\n"," use_local_average - Boolean whether to perform local averaging or not\n"," \"\"\"\n"," \n"," # load mean and std\n"," matfile = sio.loadmat(os.path.join(modelPath,'model_metadata.mat'))\n"," test_mean = np.array(matfile['mean_test'])\n"," test_std = np.array(matfile['std_test']) \n"," upsampling_factor = np.array(matfile['upsampling_factor'])\n"," upsampling_factor = upsampling_factor.item() # convert to scalar\n"," L2_weighting_factor = np.array(matfile['Normalization factor'])\n"," L2_weighting_factor = L2_weighting_factor.item() # convert to scalar\n","\n"," # Read in the raw file\n"," Images = io.imread(os.path.join(dataPath, filename))\n"," if pixel_size == None:\n"," pixel_size, _, _ = getPixelSizeTIFFmetadata(os.path.join(dataPath, filename), display=True)\n"," pixel_size_hr = pixel_size/upsampling_factor\n","\n"," # get dataset dimensions\n"," (nFrames, M, N) = Images.shape\n"," print('Input image is '+str(N)+'x'+str(M)+' with '+str(nFrames)+' frames.')\n","\n"," # Build the model for a bigger image\n"," model = buildModel((upsampling_factor*M, upsampling_factor*N, 1))\n","\n"," # Load the trained weights\n"," model.load_weights(os.path.join(modelPath,'weights_best.hdf5'))\n","\n"," # add a post-processing module\n"," max_layer = Maximafinder(thresh*L2_weighting_factor, neighborhood_size, use_local_avg)\n","\n"," # Initialise the results: lists will be used to collect all the localizations\n"," frame_number_list, x_nm_list, y_nm_list, confidence_au_list = [], [], [], []\n","\n"," # Initialise the results\n"," Prediction = np.zeros((M*upsampling_factor, N*upsampling_factor), dtype=np.float32)\n"," Widefield = np.zeros((M, N), dtype=np.float32)\n","\n"," # run model in batches\n"," n_batches = math.ceil(nFrames/batch_size)\n"," for b in tqdm(range(n_batches)):\n","\n"," nF = min(batch_size, nFrames - b*batch_size)\n"," Images_norm = np.zeros((nF, M, N),dtype=np.float32)\n"," Images_upsampled = np.zeros((nF, M*upsampling_factor, N*upsampling_factor), dtype=np.float32)\n","\n"," # Upsampling using a simple nearest neighbor interp and calculating - MULTI-THREAD this?\n"," for f in range(nF):\n"," Images_norm[f,:,:] = project_01(Images[b*batch_size+f,:,:])\n"," Images_norm[f,:,:] = normalize_im(Images_norm[f,:,:], test_mean, test_std)\n"," Images_upsampled[f,:,:] = np.kron(Images_norm[f,:,:], np.ones((upsampling_factor,upsampling_factor)))\n"," Widefield += Images[b*batch_size+f,:,:]\n","\n"," # Reshaping\n"," Images_upsampled = np.expand_dims(Images_upsampled,axis=3)\n","\n"," # Run prediction and local amxima finding\n"," predicted_density = model.predict_on_batch(Images_upsampled)\n"," predicted_density[predicted_density < 0] = 0\n"," Prediction += predicted_density.sum(axis = 3).sum(axis = 0)\n","\n"," bind, xind, yind, confidence = max_layer(predicted_density)\n"," \n"," # normalizing the confidence by the L2_weighting_factor\n"," confidence /= L2_weighting_factor \n","\n"," # turn indices to nms and append to the results\n"," xind, yind = xind*pixel_size_hr, yind*pixel_size_hr\n"," frmind = (bind.numpy() + b*batch_size + 1).tolist()\n"," xind = xind.numpy().tolist()\n"," yind = yind.numpy().tolist()\n"," confidence = confidence.numpy().tolist()\n"," frame_number_list += frmind\n"," x_nm_list += xind\n"," y_nm_list += yind\n"," confidence_au_list += confidence\n","\n"," # Open and create the csv file that will contain all the localizations\n"," if use_local_avg:\n"," ext = '_avg'\n"," else:\n"," ext = '_max'\n"," with open(os.path.join(savePath, 'Localizations_' + os.path.splitext(filename)[0] + ext + '.csv'), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow(['frame', 'x [nm]', 'y [nm]', 'confidence [a.u]'])\n"," locs = list(zip(frame_number_list, x_nm_list, y_nm_list, confidence_au_list))\n"," writer.writerows(locs)\n","\n"," # Save the prediction and widefield image\n"," Widefield = np.kron(Widefield, np.ones((upsampling_factor,upsampling_factor)))\n"," Widefield = np.float32(Widefield)\n","\n"," # io.imsave(os.path.join(savePath, 'Predicted_'+os.path.splitext(filename)[0]+'.tif'), Prediction)\n"," # io.imsave(os.path.join(savePath, 'Widefield_'+os.path.splitext(filename)[0]+'.tif'), Widefield)\n","\n"," saveAsTIF(savePath, 'Predicted_'+os.path.splitext(filename)[0], Prediction, pixel_size_hr)\n"," saveAsTIF(savePath, 'Widefield_'+os.path.splitext(filename)[0], Widefield, pixel_size_hr)\n","\n","\n"," return\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n"," NORMAL = '\\033[0m' # white (normal)\n","\n","\n","\n","def list_files(directory, extension):\n"," return (f for f in os.listdir(directory) if f.endswith('.' + extension))\n","\n","\n","# @njit(parallel=True)\n","def subPixelMaxLocalization(array, method = 'CoM', patch_size = 3):\n"," xMaxInd, yMaxInd = np.unravel_index(array.argmax(), array.shape, order='C')\n"," centralPatch = XC[(xMaxInd-patch_size):(xMaxInd+patch_size+1),(yMaxInd-patch_size):(yMaxInd+patch_size+1)]\n","\n"," if (method == 'MAX'):\n"," x0 = xMaxInd\n"," y0 = yMaxInd\n","\n"," elif (method == 'CoM'):\n"," x0 = 0\n"," y0 = 0\n"," S = 0\n"," for xy in range(patch_size*patch_size):\n"," y = math.floor(xy/patch_size)\n"," x = xy - y*patch_size\n"," x0 += x*array[x,y]\n"," y0 += y*array[x,y]\n"," S = array[x,y]\n"," \n"," x0 = x0/S - patch_size/2 + xMaxInd\n"," y0 = y0/S - patch_size/2 + yMaxInd\n"," \n"," elif (method == 'Radiality'):\n"," # Not implemented yet\n"," x0 = xMaxInd\n"," y0 = yMaxInd\n"," \n"," return (x0, y0)\n","\n","\n","@njit(parallel=True)\n","def correctDriftLocalization(xc_array, yc_array, frames, xDrift, yDrift):\n"," n_locs = xc_array.shape[0]\n"," xc_array_Corr = np.empty(n_locs)\n"," yc_array_Corr = np.empty(n_locs)\n"," \n"," for loc in prange(n_locs):\n"," xc_array_Corr[loc] = xc_array[loc] - xDrift[frames[loc]]\n"," yc_array_Corr[loc] = yc_array[loc] - yDrift[frames[loc]]\n","\n"," return (xc_array_Corr, yc_array_Corr)\n","\n","\n","print('--------------------------------')\n","print('DeepSTORM installation complete.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"vu8f5NGJkJos","colab_type":"text"},"source":["\n","# **3. Generate patches for training**\n","---\n","\n","For Deep-STORM the training data can be obtained in two ways:\n","* Simulated using ThunderSTORM or other simulation tool and loaded here (**using Section 3.1.a**)\n","* Directly simulated in this notebook (**using Section 3.1.b**)\n"]},{"cell_type":"markdown","metadata":{"id":"WSV8xnlynp0l","colab_type":"text"},"source":["## **3.1.a Load training data**\n","---\n","\n","Here you can load your simulated data along with its corresponding localization file.\n","* The `pixel_size` is defined in nanometer (nm). "]},{"cell_type":"code","metadata":{"id":"CT6SNcfNg6j0","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Load raw data\n","\n","# Get user input\n","ImageData_path = \"\" #@param {type:\"string\"}\n","LocalizationData_path = \"\" #@param {type: \"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value:\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","if get_pixel_size_from_file:\n"," pixel_size,_,_ = getPixelSizeTIFFmetadata(ImageData_path, True)\n","\n","# load the tiff data\n","Images = io.imread(ImageData_path)\n","# get dataset dimensions\n","if len(Images.shape) == 3:\n"," (number_of_frames, M, N) = Images.shape\n","elif len(Images.shape) == 2:\n"," (M, N) = Images.shape\n"," number_of_frames = 1\n","print('Loaded images: '+str(M)+'x'+str(N)+' with '+str(number_of_frames)+' frames')\n","\n","# Interactive display of the stack\n","def scroll_in_time(frame):\n"," f=plt.figure(figsize=(6,6))\n"," plt.imshow(Images[frame-1], interpolation='nearest', cmap = 'gray')\n"," plt.title('Training source at frame = ' + str(frame))\n"," plt.axis('off');\n","\n","if number_of_frames > 1:\n"," interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));\n","else:\n"," f=plt.figure(figsize=(6,6))\n"," plt.imshow(Images, interpolation='nearest', cmap = 'gray')\n"," plt.title('Training source')\n"," plt.axis('off');\n","\n","# Load the localization file and display the first\n","LocData = pd.read_csv(LocalizationData_path, index_col=0)\n","LocData.tail()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K9xE5GeYiks9","colab_type":"text"},"source":["## **3.1.b Simulate training data**\n","---\n","This simulation tool allows you to generate SMLM data of randomly distrubuted emitters in a field-of-view. \n","The assumptions are as follows:\n","\n","* Gaussian Point Spread Function (PSF) with standard deviation defined by `Sigma`. The nominal value of `sigma` can be evaluated using `sigma = 0.21 x Lambda / NA`. \n","* Each emitter will emit `n_photons` per frame, and generate their equivalent Poisson noise.\n","* The camera will contribute Gaussian noise to the signal with a standard deviation defined by `ReadOutNoise_ADC` in ADC\n","* The `emitter_density` is defined as the number of emitters / um^2 on any given frame. Variability in the emitter density can be applied by adjusting `emitter_density_std`. The latter parameter represents the standard deviation of the normal distribution that the density is drawn from for each individual frame. `emitter_density` **is defined in number of emitters / um^2**.\n","* The `n_photons` and `sigma` can additionally include some Gaussian variability by setting `n_photons_std` and `sigma_std`.\n","\n","Important note:\n","- All dimensions are in nanometer (e.g. `FOV_size` = 6400 represents a field of view of 6.4 um x 6.4 um).\n","\n"]},{"cell_type":"code","metadata":{"id":"sQyLXpEhitsg","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# ---------------------------- User input ----------------------------\n","#@markdown Run the simulation\n","#@markdown --- \n","#@markdown Camera settings: \n","FOV_size = 6400#@param {type:\"number\"}\n","pixel_size = 100#@param {type:\"number\"}\n","ADC_per_photon_conversion = 1 #@param {type:\"number\"}\n","ReadOutNoise_ADC = 4.5#@param {type:\"number\"}\n","ADC_offset = 50#@param {type:\"number\"}\n","\n","#@markdown Acquisition settings: \n","emitter_density = 6#@param {type:\"number\"}\n","emitter_density_std = 0#@param {type:\"number\"}\n","\n","number_of_frames = 20#@param {type:\"integer\"}\n","\n","sigma = 110 #@param {type:\"number\"}\n","sigma_std = 5 #@param {type:\"number\"}\n","# NA = 1.1 #@param {type:\"number\"}\n","# wavelength = 800#@param {type:\"number\"}\n","# wavelength_std = 150#@param {type:\"number\"}\n","n_photons = 2250#@param {type:\"number\"}\n","n_photons_std = 250#@param {type:\"number\"}\n","\n","\n","# ---------------------------- Variable initialisation ----------------------------\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","print('-----------------------------------------------------------')\n","n_molecules = emitter_density*FOV_size*FOV_size/10**6\n","n_molecules_std = emitter_density_std*FOV_size*FOV_size/10**6\n","print('Number of molecules / FOV: '+str(round(n_molecules,2))+' +/- '+str((round(n_molecules_std,2))))\n","\n","# sigma = 0.21*wavelength/NA\n","# sigma_std = 0.21*wavelength_std/NA\n","# print('Gaussian PSF sigma: '+str(round(sigma,2))+' +/- '+str(round(sigma_std,2))+' nm')\n","\n","M = N = round(FOV_size/pixel_size)\n","FOV_size = M*pixel_size\n","print('Final image size: '+str(M)+'x'+str(M)+' ('+str(round(FOV_size/1000, 3))+'um x'+str(round(FOV_size/1000,3))+' um)')\n","\n","np.random.seed(1)\n","display_upsampling = 8 # used to display the loc map here\n","NoiseFreeImages = np.zeros((number_of_frames, M, M))\n","locImage = np.zeros((number_of_frames, display_upsampling*M, display_upsampling*N))\n","\n","frames = []\n","all_xloc = []\n","all_yloc = []\n","all_photons = []\n","all_sigmas = []\n","\n","# ---------------------------- Main simulation loop ----------------------------\n","print('-----------------------------------------------------------')\n","for f in tqdm(range(number_of_frames)):\n"," \n"," # Define the coordinates of emitters by randomly distributing them across the FOV\n"," n_mol = int(max(round(np.random.normal(n_molecules, n_molecules_std, size=1)[0]), 0))\n"," x_c = np.random.uniform(low=0.0, high=FOV_size, size=n_mol)\n"," y_c = np.random.uniform(low=0.0, high=FOV_size, size=n_mol)\n"," photon_array = np.random.normal(n_photons, n_photons_std, size=n_mol)\n"," sigma_array = np.random.normal(sigma, sigma_std, size=n_mol)\n"," # x_c = np.linspace(0,3000,5)\n"," # y_c = np.linspace(0,3000,5)\n","\n"," all_xloc += x_c.tolist()\n"," all_yloc += y_c.tolist()\n"," frames += ((f+1)*np.ones(x_c.shape[0])).tolist()\n"," all_photons += photon_array.tolist()\n"," all_sigmas += sigma_array.tolist()\n","\n"," locImage[f] = FromLoc2Image_SimpleHistogram(x_c, y_c, image_size = (N*display_upsampling, M*display_upsampling), pixel_size = pixel_size/display_upsampling)\n","\n"," # # Get the approximated locations according to the grid pixel size\n"," # Chr_emitters = [int(max(min(round(display_upsampling*x_c[i]/pixel_size),N*display_upsampling-1),0)) for i in range(len(x_c))]\n"," # Rhr_emitters = [int(max(min(round(display_upsampling*y_c[i]/pixel_size),M*display_upsampling-1),0)) for i in range(len(y_c))]\n","\n"," # # Build Localization image\n"," # for (r,c) in zip(Rhr_emitters, Chr_emitters):\n"," # locImage[f][r][c] += 1\n","\n"," NoiseFreeImages[f] = FromLoc2Image_Erf(x_c, y_c, photon_array, sigma_array, image_size = (M,M), pixel_size = pixel_size)\n","\n","\n","# ---------------------------- Create DataFrame fof localization file ----------------------------\n","# Table with localization info as dataframe output\n","LocData = pd.DataFrame()\n","LocData[\"frame\"] = frames\n","LocData[\"x [nm]\"] = all_xloc\n","LocData[\"y [nm]\"] = all_yloc\n","LocData[\"Photon #\"] = all_photons\n","LocData[\"Sigma [nm]\"] = all_sigmas\n","LocData.index += 1 # set indices to start at 1 and not 0 (same as ThunderSTORM)\n","\n","\n","# ---------------------------- Estimation of SNR ----------------------------\n","n_frames_for_SNR = 100\n","M_SNR = 10\n","x_c = np.random.uniform(low=0.0, high=pixel_size*M_SNR, size=n_frames_for_SNR)\n","y_c = np.random.uniform(low=0.0, high=pixel_size*M_SNR, size=n_frames_for_SNR)\n","photon_array = np.random.normal(n_photons, n_photons_std, size=n_frames_for_SNR)\n","sigma_array = np.random.normal(sigma, sigma_std, size=n_frames_for_SNR)\n","\n","SNR = np.zeros(n_frames_for_SNR)\n","for i in range(n_frames_for_SNR):\n"," SingleEmitterImage = FromLoc2Image_Erf(np.array([x_c[i]]), np.array([x_c[i]]), np.array([photon_array[i]]), np.array([sigma_array[i]]), (M_SNR, M_SNR), pixel_size)\n"," Signal_photon = np.max(SingleEmitterImage)\n"," Noise_photon = math.sqrt((ReadOutNoise_ADC/ADC_per_photon_conversion)**2 + Signal_photon)\n"," SNR[i] = Signal_photon/Noise_photon\n","\n","print('SNR: '+str(round(np.mean(SNR),2))+' +/- '+str(round(np.std(SNR),2)))\n","# ---------------------------- ----------------------------\n","\n","\n","# Table with info\n","simParameters = pd.DataFrame()\n","simParameters[\"FOV size (nm)\"] = [FOV_size]\n","simParameters[\"Pixel size (nm)\"] = [pixel_size]\n","simParameters[\"ADC/photon\"] = [ADC_per_photon_conversion]\n","simParameters[\"Read-out noise (ADC)\"] = [ReadOutNoise_ADC]\n","simParameters[\"Constant offset (ADC)\"] = [ADC_offset]\n","\n","simParameters[\"Emitter density (emitters/um^2)\"] = [emitter_density]\n","simParameters[\"STD of emitter density (emitters/um^2)\"] = [emitter_density_std]\n","simParameters[\"Number of frames\"] = [number_of_frames]\n","# simParameters[\"NA\"] = [NA]\n","# simParameters[\"Wavelength (nm)\"] = [wavelength]\n","# simParameters[\"STD of wavelength (nm)\"] = [wavelength_std]\n","simParameters[\"Sigma (nm))\"] = [sigma]\n","simParameters[\"STD of Sigma (nm))\"] = [sigma_std]\n","simParameters[\"Number of photons\"] = [n_photons]\n","simParameters[\"STD of number of photons\"] = [n_photons_std]\n","simParameters[\"SNR\"] = [np.mean(SNR)]\n","simParameters[\"STD of SNR\"] = [np.std(SNR)]\n","\n","\n","# ---------------------------- Finish simulation ----------------------------\n","# Calculating the noisy image\n","Images = ADC_per_photon_conversion * np.random.poisson(NoiseFreeImages) + ReadOutNoise_ADC * np.random.normal(size = (number_of_frames, M, N)) + ADC_offset\n","Images[Images <= 0] = 0\n","\n","# Convert to 16-bit or 32-bits integers\n","if Images.max() < (2**16-1):\n"," Images = Images.astype(np.uint16)\n","else:\n"," Images = Images.astype(np.uint32)\n","\n","\n","# ---------------------------- Display ----------------------------\n","# Displaying the time elapsed for simulation\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds,1),\"sec(s)\")\n","\n","\n","# Interactively display the results using Widgets\n","def scroll_in_time(frame):\n"," f = plt.figure(figsize=(18,6))\n"," plt.subplot(1,3,1)\n"," plt.imshow(locImage[frame-1], interpolation='bilinear', vmin = 0, vmax=0.1)\n"," plt.title('Localization image')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,2)\n"," plt.imshow(NoiseFreeImages[frame-1], interpolation='nearest', cmap='gray')\n"," plt.title('Noise-free simulation')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,3)\n"," plt.imshow(Images[frame-1], interpolation='nearest', cmap='gray')\n"," plt.title('Noisy simulation')\n"," plt.axis('off');\n","\n","interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));\n","\n","# Display the head of the dataframe with localizations\n","LocData.tail()\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Pz7RfSuoeJeq","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ---\n","# @markdown #Play this cell to save the simulated stack\n","# @markdown ####Please select a path to the folder where to save the simulated data. It is not necesary to save the data to run the training, but keeping the simulated for your own record can be useful to check its validity.\n","Save_path = \"\" #@param {type:\"string\"}\n","\n","if not os.path.exists(Save_path):\n"," os.makedirs(Save_path)\n"," print('Folder created.')\n","else:\n"," print('Training data already exists in folder: Data overwritten.')\n","\n","saveAsTIF(Save_path, 'SimulatedDataset', Images, pixel_size)\n","# io.imsave(os.path.join(Save_path, 'SimulatedDataset.tif'),Images)\n","LocData.to_csv(os.path.join(Save_path, 'SimulatedDataset.csv'))\n","simParameters.to_csv(os.path.join(Save_path, 'SimulatedParameters.csv'))\n","print('Training dataset saved.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K_8e3kE-JhVY","colab_type":"text"},"source":["## **3.2. Generate training patches**\n","---\n","\n","Training patches need to be created from the training data generated above. \n","* The `patch_size` needs to give sufficient contextual information and for most cases a `patch_size` of 26 (corresponding to patches of 26x26 pixels) works fine. **DEFAULT: 26**\n","* The `upsampling_factor` defines the effective magnification of the final super-resolved image compared to the input image (this is called magnification in ThunderSTORM). This is used to generate the super-resolved patches as target dataset. Using an `upsampling_factor` of 16 will require the use of more memory and it may be necessary to decreae the `patch_size` to 16 for example. **DEFAULT: 8**\n","* The `num_patches_per_frame` defines the number of patches extracted from each frame generated in section 3.1. **DEFAULT: 500**\n","* The `min_number_of_emitters_per_patch` defines the minimum number of emitters that need to be present in the patch to be a valid patch. An empty patch does not contain useful information for the network to learn from. **DEFAULT: 7**\n","* The `max_num_patches` defines the maximum number of patches to generate. Fewer may be generated depending on how many pacthes are rejected and how many frames are available. **DEFAULT: 10000**\n","* The `gaussian_sigma` defines the Gaussian standard deviation (in magnified pixels) applied to generate the super-resolved target image. **DEFAULT: 1**\n","* The `L2_weighting_factor` is a normalization factor used in the loss function. It helps balancing the loss from the L2 norm. When using higher densities, this factor should be decreased and vice-versa. This factor can be autimatically calculated using an empiraical formula. **DEFAULT: 100**\n","\n"]},{"cell_type":"code","metadata":{"id":"AsNx5KzcFNvC","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ## **Provide patch parameters**\n","\n","\n","# -------------------- User input --------------------\n","patch_size = 26 #@param {type:\"integer\"}\n","upsampling_factor = 8 #@param [\"4\", \"8\", \"16\"] {type:\"raw\"}\n","num_patches_per_frame = 500#@param {type:\"integer\"}\n","min_number_of_emitters_per_patch = 7#@param {type:\"integer\"}\n","max_num_patches = 10000#@param {type:\"integer\"}\n","gaussian_sigma = 1#@param {type:\"integer\"}\n","\n","#@markdown Estimate the optimal normalization factor automatically?\n","Automatic_normalization = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, it will use the following value:\n","L2_weighting_factor = 100 #@param {type:\"number\"}\n","\n","\n","# -------------------- Prepare variables --------------------\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","# Initialize some parameters\n","pixel_size_hr = pixel_size/upsampling_factor # in nm\n","n_patches = min(number_of_frames*num_patches_per_frame, max_num_patches)\n","patch_size = patch_size*upsampling_factor\n","\n","# Dimensions of the high-res grid\n","Mhr = upsampling_factor*M # in pixels\n","Nhr = upsampling_factor*N # in pixels\n","\n","# Initialize the training patches and labels\n","patches = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","spikes = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","heatmaps = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","\n","# Run over all frames and construct the training examples\n","k = 1 # current patch count\n","skip_counter = 0 # number of dataset skipped due to low density\n","id_start = 0 # id position in LocData for current frame\n","print('Generating '+str(n_patches)+' patches of '+str(patch_size)+'x'+str(patch_size))\n","\n","n_locs = len(LocData.index)\n","print('Total number of localizations: '+str(n_locs))\n","density = n_locs/(M*N*number_of_frames*(0.001*pixel_size)**2)\n","print('Density: '+str(round(density,2))+' locs/um^2')\n","n_locs_per_patch = patch_size**2*density\n","\n","if Automatic_normalization:\n"," # This empirical formulae attempts to balance the loss L2 function between the background and the bright spikes\n"," # A value of 100 was originally chosen to balance L2 for a patch size of 2.6x2.6^2 0.1um pixel size and density of 3 (hence the 20.28), at upsampling_factor = 8\n"," L2_weighting_factor = 100/math.sqrt(min(n_locs_per_patch, min_number_of_emitters_per_patch)*8**2/(upsampling_factor**2*20.28))\n"," print('Normalization factor: '+str(round(L2_weighting_factor,2)))\n","\n","# -------------------- Patch generation loop --------------------\n","\n","print('-----------------------------------------------------------')\n","for (f, thisFrame) in enumerate(tqdm(Images)):\n","\n"," # Upsample the frame\n"," upsampledFrame = np.kron(thisFrame, np.ones((upsampling_factor,upsampling_factor)))\n"," # Read all the provided high-resolution locations for current frame\n"," DataFrame = LocData[LocData['frame'] == f+1].copy()\n","\n"," # Get the approximated locations according to the high-res grid pixel size\n"," Chr_emitters = [int(max(min(round(DataFrame['x [nm]'][i]/pixel_size_hr),Nhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]\n"," Rhr_emitters = [int(max(min(round(DataFrame['y [nm]'][i]/pixel_size_hr),Mhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]\n"," id_start += len(DataFrame.index)\n","\n"," # Build Localization image\n"," LocImage = np.zeros((Mhr,Nhr))\n"," LocImage[(Rhr_emitters, Chr_emitters)] = 1\n","\n"," # Here, there's a choice between the original Gaussian (classification approach) and using the erf function\n"," HeatMapImage = L2_weighting_factor*gaussian_filter(LocImage, gaussian_sigma) \n"," # HeatMapImage = L2_weighting_factor*FromLoc2Image_MultiThreaded(np.array(list(DataFrame['x [nm]'])), np.array(list(DataFrame['y [nm]'])), \n"," # np.ones(len(DataFrame.index)), pixel_size_hr*gaussian_sigma*np.ones(len(DataFrame.index)), \n"," # Mhr, pixel_size_hr)\n"," \n","\n"," # Generate random position for the top left corner of the patch\n"," xc = np.random.randint(0, Mhr-patch_size, size=num_patches_per_frame)\n"," yc = np.random.randint(0, Nhr-patch_size, size=num_patches_per_frame)\n","\n"," for c in range(len(xc)):\n"," if LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size].sum() < min_number_of_emitters_per_patch:\n"," skip_counter += 1\n"," continue\n"," \n"," else:\n"," # Limit maximal number of training examples to 15k\n"," if k > max_num_patches:\n"," break\n"," else:\n"," # Assign the patches to the right part of the images\n"," patches[k-1] = upsampledFrame[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," spikes[k-1] = LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," heatmaps[k-1] = HeatMapImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," k += 1 # increment current patch count\n","\n","# Remove the empty data\n","patches = patches[:k-1]\n","spikes = spikes[:k-1]\n","heatmaps = heatmaps[:k-1]\n","n_patches = k-1\n","\n","# -------------------- Failsafe --------------------\n","# Check if the size of the training set is smaller than 5k to notify user to simulate more images using ThunderSTORM\n","if ((k-1) < 5000):\n"," # W = '\\033[0m' # white (normal)\n"," # R = '\\033[31m' # red\n"," print(bcolors.WARNING+'!! WARNING: Training set size is below 5K - Consider simulating more images in ThunderSTORM. !!'+bcolors.NORMAL)\n","\n","\n","\n","# -------------------- Displays --------------------\n","print('Number of patches skipped due to low density: '+str(skip_counter))\n","# dataSize = int((getsizeof(patches)+getsizeof(heatmaps)+getsizeof(spikes))/(1024*1024)) #rounded in MB\n","# print('Size of patches: '+str(dataSize)+' MB')\n","print(str(n_patches)+' patches were generated.')\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","# Display patches interactively with a slider\n","def scroll_patches(patch):\n"," f = plt.figure(figsize=(16,6))\n"," plt.subplot(1,3,1)\n"," plt.imshow(patches[patch-1], interpolation='nearest', cmap='gray')\n"," plt.title('Raw data (frame #'+str(patch)+')')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,2)\n"," plt.imshow(heatmaps[patch-1], interpolation='nearest')\n"," plt.title('Heat map')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,3)\n"," plt.imshow(spikes[patch-1], interpolation='nearest')\n"," plt.title('Localization map')\n"," plt.axis('off');\n","\n","interact(scroll_patches, patch=widgets.IntSlider(min=1, max=patches.shape[0], step=1, value=0, continuous_update=False));\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DSjXFMevK7Iz","colab_type":"text"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"hVeyKU0MdAPx","colab_type":"text"},"source":["## **4.1. Select your paths and parameters**\n","\n","---\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for ~100 epochs. Evaluate the performance after training (see 5). **Default value: 80**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. **If this value is set to 0**, by default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 30** \n","\n","**`initial_learning_rate`:** This parameter represents the initial value to be used as learning rate in the optimizer. **Default value: 0.001**"]},{"cell_type":"code","metadata":{"id":"oa5cDZ7f_PF6","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###Path to training images and parameters\n","\n","model_path = \"\" #@param {type: \"string\"} \n","model_name = \"\" #@param {type: \"string\"} \n","number_of_epochs = 80#@param {type:\"integer\"}\n","batch_size = 16#@param {type:\"integer\"}\n","\n","number_of_steps = 0#@param {type:\"integer\"}\n","percentage_validation = 30 #@param {type:\"number\"}\n","initial_learning_rate = 0.001 #@param {type:\"number\"}\n","\n","\n","percentage_validation /= 100\n","if number_of_steps == 0: \n"," number_of_steps = int((1-percentage_validation)*n_patches/batch_size)\n"," print('Number of steps: '+str(number_of_steps))\n","\n","# Pretrained model path initialised here so next cell does not need to be run\n","h5_file_path = ''\n","Use_pretrained_model = False\n","\n","if not ('patches' in locals()):\n"," # W = '\\033[0m' # white (normal)\n"," # R = '\\033[31m' # red\n"," print(WARNING+'!! WARNING: No patches were found in memory currently. !!')\n","\n","Save_path = os.path.join(model_path, model_name)\n","if os.path.exists(Save_path):\n"," print(bcolors.WARNING+'The model folder already exists and will be overwritten.'+bcolors.NORMAL)\n","\n","print('-----------------------------')\n","print('Training parameters set.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"WIyEvQBWLp9n","colab_type":"text"},"source":["\n","## **4.2. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a Deep-STORM 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"oHL5g0w8LqR0","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_'+Weights_choice+'.hdf5 pretrained model does not exist'+bcolors.NORMAL)\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead.'+bcolors.NORMAL)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+bcolors.NORMAL)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print('No pretrained network will be used.')\n"," h5_file_path = ''\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"OADNcie-LHxA","colab_type":"text"},"source":["## **4.2. Train the network**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches."]},{"cell_type":"code","metadata":{"id":"qDgMu_mAK8US","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Start training\n","\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(Save_path):\n"," shutil.rmtree(Save_path)\n","\n","# Create the model folder!\n","os.makedirs(Save_path)\n","\n","# Let's go !\n","train_model(patches, heatmaps, Save_path, \n"," steps_per_epoch=number_of_steps, epochs=number_of_epochs, batch_size=batch_size,\n"," upsampling_factor = upsampling_factor,\n"," validation_split = percentage_validation,\n"," initial_learning_rate = initial_learning_rate, \n"," pretrained_model_path = h5_file_path,\n"," L2_weighting_factor = L2_weighting_factor)\n","\n","# # Show info about the GPU memory useage\n","# !nvidia-smi\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"CHVTRjEOLRDH","colab_type":"text"},"source":["##**4.3. Download your model(s) from Google Drive**\n","\n","\n","---\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"4N7-ShZpLhwr","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"JDRsm7uKoBa-","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","#@markdown #####During training, the model files are automatically saved inside a folder named after the parameter `model_name` (see section 4.1). Provide the name of this folder as `QC_model_path` . \n","\n","QC_model_path = \"\" #@param {type:\"string\"}\n","\n","if (Use_the_current_trained_model): \n"," QC_model_path = os.path.join(model_path, model_name)\n","\n","if os.path.exists(QC_model_path):\n"," print(\"The \"+os.path.basename(QC_model_path)+\" model will be evaluated\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+bcolors.NORMAL)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Gw7KaHZUoHC4","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"qUc-JMOcoGNZ","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","import csv\n","from matplotlib import pyplot as plt\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(os.path.join(QC_model_path,'Quality Control/training_evaluation.csv'),'r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(os.path.join(QC_model_path,'Quality Control/lossCurvePlots.png'))\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"32eNQjFioQkY","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"QC_image_folder\" using teh corresponding localization data contained in \"QC_loc_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"dhlTnxC5lUZy","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# ------------------------ User input ------------------------\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","QC_image_folder = \"\" #@param{type:\"string\"}\n","QC_loc_folder = \"\" #@param{type:\"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value:\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","if get_pixel_size_from_file:\n"," pixel_size_INPUT = None\n","else:\n"," pixel_size_INPUT = pixel_size\n","\n","\n","# ------------------------ QC analysis loop over provided dataset ------------------------\n","\n","savePath = os.path.join(QC_model_path, 'Quality Control')\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(os.path.join(savePath, \"QC_metrics.csv\"), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"WF v. GT mSSIM\", \"Prediction v. GT NRMSE\",\"WF v. GT NRMSE\", \"Prediction v. GT PSNR\", \"WF v. GT PSNR\"])\n","\n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvWF_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvWF_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvWF_list = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n"," for (imageFilename, locFilename) in zip(list_files(QC_image_folder, 'tif'), list_files(QC_loc_folder, 'csv')):\n"," print('--------------')\n"," print(imageFilename)\n"," print(locFilename)\n","\n"," # Get the prediction\n"," batchFramePredictionLocalization(QC_image_folder, imageFilename, QC_model_path, savePath, pixel_size = pixel_size_INPUT)\n","\n"," # test_model(QC_image_folder, imageFilename, QC_model_path, savePath, display=False);\n"," thisPrediction = io.imread(os.path.join(savePath, 'Predicted_'+imageFilename))\n"," thisWidefield = io.imread(os.path.join(savePath, 'Widefield_'+imageFilename))\n","\n"," Mhr = thisPrediction.shape[0]\n"," Nhr = thisPrediction.shape[1]\n","\n"," if pixel_size_INPUT == None:\n"," pixel_size, N, M = getPixelSizeTIFFmetadata(os.path.join(QC_image_folder,imageFilename))\n","\n"," upsampling_factor = int(Mhr/M)\n"," print('Upsampling factor: '+str(upsampling_factor))\n"," pixel_size_hr = pixel_size/upsampling_factor # in nm\n","\n"," # Load the localization file and display the first\n"," LocData = pd.read_csv(os.path.join(QC_loc_folder,locFilename), index_col=0)\n","\n"," x = np.array(list(LocData['x [nm]']))\n"," y = np.array(list(LocData['y [nm]']))\n"," locImage = FromLoc2Image_SimpleHistogram(x, y, image_size = (Mhr,Nhr), pixel_size = pixel_size_hr)\n","\n"," # Remove extension from filename\n"," imageFilename_no_extension = os.path.splitext(imageFilename)[0]\n","\n"," # io.imsave(os.path.join(savePath, 'GT_image_'+imageFilename), locImage)\n"," saveAsTIF(savePath, 'GT_image_'+imageFilename_no_extension, locImage, pixel_size_hr)\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm, test_prediction_norm = norm_minmse(locImage, thisPrediction, normalize_gt=True)\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm, test_wf_norm = norm_minmse(locImage, thisWidefield, normalize_gt=True)\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1., full=True)\n"," index_SSIM_GTvsWF, img_SSIM_GTvsWF = structural_similarity(test_GT_norm, test_wf_norm, data_range=1., full=True)\n","\n","\n"," # Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," # io.imsave(os.path.join(savePath,'SSIM_GTvsPrediction_'+imageFilename),img_SSIM_GTvsPrediction_32bit)\n"," saveAsTIF(savePath,'SSIM_GTvsPrediction_'+imageFilename_no_extension, img_SSIM_GTvsPrediction_32bit, pixel_size_hr)\n","\n","\n"," img_SSIM_GTvsWF_32bit = np.float32(img_SSIM_GTvsWF)\n"," # io.imsave(os.path.join(savePath,'SSIM_GTvsWF_'+imageFilename),img_SSIM_GTvsWF_32bit)\n"," saveAsTIF(savePath,'SSIM_GTvsWF_'+imageFilename_no_extension, img_SSIM_GTvsWF_32bit, pixel_size_hr)\n","\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsWF = np.sqrt(np.square(test_GT_norm - test_wf_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," # io.imsave(os.path.join(savePath,'RSE_GTvsPrediction_'+imageFilename),img_RSE_GTvsPrediction_32bit)\n"," saveAsTIF(savePath,'RSE_GTvsPrediction_'+imageFilename_no_extension, img_RSE_GTvsPrediction_32bit, pixel_size_hr)\n","\n"," img_RSE_GTvsWF_32bit = np.float32(img_RSE_GTvsWF)\n"," # io.imsave(os.path.join(savePath,'RSE_GTvsWF_'+imageFilename),img_RSE_GTvsWF_32bit)\n"," saveAsTIF(savePath,'RSE_GTvsWF_'+imageFilename_no_extension, img_RSE_GTvsWF_32bit, pixel_size_hr)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsWF = np.sqrt(np.mean(img_RSE_GTvsWF))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsWF = psnr(test_GT_norm,test_wf_norm,data_range=1.0)\n","\n"," writer.writerow([imageFilename,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsWF),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsWF),str(PSNR_GTvsPrediction), str(PSNR_GTvsWF)])\n","\n"," # Collect values to display in dataframe output\n"," file_name_list.append(imageFilename)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvWF_list.append(index_SSIM_GTvsWF)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvWF_list.append(NRMSE_GTvsWF)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvWF_list.append(PSNR_GTvsWF)\n","\n","\n","# Table with metrics as dataframe output\n","pdResults = pd.DataFrame(index = file_name_list)\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list\n","pdResults[\"Wide-field v. GT mSSIM\"] = mSSIM_GvWF_list\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list\n","pdResults[\"Wide-field v. GT NRMSE\"] = NRMSE_GvWF_list\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list\n","pdResults[\"Wide-field v. GT PSNR\"] = PSNR_GvWF_list\n","\n","\n","# ------------------------ Display ------------------------\n","\n","print('--------------------------------------------')\n","@interact\n","def show_QC_results(file = list_files(QC_image_folder, 'tif')):\n","\n"," plt.figure(figsize=(15,15))\n"," # Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = io.imread(os.path.join(savePath, 'GT_image_'+file))\n"," plt.imshow(img_GT, norm = simple_norm(img_GT, percent = 99.5))\n"," plt.title('Target',fontsize=15)\n","\n"," # Wide-field\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(savePath, 'Widefield_'+file))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield',fontsize=15)\n","\n"," #Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(savePath, 'Predicted_'+file))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Prediction',fontsize=15)\n","\n"," #Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n"," #SSIM between GT and Source\n"," plt.subplot(3,3,5)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_SSIM_GTvsWF = io.imread(os.path.join(savePath, 'SSIM_GTvsWF_'+file))\n"," imSSIM_GTvsWF = plt.imshow(img_SSIM_GTvsWF, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imSSIM_GTvsWF,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Widefield',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(pdResults.loc[file][\"Wide-field v. GT mSSIM\"],3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n"," #SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_SSIM_GTvsPrediction = io.imread(os.path.join(savePath, 'SSIM_GTvsPrediction_'+file))\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(pdResults.loc[file][\"Prediction v. GT mSSIM\"],3)),fontsize=14)\n","\n"," #Root Squared Error between GT and Source\n"," plt.subplot(3,3,8)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_RSE_GTvsWF = io.imread(os.path.join(savePath, 'RSE_GTvsWF_'+file))\n"," imRSE_GTvsWF = plt.imshow(img_RSE_GTvsWF, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsWF,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Widefield',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(pdResults.loc[file][\"Wide-field v. GT NRMSE\"],3))+', PSNR: '+str(round(pdResults.loc[file][\"Wide-field v. GT PSNR\"],3)),fontsize=14)\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n"," #Root Squared Error between GT and Prediction\n"," plt.subplot(3,3,9)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_RSE_GTvsPrediction = io.imread(os.path.join(savePath, 'RSE_GTvsPrediction_'+file))\n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(pdResults.loc[file][\"Prediction v. GT NRMSE\"],3))+', PSNR: '+str(round(pdResults.loc[file][\"Prediction v. GT PSNR\"],3)),fontsize=14)\n","\n","print('--------------------------------------------')\n","pdResults.head()\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yTRou0izLjhd","colab_type":"text"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"eAf8aBDmWTx7"},"source":["## **6.1 Generate image prediction and localizations from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the found localizations csv.\n","\n","**`batch_size`:** This paramter determines how many frames are processed by any single pass on the GPU. A higher `batch_size` will make the prediction faster but will use more GPU memory. If an OutOfMemory (OOM) error occurs, decrease the `batch_size`. **DEFAULT: 4**\n","\n","**`threshold`:** This paramter determines threshold for local maxima finding. The value is expected to reside in the range **[0,1]**. A higher `threshold` will result in less localizations. **DEFAULT: 0.1**\n","\n","**`neighborhood_size`:** This paramter determines size of the neighborhood within which the prediction needs to be a local maxima in recovery pixels (CCD pixel/upsampling_factor). A high `neighborhood_size` will make the prediction slower and potentially discard nearby localizations. **DEFAULT: 3**\n","\n","**`use_local_average`:** This paramter determines whether to locally average the prediction in a 3x3 neighborhood to get the final localizations. If set to **True** it will make inference slightly slower depending on the size of the FOV. **DEFAULT: True**\n"]},{"cell_type":"code","metadata":{"id":"7qn06T_A0lxf","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# ------------------------------- User input -------------------------------\n","#@markdown ### Data parameters\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value (in nm):\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","#@markdown ### Model parameters\n","#@markdown Do you want to use the model you just trained?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, please provide path to the model folder below\n","prediction_model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ### Prediction parameters\n","batch_size = 4#@param {type:\"integer\"}\n","\n","#@markdown ### Post processing parameters\n","threshold = 0.1#@param {type:\"number\"}\n","neighborhood_size = 3#@param {type:\"integer\"}\n","#@markdown Do you want to locally average the model output with CoG estimator ?\n","use_local_average = True #@param {type:\"boolean\"}\n","\n","\n","if get_pixel_size_from_file:\n"," pixel_size = None\n","\n","if (Use_the_current_trained_model): \n"," prediction_model_path = os.path.join(model_path, model_name)\n","\n","if os.path.exists(prediction_model_path):\n"," print(\"The \"+os.path.basename(prediction_model_path)+\" model will be used.\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+bcolors.NORMAL)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n","\n","# inform user whether local averaging is being used\n","if use_local_average == True: \n"," print('Using local averaging')\n","\n","if not os.path.exists(Result_folder):\n"," print('Result folder was created.')\n"," os.makedirs(Result_folder)\n","\n","\n","# ------------------------------- Run predictions -------------------------------\n","\n","start = time.time()\n","#%% This script tests the trained fully convolutional network based on the \n","# saved training weights, and normalization created using train_model.\n","\n","if os.path.isdir(Data_folder): \n"," for filename in list_files(Data_folder, 'tif'):\n"," # run the testing/reconstruction process\n"," print(\"------------------------------------\")\n"," print(\"Running prediction on: \"+ filename)\n"," batchFramePredictionLocalization(Data_folder, filename, prediction_model_path, Result_folder, \n"," batch_size, \n"," threshold, \n"," neighborhood_size, \n"," use_local_average,\n"," pixel_size = pixel_size)\n","\n","elif os.path.isfile(Data_folder):\n"," batchFramePredictionLocalization(os.path.dirname(Data_folder), os.path.basename(Data_folder), prediction_model_path, Result_folder, \n"," batch_size, \n"," threshold, \n"," neighborhood_size, \n"," use_local_average, \n"," pixel_size = pixel_size)\n","\n","\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","\n","# ------------------------------- Interactive display -------------------------------\n","\n","print('--------------------------------------------------------------------')\n","print('---------------------------- Previews ------------------------------')\n","print('--------------------------------------------------------------------')\n","\n","if os.path.isdir(Data_folder): \n"," @interact\n"," def show_QC_results(file = list_files(Data_folder, 'tif')):\n","\n"," plt.figure(figsize=(15,7.5))\n"," # Wide-field\n"," plt.subplot(1,2,1)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+file))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield', fontsize=15)\n"," # Prediction\n"," plt.subplot(1,2,2)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+file))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Predicted',fontsize=15)\n","\n","if os.path.isfile(Data_folder):\n","\n"," plt.figure(figsize=(15,7.5))\n"," # Wide-field\n"," plt.subplot(1,2,1)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+os.path.basename(Data_folder)))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield', fontsize=15)\n"," # Prediction\n"," plt.subplot(1,2,2)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+os.path.basename(Data_folder)))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Predicted',fontsize=15)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ZekzexaPmzFZ","colab_type":"text"},"source":["## **6.2 Drift correction**\n","---\n","\n","The visualization above is the raw output of the network and displayed at the `upsampling_factor` chosen during model training. The display is a preview without any drift correction applied. This section performs drift correction using cross-correlation between time bins to estimate the drift.\n","\n","**`Loc_file_path`:** is the path to the localization file to use for visualization.\n","\n","**`original_image_path`:** is the path to the original image. This only serves to extract the original image size and pixel size to shape the visualization properly.\n","\n","**`visualization_pixel_size`:** This parameter corresponds to the pixel size to use for the image reconstructions used for the Drift Correction estmication (in **nm**). A smaller pixel size will be more precise but will take longer to compute. **DEFAULT: 20**\n","\n","**`number_of_bins`:** This parameter defines how many temporal bins are used across the full dataset. All localizations in each bins are used ot build an image. This image is used to find the drift with respect to the image obtained from the very first bin. A typical value would correspond to about 500 frames per bin. **DEFAULT: Total number of frames / 500**\n","\n","**`polynomial_fit_degree`:** The drift obtained for each temporal bins needs to be interpolated to every single frames. This is performed by polynomial fit, the degree of which is defined here. **DEFAULT: 4**\n","\n"," The drift-corrected localization data is automaticaly saved in the `save_path` folder."]},{"cell_type":"code","metadata":{"id":"hYtP_vh6mzUP","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Data parameters\n","Loc_file_path = \"\" #@param {type:\"string\"}\n","# @markdown Provide information about original data. Get the info automatically from the raw data?\n","Get_info_from_file = True #@param {type:\"boolean\"}\n","# Loc_file_path = \"/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv\" #@param {type:\"string\"}\n","original_image_path = \"\" #@param {type:\"string\"}\n","# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)\n","image_width = 256#@param {type:\"integer\"}\n","image_height = 256#@param {type:\"integer\"}\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","# @markdown ##Drift correction parameters\n","visualization_pixel_size = 20#@param {type:\"number\"}\n","number_of_bins = 50#@param {type:\"integer\"}\n","polynomial_fit_degree = 4#@param {type:\"integer\"}\n","\n","# @markdown ##Saving parameters\n","save_path = '' #@param {type:\"string\"}\n","\n","\n","# Let's go !\n","start = time.time()\n","\n","# Get info from the raw file if selected\n","if Get_info_from_file:\n"," pixel_size, image_width, image_height = getPixelSizeTIFFmetadata(original_image_path, display=True)\n","\n","# Read the localizations in\n","LocData = pd.read_csv(Loc_file_path)\n","\n","# Calculate a few variables \n","Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))\n","Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))\n","nFrames = max(LocData['frame'])\n","x_max = max(LocData['x [nm]'])\n","y_max = max(LocData['y [nm]'])\n","image_size = (Mhr, Nhr)\n","n_locs = len(LocData.index)\n","\n","print('Image size: '+str(image_size))\n","print('Number of frames in data: '+str(nFrames))\n","print('Number of localizations in data: '+str(n_locs))\n","\n","blocksize = math.ceil(nFrames/number_of_bins)\n","print('Number of frames per block: '+str(blocksize))\n","\n","blockDataFrame = LocData[(LocData['frame'] < blocksize)].copy()\n","xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)\n","yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)\n","\n","# Preparing the Reference image\n","photon_array = np.ones(yc_array.shape[0])\n","sigma_array = np.ones(yc_array.shape[0])\n","ImageRef = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","ImagesRef = np.rot90(ImageRef, k=2)\n","\n","xDrift = np.zeros(number_of_bins)\n","yDrift = np.zeros(number_of_bins)\n","\n","filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]\n","\n","with open(os.path.join(save_path, filename_no_extension+\"_DriftCorrectionData.csv\"), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"Block #\", \"x-drift [nm]\",\"y-drift [nm]\"])\n","\n"," for b in tqdm(range(number_of_bins)):\n","\n"," blockDataFrame = LocData[(LocData['frame'] >= (b*blocksize)) & (LocData['frame'] < ((b+1)*blocksize))].copy()\n"," xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)\n"," yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)\n","\n"," photon_array = np.ones(yc_array.shape[0])\n"," sigma_array = np.ones(yc_array.shape[0])\n"," ImageBlock = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n"," XC = fftconvolve(ImagesRef, ImageBlock, mode = 'same')\n"," yDrift[b], xDrift[b] = subPixelMaxLocalization(XC, method = 'CoM')\n","\n"," # saveAsTIF(save_path, 'ImageBlock'+str(b), ImageBlock, visualization_pixel_size)\n"," # saveAsTIF(save_path, 'XCBlock'+str(b), XC, visualization_pixel_size)\n"," writer.writerow([str(b), str((xDrift[b]-xDrift[0])*visualization_pixel_size), str((yDrift[b]-yDrift[0])*visualization_pixel_size)])\n","\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","print('Fitting drift data...')\n","bin_number = np.arange(number_of_bins)*blocksize + blocksize/2\n","xDrift = (xDrift-xDrift[0])*visualization_pixel_size\n","yDrift = (yDrift-yDrift[0])*visualization_pixel_size\n","\n","xDriftCoeff = np.polyfit(bin_number, xDrift, polynomial_fit_degree)\n","yDriftCoeff = np.polyfit(bin_number, yDrift, polynomial_fit_degree)\n","\n","xDriftFit = np.poly1d(xDriftCoeff)\n","yDriftFit = np.poly1d(yDriftCoeff)\n","bins = np.arange(nFrames)\n","xDriftInterpolated = xDriftFit(bins)\n","yDriftInterpolated = yDriftFit(bins)\n","\n","\n","# ------------------ Displaying the image results ------------------\n","\n","plt.figure(figsize=(15,10))\n","plt.plot(bin_number,xDrift, 'r+', label='x-drift')\n","plt.plot(bin_number,yDrift, 'b+', label='y-drift')\n","plt.plot(bins,xDriftInterpolated, 'r-', label='y-drift (fit)')\n","plt.plot(bins,yDriftInterpolated, 'b-', label='y-drift (fit)')\n","plt.title('Cross-correlation estimated drift')\n","plt.ylabel('Drift [nm]')\n","plt.xlabel('Bin number')\n","plt.legend();\n","\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\", hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","\n","# ------------------ Actual drift correction -------------------\n","\n","print('Correcting localization data...')\n","xc_array = LocData['x [nm]'].to_numpy(dtype=np.float32)\n","yc_array = LocData['y [nm]'].to_numpy(dtype=np.float32)\n","frames = LocData['frame'].to_numpy(dtype=np.int32)\n","\n","\n","xc_array_Corr, yc_array_Corr = correctDriftLocalization(xc_array, yc_array, frames, xDriftInterpolated, yDriftInterpolated)\n","ImageRaw = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","ImageCorr = FromLoc2Image_SimpleHistogram(xc_array_Corr, yc_array_Corr, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n","\n","# ------------------ Displaying the imge results ------------------\n","plt.figure(figsize=(15,7.5))\n","# Raw\n","plt.subplot(1,2,1)\n","plt.axis('off')\n","plt.imshow(ImageRaw, norm = simple_norm(ImageRaw, percent = 99.5))\n","plt.title('Raw', fontsize=15);\n","# Corrected\n","plt.subplot(1,2,2)\n","plt.axis('off')\n","plt.imshow(ImageCorr, norm = simple_norm(ImageCorr, percent = 99.5))\n","plt.title('Corrected',fontsize=15);\n","\n","\n","# ------------------ Table with info -------------------\n","driftCorrectedLocData = pd.DataFrame()\n","driftCorrectedLocData['frame'] = frames\n","driftCorrectedLocData['x [nm]'] = xc_array_Corr\n","driftCorrectedLocData['y [nm]'] = yc_array_Corr\n","driftCorrectedLocData['confidence [a.u]'] = LocData['confidence [a.u]']\n","\n","driftCorrectedLocData.to_csv(os.path.join(save_path, filename_no_extension+'_DriftCorrected.csv'))\n","print('-------------------------------')\n","print('Corrected localizations saved.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"mzOuc-V7rB-r","colab_type":"text"},"source":["## **6.3 Visualization of the localizations**\n","---\n","\n","\n","The visualization in section 6.1 is the raw output of the network and displayed at the `upsampling_factor` chosen during model training. This section performs visualization of the result by plotting the localizations as a simple histogram.\n","\n","**`Loc_file_path`:** is the path to the localization file to use for visualization.\n","\n","**`original_image_path`:** is the path to the original image. This only serves to extract the original image size and pixel size to shape the visualization properly.\n","\n","**`visualization_pixel_size`:** This parameter corresponds to the pixel size to use for the final image reconstruction (in **nm**). **DEFAULT: 10**\n","\n","**`visualization_mode`:** This parameter defines what visualization method is used to visualize the final image. NOTES: The Integrated Gaussian can be quite slow. **DEFAULT: Simple histogram.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"876yIXnqq-nW","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Data parameters\n","Use_current_drift_corrected_localizations = True #@param {type:\"boolean\"}\n","# @markdown Otherwise provide a localization file path\n","Loc_file_path = \"\" #@param {type:\"string\"}\n","# @markdown Provide information about original data. Get the info automatically from the raw data?\n","Get_info_from_file = True #@param {type:\"boolean\"}\n","# Loc_file_path = \"/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv\" #@param {type:\"string\"}\n","original_image_path = \"\" #@param {type:\"string\"}\n","# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)\n","image_width = 256#@param {type:\"integer\"}\n","image_height = 256#@param {type:\"integer\"}\n","pixel_size = 100#@param {type:\"number\"}\n","\n","# @markdown ##Visualization parameters\n","visualization_pixel_size = 10#@param {type:\"number\"}\n","visualization_mode = \"Simple histogram\" #@param [\"Simple histogram\", \"Integrated Gaussian (SLOW!)\"]\n","\n","if not Use_current_drift_corrected_localizations:\n"," filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]\n","\n","\n","if Get_info_from_file:\n"," pixel_size, image_width, image_height = getPixelSizeTIFFmetadata(original_image_path, display=True)\n","\n","if Use_current_drift_corrected_localizations:\n"," LocData = driftCorrectedLocData\n","else:\n"," LocData = pd.read_csv(Loc_file_path)\n","\n","Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))\n","Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))\n","\n","\n","nFrames = max(LocData['frame'])\n","x_max = max(LocData['x [nm]'])\n","y_max = max(LocData['y [nm]'])\n","image_size = (Mhr, Nhr)\n","\n","print('Image size: '+str(image_size))\n","print('Number of frames in data: '+str(nFrames))\n","print('Number of localizations in data: '+str(len(LocData.index)))\n","\n","xc_array = LocData['x [nm]'].to_numpy()\n","yc_array = LocData['y [nm]'].to_numpy()\n","if (visualization_mode == 'Simple histogram'):\n"," locImage = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","elif (visualization_mode == 'Shifted histogram'):\n"," print(bcolors.WARNING+'Method not implemented yet!'+bcolors.NORMAL)\n"," locImage = np.zeros(image_size)\n","elif (visualization_mode == 'Integrated Gaussian (SLOW!)'):\n"," photon_array = np.ones(xc_array.shape)\n"," sigma_array = np.ones(xc_array.shape)\n"," locImage = FromLoc2Image_Erf(xc_array, yc_array, photon_array, sigma_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","# Display\n","plt.figure(figsize=(20,10))\n","plt.axis('off')\n","# plt.imshow(locImage, cmap='gray');\n","plt.imshow(locImage, norm = simple_norm(locImage, percent = 99.5));\n","\n","\n","LocData.head()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"PdOhWwMn1zIT","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ---\n","# @markdown #Play this cell to save the visualization\n","# @markdown ####Please select a path to the folder where to save the visualization.\n","save_path = \"\" #@param {type:\"string\"}\n","\n","if not os.path.exists(save_path):\n"," os.makedirs(save_path)\n"," print('Folder created.')\n","\n","saveAsTIF(save_path, filename_no_extension+'_Visualization', locImage, visualization_pixel_size)\n","print('Image saved.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"1EszIF4Dkz_n","colab_type":"text"},"source":["## **6.4. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UgN-NooKk3nV","colab_type":"text"},"source":["\n","#**Thank you for using Deep-STORM 2D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Deep-STORM_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"169qcwQo-yw15PwoGatXAdBvjs4wt_foD","timestamp":1592147948265},{"file_id":"1gjRCgDORKi_GNBu4QnVCBkSWrfPtqL-E","timestamp":1588525976305},{"file_id":"1DFy6aCi1XAVdjA5KLRZirB2aMZkMFdv-","timestamp":1587998755430},{"file_id":"1NpzigQoXGy3GFdxh4_jvG1PnBfyrcpBs","timestamp":1587569988032},{"file_id":"1jdI540qAfMSQwjnMhoAFkGJH9EbHwNSf","timestamp":1587486196143}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"FpCtYevLHfl4","colab_type":"text"},"source":["# **Deep-STORM (2D)**\n","\n","---\n","\n","Deep-STORM is a neural network capable of image reconstruction from high-density single-molecule localization microscopy (SMLM), first published in 2018 by [Nehme *et al.* in Optica](https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458). The architecture used here is a U-Net based network without skip connections. This network allows image reconstruction of 2D super-resolution images, in a supervised training manner. The network is trained using simulated high-density SMLM data for which the ground-truth is available. These simulations are obtained from random distribution of single molecules in a field-of-view and therefore do not imprint structural priors during training. The network output a super-resolution image with increased pixel density (typically upsampling factor of 8 in each dimension).\n","\n","Deep-STORM has **two key advantages**:\n","- SMLM reconstruction at high density of emitters\n","- fast prediction (reconstruction) once the model is trained appropriately, compared to more common multi-emitter fitting processes.\n","\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n","**Deep-STORM: super-resolution single-molecule microscopy by deep learning**, Optica (2018) by *Elias Nehme, Lucien E. Weiss, Tomer Michaeli, and Yoav Shechtman* (https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458)\n","\n","And source code found in: https://github.com/EliasNehme/Deep-STORM\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"wyzTn3IcHq6Y","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"bEy4EBXHHyAX","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," Deep-STORM is able to train on simulated dataset of SMLM data (see https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458 for more info). Here, we provide a simulator that will generate training dataset (section 3.1.b). A few parameters will allow you to match the simulation to your experimental data. Similarly to what is described in the paper, simulations obtained from ThunderSTORM can also be loaded here (section 3.1.a).\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"E04mOlG_H5Tz","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"F_tjlGzsH-Dn","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"gn-LaaNNICqL","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","# %tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.__version__ != '2.2.0':\n"," !pip install tensorflow==2.2.0\n","\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime settings are correct then Google did not allocate GPU to your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n","\n","# from tensorflow.python.client import device_lib \n","# device_lib.list_local_devices()\n","\n","# print the tensorflow version\n","print('Tensorflow version is ' + str(tf.__version__))\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"tnP7wM79IKW-","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"1R-7Fo34_gOd","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jRnQZWSZhArJ","colab_type":"text"},"source":["# **2. Install Deep-STORM and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"kSrZMo3X_NhO","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install Deep-STORM and dependencies\n","\n","# %% Model definition + helper functions\n","\n","# Import keras modules and libraries\n","from tensorflow import keras\n","from tensorflow.keras.models import Model\n","from tensorflow.keras.layers import Input, Activation, UpSampling2D, Convolution2D, MaxPooling2D, BatchNormalization, Layer\n","from tensorflow.keras.callbacks import Callback\n","from tensorflow.keras import backend as K\n","from tensorflow.keras import optimizers, losses\n","\n","from tensorflow.keras.preprocessing.image import ImageDataGenerator\n","from tensorflow.keras.callbacks import ModelCheckpoint\n","from tensorflow.keras.callbacks import ReduceLROnPlateau\n","from skimage.transform import warp\n","from skimage.transform import SimilarityTransform\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from scipy.signal import fftconvolve\n","\n","# Import common libraries\n","import tensorflow as tf\n","import numpy as np\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","import h5py\n","import scipy.io as sio\n","from os.path import abspath\n","from sklearn.model_selection import train_test_split\n","from skimage import io\n","import time\n","import os\n","import shutil\n","import csv\n","from PIL import Image \n","from PIL.TiffTags import TAGS\n","from scipy.ndimage import gaussian_filter\n","import math\n","from astropy.visualization import simple_norm\n","from sys import getsizeof\n","\n","# For sliders and dropdown menu, progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","from tqdm import tqdm\n","\n","# For Multi-threading in simulation\n","from numba import njit, prange\n","\n","\n","# define a function that projects and rescales an image to the range [0,1]\n","def project_01(im):\n"," im = np.squeeze(im)\n"," min_val = im.min()\n"," max_val = im.max()\n"," return (im - min_val)/(max_val - min_val)\n","\n","# normalize image given mean and std\n","def normalize_im(im, dmean, dstd):\n"," im = np.squeeze(im)\n"," im_norm = np.zeros(im.shape,dtype=np.float32)\n"," im_norm = (im - dmean)/dstd\n"," return im_norm\n","\n","# Define the loss history recorder\n","class LossHistory(Callback):\n"," def on_train_begin(self, logs={}):\n"," self.losses = []\n","\n"," def on_batch_end(self, batch, logs={}):\n"," self.losses.append(logs.get('loss'))\n"," \n","# Define a matlab like gaussian 2D filter\n","def matlab_style_gauss2D(shape=(7,7),sigma=1):\n"," \"\"\" \n"," 2D gaussian filter - should give the same result as:\n"," MATLAB's fspecial('gaussian',[shape],[sigma]) \n"," \"\"\"\n"," m,n = [(ss-1.)/2. for ss in shape]\n"," y,x = np.ogrid[-m:m+1,-n:n+1]\n"," h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )\n"," h.astype(dtype=K.floatx())\n"," h[ h < np.finfo(h.dtype).eps*h.max() ] = 0\n"," sumh = h.sum()\n"," if sumh != 0:\n"," h /= sumh\n"," h = h*2.0\n"," h = h.astype('float32')\n"," return h\n","\n","# Expand the filter dimensions\n","psf_heatmap = matlab_style_gauss2D(shape = (7,7),sigma=1)\n","gfilter = tf.reshape(psf_heatmap, [7, 7, 1, 1])\n","\n","# Combined MSE + L1 loss\n","def L1L2loss(input_shape):\n"," def bump_mse(heatmap_true, spikes_pred):\n","\n"," # generate the heatmap corresponding to the predicted spikes\n"," heatmap_pred = K.conv2d(spikes_pred, gfilter, strides=(1, 1), padding='same')\n","\n"," # heatmaps MSE\n"," loss_heatmaps = losses.mean_squared_error(heatmap_true,heatmap_pred)\n","\n"," # l1 on the predicted spikes\n"," loss_spikes = losses.mean_absolute_error(spikes_pred,tf.zeros(input_shape))\n"," return loss_heatmaps + loss_spikes\n"," return bump_mse\n","\n","# Define the concatenated conv2, batch normalization, and relu block\n","def conv_bn_relu(nb_filter, rk, ck, name):\n"," def f(input):\n"," conv = Convolution2D(nb_filter, kernel_size=(rk, ck), strides=(1,1),\\\n"," padding=\"same\", use_bias=False,\\\n"," kernel_initializer=\"Orthogonal\",name='conv-'+name)(input)\n"," conv_norm = BatchNormalization(name='BN-'+name)(conv)\n"," conv_norm_relu = Activation(activation = \"relu\",name='Relu-'+name)(conv_norm)\n"," return conv_norm_relu\n"," return f\n","\n","# Define the model architechture\n","def CNN(input,names):\n"," Features1 = conv_bn_relu(32,3,3,names+'F1')(input)\n"," pool1 = MaxPooling2D(pool_size=(2,2),name=names+'Pool1')(Features1)\n"," Features2 = conv_bn_relu(64,3,3,names+'F2')(pool1)\n"," pool2 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool2')(Features2)\n"," Features3 = conv_bn_relu(128,3,3,names+'F3')(pool2)\n"," pool3 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool3')(Features3)\n"," Features4 = conv_bn_relu(512,3,3,names+'F4')(pool3)\n"," up5 = UpSampling2D(size=(2, 2),name=names+'Upsample1')(Features4)\n"," Features5 = conv_bn_relu(128,3,3,names+'F5')(up5)\n"," up6 = UpSampling2D(size=(2, 2),name=names+'Upsample2')(Features5)\n"," Features6 = conv_bn_relu(64,3,3,names+'F6')(up6)\n"," up7 = UpSampling2D(size=(2, 2),name=names+'Upsample3')(Features6)\n"," Features7 = conv_bn_relu(32,3,3,names+'F7')(up7)\n"," return Features7\n","\n","# Define the Model building for an arbitrary input size\n","def buildModel(input_dim, initial_learning_rate = 0.001):\n"," input_ = Input (shape = (input_dim))\n"," act_ = CNN (input_,'CNN')\n"," density_pred = Convolution2D(1, kernel_size=(1, 1), strides=(1, 1), padding=\"same\",\\\n"," activation=\"linear\", use_bias = False,\\\n"," kernel_initializer=\"Orthogonal\",name='Prediction')(act_)\n"," model = Model (inputs= input_, outputs=density_pred)\n"," opt = optimizers.Adam(lr = initial_learning_rate)\n"," model.compile(optimizer=opt, loss = L1L2loss(input_dim))\n"," return model\n","\n","\n","# define a function that trains a model for a given data SNR and density\n","def train_model(patches, heatmaps, modelPath, epochs, steps_per_epoch, batch_size, upsampling_factor=8, validation_split = 0.3, initial_learning_rate = 0.001, pretrained_model_path = '', L2_weighting_factor = 100):\n"," \n"," \"\"\"\n"," This function trains a CNN model on the desired training set, given the \n"," upsampled training images and labels generated in MATLAB.\n"," \n"," # Inputs\n"," # TO UPDATE ----------\n","\n"," # Outputs\n"," function saves the weights of the trained model to a hdf5, and the \n"," normalization factors to a mat file. These will be loaded later for testing \n"," the model in test_model. \n"," \"\"\"\n"," \n"," # for reproducibility\n"," np.random.seed(123)\n","\n"," X_train, X_test, y_train, y_test = train_test_split(patches, heatmaps, test_size = validation_split, random_state=42)\n"," print('Number of training examples: %d' % X_train.shape[0])\n"," print('Number of validation examples: %d' % X_test.shape[0])\n"," \n"," # Setting type\n"," X_train = X_train.astype('float32')\n"," X_test = X_test.astype('float32')\n"," y_train = y_train.astype('float32')\n"," y_test = y_test.astype('float32')\n","\n"," \n"," #===================== Training set normalization ==========================\n"," # normalize training images to be in the range [0,1] and calculate the \n"," # training set mean and std\n"," mean_train = np.zeros(X_train.shape[0],dtype=np.float32)\n"," std_train = np.zeros(X_train.shape[0], dtype=np.float32)\n"," for i in range(X_train.shape[0]):\n"," X_train[i, :, :] = project_01(X_train[i, :, :])\n"," mean_train[i] = X_train[i, :, :].mean()\n"," std_train[i] = X_train[i, :, :].std()\n","\n"," # resulting normalized training images\n"," mean_val_train = mean_train.mean()\n"," std_val_train = std_train.mean()\n"," X_train_norm = np.zeros(X_train.shape, dtype=np.float32)\n"," for i in range(X_train.shape[0]):\n"," X_train_norm[i, :, :] = normalize_im(X_train[i, :, :], mean_val_train, std_val_train)\n"," \n"," # patch size\n"," psize = X_train_norm.shape[1]\n","\n"," # Reshaping\n"," X_train_norm = X_train_norm.reshape(X_train.shape[0], psize, psize, 1)\n","\n"," # ===================== Test set normalization ==========================\n"," # normalize test images to be in the range [0,1] and calculate the test set \n"," # mean and std\n"," mean_test = np.zeros(X_test.shape[0],dtype=np.float32)\n"," std_test = np.zeros(X_test.shape[0], dtype=np.float32)\n"," for i in range(X_test.shape[0]):\n"," X_test[i, :, :] = project_01(X_test[i, :, :])\n"," mean_test[i] = X_test[i, :, :].mean()\n"," std_test[i] = X_test[i, :, :].std()\n","\n"," # resulting normalized test images\n"," mean_val_test = mean_test.mean()\n"," std_val_test = std_test.mean()\n"," X_test_norm = np.zeros(X_test.shape, dtype=np.float32)\n"," for i in range(X_test.shape[0]):\n"," X_test_norm[i, :, :] = normalize_im(X_test[i, :, :], mean_val_test, std_val_test)\n"," \n"," # Reshaping\n"," X_test_norm = X_test_norm.reshape(X_test.shape[0], psize, psize, 1)\n","\n"," # Reshaping labels\n"," Y_train = y_train.reshape(y_train.shape[0], psize, psize, 1)\n"," Y_test = y_test.reshape(y_test.shape[0], psize, psize, 1)\n","\n"," # Save datasets to a matfile to open later in matlab\n"," mdict = {\"mean_test\": mean_val_test, \"std_test\": std_val_test, \"upsampling_factor\": upsampling_factor, \"Normalization factor\": L2_weighting_factor}\n"," sio.savemat(os.path.join(modelPath,\"model_metadata.mat\"), mdict)\n","\n","\n"," # Set the dimensions ordering according to tensorflow consensous\n"," # K.set_image_dim_ordering('tf')\n"," K.set_image_data_format('channels_last')\n","\n"," # Save the model weights after each epoch if the validation loss decreased\n"," checkpointer = ModelCheckpoint(filepath=os.path.join(modelPath,\"weights_best.hdf5\"), verbose=1,\n"," save_best_only=True)\n","\n"," # Change learning when loss reaches a plataeu\n"," change_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=0.00005)\n"," \n"," # Model building and complitation\n"," model = buildModel((psize, psize, 1), initial_learning_rate = initial_learning_rate)\n"," model.summary()\n","\n"," # Load pretrained model\n"," if not pretrained_model_path:\n"," print('Using random initial model weights.')\n"," else:\n"," print('Loading model weights from '+pretrained_model_path)\n"," model.load_weights(pretrained_model_path)\n"," \n"," # Create an image data generator for real time data augmentation\n"," datagen = ImageDataGenerator(\n"," featurewise_center=False, # set input mean to 0 over the dataset\n"," samplewise_center=False, # set each sample mean to 0\n"," featurewise_std_normalization=False, # divide inputs by std of the dataset\n"," samplewise_std_normalization=False, # divide each input by its std\n"," zca_whitening=False, # apply ZCA whitening\n"," rotation_range=0., # randomly rotate images in the range (degrees, 0 to 180)\n"," width_shift_range=0., # randomly shift images horizontally (fraction of total width)\n"," height_shift_range=0., # randomly shift images vertically (fraction of total height)\n"," zoom_range=0.,\n"," shear_range=0.,\n"," horizontal_flip=False, # randomly flip images\n"," vertical_flip=False, # randomly flip images\n"," fill_mode='constant',\n"," data_format=K.image_data_format())\n","\n"," # Fit the image generator on the training data\n"," datagen.fit(X_train_norm)\n"," \n"," # loss history recorder\n"," history = LossHistory()\n","\n"," # Inform user training begun\n"," print('-------------------------------')\n"," print('Training model...')\n","\n"," # Fit model on the batches generated by datagen.flow()\n"," train_history = model.fit_generator(datagen.flow(X_train_norm, Y_train, batch_size=batch_size), \n"," steps_per_epoch=steps_per_epoch, epochs=epochs, verbose=1, \n"," validation_data=(X_test_norm, Y_test), \n"," callbacks=[history, checkpointer, change_lr]) \n","\n"," # Inform user training ended\n"," print('-------------------------------')\n"," print('Training Complete!')\n"," \n"," # Save the last model\n"," model.save(os.path.join(modelPath, 'weights_last.hdf5'))\n","\n"," # convert the history.history dict to a pandas DataFrame: \n"," lossData = pd.DataFrame(train_history.history) \n","\n"," if os.path.exists(os.path.join(modelPath,\"Quality Control\")):\n"," shutil.rmtree(os.path.join(modelPath,\"Quality Control\"))\n","\n"," os.makedirs(os.path.join(modelPath,\"Quality Control\"))\n","\n"," # The training evaluation.csv is saved (overwrites the Files if needed). \n"," lossDataCSVpath = os.path.join(modelPath,\"Quality Control/training_evaluation.csv\")\n"," with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss','learning rate'])\n"," for i in range(len(train_history.history['loss'])):\n"," writer.writerow([train_history.history['loss'][i], train_history.history['val_loss'][i], train_history.history['lr'][i]])\n","\n"," return\n","\n","\n","# Normalization functions from Martin Weigert used in CARE\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","\n","# Multi-threaded Erf-based image construction\n","@njit(parallel=True)\n","def FromLoc2Image_Erf(xc_array, yc_array, photon_array, sigma_array, image_size = (64,64), pixel_size = 100):\n"," w = image_size[0]\n"," h = image_size[1]\n"," erfImage = np.zeros((w, h))\n"," for ij in prange(w*h):\n"," j = int(ij/w)\n"," i = ij - j*w\n"," for (xc, yc, photon, sigma) in zip(xc_array, yc_array, photon_array, sigma_array):\n"," # Don't bother if the emitter has photons <= 0 or if Sigma <= 0\n"," if (sigma > 0) and (photon > 0):\n"," S = sigma*math.sqrt(2)\n"," x = i*pixel_size - xc\n"," y = j*pixel_size - yc\n"," # Don't bother if the emitter is further than 4 sigma from the centre of the pixel\n"," if (x+pixel_size/2)**2 + (y+pixel_size/2)**2 < 16*sigma**2:\n"," ErfX = math.erf((x+pixel_size)/S) - math.erf(x/S)\n"," ErfY = math.erf((y+pixel_size)/S) - math.erf(y/S)\n"," erfImage[j][i] += 0.25*photon*ErfX*ErfY\n"," return erfImage\n","\n","\n","@njit(parallel=True)\n","def FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = (64,64), pixel_size = 100):\n"," w = image_size[0]\n"," h = image_size[1]\n"," locImage = np.zeros((image_size[0],image_size[1]) )\n"," n_locs = len(xc_array)\n","\n"," for e in prange(n_locs):\n"," locImage[int(max(min(round(yc_array[e]/pixel_size),w-1),0))][int(max(min(round(xc_array[e]/pixel_size),h-1),0))] += 1\n","\n"," return locImage\n","\n","\n","\n","def getPixelSizeTIFFmetadata(TIFFpath, display=False):\n"," with Image.open(TIFFpath) as img:\n"," meta_dict = {TAGS[key] : img.tag[key] for key in img.tag.keys()}\n","\n","\n"," # TIFF tags\n"," # https://www.loc.gov/preservation/digital/formats/content/tiff_tags.shtml\n"," # https://www.awaresystems.be/imaging/tiff/tifftags/resolutionunit.html\n"," ResolutionUnit = meta_dict['ResolutionUnit'][0] # unit of resolution\n"," width = meta_dict['ImageWidth'][0]\n"," height = meta_dict['ImageLength'][0]\n","\n"," xResolution = meta_dict['XResolution'][0] # number of pixels / ResolutionUnit\n","\n"," if len(xResolution) == 1:\n"," xResolution = xResolution[0]\n"," elif len(xResolution) == 2:\n"," xResolution = xResolution[0]/xResolution[1]\n"," else:\n"," print('Image resolution not defined.')\n"," xResolution = 1\n","\n"," if ResolutionUnit == 2:\n"," # Units given are in inches\n"," pixel_size = 0.025*1e9/xResolution\n"," elif ResolutionUnit == 3:\n"," # Units given are in cm\n"," pixel_size = 0.01*1e9/xResolution\n"," else: \n"," # ResolutionUnit is therefore 1\n"," print('Resolution unit not defined. Assuming: um')\n"," pixel_size = 1e3/xResolution\n","\n"," if display:\n"," print('Pixel size obtained from metadata: '+str(pixel_size)+' nm')\n"," print('Image size: '+str(width)+'x'+str(height))\n"," \n"," return (pixel_size, width, height)\n","\n","\n","def saveAsTIF(path, filename, array, pixel_size):\n"," \"\"\"\n"," Image saving using PIL to save as .tif format\n"," # Input \n"," path - path where it will be saved\n"," filename - name of the file to save (no extension)\n"," array - numpy array conatining the data at the required format\n"," pixel_size - physical size of pixels in nanometers (identical for x and y)\n"," \"\"\"\n","\n"," # print('Data type: '+str(array.dtype))\n"," if (array.dtype == np.uint16):\n"," mode = 'I;16'\n"," elif (array.dtype == np.uint32):\n"," mode = 'I'\n"," else:\n"," mode = 'F'\n","\n"," # Rounding the pixel size to the nearest number that divides exactly 1cm.\n"," # Resolution needs to be a rational number --> see TIFF format\n"," # pixel_size = 10000/(round(10000/pixel_size))\n","\n"," if len(array.shape) == 2:\n"," im = Image.fromarray(array)\n"," im.save(os.path.join(path, filename+'.tif'),\n"," mode = mode, \n"," resolution_unit = 3,\n"," resolution = 0.01*1e9/pixel_size)\n","\n","\n"," elif len(array.shape) == 3:\n"," imlist = []\n"," for frame in array:\n"," imlist.append(Image.fromarray(frame))\n","\n"," imlist[0].save(os.path.join(path, filename+'.tif'), save_all=True,\n"," append_images=imlist[1:],\n"," mode = mode, \n"," resolution_unit = 3,\n"," resolution = 0.01*1e9/pixel_size)\n","\n"," return\n","\n","\n","\n","\n","class Maximafinder(Layer):\n"," def __init__(self, thresh, neighborhood_size, use_local_avg, **kwargs):\n"," super(Maximafinder, self).__init__(**kwargs)\n"," self.thresh = tf.constant(thresh, dtype=tf.float32)\n"," self.nhood = neighborhood_size\n"," self.use_local_avg = use_local_avg\n","\n"," def build(self, input_shape):\n"," if self.use_local_avg is True:\n"," self.kernel_x = tf.reshape(tf.constant([[-1,0,1],[-1,0,1],[-1,0,1]], dtype=tf.float32), [3, 3, 1, 1])\n"," self.kernel_y = tf.reshape(tf.constant([[-1,-1,-1],[0,0,0],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])\n"," self.kernel_sum = tf.reshape(tf.constant([[1,1,1],[1,1,1],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])\n","\n"," def call(self, inputs):\n","\n"," # local maxima positions\n"," max_pool_image = MaxPooling2D(pool_size=(self.nhood,self.nhood), strides=(1,1), padding='same')(inputs)\n"," cond = tf.math.greater(max_pool_image, self.thresh) & tf.math.equal(max_pool_image, inputs)\n"," indices = tf.where(cond)\n"," bind, xind, yind = indices[:, 0], indices[:, 2], indices[:, 1]\n"," confidence = tf.gather_nd(inputs, indices)\n","\n"," # local CoG estimator\n"," if self.use_local_avg:\n"," x_image = K.conv2d(inputs, self.kernel_x, padding='same')\n"," y_image = K.conv2d(inputs, self.kernel_y, padding='same')\n"," sum_image = K.conv2d(inputs, self.kernel_sum, padding='same')\n"," confidence = tf.cast(tf.gather_nd(sum_image, indices), dtype=tf.float32)\n"," x_local = tf.math.divide(tf.gather_nd(x_image, indices),tf.gather_nd(sum_image, indices))\n"," y_local = tf.math.divide(tf.gather_nd(y_image, indices),tf.gather_nd(sum_image, indices))\n"," xind = tf.cast(xind, dtype=tf.float32) + tf.cast(x_local, dtype=tf.float32)\n"," yind = tf.cast(yind, dtype=tf.float32) + tf.cast(y_local, dtype=tf.float32)\n"," else:\n"," xind = tf.cast(xind, dtype=tf.float32)\n"," yind = tf.cast(yind, dtype=tf.float32)\n"," \n"," return bind, xind, yind, confidence\n","\n"," def get_config(self):\n","\n"," # Implement get_config to enable serialization. This is optional.\n"," base_config = super(Maximafinder, self).get_config()\n"," config = {}\n"," return dict(list(base_config.items()) + list(config.items()))\n","\n","\n","\n","# ------------------------------- Prediction with postprocessing function-------------------------------\n","def batchFramePredictionLocalization(dataPath, filename, modelPath, savePath, batch_size=1, thresh=0.1, neighborhood_size=3, use_local_avg = False, pixel_size = None):\n"," \"\"\"\n"," This function tests a trained model on the desired test set, given the \n"," tiff stack of test images, learned weights, and normalization factors.\n"," \n"," # Inputs\n"," dataPath - the path to the folder containing the tiff stack(s) to run prediction on \n"," filename - the name of the file to process\n"," modelPath - the path to the folder containing the weights file and the mean and standard deviation file generated in train_model\n"," savePath - the path to the folder where to save the prediction\n"," batch_size. - the number of frames to predict on for each iteration\n"," thresh - threshoold percentage from the maximum of the gaussian scaling\n"," neighborhood_size - the size of the neighborhood for local maxima finding\n"," use_local_average - Boolean whether to perform local averaging or not\n"," \"\"\"\n"," \n"," # load mean and std\n"," matfile = sio.loadmat(os.path.join(modelPath,'model_metadata.mat'))\n"," test_mean = np.array(matfile['mean_test'])\n"," test_std = np.array(matfile['std_test']) \n"," upsampling_factor = np.array(matfile['upsampling_factor'])\n"," upsampling_factor = upsampling_factor.item() # convert to scalar\n"," L2_weighting_factor = np.array(matfile['Normalization factor'])\n"," L2_weighting_factor = L2_weighting_factor.item() # convert to scalar\n","\n"," # Read in the raw file\n"," Images = io.imread(os.path.join(dataPath, filename))\n"," if pixel_size == None:\n"," pixel_size, _, _ = getPixelSizeTIFFmetadata(os.path.join(dataPath, filename), display=True)\n"," pixel_size_hr = pixel_size/upsampling_factor\n","\n"," # get dataset dimensions\n"," (nFrames, M, N) = Images.shape\n"," print('Input image is '+str(N)+'x'+str(M)+' with '+str(nFrames)+' frames.')\n","\n"," # Build the model for a bigger image\n"," model = buildModel((upsampling_factor*M, upsampling_factor*N, 1))\n","\n"," # Load the trained weights\n"," model.load_weights(os.path.join(modelPath,'weights_best.hdf5'))\n","\n"," # add a post-processing module\n"," max_layer = Maximafinder(thresh*L2_weighting_factor, neighborhood_size, use_local_avg)\n","\n"," # Initialise the results: lists will be used to collect all the localizations\n"," frame_number_list, x_nm_list, y_nm_list, confidence_au_list = [], [], [], []\n","\n"," # Initialise the results\n"," Prediction = np.zeros((M*upsampling_factor, N*upsampling_factor), dtype=np.float32)\n"," Widefield = np.zeros((M, N), dtype=np.float32)\n","\n"," # run model in batches\n"," n_batches = math.ceil(nFrames/batch_size)\n"," for b in tqdm(range(n_batches)):\n","\n"," nF = min(batch_size, nFrames - b*batch_size)\n"," Images_norm = np.zeros((nF, M, N),dtype=np.float32)\n"," Images_upsampled = np.zeros((nF, M*upsampling_factor, N*upsampling_factor), dtype=np.float32)\n","\n"," # Upsampling using a simple nearest neighbor interp and calculating - MULTI-THREAD this?\n"," for f in range(nF):\n"," Images_norm[f,:,:] = project_01(Images[b*batch_size+f,:,:])\n"," Images_norm[f,:,:] = normalize_im(Images_norm[f,:,:], test_mean, test_std)\n"," Images_upsampled[f,:,:] = np.kron(Images_norm[f,:,:], np.ones((upsampling_factor,upsampling_factor)))\n"," Widefield += Images[b*batch_size+f,:,:]\n","\n"," # Reshaping\n"," Images_upsampled = np.expand_dims(Images_upsampled,axis=3)\n","\n"," # Run prediction and local amxima finding\n"," predicted_density = model.predict_on_batch(Images_upsampled)\n"," predicted_density[predicted_density < 0] = 0\n"," Prediction += predicted_density.sum(axis = 3).sum(axis = 0)\n","\n"," bind, xind, yind, confidence = max_layer(predicted_density)\n"," \n"," # normalizing the confidence by the L2_weighting_factor\n"," confidence /= L2_weighting_factor \n","\n"," # turn indices to nms and append to the results\n"," xind, yind = xind*pixel_size_hr, yind*pixel_size_hr\n"," frmind = (bind.numpy() + b*batch_size + 1).tolist()\n"," xind = xind.numpy().tolist()\n"," yind = yind.numpy().tolist()\n"," confidence = confidence.numpy().tolist()\n"," frame_number_list += frmind\n"," x_nm_list += xind\n"," y_nm_list += yind\n"," confidence_au_list += confidence\n","\n"," # Open and create the csv file that will contain all the localizations\n"," if use_local_avg:\n"," ext = '_avg'\n"," else:\n"," ext = '_max'\n"," with open(os.path.join(savePath, 'Localizations_' + os.path.splitext(filename)[0] + ext + '.csv'), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow(['frame', 'x [nm]', 'y [nm]', 'confidence [a.u]'])\n"," locs = list(zip(frame_number_list, x_nm_list, y_nm_list, confidence_au_list))\n"," writer.writerows(locs)\n","\n"," # Save the prediction and widefield image\n"," Widefield = np.kron(Widefield, np.ones((upsampling_factor,upsampling_factor)))\n"," Widefield = np.float32(Widefield)\n","\n"," # io.imsave(os.path.join(savePath, 'Predicted_'+os.path.splitext(filename)[0]+'.tif'), Prediction)\n"," # io.imsave(os.path.join(savePath, 'Widefield_'+os.path.splitext(filename)[0]+'.tif'), Widefield)\n","\n"," saveAsTIF(savePath, 'Predicted_'+os.path.splitext(filename)[0], Prediction, pixel_size_hr)\n"," saveAsTIF(savePath, 'Widefield_'+os.path.splitext(filename)[0], Widefield, pixel_size_hr)\n","\n","\n"," return\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n"," NORMAL = '\\033[0m' # white (normal)\n","\n","\n","\n","def list_files(directory, extension):\n"," return (f for f in os.listdir(directory) if f.endswith('.' + extension))\n","\n","\n","# @njit(parallel=True)\n","def subPixelMaxLocalization(array, method = 'CoM', patch_size = 3):\n"," xMaxInd, yMaxInd = np.unravel_index(array.argmax(), array.shape, order='C')\n"," centralPatch = XC[(xMaxInd-patch_size):(xMaxInd+patch_size+1),(yMaxInd-patch_size):(yMaxInd+patch_size+1)]\n","\n"," if (method == 'MAX'):\n"," x0 = xMaxInd\n"," y0 = yMaxInd\n","\n"," elif (method == 'CoM'):\n"," x0 = 0\n"," y0 = 0\n"," S = 0\n"," for xy in range(patch_size*patch_size):\n"," y = math.floor(xy/patch_size)\n"," x = xy - y*patch_size\n"," x0 += x*array[x,y]\n"," y0 += y*array[x,y]\n"," S = array[x,y]\n"," \n"," x0 = x0/S - patch_size/2 + xMaxInd\n"," y0 = y0/S - patch_size/2 + yMaxInd\n"," \n"," elif (method == 'Radiality'):\n"," # Not implemented yet\n"," x0 = xMaxInd\n"," y0 = yMaxInd\n"," \n"," return (x0, y0)\n","\n","\n","@njit(parallel=True)\n","def correctDriftLocalization(xc_array, yc_array, frames, xDrift, yDrift):\n"," n_locs = xc_array.shape[0]\n"," xc_array_Corr = np.empty(n_locs)\n"," yc_array_Corr = np.empty(n_locs)\n"," \n"," for loc in prange(n_locs):\n"," xc_array_Corr[loc] = xc_array[loc] - xDrift[frames[loc]]\n"," yc_array_Corr[loc] = yc_array[loc] - yDrift[frames[loc]]\n","\n"," return (xc_array_Corr, yc_array_Corr)\n","\n","\n","print('--------------------------------')\n","print('DeepSTORM installation complete.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"vu8f5NGJkJos","colab_type":"text"},"source":["\n","# **3. Generate patches for training**\n","---\n","\n","For Deep-STORM the training data can be obtained in two ways:\n","* Simulated using ThunderSTORM or other simulation tool and loaded here (**using Section 3.1.a**)\n","* Directly simulated in this notebook (**using Section 3.1.b**)\n"]},{"cell_type":"markdown","metadata":{"id":"WSV8xnlynp0l","colab_type":"text"},"source":["## **3.1.a Load training data**\n","---\n","\n","Here you can load your simulated data along with its corresponding localization file.\n","* The `pixel_size` is defined in nanometer (nm). "]},{"cell_type":"code","metadata":{"id":"CT6SNcfNg6j0","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Load raw data\n","\n","# Get user input\n","ImageData_path = \"\" #@param {type:\"string\"}\n","LocalizationData_path = \"\" #@param {type: \"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value:\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","if get_pixel_size_from_file:\n"," pixel_size,_,_ = getPixelSizeTIFFmetadata(ImageData_path, True)\n","\n","# load the tiff data\n","Images = io.imread(ImageData_path)\n","# get dataset dimensions\n","if len(Images.shape) == 3:\n"," (number_of_frames, M, N) = Images.shape\n","elif len(Images.shape) == 2:\n"," (M, N) = Images.shape\n"," number_of_frames = 1\n","print('Loaded images: '+str(M)+'x'+str(N)+' with '+str(number_of_frames)+' frames')\n","\n","# Interactive display of the stack\n","def scroll_in_time(frame):\n"," f=plt.figure(figsize=(6,6))\n"," plt.imshow(Images[frame-1], interpolation='nearest', cmap = 'gray')\n"," plt.title('Training source at frame = ' + str(frame))\n"," plt.axis('off');\n","\n","if number_of_frames > 1:\n"," interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));\n","else:\n"," f=plt.figure(figsize=(6,6))\n"," plt.imshow(Images, interpolation='nearest', cmap = 'gray')\n"," plt.title('Training source')\n"," plt.axis('off');\n","\n","# Load the localization file and display the first\n","LocData = pd.read_csv(LocalizationData_path, index_col=0)\n","LocData.tail()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K9xE5GeYiks9","colab_type":"text"},"source":["## **3.1.b Simulate training data**\n","---\n","This simulation tool allows you to generate SMLM data of randomly distrubuted emitters in a field-of-view. \n","The assumptions are as follows:\n","\n","* Gaussian Point Spread Function (PSF) with standard deviation defined by `Sigma`. The nominal value of `sigma` can be evaluated using `sigma = 0.21 x Lambda / NA`. \n","* Each emitter will emit `n_photons` per frame, and generate their equivalent Poisson noise.\n","* The camera will contribute Gaussian noise to the signal with a standard deviation defined by `ReadOutNoise_ADC` in ADC\n","* The `emitter_density` is defined as the number of emitters / um^2 on any given frame. Variability in the emitter density can be applied by adjusting `emitter_density_std`. The latter parameter represents the standard deviation of the normal distribution that the density is drawn from for each individual frame. `emitter_density` **is defined in number of emitters / um^2**.\n","* The `n_photons` and `sigma` can additionally include some Gaussian variability by setting `n_photons_std` and `sigma_std`.\n","\n","Important note:\n","- All dimensions are in nanometer (e.g. `FOV_size` = 6400 represents a field of view of 6.4 um x 6.4 um).\n","\n"]},{"cell_type":"code","metadata":{"id":"sQyLXpEhitsg","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# ---------------------------- User input ----------------------------\n","#@markdown Run the simulation\n","#@markdown --- \n","#@markdown Camera settings: \n","FOV_size = 6400#@param {type:\"number\"}\n","pixel_size = 100#@param {type:\"number\"}\n","ADC_per_photon_conversion = 1 #@param {type:\"number\"}\n","ReadOutNoise_ADC = 4.5#@param {type:\"number\"}\n","ADC_offset = 50#@param {type:\"number\"}\n","\n","#@markdown Acquisition settings: \n","emitter_density = 6#@param {type:\"number\"}\n","emitter_density_std = 0#@param {type:\"number\"}\n","\n","number_of_frames = 20#@param {type:\"integer\"}\n","\n","sigma = 110 #@param {type:\"number\"}\n","sigma_std = 5 #@param {type:\"number\"}\n","# NA = 1.1 #@param {type:\"number\"}\n","# wavelength = 800#@param {type:\"number\"}\n","# wavelength_std = 150#@param {type:\"number\"}\n","n_photons = 2250#@param {type:\"number\"}\n","n_photons_std = 250#@param {type:\"number\"}\n","\n","\n","# ---------------------------- Variable initialisation ----------------------------\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","print('-----------------------------------------------------------')\n","n_molecules = emitter_density*FOV_size*FOV_size/10**6\n","n_molecules_std = emitter_density_std*FOV_size*FOV_size/10**6\n","print('Number of molecules / FOV: '+str(round(n_molecules,2))+' +/- '+str((round(n_molecules_std,2))))\n","\n","# sigma = 0.21*wavelength/NA\n","# sigma_std = 0.21*wavelength_std/NA\n","# print('Gaussian PSF sigma: '+str(round(sigma,2))+' +/- '+str(round(sigma_std,2))+' nm')\n","\n","M = N = round(FOV_size/pixel_size)\n","FOV_size = M*pixel_size\n","print('Final image size: '+str(M)+'x'+str(M)+' ('+str(round(FOV_size/1000, 3))+'um x'+str(round(FOV_size/1000,3))+' um)')\n","\n","np.random.seed(1)\n","display_upsampling = 8 # used to display the loc map here\n","NoiseFreeImages = np.zeros((number_of_frames, M, M))\n","locImage = np.zeros((number_of_frames, display_upsampling*M, display_upsampling*N))\n","\n","frames = []\n","all_xloc = []\n","all_yloc = []\n","all_photons = []\n","all_sigmas = []\n","\n","# ---------------------------- Main simulation loop ----------------------------\n","print('-----------------------------------------------------------')\n","for f in tqdm(range(number_of_frames)):\n"," \n"," # Define the coordinates of emitters by randomly distributing them across the FOV\n"," n_mol = int(max(round(np.random.normal(n_molecules, n_molecules_std, size=1)[0]), 0))\n"," x_c = np.random.uniform(low=0.0, high=FOV_size, size=n_mol)\n"," y_c = np.random.uniform(low=0.0, high=FOV_size, size=n_mol)\n"," photon_array = np.random.normal(n_photons, n_photons_std, size=n_mol)\n"," sigma_array = np.random.normal(sigma, sigma_std, size=n_mol)\n"," # x_c = np.linspace(0,3000,5)\n"," # y_c = np.linspace(0,3000,5)\n","\n"," all_xloc += x_c.tolist()\n"," all_yloc += y_c.tolist()\n"," frames += ((f+1)*np.ones(x_c.shape[0])).tolist()\n"," all_photons += photon_array.tolist()\n"," all_sigmas += sigma_array.tolist()\n","\n"," locImage[f] = FromLoc2Image_SimpleHistogram(x_c, y_c, image_size = (N*display_upsampling, M*display_upsampling), pixel_size = pixel_size/display_upsampling)\n","\n"," # # Get the approximated locations according to the grid pixel size\n"," # Chr_emitters = [int(max(min(round(display_upsampling*x_c[i]/pixel_size),N*display_upsampling-1),0)) for i in range(len(x_c))]\n"," # Rhr_emitters = [int(max(min(round(display_upsampling*y_c[i]/pixel_size),M*display_upsampling-1),0)) for i in range(len(y_c))]\n","\n"," # # Build Localization image\n"," # for (r,c) in zip(Rhr_emitters, Chr_emitters):\n"," # locImage[f][r][c] += 1\n","\n"," NoiseFreeImages[f] = FromLoc2Image_Erf(x_c, y_c, photon_array, sigma_array, image_size = (M,M), pixel_size = pixel_size)\n","\n","\n","# ---------------------------- Create DataFrame fof localization file ----------------------------\n","# Table with localization info as dataframe output\n","LocData = pd.DataFrame()\n","LocData[\"frame\"] = frames\n","LocData[\"x [nm]\"] = all_xloc\n","LocData[\"y [nm]\"] = all_yloc\n","LocData[\"Photon #\"] = all_photons\n","LocData[\"Sigma [nm]\"] = all_sigmas\n","LocData.index += 1 # set indices to start at 1 and not 0 (same as ThunderSTORM)\n","\n","\n","# ---------------------------- Estimation of SNR ----------------------------\n","n_frames_for_SNR = 100\n","M_SNR = 10\n","x_c = np.random.uniform(low=0.0, high=pixel_size*M_SNR, size=n_frames_for_SNR)\n","y_c = np.random.uniform(low=0.0, high=pixel_size*M_SNR, size=n_frames_for_SNR)\n","photon_array = np.random.normal(n_photons, n_photons_std, size=n_frames_for_SNR)\n","sigma_array = np.random.normal(sigma, sigma_std, size=n_frames_for_SNR)\n","\n","SNR = np.zeros(n_frames_for_SNR)\n","for i in range(n_frames_for_SNR):\n"," SingleEmitterImage = FromLoc2Image_Erf(np.array([x_c[i]]), np.array([x_c[i]]), np.array([photon_array[i]]), np.array([sigma_array[i]]), (M_SNR, M_SNR), pixel_size)\n"," Signal_photon = np.max(SingleEmitterImage)\n"," Noise_photon = math.sqrt((ReadOutNoise_ADC/ADC_per_photon_conversion)**2 + Signal_photon)\n"," SNR[i] = Signal_photon/Noise_photon\n","\n","print('SNR: '+str(round(np.mean(SNR),2))+' +/- '+str(round(np.std(SNR),2)))\n","# ---------------------------- ----------------------------\n","\n","\n","# Table with info\n","simParameters = pd.DataFrame()\n","simParameters[\"FOV size (nm)\"] = [FOV_size]\n","simParameters[\"Pixel size (nm)\"] = [pixel_size]\n","simParameters[\"ADC/photon\"] = [ADC_per_photon_conversion]\n","simParameters[\"Read-out noise (ADC)\"] = [ReadOutNoise_ADC]\n","simParameters[\"Constant offset (ADC)\"] = [ADC_offset]\n","\n","simParameters[\"Emitter density (emitters/um^2)\"] = [emitter_density]\n","simParameters[\"STD of emitter density (emitters/um^2)\"] = [emitter_density_std]\n","simParameters[\"Number of frames\"] = [number_of_frames]\n","# simParameters[\"NA\"] = [NA]\n","# simParameters[\"Wavelength (nm)\"] = [wavelength]\n","# simParameters[\"STD of wavelength (nm)\"] = [wavelength_std]\n","simParameters[\"Sigma (nm))\"] = [sigma]\n","simParameters[\"STD of Sigma (nm))\"] = [sigma_std]\n","simParameters[\"Number of photons\"] = [n_photons]\n","simParameters[\"STD of number of photons\"] = [n_photons_std]\n","simParameters[\"SNR\"] = [np.mean(SNR)]\n","simParameters[\"STD of SNR\"] = [np.std(SNR)]\n","\n","\n","# ---------------------------- Finish simulation ----------------------------\n","# Calculating the noisy image\n","Images = ADC_per_photon_conversion * np.random.poisson(NoiseFreeImages) + ReadOutNoise_ADC * np.random.normal(size = (number_of_frames, M, N)) + ADC_offset\n","Images[Images <= 0] = 0\n","\n","# Convert to 16-bit or 32-bits integers\n","if Images.max() < (2**16-1):\n"," Images = Images.astype(np.uint16)\n","else:\n"," Images = Images.astype(np.uint32)\n","\n","\n","# ---------------------------- Display ----------------------------\n","# Displaying the time elapsed for simulation\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds,1),\"sec(s)\")\n","\n","\n","# Interactively display the results using Widgets\n","def scroll_in_time(frame):\n"," f = plt.figure(figsize=(18,6))\n"," plt.subplot(1,3,1)\n"," plt.imshow(locImage[frame-1], interpolation='bilinear', vmin = 0, vmax=0.1)\n"," plt.title('Localization image')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,2)\n"," plt.imshow(NoiseFreeImages[frame-1], interpolation='nearest', cmap='gray')\n"," plt.title('Noise-free simulation')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,3)\n"," plt.imshow(Images[frame-1], interpolation='nearest', cmap='gray')\n"," plt.title('Noisy simulation')\n"," plt.axis('off');\n","\n","interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));\n","\n","# Display the head of the dataframe with localizations\n","LocData.tail()\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Pz7RfSuoeJeq","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ---\n","# @markdown #Play this cell to save the simulated stack\n","# @markdown ####Please select a path to the folder where to save the simulated data. It is not necesary to save the data to run the training, but keeping the simulated for your own record can be useful to check its validity.\n","Save_path = \"\" #@param {type:\"string\"}\n","\n","if not os.path.exists(Save_path):\n"," os.makedirs(Save_path)\n"," print('Folder created.')\n","else:\n"," print('Training data already exists in folder: Data overwritten.')\n","\n","saveAsTIF(Save_path, 'SimulatedDataset', Images, pixel_size)\n","# io.imsave(os.path.join(Save_path, 'SimulatedDataset.tif'),Images)\n","LocData.to_csv(os.path.join(Save_path, 'SimulatedDataset.csv'))\n","simParameters.to_csv(os.path.join(Save_path, 'SimulatedParameters.csv'))\n","print('Training dataset saved.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K_8e3kE-JhVY","colab_type":"text"},"source":["## **3.2. Generate training patches**\n","---\n","\n","Training patches need to be created from the training data generated above. \n","* The `patch_size` needs to give sufficient contextual information and for most cases a `patch_size` of 26 (corresponding to patches of 26x26 pixels) works fine. **DEFAULT: 26**\n","* The `upsampling_factor` defines the effective magnification of the final super-resolved image compared to the input image (this is called magnification in ThunderSTORM). This is used to generate the super-resolved patches as target dataset. Using an `upsampling_factor` of 16 will require the use of more memory and it may be necessary to decreae the `patch_size` to 16 for example. **DEFAULT: 8**\n","* The `num_patches_per_frame` defines the number of patches extracted from each frame generated in section 3.1. **DEFAULT: 500**\n","* The `min_number_of_emitters_per_patch` defines the minimum number of emitters that need to be present in the patch to be a valid patch. An empty patch does not contain useful information for the network to learn from. **DEFAULT: 7**\n","* The `max_num_patches` defines the maximum number of patches to generate. Fewer may be generated depending on how many pacthes are rejected and how many frames are available. **DEFAULT: 10000**\n","* The `gaussian_sigma` defines the Gaussian standard deviation (in magnified pixels) applied to generate the super-resolved target image. **DEFAULT: 1**\n","* The `L2_weighting_factor` is a normalization factor used in the loss function. It helps balancing the loss from the L2 norm. When using higher densities, this factor should be decreased and vice-versa. This factor can be autimatically calculated using an empiraical formula. **DEFAULT: 100**\n","\n"]},{"cell_type":"code","metadata":{"id":"AsNx5KzcFNvC","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ## **Provide patch parameters**\n","\n","\n","# -------------------- User input --------------------\n","patch_size = 26 #@param {type:\"integer\"}\n","upsampling_factor = 8 #@param [\"4\", \"8\", \"16\"] {type:\"raw\"}\n","num_patches_per_frame = 500#@param {type:\"integer\"}\n","min_number_of_emitters_per_patch = 7#@param {type:\"integer\"}\n","max_num_patches = 10000#@param {type:\"integer\"}\n","gaussian_sigma = 1#@param {type:\"integer\"}\n","\n","#@markdown Estimate the optimal normalization factor automatically?\n","Automatic_normalization = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, it will use the following value:\n","L2_weighting_factor = 100 #@param {type:\"number\"}\n","\n","\n","# -------------------- Prepare variables --------------------\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","# Initialize some parameters\n","pixel_size_hr = pixel_size/upsampling_factor # in nm\n","n_patches = min(number_of_frames*num_patches_per_frame, max_num_patches)\n","patch_size = patch_size*upsampling_factor\n","\n","# Dimensions of the high-res grid\n","Mhr = upsampling_factor*M # in pixels\n","Nhr = upsampling_factor*N # in pixels\n","\n","# Initialize the training patches and labels\n","patches = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","spikes = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","heatmaps = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","\n","# Run over all frames and construct the training examples\n","k = 1 # current patch count\n","skip_counter = 0 # number of dataset skipped due to low density\n","id_start = 0 # id position in LocData for current frame\n","print('Generating '+str(n_patches)+' patches of '+str(patch_size)+'x'+str(patch_size))\n","\n","n_locs = len(LocData.index)\n","print('Total number of localizations: '+str(n_locs))\n","density = n_locs/(M*N*number_of_frames*(0.001*pixel_size)**2)\n","print('Density: '+str(round(density,2))+' locs/um^2')\n","n_locs_per_patch = patch_size**2*density\n","\n","if Automatic_normalization:\n"," # This empirical formulae attempts to balance the loss L2 function between the background and the bright spikes\n"," # A value of 100 was originally chosen to balance L2 for a patch size of 2.6x2.6^2 0.1um pixel size and density of 3 (hence the 20.28), at upsampling_factor = 8\n"," L2_weighting_factor = 100/math.sqrt(min(n_locs_per_patch, min_number_of_emitters_per_patch)*8**2/(upsampling_factor**2*20.28))\n"," print('Normalization factor: '+str(round(L2_weighting_factor,2)))\n","\n","# -------------------- Patch generation loop --------------------\n","\n","print('-----------------------------------------------------------')\n","for (f, thisFrame) in enumerate(tqdm(Images)):\n","\n"," # Upsample the frame\n"," upsampledFrame = np.kron(thisFrame, np.ones((upsampling_factor,upsampling_factor)))\n"," # Read all the provided high-resolution locations for current frame\n"," DataFrame = LocData[LocData['frame'] == f+1].copy()\n","\n"," # Get the approximated locations according to the high-res grid pixel size\n"," Chr_emitters = [int(max(min(round(DataFrame['x [nm]'][i]/pixel_size_hr),Nhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]\n"," Rhr_emitters = [int(max(min(round(DataFrame['y [nm]'][i]/pixel_size_hr),Mhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]\n"," id_start += len(DataFrame.index)\n","\n"," # Build Localization image\n"," LocImage = np.zeros((Mhr,Nhr))\n"," LocImage[(Rhr_emitters, Chr_emitters)] = 1\n","\n"," # Here, there's a choice between the original Gaussian (classification approach) and using the erf function\n"," HeatMapImage = L2_weighting_factor*gaussian_filter(LocImage, gaussian_sigma) \n"," # HeatMapImage = L2_weighting_factor*FromLoc2Image_MultiThreaded(np.array(list(DataFrame['x [nm]'])), np.array(list(DataFrame['y [nm]'])), \n"," # np.ones(len(DataFrame.index)), pixel_size_hr*gaussian_sigma*np.ones(len(DataFrame.index)), \n"," # Mhr, pixel_size_hr)\n"," \n","\n"," # Generate random position for the top left corner of the patch\n"," xc = np.random.randint(0, Mhr-patch_size, size=num_patches_per_frame)\n"," yc = np.random.randint(0, Nhr-patch_size, size=num_patches_per_frame)\n","\n"," for c in range(len(xc)):\n"," if LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size].sum() < min_number_of_emitters_per_patch:\n"," skip_counter += 1\n"," continue\n"," \n"," else:\n"," # Limit maximal number of training examples to 15k\n"," if k > max_num_patches:\n"," break\n"," else:\n"," # Assign the patches to the right part of the images\n"," patches[k-1] = upsampledFrame[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," spikes[k-1] = LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," heatmaps[k-1] = HeatMapImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," k += 1 # increment current patch count\n","\n","# Remove the empty data\n","patches = patches[:k-1]\n","spikes = spikes[:k-1]\n","heatmaps = heatmaps[:k-1]\n","n_patches = k-1\n","\n","# -------------------- Failsafe --------------------\n","# Check if the size of the training set is smaller than 5k to notify user to simulate more images using ThunderSTORM\n","if ((k-1) < 5000):\n"," # W = '\\033[0m' # white (normal)\n"," # R = '\\033[31m' # red\n"," print(bcolors.WARNING+'!! WARNING: Training set size is below 5K - Consider simulating more images in ThunderSTORM. !!'+bcolors.NORMAL)\n","\n","\n","\n","# -------------------- Displays --------------------\n","print('Number of patches skipped due to low density: '+str(skip_counter))\n","# dataSize = int((getsizeof(patches)+getsizeof(heatmaps)+getsizeof(spikes))/(1024*1024)) #rounded in MB\n","# print('Size of patches: '+str(dataSize)+' MB')\n","print(str(n_patches)+' patches were generated.')\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","# Display patches interactively with a slider\n","def scroll_patches(patch):\n"," f = plt.figure(figsize=(16,6))\n"," plt.subplot(1,3,1)\n"," plt.imshow(patches[patch-1], interpolation='nearest', cmap='gray')\n"," plt.title('Raw data (frame #'+str(patch)+')')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,2)\n"," plt.imshow(heatmaps[patch-1], interpolation='nearest')\n"," plt.title('Heat map')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,3)\n"," plt.imshow(spikes[patch-1], interpolation='nearest')\n"," plt.title('Localization map')\n"," plt.axis('off');\n","\n","interact(scroll_patches, patch=widgets.IntSlider(min=1, max=patches.shape[0], step=1, value=0, continuous_update=False));\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DSjXFMevK7Iz","colab_type":"text"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"hVeyKU0MdAPx","colab_type":"text"},"source":["## **4.1. Select your paths and parameters**\n","\n","---\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for ~100 epochs. Evaluate the performance after training (see 5). **Default value: 80**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. **If this value is set to 0**, by default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 30** \n","\n","**`initial_learning_rate`:** This parameter represents the initial value to be used as learning rate in the optimizer. **Default value: 0.001**"]},{"cell_type":"code","metadata":{"id":"oa5cDZ7f_PF6","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###Path to training images and parameters\n","\n","model_path = \"\" #@param {type: \"string\"} \n","model_name = \"\" #@param {type: \"string\"} \n","number_of_epochs = 80#@param {type:\"integer\"}\n","batch_size = 16#@param {type:\"integer\"}\n","\n","number_of_steps = 0#@param {type:\"integer\"}\n","percentage_validation = 30 #@param {type:\"number\"}\n","initial_learning_rate = 0.001 #@param {type:\"number\"}\n","\n","\n","percentage_validation /= 100\n","if number_of_steps == 0: \n"," number_of_steps = int((1-percentage_validation)*n_patches/batch_size)\n"," print('Number of steps: '+str(number_of_steps))\n","\n","# Pretrained model path initialised here so next cell does not need to be run\n","h5_file_path = ''\n","Use_pretrained_model = False\n","\n","if not ('patches' in locals()):\n"," # W = '\\033[0m' # white (normal)\n"," # R = '\\033[31m' # red\n"," print(WARNING+'!! WARNING: No patches were found in memory currently. !!')\n","\n","Save_path = os.path.join(model_path, model_name)\n","if os.path.exists(Save_path):\n"," print(bcolors.WARNING+'The model folder already exists and will be overwritten.'+bcolors.NORMAL)\n","\n","print('-----------------------------')\n","print('Training parameters set.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"WIyEvQBWLp9n","colab_type":"text"},"source":["\n","## **4.2. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a Deep-STORM 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"oHL5g0w8LqR0","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_'+Weights_choice+'.hdf5 pretrained model does not exist'+bcolors.NORMAL)\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead.'+bcolors.NORMAL)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+bcolors.NORMAL)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print('No pretrained network will be used.')\n"," h5_file_path = ''\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"OADNcie-LHxA","colab_type":"text"},"source":["## **4.4. Start Trainning**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches."]},{"cell_type":"code","metadata":{"id":"qDgMu_mAK8US","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Start training\n","\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(Save_path):\n"," shutil.rmtree(Save_path)\n","\n","# Create the model folder!\n","os.makedirs(Save_path)\n","\n","# Let's go !\n","train_model(patches, heatmaps, Save_path, \n"," steps_per_epoch=number_of_steps, epochs=number_of_epochs, batch_size=batch_size,\n"," upsampling_factor = upsampling_factor,\n"," validation_split = percentage_validation,\n"," initial_learning_rate = initial_learning_rate, \n"," pretrained_model_path = h5_file_path,\n"," L2_weighting_factor = L2_weighting_factor)\n","\n","# # Show info about the GPU memory useage\n","# !nvidia-smi\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"CHVTRjEOLRDH","colab_type":"text"},"source":["##**4.5. Download your model(s) from Google Drive**\n","\n","\n","---\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"4N7-ShZpLhwr","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"JDRsm7uKoBa-","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","#@markdown #####During training, the model files are automatically saved inside a folder named after the parameter `model_name` (see section 4.1). Provide the name of this folder as `QC_model_path` . \n","\n","QC_model_path = \"\" #@param {type:\"string\"}\n","\n","if (Use_the_current_trained_model): \n"," QC_model_path = os.path.join(model_path, model_name)\n","\n","if os.path.exists(QC_model_path):\n"," print(\"The \"+os.path.basename(QC_model_path)+\" model will be evaluated\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+bcolors.NORMAL)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Gw7KaHZUoHC4","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"qUc-JMOcoGNZ","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","import csv\n","from matplotlib import pyplot as plt\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(os.path.join(QC_model_path,'Quality Control/training_evaluation.csv'),'r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(os.path.join(QC_model_path,'Quality Control/lossCurvePlots.png'))\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"32eNQjFioQkY","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"QC_image_folder\" using teh corresponding localization data contained in \"QC_loc_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"dhlTnxC5lUZy","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# ------------------------ User input ------------------------\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","QC_image_folder = \"\" #@param{type:\"string\"}\n","QC_loc_folder = \"\" #@param{type:\"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value:\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","if get_pixel_size_from_file:\n"," pixel_size_INPUT = None\n","else:\n"," pixel_size_INPUT = pixel_size\n","\n","\n","# ------------------------ QC analysis loop over provided dataset ------------------------\n","\n","savePath = os.path.join(QC_model_path, 'Quality Control')\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(os.path.join(savePath, \"QC_metrics.csv\"), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"WF v. GT mSSIM\", \"Prediction v. GT NRMSE\",\"WF v. GT NRMSE\", \"Prediction v. GT PSNR\", \"WF v. GT PSNR\"])\n","\n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvWF_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvWF_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvWF_list = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n"," for (imageFilename, locFilename) in zip(list_files(QC_image_folder, 'tif'), list_files(QC_loc_folder, 'csv')):\n"," print('--------------')\n"," print(imageFilename)\n"," print(locFilename)\n","\n"," # Get the prediction\n"," batchFramePredictionLocalization(QC_image_folder, imageFilename, QC_model_path, savePath, pixel_size = pixel_size_INPUT)\n","\n"," # test_model(QC_image_folder, imageFilename, QC_model_path, savePath, display=False);\n"," thisPrediction = io.imread(os.path.join(savePath, 'Predicted_'+imageFilename))\n"," thisWidefield = io.imread(os.path.join(savePath, 'Widefield_'+imageFilename))\n","\n"," Mhr = thisPrediction.shape[0]\n"," Nhr = thisPrediction.shape[1]\n","\n"," if pixel_size_INPUT == None:\n"," pixel_size, N, M = getPixelSizeTIFFmetadata(os.path.join(QC_image_folder,imageFilename))\n","\n"," upsampling_factor = int(Mhr/M)\n"," print('Upsampling factor: '+str(upsampling_factor))\n"," pixel_size_hr = pixel_size/upsampling_factor # in nm\n","\n"," # Load the localization file and display the first\n"," LocData = pd.read_csv(os.path.join(QC_loc_folder,locFilename), index_col=0)\n","\n"," x = np.array(list(LocData['x [nm]']))\n"," y = np.array(list(LocData['y [nm]']))\n"," locImage = FromLoc2Image_SimpleHistogram(x, y, image_size = (Mhr,Nhr), pixel_size = pixel_size_hr)\n","\n"," # Remove extension from filename\n"," imageFilename_no_extension = os.path.splitext(imageFilename)[0]\n","\n"," # io.imsave(os.path.join(savePath, 'GT_image_'+imageFilename), locImage)\n"," saveAsTIF(savePath, 'GT_image_'+imageFilename_no_extension, locImage, pixel_size_hr)\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm, test_prediction_norm = norm_minmse(locImage, thisPrediction, normalize_gt=True)\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm, test_wf_norm = norm_minmse(locImage, thisWidefield, normalize_gt=True)\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1., full=True)\n"," index_SSIM_GTvsWF, img_SSIM_GTvsWF = structural_similarity(test_GT_norm, test_wf_norm, data_range=1., full=True)\n","\n","\n"," # Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," # io.imsave(os.path.join(savePath,'SSIM_GTvsPrediction_'+imageFilename),img_SSIM_GTvsPrediction_32bit)\n"," saveAsTIF(savePath,'SSIM_GTvsPrediction_'+imageFilename_no_extension, img_SSIM_GTvsPrediction_32bit, pixel_size_hr)\n","\n","\n"," img_SSIM_GTvsWF_32bit = np.float32(img_SSIM_GTvsWF)\n"," # io.imsave(os.path.join(savePath,'SSIM_GTvsWF_'+imageFilename),img_SSIM_GTvsWF_32bit)\n"," saveAsTIF(savePath,'SSIM_GTvsWF_'+imageFilename_no_extension, img_SSIM_GTvsWF_32bit, pixel_size_hr)\n","\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsWF = np.sqrt(np.square(test_GT_norm - test_wf_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," # io.imsave(os.path.join(savePath,'RSE_GTvsPrediction_'+imageFilename),img_RSE_GTvsPrediction_32bit)\n"," saveAsTIF(savePath,'RSE_GTvsPrediction_'+imageFilename_no_extension, img_RSE_GTvsPrediction_32bit, pixel_size_hr)\n","\n"," img_RSE_GTvsWF_32bit = np.float32(img_RSE_GTvsWF)\n"," # io.imsave(os.path.join(savePath,'RSE_GTvsWF_'+imageFilename),img_RSE_GTvsWF_32bit)\n"," saveAsTIF(savePath,'RSE_GTvsWF_'+imageFilename_no_extension, img_RSE_GTvsWF_32bit, pixel_size_hr)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsWF = np.sqrt(np.mean(img_RSE_GTvsWF))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsWF = psnr(test_GT_norm,test_wf_norm,data_range=1.0)\n","\n"," writer.writerow([imageFilename,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsWF),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsWF),str(PSNR_GTvsPrediction), str(PSNR_GTvsWF)])\n","\n"," # Collect values to display in dataframe output\n"," file_name_list.append(imageFilename)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvWF_list.append(index_SSIM_GTvsWF)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvWF_list.append(NRMSE_GTvsWF)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvWF_list.append(PSNR_GTvsWF)\n","\n","\n","# Table with metrics as dataframe output\n","pdResults = pd.DataFrame(index = file_name_list)\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list\n","pdResults[\"Wide-field v. GT mSSIM\"] = mSSIM_GvWF_list\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list\n","pdResults[\"Wide-field v. GT NRMSE\"] = NRMSE_GvWF_list\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list\n","pdResults[\"Wide-field v. GT PSNR\"] = PSNR_GvWF_list\n","\n","\n","# ------------------------ Display ------------------------\n","\n","print('--------------------------------------------')\n","@interact\n","def show_QC_results(file = list_files(QC_image_folder, 'tif')):\n","\n"," plt.figure(figsize=(15,15))\n"," # Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = io.imread(os.path.join(savePath, 'GT_image_'+file))\n"," plt.imshow(img_GT, norm = simple_norm(img_GT, percent = 99.5))\n"," plt.title('Target',fontsize=15)\n","\n"," # Wide-field\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(savePath, 'Widefield_'+file))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield',fontsize=15)\n","\n"," #Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(savePath, 'Predicted_'+file))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Prediction',fontsize=15)\n","\n"," #Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n"," #SSIM between GT and Source\n"," plt.subplot(3,3,5)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_SSIM_GTvsWF = io.imread(os.path.join(savePath, 'SSIM_GTvsWF_'+file))\n"," imSSIM_GTvsWF = plt.imshow(img_SSIM_GTvsWF, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imSSIM_GTvsWF,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Widefield',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(pdResults.loc[file][\"Wide-field v. GT mSSIM\"],3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n"," #SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_SSIM_GTvsPrediction = io.imread(os.path.join(savePath, 'SSIM_GTvsPrediction_'+file))\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(pdResults.loc[file][\"Prediction v. GT mSSIM\"],3)),fontsize=14)\n","\n"," #Root Squared Error between GT and Source\n"," plt.subplot(3,3,8)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_RSE_GTvsWF = io.imread(os.path.join(savePath, 'RSE_GTvsWF_'+file))\n"," imRSE_GTvsWF = plt.imshow(img_RSE_GTvsWF, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsWF,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Widefield',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(pdResults.loc[file][\"Wide-field v. GT NRMSE\"],3))+', PSNR: '+str(round(pdResults.loc[file][\"Wide-field v. GT PSNR\"],3)),fontsize=14)\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n"," #Root Squared Error between GT and Prediction\n"," plt.subplot(3,3,9)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_RSE_GTvsPrediction = io.imread(os.path.join(savePath, 'RSE_GTvsPrediction_'+file))\n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(pdResults.loc[file][\"Prediction v. GT NRMSE\"],3))+', PSNR: '+str(round(pdResults.loc[file][\"Prediction v. GT PSNR\"],3)),fontsize=14)\n","\n","print('--------------------------------------------')\n","pdResults.head()\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yTRou0izLjhd","colab_type":"text"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"eAf8aBDmWTx7"},"source":["## **6.1 Generate image prediction and localizations from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the found localizations csv.\n","\n","**`batch_size`:** This paramter determines how many frames are processed by any single pass on the GPU. A higher `batch_size` will make the prediction faster but will use more GPU memory. If an OutOfMemory (OOM) error occurs, decrease the `batch_size`. **DEFAULT: 4**\n","\n","**`threshold`:** This paramter determines threshold for local maxima finding. The value is expected to reside in the range **[0,1]**. A higher `threshold` will result in less localizations. **DEFAULT: 0.1**\n","\n","**`neighborhood_size`:** This paramter determines size of the neighborhood within which the prediction needs to be a local maxima in recovery pixels (CCD pixel/upsampling_factor). A high `neighborhood_size` will make the prediction slower and potentially discard nearby localizations. **DEFAULT: 3**\n","\n","**`use_local_average`:** This paramter determines whether to locally average the prediction in a 3x3 neighborhood to get the final localizations. If set to **True** it will make inference slightly slower depending on the size of the FOV. **DEFAULT: True**\n"]},{"cell_type":"code","metadata":{"id":"7qn06T_A0lxf","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# ------------------------------- User input -------------------------------\n","#@markdown ### Data parameters\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value (in nm):\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","#@markdown ### Model parameters\n","#@markdown Do you want to use the model you just trained?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, please provide path to the model folder below\n","prediction_model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ### Prediction parameters\n","batch_size = 4#@param {type:\"integer\"}\n","\n","#@markdown ### Post processing parameters\n","threshold = 0.1#@param {type:\"number\"}\n","neighborhood_size = 3#@param {type:\"integer\"}\n","#@markdown Do you want to locally average the model output with CoG estimator ?\n","use_local_average = True #@param {type:\"boolean\"}\n","\n","\n","if get_pixel_size_from_file:\n"," pixel_size = None\n","\n","if (Use_the_current_trained_model): \n"," prediction_model_path = os.path.join(model_path, model_name)\n","\n","if os.path.exists(prediction_model_path):\n"," print(\"The \"+os.path.basename(prediction_model_path)+\" model will be used.\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+bcolors.NORMAL)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n","\n","# inform user whether local averaging is being used\n","if use_local_average == True: \n"," print('Using local averaging')\n","\n","if not os.path.exists(Result_folder):\n"," print('Result folder was created.')\n"," os.makedirs(Result_folder)\n","\n","\n","# ------------------------------- Run predictions -------------------------------\n","\n","start = time.time()\n","#%% This script tests the trained fully convolutional network based on the \n","# saved training weights, and normalization created using train_model.\n","\n","if os.path.isdir(Data_folder): \n"," for filename in list_files(Data_folder, 'tif'):\n"," # run the testing/reconstruction process\n"," print(\"------------------------------------\")\n"," print(\"Running prediction on: \"+ filename)\n"," batchFramePredictionLocalization(Data_folder, filename, prediction_model_path, Result_folder, \n"," batch_size, \n"," threshold, \n"," neighborhood_size, \n"," use_local_average,\n"," pixel_size = pixel_size)\n","\n","elif os.path.isfile(Data_folder):\n"," batchFramePredictionLocalization(os.path.dirname(Data_folder), os.path.basename(Data_folder), prediction_model_path, Result_folder, \n"," batch_size, \n"," threshold, \n"," neighborhood_size, \n"," use_local_average, \n"," pixel_size = pixel_size)\n","\n","\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","\n","# ------------------------------- Interactive display -------------------------------\n","\n","print('--------------------------------------------------------------------')\n","print('---------------------------- Previews ------------------------------')\n","print('--------------------------------------------------------------------')\n","\n","if os.path.isdir(Data_folder): \n"," @interact\n"," def show_QC_results(file = list_files(Data_folder, 'tif')):\n","\n"," plt.figure(figsize=(15,7.5))\n"," # Wide-field\n"," plt.subplot(1,2,1)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+file))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield', fontsize=15)\n"," # Prediction\n"," plt.subplot(1,2,2)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+file))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Predicted',fontsize=15)\n","\n","if os.path.isfile(Data_folder):\n","\n"," plt.figure(figsize=(15,7.5))\n"," # Wide-field\n"," plt.subplot(1,2,1)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+os.path.basename(Data_folder)))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield', fontsize=15)\n"," # Prediction\n"," plt.subplot(1,2,2)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+os.path.basename(Data_folder)))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Predicted',fontsize=15)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ZekzexaPmzFZ","colab_type":"text"},"source":["## **6.2 Drift correction**\n","---\n","\n","The visualization above is the raw output of the network and displayed at the `upsampling_factor` chosen during model training. The display is a preview without any drift correction applied. This section performs drift correction using cross-correlation between time bins to estimate the drift.\n","\n","**`Loc_file_path`:** is the path to the localization file to use for visualization.\n","\n","**`original_image_path`:** is the path to the original image. This only serves to extract the original image size and pixel size to shape the visualization properly.\n","\n","**`visualization_pixel_size`:** This parameter corresponds to the pixel size to use for the image reconstructions used for the Drift Correction estmication (in **nm**). A smaller pixel size will be more precise but will take longer to compute. **DEFAULT: 20**\n","\n","**`number_of_bins`:** This parameter defines how many temporal bins are used across the full dataset. All localizations in each bins are used ot build an image. This image is used to find the drift with respect to the image obtained from the very first bin. A typical value would correspond to about 500 frames per bin. **DEFAULT: Total number of frames / 500**\n","\n","**`polynomial_fit_degree`:** The drift obtained for each temporal bins needs to be interpolated to every single frames. This is performed by polynomial fit, the degree of which is defined here. **DEFAULT: 4**\n","\n"," The drift-corrected localization data is automaticaly saved in the `save_path` folder."]},{"cell_type":"code","metadata":{"id":"hYtP_vh6mzUP","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Data parameters\n","Loc_file_path = \"\" #@param {type:\"string\"}\n","# @markdown Provide information about original data. Get the info automatically from the raw data?\n","Get_info_from_file = True #@param {type:\"boolean\"}\n","# Loc_file_path = \"/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv\" #@param {type:\"string\"}\n","original_image_path = \"\" #@param {type:\"string\"}\n","# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)\n","image_width = 256#@param {type:\"integer\"}\n","image_height = 256#@param {type:\"integer\"}\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","# @markdown ##Drift correction parameters\n","visualization_pixel_size = 20#@param {type:\"number\"}\n","number_of_bins = 50#@param {type:\"integer\"}\n","polynomial_fit_degree = 4#@param {type:\"integer\"}\n","\n","# @markdown ##Saving parameters\n","save_path = '' #@param {type:\"string\"}\n","\n","\n","# Let's go !\n","start = time.time()\n","\n","# Get info from the raw file if selected\n","if Get_info_from_file:\n"," pixel_size, image_width, image_height = getPixelSizeTIFFmetadata(original_image_path, display=True)\n","\n","# Read the localizations in\n","LocData = pd.read_csv(Loc_file_path)\n","\n","# Calculate a few variables \n","Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))\n","Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))\n","nFrames = max(LocData['frame'])\n","x_max = max(LocData['x [nm]'])\n","y_max = max(LocData['y [nm]'])\n","image_size = (Mhr, Nhr)\n","n_locs = len(LocData.index)\n","\n","print('Image size: '+str(image_size))\n","print('Number of frames in data: '+str(nFrames))\n","print('Number of localizations in data: '+str(n_locs))\n","\n","blocksize = math.ceil(nFrames/number_of_bins)\n","print('Number of frames per block: '+str(blocksize))\n","\n","blockDataFrame = LocData[(LocData['frame'] < blocksize)].copy()\n","xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)\n","yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)\n","\n","# Preparing the Reference image\n","photon_array = np.ones(yc_array.shape[0])\n","sigma_array = np.ones(yc_array.shape[0])\n","ImageRef = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","ImagesRef = np.rot90(ImageRef, k=2)\n","\n","xDrift = np.zeros(number_of_bins)\n","yDrift = np.zeros(number_of_bins)\n","\n","filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]\n","\n","with open(os.path.join(save_path, filename_no_extension+\"_DriftCorrectionData.csv\"), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"Block #\", \"x-drift [nm]\",\"y-drift [nm]\"])\n","\n"," for b in tqdm(range(number_of_bins)):\n","\n"," blockDataFrame = LocData[(LocData['frame'] >= (b*blocksize)) & (LocData['frame'] < ((b+1)*blocksize))].copy()\n"," xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)\n"," yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)\n","\n"," photon_array = np.ones(yc_array.shape[0])\n"," sigma_array = np.ones(yc_array.shape[0])\n"," ImageBlock = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n"," XC = fftconvolve(ImagesRef, ImageBlock, mode = 'same')\n"," yDrift[b], xDrift[b] = subPixelMaxLocalization(XC, method = 'CoM')\n","\n"," # saveAsTIF(save_path, 'ImageBlock'+str(b), ImageBlock, visualization_pixel_size)\n"," # saveAsTIF(save_path, 'XCBlock'+str(b), XC, visualization_pixel_size)\n"," writer.writerow([str(b), str((xDrift[b]-xDrift[0])*visualization_pixel_size), str((yDrift[b]-yDrift[0])*visualization_pixel_size)])\n","\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","print('Fitting drift data...')\n","bin_number = np.arange(number_of_bins)*blocksize + blocksize/2\n","xDrift = (xDrift-xDrift[0])*visualization_pixel_size\n","yDrift = (yDrift-yDrift[0])*visualization_pixel_size\n","\n","xDriftCoeff = np.polyfit(bin_number, xDrift, polynomial_fit_degree)\n","yDriftCoeff = np.polyfit(bin_number, yDrift, polynomial_fit_degree)\n","\n","xDriftFit = np.poly1d(xDriftCoeff)\n","yDriftFit = np.poly1d(yDriftCoeff)\n","bins = np.arange(nFrames)\n","xDriftInterpolated = xDriftFit(bins)\n","yDriftInterpolated = yDriftFit(bins)\n","\n","\n","# ------------------ Displaying the image results ------------------\n","\n","plt.figure(figsize=(15,10))\n","plt.plot(bin_number,xDrift, 'r+', label='x-drift')\n","plt.plot(bin_number,yDrift, 'b+', label='y-drift')\n","plt.plot(bins,xDriftInterpolated, 'r-', label='y-drift (fit)')\n","plt.plot(bins,yDriftInterpolated, 'b-', label='y-drift (fit)')\n","plt.title('Cross-correlation estimated drift')\n","plt.ylabel('Drift [nm]')\n","plt.xlabel('Bin number')\n","plt.legend();\n","\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\", hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","\n","# ------------------ Actual drift correction -------------------\n","\n","print('Correcting localization data...')\n","xc_array = LocData['x [nm]'].to_numpy(dtype=np.float32)\n","yc_array = LocData['y [nm]'].to_numpy(dtype=np.float32)\n","frames = LocData['frame'].to_numpy(dtype=np.int32)\n","\n","\n","xc_array_Corr, yc_array_Corr = correctDriftLocalization(xc_array, yc_array, frames, xDriftInterpolated, yDriftInterpolated)\n","ImageRaw = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","ImageCorr = FromLoc2Image_SimpleHistogram(xc_array_Corr, yc_array_Corr, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n","\n","# ------------------ Displaying the imge results ------------------\n","plt.figure(figsize=(15,7.5))\n","# Raw\n","plt.subplot(1,2,1)\n","plt.axis('off')\n","plt.imshow(ImageRaw, norm = simple_norm(ImageRaw, percent = 99.5))\n","plt.title('Raw', fontsize=15);\n","# Corrected\n","plt.subplot(1,2,2)\n","plt.axis('off')\n","plt.imshow(ImageCorr, norm = simple_norm(ImageCorr, percent = 99.5))\n","plt.title('Corrected',fontsize=15);\n","\n","\n","# ------------------ Table with info -------------------\n","driftCorrectedLocData = pd.DataFrame()\n","driftCorrectedLocData['frame'] = frames\n","driftCorrectedLocData['x [nm]'] = xc_array_Corr\n","driftCorrectedLocData['y [nm]'] = yc_array_Corr\n","driftCorrectedLocData['confidence [a.u]'] = LocData['confidence [a.u]']\n","\n","driftCorrectedLocData.to_csv(os.path.join(save_path, filename_no_extension+'_DriftCorrected.csv'))\n","print('-------------------------------')\n","print('Corrected localizations saved.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"mzOuc-V7rB-r","colab_type":"text"},"source":["## **6.3 Visualization of the localizations**\n","---\n","\n","\n","The visualization in section 6.1 is the raw output of the network and displayed at the `upsampling_factor` chosen during model training. This section performs visualization of the result by plotting the localizations as a simple histogram.\n","\n","**`Loc_file_path`:** is the path to the localization file to use for visualization.\n","\n","**`original_image_path`:** is the path to the original image. This only serves to extract the original image size and pixel size to shape the visualization properly.\n","\n","**`visualization_pixel_size`:** This parameter corresponds to the pixel size to use for the final image reconstruction (in **nm**). **DEFAULT: 10**\n","\n","**`visualization_mode`:** This parameter defines what visualization method is used to visualize the final image. NOTES: The Integrated Gaussian can be quite slow. **DEFAULT: Simple histogram.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"876yIXnqq-nW","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Data parameters\n","Use_current_drift_corrected_localizations = True #@param {type:\"boolean\"}\n","# @markdown Otherwise provide a localization file path\n","Loc_file_path = \"\" #@param {type:\"string\"}\n","# @markdown Provide information about original data. Get the info automatically from the raw data?\n","Get_info_from_file = True #@param {type:\"boolean\"}\n","# Loc_file_path = \"/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv\" #@param {type:\"string\"}\n","original_image_path = \"\" #@param {type:\"string\"}\n","# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)\n","image_width = 256#@param {type:\"integer\"}\n","image_height = 256#@param {type:\"integer\"}\n","pixel_size = 100#@param {type:\"number\"}\n","\n","# @markdown ##Visualization parameters\n","visualization_pixel_size = 10#@param {type:\"number\"}\n","visualization_mode = \"Simple histogram\" #@param [\"Simple histogram\", \"Integrated Gaussian (SLOW!)\"]\n","\n","if not Use_current_drift_corrected_localizations:\n"," filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]\n","\n","\n","if Get_info_from_file:\n"," pixel_size, image_width, image_height = getPixelSizeTIFFmetadata(original_image_path, display=True)\n","\n","if Use_current_drift_corrected_localizations:\n"," LocData = driftCorrectedLocData\n","else:\n"," LocData = pd.read_csv(Loc_file_path)\n","\n","Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))\n","Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))\n","\n","\n","nFrames = max(LocData['frame'])\n","x_max = max(LocData['x [nm]'])\n","y_max = max(LocData['y [nm]'])\n","image_size = (Mhr, Nhr)\n","\n","print('Image size: '+str(image_size))\n","print('Number of frames in data: '+str(nFrames))\n","print('Number of localizations in data: '+str(len(LocData.index)))\n","\n","xc_array = LocData['x [nm]'].to_numpy()\n","yc_array = LocData['y [nm]'].to_numpy()\n","if (visualization_mode == 'Simple histogram'):\n"," locImage = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","elif (visualization_mode == 'Shifted histogram'):\n"," print(bcolors.WARNING+'Method not implemented yet!'+bcolors.NORMAL)\n"," locImage = np.zeros(image_size)\n","elif (visualization_mode == 'Integrated Gaussian (SLOW!)'):\n"," photon_array = np.ones(xc_array.shape)\n"," sigma_array = np.ones(xc_array.shape)\n"," locImage = FromLoc2Image_Erf(xc_array, yc_array, photon_array, sigma_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","# Display\n","plt.figure(figsize=(20,10))\n","plt.axis('off')\n","# plt.imshow(locImage, cmap='gray');\n","plt.imshow(locImage, norm = simple_norm(locImage, percent = 99.5));\n","\n","\n","LocData.head()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"PdOhWwMn1zIT","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ---\n","# @markdown #Play this cell to save the visualization\n","# @markdown ####Please select a path to the folder where to save the visualization.\n","save_path = \"\" #@param {type:\"string\"}\n","\n","if not os.path.exists(save_path):\n"," os.makedirs(save_path)\n"," print('Folder created.')\n","\n","saveAsTIF(save_path, filename_no_extension+'_Visualization', locImage, visualization_pixel_size)\n","print('Image saved.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"1EszIF4Dkz_n","colab_type":"text"},"source":["## **6.4. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UgN-NooKk3nV","colab_type":"text"},"source":["\n","#**Thank you for using Deep-STORM 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Noise2VOID_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Noise2VOID_2D_ZeroCostDL4Mic.ipynb index 469219a5..b1c93e14 100755 --- a/Colab_notebooks/Noise2VOID_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Noise2VOID_2D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"Noise2Void_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1hzAI0joLETcG5sI2Qvo8AKDr0TWRKySJ","timestamp":1587653755731},{"file_id":"1QFcz4NnQv4rMwDNl7AzHajN-Ola9sUFW","timestamp":1586411847878},{"file_id":"12UDRQ7abcnXcf5FctR9IUStgCpBiQWn7","timestamp":1584466922281},{"file_id":"1zXCn3A39GI1MCnXK_g_Z-AWh9vkB0YhU","timestamp":1583244415636}],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.9"}},"cells":[{"cell_type":"markdown","metadata":{"colab_type":"text","id":"IkSguVy8Xv83"},"source":["# **Noise2Void (2D)**\n","\n","---\n","\n"," Noise2Void is a deep-learning method that can be used to denoise many types of images, including microscopy images and which was originally published by [Krull *et al.* on arXiv](https://arxiv.org/abs/1811.10980). It allows denoising of image data in a self-supervised manner, therefore high-quality, low noise equivalent images are not necessary to train this network. This is performed by \"masking\" a random subset of pixels in the noisy image and training the network to predict the values in these pixels. The resulting output is a denoised version of the image. Noise2Void is based on the popular U-Net network architecture, adapted from [CARE](https://www.nature.com/articles/s41592-018-0216-7).\n","\n"," **This particular notebook enables self-supervised denoised of 2D dataset. If you are interested in 3D dataset, you should use the Noise2Void 3D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper:\n","\n","**Noise2Void - Learning Denoising from Single Noisy Images**\n","from Krull *et al.* published on arXiv in 2018 (https://arxiv.org/abs/1811.10980)\n","\n","And source code found in: https://github.com/juglab/n2v\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"gKDLkLWUd-YX"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For Noise2Void to train, it only requires a single noisy image but multiple images can be used. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","Please note that you currently can **only use .tif files!**\n","\n","**We strongly recommend that you generate high signal to noise ration version of your noisy images (Quality control dataset). These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here is a common data structure that can work:\n","\n","* Data\n"," - **Training dataset**\n"," - **Quality control dataset** (Optional but recomended)\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif \n"," - **Data to be predicted** \n"," - Results\n","\n","\n","The **Results** folder will contain the processed images, trained model and network parameters as csv file. Your original images remain unmodified.\n","\n","---\n","**Important note**\n","\n","- If you wish to **train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---\n"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"cbTknRcviyT7"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"DMNHVZfHmbKb"},"source":["## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"h5i5CS2bSmZr","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"n3B3meGTbYVi"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"01Djr8v-5pPk","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"n4yWFoJNnoin"},"source":["# **2. Install Noise2Void and dependencies**\n","---"]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"fq21zJVFNASx","colab":{}},"source":["#@markdown ##Install Noise2Void and dependencies\n","\n","# Here we enable Tensorflow 1. \n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","\n","# Here we install Noise2Void and other required packages\n","!pip install n2v\n","!pip install wget\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","print(\"Noise2Void installed.\")\n","\n","# Here we install all libraries and other depencies to run the notebook.\n","\n","# ------- Variable specific to N2V -------\n","from n2v.models import N2VConfig, N2V\n","from csbdeep.utils import plot_history\n","from n2v.utils.n2v_utils import manipulate_val_data\n","from n2v.internals.N2V_DataGenerator import N2V_DataGenerator\n","from csbdeep.io import save_tiff_imagej_compatible\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"Kbn9_JdqnNnK","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`:** These is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 4.3.). **Default value: 30**\n"," \n","**`patch_size`:** Noise2Void divides the image into patches for training. Input the size of the patches (length of a side). The value should be between 64 and the dimensions of the image and divisible by 8. **Default value: 64**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**\n"]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"ewpNJ_I0Mv47","colab":{}},"source":["# create DataGenerator-object.\n","\n","datagen = N2V_DataGenerator()\n","\n","#@markdown ###Path to training image(s): \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","#compatibility to easily change the name of the parameters\n","training_images = Training_source \n","imgs = datagen.load_imgs_from_directory(directory = Training_source)\n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 30#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels)\n","patch_size = 64#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True#@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 128#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","percentage_validation = 10#@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," # number_of_steps is defined in the following cell in this case\n"," batch_size = 128\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n"," \n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name): \n"," print(R + \"!! WARNING: Folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","\n","# This will open a randomly chosen dataset input image\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images contains the expected dimensions\n","if len(x.shape) == 2:\n"," print(\"Image dimensions (y,x)\",x.shape)\n","\n","if not len(x.shape) == 2:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not run)\n","Use_pretrained_model = False\n","\n","# Here we enable data augmentation by default (in case the cell is not ran)\n","Use_Data_augmentation = True\n","\n","print(\"Parameters initiated.\")\n","\n","#Here we display one image\n","norm = simple_norm(x, percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"STDOuNOFsTTJ","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"E4QW-tvYsWhX","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n"," **By default data augmentation is enabled. Disable this option is you run out of RAM during the training**.\n"," "]},{"cell_type":"code","metadata":{"id":"-Vy-vV7ssabS","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"W6pZg0KVnPzf","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a N2V 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"l-EDcv3Wyvqb","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"keIQhCmOMv5S"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"PXcLuX5jbNUv"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"rBelu-LtbOTh","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","\n","# split patches from the training images\n","Xdata = datagen.generate_patches_from_list(imgs, shape=(patch_size,patch_size), augment=Use_Data_augmentation)\n","shape_of_Xdata = Xdata.shape\n","# create a threshold (10 % patches for the validation)\n","threshold = int(shape_of_Xdata[0]*(percentage_validation/100))\n","# split the patches into training patches and validation patches\n","X = Xdata[threshold:]\n","X_val = Xdata[:threshold]\n","print(Xdata.shape[0],\"patches created.\")\n","print(threshold,\"patch images for validation (\",percentage_validation,\"%).\")\n","print(X.shape[0]-threshold,\"patch images for training.\")\n","%memit\n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size)+1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# create a Config object\n","config = N2VConfig(X, unet_kern_size=3, \n"," train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, \n"," train_loss='mse', batch_norm=True, train_batch_size=batch_size, n2v_perc_pix=0.198, \n"," n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5, train_learning_rate = initial_learning_rate)\n","\n","# Let's look at the parameters stored in the config-object.\n","vars(config)\n"," \n"," \n","# create network model.\n","model = N2V(config=config, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","print(\"Setup done.\")\n","print(config)\n","\n","\n","# creates a plot and shows one training patch and one validation patch.\n","plt.figure(figsize=(16,87))\n","plt.subplot(1,2,1)\n","plt.imshow(X[0,...,0], cmap='magma')\n","plt.axis('off')\n","plt.title('Training Patch');\n","plt.subplot(1,2,2)\n","plt.imshow(X_val[0,...,0], cmap='magma')\n","plt.axis('off')\n","plt.title('Validation Patch');"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Train the network**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"fisJmA13Mv5e","scrolled":true,"colab":{}},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","%memit\n","\n","history = model.train(X, X_val)\n","print(\"Training done.\")\n","%memit\n","\n","\n","print(\"Training, done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"Vd9igRYvSnTr"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"sTMDT1u7rK9g","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"OVxLyPyPiv85","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"WZDvRjLZu-Lm"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","It is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact noise patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"vMzSP50kMv5p","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"lreUY7-SsGkI","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n"]},{"cell_type":"code","metadata":{"id":"kjbHJHbtsg2R","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","# Activate the pretrained model. \n","model_training = N2V(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," predicted = model.predict(img, axes='YX', n_tiles=(2,1))\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(filename, predicted)\n","\n","def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = io.imread(os.path.join(Target_QC_folder, i))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n","\n","plt.figure(figsize=(15,15))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_GT)\n","plt.title('Target',fontsize=15)\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source)\n","plt.title('Source',fontsize=15)\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", Test_FileList[-1]))\n","plt.imshow(img_Prediction)\n","plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"DWAhOBc7gpzN"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"KAILvLGFS2-1"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"bl3EdYFVS7X9","colab":{}},"source":["#Activate the pretrained model. \n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING +'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#Activate the pretrained model. \n","config = None\n","model = N2V(config, Prediction_model_name, basedir=Prediction_model_path)\n","\n","\n","# creates a loop, creating filenames and saving them\n","print(\"Saving the images...\")\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","# The code by Lucas von Chamier.\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='YX', n_tiles=(2,1))\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='YX') \n","\n","print(\"Images saved into folder:\", Result_folder)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"PfTw_pQUUAqB"},"source":["## **6.2. Assess predicted output**\n","---\n","\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"jFp-0y4zT_gL","colab":{}},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Input')\n","plt.axis('off');\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Predicted output')\n","plt.axis('off');"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"wgO7Ok1PBFQj"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"nlyPYwZu4VVS","colab_type":"text"},"source":["#**Thank you for using Noise2Void 2D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"Noise2Void_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1hzAI0joLETcG5sI2Qvo8AKDr0TWRKySJ","timestamp":1587653755731},{"file_id":"1QFcz4NnQv4rMwDNl7AzHajN-Ola9sUFW","timestamp":1586411847878},{"file_id":"12UDRQ7abcnXcf5FctR9IUStgCpBiQWn7","timestamp":1584466922281},{"file_id":"1zXCn3A39GI1MCnXK_g_Z-AWh9vkB0YhU","timestamp":1583244415636}],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.9"}},"cells":[{"cell_type":"markdown","metadata":{"colab_type":"text","id":"IkSguVy8Xv83"},"source":["# **Noise2Void (2D)**\n","\n","---\n","\n"," Noise2Void is a deep-learning method that can be used to denoise many types of images, including microscopy images and which was originally published by [Krull *et al.* on arXiv](https://arxiv.org/abs/1811.10980). It allows denoising of image data in a self-supervised manner, therefore high-quality, low noise equivalent images are not necessary to train this network. This is performed by \"masking\" a random subset of pixels in the noisy image and training the network to predict the values in these pixels. The resulting output is a denoised version of the image. Noise2Void is based on the popular U-Net network architecture, adapted from [CARE](https://www.nature.com/articles/s41592-018-0216-7).\n","\n"," **This particular notebook enables self-supervised denoised of 2D dataset. If you are interested in 3D dataset, you should use the Noise2Void 3D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper:\n","\n","**Noise2Void - Learning Denoising from Single Noisy Images**\n","from Krull *et al.* published on arXiv in 2018 (https://arxiv.org/abs/1811.10980)\n","\n","And source code found in: https://github.com/juglab/n2v\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"gKDLkLWUd-YX"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For Noise2Void to train, it only requires a single noisy image but multiple images can be used. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","Please note that you currently can **only use .tif files!**\n","\n","**We strongly recommend that you generate high signal to noise ration version of your noisy images (Quality control dataset). These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here is a common data structure that can work:\n","\n","* Data\n"," - **Training dataset**\n"," - **Quality control dataset** (Optional but recomended)\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif \n"," - **Data to be predicted** \n"," - Results\n","\n","\n","The **Results** folder will contain the processed images, trained model and network parameters as csv file. Your original images remain unmodified.\n","\n","---\n","**Important note**\n","\n","- If you wish to **train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---\n"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"cbTknRcviyT7"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"DMNHVZfHmbKb"},"source":["## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"h5i5CS2bSmZr","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"n3B3meGTbYVi"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"01Djr8v-5pPk","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"n4yWFoJNnoin"},"source":["# **2. Install Noise2Void and dependencies**\n","---"]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"fq21zJVFNASx","colab":{}},"source":["#@markdown ##Install Noise2Void and dependencies\n","\n","# Here we enable Tensorflow 1. \n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","\n","# Here we install Noise2Void and other required packages\n","!pip install n2v\n","!pip install wget\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","print(\"Noise2Void installed.\")\n","\n","# Here we install all libraries and other depencies to run the notebook.\n","\n","# ------- Variable specific to N2V -------\n","from n2v.models import N2VConfig, N2V\n","from csbdeep.utils import plot_history\n","from n2v.utils.n2v_utils import manipulate_val_data\n","from n2v.internals.N2V_DataGenerator import N2V_DataGenerator\n","from csbdeep.io import save_tiff_imagej_compatible\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"Kbn9_JdqnNnK","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`:** These is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 4.3.). **Default value: 30**\n"," \n","**`patch_size`:** Noise2Void divides the image into patches for training. Input the size of the patches (length of a side). The value should be between 64 and the dimensions of the image and divisible by 8. **Default value: 64**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**\n"]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"ewpNJ_I0Mv47","colab":{}},"source":["# create DataGenerator-object.\n","\n","datagen = N2V_DataGenerator()\n","\n","#@markdown ###Path to training image(s): \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","#compatibility to easily change the name of the parameters\n","training_images = Training_source \n","imgs = datagen.load_imgs_from_directory(directory = Training_source)\n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 30#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels)\n","patch_size = 64#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True#@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 128#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","percentage_validation = 10#@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," # number_of_steps is defined in the following cell in this case\n"," batch_size = 128\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n"," \n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name): \n"," print(R + \"!! WARNING: Folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","\n","# This will open a randomly chosen dataset input image\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images contains the expected dimensions\n","if len(x.shape) == 2:\n"," print(\"Image dimensions (y,x)\",x.shape)\n","\n","if not len(x.shape) == 2:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not run)\n","Use_pretrained_model = False\n","\n","# Here we enable data augmentation by default (in case the cell is not ran)\n","Use_Data_augmentation = True\n","\n","print(\"Parameters initiated.\")\n","\n","#Here we display one image\n","norm = simple_norm(x, percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"STDOuNOFsTTJ","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"E4QW-tvYsWhX","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n"," **By default data augmentation is enabled. Disable this option is you run out of RAM during the training**.\n"," "]},{"cell_type":"code","metadata":{"id":"-Vy-vV7ssabS","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"W6pZg0KVnPzf","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a N2V 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"l-EDcv3Wyvqb","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"keIQhCmOMv5S"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"PXcLuX5jbNUv"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"rBelu-LtbOTh","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","\n","# split patches from the training images\n","Xdata = datagen.generate_patches_from_list(imgs, shape=(patch_size,patch_size), augment=Use_Data_augmentation)\n","shape_of_Xdata = Xdata.shape\n","# create a threshold (10 % patches for the validation)\n","threshold = int(shape_of_Xdata[0]*(percentage_validation/100))\n","# split the patches into training patches and validation patches\n","X = Xdata[threshold:]\n","X_val = Xdata[:threshold]\n","print(Xdata.shape[0],\"patches created.\")\n","print(threshold,\"patch images for validation (\",percentage_validation,\"%).\")\n","print(X.shape[0]-threshold,\"patch images for training.\")\n","%memit\n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size)+1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# create a Config object\n","config = N2VConfig(X, unet_kern_size=3, \n"," train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, \n"," train_loss='mse', batch_norm=True, train_batch_size=batch_size, n2v_perc_pix=0.198, \n"," n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5, train_learning_rate = initial_learning_rate)\n","\n","# Let's look at the parameters stored in the config-object.\n","vars(config)\n"," \n"," \n","# create network model.\n","model = N2V(config=config, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","print(\"Setup done.\")\n","print(config)\n","\n","\n","# creates a plot and shows one training patch and one validation patch.\n","plt.figure(figsize=(16,87))\n","plt.subplot(1,2,1)\n","plt.imshow(X[0,...,0], cmap='magma')\n","plt.axis('off')\n","plt.title('Training Patch');\n","plt.subplot(1,2,2)\n","plt.imshow(X_val[0,...,0], cmap='magma')\n","plt.axis('off')\n","plt.title('Validation Patch');"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Trainning**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"fisJmA13Mv5e","scrolled":true,"colab":{}},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","%memit\n","\n","history = model.train(X, X_val)\n","print(\"Training done.\")\n","%memit\n","\n","\n","print(\"Training, done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"Vd9igRYvSnTr"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"sTMDT1u7rK9g","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"OVxLyPyPiv85","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"WZDvRjLZu-Lm"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","It is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact noise patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"vMzSP50kMv5p","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"lreUY7-SsGkI","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n"]},{"cell_type":"code","metadata":{"id":"kjbHJHbtsg2R","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","# Activate the pretrained model. \n","model_training = N2V(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," predicted = model.predict(img, axes='YX', n_tiles=(2,1))\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(filename, predicted)\n","\n","def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = io.imread(os.path.join(Target_QC_folder, i))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n","\n","plt.figure(figsize=(15,15))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_GT)\n","plt.title('Target',fontsize=15)\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source)\n","plt.title('Source',fontsize=15)\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", Test_FileList[-1]))\n","plt.imshow(img_Prediction)\n","plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"DWAhOBc7gpzN"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"KAILvLGFS2-1"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"bl3EdYFVS7X9","colab":{}},"source":["#Activate the pretrained model. \n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING +'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#Activate the pretrained model. \n","config = None\n","model = N2V(config, Prediction_model_name, basedir=Prediction_model_path)\n","\n","\n","# creates a loop, creating filenames and saving them\n","print(\"Saving the images...\")\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","# The code by Lucas von Chamier.\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='YX', n_tiles=(2,1))\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='YX') \n","\n","print(\"Images saved into folder:\", Result_folder)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"PfTw_pQUUAqB"},"source":["## **6.2. Assess predicted output**\n","---\n","\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","colab_type":"code","id":"jFp-0y4zT_gL","colab":{}},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Input')\n","plt.axis('off');\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Predicted output')\n","plt.axis('off');"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"wgO7Ok1PBFQj"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"nlyPYwZu4VVS","colab_type":"text"},"source":["#**Thank you for using Noise2Void 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Noise2VOID_3D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Noise2VOID_3D_ZeroCostDL4Mic.ipynb index 21bd0b6d..1116dcf5 100755 --- a/Colab_notebooks/Noise2VOID_3D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Noise2VOID_3D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Noise2Void_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83","colab_type":"text"},"source":["# **Noise2Void (3D)**\n","\n","---\n","\n"," Noise2Void is a deep-learning method that can be used to denoise many types of images, including microscopy images and which was originally published by [Krull *et al.* on arXiv](https://arxiv.org/abs/1811.10980). It allows denoising of image data in a self-supervised manner, therefore high-quality, low noise equivalent images are not necessary to train this network. This is performed by \"masking\" a random subset of pixels in the noisy image and training the network to predict the values in these pixels. The resulting output is a denoised version of the image. Noise2Void is based on the popular U-Net network architecture, adapted from [CARE](https://www.nature.com/articles/s41592-018-0216-7).\n","\n"," **This particular notebook enables self-supervised denoised of 3D dataset. If you are interested in 2D dataset, you should use the Noise2Void 2D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper:\n","\n","**Noise2Void - Learning Denoising from Single Noisy Images**\n","from Krull *et al.* published on arXiv in 2018 (https://arxiv.org/abs/1811.10980)\n","\n","And source code found in: https://github.com/juglab/n2v\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","\n","\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX","colab_type":"text"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For Noise2Void to train, it only requires a single noisy image but multiple images can be used. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","Please note that you currently can **only use .tif files!**\n","\n","**We strongly recommend that you generate high signal to noise ration version of your noisy images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here is a common data structure that can work:\n","\n","* Data\n"," - **Training dataset**\n"," - **Quality control dataset** (Optional but recomended)\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif \n"," - **Data to be predicted** \n"," - **Results**\n","\n","\n","The **Results** folder will contain the processed images, trained model and network parameters as csv file. Your original images remain unmodified.\n","\n","---\n","**Important note**\n","\n","- If you wish to **train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"AdN8B91xZO0x"},"source":["# **2. Install Noise2Void and dependencies**\n","---"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install Noise2Void and dependencies\n","\n","# Enable the Tensorflow 1 instead of the Tensorflow 2.\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","# Here we install Noise2Void and other required packages\n","!pip install n2v\n","!pip install wget\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","print(\"Noise2Void installed.\")\n","\n","# Here we install all libraries and other depencies to run the notebook.\n","\n","# ------- Variable specific to N2V -------\n","from n2v.models import N2VConfig, N2V\n","from csbdeep.utils import plot_history\n","from n2v.utils.n2v_utils import manipulate_val_data\n","from n2v.internals.N2V_DataGenerator import N2V_DataGenerator\n","from csbdeep.io import save_tiff_imagej_compatible\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43","colab_type":"text"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`:** This is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 5.). **Default value: 30**\n","\n","**`patch_size`:** Noise2Void divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 64**\n","\n","**`patch_height`:** The value should be smaller than the Z dimensions of the image and divisible by 4. When analysing isotropic stacks patch_size and patch_height should have similar values.\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# Create DataGenerator-object.\n","datagen = N2V_DataGenerator()\n","\n","#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","imgs = datagen.load_imgs_from_directory(directory = Training_source, dims='ZYX')\n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Training Parameters\n","#@markdown Number of steps and epochs:\n","\n","number_of_epochs = 30#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels) and number\n","patch_size = 64#@param {type:\"number\"}\n","\n","patch_height = 4#@param {type:\"number\"}\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 128#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," # number_of_steps is defined in the following cell in this case\n"," batch_size = 128\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name): \n"," print(bcolors.WARNING +\"!! WARNING: Folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","\n","#Load one randomly chosen training target file\n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING + \"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_height is smaller than the z dimension of the image \n","if patch_height > Image_Z :\n"," patch_height = Image_Z\n"," print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we check that patch_height is divisible by 4\n","if not patch_height % 4 == 0:\n"," patch_height = ((int(patch_height / 4)-1) * 4)\n"," if patch_height == 0:\n"," patch_height = 4\n"," print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not run)\n","Use_pretrained_model = False\n","\n","# Here we enable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = True\n","\n","print(\"Parameters initiated.\")\n","\n","\n","#Here we display a single z plane\n","\n","norm = simple_norm(x[mid_plane], percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n"," By default data augmentation is enabled. Disable this option is you run out of RAM during the training.\n"," "]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a N2V 3D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M","colab_type":"text"},"source":["#**4. Train your network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT","colab_type":"text"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","\n","#Disable some of the warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","# Create batches from the training data.\n","patches = datagen.generate_patches_from_list(imgs, shape=(patch_height, patch_size, patch_size), augment=Use_Data_augmentation)\n","\n","# Patches are divited into training and validation patch set. This inhibits over-lapping of patches. \n","number_train_images =int(len(patches)*(percentage_validation/100))\n","X = patches[number_train_images:]\n","X_val = patches[:number_train_images]\n","\n","print(len(patches),\"patches created.\")\n","print(number_train_images,\"patch images for validation (\",percentage_validation,\"%).\")\n","print((len(patches)-number_train_images),\"patch images for training.\")\n","%memit \n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size) + 1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# creates Congfig object. \n","config = N2VConfig(X, unet_kern_size=3, \n"," train_steps_per_epoch=number_of_steps,train_epochs=number_of_epochs, train_loss='mse', batch_norm=True, \n"," train_batch_size=batch_size, n2v_perc_pix=0.198, n2v_patch_shape=(patch_height, patch_size, patch_size), \n"," n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5, train_learning_rate = initial_learning_rate)\n","\n","vars(config)\n","\n","# Create the default model.\n","model = N2V(config=config, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","print(\"Parameters transferred into the model.\")\n","print(config)\n","\n","# Shows a training batch and a validation batch.\n","plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(X[0,1,...,0],cmap='magma')\n","plt.axis('off')\n","plt.title('Training Patch');\n","plt.subplot(1,2,2)\n","plt.imshow(X_val[0,1,...,0],cmap='magma')\n","plt.axis('off')\n","plt.title('Validation Patch');\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d","colab_type":"text"},"source":["## **4.2. Train the network**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point."]},{"cell_type":"code","metadata":{"scrolled":true,"colab_type":"code","cellView":"form","id":"iwNmp1PUzRDQ","colab":{}},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","%memit\n","# the training starts.\n","history = model.train(X, X_val)\n","%memit\n","print(\"Model training is now done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nRaaG02xZh_N","colab_type":"text"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else: \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n"]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(path_metrics_save+'Prediction'):\n"," shutil.rmtree(path_metrics_save+'Prediction')\n","os.makedirs(path_metrics_save+'Prediction')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = True #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","\n","# Activate the pretrained model. \n","model_training = N2V(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," n_slices = img.shape[0]\n"," predicted = model_training.predict(img, axes='ZYX', n_tiles=n_tilesZYX)\n"," os.chdir(path_metrics_save+'Prediction/')\n"," imsave('Predicted_'+filename, predicted)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvS_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvS_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvS_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," mSSIM_GvS_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," NRMSE_GvS_list_mean = []\n"," PSNR_GvP_list_mean = []\n"," PSNR_GvS_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," # Calculating the position of the mid-plane slice\n"," z_mid_plane = int(n_slices / 2)+1\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_SSIM_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," # -------------------------------- Normalising the dataset --------------------------------\n","\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT_stack[z], test_source_stack[z], normalize_gt=True)\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = structural_similarity(test_GT_norm, test_source_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = img_as_float32(img_SSIM_GTvsPrediction,force_copy=False)\n"," img_SSIM_GTvsSource_stack[z] = img_as_float32(img_SSIM_GTvsSource,force_copy=False)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = img_as_float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_stack[z] = img_as_float32(img_RSE_GTvsSource)\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource), str(PSNR_GTvsPrediction), str(PSNR_GTvsSource)])\n"," \n"," # Collect values to display in dataframe output\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvS_list.append(index_SSIM_GTvsSource)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvS_list.append(NRMSE_GTvsSource)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvS_list.append(PSNR_GTvsSource)\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n"," SSIM_GTvsS_forDisplay = index_SSIM_GTvsSource\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n"," NRMSE_GTvsS_forDisplay = NRMSE_GTvsSource\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n"," mSSIM_GvS_list_mean.append(sum(mSSIM_GvS_list)/len(mSSIM_GvS_list))\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n"," NRMSE_GvS_list_mean.append(sum(NRMSE_GvS_list)/len(NRMSE_GvS_list))\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n"," PSNR_GvS_list_mean.append(sum(PSNR_GvS_list)/len(PSNR_GvS_list))\n","\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n","\n"," img_SSIM_GTvsSource_stack_32 = img_as_float32(img_SSIM_GTvsSource_stack, force_copy=False)\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsSource_stack_32 = img_as_float32(img_RSE_GTvsSource_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsSource_'+thisFile,img_SSIM_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsSource_'+thisFile,img_RSE_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","pdResults[\"Input v. GT mSSIM\"] = mSSIM_GvS_list_mean\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","pdResults[\"Input v. GT NRMSE\"] = NRMSE_GvS_list_mean\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","pdResults[\"Input v. GT PSNR\"] = PSNR_GvS_list_mean\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(15,15))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","\n","# Calculating the position of the mid-plane slice\n","z_mid_plane = int(img_GT.shape[0] / 2)+1\n","\n","plt.imshow(img_GT[z_mid_plane])\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane])\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","plt.imshow(img_Prediction[z_mid_plane])\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_SSIM_GTvsSource = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsSource_'+Test_FileList[-1]))\n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsS_forDisplay,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_RSE_GTvsSource = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsSource_'+Test_FileList[-1]))\n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax = 1) \n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsS_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Qulity Control folder.')\n","pdResults.head()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP","colab_type":"text"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","colab_type":"code","cellView":"form","colab":{}},"source":["#Activate the pretrained model. \n","#model_training = CARE(config=None, name=model_name, basedir=model_path)\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else: \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = True #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 5#@param {type:\"number\"}\n","n_tiles_X = 5#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","#Activate the pretrained model.\n","config = None\n","model = N2V(config, Prediction_model_name, basedir=Prediction_model_path)\n","\n","print(\"Denoising images...\")\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","suffix = '.tif'\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","# The code by Lucas von Chamier\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='ZYX', n_tiles=n_tilesZYX)\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='ZYX')\n"," \n","print(\"Prediction of images done.\")\n","\n","print(\"One example is displayed here.\")\n","\n","\n","#Display an example\n","random_choice=random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest')\n","plt.title('Noisy Input (single Z plane)');\n","plt.axis('off');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], interpolation='nearest')\n","plt.title('Prediction (single Z plane)');\n","plt.axis('off');"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB","colab_type":"text"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t","colab_type":"text"},"source":["#**Thank you for using Noise2Void 3D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Noise2Void_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83","colab_type":"text"},"source":["# **Noise2Void (3D)**\n","\n","---\n","\n"," Noise2Void is a deep-learning method that can be used to denoise many types of images, including microscopy images and which was originally published by [Krull *et al.* on arXiv](https://arxiv.org/abs/1811.10980). It allows denoising of image data in a self-supervised manner, therefore high-quality, low noise equivalent images are not necessary to train this network. This is performed by \"masking\" a random subset of pixels in the noisy image and training the network to predict the values in these pixels. The resulting output is a denoised version of the image. Noise2Void is based on the popular U-Net network architecture, adapted from [CARE](https://www.nature.com/articles/s41592-018-0216-7).\n","\n"," **This particular notebook enables self-supervised denoised of 3D dataset. If you are interested in 2D dataset, you should use the Noise2Void 2D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper:\n","\n","**Noise2Void - Learning Denoising from Single Noisy Images**\n","from Krull *et al.* published on arXiv in 2018 (https://arxiv.org/abs/1811.10980)\n","\n","And source code found in: https://github.com/juglab/n2v\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","\n","\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX","colab_type":"text"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For Noise2Void to train, it only requires a single noisy image but multiple images can be used. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","Please note that you currently can **only use .tif files!**\n","\n","**We strongly recommend that you generate high signal to noise ration version of your noisy images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here is a common data structure that can work:\n","\n","* Data\n"," - **Training dataset**\n"," - **Quality control dataset** (Optional but recomended)\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif \n"," - **Data to be predicted** \n"," - **Results**\n","\n","\n","The **Results** folder will contain the processed images, trained model and network parameters as csv file. Your original images remain unmodified.\n","\n","---\n","**Important note**\n","\n","- If you wish to **train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"AdN8B91xZO0x"},"source":["# **2. Install Noise2Void and dependencies**\n","---"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install Noise2Void and dependencies\n","\n","# Enable the Tensorflow 1 instead of the Tensorflow 2.\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","# Here we install Noise2Void and other required packages\n","!pip install n2v\n","!pip install wget\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","print(\"Noise2Void installed.\")\n","\n","# Here we install all libraries and other depencies to run the notebook.\n","\n","# ------- Variable specific to N2V -------\n","from n2v.models import N2VConfig, N2V\n","from csbdeep.utils import plot_history\n","from n2v.utils.n2v_utils import manipulate_val_data\n","from n2v.internals.N2V_DataGenerator import N2V_DataGenerator\n","from csbdeep.io import save_tiff_imagej_compatible\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43","colab_type":"text"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`:** This is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 5.). **Default value: 30**\n","\n","**`patch_size`:** Noise2Void divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 64**\n","\n","**`patch_height`:** The value should be smaller than the Z dimensions of the image and divisible by 4. When analysing isotropic stacks patch_size and patch_height should have similar values.\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# Create DataGenerator-object.\n","datagen = N2V_DataGenerator()\n","\n","#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","imgs = datagen.load_imgs_from_directory(directory = Training_source, dims='ZYX')\n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Training Parameters\n","#@markdown Number of steps and epochs:\n","\n","number_of_epochs = 30#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels) and number\n","patch_size = 64#@param {type:\"number\"}\n","\n","patch_height = 4#@param {type:\"number\"}\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 128#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," # number_of_steps is defined in the following cell in this case\n"," batch_size = 128\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name): \n"," print(bcolors.WARNING +\"!! WARNING: Folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","\n","#Load one randomly chosen training target file\n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING + \"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_height is smaller than the z dimension of the image \n","if patch_height > Image_Z :\n"," patch_height = Image_Z\n"," print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we check that patch_height is divisible by 4\n","if not patch_height % 4 == 0:\n"," patch_height = ((int(patch_height / 4)-1) * 4)\n"," if patch_height == 0:\n"," patch_height = 4\n"," print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not run)\n","Use_pretrained_model = False\n","\n","# Here we enable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = True\n","\n","print(\"Parameters initiated.\")\n","\n","\n","#Here we display a single z plane\n","\n","norm = simple_norm(x[mid_plane], percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n"," By default data augmentation is enabled. Disable this option is you run out of RAM during the training.\n"," "]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a N2V 3D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M","colab_type":"text"},"source":["#**4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT","colab_type":"text"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","\n","#Disable some of the warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","# Create batches from the training data.\n","patches = datagen.generate_patches_from_list(imgs, shape=(patch_height, patch_size, patch_size), augment=Use_Data_augmentation)\n","\n","# Patches are divited into training and validation patch set. This inhibits over-lapping of patches. \n","number_train_images =int(len(patches)*(percentage_validation/100))\n","X = patches[number_train_images:]\n","X_val = patches[:number_train_images]\n","\n","print(len(patches),\"patches created.\")\n","print(number_train_images,\"patch images for validation (\",percentage_validation,\"%).\")\n","print((len(patches)-number_train_images),\"patch images for training.\")\n","%memit \n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size) + 1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# creates Congfig object. \n","config = N2VConfig(X, unet_kern_size=3, \n"," train_steps_per_epoch=number_of_steps,train_epochs=number_of_epochs, train_loss='mse', batch_norm=True, \n"," train_batch_size=batch_size, n2v_perc_pix=0.198, n2v_patch_shape=(patch_height, patch_size, patch_size), \n"," n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5, train_learning_rate = initial_learning_rate)\n","\n","vars(config)\n","\n","# Create the default model.\n","model = N2V(config=config, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","print(\"Parameters transferred into the model.\")\n","print(config)\n","\n","# Shows a training batch and a validation batch.\n","plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(X[0,1,...,0],cmap='magma')\n","plt.axis('off')\n","plt.title('Training Patch');\n","plt.subplot(1,2,2)\n","plt.imshow(X_val[0,1,...,0],cmap='magma')\n","plt.axis('off')\n","plt.title('Validation Patch');\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d","colab_type":"text"},"source":["## **4.2. Start Trainning**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point."]},{"cell_type":"code","metadata":{"scrolled":true,"colab_type":"code","cellView":"form","id":"iwNmp1PUzRDQ","colab":{}},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","%memit\n","# the training starts.\n","history = model.train(X, X_val)\n","%memit\n","print(\"Model training is now done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nRaaG02xZh_N","colab_type":"text"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else: \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n"]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(path_metrics_save+'Prediction'):\n"," shutil.rmtree(path_metrics_save+'Prediction')\n","os.makedirs(path_metrics_save+'Prediction')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = True #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","\n","# Activate the pretrained model. \n","model_training = N2V(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," n_slices = img.shape[0]\n"," predicted = model_training.predict(img, axes='ZYX', n_tiles=n_tilesZYX)\n"," os.chdir(path_metrics_save+'Prediction/')\n"," imsave('Predicted_'+filename, predicted)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvS_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvS_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvS_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," mSSIM_GvS_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," NRMSE_GvS_list_mean = []\n"," PSNR_GvP_list_mean = []\n"," PSNR_GvS_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," # Calculating the position of the mid-plane slice\n"," z_mid_plane = int(n_slices / 2)+1\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_SSIM_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," # -------------------------------- Normalising the dataset --------------------------------\n","\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT_stack[z], test_source_stack[z], normalize_gt=True)\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = structural_similarity(test_GT_norm, test_source_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = img_as_float32(img_SSIM_GTvsPrediction,force_copy=False)\n"," img_SSIM_GTvsSource_stack[z] = img_as_float32(img_SSIM_GTvsSource,force_copy=False)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = img_as_float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_stack[z] = img_as_float32(img_RSE_GTvsSource)\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource), str(PSNR_GTvsPrediction), str(PSNR_GTvsSource)])\n"," \n"," # Collect values to display in dataframe output\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvS_list.append(index_SSIM_GTvsSource)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvS_list.append(NRMSE_GTvsSource)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvS_list.append(PSNR_GTvsSource)\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n"," SSIM_GTvsS_forDisplay = index_SSIM_GTvsSource\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n"," NRMSE_GTvsS_forDisplay = NRMSE_GTvsSource\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n"," mSSIM_GvS_list_mean.append(sum(mSSIM_GvS_list)/len(mSSIM_GvS_list))\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n"," NRMSE_GvS_list_mean.append(sum(NRMSE_GvS_list)/len(NRMSE_GvS_list))\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n"," PSNR_GvS_list_mean.append(sum(PSNR_GvS_list)/len(PSNR_GvS_list))\n","\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n","\n"," img_SSIM_GTvsSource_stack_32 = img_as_float32(img_SSIM_GTvsSource_stack, force_copy=False)\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsSource_stack_32 = img_as_float32(img_RSE_GTvsSource_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsSource_'+thisFile,img_SSIM_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsSource_'+thisFile,img_RSE_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","pdResults[\"Input v. GT mSSIM\"] = mSSIM_GvS_list_mean\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","pdResults[\"Input v. GT NRMSE\"] = NRMSE_GvS_list_mean\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","pdResults[\"Input v. GT PSNR\"] = PSNR_GvS_list_mean\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(15,15))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","\n","# Calculating the position of the mid-plane slice\n","z_mid_plane = int(img_GT.shape[0] / 2)+1\n","\n","plt.imshow(img_GT[z_mid_plane])\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane])\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","plt.imshow(img_Prediction[z_mid_plane])\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_SSIM_GTvsSource = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsSource_'+Test_FileList[-1]))\n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsS_forDisplay,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_RSE_GTvsSource = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsSource_'+Test_FileList[-1]))\n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax = 1) \n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsS_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Qulity Control folder.')\n","pdResults.head()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP","colab_type":"text"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","colab_type":"code","cellView":"form","colab":{}},"source":["#Activate the pretrained model. \n","#model_training = CARE(config=None, name=model_name, basedir=model_path)\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else: \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = True #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 5#@param {type:\"number\"}\n","n_tiles_X = 5#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","#Activate the pretrained model.\n","config = None\n","model = N2V(config, Prediction_model_name, basedir=Prediction_model_path)\n","\n","print(\"Denoising images...\")\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","suffix = '.tif'\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","# The code by Lucas von Chamier\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='ZYX', n_tiles=n_tilesZYX)\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='ZYX')\n"," \n","print(\"Prediction of images done.\")\n","\n","print(\"One example is displayed here.\")\n","\n","\n","#Display an example\n","random_choice=random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest')\n","plt.title('Noisy Input (single Z plane)');\n","plt.axis('off');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], interpolation='nearest')\n","plt.title('Prediction (single Z plane)');\n","plt.axis('off');"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB","colab_type":"text"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t","colab_type":"text"},"source":["#**Thank you for using Noise2Void 3D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Stardist_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Stardist_2D_ZeroCostDL4Mic.ipynb index f78918aa..5a7c52e2 100755 --- a/Colab_notebooks/Stardist_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Stardist_2D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"colab":{"name":"StarDist_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1WAfQW1Mj3wy1XQZZUfU4DJVS_R_E8Cn3","timestamp":1585665697353},{"file_id":"1PKVyox_mx2rEE3VlMFQtdnVULJFhYPaD","timestamp":1583443864213},{"file_id":"1XSclOkhhHmn-9LQc9k8c3Y6seT1LEi-Y","timestamp":1583264105465},{"file_id":"1VPZYk3MeSVyZVVEmesz10VtujbD4diJk","timestamp":1579481583477},{"file_id":"1ENdOZir1Gytf6JxzyfbjgfxO3_C1dLHK","timestamp":1575415287126},{"file_id":"1G8b4dF2kCs3ePBGZthPUGOyjJpZ2G_Dm","timestamp":1575379725785},{"file_id":"1P0tT0RR_b3SFKvOcON_MzcAIcxRUQK5B","timestamp":1575377313115},{"file_id":"1hQz8PyJzBRkBZc9NwxM9mU9azRSvghBk","timestamp":1574783624098},{"file_id":"14mWTNjHgIbuuWAxb-0lhmhdIvMoZgrI0","timestamp":1574099686195},{"file_id":"1IWvFuBb0gqaJcUXhhfbcTWNh9cZEXW4S","timestamp":1573647131082},{"file_id":"1hFulBwI57YU6GoVc8sBt5KNIkCS7ynQ3","timestamp":1573579952409},{"file_id":"1Ba_Bu-PXN_2Mq5W6YHMgUYsJEfgbPtS-","timestamp":1573035984524},{"file_id":"1ePC44Qq_C2hSFGPM3PKyb0J6UBXSPddp","timestamp":1573032545399},{"file_id":"https://github.com/mpicbg-csbd/stardist/blob/master/examples/2D/2_training.ipynb","timestamp":1572984225873}],"collapsed_sections":[],"toc_visible":true},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"kiFRRolPa-Rb","colab_type":"text"},"source":["# **StarDist (2D)**\n","---\n","\n","**StarDist 2D** is a deep-learning method that can be used to segment cell nuclei from bioimages and was first published by [Schmidt *et al.* in 2018, on arXiv](https://arxiv.org/abs/1806.03535). It uses a shape representation based on star-convex polygons for nuclei in an image to predict the presence and the shape of these nuclei. This StarDist 2D network is based on an adapted U-Net network architecture.\n","\n"," **This particular notebook enables nuclei segmentation of 2D dataset. If you are interested in 3D dataset, you should use the StarDist 3D notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the paper:\n","\n","**Cell Detection with Star-convex Polygons** from Schmidt *et al.*, International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018. (https://arxiv.org/abs/1806.03535)\n","\n","and the 3D extension of the approach:\n","\n","**Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy** from Weigert *et al.* published on arXiv in 2019 (https://arxiv.org/abs/1908.03636)\n","\n","**The Original code** is freely available in GitHub:\n","https://github.com/mpicbg-csbd/stardist\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"iSuNqQ2ZMVGM","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"4-oByBSdE6DE","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," For StarDist to train, **it needs to have access to a paired training dataset made of images of nuclei and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n","The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Images\" (Training_source) and \"Training - Masks\" (Training_target).\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Please note that you currently can **only use .tif files!**\n","\n","You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. This can include Test dataset for which you have the equivalent output and can compare to what the network provides.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Images of nuclei (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - Masks (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Images of nuclei\n"," - img_1.tif, img_2.tif\n"," - Masks \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"t1sYuLChbRV3","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CDxBu1-19OyC","colab_type":"text"},"source":["\n","\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"4waLStm0RPFo","colab_type":"code","cellView":"form","colab":{"base_uri":"https://localhost:8080/","height":362},"executionInfo":{"status":"ok","timestamp":1596557087130,"user_tz":-60,"elapsed":9715,"user":{"displayName":"Romain Laine","photoUrl":"","userId":"09656923706700292222"}},"outputId":"128b12db-f59a-46d3-c9de-81918e83960b"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":1,"outputs":[{"output_type":"stream","text":["You have GPU access\n","Tue Aug 4 16:04:44 2020 \n","+-----------------------------------------------------------------------------+\n","| NVIDIA-SMI 450.57 Driver Version: 418.67 CUDA Version: 10.1 |\n","|-------------------------------+----------------------+----------------------+\n","| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n","| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n","| | | MIG M. |\n","|===============================+======================+======================|\n","| 0 Tesla K80 Off | 00000000:00:04.0 Off | 0 |\n","| N/A 37C P0 57W / 149W | 134MiB / 11441MiB | 0% Default |\n","| | | ERR! |\n","+-------------------------------+----------------------+----------------------+\n"," \n","+-----------------------------------------------------------------------------+\n","| Processes: |\n","| GPU GI CI PID Type Process name GPU Memory |\n","| ID ID Usage |\n","|=============================================================================|\n","| No running processes found |\n","+-----------------------------------------------------------------------------+\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"ZLY4qhgj8w-R","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"Ukil4yuS8seC","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bB0IaQMZmWYM","colab_type":"text"},"source":["# **2. Install StarDist and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"j0w7C8P5zPIp","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install StarDist and dependencies\n","\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# Install packages which are not included in Google Colab\n","\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install stardist # contains tools to operate STARDIST.\n","!pip install gputools # improves STARDIST performances\n","!pip install edt # improves STARDIST performances\n","!pip install wget\n","\n","\n","# ------- Variable specific to Stardist -------\n","from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available, relabel_image_stardist, random_label_cmap, relabel_image_stardist, _draw_polygons, export_imagej_rois\n","from stardist.models import Config2D, StarDist2D, StarDistData2D # import objects\n","from stardist.matching import matching_dataset\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","from csbdeep.utils import Path, normalize, download_and_extract_zip_file, plot_history # for loss plot\n","from csbdeep.io import save_tiff_imagej_compatible\n","import numpy as np\n","np.random.seed(42)\n","lbl_cmap = random_label_cmap()\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32, img_as_ubyte, img_as_float\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","import cv2\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DPWhXaltAYgH","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"KWpu5p8utpE2","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"HJKFAmuXc6d1"},"source":[" **Paths for training, predictions and results**\n","\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (images of nuclei) and Training_target (masks) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a 50-100 epochs, but a full training should run for up to 400 epochs. Evaluate the performance after training (see 5.). **Default value: 100**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 2**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`patch_size`:** Input the size of the patches use to train StarDist 2D (length of a side). The value should be smaller or equal to the dimensions of the image. Make the patch size as large as possible and divisible by 8. **Default value: dimension of the training images** \n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`n_rays`:** Set number of rays (corners) used for StarDist (for instance, a square has 4 corners). **Default value: 32** \n","\n","**`grid_parameter`:** increase this number if the cells/nuclei are very large or decrease it if they are very small. **Default value: 2**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size value until the OOM error disappear.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"colab_type":"code","cellView":"form","id":"CNJImzzVnr7h","colab":{}},"source":["#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","\n","model_path = \"\" #@param {type:\"string\"}\n","#trained_model = model_path \n","\n","\n","#@markdown ### Other parameters for training:\n","number_of_epochs = 100#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","\n","#GPU_limit = 90 #@param {type:\"number\"}\n","batch_size = 2 #@param {type:\"number\"}\n","number_of_steps = 20#@param {type:\"number\"}\n","patch_size = 1024 #@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","n_rays = 32 #@param {type:\"number\"}\n","grid_parameter = 2#@param [1, 2, 4, 8, 16, 32] {type:\"raw\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 2\n"," n_rays = 32\n"," percentage_validation = 10\n"," grid_parameter = 2\n"," initial_learning_rate = 0.0003\n","\n","percentage = percentage_validation/100\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","# Here we open will randomly chosen input and output image\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check the image dimensions\n","\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","print('Loaded images (width, length) =', x.shape)\n","\n","# If default parameters, patch size is the same as image size\n","if (Use_Default_Advanced_Parameters):\n"," patch_size = min(Image_Y, Image_X)\n"," \n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","# Here we check that the patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is:\",patch_size)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","\n","print(\"Parameters initiated.\")\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","#Here we use a simple normalisation strategy to visualise the image\n","norm = simple_norm(x, percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest', cmap=lbl_cmap)\n","plt.title('Training target')\n","plt.axis('off');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"vgT0NU3P6Bwt","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"8in3wzAw6G6g","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by [Augmentor.](https://github.com/mdbloice/Augmentor)\n","\n","[Augmentor](https://github.com/mdbloice/Augmentor) was described in the following article:\n","\n","Marcus D Bloice, Peter M Roth, Andreas Holzinger, Biomedical image augmentation using Augmentor, Bioinformatics, https://doi.org/10.1093/bioinformatics/btz259\n","\n","**Please also cite this original paper when publishing results obtained using this notebook with augmentation enabled.** "]},{"cell_type":"code","metadata":{"id":"2zk1H8J06aJH","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," !pip install Augmentor\n"," import Augmentor\n","\n","\n","#@markdown ####Choose a factor by which you want to multiply your original dataset\n","\n","Multiply_dataset_by = 2 #@param {type:\"slider\", min:1, max:30, step:1}\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please choose the probability of the following image manipulations to be used to augment your dataset (1 = always used; 0 = disabled ):\n","\n","#@markdown ####Mirror and rotate images\n","rotate_90_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","rotate_270_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_left_right = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_top_bottom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image Zoom\n","\n","random_zoom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","random_zoom_magnification = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image distortion\n","\n","random_distortion = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","#@markdown ####Image shearing and skewing \n","\n","image_shear = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","max_image_shear = 1 #@param {type:\"slider\", min:1, max:25, step:1}\n","\n","skew_image = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","skew_image_magnitude = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","if Use_Default_Augmentation_Parameters:\n"," rotate_90_degrees = 0.5\n"," rotate_270_degrees = 0.5\n"," flip_left_right = 0.5\n"," flip_top_bottom = 0.5\n","\n"," if not Multiply_dataset_by >5:\n"," random_zoom = 0\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0\n"," image_shear = 0\n"," max_image_shear = 10\n"," skew_image = 0\n"," skew_image_magnitude = 0\n","\n"," if Multiply_dataset_by >5:\n"," random_zoom = 0.1\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0.5\n"," image_shear = 0.2\n"," max_image_shear = 5\n"," skew_image = 0.2\n"," skew_image_magnitude = 0.4\n","\n"," if Multiply_dataset_by >25:\n"," random_zoom = 0.5\n"," random_zoom_magnification = 0.8\n"," random_distortion = 0.5\n"," image_shear = 0.5\n"," max_image_shear = 20\n"," skew_image = 0.5\n"," skew_image_magnitude = 0.6\n","\n","\n","list_files = os.listdir(Training_source)\n","Nb_files = len(list_files)\n","\n","Nb_augmented_files = (Nb_files * Multiply_dataset_by)\n","\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","# Here we set the path for the various folder were the augmented images will be loaded\n","\n","# All images are first saved into the augmented folder\n"," #Augmented_folder = \"/content/Augmented_Folder\"\n"," \n"," if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n"," Augmented_folder = Saving_path+\"/Augmented_Folder\"\n"," if os.path.exists(Augmented_folder):\n"," shutil.rmtree(Augmented_folder)\n"," os.makedirs(Augmented_folder)\n","\n"," #Training_source_augmented = \"/content/Training_source_augmented\"\n"," Training_source_augmented = Saving_path+\"/Training_source_augmented\"\n","\n"," if os.path.exists(Training_source_augmented):\n"," shutil.rmtree(Training_source_augmented)\n"," os.makedirs(Training_source_augmented)\n","\n"," #Training_target_augmented = \"/content/Training_target_augmented\"\n"," Training_target_augmented = Saving_path+\"/Training_target_augmented\"\n","\n"," if os.path.exists(Training_target_augmented):\n"," shutil.rmtree(Training_target_augmented)\n"," os.makedirs(Training_target_augmented)\n","\n","\n","# Here we generate the augmented images\n","#Load the images\n"," p = Augmentor.Pipeline(Training_source, Augmented_folder)\n","\n","#Define the matching images\n"," p.ground_truth(Training_target)\n","#Define the augmentation possibilities\n"," if not rotate_90_degrees == 0:\n"," p.rotate90(probability=rotate_90_degrees)\n"," \n"," if not rotate_270_degrees == 0:\n"," p.rotate270(probability=rotate_270_degrees)\n","\n"," if not flip_left_right == 0:\n"," p.flip_left_right(probability=flip_left_right)\n","\n"," if not flip_top_bottom == 0:\n"," p.flip_top_bottom(probability=flip_top_bottom)\n","\n"," if not random_zoom == 0:\n"," p.zoom_random(probability=random_zoom, percentage_area=random_zoom_magnification)\n"," \n"," if not random_distortion == 0:\n"," p.random_distortion(probability=random_distortion, grid_width=4, grid_height=4, magnitude=8)\n","\n"," if not image_shear == 0:\n"," p.shear(probability=image_shear,max_shear_left=20,max_shear_right=20)\n"," \n"," if not skew_image == 0:\n"," p.skew(probability=skew_image,magnitude=skew_image_magnitude)\n","\n"," p.sample(int(Nb_augmented_files))\n","\n"," print(int(Nb_augmented_files),\"matching images generated\")\n","\n","# Here we sort through the images and move them back to augmented trainning source and targets folders\n","\n"," augmented_files = os.listdir(Augmented_folder)\n","\n"," for f in augmented_files:\n","\n"," if (f.startswith(\"_groundtruth_(1)_\")):\n"," shortname_noprefix = f[17:]\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_target_augmented+\"/\"+shortname_noprefix) \n"," if not (f.startswith(\"_groundtruth_(1)_\")):\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_source_augmented+\"/\"+f)\n"," \n","\n"," for filename in os.listdir(Training_source_augmented):\n"," os.chdir(Training_source_augmented)\n"," os.rename(filename, filename.replace('_original', ''))\n"," \n"," #Here we clean up the extra files\n"," shutil.rmtree(Augmented_folder)\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\") \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"x4zMG4lMths-","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a StarDist model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"SfQeukJJtv9u","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"2D_versatile_fluo_from_Stardist_Fiji\" #@param [\"Model_from_file\", \"2D_versatile_fluo_from_Stardist_Fiji\", \"2D_Demo_Model_from_Stardist_Github\", \"Versatile_H&E_nuclei\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the Demo 2D model provided in the Stardist 2D github ------------------------\n","\n"," if pretrained_model_choice == \"2D_Demo_Model_from_Stardist_Github\":\n"," pretrained_model_name = \"2D_Demo\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_Github\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/2D_demo/config.json\", pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/2D_demo/thresholds.json\", pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/2D_demo/weights_best.h5?raw=true\", pretrained_model_path) \n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/2D_demo/weights_last.h5?raw=true\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Download the Demo 2D_versatile_fluo_from_Stardist_Fiji ------------------------\n","\n"," if pretrained_model_choice == \"2D_versatile_fluo_from_Stardist_Fiji\":\n"," print(\"Downloading the 2D_versatile_fluo_from_Stardist_Fiji\")\n"," pretrained_model_name = \"2D_versatile_fluo\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," \n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," \n"," wget.download(\"https://cloud.mpi-cbg.de/index.php/s/1k5Zcy7PpFWRb0Q/download?path=/versatile&files=2D_versatile_fluo.zip\", pretrained_model_path)\n"," \n"," with zipfile.ZipFile(pretrained_model_path+\"/2D_versatile_fluo.zip\", 'r') as zip_ref:\n"," zip_ref.extractall(pretrained_model_path)\n"," \n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_best.h5\")\n","\n","# --------------------- Download the Versatile (H&E nuclei)_fluo_from_Stardist_Fiji ------------------------\n","\n"," if pretrained_model_choice == \"Versatile_H&E_nuclei\":\n"," print(\"Downloading the Versatile_H&E_nuclei from_Stardist_Fiji\")\n"," pretrained_model_name = \"2D_versatile_he\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," \n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," \n"," wget.download(\"https://cloud.mpi-cbg.de/index.php/s/1k5Zcy7PpFWRb0Q/download?path=/versatile&files=2D_versatile_he.zip\", pretrained_model_path)\n"," \n"," with zipfile.ZipFile(pretrained_model_path+\"/2D_versatile_he.zip\", 'r') as zip_ref:\n"," zip_ref.extractall(pretrained_model_path)\n"," \n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_best.h5\")\n","\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist' + W)\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DECuc3HZDbwG","colab_type":"text"},"source":["#**4. Train your network**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"NwV5LweiavgQ","colab_type":"text"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"uTM781rCKT8r","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","if Use_Data_augmentation:\n"," Training_source_dir = Training_source_augmented\n"," Training_target_dir = Training_target_augmented\n","\n","if not Use_Data_augmentation:\n"," Training_source_dir = Training_source\n"," Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","training_images_tiff=Training_source_dir+\"/*.tif\"\n","mask_images_tiff=Training_target_dir+\"/*.tif\"\n","\n","# this funtion imports training images and masks and sorts them suitable for the network\n","X = sorted(glob(training_images_tiff)) \n","Y = sorted(glob(mask_images_tiff)) \n","\n","# assert -funtion check that X and Y really have images. If not this cell raises an error\n","assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))\n","\n","# Here we map the training dataset (images and masks).\n","X = list(map(imread,X))\n","Y = list(map(imread,Y))\n","n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n","\n","#Normalize images and fill small label holes.\n","axis_norm = (0,1) # normalize channels independently\n","# axis_norm = (0,1,2) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n"," sys.stdout.flush()\n","\n","X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]\n","Y = [fill_label_holes(y) for y in tqdm(Y)]\n","\n","#Here we split the your training dataset into training images (90 %) and validation images (10 %). \n","#It is advisable to use 10 % of your training dataset for validation. This ensures the truthfull validation error value. If only few validation images are used network may choose too easy or too challenging images for validation. \n","# split training data (images and masks) into training images and validation images.\n","assert len(X) > 1, \"not enough training data\"\n","rng = np.random.RandomState(42)\n","ind = rng.permutation(len(X))\n","n_val = max(1, int(round(percentage * len(ind))))\n","ind_train, ind_val = ind[:-n_val], ind[-n_val:]\n","X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]\n","X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] \n","print('number of images: %3d' % len(X))\n","print('- training: %3d' % len(X_trn))\n","print('- validation: %3d' % len(X_val))\n","\n","# Use OpenCL-based computations for data generator during training (requires 'gputools')\n","use_gpu = False and gputools_available()\n","\n","#Here we ensure that our network has a minimal number of steps\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(len(X)/batch_size)+1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","\n","conf = Config2D (\n"," n_rays = n_rays,\n"," use_gpu = use_gpu,\n"," train_batch_size = batch_size,\n"," n_channel_in = n_channel,\n"," train_patch_size = (patch_size, patch_size),\n"," grid = (grid_parameter, grid_parameter),\n"," train_learning_rate = initial_learning_rate,\n",")\n","\n","# Here we create a model according to section 5.3.\n","model = StarDist2D(conf, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","\n","\n","# --------------------- ---------------------- ------------------------\n","\n","#Here we check the FOV of the network.\n","median_size = calculate_extents(list(Y), np.median)\n","fov = np.array(model._axes_tile_overlap('YX'))\n","if any(median_size > fov):\n"," print(bcolors.WARNING+\"WARNING: median object size larger than field of view of the neural network.\")\n","print(conf)\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nnMCvu2PKT9W","colab_type":"text"},"source":["\n","## **4.2. Train the network**\n","---\n","\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the Stardist Fiji plugin. You can find it in your model folder (TF_SavedModel.zip). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system)."]},{"cell_type":"code","metadata":{"id":"XfCF-Q4lKT9e","colab_type":"code","cellView":"form","colab":{}},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","augmenter = None\n","\n","# def augmenter(X_batch, Y_batch):\n","# \"\"\"Augmentation for data batch.\n","# X_batch is a list of input images (length at most batch_size)\n","# Y_batch is the corresponding list of ground-truth label images\n","# \"\"\"\n","# # ...\n","# return X_batch, Y_batch\n","\n","# Training the model. \n","# 'input_epochs' and 'steps' refers to your input data in section 5.1 \n","history = model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter,\n"," epochs=number_of_epochs, steps_per_epoch=number_of_steps)\n","None;\n","\n","print(\"Training done\")\n","\n","print(\"Network optimization in progress\")\n","#Here we optimize the network.\n","model.optimize_thresholds(X_val, Y_val)\n","\n","print(\"Done\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model.export_TF()\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the Stardist Fiji plugin\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"iYRrmh0dCrNs","colab_type":"text"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder.\n","\n"]},{"cell_type":"markdown","metadata":{"id":"U8H7QRfKBzI8","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"o2O0QnO4PFlz","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else: \n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-2b4RMU_Ec2y","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"KG8wZrA3Ef4n","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"GFJBwr5TEgcq","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n"," The results can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\"."]},{"cell_type":"code","metadata":{"id":"EvCMiYaeElc4","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","\n","#Create a quality control Folder and check if the folder already exist\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\") == False:\n"," os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","\n","# Generate predictions from the Source_QC_folder and save them in the QC folder\n","\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","n_channel = 1 if Z[0].ndim == 2 else Z[0].shape[-1]\n","axis_norm = (0,1) # normalize channels independently\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n"," \n"," # axis_norm = (0,1,2) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n","\n","model = StarDist2D(None, name=QC_model_name, basedir=QC_model_path)\n","\n","names = [os.path.basename(f) for f in sorted(glob(Source_QC_folder_tif))]\n","\n"," \n","# modify the names to suitable form: path_images/image_numberX.tif\n"," \n","lenght_of_Z = len(Z)\n"," \n","for i in range(lenght_of_Z):\n"," img = normalize(Z[i], 1,99.8, axis=axis_norm)\n"," labels, polygons = model.predict_instances(img)\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(names[i], labels, polygons)\n","\n","\n","# Here we start testing the differences between GT and predicted masks\n","\n","\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\"]) \n","\n","# define the images\n","\n"," for n in os.listdir(Source_QC_folder):\n"," \n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))\n","\n"," #Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n"," #Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n"," # Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score)])\n","\n","\n","#Display the last image\n","\n","f = plt.figure(figsize=(25,25))\n","\n","from astropy.visualization import simple_norm\n","norm = simple_norm(test_input, percent = 99)\n","\n","#Input\n","plt.subplot(1,4,1)\n","plt.axis('off')\n","plt.imshow(test_input, aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n","plt.title('Input')\n","\n","\n","#Ground-truth\n","plt.subplot(1,4,2)\n","plt.axis('off')\n","plt.imshow(test_ground_truth_0_to_255, aspect='equal', cmap='Greens')\n","plt.title('Ground Truth')\n","\n","#Prediction\n","plt.subplot(1,4,3)\n","plt.axis('off')\n","plt.imshow(test_prediction_0_to_255, aspect='equal', cmap='Purples')\n","plt.title('Prediction')\n","\n","#Overlay\n","plt.subplot(1,4,4)\n","plt.axis('off')\n","plt.imshow(test_ground_truth_0_to_255, cmap='Greens')\n","plt.imshow(test_prediction_0_to_255, alpha=0.5, cmap='Purples')\n","plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3)));\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"iAPmwlxCEzxQ","colab_type":"text"},"source":["# **6. Using the trained model**\n","---"]},{"cell_type":"markdown","metadata":{"id":"btXwwnVpBEMB","colab_type":"text"},"source":["\n","\n","## **6.1 Generate prediction(s) from unseen dataset**\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive.\n","\n","---\n","\n","The current trained model (from section 4.3) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output ROI.\n","\n","**`Data_type`:** Please indicate if the images you want to predict are single images or stacks\n","\n","\n","In stardist the following results can be exported:\n","- Region of interest (ROI) that can be opened in ImageJ / Fiji. The ROI are saved inside of a .zip file in your choosen result folder. To open the ROI in Fiji, just drag and drop the zip file !**\n","- The predicted mask images\n","- A tracking file that can easily be imported into Trackmate to track the nuclei (Stacks only).\n","- A CSV file that contains the number of nuclei detected per image (single image only). \n","\n"]},{"cell_type":"code","metadata":{"id":"x8UXP8S2eoo_","colab_type":"code","cellView":"form","colab":{}},"source":["Single_Images = 1\n","Stacks = 2\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Results_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Are your data single images or stacks?\n","\n","Data_type = Single_Images #@param [\"Single_Images\", \"Stacks\"] {type:\"raw\"}\n","\n","#@markdown ###What outputs would you like to generate?\n","Region_of_interests = True #@param {type:\"boolean\"}\n","Mask_images = True #@param {type:\"boolean\"}\n","Tracking_file = False #@param {type:\"boolean\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","#single images\n","Data_folder = Data_folder+\"/*.tif\"\n","\n","if Data_type == 1 :\n"," print(\"Single images are now beeing predicted\")\n"," np.random.seed(16)\n"," lbl_cmap = random_label_cmap()\n"," X = sorted(glob(Data_folder))\n"," X = list(map(imread,X))\n"," n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n"," axis_norm = (0,1) # normalize channels independently\n"," \n"," # axis_norm = (0,1,2) # normalize channels jointly\n"," if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n"," model = StarDist2D(None, name = Prediction_model_name, basedir = Prediction_model_path)\n"," \n"," names = [os.path.basename(f) for f in sorted(glob(Data_folder))]\n"," \n"," Nuclei_number = []\n","\n"," # modify the names to suitable form: path_images/image_numberX.tif\n"," FILEnames = []\n"," for m in names:\n"," m = Results_folder+'/'+m\n"," FILEnames.append(m)\n","\n"," # Create a list of name with no extension\n"," \n"," name_no_extension=[]\n"," for n in names:\n"," name_no_extension.append(os.path.splitext(n)[0])\n"," \n","\n"," # Save all ROIs and masks into results folder\n"," \n"," for i in range(len(X)):\n"," img = normalize(X[i], 1,99.8, axis = axis_norm)\n"," labels, polygons = model.predict_instances(img)\n"," \n"," os.chdir(Results_folder)\n","\n"," if Mask_images:\n"," imsave(FILEnames[i], labels, polygons)\n","\n"," if Region_of_interests:\n"," export_imagej_rois(name_no_extension[i], polygons['coord'])\n","\n"," if Tracking_file:\n"," print(bcolors.WARNING+\"Tracking files are only generated when stacks are predicted\"+W) \n"," \n"," \n"," Nuclei_array = polygons['coord']\n"," Nuclei_array2 = [names[i], Nuclei_array.shape[0]]\n"," Nuclei_number.append(Nuclei_array2) \n","\n"," my_df = pd.DataFrame(Nuclei_number)\n"," my_df.to_csv(Results_folder+'/Nuclei_count.csv', index=False, header=False)\n"," \n","\n"," # One example is displayed\n","\n"," print(\"One example image is displayed bellow:\")\n"," plt.figure(figsize=(10,10))\n"," plt.imshow(img if img.ndim==2 else img[...,:3], clim=(0,1), cmap='gray')\n"," plt.imshow(labels, cmap=lbl_cmap, alpha=0.5)\n"," plt.axis('off');\n"," plt.savefig(name_no_extension[i]+\"_overlay.tif\")\n","\n","if Data_type == 2 :\n"," print(\"Stacks are now beeing predicted\")\n"," np.random.seed(42)\n"," lbl_cmap = random_label_cmap()\n"," Y = sorted(glob(Data_folder))\n"," X = list(map(imread,Y))\n"," n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n"," axis_norm = (0,1) # normalize channels independently\n"," # axis_norm = (0,1,2) # normalize channels jointly\n"," if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n"," #Load a pretrained network\n"," model = StarDist2D(None, name = Prediction_model_name, basedir = Prediction_model_path)\n"," \n"," names = [os.path.basename(f) for f in sorted(glob(Data_folder))]\n","\n"," # Create a list of name with no extension\n"," \n"," name_no_extension = []\n"," for n in names:\n"," name_no_extension.append(os.path.splitext(n)[0])\n","\n"," outputdir = Path(Results_folder)\n","\n","# Save all ROIs and images in Results folder.\n"," for num, i in enumerate(X):\n"," print(\"Performing prediction on: \"+names[num])\n","\n"," \n"," timelapse = np.stack(i)\n"," timelapse = normalize(timelapse, 1,99.8, axis=(0,)+tuple(1+np.array(axis_norm)))\n"," timelapse.shape\n","\n"," if Region_of_interests: \n"," polygons = [model.predict_instances(frame)[1]['coord'] for frame in tqdm(timelapse)] \n"," export_imagej_rois(os.path.join(outputdir, name_no_extension[num]), polygons) \n"," \n"," n_timepoint = timelapse.shape[0]\n"," prediction_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n"," Tracking_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n","\n","# Save the masks in the result folder\n"," if Mask_images or Tracking_file:\n"," for t in range(n_timepoint):\n"," img_t = timelapse[t]\n"," labels, polygons = model.predict_instances(img_t) \n"," prediction_stack[t] = labels\n","\n","# Create a tracking file for trackmate\n","\n"," for point in polygons['points']:\n"," cv2.circle(Tracking_stack[t],tuple(point),0,(1), -1)\n","\n"," prediction_stack_32 = img_as_float32(prediction_stack, force_copy=False)\n"," Tracking_stack_32 = img_as_float32(Tracking_stack, force_copy=False)\n"," Tracking_stack_8 = img_as_ubyte(Tracking_stack_32, force_copy=True)\n"," \n"," Tracking_stack_8_rot = np.rot90(Tracking_stack_8, axes=(1,2))\n"," Tracking_stack_8_rot_flip = np.fliplr(Tracking_stack_8_rot)\n","\n"," os.chdir(Results_folder)\n"," if Mask_images:\n"," imsave(names[num], prediction_stack_32)\n"," if Tracking_file:\n"," imsave(name_no_extension[num]+\"_tracking_file.tif\", Tracking_stack_8_rot_flip)\n","\n"," \n","\n","print(\"Predictions completed\") "],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"SxJsrw3kTcFx","colab_type":"text"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"rH_J20ydXWRQ","colab_type":"text"},"source":["\n","#**Thank you for using StarDist 2D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"colab":{"name":"StarDist_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1WAfQW1Mj3wy1XQZZUfU4DJVS_R_E8Cn3","timestamp":1585665697353},{"file_id":"1PKVyox_mx2rEE3VlMFQtdnVULJFhYPaD","timestamp":1583443864213},{"file_id":"1XSclOkhhHmn-9LQc9k8c3Y6seT1LEi-Y","timestamp":1583264105465},{"file_id":"1VPZYk3MeSVyZVVEmesz10VtujbD4diJk","timestamp":1579481583477},{"file_id":"1ENdOZir1Gytf6JxzyfbjgfxO3_C1dLHK","timestamp":1575415287126},{"file_id":"1G8b4dF2kCs3ePBGZthPUGOyjJpZ2G_Dm","timestamp":1575379725785},{"file_id":"1P0tT0RR_b3SFKvOcON_MzcAIcxRUQK5B","timestamp":1575377313115},{"file_id":"1hQz8PyJzBRkBZc9NwxM9mU9azRSvghBk","timestamp":1574783624098},{"file_id":"14mWTNjHgIbuuWAxb-0lhmhdIvMoZgrI0","timestamp":1574099686195},{"file_id":"1IWvFuBb0gqaJcUXhhfbcTWNh9cZEXW4S","timestamp":1573647131082},{"file_id":"1hFulBwI57YU6GoVc8sBt5KNIkCS7ynQ3","timestamp":1573579952409},{"file_id":"1Ba_Bu-PXN_2Mq5W6YHMgUYsJEfgbPtS-","timestamp":1573035984524},{"file_id":"1ePC44Qq_C2hSFGPM3PKyb0J6UBXSPddp","timestamp":1573032545399},{"file_id":"https://github.com/mpicbg-csbd/stardist/blob/master/examples/2D/2_training.ipynb","timestamp":1572984225873}],"collapsed_sections":[],"toc_visible":true},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"kiFRRolPa-Rb","colab_type":"text"},"source":["# **StarDist (2D)**\n","---\n","\n","**StarDist 2D** is a deep-learning method that can be used to segment cell nuclei from bioimages and was first published by [Schmidt *et al.* in 2018, on arXiv](https://arxiv.org/abs/1806.03535). It uses a shape representation based on star-convex polygons for nuclei in an image to predict the presence and the shape of these nuclei. This StarDist 2D network is based on an adapted U-Net network architecture.\n","\n"," **This particular notebook enables nuclei segmentation of 2D dataset. If you are interested in 3D dataset, you should use the StarDist 3D notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the paper:\n","\n","**Cell Detection with Star-convex Polygons** from Schmidt *et al.*, International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018. (https://arxiv.org/abs/1806.03535)\n","\n","and the 3D extension of the approach:\n","\n","**Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy** from Weigert *et al.* published on arXiv in 2019 (https://arxiv.org/abs/1908.03636)\n","\n","**The Original code** is freely available in GitHub:\n","https://github.com/mpicbg-csbd/stardist\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"iSuNqQ2ZMVGM","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"4-oByBSdE6DE","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," For StarDist to train, **it needs to have access to a paired training dataset made of images of nuclei and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n","The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Images\" (Training_source) and \"Training - Masks\" (Training_target).\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Please note that you currently can **only use .tif files!**\n","\n","You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. This can include Test dataset for which you have the equivalent output and can compare to what the network provides.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Images of nuclei (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - Masks (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Images of nuclei\n"," - img_1.tif, img_2.tif\n"," - Masks \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"t1sYuLChbRV3","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CDxBu1-19OyC","colab_type":"text"},"source":["\n","\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"4waLStm0RPFo","colab_type":"code","cellView":"form","colab":{"base_uri":"https://localhost:8080/","height":362},"executionInfo":{"status":"ok","timestamp":1596557087130,"user_tz":-60,"elapsed":9715,"user":{"displayName":"Romain Laine","photoUrl":"","userId":"09656923706700292222"}},"outputId":"128b12db-f59a-46d3-c9de-81918e83960b"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[{"output_type":"stream","text":["You have GPU access\n","Tue Aug 4 16:04:44 2020 \n","+-----------------------------------------------------------------------------+\n","| NVIDIA-SMI 450.57 Driver Version: 418.67 CUDA Version: 10.1 |\n","|-------------------------------+----------------------+----------------------+\n","| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n","| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n","| | | MIG M. |\n","|===============================+======================+======================|\n","| 0 Tesla K80 Off | 00000000:00:04.0 Off | 0 |\n","| N/A 37C P0 57W / 149W | 134MiB / 11441MiB | 0% Default |\n","| | | ERR! |\n","+-------------------------------+----------------------+----------------------+\n"," \n","+-----------------------------------------------------------------------------+\n","| Processes: |\n","| GPU GI CI PID Type Process name GPU Memory |\n","| ID ID Usage |\n","|=============================================================================|\n","| No running processes found |\n","+-----------------------------------------------------------------------------+\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"ZLY4qhgj8w-R","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"Ukil4yuS8seC","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bB0IaQMZmWYM","colab_type":"text"},"source":["# **2. Install StarDist and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"j0w7C8P5zPIp","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install StarDist and dependencies\n","\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# Install packages which are not included in Google Colab\n","\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install stardist # contains tools to operate STARDIST.\n","!pip install gputools # improves STARDIST performances\n","!pip install edt # improves STARDIST performances\n","!pip install wget\n","\n","\n","# ------- Variable specific to Stardist -------\n","from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available, relabel_image_stardist, random_label_cmap, relabel_image_stardist, _draw_polygons, export_imagej_rois\n","from stardist.models import Config2D, StarDist2D, StarDistData2D # import objects\n","from stardist.matching import matching_dataset\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","from csbdeep.utils import Path, normalize, download_and_extract_zip_file, plot_history # for loss plot\n","from csbdeep.io import save_tiff_imagej_compatible\n","import numpy as np\n","np.random.seed(42)\n","lbl_cmap = random_label_cmap()\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32, img_as_ubyte, img_as_float\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","import cv2\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DPWhXaltAYgH","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"KWpu5p8utpE2","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"HJKFAmuXc6d1"},"source":[" **Paths for training, predictions and results**\n","\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (images of nuclei) and Training_target (masks) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a 50-100 epochs, but a full training should run for up to 400 epochs. Evaluate the performance after training (see 5.). **Default value: 100**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 2**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`patch_size`:** Input the size of the patches use to train StarDist 2D (length of a side). The value should be smaller or equal to the dimensions of the image. Make the patch size as large as possible and divisible by 8. **Default value: dimension of the training images** \n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`n_rays`:** Set number of rays (corners) used for StarDist (for instance, a square has 4 corners). **Default value: 32** \n","\n","**`grid_parameter`:** increase this number if the cells/nuclei are very large or decrease it if they are very small. **Default value: 2**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size value until the OOM error disappear.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"colab_type":"code","cellView":"form","id":"CNJImzzVnr7h","colab":{}},"source":["#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","\n","model_path = \"\" #@param {type:\"string\"}\n","#trained_model = model_path \n","\n","\n","#@markdown ### Other parameters for training:\n","number_of_epochs = 100#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","\n","#GPU_limit = 90 #@param {type:\"number\"}\n","batch_size = 2 #@param {type:\"number\"}\n","number_of_steps = 20#@param {type:\"number\"}\n","patch_size = 1024 #@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","n_rays = 32 #@param {type:\"number\"}\n","grid_parameter = 2#@param [1, 2, 4, 8, 16, 32] {type:\"raw\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 2\n"," n_rays = 32\n"," percentage_validation = 10\n"," grid_parameter = 2\n"," initial_learning_rate = 0.0003\n","\n","percentage = percentage_validation/100\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","# Here we open will randomly chosen input and output image\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check the image dimensions\n","\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","print('Loaded images (width, length) =', x.shape)\n","\n","# If default parameters, patch size is the same as image size\n","if (Use_Default_Advanced_Parameters):\n"," patch_size = min(Image_Y, Image_X)\n"," \n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","# Here we check that the patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is:\",patch_size)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","\n","print(\"Parameters initiated.\")\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","#Here we use a simple normalisation strategy to visualise the image\n","norm = simple_norm(x, percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest', cmap=lbl_cmap)\n","plt.title('Training target')\n","plt.axis('off');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"vgT0NU3P6Bwt","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"8in3wzAw6G6g","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by [Augmentor.](https://github.com/mdbloice/Augmentor)\n","\n","[Augmentor](https://github.com/mdbloice/Augmentor) was described in the following article:\n","\n","Marcus D Bloice, Peter M Roth, Andreas Holzinger, Biomedical image augmentation using Augmentor, Bioinformatics, https://doi.org/10.1093/bioinformatics/btz259\n","\n","**Please also cite this original paper when publishing results obtained using this notebook with augmentation enabled.** "]},{"cell_type":"code","metadata":{"id":"2zk1H8J06aJH","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," !pip install Augmentor\n"," import Augmentor\n","\n","\n","#@markdown ####Choose a factor by which you want to multiply your original dataset\n","\n","Multiply_dataset_by = 2 #@param {type:\"slider\", min:1, max:30, step:1}\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please choose the probability of the following image manipulations to be used to augment your dataset (1 = always used; 0 = disabled ):\n","\n","#@markdown ####Mirror and rotate images\n","rotate_90_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","rotate_270_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_left_right = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_top_bottom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image Zoom\n","\n","random_zoom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","random_zoom_magnification = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image distortion\n","\n","random_distortion = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","#@markdown ####Image shearing and skewing \n","\n","image_shear = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","max_image_shear = 1 #@param {type:\"slider\", min:1, max:25, step:1}\n","\n","skew_image = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","skew_image_magnitude = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","if Use_Default_Augmentation_Parameters:\n"," rotate_90_degrees = 0.5\n"," rotate_270_degrees = 0.5\n"," flip_left_right = 0.5\n"," flip_top_bottom = 0.5\n","\n"," if not Multiply_dataset_by >5:\n"," random_zoom = 0\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0\n"," image_shear = 0\n"," max_image_shear = 10\n"," skew_image = 0\n"," skew_image_magnitude = 0\n","\n"," if Multiply_dataset_by >5:\n"," random_zoom = 0.1\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0.5\n"," image_shear = 0.2\n"," max_image_shear = 5\n"," skew_image = 0.2\n"," skew_image_magnitude = 0.4\n","\n"," if Multiply_dataset_by >25:\n"," random_zoom = 0.5\n"," random_zoom_magnification = 0.8\n"," random_distortion = 0.5\n"," image_shear = 0.5\n"," max_image_shear = 20\n"," skew_image = 0.5\n"," skew_image_magnitude = 0.6\n","\n","\n","list_files = os.listdir(Training_source)\n","Nb_files = len(list_files)\n","\n","Nb_augmented_files = (Nb_files * Multiply_dataset_by)\n","\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","# Here we set the path for the various folder were the augmented images will be loaded\n","\n","# All images are first saved into the augmented folder\n"," #Augmented_folder = \"/content/Augmented_Folder\"\n"," \n"," if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n"," Augmented_folder = Saving_path+\"/Augmented_Folder\"\n"," if os.path.exists(Augmented_folder):\n"," shutil.rmtree(Augmented_folder)\n"," os.makedirs(Augmented_folder)\n","\n"," #Training_source_augmented = \"/content/Training_source_augmented\"\n"," Training_source_augmented = Saving_path+\"/Training_source_augmented\"\n","\n"," if os.path.exists(Training_source_augmented):\n"," shutil.rmtree(Training_source_augmented)\n"," os.makedirs(Training_source_augmented)\n","\n"," #Training_target_augmented = \"/content/Training_target_augmented\"\n"," Training_target_augmented = Saving_path+\"/Training_target_augmented\"\n","\n"," if os.path.exists(Training_target_augmented):\n"," shutil.rmtree(Training_target_augmented)\n"," os.makedirs(Training_target_augmented)\n","\n","\n","# Here we generate the augmented images\n","#Load the images\n"," p = Augmentor.Pipeline(Training_source, Augmented_folder)\n","\n","#Define the matching images\n"," p.ground_truth(Training_target)\n","#Define the augmentation possibilities\n"," if not rotate_90_degrees == 0:\n"," p.rotate90(probability=rotate_90_degrees)\n"," \n"," if not rotate_270_degrees == 0:\n"," p.rotate270(probability=rotate_270_degrees)\n","\n"," if not flip_left_right == 0:\n"," p.flip_left_right(probability=flip_left_right)\n","\n"," if not flip_top_bottom == 0:\n"," p.flip_top_bottom(probability=flip_top_bottom)\n","\n"," if not random_zoom == 0:\n"," p.zoom_random(probability=random_zoom, percentage_area=random_zoom_magnification)\n"," \n"," if not random_distortion == 0:\n"," p.random_distortion(probability=random_distortion, grid_width=4, grid_height=4, magnitude=8)\n","\n"," if not image_shear == 0:\n"," p.shear(probability=image_shear,max_shear_left=20,max_shear_right=20)\n"," \n"," if not skew_image == 0:\n"," p.skew(probability=skew_image,magnitude=skew_image_magnitude)\n","\n"," p.sample(int(Nb_augmented_files))\n","\n"," print(int(Nb_augmented_files),\"matching images generated\")\n","\n","# Here we sort through the images and move them back to augmented trainning source and targets folders\n","\n"," augmented_files = os.listdir(Augmented_folder)\n","\n"," for f in augmented_files:\n","\n"," if (f.startswith(\"_groundtruth_(1)_\")):\n"," shortname_noprefix = f[17:]\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_target_augmented+\"/\"+shortname_noprefix) \n"," if not (f.startswith(\"_groundtruth_(1)_\")):\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_source_augmented+\"/\"+f)\n"," \n","\n"," for filename in os.listdir(Training_source_augmented):\n"," os.chdir(Training_source_augmented)\n"," os.rename(filename, filename.replace('_original', ''))\n"," \n"," #Here we clean up the extra files\n"," shutil.rmtree(Augmented_folder)\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\") \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"x4zMG4lMths-","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a StarDist model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"SfQeukJJtv9u","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"2D_versatile_fluo_from_Stardist_Fiji\" #@param [\"Model_from_file\", \"2D_versatile_fluo_from_Stardist_Fiji\", \"2D_Demo_Model_from_Stardist_Github\", \"Versatile_H&E_nuclei\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the Demo 2D model provided in the Stardist 2D github ------------------------\n","\n"," if pretrained_model_choice == \"2D_Demo_Model_from_Stardist_Github\":\n"," pretrained_model_name = \"2D_Demo\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_Github\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/2D_demo/config.json\", pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/2D_demo/thresholds.json\", pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/2D_demo/weights_best.h5?raw=true\", pretrained_model_path) \n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/2D_demo/weights_last.h5?raw=true\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Download the Demo 2D_versatile_fluo_from_Stardist_Fiji ------------------------\n","\n"," if pretrained_model_choice == \"2D_versatile_fluo_from_Stardist_Fiji\":\n"," print(\"Downloading the 2D_versatile_fluo_from_Stardist_Fiji\")\n"," pretrained_model_name = \"2D_versatile_fluo\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," \n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," \n"," wget.download(\"https://cloud.mpi-cbg.de/index.php/s/1k5Zcy7PpFWRb0Q/download?path=/versatile&files=2D_versatile_fluo.zip\", pretrained_model_path)\n"," \n"," with zipfile.ZipFile(pretrained_model_path+\"/2D_versatile_fluo.zip\", 'r') as zip_ref:\n"," zip_ref.extractall(pretrained_model_path)\n"," \n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_best.h5\")\n","\n","# --------------------- Download the Versatile (H&E nuclei)_fluo_from_Stardist_Fiji ------------------------\n","\n"," if pretrained_model_choice == \"Versatile_H&E_nuclei\":\n"," print(\"Downloading the Versatile_H&E_nuclei from_Stardist_Fiji\")\n"," pretrained_model_name = \"2D_versatile_he\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," \n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," \n"," wget.download(\"https://cloud.mpi-cbg.de/index.php/s/1k5Zcy7PpFWRb0Q/download?path=/versatile&files=2D_versatile_he.zip\", pretrained_model_path)\n"," \n"," with zipfile.ZipFile(pretrained_model_path+\"/2D_versatile_he.zip\", 'r') as zip_ref:\n"," zip_ref.extractall(pretrained_model_path)\n"," \n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_best.h5\")\n","\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist' + W)\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DECuc3HZDbwG","colab_type":"text"},"source":["#**4. Train the network**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"NwV5LweiavgQ","colab_type":"text"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"uTM781rCKT8r","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","if Use_Data_augmentation:\n"," Training_source_dir = Training_source_augmented\n"," Training_target_dir = Training_target_augmented\n","\n","if not Use_Data_augmentation:\n"," Training_source_dir = Training_source\n"," Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","training_images_tiff=Training_source_dir+\"/*.tif\"\n","mask_images_tiff=Training_target_dir+\"/*.tif\"\n","\n","# this funtion imports training images and masks and sorts them suitable for the network\n","X = sorted(glob(training_images_tiff)) \n","Y = sorted(glob(mask_images_tiff)) \n","\n","# assert -funtion check that X and Y really have images. If not this cell raises an error\n","assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))\n","\n","# Here we map the training dataset (images and masks).\n","X = list(map(imread,X))\n","Y = list(map(imread,Y))\n","n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n","\n","#Normalize images and fill small label holes.\n","axis_norm = (0,1) # normalize channels independently\n","# axis_norm = (0,1,2) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n"," sys.stdout.flush()\n","\n","X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]\n","Y = [fill_label_holes(y) for y in tqdm(Y)]\n","\n","#Here we split the your training dataset into training images (90 %) and validation images (10 %). \n","#It is advisable to use 10 % of your training dataset for validation. This ensures the truthfull validation error value. If only few validation images are used network may choose too easy or too challenging images for validation. \n","# split training data (images and masks) into training images and validation images.\n","assert len(X) > 1, \"not enough training data\"\n","rng = np.random.RandomState(42)\n","ind = rng.permutation(len(X))\n","n_val = max(1, int(round(percentage * len(ind))))\n","ind_train, ind_val = ind[:-n_val], ind[-n_val:]\n","X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]\n","X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] \n","print('number of images: %3d' % len(X))\n","print('- training: %3d' % len(X_trn))\n","print('- validation: %3d' % len(X_val))\n","\n","# Use OpenCL-based computations for data generator during training (requires 'gputools')\n","use_gpu = False and gputools_available()\n","\n","#Here we ensure that our network has a minimal number of steps\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(len(X)/batch_size)+1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","\n","conf = Config2D (\n"," n_rays = n_rays,\n"," use_gpu = use_gpu,\n"," train_batch_size = batch_size,\n"," n_channel_in = n_channel,\n"," train_patch_size = (patch_size, patch_size),\n"," grid = (grid_parameter, grid_parameter),\n"," train_learning_rate = initial_learning_rate,\n",")\n","\n","# Here we create a model according to section 5.3.\n","model = StarDist2D(conf, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","\n","\n","# --------------------- ---------------------- ------------------------\n","\n","#Here we check the FOV of the network.\n","median_size = calculate_extents(list(Y), np.median)\n","fov = np.array(model._axes_tile_overlap('YX'))\n","if any(median_size > fov):\n"," print(bcolors.WARNING+\"WARNING: median object size larger than field of view of the neural network.\")\n","print(conf)\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nnMCvu2PKT9W","colab_type":"text"},"source":["\n","## **4.2. Start Trainning**\n","---\n","\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the Stardist Fiji plugin. You can find it in your model folder (TF_SavedModel.zip). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system)."]},{"cell_type":"code","metadata":{"id":"XfCF-Q4lKT9e","colab_type":"code","cellView":"form","colab":{}},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","augmenter = None\n","\n","# def augmenter(X_batch, Y_batch):\n","# \"\"\"Augmentation for data batch.\n","# X_batch is a list of input images (length at most batch_size)\n","# Y_batch is the corresponding list of ground-truth label images\n","# \"\"\"\n","# # ...\n","# return X_batch, Y_batch\n","\n","# Training the model. \n","# 'input_epochs' and 'steps' refers to your input data in section 5.1 \n","history = model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter,\n"," epochs=number_of_epochs, steps_per_epoch=number_of_steps)\n","None;\n","\n","print(\"Training done\")\n","\n","print(\"Network optimization in progress\")\n","#Here we optimize the network.\n","model.optimize_thresholds(X_val, Y_val)\n","\n","print(\"Done\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model.export_TF()\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the Stardist Fiji plugin\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"iYRrmh0dCrNs","colab_type":"text"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder.\n","\n"]},{"cell_type":"markdown","metadata":{"id":"U8H7QRfKBzI8","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"o2O0QnO4PFlz","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else: \n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-2b4RMU_Ec2y","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"KG8wZrA3Ef4n","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"GFJBwr5TEgcq","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n"," The results can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\"."]},{"cell_type":"code","metadata":{"id":"EvCMiYaeElc4","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","\n","#Create a quality control Folder and check if the folder already exist\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\") == False:\n"," os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","\n","# Generate predictions from the Source_QC_folder and save them in the QC folder\n","\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","n_channel = 1 if Z[0].ndim == 2 else Z[0].shape[-1]\n","axis_norm = (0,1) # normalize channels independently\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n"," \n"," # axis_norm = (0,1,2) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n","\n","model = StarDist2D(None, name=QC_model_name, basedir=QC_model_path)\n","\n","names = [os.path.basename(f) for f in sorted(glob(Source_QC_folder_tif))]\n","\n"," \n","# modify the names to suitable form: path_images/image_numberX.tif\n"," \n","lenght_of_Z = len(Z)\n"," \n","for i in range(lenght_of_Z):\n"," img = normalize(Z[i], 1,99.8, axis=axis_norm)\n"," labels, polygons = model.predict_instances(img)\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(names[i], labels, polygons)\n","\n","\n","# Here we start testing the differences between GT and predicted masks\n","\n","\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\"]) \n","\n","# define the images\n","\n"," for n in os.listdir(Source_QC_folder):\n"," \n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))\n","\n"," #Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n"," #Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n"," # Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score)])\n","\n","\n","#Display the last image\n","\n","f = plt.figure(figsize=(25,25))\n","\n","from astropy.visualization import simple_norm\n","norm = simple_norm(test_input, percent = 99)\n","\n","#Input\n","plt.subplot(1,4,1)\n","plt.axis('off')\n","plt.imshow(test_input, aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n","plt.title('Input')\n","\n","\n","#Ground-truth\n","plt.subplot(1,4,2)\n","plt.axis('off')\n","plt.imshow(test_ground_truth_0_to_255, aspect='equal', cmap='Greens')\n","plt.title('Ground Truth')\n","\n","#Prediction\n","plt.subplot(1,4,3)\n","plt.axis('off')\n","plt.imshow(test_prediction_0_to_255, aspect='equal', cmap='Purples')\n","plt.title('Prediction')\n","\n","#Overlay\n","plt.subplot(1,4,4)\n","plt.axis('off')\n","plt.imshow(test_ground_truth_0_to_255, cmap='Greens')\n","plt.imshow(test_prediction_0_to_255, alpha=0.5, cmap='Purples')\n","plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3)));\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"iAPmwlxCEzxQ","colab_type":"text"},"source":["# **6. Using the trained model**\n","---"]},{"cell_type":"markdown","metadata":{"id":"btXwwnVpBEMB","colab_type":"text"},"source":["\n","\n","## **6.1 Generate prediction(s) from unseen dataset**\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive.\n","\n","---\n","\n","The current trained model (from section 4.3) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output ROI.\n","\n","**`Data_type`:** Please indicate if the images you want to predict are single images or stacks\n","\n","\n","In stardist the following results can be exported:\n","- Region of interest (ROI) that can be opened in ImageJ / Fiji. The ROI are saved inside of a .zip file in your choosen result folder. To open the ROI in Fiji, just drag and drop the zip file !**\n","- The predicted mask images\n","- A tracking file that can easily be imported into Trackmate to track the nuclei (Stacks only).\n","- A CSV file that contains the number of nuclei detected per image (single image only). \n","\n"]},{"cell_type":"code","metadata":{"id":"x8UXP8S2eoo_","colab_type":"code","cellView":"form","colab":{}},"source":["Single_Images = 1\n","Stacks = 2\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Results_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Are your data single images or stacks?\n","\n","Data_type = Single_Images #@param [\"Single_Images\", \"Stacks\"] {type:\"raw\"}\n","\n","#@markdown ###What outputs would you like to generate?\n","Region_of_interests = True #@param {type:\"boolean\"}\n","Mask_images = True #@param {type:\"boolean\"}\n","Tracking_file = False #@param {type:\"boolean\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","#single images\n","Data_folder = Data_folder+\"/*.tif\"\n","\n","if Data_type == 1 :\n"," print(\"Single images are now beeing predicted\")\n"," np.random.seed(16)\n"," lbl_cmap = random_label_cmap()\n"," X = sorted(glob(Data_folder))\n"," X = list(map(imread,X))\n"," n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n"," axis_norm = (0,1) # normalize channels independently\n"," \n"," # axis_norm = (0,1,2) # normalize channels jointly\n"," if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n"," model = StarDist2D(None, name = Prediction_model_name, basedir = Prediction_model_path)\n"," \n"," names = [os.path.basename(f) for f in sorted(glob(Data_folder))]\n"," \n"," Nuclei_number = []\n","\n"," # modify the names to suitable form: path_images/image_numberX.tif\n"," FILEnames = []\n"," for m in names:\n"," m = Results_folder+'/'+m\n"," FILEnames.append(m)\n","\n"," # Create a list of name with no extension\n"," \n"," name_no_extension=[]\n"," for n in names:\n"," name_no_extension.append(os.path.splitext(n)[0])\n"," \n","\n"," # Save all ROIs and masks into results folder\n"," \n"," for i in range(len(X)):\n"," img = normalize(X[i], 1,99.8, axis = axis_norm)\n"," labels, polygons = model.predict_instances(img)\n"," \n"," os.chdir(Results_folder)\n","\n"," if Mask_images:\n"," imsave(FILEnames[i], labels, polygons)\n","\n"," if Region_of_interests:\n"," export_imagej_rois(name_no_extension[i], polygons['coord'])\n","\n"," if Tracking_file:\n"," print(bcolors.WARNING+\"Tracking files are only generated when stacks are predicted\"+W) \n"," \n"," \n"," Nuclei_array = polygons['coord']\n"," Nuclei_array2 = [names[i], Nuclei_array.shape[0]]\n"," Nuclei_number.append(Nuclei_array2) \n","\n"," my_df = pd.DataFrame(Nuclei_number)\n"," my_df.to_csv(Results_folder+'/Nuclei_count.csv', index=False, header=False)\n"," \n","\n"," # One example is displayed\n","\n"," print(\"One example image is displayed bellow:\")\n"," plt.figure(figsize=(10,10))\n"," plt.imshow(img if img.ndim==2 else img[...,:3], clim=(0,1), cmap='gray')\n"," plt.imshow(labels, cmap=lbl_cmap, alpha=0.5)\n"," plt.axis('off');\n"," plt.savefig(name_no_extension[i]+\"_overlay.tif\")\n","\n","if Data_type == 2 :\n"," print(\"Stacks are now beeing predicted\")\n"," np.random.seed(42)\n"," lbl_cmap = random_label_cmap()\n"," Y = sorted(glob(Data_folder))\n"," X = list(map(imread,Y))\n"," n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n"," axis_norm = (0,1) # normalize channels independently\n"," # axis_norm = (0,1,2) # normalize channels jointly\n"," if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n"," #Load a pretrained network\n"," model = StarDist2D(None, name = Prediction_model_name, basedir = Prediction_model_path)\n"," \n"," names = [os.path.basename(f) for f in sorted(glob(Data_folder))]\n","\n"," # Create a list of name with no extension\n"," \n"," name_no_extension = []\n"," for n in names:\n"," name_no_extension.append(os.path.splitext(n)[0])\n","\n"," outputdir = Path(Results_folder)\n","\n","# Save all ROIs and images in Results folder.\n"," for num, i in enumerate(X):\n"," print(\"Performing prediction on: \"+names[num])\n","\n"," \n"," timelapse = np.stack(i)\n"," timelapse = normalize(timelapse, 1,99.8, axis=(0,)+tuple(1+np.array(axis_norm)))\n"," timelapse.shape\n","\n"," if Region_of_interests: \n"," polygons = [model.predict_instances(frame)[1]['coord'] for frame in tqdm(timelapse)] \n"," export_imagej_rois(os.path.join(outputdir, name_no_extension[num]), polygons) \n"," \n"," n_timepoint = timelapse.shape[0]\n"," prediction_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n"," Tracking_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n","\n","# Save the masks in the result folder\n"," if Mask_images or Tracking_file:\n"," for t in range(n_timepoint):\n"," img_t = timelapse[t]\n"," labels, polygons = model.predict_instances(img_t) \n"," prediction_stack[t] = labels\n","\n","# Create a tracking file for trackmate\n","\n"," for point in polygons['points']:\n"," cv2.circle(Tracking_stack[t],tuple(point),0,(1), -1)\n","\n"," prediction_stack_32 = img_as_float32(prediction_stack, force_copy=False)\n"," Tracking_stack_32 = img_as_float32(Tracking_stack, force_copy=False)\n"," Tracking_stack_8 = img_as_ubyte(Tracking_stack_32, force_copy=True)\n"," \n"," Tracking_stack_8_rot = np.rot90(Tracking_stack_8, axes=(1,2))\n"," Tracking_stack_8_rot_flip = np.fliplr(Tracking_stack_8_rot)\n","\n"," os.chdir(Results_folder)\n"," if Mask_images:\n"," imsave(names[num], prediction_stack_32)\n"," if Tracking_file:\n"," imsave(name_no_extension[num]+\"_tracking_file.tif\", Tracking_stack_8_rot_flip)\n","\n"," \n","\n","print(\"Predictions completed\") "],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"SxJsrw3kTcFx","colab_type":"text"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"rH_J20ydXWRQ","colab_type":"text"},"source":["\n","#**Thank you for using StarDist 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Stardist_3D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Stardist_3D_ZeroCostDL4Mic.ipynb index 2a7b2e0b..d36ffe5d 100755 --- a/Colab_notebooks/Stardist_3D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Stardist_3D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"colab":{"name":"StarDist_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1Ur-4VIQ6gf4ONupD6hK0M-AcJkoTzMlU","timestamp":1586789439593},{"file_id":"1PKVyox_mx2rEE3VlMFQtdnVULJFhYPaD","timestamp":1583443864213},{"file_id":"1XSclOkhhHmn-9LQc9k8c3Y6seT1LEi-Y","timestamp":1583264105465},{"file_id":"1VPZYk3MeSVyZVVEmesz10VtujbD4diJk","timestamp":1579481583477},{"file_id":"1ENdOZir1Gytf6JxzyfbjgfxO3_C1dLHK","timestamp":1575415287126},{"file_id":"1G8b4dF2kCs3ePBGZthPUGOyjJpZ2G_Dm","timestamp":1575379725785},{"file_id":"1P0tT0RR_b3SFKvOcON_MzcAIcxRUQK5B","timestamp":1575377313115},{"file_id":"1hQz8PyJzBRkBZc9NwxM9mU9azRSvghBk","timestamp":1574783624098},{"file_id":"14mWTNjHgIbuuWAxb-0lhmhdIvMoZgrI0","timestamp":1574099686195},{"file_id":"1IWvFuBb0gqaJcUXhhfbcTWNh9cZEXW4S","timestamp":1573647131082},{"file_id":"1hFulBwI57YU6GoVc8sBt5KNIkCS7ynQ3","timestamp":1573579952409},{"file_id":"1Ba_Bu-PXN_2Mq5W6YHMgUYsJEfgbPtS-","timestamp":1573035984524},{"file_id":"1ePC44Qq_C2hSFGPM3PKyb0J6UBXSPddp","timestamp":1573032545399},{"file_id":"https://github.com/mpicbg-csbd/stardist/blob/master/examples/2D/2_training.ipynb","timestamp":1572984225873}],"collapsed_sections":[],"toc_visible":true},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"kiFRRolPa-Rb","colab_type":"text"},"source":["# **StarDist (3D)**\n","---\n","\n","**StarDist 3D** is a deep-learning method that can be used to segment cell nuclei from 3D bioimages and was first published by [Weigert *et al.* in 2019 on arXiv](https://arxiv.org/abs/1908.03636), extending to 3D the 2D appraoch from [Schmidt *et al.* in 2018](https://arxiv.org/abs/1806.03535). It uses a shape representation based on star-convex polygons for nuclei in an image to predict the presence and the shape of these nuclei. This StarDist 3D network is based on an adapted ResNet network architecture.\n","\n"," **This particular notebook enables nuclei segmentation of 2D dataset. If you are interested in 3D dataset, you should use the StarDist 3D notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the paper:\n","\n","**Cell Detection with Star-convex Polygons** from Schmidt *et al.*, International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018. (https://arxiv.org/abs/1806.03535)\n","\n","and the 3D extension of the approach:\n","\n","**Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy** from Weigert *et al.* published on arXiv in 2019 (https://arxiv.org/abs/1908.03636)\n","\n","**The Original code** is freely available in GitHub:\n","https://github.com/mpicbg-csbd/stardist\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"iSuNqQ2ZMVGM","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"4-oByBSdE6DE","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," For StarDist to train, **it needs to have access to a paired training dataset made of images of nuclei and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Images\" (Training_source) and \"Training - Masks\" (Training_target).\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Please note that you currently can **only use .tif files!**\n","\n","You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Images of nuclei (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - Masks (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Images of nuclei\n"," - img_1.tif, img_2.tif\n"," - **Masks** \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"t1sYuLChbRV3","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CDxBu1-19OyC","colab_type":"text"},"source":["\n","\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"4waLStm0RPFo","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ZLY4qhgj8w-R","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"Ukil4yuS8seC","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bB0IaQMZmWYM","colab_type":"text"},"source":["# **2. Install StarDist and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"j0w7C8P5zPIp","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install StarDist and dependencies\n","\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# Install packages which are not included in Google Colab\n","\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install stardist # contains tools to operate STARDIST.\n","!pip install gputools\n","!pip install edt\n","!pip install wget\n","\n","\n","# ------- Variable specific to Stardist -------\n","from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available\n","from stardist.models import Config3D, StarDist3D, StarDistData3D\n","from stardist import relabel_image_stardist3D, Rays_GoldenSpiral, calculate_extents\n","from stardist.matching import matching_dataset\n","from csbdeep.utils import Path, normalize, download_and_extract_zip_file, plot_history # for loss plot\n","from csbdeep.io import save_tiff_imagej_compatible\n","import numpy as np\n","np.random.seed(42)\n","lbl_cmap = random_label_cmap()\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","import cv2\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DPWhXaltAYgH","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---\n","\n"]},{"cell_type":"markdown","metadata":{"id":"nAW3oU60htR_","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"HJKFAmuXc6d1"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (images of nuclei) and Training_target (masks) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a 400 epochs, but a full training should run for more. Evaluate the performance after training (see 5.). **Default value: 400**\n","\n","**Advanced parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1** \n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`patch_size`:** and **`patch_height`:** Input the size of the patches use to train StarDist 3D (length of a side). The value should be smaller or equal to the dimensions of the image. Make patch size and patch_height as large as possible and divisible by 8 and 4, respectively. **Default value: dimension of the training images**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`n_rays`:** Set number of rays (corners) used for StarDist (for instance a cube has 8 corners). **Default value: 96** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**"]},{"cell_type":"code","metadata":{"colab_type":"code","cellView":"form","id":"CNJImzzVnr7h","colab":{}},"source":["\n","\n","#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","training_images = Training_source\n","\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","mask_images = Training_target \n","\n","\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","\n","model_path = \"\" #@param {type:\"string\"}\n","trained_model = model_path \n","\n","#@markdown ### Other parameters for training:\n","number_of_epochs = 400#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","\n","#GPU_limit = 90 #@param {type:\"number\"}\n","batch_size = 1#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","patch_size = 64#@param {type:\"number\"} # pixels in\n","patch_height = 64#@param {type:\"number\"}\n","percentage_validation = 10#@param {type:\"number\"}\n","n_rays = 96 #@param {type:\"number\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 1\n"," n_rays = 96\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0003\n","\n","\n","percentage = percentage_validation/100\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Folder already exists and has been removed !!\")\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","# If default parameters, patch size is the same as image size\n","if (Use_Default_Advanced_Parameters): \n"," patch_size = min(Image_Y, Image_X) \n"," patch_height = Image_Z\n","\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_height is smaller than the z dimension of the image \n","\n","if patch_height > Image_Z :\n"," patch_height = Image_Z\n"," print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we check that patch_height is divisible by 4\n","if not patch_height % 4 == 0:\n"," patch_height = ((int(patch_height / 4)-1) * 4)\n"," if patch_height == 0:\n"," patch_height = 4\n"," print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","print(\"Parameters initiated.\")\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","#Here we use a simple normalisation strategy to visualise the image\n","from astropy.visualization import simple_norm\n","norm = simple_norm(x, percent = 99)\n","\n","mid_plane = int(Image_Z / 2)+1\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n","plt.axis('off')\n","plt.title('Training source (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], interpolation='nearest', cmap=lbl_cmap)\n","plt.axis('off')\n","plt.title('Training target (single Z plane)');\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nbyf-RevQhDL","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"UQ2hultWQlT9","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by rotating the training images in the XY-Plane and flipping them along X-Axis as well as performing elastic deformations\n","\n","**The flip option and the elastic deformation will double the size of your dataset, rotation will quadruple and all together will increase the dataset by a factor of 16.**\n","\n"," Elastic deformations performed by [Elasticdeform.](https://elasticdeform.readthedocs.io/en/latest/index.html).\n"]},{"cell_type":"code","metadata":{"id":"wYdTY6ULg01b","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###See Elasticdeform’s license\n","#Copyright (c) 2001, 2002 Enthought, Inc. All rights reserved.\n","\n","#Copyright (c) 2003-2017 SciPy Developers. All rights reserved.\n","\n","#Copyright (c) 2018 Gijs van Tulder. All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n","\n","##Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n","#Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n","#Neither the name of Enthought nor the names of the SciPy Developers may be used to endorse or promote products derived from this software without specific prior written permission.\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n","\n","print(\"Double click to see elasticdeform’s license\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"kKLB47jgQrxr","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","#@markdown **Deform your images**\n","\n","Elastic_deformation = True #@param {type:\"boolean\"}\n","\n","Deformation_Sigma = 3 #@param {type:\"slider\", min:1, max:30, step:1}\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","Save_augmented_images = True #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Training_source_augmented+'/'+image,source_img)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Training_target_augmented+'/'+image,target_img)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Training_source_augmented+'/'+image,source_img)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Training_target_augmented+'/'+image,target_img)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","\n","\n","if Use_Data_augmentation:\n","\n","\n"," if Elastic_deformation:\n"," !pip install elasticdeform\n"," import numpy, imageio, elasticdeform\n","\n"," if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n"," Augmented_folder = Saving_path+\"/Augmented_Folder\"\n","\n"," if os.path.exists(Augmented_folder):\n"," shutil.rmtree(Augmented_folder)\n"," os.makedirs(Augmented_folder)\n"," Training_source_augmented = Augmented_folder+\"/Training_source\"\n"," os.makedirs(Training_source_augmented)\n"," Training_target_augmented = Augmented_folder+\"/Training_target\"\n"," os.makedirs(Training_target_augmented)\n"," print(\"Data augmentation enabled\")\n"," print(\"Generation of the augmented dataset in progress\")\n","\n"," if Elastic_deformation:\n"," for filename in os.listdir(Training_source):\n"," X = imread(os.path.join(Training_source, filename))\n"," Y = imread(os.path.join(Training_target, filename))\n"," [X_deformed, Y_deformed] = elasticdeform.deform_random_grid([X, Y], sigma=Deformation_Sigma, order=0)\n","\n"," os.chdir(Augmented_folder+\"/Training_source\")\n"," imsave(filename, X)\n"," imsave(filename+\"_deformed.tif\", X_deformed)\n","\n"," os.chdir(Augmented_folder+\"/Training_target\")\n"," imsave(filename, Y)\n"," imsave(filename+\"_deformed.tif\", Y_deformed)\n","\n"," Training_source_rot = Training_source_augmented\n"," Training_target_rot = Training_target_augmented\n"," \n"," if not Elastic_deformation:\n"," Training_source_rot = Training_source\n"," Training_target_rot = Training_target\n","\n"," \n"," if Rotation == True:\n"," rotation_aug(Training_source_rot,Training_target_rot,flip=Flip)\n"," elif Rotation == False and Flip == True:\n"," flip(Training_source_rot,Training_target_rot)\n","\n"," print(\"Done\")\n","\n"," if Elastic_deformation:\n"," from astropy.visualization import simple_norm\n"," norm = simple_norm(x, percent = 99)\n","\n"," random_choice=random.choice(os.listdir(Training_source))\n"," x = imread(Augmented_folder+\"/Training_source/\"+random_choice)\n"," x_deformed = imread(Augmented_folder+\"/Training_source/\"+random_choice+\"_deformed.tif\")\n"," y = imread(Augmented_folder+\"/Training_target/\"+random_choice)\n"," y_deformed = imread(Augmented_folder+\"/Training_target/\"+random_choice+\"_deformed.tif\") \n","\n"," Image_Z = x.shape[0]\n"," mid_plane = int(Image_Z / 2)+1\n","\n"," f=plt.figure(figsize=(10,10))\n"," plt.subplot(2,2,1)\n"," plt.imshow(x[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n"," plt.axis('off')\n"," plt.title('Training source (single Z plane)');\n"," plt.subplot(2,2,2)\n"," plt.imshow(y[mid_plane], interpolation='nearest', cmap=lbl_cmap)\n"," plt.axis('off')\n"," plt.title('Training target (single Z plane)');\n"," plt.subplot(2,2,3)\n"," plt.imshow(x_deformed[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n"," plt.axis('off')\n"," plt.title('Deformed training source (single Z plane)');\n"," plt.subplot(2,2,4)\n"," plt.imshow(y_deformed[mid_plane], interpolation='nearest', cmap=lbl_cmap)\n"," plt.axis('off')\n"," plt.title('Deformed training target (single Z plane)');\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")\n","\n","\n","\n"," \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"pjz-5bRVh1ja","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a StarDist model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"zeSUtd2Thw-O","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Demo_3D_Model_from_Stardist_3D_paper\" #@param [\"Model_from_file\", \"Demo_3D_Model_from_Stardist_3D_paper\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the Demo 3D model provided in the Stardist 3D github ------------------------\n","\n"," if pretrained_model_choice == \"Demo_3D_Model_from_Stardist_3D_paper\":\n"," pretrained_model_name = \"Demo_3D\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the Demo 3D model from the Stardist_3D paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"https://raw.githubusercontent.com/mpicbg-csbd/stardist/master/models/examples/3D_demo/config.json\", pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/3D_demo/thresholds.json\", pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/3D_demo/weights_best.h5?raw=true\", pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/3D_demo/weights_last.h5?raw=true\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print(bcolors.WARNING+'Weights found in:')\n"," print(h5_file_path)\n"," print(bcolors.WARNING+'will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DECuc3HZDbwG","colab_type":"text"},"source":["#**4. Train your network**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"NwV5LweiavgQ","colab_type":"text"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"uTM781rCKT8r","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","import warnings\n","warnings.simplefilter(\"ignore\")\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","if Use_Data_augmentation:\n"," Training_source_dir = Training_source_augmented\n"," Training_target_dir = Training_target_augmented\n","\n","if not Use_Data_augmentation:\n"," Training_source_dir = Training_source\n"," Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","training_images_tiff=Training_source_dir+\"/*.tif\"\n","mask_images_tiff=Training_target_dir+\"/*.tif\"\n","\n","\n","# this funtion imports training images and masks and sorts them suitable for the network\n","X = sorted(glob(training_images_tiff)) \n","Y = sorted(glob(mask_images_tiff)) \n","\n","# assert -funtion check that X and Y really have images. If not this cell raises an error\n","assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))\n","\n","# Here we map the training dataset (images and masks).\n","X = list(map(imread,X))\n","Y = list(map(imread,Y))\n","\n","n_channel = 1 if X[0].ndim == 3 else X[0].shape[-1]\n","\n","\n","\n","#Normalize images and fill small label holes.\n","axis_norm = (0,1,2) # normalize channels independently\n","# axis_norm = (0,1,2,3) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 3 in axis_norm else 'independently'))\n"," sys.stdout.flush()\n","\n","X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]\n","Y = [fill_label_holes(y) for y in tqdm(Y)]\n","\n","#Here we split the your training dataset into training images (90 %) and validation images (10 %). \n","\n","assert len(X) > 1, \"not enough training data\"\n","rng = np.random.RandomState(42)\n","ind = rng.permutation(len(X))\n","n_val = max(1, int(round(percentage * len(ind))))\n","ind_train, ind_val = ind[:-n_val], ind[-n_val:]\n","X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]\n","X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] \n","print('number of images: %3d' % len(X))\n","print('- training: %3d' % len(X_trn))\n","print('- validation: %3d' % len(X_val))\n","\n","\n","\n","extents = calculate_extents(Y)\n","anisotropy = tuple(np.max(extents) / extents)\n","print('empirical anisotropy of labeled objects = %s' % str(anisotropy))\n","\n","\n","# Use OpenCL-based computations for data generator during training (requires 'gputools')\n","use_gpu = False and gputools_available()\n","\n","\n","#Here we ensure that our network has a minimal number of steps\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(len(X)/batch_size)+1\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# Predict on subsampled grid for increased efficiency and larger field of view\n","grid = tuple(1 if a > 1.5 else 2 for a in anisotropy)\n","\n","# Use rays on a Fibonacci lattice adjusted for measured anisotropy of the training data\n","rays = Rays_GoldenSpiral(n_rays, anisotropy=anisotropy)\n","\n","conf = Config3D (\n"," rays = rays,\n"," grid = grid,\n"," anisotropy = anisotropy,\n"," use_gpu = use_gpu,\n"," n_channel_in = n_channel,\n"," train_learning_rate = initial_learning_rate,\n"," train_patch_size = (patch_height, patch_size, patch_size),\n"," train_batch_size = batch_size,\n",")\n","print(conf)\n","vars(conf)\n","\n","\n","# --------------------- This is currently disabled as it give an error ------------------------\n","#here we limit GPU to 80%\n","if use_gpu:\n"," from csbdeep.utils.tf import limit_gpu_memory\n"," # adjust as necessary: limit GPU memory to be used by TensorFlow to leave some to OpenCL-based computations\n"," limit_gpu_memory(0.8)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# Here we create a model according to section 5.3.\n","model = StarDist3D(conf, name=model_name, basedir=trained_model)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#Here we check the FOV of the network.\n","median_size = calculate_extents(Y, np.median)\n","fov = np.array(model._axes_tile_overlap('ZYX'))\n","if any(median_size > fov):\n"," print(\"WARNING: median object size larger than field of view of the neural network.\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nnMCvu2PKT9W","colab_type":"text"},"source":["## **4.2. Train the network**\n","---\n","\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point."]},{"cell_type":"code","metadata":{"id":"XfCF-Q4lKT9e","colab_type":"code","cellView":"form","colab":{}},"source":["import time\n","start = time.time()\n","\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","#@markdown ##Start training\n","\n","augmenter = None\n","\n","# def augmenter(X_batch, Y_batch):\n","# \"\"\"Augmentation for data batch.\n","# X_batch is a list of input images (length at most batch_size)\n","# Y_batch is the corresponding list of ground-truth label images\n","# \"\"\"\n","# # ...\n","# return X_batch, Y_batch\n","\n","# Training the model. \n","# 'input_epochs' and 'steps' refers to your input data in section 5.1 \n","history = model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter,\n"," epochs=number_of_epochs, steps_per_epoch=number_of_steps)\n","None;\n","\n","print(\"Training done\")\n","\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","print(\"Network optimization in progress\")\n","\n","#Here we optimize the network.\n","model.optimize_thresholds(X_val, Y_val)\n","print(\"Done\")\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"iYRrmh0dCrNs","colab_type":"text"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"LqH54fYhdbXU","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"RzAHUsi-78Ak","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"w3Z7Jkv8bPvq","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"05dbg6UrGunj","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","import csv\n","from matplotlib import pyplot as plt\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"mBkuXf5zhHUd","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n"," The results can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\"."]},{"cell_type":"code","metadata":{"id":"i9ek_kIHhK1R","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Give the paths to an image to test the performance of the model with.\n","\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = True #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 1#@param {type:\"number\"}\n","n_tiles_X = 1#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","\n","#Create a quality control Folder and check if the folder already exist\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\") == False:\n"," os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","\n","# Generate predictions from the Source_QC_folder and save them in the QC folder\n","\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","n_channel = 1 if Z[0].ndim == 2 else Z[0].shape[-1]\n","axis_norm = (0,1) # normalize channels independently\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n"," \n"," # axis_norm = (0,1,2) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n","\n","model = StarDist3D(None, name=QC_model_name, basedir=QC_model_path)\n","\n","names = [os.path.basename(f) for f in sorted(glob(Source_QC_folder_tif))]\n","\n"," \n","# modify the names to suitable form: path_images/image_numberX.tif\n"," \n","lenght_of_Z = len(Z)\n"," \n","for i in range(lenght_of_Z):\n"," img = normalize(Z[i], 1,99.8, axis=axis_norm)\n"," labels, polygons = model.predict_instances(img, n_tiles=n_tilesZYX)\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(names[i], labels, polygons)\n","\n","\n","# Here we start testing the differences between GT and predicted masks\n","\n","\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\"]) \n","\n","# define the images\n","\n"," for n in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," \n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))\n","\n","#Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n","#Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n","# Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score)])\n","\n","\n","Image_Z = test_input.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Display the last image\n","\n","f=plt.figure(figsize=(25,25))\n","\n","from astropy.visualization import simple_norm\n","norm = simple_norm(test_input, percent = 99)\n","\n","#Input\n","plt.subplot(1,4,1)\n","plt.axis('off')\n","plt.imshow(test_input[mid_plane], aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n","plt.title('Input')\n","\n","#Ground-truth\n","plt.subplot(1,4,2)\n","plt.axis('off')\n","plt.imshow(test_ground_truth_0_to_255[mid_plane], aspect='equal', cmap='Greens')\n","plt.title('Ground Truth')\n","\n","#Prediction\n","plt.subplot(1,4,3)\n","plt.axis('off')\n","plt.imshow(test_prediction_0_to_255[mid_plane], aspect='equal', cmap='Purples')\n","plt.title('Prediction')\n","\n","#Overlay\n","plt.subplot(1,4,4)\n","plt.axis('off')\n","plt.imshow(test_ground_truth_0_to_255[mid_plane], cmap='Greens')\n","plt.imshow(test_prediction_0_to_255[mid_plane], alpha=0.5, cmap='Purples')\n","plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3)))\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"U8H7QRfKBzI8","colab_type":"text"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"btXwwnVpBEMB","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.3) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you trained.\n","\n","**`Result_folder`:** This folder will contain the predicted output ROI.\n","\n","**`Data_type`:** Please indicate if the images you want to predict are single images or stacks\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"x8UXP8S2eoo_","colab_type":"code","cellView":"form","colab":{}},"source":["from PIL import Image\n","\n","\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","#test_dataset = Data_folder\n","\n","Results_folder = \"\" #@param {type:\"string\"}\n","#results = results_folder\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 2#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#single images\n","#testDATA = test_dataset\n","Dataset = Data_folder+\"/*.tif\"\n","\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","X = sorted(glob(Dataset))\n","X = list(map(imread,X))\n","n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n","axis_norm = (0,1) # normalize channels independently\n"," \n","# axis_norm = (0,1,2) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n","model = StarDist3D(None, name=Prediction_model_name, basedir=Prediction_model_path)\n"," \n","#Sorting and mapping original test dataset\n","X = sorted(glob(Dataset))\n","X = list(map(imread,X))\n","names = [os.path.basename(f) for f in sorted(glob(Dataset))]\n","\n","# modify the names to suitable form: path_images/image_numberX.tif\n","FILEnames=[]\n","for m in names:\n"," m=Results_folder+'/'+m\n"," FILEnames.append(m)\n","\n"," # Predictions folder\n","lenght_of_X = len(X)\n","for i in range(lenght_of_X):\n"," img = normalize(X[i], 1,99.8, axis=axis_norm)\n"," labels, polygons = model.predict_instances(img, n_tiles=n_tilesZYX)\n"," \n","# Save the predicted mask in the result folder\n"," os.chdir(Results_folder)\n"," imsave(FILEnames[i], labels, polygons)\n","\n"," # One example image \n","print(\"One example image is displayed bellow:\")\n","plt.figure(figsize=(13,10))\n","z = max(0, img.shape[0] // 2 - 5)\n","plt.subplot(121)\n","plt.imshow((img if img.ndim==3 else img[...,:3])[z], clim=(0,1), cmap='gray')\n","plt.title('Raw image (XY slice)')\n","plt.axis('off')\n","plt.subplot(122)\n","plt.imshow((img if img.ndim==3 else img[...,:3])[z], clim=(0,1), cmap='gray')\n","plt.imshow(labels[z], cmap=lbl_cmap, alpha=0.5)\n","plt.title('Image and predicted labels (XY slice)')\n","plt.axis('off');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"SxJsrw3kTcFx","colab_type":"text"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"rH_J20ydXWRQ","colab_type":"text"},"source":["#**Thank you for using StarDist 3D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"colab":{"name":"StarDist_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1Ur-4VIQ6gf4ONupD6hK0M-AcJkoTzMlU","timestamp":1586789439593},{"file_id":"1PKVyox_mx2rEE3VlMFQtdnVULJFhYPaD","timestamp":1583443864213},{"file_id":"1XSclOkhhHmn-9LQc9k8c3Y6seT1LEi-Y","timestamp":1583264105465},{"file_id":"1VPZYk3MeSVyZVVEmesz10VtujbD4diJk","timestamp":1579481583477},{"file_id":"1ENdOZir1Gytf6JxzyfbjgfxO3_C1dLHK","timestamp":1575415287126},{"file_id":"1G8b4dF2kCs3ePBGZthPUGOyjJpZ2G_Dm","timestamp":1575379725785},{"file_id":"1P0tT0RR_b3SFKvOcON_MzcAIcxRUQK5B","timestamp":1575377313115},{"file_id":"1hQz8PyJzBRkBZc9NwxM9mU9azRSvghBk","timestamp":1574783624098},{"file_id":"14mWTNjHgIbuuWAxb-0lhmhdIvMoZgrI0","timestamp":1574099686195},{"file_id":"1IWvFuBb0gqaJcUXhhfbcTWNh9cZEXW4S","timestamp":1573647131082},{"file_id":"1hFulBwI57YU6GoVc8sBt5KNIkCS7ynQ3","timestamp":1573579952409},{"file_id":"1Ba_Bu-PXN_2Mq5W6YHMgUYsJEfgbPtS-","timestamp":1573035984524},{"file_id":"1ePC44Qq_C2hSFGPM3PKyb0J6UBXSPddp","timestamp":1573032545399},{"file_id":"https://github.com/mpicbg-csbd/stardist/blob/master/examples/2D/2_training.ipynb","timestamp":1572984225873}],"collapsed_sections":[],"toc_visible":true},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"kiFRRolPa-Rb","colab_type":"text"},"source":["# **StarDist (3D)**\n","---\n","\n","**StarDist 3D** is a deep-learning method that can be used to segment cell nuclei from 3D bioimages and was first published by [Weigert *et al.* in 2019 on arXiv](https://arxiv.org/abs/1908.03636), extending to 3D the 2D appraoch from [Schmidt *et al.* in 2018](https://arxiv.org/abs/1806.03535). It uses a shape representation based on star-convex polygons for nuclei in an image to predict the presence and the shape of these nuclei. This StarDist 3D network is based on an adapted ResNet network architecture.\n","\n"," **This particular notebook enables nuclei segmentation of 2D dataset. If you are interested in 3D dataset, you should use the StarDist 3D notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the paper:\n","\n","**Cell Detection with Star-convex Polygons** from Schmidt *et al.*, International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018. (https://arxiv.org/abs/1806.03535)\n","\n","and the 3D extension of the approach:\n","\n","**Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy** from Weigert *et al.* published on arXiv in 2019 (https://arxiv.org/abs/1908.03636)\n","\n","**The Original code** is freely available in GitHub:\n","https://github.com/mpicbg-csbd/stardist\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"iSuNqQ2ZMVGM","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"4-oByBSdE6DE","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," For StarDist to train, **it needs to have access to a paired training dataset made of images of nuclei and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Images\" (Training_source) and \"Training - Masks\" (Training_target).\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Please note that you currently can **only use .tif files!**\n","\n","You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Images of nuclei (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - Masks (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Images of nuclei\n"," - img_1.tif, img_2.tif\n"," - **Masks** \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"t1sYuLChbRV3","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CDxBu1-19OyC","colab_type":"text"},"source":["\n","\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"4waLStm0RPFo","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ZLY4qhgj8w-R","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"Ukil4yuS8seC","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bB0IaQMZmWYM","colab_type":"text"},"source":["# **2. Install StarDist and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"j0w7C8P5zPIp","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install StarDist and dependencies\n","\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# Install packages which are not included in Google Colab\n","\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install stardist # contains tools to operate STARDIST.\n","!pip install gputools\n","!pip install edt\n","!pip install wget\n","\n","\n","# ------- Variable specific to Stardist -------\n","from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available\n","from stardist.models import Config3D, StarDist3D, StarDistData3D\n","from stardist import relabel_image_stardist3D, Rays_GoldenSpiral, calculate_extents\n","from stardist.matching import matching_dataset\n","from csbdeep.utils import Path, normalize, download_and_extract_zip_file, plot_history # for loss plot\n","from csbdeep.io import save_tiff_imagej_compatible\n","import numpy as np\n","np.random.seed(42)\n","lbl_cmap = random_label_cmap()\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","import cv2\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DPWhXaltAYgH","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---\n","\n"]},{"cell_type":"markdown","metadata":{"id":"nAW3oU60htR_","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"HJKFAmuXc6d1"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (images of nuclei) and Training_target (masks) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a 400 epochs, but a full training should run for more. Evaluate the performance after training (see 5.). **Default value: 400**\n","\n","**Advanced parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1** \n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`patch_size`:** and **`patch_height`:** Input the size of the patches use to train StarDist 3D (length of a side). The value should be smaller or equal to the dimensions of the image. Make patch size and patch_height as large as possible and divisible by 8 and 4, respectively. **Default value: dimension of the training images**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`n_rays`:** Set number of rays (corners) used for StarDist (for instance a cube has 8 corners). **Default value: 96** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**"]},{"cell_type":"code","metadata":{"colab_type":"code","cellView":"form","id":"CNJImzzVnr7h","colab":{}},"source":["\n","\n","#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","training_images = Training_source\n","\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","mask_images = Training_target \n","\n","\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","\n","model_path = \"\" #@param {type:\"string\"}\n","trained_model = model_path \n","\n","#@markdown ### Other parameters for training:\n","number_of_epochs = 400#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","\n","#GPU_limit = 90 #@param {type:\"number\"}\n","batch_size = 1#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","patch_size = 64#@param {type:\"number\"} # pixels in\n","patch_height = 64#@param {type:\"number\"}\n","percentage_validation = 10#@param {type:\"number\"}\n","n_rays = 96 #@param {type:\"number\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 1\n"," n_rays = 96\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0003\n","\n","\n","percentage = percentage_validation/100\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Folder already exists and has been removed !!\")\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","# If default parameters, patch size is the same as image size\n","if (Use_Default_Advanced_Parameters): \n"," patch_size = min(Image_Y, Image_X) \n"," patch_height = Image_Z\n","\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_height is smaller than the z dimension of the image \n","\n","if patch_height > Image_Z :\n"," patch_height = Image_Z\n"," print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we check that patch_height is divisible by 4\n","if not patch_height % 4 == 0:\n"," patch_height = ((int(patch_height / 4)-1) * 4)\n"," if patch_height == 0:\n"," patch_height = 4\n"," print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","print(\"Parameters initiated.\")\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","#Here we use a simple normalisation strategy to visualise the image\n","from astropy.visualization import simple_norm\n","norm = simple_norm(x, percent = 99)\n","\n","mid_plane = int(Image_Z / 2)+1\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n","plt.axis('off')\n","plt.title('Training source (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], interpolation='nearest', cmap=lbl_cmap)\n","plt.axis('off')\n","plt.title('Training target (single Z plane)');\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nbyf-RevQhDL","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"UQ2hultWQlT9","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by rotating the training images in the XY-Plane and flipping them along X-Axis as well as performing elastic deformations\n","\n","**The flip option and the elastic deformation will double the size of your dataset, rotation will quadruple and all together will increase the dataset by a factor of 16.**\n","\n"," Elastic deformations performed by [Elasticdeform.](https://elasticdeform.readthedocs.io/en/latest/index.html).\n"]},{"cell_type":"code","metadata":{"id":"wYdTY6ULg01b","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###See Elasticdeform’s license\n","#Copyright (c) 2001, 2002 Enthought, Inc. All rights reserved.\n","\n","#Copyright (c) 2003-2017 SciPy Developers. All rights reserved.\n","\n","#Copyright (c) 2018 Gijs van Tulder. All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n","\n","##Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n","#Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n","#Neither the name of Enthought nor the names of the SciPy Developers may be used to endorse or promote products derived from this software without specific prior written permission.\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n","\n","print(\"Double click to see elasticdeform’s license\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"kKLB47jgQrxr","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","#@markdown **Deform your images**\n","\n","Elastic_deformation = True #@param {type:\"boolean\"}\n","\n","Deformation_Sigma = 3 #@param {type:\"slider\", min:1, max:30, step:1}\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","Save_augmented_images = True #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Training_source_augmented+'/'+image,source_img)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Training_target_augmented+'/'+image,target_img)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Training_source_augmented+'/'+image,source_img)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Training_target_augmented+'/'+image,target_img)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","\n","\n","if Use_Data_augmentation:\n","\n","\n"," if Elastic_deformation:\n"," !pip install elasticdeform\n"," import numpy, imageio, elasticdeform\n","\n"," if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n"," Augmented_folder = Saving_path+\"/Augmented_Folder\"\n","\n"," if os.path.exists(Augmented_folder):\n"," shutil.rmtree(Augmented_folder)\n"," os.makedirs(Augmented_folder)\n"," Training_source_augmented = Augmented_folder+\"/Training_source\"\n"," os.makedirs(Training_source_augmented)\n"," Training_target_augmented = Augmented_folder+\"/Training_target\"\n"," os.makedirs(Training_target_augmented)\n"," print(\"Data augmentation enabled\")\n"," print(\"Generation of the augmented dataset in progress\")\n","\n"," if Elastic_deformation:\n"," for filename in os.listdir(Training_source):\n"," X = imread(os.path.join(Training_source, filename))\n"," Y = imread(os.path.join(Training_target, filename))\n"," [X_deformed, Y_deformed] = elasticdeform.deform_random_grid([X, Y], sigma=Deformation_Sigma, order=0)\n","\n"," os.chdir(Augmented_folder+\"/Training_source\")\n"," imsave(filename, X)\n"," imsave(filename+\"_deformed.tif\", X_deformed)\n","\n"," os.chdir(Augmented_folder+\"/Training_target\")\n"," imsave(filename, Y)\n"," imsave(filename+\"_deformed.tif\", Y_deformed)\n","\n"," Training_source_rot = Training_source_augmented\n"," Training_target_rot = Training_target_augmented\n"," \n"," if not Elastic_deformation:\n"," Training_source_rot = Training_source\n"," Training_target_rot = Training_target\n","\n"," \n"," if Rotation == True:\n"," rotation_aug(Training_source_rot,Training_target_rot,flip=Flip)\n"," elif Rotation == False and Flip == True:\n"," flip(Training_source_rot,Training_target_rot)\n","\n"," print(\"Done\")\n","\n"," if Elastic_deformation:\n"," from astropy.visualization import simple_norm\n"," norm = simple_norm(x, percent = 99)\n","\n"," random_choice=random.choice(os.listdir(Training_source))\n"," x = imread(Augmented_folder+\"/Training_source/\"+random_choice)\n"," x_deformed = imread(Augmented_folder+\"/Training_source/\"+random_choice+\"_deformed.tif\")\n"," y = imread(Augmented_folder+\"/Training_target/\"+random_choice)\n"," y_deformed = imread(Augmented_folder+\"/Training_target/\"+random_choice+\"_deformed.tif\") \n","\n"," Image_Z = x.shape[0]\n"," mid_plane = int(Image_Z / 2)+1\n","\n"," f=plt.figure(figsize=(10,10))\n"," plt.subplot(2,2,1)\n"," plt.imshow(x[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n"," plt.axis('off')\n"," plt.title('Training source (single Z plane)');\n"," plt.subplot(2,2,2)\n"," plt.imshow(y[mid_plane], interpolation='nearest', cmap=lbl_cmap)\n"," plt.axis('off')\n"," plt.title('Training target (single Z plane)');\n"," plt.subplot(2,2,3)\n"," plt.imshow(x_deformed[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n"," plt.axis('off')\n"," plt.title('Deformed training source (single Z plane)');\n"," plt.subplot(2,2,4)\n"," plt.imshow(y_deformed[mid_plane], interpolation='nearest', cmap=lbl_cmap)\n"," plt.axis('off')\n"," plt.title('Deformed training target (single Z plane)');\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")\n","\n","\n","\n"," \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"pjz-5bRVh1ja","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a StarDist model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"zeSUtd2Thw-O","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Demo_3D_Model_from_Stardist_3D_paper\" #@param [\"Model_from_file\", \"Demo_3D_Model_from_Stardist_3D_paper\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the Demo 3D model provided in the Stardist 3D github ------------------------\n","\n"," if pretrained_model_choice == \"Demo_3D_Model_from_Stardist_3D_paper\":\n"," pretrained_model_name = \"Demo_3D\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the Demo 3D model from the Stardist_3D paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"https://raw.githubusercontent.com/mpicbg-csbd/stardist/master/models/examples/3D_demo/config.json\", pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/3D_demo/thresholds.json\", pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/3D_demo/weights_best.h5?raw=true\", pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/3D_demo/weights_last.h5?raw=true\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print(bcolors.WARNING+'Weights found in:')\n"," print(h5_file_path)\n"," print(bcolors.WARNING+'will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DECuc3HZDbwG","colab_type":"text"},"source":["#**4. Train the network**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"NwV5LweiavgQ","colab_type":"text"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"uTM781rCKT8r","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Create the model and dataset objects\n","import warnings\n","warnings.simplefilter(\"ignore\")\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","if Use_Data_augmentation:\n"," Training_source_dir = Training_source_augmented\n"," Training_target_dir = Training_target_augmented\n","\n","if not Use_Data_augmentation:\n"," Training_source_dir = Training_source\n"," Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","training_images_tiff=Training_source_dir+\"/*.tif\"\n","mask_images_tiff=Training_target_dir+\"/*.tif\"\n","\n","\n","# this funtion imports training images and masks and sorts them suitable for the network\n","X = sorted(glob(training_images_tiff)) \n","Y = sorted(glob(mask_images_tiff)) \n","\n","# assert -funtion check that X and Y really have images. If not this cell raises an error\n","assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))\n","\n","# Here we map the training dataset (images and masks).\n","X = list(map(imread,X))\n","Y = list(map(imread,Y))\n","\n","n_channel = 1 if X[0].ndim == 3 else X[0].shape[-1]\n","\n","\n","\n","#Normalize images and fill small label holes.\n","axis_norm = (0,1,2) # normalize channels independently\n","# axis_norm = (0,1,2,3) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 3 in axis_norm else 'independently'))\n"," sys.stdout.flush()\n","\n","X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]\n","Y = [fill_label_holes(y) for y in tqdm(Y)]\n","\n","#Here we split the your training dataset into training images (90 %) and validation images (10 %). \n","\n","assert len(X) > 1, \"not enough training data\"\n","rng = np.random.RandomState(42)\n","ind = rng.permutation(len(X))\n","n_val = max(1, int(round(percentage * len(ind))))\n","ind_train, ind_val = ind[:-n_val], ind[-n_val:]\n","X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]\n","X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] \n","print('number of images: %3d' % len(X))\n","print('- training: %3d' % len(X_trn))\n","print('- validation: %3d' % len(X_val))\n","\n","\n","\n","extents = calculate_extents(Y)\n","anisotropy = tuple(np.max(extents) / extents)\n","print('empirical anisotropy of labeled objects = %s' % str(anisotropy))\n","\n","\n","# Use OpenCL-based computations for data generator during training (requires 'gputools')\n","use_gpu = False and gputools_available()\n","\n","\n","#Here we ensure that our network has a minimal number of steps\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(len(X)/batch_size)+1\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# Predict on subsampled grid for increased efficiency and larger field of view\n","grid = tuple(1 if a > 1.5 else 2 for a in anisotropy)\n","\n","# Use rays on a Fibonacci lattice adjusted for measured anisotropy of the training data\n","rays = Rays_GoldenSpiral(n_rays, anisotropy=anisotropy)\n","\n","conf = Config3D (\n"," rays = rays,\n"," grid = grid,\n"," anisotropy = anisotropy,\n"," use_gpu = use_gpu,\n"," n_channel_in = n_channel,\n"," train_learning_rate = initial_learning_rate,\n"," train_patch_size = (patch_height, patch_size, patch_size),\n"," train_batch_size = batch_size,\n",")\n","print(conf)\n","vars(conf)\n","\n","\n","# --------------------- This is currently disabled as it give an error ------------------------\n","#here we limit GPU to 80%\n","if use_gpu:\n"," from csbdeep.utils.tf import limit_gpu_memory\n"," # adjust as necessary: limit GPU memory to be used by TensorFlow to leave some to OpenCL-based computations\n"," limit_gpu_memory(0.8)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# Here we create a model according to section 5.3.\n","model = StarDist3D(conf, name=model_name, basedir=trained_model)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#Here we check the FOV of the network.\n","median_size = calculate_extents(Y, np.median)\n","fov = np.array(model._axes_tile_overlap('ZYX'))\n","if any(median_size > fov):\n"," print(\"WARNING: median object size larger than field of view of the neural network.\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nnMCvu2PKT9W","colab_type":"text"},"source":["## **4.2. Start Trainning**\n","---\n","\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point."]},{"cell_type":"code","metadata":{"id":"XfCF-Q4lKT9e","colab_type":"code","cellView":"form","colab":{}},"source":["import time\n","start = time.time()\n","\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","#@markdown ##Start training\n","\n","augmenter = None\n","\n","# def augmenter(X_batch, Y_batch):\n","# \"\"\"Augmentation for data batch.\n","# X_batch is a list of input images (length at most batch_size)\n","# Y_batch is the corresponding list of ground-truth label images\n","# \"\"\"\n","# # ...\n","# return X_batch, Y_batch\n","\n","# Training the model. \n","# 'input_epochs' and 'steps' refers to your input data in section 5.1 \n","history = model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter,\n"," epochs=number_of_epochs, steps_per_epoch=number_of_steps)\n","None;\n","\n","print(\"Training done\")\n","\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","print(\"Network optimization in progress\")\n","\n","#Here we optimize the network.\n","model.optimize_thresholds(X_val, Y_val)\n","print(\"Done\")\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"iYRrmh0dCrNs","colab_type":"text"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"LqH54fYhdbXU","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"RzAHUsi-78Ak","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"w3Z7Jkv8bPvq","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"05dbg6UrGunj","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","import csv\n","from matplotlib import pyplot as plt\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"mBkuXf5zhHUd","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n"," The results can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\"."]},{"cell_type":"code","metadata":{"id":"i9ek_kIHhK1R","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Give the paths to an image to test the performance of the model with.\n","\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = True #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 1#@param {type:\"number\"}\n","n_tiles_X = 1#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","\n","#Create a quality control Folder and check if the folder already exist\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\") == False:\n"," os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","\n","# Generate predictions from the Source_QC_folder and save them in the QC folder\n","\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","n_channel = 1 if Z[0].ndim == 2 else Z[0].shape[-1]\n","axis_norm = (0,1) # normalize channels independently\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n"," \n"," # axis_norm = (0,1,2) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n","\n","model = StarDist3D(None, name=QC_model_name, basedir=QC_model_path)\n","\n","names = [os.path.basename(f) for f in sorted(glob(Source_QC_folder_tif))]\n","\n"," \n","# modify the names to suitable form: path_images/image_numberX.tif\n"," \n","lenght_of_Z = len(Z)\n"," \n","for i in range(lenght_of_Z):\n"," img = normalize(Z[i], 1,99.8, axis=axis_norm)\n"," labels, polygons = model.predict_instances(img, n_tiles=n_tilesZYX)\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(names[i], labels, polygons)\n","\n","\n","# Here we start testing the differences between GT and predicted masks\n","\n","\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\"]) \n","\n","# define the images\n","\n"," for n in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," \n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))\n","\n","#Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n","#Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n","# Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score)])\n","\n","\n","Image_Z = test_input.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Display the last image\n","\n","f=plt.figure(figsize=(25,25))\n","\n","from astropy.visualization import simple_norm\n","norm = simple_norm(test_input, percent = 99)\n","\n","#Input\n","plt.subplot(1,4,1)\n","plt.axis('off')\n","plt.imshow(test_input[mid_plane], aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n","plt.title('Input')\n","\n","#Ground-truth\n","plt.subplot(1,4,2)\n","plt.axis('off')\n","plt.imshow(test_ground_truth_0_to_255[mid_plane], aspect='equal', cmap='Greens')\n","plt.title('Ground Truth')\n","\n","#Prediction\n","plt.subplot(1,4,3)\n","plt.axis('off')\n","plt.imshow(test_prediction_0_to_255[mid_plane], aspect='equal', cmap='Purples')\n","plt.title('Prediction')\n","\n","#Overlay\n","plt.subplot(1,4,4)\n","plt.axis('off')\n","plt.imshow(test_ground_truth_0_to_255[mid_plane], cmap='Greens')\n","plt.imshow(test_prediction_0_to_255[mid_plane], alpha=0.5, cmap='Purples')\n","plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3)))\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"U8H7QRfKBzI8","colab_type":"text"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"btXwwnVpBEMB","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.3) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you trained.\n","\n","**`Result_folder`:** This folder will contain the predicted output ROI.\n","\n","**`Data_type`:** Please indicate if the images you want to predict are single images or stacks\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"x8UXP8S2eoo_","colab_type":"code","cellView":"form","colab":{}},"source":["from PIL import Image\n","\n","\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","#test_dataset = Data_folder\n","\n","Results_folder = \"\" #@param {type:\"string\"}\n","#results = results_folder\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 2#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#single images\n","#testDATA = test_dataset\n","Dataset = Data_folder+\"/*.tif\"\n","\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","X = sorted(glob(Dataset))\n","X = list(map(imread,X))\n","n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n","axis_norm = (0,1) # normalize channels independently\n"," \n","# axis_norm = (0,1,2) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n","model = StarDist3D(None, name=Prediction_model_name, basedir=Prediction_model_path)\n"," \n","#Sorting and mapping original test dataset\n","X = sorted(glob(Dataset))\n","X = list(map(imread,X))\n","names = [os.path.basename(f) for f in sorted(glob(Dataset))]\n","\n","# modify the names to suitable form: path_images/image_numberX.tif\n","FILEnames=[]\n","for m in names:\n"," m=Results_folder+'/'+m\n"," FILEnames.append(m)\n","\n"," # Predictions folder\n","lenght_of_X = len(X)\n","for i in range(lenght_of_X):\n"," img = normalize(X[i], 1,99.8, axis=axis_norm)\n"," labels, polygons = model.predict_instances(img, n_tiles=n_tilesZYX)\n"," \n","# Save the predicted mask in the result folder\n"," os.chdir(Results_folder)\n"," imsave(FILEnames[i], labels, polygons)\n","\n"," # One example image \n","print(\"One example image is displayed bellow:\")\n","plt.figure(figsize=(13,10))\n","z = max(0, img.shape[0] // 2 - 5)\n","plt.subplot(121)\n","plt.imshow((img if img.ndim==3 else img[...,:3])[z], clim=(0,1), cmap='gray')\n","plt.title('Raw image (XY slice)')\n","plt.axis('off')\n","plt.subplot(122)\n","plt.imshow((img if img.ndim==3 else img[...,:3])[z], clim=(0,1), cmap='gray')\n","plt.imshow(labels[z], cmap=lbl_cmap, alpha=0.5)\n","plt.title('Image and predicted labels (XY slice)')\n","plt.axis('off');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"SxJsrw3kTcFx","colab_type":"text"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"rH_J20ydXWRQ","colab_type":"text"},"source":["#**Thank you for using StarDist 3D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/U-net_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/U-net_2D_ZeroCostDL4Mic.ipynb index a60dc73b..e98d0c25 100755 --- a/Colab_notebooks/U-net_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/U-net_2D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"U-Net_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1VcTsLOL28ntbr23gYrhY3upxkztZeUvn","timestamp":1591024690909},{"file_id":"19jT_GoHGN-UTM1aEgkgrOjB8pcFz5AW4","timestamp":1591017297795},{"file_id":"1UkoWB27ZWh5j_qivSZIOeOJP1h2EqrVz","timestamp":1589363183397},{"file_id":"1ofNqOc7lz-m6NL4B-m4BIheaU5N0GMln","timestamp":1588873191434},{"file_id":"1rJnsgIKyL6vuneydIfjCKMtMhV3XlQ6o","timestamp":1588583580765},{"file_id":"1RUYrp8beEgDKL1kOWw5LgR1QQb4yHQtG","timestamp":1587061416704},{"file_id":"1FVax0eY3-m8DbJHx0B8Dnep-uGlp30Zt","timestamp":1586601038120},{"file_id":"1TTqmCf2mFQ_PNIZEXX9sRAhoixjYP_AB","timestamp":1585842446113},{"file_id":"1cWwS-jbLYTDOpPp_hhKOLGFXfu06ccpG","timestamp":1585821375983},{"file_id":"1TPEE_AtGTLedawgVBwwXofEJEcJUCgo3","timestamp":1585137343783},{"file_id":"1SxFRb38aC_kmKzKVQfkwWzkK9n7YFxVv","timestamp":1585053829456},{"file_id":"15iw9IOwHNF_GhiHxkh_rWbJG8JnW14Wh","timestamp":1584375074441},{"file_id":"15oMbXnMa4LDEMhPHBr3ga0xhJomMLhDo","timestamp":1584105762670},{"file_id":"1__NtYFNA3DxNB7LrUY13Bt8_frye3iWl","timestamp":1583445015203},{"file_id":"11jsQfqKeDU1Zk3nPykjWKwYhFmvJ1zJ-","timestamp":1575289898486}],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"WDrFAwpFIpE0","colab_type":"text"},"source":["# **U-Net (2D)**\n","---\n","\n","U-Net is an encoder-decoder network architecture originally used for image segmentation, first published by [Ronneberger *et al.*](https://arxiv.org/abs/1505.04597). The first half of the U-Net architecture is a downsampling convolutional neural network which acts as a feature extractor from input images. The other half upsamples these results and restores an image by combining results from downsampling with the upsampled images.\n","\n"," **This particular notebook enables image segmentation of 2D dataset. If you are interested in 3D dataset, you should use the 3D U-Net notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the papers: \n","\n","**U-Net: Convolutional Networks for Biomedical Image Segmentation** by Ronneberger *et al.* published on arXiv in 2015 (https://arxiv.org/abs/1505.04597)\n","\n","and \n","\n","**U-Net: deep learning for cell counting, detection, and morphometry** by Thorsten Falk *et al.* in Nature Methods 2019\n","(https://www.nature.com/articles/s41592-018-0261-2)\n","And source code found in: https://github.com/zhixuhao/unet by *Zhixuhao*\n","\n","**Please also cite this original paper when using or developing this notebook.** "]},{"cell_type":"markdown","metadata":{"id":"ABNu2p4stHeB","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"HVwncY_NvlYi","colab_type":"text"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For U-Net to train, **it needs to have access to a paired training dataset corresponding to images and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Training_source\n"," - img_1.tif, img_2.tif, ...\n"," - Training_target\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Training_source\n"," - img_1.tif, img_2.tif\n"," - Training_target \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"JrGNzgEyxzGQ","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"wYoajeT54sQM","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"TpT6gbwURzrV","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n","\n","# from tensorflow.python.client import device_lib \n","# device_lib.list_local_devices()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"quzkzlRD45HF","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"eLwDxBnp4-bc","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"leK5kmgD5Ism","colab_type":"text"},"source":["# **2. Install U-Net dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"vOeLpQfT0QF1","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play to install U-Net dependencies\n","\n","#As this notebokk depends mostly on keras which runs a tensorflow backend (which in turn is pre-installed in colab)\n","#only the data library needs to be additionally installed.\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","#We enforce the keras==2.2.5 release to ensure that the notebook continues working even if keras is updated.\n","\n","!pip install keras==2.2.5\n","!pip install data\n","\n","# Keras imports\n","from keras import models\n","from keras.models import Model, load_model\n","from keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, UpSampling2D\n","from keras.optimizers import Adam\n","# from keras.callbacks import ModelCheckpoint, LearningRateScheduler, CSVLogger # we currently don't use any other callbacks from ModelCheckpoints\n","from keras.callbacks import ModelCheckpoint\n","from keras.callbacks import ReduceLROnPlateau\n","from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img\n","from keras import backend as keras\n","\n","# General import\n","from __future__ import print_function\n","import numpy as np\n","import pandas as pd\n","import os\n","import glob\n","from skimage import img_as_ubyte, io, transform\n","import matplotlib as mpl\n","from matplotlib import pyplot as plt\n","from matplotlib.pyplot import imread\n","from pathlib import Path\n","import shutil\n","import random\n","import time\n","import csv\n","import sys\n","from math import ceil\n","\n","# Imports for QC\n","from PIL import Image\n","from scipy import signal\n","from scipy import ndimage\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","# from tqdm import tqdm\n","from tqdm.notebook import tqdm\n","\n","from sklearn.feature_extraction import image\n","from skimage import img_as_ubyte, io, transform\n","from skimage.util.shape import view_as_windows\n","\n","# Suppressing some warnings\n","import warnings\n","warnings.filterwarnings('ignore')\n","\n","\n","\n","def create_patches(Training_source, Training_target, patch_width, patch_height):\n"," \"\"\"\n"," Function creates patches from the Training_source and Training_target images. \n"," The steps parameter indicates the offset between patches and, if integer, is the same in x and y.\n"," Saves all created patches in two new directories in the /content folder.\n","\n"," Returns: - Two paths to where the patches are now saved\n"," \"\"\"\n"," DEBUG = False\n","\n"," Patch_source = os.path.join('/content','img_patches')\n"," Patch_target = os.path.join('/content','mask_patches')\n"," Patch_rejected = os.path.join('/content','rejected')\n"," \n","\n"," #Here we save the patches, in the /content directory as they will not usually be needed after training\n"," if os.path.exists(Patch_source):\n"," shutil.rmtree(Patch_source)\n"," if os.path.exists(Patch_target):\n"," shutil.rmtree(Patch_target)\n"," if os.path.exists(Patch_rejected):\n"," shutil.rmtree(Patch_rejected)\n","\n"," os.mkdir(Patch_source)\n"," os.mkdir(Patch_target)\n"," os.mkdir(Patch_rejected) #This directory will contain the images that have too little signal.\n"," \n","\n"," all_patches_img = np.empty([0,patch_width, patch_height])\n"," all_patches_mask = np.empty([0,patch_width, patch_height])\n","\n"," for file in os.listdir(Training_source):\n","\n"," img = io.imread(os.path.join(Training_source, file))\n"," mask = io.imread(os.path.join(Training_target, file),as_gray=True)\n","\n"," if DEBUG:\n"," print(file)\n"," print(img.dtype)\n","\n"," # Using view_as_windows with step size equal to the patch size to ensure there is no overlap\n"," patches_img = view_as_windows(img, (patch_width, patch_height), (patch_width, patch_height))\n"," patches_mask = view_as_windows(mask, (patch_width, patch_height), (patch_width, patch_height))\n"," #the shape of patches_img and patches_mask will be (number of patches along x, number of patches along y,patch_width,patch_height)\n","\n"," all_patches_img = np.concatenate((all_patches_img, patches_img.reshape(patches_img.shape[0]*patches_img.shape[1], patch_width,patch_height)), axis = 0)\n"," all_patches_mask = np.concatenate((all_patches_mask, patches_mask.reshape(patches_mask.shape[0]*patches_mask.shape[1], patch_width,patch_height)), axis = 0)\n","\n"," number_of_patches = all_patches_img.shape[0]\n"," print('number of patches: '+str(number_of_patches))\n","\n"," if DEBUG:\n"," print(all_patches_img.shape)\n"," print(all_patches_img.dtype)\n","\n"," for i in range(number_of_patches):\n"," img_save_path = os.path.join(Patch_source,'patch_'+str(i)+'.tif')\n"," mask_save_path = os.path.join(Patch_target,'patch_'+str(i)+'.tif')\n","\n"," # if the mask conatins at least 2% of its total number pixels as mask, then go ahead and save the images\n"," pixel_threshold_array = sorted(all_patches_mask[i].flatten())\n"," if pixel_threshold_array[int(round(len(pixel_threshold_array)*0.98))]>0:\n"," io.imsave(img_save_path, img_as_ubyte(normalizeMinMax(all_patches_img[i])))\n"," io.imsave(mask_save_path, convert2Mask(normalizeMinMax(all_patches_mask[i]),0))\n"," else:\n"," io.imsave(Patch_rejected+'/patch_'+str(i)+'_image.tif', img_as_ubyte(normalizeMinMax(all_patches_img[i])))\n"," io.imsave(Patch_rejected+'/patch_'+str(i)+'_mask.tif', convert2Mask(normalizeMinMax(all_patches_mask[i]),0))\n","\n"," return Patch_source, Patch_target\n","\n","\n","def estimatePatchSize(data_path, max_width = 512, max_height = 512):\n","\n"," files = os.listdir(data_path)\n"," \n"," # Get the size of the first image found in the folder and initialise the variables to that\n"," n = 0 \n"," while os.path.isdir(os.path.join(data_path, files[n])):\n"," n += 1\n"," (height_min, width_min) = Image.open(os.path.join(data_path, files[n])).size\n","\n"," # Screen the size of all dataset to find the minimum image size\n"," for file in files:\n"," if not os.path.isdir(os.path.join(data_path, file)):\n"," (height, width) = Image.open(os.path.join(data_path, file)).size\n"," if width < width_min:\n"," width_min = width\n"," if height < height_min:\n"," height_min = height\n"," \n"," # Find the power of patches that will fit within the smallest dataset\n"," width_min, height_min = (fittingPowerOfTwo(width_min), fittingPowerOfTwo(height_min))\n","\n"," # Clip values at maximum permissible values\n"," if width_min > max_width:\n"," width_min = max_width\n","\n"," if height_min > max_height:\n"," height_min = max_height\n"," \n"," return (width_min, height_min)\n","\n","def fittingPowerOfTwo(number):\n"," n = 0\n"," while 2**n <= number:\n"," n += 1 \n"," return 2**(n-1)\n","\n","\n","def getClassWeights(Training_target_path):\n","\n"," Mask_dir_list = os.listdir(Training_target_path)\n"," number_of_dataset = len(Mask_dir_list)\n","\n"," class_count = np.zeros(2, dtype=int)\n"," for i in tqdm(range(number_of_dataset)):\n"," mask = io.imread(os.path.join(Training_target_path, Mask_dir_list[i]))\n"," mask = normalizeMinMax(mask)\n"," class_count[0] += mask.shape[0]*mask.shape[1] - mask.sum()\n"," class_count[1] += mask.sum()\n","\n"," n_samples = class_count.sum()\n"," n_classes = 2\n","\n"," class_weights = n_samples / (n_classes * class_count)\n"," return class_weights\n","\n","def weighted_binary_crossentropy(class_weights):\n","\n"," def _weighted_binary_crossentropy(y_true, y_pred):\n"," binary_crossentropy = keras.binary_crossentropy(y_true, y_pred)\n"," weight_vector = y_true * class_weights[1] + (1. - y_true) * class_weights[0]\n"," weighted_binary_crossentropy = weight_vector * binary_crossentropy\n","\n"," return keras.mean(weighted_binary_crossentropy)\n","\n"," return _weighted_binary_crossentropy\n","\n","\n","def save_augment(datagen,orig_img,dir_augmented_data=\"/content/augment\"):\n"," \"\"\"\n"," Saves a subset of the augmented data for visualisation, by default in /content.\n","\n"," This is adapted from: https://fairyonice.github.io/Learn-about-ImageDataGenerator.html\n"," \n"," \"\"\"\n"," try:\n"," os.mkdir(dir_augmented_data)\n"," except:\n"," ## if the preview folder exists, then remove\n"," ## the contents (pictures) in the folder\n"," for item in os.listdir(dir_augmented_data):\n"," os.remove(dir_augmented_data + \"/\" + item)\n","\n"," ## convert the original image to array\n"," x = img_to_array(orig_img)\n"," ## reshape (Sampke, Nrow, Ncol, 3) 3 = R, G or B\n"," #print(x.shape)\n"," x = x.reshape((1,) + x.shape)\n"," #print(x.shape)\n"," ## -------------------------- ##\n"," ## randomly generate pictures\n"," ## -------------------------- ##\n"," i = 0\n"," #We will just save 5 images,\n"," #but this can be changed, but note the visualisation in 3. currently uses 5.\n"," Nplot = 5\n"," for batch in datagen.flow(x,batch_size=1,\n"," save_to_dir=dir_augmented_data,\n"," save_format='tif',\n"," seed=42):\n"," i += 1\n"," if i > Nplot - 1:\n"," break\n","\n","# Generators\n","def buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, subset, batch_size, target_size):\n"," '''\n"," Can generate image and mask at the same time use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same\n"," \n"," datagen: ImageDataGenerator \n"," subset: can take either 'training' or 'validation'\n"," '''\n"," seed = 1\n"," image_generator = image_datagen.flow_from_directory(\n"," os.path.dirname(image_folder_path),\n"," classes = [os.path.basename(image_folder_path)],\n"," class_mode = None,\n"," color_mode = \"grayscale\",\n"," target_size = target_size,\n"," batch_size = batch_size,\n"," subset = subset,\n"," interpolation = \"bicubic\",\n"," seed = seed)\n"," \n"," mask_generator = mask_datagen.flow_from_directory(\n"," os.path.dirname(mask_folder_path),\n"," classes = [os.path.basename(mask_folder_path)],\n"," class_mode = None,\n"," color_mode = \"grayscale\",\n"," target_size = target_size,\n"," batch_size = batch_size,\n"," subset = subset,\n"," interpolation = \"nearest\",\n"," seed = seed)\n"," \n"," this_generator = zip(image_generator, mask_generator)\n"," for (img,mask) in this_generator:\n"," # img,mask = adjustData(img,mask)\n"," yield (img,mask)\n","\n","\n","def prepareGenerators(image_folder_path, mask_folder_path, datagen_parameters, batch_size = 4, target_size = (512, 512)):\n"," image_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizePercentile)\n"," mask_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizeMinMax)\n","\n"," train_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'training', batch_size, target_size)\n"," validation_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'validation', batch_size, target_size)\n","\n"," return (train_datagen, validation_datagen)\n","\n","\n","# Normalization functions from Martin Weigert\n","def normalizePercentile(x, pmin=1, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","\n","\n","# Simple normalization to min/max fir the Mask\n","def normalizeMinMax(x, dtype=np.float32):\n"," x = x.astype(dtype,copy=False)\n"," x = (x - np.amin(x)) / (np.amax(x) - np.amin(x))\n"," return x\n","\n","\n","# def predictionGenerator(Data_path, target_size = (256,256), as_gray = True):\n","# for filename in os.listdir(Data_path):\n","# if not os.path.isdir(os.path.join(Data_path, filename)):\n","# img = io.imread(os.path.join(Data_path, filename), as_gray = as_gray)\n","# img = normalizePercentile(img)\n","# # img = img/255 # WARNING: this is expecting 8bit images\n","# img = transform.resize(img,target_size, preserve_range=True, anti_aliasing=True, order = 1) # liner interpolation\n","# img = np.reshape(img,img.shape+(1,))\n","# img = np.reshape(img,(1,)+img.shape)\n","# yield img\n","\n","\n","# def predictionResize(Data_path, predictions):\n","# resized_predictions = []\n","# for (i, filename) in enumerate(os.listdir(Data_path)):\n","# if not os.path.isdir(os.path.join(Data_path, filename)):\n","# img = Image.open(os.path.join(Data_path, filename))\n","# (width, height) = img.size\n","# resized_predictions.append(transform.resize(predictions[i], (height, width), preserve_range=True, anti_aliasing=True, order = 1))\n","# return resized_predictions\n","\n","\n","# This is code outlines the architecture of U-net. The choice of pooling steps decides the depth of the network. \n","def unet(pretrained_weights = None, input_size = (256,256,1), pooling_steps = 4, learning_rate = 1e-4, verbose=True, class_weights=np.ones(2)):\n"," inputs = Input(input_size)\n"," conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)\n"," conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)\n"," # Downsampling steps\n"," pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)\n"," conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)\n"," conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)\n"," \n"," if pooling_steps > 1:\n"," pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)\n"," conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)\n"," conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)\n","\n"," if pooling_steps > 2:\n"," pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)\n"," conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)\n"," conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)\n"," drop4 = Dropout(0.5)(conv4)\n"," \n"," if pooling_steps > 3:\n"," pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)\n"," conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)\n"," conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)\n"," drop5 = Dropout(0.5)(conv5)\n","\n"," #Upsampling steps\n"," up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))\n"," merge6 = concatenate([drop4,up6], axis = 3)\n"," conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)\n"," conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)\n"," \n"," if pooling_steps > 2:\n"," up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop4))\n"," if pooling_steps > 3:\n"," up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))\n"," merge7 = concatenate([conv3,up7], axis = 3)\n"," conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)\n"," \n"," if pooling_steps > 1:\n"," up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv3))\n"," if pooling_steps > 2:\n"," up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))\n"," merge8 = concatenate([conv2,up8], axis = 3)\n"," conv8 = Conv2D(128, 3, activation= 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)\n"," \n"," if pooling_steps == 1:\n"," up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv2))\n"," else:\n"," up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) #activation = 'relu'\n"," \n"," merge9 = concatenate([conv1,up9], axis = 3)\n"," conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(merge9) #activation = 'relu'\n"," conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'\n"," conv9 = Conv2D(2, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'\n"," conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)\n","\n"," model = Model(inputs = inputs, outputs = conv10)\n","\n"," # model.compile(optimizer = Adam(lr = learning_rate), loss = 'binary_crossentropy', metrics = ['acc'])\n"," model.compile(optimizer = Adam(lr = learning_rate), loss = weighted_binary_crossentropy(class_weights))\n","\n","\n"," if verbose:\n"," model.summary()\n","\n"," if(pretrained_weights):\n"," \tmodel.load_weights(pretrained_weights);\n","\n"," return model\n","\n","\n","\n","def predict_as_tiles(Image_path, model):\n","\n"," # Read the data in and normalize\n"," Image_raw = io.imread(Image_path, as_gray = True)\n"," Image_raw = normalizePercentile(Image_raw)\n","\n"," # Get the patch size from the input layer of the model\n"," patch_size = model.layers[0].output_shape[1:3]\n","\n"," # Pad the image with zeros if any of its dimensions is smaller than the patch size\n"," if Image_raw.shape[0] < patch_size[0] or Image_raw.shape[1] < patch_size[1]:\n"," Image = np.zeros((max(Image_raw.shape[0], patch_size[0]), max(Image_raw.shape[1], patch_size[1])))\n"," Image[0:Image_raw.shape[0], 0: Image_raw.shape[1]] = Image_raw\n"," else:\n"," Image = Image_raw\n","\n"," # Calculate the number of patches in each dimension\n"," n_patch_in_width = ceil(Image.shape[0]/patch_size[0])\n"," n_patch_in_height = ceil(Image.shape[1]/patch_size[1])\n","\n"," prediction = np.zeros(Image.shape)\n","\n"," for x in range(n_patch_in_width):\n"," for y in range(n_patch_in_height):\n"," xi = patch_size[0]*x\n"," yi = patch_size[1]*y\n","\n"," # If the patch exceeds the edge of the image shift it back \n"," if xi+patch_size[0] >= Image.shape[0]:\n"," xi = Image.shape[0]-patch_size[0]\n","\n"," if yi+patch_size[1] >= Image.shape[1]:\n"," yi = Image.shape[1]-patch_size[1]\n"," \n"," # Extract and reshape the patch\n"," patch = Image[xi:xi+patch_size[0], yi:yi+patch_size[1]]\n"," patch = np.reshape(patch,patch.shape+(1,))\n"," patch = np.reshape(patch,(1,)+patch.shape)\n","\n"," # Get the prediction from the patch and paste it in the prediction in the right place\n"," predicted_patch = model.predict(patch, batch_size = 1)\n"," prediction[xi:xi+patch_size[0], yi:yi+patch_size[1]] = np.squeeze(predicted_patch)\n","\n","\n"," return prediction[0:Image_raw.shape[0], 0: Image_raw.shape[1]]\n"," \n","\n","\n","\n","def saveResult(save_path, nparray, source_dir_list, prefix='', threshold=None):\n"," for (filename, image) in zip(source_dir_list, nparray):\n"," io.imsave(os.path.join(save_path, prefix+os.path.splitext(filename)[0]+'.tif'), img_as_ubyte(image)) # saving as unsigned 8-bit image\n"," \n"," # For masks, threshold the images and return 8 bit image\n"," if threshold is not None:\n"," mask = convert2Mask(image, threshold)\n"," io.imsave(os.path.join(save_path, prefix+'mask_'+os.path.splitext(filename)[0]+'.tif'), mask)\n","\n","\n","def convert2Mask(image, threshold):\n"," mask = img_as_ubyte(image, force_copy=True)\n"," mask[mask > threshold] = 255\n"," mask[mask <= threshold] = 0\n"," return mask\n","\n","\n","def getIoUvsThreshold(prediction_filepath, groud_truth_filepath):\n"," prediction = io.imread(prediction_filepath)\n"," ground_truth_image = img_as_ubyte(io.imread(groud_truth_filepath, as_gray=True), force_copy=True)\n","\n"," threshold_list = []\n"," IoU_scores_list = []\n","\n"," for threshold in range(0,256): \n"," # Convert to 8-bit for calculating the IoU\n"," mask = img_as_ubyte(prediction, force_copy=True)\n"," mask[mask > threshold] = 255\n"," mask[mask <= threshold] = 0\n","\n"," # Intersection over Union metric\n"," intersection = np.logical_and(ground_truth_image, np.squeeze(mask))\n"," union = np.logical_or(ground_truth_image, np.squeeze(mask))\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," threshold_list.append(threshold)\n"," IoU_scores_list.append(iou_score)\n","\n"," return (threshold_list, IoU_scores_list)\n","\n","\n","\n","# -------------- Other definitions -----------\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","prediction_prefix = 'Predicted_'\n","\n","\n","print('-------------------')\n","print('U-Net and dependencies installed.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"7hTKImff6Est","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---"]},{"cell_type":"markdown","metadata":{"id":"S74FbqV6PNNv","colab_type":"text"},"source":["##**3.1. Parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"3np5EpJF8_q2","colab_type":"text"},"source":[" **Paths for training data and models**\n","\n","**`Training_source`, `Training_target`:** These are the folders containing your source (e.g. EM images) and target files (segmentation masks). Enter the path to the source and target images for training. **These should be located in the same parent folder.**\n","\n","**`model_name`:** Use only my_model -style, not my-model. If you want to use a previously trained model, enter the name of the pretrained model (which should be contained in the trained_model -folder after training).\n","\n","**`model_path`**: Enter the path of the folder where you want to save your model.\n","\n","**`visual_validation_after_training`**: If you select this option, a random image pair will be set aside from your training set and will be used to display a predicted image of the trained network next to the input and the ground-truth. This can aid in visually assessing the performance of your network after training. **Note: Your training set size will decrease by 1 if you select this option.**\n","\n","**Make sure the directories exist before entering them!**\n","\n"," **Select training parameters**\n","\n","**`number_of_epochs`**: Choose more epochs for larger training sets. Observing how much the loss reduces between epochs during training may help determine the optimal value. **Default: 200**\n","\n","**Advanced parameters - experienced users only**\n","\n","**`batch_size`**: This parameter describes the amount of images that are loaded into the network per step. Smaller batchsizes may improve training performance slightly but may increase training time. If the notebook crashes while loading the dataset this can be due to a too large batch size. Decrease the number in this case. **Default: 4**\n","\n","**`number_of_steps`**: This number should be equivalent to the number of samples in the training set divided by the batch size, to ensure the training iterates through the entire training set. Smaller values can be used for testing. **Default: 6**\n","\n"," **`pooling_steps`**: Choosing a different number of pooling layers can affect the performance of the network. Each additional pooling step will also two additional convolutions. The network can learn more complex information but is also more likely to overfit. Achieving best performance may require testing different values here. **Default: 2**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**`patch_width` and `patch_height`:** The notebook crops the data in patches of fixed size prior to training. The dimensions of the patches can be defined here. When `Use_Default_Advanced_Parameters` is selected, the largest 2^n x 2^n patch that fits in the smallest dataset is chosen. Larger patches than 512x512 should **NOT** be selected for network stability.\n","\n"]},{"cell_type":"code","metadata":{"id":"7deNuPZd5d-B","colab_type":"code","cellView":"form","colab":{}},"source":["# ------------- Initial user input ------------\n","#@markdown ###Path to training images:\n","Training_source = '' #@param {type:\"string\"}\n","Training_target = '' #@param {type:\"string\"}\n","\n","model_name = '' #@param {type:\"string\"}\n","model_path = '' #@param {type:\"string\"}\n","\n","#@markdown ###Training parameters:\n","#@markdown Number of epochs\n","number_of_epochs = 200#@param {type:\"number\"}\n","\n","#@markdown ###Advanced parameters:\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 4#@param {type:\"integer\"}\n","number_of_steps = 6#@param {type:\"number\"}\n","pooling_steps = 2 #@param [1,2,3,4]{type:\"raw\"}\n","percentage_validation = 10#@param{type:\"number\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","patch_width = 512#@param{type:\"number\"}\n","patch_height = 512#@param{type:\"number\"}\n","\n","\n","# ------------- Initialising folder, variables and failsafes ------------\n","# Create the folders where to save the model and the QC\n","full_model_path = os.path.join(model_path, model_name)\n","if os.path.exists(full_model_path):\n"," print(R+'!! WARNING: Folder already exists and will be overwritten !!'+W)\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 4\n"," pooling_steps = 2\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0003\n"," patch_width, patch_height = estimatePatchSize(Training_source)\n","\n","\n","#The create_patches function will create the two folders below\n","# Patch_source = '/content/img_patches'\n","# Patch_target = '/content/mask_patches'\n","print('Training on patches of size (x,y): ('+str(patch_width)+','+str(patch_height)+')')\n","\n","#Create patches\n","print('Creating patches...')\n","Patch_source, Patch_target = create_patches(Training_source, Training_target, patch_width, patch_height)\n","\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","Use_Data_augmentation = False\n","\n","# ------------- Display ------------\n","\n","#if not os.path.exists('/content/img_patches/'):\n","random_choice = random.choice(os.listdir(Patch_source))\n","x = io.imread(os.path.join(Patch_source, random_choice))\n","\n","#os.chdir(Training_target)\n","y = io.imread(os.path.join(Patch_target, random_choice), as_gray=True)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest',cmap='gray')\n","plt.title('Training image patch')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest',cmap='gray')\n","plt.title('Training mask patch')\n","plt.axis('off');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"V9UCjlLJ5Rfc","colab_type":"text"},"source":["##**3.2. Data augmentation**\n","\n","---\n","\n"," Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if the dataset is large the values can be set to 0.\n","\n"," The augmentation options below are to be used as follows:\n","\n","* **shift**: a translation of the image by a fraction of the image size (width or height), **default: 10%**\n","* **zoom_range**: Increasing or decreasing the field of view. E.g. 10% will result in a zoom range of (0.9 to 1.1), with pixels added or interpolated, depending on the transformation, **default: 10%**\n","* **shear_range**: Shear angle in counter-clockwise direction, **default: 10%**\n","* **flip**: creating a mirror image along specified axis (horizontal or vertical), **default: True**\n","* **rotation_range**: range of allowed rotation angles in degrees (from 0 to *value*), **default: 180**"]},{"cell_type":"code","metadata":{"id":"i-PahNX94-pl","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##**Augmentation options**\n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," if Use_Default_Augmentation_Parameters:\n"," horizontal_shift = 10 \n"," vertical_shift = 20 \n"," zoom_range = 10\n"," shear_range = 10\n"," horizontal_flip = True\n"," vertical_flip = True\n"," rotation_range = 180\n","#@markdown ###If you are not using the default settings, please provide the values below:\n","\n","#@markdown ###**Image shift, zoom, shear and flip (%)**\n"," else:\n"," horizontal_shift = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," vertical_shift = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," zoom_range = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," shear_range = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," horizontal_flip = True #@param {type:\"boolean\"}\n"," vertical_flip = True #@param {type:\"boolean\"}\n","\n","#@markdown ###**Rotate image within angle range (degrees):**\n"," rotation_range = 180 #@param {type:\"slider\", min:0, max:180, step:1}\n","\n","#given behind the # are the default values for each parameter.\n","\n","else:\n"," horizontal_shift = 0 \n"," vertical_shift = 0 \n"," zoom_range = 0\n"," shear_range = 0\n"," horizontal_flip = False\n"," vertical_flip = False\n"," rotation_range = 0\n","\n","\n","# Build the dict for the ImageDataGenerator\n","data_gen_args = dict(width_shift_range = horizontal_shift/100.,\n"," height_shift_range = vertical_shift/100.,\n"," rotation_range = rotation_range, #90\n"," zoom_range = zoom_range/100.,\n"," shear_range = shear_range/100.,\n"," horizontal_flip = horizontal_flip,\n"," vertical_flip = vertical_flip,\n"," validation_split = percentage_validation/100,\n"," fill_mode = 'reflect')\n","\n","\n","\n","# ------------- Display ------------\n","dir_augmented_data_imgs=\"/content/augment_img\"\n","dir_augmented_data_masks=\"/content/augment_mask\"\n","random_choice = random.choice(os.listdir(Patch_source))\n","orig_img = load_img(os.path.join(Patch_source,random_choice))\n","orig_mask = load_img(os.path.join(Patch_target,random_choice))\n","\n","augment_view = ImageDataGenerator(**data_gen_args)\n","\n","if Use_Data_augmentation:\n"," print(\"Parameters enabled\")\n"," print(\"Here is what a subset of your augmentations looks like:\")\n"," save_augment(augment_view, orig_img, dir_augmented_data=dir_augmented_data_imgs)\n"," save_augment(augment_view, orig_mask, dir_augmented_data=dir_augmented_data_masks)\n","\n"," fig = plt.figure(figsize=(15, 7))\n"," fig.subplots_adjust(hspace=0.0,wspace=0.1,left=0,right=1.1,bottom=0, top=0.8)\n","\n"," \n"," ax = fig.add_subplot(2, 6, 1,xticks=[],yticks=[]) \n"," new_img=img_as_ubyte(normalizeMinMax(img_to_array(orig_img)))\n"," ax.imshow(new_img)\n"," ax.set_title('Original Image')\n"," i = 2\n"," for imgnm in os.listdir(dir_augmented_data_imgs):\n"," ax = fig.add_subplot(2, 6, i,xticks=[],yticks=[]) \n"," img = load_img(dir_augmented_data_imgs + \"/\" + imgnm)\n"," ax.imshow(img)\n"," i += 1\n","\n"," ax = fig.add_subplot(2, 6, 7,xticks=[],yticks=[]) \n"," new_mask=img_as_ubyte(normalizeMinMax(img_to_array(orig_mask)))\n"," ax.imshow(new_mask)\n"," ax.set_title('Original Mask')\n"," j=2\n"," for imgnm in os.listdir(dir_augmented_data_masks):\n"," ax = fig.add_subplot(2, 6, j+6,xticks=[],yticks=[]) \n"," mask = load_img(dir_augmented_data_masks + \"/\" + imgnm)\n"," ax.imshow(mask)\n"," j += 1\n"," plt.show()\n","\n","else:\n"," print(\"No augmentation will be used\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"7vFEIHbNAuOs","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a U-Net model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"RfR9UyKAAulw","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the UNET_Model_from_\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(R+'WARNING: pretrained model does not exist')\n"," Use_pretrained_model = False\n"," \n","\n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(R+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"94FX4wzE8w1W","colab_type":"text"},"source":["# **4. Train the network**\n","---\n","####**Troubleshooting:** If you receive a time-out or exhausted error, try reducing the batchsize of your training set. This reduces the amount of data loaded into the model at one point in time. "]},{"cell_type":"markdown","metadata":{"id":"tlTDGcmDDHDe","colab_type":"text"},"source":["## **4.1. Prepare model for training**\n","---"]},{"cell_type":"code","metadata":{"id":"ezFy_mpz_op4","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play this cell to prepare the model for training\n","\n","\n","# ------------------ Set the generators, model and logger ------------------\n","# This will take the image size and set that as a patch size (arguable...)\n","# Read image size (without actuall reading the data)\n","\n","\n","# n = 0 \n","# while os.path.isdir(os.path.join(Training_source, source_images[n])):\n","# n += 1\n","\n","# (width, height) = Image.open(os.path.join(Training_target, source_images[n])).size\n","# ImageSize = (height, width) # np.shape different from PIL image.size return !\n","\n","# !!! WARNING !!! Check potential issues with resizing at the ImageDataGenerator level\n","# (train_datagen, validation_datagen) = prepareGenerators(Training_source, Training_target, data_gen_args, batch_size, target_size = ImageSize)\n","(train_datagen, validation_datagen) = prepareGenerators(Patch_source, Patch_target, data_gen_args, batch_size, target_size = (patch_width, patch_height))\n","\n","\n","# This modelcheckpoint will only save the best model from the validation loss point of view\n","model_checkpoint = ModelCheckpoint(os.path.join(full_model_path, 'weights_best.hdf5'), monitor='val_loss',verbose=1, save_best_only=True)\n","\n","print('Getting class weights...')\n","class_weights = getClassWeights(Training_target)\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we make sure this is properly defined\n","if not Use_pretrained_model:\n"," h5_file_path = None\n","# --------------------- ---------------------- ------------------------\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# --------------------- Reduce learning rate on plateau ------------------------\n","\n","reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, verbose=1, mode='auto',\n"," patience=10, min_lr=0)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# Define the model\n","model = unet(pretrained_weights = h5_file_path, \n"," input_size = (patch_width,patch_height,1), \n"," pooling_steps = pooling_steps, \n"," learning_rate = initial_learning_rate, \n"," class_weights = class_weights)\n","\n","# Dfine CSV logger that will create the loss file (we're not using this anylonger)\n","# csv_log = CSVLogger(os.path.join(full_model_path, 'Quality Control', 'training_evaluation.csv'), separator=',', append=False)\n","\n","number_of_training_dataset = len(os.listdir(Patch_source))\n","\n","if Use_Default_Advanced_Parameters:\n"," number_of_steps = ceil((100-percentage_validation)/100*number_of_training_dataset/batch_size)\n","\n","# Calculate the number of steps to use for validation\n","validation_steps = max(1, ceil(percentage_validation/100*number_of_training_dataset/batch_size))\n","\n","config_model= model.optimizer.get_config()\n","print(config_model)\n","\n","\n","# ------------------ Failsafes ------------------\n","if os.path.exists(full_model_path):\n"," print(R+'!! WARNING: Model folder already existed and has been removed !!'+W)\n"," shutil.rmtree(full_model_path)\n","\n","os.makedirs(full_model_path)\n","os.makedirs(os.path.join(full_model_path,'Quality Control'))\n","\n","\n","# ------------------ Display ------------------\n","print('---------------------------- Main training parameters ----------------------------')\n","print('Number of epochs: '+str(number_of_epochs))\n","print('Batch size: '+str(batch_size))\n","print('Number of training dataset: '+str(number_of_training_dataset))\n","print('Number of training steps: '+str(number_of_steps))\n","print('Number of validation steps: '+str(validation_steps))\n","print('---------------------------- ------------------------ ----------------------------')\n","\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"urpQ9UM-6NBE","colab_type":"text"},"source":["## **4.2. Train the network**\n","---\n","\n","####**Be patient**. Please be patient, this may take a while. But the verbose allow you to estimate how fast it's training and how long it'll take. While it's training, please make sure that the computer is not powering down due to inactivity, otherwise this will interupt the runtime."]},{"cell_type":"code","metadata":{"id":"sMyCENd29TKz","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","# history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps, epochs=epochs, callbacks=[model_checkpoint,csv_log], validation_data = validation_datagen, validation_steps = validation_steps, shuffle=True, verbose=1)\n","history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps, epochs = number_of_epochs, callbacks=[model_checkpoint, reduce_lr], validation_data = validation_datagen, validation_steps = validation_steps, shuffle=True, verbose=1)\n","\n","# Save the last model\n","model.save(os.path.join(full_model_path, 'weights_last.hdf5'))\n","\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = os.path.join(full_model_path,'Quality Control/training_evaluation.csv')\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n"," \n","\n","\n","# Displaying the time elapsed for training\n","print(\"------------------------------------------\")\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\", hour, \"hour(s)\", mins,\"min(s)\",round(sec),\"sec(s)\")\n","print(\"------------------------------------------\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"LWaFk0JNda-N","colab_type":"text"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"mEMcFNHZdmTz","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"X11zGW0Ldu-z","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###Do you want to assess the model you just trained ?\n","\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","\n","full_QC_model_path = os.path.join(QC_model_path, QC_model_name)\n","if os.path.exists(os.path.join(full_QC_model_path, 'weights_best.hdf5')):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"pkJyRzWJCrKG","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"qul6BpaX1GqS","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","epochNumber = []\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(os.path.join(full_QC_model_path, 'Quality Control', 'lossCurvePlots.png'))\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"h33P0C2geqZu","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder. The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n","The Input, Ground Truth, Prediction and IoU maps are shown below for the last example in the QC set.\n","\n"," The results for all QC examples can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\".\n","\n","### **Thresholds for image masks**\n","\n"," Since the output from Unet is not a binary mask, the output images are converted to binary masks using thresholding. This section will test different thresholds (from 0 to 255) to find the one yielding the best IoU score compared with the ground truth. The best threshold for each image and the average of these thresholds will be displayed below. **These values can be a guideline when creating masks for unseen data in section 6.**"]},{"cell_type":"code","metadata":{"id":"Tpqjvwv2zug-","colab_type":"code","cellView":"form","colab":{}},"source":["# ------------- User input ------------\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","\n","# ------------- Initialise folders ------------\n","# Create a quality control/Prediction Folder\n","prediction_QC_folder = os.path.join(full_QC_model_path, 'Quality Control', 'Prediction')\n","if os.path.exists(prediction_QC_folder):\n"," shutil.rmtree(prediction_QC_folder)\n","\n","os.makedirs(prediction_QC_folder)\n","\n","\n","# ------------- Prepare the model and run predictions ------------\n","\n","# Load the model\n","unet = load_model(os.path.join(full_QC_model_path, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n","Input_size = unet.layers[0].output_shape[1:3]\n","print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))\n","\n","# Create a list of sources\n","source_dir_list = os.listdir(Source_QC_folder)\n","number_of_dataset = len(source_dir_list)\n","print('Number of dataset found in the folder: '+str(number_of_dataset))\n","\n","predictions = []\n","for i in tqdm(range(number_of_dataset)):\n"," predictions.append(predict_as_tiles(os.path.join(Source_QC_folder, source_dir_list[i]), unet))\n","\n","\n","# Save the results in the folder along with the masks according to the set threshold\n","saveResult(prediction_QC_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=None)\n","\n","#-----------------------------Calculate Metrics----------------------------------------#\n","\n","f = plt.figure(figsize=((5,5)))\n","\n","with open(os.path.join(full_QC_model_path,'Quality Control', 'QC_metrics_'+QC_model_name+'.csv'), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"File name\",\"IoU\", \"IoU-optimised threshold\"]) \n","\n"," # Initialise the lists \n"," filename_list = []\n"," best_threshold_list = []\n"," best_IoU_score_list = []\n","\n"," for filename in os.listdir(Source_QC_folder):\n","\n"," if not os.path.isdir(os.path.join(Source_QC_folder, filename)):\n"," print('Running QC on: '+filename)\n"," test_input = io.imread(os.path.join(Source_QC_folder, filename), as_gray=True)\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, filename), as_gray=True)\n","\n"," (threshold_list, iou_scores_per_threshold) = getIoUvsThreshold(os.path.join(prediction_QC_folder, prediction_prefix+filename), os.path.join(Target_QC_folder, filename))\n"," plt.plot(threshold_list,iou_scores_per_threshold, label=filename)\n","\n"," # Here we find which threshold yielded the highest IoU score for image n.\n"," best_IoU_score = max(iou_scores_per_threshold)\n"," best_threshold = iou_scores_per_threshold.index(best_IoU_score)\n","\n"," # Write the results in the CSV file\n"," writer.writerow([filename, str(best_IoU_score), str(best_threshold)])\n","\n"," # Here we append the best threshold and score to the lists\n"," filename_list.append(filename)\n"," best_IoU_score_list.append(best_IoU_score)\n"," best_threshold_list.append(best_threshold)\n","\n","# Display the IoV vs Threshold plot\n","plt.title('IoU vs. Threshold')\n","plt.ylabel('Threshold value')\n","plt.xlabel('IoU')\n","plt.legend()\n","plt.show()\n","\n","\n","# Table with metrics as dataframe output\n","pdResults = pd.DataFrame(index = filename_list)\n","pdResults[\"IoU\"] = best_IoU_score_list\n","pdResults[\"IoU-optimised threshold\"] = best_threshold_list\n","\n","\n","\n","average_best_threshold = sum(best_threshold_list)/len(best_threshold_list)\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_QC_results(file=os.listdir(Source_QC_folder)):\n"," \n"," plt.figure(figsize=(25,5))\n"," #Input\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," plt.imshow(plt.imread(os.path.join(Source_QC_folder, file)), aspect='equal', cmap='gray', interpolation='nearest')\n"," plt.title('Input')\n","\n"," #Ground-truth\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, file),as_gray=True)\n"," plt.imshow(test_ground_truth_image, aspect='equal', cmap='Greens')\n"," plt.title('Ground Truth')\n","\n"," #Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," test_prediction = plt.imread(os.path.join(prediction_QC_folder, prediction_prefix+file))\n"," test_prediction_mask = np.empty_like(test_prediction)\n"," test_prediction_mask[test_prediction > average_best_threshold] = 255\n"," test_prediction_mask[test_prediction <= average_best_threshold] = 0\n"," plt.imshow(test_prediction_mask, aspect='equal', cmap='Purples')\n"," plt.title('Prediction')\n","\n"," #Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(test_ground_truth_image, cmap='Greens')\n"," plt.imshow(test_prediction_mask, alpha=0.5, cmap='Purples')\n"," metrics_title = 'Overlay (IoU: ' + str(round(pdResults.loc[file][\"IoU\"],3)) + ' T: ' + str(round(pdResults.loc[file][\"IoU-optimised threshold\"])) + ')'\n"," plt.title(metrics_title)\n","\n","\n","\n","print('--------------------------------------------------------------')\n","print('Best average threshold is: '+str(round(average_best_threshold)))\n","print('--------------------------------------------------------------')\n","\n","pdResults.head()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"gofmRsLP96O8","colab_type":"text"},"source":["# **6. Using the trained model**\n","\n","---\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"Pv_v1Ru2OJkU","colab_type":"text"},"source":["## **6.1 Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.1) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n"," Once the predictions are complete the cell will display a random example prediction beside the input image and the calculated mask for visual inspection.\n","\n"," **Troubleshooting:** If there is a low contrast image warning when saving the images, this may be due to overfitting of the model to the data. It may result in images containing only a single colour. Train the network again with different network hyperparameters."]},{"cell_type":"code","metadata":{"id":"FJAe55ZoOJGs","colab_type":"code","cellView":"form","colab":{}},"source":["# ------------- Initial user input ------------\n","#@markdown ###Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","Data_folder = '' #@param {type:\"string\"}\n","Results_folder = '' #@param {type:\"string\"}\n","\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","\n","# ------------- Failsafes ------------\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = os.path.join(Prediction_model_path, Prediction_model_name)\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","# ------------- Prepare the model and run predictions ------------\n","\n","# Load the model and prepare generator\n","\n","unet = load_model(os.path.join(Prediction_model_path, Prediction_model_name, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n","Input_size = unet.layers[0].output_shape[1:3]\n","print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))\n","\n","# Create a list of sources\n","source_dir_list = os.listdir(Data_folder)\n","number_of_dataset = len(source_dir_list)\n","print('Number of dataset found in the folder: '+str(number_of_dataset))\n","\n","predictions = []\n","for i in tqdm(range(number_of_dataset)):\n"," predictions.append(predict_as_tiles(os.path.join(Data_folder, source_dir_list[i]), unet))\n"," # predictions.append(prediction(os.path.join(Data_folder, source_dir_list[i]), os.path.join(Prediction_model_path, Prediction_model_name)))\n","\n","\n","# Save the results in the folder along with the masks according to the set threshold\n","saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=None)\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","\n","def show_prediction_mask(file=os.listdir(Data_folder), threshold=(0,255,1)):\n","\n"," plt.figure(figsize=(18,6))\n"," # Wide-field\n"," plt.subplot(1,3,1)\n"," plt.axis('off')\n"," img_Source = plt.imread(os.path.join(Data_folder, file))\n"," plt.imshow(img_Source, cmap='gray')\n"," plt.title('Source image',fontsize=15)\n"," # Prediction\n"," plt.subplot(1,3,2)\n"," plt.axis('off')\n"," img_Prediction = plt.imread(os.path.join(Results_folder, prediction_prefix+file))\n"," plt.imshow(img_Prediction, cmap='gray')\n"," plt.title('Prediction',fontsize=15)\n","\n"," # Thresholded mask\n"," plt.subplot(1,3,3)\n"," plt.axis('off')\n"," img_Mask = convert2Mask(img_Prediction, threshold)\n"," plt.imshow(img_Mask, cmap='gray')\n"," plt.title('Mask (Threshold: '+str(round(threshold))+')',fontsize=15)\n","\n","\n","interact(show_prediction_mask, continuous_update=False);\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"su-Mo2POVpja","colab_type":"text"},"source":["## **6.2. Export results as masks**\n","---\n"]},{"cell_type":"code","metadata":{"id":"iC_B_9lxNUny","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# @markdown #Play this cell to save results as masks with the chosen threshold\n","threshold = 120#@param {type:\"number\"}\n","\n","saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=threshold)\n","print('-------------------')\n","print('Masks were saved in: '+Results_folder)\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wYmwCQKjYsJ7","colab_type":"text"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"sCXzzvnh2_rc","colab_type":"text"},"source":["#**Thank you for using U-Net!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"U-Net_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1VcTsLOL28ntbr23gYrhY3upxkztZeUvn","timestamp":1591024690909},{"file_id":"19jT_GoHGN-UTM1aEgkgrOjB8pcFz5AW4","timestamp":1591017297795},{"file_id":"1UkoWB27ZWh5j_qivSZIOeOJP1h2EqrVz","timestamp":1589363183397},{"file_id":"1ofNqOc7lz-m6NL4B-m4BIheaU5N0GMln","timestamp":1588873191434},{"file_id":"1rJnsgIKyL6vuneydIfjCKMtMhV3XlQ6o","timestamp":1588583580765},{"file_id":"1RUYrp8beEgDKL1kOWw5LgR1QQb4yHQtG","timestamp":1587061416704},{"file_id":"1FVax0eY3-m8DbJHx0B8Dnep-uGlp30Zt","timestamp":1586601038120},{"file_id":"1TTqmCf2mFQ_PNIZEXX9sRAhoixjYP_AB","timestamp":1585842446113},{"file_id":"1cWwS-jbLYTDOpPp_hhKOLGFXfu06ccpG","timestamp":1585821375983},{"file_id":"1TPEE_AtGTLedawgVBwwXofEJEcJUCgo3","timestamp":1585137343783},{"file_id":"1SxFRb38aC_kmKzKVQfkwWzkK9n7YFxVv","timestamp":1585053829456},{"file_id":"15iw9IOwHNF_GhiHxkh_rWbJG8JnW14Wh","timestamp":1584375074441},{"file_id":"15oMbXnMa4LDEMhPHBr3ga0xhJomMLhDo","timestamp":1584105762670},{"file_id":"1__NtYFNA3DxNB7LrUY13Bt8_frye3iWl","timestamp":1583445015203},{"file_id":"11jsQfqKeDU1Zk3nPykjWKwYhFmvJ1zJ-","timestamp":1575289898486}],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"WDrFAwpFIpE0","colab_type":"text"},"source":["# **U-Net (2D)**\n","---\n","\n","U-Net is an encoder-decoder network architecture originally used for image segmentation, first published by [Ronneberger *et al.*](https://arxiv.org/abs/1505.04597). The first half of the U-Net architecture is a downsampling convolutional neural network which acts as a feature extractor from input images. The other half upsamples these results and restores an image by combining results from downsampling with the upsampled images.\n","\n"," **This particular notebook enables image segmentation of 2D dataset. If you are interested in 3D dataset, you should use the 3D U-Net notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the papers: \n","\n","**U-Net: Convolutional Networks for Biomedical Image Segmentation** by Ronneberger *et al.* published on arXiv in 2015 (https://arxiv.org/abs/1505.04597)\n","\n","and \n","\n","**U-Net: deep learning for cell counting, detection, and morphometry** by Thorsten Falk *et al.* in Nature Methods 2019\n","(https://www.nature.com/articles/s41592-018-0261-2)\n","And source code found in: https://github.com/zhixuhao/unet by *Zhixuhao*\n","\n","**Please also cite this original paper when using or developing this notebook.** "]},{"cell_type":"markdown","metadata":{"id":"ABNu2p4stHeB","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"HVwncY_NvlYi","colab_type":"text"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For U-Net to train, **it needs to have access to a paired training dataset corresponding to images and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Training_source\n"," - img_1.tif, img_2.tif, ...\n"," - Training_target\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Training_source\n"," - img_1.tif, img_2.tif\n"," - Training_target \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"JrGNzgEyxzGQ","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"wYoajeT54sQM","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"TpT6gbwURzrV","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n","\n","# from tensorflow.python.client import device_lib \n","# device_lib.list_local_devices()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"quzkzlRD45HF","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"eLwDxBnp4-bc","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"leK5kmgD5Ism","colab_type":"text"},"source":["# **2. Install U-Net dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"vOeLpQfT0QF1","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play to install U-Net dependencies\n","\n","#As this notebokk depends mostly on keras which runs a tensorflow backend (which in turn is pre-installed in colab)\n","#only the data library needs to be additionally installed.\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","#We enforce the keras==2.2.5 release to ensure that the notebook continues working even if keras is updated.\n","\n","!pip install keras==2.2.5\n","!pip install data\n","\n","# Keras imports\n","from keras import models\n","from keras.models import Model, load_model\n","from keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, UpSampling2D\n","from keras.optimizers import Adam\n","# from keras.callbacks import ModelCheckpoint, LearningRateScheduler, CSVLogger # we currently don't use any other callbacks from ModelCheckpoints\n","from keras.callbacks import ModelCheckpoint\n","from keras.callbacks import ReduceLROnPlateau\n","from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img\n","from keras import backend as keras\n","\n","# General import\n","from __future__ import print_function\n","import numpy as np\n","import pandas as pd\n","import os\n","import glob\n","from skimage import img_as_ubyte, io, transform\n","import matplotlib as mpl\n","from matplotlib import pyplot as plt\n","from matplotlib.pyplot import imread\n","from pathlib import Path\n","import shutil\n","import random\n","import time\n","import csv\n","import sys\n","from math import ceil\n","\n","# Imports for QC\n","from PIL import Image\n","from scipy import signal\n","from scipy import ndimage\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","# from tqdm import tqdm\n","from tqdm.notebook import tqdm\n","\n","from sklearn.feature_extraction import image\n","from skimage import img_as_ubyte, io, transform\n","from skimage.util.shape import view_as_windows\n","\n","# Suppressing some warnings\n","import warnings\n","warnings.filterwarnings('ignore')\n","\n","\n","\n","def create_patches(Training_source, Training_target, patch_width, patch_height):\n"," \"\"\"\n"," Function creates patches from the Training_source and Training_target images. \n"," The steps parameter indicates the offset between patches and, if integer, is the same in x and y.\n"," Saves all created patches in two new directories in the /content folder.\n","\n"," Returns: - Two paths to where the patches are now saved\n"," \"\"\"\n"," DEBUG = False\n","\n"," Patch_source = os.path.join('/content','img_patches')\n"," Patch_target = os.path.join('/content','mask_patches')\n"," Patch_rejected = os.path.join('/content','rejected')\n"," \n","\n"," #Here we save the patches, in the /content directory as they will not usually be needed after training\n"," if os.path.exists(Patch_source):\n"," shutil.rmtree(Patch_source)\n"," if os.path.exists(Patch_target):\n"," shutil.rmtree(Patch_target)\n"," if os.path.exists(Patch_rejected):\n"," shutil.rmtree(Patch_rejected)\n","\n"," os.mkdir(Patch_source)\n"," os.mkdir(Patch_target)\n"," os.mkdir(Patch_rejected) #This directory will contain the images that have too little signal.\n"," \n","\n"," all_patches_img = np.empty([0,patch_width, patch_height])\n"," all_patches_mask = np.empty([0,patch_width, patch_height])\n","\n"," for file in os.listdir(Training_source):\n","\n"," img = io.imread(os.path.join(Training_source, file))\n"," mask = io.imread(os.path.join(Training_target, file),as_gray=True)\n","\n"," if DEBUG:\n"," print(file)\n"," print(img.dtype)\n","\n"," # Using view_as_windows with step size equal to the patch size to ensure there is no overlap\n"," patches_img = view_as_windows(img, (patch_width, patch_height), (patch_width, patch_height))\n"," patches_mask = view_as_windows(mask, (patch_width, patch_height), (patch_width, patch_height))\n"," #the shape of patches_img and patches_mask will be (number of patches along x, number of patches along y,patch_width,patch_height)\n","\n"," all_patches_img = np.concatenate((all_patches_img, patches_img.reshape(patches_img.shape[0]*patches_img.shape[1], patch_width,patch_height)), axis = 0)\n"," all_patches_mask = np.concatenate((all_patches_mask, patches_mask.reshape(patches_mask.shape[0]*patches_mask.shape[1], patch_width,patch_height)), axis = 0)\n","\n"," number_of_patches = all_patches_img.shape[0]\n"," print('number of patches: '+str(number_of_patches))\n","\n"," if DEBUG:\n"," print(all_patches_img.shape)\n"," print(all_patches_img.dtype)\n","\n"," for i in range(number_of_patches):\n"," img_save_path = os.path.join(Patch_source,'patch_'+str(i)+'.tif')\n"," mask_save_path = os.path.join(Patch_target,'patch_'+str(i)+'.tif')\n","\n"," # if the mask conatins at least 2% of its total number pixels as mask, then go ahead and save the images\n"," pixel_threshold_array = sorted(all_patches_mask[i].flatten())\n"," if pixel_threshold_array[int(round(len(pixel_threshold_array)*0.98))]>0:\n"," io.imsave(img_save_path, img_as_ubyte(normalizeMinMax(all_patches_img[i])))\n"," io.imsave(mask_save_path, convert2Mask(normalizeMinMax(all_patches_mask[i]),0))\n"," else:\n"," io.imsave(Patch_rejected+'/patch_'+str(i)+'_image.tif', img_as_ubyte(normalizeMinMax(all_patches_img[i])))\n"," io.imsave(Patch_rejected+'/patch_'+str(i)+'_mask.tif', convert2Mask(normalizeMinMax(all_patches_mask[i]),0))\n","\n"," return Patch_source, Patch_target\n","\n","\n","def estimatePatchSize(data_path, max_width = 512, max_height = 512):\n","\n"," files = os.listdir(data_path)\n"," \n"," # Get the size of the first image found in the folder and initialise the variables to that\n"," n = 0 \n"," while os.path.isdir(os.path.join(data_path, files[n])):\n"," n += 1\n"," (height_min, width_min) = Image.open(os.path.join(data_path, files[n])).size\n","\n"," # Screen the size of all dataset to find the minimum image size\n"," for file in files:\n"," if not os.path.isdir(os.path.join(data_path, file)):\n"," (height, width) = Image.open(os.path.join(data_path, file)).size\n"," if width < width_min:\n"," width_min = width\n"," if height < height_min:\n"," height_min = height\n"," \n"," # Find the power of patches that will fit within the smallest dataset\n"," width_min, height_min = (fittingPowerOfTwo(width_min), fittingPowerOfTwo(height_min))\n","\n"," # Clip values at maximum permissible values\n"," if width_min > max_width:\n"," width_min = max_width\n","\n"," if height_min > max_height:\n"," height_min = max_height\n"," \n"," return (width_min, height_min)\n","\n","def fittingPowerOfTwo(number):\n"," n = 0\n"," while 2**n <= number:\n"," n += 1 \n"," return 2**(n-1)\n","\n","\n","def getClassWeights(Training_target_path):\n","\n"," Mask_dir_list = os.listdir(Training_target_path)\n"," number_of_dataset = len(Mask_dir_list)\n","\n"," class_count = np.zeros(2, dtype=int)\n"," for i in tqdm(range(number_of_dataset)):\n"," mask = io.imread(os.path.join(Training_target_path, Mask_dir_list[i]))\n"," mask = normalizeMinMax(mask)\n"," class_count[0] += mask.shape[0]*mask.shape[1] - mask.sum()\n"," class_count[1] += mask.sum()\n","\n"," n_samples = class_count.sum()\n"," n_classes = 2\n","\n"," class_weights = n_samples / (n_classes * class_count)\n"," return class_weights\n","\n","def weighted_binary_crossentropy(class_weights):\n","\n"," def _weighted_binary_crossentropy(y_true, y_pred):\n"," binary_crossentropy = keras.binary_crossentropy(y_true, y_pred)\n"," weight_vector = y_true * class_weights[1] + (1. - y_true) * class_weights[0]\n"," weighted_binary_crossentropy = weight_vector * binary_crossentropy\n","\n"," return keras.mean(weighted_binary_crossentropy)\n","\n"," return _weighted_binary_crossentropy\n","\n","\n","def save_augment(datagen,orig_img,dir_augmented_data=\"/content/augment\"):\n"," \"\"\"\n"," Saves a subset of the augmented data for visualisation, by default in /content.\n","\n"," This is adapted from: https://fairyonice.github.io/Learn-about-ImageDataGenerator.html\n"," \n"," \"\"\"\n"," try:\n"," os.mkdir(dir_augmented_data)\n"," except:\n"," ## if the preview folder exists, then remove\n"," ## the contents (pictures) in the folder\n"," for item in os.listdir(dir_augmented_data):\n"," os.remove(dir_augmented_data + \"/\" + item)\n","\n"," ## convert the original image to array\n"," x = img_to_array(orig_img)\n"," ## reshape (Sampke, Nrow, Ncol, 3) 3 = R, G or B\n"," #print(x.shape)\n"," x = x.reshape((1,) + x.shape)\n"," #print(x.shape)\n"," ## -------------------------- ##\n"," ## randomly generate pictures\n"," ## -------------------------- ##\n"," i = 0\n"," #We will just save 5 images,\n"," #but this can be changed, but note the visualisation in 3. currently uses 5.\n"," Nplot = 5\n"," for batch in datagen.flow(x,batch_size=1,\n"," save_to_dir=dir_augmented_data,\n"," save_format='tif',\n"," seed=42):\n"," i += 1\n"," if i > Nplot - 1:\n"," break\n","\n","# Generators\n","def buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, subset, batch_size, target_size):\n"," '''\n"," Can generate image and mask at the same time use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same\n"," \n"," datagen: ImageDataGenerator \n"," subset: can take either 'training' or 'validation'\n"," '''\n"," seed = 1\n"," image_generator = image_datagen.flow_from_directory(\n"," os.path.dirname(image_folder_path),\n"," classes = [os.path.basename(image_folder_path)],\n"," class_mode = None,\n"," color_mode = \"grayscale\",\n"," target_size = target_size,\n"," batch_size = batch_size,\n"," subset = subset,\n"," interpolation = \"bicubic\",\n"," seed = seed)\n"," \n"," mask_generator = mask_datagen.flow_from_directory(\n"," os.path.dirname(mask_folder_path),\n"," classes = [os.path.basename(mask_folder_path)],\n"," class_mode = None,\n"," color_mode = \"grayscale\",\n"," target_size = target_size,\n"," batch_size = batch_size,\n"," subset = subset,\n"," interpolation = \"nearest\",\n"," seed = seed)\n"," \n"," this_generator = zip(image_generator, mask_generator)\n"," for (img,mask) in this_generator:\n"," # img,mask = adjustData(img,mask)\n"," yield (img,mask)\n","\n","\n","def prepareGenerators(image_folder_path, mask_folder_path, datagen_parameters, batch_size = 4, target_size = (512, 512)):\n"," image_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizePercentile)\n"," mask_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizeMinMax)\n","\n"," train_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'training', batch_size, target_size)\n"," validation_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'validation', batch_size, target_size)\n","\n"," return (train_datagen, validation_datagen)\n","\n","\n","# Normalization functions from Martin Weigert\n","def normalizePercentile(x, pmin=1, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","\n","\n","# Simple normalization to min/max fir the Mask\n","def normalizeMinMax(x, dtype=np.float32):\n"," x = x.astype(dtype,copy=False)\n"," x = (x - np.amin(x)) / (np.amax(x) - np.amin(x))\n"," return x\n","\n","\n","# def predictionGenerator(Data_path, target_size = (256,256), as_gray = True):\n","# for filename in os.listdir(Data_path):\n","# if not os.path.isdir(os.path.join(Data_path, filename)):\n","# img = io.imread(os.path.join(Data_path, filename), as_gray = as_gray)\n","# img = normalizePercentile(img)\n","# # img = img/255 # WARNING: this is expecting 8bit images\n","# img = transform.resize(img,target_size, preserve_range=True, anti_aliasing=True, order = 1) # liner interpolation\n","# img = np.reshape(img,img.shape+(1,))\n","# img = np.reshape(img,(1,)+img.shape)\n","# yield img\n","\n","\n","# def predictionResize(Data_path, predictions):\n","# resized_predictions = []\n","# for (i, filename) in enumerate(os.listdir(Data_path)):\n","# if not os.path.isdir(os.path.join(Data_path, filename)):\n","# img = Image.open(os.path.join(Data_path, filename))\n","# (width, height) = img.size\n","# resized_predictions.append(transform.resize(predictions[i], (height, width), preserve_range=True, anti_aliasing=True, order = 1))\n","# return resized_predictions\n","\n","\n","# This is code outlines the architecture of U-net. The choice of pooling steps decides the depth of the network. \n","def unet(pretrained_weights = None, input_size = (256,256,1), pooling_steps = 4, learning_rate = 1e-4, verbose=True, class_weights=np.ones(2)):\n"," inputs = Input(input_size)\n"," conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)\n"," conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)\n"," # Downsampling steps\n"," pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)\n"," conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)\n"," conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)\n"," \n"," if pooling_steps > 1:\n"," pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)\n"," conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)\n"," conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)\n","\n"," if pooling_steps > 2:\n"," pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)\n"," conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)\n"," conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)\n"," drop4 = Dropout(0.5)(conv4)\n"," \n"," if pooling_steps > 3:\n"," pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)\n"," conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)\n"," conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)\n"," drop5 = Dropout(0.5)(conv5)\n","\n"," #Upsampling steps\n"," up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))\n"," merge6 = concatenate([drop4,up6], axis = 3)\n"," conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)\n"," conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)\n"," \n"," if pooling_steps > 2:\n"," up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop4))\n"," if pooling_steps > 3:\n"," up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))\n"," merge7 = concatenate([conv3,up7], axis = 3)\n"," conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)\n"," \n"," if pooling_steps > 1:\n"," up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv3))\n"," if pooling_steps > 2:\n"," up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))\n"," merge8 = concatenate([conv2,up8], axis = 3)\n"," conv8 = Conv2D(128, 3, activation= 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)\n"," \n"," if pooling_steps == 1:\n"," up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv2))\n"," else:\n"," up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) #activation = 'relu'\n"," \n"," merge9 = concatenate([conv1,up9], axis = 3)\n"," conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(merge9) #activation = 'relu'\n"," conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'\n"," conv9 = Conv2D(2, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'\n"," conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)\n","\n"," model = Model(inputs = inputs, outputs = conv10)\n","\n"," # model.compile(optimizer = Adam(lr = learning_rate), loss = 'binary_crossentropy', metrics = ['acc'])\n"," model.compile(optimizer = Adam(lr = learning_rate), loss = weighted_binary_crossentropy(class_weights))\n","\n","\n"," if verbose:\n"," model.summary()\n","\n"," if(pretrained_weights):\n"," \tmodel.load_weights(pretrained_weights);\n","\n"," return model\n","\n","\n","\n","def predict_as_tiles(Image_path, model):\n","\n"," # Read the data in and normalize\n"," Image_raw = io.imread(Image_path, as_gray = True)\n"," Image_raw = normalizePercentile(Image_raw)\n","\n"," # Get the patch size from the input layer of the model\n"," patch_size = model.layers[0].output_shape[1:3]\n","\n"," # Pad the image with zeros if any of its dimensions is smaller than the patch size\n"," if Image_raw.shape[0] < patch_size[0] or Image_raw.shape[1] < patch_size[1]:\n"," Image = np.zeros((max(Image_raw.shape[0], patch_size[0]), max(Image_raw.shape[1], patch_size[1])))\n"," Image[0:Image_raw.shape[0], 0: Image_raw.shape[1]] = Image_raw\n"," else:\n"," Image = Image_raw\n","\n"," # Calculate the number of patches in each dimension\n"," n_patch_in_width = ceil(Image.shape[0]/patch_size[0])\n"," n_patch_in_height = ceil(Image.shape[1]/patch_size[1])\n","\n"," prediction = np.zeros(Image.shape)\n","\n"," for x in range(n_patch_in_width):\n"," for y in range(n_patch_in_height):\n"," xi = patch_size[0]*x\n"," yi = patch_size[1]*y\n","\n"," # If the patch exceeds the edge of the image shift it back \n"," if xi+patch_size[0] >= Image.shape[0]:\n"," xi = Image.shape[0]-patch_size[0]\n","\n"," if yi+patch_size[1] >= Image.shape[1]:\n"," yi = Image.shape[1]-patch_size[1]\n"," \n"," # Extract and reshape the patch\n"," patch = Image[xi:xi+patch_size[0], yi:yi+patch_size[1]]\n"," patch = np.reshape(patch,patch.shape+(1,))\n"," patch = np.reshape(patch,(1,)+patch.shape)\n","\n"," # Get the prediction from the patch and paste it in the prediction in the right place\n"," predicted_patch = model.predict(patch, batch_size = 1)\n"," prediction[xi:xi+patch_size[0], yi:yi+patch_size[1]] = np.squeeze(predicted_patch)\n","\n","\n"," return prediction[0:Image_raw.shape[0], 0: Image_raw.shape[1]]\n"," \n","\n","\n","\n","def saveResult(save_path, nparray, source_dir_list, prefix='', threshold=None):\n"," for (filename, image) in zip(source_dir_list, nparray):\n"," io.imsave(os.path.join(save_path, prefix+os.path.splitext(filename)[0]+'.tif'), img_as_ubyte(image)) # saving as unsigned 8-bit image\n"," \n"," # For masks, threshold the images and return 8 bit image\n"," if threshold is not None:\n"," mask = convert2Mask(image, threshold)\n"," io.imsave(os.path.join(save_path, prefix+'mask_'+os.path.splitext(filename)[0]+'.tif'), mask)\n","\n","\n","def convert2Mask(image, threshold):\n"," mask = img_as_ubyte(image, force_copy=True)\n"," mask[mask > threshold] = 255\n"," mask[mask <= threshold] = 0\n"," return mask\n","\n","\n","def getIoUvsThreshold(prediction_filepath, groud_truth_filepath):\n"," prediction = io.imread(prediction_filepath)\n"," ground_truth_image = img_as_ubyte(io.imread(groud_truth_filepath, as_gray=True), force_copy=True)\n","\n"," threshold_list = []\n"," IoU_scores_list = []\n","\n"," for threshold in range(0,256): \n"," # Convert to 8-bit for calculating the IoU\n"," mask = img_as_ubyte(prediction, force_copy=True)\n"," mask[mask > threshold] = 255\n"," mask[mask <= threshold] = 0\n","\n"," # Intersection over Union metric\n"," intersection = np.logical_and(ground_truth_image, np.squeeze(mask))\n"," union = np.logical_or(ground_truth_image, np.squeeze(mask))\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," threshold_list.append(threshold)\n"," IoU_scores_list.append(iou_score)\n","\n"," return (threshold_list, IoU_scores_list)\n","\n","\n","\n","# -------------- Other definitions -----------\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","prediction_prefix = 'Predicted_'\n","\n","\n","print('-------------------')\n","print('U-Net and dependencies installed.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"7hTKImff6Est","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---"]},{"cell_type":"markdown","metadata":{"id":"S74FbqV6PNNv","colab_type":"text"},"source":["##**3.1. Parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"3np5EpJF8_q2","colab_type":"text"},"source":[" **Paths for training data and models**\n","\n","**`Training_source`, `Training_target`:** These are the folders containing your source (e.g. EM images) and target files (segmentation masks). Enter the path to the source and target images for training. **These should be located in the same parent folder.**\n","\n","**`model_name`:** Use only my_model -style, not my-model. If you want to use a previously trained model, enter the name of the pretrained model (which should be contained in the trained_model -folder after training).\n","\n","**`model_path`**: Enter the path of the folder where you want to save your model.\n","\n","**`visual_validation_after_training`**: If you select this option, a random image pair will be set aside from your training set and will be used to display a predicted image of the trained network next to the input and the ground-truth. This can aid in visually assessing the performance of your network after training. **Note: Your training set size will decrease by 1 if you select this option.**\n","\n","**Make sure the directories exist before entering them!**\n","\n"," **Select training parameters**\n","\n","**`number_of_epochs`**: Choose more epochs for larger training sets. Observing how much the loss reduces between epochs during training may help determine the optimal value. **Default: 200**\n","\n","**Advanced parameters - experienced users only**\n","\n","**`batch_size`**: This parameter describes the amount of images that are loaded into the network per step. Smaller batchsizes may improve training performance slightly but may increase training time. If the notebook crashes while loading the dataset this can be due to a too large batch size. Decrease the number in this case. **Default: 4**\n","\n","**`number_of_steps`**: This number should be equivalent to the number of samples in the training set divided by the batch size, to ensure the training iterates through the entire training set. Smaller values can be used for testing. **Default: 6**\n","\n"," **`pooling_steps`**: Choosing a different number of pooling layers can affect the performance of the network. Each additional pooling step will also two additional convolutions. The network can learn more complex information but is also more likely to overfit. Achieving best performance may require testing different values here. **Default: 2**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**`patch_width` and `patch_height`:** The notebook crops the data in patches of fixed size prior to training. The dimensions of the patches can be defined here. When `Use_Default_Advanced_Parameters` is selected, the largest 2^n x 2^n patch that fits in the smallest dataset is chosen. Larger patches than 512x512 should **NOT** be selected for network stability.\n","\n"]},{"cell_type":"code","metadata":{"id":"7deNuPZd5d-B","colab_type":"code","cellView":"form","colab":{}},"source":["# ------------- Initial user input ------------\n","#@markdown ###Path to training images:\n","Training_source = '' #@param {type:\"string\"}\n","Training_target = '' #@param {type:\"string\"}\n","\n","model_name = '' #@param {type:\"string\"}\n","model_path = '' #@param {type:\"string\"}\n","\n","#@markdown ###Training parameters:\n","#@markdown Number of epochs\n","number_of_epochs = 200#@param {type:\"number\"}\n","\n","#@markdown ###Advanced parameters:\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 4#@param {type:\"integer\"}\n","number_of_steps = 6#@param {type:\"number\"}\n","pooling_steps = 2 #@param [1,2,3,4]{type:\"raw\"}\n","percentage_validation = 10#@param{type:\"number\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","patch_width = 512#@param{type:\"number\"}\n","patch_height = 512#@param{type:\"number\"}\n","\n","\n","# ------------- Initialising folder, variables and failsafes ------------\n","# Create the folders where to save the model and the QC\n","full_model_path = os.path.join(model_path, model_name)\n","if os.path.exists(full_model_path):\n"," print(R+'!! WARNING: Folder already exists and will be overwritten !!'+W)\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 4\n"," pooling_steps = 2\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0003\n"," patch_width, patch_height = estimatePatchSize(Training_source)\n","\n","\n","#The create_patches function will create the two folders below\n","# Patch_source = '/content/img_patches'\n","# Patch_target = '/content/mask_patches'\n","print('Training on patches of size (x,y): ('+str(patch_width)+','+str(patch_height)+')')\n","\n","#Create patches\n","print('Creating patches...')\n","Patch_source, Patch_target = create_patches(Training_source, Training_target, patch_width, patch_height)\n","\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","Use_Data_augmentation = False\n","\n","# ------------- Display ------------\n","\n","#if not os.path.exists('/content/img_patches/'):\n","random_choice = random.choice(os.listdir(Patch_source))\n","x = io.imread(os.path.join(Patch_source, random_choice))\n","\n","#os.chdir(Training_target)\n","y = io.imread(os.path.join(Patch_target, random_choice), as_gray=True)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest',cmap='gray')\n","plt.title('Training image patch')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest',cmap='gray')\n","plt.title('Training mask patch')\n","plt.axis('off');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"V9UCjlLJ5Rfc","colab_type":"text"},"source":["##**3.2. Data augmentation**\n","\n","---\n","\n"," Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if the dataset is large the values can be set to 0.\n","\n"," The augmentation options below are to be used as follows:\n","\n","* **shift**: a translation of the image by a fraction of the image size (width or height), **default: 10%**\n","* **zoom_range**: Increasing or decreasing the field of view. E.g. 10% will result in a zoom range of (0.9 to 1.1), with pixels added or interpolated, depending on the transformation, **default: 10%**\n","* **shear_range**: Shear angle in counter-clockwise direction, **default: 10%**\n","* **flip**: creating a mirror image along specified axis (horizontal or vertical), **default: True**\n","* **rotation_range**: range of allowed rotation angles in degrees (from 0 to *value*), **default: 180**"]},{"cell_type":"code","metadata":{"id":"i-PahNX94-pl","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##**Augmentation options**\n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," if Use_Default_Augmentation_Parameters:\n"," horizontal_shift = 10 \n"," vertical_shift = 20 \n"," zoom_range = 10\n"," shear_range = 10\n"," horizontal_flip = True\n"," vertical_flip = True\n"," rotation_range = 180\n","#@markdown ###If you are not using the default settings, please provide the values below:\n","\n","#@markdown ###**Image shift, zoom, shear and flip (%)**\n"," else:\n"," horizontal_shift = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," vertical_shift = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," zoom_range = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," shear_range = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," horizontal_flip = True #@param {type:\"boolean\"}\n"," vertical_flip = True #@param {type:\"boolean\"}\n","\n","#@markdown ###**Rotate image within angle range (degrees):**\n"," rotation_range = 180 #@param {type:\"slider\", min:0, max:180, step:1}\n","\n","#given behind the # are the default values for each parameter.\n","\n","else:\n"," horizontal_shift = 0 \n"," vertical_shift = 0 \n"," zoom_range = 0\n"," shear_range = 0\n"," horizontal_flip = False\n"," vertical_flip = False\n"," rotation_range = 0\n","\n","\n","# Build the dict for the ImageDataGenerator\n","data_gen_args = dict(width_shift_range = horizontal_shift/100.,\n"," height_shift_range = vertical_shift/100.,\n"," rotation_range = rotation_range, #90\n"," zoom_range = zoom_range/100.,\n"," shear_range = shear_range/100.,\n"," horizontal_flip = horizontal_flip,\n"," vertical_flip = vertical_flip,\n"," validation_split = percentage_validation/100,\n"," fill_mode = 'reflect')\n","\n","\n","\n","# ------------- Display ------------\n","dir_augmented_data_imgs=\"/content/augment_img\"\n","dir_augmented_data_masks=\"/content/augment_mask\"\n","random_choice = random.choice(os.listdir(Patch_source))\n","orig_img = load_img(os.path.join(Patch_source,random_choice))\n","orig_mask = load_img(os.path.join(Patch_target,random_choice))\n","\n","augment_view = ImageDataGenerator(**data_gen_args)\n","\n","if Use_Data_augmentation:\n"," print(\"Parameters enabled\")\n"," print(\"Here is what a subset of your augmentations looks like:\")\n"," save_augment(augment_view, orig_img, dir_augmented_data=dir_augmented_data_imgs)\n"," save_augment(augment_view, orig_mask, dir_augmented_data=dir_augmented_data_masks)\n","\n"," fig = plt.figure(figsize=(15, 7))\n"," fig.subplots_adjust(hspace=0.0,wspace=0.1,left=0,right=1.1,bottom=0, top=0.8)\n","\n"," \n"," ax = fig.add_subplot(2, 6, 1,xticks=[],yticks=[]) \n"," new_img=img_as_ubyte(normalizeMinMax(img_to_array(orig_img)))\n"," ax.imshow(new_img)\n"," ax.set_title('Original Image')\n"," i = 2\n"," for imgnm in os.listdir(dir_augmented_data_imgs):\n"," ax = fig.add_subplot(2, 6, i,xticks=[],yticks=[]) \n"," img = load_img(dir_augmented_data_imgs + \"/\" + imgnm)\n"," ax.imshow(img)\n"," i += 1\n","\n"," ax = fig.add_subplot(2, 6, 7,xticks=[],yticks=[]) \n"," new_mask=img_as_ubyte(normalizeMinMax(img_to_array(orig_mask)))\n"," ax.imshow(new_mask)\n"," ax.set_title('Original Mask')\n"," j=2\n"," for imgnm in os.listdir(dir_augmented_data_masks):\n"," ax = fig.add_subplot(2, 6, j+6,xticks=[],yticks=[]) \n"," mask = load_img(dir_augmented_data_masks + \"/\" + imgnm)\n"," ax.imshow(mask)\n"," j += 1\n"," plt.show()\n","\n","else:\n"," print(\"No augmentation will be used\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"7vFEIHbNAuOs","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a U-Net model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"RfR9UyKAAulw","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the UNET_Model_from_\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(R+'WARNING: pretrained model does not exist')\n"," Use_pretrained_model = False\n"," \n","\n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(R+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"94FX4wzE8w1W","colab_type":"text"},"source":["# **4. Train the network**\n","---\n","####**Troubleshooting:** If you receive a time-out or exhausted error, try reducing the batchsize of your training set. This reduces the amount of data loaded into the model at one point in time. "]},{"cell_type":"markdown","metadata":{"id":"tlTDGcmDDHDe","colab_type":"text"},"source":["## **4.1. Prepare model for training**\n","---"]},{"cell_type":"code","metadata":{"id":"ezFy_mpz_op4","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play this cell to prepare the model for training\n","\n","\n","# ------------------ Set the generators, model and logger ------------------\n","# This will take the image size and set that as a patch size (arguable...)\n","# Read image size (without actuall reading the data)\n","\n","\n","# n = 0 \n","# while os.path.isdir(os.path.join(Training_source, source_images[n])):\n","# n += 1\n","\n","# (width, height) = Image.open(os.path.join(Training_target, source_images[n])).size\n","# ImageSize = (height, width) # np.shape different from PIL image.size return !\n","\n","# !!! WARNING !!! Check potential issues with resizing at the ImageDataGenerator level\n","# (train_datagen, validation_datagen) = prepareGenerators(Training_source, Training_target, data_gen_args, batch_size, target_size = ImageSize)\n","(train_datagen, validation_datagen) = prepareGenerators(Patch_source, Patch_target, data_gen_args, batch_size, target_size = (patch_width, patch_height))\n","\n","\n","# This modelcheckpoint will only save the best model from the validation loss point of view\n","model_checkpoint = ModelCheckpoint(os.path.join(full_model_path, 'weights_best.hdf5'), monitor='val_loss',verbose=1, save_best_only=True)\n","\n","print('Getting class weights...')\n","class_weights = getClassWeights(Training_target)\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we make sure this is properly defined\n","if not Use_pretrained_model:\n"," h5_file_path = None\n","# --------------------- ---------------------- ------------------------\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# --------------------- Reduce learning rate on plateau ------------------------\n","\n","reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, verbose=1, mode='auto',\n"," patience=10, min_lr=0)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# Define the model\n","model = unet(pretrained_weights = h5_file_path, \n"," input_size = (patch_width,patch_height,1), \n"," pooling_steps = pooling_steps, \n"," learning_rate = initial_learning_rate, \n"," class_weights = class_weights)\n","\n","# Dfine CSV logger that will create the loss file (we're not using this anylonger)\n","# csv_log = CSVLogger(os.path.join(full_model_path, 'Quality Control', 'training_evaluation.csv'), separator=',', append=False)\n","\n","number_of_training_dataset = len(os.listdir(Patch_source))\n","\n","if Use_Default_Advanced_Parameters:\n"," number_of_steps = ceil((100-percentage_validation)/100*number_of_training_dataset/batch_size)\n","\n","# Calculate the number of steps to use for validation\n","validation_steps = max(1, ceil(percentage_validation/100*number_of_training_dataset/batch_size))\n","\n","config_model= model.optimizer.get_config()\n","print(config_model)\n","\n","\n","# ------------------ Failsafes ------------------\n","if os.path.exists(full_model_path):\n"," print(R+'!! WARNING: Model folder already existed and has been removed !!'+W)\n"," shutil.rmtree(full_model_path)\n","\n","os.makedirs(full_model_path)\n","os.makedirs(os.path.join(full_model_path,'Quality Control'))\n","\n","\n","# ------------------ Display ------------------\n","print('---------------------------- Main training parameters ----------------------------')\n","print('Number of epochs: '+str(number_of_epochs))\n","print('Batch size: '+str(batch_size))\n","print('Number of training dataset: '+str(number_of_training_dataset))\n","print('Number of training steps: '+str(number_of_steps))\n","print('Number of validation steps: '+str(validation_steps))\n","print('---------------------------- ------------------------ ----------------------------')\n","\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"urpQ9UM-6NBE","colab_type":"text"},"source":["## **4.2. Start Trainning**\n","---\n","\n","####**Be patient**. Please be patient, this may take a while. But the verbose allow you to estimate how fast it's training and how long it'll take. While it's training, please make sure that the computer is not powering down due to inactivity, otherwise this will interupt the runtime."]},{"cell_type":"code","metadata":{"id":"sMyCENd29TKz","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","# history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps, epochs=epochs, callbacks=[model_checkpoint,csv_log], validation_data = validation_datagen, validation_steps = validation_steps, shuffle=True, verbose=1)\n","history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps, epochs = number_of_epochs, callbacks=[model_checkpoint, reduce_lr], validation_data = validation_datagen, validation_steps = validation_steps, shuffle=True, verbose=1)\n","\n","# Save the last model\n","model.save(os.path.join(full_model_path, 'weights_last.hdf5'))\n","\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = os.path.join(full_model_path,'Quality Control/training_evaluation.csv')\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n"," \n","\n","\n","# Displaying the time elapsed for training\n","print(\"------------------------------------------\")\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\", hour, \"hour(s)\", mins,\"min(s)\",round(sec),\"sec(s)\")\n","print(\"------------------------------------------\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"LWaFk0JNda-N","colab_type":"text"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"mEMcFNHZdmTz","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"X11zGW0Ldu-z","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###Do you want to assess the model you just trained ?\n","\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","\n","full_QC_model_path = os.path.join(QC_model_path, QC_model_name)\n","if os.path.exists(os.path.join(full_QC_model_path, 'weights_best.hdf5')):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"pkJyRzWJCrKG","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"qul6BpaX1GqS","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","epochNumber = []\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(os.path.join(full_QC_model_path, 'Quality Control', 'lossCurvePlots.png'))\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"h33P0C2geqZu","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder. The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n","The Input, Ground Truth, Prediction and IoU maps are shown below for the last example in the QC set.\n","\n"," The results for all QC examples can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\".\n","\n","### **Thresholds for image masks**\n","\n"," Since the output from Unet is not a binary mask, the output images are converted to binary masks using thresholding. This section will test different thresholds (from 0 to 255) to find the one yielding the best IoU score compared with the ground truth. The best threshold for each image and the average of these thresholds will be displayed below. **These values can be a guideline when creating masks for unseen data in section 6.**"]},{"cell_type":"code","metadata":{"id":"Tpqjvwv2zug-","colab_type":"code","cellView":"form","colab":{}},"source":["# ------------- User input ------------\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","\n","# ------------- Initialise folders ------------\n","# Create a quality control/Prediction Folder\n","prediction_QC_folder = os.path.join(full_QC_model_path, 'Quality Control', 'Prediction')\n","if os.path.exists(prediction_QC_folder):\n"," shutil.rmtree(prediction_QC_folder)\n","\n","os.makedirs(prediction_QC_folder)\n","\n","\n","# ------------- Prepare the model and run predictions ------------\n","\n","# Load the model\n","unet = load_model(os.path.join(full_QC_model_path, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n","Input_size = unet.layers[0].output_shape[1:3]\n","print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))\n","\n","# Create a list of sources\n","source_dir_list = os.listdir(Source_QC_folder)\n","number_of_dataset = len(source_dir_list)\n","print('Number of dataset found in the folder: '+str(number_of_dataset))\n","\n","predictions = []\n","for i in tqdm(range(number_of_dataset)):\n"," predictions.append(predict_as_tiles(os.path.join(Source_QC_folder, source_dir_list[i]), unet))\n","\n","\n","# Save the results in the folder along with the masks according to the set threshold\n","saveResult(prediction_QC_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=None)\n","\n","#-----------------------------Calculate Metrics----------------------------------------#\n","\n","f = plt.figure(figsize=((5,5)))\n","\n","with open(os.path.join(full_QC_model_path,'Quality Control', 'QC_metrics_'+QC_model_name+'.csv'), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"File name\",\"IoU\", \"IoU-optimised threshold\"]) \n","\n"," # Initialise the lists \n"," filename_list = []\n"," best_threshold_list = []\n"," best_IoU_score_list = []\n","\n"," for filename in os.listdir(Source_QC_folder):\n","\n"," if not os.path.isdir(os.path.join(Source_QC_folder, filename)):\n"," print('Running QC on: '+filename)\n"," test_input = io.imread(os.path.join(Source_QC_folder, filename), as_gray=True)\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, filename), as_gray=True)\n","\n"," (threshold_list, iou_scores_per_threshold) = getIoUvsThreshold(os.path.join(prediction_QC_folder, prediction_prefix+filename), os.path.join(Target_QC_folder, filename))\n"," plt.plot(threshold_list,iou_scores_per_threshold, label=filename)\n","\n"," # Here we find which threshold yielded the highest IoU score for image n.\n"," best_IoU_score = max(iou_scores_per_threshold)\n"," best_threshold = iou_scores_per_threshold.index(best_IoU_score)\n","\n"," # Write the results in the CSV file\n"," writer.writerow([filename, str(best_IoU_score), str(best_threshold)])\n","\n"," # Here we append the best threshold and score to the lists\n"," filename_list.append(filename)\n"," best_IoU_score_list.append(best_IoU_score)\n"," best_threshold_list.append(best_threshold)\n","\n","# Display the IoV vs Threshold plot\n","plt.title('IoU vs. Threshold')\n","plt.ylabel('Threshold value')\n","plt.xlabel('IoU')\n","plt.legend()\n","plt.show()\n","\n","\n","# Table with metrics as dataframe output\n","pdResults = pd.DataFrame(index = filename_list)\n","pdResults[\"IoU\"] = best_IoU_score_list\n","pdResults[\"IoU-optimised threshold\"] = best_threshold_list\n","\n","\n","\n","average_best_threshold = sum(best_threshold_list)/len(best_threshold_list)\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_QC_results(file=os.listdir(Source_QC_folder)):\n"," \n"," plt.figure(figsize=(25,5))\n"," #Input\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," plt.imshow(plt.imread(os.path.join(Source_QC_folder, file)), aspect='equal', cmap='gray', interpolation='nearest')\n"," plt.title('Input')\n","\n"," #Ground-truth\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, file),as_gray=True)\n"," plt.imshow(test_ground_truth_image, aspect='equal', cmap='Greens')\n"," plt.title('Ground Truth')\n","\n"," #Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," test_prediction = plt.imread(os.path.join(prediction_QC_folder, prediction_prefix+file))\n"," test_prediction_mask = np.empty_like(test_prediction)\n"," test_prediction_mask[test_prediction > average_best_threshold] = 255\n"," test_prediction_mask[test_prediction <= average_best_threshold] = 0\n"," plt.imshow(test_prediction_mask, aspect='equal', cmap='Purples')\n"," plt.title('Prediction')\n","\n"," #Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(test_ground_truth_image, cmap='Greens')\n"," plt.imshow(test_prediction_mask, alpha=0.5, cmap='Purples')\n"," metrics_title = 'Overlay (IoU: ' + str(round(pdResults.loc[file][\"IoU\"],3)) + ' T: ' + str(round(pdResults.loc[file][\"IoU-optimised threshold\"])) + ')'\n"," plt.title(metrics_title)\n","\n","\n","\n","print('--------------------------------------------------------------')\n","print('Best average threshold is: '+str(round(average_best_threshold)))\n","print('--------------------------------------------------------------')\n","\n","pdResults.head()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"gofmRsLP96O8","colab_type":"text"},"source":["# **6. Using the trained model**\n","\n","---\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"Pv_v1Ru2OJkU","colab_type":"text"},"source":["## **6.1 Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.1) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n"," Once the predictions are complete the cell will display a random example prediction beside the input image and the calculated mask for visual inspection.\n","\n"," **Troubleshooting:** If there is a low contrast image warning when saving the images, this may be due to overfitting of the model to the data. It may result in images containing only a single colour. Train the network again with different network hyperparameters."]},{"cell_type":"code","metadata":{"id":"FJAe55ZoOJGs","colab_type":"code","cellView":"form","colab":{}},"source":["# ------------- Initial user input ------------\n","#@markdown ###Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","Data_folder = '' #@param {type:\"string\"}\n","Results_folder = '' #@param {type:\"string\"}\n","\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","\n","# ------------- Failsafes ------------\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = os.path.join(Prediction_model_path, Prediction_model_name)\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","# ------------- Prepare the model and run predictions ------------\n","\n","# Load the model and prepare generator\n","\n","unet = load_model(os.path.join(Prediction_model_path, Prediction_model_name, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n","Input_size = unet.layers[0].output_shape[1:3]\n","print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))\n","\n","# Create a list of sources\n","source_dir_list = os.listdir(Data_folder)\n","number_of_dataset = len(source_dir_list)\n","print('Number of dataset found in the folder: '+str(number_of_dataset))\n","\n","predictions = []\n","for i in tqdm(range(number_of_dataset)):\n"," predictions.append(predict_as_tiles(os.path.join(Data_folder, source_dir_list[i]), unet))\n"," # predictions.append(prediction(os.path.join(Data_folder, source_dir_list[i]), os.path.join(Prediction_model_path, Prediction_model_name)))\n","\n","\n","# Save the results in the folder along with the masks according to the set threshold\n","saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=None)\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","\n","def show_prediction_mask(file=os.listdir(Data_folder), threshold=(0,255,1)):\n","\n"," plt.figure(figsize=(18,6))\n"," # Wide-field\n"," plt.subplot(1,3,1)\n"," plt.axis('off')\n"," img_Source = plt.imread(os.path.join(Data_folder, file))\n"," plt.imshow(img_Source, cmap='gray')\n"," plt.title('Source image',fontsize=15)\n"," # Prediction\n"," plt.subplot(1,3,2)\n"," plt.axis('off')\n"," img_Prediction = plt.imread(os.path.join(Results_folder, prediction_prefix+file))\n"," plt.imshow(img_Prediction, cmap='gray')\n"," plt.title('Prediction',fontsize=15)\n","\n"," # Thresholded mask\n"," plt.subplot(1,3,3)\n"," plt.axis('off')\n"," img_Mask = convert2Mask(img_Prediction, threshold)\n"," plt.imshow(img_Mask, cmap='gray')\n"," plt.title('Mask (Threshold: '+str(round(threshold))+')',fontsize=15)\n","\n","\n","interact(show_prediction_mask, continuous_update=False);\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"su-Mo2POVpja","colab_type":"text"},"source":["## **6.2. Export results as masks**\n","---\n"]},{"cell_type":"code","metadata":{"id":"iC_B_9lxNUny","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# @markdown #Play this cell to save results as masks with the chosen threshold\n","threshold = 120#@param {type:\"number\"}\n","\n","saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=threshold)\n","print('-------------------')\n","print('Masks were saved in: '+Results_folder)\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wYmwCQKjYsJ7","colab_type":"text"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"sCXzzvnh2_rc","colab_type":"text"},"source":["#**Thank you for using U-Net!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/YOLOv2_ZeroCostDL4Mic.ipynb b/Colab_notebooks/YOLOv2_ZeroCostDL4Mic.ipynb index 752fc61f..ff31023c 100755 --- a/Colab_notebooks/YOLOv2_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/YOLOv2_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"YOLOv2_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1LWs9bFbYclR1nWaupcSPUYFN6yyUU_5t","timestamp":1596536407170},{"file_id":"1uUjR8Sm2l6vAJfclb84gUUH4MCwzQUWO","timestamp":1594734310956},{"file_id":"1zileODcR2RNrVSidXNuBfgFDv68JRRa0","timestamp":1593093410185},{"file_id":"1EpgWlJK6U_ZwlBGiomLfbxx9UUtRPBTy","timestamp":1592904104821},{"file_id":"1f5usS6p8Cu_efegMwcR3v68AVOXBSyIf","timestamp":1588870626184},{"file_id":"1fM7obTEQKnSgVZMDa1KjiBgiBar2b0t8","timestamp":1588693012611},{"file_id":"1owWtQQucUxUOZMaPh2x_mxe_qXKHCZhp","timestamp":1588074588514},{"file_id":"159ARwlQE7-zi0EHxunOF_YPFLt-ZVU5x","timestamp":1587562499898},{"file_id":"1W-7NHehG5MRFILvZZzhPWWnOdJMkadb2","timestamp":1586332290412},{"file_id":"1pUetEQICxYWkYVaQIgdRH1EZBTl7oc2A","timestamp":1586292199692},{"file_id":"1MD36ZkM6XR9EuV12zimJmfCjzyeYZFWq","timestamp":1586269469061},{"file_id":"16A2mbaHzlEElntS8qkFBOsBvZG-mUeY6","timestamp":1586253795726},{"file_id":"1gJlcjOiSxr2buDOxmcFbT_d-GqwLjXtK","timestamp":1583343225796},{"file_id":"10yGI51WzHfgWgZAyE-EbkZFEvIOd6CP6","timestamp":1583171396283}],"collapsed_sections":[]},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I","colab_type":"text"},"source":["# **YOLOv2**\n","---\n","\n"," YOLOv2 is a deep-learning method designed to perform object detection and classification of objects in images, published by [Redmon and Farhadi](https://ieeexplore.ieee.org/document/8100173). This is based on the original [YOLO](https://arxiv.org/abs/1506.02640) implementation published by the same authors. YOLOv2 is trained on images with class annotations in the form of bounding boxes drawn around the objects of interest. The images are downsampled by a convolutional neural network (CNN) and objects are classified in two final fully connected layers in the network. YOLOv2 learns classification and object detection simultaneously by taking the whole input image into account, predicting many possible bounding box solutions, and then using regression to find the best bounding boxes and classifications for each object.\n","\n","**This particular notebook enables object detection and classification on 2D images given ground truth bounding boxes. If you are interested in image segmentation, you should use our U-net or Stardist notebooks instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following papers: \n","\n","**YOLO9000: Better, Faster, Stronger** from Joseph Redmon and Ali Farhadi in Proceedings of the IEEE conference on computer vision and pattern recognition, 2017, (https://ieeexplore.ieee.org/document/8100173)\n","\n","**You Only Look Once: Unified, Real-Time Object Detection** from Joseph Redmon, Santosh Divvala, Ross Girshick, Ali Farhadi in IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2016, (https://ieeexplore.ieee.org/document/7780460)\n","\n","**Note: The source code for this notebook is adapted for keras and can be found in: (https://github.com/experiencor/keras-yolo2)**\n","\n","\n","**Please also cite these original papers when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use ZeroCostDL4Mic notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," Preparing the dataset carefully is essential to make this YOLOv2 notebook work. This model requires as input a set of images (currently .jpg) and as target a list of annotation files in Pascal VOC format. The annotation files should have the exact same name as the input files, except with an .xml instead of the .jpg extension. The annotation files contain the class labels and all bounding boxes for the objects for each image in your dataset. Most datasets will give the option of saving the annotations in this format or using software for hand-annotations will automatically save the annotations in this format. \n","\n"," If you want to assemble your own dataset we recommend using the open source https://www.makesense.ai/ resource. You can follow our instructions on how to label your dataset with this tool on our [wiki](https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki/Object-Detection-(YOLOv2)).\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .png or .jpg files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Input images (Training_source)\n"," - img_1.png, img_2.png, ...\n"," - High SNR images (Training_source_annotations)\n"," - img_1.xml, img_2.xml, ...\n"," - **Quality control dataset**\n"," - Input images\n"," - img_1.png, img_2.png\n"," - High SNR images\n"," - img_1.xml, img_2.xml\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"BCPhV-pe-syw","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"VNZetvLiS1qV","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"UBrnApIUBgxv","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin","colab_type":"text"},"source":["# **2. Install YOLOv2 and Dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install Network and Dependencies\n","%tensorflow_version 1.x\n","!pip install pascal-voc-writer\n","from pascal_voc_writer import Writer\n","from __future__ import division\n","from __future__ import print_function\n","from __future__ import absolute_import\n","import csv\n","import random\n","import pprint\n","import sys\n","import time\n","import numpy as np\n","from optparse import OptionParser\n","import pickle\n","import math\n","import cv2\n","import copy\n","import math\n","from matplotlib import pyplot as plt\n","import matplotlib.patches as patches\n","import tensorflow as tf\n","import pandas as pd\n","import os\n","import shutil\n","from skimage import io\n","from sklearn.metrics import average_precision_score\n","\n","from keras.models import Model\n","from keras.layers import Flatten, Dense, Input, Conv2D, MaxPooling2D, Dropout, Reshape, Activation, Conv2D, MaxPooling2D, BatchNormalization, Lambda\n","from keras.layers.advanced_activations import LeakyReLU\n","from keras.layers.merge import concatenate\n","from keras.applications.mobilenet import MobileNet\n","from keras.applications import InceptionV3\n","from keras.applications.vgg16 import VGG16\n","from keras.applications.resnet50 import ResNet50\n","\n","from keras import backend as K\n","from keras.optimizers import Adam, SGD, RMSprop\n","from keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, TimeDistributed\n","from keras.engine.topology import get_source_inputs\n","from keras.utils import layer_utils\n","from keras.utils.data_utils import get_file\n","from keras.objectives import categorical_crossentropy\n","from keras.models import Model\n","from keras.utils import generic_utils\n","from keras.engine import Layer, InputSpec\n","from keras import initializers, regularizers\n","from keras.utils import Sequence\n","import xml.etree.ElementTree as ET\n","from collections import OrderedDict, Counter\n","import json\n","import imageio\n","import imgaug as ia\n","from imgaug import augmenters as iaa\n","import copy\n","import cv2\n","from tqdm import tqdm\n","from tempfile import mkstemp\n","from shutil import move, copymode\n","from os import fdopen, remove\n","ia.seed(1)\n","# imgaug uses matplotlib backend for displaying images\n","from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage\n","import re\n","import glob\n","\n","!git clone https://github.com/rodrigo2019/keras_yolo2.git\n","\n","if os.path.exists('/content/gdrive/My Drive/keras-yolo2'):\n"," shutil.rmtree('/content/gdrive/My Drive/keras-yolo2')\n","\n","!git clone https://github.com/experiencor/keras-yolo2.git\n","shutil.move('/content/keras-yolo2','/content/gdrive/My Drive/keras-yolo2')\n","shutil.move('/content/keras_yolo2/keras_yolov2/map_evaluation.py','/content/gdrive/My Drive/keras-yolo2/map_evaluation.py')\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","\n","from backend import BaseFeatureExtractor, FullYoloFeature\n","from preprocessing import parse_annotation, BatchGenerator\n","\n","#shutil.move('/content/map_evaluation.py','/content/gdrive/My Drive/keras-yolo2/map_evaluation.py')\n","\n","print(\"Depencies installed and imported.\")\n","\n","def plt_rectangle(plt,label,x1,y1,x2,y2,fontsize=10):\n"," '''\n"," == Input ==\n"," \n"," plt : matplotlib.pyplot object\n"," label : string containing the object class name\n"," x1 : top left corner x coordinate\n"," y1 : top left corner y coordinate\n"," x2 : bottom right corner x coordinate\n"," y2 : bottom right corner y coordinate\n"," '''\n"," linewidth = 1\n"," color = \"yellow\"\n"," plt.text(x1,y1,label,fontsize=fontsize,backgroundcolor=\"magenta\")\n"," plt.plot([x1,x1],[y1,y2], linewidth=linewidth,color=color)\n"," plt.plot([x2,x2],[y1,y2], linewidth=linewidth,color=color)\n"," plt.plot([x1,x2],[y1,y1], linewidth=linewidth,color=color)\n"," plt.plot([x1,x2],[y2,y2], linewidth=linewidth,color=color)\n","\n","def extract_single_xml_file(tree,object_count=True):\n"," Nobj = 0\n"," row = OrderedDict()\n"," for elems in tree.iter():\n","\n"," if elems.tag == \"size\":\n"," for elem in elems:\n"," row[elem.tag] = int(elem.text)\n"," if elems.tag == \"object\":\n"," for elem in elems:\n"," if elem.tag == \"name\":\n"," row[\"bbx_{}_{}\".format(Nobj,elem.tag)] = str(elem.text) \n"," if elem.tag == \"bndbox\":\n"," for k in elem:\n"," row[\"bbx_{}_{}\".format(Nobj,k.tag)] = float(k.text)\n"," Nobj += 1\n"," if object_count == True:\n"," row[\"Nobj\"] = Nobj\n"," return(row)\n","\n","def count_objects(tree):\n"," Nobj=0\n"," for elems in tree.iter():\n"," if elems.tag == \"object\":\n"," for elem in elems:\n"," if elem.tag == \"bndbox\":\n"," Nobj += 1\n"," return(Nobj)\n","\n","def compute_overlap(a, b):\n"," \"\"\"\n"," Code originally from https://github.com/rbgirshick/py-faster-rcnn.\n"," Parameters\n"," ----------\n"," a: (N, 4) ndarray of float\n"," b: (K, 4) ndarray of float\n"," Returns\n"," -------\n"," overlaps: (N, K) ndarray of overlap between boxes and query_boxes\n"," \"\"\"\n"," area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])\n","\n"," iw = np.minimum(np.expand_dims(a[:, 2], axis=1), b[:, 2]) - np.maximum(np.expand_dims(a[:, 0], 1), b[:, 0])\n"," ih = np.minimum(np.expand_dims(a[:, 3], axis=1), b[:, 3]) - np.maximum(np.expand_dims(a[:, 1], 1), b[:, 1])\n","\n"," iw = np.maximum(iw, 0)\n"," ih = np.maximum(ih, 0)\n","\n"," ua = np.expand_dims((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), axis=1) + area - iw * ih\n","\n"," ua = np.maximum(ua, np.finfo(float).eps)\n","\n"," intersection = iw * ih\n","\n"," return intersection / ua\n","\n","def compute_ap(recall, precision):\n"," \"\"\" Compute the average precision, given the recall and precision curves.\n"," Code originally from https://github.com/rbgirshick/py-faster-rcnn.\n","\n"," # Arguments\n"," recall: The recall curve (list).\n"," precision: The precision curve (list).\n"," # Returns\n"," The average precision as computed in py-faster-rcnn.\n"," \"\"\"\n"," # correct AP calculation\n"," # first append sentinel values at the end\n"," mrec = np.concatenate(([0.], recall, [1.]))\n"," mpre = np.concatenate(([0.], precision, [0.]))\n","\n"," # compute the precision envelope\n"," for i in range(mpre.size - 1, 0, -1):\n"," mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])\n","\n"," # to calculate area under PR curve, look for points\n"," # where X axis (recall) changes value\n"," i = np.where(mrec[1:] != mrec[:-1])[0]\n","\n"," # and sum (\\Delta recall) * prec\n"," ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])\n"," return ap \n","\n","def load_annotation(image_folder,annotations_folder, i, config):\n"," annots = []\n"," imgs, anns = parse_annotation(annotations_folder,image_folder,config['model']['labels'])\n"," for obj in imgs[i]['object']:\n"," annot = [obj['xmin'], obj['ymin'], obj['xmax'], obj['ymax'], config['model']['labels'].index(obj['name'])]\n"," annots += [annot]\n","\n"," if len(annots) == 0: annots = [[]]\n","\n"," return np.array(annots)\n","\n","def _calc_avg_precisions(config,image_folder,annotations_folder,weights_path,iou_threshold,score_threshold):\n","\n"," # gather all detections and annotations\n"," all_detections = [[None for _ in range(len(config['model']['labels']))] for _ in range(len(os.listdir(image_folder)))]\n"," all_annotations = [[None for _ in range(len(config['model']['labels']))] for _ in range(len(os.listdir(annotations_folder)))]\n","\n"," for i in range(len(os.listdir(image_folder))):\n"," raw_image = cv2.imread(os.path.join(image_folder,sorted(os.listdir(image_folder))[i]))\n"," raw_height, raw_width, _ = raw_image.shape\n"," #print(raw_height)\n"," # make the boxes and the labels\n"," yolo = YOLO(backend = config['model']['backend'],\n"," input_size = config['model']['input_size'], \n"," labels = config['model']['labels'], \n"," max_box_per_image = config['model']['max_box_per_image'],\n"," anchors = config['model']['anchors'])\n"," yolo.load_weights(weights_path)\n"," pred_boxes = yolo.predict(raw_image,iou_threshold=iou_threshold,score_threshold=score_threshold)\n","\n"," score = np.array([box.score for box in pred_boxes])\n"," #print(score)\n"," pred_labels = np.array([box.label for box in pred_boxes])\n"," #print(len(pred_boxes))\n"," if len(pred_boxes) > 0:\n"," pred_boxes = np.array([[box.xmin * raw_width, box.ymin * raw_height, box.xmax * raw_width,\n"," box.ymax * raw_height, box.score] for box in pred_boxes])\n"," else:\n"," pred_boxes = np.array([[]])\n","\n"," # sort the boxes and the labels according to scores\n"," score_sort = np.argsort(-score)\n"," pred_labels = pred_labels[score_sort]\n"," pred_boxes = pred_boxes[score_sort]\n","\n"," # copy detections to all_detections\n"," for label in range(len(config['model']['labels'])):\n"," all_detections[i][label] = pred_boxes[pred_labels == label, :]\n","\n"," annotations = load_annotation(image_folder,annotations_folder,i,config)\n","\n"," # copy ground truth to all_annotations\n"," for label in range(len(config['model']['labels'])):\n"," all_annotations[i][label] = annotations[annotations[:, 4] == label, :4].copy()\n","\n"," # compute mAP by comparing all detections and all annotations\n"," average_precisions = {}\n"," total_recall = []\n"," total_precision = []\n"," for label in range(len(config['model']['labels'])):\n"," false_positives = np.zeros((0,))\n"," true_positives = np.zeros((0,))\n"," scores = np.zeros((0,))\n"," num_annotations = 0.0\n","\n"," for i in range(len(os.listdir(image_folder))):\n"," detections = all_detections[i][label]\n"," annotations = all_annotations[i][label]\n"," num_annotations += annotations.shape[0]\n"," detected_annotations = []\n","\n"," for d in detections:\n"," scores = np.append(scores, d[4])\n","\n"," if annotations.shape[0] == 0:\n"," false_positives = np.append(false_positives, 1)\n"," true_positives = np.append(true_positives, 0)\n"," continue\n","\n"," overlaps = compute_overlap(np.expand_dims(d, axis=0), annotations)\n"," assigned_annotation = np.argmax(overlaps, axis=1)\n"," max_overlap = overlaps[0, assigned_annotation]\n","\n"," if max_overlap >= iou_threshold and assigned_annotation not in detected_annotations:\n"," false_positives = np.append(false_positives, 0)\n"," true_positives = np.append(true_positives, 1)\n"," detected_annotations.append(assigned_annotation)\n"," else:\n"," false_positives = np.append(false_positives, 1)\n"," true_positives = np.append(true_positives, 0)\n","\n"," # no annotations -> AP for this class is 0 (is this correct?)\n"," if num_annotations == 0:\n"," average_precisions[label] = 0\n"," continue\n","\n"," # sort by score\n"," indices = np.argsort(-scores)\n"," false_positives = false_positives[indices]\n"," true_positives = true_positives[indices]\n","\n"," # compute false positives and true positives\n"," false_positives = np.cumsum(false_positives)\n"," true_positives = np.cumsum(true_positives)\n","\n"," # compute recall and precision\n"," recall = true_positives / num_annotations\n"," precision = true_positives / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps)\n"," total_recall.append(recall)\n"," total_precision.append(precision)\n"," #print(precision)\n"," # compute average precision\n"," average_precision = compute_ap(recall, precision)\n"," average_precisions[label] = average_precision\n","\n"," return average_precisions, total_recall, total_precision\n","\n","\n","def show_frame(pred_bb, pred_classes, pred_conf, gt_bb, gt_classes, class_dict, background=np.zeros((512, 512, 3)), show_confidence=True):\n"," \"\"\"\n"," Here, we are adapting classes and functions from https://github.com/MathGaron/mean_average_precision\n"," \"\"\"\n"," \"\"\"\n"," Plot the boundingboxes\n"," :param pred_bb: (np.array) Predicted Bounding Boxes [x1, y1, x2, y2] : Shape [n_pred, 4]\n"," :param pred_classes: (np.array) Predicted Classes : Shape [n_pred]\n"," :param pred_conf: (np.array) Predicted Confidences [0.-1.] : Shape [n_pred]\n"," :param gt_bb: (np.array) Ground Truth Bounding Boxes [x1, y1, x2, y2] : Shape [n_gt, 4]\n"," :param gt_classes: (np.array) Ground Truth Classes : Shape [n_gt]\n"," :param class_dict: (dictionary) Key value pairs of classes, e.g. {0:'dog',1:'cat',2:'horse'}\n"," :return:\n"," \"\"\"\n"," n_pred = pred_bb.shape[0]\n"," n_gt = gt_bb.shape[0]\n"," n_class = int(np.max(np.append(pred_classes, gt_classes)) + 1)\n"," #print(n_class)\n"," if len(background.shape) < 3:\n"," h, w = background.shape\n"," else:\n"," h, w, c = background.shape\n","\n"," ax = plt.subplot(\"111\")\n"," ax.imshow(background)\n"," cmap = plt.cm.get_cmap('hsv')\n","\n"," confidence_alpha = pred_conf.copy()\n"," if not show_confidence:\n"," confidence_alpha.fill(1)\n","\n"," for i in range(n_pred):\n"," x1 = pred_bb[i, 0]# * w\n"," y1 = pred_bb[i, 1]# * h\n"," x2 = pred_bb[i, 2]# * w\n"," y2 = pred_bb[i, 3]# * h\n"," rect_w = x2 - x1\n"," rect_h = y2 - y1\n"," #print(x1, y1)\n"," ax.add_patch(patches.Rectangle((x1, y1), rect_w, rect_h,\n"," fill=False,\n"," edgecolor=cmap(float(pred_classes[i]) / n_class),\n"," linestyle='dashdot',\n"," alpha=confidence_alpha[i]))\n","\n"," for i in range(n_gt):\n"," x1 = gt_bb[i, 0]# * w\n"," y1 = gt_bb[i, 1]# * h\n"," x2 = gt_bb[i, 2]# * w\n"," y2 = gt_bb[i, 3]# * h\n"," rect_w = x2 - x1\n"," rect_h = y2 - y1\n"," ax.add_patch(patches.Rectangle((x1, y1), rect_w, rect_h,\n"," fill=False,\n"," edgecolor=cmap(float(gt_classes[i]) / n_class)))\n","\n"," legend_handles = []\n","\n"," for i in range(n_class):\n"," legend_handles.append(patches.Patch(color=cmap(float(i) / n_class), label=class_dict[i]))\n"," \n"," ax.legend(handles=legend_handles)\n"," plt.show()\n","\n","class BoundBox:\n"," \"\"\"\n"," Here, we are adapting classes and functions from https://github.com/MathGaron/mean_average_precision\n"," \"\"\"\n"," def __init__(self, xmin, ymin, xmax, ymax, c = None, classes = None):\n"," self.xmin = xmin\n"," self.ymin = ymin\n"," self.xmax = xmax\n"," self.ymax = ymax\n"," \n"," self.c = c\n"," self.classes = classes\n","\n"," self.label = -1\n"," self.score = -1\n","\n"," def get_label(self):\n"," if self.label == -1:\n"," self.label = np.argmax(self.classes)\n"," \n"," return self.label\n"," \n"," def get_score(self):\n"," if self.score == -1:\n"," self.score = self.classes[self.get_label()]\n"," \n"," return self.score\n","\n","class WeightReader:\n"," def __init__(self, weight_file):\n"," self.offset = 4\n"," self.all_weights = np.fromfile(weight_file, dtype='float32')\n"," \n"," def read_bytes(self, size):\n"," self.offset = self.offset + size\n"," return self.all_weights[self.offset-size:self.offset]\n"," \n"," def reset(self):\n"," self.offset = 4\n","\n","def bbox_iou(box1, box2):\n"," intersect_w = _interval_overlap([box1.xmin, box1.xmax], [box2.xmin, box2.xmax])\n"," intersect_h = _interval_overlap([box1.ymin, box1.ymax], [box2.ymin, box2.ymax]) \n"," \n"," intersect = intersect_w * intersect_h\n","\n"," w1, h1 = box1.xmax-box1.xmin, box1.ymax-box1.ymin\n"," w2, h2 = box2.xmax-box2.xmin, box2.ymax-box2.ymin\n"," \n"," union = w1*h1 + w2*h2 - intersect\n"," \n"," return float(intersect) / union\n","\n","def draw_boxes(image, boxes, labels):\n"," image_h, image_w, _ = image.shape\n"," #Changes in box color added by LvC\n"," # class_colours = []\n"," # for c in range(len(labels)):\n"," # colour = np.random.randint(low=0,high=255,size=3).tolist()\n"," # class_colours.append(tuple(colour))\n"," for box in boxes:\n"," xmin = int(box.xmin*image_w)\n"," ymin = int(box.ymin*image_h)\n"," xmax = int(box.xmax*image_w)\n"," ymax = int(box.ymax*image_h)\n"," if box.get_label() == 0:\n"," cv2.rectangle(image, (xmin,ymin), (xmax,ymax), (255,0,0), 3)\n"," elif box.get_label() == 1:\n"," cv2.rectangle(image, (xmin,ymin), (xmax,ymax), (0,255,0), 3)\n"," else:\n"," cv2.rectangle(image, (xmin,ymin), (xmax,ymax), (0,0,255), 3)\n"," #cv2.rectangle(image, (xmin,ymin), (xmax,ymax), class_colours[box.get_label()], 3)\n"," cv2.putText(image, \n"," labels[box.get_label()] + ' ' + str(round(box.get_score(),3)), \n"," (xmin, ymin - 13), \n"," cv2.FONT_HERSHEY_SIMPLEX, \n"," 1e-3 * image_h, \n"," (0,0,0), 2)\n"," #print(box.get_label()) \n"," return image \n","\n","#Function added by LvC\n","def save_boxes(image_path, boxes, labels):#, save_path):\n"," image = cv2.imread(image_path)\n"," image_h, image_w, _ = image.shape\n"," save_boxes =[]\n"," save_boxes_names = []\n"," save_boxes.append(os.path.basename(image_path))\n"," save_boxes_names.append(os.path.basename(image_path))\n"," for box in boxes:\n"," # xmin = box.xmin\n"," save_boxes.append(int(box.xmin*image_w))\n"," save_boxes_names.append(int(box.xmin*image_w))\n"," # ymin = box.ymin\n"," save_boxes.append(int(box.ymin*image_h))\n"," save_boxes_names.append(int(box.ymin*image_h))\n"," # xmax = box.xmax\n"," save_boxes.append(int(box.xmax*image_w))\n"," save_boxes_names.append(int(box.xmax*image_w))\n"," # ymax = box.ymax\n"," save_boxes.append(int(box.ymax*image_h))\n"," save_boxes_names.append(int(box.ymax*image_h))\n"," score = box.get_score()\n"," save_boxes.append(score)\n"," save_boxes_names.append(score)\n"," label = box.get_label()\n"," save_boxes.append(label)\n"," save_boxes_names.append(labels[label])\n"," \n"," #This file will be for later analysis of the bounding boxes in imagej\n"," if not os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," with open('/content/predicted_bounding_boxes.csv', 'w', newline='') as csvfile:\n"," csvwriter = csv.writer(csvfile, delimiter=',')\n"," specs_list = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*len(boxes)\n"," csvwriter.writerow(specs_list)\n"," csvwriter.writerow(save_boxes)\n"," else:\n"," with open('/content/predicted_bounding_boxes.csv', 'a+', newline='') as csvfile:\n"," csvwriter = csv.writer(csvfile)\n"," csvwriter.writerow(save_boxes)\n"," \n"," if not os.path.exists('/content/predicted_bounding_boxes_names.csv'):\n"," with open('/content/predicted_bounding_boxes_names.csv', 'w', newline='') as csvfile_names:\n"," csvwriter = csv.writer(csvfile_names, delimiter=',')\n"," specs_list = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*len(boxes)\n"," csvwriter.writerow(specs_list)\n"," csvwriter.writerow(save_boxes_names)\n"," else:\n"," with open('/content/predicted_bounding_boxes_names.csv', 'a+', newline='') as csvfile_names:\n"," csvwriter = csv.writer(csvfile_names)\n"," csvwriter.writerow(save_boxes_names)\n"," # #This file is to create a nicer display for the output images\n"," # if not os.path.exists('/content/predicted_bounding_boxes_display.csv'):\n"," # with open('/content/predicted_bounding_boxes_display.csv', 'w', newline='') as csvfile_new:\n"," # csvwriter2 = csv.writer(csvfile_new, delimiter=',')\n"," # specs_list = ['filename','width','height','class','xmin','ymin','xmax','ymax']\n"," # csvwriter2.writerow(specs_list)\n"," # else:\n"," # with open('/content/predicted_bounding_boxes_display.csv','a+',newline='') as csvfile_new:\n"," # csvwriter2 = csv.writer(csvfile_new)\n"," # for box in boxes:\n"," # row = [os.path.basename(image_path),image_w,image_h,box.get_label(),int(box.xmin*image_w),int(box.ymin*image_h),int(box.xmax*image_w),int(box.ymax*image_h)]\n"," # csvwriter2.writerow(row)\n","\n","def add_header(inFilePath,outFilePath):\n"," header = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*max(n_objects)\n"," with open(inFilePath, newline='') as inFile, open(outFilePath, 'w', newline='') as outfile:\n"," r = csv.reader(inFile)\n"," w = csv.writer(outfile)\n"," next(r, None) # skip the first row from the reader, the old header\n"," # write new header\n"," w.writerow(header)\n"," # copy the rest\n"," for row in r:\n"," w.writerow(row)\n"," \n","def decode_netout(netout, anchors, nb_class, obj_threshold=0.3, nms_threshold=0.5):\n"," grid_h, grid_w, nb_box = netout.shape[:3]\n","\n"," boxes = []\n"," \n"," # decode the output by the network\n"," netout[..., 4] = _sigmoid(netout[..., 4])\n"," netout[..., 5:] = netout[..., 4][..., np.newaxis] * _softmax(netout[..., 5:])\n"," netout[..., 5:] *= netout[..., 5:] > obj_threshold\n"," \n"," for row in range(grid_h):\n"," for col in range(grid_w):\n"," for b in range(nb_box):\n"," # from 4th element onwards are confidence and class classes\n"," classes = netout[row,col,b,5:]\n"," \n"," if np.sum(classes) > 0:\n"," # first 4 elements are x, y, w, and h\n"," x, y, w, h = netout[row,col,b,:4]\n","\n"," x = (col + _sigmoid(x)) / grid_w # center position, unit: image width\n"," y = (row + _sigmoid(y)) / grid_h # center position, unit: image height\n"," w = anchors[2 * b + 0] * np.exp(w) / grid_w # unit: image width\n"," h = anchors[2 * b + 1] * np.exp(h) / grid_h # unit: image height\n"," confidence = netout[row,col,b,4]\n"," \n"," box = BoundBox(x-w/2, y-h/2, x+w/2, y+h/2, confidence, classes)\n"," \n"," boxes.append(box)\n","\n"," # suppress non-maximal boxes\n"," for c in range(nb_class):\n"," sorted_indices = list(reversed(np.argsort([box.classes[c] for box in boxes])))\n","\n"," for i in range(len(sorted_indices)):\n"," index_i = sorted_indices[i]\n"," \n"," if boxes[index_i].classes[c] == 0: \n"," continue\n"," else:\n"," for j in range(i+1, len(sorted_indices)):\n"," index_j = sorted_indices[j]\n"," \n"," if bbox_iou(boxes[index_i], boxes[index_j]) >= nms_threshold:\n"," boxes[index_j].classes[c] = 0\n"," \n"," # remove the boxes which are less likely than a obj_threshold\n"," boxes = [box for box in boxes if box.get_score() > obj_threshold]\n"," \n"," return boxes\n","\n","def replace(file_path, pattern, subst):\n"," #Create temp file\n"," fh, abs_path = mkstemp()\n"," with fdopen(fh,'w') as new_file:\n"," with open(file_path) as old_file:\n"," for line in old_file:\n"," new_file.write(line.replace(pattern, subst))\n"," #Copy the file permissions from the old file to the new file\n"," copymode(file_path, abs_path)\n"," #Remove original file\n"," remove(file_path)\n"," #Move new file\n"," move(abs_path, file_path)\n","\n","with open(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"r\") as check:\n"," lineReader = check.readlines()\n"," reduce_lr = False\n"," for line in lineReader:\n"," if \"reduce_lr\" in line:\n"," reduce_lr = True\n"," break\n","\n","if reduce_lr == False:\n"," #replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\",\"period=1)\",\"period=1)\\n csv_logger=CSVLogger('/content/training_evaluation.csv')\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\",\"period=1)\",\"period=1)\\n reduce_lr=ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, verbose=1)\")\n","replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\",\"import EarlyStopping\",\"import ReduceLROnPlateau, EarlyStopping\")\n","\n","with open(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"r\") as check:\n"," lineReader = check.readlines()\n"," map_eval = False\n"," for line in lineReader:\n"," if \"map_evaluation\" in line:\n"," map_eval = True\n"," break\n","\n","if map_eval == False:\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"import cv2\",\"import cv2\\nfrom map_evaluation import MapEvaluation\")\n"," new_callback = ' map_evaluator = MapEvaluation(self, valid_generator,save_best=True,save_name=\"/content/gdrive/My Drive/keras-yolo2/best_map_weights.h5\",iou_threshold=0.3,score_threshold=0.3)'\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\",\"write_images=False)\",\"write_images=False)\\n\"+new_callback)\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\"import keras\",\"import keras\\nimport csv\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\"from .utils\",\"from utils\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\".format(_map))\",\".format(_map))\\n with open('/content/gdrive/My Drive/mAP.csv','a+', newline='') as mAP_csv:\\n csv_writer=csv.writer(mAP_csv)\\n csv_writer.writerow(['mAP:','{:.4f}'.format(_map)])\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\"iou_threshold=0.5\",\"iou_threshold=0.3\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\"score_threshold=0.5\",\"score_threshold=0.3\")\n","\n","replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"[early_stop, checkpoint, tensorboard]\",\"[checkpoint, reduce_lr, map_evaluator]\")\n","replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"predict(self, image)\",\"predict(self,image,iou_threshold=0.3,score_threshold=0.3)\")\n","replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"self.model.summary()\",\"#self.model.summary()\")\n","from frontend import YOLO\n","\n","def train(config_path, model_path, percentage_validation):\n"," #config_path = args.conf\n","\n"," with open(config_path) as config_buffer: \n"," config = json.loads(config_buffer.read())\n","\n"," ###############################\n"," # Parse the annotations \n"," ###############################\n","\n"," # parse annotations of the training set\n"," train_imgs, train_labels = parse_annotation(config['train']['train_annot_folder'], \n"," config['train']['train_image_folder'], \n"," config['model']['labels'])\n","\n"," # parse annotations of the validation set, if any, otherwise split the training set\n"," if os.path.exists(config['valid']['valid_annot_folder']):\n"," valid_imgs, valid_labels = parse_annotation(config['valid']['valid_annot_folder'], \n"," config['valid']['valid_image_folder'], \n"," config['model']['labels'])\n"," else:\n"," train_valid_split = int((1-percentage_validation/100.)*len(train_imgs))\n"," np.random.shuffle(train_imgs)\n","\n"," valid_imgs = train_imgs[train_valid_split:]\n"," train_imgs = train_imgs[:train_valid_split]\n","\n"," if len(config['model']['labels']) > 0:\n"," overlap_labels = set(config['model']['labels']).intersection(set(train_labels.keys()))\n","\n"," print('Seen labels:\\t', train_labels)\n"," print('Given labels:\\t', config['model']['labels'])\n"," print('Overlap labels:\\t', overlap_labels) \n","\n"," if len(overlap_labels) < len(config['model']['labels']):\n"," print('Some labels have no annotations! Please revise the list of labels in the config.json file!')\n"," return\n"," else:\n"," print('No labels are provided. Train on all seen labels.')\n"," config['model']['labels'] = train_labels.keys()\n"," \n"," ###############################\n"," # Construct the model \n"," ###############################\n","\n"," yolo = YOLO(backend = config['model']['backend'],\n"," input_size = config['model']['input_size'], \n"," labels = config['model']['labels'], \n"," max_box_per_image = config['model']['max_box_per_image'],\n"," anchors = config['model']['anchors'])\n","\n"," ###############################\n"," # Load the pretrained weights (if any) \n"," ############################### \n","\n"," if os.path.exists(config['train']['pretrained_weights']):\n"," print(\"Loading pre-trained weights in\", config['train']['pretrained_weights'])\n"," yolo.load_weights(config['train']['pretrained_weights'])\n"," if os.path.exists('/content/gdrive/My Drive/mAP.csv'):\n"," os.remove('/content/gdrive/My Drive/mAP.csv')\n"," ###############################\n"," # Start the training process \n"," ###############################\n","\n"," yolo.train(train_imgs = train_imgs,\n"," valid_imgs = valid_imgs,\n"," train_times = config['train']['train_times'],\n"," valid_times = config['valid']['valid_times'],\n"," nb_epochs = config['train']['nb_epochs'], \n"," learning_rate = config['train']['learning_rate'], \n"," batch_size = config['train']['batch_size'],\n"," warmup_epochs = config['train']['warmup_epochs'],\n"," object_scale = config['train']['object_scale'],\n"," no_object_scale = config['train']['no_object_scale'],\n"," coord_scale = config['train']['coord_scale'],\n"," class_scale = config['train']['class_scale'],\n"," saved_weights_name = config['train']['saved_weights_name'],\n"," debug = config['train']['debug'])\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n"," lossDataCSVpath = os.path.join(model_path,'Quality Control/training_evaluation.csv')\n"," with open(lossDataCSVpath, 'w') as f1:\n"," writer = csv.writer(f1)\n"," mAP_df = pd.read_csv('/content/gdrive/My Drive/mAP.csv',header=None)\n"," writer.writerow(['loss','val_loss','mAP','learning rate'])\n"," for i in range(len(yolo.model.history.history['loss'])):\n"," writer.writerow([yolo.model.history.history['loss'][i], yolo.model.history.history['val_loss'][i], float(mAP_df[1][i]), yolo.model.history.history['lr'][i]])\n","\n"," yolo.model.save(model_path+'/last_weights.h5')\n","\n","def predict(config, weights_path, image_path):#, model_path):\n","\n"," with open(config) as config_buffer: \n"," config = json.load(config_buffer)\n","\n"," ###############################\n"," # Make the model \n"," ###############################\n","\n"," yolo = YOLO(backend = config['model']['backend'],\n"," input_size = config['model']['input_size'], \n"," labels = config['model']['labels'], \n"," max_box_per_image = config['model']['max_box_per_image'],\n"," anchors = config['model']['anchors'])\n","\n"," ###############################\n"," # Load trained weights\n"," ############################### \n","\n"," yolo.load_weights(weights_path)\n","\n"," ###############################\n"," # Predict bounding boxes \n"," ###############################\n","\n"," if image_path[-4:] == '.mp4':\n"," video_out = image_path[:-4] + '_detected' + image_path[-4:]\n"," video_reader = cv2.VideoCapture(image_path)\n","\n"," nb_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT))\n"," frame_h = int(video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT))\n"," frame_w = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH))\n","\n"," video_writer = cv2.VideoWriter(video_out,\n"," cv2.VideoWriter_fourcc(*'MPEG'), \n"," 50.0, \n"," (frame_w, frame_h))\n","\n"," for i in tqdm(range(nb_frames)):\n"," _, image = video_reader.read()\n"," \n"," boxes = yolo.predict(image)\n"," image = draw_boxes(image, boxes, config['model']['labels'])\n","\n"," video_writer.write(np.uint8(image))\n","\n"," video_reader.release()\n"," video_writer.release() \n"," else:\n"," image = cv2.imread(image_path)\n"," boxes = yolo.predict(image)\n"," image = draw_boxes(image, boxes, config['model']['labels'])\n"," save_boxes(image_path,boxes,config['model']['labels'])#,model_path)#added by LvC\n"," print(len(boxes), 'boxes are found')\n"," #print(image)\n"," cv2.imwrite(image_path[:-4] + '_detected' + image_path[-4:], image)\n"," \n"," return len(boxes)\n","\n","# function to convert BoundingBoxesOnImage object into DataFrame\n","def bbs_obj_to_df(bbs_object):\n","# convert BoundingBoxesOnImage object into array\n"," bbs_array = bbs_object.to_xyxy_array()\n","# convert array into a DataFrame ['xmin', 'ymin', 'xmax', 'ymax'] columns\n"," df_bbs = pd.DataFrame(bbs_array, columns=['xmin', 'ymin', 'xmax', 'ymax'])\n"," return df_bbs\n","\n","# Function that will extract column data for our CSV file\n","def xml_to_csv(path):\n"," xml_list = []\n"," for xml_file in glob.glob(path + '/*.xml'):\n"," tree = ET.parse(xml_file)\n"," root = tree.getroot()\n"," for member in root.findall('object'):\n"," value = (root.find('filename').text,\n"," int(root.find('size')[0].text),\n"," int(root.find('size')[1].text),\n"," member[0].text,\n"," int(member[4][0].text),\n"," int(member[4][1].text),\n"," int(member[4][2].text),\n"," int(member[4][3].text)\n"," )\n"," xml_list.append(value)\n"," column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']\n"," xml_df = pd.DataFrame(xml_list, columns=column_name)\n"," return xml_df"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4","colab_type":"text"},"source":["# **3. Select your paths and parameters**\n","\n","---\n","\n","The code below allows the user to enter the paths to where the training data is and to define the training parameters.\n","\n","After playing the cell will display some quantitative metrics of your dataset, including a count of objects per image and the number of instances per class.\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd","colab_type":"text"},"source":["# **3.1. Parameters and paths**\n","---\n","\n","**`Training_source:`, `Training_source_annotations`:** These are the paths to your folders containing the Training_source and the annotation data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:**Give estimates for training performance given a number of epochs and provide a default value. **Default value: 27**\n","\n","**Note that YOLOv2 uses 3 Warm-up epochs which improves the model's performance. This means the network will train for number_of_epochs + 3 epochs.**\n","\n","**`backend`:** There are different backends which are available to be trained for YOLO. These are usually slightly different model architectures, with pretrained weights. Take a look at the available backends and research which one will be best suited for your dataset.\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`train_times:`**Input how many times to cycle through the dataset per epoch. This is more useful for smaller datasets (but risks overfitting). **Default value: 4**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`learning_rate:`** Input the initial value to be used as learning rate. **Default value: 0.0004**\n","\n","**`false_negative_penalty:`** Penalize wrong detection of 'no-object'. **Default: 5.0**\n","\n","**`false_positive_penalty:`** Penalize wrong detection of 'object'. **Default: 1.0**\n","\n","**`position_size_penalty:`** Penalize inaccurate positioning or size of bounding boxes. **Default:1.0**\n","\n","**`false_class_penalty:`** Penalize misclassification of object in bounding box. **Default: 1.0**\n","\n","**`percentage_validation:`** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** "]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","colab_type":"code","cellView":"form","colab":{}},"source":["class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#@markdown ###Path to training images:\n","\n","Training_Source = \"\" #@param {type:\"string\"}\n","\n","# Ground truth images\n","Training_Source_annotations = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# backend\n","#@markdown ###Choose a backend\n","#os.chdir(model_path+'/keras-yolo2')\n","backend = \"Full Yolo\" #@param [\"Select Model\",\"Full Yolo\",\"Inception3\",\"SqueezeNet\",\"MobileNet\",\"Tiny Yolo\"]\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","if backend == \"Full Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/full_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/full_yolo_backend.h5\n","elif backend == \"Inception3\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/inception_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/inception_backend.h5\n","elif backend == \"MobileNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/mobilenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/mobilenet_backend.h5\n","elif backend == \"SqueezeNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/squeezenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/squeezenet_backend.h5\n","elif backend == \"Tiny Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/tiny_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/tiny_yolo_backend.h5\n","\n","#os.chdir('/content/drive/My Drive/Zero-Cost Deep-Learning to Enhance Microscopy/Various dataset/Detection_Dataset_2/BCCD.v2.voc')\n","#if not os.path.exists(model_path+'/full_raccoon.h5'):\n"," # !wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1NWbrpMGLc84ow-4gXn2mloFocFGU595s' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id=1NWbrpMGLc84ow-4gXn2mloFocFGU595s\" -O full_yolo_raccoon.h5 && rm -rf /tmp/cookies.txt\n","\n","full_model_path = os.path.join(model_path,model_name)\n","if os.path.exists(full_model_path):\n"," print('Existing model path will be overwritten')\n"," shutil.rmtree(full_model_path)\n","os.mkdir(full_model_path)\n","\n","full_model_file_path = full_model_path+'/best_weights.h5'\n","os.chdir('/content/gdrive/My Drive/keras-yolo2/')\n","\n","#Change backend name\n","!sed -i 's@\\\"backend\\\":.*,@\\\"backend\\\": \\\"$backend\\\",@g' config.json\n","\n","#Change the name of the training folder\n","!sed -i 's@\\\"train_image_folder\\\":.*,@\\\"train_image_folder\\\": \\\"$Training_Source/\\\",@g' config.json\n","\n","#Change annotation folder\n","!sed -i 's@\\\"train_annot_folder\\\":.*,@\\\"train_annot_folder\\\": \\\"$Training_Source_annotations/\\\",@g' config.json\n","\n","#Change the name of the saved model\n","!sed -i 's@\\\"saved_weights_name\\\":.*,@\\\"saved_weights_name\\\": \\\"$full_model_file_path\\\",@g' config.json\n","\n","#Change warmup epochs for untrained model\n","!sed -i 's@\\\"warmup_epochs\\\":.*,@\\\"warmup_epochs\\\": 3,@g' config.json\n","\n","#When defining a new model we should reset the pretrained model parameter\n","!sed -i 's@\\\"pretrained_weights\\\":.*,@\\\"pretrained_weights\\\": \\\"No_pretrained_weights\\\",@g' config.json\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","\n","number_of_epochs = 10#@param {type:\"number\"}\n","!sed -i 's@\\\"nb_epochs\\\":.*,@\\\"nb_epochs\\\": $number_of_epochs,@g' config.json\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","train_times = 4 #@param {type:\"integer\"}\n","batch_size = 4#@param {type:\"number\"}\n","learning_rate = 1e-4 #@param{type:\"number\"}\n","false_negative_penalty = 5.0 #@param{type:\"number\"}\n","false_positive_penalty = 2.0 #@param{type:\"number\"}\n","position_size_penalty = 1.0 #@param{type:\"number\"}\n","false_class_penalty = 1.0 #@param{type:\"number\"}\n","percentage_validation = 10#@param{type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," train_times = 4\n"," batch_size = 8\n"," learning_rate = 1e-4\n"," false_negative_penalty = 5.0\n"," false_positive_penalty = 1.0\n"," position_size_penalty = 1.0\n"," false_class_penalty = 1.0\n"," percentage_validation = 10\n","\n","!sed -i 's@\\\"train_times\\\":.*,@\\\"train_times\\\": $train_times,@g' config.json\n","!sed -i 's@\\\"batch_size\\\":.*,@\\\"batch_size\\\": $batch_size,@g' config.json\n","!sed -i 's@\\\"learning_rate\\\":.*,@\\\"learning_rate\\\": $learning_rate,@g' config.json\n","!sed -i 's@\\\"object_scale\":.*,@\\\"object_scale\\\": $false_negative_penalty,@g' config.json\n","!sed -i 's@\\\"no_object_scale\":.*,@\\\"no_object_scale\\\": $false_positive_penalty,@g' config.json\n","!sed -i 's@\\\"coord_scale\\\":.*,@\\\"coord_scale\\\": $position_size_penalty,@g' config.json\n","!sed -i 's@\\\"class_scale\\\":.*,@\\\"class_scale\\\": $false_class_penalty,@g' config.json\n","\n","df_anno = []\n","dir_anno = Training_Source_annotations\n","for fnm in os.listdir(dir_anno): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(dir_anno,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno.append(row)\n","df_anno = pd.DataFrame(df_anno)\n","\n","maxNobj = np.max(df_anno[\"Nobj\"])\n","\n","#Write the annotations to a csv file\n","df_anno.to_csv(model_path+'/annot.csv', index=False)#header=False, sep=',')\n","\n","file_suffix = os.path.splitext(os.listdir(Training_Source)[0])[1]\n","\n","#Show how many objects there are in the images\n","plt.figure()\n","plt.subplot(2,1,1)\n","plt.hist(df_anno[\"Nobj\"].values,bins=50)\n","plt.title(\"max N of objects per image={}\".format(maxNobj))\n","plt.show()\n","\n","#Show the classes and how many there are of each in the dataset\n","from collections import Counter\n","class_obj = []\n","for ibbx in range(maxNobj):\n"," class_obj.extend(df_anno[\"bbx_{}_name\".format(ibbx)].values)\n","class_obj = np.array(class_obj)\n","\n","count = Counter(class_obj[class_obj != 'nan'])\n","print(count)\n","class_nm = list(count.keys())\n","class_labels = json.dumps(class_nm)\n","class_count = list(count.values())\n","asort_class_count = np.argsort(class_count)\n","\n","class_nm = np.array(class_nm)[asort_class_count]\n","class_count = np.array(class_count)[asort_class_count]\n","\n","!sed -i 's@\\\"labels\\\":.*@\\\"labels\\\": $class_labels@g' config.json\n","xs = range(len(class_count))\n","\n","plt.subplot(2,1,2)\n","plt.barh(xs,class_count)\n","plt.yticks(xs,class_nm)\n","plt.title(\"The number of objects per class: {} objects in total\".format(len(count)))\n","plt.show()\n","\n","\n","#Generate anchors for the bounding boxes\n","import subprocess as sp\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","output = sp.getoutput('python ./gen_anchors.py -c ./config.json')\n","\n","anchors_1 = output.find(\"[\")\n","anchors_2 = output.find(\"]\")\n","\n","config_anchors = output[anchors_1:anchors_2+1]\n","!sed -i 's@\\\"anchors\\\":.*,@\\\"anchors\\\": $config_anchors,@g' config.json\n","#here we check that no model with the same name already exist, if so delete\n","#if os.path.exists(model_path+'/'+model_name):\n"," # shutil.rmtree(model_path+'/'+model_name)\n","\n","Use_pretrained_model = False"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab_type":"code","cellView":"form","id":"NXxj-Xi3Kang","colab":{}},"source":["#@markdown ###Play this cell to visualise some example images from your dataset to make sure annotations and images are properly matched.\n","import imageio\n"," \n","size = 3 \n","ind_random = np.random.randint(0,df_anno.shape[0],size=size)\n","img_dir=Training_Source\n","\n","file_suffix = os.path.splitext(os.listdir(Training_Source)[0])[1]\n","for irow in ind_random:\n"," row = df_anno.iloc[irow,:]\n"," path = os.path.join(img_dir, row[\"fileID\"] + file_suffix)\n"," # read in image\n"," img = imageio.imread(path)\n","\n"," plt.figure(figsize=(12,12))\n"," plt.imshow(img) # plot image\n"," plt.title(\"Nobj={}, height={}, width={}\".format(row[\"Nobj\"],row[\"height\"],row[\"width\"]))\n"," # for each object in the image, plot the bounding box\n"," for iplot in range(row[\"Nobj\"]):\n"," plt_rectangle(plt,\n"," label = row[\"bbx_{}_name\".format(iplot)],\n"," x1=row[\"bbx_{}_xmin\".format(iplot)],\n"," y1=row[\"bbx_{}_ymin\".format(iplot)],\n"," x2=row[\"bbx_{}_xmax\".format(iplot)],\n"," y2=row[\"bbx_{}_ymax\".format(iplot)])\n"," plt.show() ## show the plot"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"eik5zLKWpN_O","colab_type":"text"},"source":["##**3.2. Data augmentation**\n","\n","---\n","\n"," Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if the dataset the `Use_Data_Augmentation` box can be unticked.\n","\n","Here, the images and bounding boxes are augmented by flipping and rotation. When doubling the dataset the images are only flipped. With each higher factor of augmentation the images added to the dataset represent one further rotation to the right by 90 degrees. 8x augmentation will give a dataset that is fully rotated and flipped once."]},{"cell_type":"code","metadata":{"id":"RmTSfMO-pNMc","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##**Augmentation Options**\n","\n","def image_aug(df, images_path, aug_images_path, image_prefix, augmentor):\n"," # create data frame which we're going to populate with augmented image info\n"," aug_bbs_xy = pd.DataFrame(columns=\n"," ['filename','width','height','class', 'xmin', 'ymin', 'xmax', 'ymax']\n"," )\n"," grouped = df.groupby('filename')\n"," \n"," for filename in df['filename'].unique():\n"," # get separate data frame grouped by file name\n"," group_df = grouped.get_group(filename)\n"," group_df = group_df.reset_index()\n"," group_df = group_df.drop(['index'], axis=1) \n"," # read the image\n"," image = imageio.imread(images_path+filename)\n"," # get bounding boxes coordinates and write into array \n"," bb_array = group_df.drop(['filename', 'width', 'height', 'class'], axis=1).values\n"," # pass the array of bounding boxes coordinates to the imgaug library\n"," bbs = BoundingBoxesOnImage.from_xyxy_array(bb_array, shape=image.shape)\n"," # apply augmentation on image and on the bounding boxes\n"," image_aug, bbs_aug = augmentor(image=image, bounding_boxes=bbs)\n"," # disregard bounding boxes which have fallen out of image pane \n"," bbs_aug = bbs_aug.remove_out_of_image()\n"," # clip bounding boxes which are partially outside of image pane\n"," bbs_aug = bbs_aug.clip_out_of_image()\n"," \n"," # don't perform any actions with the image if there are no bounding boxes left in it \n"," if re.findall('Image...', str(bbs_aug)) == ['Image([]']:\n"," pass\n"," \n"," # otherwise continue\n"," else:\n"," # write augmented image to a file\n"," imageio.imwrite(aug_images_path+image_prefix+filename, image_aug) \n"," # create a data frame with augmented values of image width and height\n"," info_df = group_df.drop(['xmin', 'ymin', 'xmax', 'ymax'], axis=1) \n"," for index, _ in info_df.iterrows():\n"," info_df.at[index, 'width'] = image_aug.shape[1]\n"," info_df.at[index, 'height'] = image_aug.shape[0]\n"," # rename filenames by adding the predifined prefix\n"," info_df['filename'] = info_df['filename'].apply(lambda x: image_prefix+x)\n"," # create a data frame with augmented bounding boxes coordinates using the function we created earlier\n"," bbs_df = bbs_obj_to_df(bbs_aug)\n"," # concat all new augmented info into new data frame\n"," aug_df = pd.concat([info_df, bbs_df], axis=1)\n"," # append rows to aug_bbs_xy data frame\n"," aug_bbs_xy = pd.concat([aug_bbs_xy, aug_df]) \n"," \n"," # return dataframe with updated images and bounding boxes annotations \n"," aug_bbs_xy = aug_bbs_xy.reset_index()\n"," aug_bbs_xy = aug_bbs_xy.drop(['index'], axis=1)\n"," return aug_bbs_xy\n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","multiply_dataset_by = 3 #@param {type:\"slider\", min:2, max:8, step:1}\n","\n","rotation_range = 90\n","\n","if (Use_Data_augmentation):\n"," print('Data Augmentation enabled')\n"," # load images as NumPy arrays and append them to images list\n"," if os.path.exists(Training_Source+'/.ipynb_checkpoints'):\n"," shutil.rmtree(Training_Source+'/.ipynb_checkpoints')\n"," \n"," images = []\n"," for index, file in enumerate(glob.glob(Training_Source+'/*'+file_suffix)):\n"," images.append(imageio.imread(file))\n"," \n"," # how many images we have\n"," print('Augmenting {} images'.format(len(images)))\n","\n"," # apply xml_to_csv() function to convert all XML files in images/ folder into labels.csv\n"," labels_df = xml_to_csv(Training_Source_annotations)\n"," labels_df.to_csv(('/content/original_labels.csv'), index=None)\n"," \n"," # Apply flip augmentation\n"," aug = iaa.OneOf([ \n"," iaa.Fliplr(1),\n"," iaa.Flipud(1)\n"," ])\n"," aug_2 = iaa.Affine(rotate=rotation_range, fit_output=True)\n"," aug_3 = iaa.Affine(rotate=rotation_range*2, fit_output=True)\n"," aug_4 = iaa.Affine(rotate=rotation_range*3, fit_output=True)\n","\n"," #Here we create a folder that will hold the original image dataset and the augmented image dataset\n"," augmented_training_source = os.path.dirname(Training_Source)+'/'+os.path.basename(Training_Source)+'_augmentation'\n"," if os.path.exists(augmented_training_source):\n"," shutil.rmtree(augmented_training_source)\n"," os.mkdir(augmented_training_source)\n","\n"," #Here we create a folder that will hold the original image annotation dataset and the augmented image annotation dataset (the bounding boxes).\n"," augmented_training_source_annotation = os.path.dirname(Training_Source_annotations)+'/'+os.path.basename(Training_Source_annotations)+'_augmentation'\n"," if os.path.exists(augmented_training_source_annotation):\n"," shutil.rmtree(augmented_training_source_annotation)\n"," os.mkdir(augmented_training_source_annotation)\n","\n"," #Create the augmentation\n"," augmented_images_df = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'flip_', aug)\n"," \n"," # Concat resized_images_df and augmented_images_df together and save in a new all_labels.csv file\n"," all_labels_df = pd.concat([labels_df, augmented_images_df])\n"," all_labels_df.to_csv('/content/combined_labels.csv', index=False)\n","\n"," #Here we convert the new bounding boxes for the augmented images to PASCAL VOC .xml format\n"," def convert_to_xml(df,source,target_folder):\n"," grouped = df.groupby('filename')\n"," for file in os.listdir(source):\n"," #if file in grouped.filename:\n"," group_df = grouped.get_group(file)\n"," group_df = group_df.reset_index()\n"," group_df = group_df.drop(['index'], axis=1)\n"," #group_df = group_df.dropna(axis=0)\n"," writer = Writer(source+'/'+file,group_df.iloc[1]['width'],group_df.iloc[1]['height'])\n"," for i, row in group_df.iterrows():\n"," writer.addObject(row['class'],round(row['xmin']),round(row['ymin']),round(row['xmax']),round(row['ymax']))\n"," writer.save(target_folder+'/'+os.path.splitext(file)[0]+'.xml')\n"," convert_to_xml(all_labels_df,augmented_training_source,augmented_training_source_annotation)\n"," \n"," #Second round of augmentation\n"," if multiply_dataset_by > 2:\n"," aug_labels_df_2 = xml_to_csv(augmented_training_source_annotation)\n"," augmented_images_2_df = image_aug(aug_labels_df_2, augmented_training_source+'/', augmented_training_source+'/', 'rot1_90_', aug_2)\n"," all_aug_labels_df = pd.concat([augmented_images_df, augmented_images_2_df])\n"," #all_labels_df.to_csv('/content/all_labels_aug.csv', index=False)\n"," \n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df,augmented_training_source,augmented_training_source_annotation)\n","\n"," if multiply_dataset_by > 3:\n"," print('Augmenting again')\n"," aug_labels_df_3 = xml_to_csv(augmented_training_source_annotation)\n"," augmented_images_3_df = image_aug(aug_labels_df_3, augmented_training_source+'/', augmented_training_source+'/', 'rot2_90_', aug_2)\n"," all_aug_labels_df_3 = pd.concat([all_aug_labels_df, augmented_images_3_df])\n","\n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_3,augmented_training_source,augmented_training_source_annotation)\n"," \n"," #This is a preliminary remover of potential duplicates in the augmentation\n"," #Ideally, duplicates are not even produced, but this acts as a fail safe.\n"," if multiply_dataset_by==4:\n"," for file in os.listdir(augmented_training_source):\n"," if file.startswith('rot2_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n","\n"," if multiply_dataset_by > 4:\n"," print('And Again')\n"," aug_labels_df_4 = xml_to_csv(augmented_training_source_annotation)\n"," augmented_images_4_df = image_aug(aug_labels_df_4, augmented_training_source+'/',augmented_training_source+'/','rot3_90_', aug_2)\n"," all_aug_labels_df_4 = pd.concat([all_aug_labels_df_3, augmented_images_4_df])\n","\n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_4,augmented_training_source,augmented_training_source_annotation)\n","\n"," for file in os.listdir(augmented_training_source):\n"," if file.startswith('rot3_90_rot2_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n"," if file.startswith('rot3_90_rot1_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n"," if file.startswith('rot3_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n"," if file.startswith('rot2_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n","\n","\n"," if multiply_dataset_by > 5:\n"," print('And again')\n"," augmented_images_5_df = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'rot_90_', aug_2)\n"," all_aug_labels_df_5 = pd.concat([all_aug_labels_df_4,augmented_images_5_df])\n","\n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," \n"," convert_to_xml(all_aug_labels_df_5,augmented_training_source,augmented_training_source_annotation)\n","\n"," if multiply_dataset_by > 6:\n"," print('And again')\n"," augmented_images_df_6 = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'rot_180_', aug_3)\n"," all_aug_labels_df_6 = pd.concat([all_aug_labels_df_5,augmented_images_df_6])\n"," \n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_6,augmented_training_source,augmented_training_source_annotation)\n","\n"," if multiply_dataset_by > 7:\n"," print('And again')\n"," augmented_images_df_7 = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'rot_270_', aug_4)\n"," all_aug_labels_df_7 = pd.concat([all_aug_labels_df_6,augmented_images_df_7])\n"," \n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_7,augmented_training_source,augmented_training_source_annotation)\n","\n"," for file in os.listdir(Training_Source):\n"," shutil.copyfile(Training_Source+'/'+file,augmented_training_source+'/'+file)\n"," shutil.copyfile(Training_Source_annotations+'/'+os.path.splitext(file)[0]+'.xml',augmented_training_source_annotation+'/'+os.path.splitext(file)[0]+'.xml')\n"," # display new dataframe\n"," #augmented_images_df\n"," \n"," os.chdir('/content/gdrive/My Drive/keras-yolo2')\n"," #Change the name of the training folder\n"," !sed -i 's@\\\"train_image_folder\\\":.*,@\\\"train_image_folder\\\": \\\"$augmented_training_source/\\\",@g' config.json\n","\n"," #Change annotation folder\n"," !sed -i 's@\\\"train_annot_folder\\\":.*,@\\\"train_annot_folder\\\": \\\"$augmented_training_source_annotation/\\\",@g' config.json\n","\n"," df_anno = []\n"," dir_anno = augmented_training_source_annotation\n"," for fnm in os.listdir(dir_anno): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(dir_anno,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno.append(row)\n"," df_anno = pd.DataFrame(df_anno)\n","\n"," maxNobj = np.max(df_anno[\"Nobj\"])\n","\n"," #Write the annotations to a csv file\n"," #df_anno.to_csv(model_path+'/annot.csv', index=False)#header=False, sep=',')\n","\n"," #Show how many objects there are in the images\n"," plt.figure()\n"," plt.subplot(2,1,1)\n"," plt.hist(df_anno[\"Nobj\"].values,bins=50)\n"," plt.title(\"max N of objects per image={}\".format(maxNobj))\n"," plt.show()\n","\n"," #Show the classes and how many there are of each in the dataset\n"," from collections import Counter\n"," class_obj = []\n"," for ibbx in range(maxNobj):\n"," class_obj.extend(df_anno[\"bbx_{}_name\".format(ibbx)].values)\n"," class_obj = np.array(class_obj)\n","\n"," count = Counter(class_obj[class_obj != 'nan'])\n"," print(count)\n"," class_nm = list(count.keys())\n"," class_labels = json.dumps(class_nm)\n"," class_count = list(count.values())\n"," asort_class_count = np.argsort(class_count)\n","\n"," class_nm = np.array(class_nm)[asort_class_count]\n"," class_count = np.array(class_count)[asort_class_count]\n","\n"," xs = range(len(class_count))\n","\n"," plt.subplot(2,1,2)\n"," plt.barh(xs,class_count)\n"," plt.yticks(xs,class_nm)\n"," plt.title(\"The number of objects per class: {} objects in total\".format(len(count)))\n"," plt.show()\n","\n","else:\n"," print('No augmentation will be used')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"tZvcYmxTdXQm","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###Play this cell to visualise some example images from your **augmented** dataset to make sure annotations and images are properly matched.\n","if (Use_Data_augmentation):\n"," df_anno_aug = []\n"," dir_anno_aug = augmented_training_source_annotation\n"," for fnm in os.listdir(dir_anno_aug): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(dir_anno_aug,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno_aug.append(row)\n"," df_anno_aug = pd.DataFrame(df_anno_aug)\n","\n"," size = 3 \n"," ind_random = np.random.randint(0,df_anno_aug.shape[0],size=size)\n"," img_dir=augmented_training_source\n","\n"," file_suffix = os.path.splitext(os.listdir(augmented_training_source)[0])[1]\n"," for irow in ind_random:\n"," row = df_anno_aug.iloc[irow,:]\n"," path = os.path.join(img_dir, row[\"fileID\"] + file_suffix)\n"," # read in image\n"," img = imageio.imread(path)\n","\n"," plt.figure(figsize=(12,12))\n"," plt.imshow(img) # plot image\n"," plt.title(\"Nobj={}, height={}, width={}\".format(row[\"Nobj\"],row[\"height\"],row[\"width\"]))\n"," # for each object in the image, plot the bounding box\n"," for iplot in range(row[\"Nobj\"]):\n"," plt_rectangle(plt,\n"," label = row[\"bbx_{}_name\".format(iplot)],\n"," x1=row[\"bbx_{}_xmin\".format(iplot)],\n"," y1=row[\"bbx_{}_ymin\".format(iplot)],\n"," x2=row[\"bbx_{}_xmax\".format(iplot)],\n"," y2=row[\"bbx_{}_ymax\".format(iplot)])\n"," plt.show() ## show the plot\n"," print('These are the augmented training images.')\n","\n","else:\n"," for irow in ind_random:\n"," row = df_anno.iloc[irow,:]\n"," path = os.path.join(img_dir, row[\"fileID\"] + file_suffix)\n"," # read in image\n"," img = imageio.imread(path)\n","\n"," plt.figure(figsize=(12,12))\n"," plt.imshow(img) # plot image\n"," plt.title(\"Nobj={}, height={}, width={}\".format(row[\"Nobj\"],row[\"height\"],row[\"width\"]))\n"," # for each object in the image, plot the bounding box\n"," for iplot in range(row[\"Nobj\"]):\n"," plt_rectangle(plt,\n"," label = row[\"bbx_{}_name\".format(iplot)],\n"," x1=row[\"bbx_{}_xmin\".format(iplot)],\n"," y1=row[\"bbx_{}_ymin\".format(iplot)],\n"," x2=row[\"bbx_{}_xmax\".format(iplot)],\n"," y2=row[\"bbx_{}_ymax\".format(iplot)])\n"," plt.show() ## show the plot\n"," print('These are the non-augmented training images.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL","colab_type":"text"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"code","metadata":{"id":"_cvRRrStGe3y","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pretrained network\n","\n","# Training_Source = \"\" #@param{type:\"string\"}\n","# Training_Source_annotation = \"\" #@param{type:\"string\"}\n","# Check if the right files exist\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","pretrained_model_path = \"\" #@param{type:\"string\"}\n","h5_file_path = pretrained_model_path+'/'+Weights_choice+'_weights.h5'\n","\n","if not os.path.exists(h5_file_path):\n"," print('WARNING pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","!sed -i 's@\\\"pretrained_weights\\\":.*,@\\\"pretrained_weights\\\": \\\"$h5_file_path\\\",@g' config.json\n","\n","if Use_pretrained_model == True:\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4):\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n"," learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n"," learning_rate = bestLearningRate\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," #bestLearningRate = learning_rate\n"," #lastLearningRate = learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n"," \n"," !sed -i 's@\\\"warmup_epochs\\\":.*,@\\\"warmup_epochs\\\": 0,@g' config.json\n"," !sed -i 's@\\\"learning_rate\\\":.*,@\\\"learning_rate\\\": $learning_rate,@g' config.json\n","\n","# with open(os.path.join(pretrained_model_path, 'Quality Control', 'lr.csv'),'r') as csvfile:\n","# csvRead = pd.read_csv(csvfile, sep=',')\n","# #print(csvRead)\n"," \n","# if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n","# print(\"pretrained network learning rate found\")\n","# #find the last learning rate\n","# lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n","# #Find the learning rate corresponding to the lowest validation loss\n","# min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n","# #print(min_val_loss)\n","# bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n","# if Weights_choice == \"last\":\n","# print('Last learning rate: '+str(lastLearningRate))\n","\n","# if Weights_choice == \"best\":\n","# print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n","# if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n","# bestLearningRate = initial_learning_rate\n","# lastLearningRate = initial_learning_rate\n","# print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR","colab_type":"text"},"source":["## **4.1. Train the network**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches."]},{"cell_type":"code","metadata":{"id":"EZnoS3rb8BSR","colab_type":"code","cellView":"form","colab":{}},"source":["import time\n","import csv\n","#from frontend import YOLO\n","\n","if os.path.exists(full_model_path+\"/Quality Control\"):\n"," shutil.rmtree(full_model_path+\"/Quality Control\")\n","os.makedirs(full_model_path+\"/Quality Control\")\n","\n","start = time.time()\n","\n","#@markdown ##Start Training\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","train('config.json', full_model_path, percentage_validation)\n","\n","shutil.copyfile('/content/gdrive/My Drive/keras-yolo2/config.json',full_model_path+'/config.json')\n","\n","if os.path.exists('/content/gdrive/My Drive/keras-yolo2/best_map_weights.h5'):\n"," shutil.move('/content/gdrive/My Drive/keras-yolo2/best_map_weights.h5',full_model_path+'/best_map_weights.h5')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"XQjQb_J_Qyku","colab_type":"text"},"source":["##**4.3. Download your model(s) from Google Drive**\n","\n","\n","---\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"2HbZd7rFqAad","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"EdcnkCr9Nbl8","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the name of the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","if (Use_the_current_trained_model): \n"," QC_model_folder = full_model_path\n","\n","#print(os.path.join(model_path, model_name))\n","\n","if os.path.exists(QC_model_folder):\n"," print(\"The \"+os.path.basename(QC_model_folder)+\" model will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n","\n","if Use_the_current_trained_model == False:\n"," if os.path.exists('/content/gdrive/My Drive/keras-yolo2/config.json'):\n"," os.remove('/content/gdrive/My Drive/keras-yolo2/config.json')\n"," shutil.copyfile(QC_model_folder+'/config.json','/content/gdrive/My Drive/keras-yolo2/config.json')\n","\n","#@markdown ###Which backend is the model using?\n","backend = \"Full Yolo\" #@param [\"Select Model\",\"Full Yolo\",\"Inception3\",\"SqueezeNet\",\"MobileNet\",\"Tiny Yolo\"]\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","if backend == \"Full Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/full_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/full_yolo_backend.h5\n","elif backend == \"Inception3\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/inception_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/inception_backend.h5\n","elif backend == \"MobileNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/mobilenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/mobilenet_backend.h5\n","elif backend == \"SqueezeNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/squeezenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/squeezenet_backend.h5\n","elif backend == \"Tiny Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/tiny_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/tiny_yolo_backend.h5\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","import csv\n","from matplotlib import pyplot as plt\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","mAPDataFromCSV = []\n","with open(QC_model_folder+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n"," mAPDataFromCSV.append(float(row[2]))\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(20,15))\n","\n","plt.subplot(3,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(3,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","#plt.savefig(os.path.dirname(QC_model_folder)+'/Quality Control/lossCurvePlots.png')\n","#plt.show()\n","\n","plt.subplot(3,1,3)\n","plt.plot(epochNumber,mAPDataFromCSV, label='mAP score')\n","plt.title('mean average precision (mAP) vs. epoch number (linear scale)')\n","plt.ylabel('mAP score')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_folder+'/Quality Control/lossCurveAndmAPPlots.png')\n","plt.show()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"RZOPCVN0qcYb","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display an overlay of the input images ground-truth (solid lines) and predicted boxes (dashed lines). Additionally, the below cell will show the mAP value of the model on the QC data together with plots of the Precision-Recall curves for all the classes in the dataset. If you want to read in more detail about these scores, we recommend [this brief explanation](https://medium.com/@jonathan_hui/map-mean-average-precision-for-object-detection-45c121a31173).\n","\n"," The images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" should contain images (e.g. as .jpg)and annotations (.xml files)!\n","\n","Since the training saves three different models, for the best validation loss (`best_weights`), best average precision (`best_mAP_weights`) and the model after the last epoch (`last_weights`), you should choose which ones you want to use for quality control or prediction. We recommend using `best_map_weights` because they should yield the best performance on the dataset. However, it can be worth testing how well `best_weights` perform too.\n","\n","**mAP score:** This refers to the mean average precision of the model on the given dataset. This value gives an indication how precise the predictions of the classes on this dataset are when compared to the ground-truth. Values closer to 1 indicate a good fit.\n","\n","**Precision:** This is the proportion of the correct classifications (true positives) in all the predictions made by the model.\n","\n","**Recall:** This is the proportion of the detected true positives in all the detectable data."]},{"cell_type":"code","metadata":{"id":"Nh8MlX3sqd_7","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Annotations_QC_folder = \"\" #@param{type:\"string\"}\n","\n","#@markdown ##Choose which model you want to evaluate:\n","model_choice = \"best_map_weights\" #@param[\"best_weights\",\"last_weights\",\"best_map_weights\"]\n","\n","file_suffix = os.path.splitext(os.listdir(Source_QC_folder)[0])[1]\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_folder+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_folder+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_folder+\"/Quality Control/Prediction\")\n","\n","#Delete old csv with box predictions if one exists\n","\n","if os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," os.remove('/content/predicted_bounding_boxes.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_names.csv'):\n"," os.remove('/content/predicted_bounding_boxes_names.csv')\n","if os.path.exists(Source_QC_folder+'/.ipynb_checkpoints'):\n"," shutil.rmtree(Source_QC_folder+'/.ipynb_checkpoints')\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","n_objects = []\n","for img in os.listdir(Source_QC_folder):\n"," full_image_path = Source_QC_folder+'/'+img\n"," n_obj = predict('config.json',QC_model_folder+'/'+model_choice+'.h5',full_image_path)\n"," n_objects.append(n_obj)\n","\n","for img in os.listdir(Source_QC_folder):\n"," if img.endswith('detected'+file_suffix):\n"," shutil.move(Source_QC_folder+'/'+img,QC_model_folder+\"/Quality Control/Prediction/\"+img)\n","\n","### Get the coordinates of the predicted boxes, ###\n","### box classes and confidence scores ###\n","\n","# from the csv containing the predicted boxes\n","with open('/content/predicted_bounding_boxes.csv','r', newline='') as csvfile:\n"," csv_reader = csv.reader(csvfile)\n"," next(csv_reader)\n"," pred_boxes = []\n"," pred_classes = []\n"," pred_conf = []\n"," for row in csv_reader:\n"," image_boxes = []\n"," box_classes = []\n"," box_conf = []\n"," for i in range(1,len(row),6):\n"," image_boxes.append(list(map(float,row[i:i+4])))\n"," box_classes.append(int(row[i+5]))\n"," box_conf.append(float(row[i+4]))\n"," pred_boxes.append(image_boxes) # The rows of this list contain the coordinates for all boxes per image\n"," pred_classes.append(box_classes) # The rows of this list contain the predicted classes for each box in the pred_boxes\n"," pred_conf.append(box_conf) # The rows of this list contain the confidence scores for each predicted box in pred_boxes\n","\n","#shutil.move('/content/predicted_bounding_boxes.csv',QC_model_folder+\"/Quality Control/Prediction/predicted_boxes_QC.csv\")\n","\n","#### Get the coordinates of the GT boxes ###\n","\n","df_anno_QC_gt = []\n","#dir_anno = Training_Source_annotations\n","for fnm in os.listdir(Annotations_QC_folder): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(Annotations_QC_folder,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno_QC_gt.append(row)\n","df_anno_QC_gt = pd.DataFrame(df_anno_QC_gt)\n","\n","#df_anno_QC_gt.to_csv('/content/gt_bboxes_QC.csv')\n","maxNobj = np.max(df_anno_QC_gt[\"Nobj\"])\n","\n","config_path = '/content/gdrive/My Drive/keras-yolo2/config.json'\n","class_dict = {}\n","\n","with open(config_path) as config_buffer:\n"," config = json.load(config_buffer)\n"," for i in config[\"model\"][\"labels\"]:\n"," class_dict[i] = int(config[\"model\"][\"labels\"].index(i))\n","\n","reverse_class_dict = {value : key for (key, value) in class_dict.items()}\n","\n","df_anno_QC_gt = df_anno_QC_gt.replace(class_dict)\n","df_anno_QC_gt.to_csv(QC_model_folder+'/Quality Control/gt_bboxes_QC.csv')\n","\n","gt_boxes = []\n","gt_labels = []\n","gt_label_names = []\n","for j in range(0,df_anno_QC_gt.shape[0]):\n"," row = df_anno_QC_gt.iloc[j]\n"," width = int(row[\"width\"])\n"," height = int(row[\"height\"])\n"," gt_box = []\n"," gt_label = []\n"," gt_label_name = []\n"," for i in range(row[\"Nobj\"]):\n"," label = int(float(row[\"bbx_{}_name\".format(i)]))\n"," label_name = row[\"bbx_{}_name\".format(i)]\n"," x1=row[\"bbx_{}_xmin\".format(i)]\n"," y1=row[\"bbx_{}_ymin\".format(i)]\n"," x2=row[\"bbx_{}_xmax\".format(i)]\n"," y2=row[\"bbx_{}_ymax\".format(i)]\n"," #gt_box.append([x1/width,y1/height,x2/width,y2/height])\n"," gt_box.append([x1,y1,x2,y2])\n","\n"," gt_label.append(label)\n"," gt_label_name.append(label_name)\n"," gt_boxes.append(gt_box)\n"," gt_labels.append(gt_label)\n"," gt_label_names.append(gt_label_name)\n","\n","#The essential outputs from this are gt_array and gt_classes_full\n","#Each row contains all bounding boxes and classes for each gt image.\n","\n","#Here we create the Detection Maps for the first three predictions\n","#Prediction\n","\n","pred_box_1 = np.array(pred_boxes[0])\n","#pred_box_2 = np.array(pred_boxes[1])\n","#pred_box_3 = np.array(pred_boxes[2])\n","\n","pred_class_1 = np.array(pred_classes[0])\n","#pred_class_2 = np.array(pred_classes[1])\n","#pred_class_3 = np.array(pred_classes[2])\n","\n","pred_conf_1 = np.array(pred_conf[0])\n","#pred_conf_2 = np.array(pred_conf[1])\n","#pred_conf_3 = np.array(pred_conf[2])\n"," \n","#print(pred_box_1)\n","\n","#print(pred_conf_1)\n","\n","# #GT\n","#print(gt_box_1[0])\n","gt_box_1 = np.array(gt_boxes[0])\n","#gt_box_2 = np.array(gt_boxes[1])\n","#gt_box_3 = np.array(gt_boxes[2])\n","#print(gt_box_1)\n","\n","gt_class_1 = np.array(gt_labels[0])\n","#gt_class_2 = np.array(gt_labels[1])\n","#gt_class_3 = np.array(gt_labels[2])\n","\n","frames = [(pred_box_1, pred_class_1, pred_conf_1, gt_box_1, gt_class_1)]\n"," #(pred_box_2, pred_class_2, pred_conf_2, gt_box_3, gt_class_3),#]#,\n"," #(pred_box_3, pred_class_3, pred_conf_3, gt_box_1, gt_class_1)#]#,\n"," #]\n"," #]\n","\n","n_class = len(config['model']['labels'])\n","\n","plt.figure(figsize=(15,5))\n","for i, frame in enumerate(frames):\n"," img = np.array(io.imread(os.path.join(Source_QC_folder,os.path.splitext(sorted(os.listdir(Annotations_QC_folder))[i])[0]+file_suffix)))\n"," show_frame(*frame, reverse_class_dict, background = img)\n","\n","\n","#Make a csv file to read into imagej macro, to create custom bounding boxes\n","header = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*max(n_objects)\n","with open('/content/predicted_bounding_boxes.csv', newline='') as inFile, open('/content/predicted_bounding_boxes_new.csv', 'w', newline='') as outfile:\n"," r = csv.reader(inFile)\n"," w = csv.writer(outfile)\n"," next(r, None) # skip the first row from the reader, the old header\n"," # write new header\n"," w.writerow(header)\n"," # copy the rest\n"," for row in r:\n"," w.writerow(row)\n","\n","df_bbox=pd.read_csv('/content/predicted_bounding_boxes_new.csv')\n","df_bbox=df_bbox.transpose()\n","new_header = df_bbox.iloc[0] #grab the first row for the header\n","df_bbox = df_bbox[1:] #take the data less the header row\n","df_bbox.columns = new_header #set the header row as the df header\n","df_bbox.sort_values(by='filename',axis=1,inplace=True)\n","df_bbox.to_csv(QC_model_folder+'/Quality Control/predicted_bounding_boxes_for_custom_ROI_QC.csv')\n","\n","AP, recall, precision = _calc_avg_precisions(config,Source_QC_folder,Annotations_QC_folder+'/',QC_model_folder+'/'+model_choice+'.h5',0.3,0.3)\n","\n","print('mAP score for QC dataset: '+str(sum(AP.values())/len(AP)))\n","for i in range(len(AP)):\n"," if AP[i]!=0:\n"," if len(recall[i]) == 1:\n"," new_recall = np.linspace(0,list(recall[i])[0],10)\n"," new_precision = list(precision[i])*10\n"," fig = plt.figure(figsize=(3,2))\n"," plt.plot(new_recall,new_precision)\n"," plt.axis([min(new_recall),1,0,1.02])\n"," plt.xlabel('Recall',fontsize=14)\n"," plt.ylabel('Precision',fontsize=14)\n"," plt.title(config['model']['labels'][i]+', AP: '+str(round(AP[i],3)),fontsize=14)\n"," plt.fill_between(new_recall,new_precision,alpha=0.3)\n"," plt.savefig('/content/P-R_curve_'+str(i)+'.png')\n"," plt.show()\n"," else:\n"," new_recall = list(recall[i])\n"," new_recall.append(new_recall[len(new_recall)-1])\n"," new_precision = list(precision[i])\n"," new_precision.append(0)\n"," fig = plt.figure(figsize=(3,2))\n"," plt.plot(new_recall,new_precision)\n"," plt.axis([min(new_recall),1,0,1.02])\n"," plt.xlabel('Recall',fontsize=14)\n"," plt.ylabel('Precision',fontsize=14)\n"," plt.title(config['model']['labels'][i]+', AP: '+str(round(AP[i],3)),fontsize=14)\n"," plt.fill_between(new_recall,new_precision,alpha=0.3)\n"," plt.savefig('/content/P-R_curve_'+str(i)+'.png')\n"," plt.show()\n"," else:\n"," print('No object of class '+config['model']['labels'][i]+' was detected. This will lower the mAP score. Consider adding an image containing this class to your QC dataset to see if the model can detect this class at all.')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"-n9CLLJ77FAA","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Inspect example output from QC\n","import random\n","from matplotlib.pyplot import imread\n","import imageio\n","\n","add_header('/content/predicted_bounding_boxes_names.csv','/content/predicted_bounding_boxes_names_new.csv')\n","\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","file_suffix = os.path.splitext(random_choice)[1]\n","\n","plt.figure(figsize=(30,15))\n","\n","\n","### Display Raw input ###\n","\n","x = imread(Source_QC_folder+\"/\"+random_choice)\n","plt.subplot(1,3,1)\n","plt.axis('off')\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Input')\n","\n","### Display Predicted annotation ###\n","\n","df_bbox2 = pd.read_csv('/content/predicted_bounding_boxes_names_new.csv')\n","for img in range(0,df_bbox2.shape[0]):\n"," df_bbox2.iloc[img]\n"," row = pd.DataFrame(df_bbox2.iloc[img])\n"," if row[img][0] == random_choice:\n"," row = row.dropna()\n"," image = imageio.imread(Source_QC_folder+'/'+row[img][0])\n"," #plt.figure(figsize=(12,12))\n"," plt.subplot(1,3,2)\n"," plt.axis('off')\n"," plt.imshow(image) # plot image\n"," plt.title('Prediction')\n"," for i in range(1,int(len(row)-1),6):\n"," plt_rectangle(plt,\n"," label = row[img][i+5],\n"," x1=row[img][i],#.format(iplot)],\n"," y1=row[img][i+1],\n"," x2=row[img][i+2],\n"," y2=row[img][i+3])#,\n"," #fontsize=8)\n","\n","\n","### Display GT Annotation ###\n","\n","df_anno_QC_gt = []\n","for fnm in os.listdir(Annotations_QC_folder): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(Annotations_QC_folder,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno_QC_gt.append(row)\n","df_anno_QC_gt = pd.DataFrame(df_anno_QC_gt)\n","#maxNobj = np.max(df_anno_QC_gt[\"Nobj\"])\n","\n","for i in range(0,df_anno_QC_gt.shape[0]):\n"," if df_anno_QC_gt.iloc[i][\"fileID\"]+file_suffix == random_choice:\n"," row = df_anno_QC_gt.iloc[i]\n","\n","img = imageio.imread(Source_QC_folder+'/'+random_choice)\n","plt.subplot(1,3,3)\n","plt.axis('off')\n","plt.imshow(img) # plot image\n","plt.title('Ground Truth annotations')\n","\n","# for each object in the image, plot the bounding box\n","for iplot in range(row[\"Nobj\"]):\n"," plt_rectangle(plt,\n"," label = row[\"bbx_{}_name\".format(iplot)],\n"," x1=row[\"bbx_{}_xmin\".format(iplot)],\n"," y1=row[\"bbx_{}_ymin\".format(iplot)],\n"," x2=row[\"bbx_{}_xmax\".format(iplot)],\n"," y2=row[\"bbx_{}_ymax\".format(iplot)])#,\n"," #fontsize=8)\n","\n","### Show the plot ###\n","plt.show()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"Esqnbew8uznk"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`Prediction_model_path`:** This should be the folder that contains your model."]},{"cell_type":"code","metadata":{"id":"9ZmST3JRq-Ho","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","file_suffix = os.path.splitext(os.listdir(Data_folder)[0])[1]\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, provide the name of the model and path to model folder:\n","\n","Prediction_model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Which model do you want to use?\n","model_choice = \"best_map_weights\" #@param[\"best_weights\",\"last_weights\",\"best_map_weights\"]\n","\n","#@markdown ###Which backend is the model using?\n","backend = \"Full Yolo\" #@param [\"Select Model\",\"Full Yolo\",\"Inception3\",\"SqueezeNet\",\"MobileNet\",\"Tiny Yolo\"]\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","if backend == \"Full Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/full_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/full_yolo_backend.h5\n","elif backend == \"Inception3\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/inception_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/inception_backend.h5\n","elif backend == \"MobileNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/mobilenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/mobilenet_backend.h5\n","elif backend == \"SqueezeNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/squeezenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/squeezenet_backend.h5\n","elif backend == \"Tiny Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/tiny_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/tiny_yolo_backend.h5\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_path = full_model_path\n","\n","if Use_the_current_trained_model == False:\n"," if os.path.exists('/content/gdrive/My Drive/keras-yolo2/config.json'):\n"," os.remove('/content/gdrive/My Drive/keras-yolo2/config.json')\n"," shutil.copyfile(Prediction_model_path+'/config.json','/content/gdrive/My Drive/keras-yolo2/config.json')\n","\n","if os.path.exists(Prediction_model_path+'/'+model_choice+'.h5'):\n"," print(\"The \"+os.path.basename(Prediction_model_path)+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","# Provide the code for performing predictions and saving them\n","print(\"Images saved into folder:\", Result_folder)\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"GcmBwMJVcFh1","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run Prediction\n","\n","#Remove any files that might be from the prediction of QC examples.\n","if os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," os.remove('/content/predicted_bounding_boxes.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_new.csv'):\n"," os.remove('/content/predicted_bounding_boxes_new.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_names.csv'):\n"," os.remove('/content/predicted_bounding_boxes_names.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_names_new.csv'):\n"," os.remove('/content/predicted_bounding_boxes_names_new.csv')\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","\n","if os.path.exists(Data_folder+'/.ipynb_checkpoints'):\n"," shutil.rmtree(Data_folder+'/.ipynb_checkpoints')\n","\n","n_objects = []\n","for img in os.listdir(Data_folder):\n"," full_image_path = Data_folder+'/'+img\n"," n_obj = predict('config.json',Prediction_model_path+'/'+model_choice+'.h5',full_image_path)#,Result_folder)\n"," n_objects.append(n_obj)\n","\n","for img in os.listdir(Data_folder):\n"," if img.endswith('detected'+file_suffix):\n"," shutil.move(Data_folder+'/'+img,Result_folder+'/'+img)\n","\n","if os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," #shutil.move('/content/predicted_bounding_boxes.csv',Result_folder+'/predicted_bounding_boxes.csv')\n"," print('Bounding box labels and coordinates saved to '+ Result_folder)\n","else:\n"," print('For some reason the bounding box labels and coordinates were not saved. Check that your predictions look as expected.')\n","\n","#Make a csv file to read into imagej macro, to create custom bounding boxes\n","header = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*max(n_objects)\n","with open('/content/predicted_bounding_boxes.csv', newline='') as inFile, open('/content/predicted_bounding_boxes_new.csv', 'w', newline='') as outfile:\n"," r = csv.reader(inFile)\n"," w = csv.writer(outfile)\n"," next(r, None) # skip the first row from the reader, the old header\n"," # write new header\n"," w.writerow(header)\n"," # copy the rest\n"," for row in r:\n"," w.writerow(row)\n","\n","df_bbox=pd.read_csv('/content/predicted_bounding_boxes_new.csv')\n","df_bbox=df_bbox.transpose()\n","new_header = df_bbox.iloc[0] #grab the first row for the header\n","df_bbox = df_bbox[1:] #take the data less the header row\n","df_bbox.columns = new_header #set the header row as the df header\n","df_bbox.sort_values(by='filename',axis=1,inplace=True)\n","df_bbox.to_csv(Result_folder+'/predicted_bounding_boxes_for_custom_ROI.csv')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"EIe3CRD7XUxa","colab_type":"text"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","metadata":{"id":"LmDP8xiwXTTL","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","import random\n","from matplotlib.pyplot import imread\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","print(random_choice)\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+os.path.splitext(random_choice)[0]+'_detected'+file_suffix)\n","\n","plt.figure(figsize=(20,8))\n","\n","plt.subplot(1,3,1)\n","plt.axis('off')\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Input')\n","\n","plt.subplot(1,3,2)\n","plt.axis('off')\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Predicted output');\n","\n","add_header('/content/predicted_bounding_boxes_names.csv','/content/predicted_bounding_boxes_names_new.csv')\n","\n","#We need to edit this predicted_bounding_boxes_new.csv file slightly to display the bounding boxes\n","df_bbox2 = pd.read_csv('/content/predicted_bounding_boxes_names_new.csv')\n","for img in range(0,df_bbox2.shape[0]):\n"," df_bbox2.iloc[img]\n"," row = pd.DataFrame(df_bbox2.iloc[img])\n"," if row[img][0] == random_choice:\n"," row = row.dropna()\n"," image = imageio.imread(Data_folder+'/'+row[img][0])\n"," #plt.figure(figsize=(12,12))\n"," plt.subplot(1,3,3)\n"," plt.axis('off')\n"," plt.title('Alternative Display of Prediction')\n"," plt.imshow(image) # plot image\n","\n"," for i in range(1,int(len(row)-1),6):\n"," plt_rectangle(plt,\n"," label = row[img][i+5],\n"," x1=row[img][i],#.format(iplot)],\n"," y1=row[img][i+1],\n"," x2=row[img][i+2],\n"," y2=row[img][i+3])#,\n"," #fontsize=8)\n"," #plt.margins(0,0)\n"," #plt.subplots_adjust(left=0., right=1., top=1., bottom=0.)\n"," #plt.gca().xaxis.set_major_locator(plt.NullLocator())\n"," #plt.gca().yaxis.set_major_locator(plt.NullLocator())\n"," plt.savefig('/content/detected_cells.png',bbox_inches='tight',transparent=True,pad_inches=0)\n","plt.show() ## show the plot\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB","colab_type":"text"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"Rn9zpWpo0xNw","colab_type":"text"},"source":["\n","#**Thank you for using YOLOv2!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"YOLOv2_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1LWs9bFbYclR1nWaupcSPUYFN6yyUU_5t","timestamp":1596536407170},{"file_id":"1uUjR8Sm2l6vAJfclb84gUUH4MCwzQUWO","timestamp":1594734310956},{"file_id":"1zileODcR2RNrVSidXNuBfgFDv68JRRa0","timestamp":1593093410185},{"file_id":"1EpgWlJK6U_ZwlBGiomLfbxx9UUtRPBTy","timestamp":1592904104821},{"file_id":"1f5usS6p8Cu_efegMwcR3v68AVOXBSyIf","timestamp":1588870626184},{"file_id":"1fM7obTEQKnSgVZMDa1KjiBgiBar2b0t8","timestamp":1588693012611},{"file_id":"1owWtQQucUxUOZMaPh2x_mxe_qXKHCZhp","timestamp":1588074588514},{"file_id":"159ARwlQE7-zi0EHxunOF_YPFLt-ZVU5x","timestamp":1587562499898},{"file_id":"1W-7NHehG5MRFILvZZzhPWWnOdJMkadb2","timestamp":1586332290412},{"file_id":"1pUetEQICxYWkYVaQIgdRH1EZBTl7oc2A","timestamp":1586292199692},{"file_id":"1MD36ZkM6XR9EuV12zimJmfCjzyeYZFWq","timestamp":1586269469061},{"file_id":"16A2mbaHzlEElntS8qkFBOsBvZG-mUeY6","timestamp":1586253795726},{"file_id":"1gJlcjOiSxr2buDOxmcFbT_d-GqwLjXtK","timestamp":1583343225796},{"file_id":"10yGI51WzHfgWgZAyE-EbkZFEvIOd6CP6","timestamp":1583171396283}],"collapsed_sections":[],"toc_visible":true},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I","colab_type":"text"},"source":["# **YOLOv2**\n","---\n","\n"," YOLOv2 is a deep-learning method designed to perform object detection and classification of objects in images, published by [Redmon and Farhadi](https://ieeexplore.ieee.org/document/8100173). This is based on the original [YOLO](https://arxiv.org/abs/1506.02640) implementation published by the same authors. YOLOv2 is trained on images with class annotations in the form of bounding boxes drawn around the objects of interest. The images are downsampled by a convolutional neural network (CNN) and objects are classified in two final fully connected layers in the network. YOLOv2 learns classification and object detection simultaneously by taking the whole input image into account, predicting many possible bounding box solutions, and then using regression to find the best bounding boxes and classifications for each object.\n","\n","**This particular notebook enables object detection and classification on 2D images given ground truth bounding boxes. If you are interested in image segmentation, you should use our U-net or Stardist notebooks instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following papers: \n","\n","**YOLO9000: Better, Faster, Stronger** from Joseph Redmon and Ali Farhadi in Proceedings of the IEEE conference on computer vision and pattern recognition, 2017, (https://ieeexplore.ieee.org/document/8100173)\n","\n","**You Only Look Once: Unified, Real-Time Object Detection** from Joseph Redmon, Santosh Divvala, Ross Girshick, Ali Farhadi in IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2016, (https://ieeexplore.ieee.org/document/7780460)\n","\n","**Note: The source code for this notebook is adapted for keras and can be found in: (https://github.com/experiencor/keras-yolo2)**\n","\n","\n","**Please also cite these original papers when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use ZeroCostDL4Mic notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," Preparing the dataset carefully is essential to make this YOLOv2 notebook work. This model requires as input a set of images (currently .jpg) and as target a list of annotation files in Pascal VOC format. The annotation files should have the exact same name as the input files, except with an .xml instead of the .jpg extension. The annotation files contain the class labels and all bounding boxes for the objects for each image in your dataset. Most datasets will give the option of saving the annotations in this format or using software for hand-annotations will automatically save the annotations in this format. \n","\n"," If you want to assemble your own dataset we recommend using the open source https://www.makesense.ai/ resource. You can follow our instructions on how to label your dataset with this tool on our [wiki](https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki/Object-Detection-(YOLOv2)).\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .png or .jpg files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Input images (Training_source)\n"," - img_1.png, img_2.png, ...\n"," - High SNR images (Training_source_annotations)\n"," - img_1.xml, img_2.xml, ...\n"," - **Quality control dataset**\n"," - Input images\n"," - img_1.png, img_2.png\n"," - High SNR images\n"," - img_1.xml, img_2.xml\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"BCPhV-pe-syw","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"VNZetvLiS1qV","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"UBrnApIUBgxv","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin","colab_type":"text"},"source":["# **2. Install YOLOv2 and Dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install Network and Dependencies\n","%tensorflow_version 1.x\n","!pip install pascal-voc-writer\n","from pascal_voc_writer import Writer\n","from __future__ import division\n","from __future__ import print_function\n","from __future__ import absolute_import\n","import csv\n","import random\n","import pprint\n","import sys\n","import time\n","import numpy as np\n","from optparse import OptionParser\n","import pickle\n","import math\n","import cv2\n","import copy\n","import math\n","from matplotlib import pyplot as plt\n","import matplotlib.patches as patches\n","import tensorflow as tf\n","import pandas as pd\n","import os\n","import shutil\n","from skimage import io\n","from sklearn.metrics import average_precision_score\n","\n","from keras.models import Model\n","from keras.layers import Flatten, Dense, Input, Conv2D, MaxPooling2D, Dropout, Reshape, Activation, Conv2D, MaxPooling2D, BatchNormalization, Lambda\n","from keras.layers.advanced_activations import LeakyReLU\n","from keras.layers.merge import concatenate\n","from keras.applications.mobilenet import MobileNet\n","from keras.applications import InceptionV3\n","from keras.applications.vgg16 import VGG16\n","from keras.applications.resnet50 import ResNet50\n","\n","from keras import backend as K\n","from keras.optimizers import Adam, SGD, RMSprop\n","from keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, TimeDistributed\n","from keras.engine.topology import get_source_inputs\n","from keras.utils import layer_utils\n","from keras.utils.data_utils import get_file\n","from keras.objectives import categorical_crossentropy\n","from keras.models import Model\n","from keras.utils import generic_utils\n","from keras.engine import Layer, InputSpec\n","from keras import initializers, regularizers\n","from keras.utils import Sequence\n","import xml.etree.ElementTree as ET\n","from collections import OrderedDict, Counter\n","import json\n","import imageio\n","import imgaug as ia\n","from imgaug import augmenters as iaa\n","import copy\n","import cv2\n","from tqdm import tqdm\n","from tempfile import mkstemp\n","from shutil import move, copymode\n","from os import fdopen, remove\n","ia.seed(1)\n","# imgaug uses matplotlib backend for displaying images\n","from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage\n","import re\n","import glob\n","\n","!git clone https://github.com/rodrigo2019/keras_yolo2.git\n","\n","if os.path.exists('/content/gdrive/My Drive/keras-yolo2'):\n"," shutil.rmtree('/content/gdrive/My Drive/keras-yolo2')\n","\n","!git clone https://github.com/experiencor/keras-yolo2.git\n","shutil.move('/content/keras-yolo2','/content/gdrive/My Drive/keras-yolo2')\n","shutil.move('/content/keras_yolo2/keras_yolov2/map_evaluation.py','/content/gdrive/My Drive/keras-yolo2/map_evaluation.py')\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","\n","from backend import BaseFeatureExtractor, FullYoloFeature\n","from preprocessing import parse_annotation, BatchGenerator\n","\n","#shutil.move('/content/map_evaluation.py','/content/gdrive/My Drive/keras-yolo2/map_evaluation.py')\n","\n","print(\"Depencies installed and imported.\")\n","\n","def plt_rectangle(plt,label,x1,y1,x2,y2,fontsize=10):\n"," '''\n"," == Input ==\n"," \n"," plt : matplotlib.pyplot object\n"," label : string containing the object class name\n"," x1 : top left corner x coordinate\n"," y1 : top left corner y coordinate\n"," x2 : bottom right corner x coordinate\n"," y2 : bottom right corner y coordinate\n"," '''\n"," linewidth = 1\n"," color = \"yellow\"\n"," plt.text(x1,y1,label,fontsize=fontsize,backgroundcolor=\"magenta\")\n"," plt.plot([x1,x1],[y1,y2], linewidth=linewidth,color=color)\n"," plt.plot([x2,x2],[y1,y2], linewidth=linewidth,color=color)\n"," plt.plot([x1,x2],[y1,y1], linewidth=linewidth,color=color)\n"," plt.plot([x1,x2],[y2,y2], linewidth=linewidth,color=color)\n","\n","def extract_single_xml_file(tree,object_count=True):\n"," Nobj = 0\n"," row = OrderedDict()\n"," for elems in tree.iter():\n","\n"," if elems.tag == \"size\":\n"," for elem in elems:\n"," row[elem.tag] = int(elem.text)\n"," if elems.tag == \"object\":\n"," for elem in elems:\n"," if elem.tag == \"name\":\n"," row[\"bbx_{}_{}\".format(Nobj,elem.tag)] = str(elem.text) \n"," if elem.tag == \"bndbox\":\n"," for k in elem:\n"," row[\"bbx_{}_{}\".format(Nobj,k.tag)] = float(k.text)\n"," Nobj += 1\n"," if object_count == True:\n"," row[\"Nobj\"] = Nobj\n"," return(row)\n","\n","def count_objects(tree):\n"," Nobj=0\n"," for elems in tree.iter():\n"," if elems.tag == \"object\":\n"," for elem in elems:\n"," if elem.tag == \"bndbox\":\n"," Nobj += 1\n"," return(Nobj)\n","\n","def compute_overlap(a, b):\n"," \"\"\"\n"," Code originally from https://github.com/rbgirshick/py-faster-rcnn.\n"," Parameters\n"," ----------\n"," a: (N, 4) ndarray of float\n"," b: (K, 4) ndarray of float\n"," Returns\n"," -------\n"," overlaps: (N, K) ndarray of overlap between boxes and query_boxes\n"," \"\"\"\n"," area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])\n","\n"," iw = np.minimum(np.expand_dims(a[:, 2], axis=1), b[:, 2]) - np.maximum(np.expand_dims(a[:, 0], 1), b[:, 0])\n"," ih = np.minimum(np.expand_dims(a[:, 3], axis=1), b[:, 3]) - np.maximum(np.expand_dims(a[:, 1], 1), b[:, 1])\n","\n"," iw = np.maximum(iw, 0)\n"," ih = np.maximum(ih, 0)\n","\n"," ua = np.expand_dims((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), axis=1) + area - iw * ih\n","\n"," ua = np.maximum(ua, np.finfo(float).eps)\n","\n"," intersection = iw * ih\n","\n"," return intersection / ua\n","\n","def compute_ap(recall, precision):\n"," \"\"\" Compute the average precision, given the recall and precision curves.\n"," Code originally from https://github.com/rbgirshick/py-faster-rcnn.\n","\n"," # Arguments\n"," recall: The recall curve (list).\n"," precision: The precision curve (list).\n"," # Returns\n"," The average precision as computed in py-faster-rcnn.\n"," \"\"\"\n"," # correct AP calculation\n"," # first append sentinel values at the end\n"," mrec = np.concatenate(([0.], recall, [1.]))\n"," mpre = np.concatenate(([0.], precision, [0.]))\n","\n"," # compute the precision envelope\n"," for i in range(mpre.size - 1, 0, -1):\n"," mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])\n","\n"," # to calculate area under PR curve, look for points\n"," # where X axis (recall) changes value\n"," i = np.where(mrec[1:] != mrec[:-1])[0]\n","\n"," # and sum (\\Delta recall) * prec\n"," ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])\n"," return ap \n","\n","def load_annotation(image_folder,annotations_folder, i, config):\n"," annots = []\n"," imgs, anns = parse_annotation(annotations_folder,image_folder,config['model']['labels'])\n"," for obj in imgs[i]['object']:\n"," annot = [obj['xmin'], obj['ymin'], obj['xmax'], obj['ymax'], config['model']['labels'].index(obj['name'])]\n"," annots += [annot]\n","\n"," if len(annots) == 0: annots = [[]]\n","\n"," return np.array(annots)\n","\n","def _calc_avg_precisions(config,image_folder,annotations_folder,weights_path,iou_threshold,score_threshold):\n","\n"," # gather all detections and annotations\n"," all_detections = [[None for _ in range(len(config['model']['labels']))] for _ in range(len(os.listdir(image_folder)))]\n"," all_annotations = [[None for _ in range(len(config['model']['labels']))] for _ in range(len(os.listdir(annotations_folder)))]\n","\n"," for i in range(len(os.listdir(image_folder))):\n"," raw_image = cv2.imread(os.path.join(image_folder,sorted(os.listdir(image_folder))[i]))\n"," raw_height, raw_width, _ = raw_image.shape\n"," #print(raw_height)\n"," # make the boxes and the labels\n"," yolo = YOLO(backend = config['model']['backend'],\n"," input_size = config['model']['input_size'], \n"," labels = config['model']['labels'], \n"," max_box_per_image = config['model']['max_box_per_image'],\n"," anchors = config['model']['anchors'])\n"," yolo.load_weights(weights_path)\n"," pred_boxes = yolo.predict(raw_image,iou_threshold=iou_threshold,score_threshold=score_threshold)\n","\n"," score = np.array([box.score for box in pred_boxes])\n"," #print(score)\n"," pred_labels = np.array([box.label for box in pred_boxes])\n"," #print(len(pred_boxes))\n"," if len(pred_boxes) > 0:\n"," pred_boxes = np.array([[box.xmin * raw_width, box.ymin * raw_height, box.xmax * raw_width,\n"," box.ymax * raw_height, box.score] for box in pred_boxes])\n"," else:\n"," pred_boxes = np.array([[]])\n","\n"," # sort the boxes and the labels according to scores\n"," score_sort = np.argsort(-score)\n"," pred_labels = pred_labels[score_sort]\n"," pred_boxes = pred_boxes[score_sort]\n","\n"," # copy detections to all_detections\n"," for label in range(len(config['model']['labels'])):\n"," all_detections[i][label] = pred_boxes[pred_labels == label, :]\n","\n"," annotations = load_annotation(image_folder,annotations_folder,i,config)\n","\n"," # copy ground truth to all_annotations\n"," for label in range(len(config['model']['labels'])):\n"," all_annotations[i][label] = annotations[annotations[:, 4] == label, :4].copy()\n","\n"," # compute mAP by comparing all detections and all annotations\n"," average_precisions = {}\n"," total_recall = []\n"," total_precision = []\n"," for label in range(len(config['model']['labels'])):\n"," false_positives = np.zeros((0,))\n"," true_positives = np.zeros((0,))\n"," scores = np.zeros((0,))\n"," num_annotations = 0.0\n","\n"," for i in range(len(os.listdir(image_folder))):\n"," detections = all_detections[i][label]\n"," annotations = all_annotations[i][label]\n"," num_annotations += annotations.shape[0]\n"," detected_annotations = []\n","\n"," for d in detections:\n"," scores = np.append(scores, d[4])\n","\n"," if annotations.shape[0] == 0:\n"," false_positives = np.append(false_positives, 1)\n"," true_positives = np.append(true_positives, 0)\n"," continue\n","\n"," overlaps = compute_overlap(np.expand_dims(d, axis=0), annotations)\n"," assigned_annotation = np.argmax(overlaps, axis=1)\n"," max_overlap = overlaps[0, assigned_annotation]\n","\n"," if max_overlap >= iou_threshold and assigned_annotation not in detected_annotations:\n"," false_positives = np.append(false_positives, 0)\n"," true_positives = np.append(true_positives, 1)\n"," detected_annotations.append(assigned_annotation)\n"," else:\n"," false_positives = np.append(false_positives, 1)\n"," true_positives = np.append(true_positives, 0)\n","\n"," # no annotations -> AP for this class is 0 (is this correct?)\n"," if num_annotations == 0:\n"," average_precisions[label] = 0\n"," continue\n","\n"," # sort by score\n"," indices = np.argsort(-scores)\n"," false_positives = false_positives[indices]\n"," true_positives = true_positives[indices]\n","\n"," # compute false positives and true positives\n"," false_positives = np.cumsum(false_positives)\n"," true_positives = np.cumsum(true_positives)\n","\n"," # compute recall and precision\n"," recall = true_positives / num_annotations\n"," precision = true_positives / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps)\n"," total_recall.append(recall)\n"," total_precision.append(precision)\n"," #print(precision)\n"," # compute average precision\n"," average_precision = compute_ap(recall, precision)\n"," average_precisions[label] = average_precision\n","\n"," return average_precisions, total_recall, total_precision\n","\n","\n","def show_frame(pred_bb, pred_classes, pred_conf, gt_bb, gt_classes, class_dict, background=np.zeros((512, 512, 3)), show_confidence=True):\n"," \"\"\"\n"," Here, we are adapting classes and functions from https://github.com/MathGaron/mean_average_precision\n"," \"\"\"\n"," \"\"\"\n"," Plot the boundingboxes\n"," :param pred_bb: (np.array) Predicted Bounding Boxes [x1, y1, x2, y2] : Shape [n_pred, 4]\n"," :param pred_classes: (np.array) Predicted Classes : Shape [n_pred]\n"," :param pred_conf: (np.array) Predicted Confidences [0.-1.] : Shape [n_pred]\n"," :param gt_bb: (np.array) Ground Truth Bounding Boxes [x1, y1, x2, y2] : Shape [n_gt, 4]\n"," :param gt_classes: (np.array) Ground Truth Classes : Shape [n_gt]\n"," :param class_dict: (dictionary) Key value pairs of classes, e.g. {0:'dog',1:'cat',2:'horse'}\n"," :return:\n"," \"\"\"\n"," n_pred = pred_bb.shape[0]\n"," n_gt = gt_bb.shape[0]\n"," n_class = int(np.max(np.append(pred_classes, gt_classes)) + 1)\n"," #print(n_class)\n"," if len(background.shape) < 3:\n"," h, w = background.shape\n"," else:\n"," h, w, c = background.shape\n","\n"," ax = plt.subplot(\"111\")\n"," ax.imshow(background)\n"," cmap = plt.cm.get_cmap('hsv')\n","\n"," confidence_alpha = pred_conf.copy()\n"," if not show_confidence:\n"," confidence_alpha.fill(1)\n","\n"," for i in range(n_pred):\n"," x1 = pred_bb[i, 0]# * w\n"," y1 = pred_bb[i, 1]# * h\n"," x2 = pred_bb[i, 2]# * w\n"," y2 = pred_bb[i, 3]# * h\n"," rect_w = x2 - x1\n"," rect_h = y2 - y1\n"," #print(x1, y1)\n"," ax.add_patch(patches.Rectangle((x1, y1), rect_w, rect_h,\n"," fill=False,\n"," edgecolor=cmap(float(pred_classes[i]) / n_class),\n"," linestyle='dashdot',\n"," alpha=confidence_alpha[i]))\n","\n"," for i in range(n_gt):\n"," x1 = gt_bb[i, 0]# * w\n"," y1 = gt_bb[i, 1]# * h\n"," x2 = gt_bb[i, 2]# * w\n"," y2 = gt_bb[i, 3]# * h\n"," rect_w = x2 - x1\n"," rect_h = y2 - y1\n"," ax.add_patch(patches.Rectangle((x1, y1), rect_w, rect_h,\n"," fill=False,\n"," edgecolor=cmap(float(gt_classes[i]) / n_class)))\n","\n"," legend_handles = []\n","\n"," for i in range(n_class):\n"," legend_handles.append(patches.Patch(color=cmap(float(i) / n_class), label=class_dict[i]))\n"," \n"," ax.legend(handles=legend_handles)\n"," plt.show()\n","\n","class BoundBox:\n"," \"\"\"\n"," Here, we are adapting classes and functions from https://github.com/MathGaron/mean_average_precision\n"," \"\"\"\n"," def __init__(self, xmin, ymin, xmax, ymax, c = None, classes = None):\n"," self.xmin = xmin\n"," self.ymin = ymin\n"," self.xmax = xmax\n"," self.ymax = ymax\n"," \n"," self.c = c\n"," self.classes = classes\n","\n"," self.label = -1\n"," self.score = -1\n","\n"," def get_label(self):\n"," if self.label == -1:\n"," self.label = np.argmax(self.classes)\n"," \n"," return self.label\n"," \n"," def get_score(self):\n"," if self.score == -1:\n"," self.score = self.classes[self.get_label()]\n"," \n"," return self.score\n","\n","class WeightReader:\n"," def __init__(self, weight_file):\n"," self.offset = 4\n"," self.all_weights = np.fromfile(weight_file, dtype='float32')\n"," \n"," def read_bytes(self, size):\n"," self.offset = self.offset + size\n"," return self.all_weights[self.offset-size:self.offset]\n"," \n"," def reset(self):\n"," self.offset = 4\n","\n","def bbox_iou(box1, box2):\n"," intersect_w = _interval_overlap([box1.xmin, box1.xmax], [box2.xmin, box2.xmax])\n"," intersect_h = _interval_overlap([box1.ymin, box1.ymax], [box2.ymin, box2.ymax]) \n"," \n"," intersect = intersect_w * intersect_h\n","\n"," w1, h1 = box1.xmax-box1.xmin, box1.ymax-box1.ymin\n"," w2, h2 = box2.xmax-box2.xmin, box2.ymax-box2.ymin\n"," \n"," union = w1*h1 + w2*h2 - intersect\n"," \n"," return float(intersect) / union\n","\n","def draw_boxes(image, boxes, labels):\n"," image_h, image_w, _ = image.shape\n"," #Changes in box color added by LvC\n"," # class_colours = []\n"," # for c in range(len(labels)):\n"," # colour = np.random.randint(low=0,high=255,size=3).tolist()\n"," # class_colours.append(tuple(colour))\n"," for box in boxes:\n"," xmin = int(box.xmin*image_w)\n"," ymin = int(box.ymin*image_h)\n"," xmax = int(box.xmax*image_w)\n"," ymax = int(box.ymax*image_h)\n"," if box.get_label() == 0:\n"," cv2.rectangle(image, (xmin,ymin), (xmax,ymax), (255,0,0), 3)\n"," elif box.get_label() == 1:\n"," cv2.rectangle(image, (xmin,ymin), (xmax,ymax), (0,255,0), 3)\n"," else:\n"," cv2.rectangle(image, (xmin,ymin), (xmax,ymax), (0,0,255), 3)\n"," #cv2.rectangle(image, (xmin,ymin), (xmax,ymax), class_colours[box.get_label()], 3)\n"," cv2.putText(image, \n"," labels[box.get_label()] + ' ' + str(round(box.get_score(),3)), \n"," (xmin, ymin - 13), \n"," cv2.FONT_HERSHEY_SIMPLEX, \n"," 1e-3 * image_h, \n"," (0,0,0), 2)\n"," #print(box.get_label()) \n"," return image \n","\n","#Function added by LvC\n","def save_boxes(image_path, boxes, labels):#, save_path):\n"," image = cv2.imread(image_path)\n"," image_h, image_w, _ = image.shape\n"," save_boxes =[]\n"," save_boxes_names = []\n"," save_boxes.append(os.path.basename(image_path))\n"," save_boxes_names.append(os.path.basename(image_path))\n"," for box in boxes:\n"," # xmin = box.xmin\n"," save_boxes.append(int(box.xmin*image_w))\n"," save_boxes_names.append(int(box.xmin*image_w))\n"," # ymin = box.ymin\n"," save_boxes.append(int(box.ymin*image_h))\n"," save_boxes_names.append(int(box.ymin*image_h))\n"," # xmax = box.xmax\n"," save_boxes.append(int(box.xmax*image_w))\n"," save_boxes_names.append(int(box.xmax*image_w))\n"," # ymax = box.ymax\n"," save_boxes.append(int(box.ymax*image_h))\n"," save_boxes_names.append(int(box.ymax*image_h))\n"," score = box.get_score()\n"," save_boxes.append(score)\n"," save_boxes_names.append(score)\n"," label = box.get_label()\n"," save_boxes.append(label)\n"," save_boxes_names.append(labels[label])\n"," \n"," #This file will be for later analysis of the bounding boxes in imagej\n"," if not os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," with open('/content/predicted_bounding_boxes.csv', 'w', newline='') as csvfile:\n"," csvwriter = csv.writer(csvfile, delimiter=',')\n"," specs_list = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*len(boxes)\n"," csvwriter.writerow(specs_list)\n"," csvwriter.writerow(save_boxes)\n"," else:\n"," with open('/content/predicted_bounding_boxes.csv', 'a+', newline='') as csvfile:\n"," csvwriter = csv.writer(csvfile)\n"," csvwriter.writerow(save_boxes)\n"," \n"," if not os.path.exists('/content/predicted_bounding_boxes_names.csv'):\n"," with open('/content/predicted_bounding_boxes_names.csv', 'w', newline='') as csvfile_names:\n"," csvwriter = csv.writer(csvfile_names, delimiter=',')\n"," specs_list = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*len(boxes)\n"," csvwriter.writerow(specs_list)\n"," csvwriter.writerow(save_boxes_names)\n"," else:\n"," with open('/content/predicted_bounding_boxes_names.csv', 'a+', newline='') as csvfile_names:\n"," csvwriter = csv.writer(csvfile_names)\n"," csvwriter.writerow(save_boxes_names)\n"," # #This file is to create a nicer display for the output images\n"," # if not os.path.exists('/content/predicted_bounding_boxes_display.csv'):\n"," # with open('/content/predicted_bounding_boxes_display.csv', 'w', newline='') as csvfile_new:\n"," # csvwriter2 = csv.writer(csvfile_new, delimiter=',')\n"," # specs_list = ['filename','width','height','class','xmin','ymin','xmax','ymax']\n"," # csvwriter2.writerow(specs_list)\n"," # else:\n"," # with open('/content/predicted_bounding_boxes_display.csv','a+',newline='') as csvfile_new:\n"," # csvwriter2 = csv.writer(csvfile_new)\n"," # for box in boxes:\n"," # row = [os.path.basename(image_path),image_w,image_h,box.get_label(),int(box.xmin*image_w),int(box.ymin*image_h),int(box.xmax*image_w),int(box.ymax*image_h)]\n"," # csvwriter2.writerow(row)\n","\n","def add_header(inFilePath,outFilePath):\n"," header = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*max(n_objects)\n"," with open(inFilePath, newline='') as inFile, open(outFilePath, 'w', newline='') as outfile:\n"," r = csv.reader(inFile)\n"," w = csv.writer(outfile)\n"," next(r, None) # skip the first row from the reader, the old header\n"," # write new header\n"," w.writerow(header)\n"," # copy the rest\n"," for row in r:\n"," w.writerow(row)\n"," \n","def decode_netout(netout, anchors, nb_class, obj_threshold=0.3, nms_threshold=0.5):\n"," grid_h, grid_w, nb_box = netout.shape[:3]\n","\n"," boxes = []\n"," \n"," # decode the output by the network\n"," netout[..., 4] = _sigmoid(netout[..., 4])\n"," netout[..., 5:] = netout[..., 4][..., np.newaxis] * _softmax(netout[..., 5:])\n"," netout[..., 5:] *= netout[..., 5:] > obj_threshold\n"," \n"," for row in range(grid_h):\n"," for col in range(grid_w):\n"," for b in range(nb_box):\n"," # from 4th element onwards are confidence and class classes\n"," classes = netout[row,col,b,5:]\n"," \n"," if np.sum(classes) > 0:\n"," # first 4 elements are x, y, w, and h\n"," x, y, w, h = netout[row,col,b,:4]\n","\n"," x = (col + _sigmoid(x)) / grid_w # center position, unit: image width\n"," y = (row + _sigmoid(y)) / grid_h # center position, unit: image height\n"," w = anchors[2 * b + 0] * np.exp(w) / grid_w # unit: image width\n"," h = anchors[2 * b + 1] * np.exp(h) / grid_h # unit: image height\n"," confidence = netout[row,col,b,4]\n"," \n"," box = BoundBox(x-w/2, y-h/2, x+w/2, y+h/2, confidence, classes)\n"," \n"," boxes.append(box)\n","\n"," # suppress non-maximal boxes\n"," for c in range(nb_class):\n"," sorted_indices = list(reversed(np.argsort([box.classes[c] for box in boxes])))\n","\n"," for i in range(len(sorted_indices)):\n"," index_i = sorted_indices[i]\n"," \n"," if boxes[index_i].classes[c] == 0: \n"," continue\n"," else:\n"," for j in range(i+1, len(sorted_indices)):\n"," index_j = sorted_indices[j]\n"," \n"," if bbox_iou(boxes[index_i], boxes[index_j]) >= nms_threshold:\n"," boxes[index_j].classes[c] = 0\n"," \n"," # remove the boxes which are less likely than a obj_threshold\n"," boxes = [box for box in boxes if box.get_score() > obj_threshold]\n"," \n"," return boxes\n","\n","def replace(file_path, pattern, subst):\n"," #Create temp file\n"," fh, abs_path = mkstemp()\n"," with fdopen(fh,'w') as new_file:\n"," with open(file_path) as old_file:\n"," for line in old_file:\n"," new_file.write(line.replace(pattern, subst))\n"," #Copy the file permissions from the old file to the new file\n"," copymode(file_path, abs_path)\n"," #Remove original file\n"," remove(file_path)\n"," #Move new file\n"," move(abs_path, file_path)\n","\n","with open(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"r\") as check:\n"," lineReader = check.readlines()\n"," reduce_lr = False\n"," for line in lineReader:\n"," if \"reduce_lr\" in line:\n"," reduce_lr = True\n"," break\n","\n","if reduce_lr == False:\n"," #replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\",\"period=1)\",\"period=1)\\n csv_logger=CSVLogger('/content/training_evaluation.csv')\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\",\"period=1)\",\"period=1)\\n reduce_lr=ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, verbose=1)\")\n","replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\",\"import EarlyStopping\",\"import ReduceLROnPlateau, EarlyStopping\")\n","\n","with open(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"r\") as check:\n"," lineReader = check.readlines()\n"," map_eval = False\n"," for line in lineReader:\n"," if \"map_evaluation\" in line:\n"," map_eval = True\n"," break\n","\n","if map_eval == False:\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"import cv2\",\"import cv2\\nfrom map_evaluation import MapEvaluation\")\n"," new_callback = ' map_evaluator = MapEvaluation(self, valid_generator,save_best=True,save_name=\"/content/gdrive/My Drive/keras-yolo2/best_map_weights.h5\",iou_threshold=0.3,score_threshold=0.3)'\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\",\"write_images=False)\",\"write_images=False)\\n\"+new_callback)\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\"import keras\",\"import keras\\nimport csv\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\"from .utils\",\"from utils\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\".format(_map))\",\".format(_map))\\n with open('/content/gdrive/My Drive/mAP.csv','a+', newline='') as mAP_csv:\\n csv_writer=csv.writer(mAP_csv)\\n csv_writer.writerow(['mAP:','{:.4f}'.format(_map)])\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\"iou_threshold=0.5\",\"iou_threshold=0.3\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\"score_threshold=0.5\",\"score_threshold=0.3\")\n","\n","replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"[early_stop, checkpoint, tensorboard]\",\"[checkpoint, reduce_lr, map_evaluator]\")\n","replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"predict(self, image)\",\"predict(self,image,iou_threshold=0.3,score_threshold=0.3)\")\n","replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"self.model.summary()\",\"#self.model.summary()\")\n","from frontend import YOLO\n","\n","def train(config_path, model_path, percentage_validation):\n"," #config_path = args.conf\n","\n"," with open(config_path) as config_buffer: \n"," config = json.loads(config_buffer.read())\n","\n"," ###############################\n"," # Parse the annotations \n"," ###############################\n","\n"," # parse annotations of the training set\n"," train_imgs, train_labels = parse_annotation(config['train']['train_annot_folder'], \n"," config['train']['train_image_folder'], \n"," config['model']['labels'])\n","\n"," # parse annotations of the validation set, if any, otherwise split the training set\n"," if os.path.exists(config['valid']['valid_annot_folder']):\n"," valid_imgs, valid_labels = parse_annotation(config['valid']['valid_annot_folder'], \n"," config['valid']['valid_image_folder'], \n"," config['model']['labels'])\n"," else:\n"," train_valid_split = int((1-percentage_validation/100.)*len(train_imgs))\n"," np.random.shuffle(train_imgs)\n","\n"," valid_imgs = train_imgs[train_valid_split:]\n"," train_imgs = train_imgs[:train_valid_split]\n","\n"," if len(config['model']['labels']) > 0:\n"," overlap_labels = set(config['model']['labels']).intersection(set(train_labels.keys()))\n","\n"," print('Seen labels:\\t', train_labels)\n"," print('Given labels:\\t', config['model']['labels'])\n"," print('Overlap labels:\\t', overlap_labels) \n","\n"," if len(overlap_labels) < len(config['model']['labels']):\n"," print('Some labels have no annotations! Please revise the list of labels in the config.json file!')\n"," return\n"," else:\n"," print('No labels are provided. Train on all seen labels.')\n"," config['model']['labels'] = train_labels.keys()\n"," \n"," ###############################\n"," # Construct the model \n"," ###############################\n","\n"," yolo = YOLO(backend = config['model']['backend'],\n"," input_size = config['model']['input_size'], \n"," labels = config['model']['labels'], \n"," max_box_per_image = config['model']['max_box_per_image'],\n"," anchors = config['model']['anchors'])\n","\n"," ###############################\n"," # Load the pretrained weights (if any) \n"," ############################### \n","\n"," if os.path.exists(config['train']['pretrained_weights']):\n"," print(\"Loading pre-trained weights in\", config['train']['pretrained_weights'])\n"," yolo.load_weights(config['train']['pretrained_weights'])\n"," if os.path.exists('/content/gdrive/My Drive/mAP.csv'):\n"," os.remove('/content/gdrive/My Drive/mAP.csv')\n"," ###############################\n"," # Start the training process \n"," ###############################\n","\n"," yolo.train(train_imgs = train_imgs,\n"," valid_imgs = valid_imgs,\n"," train_times = config['train']['train_times'],\n"," valid_times = config['valid']['valid_times'],\n"," nb_epochs = config['train']['nb_epochs'], \n"," learning_rate = config['train']['learning_rate'], \n"," batch_size = config['train']['batch_size'],\n"," warmup_epochs = config['train']['warmup_epochs'],\n"," object_scale = config['train']['object_scale'],\n"," no_object_scale = config['train']['no_object_scale'],\n"," coord_scale = config['train']['coord_scale'],\n"," class_scale = config['train']['class_scale'],\n"," saved_weights_name = config['train']['saved_weights_name'],\n"," debug = config['train']['debug'])\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n"," lossDataCSVpath = os.path.join(model_path,'Quality Control/training_evaluation.csv')\n"," with open(lossDataCSVpath, 'w') as f1:\n"," writer = csv.writer(f1)\n"," mAP_df = pd.read_csv('/content/gdrive/My Drive/mAP.csv',header=None)\n"," writer.writerow(['loss','val_loss','mAP','learning rate'])\n"," for i in range(len(yolo.model.history.history['loss'])):\n"," writer.writerow([yolo.model.history.history['loss'][i], yolo.model.history.history['val_loss'][i], float(mAP_df[1][i]), yolo.model.history.history['lr'][i]])\n","\n"," yolo.model.save(model_path+'/last_weights.h5')\n","\n","def predict(config, weights_path, image_path):#, model_path):\n","\n"," with open(config) as config_buffer: \n"," config = json.load(config_buffer)\n","\n"," ###############################\n"," # Make the model \n"," ###############################\n","\n"," yolo = YOLO(backend = config['model']['backend'],\n"," input_size = config['model']['input_size'], \n"," labels = config['model']['labels'], \n"," max_box_per_image = config['model']['max_box_per_image'],\n"," anchors = config['model']['anchors'])\n","\n"," ###############################\n"," # Load trained weights\n"," ############################### \n","\n"," yolo.load_weights(weights_path)\n","\n"," ###############################\n"," # Predict bounding boxes \n"," ###############################\n","\n"," if image_path[-4:] == '.mp4':\n"," video_out = image_path[:-4] + '_detected' + image_path[-4:]\n"," video_reader = cv2.VideoCapture(image_path)\n","\n"," nb_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT))\n"," frame_h = int(video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT))\n"," frame_w = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH))\n","\n"," video_writer = cv2.VideoWriter(video_out,\n"," cv2.VideoWriter_fourcc(*'MPEG'), \n"," 50.0, \n"," (frame_w, frame_h))\n","\n"," for i in tqdm(range(nb_frames)):\n"," _, image = video_reader.read()\n"," \n"," boxes = yolo.predict(image)\n"," image = draw_boxes(image, boxes, config['model']['labels'])\n","\n"," video_writer.write(np.uint8(image))\n","\n"," video_reader.release()\n"," video_writer.release() \n"," else:\n"," image = cv2.imread(image_path)\n"," boxes = yolo.predict(image)\n"," image = draw_boxes(image, boxes, config['model']['labels'])\n"," save_boxes(image_path,boxes,config['model']['labels'])#,model_path)#added by LvC\n"," print(len(boxes), 'boxes are found')\n"," #print(image)\n"," cv2.imwrite(image_path[:-4] + '_detected' + image_path[-4:], image)\n"," \n"," return len(boxes)\n","\n","# function to convert BoundingBoxesOnImage object into DataFrame\n","def bbs_obj_to_df(bbs_object):\n","# convert BoundingBoxesOnImage object into array\n"," bbs_array = bbs_object.to_xyxy_array()\n","# convert array into a DataFrame ['xmin', 'ymin', 'xmax', 'ymax'] columns\n"," df_bbs = pd.DataFrame(bbs_array, columns=['xmin', 'ymin', 'xmax', 'ymax'])\n"," return df_bbs\n","\n","# Function that will extract column data for our CSV file\n","def xml_to_csv(path):\n"," xml_list = []\n"," for xml_file in glob.glob(path + '/*.xml'):\n"," tree = ET.parse(xml_file)\n"," root = tree.getroot()\n"," for member in root.findall('object'):\n"," value = (root.find('filename').text,\n"," int(root.find('size')[0].text),\n"," int(root.find('size')[1].text),\n"," member[0].text,\n"," int(member[4][0].text),\n"," int(member[4][1].text),\n"," int(member[4][2].text),\n"," int(member[4][3].text)\n"," )\n"," xml_list.append(value)\n"," column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']\n"," xml_df = pd.DataFrame(xml_list, columns=column_name)\n"," return xml_df"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4","colab_type":"text"},"source":["# **3. Select your paths and parameters**\n","\n","---\n","\n","The code below allows the user to enter the paths to where the training data is and to define the training parameters.\n","\n","After playing the cell will display some quantitative metrics of your dataset, including a count of objects per image and the number of instances per class.\n"]},{"cell_type":"markdown","metadata":{"id":"grFtuWsY5LZm","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd","colab_type":"text"},"source":[" **Paths for training, predictions and results**\n","\n","\n","**`Training_source:`, `Training_source_annotations`:** These are the paths to your folders containing the Training_source and the annotation data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:**Give estimates for training performance given a number of epochs and provide a default value. **Default value: 27**\n","\n","**Note that YOLOv2 uses 3 Warm-up epochs which improves the model's performance. This means the network will train for number_of_epochs + 3 epochs.**\n","\n","**`backend`:** There are different backends which are available to be trained for YOLO. These are usually slightly different model architectures, with pretrained weights. Take a look at the available backends and research which one will be best suited for your dataset.\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`train_times:`**Input how many times to cycle through the dataset per epoch. This is more useful for smaller datasets (but risks overfitting). **Default value: 4**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`learning_rate:`** Input the initial value to be used as learning rate. **Default value: 0.0004**\n","\n","**`false_negative_penalty:`** Penalize wrong detection of 'no-object'. **Default: 5.0**\n","\n","**`false_positive_penalty:`** Penalize wrong detection of 'object'. **Default: 1.0**\n","\n","**`position_size_penalty:`** Penalize inaccurate positioning or size of bounding boxes. **Default:1.0**\n","\n","**`false_class_penalty:`** Penalize misclassification of object in bounding box. **Default: 1.0**\n","\n","**`percentage_validation:`** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** "]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","colab_type":"code","cellView":"form","colab":{}},"source":["class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#@markdown ###Path to training images:\n","\n","Training_Source = \"\" #@param {type:\"string\"}\n","\n","# Ground truth images\n","Training_Source_annotations = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# backend\n","#@markdown ###Choose a backend\n","#os.chdir(model_path+'/keras-yolo2')\n","backend = \"Full Yolo\" #@param [\"Select Model\",\"Full Yolo\",\"Inception3\",\"SqueezeNet\",\"MobileNet\",\"Tiny Yolo\"]\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","if backend == \"Full Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/full_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/full_yolo_backend.h5\n","elif backend == \"Inception3\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/inception_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/inception_backend.h5\n","elif backend == \"MobileNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/mobilenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/mobilenet_backend.h5\n","elif backend == \"SqueezeNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/squeezenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/squeezenet_backend.h5\n","elif backend == \"Tiny Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/tiny_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/tiny_yolo_backend.h5\n","\n","#os.chdir('/content/drive/My Drive/Zero-Cost Deep-Learning to Enhance Microscopy/Various dataset/Detection_Dataset_2/BCCD.v2.voc')\n","#if not os.path.exists(model_path+'/full_raccoon.h5'):\n"," # !wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1NWbrpMGLc84ow-4gXn2mloFocFGU595s' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id=1NWbrpMGLc84ow-4gXn2mloFocFGU595s\" -O full_yolo_raccoon.h5 && rm -rf /tmp/cookies.txt\n","\n","full_model_path = os.path.join(model_path,model_name)\n","if os.path.exists(full_model_path):\n"," print('Existing model path will be overwritten')\n"," shutil.rmtree(full_model_path)\n","os.mkdir(full_model_path)\n","\n","full_model_file_path = full_model_path+'/best_weights.h5'\n","os.chdir('/content/gdrive/My Drive/keras-yolo2/')\n","\n","#Change backend name\n","!sed -i 's@\\\"backend\\\":.*,@\\\"backend\\\": \\\"$backend\\\",@g' config.json\n","\n","#Change the name of the training folder\n","!sed -i 's@\\\"train_image_folder\\\":.*,@\\\"train_image_folder\\\": \\\"$Training_Source/\\\",@g' config.json\n","\n","#Change annotation folder\n","!sed -i 's@\\\"train_annot_folder\\\":.*,@\\\"train_annot_folder\\\": \\\"$Training_Source_annotations/\\\",@g' config.json\n","\n","#Change the name of the saved model\n","!sed -i 's@\\\"saved_weights_name\\\":.*,@\\\"saved_weights_name\\\": \\\"$full_model_file_path\\\",@g' config.json\n","\n","#Change warmup epochs for untrained model\n","!sed -i 's@\\\"warmup_epochs\\\":.*,@\\\"warmup_epochs\\\": 3,@g' config.json\n","\n","#When defining a new model we should reset the pretrained model parameter\n","!sed -i 's@\\\"pretrained_weights\\\":.*,@\\\"pretrained_weights\\\": \\\"No_pretrained_weights\\\",@g' config.json\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","\n","number_of_epochs = 10#@param {type:\"number\"}\n","!sed -i 's@\\\"nb_epochs\\\":.*,@\\\"nb_epochs\\\": $number_of_epochs,@g' config.json\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","train_times = 4 #@param {type:\"integer\"}\n","batch_size = 4#@param {type:\"number\"}\n","learning_rate = 1e-4 #@param{type:\"number\"}\n","false_negative_penalty = 5.0 #@param{type:\"number\"}\n","false_positive_penalty = 2.0 #@param{type:\"number\"}\n","position_size_penalty = 1.0 #@param{type:\"number\"}\n","false_class_penalty = 1.0 #@param{type:\"number\"}\n","percentage_validation = 10#@param{type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," train_times = 4\n"," batch_size = 8\n"," learning_rate = 1e-4\n"," false_negative_penalty = 5.0\n"," false_positive_penalty = 1.0\n"," position_size_penalty = 1.0\n"," false_class_penalty = 1.0\n"," percentage_validation = 10\n","\n","!sed -i 's@\\\"train_times\\\":.*,@\\\"train_times\\\": $train_times,@g' config.json\n","!sed -i 's@\\\"batch_size\\\":.*,@\\\"batch_size\\\": $batch_size,@g' config.json\n","!sed -i 's@\\\"learning_rate\\\":.*,@\\\"learning_rate\\\": $learning_rate,@g' config.json\n","!sed -i 's@\\\"object_scale\":.*,@\\\"object_scale\\\": $false_negative_penalty,@g' config.json\n","!sed -i 's@\\\"no_object_scale\":.*,@\\\"no_object_scale\\\": $false_positive_penalty,@g' config.json\n","!sed -i 's@\\\"coord_scale\\\":.*,@\\\"coord_scale\\\": $position_size_penalty,@g' config.json\n","!sed -i 's@\\\"class_scale\\\":.*,@\\\"class_scale\\\": $false_class_penalty,@g' config.json\n","\n","df_anno = []\n","dir_anno = Training_Source_annotations\n","for fnm in os.listdir(dir_anno): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(dir_anno,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno.append(row)\n","df_anno = pd.DataFrame(df_anno)\n","\n","maxNobj = np.max(df_anno[\"Nobj\"])\n","\n","#Write the annotations to a csv file\n","df_anno.to_csv(model_path+'/annot.csv', index=False)#header=False, sep=',')\n","\n","file_suffix = os.path.splitext(os.listdir(Training_Source)[0])[1]\n","\n","#Show how many objects there are in the images\n","plt.figure()\n","plt.subplot(2,1,1)\n","plt.hist(df_anno[\"Nobj\"].values,bins=50)\n","plt.title(\"max N of objects per image={}\".format(maxNobj))\n","plt.show()\n","\n","#Show the classes and how many there are of each in the dataset\n","from collections import Counter\n","class_obj = []\n","for ibbx in range(maxNobj):\n"," class_obj.extend(df_anno[\"bbx_{}_name\".format(ibbx)].values)\n","class_obj = np.array(class_obj)\n","\n","count = Counter(class_obj[class_obj != 'nan'])\n","print(count)\n","class_nm = list(count.keys())\n","class_labels = json.dumps(class_nm)\n","class_count = list(count.values())\n","asort_class_count = np.argsort(class_count)\n","\n","class_nm = np.array(class_nm)[asort_class_count]\n","class_count = np.array(class_count)[asort_class_count]\n","\n","!sed -i 's@\\\"labels\\\":.*@\\\"labels\\\": $class_labels@g' config.json\n","xs = range(len(class_count))\n","\n","plt.subplot(2,1,2)\n","plt.barh(xs,class_count)\n","plt.yticks(xs,class_nm)\n","plt.title(\"The number of objects per class: {} objects in total\".format(len(count)))\n","plt.show()\n","\n","\n","#Generate anchors for the bounding boxes\n","import subprocess as sp\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","output = sp.getoutput('python ./gen_anchors.py -c ./config.json')\n","\n","anchors_1 = output.find(\"[\")\n","anchors_2 = output.find(\"]\")\n","\n","config_anchors = output[anchors_1:anchors_2+1]\n","!sed -i 's@\\\"anchors\\\":.*,@\\\"anchors\\\": $config_anchors,@g' config.json\n","#here we check that no model with the same name already exist, if so delete\n","#if os.path.exists(model_path+'/'+model_name):\n"," # shutil.rmtree(model_path+'/'+model_name)\n","\n","Use_pretrained_model = False"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab_type":"code","cellView":"form","id":"NXxj-Xi3Kang","colab":{}},"source":["#@markdown ###Play this cell to visualise some example images from your dataset to make sure annotations and images are properly matched.\n","import imageio\n"," \n","size = 3 \n","ind_random = np.random.randint(0,df_anno.shape[0],size=size)\n","img_dir=Training_Source\n","\n","file_suffix = os.path.splitext(os.listdir(Training_Source)[0])[1]\n","for irow in ind_random:\n"," row = df_anno.iloc[irow,:]\n"," path = os.path.join(img_dir, row[\"fileID\"] + file_suffix)\n"," # read in image\n"," img = imageio.imread(path)\n","\n"," plt.figure(figsize=(12,12))\n"," plt.imshow(img) # plot image\n"," plt.title(\"Nobj={}, height={}, width={}\".format(row[\"Nobj\"],row[\"height\"],row[\"width\"]))\n"," # for each object in the image, plot the bounding box\n"," for iplot in range(row[\"Nobj\"]):\n"," plt_rectangle(plt,\n"," label = row[\"bbx_{}_name\".format(iplot)],\n"," x1=row[\"bbx_{}_xmin\".format(iplot)],\n"," y1=row[\"bbx_{}_ymin\".format(iplot)],\n"," x2=row[\"bbx_{}_xmax\".format(iplot)],\n"," y2=row[\"bbx_{}_ymax\".format(iplot)])\n"," plt.show() ## show the plot"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"eik5zLKWpN_O","colab_type":"text"},"source":["##**3.2. Data augmentation**\n","\n","---\n","\n"," Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if the dataset the `Use_Data_Augmentation` box can be unticked.\n","\n","Here, the images and bounding boxes are augmented by flipping and rotation. When doubling the dataset the images are only flipped. With each higher factor of augmentation the images added to the dataset represent one further rotation to the right by 90 degrees. 8x augmentation will give a dataset that is fully rotated and flipped once."]},{"cell_type":"code","metadata":{"id":"RmTSfMO-pNMc","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##**Augmentation Options**\n","\n","def image_aug(df, images_path, aug_images_path, image_prefix, augmentor):\n"," # create data frame which we're going to populate with augmented image info\n"," aug_bbs_xy = pd.DataFrame(columns=\n"," ['filename','width','height','class', 'xmin', 'ymin', 'xmax', 'ymax']\n"," )\n"," grouped = df.groupby('filename')\n"," \n"," for filename in df['filename'].unique():\n"," # get separate data frame grouped by file name\n"," group_df = grouped.get_group(filename)\n"," group_df = group_df.reset_index()\n"," group_df = group_df.drop(['index'], axis=1) \n"," # read the image\n"," image = imageio.imread(images_path+filename)\n"," # get bounding boxes coordinates and write into array \n"," bb_array = group_df.drop(['filename', 'width', 'height', 'class'], axis=1).values\n"," # pass the array of bounding boxes coordinates to the imgaug library\n"," bbs = BoundingBoxesOnImage.from_xyxy_array(bb_array, shape=image.shape)\n"," # apply augmentation on image and on the bounding boxes\n"," image_aug, bbs_aug = augmentor(image=image, bounding_boxes=bbs)\n"," # disregard bounding boxes which have fallen out of image pane \n"," bbs_aug = bbs_aug.remove_out_of_image()\n"," # clip bounding boxes which are partially outside of image pane\n"," bbs_aug = bbs_aug.clip_out_of_image()\n"," \n"," # don't perform any actions with the image if there are no bounding boxes left in it \n"," if re.findall('Image...', str(bbs_aug)) == ['Image([]']:\n"," pass\n"," \n"," # otherwise continue\n"," else:\n"," # write augmented image to a file\n"," imageio.imwrite(aug_images_path+image_prefix+filename, image_aug) \n"," # create a data frame with augmented values of image width and height\n"," info_df = group_df.drop(['xmin', 'ymin', 'xmax', 'ymax'], axis=1) \n"," for index, _ in info_df.iterrows():\n"," info_df.at[index, 'width'] = image_aug.shape[1]\n"," info_df.at[index, 'height'] = image_aug.shape[0]\n"," # rename filenames by adding the predifined prefix\n"," info_df['filename'] = info_df['filename'].apply(lambda x: image_prefix+x)\n"," # create a data frame with augmented bounding boxes coordinates using the function we created earlier\n"," bbs_df = bbs_obj_to_df(bbs_aug)\n"," # concat all new augmented info into new data frame\n"," aug_df = pd.concat([info_df, bbs_df], axis=1)\n"," # append rows to aug_bbs_xy data frame\n"," aug_bbs_xy = pd.concat([aug_bbs_xy, aug_df]) \n"," \n"," # return dataframe with updated images and bounding boxes annotations \n"," aug_bbs_xy = aug_bbs_xy.reset_index()\n"," aug_bbs_xy = aug_bbs_xy.drop(['index'], axis=1)\n"," return aug_bbs_xy\n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","multiply_dataset_by = 3 #@param {type:\"slider\", min:2, max:8, step:1}\n","\n","rotation_range = 90\n","\n","if (Use_Data_augmentation):\n"," print('Data Augmentation enabled')\n"," # load images as NumPy arrays and append them to images list\n"," if os.path.exists(Training_Source+'/.ipynb_checkpoints'):\n"," shutil.rmtree(Training_Source+'/.ipynb_checkpoints')\n"," \n"," images = []\n"," for index, file in enumerate(glob.glob(Training_Source+'/*'+file_suffix)):\n"," images.append(imageio.imread(file))\n"," \n"," # how many images we have\n"," print('Augmenting {} images'.format(len(images)))\n","\n"," # apply xml_to_csv() function to convert all XML files in images/ folder into labels.csv\n"," labels_df = xml_to_csv(Training_Source_annotations)\n"," labels_df.to_csv(('/content/original_labels.csv'), index=None)\n"," \n"," # Apply flip augmentation\n"," aug = iaa.OneOf([ \n"," iaa.Fliplr(1),\n"," iaa.Flipud(1)\n"," ])\n"," aug_2 = iaa.Affine(rotate=rotation_range, fit_output=True)\n"," aug_3 = iaa.Affine(rotate=rotation_range*2, fit_output=True)\n"," aug_4 = iaa.Affine(rotate=rotation_range*3, fit_output=True)\n","\n"," #Here we create a folder that will hold the original image dataset and the augmented image dataset\n"," augmented_training_source = os.path.dirname(Training_Source)+'/'+os.path.basename(Training_Source)+'_augmentation'\n"," if os.path.exists(augmented_training_source):\n"," shutil.rmtree(augmented_training_source)\n"," os.mkdir(augmented_training_source)\n","\n"," #Here we create a folder that will hold the original image annotation dataset and the augmented image annotation dataset (the bounding boxes).\n"," augmented_training_source_annotation = os.path.dirname(Training_Source_annotations)+'/'+os.path.basename(Training_Source_annotations)+'_augmentation'\n"," if os.path.exists(augmented_training_source_annotation):\n"," shutil.rmtree(augmented_training_source_annotation)\n"," os.mkdir(augmented_training_source_annotation)\n","\n"," #Create the augmentation\n"," augmented_images_df = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'flip_', aug)\n"," \n"," # Concat resized_images_df and augmented_images_df together and save in a new all_labels.csv file\n"," all_labels_df = pd.concat([labels_df, augmented_images_df])\n"," all_labels_df.to_csv('/content/combined_labels.csv', index=False)\n","\n"," #Here we convert the new bounding boxes for the augmented images to PASCAL VOC .xml format\n"," def convert_to_xml(df,source,target_folder):\n"," grouped = df.groupby('filename')\n"," for file in os.listdir(source):\n"," #if file in grouped.filename:\n"," group_df = grouped.get_group(file)\n"," group_df = group_df.reset_index()\n"," group_df = group_df.drop(['index'], axis=1)\n"," #group_df = group_df.dropna(axis=0)\n"," writer = Writer(source+'/'+file,group_df.iloc[1]['width'],group_df.iloc[1]['height'])\n"," for i, row in group_df.iterrows():\n"," writer.addObject(row['class'],round(row['xmin']),round(row['ymin']),round(row['xmax']),round(row['ymax']))\n"," writer.save(target_folder+'/'+os.path.splitext(file)[0]+'.xml')\n"," convert_to_xml(all_labels_df,augmented_training_source,augmented_training_source_annotation)\n"," \n"," #Second round of augmentation\n"," if multiply_dataset_by > 2:\n"," aug_labels_df_2 = xml_to_csv(augmented_training_source_annotation)\n"," augmented_images_2_df = image_aug(aug_labels_df_2, augmented_training_source+'/', augmented_training_source+'/', 'rot1_90_', aug_2)\n"," all_aug_labels_df = pd.concat([augmented_images_df, augmented_images_2_df])\n"," #all_labels_df.to_csv('/content/all_labels_aug.csv', index=False)\n"," \n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df,augmented_training_source,augmented_training_source_annotation)\n","\n"," if multiply_dataset_by > 3:\n"," print('Augmenting again')\n"," aug_labels_df_3 = xml_to_csv(augmented_training_source_annotation)\n"," augmented_images_3_df = image_aug(aug_labels_df_3, augmented_training_source+'/', augmented_training_source+'/', 'rot2_90_', aug_2)\n"," all_aug_labels_df_3 = pd.concat([all_aug_labels_df, augmented_images_3_df])\n","\n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_3,augmented_training_source,augmented_training_source_annotation)\n"," \n"," #This is a preliminary remover of potential duplicates in the augmentation\n"," #Ideally, duplicates are not even produced, but this acts as a fail safe.\n"," if multiply_dataset_by==4:\n"," for file in os.listdir(augmented_training_source):\n"," if file.startswith('rot2_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n","\n"," if multiply_dataset_by > 4:\n"," print('And Again')\n"," aug_labels_df_4 = xml_to_csv(augmented_training_source_annotation)\n"," augmented_images_4_df = image_aug(aug_labels_df_4, augmented_training_source+'/',augmented_training_source+'/','rot3_90_', aug_2)\n"," all_aug_labels_df_4 = pd.concat([all_aug_labels_df_3, augmented_images_4_df])\n","\n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_4,augmented_training_source,augmented_training_source_annotation)\n","\n"," for file in os.listdir(augmented_training_source):\n"," if file.startswith('rot3_90_rot2_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n"," if file.startswith('rot3_90_rot1_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n"," if file.startswith('rot3_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n"," if file.startswith('rot2_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n","\n","\n"," if multiply_dataset_by > 5:\n"," print('And again')\n"," augmented_images_5_df = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'rot_90_', aug_2)\n"," all_aug_labels_df_5 = pd.concat([all_aug_labels_df_4,augmented_images_5_df])\n","\n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," \n"," convert_to_xml(all_aug_labels_df_5,augmented_training_source,augmented_training_source_annotation)\n","\n"," if multiply_dataset_by > 6:\n"," print('And again')\n"," augmented_images_df_6 = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'rot_180_', aug_3)\n"," all_aug_labels_df_6 = pd.concat([all_aug_labels_df_5,augmented_images_df_6])\n"," \n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_6,augmented_training_source,augmented_training_source_annotation)\n","\n"," if multiply_dataset_by > 7:\n"," print('And again')\n"," augmented_images_df_7 = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'rot_270_', aug_4)\n"," all_aug_labels_df_7 = pd.concat([all_aug_labels_df_6,augmented_images_df_7])\n"," \n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_7,augmented_training_source,augmented_training_source_annotation)\n","\n"," for file in os.listdir(Training_Source):\n"," shutil.copyfile(Training_Source+'/'+file,augmented_training_source+'/'+file)\n"," shutil.copyfile(Training_Source_annotations+'/'+os.path.splitext(file)[0]+'.xml',augmented_training_source_annotation+'/'+os.path.splitext(file)[0]+'.xml')\n"," # display new dataframe\n"," #augmented_images_df\n"," \n"," os.chdir('/content/gdrive/My Drive/keras-yolo2')\n"," #Change the name of the training folder\n"," !sed -i 's@\\\"train_image_folder\\\":.*,@\\\"train_image_folder\\\": \\\"$augmented_training_source/\\\",@g' config.json\n","\n"," #Change annotation folder\n"," !sed -i 's@\\\"train_annot_folder\\\":.*,@\\\"train_annot_folder\\\": \\\"$augmented_training_source_annotation/\\\",@g' config.json\n","\n"," df_anno = []\n"," dir_anno = augmented_training_source_annotation\n"," for fnm in os.listdir(dir_anno): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(dir_anno,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno.append(row)\n"," df_anno = pd.DataFrame(df_anno)\n","\n"," maxNobj = np.max(df_anno[\"Nobj\"])\n","\n"," #Write the annotations to a csv file\n"," #df_anno.to_csv(model_path+'/annot.csv', index=False)#header=False, sep=',')\n","\n"," #Show how many objects there are in the images\n"," plt.figure()\n"," plt.subplot(2,1,1)\n"," plt.hist(df_anno[\"Nobj\"].values,bins=50)\n"," plt.title(\"max N of objects per image={}\".format(maxNobj))\n"," plt.show()\n","\n"," #Show the classes and how many there are of each in the dataset\n"," from collections import Counter\n"," class_obj = []\n"," for ibbx in range(maxNobj):\n"," class_obj.extend(df_anno[\"bbx_{}_name\".format(ibbx)].values)\n"," class_obj = np.array(class_obj)\n","\n"," count = Counter(class_obj[class_obj != 'nan'])\n"," print(count)\n"," class_nm = list(count.keys())\n"," class_labels = json.dumps(class_nm)\n"," class_count = list(count.values())\n"," asort_class_count = np.argsort(class_count)\n","\n"," class_nm = np.array(class_nm)[asort_class_count]\n"," class_count = np.array(class_count)[asort_class_count]\n","\n"," xs = range(len(class_count))\n","\n"," plt.subplot(2,1,2)\n"," plt.barh(xs,class_count)\n"," plt.yticks(xs,class_nm)\n"," plt.title(\"The number of objects per class: {} objects in total\".format(len(count)))\n"," plt.show()\n","\n","else:\n"," print('No augmentation will be used')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"tZvcYmxTdXQm","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###Play this cell to visualise some example images from your **augmented** dataset to make sure annotations and images are properly matched.\n","if (Use_Data_augmentation):\n"," df_anno_aug = []\n"," dir_anno_aug = augmented_training_source_annotation\n"," for fnm in os.listdir(dir_anno_aug): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(dir_anno_aug,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno_aug.append(row)\n"," df_anno_aug = pd.DataFrame(df_anno_aug)\n","\n"," size = 3 \n"," ind_random = np.random.randint(0,df_anno_aug.shape[0],size=size)\n"," img_dir=augmented_training_source\n","\n"," file_suffix = os.path.splitext(os.listdir(augmented_training_source)[0])[1]\n"," for irow in ind_random:\n"," row = df_anno_aug.iloc[irow,:]\n"," path = os.path.join(img_dir, row[\"fileID\"] + file_suffix)\n"," # read in image\n"," img = imageio.imread(path)\n","\n"," plt.figure(figsize=(12,12))\n"," plt.imshow(img) # plot image\n"," plt.title(\"Nobj={}, height={}, width={}\".format(row[\"Nobj\"],row[\"height\"],row[\"width\"]))\n"," # for each object in the image, plot the bounding box\n"," for iplot in range(row[\"Nobj\"]):\n"," plt_rectangle(plt,\n"," label = row[\"bbx_{}_name\".format(iplot)],\n"," x1=row[\"bbx_{}_xmin\".format(iplot)],\n"," y1=row[\"bbx_{}_ymin\".format(iplot)],\n"," x2=row[\"bbx_{}_xmax\".format(iplot)],\n"," y2=row[\"bbx_{}_ymax\".format(iplot)])\n"," plt.show() ## show the plot\n"," print('These are the augmented training images.')\n","\n","else:\n"," for irow in ind_random:\n"," row = df_anno.iloc[irow,:]\n"," path = os.path.join(img_dir, row[\"fileID\"] + file_suffix)\n"," # read in image\n"," img = imageio.imread(path)\n","\n"," plt.figure(figsize=(12,12))\n"," plt.imshow(img) # plot image\n"," plt.title(\"Nobj={}, height={}, width={}\".format(row[\"Nobj\"],row[\"height\"],row[\"width\"]))\n"," # for each object in the image, plot the bounding box\n"," for iplot in range(row[\"Nobj\"]):\n"," plt_rectangle(plt,\n"," label = row[\"bbx_{}_name\".format(iplot)],\n"," x1=row[\"bbx_{}_xmin\".format(iplot)],\n"," y1=row[\"bbx_{}_ymin\".format(iplot)],\n"," x2=row[\"bbx_{}_xmax\".format(iplot)],\n"," y2=row[\"bbx_{}_ymax\".format(iplot)])\n"," plt.show() ## show the plot\n"," print('These are the non-augmented training images.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ud_Sx7MT5f4_","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a YOLOv2 model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"_cvRRrStGe3y","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pretrained network\n","\n","# Training_Source = \"\" #@param{type:\"string\"}\n","# Training_Source_annotation = \"\" #@param{type:\"string\"}\n","# Check if the right files exist\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","pretrained_model_path = \"\" #@param{type:\"string\"}\n","h5_file_path = pretrained_model_path+'/'+Weights_choice+'_weights.h5'\n","\n","if not os.path.exists(h5_file_path):\n"," print('WARNING pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","!sed -i 's@\\\"pretrained_weights\\\":.*,@\\\"pretrained_weights\\\": \\\"$h5_file_path\\\",@g' config.json\n","\n","if Use_pretrained_model == True:\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4):\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n"," learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n"," learning_rate = bestLearningRate\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," #bestLearningRate = learning_rate\n"," #lastLearningRate = learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n"," \n"," !sed -i 's@\\\"warmup_epochs\\\":.*,@\\\"warmup_epochs\\\": 0,@g' config.json\n"," !sed -i 's@\\\"learning_rate\\\":.*,@\\\"learning_rate\\\": $learning_rate,@g' config.json\n","\n","# with open(os.path.join(pretrained_model_path, 'Quality Control', 'lr.csv'),'r') as csvfile:\n","# csvRead = pd.read_csv(csvfile, sep=',')\n","# #print(csvRead)\n"," \n","# if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n","# print(\"pretrained network learning rate found\")\n","# #find the last learning rate\n","# lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n","# #Find the learning rate corresponding to the lowest validation loss\n","# min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n","# #print(min_val_loss)\n","# bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n","# if Weights_choice == \"last\":\n","# print('Last learning rate: '+str(lastLearningRate))\n","\n","# if Weights_choice == \"best\":\n","# print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n","# if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n","# bestLearningRate = initial_learning_rate\n","# lastLearningRate = initial_learning_rate\n","# print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL","colab_type":"text"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR","colab_type":"text"},"source":["## **4.1. Start Trainning**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches."]},{"cell_type":"code","metadata":{"id":"EZnoS3rb8BSR","colab_type":"code","cellView":"form","colab":{}},"source":["import time\n","import csv\n","#from frontend import YOLO\n","\n","if os.path.exists(full_model_path+\"/Quality Control\"):\n"," shutil.rmtree(full_model_path+\"/Quality Control\")\n","os.makedirs(full_model_path+\"/Quality Control\")\n","\n","start = time.time()\n","\n","#@markdown ##Start Training\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","train('config.json', full_model_path, percentage_validation)\n","\n","shutil.copyfile('/content/gdrive/My Drive/keras-yolo2/config.json',full_model_path+'/config.json')\n","\n","if os.path.exists('/content/gdrive/My Drive/keras-yolo2/best_map_weights.h5'):\n"," shutil.move('/content/gdrive/My Drive/keras-yolo2/best_map_weights.h5',full_model_path+'/best_map_weights.h5')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"XQjQb_J_Qyku","colab_type":"text"},"source":["##**4.2. Download your model(s) from Google Drive**\n","\n","\n","---\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"2HbZd7rFqAad","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"EdcnkCr9Nbl8","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the name of the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","if (Use_the_current_trained_model): \n"," QC_model_folder = full_model_path\n","\n","#print(os.path.join(model_path, model_name))\n","\n","if os.path.exists(QC_model_folder):\n"," print(\"The \"+os.path.basename(QC_model_folder)+\" model will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n","\n","if Use_the_current_trained_model == False:\n"," if os.path.exists('/content/gdrive/My Drive/keras-yolo2/config.json'):\n"," os.remove('/content/gdrive/My Drive/keras-yolo2/config.json')\n"," shutil.copyfile(QC_model_folder+'/config.json','/content/gdrive/My Drive/keras-yolo2/config.json')\n","\n","#@markdown ###Which backend is the model using?\n","backend = \"Full Yolo\" #@param [\"Select Model\",\"Full Yolo\",\"Inception3\",\"SqueezeNet\",\"MobileNet\",\"Tiny Yolo\"]\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","if backend == \"Full Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/full_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/full_yolo_backend.h5\n","elif backend == \"Inception3\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/inception_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/inception_backend.h5\n","elif backend == \"MobileNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/mobilenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/mobilenet_backend.h5\n","elif backend == \"SqueezeNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/squeezenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/squeezenet_backend.h5\n","elif backend == \"Tiny Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/tiny_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/tiny_yolo_backend.h5\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","import csv\n","from matplotlib import pyplot as plt\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","mAPDataFromCSV = []\n","with open(QC_model_folder+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n"," mAPDataFromCSV.append(float(row[2]))\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(20,15))\n","\n","plt.subplot(3,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(3,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","#plt.savefig(os.path.dirname(QC_model_folder)+'/Quality Control/lossCurvePlots.png')\n","#plt.show()\n","\n","plt.subplot(3,1,3)\n","plt.plot(epochNumber,mAPDataFromCSV, label='mAP score')\n","plt.title('mean average precision (mAP) vs. epoch number (linear scale)')\n","plt.ylabel('mAP score')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_folder+'/Quality Control/lossCurveAndmAPPlots.png')\n","plt.show()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"RZOPCVN0qcYb","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display an overlay of the input images ground-truth (solid lines) and predicted boxes (dashed lines). Additionally, the below cell will show the mAP value of the model on the QC data together with plots of the Precision-Recall curves for all the classes in the dataset. If you want to read in more detail about these scores, we recommend [this brief explanation](https://medium.com/@jonathan_hui/map-mean-average-precision-for-object-detection-45c121a31173).\n","\n"," The images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" should contain images (e.g. as .jpg)and annotations (.xml files)!\n","\n","Since the training saves three different models, for the best validation loss (`best_weights`), best average precision (`best_mAP_weights`) and the model after the last epoch (`last_weights`), you should choose which ones you want to use for quality control or prediction. We recommend using `best_map_weights` because they should yield the best performance on the dataset. However, it can be worth testing how well `best_weights` perform too.\n","\n","**mAP score:** This refers to the mean average precision of the model on the given dataset. This value gives an indication how precise the predictions of the classes on this dataset are when compared to the ground-truth. Values closer to 1 indicate a good fit.\n","\n","**Precision:** This is the proportion of the correct classifications (true positives) in all the predictions made by the model.\n","\n","**Recall:** This is the proportion of the detected true positives in all the detectable data."]},{"cell_type":"code","metadata":{"id":"Nh8MlX3sqd_7","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Annotations_QC_folder = \"\" #@param{type:\"string\"}\n","\n","#@markdown ##Choose which model you want to evaluate:\n","model_choice = \"best_map_weights\" #@param[\"best_weights\",\"last_weights\",\"best_map_weights\"]\n","\n","file_suffix = os.path.splitext(os.listdir(Source_QC_folder)[0])[1]\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_folder+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_folder+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_folder+\"/Quality Control/Prediction\")\n","\n","#Delete old csv with box predictions if one exists\n","\n","if os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," os.remove('/content/predicted_bounding_boxes.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_names.csv'):\n"," os.remove('/content/predicted_bounding_boxes_names.csv')\n","if os.path.exists(Source_QC_folder+'/.ipynb_checkpoints'):\n"," shutil.rmtree(Source_QC_folder+'/.ipynb_checkpoints')\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","n_objects = []\n","for img in os.listdir(Source_QC_folder):\n"," full_image_path = Source_QC_folder+'/'+img\n"," n_obj = predict('config.json',QC_model_folder+'/'+model_choice+'.h5',full_image_path)\n"," n_objects.append(n_obj)\n","\n","for img in os.listdir(Source_QC_folder):\n"," if img.endswith('detected'+file_suffix):\n"," shutil.move(Source_QC_folder+'/'+img,QC_model_folder+\"/Quality Control/Prediction/\"+img)\n","\n","### Get the coordinates of the predicted boxes, ###\n","### box classes and confidence scores ###\n","\n","# from the csv containing the predicted boxes\n","with open('/content/predicted_bounding_boxes.csv','r', newline='') as csvfile:\n"," csv_reader = csv.reader(csvfile)\n"," next(csv_reader)\n"," pred_boxes = []\n"," pred_classes = []\n"," pred_conf = []\n"," for row in csv_reader:\n"," image_boxes = []\n"," box_classes = []\n"," box_conf = []\n"," for i in range(1,len(row),6):\n"," image_boxes.append(list(map(float,row[i:i+4])))\n"," box_classes.append(int(row[i+5]))\n"," box_conf.append(float(row[i+4]))\n"," pred_boxes.append(image_boxes) # The rows of this list contain the coordinates for all boxes per image\n"," pred_classes.append(box_classes) # The rows of this list contain the predicted classes for each box in the pred_boxes\n"," pred_conf.append(box_conf) # The rows of this list contain the confidence scores for each predicted box in pred_boxes\n","\n","#shutil.move('/content/predicted_bounding_boxes.csv',QC_model_folder+\"/Quality Control/Prediction/predicted_boxes_QC.csv\")\n","\n","#### Get the coordinates of the GT boxes ###\n","\n","df_anno_QC_gt = []\n","#dir_anno = Training_Source_annotations\n","for fnm in os.listdir(Annotations_QC_folder): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(Annotations_QC_folder,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno_QC_gt.append(row)\n","df_anno_QC_gt = pd.DataFrame(df_anno_QC_gt)\n","\n","#df_anno_QC_gt.to_csv('/content/gt_bboxes_QC.csv')\n","maxNobj = np.max(df_anno_QC_gt[\"Nobj\"])\n","\n","config_path = '/content/gdrive/My Drive/keras-yolo2/config.json'\n","class_dict = {}\n","\n","with open(config_path) as config_buffer:\n"," config = json.load(config_buffer)\n"," for i in config[\"model\"][\"labels\"]:\n"," class_dict[i] = int(config[\"model\"][\"labels\"].index(i))\n","\n","reverse_class_dict = {value : key for (key, value) in class_dict.items()}\n","\n","df_anno_QC_gt = df_anno_QC_gt.replace(class_dict)\n","df_anno_QC_gt.to_csv(QC_model_folder+'/Quality Control/gt_bboxes_QC.csv')\n","\n","gt_boxes = []\n","gt_labels = []\n","gt_label_names = []\n","for j in range(0,df_anno_QC_gt.shape[0]):\n"," row = df_anno_QC_gt.iloc[j]\n"," width = int(row[\"width\"])\n"," height = int(row[\"height\"])\n"," gt_box = []\n"," gt_label = []\n"," gt_label_name = []\n"," for i in range(row[\"Nobj\"]):\n"," label = int(float(row[\"bbx_{}_name\".format(i)]))\n"," label_name = row[\"bbx_{}_name\".format(i)]\n"," x1=row[\"bbx_{}_xmin\".format(i)]\n"," y1=row[\"bbx_{}_ymin\".format(i)]\n"," x2=row[\"bbx_{}_xmax\".format(i)]\n"," y2=row[\"bbx_{}_ymax\".format(i)]\n"," #gt_box.append([x1/width,y1/height,x2/width,y2/height])\n"," gt_box.append([x1,y1,x2,y2])\n","\n"," gt_label.append(label)\n"," gt_label_name.append(label_name)\n"," gt_boxes.append(gt_box)\n"," gt_labels.append(gt_label)\n"," gt_label_names.append(gt_label_name)\n","\n","#The essential outputs from this are gt_array and gt_classes_full\n","#Each row contains all bounding boxes and classes for each gt image.\n","\n","#Here we create the Detection Maps for the first three predictions\n","#Prediction\n","\n","pred_box_1 = np.array(pred_boxes[0])\n","#pred_box_2 = np.array(pred_boxes[1])\n","#pred_box_3 = np.array(pred_boxes[2])\n","\n","pred_class_1 = np.array(pred_classes[0])\n","#pred_class_2 = np.array(pred_classes[1])\n","#pred_class_3 = np.array(pred_classes[2])\n","\n","pred_conf_1 = np.array(pred_conf[0])\n","#pred_conf_2 = np.array(pred_conf[1])\n","#pred_conf_3 = np.array(pred_conf[2])\n"," \n","#print(pred_box_1)\n","\n","#print(pred_conf_1)\n","\n","# #GT\n","#print(gt_box_1[0])\n","gt_box_1 = np.array(gt_boxes[0])\n","#gt_box_2 = np.array(gt_boxes[1])\n","#gt_box_3 = np.array(gt_boxes[2])\n","#print(gt_box_1)\n","\n","gt_class_1 = np.array(gt_labels[0])\n","#gt_class_2 = np.array(gt_labels[1])\n","#gt_class_3 = np.array(gt_labels[2])\n","\n","frames = [(pred_box_1, pred_class_1, pred_conf_1, gt_box_1, gt_class_1)]\n"," #(pred_box_2, pred_class_2, pred_conf_2, gt_box_3, gt_class_3),#]#,\n"," #(pred_box_3, pred_class_3, pred_conf_3, gt_box_1, gt_class_1)#]#,\n"," #]\n"," #]\n","\n","n_class = len(config['model']['labels'])\n","\n","plt.figure(figsize=(15,5))\n","for i, frame in enumerate(frames):\n"," img = np.array(io.imread(os.path.join(Source_QC_folder,os.path.splitext(sorted(os.listdir(Annotations_QC_folder))[i])[0]+file_suffix)))\n"," show_frame(*frame, reverse_class_dict, background = img)\n","\n","\n","#Make a csv file to read into imagej macro, to create custom bounding boxes\n","header = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*max(n_objects)\n","with open('/content/predicted_bounding_boxes.csv', newline='') as inFile, open('/content/predicted_bounding_boxes_new.csv', 'w', newline='') as outfile:\n"," r = csv.reader(inFile)\n"," w = csv.writer(outfile)\n"," next(r, None) # skip the first row from the reader, the old header\n"," # write new header\n"," w.writerow(header)\n"," # copy the rest\n"," for row in r:\n"," w.writerow(row)\n","\n","df_bbox=pd.read_csv('/content/predicted_bounding_boxes_new.csv')\n","df_bbox=df_bbox.transpose()\n","new_header = df_bbox.iloc[0] #grab the first row for the header\n","df_bbox = df_bbox[1:] #take the data less the header row\n","df_bbox.columns = new_header #set the header row as the df header\n","df_bbox.sort_values(by='filename',axis=1,inplace=True)\n","df_bbox.to_csv(QC_model_folder+'/Quality Control/predicted_bounding_boxes_for_custom_ROI_QC.csv')\n","\n","AP, recall, precision = _calc_avg_precisions(config,Source_QC_folder,Annotations_QC_folder+'/',QC_model_folder+'/'+model_choice+'.h5',0.3,0.3)\n","\n","print('mAP score for QC dataset: '+str(sum(AP.values())/len(AP)))\n","for i in range(len(AP)):\n"," if AP[i]!=0:\n"," if len(recall[i]) == 1:\n"," new_recall = np.linspace(0,list(recall[i])[0],10)\n"," new_precision = list(precision[i])*10\n"," fig = plt.figure(figsize=(3,2))\n"," plt.plot(new_recall,new_precision)\n"," plt.axis([min(new_recall),1,0,1.02])\n"," plt.xlabel('Recall',fontsize=14)\n"," plt.ylabel('Precision',fontsize=14)\n"," plt.title(config['model']['labels'][i]+', AP: '+str(round(AP[i],3)),fontsize=14)\n"," plt.fill_between(new_recall,new_precision,alpha=0.3)\n"," plt.savefig('/content/P-R_curve_'+str(i)+'.png')\n"," plt.show()\n"," else:\n"," new_recall = list(recall[i])\n"," new_recall.append(new_recall[len(new_recall)-1])\n"," new_precision = list(precision[i])\n"," new_precision.append(0)\n"," fig = plt.figure(figsize=(3,2))\n"," plt.plot(new_recall,new_precision)\n"," plt.axis([min(new_recall),1,0,1.02])\n"," plt.xlabel('Recall',fontsize=14)\n"," plt.ylabel('Precision',fontsize=14)\n"," plt.title(config['model']['labels'][i]+', AP: '+str(round(AP[i],3)),fontsize=14)\n"," plt.fill_between(new_recall,new_precision,alpha=0.3)\n"," plt.savefig('/content/P-R_curve_'+str(i)+'.png')\n"," plt.show()\n"," else:\n"," print('No object of class '+config['model']['labels'][i]+' was detected. This will lower the mAP score. Consider adding an image containing this class to your QC dataset to see if the model can detect this class at all.')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"-n9CLLJ77FAA","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Inspect example output from QC\n","import random\n","from matplotlib.pyplot import imread\n","import imageio\n","\n","add_header('/content/predicted_bounding_boxes_names.csv','/content/predicted_bounding_boxes_names_new.csv')\n","\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","file_suffix = os.path.splitext(random_choice)[1]\n","\n","plt.figure(figsize=(30,15))\n","\n","\n","### Display Raw input ###\n","\n","x = imread(Source_QC_folder+\"/\"+random_choice)\n","plt.subplot(1,3,1)\n","plt.axis('off')\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Input')\n","\n","### Display Predicted annotation ###\n","\n","df_bbox2 = pd.read_csv('/content/predicted_bounding_boxes_names_new.csv')\n","for img in range(0,df_bbox2.shape[0]):\n"," df_bbox2.iloc[img]\n"," row = pd.DataFrame(df_bbox2.iloc[img])\n"," if row[img][0] == random_choice:\n"," row = row.dropna()\n"," image = imageio.imread(Source_QC_folder+'/'+row[img][0])\n"," #plt.figure(figsize=(12,12))\n"," plt.subplot(1,3,2)\n"," plt.axis('off')\n"," plt.imshow(image) # plot image\n"," plt.title('Prediction')\n"," for i in range(1,int(len(row)-1),6):\n"," plt_rectangle(plt,\n"," label = row[img][i+5],\n"," x1=row[img][i],#.format(iplot)],\n"," y1=row[img][i+1],\n"," x2=row[img][i+2],\n"," y2=row[img][i+3])#,\n"," #fontsize=8)\n","\n","\n","### Display GT Annotation ###\n","\n","df_anno_QC_gt = []\n","for fnm in os.listdir(Annotations_QC_folder): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(Annotations_QC_folder,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno_QC_gt.append(row)\n","df_anno_QC_gt = pd.DataFrame(df_anno_QC_gt)\n","#maxNobj = np.max(df_anno_QC_gt[\"Nobj\"])\n","\n","for i in range(0,df_anno_QC_gt.shape[0]):\n"," if df_anno_QC_gt.iloc[i][\"fileID\"]+file_suffix == random_choice:\n"," row = df_anno_QC_gt.iloc[i]\n","\n","img = imageio.imread(Source_QC_folder+'/'+random_choice)\n","plt.subplot(1,3,3)\n","plt.axis('off')\n","plt.imshow(img) # plot image\n","plt.title('Ground Truth annotations')\n","\n","# for each object in the image, plot the bounding box\n","for iplot in range(row[\"Nobj\"]):\n"," plt_rectangle(plt,\n"," label = row[\"bbx_{}_name\".format(iplot)],\n"," x1=row[\"bbx_{}_xmin\".format(iplot)],\n"," y1=row[\"bbx_{}_ymin\".format(iplot)],\n"," x2=row[\"bbx_{}_xmax\".format(iplot)],\n"," y2=row[\"bbx_{}_ymax\".format(iplot)])#,\n"," #fontsize=8)\n","\n","### Show the plot ###\n","plt.show()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"Esqnbew8uznk"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`Prediction_model_path`:** This should be the folder that contains your model."]},{"cell_type":"code","metadata":{"id":"9ZmST3JRq-Ho","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","file_suffix = os.path.splitext(os.listdir(Data_folder)[0])[1]\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, provide the name of the model and path to model folder:\n","\n","Prediction_model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Which model do you want to use?\n","model_choice = \"best_map_weights\" #@param[\"best_weights\",\"last_weights\",\"best_map_weights\"]\n","\n","#@markdown ###Which backend is the model using?\n","backend = \"Full Yolo\" #@param [\"Select Model\",\"Full Yolo\",\"Inception3\",\"SqueezeNet\",\"MobileNet\",\"Tiny Yolo\"]\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","if backend == \"Full Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/full_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/full_yolo_backend.h5\n","elif backend == \"Inception3\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/inception_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/inception_backend.h5\n","elif backend == \"MobileNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/mobilenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/mobilenet_backend.h5\n","elif backend == \"SqueezeNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/squeezenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/squeezenet_backend.h5\n","elif backend == \"Tiny Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/tiny_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/tiny_yolo_backend.h5\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_path = full_model_path\n","\n","if Use_the_current_trained_model == False:\n"," if os.path.exists('/content/gdrive/My Drive/keras-yolo2/config.json'):\n"," os.remove('/content/gdrive/My Drive/keras-yolo2/config.json')\n"," shutil.copyfile(Prediction_model_path+'/config.json','/content/gdrive/My Drive/keras-yolo2/config.json')\n","\n","if os.path.exists(Prediction_model_path+'/'+model_choice+'.h5'):\n"," print(\"The \"+os.path.basename(Prediction_model_path)+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","# Provide the code for performing predictions and saving them\n","print(\"Images saved into folder:\", Result_folder)\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"GcmBwMJVcFh1","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run Prediction\n","\n","#Remove any files that might be from the prediction of QC examples.\n","if os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," os.remove('/content/predicted_bounding_boxes.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_new.csv'):\n"," os.remove('/content/predicted_bounding_boxes_new.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_names.csv'):\n"," os.remove('/content/predicted_bounding_boxes_names.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_names_new.csv'):\n"," os.remove('/content/predicted_bounding_boxes_names_new.csv')\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","\n","if os.path.exists(Data_folder+'/.ipynb_checkpoints'):\n"," shutil.rmtree(Data_folder+'/.ipynb_checkpoints')\n","\n","n_objects = []\n","for img in os.listdir(Data_folder):\n"," full_image_path = Data_folder+'/'+img\n"," n_obj = predict('config.json',Prediction_model_path+'/'+model_choice+'.h5',full_image_path)#,Result_folder)\n"," n_objects.append(n_obj)\n","\n","for img in os.listdir(Data_folder):\n"," if img.endswith('detected'+file_suffix):\n"," shutil.move(Data_folder+'/'+img,Result_folder+'/'+img)\n","\n","if os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," #shutil.move('/content/predicted_bounding_boxes.csv',Result_folder+'/predicted_bounding_boxes.csv')\n"," print('Bounding box labels and coordinates saved to '+ Result_folder)\n","else:\n"," print('For some reason the bounding box labels and coordinates were not saved. Check that your predictions look as expected.')\n","\n","#Make a csv file to read into imagej macro, to create custom bounding boxes\n","header = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*max(n_objects)\n","with open('/content/predicted_bounding_boxes.csv', newline='') as inFile, open('/content/predicted_bounding_boxes_new.csv', 'w', newline='') as outfile:\n"," r = csv.reader(inFile)\n"," w = csv.writer(outfile)\n"," next(r, None) # skip the first row from the reader, the old header\n"," # write new header\n"," w.writerow(header)\n"," # copy the rest\n"," for row in r:\n"," w.writerow(row)\n","\n","df_bbox=pd.read_csv('/content/predicted_bounding_boxes_new.csv')\n","df_bbox=df_bbox.transpose()\n","new_header = df_bbox.iloc[0] #grab the first row for the header\n","df_bbox = df_bbox[1:] #take the data less the header row\n","df_bbox.columns = new_header #set the header row as the df header\n","df_bbox.sort_values(by='filename',axis=1,inplace=True)\n","df_bbox.to_csv(Result_folder+'/predicted_bounding_boxes_for_custom_ROI.csv')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"EIe3CRD7XUxa","colab_type":"text"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","metadata":{"id":"LmDP8xiwXTTL","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","import random\n","from matplotlib.pyplot import imread\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","print(random_choice)\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+os.path.splitext(random_choice)[0]+'_detected'+file_suffix)\n","\n","plt.figure(figsize=(20,8))\n","\n","plt.subplot(1,3,1)\n","plt.axis('off')\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Input')\n","\n","plt.subplot(1,3,2)\n","plt.axis('off')\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Predicted output');\n","\n","add_header('/content/predicted_bounding_boxes_names.csv','/content/predicted_bounding_boxes_names_new.csv')\n","\n","#We need to edit this predicted_bounding_boxes_new.csv file slightly to display the bounding boxes\n","df_bbox2 = pd.read_csv('/content/predicted_bounding_boxes_names_new.csv')\n","for img in range(0,df_bbox2.shape[0]):\n"," df_bbox2.iloc[img]\n"," row = pd.DataFrame(df_bbox2.iloc[img])\n"," if row[img][0] == random_choice:\n"," row = row.dropna()\n"," image = imageio.imread(Data_folder+'/'+row[img][0])\n"," #plt.figure(figsize=(12,12))\n"," plt.subplot(1,3,3)\n"," plt.axis('off')\n"," plt.title('Alternative Display of Prediction')\n"," plt.imshow(image) # plot image\n","\n"," for i in range(1,int(len(row)-1),6):\n"," plt_rectangle(plt,\n"," label = row[img][i+5],\n"," x1=row[img][i],#.format(iplot)],\n"," y1=row[img][i+1],\n"," x2=row[img][i+2],\n"," y2=row[img][i+3])#,\n"," #fontsize=8)\n"," #plt.margins(0,0)\n"," #plt.subplots_adjust(left=0., right=1., top=1., bottom=0.)\n"," #plt.gca().xaxis.set_major_locator(plt.NullLocator())\n"," #plt.gca().yaxis.set_major_locator(plt.NullLocator())\n"," plt.savefig('/content/detected_cells.png',bbox_inches='tight',transparent=True,pad_inches=0)\n","plt.show() ## show the plot\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB","colab_type":"text"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"Rn9zpWpo0xNw","colab_type":"text"},"source":["\n","#**Thank you for using YOLOv2!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/fnet_ZeroCostDL4Mic.ipynb b/Colab_notebooks/fnet_ZeroCostDL4Mic.ipynb index a87c5ff2..7995fce1 100755 --- a/Colab_notebooks/fnet_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/fnet_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"fnet_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1G6lQzjd259Yoy_OozBhJolF4HraE52PG","timestamp":1591353884724},{"file_id":"1pSC680miQesRinU8Tjn7X6AmJXNtxUNI","timestamp":1591182507229},{"file_id":"1ajYZgvhQfpcUZ5YWlUeB-GUU_j-njsqw","timestamp":1589209398121},{"file_id":"1QiFrHg_cVlOl_yzu-RO9mMIrA2L1dXwj","timestamp":1587744376035},{"file_id":"1_S3UtNcuAaZhVc4yqlFDHc2eKq1x-ynn","timestamp":1587058075616},{"file_id":"1Gce_llcAX7yJTFZP2HiNpTL56gXR7PQ-","timestamp":1586854238074},{"file_id":"10l0NA5VWlqRvDlJRTxOiOUgN5LxEo2gy","timestamp":1586601464429},{"file_id":"1NSdad2BEDJZ16AO3SEEaG-ZSe0o4u3eY","timestamp":1586368373257},{"file_id":"1ubiSLYW3G4eNGNF31e2Vbw_3jMHJ9Y7M","timestamp":1585303720184},{"file_id":"1O6YzESEk9VFr6Nc6ijOAYCtiP80uuh7I","timestamp":1585248652537},{"file_id":"1DPrSIbf-ML-LIO2e4YhL1KedWVsVcFlT","timestamp":1585232236512},{"file_id":"1Qanbeybd44tHmdzKxTJAMDD4trFdCYwD","timestamp":1585049767771},{"file_id":"1Fr9Ea5QdUgK0CKfQKpq9KrxtxxAkSVwc","timestamp":1584619265981},{"file_id":"1RQ6XuOBIRaWgId2WKO2i-MMnXoKn_tNA","timestamp":1584541702239},{"file_id":"1mAvQKCCelwK8zPkAWFvKtiAsE_35KSpW","timestamp":1584533728194},{"file_id":"1LdMzIh-v-gUXnd6v9U2Ov28T-XpeT1PP","timestamp":1584463518766},{"file_id":"18Y0NabtThelB0uOAJlg7UbjHPYMEoCqW","timestamp":1584455459923},{"file_id":"1ZCnLW6HUl0bXrPa-54-bv_C9f6jYL0T4","timestamp":1584436296801},{"file_id":"1gTLXTd_rOpXmlktZz2yeEW62gY8ety-I","timestamp":1583941948440},{"file_id":"1gC_pmaDD73tD-yNoFGjHEolYfLd_7czL","timestamp":1583593255888},{"file_id":"17pZee2Vp0kCh3W8pfzRYk8asqk35mOfw","timestamp":1583335080677},{"file_id":"1KyYm3JglQpPYnf-aBLLiP-sFgi_A0Og1","timestamp":1583291424450},{"file_id":"1ZJCI2p66noTaLCnVUQJkTR16ig6GAqAx","timestamp":1576151149296}],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"C-wdtVN5KUFi","colab_type":"text"},"source":["#**Label-free prediction - fnet**\n","---\n","\n"," \n","Label-free prediction (fnet) is a neural network developped to infer the distribution of specific cellular structures from label-free images such as brightfield or EM images. It was first published in 2018 by [Ounkomol *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0111-2). The network uses a common U-Net architecture and is trained using paired imaging volumes from the same field of view, imaged in a label-free (e.g. brightfield) and labelled condition (e.g. fluorescence images of a specific label of interest). When trained, this allows the user to identify certain structures from brightfield images alone. The performance of fnet may depend significantly on the structure at hand.\n","\n","---\n"," *Disclaimer*:\n","\n"," This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n"," This notebook is largely based on the paper: \n","\n","**Label-free prediction of three-dimensional fluorescence images from transmitted light microscopy** by Ounkomol *et al.* in Nature Methods, 2018 (https://www.nature.com/articles/s41592-018-0111-2)\n","\n"," And source code found in: https://github.com/AllenCellModeling/pytorch_fnet\n","\n"," **Please also cite this original paper when using or developing this notebook.** \n"]},{"cell_type":"markdown","metadata":{"id":"Qt5Yt1vsD163","colab_type":"text"},"source":["# **How to use this notebook?**\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"zwILBhMkzKp_","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n","\n"," This notebook provides two opportunities: firstly, to download and train Fnet with data published in the original manuscript or secondly, to upload a personal dataset and train Fnet on it.\n"," The notebook may require a large amount of disk space. If using the datasets from the paper, the available disk space on the user's google drive should contain at least 40GB."]},{"cell_type":"markdown","metadata":{"id":"pcNfrIVpNZC-","colab_type":"text"},"source":["---\n","**Data Format**\n","\n"," **The data used to train fnet must be 3D stacks in .tiff (.tif) file format and contain the signal (e.g. bright-field image) and the target channel (e.g. fluorescence) for each field of view**. To use this notebook on user data, upload the data in the following format to your google drive. To ensure corresponding images are used during training give corresponding signal and target images the same name.\n","\n","Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n"," **Note: Your *dataset_folder* should not have spaces or brackets in its name as this is not recognized by the fnet code and will throw an error** \n","\n","\n","* Experiment A\n"," - **Training dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif, ...\n"," - fluorescence images\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif\n"," - fluorescence images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"I0aF5U_Y0IFW","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"EBHobPtQ8wx7","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"UphYcwdDS8yO","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"kRVmtCZB9OQ2","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"QTEFQc6j9RTv","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yk96o-_u-27d","colab_type":"text"},"source":["#**2. Install fnet and dependencies**\n","---\n","Running fnet requires the fnet folder to be downloaded into the session's Files. As fnet needs several packages to be installed, this step may take a few minutes.\n","\n","You can ignore **the error warnings** as they refer to packages not required for this notebook.\n","\n","**Note: It is not necessary to keep the pytorch_fnet folder after you are finished using the notebook, so it can be deleted afterwards by playing the last cell (bottom).**"]},{"cell_type":"code","metadata":{"id":"BbYpGlfskzrO","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play this cell to download fnet to your drive. If it is already installed this will only install the fnet dependencies.\n","import os\n","import csv\n","import shutil\n","import random\n","from tempfile import mkstemp\n","from shutil import move, copymode\n","from os import fdopen, remove\n","import sys\n","import numpy as np\n","import shutil\n","import os\n","from tempfile import mkstemp\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from skimage import img_as_float32\n","from distutils.dir_util import copy_tree\n","import datetime\n","import time\n","\n","#Ensure tensorflow 1.x\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","#clone fnet from github to colab\n","!pip install -U scipy==1.2.0\n","!pip install matplotlib==2.2.3\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet'):\n"," !git clone -b release_1 --single-branch https://github.com/AllenCellModeling/pytorch_fnet.git; cd pytorch_fnet; pip install .\n"," shutil.move('/content/pytorch_fnet','/content/gdrive/My Drive/pytorch_fnet')\n","from skimage import io\n","from matplotlib import pyplot as plt\n","import pandas as pd\n","#from skimage.util import img_as_uint\n","import matplotlib as mpl\n","#from scipy import signal\n","#from scipy import ndimage\n","\n","\n","#This function replaces the old default files with new values\n","def replace(file_path, pattern, subst):\n"," #Create temp file\n"," fh, abs_path = mkstemp()\n"," with fdopen(fh,'w') as new_file:\n"," with open(file_path) as old_file:\n"," for line in old_file:\n"," new_file.write(line.replace(pattern, subst))\n"," #Copy the file permissions from the old file to the new file\n"," copymode(file_path, abs_path)\n"," #Remove original file\n"," remove(file_path)\n"," #Move new file\n"," move(abs_path, file_path)\n","\n","def insert_line_to_file(filepath,line_number,insertion):\n"," f = open(filepath, \"r\")\n"," contents = f.readlines()\n"," f.close()\n"," f = open(filepath, \"r\")\n"," if not insertion in f.read():\n"," contents.insert(line_number, insertion)\n"," f.close()\n"," f = open(filepath, \"w\")\n"," contents = \"\".join(contents)\n"," f.write(contents)\n"," f.close()\n","\n","def add_validation(filepath,line_number,insert,append):\n"," f = open(filepath, \"r\")\n"," contents = f.readlines()\n"," f.close()\n"," f = open(filepath, \"r\")\n"," if not 'PATH_DATASET_VAL_CSV=' in f.read():\n"," contents.insert(line_number, insert)\n"," contents.append(append)\n"," f.close()\n"," f = open(filepath, \"w\")\n"," contents = \"\".join(contents)\n"," f.write(contents)\n"," f.close()\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","#Here we replace values in the old files\n","#Change maximum pixel number\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/fnet/transforms.py\",'n_max_pixels=9732096','n_max_pixels=20000000')\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/predict.py\",'6000000','20000000')\n","\n","#Prevent resizing in the training and the prediction\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/predict.py\",\"0.37241\",\"1.0\")\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"0.37241\",\"1.0\")\n","\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"JqCe6m-C_PrH","colab_type":"text"},"source":["#**3. Select your paths and parameters**\n","---"]},{"cell_type":"markdown","metadata":{"id":"w5NmDpJ4xvWE","colab_type":"text"},"source":["## **3.1. Setting the main training parameters**\n","---\n"," **Paths for training data**\n","\n"," **`Training_source`,`Training_target:`** These are the paths to your folders containing the Training_source (brightfield) and Training_target (fluorescent label) training data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**Note: The stacks for fnet should either have 32 or more slices or have a number of slices which are a power of 2 (e.g. 2,4,8,16).**\n","\n"," **Training Parameters**\n","\n"," **`steps:`** Input how many iterations you want to train the network for. A larger number may improve performance but risks overfitting to the training data. To reach good performance of fnet requires several 10000's iterations which will usually require **several hours**, depending on the dataset size. **Default: 50000**\n","\n","**`batch_size:`** Reducing or increasing the **batch size** may speed up or slow down your training, respectively and can influence network performance. **Default: 4**"]},{"cell_type":"code","metadata":{"id":"PWxNzzgKu9Kb","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###Datasets\n","#Datasets\n","\n","#Change checkpoints\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","#replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_source = \"\" #@param {type: \"string\"}\n","source_name = os.path.basename(os.path.normpath(Training_source))\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_target = \"\" #@param {type: \"string\"}\n","target_name = os.path.basename(os.path.normpath(Training_target))\n","\n","#@markdown ###Model name and model path\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","#dataset = model_name #The name of the dataset and the model will be the same\n","\n","#Here, we check if the dataset already exists. If not, copy the dataset from google drive to the data folder\n"," \n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name):\n"," #shutil.copytree(own_dataset,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset)\n"," os.makedirs('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name)\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","elif os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name) and not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name):\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","elif os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name) and os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n","#Create a path_csv file to point to the training images\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","\n","source = os.listdir('./'+model_name+'/'+source_name)\n","target = os.listdir('./'+model_name+'/'+target_name)\n","\n","#print(\"Selected \"+dataset+\" as training set\")\n","\n","model_name_x = model_name+\"}\" # this variable is only used to ensure closed curly brackets when editing the .sh files\n","\n","#We need to declare that we will run validation on the dataset\n","#We need to add a new line to the train.sh file\n","with open(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\", \"r\") as f:\n"," if not \"gpu_ids ${GPU_IDS} \\\\\" in f.read():\n"," replace(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\",\" --gpu_ids ${GPU_IDS}\",\" --gpu_ids ${GPU_IDS} \\\\\")\n","\n","#We add the necessary validation parameters here.\n","insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n","append = '\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}'\n","add_validation(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\",10,insert,append)\n","\n","#Clear the White space from train.sh\n","\n","with open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh','/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","\n","#Here we define the random set of training files to be used for validation\n","val_files = random.sample(source,len(source)//10)\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","\n","#Make validation directories\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","\n","#Move a random set of files from the training to the validation folders\n","for file in val_files:\n"," shutil.move('./'+model_name+'/'+source_name+'/'+file,'./'+model_name+'/Validation_Input/'+file)\n"," shutil.move('./'+model_name+'/'+target_name+'/'+file,'./'+model_name+'/Validation_Target/'+file)\n","\n","#Redefine the source and target lists after moving the validation files\n","source = os.listdir('./'+model_name+'/'+source_name)\n","target = os.listdir('./'+model_name+'/'+target_name)\n","\n","#Define Validation file lists\n","val_signal = os.listdir('./'+model_name+'/Validation_Input')\n","val_target = os.listdir('./'+model_name+'/Validation_Target')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv')\n","\n","#Finally, we create a validation csv file to construct the validation dataset\n","with open(model_name+'_val.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(val_signal)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/Validation_Input/\"+val_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/Validation_Target/\"+val_target[i]])\n","\n","shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'_val.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n","with open(model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+source_name+\"/\"+source[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+target_name+\"/\"+target[i]])\n","\n","shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n","\n","#@markdown ---\n","\n","#@markdown ###Training Parameters\n","\n","#Training parameters in fnet are indicated in the train_model.sh file.\n","#Here, we edit this file to include the desired parameters\n","\n","#1. Add permissions to train_model.sh\n","os.chdir(\"/content/gdrive/My Drive/pytorch_fnet/scripts\")\n","!chmod u+x train_model.sh\n","\n","#2. Select parameters\n","steps = 50000#@param {type:\"number\"}\n","batch_size = 4#@param {type:\"number\"}\n","number_of_images = len(source)\n","\n","#3. Insert the above values into train_model.sh\n","!if ! grep saved_models\\/\\${ train_model.sh;then sed -i 's/saved_models\\/.*/saved_models\\/\\${DATASET}\"/g' train_model.sh; fi \n","!sed -i \"s/1:-.*/1:-$model_name_x/g\" train_model.sh #change the dataset to be trained with\n","!sed -i \"s/N_ITER=.*/N_ITER=$steps/g\" train_model.sh #change the number of training iterations (steps)\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n","\n","#We also change the training split as in our notebook the test images are used separately for prediction and we want fnet to train on the whole training data set.\n","!sed -i \"s/train_size .* -v/train_size 1.0 -v/g\" train_model.sh\n","\n","#If new parameters are inserted here for training a model with the same name\n","#the previous training csv needs to be removed, to prevent the model using the old training split or paths.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"BCKcSJxkxi33","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"msrTTcPI1Cav","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating images in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n","**Note:** Using a full augmented dataset can exceed the RAM limitations of the colab notebook. If the augmented dataset is too large, the notebook will therefore only pick a subset of the augmented dataset for training. Make sure you only augment datasets which are small (ca. 20-30 images)."]},{"cell_type":"code","metadata":{"id":"u_YFN6Bd594L","colab_type":"code","cellView":"form","colab":{}},"source":["from skimage import io\n","import numpy as np\n","\n","Use_Data_augmentation = True #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n","\n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name,flip=Flip)\n"," \n"," elif Rotation == False and Flip == True:\n"," flip('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n"," \n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n"," #Fetch the path and extract the name of the signal folder\n"," Training_source = Saving_path+\"/augmented_source\"\n"," source_name = os.path.basename(os.path.normpath(Training_source))\n","\n"," #Fetch the path and extract the name of the signal folder\n"," Training_target = Saving_path+\"/augmented_target\"\n"," target_name = os.path.basename(os.path.normpath(Training_target))\n","\n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n"," os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n"," #Redefine the source and target lists after moving the validation files\n"," source = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," target = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n"," with open(model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+source_name+\"/\"+source[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+target_name+\"/\"+target[i]])\n","\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n","\n"," #Here, we ensure that the all files, including Validation are saved somewhere together for later access, e.g. for retraining.\n"," for image in os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input'):\n"," shutil.copyfile(os.path.join('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input',image),Saving_path+'/augmented_source/'+image)\n"," shutil.copyfile(os.path.join('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target',image),Saving_path+'/augmented_target/'+image)\n"," \n"," if len(source)>130:\n"," number_of_images = 130\n"," else:\n"," number_of_images = len(source)\n","\n"," os.chdir(\"/content/gdrive/My Drive/pytorch_fnet/scripts\")\n"," !chmod u+x train_model.sh\n"," !sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","\n"," print(\"Done\")\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"heuBzM5JADYf","colab_type":"text"},"source":["#**4. Train the model**\n","---\n","Before training, carefully read the different options. This applies especially if you have trained fnet on a dataset before.\n","\n","\n","###**Choose one of the options to train fnet**.\n","\n","**4.1.** If this is the first training on the chosen dataset, play this section to start training.\n","\n","**4.2.** If you want to continue training on an already pre-trained model choose this section\n","\n"," **Carefully read the options before starting training.**"]},{"cell_type":"markdown","metadata":{"id":"eLllOs_rA62U","colab_type":"text"},"source":["##**4.1. Train a new model**\n","---\n","\n","####Play the cell below to start training. \n","\n","**Note:** If you are training with a model of the same name as before, the model will be overwritten. If you want to keep the previous model save it before playing the cell below or give your model a different name (section 3)."]},{"cell_type":"code","metadata":{"id":"xe3TLu7M-3Dk","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##4.1. Start training\n","\n","start = time.time()\n","\n","#Overwriting old models and saving them separately if True\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name)\n","\n","#This tifffile release runs error-free in this version of fnet.\n","!pip install tifffile==2019.7.26\n","\n","#Here we import an additional module to the functions.py file to run it without errors.\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","insert_line_to_file(\"/content/gdrive/My Drive/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\")\n","\n","\n","print('Let''s start the training!')\n","#Here we start the training\n","!./scripts/train_model.sh $model_name 0\n","\n","#After training overwrite any existing model in the model_path with the new trained model.\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n","shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,model_path+'/'+model_name)\n","\n","shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv',model_path+'/'+model_name+'/'+model_name+'_val.csv')\n","#Get rid of duplicates of training data in pytorch_fnet after training completes\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"fpXr4JlCd5uV","colab_type":"text"},"source":["**Note:** Fnet takes a long time for training. If your notebook times out due to the length of the training or due to a loss of GPU acceleration the last checkpoint will be saved in the saved_models folder in the pytorch_fnet folder. If you want to save it in a more convenient location on your drive, remount the drive (if you got disconnected) and in the next cell enter the location (`model_path`) where you want to save the model (`model_name`) before continuing in 4.2. **If you did not time out you can ignore this section.**"]},{"cell_type":"code","metadata":{"id":"x41OhmO-hsX3","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play this cell if your model training timed out and indicate where you want to save the last checkpoint.\n","\n","import shutil\n","import os\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name):\n"," shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,model_path+'/'+model_name)\n","else:\n"," print('This model name does not exist in your saved_models folder. Make sure you have entered the name of the model that timed out.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QefQX9WUBz0G","colab_type":"text"},"source":["##**4.2. Training from a previously saved model**\n","---\n","This section allows you to use networks you have previously trained and saved and to continue training them for more training steps. The folders have the same meaning as above (3.1.). If you want to save the previously trained model, create a copy now as this section will overwrite the weights of the old model. **You can currently only train the model with the same dataset and batch size that the network was previously trained on.**\n","\n","**Note: To use this section the *pytorch_fnet* folder must be in your *gdrive/My Drive*. (Simply, play cell 2. to make sure).**"]},{"cell_type":"code","metadata":{"id":"2-0m_-tF9oo-","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown To test if performance improves after the initial training, you can continue training on the old model. This option can also be useful if Colab disconnects or times out.\n","#@markdown Enter the paths of the datasets you want to continue training on.\n","\n","#Here we replace values in the old files\n","\n","insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n","append = '\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}'\n","\n","add_validation(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\",10,insert,append)\n","#Clear the White space from train.sh\n","\n","with open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh','/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","\n","#Datasets\n","\n","#Change checkpoints\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","\n","Training_source = \"\" #@param {type: \"string\"}\n","source_name = os.path.basename(os.path.normpath(Training_source))\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_target = \"\" #@param {type: \"string\"}\n","target_name = os.path.basename(os.path.normpath(Training_target))\n","\n","Pretrained_model_folder = \"\" #@param{type:\"string\"}\n","#model_name = \"\" #@param {type:\"string\"}\n","\n","Pretrained_model_name = os.path.basename(Pretrained_model_folder)\n","Pretrained_model_path = os.path.dirname(Pretrained_model_folder)\n","batch_size = 4 #@param {type:\"number\"}\n","\n","Pretrained_model_name_x = Pretrained_model_name+\"}\"\n","\n","#Move your model to fnet\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Pretrained_model_name):\n"," shutil.copytree(Pretrained_model_folder,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Pretrained_model_name)\n","\n","#Move the datasets into fnet\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name)\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name)\n","shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name)\n","shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+target_name)\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/scripts')\n","\n","### number_of_images = len(os.listdir(Training_source)) ###\n","\n","#Change the train_model.sh file to include chosen dataset\n","!chmod u+x ./train_model.sh\n","!sed -i \"s/1:-.*/1:-$Pretrained_model_name_x/g\" train_model.sh\n","!sed -i \"s/train_size .* -v/train_size 1.0 -v/g\" train_model.sh #Use the whole training dataset for training\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n","\n","\n","# We will use the same validation files from the training dataset as used before,\n","# This makes sure that the model is not validated with files it has seen in training before saving.\n","\n","#First we get the names of the validation files from the previous training which are saved in the validation csv.\n","val_source_list = []\n","\n","##CHECK THIS Prediction_model_name\n","with open('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_folder+'_val.csv', 'r') as f:\n","#with open(Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv', 'r') as f:\n"," contents = csv.reader(f,delimiter=',')\n"," for row in contents:\n"," val_source_list.append(row[0])\n","\n","#Get the file list without the header\n","val_source_list = val_source_list[1::]\n","\n","#Get only the file names and not the full path\n","for i in range(0,len(val_source_list)):\n"," val_source_list[i] = os.path.basename(os.path.normpath(val_source_list[i]))\n","\n","source = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name)\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input')\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target')\n","\n","#Make validation directories\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input')\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target')\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","\n","#Move a random set of files from the training to the validation folders\n","for file in val_source_list:\n"," #os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name+'/'+file,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input/'+file)\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+target_name+'/'+file,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target/'+file)\n","\n","#Redefine the source and target lists after moving the validation files\n","source = os.listdir('./'+Pretrained_model_name+'/'+source_name)\n","target = os.listdir('./'+Pretrained_model_name+'/'+target_name)\n","\n","#Define Validation file lists\n","val_signal = os.listdir('./'+Pretrained_model_name+'/Validation_Input')\n","val_target = os.listdir('./'+Pretrained_model_name+'/Validation_Target')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv')\n","\n","shutil.copyfile(Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv')\n","\n","#Make a training csv file.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name)\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","source = os.listdir('./'+Pretrained_model_name+'/'+source_name)\n","target = os.listdir('./'+Pretrained_model_name+'/'+target_name)\n","with open(Pretrained_model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Pretrained_model_name+\"/\"+source_name+\"/\"+source[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Pretrained_model_name+\"/\"+target_name+\"/\"+target[i]])\n","\n","shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'.csv')\n","\n","#Find the number of previous training iterations (steps) from loss csv file\n","\n","with open(Pretrained_model_folder+'/losses.csv') as f:\n"," previous_steps = sum(1 for line in f)\n","print('continuing training after step '+str(previous_steps-1))\n","\n","print('To start re-training play section 4.2. below')\n","\n","#@markdown For how many additional steps do you want to train the model?\n","add_steps = 50000#@param {type:\"number\"}\n","\n","#Calculate the new number of total training epochs. Subtract 1 to discount the title row of the csv file.\n","new_steps = previous_steps + add_steps -1\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/scripts')\n","\n","#Edit train_model.sh file to include new total number of training epochs\n","!sed -i \"s/N_ITER=.*/N_ITER=$new_steps/g\" train_model.sh"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"vH3EzxbfD6Uk","colab_type":"code","cellView":"form","colab":{}},"source":["import datetime\n","import time\n","start = time.time()\n","\n","#@markdown ##4.2. Start re-training model\n","!pip install tifffile==2019.7.26\n","import os\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/fnet')\n","\n","insert_line_to_file(\"/content/gdrive/My Drive/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\")\n","\n","#Here we retrain the model on the chosen dataset.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!chmod u+x ./scripts/train_model.sh\n","!./scripts/train_model.sh $Pretrained_model_name 0\n","\n","if os.path.exists(Pretrained_model_folder):\n"," shutil.rmtree(Pretrained_model_folder)\n","shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Pretrained_model_name,Pretrained_model_folder)\n","\n","#Get rid of duplicates of training data in pytorch_fnet after training completes\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name)\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+target_name)\n","\n","shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv',Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","min, sec = divmod(dt, 60) \n","hour, min = divmod(min, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",min,\"min(s)\",round(sec),\"sec(s)\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jwORXPtcqRHZ","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"rVBx2b2MpoFf","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the name of the model and path to model folder:\n","\n","QC_model_folder = \"/content/gdrive/My Drive/NewFnet_2\" #@param {type:\"string\"}\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","#Create a folder for the quality control metrics\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"aNR6bAk6oZJD","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"ratRdSDlcQ9G","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show figure of training errors\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","iterationNumber_training = []\n","iterationNumber_val = []\n","\n","import csv\n","from matplotlib import pyplot as plt\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses.csv','r') as csvfile:\n"," plots = csv.reader(csvfile, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_training.append(int(row[0]))\n"," lossDataFromCSV.append(float(row[1]))\n","\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses_val.csv','r') as csvfile_val:\n"," plots = csv.reader(csvfile_val, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_val.append(int(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.plot(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.semilogy(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/'+'losses.png')\n","plt.show()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"YkhOGv3Hp2xI","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n"]},{"cell_type":"code","metadata":{"id":"vqSH6EQb4BwU","colab_type":"code","cellView":"form","colab":{}},"source":["#Overwrite results folder if it already exists at the given location\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/results'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')\n","\n","!pip install -U scipy==1.2.0\n","!pip install --no-cache-dir tifffile==2019.7.26 \n","from distutils.dir_util import copy_tree\n","\n","#----------------CREATING PREDICTIONS FOR QUALITY CONTROL----------------------------------#\n","\n","\n","#Choose the folder with the quality control datasets\n","Source_QC_folder = \"/content/gdrive/My Drive/Label-free_prediction_(fnet)_v2/Test_dataset/Test-Transmitted_light_stacks_Split_data\" #@param{type:\"string\"}\n","Target_QC_folder = \"/content/gdrive/My Drive/Label-free_prediction_(fnet)_v2/Test_dataset/Test-TOM20_fluorescence_stacks_Split_data\" #@param{type:\"string\"}\n","\n","Predictions_name = \"QualityControl\" \n","Predictions_name_x = Predictions_name+\"}\"\n","\n","#If the folder you are creating already exists, delete the existing version to overwrite.\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+Predictions_name):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+Predictions_name)\n","\n","if Use_the_current_trained_model == True:\n"," #Move the contents of the saved_models folder from your training to the new folder\n"," #Here, we use a different copyfunction as we only need the contents of the trained_model folder\n"," copy_tree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+QC_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n","else:\n"," copy_tree(QC_model_path+'/'+QC_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n"," #dataset = QC_model_name\n","\n","# Get the name of the folder the test data is in\n","source_dataset_name = os.path.basename(os.path.normpath(Source_QC_folder))\n","target_dataset_name = os.path.basename(os.path.normpath(Target_QC_folder))\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","!sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","!sed -i \"s/in test.*/in test/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Check that we are using .tif files\n","file_list = os.listdir(Source_QC_folder)\n","text = file_list[0]\n","\n","if text.endswith('.tif') or text.endswith('.tiff'):\n"," !chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet//scripts/predict.sh\n"," !if ! grep class_dataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi\n"," !if grep CziDataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi \n","\n","#Create test_data folder in pytorch_fnet\n","\n","# If your test data is not in the pytorch_fnet data folder it needs to be copied there.\n","if Use_the_current_trained_model == True:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name+'/'+source_dataset_name):\n"," shutil.copytree(Source_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name+'/'+source_dataset_name)\n"," shutil.copytree(Target_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name+'/'+target_dataset_name)\n","else:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+source_dataset_name):\n"," shutil.copytree(Source_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+source_dataset_name)\n"," shutil.copytree(Target_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+target_dataset_name)\n","\n","\n","# Make a folder that will hold the test.csv file in your new folder\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs')\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name)\n","\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/')\n","\n","#Make a new folder in saved_models to use the trained model for inference.\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name) \n","\n","\n","#Get file list from the folders containing the files you want to use for inference.\n","#test_signal = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_dataset_name)\n","test_signal = os.listdir(Source_QC_folder)\n","test_target = os.listdir(Target_QC_folder)\n","#Now we make a path csv file to point the predict.sh file to the correct paths for the inference files.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/')\n","\n","#If an old test csv exists we want to overwrite it, so we can insert new test data.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv')\n","\n","#Here we create a new test.csv\n","with open('test.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," if Use_the_current_trained_model == True:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+QC_model_name+\"/\"+source_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+QC_model_name+\"/\"+target_dataset_name+\"/\"+test_signal[i]])\n"," # This currently assumes that the names are identical for source and target: see \"test_target\" variable is never used\n"," else:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+source_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+target_dataset_name+\"/\"+test_signal[i]])\n","\n","#We run the predictions\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!/content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh $Predictions_name 0\n","\n","#Save the results\n","QC_results_files = os.listdir('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Target'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Target')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Target')\n","\n","for i in range(len(QC_results_files)-2):\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/prediction_'+Predictions_name+'.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction/'+'Predicted_'+test_signal[i])\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/signal.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Signal/'+test_signal[i])\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/target.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Target/'+test_signal[i])\n","\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name)\n","\n","\n","#-----------------------------METRICS EVALUATION-------------------------------#\n","\n","# Calculating the position of the mid-plane slice\n","# Perform prediction on all datasets in the Source_QC folder\n","\n","#Finding the middle slice\n","img = io.imread(os.path.join(Source_QC_folder, os.listdir(Source_QC_folder)[0]))\n","n_slices = img.shape[0]\n","z_mid_plane = int(n_slices / 2)+1\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Prediction v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," NRMSE_GvP_list = []\n"," PSNR_GvP_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," PSNR_GvP_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," test_prediction_stack = np.squeeze(test_prediction_stack,axis=(0,))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," \n"," # -------------------------------- Prediction --------------------------------\n","\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = np.float32(img_SSIM_GTvsPrediction)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = np.float32(img_RSE_GTvsPrediction)\n","\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n","\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n","\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(NRMSE_GTvsPrediction),str(PSNR_GTvsPrediction)])\n"," \n"," # Collect values to display in dataframe output\n"," #file_name_list.append(thisFile)\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n","\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n","\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n","\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n","\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n","\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n","\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n","\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","\n","pdResults.head()\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(15,10))\n","# Currently only displays the last computed set, from memory\n","\n","# Target (Ground-truth)\n","plt.subplot(2,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_GT[z_mid_plane])\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","\n","#Setting up colours\n","cmap = plt.cm.Greys\n","\n","\n","# Source\n","plt.subplot(2,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane],aspect='equal',cmap=cmap)\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","\n","#Prediction\n","plt.subplot(2,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","img_Prediction = np.squeeze(img_Prediction,axis=(0,))\n","plt.imshow(img_Prediction[z_mid_plane])\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Prediction\n","plt.subplot(2,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('SSIM map: Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(2,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('RSE map Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Qulity Control folder.')\n","pdResults.head()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"V2ghLobACMy6","colab_type":"text"},"source":["#**6. Using the trained model**\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"SMw0nWXeeC1N","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Results_folder** folder.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Results_folder`:** This folder will contain the predicted output images.\n","\n","If you want to use a model different from the most recently trained one, untick the box and enter the path of the model in **`Prediction_model_folder`**.\n","\n","**Note: `Prediction_model_folder` expects a folder name which contains a model.p file from a previous training.**\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n"]},{"cell_type":"code","metadata":{"id":"8yoXStc8Lo27","colab_type":"code","cellView":"form","colab":{}},"source":["#Before prediction we will remove the old prediction folder because fnet won't execute if a path already exists that has the same name.\n","#This is just in case you have already trained on a dataset with the same name\n","#The data will be saved outside of the pytorch_folder (Results_folder) so it won't be lost when you run this section again.\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/results'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')\n","\n","!pip install -U scipy==1.2.0\n","!pip install --no-cache-dir tifffile==2019.7.26 \n","\n","#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Results_folder = \"\" #@param {type:\"string\"}\n","\n","Predictions_name = 'TempPredictionFolder'\n","Predictions_name_x = Predictions_name+\"}\"\n","\n","#If the folder you are creating already exists, delete the existing version to overwrite.\n","if os.path.exists(Results_folder+'/'+Predictions_name):\n"," shutil.rmtree(Results_folder+'/'+Predictions_name)\n","\n","#@markdown ###Do you want to use the current trained model?\n","\n","Use_the_current_trained_model = True #@param{type:\"boolean\"}\n","\n","#@markdown ###If not, provide the name of the model you want to use \n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","if Use_the_current_trained_model:\n"," #Move the contents of the saved_models folder from your training to the new folder\n"," #Here, we use a different copyfunction as we only need the contents of the trained_model folder\n"," copy_tree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n","else:\n"," copy_tree(Prediction_model_path+'/'+Prediction_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n"," #dataset = Prediction_model_name\n","\n","# Get the name of the folder the test data is in\n","test_dataset_name = os.path.basename(os.path.normpath(Data_folder))\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","!sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","!sed -i \"s/in test.*/in test/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Check that we are using .tif files\n","file_list = os.listdir(Data_folder)\n","text = file_list[0]\n","\n","if text.endswith('.tif') or text.endswith('.tiff'):\n"," !chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet//scripts/predict.sh\n"," !if ! grep class_dataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi\n"," !if grep CziDataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi \n","\n","#Create test_data folder in pytorch_fnet\n","\n","# If your test data is not in the pytorch_fnet data folder it needs to be copied there.\n","if Use_the_current_trained_model == True:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Prediction_model_name+'/'+test_dataset_name):\n"," shutil.copytree(Data_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Prediction_model_name+'/'+test_dataset_name)\n","else:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+test_dataset_name):\n"," shutil.copytree(Data_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+test_dataset_name)\n","\n","\n","# Make a folder that will hold the test.csv file in your new folder\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs')\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name)\n","\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/')\n","\n","#Make a new folder in saved_models to use the trained model for inference.\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name) \n","\n","\n","#Get file list from the folders containing the files you want to use for inference.\n","#test_signal = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+test_dataset_name)\n","test_signal = os.listdir(Data_folder)\n","\n","#Now we make a path csv file to point the predict.sh file to the correct paths for the inference files.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/')\n","\n","#If an old test csv exists we want to overwrite it, so we can insert new test data.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv')\n","\n","#Here we create a new test.csv\n","with open('test.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," if Use_the_current_trained_model ==True:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Prediction_model_name+\"/\"+test_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Prediction_model_name+\"/\"+test_dataset_name+\"/\"+test_signal[i]])\n"," else:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+test_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+test_dataset_name+\"/\"+test_signal[i]])\n","\n","#We run the predictions\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!/content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh $Predictions_name 0\n","\n","#Save the results\n","results_files = os.listdir('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test')\n","for i in range(len(results_files)-2):\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+results_files[i]+'/prediction_'+Predictions_name+'.tiff', Results_folder+'/'+'Prediction_'+test_signal[i])\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+results_files[i]+'/signal.tiff', Results_folder+'/'+test_signal[i])\n","\n","#Comment this out if you want to see the total original results from the prediction in the pytorch_fnet folder.\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"e2f-coEkCf58","colab_type":"text"},"source":["##**6.2. Assess predicted output**\n","---\n","Here, we inspect an example prediction from the predictions on the test dataset. Select the slice of the slice you want to visualize."]},{"cell_type":"code","metadata":{"id":"Uzv5rp6LrYQF","colab_type":"code","cellView":"form","colab":{}},"source":["!pip install matplotlib==2.2.3\n","import numpy as np\n","import matplotlib.pyplot as plt\n","from skimage import io\n","import os\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","#@markdown ###Select the slice would you like to view?\n","slice_number = 1#@param {type:\"number\"}\n","\n","def show_image(file=os.listdir(Data_folder)):\n"," os.chdir(Results_folder)\n","\n","#source_image = io.imread(test_signal[0])\n"," source_image = io.imread(os.path.join(Data_folder,file))\n"," prediction_image = io.imread(os.path.join(Results_folder,'Prediction_'+file))\n"," prediction_image = np.squeeze(prediction_image, axis=(0,))\n","\n","#Create the figure\n"," fig = plt.figure(figsize=(10,20))\n","\n"," #Setting up colours\n"," cmap = plt.cm.Greys\n","\n"," plt.subplot(1,2,1)\n"," print(prediction_image.shape)\n"," plt.imshow(source_image[slice_number], cmap = cmap, aspect = 'equal')\n"," plt.title('Source')\n"," plt.subplot(1,2,2)\n"," plt.imshow(prediction_image[slice_number], cmap = cmap, aspect = 'equal')\n"," plt.title('Prediction')\n","\n","interact(show_image, continuous_update=False);"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3dP2CrCVee1m","colab_type":"text"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"IXXOocFl3on8","colab_type":"text"},"source":["## **6.4. Purge unnecessary folders**\n","---\n"]},{"cell_type":"code","metadata":{"id":"emO85anSThPJ","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##If you have checked that all your data is saved you can delete the pytorch_fnet folder from your drive by playing this cell.\n","\n","import shutil\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"l52zLRCn3z9v","colab_type":"text"},"source":["#**Thank you for using fnet!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"fnet_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1G6lQzjd259Yoy_OozBhJolF4HraE52PG","timestamp":1591353884724},{"file_id":"1pSC680miQesRinU8Tjn7X6AmJXNtxUNI","timestamp":1591182507229},{"file_id":"1ajYZgvhQfpcUZ5YWlUeB-GUU_j-njsqw","timestamp":1589209398121},{"file_id":"1QiFrHg_cVlOl_yzu-RO9mMIrA2L1dXwj","timestamp":1587744376035},{"file_id":"1_S3UtNcuAaZhVc4yqlFDHc2eKq1x-ynn","timestamp":1587058075616},{"file_id":"1Gce_llcAX7yJTFZP2HiNpTL56gXR7PQ-","timestamp":1586854238074},{"file_id":"10l0NA5VWlqRvDlJRTxOiOUgN5LxEo2gy","timestamp":1586601464429},{"file_id":"1NSdad2BEDJZ16AO3SEEaG-ZSe0o4u3eY","timestamp":1586368373257},{"file_id":"1ubiSLYW3G4eNGNF31e2Vbw_3jMHJ9Y7M","timestamp":1585303720184},{"file_id":"1O6YzESEk9VFr6Nc6ijOAYCtiP80uuh7I","timestamp":1585248652537},{"file_id":"1DPrSIbf-ML-LIO2e4YhL1KedWVsVcFlT","timestamp":1585232236512},{"file_id":"1Qanbeybd44tHmdzKxTJAMDD4trFdCYwD","timestamp":1585049767771},{"file_id":"1Fr9Ea5QdUgK0CKfQKpq9KrxtxxAkSVwc","timestamp":1584619265981},{"file_id":"1RQ6XuOBIRaWgId2WKO2i-MMnXoKn_tNA","timestamp":1584541702239},{"file_id":"1mAvQKCCelwK8zPkAWFvKtiAsE_35KSpW","timestamp":1584533728194},{"file_id":"1LdMzIh-v-gUXnd6v9U2Ov28T-XpeT1PP","timestamp":1584463518766},{"file_id":"18Y0NabtThelB0uOAJlg7UbjHPYMEoCqW","timestamp":1584455459923},{"file_id":"1ZCnLW6HUl0bXrPa-54-bv_C9f6jYL0T4","timestamp":1584436296801},{"file_id":"1gTLXTd_rOpXmlktZz2yeEW62gY8ety-I","timestamp":1583941948440},{"file_id":"1gC_pmaDD73tD-yNoFGjHEolYfLd_7czL","timestamp":1583593255888},{"file_id":"17pZee2Vp0kCh3W8pfzRYk8asqk35mOfw","timestamp":1583335080677},{"file_id":"1KyYm3JglQpPYnf-aBLLiP-sFgi_A0Og1","timestamp":1583291424450},{"file_id":"1ZJCI2p66noTaLCnVUQJkTR16ig6GAqAx","timestamp":1576151149296}],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"C-wdtVN5KUFi","colab_type":"text"},"source":["#**Label-free prediction - fnet**\n","---\n","\n"," \n","Label-free prediction (fnet) is a neural network developped to infer the distribution of specific cellular structures from label-free images such as brightfield or EM images. It was first published in 2018 by [Ounkomol *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0111-2). The network uses a common U-Net architecture and is trained using paired imaging volumes from the same field of view, imaged in a label-free (e.g. brightfield) and labelled condition (e.g. fluorescence images of a specific label of interest). When trained, this allows the user to identify certain structures from brightfield images alone. The performance of fnet may depend significantly on the structure at hand.\n","\n","---\n"," *Disclaimer*:\n","\n"," This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n"," This notebook is largely based on the paper: \n","\n","**Label-free prediction of three-dimensional fluorescence images from transmitted light microscopy** by Ounkomol *et al.* in Nature Methods, 2018 (https://www.nature.com/articles/s41592-018-0111-2)\n","\n"," And source code found in: https://github.com/AllenCellModeling/pytorch_fnet\n","\n"," **Please also cite this original paper when using or developing this notebook.** \n"]},{"cell_type":"markdown","metadata":{"id":"Qt5Yt1vsD163","colab_type":"text"},"source":["# **How to use this notebook?**\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"zwILBhMkzKp_","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n","\n"," This notebook provides two opportunities: firstly, to download and train Fnet with data published in the original manuscript or secondly, to upload a personal dataset and train Fnet on it.\n"," The notebook may require a large amount of disk space. If using the datasets from the paper, the available disk space on the user's google drive should contain at least 40GB."]},{"cell_type":"markdown","metadata":{"id":"pcNfrIVpNZC-","colab_type":"text"},"source":["---\n","**Data Format**\n","\n"," **The data used to train fnet must be 3D stacks in .tiff (.tif) file format and contain the signal (e.g. bright-field image) and the target channel (e.g. fluorescence) for each field of view**. To use this notebook on user data, upload the data in the following format to your google drive. To ensure corresponding images are used during training give corresponding signal and target images the same name.\n","\n","Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n"," **Note: Your *dataset_folder* should not have spaces or brackets in its name as this is not recognized by the fnet code and will throw an error** \n","\n","\n","* Experiment A\n"," - **Training dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif, ...\n"," - fluorescence images\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif\n"," - fluorescence images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"I0aF5U_Y0IFW","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"EBHobPtQ8wx7","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"UphYcwdDS8yO","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"kRVmtCZB9OQ2","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"QTEFQc6j9RTv","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yk96o-_u-27d","colab_type":"text"},"source":["#**2. Install fnet and dependencies**\n","---\n","Running fnet requires the fnet folder to be downloaded into the session's Files. As fnet needs several packages to be installed, this step may take a few minutes.\n","\n","You can ignore **the error warnings** as they refer to packages not required for this notebook.\n","\n","**Note: It is not necessary to keep the pytorch_fnet folder after you are finished using the notebook, so it can be deleted afterwards by playing the last cell (bottom).**"]},{"cell_type":"code","metadata":{"id":"BbYpGlfskzrO","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play this cell to download fnet to your drive. If it is already installed this will only install the fnet dependencies.\n","import os\n","import csv\n","import shutil\n","import random\n","from tempfile import mkstemp\n","from shutil import move, copymode\n","from os import fdopen, remove\n","import sys\n","import numpy as np\n","import shutil\n","import os\n","from tempfile import mkstemp\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from skimage import img_as_float32\n","from distutils.dir_util import copy_tree\n","import datetime\n","import time\n","\n","#Ensure tensorflow 1.x\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","#clone fnet from github to colab\n","!pip install -U scipy==1.2.0\n","!pip install matplotlib==2.2.3\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet'):\n"," !git clone -b release_1 --single-branch https://github.com/AllenCellModeling/pytorch_fnet.git; cd pytorch_fnet; pip install .\n"," shutil.move('/content/pytorch_fnet','/content/gdrive/My Drive/pytorch_fnet')\n","from skimage import io\n","from matplotlib import pyplot as plt\n","import pandas as pd\n","#from skimage.util import img_as_uint\n","import matplotlib as mpl\n","#from scipy import signal\n","#from scipy import ndimage\n","\n","\n","#This function replaces the old default files with new values\n","def replace(file_path, pattern, subst):\n"," #Create temp file\n"," fh, abs_path = mkstemp()\n"," with fdopen(fh,'w') as new_file:\n"," with open(file_path) as old_file:\n"," for line in old_file:\n"," new_file.write(line.replace(pattern, subst))\n"," #Copy the file permissions from the old file to the new file\n"," copymode(file_path, abs_path)\n"," #Remove original file\n"," remove(file_path)\n"," #Move new file\n"," move(abs_path, file_path)\n","\n","def insert_line_to_file(filepath,line_number,insertion):\n"," f = open(filepath, \"r\")\n"," contents = f.readlines()\n"," f.close()\n"," f = open(filepath, \"r\")\n"," if not insertion in f.read():\n"," contents.insert(line_number, insertion)\n"," f.close()\n"," f = open(filepath, \"w\")\n"," contents = \"\".join(contents)\n"," f.write(contents)\n"," f.close()\n","\n","def add_validation(filepath,line_number,insert,append):\n"," f = open(filepath, \"r\")\n"," contents = f.readlines()\n"," f.close()\n"," f = open(filepath, \"r\")\n"," if not 'PATH_DATASET_VAL_CSV=' in f.read():\n"," contents.insert(line_number, insert)\n"," contents.append(append)\n"," f.close()\n"," f = open(filepath, \"w\")\n"," contents = \"\".join(contents)\n"," f.write(contents)\n"," f.close()\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","#Here we replace values in the old files\n","#Change maximum pixel number\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/fnet/transforms.py\",'n_max_pixels=9732096','n_max_pixels=20000000')\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/predict.py\",'6000000','20000000')\n","\n","#Prevent resizing in the training and the prediction\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/predict.py\",\"0.37241\",\"1.0\")\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"0.37241\",\"1.0\")\n","\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"JqCe6m-C_PrH","colab_type":"text"},"source":["#**3. Select your paths and parameters**\n","---"]},{"cell_type":"markdown","metadata":{"id":"w5NmDpJ4xvWE","colab_type":"text"},"source":["## **3.1. Setting the main training parameters**\n","---\n"," **Paths for training data**\n","\n"," **`Training_source`,`Training_target:`** These are the paths to your folders containing the Training_source (brightfield) and Training_target (fluorescent label) training data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**Note: The stacks for fnet should either have 32 or more slices or have a number of slices which are a power of 2 (e.g. 2,4,8,16).**\n","\n"," **Training Parameters**\n","\n"," **`steps:`** Input how many iterations you want to train the network for. A larger number may improve performance but risks overfitting to the training data. To reach good performance of fnet requires several 10000's iterations which will usually require **several hours**, depending on the dataset size. **Default: 50000**\n","\n","**`batch_size:`** Reducing or increasing the **batch size** may speed up or slow down your training, respectively and can influence network performance. **Default: 4**"]},{"cell_type":"code","metadata":{"id":"PWxNzzgKu9Kb","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###Datasets\n","#Datasets\n","\n","#Change checkpoints\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","#replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_source = \"\" #@param {type: \"string\"}\n","source_name = os.path.basename(os.path.normpath(Training_source))\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_target = \"\" #@param {type: \"string\"}\n","target_name = os.path.basename(os.path.normpath(Training_target))\n","\n","#@markdown ###Model name and model path\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","#dataset = model_name #The name of the dataset and the model will be the same\n","\n","#Here, we check if the dataset already exists. If not, copy the dataset from google drive to the data folder\n"," \n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name):\n"," #shutil.copytree(own_dataset,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset)\n"," os.makedirs('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name)\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","elif os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name) and not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name):\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","elif os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name) and os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n","#Create a path_csv file to point to the training images\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","\n","source = os.listdir('./'+model_name+'/'+source_name)\n","target = os.listdir('./'+model_name+'/'+target_name)\n","\n","#print(\"Selected \"+dataset+\" as training set\")\n","\n","model_name_x = model_name+\"}\" # this variable is only used to ensure closed curly brackets when editing the .sh files\n","\n","#We need to declare that we will run validation on the dataset\n","#We need to add a new line to the train.sh file\n","with open(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\", \"r\") as f:\n"," if not \"gpu_ids ${GPU_IDS} \\\\\" in f.read():\n"," replace(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\",\" --gpu_ids ${GPU_IDS}\",\" --gpu_ids ${GPU_IDS} \\\\\")\n","\n","#We add the necessary validation parameters here.\n","insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n","append = '\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}'\n","add_validation(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\",10,insert,append)\n","\n","#Clear the White space from train.sh\n","\n","with open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh','/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","\n","#Here we define the random set of training files to be used for validation\n","val_files = random.sample(source,len(source)//10)\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","\n","#Make validation directories\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","\n","#Move a random set of files from the training to the validation folders\n","for file in val_files:\n"," shutil.move('./'+model_name+'/'+source_name+'/'+file,'./'+model_name+'/Validation_Input/'+file)\n"," shutil.move('./'+model_name+'/'+target_name+'/'+file,'./'+model_name+'/Validation_Target/'+file)\n","\n","#Redefine the source and target lists after moving the validation files\n","source = os.listdir('./'+model_name+'/'+source_name)\n","target = os.listdir('./'+model_name+'/'+target_name)\n","\n","#Define Validation file lists\n","val_signal = os.listdir('./'+model_name+'/Validation_Input')\n","val_target = os.listdir('./'+model_name+'/Validation_Target')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv')\n","\n","#Finally, we create a validation csv file to construct the validation dataset\n","with open(model_name+'_val.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(val_signal)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/Validation_Input/\"+val_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/Validation_Target/\"+val_target[i]])\n","\n","shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'_val.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n","with open(model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+source_name+\"/\"+source[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+target_name+\"/\"+target[i]])\n","\n","shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n","\n","#@markdown ---\n","\n","#@markdown ###Training Parameters\n","\n","#Training parameters in fnet are indicated in the train_model.sh file.\n","#Here, we edit this file to include the desired parameters\n","\n","#1. Add permissions to train_model.sh\n","os.chdir(\"/content/gdrive/My Drive/pytorch_fnet/scripts\")\n","!chmod u+x train_model.sh\n","\n","#2. Select parameters\n","steps = 50000#@param {type:\"number\"}\n","batch_size = 4#@param {type:\"number\"}\n","number_of_images = len(source)\n","\n","#3. Insert the above values into train_model.sh\n","!if ! grep saved_models\\/\\${ train_model.sh;then sed -i 's/saved_models\\/.*/saved_models\\/\\${DATASET}\"/g' train_model.sh; fi \n","!sed -i \"s/1:-.*/1:-$model_name_x/g\" train_model.sh #change the dataset to be trained with\n","!sed -i \"s/N_ITER=.*/N_ITER=$steps/g\" train_model.sh #change the number of training iterations (steps)\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n","\n","#We also change the training split as in our notebook the test images are used separately for prediction and we want fnet to train on the whole training data set.\n","!sed -i \"s/train_size .* -v/train_size 1.0 -v/g\" train_model.sh\n","\n","#If new parameters are inserted here for training a model with the same name\n","#the previous training csv needs to be removed, to prevent the model using the old training split or paths.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"BCKcSJxkxi33","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"msrTTcPI1Cav","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating images in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n","**Note:** Using a full augmented dataset can exceed the RAM limitations of the colab notebook. If the augmented dataset is too large, the notebook will therefore only pick a subset of the augmented dataset for training. Make sure you only augment datasets which are small (ca. 20-30 images)."]},{"cell_type":"code","metadata":{"id":"u_YFN6Bd594L","colab_type":"code","cellView":"form","colab":{}},"source":["from skimage import io\n","import numpy as np\n","\n","Use_Data_augmentation = True #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n","\n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name,flip=Flip)\n"," \n"," elif Rotation == False and Flip == True:\n"," flip('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n"," \n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n"," #Fetch the path and extract the name of the signal folder\n"," Training_source = Saving_path+\"/augmented_source\"\n"," source_name = os.path.basename(os.path.normpath(Training_source))\n","\n"," #Fetch the path and extract the name of the signal folder\n"," Training_target = Saving_path+\"/augmented_target\"\n"," target_name = os.path.basename(os.path.normpath(Training_target))\n","\n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n"," os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n"," #Redefine the source and target lists after moving the validation files\n"," source = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," target = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n"," with open(model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+source_name+\"/\"+source[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+target_name+\"/\"+target[i]])\n","\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n","\n"," #Here, we ensure that the all files, including Validation are saved somewhere together for later access, e.g. for retraining.\n"," for image in os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input'):\n"," shutil.copyfile(os.path.join('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input',image),Saving_path+'/augmented_source/'+image)\n"," shutil.copyfile(os.path.join('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target',image),Saving_path+'/augmented_target/'+image)\n"," \n"," if len(source)>130:\n"," number_of_images = 130\n"," else:\n"," number_of_images = len(source)\n","\n"," os.chdir(\"/content/gdrive/My Drive/pytorch_fnet/scripts\")\n"," !chmod u+x train_model.sh\n"," !sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","\n"," print(\"Done\")\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"heuBzM5JADYf","colab_type":"text"},"source":["#**4. Train the network**\n","---\n","Before training, carefully read the different options. This applies especially if you have trained fnet on a dataset before.\n","\n","\n","###**Choose one of the options to train fnet**.\n","\n","**4.1.** If this is the first training on the chosen dataset, play this section to start training.\n","\n","**4.2.** If you want to continue training on an already pre-trained model choose this section\n","\n"," **Carefully read the options before starting training.**"]},{"cell_type":"markdown","metadata":{"id":"eLllOs_rA62U","colab_type":"text"},"source":["##**4.1. Start Trainning**\n","---\n","\n","####Play the cell below to start training. \n","\n","**Note:** If you are training with a model of the same name as before, the model will be overwritten. If you want to keep the previous model save it before playing the cell below or give your model a different name (section 3)."]},{"cell_type":"code","metadata":{"id":"xe3TLu7M-3Dk","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","\n","#Overwriting old models and saving them separately if True\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name)\n","\n","#This tifffile release runs error-free in this version of fnet.\n","!pip install tifffile==2019.7.26\n","\n","#Here we import an additional module to the functions.py file to run it without errors.\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","insert_line_to_file(\"/content/gdrive/My Drive/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\")\n","\n","\n","print('Let''s start the training!')\n","#Here we start the training\n","!./scripts/train_model.sh $model_name 0\n","\n","#After training overwrite any existing model in the model_path with the new trained model.\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n","shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,model_path+'/'+model_name)\n","\n","shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv',model_path+'/'+model_name+'/'+model_name+'_val.csv')\n","#Get rid of duplicates of training data in pytorch_fnet after training completes\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"fpXr4JlCd5uV","colab_type":"text"},"source":["**Note:** Fnet takes a long time for training. If your notebook times out due to the length of the training or due to a loss of GPU acceleration the last checkpoint will be saved in the saved_models folder in the pytorch_fnet folder. If you want to save it in a more convenient location on your drive, remount the drive (if you got disconnected) and in the next cell enter the location (`model_path`) where you want to save the model (`model_name`) before continuing in 4.2. **If you did not time out you can ignore this section.**"]},{"cell_type":"code","metadata":{"id":"x41OhmO-hsX3","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play this cell if your model training timed out and indicate where you want to save the last checkpoint.\n","\n","import shutil\n","import os\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name):\n"," shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,model_path+'/'+model_name)\n","else:\n"," print('This model name does not exist in your saved_models folder. Make sure you have entered the name of the model that timed out.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QefQX9WUBz0G","colab_type":"text"},"source":["##**4.2. Training from a previously saved model**\n","---\n","This section allows you to use networks you have previously trained and saved and to continue training them for more training steps. The folders have the same meaning as above (3.1.). If you want to save the previously trained model, create a copy now as this section will overwrite the weights of the old model. **You can currently only train the model with the same dataset and batch size that the network was previously trained on.**\n","\n","**Note: To use this section the *pytorch_fnet* folder must be in your *gdrive/My Drive*. (Simply, play cell 2. to make sure).**"]},{"cell_type":"code","metadata":{"id":"2-0m_-tF9oo-","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown To test if performance improves after the initial training, you can continue training on the old model. This option can also be useful if Colab disconnects or times out.\n","#@markdown Enter the paths of the datasets you want to continue training on.\n","\n","#Here we replace values in the old files\n","\n","insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n","append = '\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}'\n","\n","add_validation(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\",10,insert,append)\n","#Clear the White space from train.sh\n","\n","with open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh','/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","\n","#Datasets\n","\n","#Change checkpoints\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","\n","Training_source = \"\" #@param {type: \"string\"}\n","source_name = os.path.basename(os.path.normpath(Training_source))\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_target = \"\" #@param {type: \"string\"}\n","target_name = os.path.basename(os.path.normpath(Training_target))\n","\n","Pretrained_model_folder = \"\" #@param{type:\"string\"}\n","#model_name = \"\" #@param {type:\"string\"}\n","\n","Pretrained_model_name = os.path.basename(Pretrained_model_folder)\n","Pretrained_model_path = os.path.dirname(Pretrained_model_folder)\n","batch_size = 4 #@param {type:\"number\"}\n","\n","Pretrained_model_name_x = Pretrained_model_name+\"}\"\n","\n","#Move your model to fnet\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Pretrained_model_name):\n"," shutil.copytree(Pretrained_model_folder,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Pretrained_model_name)\n","\n","#Move the datasets into fnet\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name)\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name)\n","shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name)\n","shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+target_name)\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/scripts')\n","\n","### number_of_images = len(os.listdir(Training_source)) ###\n","\n","#Change the train_model.sh file to include chosen dataset\n","!chmod u+x ./train_model.sh\n","!sed -i \"s/1:-.*/1:-$Pretrained_model_name_x/g\" train_model.sh\n","!sed -i \"s/train_size .* -v/train_size 1.0 -v/g\" train_model.sh #Use the whole training dataset for training\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n","\n","\n","# We will use the same validation files from the training dataset as used before,\n","# This makes sure that the model is not validated with files it has seen in training before saving.\n","\n","#First we get the names of the validation files from the previous training which are saved in the validation csv.\n","val_source_list = []\n","\n","##CHECK THIS Prediction_model_name\n","with open('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_folder+'_val.csv', 'r') as f:\n","#with open(Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv', 'r') as f:\n"," contents = csv.reader(f,delimiter=',')\n"," for row in contents:\n"," val_source_list.append(row[0])\n","\n","#Get the file list without the header\n","val_source_list = val_source_list[1::]\n","\n","#Get only the file names and not the full path\n","for i in range(0,len(val_source_list)):\n"," val_source_list[i] = os.path.basename(os.path.normpath(val_source_list[i]))\n","\n","source = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name)\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input')\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target')\n","\n","#Make validation directories\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input')\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target')\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","\n","#Move a random set of files from the training to the validation folders\n","for file in val_source_list:\n"," #os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name+'/'+file,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input/'+file)\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+target_name+'/'+file,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target/'+file)\n","\n","#Redefine the source and target lists after moving the validation files\n","source = os.listdir('./'+Pretrained_model_name+'/'+source_name)\n","target = os.listdir('./'+Pretrained_model_name+'/'+target_name)\n","\n","#Define Validation file lists\n","val_signal = os.listdir('./'+Pretrained_model_name+'/Validation_Input')\n","val_target = os.listdir('./'+Pretrained_model_name+'/Validation_Target')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv')\n","\n","shutil.copyfile(Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv')\n","\n","#Make a training csv file.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name)\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","source = os.listdir('./'+Pretrained_model_name+'/'+source_name)\n","target = os.listdir('./'+Pretrained_model_name+'/'+target_name)\n","with open(Pretrained_model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Pretrained_model_name+\"/\"+source_name+\"/\"+source[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Pretrained_model_name+\"/\"+target_name+\"/\"+target[i]])\n","\n","shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'.csv')\n","\n","#Find the number of previous training iterations (steps) from loss csv file\n","\n","with open(Pretrained_model_folder+'/losses.csv') as f:\n"," previous_steps = sum(1 for line in f)\n","print('continuing training after step '+str(previous_steps-1))\n","\n","print('To start re-training play section 4.2. below')\n","\n","#@markdown For how many additional steps do you want to train the model?\n","add_steps = 50000#@param {type:\"number\"}\n","\n","#Calculate the new number of total training epochs. Subtract 1 to discount the title row of the csv file.\n","new_steps = previous_steps + add_steps -1\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/scripts')\n","\n","#Edit train_model.sh file to include new total number of training epochs\n","!sed -i \"s/N_ITER=.*/N_ITER=$new_steps/g\" train_model.sh"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"vH3EzxbfD6Uk","colab_type":"code","cellView":"form","colab":{}},"source":["import datetime\n","import time\n","start = time.time()\n","\n","#@markdown ##4.2. Start re-training model\n","!pip install tifffile==2019.7.26\n","import os\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/fnet')\n","\n","insert_line_to_file(\"/content/gdrive/My Drive/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\")\n","\n","#Here we retrain the model on the chosen dataset.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!chmod u+x ./scripts/train_model.sh\n","!./scripts/train_model.sh $Pretrained_model_name 0\n","\n","if os.path.exists(Pretrained_model_folder):\n"," shutil.rmtree(Pretrained_model_folder)\n","shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Pretrained_model_name,Pretrained_model_folder)\n","\n","#Get rid of duplicates of training data in pytorch_fnet after training completes\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name)\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+target_name)\n","\n","shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv',Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","min, sec = divmod(dt, 60) \n","hour, min = divmod(min, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",min,\"min(s)\",round(sec),\"sec(s)\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jwORXPtcqRHZ","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"rVBx2b2MpoFf","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the name of the model and path to model folder:\n","\n","QC_model_folder = \"/content/gdrive/My Drive/NewFnet_2\" #@param {type:\"string\"}\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","#Create a folder for the quality control metrics\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"aNR6bAk6oZJD","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"ratRdSDlcQ9G","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show figure of training errors\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","iterationNumber_training = []\n","iterationNumber_val = []\n","\n","import csv\n","from matplotlib import pyplot as plt\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses.csv','r') as csvfile:\n"," plots = csv.reader(csvfile, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_training.append(int(row[0]))\n"," lossDataFromCSV.append(float(row[1]))\n","\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses_val.csv','r') as csvfile_val:\n"," plots = csv.reader(csvfile_val, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_val.append(int(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.plot(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.semilogy(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/'+'losses.png')\n","plt.show()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"YkhOGv3Hp2xI","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n"]},{"cell_type":"code","metadata":{"id":"vqSH6EQb4BwU","colab_type":"code","cellView":"form","colab":{}},"source":["#Overwrite results folder if it already exists at the given location\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/results'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')\n","\n","!pip install -U scipy==1.2.0\n","!pip install --no-cache-dir tifffile==2019.7.26 \n","from distutils.dir_util import copy_tree\n","\n","#----------------CREATING PREDICTIONS FOR QUALITY CONTROL----------------------------------#\n","\n","\n","#Choose the folder with the quality control datasets\n","Source_QC_folder = \"/content/gdrive/My Drive/Label-free_prediction_(fnet)_v2/Test_dataset/Test-Transmitted_light_stacks_Split_data\" #@param{type:\"string\"}\n","Target_QC_folder = \"/content/gdrive/My Drive/Label-free_prediction_(fnet)_v2/Test_dataset/Test-TOM20_fluorescence_stacks_Split_data\" #@param{type:\"string\"}\n","\n","Predictions_name = \"QualityControl\" \n","Predictions_name_x = Predictions_name+\"}\"\n","\n","#If the folder you are creating already exists, delete the existing version to overwrite.\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+Predictions_name):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+Predictions_name)\n","\n","if Use_the_current_trained_model == True:\n"," #Move the contents of the saved_models folder from your training to the new folder\n"," #Here, we use a different copyfunction as we only need the contents of the trained_model folder\n"," copy_tree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+QC_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n","else:\n"," copy_tree(QC_model_path+'/'+QC_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n"," #dataset = QC_model_name\n","\n","# Get the name of the folder the test data is in\n","source_dataset_name = os.path.basename(os.path.normpath(Source_QC_folder))\n","target_dataset_name = os.path.basename(os.path.normpath(Target_QC_folder))\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","!sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","!sed -i \"s/in test.*/in test/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Check that we are using .tif files\n","file_list = os.listdir(Source_QC_folder)\n","text = file_list[0]\n","\n","if text.endswith('.tif') or text.endswith('.tiff'):\n"," !chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet//scripts/predict.sh\n"," !if ! grep class_dataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi\n"," !if grep CziDataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi \n","\n","#Create test_data folder in pytorch_fnet\n","\n","# If your test data is not in the pytorch_fnet data folder it needs to be copied there.\n","if Use_the_current_trained_model == True:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name+'/'+source_dataset_name):\n"," shutil.copytree(Source_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name+'/'+source_dataset_name)\n"," shutil.copytree(Target_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name+'/'+target_dataset_name)\n","else:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+source_dataset_name):\n"," shutil.copytree(Source_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+source_dataset_name)\n"," shutil.copytree(Target_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+target_dataset_name)\n","\n","\n","# Make a folder that will hold the test.csv file in your new folder\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs')\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name)\n","\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/')\n","\n","#Make a new folder in saved_models to use the trained model for inference.\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name) \n","\n","\n","#Get file list from the folders containing the files you want to use for inference.\n","#test_signal = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_dataset_name)\n","test_signal = os.listdir(Source_QC_folder)\n","test_target = os.listdir(Target_QC_folder)\n","#Now we make a path csv file to point the predict.sh file to the correct paths for the inference files.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/')\n","\n","#If an old test csv exists we want to overwrite it, so we can insert new test data.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv')\n","\n","#Here we create a new test.csv\n","with open('test.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," if Use_the_current_trained_model == True:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+QC_model_name+\"/\"+source_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+QC_model_name+\"/\"+target_dataset_name+\"/\"+test_signal[i]])\n"," # This currently assumes that the names are identical for source and target: see \"test_target\" variable is never used\n"," else:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+source_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+target_dataset_name+\"/\"+test_signal[i]])\n","\n","#We run the predictions\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!/content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh $Predictions_name 0\n","\n","#Save the results\n","QC_results_files = os.listdir('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Target'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Target')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Target')\n","\n","for i in range(len(QC_results_files)-2):\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/prediction_'+Predictions_name+'.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction/'+'Predicted_'+test_signal[i])\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/signal.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Signal/'+test_signal[i])\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/target.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Target/'+test_signal[i])\n","\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name)\n","\n","\n","#-----------------------------METRICS EVALUATION-------------------------------#\n","\n","# Calculating the position of the mid-plane slice\n","# Perform prediction on all datasets in the Source_QC folder\n","\n","#Finding the middle slice\n","img = io.imread(os.path.join(Source_QC_folder, os.listdir(Source_QC_folder)[0]))\n","n_slices = img.shape[0]\n","z_mid_plane = int(n_slices / 2)+1\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Prediction v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," NRMSE_GvP_list = []\n"," PSNR_GvP_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," PSNR_GvP_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," test_prediction_stack = np.squeeze(test_prediction_stack,axis=(0,))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," \n"," # -------------------------------- Prediction --------------------------------\n","\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = np.float32(img_SSIM_GTvsPrediction)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = np.float32(img_RSE_GTvsPrediction)\n","\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n","\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n","\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(NRMSE_GTvsPrediction),str(PSNR_GTvsPrediction)])\n"," \n"," # Collect values to display in dataframe output\n"," #file_name_list.append(thisFile)\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n","\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n","\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n","\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n","\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n","\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n","\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n","\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","\n","pdResults.head()\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(15,10))\n","# Currently only displays the last computed set, from memory\n","\n","# Target (Ground-truth)\n","plt.subplot(2,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_GT[z_mid_plane])\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","\n","#Setting up colours\n","cmap = plt.cm.Greys\n","\n","\n","# Source\n","plt.subplot(2,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane],aspect='equal',cmap=cmap)\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","\n","#Prediction\n","plt.subplot(2,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","img_Prediction = np.squeeze(img_Prediction,axis=(0,))\n","plt.imshow(img_Prediction[z_mid_plane])\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Prediction\n","plt.subplot(2,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('SSIM map: Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(2,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('RSE map Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Qulity Control folder.')\n","pdResults.head()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"V2ghLobACMy6","colab_type":"text"},"source":["#**6. Using the trained model**\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"SMw0nWXeeC1N","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Results_folder** folder.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Results_folder`:** This folder will contain the predicted output images.\n","\n","If you want to use a model different from the most recently trained one, untick the box and enter the path of the model in **`Prediction_model_folder`**.\n","\n","**Note: `Prediction_model_folder` expects a folder name which contains a model.p file from a previous training.**\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n"]},{"cell_type":"code","metadata":{"id":"8yoXStc8Lo27","colab_type":"code","cellView":"form","colab":{}},"source":["#Before prediction we will remove the old prediction folder because fnet won't execute if a path already exists that has the same name.\n","#This is just in case you have already trained on a dataset with the same name\n","#The data will be saved outside of the pytorch_folder (Results_folder) so it won't be lost when you run this section again.\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/results'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')\n","\n","!pip install -U scipy==1.2.0\n","!pip install --no-cache-dir tifffile==2019.7.26 \n","\n","#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Results_folder = \"\" #@param {type:\"string\"}\n","\n","Predictions_name = 'TempPredictionFolder'\n","Predictions_name_x = Predictions_name+\"}\"\n","\n","#If the folder you are creating already exists, delete the existing version to overwrite.\n","if os.path.exists(Results_folder+'/'+Predictions_name):\n"," shutil.rmtree(Results_folder+'/'+Predictions_name)\n","\n","#@markdown ###Do you want to use the current trained model?\n","\n","Use_the_current_trained_model = True #@param{type:\"boolean\"}\n","\n","#@markdown ###If not, provide the name of the model you want to use \n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","if Use_the_current_trained_model:\n"," #Move the contents of the saved_models folder from your training to the new folder\n"," #Here, we use a different copyfunction as we only need the contents of the trained_model folder\n"," copy_tree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n","else:\n"," copy_tree(Prediction_model_path+'/'+Prediction_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n"," #dataset = Prediction_model_name\n","\n","# Get the name of the folder the test data is in\n","test_dataset_name = os.path.basename(os.path.normpath(Data_folder))\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","!sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","!sed -i \"s/in test.*/in test/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Check that we are using .tif files\n","file_list = os.listdir(Data_folder)\n","text = file_list[0]\n","\n","if text.endswith('.tif') or text.endswith('.tiff'):\n"," !chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet//scripts/predict.sh\n"," !if ! grep class_dataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi\n"," !if grep CziDataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi \n","\n","#Create test_data folder in pytorch_fnet\n","\n","# If your test data is not in the pytorch_fnet data folder it needs to be copied there.\n","if Use_the_current_trained_model == True:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Prediction_model_name+'/'+test_dataset_name):\n"," shutil.copytree(Data_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Prediction_model_name+'/'+test_dataset_name)\n","else:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+test_dataset_name):\n"," shutil.copytree(Data_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+test_dataset_name)\n","\n","\n","# Make a folder that will hold the test.csv file in your new folder\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs')\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name)\n","\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/')\n","\n","#Make a new folder in saved_models to use the trained model for inference.\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name) \n","\n","\n","#Get file list from the folders containing the files you want to use for inference.\n","#test_signal = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+test_dataset_name)\n","test_signal = os.listdir(Data_folder)\n","\n","#Now we make a path csv file to point the predict.sh file to the correct paths for the inference files.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/')\n","\n","#If an old test csv exists we want to overwrite it, so we can insert new test data.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv')\n","\n","#Here we create a new test.csv\n","with open('test.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," if Use_the_current_trained_model ==True:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Prediction_model_name+\"/\"+test_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Prediction_model_name+\"/\"+test_dataset_name+\"/\"+test_signal[i]])\n"," else:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+test_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+test_dataset_name+\"/\"+test_signal[i]])\n","\n","#We run the predictions\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!/content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh $Predictions_name 0\n","\n","#Save the results\n","results_files = os.listdir('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test')\n","for i in range(len(results_files)-2):\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+results_files[i]+'/prediction_'+Predictions_name+'.tiff', Results_folder+'/'+'Prediction_'+test_signal[i])\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+results_files[i]+'/signal.tiff', Results_folder+'/'+test_signal[i])\n","\n","#Comment this out if you want to see the total original results from the prediction in the pytorch_fnet folder.\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"e2f-coEkCf58","colab_type":"text"},"source":["##**6.2. Assess predicted output**\n","---\n","Here, we inspect an example prediction from the predictions on the test dataset. Select the slice of the slice you want to visualize."]},{"cell_type":"code","metadata":{"id":"Uzv5rp6LrYQF","colab_type":"code","cellView":"form","colab":{}},"source":["!pip install matplotlib==2.2.3\n","import numpy as np\n","import matplotlib.pyplot as plt\n","from skimage import io\n","import os\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","#@markdown ###Select the slice would you like to view?\n","slice_number = 1#@param {type:\"number\"}\n","\n","def show_image(file=os.listdir(Data_folder)):\n"," os.chdir(Results_folder)\n","\n","#source_image = io.imread(test_signal[0])\n"," source_image = io.imread(os.path.join(Data_folder,file))\n"," prediction_image = io.imread(os.path.join(Results_folder,'Prediction_'+file))\n"," prediction_image = np.squeeze(prediction_image, axis=(0,))\n","\n","#Create the figure\n"," fig = plt.figure(figsize=(10,20))\n","\n"," #Setting up colours\n"," cmap = plt.cm.Greys\n","\n"," plt.subplot(1,2,1)\n"," print(prediction_image.shape)\n"," plt.imshow(source_image[slice_number], cmap = cmap, aspect = 'equal')\n"," plt.title('Source')\n"," plt.subplot(1,2,2)\n"," plt.imshow(prediction_image[slice_number], cmap = cmap, aspect = 'equal')\n"," plt.title('Prediction')\n","\n","interact(show_image, continuous_update=False);"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3dP2CrCVee1m","colab_type":"text"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"IXXOocFl3on8","colab_type":"text"},"source":["## **6.4. Purge unnecessary folders**\n","---\n"]},{"cell_type":"code","metadata":{"id":"emO85anSThPJ","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##If you have checked that all your data is saved you can delete the pytorch_fnet folder from your drive by playing this cell.\n","\n","import shutil\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"l52zLRCn3z9v","colab_type":"text"},"source":["#**Thank you for using fnet!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/pix2pix_ZeroCostDL4Mic.ipynb b/Colab_notebooks/pix2pix_ZeroCostDL4Mic.ipynb index 4d9bb56a..d6900525 100755 --- a/Colab_notebooks/pix2pix_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/pix2pix_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"pix2pix_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1mqcexfPBaIWuvMWWbJZUFtPoZoJJwrEA","timestamp":1589278334507},{"file_id":"159ARwlQE7-zi0EHxunOF_YPFLt-ZVU5x","timestamp":1587562499898},{"file_id":"1W-7NHehG5MRFILvZZzhPWWnOdJMkadb2","timestamp":1586332290412},{"file_id":"1pUetEQICxYWkYVaQIgdRH1EZBTl7oc2A","timestamp":1586292199692},{"file_id":"1MD36ZkM6XR9EuV12zimJmfCjzyeYZFWq","timestamp":1586269469061},{"file_id":"16A2mbaHzlEElntS8qkFBOsBvZG-mUeY6","timestamp":1586253795726},{"file_id":"1gJlcjOiSxr2buDOxmcFbT_d-GqwLjXtK","timestamp":1583343225796},{"file_id":"10yGI51WzHfgWgZAyE-EbkZFEvIOd6CP6","timestamp":1583171396283}],"collapsed_sections":[],"toc_visible":true},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I","colab_type":"text"},"source":["# **pix2pix**\n","\n","---\n","\n","pix2pix is a deep-learning method allowing image-to-image translation from one image domain type to another image domain type. It was first published by [Isola *et al.* in 2016](https://arxiv.org/abs/1611.07004). The image transformation requires paired images for training (supervised learning) and is made possible here by using a conditional Generative Adversarial Network (GAN) architecture to use information from the input image and obtain the equivalent translated image.\n","\n"," **This particular notebook enables image-to-image translation learned from paired dataset. If you are interested in performing unpaired image-to-image translation, you should consider using the CycleGAN notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n"," **Image-to-Image Translation with Conditional Adversarial Networks** by Isola *et al.* on arXiv in 2016 (https://arxiv.org/abs/1611.07004)\n","\n","The source code of the PyTorch implementation of pix2pix can be found here: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"N3azwKB9O0oW","colab_type":"text"},"source":["# **License**\n","\n","---"]},{"cell_type":"code","metadata":{"id":"ByW6Vqdn9sYV","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Double click to see the license information\n","\n","#------------------------- LICENSE FOR ZeroCostDL4Mic------------------------------------\n","#This ZeroCostDL4Mic notebook is distributed under the MIT licence\n","\n","\n","\n","#------------------------- LICENSE FOR CycleGAN ------------------------------------\n","\n","#Copyright (c) 2017, Jun-Yan Zhu and Taesung Park\n","#All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without\n","#modification, are permitted provided that the following conditions are met:\n","\n","#* Redistributions of source code must retain the above copyright notice, this\n","# list of conditions and the following disclaimer.\n","\n","#* Redistributions in binary form must reproduce the above copyright notice,\n","# this list of conditions and the following disclaimer in the documentation\n","# and/or other materials provided with the distribution.\n","\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n","#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n","#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n","#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n","#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n","#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n","#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n","#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n","#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n","#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n","\n","\n","#--------------------------- LICENSE FOR pix2pix --------------------------------\n","#BSD License\n","\n","#For pix2pix software\n","#Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu\n","#All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without\n","#modification, are permitted provided that the following conditions are met:\n","\n","#* Redistributions of source code must retain the above copyright notice, this\n","# list of conditions and the following disclaimer.\n","\n","#* Redistributions in binary form must reproduce the above copyright notice,\n","# this list of conditions and the following disclaimer in the documentation\n","# and/or other materials provided with the distribution.\n","\n","#----------------------------- LICENSE FOR DCGAN --------------------------------\n","#BSD License\n","\n","#For dcgan.torch software\n","\n","#Copyright (c) 2015, Facebook, Inc. All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n","\n","#Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n","\n","#Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n","\n","#Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\n","\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," For pix2pix to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called Training_source and Training_target. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .PNG files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Training_source\n"," - img_1.png, img_2.png, ...\n"," - Training_target\n"," - img_1.png, img_2.png, ...\n"," - **Quality control dataset**\n"," - Training_source\n"," - img_1.png, img_2.png\n"," - Training_target\n"," - img_1.png, img_2.png\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"BCPhV-pe-syw","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"VNZetvLiS1qV","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"UBrnApIUBgxv","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin","colab_type":"text"},"source":["# **2. Install pix2pix and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install pix2pix and dependencies\n","\n","\n","\n","#Here, we install libraries which are not already included in Colab.\n","\n","\n","!git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","import os\n","os.chdir('pytorch-CycleGAN-and-pix2pix/')\n","!pip install -r requirements.txt\n","\n","import imageio\n","from skimage import data\n","from skimage import exposure\n","from skimage.exposure import match_histograms\n","import glob\n","import os.path\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print('----------------------------')\n","print(\"Libraries installed\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"BLmBseWbRvxL","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd","colab_type":"text"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source and Training_target training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10) epochs, but a full training should run for 200 epochs or more. Evaluate the performance after training (see 5). **Default value: 200**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`patch_size`:** pix2pix divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 512**\n","\n","**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0002**"]},{"cell_type":"code","metadata":{"id":"pIrTwJjzwV-D","colab_type":"code","cellView":"form","colab":{}},"source":["\n","\n","#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","InputFile = Training_source+\"/*.png\"\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","OutputFile = Training_target+\"/*.png\"\n","\n","\n","#Define where the patch file will be saved\n","base = \"/content\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 200#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","patch_size = 512#@param {type:\"number\"} # in pixels\n","batch_size = 1#@param {type:\"number\"}\n","initial_learning_rate = 0.0002 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 1\n"," patch_size = 512\n"," initial_learning_rate = 0.0002\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\")\n"," \n","#To use pix2pix we need to organise the data in a way the network can understand\n","\n","Saving_path= \"/content/\"+model_name\n","#Saving_path= model_path+\"/\"+model_name\n","\n","if os.path.exists(Saving_path):\n"," shutil.rmtree(Saving_path)\n","os.makedirs(Saving_path)\n","\n","imageA_folder = Saving_path+\"/A\"\n","os.makedirs(imageA_folder)\n","\n","imageB_folder = Saving_path+\"/B\"\n","os.makedirs(imageB_folder)\n","\n","imageAB_folder = Saving_path+\"/AB\"\n","os.makedirs(imageAB_folder)\n","\n","TrainA_Folder = Saving_path+\"/A/train\"\n","os.makedirs(TrainA_Folder)\n"," \n","TrainB_Folder = Saving_path+\"/B/train\"\n","os.makedirs(TrainB_Folder)\n","\n","# Here we disable pre-trained model by default (in case the cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","\n","# This will display a randomly chosen dataset input and output\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imageio.imread(Training_source+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","#Hyperparameters failsafes\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 4\n","if not patch_size % 4 == 0:\n"," patch_size = ((int(patch_size / 4)-1) * 4)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 4; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is at least bigger than 256\n","if patch_size < 256:\n"," patch_size = 256\n"," print (bcolors.WARNING + \" Your chosen patch_size is too small; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","\n","y = imageio.imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Training target')\n","plt.axis('off');\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"5LEowmfAWqPs","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"Flz3qoQrWv0v","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by [Augmentor.](https://github.com/mdbloice/Augmentor)\n","\n","[Augmentor](https://github.com/mdbloice/Augmentor) was described in the following article:\n","\n","Marcus D Bloice, Peter M Roth, Andreas Holzinger, Biomedical image augmentation using Augmentor, Bioinformatics, https://doi.org/10.1093/bioinformatics/btz259\n","\n","**Please also cite this original paper when publishing results obtained using this notebook with augmentation enabled.** "]},{"cell_type":"code","metadata":{"id":"OsIBK-sywkfy","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," !pip install Augmentor\n"," import Augmentor\n","\n","\n","#@markdown ####Choose a factor by which you want to multiply your original dataset\n","\n","Multiply_dataset_by = 4 #@param {type:\"slider\", min:1, max:30, step:1}\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please choose the probability of the following image manipulations to be used to augment your dataset (1 = always used; 0 = disabled ):\n","\n","#@markdown ####Mirror and rotate images\n","rotate_90_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","rotate_270_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_left_right = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_top_bottom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image Zoom\n","\n","random_zoom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","random_zoom_magnification = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image distortion\n","\n","random_distortion = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","#@markdown ####Image shearing and skewing \n","\n","image_shear = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","max_image_shear = 1 #@param {type:\"slider\", min:1, max:25, step:1}\n","\n","skew_image = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","skew_image_magnitude = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","if Use_Default_Augmentation_Parameters:\n"," rotate_90_degrees = 0.5\n"," rotate_270_degrees = 0.5\n"," flip_left_right = 0.5\n"," flip_top_bottom = 0.5\n","\n"," if not Multiply_dataset_by >5:\n"," random_zoom = 0\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0\n"," image_shear = 0\n"," max_image_shear = 10\n"," skew_image = 0\n"," skew_image_magnitude = 0\n","\n"," if Multiply_dataset_by >5:\n"," random_zoom = 0.1\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0.5\n"," image_shear = 0.2\n"," max_image_shear = 5\n"," skew_image = 0.2\n"," skew_image_magnitude = 0.4\n","\n"," if Multiply_dataset_by >25:\n"," random_zoom = 0.5\n"," random_zoom_magnification = 0.8\n"," random_distortion = 0.5\n"," image_shear = 0.5\n"," max_image_shear = 20\n"," skew_image = 0.5\n"," skew_image_magnitude = 0.6\n","\n","\n","list_files = os.listdir(Training_source)\n","Nb_files = len(list_files)\n","\n","Nb_augmented_files = (Nb_files * Multiply_dataset_by)\n","\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","# Here we set the path for the various folder were the augmented images will be loaded\n","\n","# All images are first saved into the augmented folder\n"," #Augmented_folder = \"/content/Augmented_Folder\"\n"," \n"," if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n"," Augmented_folder = Saving_path+\"/Augmented_Folder\"\n"," if os.path.exists(Augmented_folder):\n"," shutil.rmtree(Augmented_folder)\n"," os.makedirs(Augmented_folder)\n","\n"," #Training_source_augmented = \"/content/Training_source_augmented\"\n"," Training_source_augmented = Saving_path+\"/Training_source_augmented\"\n","\n"," if os.path.exists(Training_source_augmented):\n"," shutil.rmtree(Training_source_augmented)\n"," os.makedirs(Training_source_augmented)\n","\n"," #Training_target_augmented = \"/content/Training_target_augmented\"\n"," Training_target_augmented = Saving_path+\"/Training_target_augmented\"\n","\n"," if os.path.exists(Training_target_augmented):\n"," shutil.rmtree(Training_target_augmented)\n"," os.makedirs(Training_target_augmented)\n","\n","\n","# Here we generate the augmented images\n","#Load the images\n"," p = Augmentor.Pipeline(Training_source, Augmented_folder)\n","\n","#Define the matching images\n"," p.ground_truth(Training_target)\n","#Define the augmentation possibilities\n"," if not rotate_90_degrees == 0:\n"," p.rotate90(probability=rotate_90_degrees)\n"," \n"," if not rotate_270_degrees == 0:\n"," p.rotate270(probability=rotate_270_degrees)\n","\n"," if not flip_left_right == 0:\n"," p.flip_left_right(probability=flip_left_right)\n","\n"," if not flip_top_bottom == 0:\n"," p.flip_top_bottom(probability=flip_top_bottom)\n","\n"," if not random_zoom == 0:\n"," p.zoom_random(probability=random_zoom, percentage_area=random_zoom_magnification)\n"," \n"," if not random_distortion == 0:\n"," p.random_distortion(probability=random_distortion, grid_width=4, grid_height=4, magnitude=8)\n","\n"," if not image_shear == 0:\n"," p.shear(probability=image_shear,max_shear_left=20,max_shear_right=20)\n"," \n"," if not skew_image == 0:\n"," p.skew(probability=skew_image,magnitude=skew_image_magnitude)\n","\n"," p.sample(int(Nb_augmented_files))\n","\n"," print(int(Nb_augmented_files),\"matching images generated\")\n","\n","# Here we sort through the images and move them back to augmented trainning source and targets folders\n","\n"," augmented_files = os.listdir(Augmented_folder)\n","\n"," for f in augmented_files:\n","\n"," if (f.startswith(\"_groundtruth_(1)_\")):\n"," shortname_noprefix = f[17:]\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_target_augmented+\"/\"+shortname_noprefix) \n"," if not (f.startswith(\"_groundtruth_(1)_\")):\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_source_augmented+\"/\"+f)\n"," \n","\n"," for filename in os.listdir(Training_source_augmented):\n"," os.chdir(Training_source_augmented)\n"," os.rename(filename, filename.replace('_original', ''))\n"," \n"," #Here we clean up the extra files\n"," shutil.rmtree(Augmented_folder)\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\") \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"v-leE8pEWRkn","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a pix2pix model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n"]},{"cell_type":"code","metadata":{"id":"CbOcS3wiWV9w","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n"," h5_file_path = os.path.join(pretrained_model_path, \"latest_net_G.pth\")\n"," \n","\n","# --------------------- Check the model exist ------------------------\n","\n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: Pretrained model does not exist')\n"," Use_pretrained_model = False\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"," if os.path.exists(h5_file_path):\n"," print(\"Pretrained model \"+os.path.basename(pretrained_model_path)+\" was found and will be loaded prior to training.\")\n"," \n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL","colab_type":"text"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"-A4ipz8gs3Ew","colab_type":"text"},"source":["## **4.1. Prepare the training data for training**\n","---\n","Here, we use the information from Section 3 to prepare the training data into a suitable format for training. **Your data will be copied in the google Colab \"content\" folder which may take some time depending on the size of your dataset.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"_V2ujGB60gDv","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Prepare the data for training\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","if Use_Data_augmentation:\n"," Training_source_dir = Training_source_augmented\n"," Training_target_dir = Training_target_augmented\n","\n","if not Use_Data_augmentation:\n"," Training_source_dir = Training_source\n"," Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","print(\"Data preparation in progress\")\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n","os.makedirs(model_path+'/'+model_name)\n","\n","#--------------- Here we move the files to trainA and train B ---------\n","\n","print('Copying training source data...')\n","for f in tqdm(os.listdir(Training_source_dir)):\n"," shutil.copyfile(Training_source_dir+\"/\"+f, TrainA_Folder+\"/\"+f)\n","\n","print('Copying training target data...')\n","for f in tqdm(os.listdir(Training_target_dir)):\n"," shutil.copyfile(Training_target_dir+\"/\"+f, TrainB_Folder+\"/\"+f)\n","\n","#---------------------------------------------------------------------\n","\n","#--------------- Here we combined A and B images---------\n","os.chdir(\"/content\")\n","!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n","\n","\n","\n","# pix2pix uses EPOCH without lr decay and EPOCH with lr decay, here we automatically choose half and half\n","\n","number_of_epochs_lr_stable = int(number_of_epochs/2)\n","number_of_epochs_lr_decay = int(number_of_epochs/2)\n","\n","if Use_pretrained_model :\n"," for f in os.listdir(pretrained_model_path):\n"," if (f.startswith(\"latest_net_\")): \n"," shutil.copyfile(pretrained_model_path+\"/\"+f, model_path+'/'+model_name+\"/\"+f)\n","\n","print('------------------------')\n","print(\"Data ready for training\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR","colab_type":"text"},"source":["## **4.2. Train the network**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches or continue the training in a second Colab session. **Pix2pix will save model checkpoints every 5 epochs.**"]},{"cell_type":"code","metadata":{"id":"eBD50tAgv5qf","colab_type":"code","cellView":"form","colab":{}},"source":["\n","#@markdown ##Start training\n","\n","start = time.time()\n","\n","os.chdir(\"/content\")\n","\n","#--------------------------------- Command line inputs to change pix2pix paramaters------------\n","\n"," # basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n"," \n"," # model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n"," # dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n"," # additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n"," # visdom and HTML visualization parameters\n"," #('--display_freq', type=int, default=400, help='frequency of showing training results on screen')\n"," #('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')\n"," #('--display_id', type=int, default=1, help='window id of the web display')\n"," #('--display_server', type=str, default=\"http://localhost\", help='visdom server of the web display')\n"," #('--display_env', type=str, default='main', help='visdom display environment name (default is \"main\")')\n"," #('--display_port', type=int, default=8097, help='visdom port of the web display')\n"," #('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')\n"," #('--print_freq', type=int, default=100, help='frequency of showing training results on console')\n"," #('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')\n"," \n"," # network saving and loading parameters\n"," #('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')\n"," #('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')\n"," #('--save_by_iter', action='store_true', help='whether saves model by iteration')\n"," #('--continue_train', action='store_true', help='continue training: load the latest model')\n"," #('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')\n"," #('--phase', type=str, default='train', help='train, val, test, etc')\n"," \n"," # training parameters\n"," #('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')\n"," #('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')\n"," #('--beta1', type=float, default=0.5, help='momentum term of adam')\n"," #('--lr', type=float, default=0.0002, help='initial learning rate for adam')\n"," #('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')\n"," #('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')\n"," #('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')\n"," #('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations'\n","\n","#---------------------------------------------------------\n","\n","#----- Start the training ------------------------------------\n","if not Use_pretrained_model:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$imageAB_folder\" --name $model_name --model pix2pix --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5\n","\n","if Use_pretrained_model:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$imageAB_folder\" --name $model_name --model pix2pix --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train\n","\n","\n","#---------------------------------------------------------\n","\n","print(\"Training, done.\")\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"XQjQb_J_Qyku","colab_type":"text"},"source":["##**4.3. Download your model(s) from Google Drive**\n","\n","\n","---\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"2HbZd7rFqAad","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"NEBRRG8QyEDG","colab_type":"text"},"source":["## **5.1. Choose the model you want to assess**"]},{"cell_type":"code","metadata":{"id":"EdcnkCr9Nbl8","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ry9qN2tlydXq","colab_type":"text"},"source":["## **5.2. Identify the best checkpoint to use to make predictions**"]},{"cell_type":"markdown","metadata":{"id":"1yauWCc78HKD","colab_type":"text"},"source":[" Pix2pix save model checkpoints every five epochs. Due to the stochastic nature of GAN networks, the last checkpoint is not always the best one to use. As a consequence, it can be challenging to choose the most suitable checkpoint to use to make predictions.\n","\n","This section allows you to perform predictions using all the saved checkpoints and to estimate the quality of these predictions by comparing them to the provided ground truths images. Metric used include:\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n"]},{"cell_type":"code","metadata":{"id":"2nBPucJdK3KS","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","import glob\n","import os.path\n","\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","Image_type = \"RGB\" #@param [\"Grayscale\", \"RGB\"]\n","\n","\n","# average function\n","def Average(lst): \n"," return sum(lst) / len(lst) \n","\n","\n","\n","# Create a quality control folder\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","\n","# Create a quality control/Prediction Folder\n","\n","QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"\n","\n","if os.path.exists(QC_prediction_results):\n"," shutil.rmtree(QC_prediction_results)\n","\n","os.makedirs(QC_prediction_results)\n","\n","# Here we count how many images are in our folder to be predicted and we had a few\n","Nb_files_Data_folder = len(os.listdir(Source_QC_folder)) +10\n","\n","\n","\n","# List images in Source_QC_folder\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","# Here we need to move the data to be analysed so that pix2pix can find them\n","\n","\n","Saving_path_QC= \"/content/\"+QC_model_name+\"_images\"\n","\n","if os.path.exists(Saving_path_QC):\n"," shutil.rmtree(Saving_path_QC)\n","os.makedirs(Saving_path_QC)\n","\n","Saving_path_QC_folder = Saving_path_QC+\"/QC\"\n","\n","if os.path.exists(Saving_path_QC_folder):\n"," shutil.rmtree(Saving_path_QC_folder)\n","os.makedirs(Saving_path_QC_folder)\n","\n","\n","imageA_folder = Saving_path_QC_folder+\"/A\"\n","os.makedirs(imageA_folder)\n","\n","imageB_folder = Saving_path_QC_folder+\"/B\"\n","os.makedirs(imageB_folder)\n","\n","imageAB_folder = Saving_path_QC_folder+\"/AB\"\n","os.makedirs(imageAB_folder)\n","\n","testAB_folder = Saving_path_QC_folder+\"/AB/test\"\n","os.makedirs(testAB_folder)\n","\n","testA_Folder = Saving_path_QC_folder+\"/A/test\"\n","os.makedirs(testA_Folder)\n"," \n","testB_Folder = Saving_path_QC_folder+\"/B/test\"\n","os.makedirs(testB_Folder)\n","\n","QC_checkpoint_folders = \"/content/\"+QC_model_name\n","\n","if os.path.exists(QC_checkpoint_folders):\n"," shutil.rmtree(QC_checkpoint_folders)\n","os.makedirs(QC_checkpoint_folders)\n","\n","\n","\n","for files in os.listdir(Source_QC_folder):\n"," shutil.copyfile(Source_QC_folder+\"/\"+files, testA_Folder+\"/\"+files)\n","\n","for files in os.listdir(Target_QC_folder):\n"," shutil.copyfile(Target_QC_folder+\"/\"+files, testB_Folder+\"/\"+files)\n"," \n","#Here we create a merged folder containing only imageA\n","os.chdir(\"/content\")\n","\n","!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n","\n","\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = int(min(Image_Y, Image_X))\n","\n","patch_size_QC = Image_min_dim\n","\n","if not patch_size_QC % 256 == 0:\n"," patch_size_QC = ((int(patch_size_QC / 256)) * 256)\n"," print (\" Your image dimensions are not divisible by 256; therefore your images have now been resized to:\",patch_size_QC)\n","\n","if patch_size_QC < 256:\n"," patch_size_QC = 256\n","\n","\n","Nb_Checkpoint = len(glob.glob(os.path.join(full_QC_model_path, '*G.pth')))\n","\n","\n","print(Nb_Checkpoint)\n","\n","\n","## Initiate list\n","\n","Checkpoint_list = []\n","Average_ssim_score_list = []\n","\n","\n","for j in range(1, len(glob.glob(os.path.join(full_QC_model_path, '*G.pth')))+1):\n"," checkpoints = j*5\n","\n"," if checkpoints == Nb_Checkpoint*5:\n"," checkpoints = \"latest\"\n","\n","\n"," print(\"The checkpoint currently analysed is =\"+str(checkpoints))\n","\n"," Checkpoint_list.append(checkpoints)\n","\n","\n"," # Create a quality control/Prediction Folder\n","\n"," QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)\n","\n"," if os.path.exists(QC_prediction_results):\n"," shutil.rmtree(QC_prediction_results)\n","\n"," os.makedirs(QC_prediction_results)\n","\n","\n"," # Create a quality control/Prediction Folder\n","\n"," QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)\n","\n"," if os.path.exists(QC_prediction_results):\n"," shutil.rmtree(QC_prediction_results)\n","\n"," os.makedirs(QC_prediction_results)\n","\n","\n","#---------------------------- Predictions are performed here ----------------------\n"," os.chdir(\"/content\")\n"," !python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$imageAB_folder\" --name \"$QC_model_name\" --model pix2pix --epoch $checkpoints --no_dropout --preprocess scale_width --load_size $patch_size_QC --crop_size $patch_size_QC --results_dir \"$QC_prediction_results\" --checkpoints_dir \"$QC_model_path\" --direction AtoB --num_test $Nb_files_Data_folder\n","#-----------------------------------------------------------------------------------\n","\n","#Here we need to move the data again and remove all the unnecessary folders\n","\n"," Checkpoint_name = \"test_\"+str(checkpoints)\n","\n"," QC_results_images = QC_prediction_results+\"/\"+QC_model_name+\"/\"+Checkpoint_name+\"/images\"\n","\n"," QC_results_images_files = os.listdir(QC_results_images)\n","\n"," for f in QC_results_images_files: \n"," shutil.copyfile(QC_results_images+\"/\"+f, QC_prediction_results+\"/\"+f)\n","\n"," os.chdir(\"/content\") \n","\n"," #Here we clean up the extra files\n"," shutil.rmtree(QC_prediction_results+\"/\"+QC_model_name)\n","\n","\n"," #-------------------------------- QC for RGB ------------------------------------\n"," if Image_type == \"RGB\":\n","# List images in Source_QC_folder\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n"," random_choice = random.choice(os.listdir(Source_QC_folder))\n"," x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, multichannel=True)\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\"])\n"," \n"," \n"," # Initiate list\n"," ssim_score_list = [] \n","\n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n","\n"," shortname_no_PNG = i[:-4]\n"," \n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," \n"," test_GT = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), shortname_no_PNG+\"_real_B.png\"))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real_A.png\"))\n"," \n"," \n"," # -------------------------------- Prediction --------------------------------\n"," \n"," test_prediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake_B.png\"))\n"," \n"," #--------------------------- Here we normalise using histograms matching--------------------------------\n"," test_prediction_matched = match_histograms(test_prediction, test_GT, multichannel=True)\n"," test_source_matched = match_histograms(test_source, test_GT, multichannel=True)\n"," \n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT, test_prediction_matched)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT, test_source_matched)\n","\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n"," \n"," \n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource)])\n","\n"," #Here we calculate the ssim average for each image in each checkpoints\n","\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\n","\n","#------------------------------------------- QC for Grayscale ----------------------------------------------\n","\n"," if Image_type == \"Grayscale\":\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n"," def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n","\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n"," def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n"," def norm_minmse(gt, x, normalize_gt=True):\n"," \n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," \n"," \n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n","\n"," ssim_score_list = []\n"," shortname_no_PNG = i[:-4]\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), shortname_no_PNG+\"_real_B.png\"))\n"," \n"," test_GT = test_GT_raw[:,:,2]\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real_A.png\"))\n"," \n"," test_source = test_source_raw[:,:,2]\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake_B.png\"))\n"," \n"," test_prediction = test_prediction_raw[:,:,2]\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\n","\n"," #Save ssim_maps\n"," \n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_8bit = (img_RSE_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_RSE_GTvsPrediction_8bit)\n"," img_RSE_GTvsSource_8bit = (img_RSE_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsSource_\"+shortname_no_PNG+'.tif',img_RSE_GTvsSource_8bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n"," #Here we calculate the ssim average for each image in each checkpoints\n","\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\n","\n","\n","# All data is now processed saved\n"," \n","\n","# -------------------------------- Display --------------------------------\n","\n","# Display the IoV vs Threshold plot\n","plt.figure(figsize=(20,5))\n","plt.plot(Checkpoint_list, Average_ssim_score_list, label=\"SSIM\")\n","plt.title('Checkpoints vs. SSIM')\n","plt.ylabel('SSIM')\n","plt.xlabel('Checkpoints')\n","plt.legend()\n","plt.show()\n","\n","\n","\n","# -------------------------------- Display RGB --------------------------------\n","\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","if Image_type == \"RGB\":\n"," random_choice_shortname_no_PNG = shortname_no_PNG\n","\n"," @interact\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n","\n"," random_choice_shortname_no_PNG = file[:-4]\n","\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n"," df2 = df1.set_index(\"image #\", drop = False)\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n","\n","#Setting up colours\n"," cmap = None\n","\n","\n"," plt.figure(figsize=(15,15))\n","\n","# Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_B.png\"), as_gray=False, pilmode=\"RGB\")\n"," \n"," plt.imshow(img_GT, cmap = cmap)\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_A.png\"), as_gray=False, pilmode=\"RGB\")\n"," plt.imshow(img_Source, cmap = cmap)\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n","\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake_B.png\"))\n","\n"," plt.imshow(img_Prediction, cmap = cmap)\n"," plt.title('Prediction',fontsize=15)\n","\n","\n","#SSIM between GT and Source\n"," plt.subplot(3,3,5)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","#plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","\n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","#plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","\n","# -------------------------------- Display Grayscale --------------------------------\n","\n","if Image_type == \"Grayscale\":\n"," random_choice_shortname_no_PNG = shortname_no_PNG\n","\n"," @interact\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n","\n"," random_choice_shortname_no_PNG = file[:-4]\n","\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n"," df2 = df1.set_index(\"image #\", drop = False)\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n","\n"," NRMSE_GTvsPrediction = df2.loc[file, \"Prediction v. GT NRMSE\"]\n"," NRMSE_GTvsSource = df2.loc[file, \"Input v. GT NRMSE\"]\n"," PSNR_GTvsSource = df2.loc[file, \"Input v. GT PSNR\"]\n"," PSNR_GTvsPrediction = df2.loc[file, \"Prediction v. GT PSNR\"]\n"," \n","\n"," plt.figure(figsize=(20,20))\n"," # Currently only displays the last computed set, from memory\n"," # Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_B.png\"))\n","\n"," plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99))\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_A.png\"))\n"," plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake_B.png\"))\n"," plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\n"," plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n"," plt.subplot(3,3,5)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_SSIM_GTvsSource = img_SSIM_GTvsSource / 255\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","\n"," \n"," plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," \n"," \n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_SSIM_GTvsPrediction = img_SSIM_GTvsPrediction / 255\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","\n"," \n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n"," plt.subplot(3,3,8)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_RSE_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_RSE_GTvsSource = img_RSE_GTvsSource / 255\n"," \n","\n"," imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n"," plt.subplot(3,3,9)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_RSE_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," img_RSE_GTvsPrediction = img_RSE_GTvsPrediction / 255\n","\n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"Esqnbew8uznk"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as PNG images.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`checkpoint`:** Choose the checkpoint number you would like to use to perform predictions. To use the \"latest\" checkpoint, input \"latest\".\n"]},{"cell_type":"code","metadata":{"id":"yb3suNkfpNA9","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","import glob\n","import os.path\n","\n","latest = \"latest\"\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###What model checkpoint would you like to use?\n","\n","checkpoint = latest#@param {type:\"raw\"}\n","\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","#here we check if we use the newly trained network or not\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","\n","#here we check if the model exists\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","Nb_Checkpoint = len(glob.glob(os.path.join(full_Prediction_model_path, '*G.pth')))+1\n","\n","\n","if not checkpoint == \"latest\":\n","\n"," if checkpoint < 10:\n"," checkpoint = 5\n","\n"," if not checkpoint % 5 == 0:\n"," checkpoint = ((int(checkpoint / 5)-1) * 5)\n"," print (bcolors.WARNING + \" Your chosen checkpoints is not divisible by 5; therefore the checkpoints chosen is now:\",checkpoints)\n","\n","\n"," \n"," if checkpoint == Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n"," if checkpoint > Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n","\n","# Here we need to move the data to be analysed so that pix2pix can find them\n","\n","Saving_path_prediction= \"/content/\"+Prediction_model_name\n","\n","if os.path.exists(Saving_path_prediction):\n"," shutil.rmtree(Saving_path_prediction)\n","os.makedirs(Saving_path_prediction)\n","\n","\n","imageA_folder = Saving_path_prediction+\"/A\"\n","os.makedirs(imageA_folder)\n","\n","imageB_folder = Saving_path_prediction+\"/B\"\n","os.makedirs(imageB_folder)\n","\n","imageAB_folder = Saving_path_prediction+\"/AB\"\n","os.makedirs(imageAB_folder)\n","\n","testAB_Folder = Saving_path_prediction+\"/AB/test\"\n","os.makedirs(testAB_Folder)\n","\n","testA_Folder = Saving_path_prediction+\"/A/test\"\n","os.makedirs(testA_Folder)\n"," \n","testB_Folder = Saving_path_prediction+\"/B/test\"\n","os.makedirs(testB_Folder)\n","\n","for files in os.listdir(Data_folder):\n"," shutil.copyfile(Data_folder+\"/\"+files, testA_Folder+\"/\"+files)\n"," shutil.copyfile(Data_folder+\"/\"+files, testB_Folder+\"/\"+files)\n"," \n","# Here we create a merged A / A image for the prediction\n","os.chdir(\"/content\")\n","!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n","\n","\n","# Here we count how many images are in our folder to be predicted and we had a few\n","Nb_files_Data_folder = len(os.listdir(Data_folder)) +10\n","\n","\n","# This will find the image dimension of a randomly choosen image in Data_folder \n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imageio.imread(Data_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","\n","#-------------------------------- Perform predictions -----------------------------\n","\n","#-------------------------------- Options that can be used to perform predictions -----------------------------\n","\n","# basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n","\n","# model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n","# dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n","# additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n","\n"," #('--ntest', type=int, default=float(\"inf\"), help='# of test examples.')\n"," #('--results_dir', type=str, default='./results/', help='saves results here.')\n"," #('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')\n"," #('--phase', type=str, default='test', help='train, val, test, etc')\n","\n","# Dropout and Batchnorm has different behavioir during training and test.\n"," #('--eval', action='store_true', help='use eval mode during test time.')\n"," #('--num_test', type=int, default=50, help='how many test images to run')\n"," # rewrite devalue values\n"," \n","# To avoid cropping, the load_size should be the same as crop_size\n"," #parser.set_defaults(load_size=parser.get_default('crop_size'))\n","\n","#------------------------------------------------------------------------\n","\n","\n","#---------------------------- Predictions are performed here ----------------------\n","\n","os.chdir(\"/content\")\n","\n","!python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$imageAB_folder\" --name \"$Prediction_model_name\" --model pix2pix --no_dropout --preprocess scale_width --load_size $Image_min_dim --crop_size $Image_min_dim --results_dir \"$Result_folder\" --checkpoints_dir \"$Prediction_model_path\" --num_test $Nb_files_Data_folder --epoch $checkpoint\n","\n","#-----------------------------------------------------------------------------------\n","\n","\n","Checkpoint_name = \"test_\"+str(checkpoint)\n","\n","\n","Prediction_results_folder = Result_folder+\"/\"+Prediction_model_name+\"/\"+Checkpoint_name+\"/images\"\n","\n","Prediction_results_images = os.listdir(Prediction_results_folder)\n","\n","for f in Prediction_results_images: \n"," if (f.endswith(\"_real_B.png\")): \n"," os.remove(Prediction_results_folder+\"/\"+f)\n","\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"EIe3CRD7XUxa","colab_type":"text"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","metadata":{"id":"LmDP8xiwXTTL","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","import os\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","\n","\n","random_choice_no_extension = os.path.splitext(random_choice)\n","\n","\n","x = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_real_A.png\")\n","\n","\n","y = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_fake_B.png\")\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Input')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Prediction')\n","plt.axis('off');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB","colab_type":"text"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"Rn9zpWpo0xNw","colab_type":"text"},"source":["\n","#**Thank you for using pix2pix!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"pix2pix_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1mqcexfPBaIWuvMWWbJZUFtPoZoJJwrEA","timestamp":1589278334507},{"file_id":"159ARwlQE7-zi0EHxunOF_YPFLt-ZVU5x","timestamp":1587562499898},{"file_id":"1W-7NHehG5MRFILvZZzhPWWnOdJMkadb2","timestamp":1586332290412},{"file_id":"1pUetEQICxYWkYVaQIgdRH1EZBTl7oc2A","timestamp":1586292199692},{"file_id":"1MD36ZkM6XR9EuV12zimJmfCjzyeYZFWq","timestamp":1586269469061},{"file_id":"16A2mbaHzlEElntS8qkFBOsBvZG-mUeY6","timestamp":1586253795726},{"file_id":"1gJlcjOiSxr2buDOxmcFbT_d-GqwLjXtK","timestamp":1583343225796},{"file_id":"10yGI51WzHfgWgZAyE-EbkZFEvIOd6CP6","timestamp":1583171396283}],"collapsed_sections":[],"toc_visible":true},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I","colab_type":"text"},"source":["# **pix2pix**\n","\n","---\n","\n","pix2pix is a deep-learning method allowing image-to-image translation from one image domain type to another image domain type. It was first published by [Isola *et al.* in 2016](https://arxiv.org/abs/1611.07004). The image transformation requires paired images for training (supervised learning) and is made possible here by using a conditional Generative Adversarial Network (GAN) architecture to use information from the input image and obtain the equivalent translated image.\n","\n"," **This particular notebook enables image-to-image translation learned from paired dataset. If you are interested in performing unpaired image-to-image translation, you should consider using the CycleGAN notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n"," **Image-to-Image Translation with Conditional Adversarial Networks** by Isola *et al.* on arXiv in 2016 (https://arxiv.org/abs/1611.07004)\n","\n","The source code of the PyTorch implementation of pix2pix can be found here: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"N3azwKB9O0oW","colab_type":"text"},"source":["# **License**\n","\n","---"]},{"cell_type":"code","metadata":{"id":"ByW6Vqdn9sYV","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Double click to see the license information\n","\n","#------------------------- LICENSE FOR ZeroCostDL4Mic------------------------------------\n","#This ZeroCostDL4Mic notebook is distributed under the MIT licence\n","\n","\n","\n","#------------------------- LICENSE FOR CycleGAN ------------------------------------\n","\n","#Copyright (c) 2017, Jun-Yan Zhu and Taesung Park\n","#All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without\n","#modification, are permitted provided that the following conditions are met:\n","\n","#* Redistributions of source code must retain the above copyright notice, this\n","# list of conditions and the following disclaimer.\n","\n","#* Redistributions in binary form must reproduce the above copyright notice,\n","# this list of conditions and the following disclaimer in the documentation\n","# and/or other materials provided with the distribution.\n","\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n","#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n","#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n","#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n","#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n","#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n","#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n","#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n","#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n","#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n","\n","\n","#--------------------------- LICENSE FOR pix2pix --------------------------------\n","#BSD License\n","\n","#For pix2pix software\n","#Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu\n","#All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without\n","#modification, are permitted provided that the following conditions are met:\n","\n","#* Redistributions of source code must retain the above copyright notice, this\n","# list of conditions and the following disclaimer.\n","\n","#* Redistributions in binary form must reproduce the above copyright notice,\n","# this list of conditions and the following disclaimer in the documentation\n","# and/or other materials provided with the distribution.\n","\n","#----------------------------- LICENSE FOR DCGAN --------------------------------\n","#BSD License\n","\n","#For dcgan.torch software\n","\n","#Copyright (c) 2015, Facebook, Inc. All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n","\n","#Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n","\n","#Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n","\n","#Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\n","\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," For pix2pix to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called Training_source and Training_target. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .PNG files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Training_source\n"," - img_1.png, img_2.png, ...\n"," - Training_target\n"," - img_1.png, img_2.png, ...\n"," - **Quality control dataset**\n"," - Training_source\n"," - img_1.png, img_2.png\n"," - Training_target\n"," - img_1.png, img_2.png\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"BCPhV-pe-syw","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"VNZetvLiS1qV","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"UBrnApIUBgxv","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin","colab_type":"text"},"source":["# **2. Install pix2pix and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install pix2pix and dependencies\n","\n","\n","\n","#Here, we install libraries which are not already included in Colab.\n","\n","\n","!git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","import os\n","os.chdir('pytorch-CycleGAN-and-pix2pix/')\n","!pip install -r requirements.txt\n","\n","import imageio\n","from skimage import data\n","from skimage import exposure\n","from skimage.exposure import match_histograms\n","import glob\n","import os.path\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print('----------------------------')\n","print(\"Libraries installed\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4","colab_type":"text"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"BLmBseWbRvxL","colab_type":"text"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd","colab_type":"text"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source and Training_target training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10) epochs, but a full training should run for 200 epochs or more. Evaluate the performance after training (see 5). **Default value: 200**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`patch_size`:** pix2pix divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 512**\n","\n","**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0002**"]},{"cell_type":"code","metadata":{"id":"pIrTwJjzwV-D","colab_type":"code","cellView":"form","colab":{}},"source":["\n","\n","#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","InputFile = Training_source+\"/*.png\"\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","OutputFile = Training_target+\"/*.png\"\n","\n","\n","#Define where the patch file will be saved\n","base = \"/content\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 200#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","patch_size = 512#@param {type:\"number\"} # in pixels\n","batch_size = 1#@param {type:\"number\"}\n","initial_learning_rate = 0.0002 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 1\n"," patch_size = 512\n"," initial_learning_rate = 0.0002\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\")\n"," \n","#To use pix2pix we need to organise the data in a way the network can understand\n","\n","Saving_path= \"/content/\"+model_name\n","#Saving_path= model_path+\"/\"+model_name\n","\n","if os.path.exists(Saving_path):\n"," shutil.rmtree(Saving_path)\n","os.makedirs(Saving_path)\n","\n","imageA_folder = Saving_path+\"/A\"\n","os.makedirs(imageA_folder)\n","\n","imageB_folder = Saving_path+\"/B\"\n","os.makedirs(imageB_folder)\n","\n","imageAB_folder = Saving_path+\"/AB\"\n","os.makedirs(imageAB_folder)\n","\n","TrainA_Folder = Saving_path+\"/A/train\"\n","os.makedirs(TrainA_Folder)\n"," \n","TrainB_Folder = Saving_path+\"/B/train\"\n","os.makedirs(TrainB_Folder)\n","\n","# Here we disable pre-trained model by default (in case the cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","\n","# This will display a randomly chosen dataset input and output\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imageio.imread(Training_source+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","#Hyperparameters failsafes\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 4\n","if not patch_size % 4 == 0:\n"," patch_size = ((int(patch_size / 4)-1) * 4)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 4; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is at least bigger than 256\n","if patch_size < 256:\n"," patch_size = 256\n"," print (bcolors.WARNING + \" Your chosen patch_size is too small; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","\n","y = imageio.imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Training target')\n","plt.axis('off');\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"5LEowmfAWqPs","colab_type":"text"},"source":["## **3.2. Data augmentation**\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"Flz3qoQrWv0v","colab_type":"text"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by [Augmentor.](https://github.com/mdbloice/Augmentor)\n","\n","[Augmentor](https://github.com/mdbloice/Augmentor) was described in the following article:\n","\n","Marcus D Bloice, Peter M Roth, Andreas Holzinger, Biomedical image augmentation using Augmentor, Bioinformatics, https://doi.org/10.1093/bioinformatics/btz259\n","\n","**Please also cite this original paper when publishing results obtained using this notebook with augmentation enabled.** "]},{"cell_type":"code","metadata":{"id":"OsIBK-sywkfy","colab_type":"code","cellView":"form","colab":{}},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," !pip install Augmentor\n"," import Augmentor\n","\n","\n","#@markdown ####Choose a factor by which you want to multiply your original dataset\n","\n","Multiply_dataset_by = 4 #@param {type:\"slider\", min:1, max:30, step:1}\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please choose the probability of the following image manipulations to be used to augment your dataset (1 = always used; 0 = disabled ):\n","\n","#@markdown ####Mirror and rotate images\n","rotate_90_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","rotate_270_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_left_right = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_top_bottom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image Zoom\n","\n","random_zoom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","random_zoom_magnification = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image distortion\n","\n","random_distortion = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","#@markdown ####Image shearing and skewing \n","\n","image_shear = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","max_image_shear = 1 #@param {type:\"slider\", min:1, max:25, step:1}\n","\n","skew_image = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","skew_image_magnitude = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","if Use_Default_Augmentation_Parameters:\n"," rotate_90_degrees = 0.5\n"," rotate_270_degrees = 0.5\n"," flip_left_right = 0.5\n"," flip_top_bottom = 0.5\n","\n"," if not Multiply_dataset_by >5:\n"," random_zoom = 0\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0\n"," image_shear = 0\n"," max_image_shear = 10\n"," skew_image = 0\n"," skew_image_magnitude = 0\n","\n"," if Multiply_dataset_by >5:\n"," random_zoom = 0.1\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0.5\n"," image_shear = 0.2\n"," max_image_shear = 5\n"," skew_image = 0.2\n"," skew_image_magnitude = 0.4\n","\n"," if Multiply_dataset_by >25:\n"," random_zoom = 0.5\n"," random_zoom_magnification = 0.8\n"," random_distortion = 0.5\n"," image_shear = 0.5\n"," max_image_shear = 20\n"," skew_image = 0.5\n"," skew_image_magnitude = 0.6\n","\n","\n","list_files = os.listdir(Training_source)\n","Nb_files = len(list_files)\n","\n","Nb_augmented_files = (Nb_files * Multiply_dataset_by)\n","\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","# Here we set the path for the various folder were the augmented images will be loaded\n","\n","# All images are first saved into the augmented folder\n"," #Augmented_folder = \"/content/Augmented_Folder\"\n"," \n"," if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n"," Augmented_folder = Saving_path+\"/Augmented_Folder\"\n"," if os.path.exists(Augmented_folder):\n"," shutil.rmtree(Augmented_folder)\n"," os.makedirs(Augmented_folder)\n","\n"," #Training_source_augmented = \"/content/Training_source_augmented\"\n"," Training_source_augmented = Saving_path+\"/Training_source_augmented\"\n","\n"," if os.path.exists(Training_source_augmented):\n"," shutil.rmtree(Training_source_augmented)\n"," os.makedirs(Training_source_augmented)\n","\n"," #Training_target_augmented = \"/content/Training_target_augmented\"\n"," Training_target_augmented = Saving_path+\"/Training_target_augmented\"\n","\n"," if os.path.exists(Training_target_augmented):\n"," shutil.rmtree(Training_target_augmented)\n"," os.makedirs(Training_target_augmented)\n","\n","\n","# Here we generate the augmented images\n","#Load the images\n"," p = Augmentor.Pipeline(Training_source, Augmented_folder)\n","\n","#Define the matching images\n"," p.ground_truth(Training_target)\n","#Define the augmentation possibilities\n"," if not rotate_90_degrees == 0:\n"," p.rotate90(probability=rotate_90_degrees)\n"," \n"," if not rotate_270_degrees == 0:\n"," p.rotate270(probability=rotate_270_degrees)\n","\n"," if not flip_left_right == 0:\n"," p.flip_left_right(probability=flip_left_right)\n","\n"," if not flip_top_bottom == 0:\n"," p.flip_top_bottom(probability=flip_top_bottom)\n","\n"," if not random_zoom == 0:\n"," p.zoom_random(probability=random_zoom, percentage_area=random_zoom_magnification)\n"," \n"," if not random_distortion == 0:\n"," p.random_distortion(probability=random_distortion, grid_width=4, grid_height=4, magnitude=8)\n","\n"," if not image_shear == 0:\n"," p.shear(probability=image_shear,max_shear_left=20,max_shear_right=20)\n"," \n"," if not skew_image == 0:\n"," p.skew(probability=skew_image,magnitude=skew_image_magnitude)\n","\n"," p.sample(int(Nb_augmented_files))\n","\n"," print(int(Nb_augmented_files),\"matching images generated\")\n","\n","# Here we sort through the images and move them back to augmented trainning source and targets folders\n","\n"," augmented_files = os.listdir(Augmented_folder)\n","\n"," for f in augmented_files:\n","\n"," if (f.startswith(\"_groundtruth_(1)_\")):\n"," shortname_noprefix = f[17:]\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_target_augmented+\"/\"+shortname_noprefix) \n"," if not (f.startswith(\"_groundtruth_(1)_\")):\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_source_augmented+\"/\"+f)\n"," \n","\n"," for filename in os.listdir(Training_source_augmented):\n"," os.chdir(Training_source_augmented)\n"," os.rename(filename, filename.replace('_original', ''))\n"," \n"," #Here we clean up the extra files\n"," shutil.rmtree(Augmented_folder)\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\") \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"v-leE8pEWRkn","colab_type":"text"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a pix2pix model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n"]},{"cell_type":"code","metadata":{"id":"CbOcS3wiWV9w","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n"," h5_file_path = os.path.join(pretrained_model_path, \"latest_net_G.pth\")\n"," \n","\n","# --------------------- Check the model exist ------------------------\n","\n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: Pretrained model does not exist')\n"," Use_pretrained_model = False\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"," if os.path.exists(h5_file_path):\n"," print(\"Pretrained model \"+os.path.basename(pretrained_model_path)+\" was found and will be loaded prior to training.\")\n"," \n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL","colab_type":"text"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"-A4ipz8gs3Ew","colab_type":"text"},"source":["## **4.1. Prepare the training data for training**\n","---\n","Here, we use the information from Section 3 to prepare the training data into a suitable format for training. **Your data will be copied in the google Colab \"content\" folder which may take some time depending on the size of your dataset.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"_V2ujGB60gDv","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Prepare the data for training\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","if Use_Data_augmentation:\n"," Training_source_dir = Training_source_augmented\n"," Training_target_dir = Training_target_augmented\n","\n","if not Use_Data_augmentation:\n"," Training_source_dir = Training_source\n"," Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","print(\"Data preparation in progress\")\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n","os.makedirs(model_path+'/'+model_name)\n","\n","#--------------- Here we move the files to trainA and train B ---------\n","\n","print('Copying training source data...')\n","for f in tqdm(os.listdir(Training_source_dir)):\n"," shutil.copyfile(Training_source_dir+\"/\"+f, TrainA_Folder+\"/\"+f)\n","\n","print('Copying training target data...')\n","for f in tqdm(os.listdir(Training_target_dir)):\n"," shutil.copyfile(Training_target_dir+\"/\"+f, TrainB_Folder+\"/\"+f)\n","\n","#---------------------------------------------------------------------\n","\n","#--------------- Here we combined A and B images---------\n","os.chdir(\"/content\")\n","!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n","\n","\n","\n","# pix2pix uses EPOCH without lr decay and EPOCH with lr decay, here we automatically choose half and half\n","\n","number_of_epochs_lr_stable = int(number_of_epochs/2)\n","number_of_epochs_lr_decay = int(number_of_epochs/2)\n","\n","if Use_pretrained_model :\n"," for f in os.listdir(pretrained_model_path):\n"," if (f.startswith(\"latest_net_\")): \n"," shutil.copyfile(pretrained_model_path+\"/\"+f, model_path+'/'+model_name+\"/\"+f)\n","\n","print('------------------------')\n","print(\"Data ready for training\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR","colab_type":"text"},"source":["## **4.2. Start Trainning**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches or continue the training in a second Colab session. **Pix2pix will save model checkpoints every 5 epochs.**"]},{"cell_type":"code","metadata":{"id":"eBD50tAgv5qf","colab_type":"code","cellView":"form","colab":{}},"source":["\n","#@markdown ##Start training\n","\n","start = time.time()\n","\n","os.chdir(\"/content\")\n","\n","#--------------------------------- Command line inputs to change pix2pix paramaters------------\n","\n"," # basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n"," \n"," # model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n"," # dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n"," # additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n"," # visdom and HTML visualization parameters\n"," #('--display_freq', type=int, default=400, help='frequency of showing training results on screen')\n"," #('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')\n"," #('--display_id', type=int, default=1, help='window id of the web display')\n"," #('--display_server', type=str, default=\"http://localhost\", help='visdom server of the web display')\n"," #('--display_env', type=str, default='main', help='visdom display environment name (default is \"main\")')\n"," #('--display_port', type=int, default=8097, help='visdom port of the web display')\n"," #('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')\n"," #('--print_freq', type=int, default=100, help='frequency of showing training results on console')\n"," #('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')\n"," \n"," # network saving and loading parameters\n"," #('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')\n"," #('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')\n"," #('--save_by_iter', action='store_true', help='whether saves model by iteration')\n"," #('--continue_train', action='store_true', help='continue training: load the latest model')\n"," #('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')\n"," #('--phase', type=str, default='train', help='train, val, test, etc')\n"," \n"," # training parameters\n"," #('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')\n"," #('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')\n"," #('--beta1', type=float, default=0.5, help='momentum term of adam')\n"," #('--lr', type=float, default=0.0002, help='initial learning rate for adam')\n"," #('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')\n"," #('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')\n"," #('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')\n"," #('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations'\n","\n","#---------------------------------------------------------\n","\n","#----- Start the training ------------------------------------\n","if not Use_pretrained_model:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$imageAB_folder\" --name $model_name --model pix2pix --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5\n","\n","if Use_pretrained_model:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$imageAB_folder\" --name $model_name --model pix2pix --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train\n","\n","\n","#---------------------------------------------------------\n","\n","print(\"Training, done.\")\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"XQjQb_J_Qyku","colab_type":"text"},"source":["##**4.3. Download your model(s) from Google Drive**\n","\n","\n","---\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"2HbZd7rFqAad","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"NEBRRG8QyEDG","colab_type":"text"},"source":["## **5.1. Choose the model you want to assess**"]},{"cell_type":"code","metadata":{"id":"EdcnkCr9Nbl8","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ry9qN2tlydXq","colab_type":"text"},"source":["## **5.2. Identify the best checkpoint to use to make predictions**"]},{"cell_type":"markdown","metadata":{"id":"1yauWCc78HKD","colab_type":"text"},"source":[" Pix2pix save model checkpoints every five epochs. Due to the stochastic nature of GAN networks, the last checkpoint is not always the best one to use. As a consequence, it can be challenging to choose the most suitable checkpoint to use to make predictions.\n","\n","This section allows you to perform predictions using all the saved checkpoints and to estimate the quality of these predictions by comparing them to the provided ground truths images. Metric used include:\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n"]},{"cell_type":"code","metadata":{"id":"2nBPucJdK3KS","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","import glob\n","import os.path\n","\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","Image_type = \"RGB\" #@param [\"Grayscale\", \"RGB\"]\n","\n","\n","# average function\n","def Average(lst): \n"," return sum(lst) / len(lst) \n","\n","\n","\n","# Create a quality control folder\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","\n","# Create a quality control/Prediction Folder\n","\n","QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"\n","\n","if os.path.exists(QC_prediction_results):\n"," shutil.rmtree(QC_prediction_results)\n","\n","os.makedirs(QC_prediction_results)\n","\n","# Here we count how many images are in our folder to be predicted and we had a few\n","Nb_files_Data_folder = len(os.listdir(Source_QC_folder)) +10\n","\n","\n","\n","# List images in Source_QC_folder\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","# Here we need to move the data to be analysed so that pix2pix can find them\n","\n","\n","Saving_path_QC= \"/content/\"+QC_model_name+\"_images\"\n","\n","if os.path.exists(Saving_path_QC):\n"," shutil.rmtree(Saving_path_QC)\n","os.makedirs(Saving_path_QC)\n","\n","Saving_path_QC_folder = Saving_path_QC+\"/QC\"\n","\n","if os.path.exists(Saving_path_QC_folder):\n"," shutil.rmtree(Saving_path_QC_folder)\n","os.makedirs(Saving_path_QC_folder)\n","\n","\n","imageA_folder = Saving_path_QC_folder+\"/A\"\n","os.makedirs(imageA_folder)\n","\n","imageB_folder = Saving_path_QC_folder+\"/B\"\n","os.makedirs(imageB_folder)\n","\n","imageAB_folder = Saving_path_QC_folder+\"/AB\"\n","os.makedirs(imageAB_folder)\n","\n","testAB_folder = Saving_path_QC_folder+\"/AB/test\"\n","os.makedirs(testAB_folder)\n","\n","testA_Folder = Saving_path_QC_folder+\"/A/test\"\n","os.makedirs(testA_Folder)\n"," \n","testB_Folder = Saving_path_QC_folder+\"/B/test\"\n","os.makedirs(testB_Folder)\n","\n","QC_checkpoint_folders = \"/content/\"+QC_model_name\n","\n","if os.path.exists(QC_checkpoint_folders):\n"," shutil.rmtree(QC_checkpoint_folders)\n","os.makedirs(QC_checkpoint_folders)\n","\n","\n","\n","for files in os.listdir(Source_QC_folder):\n"," shutil.copyfile(Source_QC_folder+\"/\"+files, testA_Folder+\"/\"+files)\n","\n","for files in os.listdir(Target_QC_folder):\n"," shutil.copyfile(Target_QC_folder+\"/\"+files, testB_Folder+\"/\"+files)\n"," \n","#Here we create a merged folder containing only imageA\n","os.chdir(\"/content\")\n","\n","!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n","\n","\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = int(min(Image_Y, Image_X))\n","\n","patch_size_QC = Image_min_dim\n","\n","if not patch_size_QC % 256 == 0:\n"," patch_size_QC = ((int(patch_size_QC / 256)) * 256)\n"," print (\" Your image dimensions are not divisible by 256; therefore your images have now been resized to:\",patch_size_QC)\n","\n","if patch_size_QC < 256:\n"," patch_size_QC = 256\n","\n","\n","Nb_Checkpoint = len(glob.glob(os.path.join(full_QC_model_path, '*G.pth')))\n","\n","\n","print(Nb_Checkpoint)\n","\n","\n","## Initiate list\n","\n","Checkpoint_list = []\n","Average_ssim_score_list = []\n","\n","\n","for j in range(1, len(glob.glob(os.path.join(full_QC_model_path, '*G.pth')))+1):\n"," checkpoints = j*5\n","\n"," if checkpoints == Nb_Checkpoint*5:\n"," checkpoints = \"latest\"\n","\n","\n"," print(\"The checkpoint currently analysed is =\"+str(checkpoints))\n","\n"," Checkpoint_list.append(checkpoints)\n","\n","\n"," # Create a quality control/Prediction Folder\n","\n"," QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)\n","\n"," if os.path.exists(QC_prediction_results):\n"," shutil.rmtree(QC_prediction_results)\n","\n"," os.makedirs(QC_prediction_results)\n","\n","\n"," # Create a quality control/Prediction Folder\n","\n"," QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)\n","\n"," if os.path.exists(QC_prediction_results):\n"," shutil.rmtree(QC_prediction_results)\n","\n"," os.makedirs(QC_prediction_results)\n","\n","\n","#---------------------------- Predictions are performed here ----------------------\n"," os.chdir(\"/content\")\n"," !python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$imageAB_folder\" --name \"$QC_model_name\" --model pix2pix --epoch $checkpoints --no_dropout --preprocess scale_width --load_size $patch_size_QC --crop_size $patch_size_QC --results_dir \"$QC_prediction_results\" --checkpoints_dir \"$QC_model_path\" --direction AtoB --num_test $Nb_files_Data_folder\n","#-----------------------------------------------------------------------------------\n","\n","#Here we need to move the data again and remove all the unnecessary folders\n","\n"," Checkpoint_name = \"test_\"+str(checkpoints)\n","\n"," QC_results_images = QC_prediction_results+\"/\"+QC_model_name+\"/\"+Checkpoint_name+\"/images\"\n","\n"," QC_results_images_files = os.listdir(QC_results_images)\n","\n"," for f in QC_results_images_files: \n"," shutil.copyfile(QC_results_images+\"/\"+f, QC_prediction_results+\"/\"+f)\n","\n"," os.chdir(\"/content\") \n","\n"," #Here we clean up the extra files\n"," shutil.rmtree(QC_prediction_results+\"/\"+QC_model_name)\n","\n","\n"," #-------------------------------- QC for RGB ------------------------------------\n"," if Image_type == \"RGB\":\n","# List images in Source_QC_folder\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n"," random_choice = random.choice(os.listdir(Source_QC_folder))\n"," x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, multichannel=True)\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\"])\n"," \n"," \n"," # Initiate list\n"," ssim_score_list = [] \n","\n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n","\n"," shortname_no_PNG = i[:-4]\n"," \n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," \n"," test_GT = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), shortname_no_PNG+\"_real_B.png\"))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real_A.png\"))\n"," \n"," \n"," # -------------------------------- Prediction --------------------------------\n"," \n"," test_prediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake_B.png\"))\n"," \n"," #--------------------------- Here we normalise using histograms matching--------------------------------\n"," test_prediction_matched = match_histograms(test_prediction, test_GT, multichannel=True)\n"," test_source_matched = match_histograms(test_source, test_GT, multichannel=True)\n"," \n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT, test_prediction_matched)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT, test_source_matched)\n","\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n"," \n"," \n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource)])\n","\n"," #Here we calculate the ssim average for each image in each checkpoints\n","\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\n","\n","#------------------------------------------- QC for Grayscale ----------------------------------------------\n","\n"," if Image_type == \"Grayscale\":\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n"," def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n","\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n"," def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n"," def norm_minmse(gt, x, normalize_gt=True):\n"," \n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," \n"," \n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n","\n"," ssim_score_list = []\n"," shortname_no_PNG = i[:-4]\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), shortname_no_PNG+\"_real_B.png\"))\n"," \n"," test_GT = test_GT_raw[:,:,2]\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real_A.png\"))\n"," \n"," test_source = test_source_raw[:,:,2]\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake_B.png\"))\n"," \n"," test_prediction = test_prediction_raw[:,:,2]\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\n","\n"," #Save ssim_maps\n"," \n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_8bit = (img_RSE_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_RSE_GTvsPrediction_8bit)\n"," img_RSE_GTvsSource_8bit = (img_RSE_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsSource_\"+shortname_no_PNG+'.tif',img_RSE_GTvsSource_8bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n"," #Here we calculate the ssim average for each image in each checkpoints\n","\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\n","\n","\n","# All data is now processed saved\n"," \n","\n","# -------------------------------- Display --------------------------------\n","\n","# Display the IoV vs Threshold plot\n","plt.figure(figsize=(20,5))\n","plt.plot(Checkpoint_list, Average_ssim_score_list, label=\"SSIM\")\n","plt.title('Checkpoints vs. SSIM')\n","plt.ylabel('SSIM')\n","plt.xlabel('Checkpoints')\n","plt.legend()\n","plt.show()\n","\n","\n","\n","# -------------------------------- Display RGB --------------------------------\n","\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","if Image_type == \"RGB\":\n"," random_choice_shortname_no_PNG = shortname_no_PNG\n","\n"," @interact\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n","\n"," random_choice_shortname_no_PNG = file[:-4]\n","\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n"," df2 = df1.set_index(\"image #\", drop = False)\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n","\n","#Setting up colours\n"," cmap = None\n","\n","\n"," plt.figure(figsize=(15,15))\n","\n","# Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_B.png\"), as_gray=False, pilmode=\"RGB\")\n"," \n"," plt.imshow(img_GT, cmap = cmap)\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_A.png\"), as_gray=False, pilmode=\"RGB\")\n"," plt.imshow(img_Source, cmap = cmap)\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n","\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake_B.png\"))\n","\n"," plt.imshow(img_Prediction, cmap = cmap)\n"," plt.title('Prediction',fontsize=15)\n","\n","\n","#SSIM between GT and Source\n"," plt.subplot(3,3,5)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","#plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","\n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","#plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","\n","# -------------------------------- Display Grayscale --------------------------------\n","\n","if Image_type == \"Grayscale\":\n"," random_choice_shortname_no_PNG = shortname_no_PNG\n","\n"," @interact\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n","\n"," random_choice_shortname_no_PNG = file[:-4]\n","\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n"," df2 = df1.set_index(\"image #\", drop = False)\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n","\n"," NRMSE_GTvsPrediction = df2.loc[file, \"Prediction v. GT NRMSE\"]\n"," NRMSE_GTvsSource = df2.loc[file, \"Input v. GT NRMSE\"]\n"," PSNR_GTvsSource = df2.loc[file, \"Input v. GT PSNR\"]\n"," PSNR_GTvsPrediction = df2.loc[file, \"Prediction v. GT PSNR\"]\n"," \n","\n"," plt.figure(figsize=(20,20))\n"," # Currently only displays the last computed set, from memory\n"," # Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_B.png\"))\n","\n"," plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99))\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_A.png\"))\n"," plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake_B.png\"))\n"," plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\n"," plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n"," plt.subplot(3,3,5)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_SSIM_GTvsSource = img_SSIM_GTvsSource / 255\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","\n"," \n"," plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," \n"," \n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_SSIM_GTvsPrediction = img_SSIM_GTvsPrediction / 255\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","\n"," \n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n"," plt.subplot(3,3,8)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_RSE_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_RSE_GTvsSource = img_RSE_GTvsSource / 255\n"," \n","\n"," imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n"," plt.subplot(3,3,9)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_RSE_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," img_RSE_GTvsPrediction = img_RSE_GTvsPrediction / 255\n","\n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"Esqnbew8uznk"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as PNG images.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`checkpoint`:** Choose the checkpoint number you would like to use to perform predictions. To use the \"latest\" checkpoint, input \"latest\".\n"]},{"cell_type":"code","metadata":{"id":"yb3suNkfpNA9","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","import glob\n","import os.path\n","\n","latest = \"latest\"\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###What model checkpoint would you like to use?\n","\n","checkpoint = latest#@param {type:\"raw\"}\n","\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","#here we check if we use the newly trained network or not\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","\n","#here we check if the model exists\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","Nb_Checkpoint = len(glob.glob(os.path.join(full_Prediction_model_path, '*G.pth')))+1\n","\n","\n","if not checkpoint == \"latest\":\n","\n"," if checkpoint < 10:\n"," checkpoint = 5\n","\n"," if not checkpoint % 5 == 0:\n"," checkpoint = ((int(checkpoint / 5)-1) * 5)\n"," print (bcolors.WARNING + \" Your chosen checkpoints is not divisible by 5; therefore the checkpoints chosen is now:\",checkpoints)\n","\n","\n"," \n"," if checkpoint == Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n"," if checkpoint > Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n","\n","# Here we need to move the data to be analysed so that pix2pix can find them\n","\n","Saving_path_prediction= \"/content/\"+Prediction_model_name\n","\n","if os.path.exists(Saving_path_prediction):\n"," shutil.rmtree(Saving_path_prediction)\n","os.makedirs(Saving_path_prediction)\n","\n","\n","imageA_folder = Saving_path_prediction+\"/A\"\n","os.makedirs(imageA_folder)\n","\n","imageB_folder = Saving_path_prediction+\"/B\"\n","os.makedirs(imageB_folder)\n","\n","imageAB_folder = Saving_path_prediction+\"/AB\"\n","os.makedirs(imageAB_folder)\n","\n","testAB_Folder = Saving_path_prediction+\"/AB/test\"\n","os.makedirs(testAB_Folder)\n","\n","testA_Folder = Saving_path_prediction+\"/A/test\"\n","os.makedirs(testA_Folder)\n"," \n","testB_Folder = Saving_path_prediction+\"/B/test\"\n","os.makedirs(testB_Folder)\n","\n","for files in os.listdir(Data_folder):\n"," shutil.copyfile(Data_folder+\"/\"+files, testA_Folder+\"/\"+files)\n"," shutil.copyfile(Data_folder+\"/\"+files, testB_Folder+\"/\"+files)\n"," \n","# Here we create a merged A / A image for the prediction\n","os.chdir(\"/content\")\n","!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n","\n","\n","# Here we count how many images are in our folder to be predicted and we had a few\n","Nb_files_Data_folder = len(os.listdir(Data_folder)) +10\n","\n","\n","# This will find the image dimension of a randomly choosen image in Data_folder \n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imageio.imread(Data_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","\n","#-------------------------------- Perform predictions -----------------------------\n","\n","#-------------------------------- Options that can be used to perform predictions -----------------------------\n","\n","# basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n","\n","# model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n","# dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n","# additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n","\n"," #('--ntest', type=int, default=float(\"inf\"), help='# of test examples.')\n"," #('--results_dir', type=str, default='./results/', help='saves results here.')\n"," #('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')\n"," #('--phase', type=str, default='test', help='train, val, test, etc')\n","\n","# Dropout and Batchnorm has different behavioir during training and test.\n"," #('--eval', action='store_true', help='use eval mode during test time.')\n"," #('--num_test', type=int, default=50, help='how many test images to run')\n"," # rewrite devalue values\n"," \n","# To avoid cropping, the load_size should be the same as crop_size\n"," #parser.set_defaults(load_size=parser.get_default('crop_size'))\n","\n","#------------------------------------------------------------------------\n","\n","\n","#---------------------------- Predictions are performed here ----------------------\n","\n","os.chdir(\"/content\")\n","\n","!python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$imageAB_folder\" --name \"$Prediction_model_name\" --model pix2pix --no_dropout --preprocess scale_width --load_size $Image_min_dim --crop_size $Image_min_dim --results_dir \"$Result_folder\" --checkpoints_dir \"$Prediction_model_path\" --num_test $Nb_files_Data_folder --epoch $checkpoint\n","\n","#-----------------------------------------------------------------------------------\n","\n","\n","Checkpoint_name = \"test_\"+str(checkpoint)\n","\n","\n","Prediction_results_folder = Result_folder+\"/\"+Prediction_model_name+\"/\"+Checkpoint_name+\"/images\"\n","\n","Prediction_results_images = os.listdir(Prediction_results_folder)\n","\n","for f in Prediction_results_images: \n"," if (f.endswith(\"_real_B.png\")): \n"," os.remove(Prediction_results_folder+\"/\"+f)\n","\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"EIe3CRD7XUxa","colab_type":"text"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","metadata":{"id":"LmDP8xiwXTTL","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","import os\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","\n","\n","random_choice_no_extension = os.path.splitext(random_choice)\n","\n","\n","x = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_real_A.png\")\n","\n","\n","y = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_fake_B.png\")\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Input')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Prediction')\n","plt.axis('off');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB","colab_type":"text"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"Rn9zpWpo0xNw","colab_type":"text"},"source":["\n","#**Thank you for using pix2pix!**"]}]} \ No newline at end of file