Skip to content

Commit

Permalink
Merge branch 'divya/fix-topdown-aug' into divya/tracker
Browse files Browse the repository at this point in the history
  • Loading branch information
gitttt-1234 authored Sep 11, 2024
2 parents 5e77de9 + d8f49a8 commit 58a9715
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 39 deletions.
30 changes: 15 additions & 15 deletions sleap_nn/data/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,28 +73,28 @@ def make_training_pipeline(
provider=provider,
)

if use_augmentations and "intensity" in self.data_config.augmentation_config:
datapipe = KorniaAugmenter(
datapipe,
**dict(self.data_config.augmentation_config.intensity),
image_key="image",
instance_key="instances",
)
if use_augmentations:
if "intensity" in self.data_config.augmentation_config:
datapipe = KorniaAugmenter(
datapipe,
**dict(self.data_config.augmentation_config.intensity),
image_key="image",
instance_key="instances",
)
if "geometric" in self.data_config.augmentation_config:
datapipe = KorniaAugmenter(
datapipe,
**dict(self.data_config.augmentation_config.geometric),
image_key="image",
instance_key="instances",
)

datapipe = InstanceCentroidFinder(
datapipe, anchor_ind=self.confmap_head.anchor_part
)

datapipe = InstanceCropper(datapipe, self.crop_hw)

if use_augmentations and "geometric" in self.data_config.augmentation_config:
datapipe = KorniaAugmenter(
datapipe,
**dict(self.data_config.augmentation_config.geometric),
image_key="instance_image",
instance_key="instance",
)

