Skip to content

Commit

Permalink
WIP: Modifying Predictor (sub)class for seamless VisualPredictor inte…
Browse files Browse the repository at this point in the history
…gration
  • Loading branch information
roomrys committed Jan 11, 2023
1 parent db11202 commit d0ad90a
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 123 deletions.
2 changes: 1 addition & 1 deletion sleap/gui/overlays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _add(

@classmethod
def make_viz_predictor(cls, filename: str) -> VisualPredictor:
return VisualPredictor.from_model_path(filename)
return VisualPredictor.from_trained_models(filename)

@classmethod
def from_model(cls, filename: str, *args, **kwargs):
Expand Down
188 changes: 137 additions & 51 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import numpy as np

import sleap
from sleap.io.video import NumpyVideo, Video
from sleap.nn.config import TrainingJobConfig, DataConfig
from sleap.nn.data.resizing import SizeMatcher
from sleap.nn.model import Model
Expand Down Expand Up @@ -117,6 +118,31 @@ def report_period(self) -> float:
"""Time between progress reports in seconds."""
return 1.0 / self.report_rate

@classmethod
def categorize_model_paths(
cls, model_paths: Union[str, List[str]]
) -> Tuple[List[TrainingJobConfig], List[str], List[str]]:
"""Read configs and find model types.
Args:
model_paths: A single or list of trained model paths.
Return:
model_paths: List of `TrainingJobConfig.filename`s for all the configs in
`model_configs`.
model_types: List of the model type for each config in `model_configs`.
"""
if isinstance(model_paths, str):
model_paths = [model_paths]
model_configs: List[TrainingJobConfig] = [
sleap.load_config(model_path) for model_path in model_paths
]
model_paths: List[str] = [cfg.filename for cfg in model_configs]
model_types: List[str] = [
cfg.model.heads.which_oneof_attrib_name() for cfg in model_configs
]
return model_paths, model_types

@classmethod
def from_model_paths(
cls,
Expand Down Expand Up @@ -150,13 +176,7 @@ def from_model_paths(
See also: `SingleInstancePredictor`, `TopDownPredictor`, `BottomUpPredictor`
"""
# Read configs and find model types.
if isinstance(model_paths, str):
model_paths = [model_paths]
model_configs = [sleap.load_config(model_path) for model_path in model_paths]
model_paths = [cfg.filename for cfg in model_configs]
model_types = [
cfg.model.heads.which_oneof_attrib_name() for cfg in model_configs
]
model_paths, model_types = cls.categorize_model_paths(model_paths)

if "single_instance" in model_types:
predictor = SingleInstancePredictor.from_trained_models(
Expand Down Expand Up @@ -251,6 +271,31 @@ def is_grayscale(self) -> bool:
"""Return whether the model expects grayscale inputs."""
pass

@abstractmethod
def _make_labeled_frames_from_generator(
self, generator: Iterator[Dict[str, np.ndarray]], data_provider: Provider
) -> List[sleap.LabeledFrame]:
"""Create labeled frames from a generator that yields inference results.
This method converts pure arrays into SLEAP-specific data structures and runs
them through the tracker if it is specified.
Args:
generator: A generator that returns dictionaries with inference results.
This should return dictionaries with keys `"image"`, `"video_ind"`,
`"frame_ind"`, `"instance_peaks"`, `"instance_peak_vals"`, and
`"centroid_vals"`. This can be created using the `_predict_generator()`
method.
data_provider: The `sleap.pipelines.Provider` that the predictions are being
created from. This is used to retrieve the `sleap.Video` instance
associated with each inference result.
Returns:
A list of `sleap.LabeledFrame`s with `sleap.PredictedInstance`s created from
arrays returned from the inference result generator.
"""
pass

def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline:
"""Make a data loading pipeline.
Expand Down Expand Up @@ -434,7 +479,7 @@ def predict(
"""
# Create provider if necessary.
if isinstance(data, np.ndarray):
data = sleap.Video(backend=sleap.io.video.NumpyVideo(data))
data = sleap.Video(backend=NumpyVideo(data))
if isinstance(data, sleap.Labels):
data = LabelsReader(data)
elif isinstance(data, sleap.Video):
Expand Down Expand Up @@ -2133,6 +2178,7 @@ def from_trained_models(
cls,
centroid_model_path: Optional[Text] = None,
confmap_model_path: Optional[Text] = None,
model_paths: Optional[Union[Text, List[Text]]] = None,
batch_size: int = 4,
peak_threshold: float = 0.2,
integral_refinement: bool = True,
Expand Down Expand Up @@ -2168,44 +2214,61 @@ def from_trained_models(
One of the two models can be left as `None` to perform inference with ground
truth data. This will only work with `LabelsReader` as the provider.
"""
if centroid_model_path is None and confmap_model_path is None:

def load_from_model_path(
model_path: str,
) -> Tuple[Optional[TrainingJobConfig], Optional[Model]]:
if confmap_model_path is not None:
# Load model.
config = TrainingJobConfig.load_json(model_path)
keras_model_path = get_keras_model_path(model_path)
model = Model.from_config(config.model)
model.keras_model = tf.keras.models.load_model(
keras_model_path, compile=False
)
if resize_input_layer:
# Reset input layer dimensions to be more flexible
model.keras_model = reset_input_layer(
keras_model=model.keras_model, new_shape=None
)
else:
config = None
model = None

return (config, model)

def categorize_model_paths(
model_paths: Optional[Union[Text, List[Text]]]
) -> Tuple[Optional[str], Optional[str]]:

centroid_model_path = None
confmap_model_path = None
model_paths, model_types = Predictor.categorize_model_paths(model_paths)
if "centroid" in model_types:
centroid_model_path = model_paths[model_types.index("centroid")]
if "centered_instance" in model_types:
confmap_model_path = (
model_paths[model_types.index("centered_instance")] or None
)

return (centroid_model_path, confmap_model_path)

if (
centroid_model_path is None
and confmap_model_path is None
and model_paths is None
):
raise ValueError(
"Either the centroid or topdown confidence map model must be provided."
)

if centroid_model_path is not None:
# Load centroid model.
centroid_config = TrainingJobConfig.load_json(centroid_model_path)
centroid_keras_model_path = get_keras_model_path(centroid_model_path)
centroid_model = Model.from_config(centroid_config.model)
centroid_model.keras_model = tf.keras.models.load_model(
centroid_keras_model_path, compile=False
# Load model from path, preferring specified paths over generic model paths
if (centroid_model_path is None) and (confmap_model_path is None):
centroid_model_path, confmap_model_path = categorize_model_paths(
model_paths
)
if resize_input_layer:
# Reset input layer dimensions to be more flexible
centroid_model.keras_model = reset_input_layer(
keras_model=centroid_model.keras_model, new_shape=None
)
else:
centroid_config = None
centroid_model = None

if confmap_model_path is not None:
# Load confmap model.
confmap_config = TrainingJobConfig.load_json(confmap_model_path)
confmap_keras_model_path = get_keras_model_path(confmap_model_path)
confmap_model = Model.from_config(confmap_config.model)
confmap_model.keras_model = tf.keras.models.load_model(
confmap_keras_model_path, compile=False
)
if resize_input_layer:
# Reset input layer dimensions to be more flexible
confmap_model.keras_model = reset_input_layer(
keras_model=confmap_model.keras_model, new_shape=None
)
else:
confmap_config = None
confmap_model = None
centroid_config, centroid_model = load_from_model_path(centroid_model_path)
confmap_config, confmap_model = load_from_model_path(confmap_model_path)

obj = cls(
centroid_config=centroid_config,
Expand Down Expand Up @@ -4196,14 +4259,15 @@ def get_supported_predictors(
}

@classmethod
def from_model_path(
def from_trained_models(
cls,
model_path: str,
model_paths: str,
peak_threshold: float = 0.2,
integral_refinement: bool = True,
integral_patch_size: int = 5,
batch_size: int = 4,
) -> Predictor:
resize_input_layer: bool = True,
) -> "VisualPredictor":
"""Create the appropriate `Predictor` subclass from a list of model paths.
Args:
Expand All @@ -4220,17 +4284,20 @@ def from_model_path(
usage.
Returns:
A subclass of `Predictor`.
An instance of `VisualPredictor` with a subclass of `Predictor` stored in
`predictor` attribute.
See also: `SingleInstancePredictor`, `TopDownPredictor`, `BottomUpPredictor`,
`TopDownMultiClassPredictor`, and `BottomUpMultiClassPredictor`
"""

predictor = Predictor.from_model_paths(
model_paths=[model_path],
model_paths=[model_paths],
peak_threshold=peak_threshold,
integral_refinement=integral_refinement,
integral_patch_size=integral_patch_size,
batch_size=batch_size,
resize_input_layer=resize_input_layer,
)

# Link `model` and `config` of `VisualPredictor` to appropriate attributes
Expand All @@ -4252,10 +4319,13 @@ def from_model_path(
f"\ntype({predictor.__class__.__name__}.{cfg_attr}) = {type(config)}"
)
# TODO(LM): Remove when implement support for centered instance
elif model_attr == "confmap_model":
elif (model_attr == "confmap_model") and (
isinstance(predictor, TopDownPredictor)
or isinstance(predictor, TopDownMultiClassPredictor)
):
raise TypeError(
"Centered instance models are not currently suppported.\n"
f"Recieved model = {model_attr}.\n"
f"Recieved model = {type(predictor).__name__}: {model_attr}.\n"
"Please select a different type of model."
)

Expand Down Expand Up @@ -4365,9 +4435,25 @@ def process_batch(ds: tf.data.Dataset, progress: bool = True):
# Yield each example from dataset, catching and logging exceptions
return process_batch(self.pipeline.make_dataset())

def predict(self, data_provider: Provider):
generator = self._predict_generator(data_provider)
examples = list(generator)
def predict(
self,
data: Union[np.ndarray, Video, VideoReader],
make_labels: Optional[bool] = None,
):
# The `VisualPredictor.predict` only works with `VideoReader` of length 1.
if isinstance(data, VideoReader) and (len(data.example_indices) == 1):
# Use `VisualPredictor.make_pipeline()`
obj = self
make_labels = False
logger.info("Using VisualPredictor pipeline.")
else:
# Use `VisualPredictor.predictor.make_pipeline()`
obj = self.predictor
make_labels = True if make_labels is None else make_labels
logger.info(f"Using {type(self.predictor).__name__} pipeline.")

# Call `Predict.predict`
examples = Predictor.predict(obj, data, make_labels=make_labels)

return examples

Expand Down
Loading

0 comments on commit d0ad90a

Please sign in to comment.