Skip to content

Commit

Permalink
Allow user-defined model class in feature extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr-Milk committed Jun 12, 2024
1 parent fca901e commit c1e8c4a
Showing 1 changed file with 39 additions and 28 deletions.
67 changes: 39 additions & 28 deletions lazyslide/tl/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def feature_extraction(
wsi: WSI,
model: str | Any,
create_opts: dict = None,
model_func: Callable = None,
transform: Callable = None,
scriptable: bool = True,
compile: bool = True,
Expand All @@ -58,30 +59,33 @@ def feature_extraction(
except ImportError:
raise ImportError("Feature extraction requires pytorch and timm (optional).")

model_path = Path(model)
model_name = model_path.stem
if model_path.exists():
try:
model = torch.load(model)
except: # noqa: E722
model = torch.jit.load(model)
else:
try:
import timm
except ImportError:
raise ImportError("Using model from model market requires timm.")
try:
create_opts = {} if create_opts is None else create_opts
model_name = model
model = timm.create_model(
model, pretrained=True, scriptable=scriptable, **create_opts
)
if transform is None:
# data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
# transform = timm.data.create_transform(**data_cfg)
transform = get_default_transform()
except Exception as _: # noqa: E722
raise ValueError(f"Model {model} not found.")
try:
model_path = Path(model)
model_name = model_path.stem
if model_path.exists():
try:
model = torch.load(model)
except: # noqa: E722
model = torch.jit.load(model)
else:
try:
import timm
except ImportError:
raise ImportError("Using model from model market requires timm.")
try:
create_opts = {} if create_opts is None else create_opts
model_name = model
model = timm.create_model(
model, pretrained=True, scriptable=scriptable, **create_opts
)
if transform is None:
# data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
# transform = timm.data.create_transform(**data_cfg)
transform = get_default_transform()
except Exception as _: # noqa: E722
raise ValueError(f"Model {model} not found.")
except: # noqa: E722
model = model

if compile:
compile_opts = {} if compile_opts is None else compile_opts
Expand All @@ -90,6 +94,11 @@ def feature_extraction(
model = model.to(device)
model.eval()

if model_func is None:

def model_func(model, image):
return model(image)

# Create dataloader
# Auto chunk the wsi tile coordinates to the number of workers'
tiles_count = len(wsi.sdata.points[tile_key])
Expand Down Expand Up @@ -117,7 +126,9 @@ def feature_extraction(
chunks = chunker(np.arange(tiles_count), num_workers)
dataset = WSIImageDataset(wsi, transform=transform, key=tile_key)
futures = [
executor.submit(_inference, dataset, chunk, model, queue)
executor.submit(
_inference, dataset, chunk, model, queue, model_func
)
for chunk in chunks
]
while any(future.running() for future in futures):
Expand All @@ -140,7 +151,7 @@ def feature_extraction(
features = []
with torch.inference_mode():
for batch in loader:
output = model(batch.to(device))
output = model_func(model, (batch.to(device)))
features.append(output.cpu().numpy())
pbar.update(task, advance=batch_size)
features = np.vstack(features)
Expand All @@ -167,7 +178,7 @@ def chunker(seq, num_workers):
return out


def _inference(dataset, chunk, model, queue):
def _inference(dataset, chunk, model, queue, model_func):
import torch

with torch.inference_mode():
Expand All @@ -176,7 +187,7 @@ def _inference(dataset, chunk, model, queue):
img = dataset[c]
# image to 4d
img = img.unsqueeze(0)
output = model(img)
output = model_func(model, img)
X.append(output.cpu().numpy())
queue.put(1)
return X

0 comments on commit c1e8c4a

Please sign in to comment.