datapipe = Resizer(
datapipe,
scale=self.data_config.preprocessing.scale,
Expand Down
33 changes: 21 additions & 12 deletions sleap_nn/inference/paf_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from scipy.optimize import linear_sum_assignment
import networkx as nx
from omegaconf import OmegaConf
from sleap_nn.inference.utils import interp1d


@attrs.define(auto_attribs=True, frozen=True)
Expand Down Expand Up @@ -177,26 +178,34 @@ def make_line_subs(
dst_peaks = torch.index_select(peaks_sample, 0, edge_peak_inds[:, 1])
n_candidates = torch.tensor(src_peaks.shape[0], device=peaks_sample.device)

linspace_values = torch.linspace(0, 1, n_line_points, dtype=torch.float32).to(
device=peaks_sample.device
X = torch.cat(
(src_peaks[:, 0].unsqueeze(dim=-1), dst_peaks[:, 0].unsqueeze(dim=-1)), dim=-1
).to(torch.float64)
Y = torch.cat(
(src_peaks[:, 1].unsqueeze(dim=-1), dst_peaks[:, 1].unsqueeze(dim=-1)), dim=-1
).to(torch.float64)
samples = torch.Tensor([0, 1], device=X.device).repeat(n_candidates, 1)
samples_new = torch.linspace(0, 1, steps=n_line_points, device=X.device).repeat(
n_candidates, 1
)
linspace_values = linspace_values.repeat(n_candidates, 1).view(
n_candidates, n_line_points, 1
)
XY = (
src_peaks.view(n_candidates, 1, 2)
+ (dst_peaks - src_peaks).view(n_candidates, 1, 2) * linspace_values
)
XY = XY.transpose(1, 2)

X = interp1d(samples, X, samples_new).unsqueeze(
dim=1
) # (n_candidates, 1, n_line_points)
Y = interp1d(samples, Y, samples_new).unsqueeze(
dim=1
) # (n_candidates, 1, n_line_points)
XY = torch.concat([X, Y], dim=1)

XY = (
(XY / pafs_stride).round().int()
) # (n_candidates, 2, n_line_points) # dim 1 is [x, y]
XY = XY[:, [1, 0], :] # dim 1 is [row, col]

# clip coordinates for size of pafs tensor.
height, width = pafs_hw
XY[:, 0, :][XY[:, 0, :] >= height] = height - 1
XY[:, 1, :][XY[:, 1, :] >= width] = width - 1
XY[:, 0] = torch.clip(XY[:, 0], min=0, max=height - 1)
XY[:, 1] = torch.clip(XY[:, 1], min=0, max=width - 1)

edge_inds_expanded = (
edge_inds.view(-1, 1, 1)
Expand Down
105 changes: 105 additions & 0 deletions sleap_nn/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from omegaconf import OmegaConf
import sleap_io as sio
import torch


def get_skeleton_from_config(skeleton_config: OmegaConf):
Expand Down Expand Up @@ -50,3 +51,107 @@ def get_skeleton_from_config(skeleton_config: OmegaConf):
skeletons.append(sio.model.skeleton.Skeleton(nodes, edges, symmetries, name))

return skeletons


def interp1d(x: torch.Tensor, y: torch.Tensor, xnew: torch.Tensor) -> torch.Tensor:
"""Linear 1-D interpolation.
Src: https://github.com/aliutkus/torchinterp1d/blob/master/torchinterp1d/interp1d.py
Args:
x : (N, ) or (D, N) Tensor.
y : (N,) or (D, N) float Tensor. The length of `y` along its
last dimension must be the same as that of `x`
xnew : (P,) or (D, P) Tensor. `xnew` can only be 1-D if
_both_ `x` and `y` are 1-D. Otherwise, its length along the first
dimension must be the same as that of whichever `x` and `y` is 2-D.
Returns:
(P, ) or (D, P) Tensor.
"""
# making the vectors at least 2D
is_flat = {}
v = {}
eps = torch.finfo(y.dtype).eps
for name, vec in {"x": x, "y": y, "xnew": xnew}.items():
assert len(vec.shape) <= 2, "interp1d: all inputs must be " "at most 2-D."
if len(vec.shape) == 1:
v[name] = vec[None, :]
else:
v[name] = vec
is_flat[name] = v[name].shape[0] == 1
device = y.device

# Checking for the dimensions
assert v["x"].shape[1] == v["y"].shape[1] and (
v["x"].shape[0] == v["y"].shape[0]
or v["x"].shape[0] == 1
or v["y"].shape[0] == 1
), (
"x and y must have the same number of columns, and either "
"the same number of row or one of them having only one "
"row."
)

if (v["x"].shape[0] == 1) and (v["y"].shape[0] == 1) and (v["xnew"].shape[0] > 1):
# if there is only one row for both x and y, there is no need to
# loop over the rows of xnew because they will all have to face the
# same interpolation problem. We should just stack them together to
# call interp1d and put them back in place afterwards.
v["xnew"] = v["xnew"].contiguous().view(1, -1)

# identify the dimensions of output
D = max(v["x"].shape[0], v["xnew"].shape[0])
shape_ynew = (D, v["xnew"].shape[-1])
ynew = torch.zeros(*shape_ynew, device=device)

# moving everything to the desired device in case it was not there
# already (not handling the case things do not fit entirely, user will
# do it if required.)
for name in v:
v[name] = v[name].to(device)

# calling searchsorted on the x values.
ind = ynew.long()

# expanding xnew to match the number of rows of x in case only one xnew is
# provided
if v["xnew"].shape[0] == 1:
v["xnew"] = v["xnew"].expand(v["x"].shape[0], -1)

# the squeeze is because torch.searchsorted does accept either a nd with
# matching shapes for x and xnew or a 1d vector for x. Here we would
# have (1,len) for x sometimes
torch.searchsorted(v["x"].contiguous().squeeze(), v["xnew"].contiguous(), out=ind)

# the `-1` is because searchsorted looks for the index where the values
# must be inserted to preserve order. And we want the index of the
# preceeding value.
ind -= 1
# we clamp the index, because the number of intervals is x.shape-1,
# and the left neighbour should hence be at most number of intervals
# -1, i.e. number of columns in x -2
ind = torch.clamp(ind, 0, v["x"].shape[1] - 1 - 1)

# helper function to select stuff according to the found indices.
def sel(name):
if is_flat[name]:
return v[name].contiguous().view(-1)[ind]
return torch.gather(v[name], 1, ind)

# assuming x are sorted in the dimension 1, computing the slopes for
# the segments
is_flat["slopes"] = is_flat["x"]
# now we have found the indices of the neighbors, we start building the
# output.
v["slopes"] = (v["y"][:, 1:] - v["y"][:, :-1]) / (
eps + (v["x"][:, 1:] - v["x"][:, :-1])
)

# now build the linear interpolation
ynew = sel("y") + sel("slopes") * (v["xnew"] - sel("x"))

if len(y.shape) == 1:
ynew = ynew.view(-1)

return ynew
4 changes: 0 additions & 4 deletions tests/assets/minimal_instance_centroid/initial_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ data_config:
is_rgb: false
use_augmentations_train: true
augmentation_config:
random_crop:
random_crop_p: 0
crop_height: 160
crop_width: 160
geometric:
rotation: 180.0
scale: null
Expand Down
5 changes: 1 addition & 4 deletions tests/assets/minimal_instance_centroid/training_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ data_config:
is_rgb: false
use_augmentations_train: true
augmentation_config:
random_crop:
random_crop_p: 0
crop_height: 160
crop_width: 160
geometric:
rotation: 180.0
scale: null
Expand Down Expand Up @@ -87,6 +83,7 @@ trainer_config:
- trainer_config.optimizer.amsgrad
- trainer_config.optimizer.lr
- model_config.backbone_type
- model_config.backbone_type
- model_config.init_weights
optimizer_name: Adam
optimizer:
Expand Down
2 changes: 1 addition & 1 deletion tests/inference/test_paf_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_make_line_subs():
edge_inds,
n_line_points=3,
pafs_stride=2,
pafs_hw=(8, 8),
pafs_hw=(9, 9),
)

assert line_subs.numpy().tolist() == [
Expand Down
2 changes: 0 additions & 2 deletions tests/inference/test_single_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ def test_single_instance_inference_model(minimal_instance, minimal_instance_ckpt
)
del training_config.model_config.head_configs.single_instance.confmaps.anchor_part

print(training_config.model_config.head_configs)

torch_model = SingleInstanceModel.load_from_checkpoint(
f"{minimal_instance_ckpt}/best.ckpt",
config=training_config,
Expand Down
19 changes: 18 additions & 1 deletion tests/inference/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from omegaconf import OmegaConf
import torch

import sleap_io as sio
from sleap_nn.inference.utils import get_skeleton_from_config
from sleap_nn.inference.utils import get_skeleton_from_config, interp1d


def test_get_skeleton_from_config(minimal_instance, minimal_instance_ckpt):
Expand All @@ -10,3 +12,18 @@ def test_get_skeleton_from_config(minimal_instance, minimal_instance_ckpt):
skeletons = get_skeleton_from_config(skeleton_config)
labels = sio.load_slp(f"{minimal_instance}")
assert skeletons[0] == labels.skeletons[0]


def test_interp1d():
"""Test function for `interp()` function."""
x = torch.linspace(0, 10, steps=10)
y = torch.randint(10, 30, (10,), dtype=torch.float64)
xq = torch.linspace(0, 10, steps=20)
yq = interp1d(x, y, xq)
assert yq.shape == (20,)

x = torch.linspace(0, 10, steps=10).repeat(5, 1)
y = torch.randint(10, 30, (5, 10), dtype=torch.float64)
xq = torch.linspace(0, 10, steps=20).repeat(5, 1)
yq = interp1d(x, y, xq)
assert yq.shape == (5, 20)

0 comments on commit 58a9715

Please sign in to comment.