diff --git a/.gitignore b/.gitignore
index 881bd15..4691721 100644
--- a/.gitignore
+++ b/.gitignore
@@ -29,7 +29,11 @@ suite2p.egg-info/
build/
FaceMap/ops_user.npy
*.ipynb
+*.pth
+*.pt
+*.gif
+#
# =========================
# Operating System Files
# =========================
diff --git a/README.md b/README.md
index bdf0b57..dce8b35 100644
--- a/README.md
+++ b/README.md
@@ -1,28 +1,25 @@
-# facemap
+[![Downloads](https://pepy.tech/badge/facemap)](https://pepy.tech/project/facemap)
+[![Downloads](https://pepy.tech/badge/facemap/month)](https://pepy.tech/project/facemap)
+[![GitHub stars](https://badgen.net/github/stars/Mouseland/facemap)](https://github.com/MouseLand/facemap/stargazers)
+[![GitHub forks](https://badgen.net/github/forks/Mouseland/facemap)](https://github.com/MouseLand/facemap/network/members)
+[![](https://img.shields.io/github/license/MouseLand/facemap)](https://github.com/MouseLand/facemap/blob/main/LICENSE)
+[![PyPI version](https://badge.fury.io/py/facemap.svg)](https://badge.fury.io/py/facemap)
+[![Documentation Status](https://readthedocs.org/projects/ansicolortags/badge/?version=latest)](https://pypi.org/project/facemap/)
+[![GitHub open issues](https://badgen.net/github/open-issues/Mouseland/facemap)](https://github.com/MouseLand/facemap/issues)
-GUI for processing videos of rodents, implemented in python and MATLAB. Works for grayscale and RGB movies. Can process multi-camera videos. Some example movies to test the GUI on are located [here](https://drive.google.com/open?id=1cRWCDl8jxWToz50dCX1Op-dHcAC-ttto). You can save the output from both the python and matlab versions as a matlab file with a checkbox in the GUI (if you'd like to use the python version - it has a better GUI).
+# Facemap
-# Data acquisition info
+Pose tracking of mouse face from different camera views (python only) and svd processing of videos (python and MATLAB). Includes GUI and CLI for easy use.
-IR ILLUMINATION:
+## [Installation](https://github.com/MouseLand/facemap/blob/dev/docs/installation.md)
-For recording in darkness we use [IR illumination](https://www.amazon.com/Logisaf-Invisible-Infrared-Security-Cameras/dp/B01MQW8K7Z/ref=sr_1_12?s=security-surveillance&ie=UTF8&qid=1505507302&sr=1-12&keywords=ir+light) at 850nm, which works well with 2p imaging at 970nm and even 920nm. Depending on your needs, you might want to choose a different wavelength, which changes all the filters below as well. 950nm works just as well, and probably so does 750nm, which still outside of the visible range for rodents.
+- For latest released version (from PyPI) including svd processing only, run `pip install facemap` for headless version or `pip install facemap[gui]` for using GUI.
-If you want to focus the illumination on the mouse eye or face, you will need a different, more expensive system. Here is an example, courtesy of Michael Krumin from the Carandini lab: [driver](https://www.thorlabs.com/thorproduct.cfm?partnumber=LEDD1B), [power supply](https://www.thorlabs.com/newgrouppage9.cfm?objectgroup_id=1710&pn=KPS101#8865), [LED](https://www.thorlabs.com/newgrouppage9.cfm?objectgroup_id=2692&pn=M850L3#4426), [lens](https://www.thorlabs.com/newgrouppage9.cfm?objectgroup_id=259&pn=AC254-030-B#2231), and [lens tube](https://www.thorlabs.com/newgrouppage9.cfm?objectgroup_id=4109&pn=SM1V10#3389), and another [lens tube](https://www.thorlabs.com/thorproduct.cfm?partnumber=SM1L10).
-
-CAMERAS:
-
-We use [ptgrey cameras](https://www.ptgrey.com/flea3-13-mp-mono-usb3-vision-vita-1300-camera). The software we use for simultaneous acquisition from multiple cameras is [BIAS](http://public.iorodeo.com/notes/bias/) software. A basic lens that works for zoomed out views [here](https://www.bhphotovideo.com/c/product/414195-REG/Tamron_12VM412ASIR_12VM412ASIR_1_2_4_12_F_1_2.html). To see the pupil well you might need a better zoom lens [10x here](https://www.edmundoptics.com/imaging-lenses/zoom-lenses/10x-13-130mm-fl-c-mount-close-focus-zoom-lens/#specs).
-
-For 2p imaging, you'll need a tighter filter around 850nm so you don't see the laser shining through the mouse's eye/head, for example [this](https://www.thorlabs.de/thorproduct.cfm?partnumber=FB850-40). Depending on your lenses you'll need to figure out the right adapter(s) for such a filter. For our 10x lens above, you might need all of these: [adapter1](https://www.edmundoptics.com/optics/optical-filters/optical-filter-accessories/M52-to-M46-Filter-Thread-Adapter/), [adapter2](https://www.thorlabs.de/thorproduct.cfm?partnumber=SM2A53), [adapter3](https://www.thorlabs.de/thorproduct.cfm?partnumber=SM2A6), [adapter4](https://www.thorlabs.de/thorproduct.cfm?partnumber=SM1L03).
+- For using tracker and svd processing, run `pip install git+https://github.com/mouseland/facemap.git` that will install the latest development version on github.
-# Installation
+To upgrade Facemap ([PyPI package](https://pypi.org/project/facemap/)), within the environment run: `pip install facemap --upgrade`
-## PYTHON
-
-This package only supports python 3. We recommend installing python 3 with **[Anaconda](https://www.anaconda.com/download/)**.
-
-### Using the environment.yml file (recommended)
+Facemap installeding is recommended using the environment.yml:
1. Download the `environment.yml` file from the repository
2. Open an anaconda prompt / command prompt with `conda` for **python 3** in the path
@@ -30,98 +27,67 @@ This package only supports python 3. We recommend installing python 3 with **[An
4. To activate this new environment, run `conda activate facemap`
5. You should see `(facemap)` on the left side of the terminal line. Now run `python -m facemap` and you're all set.
-To upgrade FaceMap (package [here](https://pypi.org/project/facemap/)), within the environment run:
-~~~~
-pip install facemap --upgrade
-~~~~
-
-### Using pip package
-
-Run the following
-~~~
-pip install facemap
-~~~
+# Pose tracking
-If you are using running ROIs, you will want to install mkl_fft via conda:
-~~~~
-conda install -c conda-forge mkl_fft
-~~~~
+
-### Common issues
+The latest python version is integrated with Facemap network for tracking 14 distinct keypoints on mouse face and an additional point for tracking paw. The keypoints can be tracked from different camera views (some examples shown below).
-If you have pip issues, there might be some interaction between pre-installed dependencies and the ones FaceMap needs. First thing to try is
-~~~~
-python -m pip install --upgrade pip
-~~~~
+
+
+
+
+
+## [GUI Instructions](docs/pose_tracking_gui_tutorial.md)
+For pose tracking, load video and check `keypoints` then click `process` button. A dialog box will appear for selecting a bounding box for the face. The keypoints will be tracked in the selected bounding box. Please ensure that the bouding box is focused on the face where all the keypoints shown above will be visible. See example frames [here](figs/mouse_views.png).
-If when running `python -m facemap`, you receive the error: `No module named PyQt5.sip`, then try uninstalling and reinstalling pyqt5
-~~~
-pip uninstall pyqt5 pyqt5-tools
-pip install pyqt5 pyqt5-tools pyqt5.sip
-~~~
+Use the file menu to set path of output folder. The processed keypoints file will be saved in the output folder with an extension of `.h5` and corresponding metadata file with extension `.pkl`.
-If you are on Yosemite Mac OS, PyQt doesn't work, and you won't be able to install FaceMap. More recent versions of Mac OS are fine.
+## [CLI Instructions](docs/pose_tracking_cli_tutorial.md)
-The software has been heavily tested on Ubuntu 18.04, and less well tested on Windows 10 and Mac OS. Please post an issue if you have installation problems.
+For more examples, please see [tutorial notebooks](https://github.com/MouseLand/facemap/tree/dev/notebooks).
-### Dependencies
+## :mega: User contributions :video_camera: :camera:
+Facemap's goal is to provide a simple way to generate keypoints for rodent face tracking. However, we need a large dataset of images from different camera views to reduce any errors on new mice videos. Hence, we would like to get your help to further expand our dataset. You can contribute by sending us a video or few frames of your mouse on following email address(es): `syedaa[at]janelia.hhmi.org` or `stringerc[at]janelia.hhmi.org`. Please let us know of any issues using the software by sending us an email or [opening an issue on GitHub](https://github.com/MouseLand/facemap/issues).
-FaceMap python relies on these awesome packages:
-- [pyqtgraph](http://pyqtgraph.org/)
-- [PyQt5](http://pyqt.sourceforge.net/Docs/PyQt5/)
-- [numpy](http://www.numpy.org/) (>=1.13.0)
-- [scipy](https://www.scipy.org/)
-- [opencv](https://opencv.org/)
-- [numba](http://numba.pydata.org/numba-doc/latest/user/5minguide.html)
-- [natsort](https://natsort.readthedocs.io/en/master/)
-## MATLAB
+# SVD processing
-The matlab version needs to be downloaded/cloned from github (no install required). It works in Matlab 2014b and above - please submit issues if it's not working. The Image Processing Toolbox is necessary to use the GUI. For GPU functionality, the Parallel Processing Toolbox is required. If you don't have the Parallel Processing Toolbox, uncheck the box next to "use GPU" in the GUI before processing.
+Works for grayscale and RGB movies. Can process multi-camera videos. Some example movies to test the GUI on are located [here](https://drive.google.com/open?id=1cRWCDl8jxWToz50dCX1Op-dHcAC-ttto). You can save the output from both the python and matlab versions as a matlab file with a checkbox in the GUI (if you'd like to use the python version - it has a better GUI).
-### Supported movie files
+Supported movie files:
'.mj2','.mp4','.mkv','.avi','.mpeg','.mpg','.asf'
-# Start processing! *HOW TO GUI*
+### Data acquisition info
-## PYTHON
+IR ILLUMINATION:
-([video](https://www.youtube.com/watch?v=Rq8fEQ-DOm4) with old install instructions)
+For recording in darkness we use [IR illumination](https://www.amazon.com/Logisaf-Invisible-Infrared-Security-Cameras/dp/B01MQW8K7Z/ref=sr_1_12?s=security-surveillance&ie=UTF8&qid=1505507302&sr=1-12&keywords=ir+light) at 850nm, which works well with 2p imaging at 970nm and even 920nm. Depending on your needs, you might want to choose a different wavelength, which changes all the filters below as well. 950nm works just as well, and probably so does 750nm, which still outside of the visible range for rodents.
-
+If you want to focus the illumination on the mouse eye or face, you will need a different, more expensive system. Here is an example, courtesy of Michael Krumin from the Carandini lab: [driver](https://www.thorlabs.com/thorproduct.cfm?partnumber=LEDD1B), [power supply](https://www.thorlabs.com/newgrouppage9.cfm?objectgroup_id=1710&pn=KPS101#8865), [LED](https://www.thorlabs.com/newgrouppage9.cfm?objectgroup_id=2692&pn=M850L3#4426), [lens](https://www.thorlabs.com/newgrouppage9.cfm?objectgroup_id=259&pn=AC254-030-B#2231), and [lens tube](https://www.thorlabs.com/newgrouppage9.cfm?objectgroup_id=4109&pn=SM1V10#3389), and another [lens tube](https://www.thorlabs.com/thorproduct.cfm?partnumber=SM1L10).
-Run the following command in a terminal
-```
-python -m facemap
-```
-The following window should appear. The upper left "file" button loads single files, the upper middle "folder" button loads whole folders (from which you can select movies), and the upper right "folder" button loads processed files ("_proc.npy" files). Load a video or a group of videos (see below for file formats for simultaneous videos). The video(s) will pop up in the left side of the GUI. You can zoom in and out with the mouse wheel, and you can drag by holding down the mouse. Double-click to return to the original, full view.
+CAMERAS:
-Choose a type of ROI to add and then click "add ROI" to add it to the view. The pixels in the ROI will show up in the right window (with different processing depending on the ROI type - see below). You can move it and resize the ROI anytime. You can delete the ROI with "right-click" and selecting "remove". You can change the saturation of the ROI with the upper right saturation bar. You can also just click on the ROI at any time to see what it looks like in the right view.
+We use [ptgrey cameras](https://www.ptgrey.com/flea3-13-mp-mono-usb3-vision-vita-1300-camera). The software we use for simultaneous acquisition from multiple cameras is [BIAS](http://public.iorodeo.com/notes/bias/) software. A basic lens that works for zoomed out views [here](https://www.bhphotovideo.com/c/product/414195-REG/Tamron_12VM412ASIR_12VM412ASIR_1_2_4_12_F_1_2.html). To see the pupil well you might need a better zoom lens [10x here](https://www.edmundoptics.com/imaging-lenses/zoom-lenses/10x-13-130mm-fl-c-mount-close-focus-zoom-lens/#specs).
-By default, the "Compute multivideo SVD" box is unchecked. If you check it, then the motion SVD is computed across ALL videos - all videos are concatenated at each timepoint, and the SVD of this matrix of ALL_PIXELS x timepoints is computed. If you have just one video acquired at a time, then it is the SVD of this video.
+For 2p imaging, you'll need a tighter filter around 850nm so you don't see the laser shining through the mouse's eye/head, for example [this](https://www.thorlabs.de/thorproduct.cfm?partnumber=FB850-40). Depending on your lenses you'll need to figure out the right adapter(s) for such a filter. For our 10x lens above, you might need all of these: [adapter1](https://www.edmundoptics.com/optics/optical-filters/optical-filter-accessories/M52-to-M46-Filter-Thread-Adapter/), [adapter2](https://www.thorlabs.de/thorproduct.cfm?partnumber=SM2A53), [adapter3](https://www.thorlabs.de/thorproduct.cfm?partnumber=SM2A6), [adapter4](https://www.thorlabs.de/thorproduct.cfm?partnumber=SM1L03).
-
-
-
-Once processing starts, the interface will no longer be clickable and all information about processing will be in the terminal in which you opened FaceMap:
-
-
-
+## [*HOW TO GUI* (Python)](docs/svd_python_tutorial.md)
-If you want to open the GUI with a movie file specified and/or save path specified, the following command will allow this:
-~~~
-python -m facemap --movie '/home/carsen/movie.avi' --savedir '/media/carsen/SSD/'
-~~~
-Note this will only work if you only have one file that you need to load (can't have multiple in series / multiple views).
+([video](https://www.youtube.com/watch?v=Rq8fEQ-DOm4) with old install instructions)
+
-### Batch processing (python only)
+Run the following command in a terminal
+```
+python -m facemap
+```
+Default starting folder is set to wherever you run `python -m FaceMap`
-Load a video or a set of videos and draw your ROIs and choose your processing settings. Then click **save ROIs**. This will save a *_proc.npy file in the folder in the specified **save folder**. The name of this proc file will be listed below **process batch** (this button will also activate). You can then repeat this process: load the video(s), draw ROIs, choose settings, and click **save ROIs**. Then to process all the listed *_proc.npy files click **process batch**.
-## MATLAB
+## [*HOW TO GUI* (MATLAB)](docs/svd_matlab_tutorial.md)
To start the GUI, run the command `MovieGUI` in this folder. The following window should appear. After you click an ROI button and draw an area, you have to **double-click** inside the drawn box to confirm it. To compute the SVD across multiple simultaneously acquired videos you need to use the "multivideo SVD" options to draw ROI's on each video one at a time.
@@ -129,174 +95,4 @@ To start the GUI, run the command `MovieGUI` in this folder. The following windo
-## Default starting folder
-
-**python**: wherever you run `python -m FaceMap`
-
-**MATLAB**: set at line 59 of MovieGUI.m (h.filepath)
-
-## File loading structure
-
-If you choose a folder instead of a single file, it will assemble a list of all video files in that folder and also all videos 1 folder down. The MATLAB GUI will ask *"would you like to process all movies?"*. If you say no, then a list of movies to choose from will appear. By default the python version shows you a list of movies. If you choose no movies in the python version then it's assumed you want to process ALL of them.
-
-## Processing movies captured simultaneously (multiple camera setups)
-
-Both GUIs will then ask *"are you processing multiple videos taken simultaneously?"*. If you say yes, then the script will look if across movies the **FIRST FOUR** letters of the filename vary. If the first four letters of two movies are the same, then the GUI assumed that they were acquired *sequentially* not *simultaneously*.
-
-Example file list:
-+ cam1_G7c1_1.avi
-+ cam1_G7c1_2.avi
-+ cam2_G7c1_1.avi
-+ cam2_G7c1_2.avi
-+ cam3_G7c1_1.avi
-+ cam3_G7c1_2.avi
-
-*"are you processing multiple videos taken simultaneously?"* ANSWER: Yes
-
-Then the GUIs assume {cam1_G7c1_1.avi, cam2_G7c1_1.avi, cam3_G7c1_1.avi} were acquired simultaneously and {cam1_G7c1_2.avi, cam2_G7c1_2.avi, cam3_G7c1_2.avi} were acquired simultaneously. They will be processed in alphabetical order (1 before 2) and the results from the videos will be concatenated in time. If one of these files was missing, then the GUI will error and you will have to choose file folders again. Also you will get errors if the files acquired at the same time aren't the same frame length (e.g. {cam1_G7c1_1.avi, cam2_G7c1_1.avi, cam3_G7c1_1.avi} should all have the same number of frames).
-
-Note: if you have many simultaneous videos / overall pixels (e.g. 2000 x 2000) you will need around 32GB of RAM to compute the full SVD motion masks.
-
-**python**: you will be able to see all the videos that were simultaneously collected at once. However, you can only draw ROIs that are within ONE video. Only the "multivideo SVD" is computed over all videos.
-
-**MATLAB**: after the file choosing process is over, you will see all the movies in the drop down menu (by filename). You can switch between them and inspect how well an ROI works for each of the movies.
-
-# ROI types
-
-## Pupil computation
-
-The minimum pixel value is subtracted from the ROI. Use the saturation bar to reduce the background of the eye. The algorithm zeros out any pixels less than the saturation level (I recommend a *very* low value - so most pixels are white in the GUI).
-
-Next it finds the pixel with the largest magnitude. It draws a box around that area (1/2 the size of the ROI) and then finds the center-of-mass of that region. It then centers the box on that area. It fits a multivariate gaussian to the pixels in the box using maximum likelihood (see [pupil.py](facemap/pupil.py) or [fitMVGaus.m](matlab/utils/fitMVGaus.m)).
-
-After a Gaussian is fit, it zeros out pixels whose squared distance from the center (normalized by the standard deviation of the Gaussian fit) is greater than 2 * sigma^2 where sigma is set by the user in the GUI (default sigma = 2.5). It now performs the fit again with these points erased, and repeats this process 4 more times. The pupil is then defined as an ellipse sigma standard deviations away from the center-of-mass of the gaussian. This is plotted with '+' around the ellipse and with one '+' at the center.
-
-If there are reflections on the mouse's eye, then you can draw ellipses to account for this "corneal reflection" (plotted in black). You can add as many of these per pupil ROI as needed. The algorithm fills in these areas of the image with the predicted values, which allows for smooth transitions between big and small pupils.
-
-
-
-This raw pupil area trace is post-processed (see [smoothPupil.m](pupil/smoothPupil.m))). The trace is median filtered with a window of 30 timeframes. At each timepoint, the difference between the raw trace and the median filtered trace is computed. If the difference at a given point exceeds half the standard deviation of the raw trace, then the raw value is replaced by the median filtered value.
-
-![pupil](/figs/pupilfilter.png?raw=true "pupil filtering")
-
-## Blink computation
-
-You may want to ignore frames in which the animal is blinking if you are looking at pupil size. The blink area is the number of pixels above the saturation level that you set (all non-white pixels).
-
-
-## Motion SVD
-
-The motion SVDs (small ROIs / multivideo) are computed on the movie downsampled in space by the spatial downsampling input box in the GUI (default 4 pixels). Note the saturation set in this window is NOT used for any processing.
-
-The motion *M* is defined as the abs(current_frame - previous_frame), and the average motion energy across frames is computed using a subset of frames (*avgmot*) (at least 1000 frames - set at line 45 in [subsampledMean.m](matlab/subsampledMean.m) or line 183 in [process.py](facemap/process.py)). Then the singular vectors of the motion energy are computed on chunks of data, also from a subset of frames (15 chunks of 1000 frames each). Let *F* be the chunk of frames [pixels x time]. Then
-```
-uMot = [];
-for j = 1:nchunks
- M = abs(diff(F,1,2));
- [u,~,~] = svd(M - avgmot);
- uMot = cat(2, uMot, u);
-end
-[uMot,~,~] = svd(uMot);
-uMotMask = normc(uMot(:, 1:500)); % keep 500 components
-```
-*uMotMask* are the motion masks that are then projected onto the video at all timepoints (done in chunks of size *nt*=500):
-```
-for j = 1:nchunks
- M = abs(diff(F,1,2));
- motSVD0 = (M - avgmot)' * uMotMask;
- motSVD((j-1)*nt + [1:nt],:) = motSVD0;
-end
-```
-Example motion masks *uMotMask* and traces *motSVD*:
-
-
-
-We found that these extracted singular vectors explained up to half of the total explainable variance in neural activity in visual cortex and in other forebrain areas. See our [paper](https://science.sciencemag.org/content/364/6437/eaav7893) for more details.
-
-In the python version, we also compute the average of *M* across all pixels in each motion ROI and that is returned as the **motion**. The first **motion** field is non-empty if "multivideo SVD" is on, and in that case it is the average motion energy across all pixels in all views.
-
-## Running computation
-
-The phase-correlation between consecutive frames (in running ROI) are computed in the fourier domain (see [running.py](/facemap/running.py) or [processRunning.m](/matlab/running/processRunning.m)). The XY position of maximal correlation gives the amount of shift between the two consecutive frames. Depending on how fast the movement is frame-to-frame you may want at least a 50x50 pixel ROI to compute this.
-
-## Multivideo SVD ROIs
-
-**PYTHON**: Check box "Compute multivideo SVD" to compute the SVD of all pixels in all videos.
-
-**MATLAB**: You can draw areas to be included and excluded in the multivideo SVD (or single video if you only have one view). The buttons are "area to keep" and "area to exclude" and will draw blue and red boxes respectively. The union of all pixels in "areas to include" are used, excluding any pixels that intersect this union from "areas to exclude" (you can toggle between viewing the boxes and viewing the included pixels using the "Show areas" checkbox, see example below).
-
-
-
-The motion energy is then computed from these non-red pixels.
-
-# Output of processing
-
-The GUIs create one file for all videos (saved in current folder), the npy file has name "videofile_proc.npy" and the mat file has name "videofile_proc.mat".
-
-**PYTHON**:
-- **filenames**: list of lists of video filenames - each list are the videos taken simultaneously
-- **Ly**, **Lx**: list of number of pixels in Y (Ly) and X (Lx) for each video taken simultaneously
-- **sbin**: spatial bin size for motion SVDs
-- **Lybin**, **Lxbin**: list of number of pixels binned by sbin in Y (Ly) and X (Lx) for each video taken simultaneously
-- **sybin**, **sxbin**: coordinates of multivideo (for plotting/reshaping ONLY)
-- **LYbin**, **LXbin**: full-size of all videos embedded in rectangle (binned)
-- **fullSVD**: whether or not "multivideo SVD" is computed
-- **save_mat**: whether or not to save proc as *.mat file
-- **avgframe**: list of average frames for each video from a subset of frames (binned by sbin)
-- **avgframe_reshape**: average frame reshaped to be y-pixels x x-pixels
-- **avgmotion**: list of average motions for each video from a subset of frames (binned by sbin)
-- **avgmotion_reshape**: average motion reshaped to be y-pixels x x-pixels
-- **iframes**: array containing number of frames in each consecutive video
-- **motion**: list of absolute motion energies across time - first is "multivideo" motion energy (empty if not computed)
-- **motSVD**: list of motion SVDs - first is "multivideo SVD" (empty if not computed) - each is nframes x components
-- **motMask**: list of motion masks for each motion SVD - each motMask is pixels x components
-- **motMask_reshape**: motion masks reshaped to be y-pixels x x-pixels x components
-- **pupil**: list of pupil ROI outputs - each is a dict with 'area', 'area_smooth', and 'com' (center-of-mass)
-- **blink**: list of blink ROI outputs - each is nframes, the blink area on each frame
-- **running**: list of running ROI outputs - each is nframes x 2, for X and Y motion on each frame
-- **rois**: ROIs that were drawn and computed
- - *rind*: type of ROI in number
- - *rtype*: what type of ROI ('motion SVD', 'pupil', 'blink', 'running')
- - *ivid*: in which video is the ROI
- - *color*: color of ROI
- - *yrange*: y indices of ROI
- - *xrange*: x indices of ROI
- - *saturation*: saturation of ROI (0-255)
- - *pupil_sigma*: number of stddevs used to compute pupil radius (for pupil ROIs)
- - *yrange_bin*: binned indices in y (if motion SVD)
- - *xrange_bin*: binned indices in x (if motion SVD)
-
-Note this is a dict, so say *.item() after loading:
-```
-import numpy as np
-proc = np.load('cam1_proc.npy').item()
-```
-These *_proc.npy* files can be loaded into the GUI (and will automatically be loaded after processing). The checkboxes in the lower left allow you to view different traces from the processing.
-
-**MATLAB**:
-- **nX**,**nY**: cell arrays of number of pixels in X and Y in each video taken simultaneously
-- **sc**: spatial downsampling constant used
-- **ROI**: [# of videos x # of areas] - areas to be included for multivideo SVD (in downsampled reference)
-- **eROI**: [# of videos x # of areas] - areas to be excluded from multivideo SVD (in downsampled reference)
-- **locROI**: location of small ROIs (in order running, ROI1, ROI2, ROI3, pupil1, pupil2); in downsampled reference
-- **ROIfile**: in which movie is the small ROI
-- **plotROIs**: which ROIs are being processed (these are the ones shown on the frame in the GUI)
-- **files**: all the files you processed together
-- **npix**: array of number of pixels from each video used for multivideo SVD
-- **tpix**: array of number of pixels in each view that was used for SVD processing
-- **wpix**: cell array of which pixels were used from each video for multivideo SVD
-- **avgframe**: [sum(tpix) x 1] average frame across videos computed on a subset of frames
-- **avgmotion**: [sum(tpix) x 1] average frame across videos computed on a subset of frames
-- **motSVD**: cell array of motion SVDs [components x time] (in order: multivideo, ROI1, ROI2, ROI3)
-- **uMotMask**: cell array of motion masks [pixels x time]
-- **runSpeed**: 2D running speed computed using phase correlation [time x 2]
-- **pupil**: structure of size 2 (pupil1 and pupil2) with 3 fields: area, area_raw, and com
-- **thres**: pupil sigma used
-- **saturation**: saturation levels (array in order running, ROI1, ROI2, ROI3, pupil1, pupil2); only saturation levels for pupil1 and pupil2 are used in the processing, others are just for viewing ROIs
-
-an ROI is [1x4]: [y0 x0 Ly Lx]
-
-## Motion SVD Masks in MATLAB
-
-Use the script [plotSVDmasks.m](figs/plotSVDmasks.m) to easily view motion masks from the multivideo SVD. The motion masks from the smaller ROIs have been reshaped to be [xpixels x ypixels x components].
diff --git a/docs/installation.md b/docs/installation.md
new file mode 100644
index 0000000..b812c94
--- /dev/null
+++ b/docs/installation.md
@@ -0,0 +1,75 @@
+# Installation (Python)
+
+This package only supports python 3. We recommend installing python 3 with **[Anaconda](https://www.anaconda.com/download/)**.
+
+
+### For using pose tracker and svd processing
+Please run
+~~~
+pip install git+https://github.com/mouseland/facemap.git
+~~~
+that will install the latest development version on github.
+
+### For latest released version (from PyPI) using svd processing only
+
+Run the following for command line interface (CLI) i.e. headless version:
+~~~
+pip install facemap
+~~~
+or the following for using GUI:
+~~~~
+pip install facemap[gui]
+~~~~
+
+To upgrade Facemap (package [here](https://pypi.org/project/facemap/)), within the environment run:
+~~~~
+pip install facemap --upgrade
+~~~~
+
+Using the environment.yml file (recommended installation method):
+
+1. Download the `environment.yml` file from the repository or clone the github repository: `git clone https://www.github.com/mouseland/facemap.git`
+2. Open an anaconda prompt / command prompt with `conda` for **python 3** in the path
+3. Change directory to facemap folder `cd facemap`
+4. Run `conda env create -f environment.yml`
+5. To activate this new environment, run `conda activate facemap`
+6. You should see `(facemap)` on the left side of the terminal line. Now run `python -m facemap` and you're all set.
+
+## Common installation issues
+
+If you have pip issues, there might be some interaction between pre-installed dependencies and the ones FaceMap needs. First thing to try is
+~~~~
+python -m pip install --upgrade pip
+~~~~
+
+While running `python -m facemap`, if you receive the error: `No module named PyQt5.sip`, then try uninstalling and reinstalling pyqt5
+~~~
+pip uninstall pyqt5 pyqt5-tools
+pip install pyqt5 pyqt5-tools pyqt5.sip
+~~~
+
+If you are on Yosemite Mac OS, PyQt doesn't work, and you won't be able to install Facemap. More recent versions of Mac OS are fine.
+
+The software has been heavily tested on Ubuntu 18.04, and less well tested on Windows 10 and Mac OS. Please post an issue if you have installation problems.
+
+### Pyhton dependencies
+
+Facemap python relies on these awesome packages:
+- [pyqtgraph](http://pyqtgraph.org/)
+- [PyQt5](http://pyqt.sourceforge.net/Docs/PyQt5/)
+- [numpy](http://www.numpy.org/) (>=1.13.0)
+- [scipy](https://www.scipy.org/)
+- [opencv](https://opencv.org/)
+- [numba](http://numba.pydata.org/numba-doc/latest/user/5minguide.html)
+- [natsort](https://natsort.readthedocs.io/en/master/)
+- [PyTorch](https://pytorch.org)
+- [Matplotlib](https://matplotlib.org)
+- [SciPy](https://scipy.org)
+- [tqdm](https://tqdm.github.io)
+- [pandas](https://pandas.pydata.org)
+- [UMAP](https://umap-learn.readthedocs.io/en/latest/)
+
+
+# Installation (MATLAB)
+
+The matlab version supports SVD processing only and does not include face tracker. The package can be downloaded/cloned from github (no install required). It works in Matlab 2014b and above - please submit issues if it's not working. The Image Processing Toolbox is necessary to use the GUI. For GPU functionality, the Parallel Processing Toolbox is required. If you don't have the Parallel Processing Toolbox, uncheck the box next to "use GPU" in the GUI before processing.
diff --git a/docs/pose_tracking_cli_tutorial.md b/docs/pose_tracking_cli_tutorial.md
new file mode 100644
index 0000000..3de3c74
--- /dev/null
+++ b/docs/pose_tracking_cli_tutorial.md
@@ -0,0 +1 @@
+# Pose tracking **(CLI)**
\ No newline at end of file
diff --git a/docs/pose_tracking_gui_tutorial.md b/docs/pose_tracking_gui_tutorial.md
new file mode 100644
index 0000000..a7de406
--- /dev/null
+++ b/docs/pose_tracking_gui_tutorial.md
@@ -0,0 +1,34 @@
+# Pose tracking **(GUI)** :mouse:
+
+
+
+The latest python version is integrated with Facemap network for tracking 14 distinct keypoints on mouse face and an additional point for tracking paw. The keypoints can be tracked from different camera views (some examples shown below).
+
+
+
+
+
+
+For pose tracking using GUI after following the [installation instructions](installation.md), proceed with the following steps:
+
+1. Load video
+ - Select `File` from the menu bar
+ - For processing single video, select `Load single movie file`
+ - Alternatively, for processing multiple videos, select `Open folder of movies` and then select the files you want to process. Please note multiple videos are processed sequentially.
+2. Select output folder
+ - Use the file menu to `Set output folder`.
+ - The processed keypoints file will be saved in the selected output folder with an extension of `.h5` and corresponding metadata file with extension `.pkl`.
+3. Choose processing options
+ - Check at least one of the following boxes:
+ - `Keypoints` for face pose tracking.
+ - `motSVD` for SVD processing of difference across frames over time.
+ - `movSVD` for SVD processing of raw movie frames.
+ - Click `process` button and monitor progress bar at the bottom of the window to see updates.
+4. Select ROI/bounding box for face
+ - Once you hit `process`, a dialog box will appear for selecting a bounding box for the face. The keypoints will be tracked in the selected bounding box. Please ensure that the bouding box is focused on the face where all the keypoints shown above will be visible. See example frames [here](figs/mouse_views.png). Once the bounding box is focused, click 'Done' to continue.
+ - Alternatively, if you wish to use the entire frame for the mouse then click 'Skip' to continue.
+ - If a 'Face (pose)' ROI has already been selected using the dropdown menu for ROIs the bounding box will be automatically selected and the keypoints will be tracked in the selected ROI.
+
+The videos will be processed in the order they are listed in the file list and output will be saved in the output folder. Following is an example gif demonstrating the above mentioned steps for tracking keypoints in a video.
+
+
diff --git a/docs/svd_matlab_tutorial.md b/docs/svd_matlab_tutorial.md
new file mode 100644
index 0000000..db4b6aa
--- /dev/null
+++ b/docs/svd_matlab_tutorial.md
@@ -0,0 +1,130 @@
+# *HOW TO GUI* (MATLAB)
+
+To start the GUI, run the command `MovieGUI` in this folder. The following window should appear. After you click an ROI button and draw an area, you have to **double-click** inside the drawn box to confirm it. To compute the SVD across multiple simultaneously acquired videos you need to use the "multivideo SVD" options to draw ROI's on each video one at a time.
+
+
+
+
+
+Default starting folder is set at line 59 of MovieGUI.m (h.filepath)
+
+#### File loading structure
+
+If you choose a folder instead of a single file, it will assemble a list of all video files in that folder and also all videos 1 folder down. The MATLAB GUI will ask *"would you like to process all movies?"*. If you say no, then a list of movies to choose from will appear. By default the python version shows you a list of movies. If you choose no movies in the python version then it's assumed you want to process ALL of them.
+
+#### Processing movies captured simultaneously (multiple camera setups)
+
+Both GUIs will then ask *"are you processing multiple videos taken simultaneously?"*. If you say yes, then the script will look if across movies the **FIRST FOUR** letters of the filename vary. If the first four letters of two movies are the same, then the GUI assumed that they were acquired *sequentially* not *simultaneously*.
+
+Example file list:
++ cam1_G7c1_1.avi
++ cam1_G7c1_2.avi
++ cam2_G7c1_1.avi
++ cam2_G7c1_2.avi
++ cam3_G7c1_1.avi
++ cam3_G7c1_2.avi
+
+*"are you processing multiple videos taken simultaneously?"* ANSWER: Yes
+
+Then the GUIs assume {cam1_G7c1_1.avi, cam2_G7c1_1.avi, cam3_G7c1_1.avi} were acquired simultaneously and {cam1_G7c1_2.avi, cam2_G7c1_2.avi, cam3_G7c1_2.avi} were acquired simultaneously. They will be processed in alphabetical order (1 before 2) and the results from the videos will be concatenated in time. If one of these files was missing, then the GUI will error and you will have to choose file folders again. Also you will get errors if the files acquired at the same time aren't the same frame length (e.g. {cam1_G7c1_1.avi, cam2_G7c1_1.avi, cam3_G7c1_1.avi} should all have the same number of frames).
+
+Note: if you have many simultaneous videos / overall pixels (e.g. 2000 x 2000) you will need around 32GB of RAM to compute the full SVD motion masks.
+
+After the file choosing process is over, you will see all the movies in the drop down menu (by filename). You can switch between them and inspect how well an ROI works for each of the movies.
+
+### ROI types
+
+#### Pupil computation
+
+The minimum pixel value is subtracted from the ROI. Use the saturation bar to reduce the background of the eye. The algorithm zeros out any pixels less than the saturation level (I recommend a *very* low value - so most pixels are white in the GUI).
+
+Next it finds the pixel with the largest magnitude. It draws a box around that area (1/2 the size of the ROI) and then finds the center-of-mass of that region. It then centers the box on that area. It fits a multivariate gaussian to the pixels in the box using maximum likelihood (see [pupil.py](facemap/pupil.py) or [fitMVGaus.m](matlab/utils/fitMVGaus.m)).
+
+After a Gaussian is fit, it zeros out pixels whose squared distance from the center (normalized by the standard deviation of the Gaussian fit) is greater than 2 * sigma^2 where sigma is set by the user in the GUI (default sigma = 2.5). It now performs the fit again with these points erased, and repeats this process 4 more times. The pupil is then defined as an ellipse sigma standard deviations away from the center-of-mass of the gaussian. This is plotted with '+' around the ellipse and with one '+' at the center.
+
+If there are reflections on the mouse's eye, then you can draw ellipses to account for this "corneal reflection" (plotted in black). You can add as many of these per pupil ROI as needed. The algorithm fills in these areas of the image with the predicted values, which allows for smooth transitions between big and small pupils.
+
+
+
+This raw pupil area trace is post-processed (see [smoothPupil.m](pupil/smoothPupil.m))). The trace is median filtered with a window of 30 timeframes. At each timepoint, the difference between the raw trace and the median filtered trace is computed. If the difference at a given point exceeds half the standard deviation of the raw trace, then the raw value is replaced by the median filtered value.
+
+![pupil](../figs/pupilfilter.png?raw=true "pupil filtering")
+
+#### Blink computation
+
+You may want to ignore frames in which the animal is blinking if you are looking at pupil size. The blink area is the number of pixels above the saturation level that you set (all non-white pixels).
+
+#### Motion SVD
+
+The motion SVDs (small ROIs / multivideo) are computed on the movie downsampled in space by the spatial downsampling input box in the GUI (default 4 pixels). Note the saturation set in this window is NOT used for any processing.
+
+The motion *M* is defined as the abs(current_frame - previous_frame), and the average motion energy across frames is computed using a subset of frames (*avgmot*) (at least 1000 frames - set at line 45 in [subsampledMean.m](matlab/subsampledMean.m) or line 183 in [process.py](facemap/process.py)). Then the singular vectors of the motion energy are computed on chunks of data, also from a subset of frames (15 chunks of 1000 frames each). Let *F* be the chunk of frames [pixels x time]. Then
+```
+uMot = [];
+for j = 1:nchunks
+ M = abs(diff(F,1,2));
+ [u,~,~] = svd(M - avgmot);
+ uMot = cat(2, uMot, u);
+end
+[uMot,~,~] = svd(uMot);
+uMotMask = normc(uMot(:, 1:500)); % keep 500 components
+```
+*uMotMask* are the motion masks that are then projected onto the video at all timepoints (done in chunks of size *nt*=500):
+```
+for j = 1:nchunks
+ M = abs(diff(F,1,2));
+ motSVD0 = (M - avgmot)' * uMotMask;
+ motSVD((j-1)*nt + [1:nt],:) = motSVD0;
+end
+```
+Example motion masks *uMotMask* and traces *motSVD*:
+
+
+
+We found that these extracted singular vectors explained up to half of the total explainable variance in neural activity in visual cortex and in other forebrain areas. See our [paper](https://science.sciencemag.org/content/364/6437/eaav7893) for more details.
+
+In the python version, we also compute the average of *M* across all pixels in each motion ROI and that is returned as the **motion**. The first **motion** field is non-empty if "multivideo SVD" is on, and in that case it is the average motion energy across all pixels in all views.
+
+#### Running computation
+
+The phase-correlation between consecutive frames (in running ROI) are computed in the fourier domain (see [running.py](../facemap/running.py) or [processRunning.m](../matlab/running/processRunning.m)). The XY position of maximal correlation gives the amount of shift between the two consecutive frames. Depending on how fast the movement is frame-to-frame you may want at least a 50x50 pixel ROI to compute this.
+
+#### Multivideo SVD ROIs
+
+You can draw areas to be included and excluded in the multivideo SVD (or single video if you only have one view). The buttons are "area to keep" and "area to exclude" and will draw blue and red boxes respectively. The union of all pixels in "areas to include" are used, excluding any pixels that intersect this union from "areas to exclude" (you can toggle between viewing the boxes and viewing the included pixels using the "Show areas" checkbox, see example below).
+
+
+
+The motion energy is then computed from these non-red pixels.
+
+### Proccessed output
+
+The GUIs create one file for all videos (saved in current folder), the processed mat file has name "videofile_proc.mat".
+
+**MATLAB output**:
+- **nX**,**nY**: cell arrays of number of pixels in X and Y in each video taken simultaneously
+- **sc**: spatial downsampling constant used
+- **ROI**: [# of videos x # of areas] - areas to be included for multivideo SVD (in downsampled reference)
+- **eROI**: [# of videos x # of areas] - areas to be excluded from multivideo SVD (in downsampled reference)
+- **locROI**: location of small ROIs (in order running, ROI1, ROI2, ROI3, pupil1, pupil2); in downsampled reference
+- **ROIfile**: in which movie is the small ROI
+- **plotROIs**: which ROIs are being processed (these are the ones shown on the frame in the GUI)
+- **files**: all the files you processed together
+- **npix**: array of number of pixels from each video used for multivideo SVD
+- **tpix**: array of number of pixels in each view that was used for SVD processing
+- **wpix**: cell array of which pixels were used from each video for multivideo SVD
+- **avgframe**: [sum(tpix) x 1] average frame across videos computed on a subset of frames
+- **avgmotion**: [sum(tpix) x 1] average frame across videos computed on a subset of frames
+- **motSVD**: cell array of motion SVDs [components x time] (in order: multivideo, ROI1, ROI2, ROI3)
+- **uMotMask**: cell array of motion masks [pixels x time]
+- **runSpeed**: 2D running speed computed using phase correlation [time x 2]
+- **pupil**: structure of size 2 (pupil1 and pupil2) with 3 fields: area, area_raw, and com
+- **thres**: pupil sigma used
+- **saturation**: saturation levels (array in order running, ROI1, ROI2, ROI3, pupil1, pupil2); only saturation levels for pupil1 and pupil2 are used in the processing, others are just for viewing ROIs
+
+an ROI is [1x4]: [y0 x0 Ly Lx]
+
+#### Motion SVD Masks in MATLAB
+
+Use the script [plotSVDmasks.m](../figs/plotSVDmasks.m) to easily view motion masks from the multivideo SVD. The motion masks from the smaller ROIs have been reshaped to be [xpixels x ypixels x components].
+
diff --git a/docs/svd_python_tutorial.md b/docs/svd_python_tutorial.md
new file mode 100644
index 0000000..376c6be
--- /dev/null
+++ b/docs/svd_python_tutorial.md
@@ -0,0 +1,124 @@
+
+
+## *HOW TO GUI* (Python)
+
+([video](https://www.youtube.com/watch?v=Rq8fEQ-DOm4) with old install instructions)
+
+
+
+Run the following command in a terminal
+```
+python -m facemap
+```
+Default starting folder is set to wherever you run `python -m FaceMap`
+
+The following window should appear. The upper left "file" button loads single files, the upper middle "folder" button loads whole folders (from which you can select movies), and the upper right "folder" button loads processed files ("_proc.npy" files). Load a video or a group of videos (see below for file formats for simultaneous videos). The video(s) will pop up in the left side of the GUI. You can zoom in and out with the mouse wheel, and you can drag by holding down the mouse. Double-click to return to the original, full view.
+
+Choose a type of ROI to add and then click "add ROI" to add it to the view. The pixels in the ROI will show up in the right window (with different processing depending on the ROI type - see below). You can move it and resize the ROI anytime. You can delete the ROI with "right-click" and selecting "remove". You can change the saturation of the ROI with the upper right saturation bar. You can also just click on the ROI at any time to see what it looks like in the right view.
+
+By default, the "Compute multivideo SVD" box is unchecked. If you check it, then the motion SVD is computed across ALL videos - all videos are concatenated at each timepoint, and the SVD of this matrix of ALL_PIXELS x timepoints is computed. If you have just one video acquired at a time, then it is the SVD of this video.
+
+
+
+
+
+Once processing starts, the interface will no longer be clickable and all information about processing will be in the terminal in which you opened FaceMap:
+
+
+
+
+If you want to open the GUI with a movie file specified and/or save path specified, the following command will allow this:
+~~~
+python -m facemap --movie '/home/carsen/movie.avi' --savedir '/media/carsen/SSD/'
+~~~
+Note this will only work if you only have one file that you need to load (can't have multiple in series / multiple views).
+
+#### Processing movies captured simultaneously (multiple camera setups)
+
+Both GUIs will then ask *"are you processing multiple videos taken simultaneously?"*. If you say yes, then the script will look if across movies the **FIRST FOUR** letters of the filename vary. If the first four letters of two movies are the same, then the GUI assumed that they were acquired *sequentially* not *simultaneously*.
+
+Example file list:
++ cam1_G7c1_1.avi
++ cam1_G7c1_2.avi
++ cam2_G7c1_1.avi
++ cam2_G7c1_2.avi
++ cam3_G7c1_1.avi
++ cam3_G7c1_2.avi
+
+*"are you processing multiple videos taken simultaneously?"* ANSWER: Yes
+
+Then the GUIs assume {cam1_G7c1_1.avi, cam2_G7c1_1.avi, cam3_G7c1_1.avi} were acquired simultaneously and {cam1_G7c1_2.avi, cam2_G7c1_2.avi, cam3_G7c1_2.avi} were acquired simultaneously. They will be processed in alphabetical order (1 before 2) and the results from the videos will be concatenated in time. If one of these files was missing, then the GUI will error and you will have to choose file folders again. Also you will get errors if the files acquired at the same time aren't the same frame length (e.g. {cam1_G7c1_1.avi, cam2_G7c1_1.avi, cam3_G7c1_1.avi} should all have the same number of frames).
+
+Note: if you have many simultaneous videos / overall pixels (e.g. 2000 x 2000) you will need around 32GB of RAM to compute the full SVD motion masks.
+
+You will be able to see all the videos that were simultaneously collected at once. However, you can only draw ROIs that are within ONE video. Only the "multivideo SVD" is computed over all videos.
+
+
+##### Batch processing (python only)
+
+Load a video or a set of videos and draw your ROIs and choose your processing settings. Then click `save ROIs`. This will save a *_proc.npy file in the folder in the specified `save folder`. The name of this proc file will be listed below `process batch` (this button will also activate). You can then repeat this process: load the video(s), draw ROIs, choose settings, and click `save ROIs`. Then to process all the listed *_proc.npy files click `process batch`.
+
+#### Multivideo SVD ROIs
+
+Check box "Compute multivideo SVD" to compute the SVD of all pixels in all videos.
+
+The GUIs create one file for all videos (saved in current folder), the processed .npy file has name "_proc.npy" which contains:
+
+**PYTHON**:
+- **filenames**: list of lists of video filenames - each list are the videos taken simultaneously
+- **Ly**, **Lx**: list of number of pixels in Y (Ly) and X (Lx) for each video taken simultaneously
+- **sbin**: spatial bin size for motion SVDs
+- **Lybin**, **Lxbin**: list of number of pixels binned by sbin in Y (Ly) and X (Lx) for each video taken simultaneously
+- **sybin**, **sxbin**: coordinates of multivideo (for plotting/reshaping ONLY)
+- **LYbin**, **LXbin**: full-size of all videos embedded in rectangle (binned)
+- **fullSVD**: whether or not "multivideo SVD" is computed
+- **save_mat**: whether or not to save proc as *.mat file
+- **avgframe**: list of average frames for each video from a subset of frames (binned by sbin)
+- **avgframe_reshape**: average frame reshaped to be y-pixels x x-pixels
+- **avgmotion**: list of average motions for each video from a subset of frames (binned by sbin)
+- **avgmotion_reshape**: average motion reshaped to be y-pixels x x-pixels
+- **motion**: list of absolute motion energies across time - first is "multivideo" motion energy (empty if not computed)
+- **motSVD**: list of motion SVDs - first is "multivideo SVD" (empty if not computed) - each is nframes x components
+- **motMask**: list of motion masks for each motion SVD - each motMask is pixels x components
+- **motMask_reshape**: motion masks reshaped to be y-pixels x x-pixels x components
+- **pupil**: list of pupil ROI outputs - each is a dict with 'area', 'area_smooth', and 'com' (center-of-mass)
+- **blink**: list of blink ROI outputs - each is nframes, the blink area on each frame
+- **running**: list of running ROI outputs - each is nframes x 2, for X and Y motion on each frame
+- **rois**: ROIs that were drawn and computed
+ - *rind*: type of ROI in number
+ - *rtype*: what type of ROI ('motion SVD', 'pupil', 'blink', 'running')
+ - *ivid*: in which video is the ROI
+ - *color*: color of ROI
+ - *yrange*: y indices of ROI
+ - *xrange*: x indices of ROI
+ - *saturation*: saturation of ROI (0-255)
+ - *pupil_sigma*: number of stddevs used to compute pupil radius (for pupil ROIs)
+ - *yrange_bin*: binned indices in y (if motion SVD)
+ - *xrange_bin*: binned indices in x (if motion SVD)
+
+The above variables are related to motion energy, which uses the absolute value of differences across frames over time i.e. `abs(np.diff(frame(t+1) - frame(t)))`. To perform SVD computation for each frame over time use the flag `movSVD=True` (default=False) in the `process.run()` function call. Variables pertaining to movie SVDs include:
+- movSVD: list of movie SVDs - first is \"multivideo SVD\" (empty if not computed) - each is nframes x components
+- movMask: list of movie masks for each movie SVD - each movMask is pixels x component
+- movMask_reshape: movie masks reshaped to be y-pixels x x-pixels x components
+
+New variables:
+- motSv: array containign singular values for motSVD
+- movSv: array containign singular values for movSVD"
+
+ `process.run()` function call takes the following parameters:
+ - filenames: A 2D list of names of video(s) to get
+ - motSVD: default=True
+ - movSVD: default=False
+ - GUIobject=None
+ - parent: default=None, parent is from GUI
+ - proc: default=None, proc can be a saved ROI file from GUI
+ - savepath: default=None => set to video folder, specify a folder path in which to save _proc.npy
+Note this is a dict, so use the following command for loading processed output:
+```
+import numpy as np
+proc = np.load('cam1_proc.npy', allow_pickle=True).item()
+```
+
+The `_proc.npy` files can be loaded into the GUI (and will automatically be loaded after processing). The checkboxes in the lower left allow you to view different traces from the processed video(s).
+
+For example usage, see [notebook](https://github.com/MouseLand/facemap/blob/dev/notebooks/process.ipynb).
diff --git a/environment.yml b/environment.yml
index 4cc8a2d..e31c9f0 100644
--- a/environment.yml
+++ b/environment.yml
@@ -2,21 +2,23 @@ name: facemap
channels:
- conda-forge
dependencies:
- - python>3.6
+ - python>=3.6
- pip
- - mkl_fft=1.0.10
- - mkl=2019.3
- numpy>=1.16
- numba>=0.43.1
- scipy
- pyqt
- hdbscan
+ - pytables
- pip:
- pyqtgraph==0.12.0
- pyqt5
+ - torch>=1.9
+ - torchvision==0.12.0
- matplotlib
- natsort
- opencv_python_headless
- tqdm
- umap-learn
- pandas
+
diff --git a/facemap/.github/workflows/test_and_deploy.yml b/facemap/.github/workflows/test_and_deploy.yml
new file mode 100755
index 0000000..487a66b
--- /dev/null
+++ b/facemap/.github/workflows/test_and_deploy.yml
@@ -0,0 +1,65 @@
+# This workflows will upload a Python Package using Twine when a release is created
+# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
+
+name: tests
+
+on:
+ push:
+ branches:
+ - master
+ - main
+ tags:
+ - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10
+ pull_request:
+ branches:
+ - master
+ - main
+ workflow_dispatch:
+
+jobs:
+ test:
+ name: ${{ matrix.platform }} py${{ matrix.python-version }}
+ runs-on: ${{ matrix.platform }}
+ strategy:
+ matrix:
+ platform: [ubuntu-latest, windows-latest, macos-latest]
+ python-version: [3.7, 3.8]
+
+ steps:
+ - uses: actions/checkout@v2
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ # these libraries, along with pytest-xvfb (added in the `deps` in tox.ini),
+ # enable testing on Qt on linux
+ - name: Install Linux libraries
+ if: runner.os == 'Linux'
+ run: |
+ sudo apt-get install -y libdbus-1-3 libxkbcommon-x11-0 libxcb-icccm4 \
+ libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 \
+ libxcb-xinerama0 libxcb-xinput0 libxcb-xfixes0
+ # strategy borrowed from vispy for installing opengl libs on windows
+ - name: Install Windows OpenGL
+ if: runner.os == 'Windows'
+ run: |
+ git clone --depth 1 git://github.com/pyvista/gl-ci-helpers.git
+ powershell gl-ci-helpers/appveyor/install_opengl.ps1
+ # note: if you need dependencies from conda, considering using
+ # setup-miniconda: https://github.com/conda-incubator/setup-miniconda
+ # and
+ # tox-conda: https://github.com/tox-dev/tox-conda
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install setuptools tox tox-gh-actions
+ # this runs the platform-specific tests declared in tox.ini
+ - name: Test with tox
+ run: tox
+ env:
+ PLATFORM: ${{ matrix.platform }}
+
+ - name: Coverage
+ uses: codecov/codecov-action@v1
diff --git a/facemap/.gitignore b/facemap/.gitignore
new file mode 100755
index 0000000..881bd15
--- /dev/null
+++ b/facemap/.gitignore
@@ -0,0 +1,56 @@
+# Windows image file caches
+Thumbs.db
+ehthumbs.db
+
+# Folder config file
+Desktop.ini
+
+# Recycle Bin used on file shares
+$RECYCLE.BIN/
+
+# Compiled source
+*.mexa64
+*.mexw64
+*.asv
+
+# Windows shortcuts
+*.lnk
+
+# temporary emacs
+*~
+
+# python and jupyter
+jupyter/.ipynb_checkpoints/
+jupyter/__pycache__/
+.ipynb_checkpoints/
+__pycache__/
+dist/
+suite2p.egg-info/
+build/
+FaceMap/ops_user.npy
+*.ipynb
+
+# =========================
+# Operating System Files
+# =========================
+
+# OSX
+# =========================
+
+.DS_Store
+.AppleDouble
+.LSOverride
+
+# Thumbnails
+._*
+
+# Files that might appear on external disk
+.Spotlight-V100
+.Trashes
+
+# Directories potentially created on remote AFP share
+.AppleDB
+.AppleDesktop
+Network Trash Folder
+Temporary Items
+.apdisk
diff --git a/facemap/.vscode/settings.json b/facemap/.vscode/settings.json
old mode 100644
new mode 100755
diff --git a/facemap/__init__.py b/facemap/__init__.py
old mode 100644
new mode 100755
diff --git a/facemap/__main__.py b/facemap/__main__.py
old mode 100644
new mode 100755
index f1a6306..515a9df
--- a/facemap/__main__.py
+++ b/facemap/__main__.py
@@ -1,6 +1,7 @@
import numpy as np
import time, os
-from facemap import gui,process
+from .gui import gui
+from . import process
from scipy import stats
import argparse
@@ -18,6 +19,11 @@ def main():
parser.add_argument('--ops', default=[], type=str, help='options')
parser.add_argument('--movie', default=[], type=str, help='moviefile')
parser.add_argument('--savedir', default=[], type=str, help='savedir')
+
+ parser.add_argument('--poseGUI', dest='poseGUI', action='store_true', help="Pose GUI")
+ parser.add_argument('--no-poseGUI', dest='poseGUI', action='store_false', help="Pose CLI")
+ parser.set_defaults(poseGUI=True)
+
args = parser.parse_args()
if len(args.movie)>0:
@@ -29,6 +35,11 @@ def main():
else:
savedir = None
+ if args.poseGUI:
+ print("Running Facemap GUI w/ pose tracker")
+ else:
+ print("Running Facemap pose CLI")
+
ops = {}
if len(args.ops)>0:
ops = np.load(args.ops)
diff --git a/facemap/cluster.py b/facemap/cluster.py
old mode 100644
new mode 100755
index 32d5ef9..e2609ab
--- a/facemap/cluster.py
+++ b/facemap/cluster.py
@@ -4,12 +4,15 @@
import pyqtgraph as pg
#import pyqtgraph.opengl as gl
from sklearn.cluster import MiniBatchKMeans
-import hdbscan
+#import hdbscan
from matplotlib import cm
-import matplotlib.pyplot as plt
-from . import io, utils
+from . import utils
+from .gui import io
import cv2
import os
+from PyQt5.QtGui import QFont
+from PyQt5.QtWidgets import (QLabel, QPushButton, QRadioButton, QSpinBox, QButtonGroup,
+ QMessageBox, QLineEdit, QCheckBox)
class Cluster():
def __init__(self, parent, method=None, cluster_labels=None, cluster_labels_method=None, data_type=None):
@@ -25,62 +28,62 @@ def __init__(self, parent, method=None, cluster_labels=None, cluster_labels_meth
def create_clustering_widgets(self, parent):
# Add options to change params for embedding using user input
- parent.ClusteringLabel = QtGui.QLabel("Clustering")
+ parent.ClusteringLabel = QLabel("Clustering")
parent.ClusteringLabel.setStyleSheet("color: white;")
parent.ClusteringLabel.setAlignment(QtCore.Qt.AlignCenter)
- parent.ClusteringLabel.setFont(QtGui.QFont("Arial", 12, QtGui.QFont.Bold))
+ parent.ClusteringLabel.setFont(QFont("Arial", 12, QFont.Bold))
- parent.min_dist_label = QtGui.QLabel("min_dist:")
+ parent.min_dist_label = QLabel("min_dist:")
parent.min_dist_label.setStyleSheet("color: gray;")
- parent.min_dist_value = QtGui.QLineEdit()
+ parent.min_dist_value = QLineEdit()
parent.min_dist_value.setText(str(0.5))
parent.min_dist_value.setFixedWidth(50)
- parent.n_neighbors_label = QtGui.QLabel("n_neighbors:")
+ parent.n_neighbors_label = QLabel("n_neighbors:")
parent.n_neighbors_label.setStyleSheet("color: gray;")
- parent.n_neighbors_value = QtGui.QLineEdit()
+ parent.n_neighbors_value = QLineEdit()
parent.n_neighbors_value.setText(str(30))
parent.n_neighbors_value.setFixedWidth(50)
- parent.n_components_label = QtGui.QLabel("n_components:")
+ parent.n_components_label = QLabel("n_components:")
parent.n_components_label.setStyleSheet("color: gray;")
- parent.n_components_value = QtGui.QSpinBox()
+ parent.n_components_value = QSpinBox()
parent.n_components_value.setRange(2, 3)
parent.n_components_value.setValue(2)
parent.n_components_value.setFixedWidth(50)
#metric
- parent.cluster_method_label = QtGui.QLabel("Cluster labels")
+ parent.cluster_method_label = QLabel("Cluster labels")
parent.cluster_method_label.setStyleSheet("color: gray;")
parent.cluster_method_label.setAlignment(QtCore.Qt.AlignCenter)
- parent.load_umap_embedding_button = QtGui.QPushButton('Load emmbedding')
+ parent.load_embedding_button = QPushButton('Load emmbedding')
- parent.RadioGroup = QtGui.QButtonGroup()
- parent.load_cluster_labels_button = QtGui.QPushButton('Load')
- parent.loadlabels_radiobutton = QtGui.QRadioButton("Load labels")
+ parent.RadioGroup = QButtonGroup()
+ parent.load_cluster_labels_button = QPushButton('Load')
+ parent.loadlabels_radiobutton = QRadioButton("Load labels")
parent.loadlabels_radiobutton.setStyleSheet("color: gray;")
parent.RadioGroup.addButton(parent.loadlabels_radiobutton)
- parent.kmeans_radiobutton = QtGui.QRadioButton("KMeans")
+ parent.kmeans_radiobutton = QRadioButton("KMeans")
parent.kmeans_radiobutton.setStyleSheet("color: gray;")
parent.RadioGroup.addButton(parent.kmeans_radiobutton)
- parent.hdbscan_radiobutton = QtGui.QRadioButton("HDBSCAN")
- parent.hdbscan_radiobutton.setStyleSheet("color: gray;")
- parent.RadioGroup.addButton(parent.hdbscan_radiobutton)
+ #parent.hdbscan_radiobutton = QRadioButton("HDBSCAN")
+ #parent.hdbscan_radiobutton.setStyleSheet("color: gray;")
+ #parent.RadioGroup.addButton(parent.hdbscan_radiobutton)
- parent.min_cluster_size_label = QtGui.QLabel("min_cluster_size:")
+ parent.min_cluster_size_label = QLabel("min_cluster_size:")
parent.min_cluster_size_label.setStyleSheet("color: gray;")
- parent.min_cluster_size = QtGui.QLineEdit()
+ parent.min_cluster_size = QLineEdit()
parent.min_cluster_size.setFixedWidth(50)
parent.min_cluster_size.setText(str(500))
- parent.num_clusters_label = QtGui.QLabel("num_clusters:")
+ parent.num_clusters_label = QLabel("num_clusters:")
parent.num_clusters_label.setStyleSheet("color: gray;")
- parent.num_clusters = QtGui.QLineEdit()
+ parent.num_clusters = QLineEdit()
parent.num_clusters.setFixedWidth(50)
parent.num_clusters.setText(str(5))
- istretch = 11
+ istretch = 12
parent.l0.addWidget(parent.ClusteringLabel, istretch, 0, 1, 2)
parent.l0.addWidget(parent.min_dist_label, istretch+1, 0, 1, 1)
parent.l0.addWidget(parent.min_dist_value, istretch+1, 1, 1, 1)
@@ -88,22 +91,22 @@ def create_clustering_widgets(self, parent):
parent.l0.addWidget(parent.n_neighbors_value, istretch+2, 1, 1, 1)
parent.l0.addWidget(parent.n_components_label, istretch+3, 0, 1, 1)
parent.l0.addWidget(parent.n_components_value, istretch+3, 1, 1, 1)
- parent.l0.addWidget(parent.load_umap_embedding_button, istretch+4, 0, 1, 2)
+ parent.l0.addWidget(parent.load_embedding_button, istretch+4, 0, 1, 2)
parent.l0.addWidget(parent.cluster_method_label, istretch+5, 0, 1, 2)
parent.l0.addWidget(parent.loadlabels_radiobutton, istretch+6, 0, 1, 1)
parent.l0.addWidget(parent.load_cluster_labels_button, istretch+6, 1, 1, 1)
parent.l0.addWidget(parent.kmeans_radiobutton, istretch+7, 0, 1, 1)
- parent.l0.addWidget(parent.hdbscan_radiobutton, istretch+7, 1, 1, 1)
+ # parent.l0.addWidget(parent.hdbscan_radiobutton, istretch+7, 1, 1, 1)
parent.l0.addWidget(parent.min_cluster_size_label, istretch+8, 0, 1, 1)
parent.l0.addWidget(parent.min_cluster_size, istretch+8, 1, 1, 1)
parent.l0.addWidget(parent.num_clusters_label, istretch+8, 0, 1, 1)
parent.l0.addWidget(parent.num_clusters, istretch+8, 1, 1, 1)
self.hide_umap_param(parent)
- parent.load_umap_embedding_button.clicked.connect(lambda: self.load_umap(parent))
+ parent.load_embedding_button.clicked.connect(lambda: self.load_umap(parent))
parent.loadlabels_radiobutton.toggled.connect(lambda: self.show_cluster_method_param(parent))
parent.kmeans_radiobutton.toggled.connect(lambda: self.show_cluster_method_param(parent))
- parent.hdbscan_radiobutton.toggled.connect(lambda: self.show_cluster_method_param(parent))
+ #parent.hdbscan_radiobutton.toggled.connect(lambda: self.show_cluster_method_param(parent))
parent.load_cluster_labels_button.clicked.connect(lambda: self.load_cluster_labels(parent))
def load_umap(self, parent):
@@ -111,6 +114,10 @@ def load_umap(self, parent):
self.plot_clustering_output(parent)
def load_cluster_labels(self, parent):
+ try:
+ self.ClusteringPlot_legend.clear()
+ except Exception as e:
+ pass
io.load_cluster_labels(parent)
self.plot_clustering_output(parent)
@@ -130,10 +137,12 @@ def enable_data_clustering_features(self, parent):
parent.run_clustering_button.show()
- cluster_method = parent.clusteringVisComboBox.currentText() ######
- if cluster_method == "UMAP":
- #parent.data_clustering_combobox.show()
+ embed_method = parent.clusteringVisComboBox.currentText() ######
+ if embed_method == "UMAP":
self.show_umap_param(parent)
+ elif embed_method == "tSNE":
+ self.hide_umap_param(parent)
+ self.show_tsne_options(parent)
else:
self.disable_data_clustering_features(parent)
@@ -146,14 +155,14 @@ def show_processed_data(self, parent):
parent.data_clustering_combobox.setCurrentIndex(index)
if self.cluster_labels_method == "KMeans":
parent.kmeans_radiobutton.setChecked(True)
- elif self.cluster_labels_method == "HDBSCAN":
- parent.hdbscan_radiobutton.setChecked(True)
+ #elif self.cluster_labels_method == "HDBSCAN":
+ # parent.hdbscan_radiobutton.setChecked(True)
elif self.cluster_labels_method == "User labels":
parent.loadlabels_radiobutton.setChecked(True)
else:
parent.RadioGroup.setExclusive(False)
parent.kmeans_radiobutton.setChecked(False)
- parent.hdbscan_radiobutton.setChecked(False)
+ #parent.hdbscan_radiobutton.setChecked(False)
parent.loadlabels_radiobutton.setChecked(False)
parent.RadioGroup.setExclusive(True)
self.plot_clustering_output(parent)
@@ -161,6 +170,8 @@ def show_processed_data(self, parent):
def disable_data_clustering_features(self, parent):
parent.data_clustering_combobox.hide()
parent.ClusteringPlot.clear()
+ parent.zoom_in_button.hide()
+ parent.zoom_out_button.hide()
self.hide_umap_param(parent)
parent.run_clustering_button.hide()
parent.save_clustering_button.hide()
@@ -176,9 +187,11 @@ def show_umap_param(self, parent):
parent.cluster_method_label.show()
parent.loadlabels_radiobutton.show()
parent.kmeans_radiobutton.show()
- parent.hdbscan_radiobutton.show()
+ #parent.hdbscan_radiobutton.show()
self.show_cluster_method_param(parent)
- parent.load_umap_embedding_button.show()
+ parent.load_embedding_button.show()
+ parent.zoom_in_button.show()
+ parent.zoom_out_button.show()
def hide_umap_param(self, parent):
parent.ClusteringLabel.hide()
@@ -192,12 +205,15 @@ def hide_umap_param(self, parent):
parent.load_cluster_labels_button.hide()
parent.loadlabels_radiobutton.hide()
parent.kmeans_radiobutton.hide()
- parent.hdbscan_radiobutton.hide()
+ #parent.hdbscan_radiobutton.hide()
parent.num_clusters_label.hide()
parent.num_clusters.hide()
parent.min_cluster_size_label.hide()
parent.min_cluster_size.hide()
- parent.load_umap_embedding_button.hide()
+ parent.load_embedding_button.hide()
+
+ def show_tsne_options(self, parent):
+ parent.load_embedding_button.show()
def show_cluster_method_param(self, parent):
if parent.loadlabels_radiobutton.isChecked():
@@ -212,14 +228,16 @@ def show_cluster_method_param(self, parent):
parent.load_cluster_labels_button.hide()
parent.num_clusters_label.show()
parent.num_clusters.show()
+ else:
+ return
+ """
elif parent.hdbscan_radiobutton.isChecked():
parent.num_clusters_label.hide()
parent.num_clusters.hide()
parent.load_cluster_labels_button.hide()
parent.min_cluster_size_label.show()
parent.min_cluster_size.show()
- else:
- return
+ """
def get_cluster_labels(self, data, parent):
try:
@@ -230,21 +248,17 @@ def get_cluster_labels(self, data, parent):
batch_size=100, max_iter=50)
kmeans.fit(data)
self.cluster_labels = kmeans.labels_
- elif parent.hdbscan_radiobutton.isChecked():
- self.cluster_labels_method = "HDBSCAN"
- clusterer = hdbscan.HDBSCAN(min_cluster_size=int(parent.min_cluster_size.text())).fit(data)
- self.cluster_labels = clusterer.labels_
elif parent.loadlabels_radiobutton.isChecked():
if parent.is_cluster_labels_loaded:
self.cluster_labels = parent.loaded_cluster_labels
self.cluster_labels_method = "User labels"
else:
- QtGui.QMessageBox.about(parent, 'Error','Please load cluster labels file')
+ QMessageBox.about(parent, 'Error','Please load cluster labels file')
pass
else:
return
except Exception as e:
- QtGui.QMessageBox.about(parent, 'Error','Invalid input entered')
+ QMessageBox.about(parent, 'Error','Invalid input entered')
print(e)
pass
@@ -253,11 +267,15 @@ def get_colors(self):
colors = cm.get_cmap('gist_rainbow')(np.linspace(0, 1., num_classes))
colors *= 255
colors = colors.astype(int)
- colors[:,-1] = 127
+ colors[:,-1] = 200#127
brushes = [pg.mkBrush(color=c) for c in colors]
+ """
+ num_classes = len(np.unique(self.cluster_labels))
+ brushes = [pg.mkBrush(color=c) for c in colors_list[:num_classes]]
+ colors = colors_list[:num_classes]"""
#if -1 in np.unique(self.cluster_labels):
# brushes[0] = pg.mkBrush(color=(220,220,220))
- return brushes
+ return brushes, colors
def reset(self, parent):
self.cluster_labels = None
@@ -278,7 +296,7 @@ def set_variables(self, parent):
self.data_type = parent.data_clustering_combobox.currentText()
self.cluster_method = parent.clusteringVisComboBox.currentText()
except Exception as e:
- QtGui.QMessageBox.about(parent, 'Error','Parameter input can only be a number')
+ QMessageBox.about(parent, 'Error','Parameter input can only be a number')
print(e)
pass
@@ -297,10 +315,10 @@ def run(self, clicked, parent):
"""
else:
self.data_type = None
- msg = QtGui.QMessageBox(parent)
- msg.setIcon(QtGui.QMessageBox.Warning)
+ msg = QMessageBox(parent)
+ msg.setIcon(QMessageBox.Warning)
msg.setText("Please select data for clustering")
- msg.setStandardButtons(QtGui.QMessageBox.Ok)
+ msg.setStandardButtons(QMessageBox.Ok)
msg.exec_()
return
if self.cluster_method == "UMAP":
@@ -329,48 +347,66 @@ def plot_clustering_output(self, parent):
name = None
# Get cluster labels if clustering method selected for embedded output
- if parent.kmeans_radiobutton.isChecked() or parent.hdbscan_radiobutton.isChecked() or parent.loadlabels_radiobutton.isChecked():
+ if parent.kmeans_radiobutton.isChecked() or parent.loadlabels_radiobutton.isChecked():
self.get_cluster_labels(self.embedded_output, parent)
- brushes = self.get_colors()
+ brushes, colors = self.get_colors()
name = self.cluster_labels
if len(brushes) > 1:
is_cluster_colored = True
# Plot output (i) w/ cluster labels (ii) w/o cluster labels and (iii) 3D output
if num_comps == 2:
+ # Set pixel size of embedded points on plot
+ if self.embedded_output.shape[0]<500:
+ point_size=6
+ elif self.embedded_output.shape[0]<2000:
+ point_size=4
+ else:
+ point_size=2
if is_cluster_colored:
scatter_plots = []
- if max(self.cluster_labels) > 4: #Adjust legend
- legend_num_row = 4
- legend_num_col = int(np.ceil(max(self.cluster_labels)/legend_num_row))+1
+ if len(np.unique(self.cluster_labels)) > 9: #Adjust legend
+ legend_num_col = 9
+ legend_num_row = int(len(np.unique(self.cluster_labels))/legend_num_col)+1
else:
- legend_num_col, legend_num_row = [1, 1]
- parent.ClusteringPlot_legend = pg.LegendItem(labelTextSize='12pt', horSpacing=12,
+ legend_num_col, legend_num_row = [len(np.unique(self.cluster_labels)), 1]
+ parent.ClusteringPlot_legend = pg.LegendItem(labelTextSize='11pt', horSpacing=10,
colCount=legend_num_col, rowCount=legend_num_row)
+ parent.ClusteringPlot_legend.setPos(0,20)
for i, cluster in enumerate(np.unique(self.cluster_labels)):#range(max(self.cluster_labels)+1):
+ print("cluster", cluster)
ind = np.where(self.cluster_labels==cluster)[0]
data = self.embedded_output[ind,:]
if cluster == -1:
scatter_plots.append(pg.ScatterPlotItem(data[:,0], data[:,1], symbol='o', brush=pg.mkBrush(color=(0,1,1,1)),
- hoverable=True, hoverSize=15, pen=(0,.0001,0,0), data=ind, name=str(i))) #pg.mkPen(pg.hsvColor(hue=.01,sat=.01,alpha=0.01))
+ hoverable=True, hoverSize=15, hoverSymbol="x", hoverBrush='r',
+ pen=(0,.0001,0,0), data=ind, name=str(cluster)),size=point_size) #pg.mkPen(pg.hsvColor(hue=.01,sat=.01,alpha=0.01))
else:
scatter_plots.append(pg.ScatterPlotItem(data[:,0], data[:,1], symbol='o', brush=brushes[i],
- hoverable=True, hoverSize=15, data=ind, name=str(i)))
+ hoverable=True, hoverSize=15, hoverSymbol="x", hoverBrush='r',
+ data=ind, name=str(cluster), size=point_size))
parent.ClusteringPlot.addItem(scatter_plots[i])
- parent.ClusteringPlot_legend.addItem(scatter_plots[i], name=str(i))
+ parent.ClusteringPlot_legend.addItem(scatter_plots[i], name=str(cluster))
# Add all points (transparent) to connect them to hovered function
- parent.clustering_scatterplot.setData(self.embedded_output[:,0], self.embedded_output[:,1], symbol='o', brush=(0,0,0,0),
- hoverable=True, hoverSize=15, pen=(0,0,0,0), data=np.arange(num_feat),name=name)
+ parent.clustering_scatterplot.setData(self.embedded_output[:,0], self.embedded_output[:,1], symbol='o',
+ brush=(0,0,0,0),pxMode=True, hoverable=True, hoverSize=15,
+ hoverSymbol="x", hoverBrush='r',pen=(0,0,0,0),
+ data=np.arange(num_feat), name=name,size=point_size)
parent.ClusteringPlot.addItem(parent.clustering_scatterplot)
+ parent.ClusteringPlot.addItem(parent.clustering_highlight_scatterplot)
parent.ClusteringPlot_legend.setPos(parent.clustering_scatterplot.x()+5,parent.clustering_scatterplot.y())
parent.ClusteringPlot_legend.setParentItem(parent.ClusteringPlot)
+ parent.plot_cluster_labels_p1(self.cluster_labels, colors)
else:
- parent.clustering_scatterplot.setData(self.embedded_output[:,0], self.embedded_output[:,1], symbol='o', brush=all_spots_colors,
- hoverable=True, hoverSize=15, data=np.arange(num_feat),name=name)
+ parent.clustering_scatterplot.setData(self.embedded_output[:,0], self.embedded_output[:,1], symbol='o',
+ brush=all_spots_colors,pxMode=True,hoverable=True,
+ hoverSize=15, hoverSymbol="x",hoverBrush='r',
+ data=np.arange(num_feat),name=name, size=point_size)
parent.ClusteringPlot.addItem(parent.clustering_scatterplot)
+ parent.ClusteringPlot.addItem(parent.clustering_highlight_scatterplot)
parent.ClusteringPlot.showAxis('left')
parent.ClusteringPlot.showAxis('bottom')
- parent.ClusteringPlot.setLabels(bottom='UMAP coordinate 1',left='UMAP coordinate 2')
+ parent.ClusteringPlot.setLabels(bottom='Dimension 1',left='Dimension 2')
else: # 3D embedded visualization
view = gl.GLViewWidget()
view.setWindowTitle("3D plot of embedded points")
@@ -385,7 +421,6 @@ def plot_clustering_output(self, parent):
parent.save_clustering_button.show()
def embedded_points_hovered(self, obj, ev, parent):
-
"""
point_hovered = np.where(parent.clustering_scatterplot.data['hovered'])[0]
if point_hovered.shape[0] >= 1: # Show tooltip only when hovering over a point i.e. no empty array
@@ -440,24 +475,24 @@ def save_dialog(self, clicked, parent):
dialog.label.setText("Save files:")
dialog.verticalLayout.addWidget(dialog.label)
- dialog.data_checkbox = QtGui.QCheckBox("Cluster data (*.npy)")
+ dialog.data_checkbox = QCheckBox("Cluster data (*.npy)")
dialog.data_checkbox.setChecked(True)
- dialog.videos_checkbox = QtGui.QCheckBox("Cluster videos (*.avi)")
+ dialog.videos_checkbox = QCheckBox("Cluster videos (*.avi)")
dialog.videos_checkbox.stateChanged.connect(lambda: self.enable_video_options(dialog))
- dialog.num_frames_label = QtGui.QLabel("#Frames/cluster:")
- dialog.num_frames = QtGui.QLineEdit()
+ dialog.num_frames_label = QLabel("#Frames/cluster:")
+ dialog.num_frames = QLineEdit()
dialog.num_frames.setText(str(30))
dialog.num_frames.setEnabled(False)
- dialog.fps_label = QtGui.QLabel("FPS:")
- dialog.fps = QtGui.QLineEdit()
+ dialog.fps_label = QLabel("FPS:")
+ dialog.fps = QLineEdit()
dialog.fps.setText(str(10.0))
dialog.fps.setEnabled(False)
- dialog.ok_button = QtGui.QPushButton('Ok')
+ dialog.ok_button = QPushButton('Ok')
dialog.ok_button.setDefault(True)
dialog.ok_button.clicked.connect(lambda: self.ok_save(dialog, parent))
- dialog.cancel_button = QtGui.QPushButton('Cancel')
+ dialog.cancel_button = QPushButton('Cancel')
dialog.cancel_button.clicked.connect(dialog.close)
# Add options to dialog box
@@ -495,14 +530,14 @@ def ok_save(self, dialogBox, parent):
try:
self.save_cluster_video(parent, float(dialogBox.fps.text()), int(dialogBox.num_frames.text()))
except Exception as e:
- QtGui.QMessageBox.about(parent, 'Error','Invalid input entered')
+ QMessageBox.about(parent, 'Error','Invalid input entered')
print(e)
pass
else:
- msg = QtGui.QMessageBox(parent)
- msg.setIcon(QtGui.QMessageBox.Warning)
+ msg = QMessageBox(parent)
+ msg.setIcon(QMessageBox.Warning)
msg.setText("Please generate cluster labels for saving cluster videos")
- msg.setStandardButtons(QtGui.QMessageBox.Ok)
+ msg.setStandardButtons(QMessageBox.Ok)
msg.exec_()
if dialogBox.data_checkbox.isChecked():
self.save_cluster_output(parent)
diff --git a/facemap/gui.py b/facemap/gui/gui.py
old mode 100644
new mode 100755
similarity index 52%
rename from facemap/gui.py
rename to facemap/gui/gui.py
index aa1b558..8e9ae5c
--- a/facemap/gui.py
+++ b/facemap/gui/gui.py
@@ -1,25 +1,29 @@
-import sys, os, shutil, glob, time
+import sys, os
import numpy as np
-from PyQt5 import QtGui, QtCore
+from PyQt5 import QtGui, QtCore, QtWidgets
import pyqtgraph as pg
from scipy.stats import zscore, skew
from matplotlib import cm
-from natsort import natsorted
-import pathlib
-import cv2
+import matplotlib.pyplot as plt
import pandas as pd
-from PyQt5.QtGui import QPixmap
-from . import process, roi, utils, io, menus, guiparts, cluster
-
-istr = ['pupil', 'motSVD', 'blink', 'running']
-
-class MainW(QtGui.QMainWindow):
+from .. import process, roi, utils, cluster
+from ..pose import pose_gui, pose
+from . import io, menus, guiparts
+from PyQt5.QtGui import QPixmap, QFont, QPainterPath, QIcon, QColor
+from PyQt5.QtWidgets import ( QLabel, QPushButton, QLineEdit, QCheckBox,
+ QComboBox, QToolButton, QStatusBar, QSlider,
+ QProgressBar, QSpinBox, QMessageBox, QButtonGroup,
+ QGridLayout, QWidget, QPushButton, QWidget)
+
+istr = ['pupil', 'motSVD', 'blink', 'running', 'movSVD']
+
+class MainW(QtWidgets.QMainWindow):
def __init__(self, moviefile=None, savedir=None):
super(MainW, self).__init__()
icon_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "mouse.png"
)
- app_icon = QtGui.QIcon()
+ app_icon = QIcon()
app_icon.addFile(icon_path, QtCore.QSize(16, 16))
app_icon.addFile(icon_path, QtCore.QSize(24, 24))
app_icon.addFile(icon_path, QtCore.QSize(32, 32))
@@ -51,21 +55,20 @@ def __init__(self, moviefile=None, savedir=None):
'save_path': '', 'save_mat': False}
self.save_path = self.ops['save_path']
- self.DLC_filepath = ""
menus.mainmenu(self)
self.online_mode=False
#menus.onlinemenu(self)
- self.cwidget = QtGui.QWidget(self)
+ self.cwidget = QWidget(self)
self.setCentralWidget(self.cwidget)
- self.l0 = QtGui.QGridLayout()
+ self.l0 = QGridLayout()
self.cwidget.setLayout(self.l0)
# --- cells image
self.win = pg.GraphicsLayoutWidget()
self.win.move(600,0)
self.win.resize(1000,500)
- self.l0.addWidget(self.win,1,2,27,15)
+ self.l0.addWidget(self.win,1,2,25,15)
layout = self.win.ci.layout
# Add logo
@@ -103,22 +106,22 @@ def __init__(self, moviefile=None, savedir=None):
for j in range(2):
self.sl.append(guiparts.Slider(j, self))
self.l0.addWidget(self.sl[j],1,3+3*j,1,2)#+5*j,1,2)
- qlabel = QtGui.QLabel(txt[j])
+ qlabel = QLabel(txt[j])
qlabel.setStyleSheet('color: white;')
self.l0.addWidget(qlabel,0,3+3*j,1,1)
self.sl[0].valueChanged.connect(self.set_saturation_label)
self.sl[1].valueChanged.connect(self.set_ROI_saturation_label)
# Add label to indicate saturation level
- self.saturationLevelLabel = QtGui.QLabel(str(self.sl[0].value()))
+ self.saturationLevelLabel = QLabel(str(self.sl[0].value()))
self.saturationLevelLabel.setStyleSheet("color: white;")
self.l0.addWidget(self.saturationLevelLabel,0,5,1,1)
- self.roiSaturationLevelLabel = QtGui.QLabel(str(self.sl[1].value()))
+ self.roiSaturationLevelLabel = QLabel(str(self.sl[1].value()))
self.roiSaturationLevelLabel.setStyleSheet("color: white;")
self.l0.addWidget(self.roiSaturationLevelLabel,0,8,1,1)
# Reflector
- self.reflector = QtGui.QPushButton('Add corneal reflection')
+ self.reflector = QPushButton('Add corneal reflection')
self.reflector.setEnabled(False)
self.reflector.clicked.connect(self.add_reflectROI)
self.rROI=[]
@@ -142,19 +145,29 @@ def __init__(self, moviefile=None, savedir=None):
self.nframes = 0
self.cframe = 0
- ## DLC plot
- self.DLC_scatterplot = pg.ScatterPlotItem(hover=True)
- self.DLC_scatterplot.sigClicked.connect(self.DLC_points_clicked)
- self.DLC_scatterplot.sigHovered.connect(self.DLC_points_hovered)
- self.make_buttons()
-
+ ## Pose plot
+ self.Pose_scatterplot = pg.ScatterPlotItem(hover=True)
+ self.Pose_scatterplot.sigClicked.connect(self.keypoints_clicked)
+ self.Pose_scatterplot.sigHovered.connect(self.keypoints_hovered)
+ self.pose_model = None
+ self.poseFilepath = []
+ self.keypoints_labels = []
+ self.pose_x_coord = []
+ self.pose_y_coord = []
+ self.pose_likelihood = []
+ self.keypoints_brushes = []
+ self.bbox = []
+ self.bbox_set = False
+
self.ClusteringPlot = self.win.addPlot(row=0, col=1, lockAspect=True, enableMouse=False)
self.ClusteringPlot.hideAxis('left')
self.ClusteringPlot.hideAxis('bottom')
self.clustering_scatterplot = pg.ScatterPlotItem(hover=True)
- #self.clustering_scatterplot.sigClicked.connect(lambda obj, ev: cluster.embeddedPointsClicked(obj, ev, self))
+ #self.clustering_scatterplot.sigClicked.connect(lambda obj, ev: self.cluster_model.highlight_embedded_point(obj, ev, parent=self))
self.clustering_scatterplot.sigHovered.connect(lambda obj, ev: self.cluster_model.embedded_points_hovered(obj, ev, parent=self))
#self.ClusteringPlot.scene().sigMouseMoved.connect(lambda pos: self.cluster_model.mouse_moved_embedding(pos, parent=self))
+ self.clustering_highlight_scatterplot = pg.ScatterPlotItem(hover=True)
+ self.clustering_highlight_scatterplot.sigHovered.connect(lambda obj, ev: self.cluster_model.embedded_points_hovered(obj, ev, parent=self))
self.ClusteringPlot_legend = pg.LegendItem(labelTextSize='12pt', title="Cluster")
self.cluster_model = cluster.Cluster(parent=self)
@@ -171,19 +184,20 @@ def __init__(self, moviefile=None, savedir=None):
self.show()
self.processed = False
if moviefile is not None:
- self.load_movies([[moviefile]])
+ io.load_movies(self, [[moviefile]])
if savedir is not None:
self.save_path = savedir
self.savelabel.setText("..."+savedir[-20:])
# Status bar
- self.statusBar = QtGui.QStatusBar()
+ self.statusBar = QStatusBar()
self.setStatusBar(self.statusBar)
- self.progressBar = QtGui.QProgressBar()
+ self.progressBar = QProgressBar()
self.statusBar.addPermanentWidget(self.progressBar)
self.progressBar.setGeometry(0, 0, 300, 25)
self.progressBar.setMaximum(100)
self.progressBar.hide()
+ self.make_buttons()
def set_saturation_label(self):
self.saturationLevelLabel.setText(str(self.sl[0].value()))
@@ -196,302 +210,242 @@ def set_ROI_saturation_label(self, val=None):
def make_buttons(self):
# create frame slider
- VideoLabel = QtGui.QLabel("Analyze Videos")
+ VideoLabel = QLabel("Facemap - SVDs & Tracker")
VideoLabel.setStyleSheet("color: white;")
VideoLabel.setAlignment(QtCore.Qt.AlignCenter)
- VideoLabel.setFont(QtGui.QFont("Arial", 12, QtGui.QFont.Bold))
- #fileIOlabel = QtGui.QLabel("File I/O")
- #fileIOlabel.setStyleSheet("color: white;")
- #fileIOlabel.setAlignment(QtCore.Qt.AlignCenter)
- #fileIOlabel.setFont(QtGui.QFont("Arial", 12, QtGui.QFont.Bold))
- SVDbinLabel = QtGui.QLabel("SVD spatial bin:")
+ VideoLabel.setFont(QFont("Arial", 16, QFont.Bold))
+ SVDbinLabel = QLabel("SVD spatial bin:")
SVDbinLabel.setStyleSheet("color: gray;")
- self.binSpinBox = QtGui.QSpinBox()
+ self.binSpinBox = QSpinBox()
self.binSpinBox.setRange(1, 20)
self.binSpinBox.setValue(self.ops['sbin'])
self.binSpinBox.setFixedWidth(50)
- binLabel = QtGui.QLabel("Pupil sigma:")
+ binLabel = QLabel("Pupil sigma:")
binLabel.setStyleSheet("color: gray;")
- self.sigmaBox = QtGui.QLineEdit()
+ self.sigmaBox = QLineEdit()
self.sigmaBox.setText(str(self.ops['pupil_sigma']))
self.sigmaBox.setFixedWidth(45)
self.pupil_sigma = float(self.sigmaBox.text())
self.sigmaBox.returnPressed.connect(self.pupil_sigma_change)
- self.frameLabel = QtGui.QLabel("Frame:")
- self.frameLabel.setStyleSheet("color: white;")
- self.totalFrameLabel = QtGui.QLabel("Total frames:")
- self.totalFrameLabel.setStyleSheet("color: white;")
- self.setFrame = QtGui.QLineEdit()
+ self.setFrame = QLineEdit()
self.setFrame.setMaxLength(10)
self.setFrame.setFixedWidth(50)
self.setFrame.textChanged[str].connect(self.set_frame_changed)
- self.totalFrameNumber = QtGui.QLabel("0") #########
+ self.totalFrameNumber = QLabel("0") #########
self.totalFrameNumber.setStyleSheet("color: white;") #########
- self.frameSlider = QtGui.QSlider(QtCore.Qt.Horizontal)
+ self.frameSlider = QSlider(QtCore.Qt.Horizontal)
self.frameSlider.setTickInterval(5)
self.frameSlider.setTracking(False)
self.frameSlider.valueChanged.connect(self.go_to_frame)
self.frameDelta = 10
- istretch = 19
+ istretch = 20
iplay = istretch+10
iconSize = QtCore.QSize(20, 20)
- self.process = QtGui.QPushButton('process ROIs')
- self.process.setFont(QtGui.QFont("Arial", 10, QtGui.QFont.Bold))
+ self.process = QPushButton('process')
+ self.process.setFont(QFont("Arial", 10, QFont.Bold))
self.process.clicked.connect(self.process_ROIs)
self.process.setEnabled(False)
- self.savefolder = QtGui.QPushButton("Output folder \u2b07")
- self.savefolder.setFont(QtGui.QFont("Arial", 10, QtGui.QFont.Bold))
- self.savefolder.clicked.connect(self.save_folder)
- self.savefolder.setEnabled(False)
- self.savelabel = QtGui.QLabel('same as video')
+ self.savelabel = QLabel('same as video')
self.savelabel.setStyleSheet("color: white;")
self.savelabel.setAlignment(QtCore.Qt.AlignCenter)
- self.saverois = QtGui.QPushButton('save ROIs')
- self.saverois.setFont(QtGui.QFont("Arial", 10, QtGui.QFont.Bold))
+ self.saverois = QPushButton('save ROIs')
+ self.saverois.setFont(QFont("Arial", 10, QFont.Bold))
self.saverois.clicked.connect(self.save_ROIs)
self.saverois.setEnabled(False)
- # DLC features
- self.loadDLC = QtGui.QPushButton("Load DLC data")
- self.loadDLC.setFont(QtGui.QFont("Arial", 10, QtGui.QFont.Bold))
- self.loadDLC.clicked.connect(self.get_DLC_file)
- self.loadDLC.setEnabled(False)
- self.DLC_file_loaded = False
- self.DLClabels_checkBox = QtGui.QCheckBox("Labels")
- self.DLClabels_checkBox.setStyleSheet("color: gray;")
- self.DLClabels_checkBox.stateChanged.connect(self.update_DLC_points)
- self.DLClabels_checkBox.setEnabled(False)
+ # Pose/labels variables
+ self.poseFileLoaded = False
+ self.Labels_checkBox = QCheckBox("Keypoints")
+ self.Labels_checkBox.setStyleSheet("color: gray;")
+ self.Labels_checkBox.stateChanged.connect(self.update_pose)
+ self.Labels_checkBox.setEnabled(False)
# Process features
self.batchlist=[]
- """
self.batchname=[]
- for k in range(6):
- self.batchname.append(QtGui.QLabel(''))
+ for k in range(5):
+ self.batchname.append(QLabel(''))
self.batchname[-1].setStyleSheet("color: white;")
- self.l0.addWidget(self.batchname[-1],18+k,0,1,4)
- """
+ self.l0.addWidget(self.batchname[-1],9+k,0,1,4)
- self.processbatch = QtGui.QPushButton(u"process batch \u2b07")
- self.processbatch.setFont(QtGui.QFont("Arial", 10, QtGui.QFont.Bold))
+ self.processbatch = QPushButton(u"process batch \u2b07")
+ self.processbatch.setFont(QFont("Arial", 10, QFont.Bold))
self.processbatch.clicked.connect(self.process_batch)
self.processbatch.setEnabled(False)
# Play/pause features
iconSize = QtCore.QSize(30, 30)
- self.playButton = QtGui.QToolButton()
- self.playButton.setIcon(self.style().standardIcon(QtGui.QStyle.SP_MediaPlay))
+ self.playButton = QToolButton()
+ self.playButton.setIcon(self.style().standardIcon(QtWidgets.QStyle.SP_MediaPlay))
self.playButton.setIconSize(iconSize)
self.playButton.setToolTip("Play")
self.playButton.setCheckable(True)
self.playButton.clicked.connect(self.start)
- self.pauseButton = QtGui.QToolButton()
+ self.pauseButton = QToolButton()
self.pauseButton.setCheckable(True)
- self.pauseButton.setIcon(self.style().standardIcon(QtGui.QStyle.SP_MediaPause))
+ self.pauseButton.setIcon(self.style().standardIcon(QtWidgets.QStyle.SP_MediaPause))
self.pauseButton.setIconSize(iconSize)
self.pauseButton.setToolTip("Pause")
self.pauseButton.clicked.connect(self.pause)
- btns = QtGui.QButtonGroup(self)
+ btns = QButtonGroup(self)
btns.addButton(self.playButton,0)
btns.addButton(self.pauseButton,1)
btns.setExclusive(True)
- # Add ROI features
- self.comboBox = QtGui.QComboBox(self)
- self.comboBox.setFixedWidth(100)
+ # Create ROI features
+ self.comboBox = QComboBox(self)
+ self.comboBox.setFixedWidth(110)
self.comboBox.addItem("Select ROI")
self.comboBox.addItem("Pupil")
self.comboBox.addItem("motion SVD")
self.comboBox.addItem("Blink")
self.comboBox.addItem("Running")
+ self.comboBox.addItem("Face (pose)")
self.newROI = 0
self.comboBox.setCurrentIndex(0)
#self.comboBox.currentIndexChanged.connect(self.mode_change)
- self.addROI = QtGui.QPushButton("Add ROI")
- self.addROI.setFont(QtGui.QFont("Arial", 10, QtGui.QFont.Bold))
- self.addROI.clicked.connect(self.add_ROI)
+ self.addROI = QPushButton("Add ROI")
+ self.addROI.setFont(QFont("Arial", 10, QFont.Bold))
+ self.addROI.clicked.connect(lambda clicked: self.add_ROI())
self.addROI.setFixedWidth(70)
self.addROI.setEnabled(False)
# Add clustering analysis/visualization features
- self.clusteringVisComboBox = QtGui.QComboBox(self)
+ self.clusteringVisComboBox = QComboBox(self)
self.clusteringVisComboBox.setFixedWidth(200)
self.clusteringVisComboBox.addItem("--Select display--")
self.clusteringVisComboBox.addItem("ROI")
self.clusteringVisComboBox.addItem("UMAP")
+ self.clusteringVisComboBox.addItem("tSNE")
self.clusteringVisComboBox.currentIndexChanged.connect(self.vis_combobox_selection_changed)
self.clusteringVisComboBox.setFixedWidth(140)
- self.roiVisComboBox = QtGui.QComboBox(self)
+ self.roiVisComboBox = QComboBox(self)
self.roiVisComboBox.setFixedWidth(100)
self.roiVisComboBox.hide()
self.roiVisComboBox.activated.connect(self.display_ROI)
- self.run_clustering_button = QtGui.QPushButton("Run")
- self.run_clustering_button.setFont(QtGui.QFont("Arial", 10, QtGui.QFont.Bold))
+ self.run_clustering_button = QPushButton("Run")
+ self.run_clustering_button.setFont(QFont("Arial", 10, QFont.Bold))
self.run_clustering_button.clicked.connect(lambda clicked: self.cluster_model.run(clicked, self))
self.run_clustering_button.hide()
- self.save_clustering_button = QtGui.QPushButton("Save")
- self.save_clustering_button.setFont(QtGui.QFont("Arial", 10, QtGui.QFont.Bold))
+ self.save_clustering_button = QPushButton("Save")
+ self.save_clustering_button.setFont(QFont("Arial", 10, QFont.Bold))
self.save_clustering_button.clicked.connect(lambda clicked: self.cluster_model.save_dialog(clicked, self))
self.save_clustering_button.hide()
- self.data_clustering_combobox = QtGui.QComboBox(self)
+ self.data_clustering_combobox = QComboBox(self)
self.data_clustering_combobox.setFixedWidth(100)
self.data_clustering_combobox.hide()
+ self.zoom_in_button = QPushButton('+')
+ self.zoom_in_button.setMaximumWidth(int(0.3*self.data_clustering_combobox.width()))
+ self.zoom_in_button.clicked.connect(lambda clicked: self.cluster_plot_zoom_buttons("in"))
+ self.zoom_in_button.hide()
+ self.zoom_out_button = QPushButton('-')
+ self.zoom_out_button.setMaximumWidth(int(0.3*self.data_clustering_combobox.width()))
+ self.zoom_out_button.clicked.connect(lambda clicked: self.cluster_plot_zoom_buttons("out"))
+ self.zoom_out_button.hide()
# Check boxes
- self.checkBox = QtGui.QCheckBox("multivideo SVD")
+ self.checkBox = QCheckBox("multivideo SVD")
self.checkBox.setStyleSheet("color: gray;")
if self.ops['fullSVD']:
self.checkBox.toggle()
- self.save_mat = QtGui.QCheckBox("Save *.mat")
+ self.save_mat = QCheckBox("Save *.mat")
self.save_mat.setStyleSheet("color: gray;")
if self.ops['save_mat']:
self.save_mat.toggle()
- self.motSVD_checkbox = QtGui.QCheckBox("motSVD")
+ self.motSVD_checkbox = QCheckBox("motSVD")
self.motSVD_checkbox.setStyleSheet("color: gray;")
- self.motSVD_checkbox.setChecked(True)
- self.movSVD_checkbox = QtGui.QCheckBox("movSVD")
+ self.movSVD_checkbox = QCheckBox("movSVD")
self.movSVD_checkbox.setStyleSheet("color: gray;")
# Add features to window
+ # ~~~~~~~~~~ motsvd/movsvd options ~~~~~~~~~~
self.l0.addWidget(VideoLabel,0,0,1,2)
self.l0.addWidget(self.comboBox,1,0,1,2)
self.l0.addWidget(self.addROI,1,1,1,1)
- self.l0.addWidget(self.reflector, 2, 0, 1, 2)
- self.l0.addWidget(SVDbinLabel, 3, 0, 1, 2)
- self.l0.addWidget(self.binSpinBox,3, 1, 1, 2)
- self.l0.addWidget(binLabel, 4, 0, 1, 1)
- self.l0.addWidget(self.sigmaBox, 4, 1, 1, 1)
- self.l0.addWidget(self.motSVD_checkbox, 5, 0, 1, 1)
- self.l0.addWidget(self.movSVD_checkbox, 5, 1, 1, 1)
- self.l0.addWidget(self.checkBox, 6, 0, 1, 1)
- self.l0.addWidget(self.save_mat, 6, 1, 1, 1)
- self.l0.addWidget(self.saverois, 7, 0, 1, 1)
- self.l0.addWidget(self.process, 7, 1, 1, 1)
- self.l0.addWidget(self.processbatch, 8, 0, 1, 1)
-
- self.l0.addWidget(self.savefolder, 8, 1, 1, 1)
- self.l0.addWidget(self.savelabel, 9, 0, 1, 2)
- self.l0.addWidget(self.loadDLC, 10, 0, 1, 1) # DLC features
- self.l0.addWidget(self.DLClabels_checkBox, 10, 1, 1, 1)
- self.l0.addWidget(self.clusteringVisComboBox, 0, 11, 1, 1) # clustering visualization window features
- self.l0.addWidget(self.data_clustering_combobox, 0, 12, 1, 2) # clustering visualization window features
- self.l0.addWidget(self.roiVisComboBox, 0, 12, 1, 2) # ROI visualization window features
- self.l0.addWidget(self.run_clustering_button, 0, 14, 1, 1) # clustering visualization window features
- self.l0.addWidget(self.save_clustering_button, 0, 15, 1, 1) # clustering visualization window features
+ self.l0.addWidget(self.reflector, 0, 14, 1, 2)
+ self.l0.addWidget(SVDbinLabel, 2, 0, 1, 2)
+ self.l0.addWidget(self.binSpinBox,2, 1, 1, 2)
+ self.l0.addWidget(binLabel, 3, 0, 1, 1)
+ self.l0.addWidget(self.sigmaBox, 3, 1, 1, 1)
+ self.l0.addWidget(self.motSVD_checkbox, 4, 0, 1, 1)
+ self.l0.addWidget(self.movSVD_checkbox, 4, 1, 1, 1)
+ self.l0.addWidget(self.checkBox, 5, 0, 1, 1)
+ self.l0.addWidget(self.save_mat, 5, 1, 1, 1)
+ self.l0.addWidget(self.saverois, 6, 1, 1, 1)
+ self.l0.addWidget(self.process, 7, 0, 1, 1)
+ self.l0.addWidget(self.processbatch, 7, 1, 1, 1)
+ # ~~~~~~~~~~ Save/file IO ~~~~~~~~~~
+ self.l0.addWidget(self.savelabel, 8, 0, 1, 2)
+ # ~~~~~~~~~~ Pose features ~~~~~~~~~~
+ self.l0.addWidget(self.Labels_checkBox, 6, 0, 1, 1)
+ # ~~~~~~~~~~ clustering & ROI visualization window features
+ self.l0.addWidget(self.clusteringVisComboBox, 0, 11, 1, 1)
+ self.l0.addWidget(self.data_clustering_combobox, 0, 12, 1, 2)
+ self.l0.addWidget(self.roiVisComboBox, 0, 12, 1, 2)
+ self.l0.addWidget(self.zoom_in_button, 0, 12, 1, 1)
+ self.l0.addWidget(self.zoom_out_button, 0, 13, 1, 1)
+ self.l0.addWidget(self.run_clustering_button, 0, 14, 1, 1)
+ self.l0.addWidget(self.save_clustering_button, 0, 15, 1, 1)
+ # ~~~~~~~~~~ Video playback ~~~~~~~~~~
self.l0.addWidget(self.playButton,iplay,0,1,1)
self.l0.addWidget(self.pauseButton,iplay,1,1,1)
self.playButton.setEnabled(False)
self.pauseButton.setEnabled(False)
self.pauseButton.setChecked(True)
- self.l0.addWidget(QtGui.QLabel(''),istretch,0,1,3)
+ self.l0.addWidget(QLabel(''),istretch,0,1,3)
self.l0.setRowStretch(istretch,1)
- self.l0.addWidget(self.frameLabel, istretch+7,0,1,1)
- self.l0.addWidget(self.setFrame, istretch+7,1,1,1)
- self.l0.addWidget(self.totalFrameLabel, istretch+8,0,1,1)
- self.l0.addWidget(self.totalFrameNumber, istretch+8,1,1,1)
+ self.l0.addWidget(self.setFrame, istretch+7,0,1,1)
+ self.l0.addWidget(self.totalFrameNumber, istretch+7,1,1,1)
self.l0.addWidget(self.frameSlider, istretch+10,2,1,15)
- # plotting boxes
- #pl = QtGui.QLabel("Plot output")
- #pl.setStyleSheet("color: white")
- #pl.setAlignment(QtCore.Qt.AlignCenter)
- #self.l0.addWidget(pl, istretch+1, 0, 1, 2)
- pl = QtGui.QLabel("Plot 1")
+ # Plot 1 and 2 features
+ pl = QLabel("Plot 1")
pl.setStyleSheet("color: gray;")
self.l0.addWidget(pl, istretch, 0, 1, 1)
- pl = QtGui.QLabel("Plot 2")
+ pl = QLabel("Plot 2")
pl.setStyleSheet("color: gray;")
self.l0.addWidget(pl, istretch, 1, 1, 1)
- pl = QtGui.QLabel("ROI")
- pl.setStyleSheet("color: gray;")
- #self.l0.addWidget(pl, istretch+2, 2, 1, 1)
+ self.load_trace1_button = QPushButton('Load 1D data')
+ self.load_trace1_button.setFont(QFont("Arial", 12))
+ self.load_trace1_button.clicked.connect(lambda: self.load_trace_button_clicked(1))
+ self.load_trace1_button.setEnabled(False)
+ self.trace1_data_loaded = None
+ self.trace1_legend = pg.LegendItem(labelTextSize='12pt', horSpacing=30)
+ self.load_trace2_button = QPushButton('Load 1D data')
+ self.load_trace2_button.setFont(QFont("Arial", 12))
+ self.load_trace2_button.clicked.connect(lambda: self.load_trace_button_clicked(2))
+ self.load_trace2_button.setEnabled(False)
+ self.trace2_data_loaded = None
+ self.trace2_legend = pg.LegendItem(labelTextSize='12pt', horSpacing=30)
+ self.l0.addWidget(self.load_trace1_button, istretch+1, 0, 1, 1)
+ self.l0.addWidget(self.load_trace2_button, istretch+1, 1, 1, 1)
self.cbs1 = []
self.cbs2 = []
self.lbls = []
- for k in range(5):
- self.cbs1.append(QtGui.QCheckBox(""))
- self.l0.addWidget(self.cbs1[-1], istretch+1+k, 0, 1, 1)
- self.cbs2.append(QtGui.QCheckBox(""))
- self.l0.addWidget(self.cbs2[-1], istretch+1+k, 1, 1, 1)
+ for k in range(4):
+ self.cbs1.append(QCheckBox(""))
+ self.l0.addWidget(self.cbs1[-1], istretch+2+k, 0, 1, 1)
+ self.cbs2.append(QCheckBox(""))
+ self.l0.addWidget(self.cbs2[-1], istretch+2+k, 1, 1, 1)
self.cbs1[-1].toggled.connect(self.plot_processed)
self.cbs2[-1].toggled.connect(self.plot_processed)
self.cbs1[-1].setEnabled(False)
self.cbs2[-1].setEnabled(False)
self.cbs1[k].setStyleSheet("color: gray;")
self.cbs2[k].setStyleSheet("color: gray;")
- self.lbls.append(QtGui.QLabel(''))
+ self.lbls.append(QLabel(''))
self.lbls[-1].setStyleSheet("color: white;")
- #self.l0.addWidget(self.lbls[-1], istretch+3+k, 2, 1, 1)
- #ll = QtGui.QLabel('play/pause [SPACE]')
- #ll.setStyleSheet("color: gray;")
- #self.l0.addWidget(ll, istretch+3+k+1,0,1,1)
self.update_frame_slider()
- def vis_combobox_selection_changed(self):
- """
- Call clustering or ROI display functions upon user selection from combo box
- """
- self.clear_visualization_window()
- visualization_request = self.clusteringVisComboBox.currentText()
- if visualization_request == "ROI":
- self.cluster_model.disable_data_clustering_features(self)
- if len(self.ROIs)>0:
- self.update_ROI_vis_comboBox()
- self.update_status_bar("")
- else:
- self.update_status_bar("Please add ROIs for display")
- elif visualization_request == "UMAP":
- self.cluster_model.enable_data_clustering_features(parent=self)
- self.update_status_bar("")
- else:
- self.cluster_model.disable_data_clustering_features(self)
-
- def clear_visualization_window(self):
- self.roiVisComboBox.hide()
- self.pROIimg.clear()
- self.pROI.removeItem(self.scatter)
- self.ClusteringPlot.clear()
- self.ClusteringPlot.hideAxis('left')
- self.ClusteringPlot.hideAxis('bottom')
- self.ClusteringPlot.removeItem(self.clustering_scatterplot)
- self.ClusteringPlot_legend.setParentItem(None)
- self.ClusteringPlot_legend.hide()
-
- def update_ROI_vis_comboBox(self):
- """
- Update ROI selection combo box
- """
- self.roiVisComboBox.clear()
- self.pROIimg.clear()
- self.roiVisComboBox.addItem("--Type--")
- for i in range(len(self.ROIs)):
- selected = self.ROIs[i]
- self.roiVisComboBox.addItem(str(selected.iROI+1)+". "+selected.rtype)
- if self.clusteringVisComboBox.currentText() == "ROI":
- self.roiVisComboBox.show()
-
- def display_ROI(self):
- """
- Plot selected ROI on visualizaiton window
- """
- self.roiVisComboBox.show()
- roi_request = self.roiVisComboBox.currentText()
- if roi_request != "--Type--":
- self.pROI.addItem(self.scatter)
- roi_request_ind = int(roi_request.split(".")[0]) - 1
- self.ROIs[int(roi_request_ind)].plot(self)
- #self.set_ROI_saturation_label(self.ROIs[int(roi_request_ind)].saturation)
- else:
- self.pROIimg.clear()
- self.pROI.removeItem(self.scatter)
-
def set_frame_changed(self, text):
self.cframe = int(float(self.setFrame.text()))
self.jump_to_frame()
-
+ if self.cluster_model.embedded_output is not None:
+ self.highlight_embed_point(self.cframe)
+
def reset(self):
if len(self.rROI)>0:
for r in self.rROI:
@@ -513,11 +467,13 @@ def reset(self):
self.cluster_model.disable_data_clustering_features(self)
self.clusteringVisComboBox.setCurrentIndex(0)
self.ClusteringPlot.clear()
- # Clear DLC variables when a new file is loaded
- #self.DLCplot.clear()
- self.DLC_scatterplot.clear()
+ self.ClusteringPlot_legend.clear()
+ # Clear keypoints when a new file is loaded
+ self.Pose_scatterplot.clear()
#self.p0.clear()
- self.DLC_file_loaded = False
+ self.poseFileLoaded = False
+ self.trace1_data_loaded = None
+ self.trace2_data_loaded = None
# clear checkboxes
for k in range(len(self.cbs1)):
self.cbs1[k].setText("")
@@ -527,6 +483,17 @@ def reset(self):
self.cbs2[k].setEnabled(False)
self.cbs1[k].setChecked(False)
self.cbs2[k].setChecked(False)
+ # Clear pose variables
+ self.pose_model = None
+ self.poseFilepath = []
+ self.poseFilepath = []
+ self.keypoints_labels = []
+ self.pose_x_coord = []
+ self.pose_y_coord = []
+ self.pose_likelihood = []
+ self.keypoints_brushes = []
+ self.bbox = []
+ self.bbox_set = False
def pupil_sigma_change(self):
self.pupil_sigma = float(self.sigmaBox.text())
@@ -536,15 +503,19 @@ def pupil_sigma_change(self):
def add_reflectROI(self):
self.rROI[self.iROI].append(roi.reflectROI(iROI=self.iROI, wROI=len(self.rROI[self.iROI]), moveable=True, parent=self))
- def add_ROI(self):
- roitype = self.comboBox.currentIndex()
- roistr = self.comboBox.currentText()
- if roitype > 0:
+ def add_ROI(self, roitype=None, roistr=None, pos=None, ivid=None, xrange=None, yrange=None,
+ moveable=True, resizable=True):
+ if roitype is None and roistr is None:
+ roitype = self.comboBox.currentIndex()
+ roistr = self.comboBox.currentText()
+ if "pose" in roistr:
+ self.bbox, self.bbox_set, cancel = self.set_pose_bbox()
+ elif roitype > 0:
if self.online_mode and roitype>1:
- msg = QtGui.QMessageBox(self)
- msg.setIcon(QtGui.QMessageBox.Warning)
+ msg = QMessageBox(self)
+ msg.setIcon(QMessageBox.Warning)
msg.setText("only pupil ROI allowed during online mode")
- msg.setStandardButtons(QtGui.QMessageBox.Ok)
+ msg.setStandardButtons(QMessageBox.Ok)
msg.exec_()
return
self.saturation.append(255.)
@@ -553,98 +524,38 @@ def add_ROI(self):
for i in range(len(self.rROI[self.iROI])):
self.pROI.removeItem(self.rROI[self.iROI][i].ROI)
self.iROI = self.nROIs
- self.ROIs.append(roi.sROI(rind=roitype-1, rtype=roistr, iROI=self.nROIs, moveable=True, parent=self))
+ self.ROIs.append(roi.sROI(rind=roitype-1, rtype=roistr, iROI=self.nROIs, moveable=moveable,
+ resizable=resizable, pos=pos, parent=self, ivid=ivid, xrange=xrange,
+ yrange=yrange, saturation=255))
self.rROI.append([])
self.reflectors.append([])
self.nROIs += 1
self.update_ROI_vis_comboBox()
self.ROIs[-1].position(self)
else:
- msg = QtGui.QMessageBox(self)
- msg.setIcon(QtGui.QMessageBox.Warning)
- msg.setText("You have to choose an ROI type before creating ROI")
- msg.setStandardButtons(QtGui.QMessageBox.Ok)
+ msg = QMessageBox(self)
+ msg.setIcon(QMessageBox.Warning)
+ msg.setText("Please select an ROI type")
+ msg.setStandardButtons(QMessageBox.Ok)
msg.exec_()
return
- def update_status_bar(self, message, update_progress=False):
+ def update_status_bar(self, message, update_progress=False, hide_progress=False):
if update_progress:
self.progressBar.show()
progressBar_value = [int(s) for s in message.split("%")[0].split() if s.isdigit()]
- self.progressBar.setValue(progressBar_value[0])
- frames_processed = np.floor((progressBar_value[0]/100)*float(self.totalFrameNumber.text()))
- self.setFrame.setText(str(frames_processed))
- self.statusBar.showMessage(message.split("|")[0])
- else:
- self.progressBar.hide()
- self.statusBar.showMessage(message)
-
- def save_folder(self):
- folderName = QtGui.QFileDialog.getExistingDirectory(self,
- "Choose save folder")
- # load ops in same folder
- if folderName:
- self.save_path = folderName
- if len(folderName) > 30:
- self.savelabel.setText("..."+folderName[-30:])
+ if len(progressBar_value)>0:
+ self.progressBar.setValue(progressBar_value[0])
+ total_frames = self.totalFrameNumber.text().split()[1]
+ frames_processed = np.floor((progressBar_value[0]/100)*float(total_frames))
+ self.setFrame.setText(str(frames_processed))
+ self.statusBar.showMessage(message.split("|")[0])
else:
- self.savelabel.setText(folderName)
-
- def get_DLC_file(self):
- filepath = QtGui.QFileDialog.getOpenFileName(self,
- "Choose DLC file", "", "DLC labels file (*.h5)")
- if filepath[0]:
- self.DLC_filepath = filepath[0]
- self.DLC_file_loaded = True
- self.update_status_bar("DLC file loaded: "+self.DLC_filepath)
- self.load_DLC_points()
-
- def load_DLC_points(self):
- # Read DLC file
- self.DLC_data = pd.read_hdf(self.DLC_filepath, 'df_with_missing')
- all_labels = self.DLC_data.columns.get_level_values("bodyparts")
- self.DLC_keypoints_labels = [all_labels[i] for i in sorted(np.unique(all_labels, return_index=True)[1])]#np.unique(self.DLC_data.columns.get_level_values("bodyparts"))
- self.DLC_x_coord = self.DLC_data.T[self.DLC_data.columns.get_level_values("coords").values=="x"].values #size: key points x frames
- self.DLC_y_coord = self.DLC_data.T[self.DLC_data.columns.get_level_values("coords").values=="y"].values #size: key points x frames
- self.DLC_likelihood = self.DLC_data.T[self.DLC_data.columns.get_level_values("coords").values=="likelihood"].values #size: key points x frames
- # Choose colors for each label: provide option for color blindness as well
- self.colors = cm.get_cmap('gist_rainbow')(np.linspace(0, 1., len(self.DLC_keypoints_labels)))
- self.colors *= 255
- self.colors = self.colors.astype(int)
- self.colors[:,-1] = 127
- self.brushes = np.array([pg.mkBrush(color=c) for c in self.colors])
-
- def update_DLC_points(self):
- if self.DLC_file_loaded and self.DLClabels_checkBox.isChecked():
- self.statusBar.clearMessage()
- self.p0.addItem(self.DLC_scatterplot)
- self.p0.setRange(xRange=(0,self.LX), yRange=(0,self.LY), padding=0.0)
- filtered_keypoints = np.where(self.DLC_likelihood[:,self.cframe] > 0.9)[0]
- x = self.DLC_x_coord[filtered_keypoints,self.cframe]
- y = self.DLC_y_coord[filtered_keypoints,self.cframe]
- self.DLC_scatterplot.setData(x, y, size=15, symbol='o', brush=self.brushes[filtered_keypoints], hoverable=True, hoverSize=15)
- elif not self.DLC_file_loaded and self.DLClabels_checkBox.isChecked():
- self.update_status_bar("Please upload a DLC (*.h5) file")
+ self.statusBar.showMessage("Done!")
else:
- self.statusBar.clearMessage()
- self.DLC_scatterplot.clear()
-
- def DLC_points_clicked(self, obj, points):
- ## Can add functionality for clicking key points
- return ""
-
- def DLC_points_hovered(self, obj, ev):
- point_hovered = np.where(self.DLC_scatterplot.data['hovered'])[0]
- if point_hovered.shape[0] >= 1: # Show tooltip only when hovering over a point i.e. no empty array
- points = self.DLC_scatterplot.points()
- vb = self.DLC_scatterplot.getViewBox()
- if vb is not None and self.DLC_scatterplot.opts['tip'] is not None:
- cutoff = 1 # Display info of only one point when hovering over multiple points
- tip = [self.DLC_scatterplot.opts['tip'](data = self.DLC_keypoints_labels[pt],x=points[pt].pos().x(), y=points[pt].pos().y())
- for pt in point_hovered[:cutoff]]
- if len(point_hovered) > cutoff:
- tip.append('({} other...)'.format(len(point_hovered) - cutoff))
- vb.setToolTip('\n\n'.join(tip))
+ if hide_progress:
+ self.progressBar.hide()
+ self.statusBar.showMessage(message)
def keyPressEvent(self, event):
bid = -1
@@ -666,46 +577,6 @@ def keyPressEvent(self, event):
else:
self.pause()
- def plot_clicked(self, event):
- items = self.win.scene().items(event.scenePos())
- posx = 0
- posy = 0
- iplot = 0
- zoom = False
- zoomImg = False
- choose = False
- if self.loaded:
- for x in items:
- if x==self.p1:
- vb = self.p1.vb
- pos = vb.mapSceneToView(event.scenePos())
- posx = pos.x()
- iplot = 1
- elif x==self.p2:
- vb = self.p1.vb
- pos = vb.mapSceneToView(event.scenePos())
- posx = pos.x()
- iplot = 2
- elif x==self.p0:
- if event.button()==1:
- if event.double():
- zoomImg=True
- if iplot==1 or iplot==2:
- if event.button()==1:
- if event.double():
- zoom=True
- else:
- choose=True
- if zoomImg:
- self.p0.setRange(xRange=(0,self.LX),yRange=(0,self.LY))
- if zoom:
- self.p1.setRange(xRange=(0,self.nframes))
- if choose:
- if self.playButton.isEnabled() and not self.online_mode:
- self.cframe = np.maximum(0, np.minimum(self.nframes-1, int(np.round(posx))))
- self.frameSlider.setValue(self.cframe)
- #self.jump_to_frame()
-
def go_to_frame(self):
self.cframe = int(self.frameSlider.value())
self.setFrame.setText(str(self.cframe))
@@ -717,62 +588,21 @@ def fitToWindow(self):
def update_frame_slider(self):
self.frameSlider.setMaximum(self.nframes-1)
self.frameSlider.setMinimum(0)
- self.frameLabel.setEnabled(True)
- self.totalFrameLabel.setEnabled(True)
self.frameSlider.setEnabled(True)
- def update_buttons(self):
- self.playButton.setEnabled(True)
- self.pauseButton.setEnabled(False)
- self.addROI.setEnabled(True)
- self.pauseButton.setChecked(True)
- self.process.setEnabled(True)
- self.savefolder.setEnabled(True)
- self.saverois.setEnabled(True)
- self.checkBox.setChecked(True)
- self.save_mat.setChecked(True)
-
- # Enable DLC features for single video only
- if len(self.img)==1:
- self.loadDLC.setEnabled(True)
- self.DLClabels_checkBox.setEnabled(True)
- else:
- self.loadDLC.setEnabled(False)
- self.DLClabels_checkBox.setEnabled(False)
-
def jump_to_frame(self):
if self.playButton.isEnabled():
self.cframe = np.maximum(0, np.minimum(self.nframes-1, self.cframe))
self.cframe = int(self.cframe)
self.cframe -= 1
- self.img = self.get_frame(self.cframe)
+ self.img = utils.get_frame(self.cframe, self.nframes, self.cumframes, self.video)
for i in range(len(self.img)):
self.imgs[i][:,:,:,1] = self.img[i].copy()
- img = self.get_frame(self.cframe+1)
+ img = utils.get_frame(self.cframe+1, self.nframes, self.cumframes, self.video)
for i in range(len(self.img)):
self.imgs[i][:,:,:,2] = img[i]
self.next_frame()
- def get_frame(self, cframe):
- cframe = np.maximum(0, np.minimum(self.nframes-1, cframe))
- cframe = int(cframe)
- try:
- ivid = (self.cumframes < cframe).nonzero()[0][-1]
- except:
- ivid = 0
- img = []
- for vs in self.video[ivid]:
- frame_ind = cframe - self.cumframes[ivid]
- capture = vs
- if int(capture.get(cv2.CAP_PROP_POS_FRAMES)) != frame_ind:
- capture.set(cv2.CAP_PROP_POS_FRAMES, frame_ind)
- ret, frame = capture.read()
- if ret:
- img.append(frame)
- else:
- print("Error reading frame")
- return img
-
def next_frame(self):
if not self.online_mode:
# loop after video finishes
@@ -781,14 +611,14 @@ def next_frame(self):
self.cframe = 0
for i in range(len(self.imgs)):
self.imgs[i][:,:,:,:2] = self.imgs[i][:,:,:,1:]
- im = self.get_frame(self.cframe+1)
+ im = utils.get_frame(self.cframe+1, self.nframes, self.cumframes, self.video)
for i in range(len(self.imgs)):
self.imgs[i][:,:,:,2] = im[i]
self.img[i] = self.imgs[i][:,:,:,1].copy()
self.fullimg[self.sy[i]:self.sy[i]+self.Ly[i],
self.sx[i]:self.sx[i]+self.Lx[i]] = self.img[i]
self.frameSlider.setValue(self.cframe)
- if self.processed:
+ if self.processed or self.trace1_data_loaded is not None or self.trace2_data_loaded is not None:
self.plot_scatter()
else:
self.online_plotted = False
@@ -800,9 +630,9 @@ def next_frame(self):
self.pimg.setImage(self.fullimg)
self.pimg.setLevels([0,self.sat[0]])
self.setFrame.setText(str(self.cframe))
- self.update_DLC_points()
+ self.update_pose()
#self.frameNumber.setText(str(self.cframe))
- self.totalFrameNumber.setText(str(self.nframes))
+ self.totalFrameNumber.setText("/ "+str(self.nframes)+" frames")
self.win.show()
self.show()
@@ -814,13 +644,13 @@ def start(self):
self.playButton.setEnabled(False)
self.pauseButton.setEnabled(True)
self.frameSlider.setEnabled(False)
- self.updateTimer.start(25)
+ self.updateTimer.start(50) #25
elif self.cframe < self.nframes - 1:
self.playButton.setEnabled(False)
self.pauseButton.setEnabled(True)
self.frameSlider.setEnabled(False)
- self.updateTimer.start(25)
- self.update_DLC_points()
+ self.updateTimer.start(50) #25
+ self.update_pose()
def pause(self):
self.updateTimer.stop()
@@ -829,7 +659,7 @@ def pause(self):
self.frameSlider.setEnabled(True)
if self.online_mode:
self.online_traces = None
- self.update_DLC_points()
+ self.update_pose()
def save_ops(self):
ops = {'sbin': self.sbin, 'pupil_sigma': float(self.sigmaBox.text()),
@@ -843,43 +673,41 @@ def save_ROIs(self):
self.sbin = int(self.binSpinBox.value())
# save running parameters as defaults
ops = self.save_ops()
-
if len(self.save_path) > 0:
savepath = self.save_path
else:
savepath = None
- print("ROIs saved in:", savepath)
if len(self.ROIs)>0:
rois = utils.roi_to_dict(self.ROIs, self.rROI)
else:
rois = None
proc = {'Ly':self.Ly, 'Lx':self.Lx, 'sy': self.sy, 'sx': self.sx, 'LY':self.LY, 'LX':self.LX,
'sbin': ops['sbin'], 'fullSVD': ops['fullSVD'], 'rois': rois,
+ 'motSVD': self.motSVD_checkbox.isChecked(), 'movSVD': self.movSVD_checkbox.isChecked(),
+ 'bbox': self.bbox, 'bbox_set': self.bbox_set,
'save_mat': ops['save_mat'], 'save_path': ops['save_path'],
'filenames': self.filenames}
savename = process.save(proc, savepath=savepath)
- self.update_status_bar("File saved in "+savepath) ####
+ self.update_status_bar("ROIs saved in "+savepath)
self.batchlist.append(savename)
- basename,filename = os.path.split(savename)
- filename, ext = os.path.splitext(filename)
- #self.batchname[len(self.batchlist)-1].setText(filename)
+ _,filename = os.path.split(savename)
+ filename, _ = os.path.splitext(filename)
+ self.batchname[len(self.batchlist)-1].setText(filename)
self.processbatch.setEnabled(True)
def process_batch(self):
- if self.motSVD_checkbox.isChecked() or self.movSVD_checkbox.isChecked():
- files = self.batchlist
- for f in files:
- proc = np.load(f, allow_pickle=True).item()
- savename = process.run(proc['filenames'], GUIobject=QtGui, parent=self, proc=proc, savepath=proc['save_path'])
- if len(files)==1:
- io.open_proc(self, file_name=savename)
- else:
- msg = QtGui.QMessageBox(self)
- msg.setIcon(QtGui.QMessageBox.Warning)
- msg.setText("Please check at least one of: motSVD, movSVD")
- msg.setStandardButtons(QtGui.QMessageBox.Ok)
- msg.exec_()
- return
+ files = self.batchlist
+ for f in files:
+ proc = np.load(f, allow_pickle=True).item()
+ if proc['motSVD'] or proc['movSVD']:
+ savename = process.run(proc['filenames'], motSVD=proc['motSVD'], movSVD=proc['movSVD'],
+ GUIobject=QtGui, proc=proc, savepath=proc['save_path'])
+ self.update_status_bar("Processed "+savename)
+
+ pose.Pose(gui=None, filenames=proc['filenames'],
+ bbox=proc['bbox'], bbox_set=proc['bbox_set']).run(plot=False)
+ if len(files)==1 and (proc['motSVD'] or proc['movSVD']):
+ io.open_proc(self, file_name=savename)
def process_ROIs(self):
self.sbin = int(self.binSpinBox.value())
@@ -894,13 +722,364 @@ def process_ROIs(self):
io.open_proc(self, file_name=savename)
print("Output saved in",savepath)
self.update_status_bar("Output saved in "+savepath)
- else:
- msg = QtGui.QMessageBox(self)
- msg.setIcon(QtGui.QMessageBox.Warning)
- msg.setText("Please check at least one of: motSVD, movSVD")
- msg.setStandardButtons(QtGui.QMessageBox.Ok)
- msg.exec_()
- return
+ if self.Labels_checkBox.isChecked():
+ self.get_pose_labels()
+ if self.pose_model is not None:
+ self.pose_model.run()
+ self.update_status_bar("Pose labels saved in "+savepath)
+
+ def update_buttons(self):
+ self.playButton.setEnabled(True)
+ self.pauseButton.setEnabled(False)
+ self.addROI.setEnabled(True)
+ self.pauseButton.setChecked(True)
+ self.process.setEnabled(True)
+ self.saverois.setEnabled(True)
+ self.checkBox.setChecked(True)
+ self.save_mat.setChecked(True)
+ self.load_trace1_button.setEnabled(True)
+ self.load_trace2_button.setEnabled(True)
+
+ # Enable pose features for single video only
+ self.Labels_checkBox.setEnabled(True)
+
+ def button_status(self, status):
+ self.playButton.setEnabled(status)
+ self.pauseButton.setEnabled(status)
+ self.frameSlider.setEnabled(status)
+ self.process.setEnabled(status)
+ self.saverois.setEnabled(status)
+
+ ### ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Clustering and ROI ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ def vis_combobox_selection_changed(self):
+ """
+ Call clustering or ROI display functions upon user selection from combo box
+ """
+ self.clear_visualization_window()
+ visualization_request = int(self.clusteringVisComboBox.currentIndex())
+ self.reflector.show()
+ if visualization_request == 1: # ROI
+ self.cluster_model.disable_data_clustering_features(self)
+ if len(self.ROIs)>0:
+ self.update_ROI_vis_comboBox()
+ self.update_status_bar("")
+ else:
+ self.update_status_bar("Please add ROIs for display")
+ elif visualization_request == 2 or visualization_request == 3: # tSNE/UMAP
+ self.reflector.hide()
+ self.cluster_model.enable_data_clustering_features(parent=self)
+ self.update_status_bar("")
+ else:
+ self.cluster_model.disable_data_clustering_features(self)
+
+ def clear_visualization_window(self):
+ self.roiVisComboBox.hide()
+ self.pROIimg.clear()
+ self.pROI.removeItem(self.scatter)
+ self.ClusteringPlot.clear()
+ self.ClusteringPlot.hideAxis('left')
+ self.ClusteringPlot.hideAxis('bottom')
+ self.ClusteringPlot.removeItem(self.clustering_scatterplot)
+ self.ClusteringPlot_legend.setParentItem(None)
+ self.ClusteringPlot_legend.hide()
+
+ def cluster_plot_zoom_buttons(self, in_or_out):
+ """
+ see ViewBox.scaleBy()
+ pyqtgraph wheel zoom is s = ~0.75
+ """
+ s = 0.9
+ zoom = (s, s) if in_or_out == "in" else (1/s, 1/s)
+ self.ClusteringPlot.vb.scaleBy(zoom)
+
+ def update_ROI_vis_comboBox(self):
+ """
+ Update ROI selection combo box
+ """
+ self.roiVisComboBox.clear()
+ self.pROIimg.clear()
+ self.roiVisComboBox.addItem("--Type--")
+ for i in range(len(self.ROIs)):
+ selected = self.ROIs[i]
+ self.roiVisComboBox.addItem(str(selected.iROI+1)+". "+selected.rtype)
+ if self.clusteringVisComboBox.currentText() == "ROI":
+ self.roiVisComboBox.show()
+
+ def display_ROI(self):
+ """
+ Plot selected ROI on visualizaiton window
+ """
+ self.roiVisComboBox.show()
+ roi_request = self.roiVisComboBox.currentText()
+ if roi_request != "--Type--":
+ self.pROI.addItem(self.scatter)
+ roi_request_ind = int(roi_request.split(".")[0]) - 1
+ self.ROIs[int(roi_request_ind)].plot(self)
+ #self.set_ROI_saturation_label(self.ROIs[int(roi_request_ind)].saturation)
+ else:
+ self.pROIimg.clear()
+ self.pROI.removeItem(self.scatter)
+
+ def highlight_embed_point(self, playback_point):
+ x = [np.array(self.clustering_scatterplot.points()[playback_point].pos().x())]
+ y = [np.array(self.clustering_scatterplot.points()[playback_point].pos().y())]
+ self.clustering_highlight_scatterplot.setData(x=x, y=y,
+ symbol='x', brush='r',pxMode=True, hoverable=True, hoverSize=20,
+ hoverSymbol="x", hoverBrush='r',pen=(0,0,0,0),
+ data=playback_point, size=15)
+ """
+ old = self.clustering_scatterplot.data['hovered']
+ self.clustering_scatterplot.data['sourceRect'][old] = 1
+ bool_mask = np.full((len(self.clustering_scatterplot.data)), False, dtype=bool)
+ self.clustering_scatterplot.data['hovered'] = bool_mask
+ self.clustering_scatterplot.invalidate()
+ self.clustering_scatterplot.updateSpots()
+ self.clustering_scatterplot.sigPlotChanged.emit(self.clustering_scatterplot)
+
+ bool_mask[playback_point] = True
+ self.clustering_scatterplot.data['hovered'] = bool_mask
+ self.clustering_scatterplot.data['sourceRect'][bool_mask] = 0
+ self.clustering_scatterplot.updateSpots()
+ #points = self.clustering_scatterplot.points()
+ #self.clustering_scatterplot.sigClicked.emit([points[playback_point]], None, self)
+ """
+
+ ### ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pose functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ def set_pose_bbox(self):
+ # User defined bbox selection
+ self.pose_gui = pose_gui.PoseGUI(gui=self)
+ self.bbox, self.bbox_set, cancel = self.pose_gui.draw_user_bbox()
+ return self.bbox, self.bbox_set, cancel
+
+ def get_pose_labels(self):
+ if not self.bbox_set:
+ self.bbox, self.bbox_set, _ = self.set_pose_bbox()
+ if self.pose_model is None:
+ self.pose_model = pose.Pose(gui=self, GUIobject=QtGui, filenames=self.filenames,
+ bbox=self.bbox, bbox_set=self.bbox_set)
+
+ def load_labels(self):
+ # Read Pose file
+ for video_id in range(len(self.poseFilepath)):
+ pose_data = pd.read_hdf(self.poseFilepath[video_id], 'df_with_missing')
+ # Append pose data to list for each video_id
+ self.keypoints_labels.append(pd.unique(pose_data.columns.get_level_values("bodyparts")))
+ self.pose_x_coord.append(pose_data.T[pose_data.columns.get_level_values("coords").values=="x"].values) #size: key points x frames
+ self.pose_y_coord.append(pose_data.T[pose_data.columns.get_level_values("coords").values=="y"].values) #size: key points x frames
+ self.pose_likelihood.append(pose_data.T[pose_data.columns.get_level_values("coords").values=="likelihood"].values) #size: key points x frames
+ # Choose colors for each label: provide option for paltter that is color-blindness friendly
+ colors = cm.get_cmap('jet')(np.linspace(0, 1., len(self.keypoints_labels[video_id])))
+ colors *= 255
+ colors = colors.astype(int)
+ self.keypoints_brushes.append(np.array([pg.mkBrush(color=c) for c in colors]))
+
+ def update_pose(self):
+ if self.poseFileLoaded and self.Labels_checkBox.isChecked():
+ self.statusBar.clearMessage()
+ self.p0.addItem(self.Pose_scatterplot)
+ self.p0.setRange(xRange=(0,self.LX), yRange=(0,self.LY), padding=0.0)
+ threshold = np.nanpercentile(self.pose_likelihood, 10) # Determine threshold
+ x, y, labels, brushes = np.array([]), np.array([]), np.array([]), np.array([])
+ for video_id in range(len(self.poseFilepath)):
+ filtered_keypoints = np.where(self.pose_likelihood[video_id][:,self.cframe] > threshold)[0]
+ x_coord = self.pose_x_coord[video_id] + self.sx[video_id] # shift x coordinates
+ x = np.append(x, x_coord[filtered_keypoints,self.cframe])
+ y_coord = self.pose_y_coord[video_id] + self.sy[video_id] # shift y coordinates
+ y = np.append(y, y_coord[filtered_keypoints,self.cframe])
+ labels = np.append(labels, self.keypoints_labels[video_id][filtered_keypoints])
+ brushes = np.append(brushes, self.keypoints_brushes[video_id][filtered_keypoints])
+ self.Pose_scatterplot.setData(x, y, size=12, symbol='o', brush=brushes, hoverable=True, hoverSize=10,
+ data=labels)
+ elif not self.poseFileLoaded and self.Labels_checkBox.isChecked():
+ self.update_status_bar("Please upload a pose (*.h5) file")
+ else:
+ self.statusBar.clearMessage()
+ self.Pose_scatterplot.clear()
+
+ def keypoints_clicked(self, obj, points):
+ ## Can add functionality for clicking key points
+ return ""
+
+ def keypoints_hovered(self, obj, ev):
+ point_hovered = np.where(self.Pose_scatterplot.data['hovered'])[0]
+ if point_hovered.shape[0] >= 1: # Show tooltip only when hovering over a point i.e. no empty array
+ points = self.Pose_scatterplot.points()
+ vb = self.Pose_scatterplot.getViewBox()
+ if vb is not None and self.Pose_scatterplot.opts['tip'] is not None:
+ cutoff = 1 # Display info of only one point when hovering over multiple points
+ tip = [self.Pose_scatterplot.opts['tip'](data=points[pt].data(), x=points[pt].pos().x(), y=points[pt].pos().y())
+ for pt in point_hovered[:cutoff]]
+ if len(point_hovered) > cutoff:
+ tip.append('({} other...)'.format(len(point_hovered) - cutoff))
+ vb.setToolTip('\n\n'.join(tip))
+
+ ### ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Plot 1 and 2 functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ def load_trace_button_clicked(self, plot_id):
+ try:
+ data = io.load_trace_data(parent=self)
+ if data.ndim == 1:
+ # Open a QDialog box containing two radio buttons horizontally centered
+ # and a QLineEdit to enter the name of the trace
+ # If the user presses OK, the trace is added to the list of traces
+ # and the combo box is updated
+ # If the user presses Cancel, the trace is not added
+ dialog = QtWidgets.QDialog()
+ dialog.setWindowTitle("Set data type")
+ dialog.setFixedWidth(400)
+ dialog.verticalLayout = QtWidgets.QVBoxLayout(dialog)
+ dialog.verticalLayout.setContentsMargins(10, 10, 10, 10)
+
+ dialog.horizontalLayout = QtWidgets.QHBoxLayout()
+ dialog.verticalLayout.addLayout(dialog.horizontalLayout)
+ dialog.label = QtWidgets.QLabel("Data type:")
+ dialog.horizontalLayout.addWidget(dialog.label)
+
+ # Create radio buttons
+ dialog.radio_button_group = QtWidgets.QButtonGroup()
+ dialog.radio_button_group.setExclusive(True)
+ dialog.radioButton1 = QtWidgets.QRadioButton("Continuous")
+ dialog.radioButton1.setChecked(True)
+ dialog.horizontalLayout.addWidget(dialog.radioButton1)
+ dialog.radioButton2 = QtWidgets.QRadioButton("Discrete")
+ dialog.radioButton2.setChecked(False)
+ dialog.horizontalLayout.addWidget(dialog.radioButton2)
+ # Add radio buttons to radio buttons group
+ dialog.radio_button_group.addButton(dialog.radioButton1)
+ dialog.radio_button_group.addButton(dialog.radioButton2)
+
+ dialog.horizontalLayout2 = QtWidgets.QHBoxLayout()
+ dialog.label = QtWidgets.QLabel("Data name:")
+ dialog.horizontalLayout2.addWidget(dialog.label)
+ dialog.lineEdit = QtWidgets.QLineEdit()
+ dialog.lineEdit.setText("Trace 1")
+ # Adjust size of line edit
+ dialog.lineEdit.setFixedWidth(200)
+ dialog.horizontalLayout2.addWidget(dialog.lineEdit)
+ dialog.verticalLayout.addLayout(dialog.horizontalLayout2)
+ dialog.horizontalLayout3 = QtWidgets.QHBoxLayout()
+ dialog.buttonBox = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel)
+ dialog.buttonBox.accepted.connect(dialog.accept)
+ dialog.buttonBox.rejected.connect(dialog.reject)
+ dialog.horizontalLayout3.addWidget(dialog.buttonBox)
+ dialog.verticalLayout.addLayout(dialog.horizontalLayout3)
+ if dialog.exec_():
+ data_name = dialog.lineEdit.text()
+ if data_name == "":
+ data_name = "trace"
+ data_type = "continuous"
+ if dialog.radioButton2.isChecked():
+ data_type = "discrete"
+ # Create a color palette of len(data) using distinguishable_colors
+ # and add it to the list of color palettes
+ # The color palette is used to color the points in the scatter
+ if len(np.unique(data))<=10:
+ color_palette = np.array(plt.get_cmap('tab10').colors)
+ elif len(np.unique(data))<=20:
+ color_palette = np.array(plt.get_cmap('tab20').colors)
+ else:
+ num_classes = len(np.unique(data))
+ color_palette = cm.get_cmap('gist_rainbow')(np.linspace(0, 1., num_classes))
+ color_palette *= 255
+ color_palette = color_palette.astype(int)
+ #color_palette = color_palette[:len(np.unique(data))]
+ # Create a list of pens for each unique value in data
+ # The pen is used to color the points in the scatter plot
+ pen_list = np.empty(len(data), dtype=object)
+ for j, value in enumerate(np.unique(data)):
+ ind = np.where(data==value)[0]
+ pen_list[ind] = pg.mkPen(color_palette[j])
+ vtick = QPainterPath()
+ vtick.moveTo(0, -1)
+ vtick.lineTo(0, 1)
+
+ if plot_id == 1:
+ self.trace1_data_loaded = data
+ self.trace1_data_type = data_type
+ self.trace1_name = data_name
+ if data_type == "discrete":
+ x = np.arange(len(data))
+ y = np.ones((len(x)))
+ self.trace1_plot = pg.ScatterPlotItem()
+ self.trace1_plot.setData(x, y, pen=pen_list, brush='g',pxMode=False,
+ symbol=vtick, size=1, symbol_pen=pen_list)
+ else:
+ self.trace1_plot = pg.PlotDataItem()
+ self.trace1_plot.setData(data, pen=pg.mkPen("g", width=1))
+ self.trace1_legend.clear()
+ self.trace1_legend.addItem(self.trace1_plot, name=data_name)
+ self.trace1_legend.setPos(self.trace1_plot.x(), self.trace1_plot.y())
+ self.trace1_legend.setParentItem(self.p1)
+ self.trace1_legend.setVisible(True)
+ self.trace1_plot.setVisible(True)
+ self.update_status_bar("Trace 1 data updated")
+ try:
+ self.trace1_legend.sigClicked.connect(self.mouseClickEvent)
+ except Exception as e:
+ pass
+ elif plot_id == 2:
+ self.trace2_data_loaded = data
+ self.trace2_data_type = data_type
+ self.trace2_name = data_name
+ if data_type == "discrete":
+ x = np.arange(len(data))
+ y = np.ones((len(x)))
+ self.trace2_plot = pg.ScatterPlotItem()
+ self.trace2_plot.setData(x, y, pen=pen_list, brush='g',pxMode=False,
+ symbol=vtick, size=1, symbol_pen=pen_list)
+ else:
+ self.trace2_plot = pg.PlotDataItem()
+ self.trace2_plot.setData(data, pen=pg.mkPen("g", width=1))
+ self.trace2_legend.clear()
+ self.trace2_legend.addItem(self.trace2_plot, name=data_name)
+ self.trace2_legend.setPos(self.trace2_plot.x(), self.trace2_plot.y())
+ self.trace2_legend.setParentItem(self.p2)
+ self.trace2_legend.setVisible(True)
+ self.trace2_plot.setVisible(True)
+ self.update_status_bar("Trace 2 data updated")
+ try:
+ self.trace2_legend.sigClicked.connect(self.mouseClickEvent)
+ except Exception as e:
+ pass
+ else:
+ self.update_status_bar("Error: plot ID not recognized")
+ pass
+ self.plot_processed()
+ except Exception as e:
+ print(e)
+ self.update_status_bar("Error: data not recognized")
+
+ # Plot trace on p1 showing cluster labels as discrete data
+ def plot_cluster_labels_p1(self, labels, color_palette):
+ x = np.arange(len(labels))
+ y = np.ones((len(x)))
+ self.trace1_data_loaded = y
+ self.trace1_data_type = "discrete"
+ self.trace1_name = "Cluster Labels"
+ # Create a list of pens for each unique value in data
+ # The pen is used to color the points in the scatter plot
+ pen_list = np.empty(len(labels), dtype=object)
+ for j, value in enumerate(np.unique(labels)):
+ ind = np.where(labels==value)[0]
+ pen_list[ind] = pg.mkPen(color_palette[j])
+ vtick = QPainterPath()
+ vtick.moveTo(0, -1)
+ vtick.lineTo(0, 1)
+ # Plot trace 1 data points
+ self.trace1_plot = pg.ScatterPlotItem()
+ self.trace1_plot.setData(x, y, pen=pen_list, brush='g',pxMode=False,
+ symbol=vtick, size=1, symbol_pen=pen_list)
+ self.trace1_legend.clear()
+ self.trace1_legend.addItem(self.trace1_plot, name=self.trace1_name)
+ self.trace1_legend.setPos(self.trace1_plot.x(), self.trace1_plot.y())
+ self.trace1_legend.setParentItem(self.p1)
+ self.trace1_legend.setVisible(True)
+ self.trace1_plot.setVisible(True)
+ self.update_status_bar("Trace 1 data updated")
+ try:
+ self.trace1_legend.sigClicked.connect(self.mouseClickEvent)
+ except Exception as e:
+ pass
+ self.plot_processed()
def plot_processed(self):
self.p1.clear()
@@ -932,6 +1111,12 @@ def plot_processed(self):
else:
self.cbs2[k].setText(self.lbls[k].text())
self.cbs2[k].setStyleSheet("color: gray")
+ if self.trace1_data_loaded is not None:
+ self.p1.addItem(self.trace1_plot)
+ self.traces1 = np.concatenate((self.traces1, self.trace1_data_loaded[np.newaxis,:]), axis=0)
+ if self.trace2_data_loaded is not None:
+ self.p2.addItem(self.trace2_plot)
+ self.traces2 = np.concatenate((self.traces2, self.trace2_data_loaded[np.newaxis,:]), axis=0)
self.p1.setRange(xRange=(0,self.nframes),
yRange=(-4, 4),
padding=0.0)
@@ -951,7 +1136,7 @@ def plot_scatter(self):
self.p1.removeItem(self.scatter1)
self.scatter1.setData(self.cframe*np.ones((ntr,)),
self.traces1[:, self.cframe],
- size=10, brush=pg.mkBrush(255,255,255))
+ size=8, brush=pg.mkBrush(255,255,255))
self.p1.addItem(self.scatter1)
if self.traces2.shape[0] > 0:
@@ -959,7 +1144,7 @@ def plot_scatter(self):
self.p2.removeItem(self.scatter2)
self.scatter2.setData(self.cframe*np.ones((ntr,)),
self.traces2[:, self.cframe],
- size=10, brush=pg.mkBrush(255,255,255))
+ size=8, brush=pg.mkBrush(255,255,255))
self.p2.addItem(self.scatter2)
def plot_trace(self, wplot, proctype, wroi, color):
@@ -1023,20 +1208,55 @@ def plot_trace(self, wplot, proctype, wroi, color):
tr = running.T
return tr
- def button_status(self, status):
- self.playButton.setEnabled(status)
- self.pauseButton.setEnabled(status)
- self.frameSlider.setEnabled(status)
- self.process.setEnabled(status)
- self.saverois.setEnabled(status)
+ def plot_clicked(self, event):
+ items = self.win.scene().items(event.scenePos())
+ posx = 0
+ posy = 0
+ iplot = 0
+ zoom = False
+ zoomImg = False
+ choose = False
+ if self.loaded:
+ for x in items:
+ if x==self.p1:
+ vb = self.p1.vb
+ pos = vb.mapSceneToView(event.scenePos())
+ posx = pos.x()
+ iplot = 1
+ elif x==self.p2:
+ vb = self.p1.vb
+ pos = vb.mapSceneToView(event.scenePos())
+ posx = pos.x()
+ iplot = 2
+ elif x==self.p0:
+ if event.button()==1:
+ if event.double():
+ zoomImg=True
+ if iplot==1 or iplot==2:
+ if event.button()==1:
+ if event.double():
+ zoom=True
+ else:
+ choose=True
+ if zoomImg:
+ self.p0.setRange(xRange=(0,self.LX),yRange=(0,self.LY))
+ if zoom:
+ self.p1.setRange(xRange=(0,self.nframes))
+ if choose:
+ if self.playButton.isEnabled() and not self.online_mode:
+ self.cframe = np.maximum(0, np.minimum(self.nframes-1, int(np.round(posx))))
+ self.frameSlider.setValue(self.cframe)
+ #self.jump_to_frame()
+
+### ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Main ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-def run(moviefile=None,savedir=None):
+def run(moviefile=None, savedir=None):
# Always start by initializing Qt (only once per application)
- app = QtGui.QApplication(sys.argv)
+ app = QtWidgets.QApplication(sys.argv)
icon_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "mouse.png"
)
- app_icon = QtGui.QIcon()
+ app_icon = QIcon()
app_icon.addFile(icon_path, QtCore.QSize(16, 16))
app_icon.addFile(icon_path, QtCore.QSize(24, 24))
app_icon.addFile(icon_path, QtCore.QSize(32, 32))
@@ -1044,11 +1264,6 @@ def run(moviefile=None,savedir=None):
app_icon.addFile(icon_path, QtCore.QSize(96, 96))
app_icon.addFile(icon_path, QtCore.QSize(256, 256))
app.setWindowIcon(app_icon)
- GUI = MainW(moviefile,savedir)
- #p = GUI.palette()
+ GUI = MainW(moviefile, savedir)
ret = app.exec_()
- # GUI.save_gui_data()
sys.exit(ret)
-
-
-# run()
diff --git a/facemap/guiparts.py b/facemap/gui/guiparts.py
old mode 100644
new mode 100755
similarity index 94%
rename from facemap/guiparts.py
rename to facemap/gui/guiparts.py
index 14c9d1f..7599552
--- a/facemap/guiparts.py
+++ b/facemap/gui/guiparts.py
@@ -3,25 +3,29 @@
from pyqtgraph import functions as fn
from pyqtgraph import Point
import numpy as np
+from PyQt5.QtWidgets import (
+ QListWidget, QDialog, QPushButton, QWidget, QGridLayout, QRadioButton, QLabel, QLineEdit,
+ QAbstractItemView, QSlider, QButtonGroup, QStyleOptionSlider
+)
### custom QDialog which makes a list of items you can include/exclude
-class ListChooser(QtGui.QDialog):
+class ListChooser(QDialog):
def __init__(self, title, parent):
super(ListChooser, self).__init__(parent)
self.setGeometry(300,300,320,320)
self.setWindowTitle(title)
- self.win = QtGui.QWidget(self)
- layout = QtGui.QGridLayout()
+ self.win = QWidget(self)
+ layout = QGridLayout()
self.win.setLayout(layout)
#self.setCentralWidget(self.win)
- layout.addWidget(QtGui.QLabel('click to select videos (none selected => all used)'),0,0,1,1)
- self.list = QtGui.QListWidget(parent)
+ layout.addWidget(QLabel('click to select videos (none selected => all used)'),0,0,1,1)
+ self.list = QListWidget(parent)
for f in parent.filelist:
self.list.addItem(f)
layout.addWidget(self.list,1,0,7,4)
#self.list.resize(450,250)
- self.list.setSelectionMode(QtGui.QAbstractItemView.MultiSelection)
- done = QtGui.QPushButton('done')
+ self.list.setSelectionMode(QAbstractItemView.MultiSelection)
+ done = QPushButton('done')
done.clicked.connect(lambda: self.exit_list(parent))
layout.addWidget(done,8,0,1,1)
@@ -32,7 +36,7 @@ def exit_list(self, parent):
parent.filelist.append(str(self.list.selectedItems()[i].text()))
self.accept()
-class Slider(QtGui.QSlider):
+class Slider(QSlider):
def __init__(self, bid, parent=None):
super(self.__class__, self).__init__()
initval = [99,99]
@@ -57,18 +61,18 @@ def level_change(self, parent, bid):
parent.win.show()
-class TextChooser(QtGui.QDialog):
+class TextChooser(QDialog):
def __init__(self,parent=None):
super(TextChooser, self).__init__(parent)
self.setGeometry(300,300,350,100)
self.setWindowTitle('folder path')
- self.win = QtGui.QWidget(self)
- layout = QtGui.QGridLayout()
+ self.win = QWidget(self)
+ layout = QGridLayout()
self.win.setLayout(layout)
- self.qedit = QtGui.QLineEdit('')
- layout.addWidget(QtGui.QLabel('folder name (does not have to exist yet)'),0,0,1,3)
+ self.qedit = QLineEdit('')
+ layout.addWidget(QLabel('folder name (does not have to exist yet)'),0,0,1,3)
layout.addWidget(self.qedit,1,0,1,3)
- done = QtGui.QPushButton('OK')
+ done = QPushButton('OK')
done.clicked.connect(self.exit)
layout.addWidget(done,2,1,1,1)
@@ -76,16 +80,15 @@ def exit(self):
self.folder = self.qedit.text()
self.accept()
-class RGBRadioButtons(QtGui.QButtonGroup):
+class RGBRadioButtons(QButtonGroup):
def __init__(self, parent=None, row=0, col=0):
super(RGBRadioButtons, self).__init__()
parent.color = 0
self.parent = parent
self.bstr = ["image", "flowsX", "flowsY", "flowsZ", "cellprob"]
- #self.buttons = QtGui.QButtonGroup()
self.dropdown = []
for b in range(len(self.bstr)):
- button = QtGui.QRadioButton(self.bstr[b])
+ button = QRadioButton(self.bstr[b])
button.setStyleSheet('color: white;')
if b==0:
button.setChecked(True)
@@ -318,7 +321,7 @@ def setDrawKernel(self, kernel_size=3):
self.greenmask = np.concatenate((onmask,offmask,onmask,opamask), axis=-1)
-class RangeSlider(QtGui.QSlider):
+class RangeSlider(QSlider):
""" A slider for ranges.
This class provides a dual-slider for ranges, where there is a defined
@@ -337,12 +340,12 @@ def __init__(self, parent=None, *args):
self._low = self.minimum()
self._high = self.maximum()
- self.pressed_control = QtGui.QStyle.SC_None
- self.hover_control = QtGui.QStyle.SC_None
+ self.pressed_control = QStyle.SC_None
+ self.hover_control = QStyle.SC_None
self.click_offset = 0
self.setOrientation(QtCore.Qt.Vertical)
- self.setTickPosition(QtGui.QSlider.TicksRight)
+ self.setTickPosition(QSlider.TicksRight)
self.setStyleSheet(\
"QSlider::handle:horizontal {\
background-color: white;\
@@ -389,7 +392,7 @@ def paintEvent(self, event):
style = QtGui.QApplication.style()
for i, value in enumerate([self._low, self._high]):
- opt = QtGui.QStyleOptionSlider()
+ opt = QStyleOptionSlider()
self.initStyleOption(opt)
# Only draw the groove for the first slider so it doesn't get drawn
diff --git a/facemap/io.py b/facemap/gui/io.py
old mode 100644
new mode 100755
similarity index 82%
rename from facemap/io.py
rename to facemap/gui/io.py
index 667cc0b..5e8a253
--- a/facemap/io.py
+++ b/facemap/gui/io.py
@@ -2,14 +2,16 @@
import numpy as np
from PyQt5 import QtGui, QtCore
import pyqtgraph as pg
-from . import guiparts, roi, utils
+from . import guiparts
+from .. import roi, utils
from natsort import natsorted
import pickle
+from PyQt5.QtWidgets import (QFileDialog, QMessageBox)
def open_file(parent, file_name=None):
if file_name is None:
- file_name = QtGui.QFileDialog.getOpenFileName(parent,
- "Open movie file", "", "Movie files (*.h5 *.mj2 *.mp4 *.mkv *.avi *.mpeg *.mpg *.asf)")
+ file_name = QFileDialog.getOpenFileName(parent,
+ "Open movie file", "", "Movie files (*.h5 *.mj2 *.mp4 *.mkv *.avi *.mpeg *.mpg *.asf *m4v)")
# load ops in same folder
if file_name:
parent.filelist = [[file_name[0]]]
@@ -17,7 +19,7 @@ def open_file(parent, file_name=None):
def open_folder(parent, folder_name=None):
if folder_name is None:
- folder_name = QtGui.QFileDialog.getExistingDirectory(parent,
+ folder_name = QFileDialog.getExistingDirectory(parent,
"Choose folder with movies")
# load ops in same folder
if folder_name:
@@ -44,13 +46,13 @@ def choose_files(parent, file_name):
parent.filelist=file_name
parent.filelist = natsorted(parent.filelist)
if len(parent.filelist)>1:
- dm = QtGui.QMessageBox.question(
+ dm = QMessageBox.question(
parent,
"multiple videos found",
"are you processing multiple videos taken simultaneously?",
- QtGui.QMessageBox.Yes | QtGui.QMessageBox.No,
+ QMessageBox.Yes | QMessageBox.No,
)
- if dm == QtGui.QMessageBox.Yes:
+ if dm == QMessageBox.Yes:
print('multi camera view')
# expects first 4 letters to be different e.g. cam0, cam1, ...
files = []
@@ -80,11 +82,10 @@ def choose_files(parent, file_name):
else:
parent.filelist = [parent.filelist]
parent.filelist = natsorted(parent.filelist)
- print(parent.filelist)
def open_proc(parent, file_name=None):
if file_name is None:
- file_name = QtGui.QFileDialog.getOpenFileName(parent,
+ file_name = QFileDialog.getOpenFileName(parent,
"Open processed file", filter="*.npy")
file_name = file_name[0]
try:
@@ -96,8 +97,6 @@ def open_proc(parent, file_name=None):
print("ERROR: not a processed movie file")
if good:
v = []
- nframes = 0
- #iframes = []
good = load_movies(parent, filelist=parent.filenames)
if good:
if 'fullSVD' in proc:
@@ -195,7 +194,7 @@ def load_movies(parent, filelist=None):
if filelist is not None:
parent.filelist = filelist
try:
- cumframes, Ly, Lx, v = utils.get_frame_details(parent.filelist) # v is containers/videos
+ cumframes, Ly, Lx, containers = utils.get_frame_details(parent.filelist)
nframes = cumframes[-1]
good = True
except Exception as e:
@@ -204,7 +203,7 @@ def load_movies(parent, filelist=None):
good = False
if good:
parent.reset()
- parent.video = v
+ parent.video = containers
parent.filenames = parent.filelist
parent.nframes = nframes
parent.cumframes = np.array(cumframes).astype(int)
@@ -251,9 +250,34 @@ def load_movies(parent, filelist=None):
parent.jump_to_frame()
return good
+def save_folder(parent):
+ folderName = QFileDialog.getExistingDirectory(parent,
+ "Choose save folder")
+ # load ops in same folder
+ if folderName:
+ parent.save_path = folderName
+ if len(folderName) > 30:
+ parent.savelabel.setText("..."+folderName[-30:])
+ else:
+ parent.savelabel.setText(folderName)
+
+def get_pose_file(parent):
+ # Open a folder and allow selection of multiple files with extension *.h5 only
+ # Returns a list of files
+ filelist = []
+ filelist = QFileDialog.getOpenFileNames(parent, 'Open Pose File', parent.save_path, '*.h5')
+ if filelist[0] == '':
+ return
+ else:
+ parent.poseFilepath = natsorted(filelist[0])
+ parent.poseFileLoaded = True
+ parent.load_labels()
+ parent.Labels_checkBox.setChecked(True)
+ parent.update_status_bar("Pose file(s) loaded")
+
def load_cluster_labels(parent):
try:
- file_name = QtGui.QFileDialog.getOpenFileName(parent,
+ file_name = QFileDialog.getOpenFileName(parent,
"Select cluster labels file", "", "Cluster label files (*.npy *.pkl)")[0]
extension = file_name.split(".")[-1]
if extension == "npy":
@@ -266,16 +290,16 @@ def load_cluster_labels(parent):
else:
return
except Exception as e:
- msg = QtGui.QMessageBox(parent)
- msg.setIcon(QtGui.QMessageBox.Warning)
+ msg = QMessageBox(parent)
+ msg.setIcon(QMessageBox.Warning)
msg.setText("Error: not a supported filetype")
- msg.setStandardButtons(QtGui.QMessageBox.Ok)
+ msg.setStandardButtons(QMessageBox.Ok)
msg.exec_()
print(e)
def load_umap(parent):
try:
- file_name = QtGui.QFileDialog.getOpenFileName(parent,
+ file_name = QFileDialog.getOpenFileName(parent,
"Select UMAP data file", "", "UMAP label files (*.npy *.pkl)")[0]
extension = file_name.split(".")[-1]
if extension == "npy":
@@ -287,10 +311,31 @@ def load_umap(parent):
return
return embedded_data
except Exception as e:
- msg = QtGui.QMessageBox(parent)
- msg.setIcon(QtGui.QMessageBox.Warning)
+ msg = QMessageBox(parent)
+ msg.setIcon(QMessageBox.Warning)
+ msg.setText("Error: not a supported filetype")
+ msg.setStandardButtons(QMessageBox.Ok)
+ msg.exec_()
+ print(e)
+
+def load_trace_data(parent):
+ try:
+ file_name = QFileDialog.getOpenFileName(parent,
+ "Select data file", "", "(*.npy *.pkl)")[0]
+ extension = file_name.split(".")[-1]
+ if extension == "npy":
+ dat = np.load(file_name, allow_pickle=True)
+ elif extension == "pkl":
+ with open(file_name, 'rb') as f:
+ dat = pickle.load(f)
+ else:
+ return
+ return dat
+ except Exception as e:
+ msg = QMessageBox(parent)
+ msg.setIcon(QMessageBox.Warning)
msg.setText("Error: not a supported filetype")
- msg.setStandardButtons(QtGui.QMessageBox.Ok)
+ msg.setStandardButtons(QMessageBox.Ok)
msg.exec_()
print(e)
diff --git a/facemap/menus.py b/facemap/gui/menus.py
old mode 100644
new mode 100755
similarity index 74%
rename from facemap/menus.py
rename to facemap/gui/menus.py
index 4f653f5..06bedce
--- a/facemap/menus.py
+++ b/facemap/gui/menus.py
@@ -2,29 +2,40 @@
import pyqtgraph as pg
import os
from . import guiparts, io
-from PyQt5.QtGui import QPixmap
+from PyQt5.QtGui import QPixmap, QFont, QPainterPath, QPainter, QBrush
+from PyQt5.QtWidgets import QAction, QLabel
def mainmenu(parent):
# --------------- MENU BAR --------------------------
# run suite2p from scratch
- openFile = QtGui.QAction("&Load single movie file", parent)
+ openFile = QAction("&Load single movie file", parent)
openFile.setShortcut("Ctrl+L")
openFile.triggered.connect(lambda: io.open_file(parent))
parent.addAction(openFile)
- openFolder = QtGui.QAction("Open &Folder of movies", parent)
+ openFolder = QAction("Open &Folder of movies", parent)
openFolder.setShortcut("Ctrl+F")
openFolder.triggered.connect(lambda: io.open_folder(parent))
parent.addAction(openFolder)
# load processed data
- loadProc = QtGui.QAction("Load &Processed data", parent)
+ loadProc = QAction("Load &Processed data", parent)
loadProc.setShortcut("Ctrl+P")
loadProc.triggered.connect(lambda: io.open_proc(parent))
parent.addAction(loadProc)
+ # Set output folder
+ setOutputFolder = QAction("Set &output folder", parent)
+ setOutputFolder.setShortcut("Ctrl+O")
+ setOutputFolder.triggered.connect(lambda: io.save_folder(parent))
+ parent.addAction(setOutputFolder)
+
+ loadPose = QAction("Load &pose data", parent)
+ loadPose.triggered.connect(lambda: io.get_pose_file(parent))
+ parent.addAction(loadPose)
+
# Help menu actions
- helpContent = QtGui.QAction("Help Content", parent)
+ helpContent = QAction("Help Content", parent)
helpContent.setShortcut("Ctrl+H")
helpContent.triggered.connect(lambda: launch_user_manual(parent))
parent.addAction(helpContent)
@@ -35,6 +46,8 @@ def mainmenu(parent):
file_menu.addAction(openFile)
file_menu.addAction(openFolder)
file_menu.addAction(loadProc)
+ file_menu.addAction(loadPose)
+ file_menu.addAction(setOutputFolder)
help_menu = main_menu.addMenu("&Help")
help_menu.addAction(helpContent)
@@ -49,7 +62,7 @@ def __init__(self, *args, **kwargs):
self.setFixedSize(630, 470)
icon_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mouse.png")
self.logo = QPixmap(icon_path).scaled(120, 90, QtCore.Qt.KeepAspectRatio, QtCore.Qt.SmoothTransformation)
- self.logoLabel = QtGui.QLabel(self)
+ self.logoLabel = QLabel(self)
self.logoLabel.setPixmap(self.logo)
self.logoLabel.setScaledContents(True)
self.logoLabel.move(240,10)
@@ -63,14 +76,14 @@ def __init__(self, *args, **kwargs):
self.helpText.setReadOnly(True)
def paintEvent(self, event):
- painter = QtGui.QPainter(self)
- painter.setRenderHint(QtGui.QPainter.Antialiasing)
- painter.setBrush(QtGui.QBrush(QtCore.Qt.black))
+ painter = QPainter(self)
+ painter.setRenderHint(QPainter.Antialiasing)
+ painter.setBrush(QBrush(QtCore.Qt.black))
painter.setPen(QtCore.Qt.NoPen)
- path = QtGui.QPainterPath()
- path.addText(QtCore.QPoint(235, 130), QtGui.QFont("Times", 30, QtGui.QFont.Bold), "Facemap")
+ path = QPainterPath()
+ path.addText(QtCore.QPoint(235, 130), QFont("Times", 30, QFont.Bold), "Facemap")
help_text = "Help content"
- path.addText(QtCore.QPoint(10, 150), QtGui.QFont("Times", 20), help_text)
+ path.addText(QtCore.QPoint(10, 150), QFont("Times", 20), help_text)
painter.drawPath(path)
diff --git a/facemap/gui/ops_user.npy b/facemap/gui/ops_user.npy
new file mode 100755
index 0000000..1887885
Binary files /dev/null and b/facemap/gui/ops_user.npy differ
diff --git a/facemap/labeller.py b/facemap/labeller.py
old mode 100644
new mode 100755
index cc0fd68..d5370a6
--- a/facemap/labeller.py
+++ b/facemap/labeller.py
@@ -1,25 +1,29 @@
-import sys
+import argparse
+import copy
import os
import shutil
+import sys
import time
-import numpy as np
-from PyQt5 import QtGui, QtCore, Qt, QtWidgets
-import pyqtgraph as pg
-from pyqtgraph import GraphicsScene
-from scipy.ndimage import gaussian_filter1d
-from scipy.interpolate import interp1d
-from skimage import io
-from skimage import transform, draw, measure, segmentation
import warnings
-from . import guiparts
-from guiparts import ImageDraw, RangeSlider, RGBRadioButtons, ViewBoxNoRightDrag
+from glob import glob
+
import matplotlib.pyplot as plt
-import copy
import mxnet as mx
+import numpy as np
+import pyqtgraph as pg
+from guiparts import (ImageDraw, RangeSlider, RGBRadioButtons,
+ ViewBoxNoRightDrag)
from mxnet import nd
-from glob import glob
-from natsort import natsorted
-import argparse
+from PyQt5 import Qt, QtCore, QtGui, QtWidgets
+from pyqtgraph import GraphicsScene
+from scipy.interpolate import interp1d
+from scipy.ndimage import gaussian_filter1d
+from skimage import draw, io, measure, segmentation, transform
+from PyQt5.QtWidgets import (QLabel, QPushButton, QFileDialog, QWidget, QAction, QGridLayout,
+ QSlider, QComboBox, QCheckBox)
+
+from .gui import guiparts
+
def make_bwr():
# make a bwr colormap
@@ -60,7 +64,7 @@ def __init__(self, images=None):
main_menu = self.menuBar()
file_menu = main_menu.addMenu("&File")
# load processed data
- loadImg = QtGui.QAction("&Load image (*.tif, *.png, *.jpg)", self)
+ loadImg = QAction("&Load image (*.tif, *.png, *.jpg)", self)
loadImg.setShortcut("Ctrl+L")
loadImg.triggered.connect(lambda: self.load_images(images))
file_menu.addAction(loadImg)
@@ -79,8 +83,8 @@ def __init__(self, images=None):
self.loaded = False
# ---- MAIN WIDGET LAYOUT ---- #
- self.cwidget = QtGui.QWidget(self)
- self.l0 = QtGui.QGridLayout()
+ self.cwidget = QWidget(self)
+ self.l0 = QGridLayout()
self.cwidget.setLayout(self.l0)
self.setCentralWidget(self.cwidget)
self.l0.setVerticalSpacing(4)
@@ -119,15 +123,15 @@ def make_buttons(self):
self.slider.setMaximum(255)
self.slider.setLow(0)
self.slider.setHigh(255)
- self.slider.setTickPosition(QtGui.QSlider.TicksBelow)
+ self.slider.setTickPosition(QSlider.TicksBelow)
self.l0.addWidget(self.slider, 3,0,1,1)
self.brush_size = 3
- self.BrushChoose = QtGui.QComboBox()
+ self.BrushChoose = QComboBox()
self.BrushChoose.addItems(["1","3","5","7"])
self.BrushChoose.currentIndexChanged.connect(self.brush_choose)
self.l0.addWidget(self.BrushChoose, 1, 5,1,1)
- label = QtGui.QLabel('brush size:')
+ label = QLabel('brush size:')
label.setStyleSheet('color: white;')
self.l0.addWidget(label, 0, 5,1,1)
@@ -136,13 +140,13 @@ def make_buttons(self):
self.hLine = pg.InfiniteLine(angle=0, movable=False)
# turn on crosshairs
- self.CHCheckBox = QtGui.QCheckBox('cross-hairs')
+ self.CHCheckBox = QCheckBox('cross-hairs')
self.CHCheckBox.setStyleSheet('color: white;')
self.CHCheckBox.toggled.connect(self.cross_hairs)
self.l0.addWidget(self.CHCheckBox, 1,4,1,1)
# turn off masks
- self.MCheckBox = QtGui.QCheckBox('masks on [SPACE]')
+ self.MCheckBox = QCheckBox('masks on [SPACE]')
self.MCheckBox.setStyleSheet('color: white;')
self.MCheckBox.setChecked(True)
self.MCheckBox.toggled.connect(self.masks_on)
@@ -150,25 +154,25 @@ def make_buttons(self):
self.masksOn=True
# clear all masks
- self.ClearButton = QtGui.QPushButton('clear all masks')
+ self.ClearButton = QPushButton('clear all masks')
self.ClearButton.clicked.connect(self.clear_all)
self.l0.addWidget(self.ClearButton, 1,6,1,1)
self.ClearButton.setEnabled(False)
# choose models
- self.ModelChoose = QtGui.QComboBox()
+ self.ModelChoose = QComboBox()
self.model_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'models'))
models = glob(self.model_dir+'/*')
models = [os.path.split(m)[-1] for m in models]
print(models)
self.ModelChoose.addItems(models)
self.l0.addWidget(self.ModelChoose, 1, 7,1,1)
- label = QtGui.QLabel('model: ')
+ label = QLabel('model: ')
label.setStyleSheet('color: white;')
self.l0.addWidget(label, 0, 7,1,1)
# recompute model
- self.ModelButton = QtGui.QPushButton('compute masks')
+ self.ModelButton = QPushButton('compute masks')
self.ModelButton.clicked.connect(self.compute_model)
self.l0.addWidget(self.ModelButton, 1,10,1,1)
self.ModelButton.setEnabled(False)
@@ -475,7 +479,7 @@ def initialize_images(self, image):
def load_manual(self, filename=None, image=None, image_file=None):
if filename is None:
- name = QtGui.QFileDialog.getOpenFileName(
+ name = QFileDialog.getOpenFileName(
self, "Load manual labels", filter="*_manual.npy"
)
filename = name[0]
@@ -625,7 +629,7 @@ def load_images(self, filename=None):
QtWidgets.QApplication.setOverrideCursor(QtCore.Qt.WaitCursor)
types = ["*.png","*.jpg","*.tif","*.tiff"] # supported image types
if filename is None:
- name = QtGui.QFileDialog.getOpenFileName(
+ name = QFileDialog.getOpenFileName(
self, "Load image"
)
filename = name[0]
diff --git a/facemap/mouse.png b/facemap/mouse.png
old mode 100644
new mode 100755
diff --git a/facemap/ops_user.npy b/facemap/ops_user.npy
old mode 100644
new mode 100755
index f093b5f..965df93
Binary files a/facemap/ops_user.npy and b/facemap/ops_user.npy differ
diff --git a/facemap/pose/FMnet_torch.py b/facemap/pose/FMnet_torch.py
new file mode 100755
index 0000000..6941db1
--- /dev/null
+++ b/facemap/pose/FMnet_torch.py
@@ -0,0 +1,82 @@
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Network ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+import torch
+import torch.nn as nn
+from torch import optim
+import torch.nn.functional as F
+
+class FMnet(nn.Module):
+ def __init__(self,img_ch, output_ch, labels_id, channels, device,
+ kernel=3, shape=(256,256), n_upsample=2):
+ super().__init__()
+ self.n_upsample = n_upsample
+ self.image_shape = shape
+ self.bodyparts = labels_id
+ self.device = device
+
+ self.Conv = nn.Sequential()
+ self.Conv.add_module('conv0', convblock(ch_in=img_ch,ch_out=channels[0],
+ kernel_sz=kernel, block=0))
+ for k in range(1,len(channels)):
+ self.Conv.add_module(f'conv{k}', convblock(ch_in=channels[k-1],ch_out=channels[k],
+ kernel_sz=kernel, block=k))
+
+ self.Up_conv = nn.Sequential()
+ for k in range(n_upsample):
+ self.Up_conv.add_module(f'upconv{k}', convblock(ch_in=channels[-1-k]+channels[-2-k],
+ ch_out=channels[-2-k], kernel_sz=kernel))
+
+ self.Conv2_1x1 = nn.Sequential()
+ for j in range(3):
+ self.Conv2_1x1.add_module(f'conv{j}', nn.Conv2d(channels[-2-k], output_ch, kernel_size=1,
+ padding=0))
+
+ def forward(self,x,verbose=False):
+ # encoding path
+ xout = []
+ x = self.Conv[0](x)
+ xout.append(x)
+ for k in range(1, len(self.Conv)):
+ x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
+ x = self.Conv[k](x )
+ xout.append(x)
+
+ for k in range(len(self.Up_conv)):
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
+ x = self.Up_conv[k](torch.cat((x, xout[-2-k]), axis=1))
+
+ locx = self.Conv2_1x1[1](x)
+ locy = self.Conv2_1x1[2](x)
+ hm = self.Conv2_1x1[0](x)
+ hm = F.relu(hm)
+ hm = 10 * hm / (1e-4 + hm.sum(axis=(-2,-1)).unsqueeze(-1).unsqueeze(-1))
+
+ return hm, locx, locy
+
+class convblock(nn.Module):
+ def __init__(self, ch_in, ch_out, kernel_sz, block=-1):
+ super().__init__()
+ self.conv = nn.Sequential()
+ self.block = block
+ if self.block!=0:
+ self.conv.add_module('conv_0', batchconv(ch_in, ch_out, kernel_sz))
+ else:
+ self.conv.add_module('conv_0', batchconv0(ch_in, ch_out, kernel_sz))
+ self.conv.add_module('conv_1', batchconv(ch_out, ch_out, kernel_sz))
+
+ def forward(self, x):
+ x = self.conv[1](self.conv[0](x) )
+ return x
+
+def batchconv0(ch_in, ch_out, kernel_sz):
+ return nn.Sequential(
+ nn.BatchNorm2d(ch_in, eps=1e-5, momentum = 0.1),
+ nn.Conv2d(ch_in, ch_out, kernel_sz, padding=kernel_sz//2, bias=False),
+ )
+
+def batchconv(ch_in, ch_out, sz):
+ return nn.Sequential(
+ nn.BatchNorm2d(ch_in, eps=1e-5, momentum = 0.1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(ch_in, ch_out, sz, padding=sz//2, bias=False),
+ )
+
diff --git a/facemap/pose/models.py b/facemap/pose/models.py
new file mode 100755
index 0000000..ff959dd
--- /dev/null
+++ b/facemap/pose/models.py
@@ -0,0 +1,55 @@
+"""
+Facemap model trained for generating pose estimates. Contains functions for:
+- downloading pre-trained models
+- Model class
+"""
+import os
+from pathlib import Path
+from urllib.parse import urlparse
+from urllib.request import urlretrieve
+
+MODEL_PARAMS_URL = "https://www.facemappy.org/models/facemap_model_params.pth"
+MODEL_STATE_URL = "https://www.facemappy.org/models/facemap_model_state.pt"
+
+def get_data_dir():
+ """
+ Get the path to the data directory.
+ """
+ current_workdir = os.getcwd()
+ model_dir = os.path.join(current_workdir, "facemap", "pose")
+ # Change model directory to path object
+ model_dir = Path(model_dir)
+ return model_dir
+
+def get_model_params_path():
+ """
+ Get the path to the model parameters file.
+ """
+ model_dir = get_data_dir()
+ cached_params_file = str(model_dir.joinpath("facemap_model_params.pth"))
+ if not os.path.exists(cached_params_file):
+ download_url_to_file(MODEL_PARAMS_URL, cached_params_file)
+ return cached_params_file
+
+def get_model_state_path():
+ """
+ Get the path to the model state file.
+ """
+ model_dir = get_data_dir()
+ cached_state_file = str(model_dir.joinpath("facemap_model_state.pt"))
+ if not os.path.exists(cached_state_file):
+ download_url_to_file(MODEL_STATE_URL, cached_state_file)
+ return cached_state_file
+
+def download_url_to_file(url, filename):
+ """
+ Download a file from a URL to a local file.
+ """
+ # Check if file already exists
+ if os.path.exists(filename):
+ return
+
+ # Download file
+ print("Downloading %s to %s" % (url, filename))
+ urlretrieve(url, filename)
+
diff --git a/facemap/pose/pose.py b/facemap/pose/pose.py
new file mode 100755
index 0000000..853cdf1
--- /dev/null
+++ b/facemap/pose/pose.py
@@ -0,0 +1,215 @@
+import os
+import time
+from sklearn.covariance import log_likelihood
+
+from tqdm import tqdm
+
+import numpy as np
+import pandas as pd
+import torch
+import pickle
+from io import StringIO
+
+from .. import utils
+from . import FMnet_torch, pose_helper_functions as pose_utils
+from . import transforms, models
+
+"""
+Base class for generating pose estimates.
+Contains functions that can be used through CLI or GUI
+Currently supports single video processing and multi-videos as processed sequentially.
+"""
+
+class Pose():
+ def __init__(self, filenames=None, bbox=[], bbox_set=False, gui=None, GUIobject=None):
+ self.gui = gui
+ self.GUIobject = GUIobject
+ if self.gui is not None:
+ self.filenames = self.gui.filenames
+ else:
+ self.filenames = filenames
+ self.cumframes, self.Ly, self.Lx, self.containers = utils.get_frame_details(self.filenames)
+ self.nframes = self.cumframes[-1]
+ self.pose_labels = None
+ self.bodyparts = None
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.bbox = bbox
+ self.bbox_set = bbox_set
+
+ def run(self, plot=True):
+ start_time = time.time()
+ self.net = self.load_model()
+ # Predict and save pose
+ if not self.bbox_set:
+ resize = True
+ for i in range(len(self.Ly)):
+ x1, x2, y1, y2 = 0, self.Ly[i], 0, self.Lx[i]
+ self.bbox.append([x1, x2, y1, y2, resize])
+ prompt = "No bbox set. Using entire frame view: {} and resize={}".format(self.gui.bbox, resize)
+ utils.update_mainwindow_message(MainWindow=self.gui, GUIobject=self.GUIobject,
+ prompt=prompt, hide_progress=True)
+ self.bbox_set = True
+ for video_id in range(len(self.bbox)):
+ utils.update_mainwindow_message(MainWindow=self.gui, GUIobject=self.GUIobject,
+ prompt="Processing video: {}".format(self.filenames[0][video_id]), hide_progress=True)
+ pred_data, metadata = self.predict_landmarks(video_id)
+ dataFrame = self.write_dataframe(pred_data)
+ savepath = self.save_pose_prediction(dataFrame, video_id)
+ utils.update_mainwindow_message(MainWindow=self.gui, GUIobject=self.GUIobject,
+ prompt="Saved pose prediction outputs to: {}".format(savepath), hide_progress=True)
+ print("Saved pose prediction outputs to:", savepath)
+ # Save metadata to a pickle file
+ metadata_file = os.path.splitext(savepath)[0]+"_Facemap_metadata.pkl"
+ with open(metadata_file, 'wb') as f:
+ pickle.dump(metadata, f, pickle.HIGHEST_PROTOCOL)
+ if self.gui is not None:
+ self.gui.poseFilepath.append(savepath)
+ self.gui.Labels_checkBox.setChecked(True)
+ self.gui.start()
+ if plot and self.gui is not None:
+ self.plot_pose_estimates()
+ end_time = time.time()
+ print("Time elapsed:", end_time-start_time, "seconds")
+ utils.update_mainwindow_message(MainWindow=self.gui, GUIobject=self.GUIobject,
+ prompt="Time elapsed: {} seconds".format(end_time-start_time), hide_progress=True)
+
+ def write_dataframe(self, data):
+ scorer = "Facemap"
+ bodyparts = self.net.bodyparts
+ # Create an empty dataframe
+ for index, bodypart in enumerate(bodyparts):
+ columnindex = pd.MultiIndex.from_product(
+ [[scorer], [bodypart], ["x", "y", "likelihood"]],
+ names=["scorer", "bodyparts", "coords"])
+ frame = pd.DataFrame(
+ np.nan,
+ columns=columnindex,
+ index=np.arange(self.cumframes[-1]))
+ if index == 0:
+ dataFrame = frame
+ else:
+ dataFrame = pd.concat([dataFrame, frame], axis=1)
+
+ # Fill dataframe with data
+ dataFrame.iloc[:,::3] = data[:,:,0].cpu().numpy()
+ dataFrame.iloc[:,1::3] = data[:,:,1].cpu().numpy()
+ dataFrame.iloc[:,2::3] = data[:,:,2].cpu().numpy()
+
+ return dataFrame
+
+ def predict_landmarks(self, video_id):
+ """
+ Predict labels for all frames in video and save output as .h5 file
+ """
+ nchannels = 1
+ if torch.cuda.is_available():
+ batch_size = 1
+ else:
+ batch_size = 1
+
+ # Create array for storing predictions
+ pred_data = torch.zeros(self.cumframes[-1], len(self.net.bodyparts), 3)
+
+ # Store predictions in dataframe
+ self.net.eval()
+ start = 0
+ end = batch_size
+ Xstart, Xstop, Ystart, Ystop, resize = self.bbox[video_id]
+ inference_time = 0
+
+ progress_output = StringIO()
+ with tqdm(total=self.cumframes[-1], unit='frame', unit_scale=True, file=progress_output) as pbar:
+ while start != self.cumframes[-1]: # for analyzing entire video
+
+ # Pre-pocess images
+ imall = np.zeros((end-start, nchannels, self.Ly[video_id], self.Lx[video_id]))
+ cframes = np.arange(start, end)
+ utils.get_frames(imall, self.containers, cframes, self.cumframes)
+
+ # Inference time includes: pre-processing, inference, post-processing
+ t0 = time.time()
+ imall = torch.from_numpy(imall).to(self.net.device, dtype=torch.float32)
+ frame_grayscale = transforms.crop_resize(imall, Ystart, Ystop,
+ Xstart, Xstop, resize).clone().detach()
+ imall = transforms.preprocess_img(frame_grayscale)
+
+ # Network prediction
+ Xlabel, Ylabel, likelihood = pose_utils.get_predicted_landmarks(self.net, imall,
+ batchsize=batch_size, smooth=False)
+
+ # Get adjusted landmarks that fit to original image size
+ Xlabel, Ylabel = transforms.labels_crop_resize(Xlabel, Ylabel,
+ Ystart, Xstart,
+ current_size=(256, 256),
+ desired_size=(self.bbox[video_id][1]-self.bbox[video_id][0],
+ self.bbox[video_id][3]-self.bbox[video_id][2]))
+
+ # Assign predictions to dataframe
+ pred_data[start:end, :, 0] = Xlabel
+ pred_data[start:end, :, 1] = Ylabel
+ pred_data[start:end, :, 2] = likelihood
+ inference_time += time.time() - t0
+
+ pbar.update(batch_size)
+ start = end
+ end += batch_size
+ end = min(end, self.cumframes[-1])
+ # Update progress bar for every 5% of the total frames
+ if (end) % np.floor(self.cumframes[-1]*.05) == 0:
+ utils.update_mainwindow_progressbar(MainWindow=self.gui,
+ GUIobject=self.GUIobject, s=progress_output,
+ prompt="Pose prediction progress:")
+
+ if batch_size == 1:
+ inference_speed = self.cumframes[-1] / inference_time
+ print("Inference speed:", inference_speed, "fps")
+
+ metadata = {"batch_size": batch_size,
+ "image_size": (self.Ly, self.Lx),
+ "bbox": self.bbox[video_id],
+ "total_frames": self.cumframes[-1],
+ "bodyparts": self.net.bodyparts,
+ "inference_speed": inference_speed,
+ }
+ return pred_data, metadata
+
+ def save_pose_prediction(self, dataFrame, video_id):
+ # Save prediction to .h5 file
+ if self.gui is not None:
+ basename = self.gui.save_path
+ _, filename = os.path.split(self.filenames[0][video_id])
+ videoname, _ = os.path.splitext(filename)
+ else:
+ basename, filename = os.path.split(self.filenames[0][video_id])
+ videoname, _ = os.path.splitext(filename)
+ poseFilepath = os.path.join(basename, videoname+"_FacemapPose.h5")
+ dataFrame.to_hdf(poseFilepath, "df_with_missing", mode="w")
+ return poseFilepath
+
+ def plot_pose_estimates(self):
+ # Plot labels
+ self.gui.poseFileLoaded = True
+ self.gui.load_labels()
+ self.gui.Labels_checkBox.setChecked(True)
+
+ def load_model(self):
+ """
+ Load pre-trained UNet model for labels prediction
+ """
+ model_params_file = models.get_model_params_path()
+ model_state_file = models.get_model_state_path()
+ if torch.cuda.is_available():
+ print("Using cuda as device")
+ else:
+ print("Using cpu as device")
+ print("LOADING MODEL....", model_params_file)
+ model_params = torch.load(model_params_file, map_location=self.device)
+ self.bodyparts = model_params['params']['bodyparts']
+ channels = model_params['params']['channels']
+ kernel_size = 3
+ nout = len(self.bodyparts) # number of outputs from the model
+ net = FMnet_torch.FMnet(img_ch=1, output_ch=nout, labels_id=self.bodyparts,
+ channels=channels, kernel=kernel_size, device=self.device)
+ net.load_state_dict(torch.load(model_state_file, map_location=self.device))
+ net.to(self.device);
+ return net
diff --git a/facemap/pose/pose_gui.py b/facemap/pose/pose_gui.py
new file mode 100755
index 0000000..9ce3103
--- /dev/null
+++ b/facemap/pose/pose_gui.py
@@ -0,0 +1,170 @@
+import numpy as np
+import pyqtgraph as pg
+from PyQt5 import QtWidgets
+from PyQt5.QtWidgets import (
+ QDialog,
+ QPushButton)
+
+from .. import roi
+from .pose import Pose
+from . import transforms
+
+from .. import utils
+
+"""
+Pose subclass for generating obtaining bounding box from user input.
+Currently supports single video processing only.
+"""
+class PoseGUI(Pose):
+ def __init__(self, gui=None):
+ self.gui = gui
+ super(PoseGUI, self).__init__(gui=self.gui)
+ self.bbox_set = False
+ self.bbox = []
+ self.cancel = False
+
+ # Draw box on GUI using user's input
+ def draw_user_bbox(self):
+ """
+ Function for user to draw a bbox
+ """
+ # Get sample frame from each video in case of multiple videos
+ sample_frame = utils.get_frame(0, self.nframes, self.cumframes, self.containers)
+ last_video=False
+ for video_id, frame in enumerate(sample_frame):
+ # Trigger new window for ROI selection of each frame
+ if video_id == len(sample_frame)-1:
+ last_video = True
+ ROI_popup(frame, video_id, self.gui, self, last_video)
+ return self.bbox, self.bbox_set, self.cancel
+
+ def adjust_bbox_params(self):
+ # This function adjusts bbox so that it is of minimum dimension: 256,256
+ sample_frame = utils.get_frame(0, self.nframes, self.cumframes, self.containers)
+ for i, bbox in enumerate(self.bbox):
+ x1, x2, y1, y2, resize = transforms.get_crop_resize_params(sample_frame[i],
+ x_dims=(bbox[0], bbox[1]),
+ y_dims=(bbox[2], bbox[3]))
+ self.bbox[i] = [x1, x2, y1, y2, resize]
+ print("user selected bbox after adjustment:", self.bbox)
+
+ def plot_bbox_roi(self):
+ self.adjust_bbox_params()
+ for i, bbox in enumerate(self.bbox):
+ x1, x2, y1, y2, _ = bbox
+ dy, dx = y2-y1, x2-x1
+ xrange = np.arange(y1+self.gui.sx[i], y2+self.gui.sx[i]).astype(np.int32)
+ yrange = np.arange(x1+self.gui.sy[i], x2+self.gui.sy[i]).astype(np.int32)
+ x1, y1 = yrange[0], xrange[0]
+ self.gui.add_ROI(roitype=4+1, roistr="bbox_{}".format(i), moveable=False, resizable=False,
+ pos=(x1, y1, dx, dy), ivid=i, yrange=yrange, xrange=xrange)
+ self.bbox_set = True
+
+class ROI_popup(QDialog):
+ def __init__(self, frame, video_id, gui, pose, last_video):
+ super().__init__()
+ self.gui = gui
+ self.frame = frame
+ self.pose = pose
+ self.last_video = last_video
+ self.setWindowTitle('Select ROI for video: '+str(video_id))
+
+ # Add image and ROI bbox
+ self.verticalLayout = QtWidgets.QVBoxLayout(self)
+ self.win = pg.GraphicsLayoutWidget()
+ self.win.setObjectName("Dialog "+str(video_id+1))
+ ROI_win = self.win.addViewBox(invertY=True)
+ self.img = pg.ImageItem(self.frame)
+ ROI_win.addItem(self.img)
+ self.roi = pg.RectROI([0,0],[100,100],pen=pg.mkPen('r',width=2), movable=True,resizable=True)
+ ROI_win.addItem(self.roi)
+ self.win.show()
+ self.verticalLayout.addWidget(self.win)
+
+ # Add buttons to dialog box
+ self.done_button = QPushButton('Done')
+ self.done_button.setDefault(True)
+ self.done_button.clicked.connect(self.done_exec)
+ self.cancel_button = QPushButton('Cancel')
+ self.cancel_button.clicked.connect(self.cancel_exec)
+ # Add a next button to the dialog box horizontally centered with cancel button and done button
+ self.next_button = QPushButton('Next')
+ self.next_button.setDefault(True)
+ self.next_button.clicked.connect(self.next_exec)
+ # Add a skip button to the dialog box horizontally centered with cancel button and done button
+ self.skip_button = QPushButton('Skip')
+ self.skip_button.setDefault(True)
+ self.skip_button.clicked.connect(self.skip_exec)
+
+ # Position buttons
+ self.widget = QtWidgets.QWidget(self)
+ self.horizontalLayout = QtWidgets.QHBoxLayout(self.widget)
+ self.horizontalLayout.setContentsMargins(-1, -1, -1, 0)
+ self.horizontalLayout.setObjectName("horizontalLayout")
+ self.horizontalLayout.addWidget(self.cancel_button)
+ self.horizontalLayout.addWidget(self.skip_button)
+ if self.last_video:
+ self.horizontalLayout.addWidget(self.done_button)
+ else:
+ self.horizontalLayout.addWidget(self.next_button)
+ self.verticalLayout.addWidget(self.widget)
+
+ self.exec_()
+
+ def get_coordinates(self):
+ roi_tuple, _ = self.roi.getArraySlice(self.frame, self.img, returnSlice=False)
+ (x1, x2), (y1, y2) = roi_tuple[0], roi_tuple[1]
+ return (x1, x2), (y1, y2)
+
+ def skip_exec(self):
+ self.pose.bbox = []
+ self.pose.bbox_set = False
+ self.close()
+
+ def next_exec(self):
+ (x1, x2), (y1, y2) = self.get_coordinates()
+ self.pose.bbox.append([x1, x2, y1, y2, False])
+ self.close()
+
+ def cancel_exec(self):
+ self.pose.cancel = True
+ self.close()
+
+ def done_exec(self):
+ # User finished drawing ROI
+ (x1, x2), (y1, y2) = self.get_coordinates()
+ self.pose.bbox.append([x1, x2, y1, y2, False])
+ self.pose.plot_bbox_roi()
+ self.close()
+
+# Following used to check cropped sections of frames
+class test_popup(QDialog):
+ def __init__(self, frame, gui):
+ super().__init__(gui)
+ self.gui = gui
+ self.frame = frame
+
+ self.setWindowTitle('Chosen ROI')
+ self.verticalLayout = QtWidgets.QVBoxLayout(self)
+
+ # Add image and ROI bbox
+ self.win = pg.GraphicsLayoutWidget()
+ ROI_win = self.win.addViewBox(invertY=True)
+ self.img = pg.ImageItem(self.frame)
+ ROI_win.addItem(self.img)
+ self.win.show()
+ self.verticalLayout.addWidget(self.win)
+
+ self.cancel_button = QPushButton('Cancel')
+ self.cancel_button.clicked.connect(self.close)
+
+ # Position buttons
+ self.widget = QtWidgets.QWidget(self)
+ self.horizontalLayout = QtWidgets.QHBoxLayout(self.widget)
+ self.horizontalLayout.setContentsMargins(-1, -1, -1, 0)
+ self.horizontalLayout.setObjectName("horizontalLayout")
+ self.horizontalLayout.addWidget(self.cancel_button)
+ self.verticalLayout.addWidget(self.widget)
+
+ self.show()
+
diff --git a/facemap/pose/pose_helper_functions.py b/facemap/pose/pose_helper_functions.py
new file mode 100755
index 0000000..4f3f1bf
--- /dev/null
+++ b/facemap/pose/pose_helper_functions.py
@@ -0,0 +1,113 @@
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Import packages ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+import numpy as np
+
+print('numpy version: %s'%np.__version__)
+import cv2 # opencv
+import torch # pytorch
+import os # file path stuff
+import random
+from glob import glob # listing files
+from platform import python_version
+
+import pandas as pd
+from tqdm import tqdm # waitbar
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import optim
+from scipy.ndimage import gaussian_filter
+
+print("python version:", python_version())
+
+#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~Global variables~~~~~~~~~~~~~~~~~~~~~~~~~~~~~`
+N_FACTOR = 2**4 // (2 ** 2)
+SIGMA = 3 * 4 / N_FACTOR
+Lx = 64
+print("Global varaibles set:")
+print("N_FACTOR:", N_FACTOR)
+print("SIGMA:", SIGMA)
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Helper functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+def set_seed(seed):
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+
+def normalize_mean(in_img):
+ zz = in_img.astype('float')
+ # subtract mean for each img.
+ mm = zz.mean(axis=(2,3))
+ xx = zz - mm[:, :, np.newaxis, np.newaxis]
+ return xx
+
+def normalize99(X):
+ """ normalize image so 0.0 is 1st percentile and 1.0 is 99th percentile """
+ x01 = torch.quantile(X, .01)
+ x99 = torch.quantile(X, .99)
+ X = (X - x01) / (x99 - x01)
+ return X
+
+def get_predicted_landmarks(net, im_input, batchsize=1, smooth=True):
+
+ xmesh, ymesh = np.meshgrid(torch.arange(net.image_shape[0]/N_FACTOR),
+ torch.arange(net.image_shape[1]/N_FACTOR))
+ ymesh = torch.from_numpy(ymesh).to(net.device)
+ xmesh = torch.from_numpy(xmesh).to(net.device)
+
+ # Predict
+ with torch.no_grad():
+ if im_input.ndim == 3:
+ im_input = im_input[np.newaxis, ...]
+ hm_pred, locx_pred, locy_pred = net(im_input)
+
+ hm_pred = hm_pred.squeeze()
+ locx_pred = locx_pred.squeeze()
+ locy_pred = locy_pred.squeeze()
+
+ if smooth:
+ hm_smo = gaussian_filter(hm_pred.cpu().numpy(), [0, 1, 1])
+ hm_smo = hm_smo.reshape(hm_smo.shape[0], hm_smo.shape[1], Lx*Lx)
+ imax = torch.argmax(hm_smo, -1)
+ likelihood = torch.diag(hm_smo[:,:,imax])
+ else:
+ hm_pred = hm_pred.reshape(hm_pred.shape[0], Lx*Lx)
+ imax = torch.argmax(hm_pred, 1)
+ likelihood = torch.diag(hm_pred[:,imax])
+
+ # this part computes the position error on the training set
+ locx_pred = locx_pred.reshape(locx_pred.shape[0], Lx*Lx)
+ locy_pred = locy_pred.reshape(locy_pred.shape[0], Lx*Lx)
+
+ nn = hm_pred.shape[0]
+ x_pred = ymesh.flatten()[imax] - (2*SIGMA) * locx_pred[torch.arange(nn), imax]
+ y_pred = xmesh.flatten()[imax] - (2*SIGMA) * locy_pred[torch.arange(nn), imax]
+
+ return y_pred*N_FACTOR, x_pred*N_FACTOR, likelihood
+
+def add_motion_blur(img, kernel_size=None, vertical=True, horizontal=True):
+ # Create the vertical kernel.
+ kernel_v = np.zeros((kernel_size, kernel_size))
+
+ # Create a copy of the same for creating the horizontal kernel.
+ kernel_h = np.copy(kernel_v)
+
+ # Fill the middle row with ones.
+ kernel_v[:, int((kernel_size - 1)/2)] = np.ones(kernel_size)
+ kernel_h[int((kernel_size - 1)/2), :] = np.ones(kernel_size)
+
+ # Normalize.
+ kernel_v /= kernel_size
+ kernel_h /= kernel_size
+
+ if vertical:
+ # Apply the vertical kernel.
+ img = cv2.filter2D(img, -1, kernel_v)
+ if horizontal:
+ # Apply the horizontal kernel.
+ img = cv2.filter2D(img, -1, kernel_h)
+
+ return img
+
+
diff --git a/facemap/pose/transforms.py b/facemap/pose/transforms.py
new file mode 100755
index 0000000..7955e5d
--- /dev/null
+++ b/facemap/pose/transforms.py
@@ -0,0 +1,226 @@
+"""
+Facemap functions for:
+- bounding box: (suggested ROI) for UNet input images
+- image preprocessing
+- image augmentation
+"""
+import cv2
+import numpy as np
+import torch
+import torchvision
+import torchvision.transforms as transforms
+from torch.nn import functional as F
+
+from . import pose_helper_functions
+
+normalize = transforms.Normalize(mean=[0.445], std=[0.269])
+
+def preprocess_img(im):
+ """
+ Preproccesing of image involves: conversion to float32 in range0-1, normalize99, and padding image size to be
+ compatible with UNet model input i.e. divisible by 16
+ Parameters
+ -------------
+ im: ND-array
+ image of size [(Lz) x Ly x Lx]
+ Returns
+ --------------
+ im: ND-array
+ preprocessed image of size [1 x Ly x Lx] if input dimensions==2, else [Lz x Ly x Lx]
+ """
+ if im.ndim==2:
+ im = im[np.newaxis, ...]
+ # Adjust image contrast
+ im = pose_helper_functions.normalize99(im)
+ return im
+
+def get_cropped_imgs(imgs, bbox):
+ """
+ Preproccesing of image involves: conversion to float32 in range0-1, normalize99, and padding image size to be
+ compatible with UNet model input
+ Parameters
+ -------------
+ imgs: ND-array
+ images of size [batch_size x nchan x Ly x Lx]
+ bbox: tuple of size (4,)
+ bounding box positions in order x1, x2, y1, y2
+ Returns
+ --------------
+ cropped_imgs: ND-array
+ images of size [batch_size x nchan x Ly' x Lx'] where Ly' = y2-y1 and Lx'=x2-x1
+ """
+ x1, x2, y1, y2 = (np.round(bbox)).astype(int)
+ batch_size = imgs.shape[0]
+ nchannels = imgs.shape[1]
+ cropped_imgs = np.empty((batch_size, nchannels, x2-x1, y2-y1))
+ for i in range(batch_size):
+ for n in range(nchannels):
+ cropped_imgs[i,n] = imgs[i, n, x1:x2, y1:y2]
+ return cropped_imgs
+
+def get_crop_resize_params(img, x_dims, y_dims, xy=(256,256)):
+ """
+ Get cropped and resized image dimensions
+ Input:-
+ img: image
+ x_dims: min,max x pos
+ y_dims: min,max y pos
+ xy: final (desired) image size
+ Output:-
+ Xstart: (int) x dim start pos
+ Xstop: (int) x dim stop pos
+ Ystart: (int) y dim start pos
+ Ystop: (int) y dim stop pos
+ resize: (bool) whether to resize image
+ """
+ Xstart = int(x_dims[0])
+ Xstop = int(x_dims[1])
+ Ystart = int(y_dims[0])
+ Ystop = int(y_dims[1])
+
+ resize = False
+ if abs(Ystop-Ystart) > xy[0]: # if cropped image larger than desired size
+ # crop image then resize image and landmarks
+ resize = True
+ else: # if cropped image smaller than desired size then add padding accounting for labels in view
+ y_pad = abs(abs(Ystop-Ystart) - xy[0])
+ if y_pad % 2 == 0:
+ y_pad = y_pad//2
+ Ystart, Ystop = Ystart-y_pad, Ystop+y_pad
+ else: # odd number division so add 1
+ y_pad = y_pad//2
+ Ystart, Ystop = Ystart-y_pad, Ystop+y_pad+1
+
+ if abs(Xstop-Xstart) > xy[1]: # if cropped image larger than desired size
+ resize = True
+ else:
+ x_pad = abs(abs(Xstop-Xstart) - xy[1])
+ if x_pad % 2 == 0 :
+ x_pad = x_pad//2
+ Xstart, Xstop = Xstart-x_pad, Xstop+x_pad
+ else:
+ x_pad = x_pad//2
+ Xstart, Xstop = Xstart-x_pad, Xstop+x_pad+1
+
+ if Ystop > img.shape[1]:
+ Ystart -= (Ystop - img.shape[1])
+ if Xstop > img.shape[0]:
+ Xstart -= (Xstop - img.shape[0])
+
+ Ystop, Xstop = min(Ystop, img.shape[1]), min(Xstop, img.shape[0])
+ Ystart, Xstart = max(0, Ystart), max(0, Xstart)
+ Ystop, Xstop = max(Ystop, xy[0]), max(Xstop, xy[1])
+
+ return Xstart, Xstop, Ystart, Ystop, resize
+
+def crop_resize(img, Xstart, Xstop, Ystart, Ystop, resize, xy=[256,256]):
+ """
+ Crop and resize image using dimensions provided
+ Input:-
+ img: (2D array) image
+ Xstart: (int) x dim start pos
+ Xstop: (int) x dim stop pos
+ Ystart: (int) y dim start pos
+ Ystop: (int) y dim stop pos
+ resize: (bool) whether to resize image
+ Output:-
+ im_cropped: (2D array) cropped image
+ """
+ # Crop image and landmarks
+ im_cropped = img[:,:,Ystart:Ystop,Xstart:Xstop]
+ # Resize image
+ if resize:
+ im_cropped = F.interpolate(im_cropped, size=(256,256), mode='bilinear')
+ return im_cropped
+
+def labels_crop_resize(Xlabel, Ylabel, Xstart, Ystart, current_size, desired_size):
+ """
+ Adjust x,y labels on a 2D image to perform a resize operation
+ Parameters
+ -------------
+ Xlabel: ND-array
+ Ylabel: ND-array
+ current_size: tuple or array of size(2,)
+ desired_size: tuple or array of size(2,)
+ Returns
+ --------------
+ Xlabel: ND-array
+ adjusted x values on new/desired_size of image
+ Ylabel: ND-array
+ adjusted y values on new/desired_size of image
+ """
+ #Xlabel, Ylabel = Xlabel.astype(float), Ylabel.astype(float)
+ Xlabel *= (desired_size[1]/current_size[1]) # x_scale
+ Ylabel *= (desired_size[0]/current_size[0]) # y_scale
+ Xlabel = Xlabel+Xstart
+ Ylabel = Ylabel+Ystart
+ return Xlabel, Ylabel
+
+def adjust_bbox(prev_bbox, img_yx, div=16, extra=1):
+ """
+ Takes a bounding box as an input and the original image size. Adjusts bounding box to be square
+ instead of a rectangle. Uses longest dimension of prev_bbox for final image size that cannot
+ exceed img_yx
+ Parameters
+ -------------
+ prev_bbox: tuple of size (4,)
+ bounding box positions in order x1, x2, y1, y2
+ img_yx: tuple of size (2,)
+ image size for y and x dimensions
+ Returns
+ --------------
+ bbox: tuple of size (4,)
+ bounding box positions in order x1, x2, y1, y2
+ """
+ x1, x2, y1, y2 = np.round(prev_bbox)
+ xdim, ydim = (x2-x1), (y2-y1)
+
+ # Pad bbox dimensions to be divisible by div
+ Lpad = int(div * np.ceil(xdim/div) - xdim)
+ xpad1 = extra*div//2 + Lpad//2
+ xpad2 = extra*div//2 + Lpad - Lpad//2
+ Lpad = int(div * np.ceil(ydim/div) - ydim)
+ ypad1 = extra*div//2 + Lpad//2
+ ypad2 = extra*div//2+Lpad - Lpad//2
+
+ x1, x2, y1, y2 = x1-xpad1, x2+xpad2, y1-ypad1, y2+ypad2
+ xdim = min(x2-x1, img_yx[1])
+ ydim = min(y2-y1, img_yx[0])
+
+ # Choose largest dimension for image size
+ if xdim > ydim:
+ # Adjust ydim
+ ypad = xdim-ydim
+ if ypad%2!=0:
+ ypad+=1
+ y1 = max(0, y1-ypad//2)
+ y2 = min(y2+ypad//2, img_yx[0])
+ else:
+ # Adjust xdim
+ xpad = ydim-xdim
+ if xpad%2!=0:
+ xpad+=1
+ x1 = max(0, x1-xpad//2)
+ x2 = min(x2+xpad//2, img_yx[1])
+ adjusted_bbox = (x1, x2, y1, y2)
+ return adjusted_bbox
+
+# Following Function adopted from cellpose:
+# https://github.com/MouseLand/cellpose/blob/35c16c94e285a4ec2fa17f148f06bbd414deb5b8/cellpose/transforms.py#L187
+def normalize99(img):
+ """
+ Normalize image so 0.0 is 1st percentile and 1.0 is 99th percentile
+ Parameters
+ -------------
+ img: ND-array
+ image of size [Ly x Lx]
+ Returns
+ --------------
+ X: ND-array
+ normalized image of size [Ly x Lx]
+ """
+ X = img.copy()
+ x01 = np.percentile(X, 1)
+ x99 = np.percentile(X, 99)
+ X = (X - x01) / (x99 - x01)
+ return X
diff --git a/facemap/process.py b/facemap/process.py
old mode 100644
new mode 100755
index e413e0f..873969d
--- a/facemap/process.py
+++ b/facemap/process.py
@@ -1,15 +1,18 @@
-import numpy as np
-from facemap import pupil, running, utils
-from numba import vectorize,uint8,float32
-import time
-import os, sys, subprocess
+import os
import pdb
+import subprocess
+import sys
+import time
from io import StringIO
+
+import numpy as np
+from numba import float32, uint8, vectorize
from scipy import io
from scipy.ndimage import gaussian_filter
-import cv2
from tqdm import tqdm
+from facemap import pupil, running, utils
+
def binned_inds(Ly, Lx, sbin):
Lyb = np.zeros((len(Ly),), np.int32)
Lxb = np.zeros((len(Ly),), np.int32)
@@ -72,8 +75,7 @@ def subsampled_mean(containers, cumframes, Ly, Lx, sbin=3, GUIobject=None, MainW
imbin = np.abs(np.diff(imbin, axis=0))
avgmotion[ir[n]] += imbin.mean(axis=0)
ns+=1
- update_mainwindow(MainWindow, GUIobject, s, "Computing subsampled mean ")
-
+ utils.update_mainwindow_progressbar(MainWindow, GUIobject, s, "Computing subsampled mean ")
avgframe /= float(ns)
avgmotion /= float(ns)
@@ -195,7 +197,7 @@ def compute_SVD(containers, cumframes, Ly, Lx, avgframe, avgmotion, motSVD=True,
ncb = usv[0].shape[-1]
U_mov[wmot[i]+1][:, ni_mov[wmot[i]+1]:ni_mov[wmot[i]+1]+ncb] = usv[0] * usv[1]#U[wmot[i]+1][:, ni[wmot[i]+1]:ni[wmot[i]+1]+ncb] = usv[0]
ni_mov[wmot[i]+1] += ncb
- update_mainwindow(MainWindow, GUIobject, w, "Computing SVD ")
+ utils.update_mainwindow_progressbar(MainWindow, GUIobject, w, "Computing SVD ")
if fullSVD:
if motSVD:
@@ -400,15 +402,9 @@ def process_ROIs(containers, cumframes, Ly, Lx, avgframe, avgmotion, U_mot, U_mo
if n==0:
vproj = np.concatenate((vproj[0,:][np.newaxis, :], vproj), axis=0)
V_mov[0][t:t+vproj.shape[0], :] = vproj
- update_mainwindow(MainWindow, GUIobject, s, "Computing projection ")
+ utils.update_mainwindow_progressbar(MainWindow, GUIobject, s, "Computing projection ")
return V_mot, V_mov, M, pups, blinks, runs
-
-def update_mainwindow(MainWindow, GUIobject, s, prompt):
- if MainWindow is not None and GUIobject is not None:
- message = s.getvalue().split('\x1b[A\n\r')[0].split('\r')[-1]
- MainWindow.update_status_bar(prompt+message, update_progress=True)
- GUIobject.QApplication.processEvents()
def process_pupil_ROIs(self, t, nt, img, ivid, rois, pupind, pups):
"""
@@ -486,7 +482,7 @@ def save(proc, savepath=None):
del d2
return savename
-def run(filenames, motSVD=True, movSVD=False, GUIobject=None, parent=None, proc=None, savepath=None):
+def run(filenames, sbin=1, motSVD=True, movSVD=False, GUIobject=None, parent=None, proc=None, savepath=None):
'''
Parameters
----------
@@ -512,11 +508,11 @@ def run(filenames, motSVD=True, movSVD=False, GUIobject=None, parent=None, proc=
save_mat = parent.save_mat.isChecked()
sy = parent.sy
sx = parent.sx
- motSVD, movSVD = parent.motSVD_checkbox.isChecked(), parent.movSVD_checkbox.isChecked(),
+ motSVD, movSVD = parent.motSVD_checkbox.isChecked(), parent.movSVD_checkbox.isChecked()
else:
cumframes, Ly, Lx, containers = utils.get_frame_details(filenames)
if proc is None:
- sbin = 1
+ sbin = sbin
fullSVD = True
save_mat = False
rois=None
@@ -640,4 +636,4 @@ def run(filenames, motSVD=True, movSVD=False, GUIobject=None, parent=None, proc=
GUIobject.QApplication.processEvents()
tqdm.write('run time %0.2fs'%(time.time() - start))
- return savename
\ No newline at end of file
+ return savename
diff --git a/facemap/pupil.py b/facemap/pupil.py
old mode 100644
new mode 100755
index fe82a4c..3d62a96
--- a/facemap/pupil.py
+++ b/facemap/pupil.py
@@ -1,6 +1,7 @@
import numpy as np
from scipy.ndimage import gaussian_filter
+
def fit_gaussian(im, sigma=2.0, do_xy=False, missing=None):
''' iterative fitting of pupil with gaussian @ sigma '''
ix,iy = im.nonzero()
diff --git a/facemap/registration.py b/facemap/registration.py
old mode 100644
new mode 100755
index 73b0f19..87bb410
--- a/facemap/registration.py
+++ b/facemap/registration.py
@@ -1,24 +1,21 @@
-import numpy as np
-import matplotlib.pyplot as plt
+import time
+from math import pi
+
import matplotlib.cm
+import matplotlib.pyplot as plt
+import numpy as np
import scipy.stats
-from scipy.ndimage import filters
-from math import pi
-import skimage.transform
import skimage.registration
+import skimage.transform
import sklearn.cluster
-import time
-import skimage.transform
-import skimage.registration
-from . import utils, process
-
+from scipy.ndimage import filters
+from . import process, utils
'''
MOTION TRACES
'''
-
def imall_init(nfr, Ly, Lx):
imall = []
for n in range(len(Ly)):
diff --git a/facemap/roi.py b/facemap/roi.py
old mode 100644
new mode 100755
index dea0a59..c16e9db
--- a/facemap/roi.py
+++ b/facemap/roi.py
@@ -1,17 +1,25 @@
-import sys
import os
import shutil
+import sys
import time
+
import numpy as np
-from PyQt5 import QtGui, QtCore
import pyqtgraph as pg
+from matplotlib import cm
+from PyQt5 import QtCore
from pyqtgraph import GraphicsScene
-from facemap import utils, pupil
-from scipy.stats import zscore, skew
from scipy.ndimage import gaussian_filter
-from matplotlib import cm
+from scipy.stats import skew, zscore
+
+from facemap import pupil, utils
-colors = np.array([[0,200,50],[180,0,50],[40,100,250],[150,50,150]])
+# Types of ROI and their ID:
+# 0: Pupil
+# 1: motion SVD
+# 2: Blink
+# 3: Running
+# 4: Pose bbox
+colors = np.array([[0,200,50],[180,0,50],[40,100,250],[150,50,150],[0, 255, 255]])
class reflectROI():
def __init__(self, iROI, wROI, moveable=True,
@@ -105,7 +113,7 @@ def position(self, parent):
parent.ROIs[self.iROI].plot(parent)
class sROI():
- def __init__(self, rind, rtype, iROI, moveable=True,
+ def __init__(self, rind, rtype, iROI, moveable=True, resizable=True,
parent=None, saturation=None, color=None, pos=None,
yrange=None, xrange=None,
ivid=None, pupil_sigma=None):
@@ -113,6 +121,7 @@ def __init__(self, rind, rtype, iROI, moveable=True,
self.iROI = iROI
self.rind = rind
self.rtype = rtype
+ self.pos = pos
if saturation is None:
self.saturation = 0
else:
@@ -127,7 +136,8 @@ def __init__(self, rind, rtype, iROI, moveable=True,
else:
self.pupil_sigma = 0
self.moveable = moveable
- if pos is None:
+ self.resizable = resizable
+ if self.pos is None:
view = parent.p0.viewRange()
imx = (view[0][1] + view[0][0]) / 2
imy = (view[1][1] + view[1][0]) / 2
@@ -138,10 +148,10 @@ def __init__(self, rind, rtype, iROI, moveable=True,
imx = imx - dx / 2
imy = imy - dy / 2
else:
- imy = pos[0]
- imx = pos[1]
- dy = pos[2]
- dx = pos[3]
+ imy = self.pos[0]
+ imx = self.pos[1]
+ dy = self.pos[2]
+ dx = self.pos[3]
if ivid is None:
self.ivid=0
else:
@@ -157,14 +167,14 @@ def __init__(self, rind, rtype, iROI, moveable=True,
def draw(self, parent, imy, imx, dy, dx):
roipen = pg.mkPen(self.color, width=3,
style=QtCore.Qt.SolidLine)
- if self.rind==1 or self.rind==3:
+ if self.rind==1 or self.rind==3 or self.rind==4:
self.ROI = pg.RectROI(
- [imx, imy], [dx, dy], movable = self.moveable,
+ [imx, imy], [dx, dy], movable = self.moveable, resizable=self.resizable,
pen=roipen, sideScalers=True, removable=self.moveable
)
else:
self.ROI = pg.EllipseROI(
- [imx, imy], [dx, dy], movable = self.moveable,
+ [imx, imy], [dx, dy], movable = self.moveable, resizable=self.resizable,
pen=roipen, removable=self.moveable
)
self.ROI.handleSize = 8
@@ -191,6 +201,8 @@ def position(self, parent):
sizex, sizey = self.ROI.size()
xrange = (np.arange(-1 * int(sizex), 1) + int(posx)).astype(np.int32)
yrange = (np.arange(-1 * int(sizey), 1) + int(posy)).astype(np.int32)
+ self.pos = posy, posx, posy+sizey, posx+sizex
+ #self.pos = (posy, posx, posy+sizey, posx+sizex) # get ROI position
if self.rind==0 or self.rind==2:
yrange += int(sizey/2)
# what is ellipse circling?
@@ -225,14 +237,14 @@ def position(self, parent):
yrange -= parent.sy[ivid]
self.xrange = xrange
self.yrange = yrange
- self.ivid = ivid
+ self.ivid = ivid
if self.rind==0:
self.rmin = 0
parent.reflectors[self.iROI] = utils.get_reflector(parent.ROIs[self.iROI].yrange,
parent.ROIs[self.iROI].xrange,
rROI=parent.rROI[self.iROI])
- parent.sl[1].setValue(parent.saturation[self.iROI] * 100 / 255)
+ parent.sl[1].setValue(int(parent.saturation[self.iROI] * 100 / 255))
index = parent.clusteringVisComboBox.findText("ROI", QtCore.Qt.MatchFixedString)
if index >= 0:
@@ -339,7 +351,7 @@ def plot(self, parent):
parent.show()
parent.online_plotted = True
#self.p2.setLimits(xMin=0,xMax=self.nframes)
- elif self.rind==1 or self.rind==3:
+ elif self.rind==1 or self.rind==3 or self.rind==4:
parent.pROI.removeItem(parent.scatter)
parent.scatter = pg.ScatterPlotItem([0], [0], pen='k', symbol='+')
parent.pROI.addItem(parent.scatter)
@@ -363,3 +375,4 @@ def plot(self, parent):
padding=0.0)
parent.win.show()
parent.show()
+
diff --git a/facemap/running.py b/facemap/running.py
old mode 100644
new mode 100755
index 68fd329..454ecc0
--- a/facemap/running.py
+++ b/facemap/running.py
@@ -1,11 +1,10 @@
# outputs the dx, dy offsets between frames by registering frame N to frame
# N-1. If the movement is larger than half the frame size, outputs NaN.
# ops.yrange, xrange are ranges to use for rectangular section of movie
-from scipy.fftpack import next_fast_len
import numpy as np
-from numpy.fft import ifftshift
-from mkl_fft import fft2, ifft2
-from numba import vectorize, float32, complex64, uint8, int16
+from numba import complex64, float32, int16, uint8, vectorize
+from numpy.fft import fft2, ifft2, ifftshift
+from scipy.fftpack import next_fast_len
eps0 = 1e-20
diff --git a/facemap/utils.py b/facemap/utils.py
old mode 100644
new mode 100755
index 0276ef1..3e6254b
--- a/facemap/utils.py
+++ b/facemap/utils.py
@@ -1,12 +1,23 @@
-import numpy as np
-from scipy.sparse.linalg import eigsh
import cv2
-from scipy.ndimage import gaussian_filter1d
+import numpy as np
from scipy.interpolate import interp1d
from scipy.linalg import eigh
+from scipy.ndimage import gaussian_filter1d
+from scipy.sparse.linalg import eigsh
from sklearn.decomposition import PCA
from tqdm import tqdm
+def update_mainwindow_progressbar(MainWindow, GUIobject, s, prompt):
+ if MainWindow is not None and GUIobject is not None:
+ message = s.getvalue().split('\x1b[A\n\r')[0].split('\r')[-1]
+ MainWindow.update_status_bar(prompt+message, update_progress=True, hide_progress=False)
+ GUIobject.QApplication.processEvents()
+
+def update_mainwindow_message(MainWindow, GUIobject, prompt, hide_progress=True):
+ if MainWindow is not None and GUIobject is not None:
+ MainWindow.update_status_bar(prompt, update_progress=False, hide_progress=hide_progress)
+ GUIobject.QApplication.processEvents()
+
def bin1d(X, tbin):
""" bin over first axis of data with bin tbin """
size = list(X.shape)
@@ -22,18 +33,37 @@ def split_testtrain(n_t, frac=0.25):
itest = (ninds[:,np.newaxis] + np.arange(0,n_len * frac,1,int)).flatten()
itrain = np.ones(n_t,np.bool)
itrain[itest] = 0
-
return itest, itrain
-
-def rrr_prediction(X, Y, rank=None, lam=0):
+def get_frame(cframe, nframes, cumframes, containers):
+ cframe = np.maximum(0, np.minimum(nframes-1, cframe))
+ cframe = int(cframe)
+ try:
+ ivid = (cumframes < cframe).nonzero()[0][-1]
+ except:
+ ivid = 0
+ img = []
+ for vs in containers[ivid]:
+ frame_ind = cframe - cumframes[ivid]
+ capture = vs
+ if int(capture.get(cv2.CAP_PROP_POS_FRAMES)) != frame_ind:
+ capture.set(cv2.CAP_PROP_POS_FRAMES, frame_ind)
+ ret, frame = capture.read()
+ if ret:
+ img.append(frame)
+ else:
+ print("Error reading frame")
+ return img
+
+def rrr_prediction(X, Y, rank=None, lam=0, itrain=None, itest=None):
""" predict Y from X using regularized reduced rank regression
returns prediction accuracy on test data + model params
"""
n_t, n_feats = Y.shape
- itest, itrain = split_testtrain(n_t)
+ if itrain is None and itest is None:
+ itest, itrain = split_testtrain(n_t)
A,B = reduced_rank_regression(X[itrain], Y[itrain], rank=rank, lam=lam)
rank = A.shape[1]
corrf = np.zeros((rank, n_feats))
@@ -196,7 +226,6 @@ def get_frames(imall, containers, cframes, cumframes):
for ii,im in enumerate(imall):
imall[ii] = im[:nk].copy()
-
def close_videos(containers):
''' Method is called to close all videos/containers open for reading
using openCV.
@@ -361,6 +390,7 @@ def video_placement(Ly, Lx):
def svdecon(X, k=100):
np.random.seed(0) # Fix seed to get same output for eigsh
+ """
v0 = np.random.uniform(-1,1,size=min(X.shape))
NN, NT = X.shape
if NN>NT:
@@ -379,4 +409,6 @@ def svdecon(X, k=100):
else:
V = (U.T @ X).T
V = V/(V**2).sum(axis=0)**.5
+ """
+ U, Sv, V = PCA(n_components=k, svd_solver='randomized', random_state=np.random.RandomState(0))._fit(X)
return U, Sv, V
diff --git a/figs/mouse_face0_keypoints.png b/figs/mouse_face0_keypoints.png
new file mode 100644
index 0000000..a708444
Binary files /dev/null and b/figs/mouse_face0_keypoints.png differ
diff --git a/figs/mouse_face1_keypoints.png b/figs/mouse_face1_keypoints.png
new file mode 100644
index 0000000..942a210
Binary files /dev/null and b/figs/mouse_face1_keypoints.png differ
diff --git a/figs/mouse_views.png b/figs/mouse_views.png
new file mode 100644
index 0000000..95c1139
Binary files /dev/null and b/figs/mouse_views.png differ
diff --git a/figs/tracker.gif b/figs/tracker.gif
new file mode 100644
index 0000000..d0be1af
Binary files /dev/null and b/figs/tracker.gif differ
diff --git a/notebooks/process.ipynb b/notebooks/process.ipynb
new file mode 100644
index 0000000..809f4aa
--- /dev/null
+++ b/notebooks/process.ipynb
@@ -0,0 +1,204 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Facemap"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Function call `process.run()` saves a `.npy` file that contains the following variables:\n",
+ "- filenames: list of lists of video filenames - each list are the videos taken simultaneously\n",
+ "- Ly, Lx: list of number of pixels in Y (Ly) and X (Lx) for each video taken simultaneously\n",
+ "- sbin: spatial bin size for motion SVDs\n",
+ "- Lybin, Lxbin: list of number of pixels binned by sbin in Y (Ly) and X (Lx) for each video taken simultaneously\n",
+ "- sybin, sxbin: coordinates of multivideo (for plotting/reshaping ONLY)\n",
+ "- LYbin, LXbin: full-size of all videos embedded in rectangle (binned)\n",
+ "- fullSVD: (bool) whether or not \"multivideo SVD\" is computed\n",
+ "- save_mat: (bool) whether or not to save proc as *.mat file\n",
+ "- avgframe: list of average frames for each video from a subset of frames (binned by sbin)\n",
+ "- avgframe_reshape: average frame reshaped to be y-pixels x x-pixels\n",
+ "- avgmotion: list of average motions for each video from a subset of frames (binned by sbin)\n",
+ "- avgmotion_reshape: average motion reshaped to be y-pixels x x-pixels\n",
+ "- motion: list of absolute motion energies across time - first is \"multivideo\" motion energy (empty if not computed)\n",
+ "- motSVD: list of motion SVDs - first is \"multivideo SVD\" (empty if not computed) - each is nframes x components\n",
+ "- motMask: list of motion masks for each motion SVD - each motMask is pixels x components\n",
+ "- motMask_reshape: motion masks reshaped to be y-pixels x x-pixels x components\n",
+ "- pupil: list of pupil ROI outputs - each is a dict with 'area', 'area_smooth', and 'com' (center-of-mass)\n",
+ "- blink: list of blink ROI outputs - each is nframes, the blink area on each frame\n",
+ "- running: list of running ROI outputs - each is nframes x 2, for X and Y motion on each frame\n",
+ "- rois: ROIs that were drawn and computed\n",
+ " - rind: type of ROI in number\n",
+ " - rtype: what type of ROI ('motion SVD', 'pupil', 'blink', 'running')\n",
+ " - ivid: in which video is the ROI\n",
+ " - color: color of ROI\n",
+ " - yrange: y indices of ROI\n",
+ " - xrange: x indices of ROI\n",
+ "saturation: saturation of ROI (0-255)\n",
+ "pupil_sigma: number of stddevs used to compute pupil radius (for pupil ROIs)\n",
+ "yrange_bin: binned indices in y (if motion SVD)\n",
+ "xrange_bin: binned indices in x (if motion SVD)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The above variables are related to motion energy, which uses the absolute value of differences in frames over time i.e. abs(np.diff(frame$_{t+1}$ - frame$_{t}$)). To perform SVD computation for each frame over time use the flag `movSVD=True` [default=False] in the `process.run()` function call. Variables pertaining to movie SVDs include:\n",
+ "- movSVD: list of movie SVDs - first is \"multivideo SVD\" (empty if not computed) - each is nframes x components\n",
+ "- movMask: list of movie masks for each movie SVD - each movMask is pixels x components\n",
+ "- movMask_reshape: movie masks reshaped to be y-pixels x x-pixels x components\n",
+ "
New variables:\n",
+ "- motSv: array containign singular values for motSVD\n",
+ "- movSv: array containign singular values for movSVD"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "`process.run()` function call takes the following parameters:\n",
+ "- filenames: A 2D list of names of video(s) to get\n",
+ "- motSVD: default=True\n",
+ "- movSVD: default=False\n",
+ "- GUIobject=None\n",
+ "- parent: default=None, parent is from GUI\n",
+ "- proc: default=None, proc can be a saved ROI file from GUI \n",
+ "- savepath: default=None => set to video folder, specify a folder path in which to save _proc.npy "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Import packages"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from facemap import process\n",
+ "from glob import glob"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Set variables"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Example file list:\n",
+ " - cam1_G7c1_1.avi\n",
+ " - cam1_G7c1_2.avi\n",
+ " - cam2_G7c1_1.avi\n",
+ " - cam2_G7c1_2.avi\n",
+ " - cam3_G7c1_1.avi\n",
+ " - cam3_G7c1_2.avi"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "simultaneous_video_list = [['.../cam1_G7c1_1.avi',\n",
+ " '.../cam2_G7c1_1.avi',\n",
+ " '.../cam3_G7c1_1.avi',\n",
+ " '.../cam4_G7c1_1.avi']]\n",
+ "sequential_video_list = [['.../cam1_G7c1_1.avi', '.../cam1_G7c1_2.avi'],\n",
+ " ['.../cam2_G7c1_1.avi', '.../cam2_G7c1_2.avi'],\n",
+ " ['.../cam3_G7c1_1.avi', '.../cam3_G7c1_2.avi'],\n",
+ " ['.../cam4_G7c1_1.avi', '.../cam4_G7c1_2.avi']]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Process videos recorded simultaneously from different cam/views"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "savename = process.run(simultaneous_video_list)\n",
+ "print(\"Output saved in\", savename)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Process videos recorded sequentially"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "savename = process.run(sequential_video_list)\n",
+ "print(\"Output saved in\", savename)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Process videos from multiple sessions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "session_folders = [‘location1’, ‘location2’, …, ’locationN’]\n",
+ "for indexSession, folder in enumerate(session_folders):\n",
+ " video_files = glob(folder+\"/*.ext\") # replace .ext with one of ['*.mj2','*.mp4','*.mkv','*.avi','*.mpeg','*.mpg','*.asf']\n",
+ " process.run(video_files)\n",
+ " # if SVDs of ROIs is required, use 'save ROIs' from GUI and use the following command\n",
+ " process.run(video_files, proc=\"/path_to_saved_rois\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/tutorial.ipynb b/notebooks/tutorial.ipynb
old mode 100644
new mode 100755
similarity index 99%
rename from tutorial.ipynb
rename to notebooks/tutorial.ipynb
index 55e4ff1..900de3a
--- a/tutorial.ipynb
+++ b/notebooks/tutorial.ipynb
@@ -1138,7 +1138,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.7.6"
+ "version": "3.8.3"
}
},
"nbformat": 4,
diff --git a/setup.py b/setup.py
index 689c877..cf855c3 100644
--- a/setup.py
+++ b/setup.py
@@ -1,10 +1,24 @@
import setuptools
+install_deps = ['numpy>=1.16',
+ 'scipy',
+ 'matplotlib',
+ 'natsort',
+ 'tqdm',
+ 'numba>=0.43.1',
+ 'opencv-python-headless',
+ 'hdbscan',
+ 'torch>=1.9',
+ 'umap-learn',
+ 'pandas',
+ 'scikit-image']
+
with open("README.md", "r") as fh:
long_description = fh.read()
setuptools.setup(
name="facemap",
+ license="GPLv3",
version="0.2.0",
author="Carsen Stringer & Atika Syeda & Renee Tung",
author_email="carsen.stringer@gmail.com",
@@ -13,8 +27,15 @@
long_description_content_type="text/markdown",
url="https://github.com/MouseLand/FaceMap",
packages=setuptools.find_packages(),
- install_requires = ['pyqtgraph==0.11.0rc0', 'PyQt5', 'PyQt5.sip',
- 'numpy>=1.13.0', 'scipy', 'numba', 'natsort'],
+ install_requires = install_deps,
+ tests_require = ['pytest', 'tqdm'],
+ extras_require = {
+ 'gui': [
+ 'pyqtgraph==0.12.0',
+ 'pyqt5',
+ 'pyqt5.sip',
+ ]
+ }
include_package_data=True,
classifiers=(
"Programming Language :: Python :: 3",
diff --git a/tests/conftest.py b/tests/conftest.py
index 8f3a781..c769b4b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,3 +1,4 @@
+from genericpath import exists
import pytest
import os, sys, tempfile, shutil
from tqdm import tqdm
@@ -6,13 +7,12 @@
@pytest.fixture()
def video_names():
- video1_names = ['cam1_test.avi']
- video2_names = ['cam2_test.avi']
- return video1_names, video2_names
+ video1_name = 'cam1_test.avi'
+ video2_name = 'cam2_test.avi'
+ return video1_name, video2_name
@pytest.fixture()
def data_dir(video_names):
- #data_dir = os.path.join(os.getcwd(),'Video_samples/sample_movies/')
fm_dir = Path.home().joinpath('.facemap')
fm_dir.mkdir(exist_ok=True)
data_dir = fm_dir.joinpath('data')
@@ -30,8 +30,22 @@ def data_dir(video_names):
cached_file = str(data_dir_cam2.joinpath(video_name))
if not os.path.exists(cached_file):
download_url_to_file(url, cached_file)
+
return data_dir
+@pytest.fixture()
+def expected_output_dir(data_dir):
+ expected_output_dir = data_dir.joinpath('expected_output')
+ expected_output_dir.mkdir(exist_ok=True)
+ # Download expected output files
+ """
+ download_url_to_file('https://www.facemappy.org/test_data/singlevideo_proc.npy',
+ expected_output_dir.joinpath('singlevideo_proc.npy'))
+ download_url_to_file('https://www.facemappy.org/test_data/multivideo_proc.npy',
+ expected_output_dir.joinpath('multivideo_proc.npy'))
+ """
+ return expected_output_dir
+
def download_url_to_file(url, dst, progress=True):
# Following adapted from https://github.com/MouseLand/cellpose/blob/35c16c94e285a4ec2fa17f148f06bbd414deb5b8/cellpose/utils.py#L45
"""Download object at the given URL to a local path.
diff --git a/tests/test_output.py b/tests/test_output.py
index 9245cbc..ea1faa6 100644
--- a/tests/test_output.py
+++ b/tests/test_output.py
@@ -1,4 +1,5 @@
"Test facemap pipeline by comparing outputs"
+from numpy.lib.npyio import save
from facemap import process
import numpy as np
from pathlib import Path
@@ -6,47 +7,59 @@
r_tol, a_tol = 1e-2, 1e-2
-def test_output_single_video(data_dir, video_names):
+def test_output_single_video(data_dir, video_names, expected_output_dir):
clear_output(data_dir, video_names)
v1, _ = video_names
- test_filenames = [[str(data_dir.joinpath('data').joinpath('cam1').joinpath(v1))]] # [[data_dir+video for video in v1]]
- save_path = str(data_dir.joinpath('data').joinpath('cam1'))
- process.run(test_filenames, movSVD=True, savepath=save_path)
+ test_filenames = [[str(data_dir.joinpath('cam1').joinpath(v1))]] # [[data_dir+video for video in v1]]
+ save_path = str(data_dir.joinpath('cam1'))
+ output_filename, _ = v1.split(".")
+ test_proc_filename = os.path.join(save_path,output_filename+"_proc.npy")
+ # Process video
+ process.run(test_filenames, sbin=7, motSVD=True, movSVD=True, savepath=save_path)
- output_filename, _ = os.path.splitext(v1[0])
- test_proc_filename = save_path.joinpath(output_filename+"_proc.npy")
+ # Compare output
output = np.load(test_proc_filename,allow_pickle=True).item()
- expected_proc_filename = os.getcwd()+"/tests/expected_output/singlevideo_proc.npy"
+ expected_proc_filename = expected_output_dir.joinpath("singlevideo_proc.npy")
expected_output = np.load(expected_proc_filename,allow_pickle=True).item()
clear_output(data_dir, video_names)
assert is_output_correct(output, expected_output)
-def test_output_multivideo(data_dir, video_names):
+def test_output_multivideo(data_dir, video_names, expected_output_dir):
clear_output(data_dir, video_names)
v1, v2 = video_names
- test1 = str(data_dir.joinpath('data').joinpath('cam1').joinpath(v1))#os.path.join(data_dir,v1[0])
- test2 = str(data_dir.joinpath('data').joinpath('cam2').joinpath(v2))#os.path.join(data_dir,v2[0])
+ test1 = str(data_dir.joinpath('cam1').joinpath(v1))
+ test2 = str(data_dir.joinpath('cam2').joinpath(v2))
+
# For videos recorded simultaneously from multiple cams
test_filenames = [[test1, test2]]
- save_path = str(data_dir.joinpath('data').joinpath('cam2'))
- process.run(test_filenames, movSVD=True, savepath=save_path)
+ save_path = str(data_dir.joinpath('cam2'))
+ output_filename, _ = v1.split(".")
+ test_proc_filename = os.path.join(save_path, output_filename+"_proc.npy")
+ print(test_proc_filename)
+ # Process videos
+ process.run(test_filenames, sbin=12, motSVD=True, movSVD=True, savepath=save_path)
- output_filename, _ = os.path.splitext(v2[0])
- test_proc_filename = save_path.joinpath(output_filename+"_proc.npy")
+ # Compare output
output = np.load(test_proc_filename,allow_pickle=True).item()
- expected_proc_filename = os.getcwd()+"/tests/expected_output/multivideo_proc.npy"
+ expected_proc_filename = expected_output_dir.joinpath("multivideo_proc.npy")
expected_output = np.load(expected_proc_filename,allow_pickle=True).item()
clear_output(data_dir, video_names)
-
+
assert is_output_correct(output, expected_output)
+ clear_expected_output(expected_output_dir)
def is_output_correct(test_output, expected_output):
params_match = check_params(test_output, expected_output)
+ print("params match", params_match)
frames_match = check_frames(test_output, expected_output)
+ print("frames_match", frames_match)
motion_match = check_motion(test_output, expected_output)
+ print("motion_match", motion_match)
U_match = check_U(test_output, expected_output)
+ print("U_match", U_match)
V_match = check_V(test_output, expected_output)
+ print("V_match", V_match)
return params_match and frames_match and motion_match and U_match and V_match
def check_params(test_output, expected_output):
@@ -62,7 +75,6 @@ def check_params(test_output, expected_output):
return all_outputs_match
def check_frames(test_output, expected_output):
- print(test_output['avgframe'][0].shape, expected_output['avgframe'][0].shape)
avgframes_match = np.allclose(test_output['avgframe'][0], expected_output['avgframe'][0],
rtol=r_tol, atol=a_tol)
avgmotion_match = np.allclose(test_output['avgmotion'][0], expected_output['avgmotion'][0],
@@ -102,3 +114,9 @@ def clear_output(data_dir, video_names):
output = name + '_proc.npy'
if os.path.exists(output):
os.remove(output)
+
+def clear_expected_output(expected_output_dir):
+ files = ['singlevideo_proc.npy', 'multivideo_proc.npy']
+ for f in files:
+ if os.path.exists(expected_output_dir.joinpath(f)):
+ os.remove(expected_output_dir.joinpath(f))
diff --git a/tox.ini b/tox.ini
index 4d69aa8..b84af20 100644
--- a/tox.ini
+++ b/tox.ini
@@ -26,7 +26,8 @@ passenv =
NUMPY_EXPERIMENTAL_ARRAY_FUNCTION
PYVISTA_OFF_SCREEN
deps =
+ .[all]
pytest # https://docs.pytest.org/en/latest/contents.html
pytest-cov # https://pytest-cov.readthedocs.io/en/latest/
pytest-xvfb ; sys_platform == 'linux'
-commands = pytest -v --color=yes --cov=facemap --cov-report=xml
\ No newline at end of file
+commands = pytest -v --color=yes --cov=facemap --cov-report=xml