Skip to content

Commit

Permalink
Add minimal pretrained checkpoints for tests and fix PAF grouping int…
Browse files Browse the repository at this point in the history
…erpolation (#73)

* Refactor Augmentation config

* Modify ckpts

* Refactor pipeline config (#68)

* Refactor preprocessing config

* Merge train and val data configs

* Remove pipeline name

* Modify backbone_config

* Modify ckpts

* Fix inference tests

* Fix device for inference

* Fix scale in inference

* Fix Predictor

* Modify `bottom_up` to `bottomup`

* Fix bottomup inference

* Fix scale in augmentation

* changed parameters

* Remove `image` key from TopDownConfmaps pipeline (#71)

* Refactor preprocessing config

* Merge train and val data configs

* Remove pipeline name

* Modify backbone_config

* Modify ckpts

* Fix inference tests

* Fix device for inference

* Fix scale in inference

* Fix Predictor

* Modify `bottom_up` to `bottomup`

* Fix bottomup inference

* Fix scale in augmentation

* Test remove image

* Fix instance cropping

* Fix tests

* Remove scale in pipelines

* adjusting the bottomup params

* Fix paf grouping

* Fix test

* Add interpolate function

* Fix tests

* Fix augmentation in TopdownConfmaps pipeline (#78)

* Fix topdown aug

* Implement tracker module (#70)

* Refactor preprocessing config

* Merge train and val data configs

* Remove pipeline name

* Modify backbone_config

* Modify ckpts

* Fix inference tests

* Fix device for inference

* Fix scale in inference

* Fix Predictor

* Modify `bottom_up` to `bottomup`

* Fix bottomup inference

* Fix scale in augmentation

* Add tracker

* Fix tracker queue

* Add local queues

* Modify local queues\

* Add features

* Add optical flow

* Add Optical flow

* Add tracking score

* Refactor candidate update

* Integrate with Predictors

* Fix lint

* Fix tracks

* Resume training and automatically compute crop size for TopDownConfmaps pipeline (#79)

* Add option to automatically compute crop size

* Move find_crop_size to Trainer

* Fix skeleton name

* Add crop size to config

* Add resumable training option

* Add tests fore resuming training

* Fix tests

* Fix test for wandb folder

* LitData Refactor PR1: Get individual functions for data pipelines (#81)

* Add functions for data pipelines

* Add test cases

* Format file

* Add more test cases

* Fix augmentation test

* Revert merging "Fix augmentation in TopdownConfmaps pipeline (#78)"

This reverts commit 9d65cc7.

---------

Co-authored-by: gitttt-1234 <divyasesh11@gmail.com>
Co-authored-by: DivyaSesh <64513125+gitttt-1234@users.noreply.github.com>
Co-authored-by: grquach <101067674+grquach@users.noreply.github.com>
  • Loading branch information
4 people authored Sep 12, 2024
1 parent 2f44d2d commit 4425b2b
Show file tree
Hide file tree
Showing 19 changed files with 357 additions and 211 deletions.
40 changes: 29 additions & 11 deletions sleap_nn/inference/paf_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from scipy.optimize import linear_sum_assignment
import networkx as nx
from omegaconf import OmegaConf
from sleap_nn.inference.utils import interp1d


@attrs.define(auto_attribs=True, frozen=True)
Expand Down Expand Up @@ -135,6 +136,7 @@ def make_line_subs(
edge_inds: torch.Tensor,
n_line_points: int,
pafs_stride: int,
pafs_hw: tuple,
) -> torch.Tensor:
"""Create the lines between candidate connections for evaluating the PAFs.
Expand All @@ -155,6 +157,7 @@ def make_line_subs(
pafs_stride: The stride (1/scale) of the PAFs that these lines will need to
index into relative to the image. Coordinates in `peaks_sample` will be
divided by this value to adjust the indexing into the PAFs tensor.
pafs_hw: Tuple (height, width) with the dimension of PAFs tensor.
Returns:
The line subscripts as a `torch.Tensor` of shape
Expand All @@ -175,22 +178,35 @@ def make_line_subs(
dst_peaks = torch.index_select(peaks_sample, 0, edge_peak_inds[:, 1])
n_candidates = torch.tensor(src_peaks.shape[0], device=peaks_sample.device)

linspace_values = torch.linspace(0, 1, n_line_points, dtype=torch.float32).to(
device=peaks_sample.device
X = torch.cat(
(src_peaks[:, 0].unsqueeze(dim=-1), dst_peaks[:, 0].unsqueeze(dim=-1)), dim=-1
).to(torch.float64)
Y = torch.cat(
(src_peaks[:, 1].unsqueeze(dim=-1), dst_peaks[:, 1].unsqueeze(dim=-1)), dim=-1
).to(torch.float64)
samples = torch.Tensor([0, 1], device=X.device).repeat(n_candidates, 1)
samples_new = torch.linspace(0, 1, steps=n_line_points, device=X.device).repeat(
n_candidates, 1
)
linspace_values = linspace_values.repeat(n_candidates, 1).view(
n_candidates, n_line_points, 1
)
XY = (
src_peaks.view(n_candidates, 1, 2)
+ (dst_peaks - src_peaks).view(n_candidates, 1, 2) * linspace_values
)
XY = XY.transpose(1, 2)

X = interp1d(samples, X, samples_new).unsqueeze(
dim=1
) # (n_candidates, 1, n_line_points)
Y = interp1d(samples, Y, samples_new).unsqueeze(
dim=1
) # (n_candidates, 1, n_line_points)
XY = torch.concat([X, Y], dim=1)

XY = (
(XY / pafs_stride).round().int()
) # (n_candidates, 2, n_line_points) # dim 1 is [x, y]
XY = XY[:, [1, 0], :] # dim 1 is [row, col]

# clip coordinates for size of pafs tensor.
height, width = pafs_hw
XY[:, 0] = torch.clip(XY[:, 0], min=0, max=height - 1)
XY[:, 1] = torch.clip(XY[:, 1], min=0, max=width - 1)

edge_inds_expanded = (
edge_inds.view(-1, 1, 1)
.expand(-1, 1, n_line_points)
Expand Down Expand Up @@ -263,8 +279,9 @@ def get_paf_lines(
See also: get_connection_candidates, make_line_subs, score_paf_lines
"""
pafs_hw = pafs_sample.shape[:2]
line_subs = make_line_subs(
peaks_sample, edge_peak_inds, edge_inds, n_line_points, pafs_stride
peaks_sample, edge_peak_inds, edge_inds, n_line_points, pafs_stride, pafs_hw
)
lines = pafs_sample[line_subs[..., 0], line_subs[..., 1], line_subs[..., 2]]
return lines
Expand Down Expand Up @@ -565,6 +582,7 @@ def match_candidates_sample(

# Convert cost matrix to numpy for use with scipy's linear_sum_assignment.
cost_matrix_np = cost_matrix.numpy()
cost_matrix_np[np.isnan(cost_matrix_np)] = np.inf

# Match.
match_src_inds, match_dst_inds = linear_sum_assignment(cost_matrix_np)
Expand Down
100 changes: 49 additions & 51 deletions sleap_nn/inference/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,57 +254,55 @@ def _predict_generator(self) -> Iterator[Dict[str, np.ndarray]]:
yield output

elif self.provider == "VideoReader":
try:
self.pipeline.start()
batch_size = self.video_preprocess_config["batch_size"]
done = False
while not done:
imgs = []
fidxs = []
org_szs = []
for _ in range(batch_size):
frame = self.pipeline.frame_buffer.get()
if frame[0] is None:
done = True
break
imgs.append(frame[0].unsqueeze(dim=0))
fidxs.append(frame[1])
org_szs.append(frame[2].unsqueeze(dim=0))
if imgs:
imgs = torch.concatenate(imgs, dim=0)
fidxs = torch.tensor(fidxs, dtype=torch.int32)
org_szs = torch.concatenate(org_szs, dim=0)
ex = {
"image": imgs,
"frame_idx": fidxs,
"video_idx": torch.tensor(
[0] * batch_size, dtype=torch.int32
),
"orig_size": org_szs,
}
if not torch.is_floating_point(ex["image"]): # normalization
ex["image"] = ex["image"].to(torch.float32) / 255.0
if self.video_preprocess_config["is_rgb"]:
ex["image"] = convert_to_rgb(ex["image"])
else:
ex["image"] = convert_to_grayscale(ex["image"])
if self.preprocess:
scale = self.video_preprocess_config["scale"]
if scale != 1.0:
ex["image"] = resize_image(ex["image"], scale)
ex["image"] = pad_to_stride(
ex["image"], self.video_preprocess_config["max_stride"]
)
outputs_list = self.inference_model(ex)
for output in outputs_list:
output = self._convert_tensors_to_numpy(output)
yield output

except Exception as e:
raise Exception(f"Error in VideoReader during data processing: {e}")

finally:
self.pipeline.join()
# try:
self.pipeline.start()
batch_size = self.video_preprocess_config["batch_size"]
done = False
while not done:
imgs = []
fidxs = []
org_szs = []
for _ in range(batch_size):
frame = self.pipeline.frame_buffer.get()
if frame[0] is None:
done = True
break
imgs.append(frame[0].unsqueeze(dim=0))
fidxs.append(frame[1])
org_szs.append(frame[2].unsqueeze(dim=0))
if imgs:
imgs = torch.concatenate(imgs, dim=0)
fidxs = torch.tensor(fidxs, dtype=torch.int32)
org_szs = torch.concatenate(org_szs, dim=0)
ex = {
"image": imgs,
"frame_idx": fidxs,
"video_idx": torch.tensor([0] * batch_size, dtype=torch.int32),
"orig_size": org_szs,
}
if not torch.is_floating_point(ex["image"]): # normalization
ex["image"] = ex["image"].to(torch.float32) / 255.0
if self.video_preprocess_config["is_rgb"]:
ex["image"] = convert_to_rgb(ex["image"])
else:
ex["image"] = convert_to_grayscale(ex["image"])
if self.preprocess:
scale = self.video_preprocess_config["scale"]
if scale != 1.0:
ex["image"] = resize_image(ex["image"], scale)
ex["image"] = pad_to_stride(
ex["image"], self.video_preprocess_config["max_stride"]
)
outputs_list = self.inference_model(ex)
for output in outputs_list:
output = self._convert_tensors_to_numpy(output)
yield output

# except Exception as e:
# raise Exception(f"Error in VideoReader during data processing: {e}")

# finally:
self.pipeline.join()

def predict(
self,
Expand Down
105 changes: 105 additions & 0 deletions sleap_nn/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from omegaconf import OmegaConf
import sleap_io as sio
import torch


def get_skeleton_from_config(skeleton_config: OmegaConf):
Expand Down Expand Up @@ -50,3 +51,107 @@ def get_skeleton_from_config(skeleton_config: OmegaConf):
skeletons.append(sio.model.skeleton.Skeleton(nodes, edges, symmetries, name))

return skeletons


def interp1d(x: torch.Tensor, y: torch.Tensor, xnew: torch.Tensor) -> torch.Tensor:
"""Linear 1-D interpolation.
Src: https://github.com/aliutkus/torchinterp1d/blob/master/torchinterp1d/interp1d.py
Args:
x : (N, ) or (D, N) Tensor.
y : (N,) or (D, N) float Tensor. The length of `y` along its
last dimension must be the same as that of `x`
xnew : (P,) or (D, P) Tensor. `xnew` can only be 1-D if
_both_ `x` and `y` are 1-D. Otherwise, its length along the first
dimension must be the same as that of whichever `x` and `y` is 2-D.
Returns:
(P, ) or (D, P) Tensor.
"""
# making the vectors at least 2D
is_flat = {}
v = {}
eps = torch.finfo(y.dtype).eps
for name, vec in {"x": x, "y": y, "xnew": xnew}.items():
assert len(vec.shape) <= 2, "interp1d: all inputs must be " "at most 2-D."
if len(vec.shape) == 1:
v[name] = vec[None, :]
else:
v[name] = vec
is_flat[name] = v[name].shape[0] == 1
device = y.device

# Checking for the dimensions
assert v["x"].shape[1] == v["y"].shape[1] and (
v["x"].shape[0] == v["y"].shape[0]
or v["x"].shape[0] == 1
or v["y"].shape[0] == 1
), (
"x and y must have the same number of columns, and either "
"the same number of row or one of them having only one "
"row."
)

if (v["x"].shape[0] == 1) and (v["y"].shape[0] == 1) and (v["xnew"].shape[0] > 1):
# if there is only one row for both x and y, there is no need to
# loop over the rows of xnew because they will all have to face the
# same interpolation problem. We should just stack them together to
# call interp1d and put them back in place afterwards.
v["xnew"] = v["xnew"].contiguous().view(1, -1)

# identify the dimensions of output
D = max(v["x"].shape[0], v["xnew"].shape[0])
shape_ynew = (D, v["xnew"].shape[-1])
ynew = torch.zeros(*shape_ynew, device=device)

# moving everything to the desired device in case it was not there
# already (not handling the case things do not fit entirely, user will
# do it if required.)
for name in v:
v[name] = v[name].to(device)

# calling searchsorted on the x values.
ind = ynew.long()

# expanding xnew to match the number of rows of x in case only one xnew is
# provided
if v["xnew"].shape[0] == 1:
v["xnew"] = v["xnew"].expand(v["x"].shape[0], -1)

# the squeeze is because torch.searchsorted does accept either a nd with
# matching shapes for x and xnew or a 1d vector for x. Here we would
# have (1,len) for x sometimes
torch.searchsorted(v["x"].contiguous().squeeze(), v["xnew"].contiguous(), out=ind)

# the `-1` is because searchsorted looks for the index where the values
# must be inserted to preserve order. And we want the index of the
# preceeding value.
ind -= 1
# we clamp the index, because the number of intervals is x.shape-1,
# and the left neighbour should hence be at most number of intervals
# -1, i.e. number of columns in x -2
ind = torch.clamp(ind, 0, v["x"].shape[1] - 1 - 1)

# helper function to select stuff according to the found indices.
def sel(name):
if is_flat[name]:
return v[name].contiguous().view(-1)[ind]
return torch.gather(v[name], 1, ind)

# assuming x are sorted in the dimension 1, computing the slopes for
# the segments
is_flat["slopes"] = is_flat["x"]
# now we have found the indices of the neighbors, we start building the
# output.
v["slopes"] = (v["y"][:, 1:] - v["y"][:, :-1]) / (
eps + (v["x"][:, 1:] - v["x"][:, :-1])
)

# now build the linear interpolation
ynew = sel("y") + sel("slopes") * (v["xnew"] - sel("x"))

if len(y.shape) == 1:
ynew = ynew.view(-1)

return ynew
Binary file modified tests/assets/minimal_instance/best.ckpt
100644 → 100755
Binary file not shown.
10 changes: 5 additions & 5 deletions tests/assets/minimal_instance/initial_config.yaml
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,21 @@ model_config:
in_channels: 1
kernel_size: 3
filters: 16
filters_rate: 2
max_stride: 16
filters_rate: 1.5
max_stride: 8
convs_per_block: 2
stacks: 1
stem_stride: null
middle_block: true
up_interpolate: true
up_interpolate: false
head_configs:
single_instance: null
bottomup: null
bottom_up: null
centroid: null
centered_instance:
confmaps:
part_names: null
anchor_part: 0
anchor_part: null
sigma: 1.5
output_stride: 2
trainer_config:
Expand Down
Binary file modified tests/assets/minimal_instance/last.ckpt
100644 → 100755
Binary file not shown.
Loading

0 comments on commit 4425b2b

Please sign in to comment.