Skip to content

Commit

Permalink
Enable CUDA backend
Browse files Browse the repository at this point in the history
  • Loading branch information
vicentebolea committed Jan 31, 2024
1 parent bb15d68 commit c40706a
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 8 deletions.
1 change: 1 addition & 0 deletions nrtk_explorer/app/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
r"""
Define your classes and create the instances that you need to expose
"""

import logging
from trame.app import get_server
from trame.ui.quasar import QLayout
Expand Down
1 change: 1 addition & 0 deletions nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
r"""
Define your classes and create the instances that you need to expose
"""

import logging
from typing import Dict

Expand Down
44 changes: 40 additions & 4 deletions nrtk_explorer/library/embeddings_extractor.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,43 @@
from nrtk_explorer.library import images_manager

import logging
import numpy as np
import warnings

import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

with warnings.catch_warnings():
warnings.simplefilter("ignore")
warnings.simplefilter("ignore", category=UserWarning)
import timm
import numpy as np
import torch


class EmbeddingsExtractor:
def __init__(self, model_name="resnet50d", manager=None):
self.images = dict()
self.features = dict()
self.model = model_name
if manager is not None:
self.manager = manager
else:
self.manager = images_manager.ImagesManager()

if torch.cuda.is_available():
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
self.device = torch.device("cuda")
logging.info("Using CUDA devices for feature extraction")
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
self.device = torch.device("cpu")
logging.info("Using CPU devices for feature extraction")

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
self.model = model_name

@property
def model(self):
return self._model
Expand All @@ -26,6 +46,10 @@ def model(self):
def model(self, model_name):
# Create model but do not train it
model = timm.create_model(model_name, pretrained=True, num_classes=0)

# Copy the model to the requested device
model = model.to(self.device)

for param in model.parameters():
param.requires_grad = False

Expand All @@ -45,7 +69,19 @@ def extract(self, paths, cache=True, content=None):
img = content[path]
else:
img = self.manager.LoadImage(path)
features = self.model(self.transforms(img).unsqueeze(0))

img_transformation = self.transforms(img).unsqueeze(0)

# Copy image to device if using device
if self.device.type == "cuda":
img_transformation = img_transformation.cuda()

features = self.model(img_transformation)

# Copy output to cpu if using device
if self.device.type == "cuda":
features = features.cpu()

self.features[path] = features[0]
requested_features.append(self.features[path])

Expand Down
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ dependencies = [
"numpy",
"pandas",
"Pillow",
"plotly",
"scikit-image==0.22.0",
"scikit-learn==1.4.0",
"smqtk-classifier==0.19.0",
Expand All @@ -39,17 +38,16 @@ dependencies = [
"smqtk-descriptors==0.19.0",
"smqtk-detection[torch,centernet]==0.20.1",
"smqtk-image-io==0.17.1",
"tensorflow[and-cuda]>=2.15.0",
"timm",
"torch",
"torchvision",
"trame",
"trame-client>=2.15.0",
"trame-server>=2.15.0",
"trame-plotly",
"trame-quasar",
"trame-server>=2.15.0",
"ubelt==1.3.4",
"umap-learn",
"xaitk-saliency==0.7.0",
]

[project.optional-dependencies]
Expand Down

0 comments on commit c40706a

Please sign in to comment.