Skip to content

Commit

Permalink
add demo_patchvq
Browse files Browse the repository at this point in the history
  • Loading branch information
baidut committed Apr 13, 2022
1 parent 5ec3152 commit 191715b
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 6 deletions.
198 changes: 198 additions & 0 deletions demo_patchvq.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "demo_patchvq.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# Set up environment"
],
"metadata": {
"id": "EIMzjDkAbioJ"
}
},
{
"cell_type": "markdown",
"source": [
"| torch.__version__ | torchvision.__version__ | fastai.__version__ | |\n",
"| ----------------- | ----------------------- | ------------------ | --------------------------------------------------------- |\n",
"| 1.11.0 | 0.12.0 | 2.5.6 | ✓ |\n",
"| 1.9.1, 1.10.0 | 0.10.1 | 2.5.3 | ✓ |\n",
"| 1.6.0 | 0.7.0 | 2.0.18 | ✗ (use [this](https://github.com/baidut/PatchVQ) instead) |\n",
"\n"
],
"metadata": {
"id": "PA1nRYPfdEfW"
}
},
{
"cell_type": "markdown",
"source": [
"[Optional] Create a virtual environment\n",
"\n",
"1. install [miniconda](https://docs.conda.io/en/latest/miniconda.html)\n",
"1. `conda create --name venv python=3.9`\n",
"1. `conda activate venv`\n",
"\n",
"Install packages:\n",
"\n",
"1. [Pytorch](https://pytorch.org/get-started/locally/) \n",
"2. `pip install jupyter timm seaborn wandb logru fastai`\n"
],
"metadata": {
"id": "1t7NV4aAdYG1"
}
},
{
"cell_type": "code",
"source": [
"from fastiqa.basics import *"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "u6SG-FXJdWiW",
"outputId": "6a5702c3-e367-4d3a-ea4f-6c0d53da27a6"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"3.9.12 (main, Apr 5 2022, 06:56:58) \n",
"[GCC 7.5.0]\n",
"fastai.__version__(>= 2.5.3): 2.5.6\n",
"fastcore.__version__: 1.4.2\n",
"torch.__version__(>= 1.9.1): 1.11.0 w/ cuda \n",
"torchvision.__version__(>= 0.10.1): 0.12.0\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Test PatchVQ on your dataset"
],
"metadata": {
"id": "8h2ze4s8Y3-R"
}
},
{
"cell_type": "markdown",
"source": [
"```\n",
"database_folder\n",
"├── labels.csv\n",
"├── dbinfo.json\n",
"├── jpg\n",
"│ ├── video1_image_folder\n",
"│ ├── video2_image_folder\n",
"│ └── ...\n",
"|\n",
"```"
],
"metadata": {
"id": "uft_fjPsjkPG"
}
},
{
"cell_type": "markdown",
"source": [
"## Prepare dataset\n",
"\n",
"1. Extract video frames\n",
"2. Under your dataset folder, put a json file (`dbinfo.json`) containing its key information:\n",
"\n",
"```json\n",
"{\n",
" \"__name__\": \"LIVE_VQC\", # the name of the database\n",
" \"csv_labels\": \"labels.csv\", # path to the label CSV file\n",
" \"fn_col\": \"name\", # filename column in the CSV table \n",
" \"label_col\": \"mos\", # label column in the CSV table\n",
" \"folder\": \"jpg\" # the folder containing extracted video frames\n",
"}\n",
"```"
],
"metadata": {
"id": "POQcjUU1Y85x"
}
},
{
"cell_type": "markdown",
"source": [
"## Extract features"
],
"metadata": {
"id": "2SUqexqdbayN"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9JkWMtvXYww-"
},
"outputs": [],
"source": [
"from fastiqa.patchvq.main import *\n",
"dbinfo = load_dbinfo(\"/path/to/dbinfo.json\")\n",
"\n",
"model = PatchVQ()\n",
"model.extractFeatures(dbinfo)"
]
},
{
"cell_type": "markdown",
"source": [
"The extracted features will be saved under `features` folder\n",
"\n",
"```\n",
"database_folder\n",
"├── labels.csv\n",
"├── dbinfo.json\n",
"├── jpg\n",
"│ ├── video1_image_folder\n",
"│ ├── video2_image_folder\n",
"│ └── ...\n",
"|\n",
"├── features \n",
"│ ├── r3d18\n",
"│ └── r3d18_pooled\n",
"│ ├── paq2piq\n",
"│ └── paq2piq_pooled\n",
"\n",
"```"
],
"metadata": {
"id": "hxKRf_tyjs0M"
}
},
{
"cell_type": "markdown",
"source": [
"## Extract Scores (coming soon):"
],
"metadata": {
"id": "NUoWPzicfTQP"
}
}
]
}
7 changes: 4 additions & 3 deletions fastiqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from functools import partial
import sys; print (sys.version)
import fastai; print(f'fastai.__version__(expected 2.5.3): {fastai.__version__}')
import torch; print(f"torch.__version__(expected 1.9.1): {torch.__version__} {'w/' if torch.cuda.is_available() else 'w/o'} cuda ")
import torchvision; print(f'torchvision.__version__(expected 0.10.1): {torchvision.__version__}')
import fastai; print(f'fastai.__version__(>= 2.5.3): {fastai.__version__}')
import fastcore; print(f'fastcore.__version__: {fastcore.__version__}')
import torch; print(f"torch.__version__(>= 1.9.1): {torch.__version__} {'w/' if torch.cuda.is_available() else 'w/o'} cuda ")
import torchvision; print(f'torchvision.__version__(>= 0.10.1): {torchvision.__version__}')
6 changes: 3 additions & 3 deletions fastiqa/patchvq/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,12 @@ def bunch(self, dls, bs=128):
self.n_out = 1
return dls

def extractFeatures(self, featname, dbinfo):
def extractFeatures(self, dbinfo):
model_state = load_state_dict_from_url('https://github.com/baidut/PatchVQ/releases/download/v0.1/RoIPoolModel-fit.10.bs.120.pth')
self.roipool('paq2piq', dbinfo, backbone=resnet18, model_state=model_state)
self.soipool('paq2piq', dbinfo)

self.roipool('r3d18', dbinfo, backbone=r3d18_K_200ep, batch_size=1)

self.soipool('paq2piq', dbinfo)
self.soipool('r3d18', dbinfo)


Expand Down

0 comments on commit 191715b

Please sign in to comment.