diff --git a/docs/config.md b/docs/config.md index 42bfbc3c..e72d4ef6 100644 --- a/docs/config.md +++ b/docs/config.md @@ -13,9 +13,9 @@ The config file has three main sections: - `data_config`: - `provider`: (str) Provider class to read the input sleap files. Only "LabelsReader" supported for the training pipeline. - - `pipeline`: (str) Pipeline for training data. One of "TopdownConfmaps", "SingleInstanceConfmaps", "CentroidConfmapsPipeline" or "BottomUp". - - `train`: - - `labels_path`: (str) Path to `.slp` files + - `train_labels_path`: (str) Path to training data (`.slp` file) + - `val_labels_path`: (str) Path to validation data (`.slp` file) + - `preprocessing`: - `is_rgb`: (bool) True if the image has 3 channels (RGB image). If input has only one channel when this is set to `True`, then the images from single-channel is replicated along the channel axis. If input has three channels and this @@ -25,100 +25,118 @@ The config file has three main sections: original image size will be retained. Default: None. - `max_width`: (int) Maximum width the image should be padded to. If not provided, the original image size will be retained. Default: None. - - `scale`: (float or List[float]) Factor to resize the image dimensions by, specified as either a float scalar or as a 2-tuple of [scale_x, scale_y]. If a scalar is provided, both dimensions are resized by the same factor. - - `preprocessing`: - - `crop_hw`: (List[int]) Crop height and width of each instance (h, w) for centered-instance model. - - `augmentation_config`: - - `random crop`: (Dict[float]) {"random_crop_p": None, "random_crop_hw": None}, where *random_crop_p* is the probability of applying random crop and *random_crop_hw* is the desired output size (out_h, out_w) of the crop. Must be Tuple[int, int], then out_h = size[0], out_w = size[1]. - - `use_augmentations`: (bool) True if the data augmentation should be applied to the data, else False. - - `augmentation`: - - `intensity`: - - `uniform_noise`: (Tuple[float]) Tuple of uniform noise (min_noise, max_noise). Must satisfy 0. <= min_noise <= max_noise <= 1. - - `uniform_noise_p`: (float) Probability of applying random uniform noise. *Default*=0.0 - - `gaussian_noise_mean`: (float) The mean of the gaussian noise distribution. - - `gaussian_noise_std`: (float) The standard deviation of the gaussian noise distribution. - - `gaussian_noise_p`: (float) Probability of applying random gaussian noise. *Default*=0.0 - - `contrast`: (List[float]) The contrast factor to apply. *Default*: (1.0, 1.0). - - `contrast_p`: (float) Probability of applying random contrast. *Default*=0.0 - - `brightness`: (float) The brightness factor to apply. *Default*: (1.0, 1.0). - - `brightness_p`: (float) Probability of applying random brightness. *Default*=0.0 - - `geometric`: - - `rotation`: (List[float]) Angles in degrees as a scalar float of the amount of rotation. A random angle in (-rotation, rotation) will be sampled and applied to both images and keypoints. Set to 0 to disable rotation augmentation. - - `scale`: (float) A scaling factor as a scalar float specifying the amount of scaling. A - random factor between (1 - scale, 1 + scale) will be sampled and applied to both images and keypoints. If `None`, no scaling augmentation will be applied. - - `translate`: (List[float]) tuple of maximum absolute fraction for horizontal and vertical translations. For example, translate=(a, b), then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is randomly sampled in the range img_height * b < dy < img_height * b. Will not translate by default. - - `affine_p`: (float) Probability of applying random affine transformations. *Default*=0.0 - - `erase_scale`: (List[float]) Range of proportion of erased area against input image. *Default*: (0.0001, 0.01). - - `erase_ratio`: (List[float]) Range of aspect ratio of erased area. *Default*: (1, 1). - - `erase_p`: (float) Probability of applying random erase. *Default*=0.0 - - `mixup_lambda`: (float) min-max value of mixup strength. Default is 0-1. *Default*: `None`. - - `mixup_p`: (float) Probability of applying random mixup v2. *Default*=0.0 - - `input_key`: (str) Can be `image` or `instance`. The input_key `instance` expects the KorniaAugmenter to follow the InstanceCropper else `image` otherwise for default. - - `val`: (Similar to `train` structure) + - `scale`: (float or List[float]) Factor to resize the image dimensions by, specified as either a float scalar or as a 2-tuple of [scale_x, scale_y]. If a scalar is provided, both dimensions are resized by the same factor. + - `crop_hw`: (List[int]) Crop height and width of each instance (h, w) for centered-instance model. + - `use_augmentations_train`: (bool) True if the data augmentation should be applied to the training data, else False. + - `augmentation_config`: (only if `use_augmentations` is `True`) + - `random crop`: (Optional) (Dict[float]) {"random_crop_p": None, "crop_height": None. "crop_width": None}, where *random_crop_p* is the probability of applying random crop and *crop_height* and *crop_width* are the desired output size (out_h, out_w) of the crop. + - `intensity`: (Optional) + - `uniform_noise_min`: (float) Minimum value for uniform noise (uniform_noise_min >=0). + - `uniform_noise_max`: (float) Maximum value for uniform noise (uniform_noise_max <>=1). + - `uniform_noise_p`: (float) Probability of applying random uniform noise. *Default*=0.0 + - `gaussian_noise_mean`: (float) The mean of the gaussian noise distribution. + - `gaussian_noise_std`: (float) The standard deviation of the gaussian noise distribution. + - `gaussian_noise_p`: (float) Probability of applying random gaussian noise. *Default*=0.0 + - `contrast_min`: (float) Minimum contrast factor to apply. Default: 0.5. + - `contrast_max`: (float) Maximum contrast factor to apply. Default: 2.0. + - `contrast_p`: (float) Probability of applying random contrast. *Default*=0.0 + - `brightness`: (float) The brightness factor to apply. *Default*: (1.0, 1.0). + - `brightness_p`: (float) Probability of applying random brightness. *Default*=0.0 + - `geometric`: (Optional) + - `rotation`: (float) Angles in degrees as a scalar float of the amount of rotation. A random angle in (-rotation, rotation) will be sampled and applied to both images and keypoints. Set to 0 to disable rotation augmentation. + - `scale`: (float) scaling factor interval. If (a, b) represents isotropic scaling, the scale is randomly sampled from the range a <= scale <= b. If (a, b, c, d), the scale is randomly sampled from the range a <= scale_x <= b, c <= scale_y <= d Default: None. + - `translate_width`: (float) Maximum absolute fraction for horizontal translation. For example, if translate_width=a, then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a. Will not translate by default. + - `translate_height`: (float) Maximum absolute fraction for vertical translation. For example, if translate_height=a, then vertical shift is randomly sampled in the range -img_height * a < dy < img_height * a. Will not translate by default. + - `affine_p`: (float) Probability of applying random affine transformations. *Default*=0.0 + - `erase_scale_min`: (float) Minimum value of range of proportion of erased area against input image. *Default*: 0.0001. + - `erase_scale_max`: (float) Maximum value of range of proportion of erased area against input image. *Default*: 0.01. + - `erase_ration_min`: (float) Minimum value of range of aspect ratio of erased area. *Default*: 1. + - `erase_ratio_max`: (float) Maximum value of range of aspect ratio of erased area. *Default*: 1. + - `erase_p`: (float) Probability of applying random erase. *Default*=0.0 + - `mixup_lambda`: (float) min-max value of mixup strength. Default is 0-1. *Default*: `None`. + - `mixup_p`: (float) Probability of applying random mixup v2. *Default*=0.0 + - `input_key`: (str) Can be `image` or `instance`. The input_key `instance` expects the KorniaAugmenter to follow the InstanceCropper else `image` otherwise for default. - `model_config`: - `init_weight`: (str) model weights initialization method. "default" uses kaiming uniform initialization and "xavier" uses Xavier initialization method. - `pre_trained_weights`: (str) Pretrained weights file name supported only for ConvNext and SwinT backbones. For ConvNext, one of ["ConvNeXt_Base_Weights","ConvNeXt_Tiny_Weights", "ConvNeXt_Small_Weights", "ConvNeXt_Large_Weights"]. For SwinT, one of ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"]. - - `backbone_config`: - - `backbone_type`: (str) Backbone architecture for the model to be trained. One of "unet", "convnext" or "swint". - - `backbone_config`: (for UNet) - - `in_channels`: (int) Number of input channels. Default is 1. - - `kernel_size`: (int) Size of the convolutional kernels. Default is 3. - - `filters`: (int) Base number of filters in the network. Default is 32 - - `filters_rate`: (float) Factor to adjust the number of filters per block. Default is 1.5. - - `max_stride`: (int) Scalar integer specifying the maximum stride that the image must be - divisible by. - - `stem_stride`: (int) If not None, will create additional "down" blocks for initial - downsampling based on the stride. These will be configured identically to the down blocks below. - - `middle_block`: (bool) If True, add an additional block at the end of the encoder. default: True - - `up_interpolate`: (bool) If True, use bilinear interpolation instead of transposed - convolutions for upsampling. Interpolation is faster but transposed - convolutions may be able to learn richer or more complex upsampling to - recover details from higher scales. Default: True. - - `stacks`: (int) Number of upsampling blocks in the decoder. Default is 3. - - `convs_per_block`: (int) Number of convolutional layers per block. Default is 2. - - `backbone_config`: (for ConvNext) - - `arch`: (Default is `Tiny` architecture config. No need to provide if `model_type` is provided) - - `depths`: (List(int)) Number of layers in each block. Default: [3, 3, 9, 3]. - - `channels`: (List(int)) Number of channels in each block. Default: [96, 192, 384, 768]. - - `model_type`: (str) One of the ConvNext architecture types: ["tiny", "small", "base", "large"]. Default: "tiny". - - `stem_patch_kernel`: (int) Size of the convolutional kernels in the stem layer. Default is 4. - - `stem_patch_stride`: (int) Convolutional stride in the stem layer. Default is 2. - - `in_channels`: (int) Number of input channels. Default is 1. - - `kernel_size`: (int) Size of the convolutional kernels. Default is 3. - - `filters_rate`: (float) Factor to adjust the number of filters per block. Default is 1.5. - - `convs_per_block`: (int) Number of convolutional layers per block. Default is 2. - - `up_interpolate`: (bool) If True, use bilinear interpolation instead of transposed - convolutions for upsampling. Interpolation is faster but transposed - convolutions may be able to learn richer or more complex upsampling to - recover details from higher scales. Default: True. - - `backbone_config`: (for SwinT. Default is `Tiny` architecture.) - - `model_type`: (str) One of the ConvNext architecture types: ["tiny", "small", "base"]. Default: "tiny". - - `arch`: Dictionary of embed dimension, depths and number of heads in each layer. - Default is "Tiny architecture". - {'embed': 96, 'depths': [2,2,6,2], 'channels':[3, 6, 12, 24]} - - `patch_size`: (List[int]) Patch size for the stem layer of SwinT. Default: [4,4]. - - `stem_patch_stride`: (int) Stride for the patch. Default is 2. - - `window_size`: (List[int]) Window size. Default: [7,7]. - - `in_channels`: (int) Number of input channels. Default is 1. - - `kernel_size`: (int) Size of the convolutional kernels. Default is 3. - - `filters_rate`: (float) Factor to adjust the number of filters per block. Default is 1.5. - - `convs_per_block`: (int) Number of convolutional layers per block. Default is 2. - - `up_interpolate`: (bool) If True, use bilinear interpolation instead of transposed - convolutions for upsampling. Interpolation is faster but transposed - convolutions may be able to learn richer or more complex upsampling to - recover details from higher scales. Default: True. - - `head_configs`: (Dict) Dictionary having head configs with keys `confmaps` and `pafs`. For eg, BottomUp model has both `confmaps` and `pafs` whereas Centroid model only has `confmaps` key. All the keys follow the same structure as given below: - - `confmaps`: - - `head_type`: (str) Name of the head. Supported values are 'SingleInstanceConfmapsHead', 'CentroidConfmapsHead', 'CenteredInstanceConfmapsHead', 'MultiInstanceConfmapsHead', 'PartAffinityFieldsHead', 'ClassMapsHead', 'ClassVectorsHead', 'OffsetRefinementHead' - - `head_config`: + - `backbone_type`: (str) Backbone architecture for the model to be trained. One of "unet", "convnext" or "swint". + - `backbone_config`: (for UNet) + - `in_channels`: (int) Number of input channels. Default is 1. + - `kernel_size`: (int) Size of the convolutional kernels. Default is 3. + - `filters`: (int) Base number of filters in the network. Default is 32 + - `filters_rate`: (float) Factor to adjust the number of filters per block. Default is 1.5. + - `max_stride`: (int) Scalar integer specifying the maximum stride that the image must be + divisible by. + - `stem_stride`: (int) If not None, will create additional "down" blocks for initial + downsampling based on the stride. These will be configured identically to the down blocks below. + - `middle_block`: (bool) If True, add an additional block at the end of the encoder. default: True + - `up_interpolate`: (bool) If True, use bilinear interpolation instead of transposed + convolutions for upsampling. Interpolation is faster but transposed + convolutions may be able to learn richer or more complex upsampling to + recover details from higher scales. Default: True. + - `stacks`: (int) Number of upsampling blocks in the decoder. Default is 3. + - `convs_per_block`: (int) Number of convolutional layers per block. Default is 2. + - `backbone_config`: (for ConvNext) + - `arch`: (Default is `Tiny` architecture config. No need to provide if `model_type` is provided) + - `depths`: (List(int)) Number of layers in each block. Default: [3, 3, 9, 3]. + - `channels`: (List(int)) Number of channels in each block. Default: [96, 192, 384, 768]. + - `model_type`: (str) One of the ConvNext architecture types: ["tiny", "small", "base", "large"]. Default: "tiny". + - `stem_patch_kernel`: (int) Size of the convolutional kernels in the stem layer. Default is 4. + - `stem_patch_stride`: (int) Convolutional stride in the stem layer. Default is 2. + - `in_channels`: (int) Number of input channels. Default is 1. + - `kernel_size`: (int) Size of the convolutional kernels. Default is 3. + - `filters_rate`: (float) Factor to adjust the number of filters per block. Default is 1.5. + - `convs_per_block`: (int) Number of convolutional layers per block. Default is 2. + - `up_interpolate`: (bool) If True, use bilinear interpolation instead of transposed + convolutions for upsampling. Interpolation is faster but transposed + convolutions may be able to learn richer or more complex upsampling to + recover details from higher scales. Default: True. + - `backbone_config`: (for SwinT. Default is `Tiny` architecture.) + - `model_type`: (str) One of the ConvNext architecture types: ["tiny", "small", "base"]. Default: "tiny". + - `arch`: Dictionary of embed dimension, depths and number of heads in each layer. + Default is "Tiny architecture". + {'embed': 96, 'depths': [2,2,6,2], 'channels':[3, 6, 12, 24]} + - `patch_size`: (List[int]) Patch size for the stem layer of SwinT. Default: [4,4]. + - `stem_patch_stride`: (int) Stride for the patch. Default is 2. + - `window_size`: (List[int]) Window size. Default: [7,7]. + - `in_channels`: (int) Number of input channels. Default is 1. + - `kernel_size`: (int) Size of the convolutional kernels. Default is 3. + - `filters_rate`: (float) Factor to adjust the number of filters per block. Default is 1.5. + - `convs_per_block`: (int) Number of convolutional layers per block. Default is 2. + - `up_interpolate`: (bool) If True, use bilinear interpolation instead of transposed + convolutions for upsampling. Interpolation is faster but transposed + convolutions may be able to learn richer or more complex upsampling to + recover details from higher scales. Default: True. + - `head_configs`: (Dict) Dictionary with the following keys having head configs for the model to be trained. **Note**: Configs should be provided only for the model to train and others should be `None`. + - `single_instance`: + - `confmaps`: + - `part_names`: (List[str]) `None` if nodes from `sio.Labels` file can be used directly. Else provide text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'. + - `sigma`: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied. + - `output_stride`: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution. + - `centroid`: + - `confmaps`: + - `anchor_part`: (int) **Note**: Only for 'CenteredInstanceConfmapsHead'. Index of the anchor node to use as the anchor point. If None, the midpoint of the bounding box of all visible instance points will be used as the anchor. The bounding box midpoint will also be used if the anchor part is specified but not visible in the instance. Setting a reliable anchor point can significantly improve topdown model accuracy as they benefit from a consistent geometry of the body parts relative to the center of the image. + - `sigma`: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied. + - `output_stride`: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution. + - `centered_instance`: + - `confmaps`: - `part_names`: (List[str]) `None` if nodes from `sio.Labels` file can be used directly. Else provide text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'. - - `edges`: (List[str]) `None` if edges from `sio.Labels` file can be used directly. **Note**: Only for 'PartAffinityFieldsHead'. List of indices `(src, dest)` that form an edge. - `anchor_part`: (int) **Note**: Only for 'CenteredInstanceConfmapsHead'. Index of the anchor node to use as the anchor point. If None, the midpoint of the bounding box of all visible instance points will be used as the anchor. The bounding box midpoint will also be used if the anchor part is specified but not visible in the instance. Setting a reliable anchor point can significantly improve topdown model accuracy as they benefit from a consistent geometry of the body parts relative to the center of the image. - `sigma`: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied. - `output_stride`: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution. + - `bottomup`: + - `confmaps`: + - `part_names`: (List[str]) `None` if nodes from `sio.Labels` file can be used directly. Else provide text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'. + - `sigma`: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied. + - `output_stride`: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution. - `loss_weight`: (float) Scalar float used to weigh the loss term for this head during training. Increase this to encourage the optimization to focus on improving this specific output in multi-head models. - - `pafs`: (same structure as that of `confmaps`.**Note**: This section is only for BottomUp model.) + - `pafs`: (same structure as that of `confmaps`.**Note**: This section is only for BottomUp model.) + - `edges`: (List[str]) `None` if edges from `sio.Labels` file can be used directly. **Note**: Only for 'PartAffinityFieldsHead'. List of indices `(src, dest)` that form an edge. + - `sigma`: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied. + - `output_stride`: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution. + - `loss_weight`: (float) Scalar float used to weigh the loss term for this head during training. Increase this to encourage the optimization to focus on improving this specific output in multi-head models. + - `trainer_config`: - `train_data_loader`: (**Note**: Any parameters from [Torch's DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) could be used.) @@ -129,7 +147,6 @@ The config file has three main sections: - `model_ckpt`: (**Note**: Any parameters from [Lightning's ModelCheckpoint](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html) could be used.) - `save_top_k`: (int) If save_top_k == k, the best k models according to the quantity monitored will be saved. If save_top_k == 0, no models are saved. If save_top_k == -1, all models are saved. Please note that the monitors are checked every every_n_epochs epochs. if save_top_k >= 2 and the callback is called multiple times inside an epoch, the name of the saved file will be appended with a version count starting with v1 unless enable_version_counter is set to False. - `save_last`: (bool) When True, saves a last.ckpt whenever a checkpoint file gets saved. On a local filesystem, this will be a symbolic link, and otherwise a copy of the checkpoint file. This allows accessing the latest checkpoint in a deterministic manner. *Default*: None. - - `device`: (str) Device on which torch.Tensor will be allocated. One of the ("cpu", "cuda", "mkldnn", "opengl", "opencl", "ideep", "hip", "msnpu"). - `trainer_devices`: (int) Number of devices to train on (int), which devices to train on (list or str), or "auto" to select automatically. - `trainer_accelerator`: (str) One of the ("cpu", "gpu", "tpu", "ipu", "auto"). "auto" recognises the machine the model is running on and chooses the appropriate accelerator for the `Trainer` to be connected to. - `enable_progress_bar`: (bool) When True, enables printing the logs during training. @@ -139,7 +156,7 @@ The config file has three main sections: - `use_wandb`: (bool) True to enable wandb logging. - `save_ckpt`: (bool) True to enable checkpointing. - `save_ckpt_path`: (str) Directory path to save the training config and checkpoint files. *Default*: "./" - - `wandb`: + - `wandb`: (Only if `use_wandb` is `True`, else skip this) - `entity`: (str) Entity of wandb project. - `project`: (str) Project name for the wandb project. - `name`: (str) Name of the current run. diff --git a/docs/config_bottomup.yaml b/docs/config_bottomup.yaml index 0124a1c4..89eeb50c 100644 --- a/docs/config_bottomup.yaml +++ b/docs/config_bottomup.yaml @@ -1,118 +1,36 @@ data_config: provider: LabelsReader - pipeline: BottomUp - train: - labels_path: minimal_instance.pkg.slp + train_labels_path: minimal_instance.pkg.slp + val_labels_path: minimal_instance.pkg.slp + preprocessing: max_width: null max_height: null scale: 1.0 is_rgb: false - preprocessing: - crop_hw: - - 160 - - 160 - augmentation_config: - random_crop: - random_crop_p: 0 - random_crop_hw: - - 160 - - 160 - use_augmentations: true - augmentations: - intensity: - uniform_noise: - - 0.0 - - 0.04 - uniform_noise_p: 0 - gaussian_noise_mean: 0.02 - gaussian_noise_std: 0.004 - gaussian_noise_p: 0 - contrast: - - 0.5 - - 2.0 - contrast_p: 0 - brightness: 0.0 - brightness_p: 0 - geometric: - rotation: 180.0 - scale: 0 - translate: - - 0 - - 0 - affine_p: 0.5 - erase_scale: - - 0.0001 - - 0.01 - erase_ratio: - - 1 - - 1 - erase_p: 0 - mixup_lambda: null - mixup_p: 0 - val: - labels_path: minimal_instance.pkg.slp - max_width: null - max_height: null - is_rgb: false - scale: 1.0 - preprocessing: - crop_hw: - - 160 - - 160 - augmentation_config: - random_crop: - random_crop_p: 0 - random_crop_hw: - - 160 - - 160 - use_augmentations: false - augmentations: - intensity: - uniform_noise: - - 0.0 - - 0.04 - uniform_noise_p: 0 - gaussian_noise_mean: 0.02 - gaussian_noise_std: 0.004 - gaussian_noise_p: 0 - contrast: - - 0.5 - - 2.0 - contrast_p: 0 - brightness: 0.0 - brightness_p: 0 - geometric: - rotation: 180.0 - scale: 0 - translate: - - 0 - - 0 - affine_p: 0.5 - erase_scale: - - 0.0001 - - 0.01 - erase_ratio: - - 1 - - 1 - erase_p: 0 - mixup_lambda: null - mixup_p: 0 + use_augmentations_train: true + augmentation_config: + geometric: + rotation: 180.0 + scale: null + translate_width: 0 + translate_height: 0 + affine_p: 0.5 + model_config: init_weights: xavier pre_trained_weights: null + backbone_type: unet backbone_config: - backbone_type: unet - backbone_config: - in_channels: 1 - kernel_size: 3 - filters: 16 - filters_rate: 2 - max_stride: 16 - convs_per_block: 2 - stacks: 1 - stem_stride: null - middle_block: true - up_interpolate: true + in_channels: 1 + kernel_size: 3 + filters: 16 + filters_rate: 2 + max_stride: 16 + convs_per_block: 2 + stacks: 1 + stem_stride: null + middle_block: true + up_interpolate: true # pre_trained_weights: ConvNeXt_Tiny_Weights # backbone_config: @@ -144,20 +62,20 @@ model_config: # stem_patch_stride: 2 head_configs: - confmaps: - head_type: MultiInstanceConfmapsHead - head_config: - part_names: None - sigma: 1.5 - output_stride: 2 - loss_weight: 1.0 - pafs: - head_type: PartAffinityFieldsHead - head_config: - edges: None - sigma: 50 - output_stride: 4 - loss_weight: 1.0 + single_instance: + centered_instance: + centroid: + bottomup: + confmaps: + part_names: None + sigma: 1.5 + output_stride: 2 + loss_weight: 1.0 + pafs: + edges: None + sigma: 50 + output_stride: 4 + loss_weight: 1.0 trainer_config: train_data_loader: batch_size: 4 @@ -169,7 +87,6 @@ trainer_config: model_ckpt: save_top_k: 1 save_last: true - device: cpu trainer_devices: 1 trainer_accelerator: cpu enable_progress_bar: false @@ -179,18 +96,6 @@ trainer_config: use_wandb: false save_ckpt: true save_ckpt_path: min_inst_bottomup1 - wandb: - entity: team-ucsd - project: test_centroid_centered - name: fly_unet_centered - wandb_mode: '' - api_key: '' - log_params: - - trainer_config.optimizer_name - - trainer_config.optimizer.amsgrad - - trainer_config.optimizer.lr - - model_config.backbone_config.backbone_type - - model_config.init_weights optimizer_name: Adam optimizer: lr: 0.0001 diff --git a/docs/config_centroid.yaml b/docs/config_centroid.yaml index 8fd5ff5d..d596dd03 100644 --- a/docs/config_centroid.yaml +++ b/docs/config_centroid.yaml @@ -1,118 +1,59 @@ data_config: provider: LabelsReader - pipeline: CentroidConfmaps - train: - labels_path: "minimal_instance.pkg.slp" + train_labels_path: minimal_instance.pkg.slp + val_labels_path: minimal_instance.pkg.slp + preprocessing: max_width: max_height: scale: 0.5 is_rgb: False - preprocessing: - crop_hw: - - 160 - - 160 - augmentation_config: - random_crop: - random_crop_p: 0 - random_crop_hw: - - 160 - - 160 - use_augmentations: true - augmentations: - intensity: - uniform_noise: - - 0.0 - - 0.04 - uniform_noise_p: 0 - gaussian_noise_mean: 0.02 - gaussian_noise_std: 0.004 - gaussian_noise_p: 0 - contrast: - - 0.5 - - 2.0 - contrast_p: 0 - brightness: 0.0 - brightness_p: 0 - geometric: - rotation: 180.0 - scale: 0 - translate: - - 0 - - 0 - affine_p: 0.5 - erase_scale: - - 0.0001 - - 0.01 - erase_ratio: - - 1 - - 1 - erase_p: 0 - mixup_lambda: null - mixup_p: 0 - val: - labels_path: "minimal_instance.pkg.slp" - max_width: - max_height: - is_rgb: False - scale: 0.5 - preprocessing: - crop_hw: - - 160 - - 160 - augmentation_config: - random_crop: - random_crop_p: 0 - random_crop_hw: - - 160 - - 160 - use_augmentations: false - augmentations: - intensity: - uniform_noise: - - 0.0 - - 0.04 - uniform_noise_p: 0 - gaussian_noise_mean: 0.02 - gaussian_noise_std: 0.004 - gaussian_noise_p: 0 - contrast: - - 0.5 - - 2.0 - contrast_p: 0 - brightness: 0.0 - brightness_p: 0 - geometric: - rotation: 180.0 - scale: 0 - translate: - - 0 - - 0 - affine_p: 0.5 - erase_scale: - - 0.0001 - - 0.01 - erase_ratio: - - 1 - - 1 - erase_p: 0 - mixup_lambda: null - mixup_p: 0 + use_augmentations_train: true + augmentation_config: # sample augmentation_config + random_crop: + random_crop_p: 0 + crop_height: 160 + crop_width: 160 + intensity: + uniform_noise_min: 0.0 + uniform_noise_max: 0.04 + uniform_noise_p: 0 + gaussian_noise_mean: 0.02 + gaussian_noise_std: 0.004 + gaussian_noise_p: 0 + contrast_min: 0.5 + contrast_max: 2.0 + contrast_p: 0 + brightness: 0.0 + brightness_p: 0 + geometric: + rotation: 180.0 + scale: null + translate_width: 0 + translate_height: 0 + affine_p: 0.5 + erase_scale_min: 0.0001 + erase_scale_max: 0.01 + erase_ratio_min: 1 + erase_ratio_max: 1 + erase_p: 0 + mixup_lambda: null + mixup_p: 0 + model_config: init_weights: xavier pre_trained_weights: + backbone_type: unet backbone_config: - backbone_type: unet - backbone_config: - in_channels: 1 - kernel_size: 3 - filters: 16 - filters_rate: 2 - max_stride: 16 - convs_per_block: 2 - stacks: 1 - stem_stride: - middle_block: True - up_interpolate: True + in_channels: 1 + kernel_size: 3 + filters: 16 + filters_rate: 2 + max_stride: 16 + convs_per_block: 2 + stacks: 1 + stem_stride: + middle_block: True + up_interpolate: True # pre_trained_weights: ConvNeXt_Tiny_Weights # backbone_config: @@ -143,14 +84,15 @@ model_config: # up_interpolate: True # stem_patch_stride: 2 - head_configs: - confmaps: - head_type: CentroidConfmapsHead - head_config: - anchor_part: 0 - sigma: 1.5 - output_stride: 2 - loss_weight: 1.0 + head_configs: + single_instance: + centered_instance: + bottomup: + centroid: + confmaps: + anchor_part: 0 + sigma: 1.5 + output_stride: 2 trainer_config: train_data_loader: batch_size: 4 @@ -162,7 +104,6 @@ trainer_config: model_ckpt: save_top_k: 1 save_last: true - device: cuda trainer_devices: 1 trainer_accelerator: gpu enable_progress_bar: false @@ -172,7 +113,7 @@ trainer_config: use_wandb: false save_ckpt: true save_ckpt_path: 'min_inst_centroid' - wandb: + wandb: # sample wandb config entity: project: 'test_centroid_centered' name: 'fly_unet_centered' @@ -182,7 +123,7 @@ trainer_config: - trainer_config.optimizer_name - trainer_config.optimizer.amsgrad - trainer_config.optimizer.lr - - model_config.backbone_config.backbone_type + - model_config.backbone_type - model_config.init_weights optimizer_name: Adam optimizer: diff --git a/docs/config_topdown_centered_instance.yaml b/docs/config_topdown_centered_instance.yaml index 1ae11039..a0dce2da 100644 --- a/docs/config_topdown_centered_instance.yaml +++ b/docs/config_topdown_centered_instance.yaml @@ -1,118 +1,39 @@ data_config: provider: LabelsReader - pipeline: TopdownConfmaps - train: - labels_path: "minimal_instance.pkg.slp" + train_labels_path: minimal_instance.pkg.slp + val_labels_path: minimal_instance.pkg.slp + preprocessing: max_width: max_height: scale: 1.0 is_rgb: False - preprocessing: - crop_hw: + crop_hw: - 160 - 160 - augmentation_config: - random_crop: - random_crop_p: 0 - random_crop_hw: - - 160 - - 160 - use_augmentations: true - augmentations: - intensity: - uniform_noise: - - 0.0 - - 0.04 - uniform_noise_p: 0 - gaussian_noise_mean: 0.02 - gaussian_noise_std: 0.004 - gaussian_noise_p: 0 - contrast: - - 0.5 - - 2.0 - contrast_p: 0 - brightness: 0.0 - brightness_p: 0 - geometric: - rotation: 180.0 - scale: 0 - translate: - - 0 - - 0 - affine_p: 0.5 - erase_scale: - - 0.0001 - - 0.01 - erase_ratio: - - 1 - - 1 - erase_p: 0 - mixup_lambda: null - mixup_p: 0 - val: - labels_path: "minimal_instance.pkg.slp" - max_width: - max_height: - is_rgb: False - scale: 1.0 - preprocessing: - crop_hw: - - 160 - - 160 - augmentation_config: - random_crop: - random_crop_p: 0 - random_crop_hw: - - 160 - - 160 - use_augmentations: false - augmentations: - intensity: - uniform_noise: - - 0.0 - - 0.04 - uniform_noise_p: 0 - gaussian_noise_mean: 0.02 - gaussian_noise_std: 0.004 - gaussian_noise_p: 0 - contrast: - - 0.5 - - 2.0 - contrast_p: 0 - brightness: 0.0 - brightness_p: 0 - geometric: - rotation: 180.0 - scale: 0 - translate: - - 0 - - 0 - affine_p: 0.5 - erase_scale: - - 0.0001 - - 0.01 - erase_ratio: - - 1 - - 1 - erase_p: 0 - mixup_lambda: null - mixup_p: 0 + use_augmentations_train: true + augmentation_config: + geometric: + rotation: 180.0 + scale: null + translate_width: 0 + translate_height: 0 + affine_p: 0.5 + model_config: init_weights: xavier pre_trained_weights: + backbone_type: unet backbone_config: - backbone_type: unet - backbone_config: - in_channels: 1 - kernel_size: 3 - filters: 16 - filters_rate: 2 - max_stride: 16 - convs_per_block: 2 - stacks: 1 - stem_stride: - middle_block: True - up_interpolate: True + in_channels: 1 + kernel_size: 3 + filters: 16 + filters_rate: 2 + max_stride: 16 + convs_per_block: 2 + stacks: 1 + stem_stride: + middle_block: True + up_interpolate: True # pre_trained_weights: ConvNeXt_Tiny_Weights # backbone_config: @@ -143,15 +64,16 @@ model_config: # up_interpolate: True # stem_patch_stride: 2 - head_configs: - confmaps: - head_type: CenteredInstanceConfmapsHead - head_config: - part_names: None - anchor_part: 0 - sigma: 1.5 - output_stride: 2 - loss_weight: 1.0 + head_configs: + single_instance: + centroid: + bottomup: + centered_instance: + confmaps: + part_names: None + anchor_part: 0 + sigma: 1.5 + output_stride: 2 trainer_config: train_data_loader: batch_size: 4 @@ -163,7 +85,6 @@ trainer_config: model_ckpt: save_top_k: 1 save_last: true - device: cuda trainer_devices: 1 trainer_accelerator: gpu # TODO: redo device + trainer_accelerator! enable_progress_bar: false @@ -173,18 +94,6 @@ trainer_config: use_wandb: false save_ckpt: true save_ckpt_path: 'min_inst_centered' - wandb: - entity: - project: 'test_centroid_centered' - name: 'fly_unet_centered' - wandb_mode: '' - api_key: '' - log_params: - - trainer_config.optimizer_name - - trainer_config.optimizer.amsgrad - - trainer_config.optimizer.lr - - model_config.backbone_config.backbone_type - - model_config.init_weights optimizer_name: Adam optimizer: lr: 0.0001 diff --git a/sleap_nn/architectures/model.py b/sleap_nn/architectures/model.py index 6e6e7e7a..2a06383c 100644 --- a/sleap_nn/architectures/model.py +++ b/sleap_nn/architectures/model.py @@ -61,86 +61,84 @@ def get_backbone( return backbone -def get_head(head: str, head_config: DictConfig) -> Head: +def get_head(model_type: str, head_config: DictConfig) -> Head: """Get a head `nn.Module` based on the provided name. This function returns an instance of a PyTorch `nn.Module` corresponding to the given head name. Args: - head (str): Name of the head. Supported values are - - 'SingleInstanceConfmapsHead' - - 'CentroidConfmapsHead' - - 'CenteredInstanceConfmapsHead' - - 'MultiInstanceConfmapsHead' - - 'PartAffinityFieldsHead' - - 'ClassMapsHead' - - 'ClassVectorsHead' - - 'OffsetRefinementHead' + model_type (str): Name of the head. Supported values are + - 'single_instance' + - 'centroid' + - 'centered_instance' + - 'bottomup' head_config (DictConfig): A config for the head. Returns: nn.Module: An instance of the requested head. - - Raises: - KeyError: If the provided head name is not one of the supported values. """ - heads = { - "SingleInstanceConfmapsHead": SingleInstanceConfmapsHead, - "CentroidConfmapsHead": CentroidConfmapsHead, - "CenteredInstanceConfmapsHead": CenteredInstanceConfmapsHead, - "MultiInstanceConfmapsHead": MultiInstanceConfmapsHead, - "PartAffinityFieldsHead": PartAffinityFieldsHead, - "ClassMapsHead": ClassMapsHead, - "ClassVectorsHead": ClassVectorsHead, - "OffsetRefinementHead": OffsetRefinementHead, - } - - if head not in heads: - raise KeyError( - f"Unsupported head: {head}. Supported heads are: {', '.join(heads.keys())}" - ) + heads = [] + if model_type == "single_instance": + heads.append(SingleInstanceConfmapsHead(**head_config.confmaps)) + + elif model_type == "centered_instance": + heads.append(CenteredInstanceConfmapsHead(**head_config.confmaps)) - head = heads[head](**head_config) + elif model_type == "centroid": + heads.append(CentroidConfmapsHead(**head_config.confmaps)) - return head + elif model_type == "bottomup": + heads.append(MultiInstanceConfmapsHead(**head_config.confmaps)) + heads.append(PartAffinityFieldsHead(**head_config.pafs)) + + else: + raise Exception( + f"{model_type} is not a defined model type. Please choose one of `single_instance`, `centered_instance`, `centroid`, `bottomup`." + ) + + return heads class Model(nn.Module): """Model creates a model consisting of a backbone and head. Attributes: + backbone_type: Backbone type. One of `unet`, `convnext` and `swint`. backbone_config: An `DictConfig` configuration dictionary for the model backbone. head_configs: An `DictConfig` configuration dictionary for the model heads. input_expand_channels: Integer representing the number of channels the image should be expanded to. + model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`. """ def __init__( self, + backbone_type: str, backbone_config: DictConfig, head_configs: DictConfig, input_expand_channels: int, + model_type: str, ) -> None: """Initialize the backbone and head based on the backbone_config.""" super().__init__() + self.backbone_type = backbone_type self.backbone_config = backbone_config self.head_configs = head_configs self.input_expand_channels = input_expand_channels - self.heads = [] + self.heads = get_head(model_type, self.head_configs) + output_strides = [] for head_type in head_configs: head_config = head_configs[head_type] - head = get_head(head_config.head_type, head_config.head_config) - self.heads.append(head) - output_strides.append(head_config.head_config.output_stride) + output_strides.append(head_config.output_stride) min_output_stride = min(output_strides) self.backbone = get_backbone( - backbone_config.backbone_type, - backbone_config.backbone_config, + self.backbone_type, + backbone_config, min_output_stride, ) @@ -150,7 +148,7 @@ def __init__( in_channels = int( self.backbone.max_channels / ( - self.backbone_config.backbone_config.filters_rate + self.backbone_config.filters_rate ** len(self.backbone.dec.decoder_stack) ) ) @@ -158,23 +156,25 @@ def __init__( factor = strides.index(min_output_stride) - strides.index( head.output_stride ) - in_channels = in_channels * ( - self.backbone_config.backbone_config.filters_rate**factor - ) + in_channels = in_channels * (self.backbone_config.filters_rate**factor) self.head_layers.append(head.make_head(x_in=int(in_channels))) @classmethod def from_config( cls, + backbone_type: str, backbone_config: DictConfig, head_configs: DictConfig, input_expand_channels: int, + model_type: str, ) -> "Model": """Create the model from a config dictionary.""" return cls( + backbone_type=backbone_type, backbone_config=backbone_config, head_configs=head_configs, input_expand_channels=input_expand_channels, + model_type=model_type, ) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/sleap_nn/data/augmentation.py b/sleap_nn/data/augmentation.py index 9f9983c4..4eefdbb7 100644 --- a/sleap_nn/data/augmentation.py +++ b/sleap_nn/data/augmentation.py @@ -101,32 +101,37 @@ class KorniaAugmenter(IterDataPipe): rotation: Angles in degrees as a scalar float of the amount of rotation. A random angle in `(-rotation, rotation)` will be sampled and applied to both images and keypoints. Set to 0 to disable rotation augmentation. - scale: A scaling factor as a scalar float specifying the amount of scaling. A - random factor between `(1 - scale, 1 + scale)` will be sampled and applied - to both images and keypoints. If `None`, no scaling augmentation will be - applied. - translate: tuple of maximum absolute fraction for horizontal - and vertical translations. For example translate=(a, b), then horizontal shift - is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is - randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. + scale: scaling factor interval. If (a, b) represents isotropic scaling, the scale + is randomly sampled from the range a <= scale <= b. If (a, b, c, d), the scale + is randomly sampled from the range a <= scale_x <= b, c <= scale_y <= d. + Default: None. + translate_width: Maximum absolute fraction for horizontal translation. For example, + if translate_width=a, then horizontal shift is randomly sampled in the range + -img_width * a < dx < img_width * a. Will not translate by default. + translate_height: Maximum absolute fraction for vertical translation. For example, + if translate_height=a, then vertical shift is randomly sampled in the range + -img_height * a < dy < img_height * a. Will not translate by default. affine_p: Probability of applying random affine transformations. - uniform_noise: tuple of uniform noise `(min_noise, max_noise)`. - Must satisfy 0. <= min_noise <= max_noise <= 1. + uniform_noise_min: Minimum value for uniform noise (uniform_noise_min >=0). + uniform_noise_max: Maximum value for uniform noise (uniform_noise_max <=1). uniform_noise_p: Probability of applying random uniform noise. gaussian_noise_mean: The mean of the gaussian distribution. gaussian_noise_std: The standard deviation of the gaussian distribution. gaussian_noise_p: Probability of applying random gaussian noise. - contrast: The contrast factor to apply. Default: `(1.0, 1.0)`. + contrast_min: Minimum contrast factor to apply. Default: 0.5. + contrast_max: Maximum contrast factor to apply. Default: 2.0. contrast_p: Probability of applying random contrast. - brightness: The brightness factor to apply Default: `(1.0, 1.0)`. + brightness: The brightness factor to apply Default: 0.0. brightness_p: Probability of applying random brightness. - erase_scale: Range of proportion of erased area against input image. Default: `(0.0001, 0.01)`. - erase_ratio: Range of aspect ratio of erased area. Default: `(1, 1)`. + erase_scale_min: Minimum value of range of proportion of erased area against input image. Default: 0.0001. + erase_scale_max: Maximum value of range of proportion of erased area against input image. Default: 0.01. + erase_ratio_min: Minimum value of range of aspect ratio of erased area. Default: 1. + erase_ratio_max: Maximum value of range of aspect ratio of erased area. Default: 1. erase_p: Probability of applying random erase. mixup_lambda: min-max value of mixup strength. Default is 0-1. Default: `None`. mixup_p: Probability of applying random mixup v2. - random_crop_hw: Desired output size (out_h, out_w) of the crop. Must be Tuple[int, int], - then out_h = size[0], out_w = size[1]. + random_crop_height: Desired output height of the crop. Must be int. + random_crop_width: Desired output width of the crop. Must be int. random_crop_p: Probability of applying random crop. input_key: Can be `image` or `instance`. The input_key `instance` expects the the KorniaAugmenter to follow the InstanceCropper else `image` otherwise @@ -149,24 +154,32 @@ def __init__( self, source_dp: IterDataPipe, rotation: Optional[float] = 15.0, - scale: Optional[float] = 0.05, - translate: Optional[Tuple[float, float]] = (0.02, 0.02), + scale: Union[ + Optional[float], Tuple[float, float], Tuple[float, float, float, float] + ] = None, + translate_width: Optional[float] = 0.02, + translate_height: Optional[float] = 0.02, affine_p: float = 0.0, - uniform_noise: Optional[Tuple[float, float]] = (0.0, 0.04), + uniform_noise_min: Optional[float] = 0.0, + uniform_noise_max: Optional[float] = 0.04, uniform_noise_p: float = 0.0, gaussian_noise_mean: Optional[float] = 0.02, gaussian_noise_std: Optional[float] = 0.004, gaussian_noise_p: float = 0.0, - contrast: Optional[Tuple[float, float]] = (0.5, 2.0), + contrast_min: Optional[float] = 0.5, + contrast_max: Optional[float] = 2.0, contrast_p: float = 0.0, brightness: Optional[float] = 0.0, brightness_p: float = 0.0, - erase_scale: Optional[Tuple[float, float]] = (0.0001, 0.01), - erase_ratio: Optional[Tuple[float, float]] = (1, 1), + erase_scale_min: Optional[float] = 0.0001, + erase_scale_max: Optional[float] = 0.01, + erase_ratio_min: Optional[float] = 1, + erase_ratio_max: Optional[float] = 1, erase_p: float = 0.0, mixup_lambda: Union[Optional[float], Tuple[float, float], None] = None, mixup_p: float = 0.0, - random_crop_hw: Tuple[int, int] = (0, 0), + random_crop_height: int = 0, + random_crop_width: int = 0, random_crop_p: float = 0.0, image_key: str = "image", instance_key: str = "instances", @@ -174,24 +187,32 @@ def __init__( """Initialize the block and the augmentation pipeline.""" self.source_dp = source_dp self.rotation = rotation - self.scale = (1 - scale, 1 + scale) - self.translate = translate + self.scale = scale + if isinstance(self.scale, float): + self.scale = (scale, scale) + self.translate_width = translate_width + self.translate_height = translate_height self.affine_p = affine_p - self.uniform_noise = uniform_noise + self.uniform_noise_min = uniform_noise_min + self.uniform_noise_max = uniform_noise_max self.uniform_noise_p = uniform_noise_p self.gaussian_noise_mean = gaussian_noise_mean self.gaussian_noise_std = gaussian_noise_std self.gaussian_noise_p = gaussian_noise_p - self.contrast = contrast + self.contrast_min = contrast_min + self.contrast_max = contrast_max self.contrast_p = contrast_p self.brightness = brightness self.brightness_p = brightness_p - self.erase_scale = erase_scale - self.erase_ratio = erase_ratio + self.erase_scale_min = erase_scale_min + self.erase_scale_max = erase_scale_max + self.erase_ratio_min = erase_ratio_min + self.erase_ratio_max = erase_ratio_max self.erase_p = erase_p self.mixup_lambda = mixup_lambda self.mixup_p = mixup_p - self.random_crop_hw = random_crop_hw + self.random_crop_height = random_crop_height + self.random_crop_width = random_crop_width self.random_crop_p = random_crop_p self.image_key = image_key self.instance_key = instance_key @@ -201,7 +222,7 @@ def __init__( aug_stack.append( K.augmentation.RandomAffine( degrees=self.rotation, - translate=self.translate, + translate=(self.translate_width, self.translate_height), scale=self.scale, p=self.affine_p, keepdim=True, @@ -211,7 +232,7 @@ def __init__( if self.uniform_noise_p > 0: aug_stack.append( RandomUniformNoise( - noise=self.uniform_noise, + noise=(self.uniform_noise_min, self.uniform_noise_max), p=self.uniform_noise_p, keepdim=True, same_on_batch=True, @@ -230,7 +251,7 @@ def __init__( if self.contrast_p > 0: aug_stack.append( K.augmentation.RandomContrast( - contrast=self.contrast, + contrast=(self.contrast_min, self.contrast_max), p=self.contrast_p, keepdim=True, same_on_batch=True, @@ -248,8 +269,8 @@ def __init__( if self.erase_p > 0: aug_stack.append( K.augmentation.RandomErasing( - scale=self.erase_scale, - ratio=self.erase_ratio, + scale=(self.erase_scale_min, self.erase_scale_max), + ratio=(self.erase_ratio_min, self.erase_ratio_max), p=self.erase_p, keepdim=True, same_on_batch=True, @@ -265,10 +286,10 @@ def __init__( ) ) if self.random_crop_p > 0: - if self.random_crop_hw[0] > 0 and self.random_crop_hw[1] > 0: + if self.random_crop_height > 0 and self.random_crop_width > 0: aug_stack.append( K.augmentation.RandomCrop( - size=self.random_crop_hw, + size=(self.random_crop_height, self.random_crop_width), pad_if_needed=True, p=self.random_crop_p, keepdim=True, diff --git a/sleap_nn/data/instance_cropping.py b/sleap_nn/data/instance_cropping.py index fbd3cfbb..f5660cc5 100644 --- a/sleap_nn/data/instance_cropping.py +++ b/sleap_nn/data/instance_cropping.py @@ -81,6 +81,7 @@ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: centroids = ex["centroids"] # (n_samples, n_instances, 2) del ex["instances"] del ex["centroids"] + del ex["image"] for cnt, (instance, centroid) in enumerate(zip(instances[0], centroids[0])): if cnt == ex["num_instances"]: break diff --git a/sleap_nn/data/pipelines.py b/sleap_nn/data/pipelines.py index 199bdff0..ce1b09f8 100644 --- a/sleap_nn/data/pipelines.py +++ b/sleap_nn/data/pipelines.py @@ -43,29 +43,33 @@ def __init__( self.max_stride = max_stride self.confmap_head = confmap_head - def make_training_pipeline(self, data_provider: IterDataPipe) -> IterDataPipe: + def make_training_pipeline( + self, data_provider: IterDataPipe, use_augmentations: bool = False + ) -> IterDataPipe: """Create training pipeline with input data only. Args: data_provider: A `Provider` that generates data examples, typically a `LabelsReader` instance. + use_augmentations: `True` if augmentations should be applied to the training + pipeline, else `False`. Default: `False`. Returns: An `IterDataPipe` instance configured to produce input examples. """ provider = data_provider - datapipe = Normalizer(provider, self.data_config.is_rgb) + datapipe = Normalizer(provider, self.data_config.preprocessing.is_rgb) datapipe = SizeMatcher( datapipe, - max_height=self.data_config.max_height, - max_width=self.data_config.max_width, + max_height=self.data_config.preprocessing.max_height, + max_width=self.data_config.preprocessing.max_width, provider=provider, ) - if self.data_config.augmentation_config.use_augmentations: + if use_augmentations and "intensity" in self.data_config.augmentation_config: datapipe = KorniaAugmenter( datapipe, - **dict(self.data_config.augmentation_config.augmentations.intensity), + **dict(self.data_config.augmentation_config.intensity), image_key="image", instance_key="instances", ) @@ -79,26 +83,17 @@ def make_training_pipeline(self, data_provider: IterDataPipe) -> IterDataPipe: self.data_config.preprocessing.crop_hw, ) - if self.data_config.augmentation_config.random_crop.random_crop_p: + if use_augmentations and "geometric" in self.data_config.augmentation_config: datapipe = KorniaAugmenter( datapipe, - random_crop_hw=self.data_config.augmentation_config.random_crop.random_crop_hw, - random_crop_p=self.data_config.augmentation_config.random_crop.random_crop_p, - image_key="instance_image", - instance_key="instance", - ) - - if self.data_config.augmentation_config.use_augmentations: - datapipe = KorniaAugmenter( - datapipe, - **dict(self.data_config.augmentation_config.augmentations.geometric), + **dict(self.data_config.augmentation_config.geometric), image_key="instance_image", instance_key="instance", ) datapipe = Resizer( datapipe, - scale=self.data_config.scale, + scale=self.data_config.preprocessing.scale, image_key="instance_image", instances_key="instance", ) @@ -116,7 +111,6 @@ def make_training_pipeline(self, data_provider: IterDataPipe) -> IterDataPipe: datapipe = KeyFilter( datapipe, keep_keys=[ - "image", "video_idx", "frame_idx", "centroid", @@ -126,7 +120,6 @@ def make_training_pipeline(self, data_provider: IterDataPipe) -> IterDataPipe: "confidence_maps", "num_instances", "orig_size", - "scale", ], ) @@ -152,44 +145,56 @@ def __init__( self.max_stride = max_stride self.confmap_head = confmap_head - def make_training_pipeline(self, data_provider: IterDataPipe) -> IterDataPipe: + def make_training_pipeline( + self, data_provider: IterDataPipe, use_augmentations: bool = False + ) -> IterDataPipe: """Create training pipeline with input data only. Args: data_provider: A `Provider` that generates data examples, typically a `LabelsReader` instance. + use_augmentations: `True` if augmentations should be applied to the training + pipeline, else `False`. Default: `False`. Returns: An `IterDataPipe` instance configured to produce input examples. """ provider = data_provider - datapipe = Normalizer(provider, self.data_config.is_rgb) + datapipe = Normalizer(provider, self.data_config.preprocessing.is_rgb) datapipe = SizeMatcher( datapipe, - max_height=self.data_config.max_height, - max_width=self.data_config.max_width, + max_height=self.data_config.preprocessing.max_height, + max_width=self.data_config.preprocessing.max_width, provider=provider, ) - if self.data_config.augmentation_config.use_augmentations: - datapipe = KorniaAugmenter( - datapipe, - **dict(self.data_config.augmentation_config.augmentations.intensity), - **dict(self.data_config.augmentation_config.augmentations.geometric), - image_key="image", - instance_key="instances", - ) - - if self.data_config.augmentation_config.random_crop.random_crop_p: - datapipe = KorniaAugmenter( - datapipe, - random_crop_hw=self.data_config.augmentation_config.random_crop.random_crop_hw, - random_crop_p=self.data_config.augmentation_config.random_crop.random_crop_p, - image_key="image", - instance_key="instances", - ) - - datapipe = Resizer(datapipe, scale=self.data_config.scale) + if use_augmentations: + if "intensity" in self.data_config.augmentation_config: + datapipe = KorniaAugmenter( + datapipe, + **dict(self.data_config.augmentation_config.intensity), + image_key="image", + instance_key="instances", + ) + if "geometric" in self.data_config.augmentation_config: + datapipe = KorniaAugmenter( + datapipe, + **dict(self.data_config.augmentation_config.geometric), + image_key="image", + instance_key="instances", + ) + + if "random_crop" in self.data_config.augmentation_config: + datapipe = KorniaAugmenter( + datapipe, + random_crop_height=self.data_config.augmentation_config.random_crop.crop_height, + random_crop_width=self.data_config.augmentation_config.random_crop.crop_width, + random_crop_p=self.data_config.augmentation_config.random_crop.random_crop_p, + image_key="image", + instance_key="instances", + ) + + datapipe = Resizer(datapipe, scale=self.data_config.preprocessing.scale) datapipe = PadToStride(datapipe, max_stride=self.max_stride) datapipe = ConfidenceMapGenerator( @@ -208,7 +213,6 @@ def make_training_pipeline(self, data_provider: IterDataPipe) -> IterDataPipe: "instances", "confidence_maps", "orig_size", - "scale", ], ) @@ -234,12 +238,16 @@ def __init__( self.max_stride = max_stride self.confmap_head = confmap_head - def make_training_pipeline(self, data_provider: IterDataPipe) -> IterDataPipe: + def make_training_pipeline( + self, data_provider: IterDataPipe, use_augmentations: bool = False + ) -> IterDataPipe: """Create training pipeline with input data only. Args: data_provider: A `Provider` that generates data examples, typically a `LabelsReader` instance. + use_augmentations: `True` if augmentations should be applied to the training + pipeline, else `False`. Default: `False`. Returns: An `IterDataPipe` instance configured to produce input examples. @@ -252,42 +260,41 @@ def make_training_pipeline(self, data_provider: IterDataPipe) -> IterDataPipe: "centroids_confidence_maps", "orig_size", "num_instances", - "scale", ] - datapipe = Normalizer(provider, self.data_config.is_rgb) + datapipe = Normalizer(provider, self.data_config.preprocessing.is_rgb) datapipe = SizeMatcher( datapipe, - max_height=self.data_config.max_height, - max_width=self.data_config.max_width, + max_height=self.data_config.preprocessing.max_height, + max_width=self.data_config.preprocessing.max_width, provider=provider, ) - if self.data_config.augmentation_config.use_augmentations: - datapipe = KorniaAugmenter( - datapipe, - **dict(self.data_config.augmentation_config.augmentations.intensity), - image_key="image", - instance_key="instances", - ) - - if self.data_config.augmentation_config.random_crop.random_crop_p: - datapipe = KorniaAugmenter( - datapipe, - random_crop_hw=self.data_config.augmentation_config.random_crop.random_crop_hw, - random_crop_p=self.data_config.augmentation_config.random_crop.random_crop_p, - image_key="image", - instance_key="instances", - ) - - if self.data_config.augmentation_config.use_augmentations: - datapipe = KorniaAugmenter( - datapipe, - **dict(self.data_config.augmentation_config.augmentations.geometric), - image_key="image", - instance_key="instances", - ) - - datapipe = Resizer(datapipe, scale=self.data_config.scale) + if use_augmentations: + if "intensity" in self.data_config.augmentation_config: + datapipe = KorniaAugmenter( + datapipe, + **dict(self.data_config.augmentation_config.intensity), + image_key="image", + instance_key="instances", + ) + if "geometric" in self.data_config.augmentation_config: + datapipe = KorniaAugmenter( + datapipe, + **dict(self.data_config.augmentation_config.geometric), + image_key="image", + instance_key="instances", + ) + if "random_crop" in self.data_config.augmentation_config: + datapipe = KorniaAugmenter( + datapipe, + random_crop_height=self.data_config.augmentation_config.random_crop.crop_height, + random_crop_width=self.data_config.augmentation_config.random_crop.crop_width, + random_crop_p=self.data_config.augmentation_config.random_crop.random_crop_p, + image_key="image", + instance_key="instances", + ) + + datapipe = Resizer(datapipe, scale=self.data_config.preprocessing.scale) datapipe = PadToStride(datapipe, max_stride=self.max_stride) datapipe = InstanceCentroidFinder( datapipe, anchor_ind=self.confmap_head.anchor_part @@ -332,12 +339,16 @@ def __init__( self.confmap_head = confmap_head self.pafs_head = pafs_head - def make_training_pipeline(self, data_provider: IterDataPipe) -> IterDataPipe: + def make_training_pipeline( + self, data_provider: IterDataPipe, use_augmentations: bool = False + ) -> IterDataPipe: """Create training pipeline with input data only. Args: data_provider: A `Provider` that generates data examples, typically a `LabelsReader` instance. + use_augmentations: `True` if augmentations should be applied to the training + pipeline, else `False`. Default: `False`. Returns: An `IterDataPipe` instance configured to produce input examples. @@ -350,43 +361,43 @@ def make_training_pipeline(self, data_provider: IterDataPipe) -> IterDataPipe: "confidence_maps", "orig_size", "num_instances", - "scale", "part_affinity_fields", ] - datapipe = Normalizer(provider, self.data_config.is_rgb) + datapipe = Normalizer(provider, self.data_config.preprocessing.is_rgb) datapipe = SizeMatcher( datapipe, - max_height=self.data_config.max_height, - max_width=self.data_config.max_width, + max_height=self.data_config.preprocessing.max_height, + max_width=self.data_config.preprocessing.max_width, provider=provider, ) - if self.data_config.augmentation_config.use_augmentations: - datapipe = KorniaAugmenter( - datapipe, - **dict(self.data_config.augmentation_config.augmentations.intensity), - image_key="image", - instance_key="instances", - ) - - if self.data_config.augmentation_config.random_crop.random_crop_p: - datapipe = KorniaAugmenter( - datapipe, - random_crop_hw=self.data_config.augmentation_config.random_crop.random_crop_hw, - random_crop_p=self.data_config.augmentation_config.random_crop.random_crop_p, - image_key="image", - instance_key="instances", - ) - - if self.data_config.augmentation_config.use_augmentations: - datapipe = KorniaAugmenter( - datapipe, - **dict(self.data_config.augmentation_config.augmentations.geometric), - image_key="image", - instance_key="instances", - ) - - datapipe = Resizer(datapipe, scale=self.data_config.scale) + if use_augmentations: + if "intensity" in self.data_config.augmentation_config: + datapipe = KorniaAugmenter( + datapipe, + **dict(self.data_config.augmentation_config.intensity), + image_key="image", + instance_key="instances", + ) + if "geometric" in self.data_config.augmentation_config: + datapipe = KorniaAugmenter( + datapipe, + **dict(self.data_config.augmentation_config.geometric), + image_key="image", + instance_key="instances", + ) + + if "random_crop" in self.data_config.augmentation_config: + datapipe = KorniaAugmenter( + datapipe, + random_crop_height=self.data_config.augmentation_config.random_crop.crop_height, + random_crop_width=self.data_config.augmentation_config.random_crop.crop_width, + random_crop_p=self.data_config.augmentation_config.random_crop.random_crop_p, + image_key="image", + instance_key="instances", + ) + + datapipe = Resizer(datapipe, scale=self.data_config.preprocessing.scale) datapipe = PadToStride(datapipe, max_stride=self.max_stride) datapipe = MultiConfidenceMapGenerator( diff --git a/sleap_nn/data/resizing.py b/sleap_nn/data/resizing.py index 1cd1ac9b..b39a10b1 100644 --- a/sleap_nn/data/resizing.py +++ b/sleap_nn/data/resizing.py @@ -124,7 +124,6 @@ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: if self.scale != 1.0: ex[self.image_key] = resize_image(ex[self.image_key], self.scale) ex[self.instances_key] = ex[self.instances_key] * self.scale - ex["scale"] = self.scale yield ex diff --git a/sleap_nn/inference/predictors.py b/sleap_nn/inference/predictors.py index 7ec27d77..b95354e3 100644 --- a/sleap_nn/inference/predictors.py +++ b/sleap_nn/inference/predictors.py @@ -83,7 +83,7 @@ class Predictor(ABC): def from_model_paths( cls, model_paths: List[Text], - peak_threshold: float = 0.2, + peak_threshold: Union[float, List[float]] = 0.2, integral_refinement: str = None, integral_patch_size: int = 5, batch_size: int = 4, @@ -100,7 +100,10 @@ def from_model_paths( model_paths: (List[str]) List of paths to the directory where the best.ckpt and training_config.yaml are saved. peak_threshold: (float) Minimum confidence threshold. Peaks with values below - this will be ignored. Default: 0.2. + this will be ignored. Default: 0.2. This can also be `List[float]` for topdown + centroid and centered-instance model, where the first element corresponds + to centroid model peak finding threshold and the second element is for + centered-instance model peak finding. integral_refinement: If `None`, returns the grid-aligned peaks with no refinement. If `"integral"`, peaks will be refined with integral regression. Default: None. @@ -131,24 +134,17 @@ def from_model_paths( `MoveNetPredictor`, `TopDownMultiClassPredictor`, `BottomUpMultiClassPredictor`. """ - model_config_paths = [ + model_configs = [ OmegaConf.load(f"{Path(c)}/training_config.yaml") for c in model_paths ] - model_names = sum( - [ - [ - c.model_config.head_configs[head].head_type - for head in c.model_config.head_configs - ] - for c in model_config_paths - ], - [], - ) - - if "SingleInstanceConfmapsHead" in model_names: - confmap_ckpt_path = model_paths[ - model_names.index("SingleInstanceConfmapsHead") - ] + model_names = [] + for config in model_configs: + for k, v in config.model_config.head_configs.items(): + if v is not None: + model_names.append(k) + + if "single_instance" in model_names: + confmap_ckpt_path = model_paths[model_names.index("single_instance")] predictor = SingleInstancePredictor.from_trained_models( confmap_ckpt_path, peak_threshold=peak_threshold, @@ -161,20 +157,13 @@ def from_model_paths( preprocess_config=preprocess_config, ) - elif ( - "CentroidConfmapsHead" in model_names - or "CenteredInstanceConfmapsHead" in model_names - ): + elif "centroid" in model_names or "centered_instance" in model_names: centroid_ckpt_path = None confmap_ckpt_path = None - if "CentroidConfmapsHead" in model_names: - centroid_ckpt_path = model_paths[ - model_names.index("CentroidConfmapsHead") - ] - if "CenteredInstanceConfmapsHead" in model_names: - confmap_ckpt_path = model_paths[ - model_names.index("CenteredInstanceConfmapsHead") - ] + if "centroid" in model_names: + centroid_ckpt_path = model_paths[model_names.index("centroid")] + if "centered_instance" in model_names: + confmap_ckpt_path = model_paths[model_names.index("centered_instance")] # create an instance of the TopDown predictor class predictor = TopDownPredictor.from_trained_models( @@ -191,13 +180,8 @@ def from_model_paths( preprocess_config=preprocess_config, ) - elif ( - "MultiInstanceConfmapsHead" in model_names - or "PartAffinityFieldsHead" in model_names - ): - bottomup_ckpt_path = model_paths[ - model_names.index("MultiInstanceConfmapsHead") - ] + elif "bottomup" in model_names: + bottomup_ckpt_path = model_paths[model_names.index("bottomup")] predictor = BottomUpPredictor.from_trained_models( bottomup_ckpt_path, peak_threshold=peak_threshold, @@ -380,7 +364,10 @@ class TopDownPredictor(Predictor): skeletons: List of `sio.Skeleton` objects for creating `sio.Labels` object from the output predictions. peak_threshold: (float) Minimum confidence threshold. Peaks with values below - this will be ignored. Default: 0.2 + this will be ignored. Default: 0.2. This can also be `List[float]` for topdown + centroid and centered-instance model, where the first element corresponds + to centroid model peak finding threshold and the second element is for + centered-instance model peak finding. integral_refinement: If `None`, returns the grid-aligned peaks with no refinement. If `"integral"`, peaks will be refined with integral regression. Default: None. @@ -411,7 +398,7 @@ class TopDownPredictor(Predictor): confmap_model: Optional[L.LightningModule] = attrs.field(default=None) videos: Optional[List[sio.Video]] = attrs.field(default=None) skeletons: Optional[List[sio.Skeleton]] = attrs.field(default=None) - peak_threshold: float = 0.2 + peak_threshold: Union[float, List[float]] = 0.2 integral_refinement: str = None integral_patch_size: int = 5 batch_size: int = 4 @@ -425,12 +412,17 @@ def _initialize_inference_model(self): """Initialize the inference model from the trained models and configuration.""" # Create an instance of CentroidLayer if centroid_config is not None return_crops = False + if isinstance(self.peak_threshold, list): + centroid_peak_threshold = self.peak_threshold[0] + centered_instance_peak_threshold = self.peak_threshold[1] + else: + centroid_peak_threshold = self.peak_threshold + centered_instance_peak_threshold = self.peak_threshold + if self.centroid_config is None: centroid_crop_layer = None else: - max_stride = ( - self.centroid_config.model_config.backbone_config.backbone_config.max_stride - ) + max_stride = self.centroid_config.model_config.backbone_config.max_stride # if both centroid and centered-instance model are provided, set return crops to True if self.confmap_model: @@ -439,7 +431,7 @@ def _initialize_inference_model(self): # initialize centroid crop layer centroid_crop_layer = CentroidCrop( torch_model=self.centroid_model, - peak_threshold=self.peak_threshold, + peak_threshold=centroid_peak_threshold, output_stride=self.output_stride, refinement=self.integral_refinement, integral_patch_size=self.integral_patch_size, @@ -447,7 +439,7 @@ def _initialize_inference_model(self): return_crops=return_crops, max_instances=self.max_instances, max_stride=max_stride, - input_scale=self.data_config.scale, + input_scale=self.centroid_config.data_config.preprocessing.scale, crop_hw=self.data_config.crop_hw, ) @@ -456,18 +448,16 @@ def _initialize_inference_model(self): instance_peaks_layer = FindInstancePeaksGroundTruth() else: - max_stride = ( - self.confmap_config.model_config.backbone_config.backbone_config.max_stride - ) + max_stride = self.confmap_config.model_config.backbone_config.max_stride instance_peaks_layer = FindInstancePeaks( torch_model=self.confmap_model, - peak_threshold=self.peak_threshold, + peak_threshold=centered_instance_peak_threshold, output_stride=self.output_stride, refinement=self.integral_refinement, integral_patch_size=self.integral_patch_size, return_confmaps=self.return_confmaps, max_stride=max_stride, - input_scale=self.data_config.scale, + input_scale=self.confmap_config.data_config.preprocessing.scale, ) # Initialize the inference model with centroid and instance peak layers @@ -479,9 +469,9 @@ def _initialize_inference_model(self): def data_config(self) -> OmegaConf: """Returns data config section from the overall config.""" if self.centroid_config: - data_config = self.centroid_config.data_config.train.preprocessing + data_config = self.centroid_config.data_config.preprocessing else: - data_config = self.confmap_config.data_config.train.preprocessing + data_config = self.confmap_config.data_config.preprocessing if self.preprocess_config is None: return data_config return self.preprocess_config @@ -546,9 +536,9 @@ def from_trained_models( f"{centroid_ckpt_path}/best.ckpt", config=centroid_config, skeletons=skeletons, + model_type="centroid", ) centroid_model.to(device) - centroid_model.m_device = device else: centroid_config = None @@ -562,9 +552,9 @@ def from_trained_models( f"{confmap_ckpt_path}/best.ckpt", config=confmap_config, skeletons=skeletons, + model_type="centered_instance", ) confmap_model.to(device) - confmap_model.m_device = device else: confmap_config = None @@ -640,7 +630,7 @@ class (doesn't return a pipeline) and the Thread is started in if not self.centroid_model: pipeline = InstanceCentroidFinder( pipeline, - anchor_ind=self.confmap_config.model_config.head_configs.confmaps.head_config.anchor_part, + anchor_ind=self.confmap_config.model_config.head_configs.centered_instance.confmaps.anchor_part, ) pipeline = InstanceCropper( pipeline, @@ -660,7 +650,6 @@ class (doesn't return a pipeline) and the Thread is started in "confidence_maps", "num_instances", "orig_size", - "scale", ], ) @@ -685,10 +674,10 @@ class (doesn't return a pipeline) and the Thread is started in self.preprocess = False self.video_preprocess_config = { "batch_size": self.batch_size, - "scale": self.data_config.scale, + "scale": self.centroid_config.data_config.preprocessing.scale, "is_rgb": self.data_config.is_rgb, "max_stride": ( - self.centroid_config.model_config.backbone_config.backbone_config.max_stride + self.centroid_config.model_config.backbone_config.max_stride ), } @@ -842,13 +831,13 @@ def _initialize_inference_model(self): refinement=self.integral_refinement, integral_patch_size=self.integral_patch_size, return_confmaps=self.return_confmaps, - input_scale=self.data_config.scale, + input_scale=self.confmap_config.data_config.preprocessing.scale, ) @property def data_config(self) -> OmegaConf: """Returns data config section from the overall config.""" - data_config = self.confmap_config.data_config.train.preprocessing + data_config = self.confmap_config.data_config.preprocessing if self.preprocess_config is None: return data_config return self.preprocess_config @@ -899,10 +888,12 @@ def from_trained_models( confmap_config = OmegaConf.load(f"{confmap_ckpt_path}/training_config.yaml") skeletons = get_skeleton_from_config(confmap_config.data_config.skeletons) confmap_model = SingleInstanceModel.load_from_checkpoint( - f"{confmap_ckpt_path}/best.ckpt", config=confmap_config, skeletons=skeletons + f"{confmap_ckpt_path}/best.ckpt", + config=confmap_config, + skeletons=skeletons, + model_type="single_instance", ) confmap_model.to(device) - confmap_model.m_device = device # create an instance of SingleInstancePredictor class obj = cls( @@ -955,10 +946,12 @@ class (doesn't return a pipeline) and the Thread is started in max_width=self.data_config.max_width, provider=data_provider, ) - pipeline = Resizer(pipeline, scale=self.data_config.scale) + pipeline = Resizer( + pipeline, scale=self.confmap_config.data_config.preprocessing.scale + ) pipeline = PadToStride( pipeline, - max_stride=self.confmap_config.model_config.backbone_config.backbone_config.max_stride, + max_stride=self.confmap_config.model_config.backbone_config.max_stride, ) # Remove duplicates. @@ -975,10 +968,10 @@ class (doesn't return a pipeline) and the Thread is started in self.preprocess = True self.video_preprocess_config = { "batch_size": self.batch_size, - "scale": self.data_config.scale, + "scale": self.confmap_config.data_config.preprocessing.scale, "is_rgb": self.data_config.is_rgb, "max_stride": ( - self.confmap_config.model_config.backbone_config.backbone_config.max_stride + self.confmap_config.model_config.backbone_config.max_stride ), } frame_queue = Queue( @@ -1141,12 +1134,12 @@ def _initialize_inference_model(self): paf_scorer = PAFScorer.from_config( config=OmegaConf.create( { - "confmaps": self.bottomup_config.model_config.head_configs[ + "confmaps": self.bottomup_config.model_config.head_configs.bottomup[ "confmaps" - ].head_config, - "pafs": self.bottomup_config.model_config.head_configs[ + ], + "pafs": self.bottomup_config.model_config.head_configs.bottomup[ "pafs" - ].head_config, + ], } ), max_edge_length_ratio=self.max_edge_length_ratio, @@ -1166,13 +1159,13 @@ def _initialize_inference_model(self): refinement=self.integral_refinement, integral_patch_size=self.integral_patch_size, return_confmaps=self.return_confmaps, - input_scale=self.data_config.scale, + input_scale=self.bottomup_config.data_config.preprocessing.scale, ) @property def data_config(self) -> OmegaConf: """Returns data config section from the overall config.""" - data_config = self.bottomup_config.data_config.train.preprocessing + data_config = self.bottomup_config.data_config.preprocessing if self.preprocess_config is None: return data_config return self.preprocess_config @@ -1231,9 +1224,9 @@ def from_trained_models( f"{bottomup_ckpt_path}/best.ckpt", config=bottomup_config, skeletons=skeletons, + model_type="bottomup", ) bottomup_model.to(device) - bottomup_model.m_device = device # create an instance of SingleInstancePredictor class obj = cls( @@ -1287,10 +1280,10 @@ class (doesn't return a pipeline) and the Thread is started in max_width=self.data_config.max_width, provider=data_provider, ) - pipeline = Resizer(pipeline, scale=self.data_config.scale) - max_stride = ( - self.bottomup_config.model_config.backbone_config.backbone_config.max_stride + pipeline = Resizer( + pipeline, scale=self.bottomup_config.data_config.preprocessing.scale ) + max_stride = self.bottomup_config.model_config.backbone_config.max_stride pipeline = PadToStride(pipeline, max_stride=max_stride) # Remove duplicates. @@ -1307,10 +1300,10 @@ class (doesn't return a pipeline) and the Thread is started in self.preprocess = True self.video_preprocess_config = { "batch_size": self.batch_size, - "scale": self.data_config.scale, + "scale": self.bottomup_config.data_config.preprocessing.scale, "is_rgb": self.data_config.is_rgb, "max_stride": ( - self.bottomup_config.model_config.backbone_config.backbone_config.max_stride + self.bottomup_config.model_config.backbone_config.max_stride ), } frame_queue = Queue( @@ -1420,7 +1413,6 @@ def main( max_width: int = None, max_height: int = None, is_rgb: bool = False, - scale: float = 1.0, provider: str = "LabelsReader", batch_size: int = 4, num_workers: int = 0, @@ -1430,7 +1422,7 @@ def main( crop_hw: List[int] = (160, 160), output_stride: int = 2, pafs_output_stride: int = 4, - peak_threshold: float = 0.2, + peak_threshold: Union[float, List[float]] = 0.2, integral_refinement: str = None, integral_patch_size: int = 5, return_confmaps: bool = False, @@ -1449,7 +1441,8 @@ def main( Args: data_path: (str) Path to `.slp` file or `.mp4` to run inference on. - model_paths: TODO + model_paths: (List[str]) List of paths to the directory where the best.ckpt + and training_config.yaml are saved. max_instances: (int) Max number of instances to consider from the predictions. max_width: (int) Maximum width the image should be padded to. If not provided, the original image size will be retained. Default: None. @@ -1460,8 +1453,6 @@ def main( is replicated along the channel axis. If input has three channels and this is set to False, then we convert the image to grayscale (single-channel) image. Default: False. - scale: (float) Float indicating if the images should be resized before being - passed to the model. Default: 1.0. provider: (str) Provider class to read the input sleap files. Either "LabelsReader" or "VideoReader". Default: LabelsReader. batch_size: (int) Number of samples per batch. Default: 4. @@ -1480,7 +1471,10 @@ def main( pafs_output_stride: (int) Stride of the output part affinity fields relative to the input image. Default: 4. peak_threshold: (float) Minimum confidence threshold. Peaks with values below - this will be ignored. Default: 0.2. + this will be ignored. Default: 0.2. This can also be `List[float]` for topdown + centroid and centered-instance model, where the first element corresponds + to centroid model peak finding threshold and the second element is for + centered-instance model peak finding. integral_refinement: (str) If `None`, returns the grid-aligned peaks with no refinement. If `"integral"`, peaks will be refined with integral regression. Default: None. @@ -1527,7 +1521,6 @@ def main( """ preprocess_config = { # if not given, then use from training config "is_rgb": is_rgb, - "scale": scale, "crop_hw": crop_hw, "max_width": max_width, "max_height": max_height, diff --git a/sleap_nn/training/model_trainer.py b/sleap_nn/training/model_trainer.py index a18aef09..0458348d 100644 --- a/sleap_nn/training/model_trainer.py +++ b/sleap_nn/training/model_trainer.py @@ -63,11 +63,11 @@ def __init__(self, config: OmegaConf): """Initialise the class with configs and set the seed and device as class attributes.""" self.config = config - self.m_device = self.config.trainer_config.device self.seed = self.config.trainer_config.seed self.steps_per_epoch = self.config.trainer_config.steps_per_epoch # initialize attributes + self.model_type = None self.model = None self.provider = None self.skeletons = None @@ -83,68 +83,58 @@ def _create_data_loaders(self): if self.provider == "LabelsReader": self.provider = LabelsReader - if self.config.data_config.pipeline == "SingleInstanceConfmaps": - train_pipeline = SingleInstanceConfmapsPipeline( - data_config=self.config.data_config.train, - max_stride=self.config.model_config.backbone_config.backbone_config.max_stride, - confmap_head=self.config.model_config.head_configs.confmaps.head_config, - ) - val_pipeline = SingleInstanceConfmapsPipeline( - data_config=self.config.data_config.val, - max_stride=self.config.model_config.backbone_config.backbone_config.max_stride, - confmap_head=self.config.model_config.head_configs.confmaps.head_config, + # check which head type to choose the model + for k, v in self.config.model_config.head_configs.items(): + if v is not None: + self.model_type = k + break + + if self.model_type == "single_instance": + data_pipeline = SingleInstanceConfmapsPipeline( + data_config=self.config.data_config, + max_stride=self.config.model_config.backbone_config.max_stride, + confmap_head=self.config.model_config.head_configs.single_instance.confmaps, ) - elif self.config.data_config.pipeline == "TopdownConfmaps": - train_pipeline = TopdownConfmapsPipeline( - data_config=self.config.data_config.train, - max_stride=self.config.model_config.backbone_config.backbone_config.max_stride, - confmap_head=self.config.model_config.head_configs.confmaps.head_config, - ) - val_pipeline = TopdownConfmapsPipeline( - data_config=self.config.data_config.val, - max_stride=self.config.model_config.backbone_config.backbone_config.max_stride, - confmap_head=self.config.model_config.head_configs.confmaps.head_config, + elif self.model_type == "centered_instance": + data_pipeline = TopdownConfmapsPipeline( + data_config=self.config.data_config, + max_stride=self.config.model_config.backbone_config.max_stride, + confmap_head=self.config.model_config.head_configs.centered_instance.confmaps, ) - elif self.config.data_config.pipeline == "CentroidConfmaps": - train_pipeline = CentroidConfmapsPipeline( - data_config=self.config.data_config.train, - max_stride=self.config.model_config.backbone_config.backbone_config.max_stride, - confmap_head=self.config.model_config.head_configs.confmaps.head_config, - ) - val_pipeline = CentroidConfmapsPipeline( - data_config=self.config.data_config.val, - max_stride=self.config.model_config.backbone_config.backbone_config.max_stride, - confmap_head=self.config.model_config.head_configs.confmaps.head_config, + elif self.model_type == "centroid": + data_pipeline = CentroidConfmapsPipeline( + data_config=self.config.data_config, + max_stride=self.config.model_config.backbone_config.max_stride, + confmap_head=self.config.model_config.head_configs.centroid.confmaps, ) - elif self.config.data_config.pipeline == "BottomUp": - train_pipeline = BottomUpPipeline( - data_config=self.config.data_config.train, - max_stride=self.config.model_config.backbone_config.backbone_config.max_stride, - confmap_head=self.config.model_config.head_configs.confmaps.head_config, - pafs_head=self.config.model_config.head_configs.pafs.head_config, - ) - val_pipeline = BottomUpPipeline( - data_config=self.config.data_config.val, - max_stride=self.config.model_config.backbone_config.backbone_config.max_stride, - confmap_head=self.config.model_config.head_configs.confmaps.head_config, - pafs_head=self.config.model_config.head_configs.pafs.head_config, + elif self.model_type == "bottomup": + data_pipeline = BottomUpPipeline( + data_config=self.config.data_config, + max_stride=self.config.model_config.backbone_config.max_stride, + confmap_head=self.config.model_config.head_configs.bottomup.confmaps, + pafs_head=self.config.model_config.head_configs.bottomup.pafs, ) else: - raise Exception(f"{self.config.data_config.pipeline} is not defined.") + raise Exception( + f"{self.model_type} is not defined. Please choose one of `single_instance`, `centered_instance`, `centroid`, `bottomup`." + ) # train - train_labels = sio.load_slp(self.config.data_config.train.labels_path) + train_labels = sio.load_slp(self.config.data_config.train_labels_path) self.skeletons = train_labels.skeletons train_labels_reader = self.provider(train_labels) - train_datapipe = train_pipeline.make_training_pipeline( + train_datapipe = data_pipeline.make_training_pipeline( data_provider=train_labels_reader, + use_augmentations=self.config.data_config.use_augmentations_train, ) + + # Make sure an epoch runs for `steps_per_epoch` iterations if self.steps_per_epoch is not None: train_datapipe = Cycler(train_datapipe) @@ -152,20 +142,24 @@ def _create_data_loaders(self): train_datapipe = train_datapipe.sharding_filter() self.train_data_loader = DataLoader( train_datapipe, - **dict(self.config.trainer_config.train_data_loader), + batch_size=self.config.trainer_config.train_data_loader.batch_size, + shuffle=self.config.trainer_config.train_data_loader.shuffle, + num_workers=self.config.trainer_config.train_data_loader.num_workers, ) # val val_labels_reader = self.provider.from_filename( - self.config.data_config.val.labels_path, + self.config.data_config.val_labels_path, ) - val_datapipe = val_pipeline.make_training_pipeline( - data_provider=val_labels_reader, + val_datapipe = data_pipeline.make_training_pipeline( + data_provider=val_labels_reader, use_augmentations=False ) val_datapipe = val_datapipe.sharding_filter() self.val_data_loader = DataLoader( val_datapipe, - **dict(self.config.trainer_config.val_data_loader), + batch_size=self.config.trainer_config.val_data_loader.batch_size, + shuffle=False, + num_workers=self.config.trainer_config.val_data_loader.num_workers, ) def _set_wandb(self): @@ -173,13 +167,13 @@ def _set_wandb(self): def _initialize_model(self): models = { - "SingleInstanceConfmaps": SingleInstanceModel, - "TopdownConfmaps": TopDownCenteredInstanceModel, - "CentroidConfmaps": CentroidModel, - "BottomUp": BottomUpModel, + "single_instance": SingleInstanceModel, + "centered_instance": TopDownCenteredInstanceModel, + "centroid": CentroidModel, + "bottomup": BottomUpModel, } - self.model = models[self.config.data_config.pipeline]( - self.config, self.skeletons + self.model = models[self.model_type]( + self.config, self.skeletons, self.model_type ) def _get_param_count(self): @@ -201,11 +195,13 @@ def train(self): print( f"Cannot create a new folder. Check the permissions to the given Checkpoint directory. \n {e}" ) + if self.config.trainer_config.save_ckpt: # create checkpoint callback checkpoint_callback = ModelCheckpoint( - **dict(self.config.trainer_config.model_ckpt), + save_top_k=self.config.trainer_config.model_ckpt.save_top_k, + save_last=self.config.trainer_config.model_ckpt.save_last, dirpath=dir_path, filename="best", monitor="val_loss", @@ -241,8 +237,9 @@ def train(self): ) logger.append(wandb_logger) - # save the configs as yaml in the checkpoint dir - self.config.trainer_config.wandb.api_key = "" + # save the configs as yaml in the checkpoint dir + self.config.trainer_config.wandb.api_key = "" + OmegaConf.save(config=self.config, f=f"{dir_path}/initial_config.yaml") # save the skeleton in the config @@ -304,10 +301,14 @@ class TrainingModel(L.LightningModule): (ii) model_config: backbone and head configs to be passed to `Model` class. (iii) trainer_config: trainer configs like accelerator, optimiser params. skeletons: List of `sio.Skeleton` objects from the input `.slp` file. + model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`. """ def __init__( - self, config: OmegaConf, skeletons: Optional[List[sio.Skeleton]] = None + self, + config: OmegaConf, + skeletons: Optional[List[sio.Skeleton]], + model_type: str, ): """Initialise the configs and the model.""" super().__init__() @@ -316,52 +317,51 @@ def __init__( self.model_config = self.config.model_config self.trainer_config = self.config.trainer_config self.data_config = self.config.data_config - self.m_device = self.trainer_config.device - self.input_expand_channels = ( - self.model_config.backbone_config.backbone_config.in_channels - ) + self.model_type = model_type + self.input_expand_channels = self.model_config.backbone_config.in_channels if self.model_config.pre_trained_weights: ckpt = eval(self.model_config.pre_trained_weights).DEFAULT.get_state_dict( progress=True, check_hash=True ) input_channels = ckpt["features.0.0.weight"].shape[-3] - if ( - self.model_config.backbone_config.backbone_config.in_channels - != input_channels - ): + if self.model_config.backbone_config.in_channels != input_channels: self.input_expand_channels = input_channels OmegaConf.update( self.model_config, - "backbone_config.backbone_config.in_channels", + "backbone_config.in_channels", input_channels, ) # if edges and part names aren't set in config, get it from `sio.Labels` object. - head_configs = self.model_config.head_configs - for key in head_configs: - if "part_names" in head_configs[key].head_config.keys(): - if head_configs[key].head_config["part_names"] is None: + head_config = self.model_config.head_configs[self.model_type] + for key in head_config: + if "part_names" in head_config[key].keys(): + if head_config[key]["part_names"] is None: part_names = [x.name for x in self.skeletons[0].nodes] - head_configs[key].head_config["part_names"] = part_names + head_config[key]["part_names"] = part_names - if "edges" in head_configs[key].head_config.keys(): - if head_configs[key].head_config["edges"] is None: + if "edges" in head_config[key].keys(): + if head_config[key]["edges"] is None: edges = [ (x.source.name, x.destination.name) for x in self.skeletons[0].edges ] - head_configs[key].head_config["edges"] = edges + head_config[key]["edges"] = edges self.model = Model( + backbone_type=self.model_config.backbone_type, backbone_config=self.model_config.backbone_config, - head_configs=head_configs, + head_configs=head_config, input_expand_channels=self.input_expand_channels, - ).to(self.m_device) + model_type=self.model_type, + ) + + if len(self.model_config.head_configs[self.model_type]) > 1: + self.loss_weights = [ + self.model_config.head_configs[self.model_type][x].loss_weight + for x in self.model_config.head_configs[self.model_type] + ] - self.loss_weights = [ - self.model_config.head_configs[x].head_config.loss_weight - for x in self.model_config.head_configs - ] self.training_loss = {} self.val_loss = {} self.learning_rate = {} @@ -374,11 +374,6 @@ def __init__( if self.model_config.pre_trained_weights: self.model.backbone.enc.load_state_dict(ckpt, strict=False) - @property - def device(self): - """Save the device as an attribute to the class.""" - return next(self.model.parameters()).device - def forward(self, img): """Forward pass of the model.""" pass @@ -431,18 +426,24 @@ def validation_step(self, batch, batch_idx): def configure_optimizers(self): """Configure optimiser and learning rate scheduler.""" if self.trainer_config.optimizer_name == "Adam": - optimizer = torch.optim.Adam( - self.parameters(), - **dict(self.trainer_config.optimizer), - ) + optim = torch.optim.Adam elif self.trainer_config.optimizer_name == "AdamW": - optimizer = torch.optim.AdamW( - self.parameters(), - **dict(self.trainer_config.optimizer), - ) + optim = torch.optim.AdamW + + optimizer = optim( + self.parameters(), + lr=self.trainer_config.optimizer.lr, + amsgrad=self.trainer_config.optimizer.amsgrad, + ) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, - **dict(self.trainer_config.lr_scheduler), + mode="min", + threshold=self.trainer_config.lr_scheduler.threshold, + threshold_mode="rel", + cooldown=self.trainer_config.lr_scheduler.cooldown, + patience=self.trainer_config.lr_scheduler.patience, + factor=self.trainer_config.lr_scheduler.factor, + min_lr=self.trainer_config.lr_scheduler.min_lr, ) return { "optimizer": optimizer, @@ -466,29 +467,31 @@ class SingleInstanceModel(TrainingModel): (ii) model_config: backbone and head configs to be passed to `Model` class. (iii) trainer_config: trainer configs like accelerator, optimiser params. skeletons: List of `sio.Skeleton` objects from the input `.slp` file. + model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`. """ def __init__( - self, config: OmegaConf, skeletons: Optional[List[sio.Skeleton]] = None + self, + config: OmegaConf, + skeletons: Optional[List[sio.Skeleton]], + model_type: str, ): """Initialise the configs and the model.""" - super().__init__(config, skeletons) + super().__init__(config, skeletons, model_type) def forward(self, img): """Forward pass of the model.""" - img = torch.squeeze(img, dim=1) - img = img.to(self.m_device) + img = torch.squeeze(img, dim=1).to(self.device) return self.model(img)["SingleInstanceConfmapsHead"] def training_step(self, batch, batch_idx): """Training step.""" - X, y = torch.squeeze(batch["image"], dim=1).to(self.m_device), torch.squeeze( + X, y = torch.squeeze(batch["image"], dim=1).to(self.device), torch.squeeze( batch["confidence_maps"], dim=1 - ).to(self.m_device) + ).to(self.device) y_preds = self.model(X)["SingleInstanceConfmapsHead"] - y = y.to(self.m_device) loss = nn.MSELoss()(y_preds, y) self.log( "train_loss", loss, prog_bar=True, on_step=False, on_epoch=True, logger=True @@ -497,12 +500,11 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): """Validation step.""" - X, y = torch.squeeze(batch["image"], dim=1).to(self.m_device), torch.squeeze( + X, y = torch.squeeze(batch["image"], dim=1).to(self.device), torch.squeeze( batch["confidence_maps"], dim=1 - ).to(self.m_device) + ).to(self.device) y_preds = self.model(X)["SingleInstanceConfmapsHead"] - y = y.to(self.m_device) val_loss = nn.MSELoss()(y_preds, y) lr = self.optimizers().optimizer.param_groups[0]["lr"] self.log( @@ -536,29 +538,31 @@ class TopDownCenteredInstanceModel(TrainingModel): (ii) model_config: backbone and head configs to be passed to `Model` class. (iii) trainer_config: trainer configs like accelerator, optimiser params. skeletons: List of `sio.Skeleton` objects from the input `.slp` file. + model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`. """ def __init__( - self, config: OmegaConf, skeletons: Optional[List[sio.Skeleton]] = None + self, + config: OmegaConf, + skeletons: Optional[List[sio.Skeleton]], + model_type: str, ): """Initialise the configs and the model.""" - super().__init__(config, skeletons) + super().__init__(config, skeletons, model_type) def forward(self, img): """Forward pass of the model.""" - img = torch.squeeze(img, dim=1) - img = img.to(self.m_device) + img = torch.squeeze(img, dim=1).to(self.device) return self.model(img)["CenteredInstanceConfmapsHead"] def training_step(self, batch, batch_idx): """Training step.""" X, y = torch.squeeze(batch["instance_image"], dim=1).to( - self.m_device - ), torch.squeeze(batch["confidence_maps"], dim=1).to(self.m_device) + self.device + ), torch.squeeze(batch["confidence_maps"], dim=1).to(self.device) y_preds = self.model(X)["CenteredInstanceConfmapsHead"] - y = y.to(self.m_device) loss = nn.MSELoss()(y_preds, y) self.log( "train_loss", loss, prog_bar=True, on_step=False, on_epoch=True, logger=True @@ -568,11 +572,10 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): """Perform validation step.""" X, y = torch.squeeze(batch["instance_image"], dim=1).to( - self.m_device - ), torch.squeeze(batch["confidence_maps"], dim=1).to(self.m_device) + self.device + ), torch.squeeze(batch["confidence_maps"], dim=1).to(self.device) y_preds = self.model(X)["CenteredInstanceConfmapsHead"] - y = y.to(self.m_device) val_loss = nn.MSELoss()(y_preds, y) lr = self.optimizers().optimizer.param_groups[0]["lr"] self.log( @@ -606,29 +609,31 @@ class CentroidModel(TrainingModel): (ii) model_config: backbone and head configs to be passed to `Model` class. (iii) trainer_config: trainer configs like accelerator, optimiser params. skeletons: List of `sio.Skeleton` objects from the input `.slp` file. + model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`. """ def __init__( - self, config: OmegaConf, skeletons: Optional[List[sio.Skeleton]] = None + self, + config: OmegaConf, + skeletons: Optional[List[sio.Skeleton]], + model_type: str, ): """Initialise the configs and the model.""" - super().__init__(config, skeletons) + super().__init__(config, skeletons, model_type) def forward(self, img): """Forward pass of the model.""" - img = torch.squeeze(img, dim=1) - img = img.to(self.m_device) + img = torch.squeeze(img, dim=1).to(self.device) return self.model(img)["CentroidConfmapsHead"] def training_step(self, batch, batch_idx): """Training step.""" - X, y = torch.squeeze(batch["image"], dim=1).to(self.m_device), torch.squeeze( + X, y = torch.squeeze(batch["image"], dim=1).to(self.device), torch.squeeze( batch["centroids_confidence_maps"], dim=1 - ).to(self.m_device) + ).to(self.device) y_preds = self.model(X)["CentroidConfmapsHead"] - y = y.to(self.m_device) loss = nn.MSELoss()(y_preds, y) self.log( "train_loss", loss, prog_bar=True, on_step=False, on_epoch=True, logger=True @@ -637,12 +642,11 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): """Validation step.""" - X, y = torch.squeeze(batch["image"], dim=1).to(self.m_device), torch.squeeze( + X, y = torch.squeeze(batch["image"], dim=1).to(self.device), torch.squeeze( batch["centroids_confidence_maps"], dim=1 - ).to(self.m_device) + ).to(self.device) y_preds = self.model(X)["CentroidConfmapsHead"] - y = y.to(self.m_device) val_loss = nn.MSELoss()(y_preds, y) lr = self.optimizers().optimizer.param_groups[0]["lr"] self.log( @@ -676,19 +680,22 @@ class BottomUpModel(TrainingModel): (ii) model_config: backbone and head configs to be passed to `Model` class. (iii) trainer_config: trainer configs like accelerator, optimiser params. skeletons: List of `sio.Skeleton` objects from the input `.slp` file. + model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`. """ def __init__( - self, config: OmegaConf, skeletons: Optional[List[sio.Skeleton]] = None + self, + config: OmegaConf, + skeletons: Optional[List[sio.Skeleton]], + model_type: str, ): """Initialise the configs and the model.""" - super().__init__(config, skeletons) + super().__init__(config, skeletons, model_type) def forward(self, img): """Forward pass of the model.""" - img = torch.squeeze(img, dim=1) - img = img.to(self.m_device) + img = torch.squeeze(img, dim=1).to(self.device) output = self.model(img) return { "MultiInstanceConfmapsHead": output["MultiInstanceConfmapsHead"], @@ -697,9 +704,9 @@ def forward(self, img): def training_step(self, batch, batch_idx): """Training step.""" - X = torch.squeeze(batch["image"], dim=1).to(self.m_device) - y_confmap = torch.squeeze(batch["confidence_maps"], dim=1).to(self.m_device) - y_paf = batch["part_affinity_fields"].to(self.m_device) + X = torch.squeeze(batch["image"], dim=1).to(self.device) + y_confmap = torch.squeeze(batch["confidence_maps"], dim=1).to(self.device) + y_paf = batch["part_affinity_fields"].to(self.device) preds = self.model(X) pafs = preds["PartAffinityFieldsHead"].permute(0, 2, 3, 1) confmaps = preds["MultiInstanceConfmapsHead"] @@ -715,9 +722,9 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): """Validation step.""" - X = torch.squeeze(batch["image"], dim=1).to(self.m_device) - y_confmap = torch.squeeze(batch["confidence_maps"], dim=1).to(self.m_device) - y_paf = batch["part_affinity_fields"].to(self.m_device) + X = torch.squeeze(batch["image"], dim=1).to(self.device) + y_confmap = torch.squeeze(batch["confidence_maps"], dim=1).to(self.device) + y_paf = batch["part_affinity_fields"].to(self.device) preds = self.model(X) pafs = preds["PartAffinityFieldsHead"].permute(0, 2, 3, 1) diff --git a/tests/architectures/test_model.py b/tests/architectures/test_model.py index ed71e35c..fd52d862 100644 --- a/tests/architectures/test_model.py +++ b/tests/architectures/test_model.py @@ -12,27 +12,22 @@ def test_get_backbone(): # unet base_unet_model_config = OmegaConf.create( { - "backbone_type": "unet", - "init_weights": "default", - "pre_trained_weights": None, - "backbone_config": { - "in_channels": 1, - "kernel_size": 3, - "filters": 16, - "filters_rate": 2, - "max_stride": 16, - "convs_per_block": 2, - "stacks": 1, - "stem_stride": None, - "middle_block": True, - "up_interpolate": True, - }, + "in_channels": 1, + "kernel_size": 3, + "filters": 16, + "filters_rate": 2, + "max_stride": 16, + "convs_per_block": 2, + "stacks": 1, + "stem_stride": None, + "middle_block": True, + "up_interpolate": True, } ) backbone = get_backbone( - base_unet_model_config.backbone_type, - base_unet_model_config.backbone_config, + "unet", + base_unet_model_config, output_stride=1, ) assert isinstance(backbone, torch.nn.Module) @@ -40,85 +35,72 @@ def test_get_backbone(): # convnext base_convnext_model_config = OmegaConf.create( { - "backbone_type": "convnext", - "init_weights": "default", - "pretrained_weights": "", - "backbone_config": { - "in_channels": 1, - "model_type": "tiny", - "arch": None, - "kernel_size": 3, - "filters_rate": 2, - "convs_per_block": 2, - "up_interpolate": True, - "stem_patch_kernel": 4, - "stem_patch_stride": 2, - }, + "in_channels": 1, + "model_type": "tiny", + "arch": None, + "kernel_size": 3, + "filters_rate": 2, + "convs_per_block": 2, + "up_interpolate": True, + "stem_patch_kernel": 4, + "stem_patch_stride": 2, } ) backbone = get_backbone( - base_convnext_model_config.backbone_type, - base_convnext_model_config.backbone_config, + "convnext", + base_convnext_model_config, output_stride=1, ) assert isinstance(backbone, torch.nn.Module) with pytest.raises(KeyError): - _ = get_backbone("invalid_input", base_unet_model_config.backbone_config, 1) + _ = get_backbone("invalid_input", base_unet_model_config, 1) # swint base_convnext_model_config = OmegaConf.create( { - "backbone_type": "swint", - "init_weights": "default", - "pretrained_weights": "", - "backbone_config": { - "in_channels": 1, - "model_type": "tiny", - "arch": None, - "patch_size": [4, 4], - "window_size": [7, 7], - "kernel_size": 3, - "filters_rate": 2, - "convs_per_block": 2, - "up_interpolate": True, - "stem_patch_stride": 4, - }, + "in_channels": 1, + "model_type": "tiny", + "arch": None, + "patch_size": [4, 4], + "window_size": [7, 7], + "kernel_size": 3, + "filters_rate": 2, + "convs_per_block": 2, + "up_interpolate": True, + "stem_patch_stride": 4, } ) backbone = get_backbone( - base_convnext_model_config.backbone_type, - base_convnext_model_config.backbone_config, + "swint", + base_convnext_model_config, output_stride=1, ) assert isinstance(backbone, torch.nn.Module) with pytest.raises(KeyError): - _ = get_backbone( - "invalid_input", base_unet_model_config.backbone_config, output_stride=1 - ) + _ = get_backbone("invalid_input", base_unet_model_config, output_stride=1) def test_get_head(): base_unet_head_config = OmegaConf.create( { - "head_type": "SingleInstanceConfmapsHead", - "head_config": { + "confmaps": { "part_names": [f"{i}" for i in range(13)], "sigma": 5.0, "output_stride": 1, "loss_weight": 1.0, - }, + } } ) - head = get_head(base_unet_head_config.head_type, base_unet_head_config.head_config) - assert isinstance(head, Head) + head = get_head("single_instance", base_unet_head_config) + assert isinstance(head[0], Head) - with pytest.raises(KeyError): - _ = get_head("invalid_input", base_unet_head_config.head_config) + with pytest.raises(Exception): + _ = get_head("invalid_input", base_unet_head_config) def test_unet_model(): @@ -126,42 +108,40 @@ def test_unet_model(): base_unet_model_config = OmegaConf.create( { - "backbone_type": "unet", - "backbone_config": { - "in_channels": 1, - "kernel_size": 3, - "filters": 16, - "filters_rate": 2, - "max_stride": 16, - "convs_per_block": 2, - "stacks": 1, - "stem_stride": None, - "middle_block": True, - "up_interpolate": True, - }, + "in_channels": 1, + "kernel_size": 3, + "filters": 16, + "filters_rate": 2, + "max_stride": 16, + "convs_per_block": 2, + "stacks": 1, + "stem_stride": None, + "middle_block": True, + "up_interpolate": True, } ) base_unet_head_config = OmegaConf.create( { - "head_type": "SingleInstanceConfmapsHead", - "head_config": { + "confmaps": { "part_names": [f"{i}" for i in range(13)], "sigma": 5.0, "output_stride": 1, "loss_weight": 1.0, - }, + } } ) model = Model( + backbone_type="unet", backbone_config=base_unet_model_config, - head_configs=DictConfig({"confmap_head": base_unet_head_config}), + head_configs=base_unet_head_config, input_expand_channels=1, + model_type="single_instance", ).to(device) assert model.backbone_config == base_unet_model_config - assert model.head_configs == DictConfig({"confmap_head": base_unet_head_config}) + assert model.head_configs == base_unet_head_config x = torch.rand(1, 1, 192, 192).to(device) model.eval() @@ -171,48 +151,46 @@ def test_unet_model(): assert type(z) is dict assert len(z.keys()) == 1 - assert z[base_unet_head_config.head_type].shape == (1, 13, 192, 192) - assert z[base_unet_head_config.head_type].dtype == torch.float32 + assert z["SingleInstanceConfmapsHead"].shape == (1, 13, 192, 192) + assert z["SingleInstanceConfmapsHead"].dtype == torch.float32 # filter rate = 1.5 base_unet_model_config = OmegaConf.create( { - "backbone_type": "unet", - "backbone_config": { - "in_channels": 1, - "kernel_size": 3, - "filters": 16, - "filters_rate": 1.5, - "max_stride": 16, - "convs_per_block": 2, - "stacks": 1, - "stem_stride": None, - "middle_block": True, - "up_interpolate": True, - }, + "in_channels": 1, + "kernel_size": 3, + "filters": 16, + "filters_rate": 1.5, + "max_stride": 16, + "convs_per_block": 2, + "stacks": 1, + "stem_stride": None, + "middle_block": True, + "up_interpolate": True, } ) base_unet_head_config = OmegaConf.create( { - "head_type": "SingleInstanceConfmapsHead", - "head_config": { + "confmaps": { "part_names": [f"{i}" for i in range(13)], "sigma": 5.0, "output_stride": 1, "loss_weight": 1.0, - }, + } } ) model = Model( + backbone_type="unet", backbone_config=base_unet_model_config, - head_configs=DictConfig({"confmap_head": base_unet_head_config}), + head_configs=base_unet_head_config, input_expand_channels=1, + model_type="single_instance", ).to(device) assert model.backbone_config == base_unet_model_config - assert model.head_configs == DictConfig({"confmap_head": base_unet_head_config}) + assert model.head_configs == base_unet_head_config x = torch.rand(1, 1, 192, 192).to(device) model.eval() @@ -222,48 +200,46 @@ def test_unet_model(): assert type(z) is dict assert len(z.keys()) == 1 - assert z[base_unet_head_config.head_type].shape == (1, 13, 192, 192) - assert z[base_unet_head_config.head_type].dtype == torch.float32 + assert z["SingleInstanceConfmapsHead"].shape == (1, 13, 192, 192) + assert z["SingleInstanceConfmapsHead"].dtype == torch.float32 # upsampling stack with TransposeConv layers base_unet_model_config = OmegaConf.create( { - "backbone_type": "unet", - "backbone_config": { - "in_channels": 1, - "kernel_size": 3, - "filters": 16, - "filters_rate": 1.5, - "max_stride": 16, - "convs_per_block": 2, - "stacks": 1, - "stem_stride": None, - "middle_block": True, - "up_interpolate": False, - }, + "in_channels": 1, + "kernel_size": 3, + "filters": 16, + "filters_rate": 1.5, + "max_stride": 16, + "convs_per_block": 2, + "stacks": 1, + "stem_stride": None, + "middle_block": True, + "up_interpolate": False, } ) base_unet_head_config = OmegaConf.create( { - "head_type": "SingleInstanceConfmapsHead", - "head_config": { + "confmaps": { "part_names": [f"{i}" for i in range(13)], "sigma": 5.0, "output_stride": 1, "loss_weight": 1.0, - }, + } } ) model = Model( + backbone_type="unet", backbone_config=base_unet_model_config, - head_configs=DictConfig({"confmap_head": base_unet_head_config}), + head_configs=base_unet_head_config, input_expand_channels=1, + model_type="single_instance", ).to(device) assert model.backbone_config == base_unet_model_config - assert model.head_configs == DictConfig({"confmap_head": base_unet_head_config}) + assert model.head_configs == base_unet_head_config x = torch.rand(1, 1, 192, 192).to(device) model.eval() @@ -273,8 +249,8 @@ def test_unet_model(): assert type(z) is dict assert len(z.keys()) == 1 - assert z[base_unet_head_config.head_type].shape == (1, 13, 192, 192) - assert z[base_unet_head_config.head_type].dtype == torch.float32 + assert z["SingleInstanceConfmapsHead"].shape == (1, 13, 192, 192) + assert z["SingleInstanceConfmapsHead"].dtype == torch.float32 def test_convnext_model(): @@ -282,41 +258,39 @@ def test_convnext_model(): base_convnext_model_config = OmegaConf.create( { - "backbone_type": "convnext", - "backbone_config": { - "in_channels": 1, - "model_type": "tiny", - "arch": None, - "kernel_size": 3, - "filters_rate": 2, - "convs_per_block": 2, - "up_interpolate": True, - "stem_patch_kernel": 4, - "stem_patch_stride": 2, - }, + "in_channels": 1, + "model_type": "tiny", + "arch": None, + "kernel_size": 3, + "filters_rate": 2, + "convs_per_block": 2, + "up_interpolate": True, + "stem_patch_kernel": 4, + "stem_patch_stride": 2, } ) base_convnext_head_config = OmegaConf.create( { - "head_type": "SingleInstanceConfmapsHead", - "head_config": { + "confmaps": { "part_names": [f"{i}" for i in range(13)], "sigma": 5.0, "output_stride": 1, "loss_weight": 1.0, - }, + } } ) model = Model( + backbone_type="convnext", backbone_config=base_convnext_model_config, - head_configs=DictConfig({"confmap_head": base_convnext_head_config}), + head_configs=base_convnext_head_config, input_expand_channels=1, + model_type="single_instance", ).to(device) assert model.backbone_config == base_convnext_model_config - assert model.head_configs == DictConfig({"confmap_head": base_convnext_head_config}) + assert model.head_configs == base_convnext_head_config x = torch.rand(1, 1, 192, 192).to(device) model.eval() @@ -326,13 +300,15 @@ def test_convnext_model(): assert type(z) is dict assert len(z.keys()) == 1 - assert z[base_convnext_head_config.head_type].shape == (1, 13, 192, 192) - assert z[base_convnext_head_config.head_type].dtype == torch.float32 + assert z["SingleInstanceConfmapsHead"].shape == (1, 13, 192, 192) + assert z["SingleInstanceConfmapsHead"].dtype == torch.float32 model = Model.from_config( + backbone_type="convnext", backbone_config=base_convnext_model_config, - head_configs=DictConfig({"confmap_head": base_convnext_head_config}), + head_configs=base_convnext_head_config, input_expand_channels=1, + model_type="single_instance", ).to(device) x = torch.rand(1, 1, 192, 192).to(device) @@ -343,47 +319,45 @@ def test_convnext_model(): assert type(z) is dict assert len(z.keys()) == 1 - assert z[base_convnext_head_config.head_type].shape == (1, 13, 192, 192) - assert z[base_convnext_head_config.head_type].dtype == torch.float32 + assert z["SingleInstanceConfmapsHead"].shape == (1, 13, 192, 192) + assert z["SingleInstanceConfmapsHead"].dtype == torch.float32 # stride = 4 base_convnext_model_config = OmegaConf.create( { - "backbone_type": "convnext", - "backbone_config": { - "in_channels": 1, - "model_type": "tiny", - "arch": None, - "kernel_size": 3, - "filters_rate": 2, - "convs_per_block": 2, - "up_interpolate": True, - "stem_patch_kernel": 4, - "stem_patch_stride": 4, - }, + "in_channels": 1, + "model_type": "tiny", + "arch": None, + "kernel_size": 3, + "filters_rate": 2, + "convs_per_block": 2, + "up_interpolate": True, + "stem_patch_kernel": 4, + "stem_patch_stride": 4, } ) base_convnext_head_config = OmegaConf.create( { - "head_type": "SingleInstanceConfmapsHead", - "head_config": { + "confmaps": { "part_names": [f"{i}" for i in range(13)], "sigma": 5.0, "output_stride": 1, "loss_weight": 1.0, - }, + } } ) model = Model( + backbone_type="convnext", backbone_config=base_convnext_model_config, - head_configs=DictConfig({"confmap_head": base_convnext_head_config}), + head_configs=base_convnext_head_config, input_expand_channels=1, + model_type="single_instance", ).to(device) assert model.backbone_config == base_convnext_model_config - assert model.head_configs == DictConfig({"confmap_head": base_convnext_head_config}) + assert model.head_configs == base_convnext_head_config x = torch.rand(1, 1, 192, 192).to(device) model.eval() @@ -393,13 +367,15 @@ def test_convnext_model(): assert type(z) is dict assert len(z.keys()) == 1 - assert z[base_convnext_head_config.head_type].shape == (1, 13, 192, 192) - assert z[base_convnext_head_config.head_type].dtype == torch.float32 + assert z["SingleInstanceConfmapsHead"].shape == (1, 13, 192, 192) + assert z["SingleInstanceConfmapsHead"].dtype == torch.float32 model = Model.from_config( + backbone_type="convnext", backbone_config=base_convnext_model_config, - head_configs=DictConfig({"confmap_head": base_convnext_head_config}), + head_configs=base_convnext_head_config, input_expand_channels=1, + model_type="single_instance", ).to(device) x = torch.rand(1, 1, 192, 192).to(device) @@ -410,47 +386,45 @@ def test_convnext_model(): assert type(z) is dict assert len(z.keys()) == 1 - assert z[base_convnext_head_config.head_type].shape == (1, 13, 192, 192) - assert z[base_convnext_head_config.head_type].dtype == torch.float32 + assert z["SingleInstanceConfmapsHead"].shape == (1, 13, 192, 192) + assert z["SingleInstanceConfmapsHead"].dtype == torch.float32 # transposeconv as upsampling stack base_convnext_model_config = OmegaConf.create( { - "backbone_type": "convnext", - "backbone_config": { - "in_channels": 1, - "model_type": "tiny", - "arch": None, - "kernel_size": 3, - "filters_rate": 2, - "convs_per_block": 2, - "up_interpolate": False, - "stem_patch_kernel": 4, - "stem_patch_stride": 4, - }, + "in_channels": 1, + "model_type": "tiny", + "arch": None, + "kernel_size": 3, + "filters_rate": 2, + "convs_per_block": 2, + "up_interpolate": False, + "stem_patch_kernel": 4, + "stem_patch_stride": 4, } ) base_convnext_head_config = OmegaConf.create( { - "head_type": "SingleInstanceConfmapsHead", - "head_config": { + "confmaps": { "part_names": [f"{i}" for i in range(13)], "sigma": 5.0, "output_stride": 1, "loss_weight": 1.0, - }, + } } ) model = Model( + backbone_type="convnext", backbone_config=base_convnext_model_config, - head_configs=DictConfig({"confmap_head": base_convnext_head_config}), + head_configs=base_convnext_head_config, input_expand_channels=1, + model_type="single_instance", ).to(device) assert model.backbone_config == base_convnext_model_config - assert model.head_configs == DictConfig({"confmap_head": base_convnext_head_config}) + assert model.head_configs == base_convnext_head_config x = torch.rand(1, 1, 192, 192).to(device) model.eval() @@ -460,13 +434,15 @@ def test_convnext_model(): assert type(z) is dict assert len(z.keys()) == 1 - assert z[base_convnext_head_config.head_type].shape == (1, 13, 192, 192) - assert z[base_convnext_head_config.head_type].dtype == torch.float32 + assert z["SingleInstanceConfmapsHead"].shape == (1, 13, 192, 192) + assert z["SingleInstanceConfmapsHead"].dtype == torch.float32 model = Model.from_config( + backbone_type="convnext", backbone_config=base_convnext_model_config, - head_configs=DictConfig({"confmap_head": base_convnext_head_config}), + head_configs=base_convnext_head_config, input_expand_channels=1, + model_type="single_instance", ).to(device) x = torch.rand(1, 1, 192, 192).to(device) @@ -477,8 +453,8 @@ def test_convnext_model(): assert type(z) is dict assert len(z.keys()) == 1 - assert z[base_convnext_head_config.head_type].shape == (1, 13, 192, 192) - assert z[base_convnext_head_config.head_type].dtype == torch.float32 + assert z["SingleInstanceConfmapsHead"].shape == (1, 13, 192, 192) + assert z["SingleInstanceConfmapsHead"].dtype == torch.float32 def test_swint_model(): @@ -487,42 +463,40 @@ def test_swint_model(): # stride = 4 base_swint_model_config = OmegaConf.create( { - "backbone_type": "swint", - "backbone_config": { - "in_channels": 1, - "model_type": "tiny", - "arch": None, - "patch_size": [4, 4], - "window_size": [7, 7], - "kernel_size": 3, - "filters_rate": 2, - "convs_per_block": 2, - "up_interpolate": True, - "stem_patch_stride": 4, - }, + "in_channels": 1, + "model_type": "tiny", + "arch": None, + "patch_size": [4, 4], + "window_size": [7, 7], + "kernel_size": 3, + "filters_rate": 2, + "convs_per_block": 2, + "up_interpolate": True, + "stem_patch_stride": 4, } ) base_swint_head_config = OmegaConf.create( { - "head_type": "SingleInstanceConfmapsHead", - "head_config": { + "confmaps": { "part_names": [f"{i}" for i in range(13)], "sigma": 5.0, "output_stride": 1, "loss_weight": 1.0, - }, + } } ) model = Model( + backbone_type="swint", backbone_config=base_swint_model_config, - head_configs=DictConfig({"confmap_head": base_swint_head_config}), + head_configs=base_swint_head_config, input_expand_channels=1, + model_type="single_instance", ).to(device) assert model.backbone_config == base_swint_model_config - assert model.head_configs == DictConfig({"confmap_head": base_swint_head_config}) + assert model.head_configs == base_swint_head_config x = torch.rand(1, 1, 192, 192).to(device) model.eval() @@ -532,13 +506,15 @@ def test_swint_model(): assert type(z) is dict assert len(z.keys()) == 1 - assert z[base_swint_head_config.head_type].shape == (1, 13, 192, 192) - assert z[base_swint_head_config.head_type].dtype == torch.float32 + assert z["SingleInstanceConfmapsHead"].shape == (1, 13, 192, 192) + assert z["SingleInstanceConfmapsHead"].dtype == torch.float32 model = Model.from_config( + backbone_type="swint", backbone_config=base_swint_model_config, - head_configs=DictConfig({"confmap_head": base_swint_head_config}), + head_configs=base_swint_head_config, input_expand_channels=1, + model_type="single_instance", ).to(device) x = torch.rand(1, 1, 192, 192).to(device) @@ -549,49 +525,47 @@ def test_swint_model(): assert type(z) is dict assert len(z.keys()) == 1 - assert z[base_swint_head_config.head_type].shape == (1, 13, 192, 192) - assert z[base_swint_head_config.head_type].dtype == torch.float32 + assert z["SingleInstanceConfmapsHead"].shape == (1, 13, 192, 192) + assert z["SingleInstanceConfmapsHead"].dtype == torch.float32 # transposeConv for upsampling stack base_swint_model_config = OmegaConf.create( { - "backbone_type": "swint", - "backbone_config": { - "in_channels": 1, - "model_type": "tiny", - "arch": None, - "patch_size": [4, 4], - "window_size": [7, 7], - "kernel_size": 3, - "filters_rate": 2, - "convs_per_block": 2, - "up_interpolate": False, - "stem_patch_stride": 4, - "stem_stride": None, - }, + "in_channels": 1, + "model_type": "tiny", + "arch": None, + "patch_size": [4, 4], + "window_size": [7, 7], + "kernel_size": 3, + "filters_rate": 2, + "convs_per_block": 2, + "up_interpolate": False, + "stem_patch_stride": 4, + "stem_stride": None, } ) base_swint_head_config = OmegaConf.create( { - "head_type": "SingleInstanceConfmapsHead", - "head_config": { + "confmaps": { "part_names": [f"{i}" for i in range(13)], "sigma": 5.0, "output_stride": 1, "loss_weight": 1.0, - }, + } } ) model = Model( + backbone_type="swint", backbone_config=base_swint_model_config, - head_configs=DictConfig({"confmap_head": base_swint_head_config}), + head_configs=base_swint_head_config, input_expand_channels=1, + model_type="single_instance", ).to(device) assert model.backbone_config == base_swint_model_config - assert model.head_configs == DictConfig({"confmap_head": base_swint_head_config}) + assert model.head_configs == base_swint_head_config x = torch.rand(1, 1, 192, 192).to(device) model.eval() @@ -601,13 +575,15 @@ def test_swint_model(): assert type(z) is dict assert len(z.keys()) == 1 - assert z[base_swint_head_config.head_type].shape == (1, 13, 192, 192) - assert z[base_swint_head_config.head_type].dtype == torch.float32 + assert z["SingleInstanceConfmapsHead"].shape == (1, 13, 192, 192) + assert z["SingleInstanceConfmapsHead"].dtype == torch.float32 model = Model.from_config( + backbone_type="swint", backbone_config=base_swint_model_config, - head_configs=DictConfig({"confmap_head": base_swint_head_config}), + head_configs=base_swint_head_config, input_expand_channels=1, + model_type="single_instance", ).to(device) x = torch.rand(1, 1, 192, 192).to(device) @@ -618,5 +594,5 @@ def test_swint_model(): assert type(z) is dict assert len(z.keys()) == 1 - assert z[base_swint_head_config.head_type].shape == (1, 13, 192, 192) - assert z[base_swint_head_config.head_type].dtype == torch.float32 + assert z["SingleInstanceConfmapsHead"].shape == (1, 13, 192, 192) + assert z["SingleInstanceConfmapsHead"].dtype == torch.float32 diff --git a/tests/assets/minimal_instance/best.ckpt b/tests/assets/minimal_instance/best.ckpt index bd1bd6b8..6e4ce415 100644 Binary files a/tests/assets/minimal_instance/best.ckpt and b/tests/assets/minimal_instance/best.ckpt differ diff --git a/tests/assets/minimal_instance/initial_config.yaml b/tests/assets/minimal_instance/initial_config.yaml index 1046db6a..abcea35e 100644 --- a/tests/assets/minimal_instance/initial_config.yaml +++ b/tests/assets/minimal_instance/initial_config.yaml @@ -1,127 +1,48 @@ data_config: provider: LabelsReader - pipeline: TopdownConfmaps - train: - labels_path: minimal_instance.pkg.slp + train_labels_path: minimal_instance.pkg.slp + val_labels_path: minimal_instance.pkg.slp + preprocessing: max_width: null max_height: null scale: 1.0 is_rgb: false - preprocessing: - crop_hw: - - 160 - - 160 - augmentation_config: - random_crop: - random_crop_p: 0 - random_crop_hw: - - 160 - - 160 - use_augmentations: true - augmentations: - intensity: - uniform_noise: - - 0.0 - - 0.04 - uniform_noise_p: 0 - gaussian_noise_mean: 0.02 - gaussian_noise_std: 0.004 - gaussian_noise_p: 0 - contrast: - - 0.5 - - 2.0 - contrast_p: 0 - brightness: 0.0 - brightness_p: 0 - geometric: - rotation: 180.0 - scale: 0 - translate: - - 0 - - 0 - affine_p: 0.5 - erase_scale: - - 0.0001 - - 0.01 - erase_ratio: - - 1 - - 1 - erase_p: 0 - mixup_lambda: null - mixup_p: 0 - val: - labels_path: minimal_instance.pkg.slp - max_width: null - max_height: null - is_rgb: false - scale: 1.0 - preprocessing: - crop_hw: - - 160 - - 160 - augmentation_config: - random_crop: - random_crop_p: 0 - random_crop_hw: - - 160 - - 160 - use_augmentations: false - augmentations: - intensity: - uniform_noise: - - 0.0 - - 0.04 - uniform_noise_p: 0 - gaussian_noise_mean: 0.02 - gaussian_noise_std: 0.004 - gaussian_noise_p: 0 - contrast: - - 0.5 - - 2.0 - contrast_p: 0 - brightness: 0.0 - brightness_p: 0 - geometric: - rotation: 180.0 - scale: 0 - translate: - - 0 - - 0 - affine_p: 0.5 - erase_scale: - - 0.0001 - - 0.01 - erase_ratio: - - 1 - - 1 - erase_p: 0 - mixup_lambda: null - mixup_p: 0 + crop_hw: + - 160 + - 160 + use_augmentations_train: true + augmentation_config: + geometric: + rotation: 180.0 + scale: null + translate_width: 0 + translate_height: 0 + affine_p: 0.5 model_config: init_weights: xavier pre_trained_weights: null + backbone_type: unet backbone_config: - backbone_type: unet - backbone_config: - in_channels: 1 - kernel_size: 3 - filters: 16 - filters_rate: 2 - max_stride: 16 - convs_per_block: 2 - stacks: 1 - stem_stride: null - middle_block: true - up_interpolate: true + in_channels: 1 + kernel_size: 3 + filters: 16 + filters_rate: 2 + max_stride: 16 + convs_per_block: 2 + stacks: 1 + stem_stride: null + middle_block: true + up_interpolate: true head_configs: - confmaps: - head_type: CenteredInstanceConfmapsHead - head_config: + single_instance: null + bottomup: null + centroid: null + centered_instance: + confmaps: part_names: null anchor_part: 0 sigma: 1.5 output_stride: 2 - loss_weight: 1.0 trainer_config: train_data_loader: batch_size: 4 @@ -133,12 +54,11 @@ trainer_config: model_ckpt: save_top_k: 1 save_last: true - device: cpu trainer_devices: 1 trainer_accelerator: cpu enable_progress_bar: false steps_per_epoch: null - max_epochs: 10 + max_epochs: 1 seed: 1000 use_wandb: false save_ckpt: true @@ -153,7 +73,7 @@ trainer_config: - trainer_config.optimizer_name - trainer_config.optimizer.amsgrad - trainer_config.optimizer.lr - - model_config.backbone_config.backbone_type + - model_config.backbone_type - model_config.init_weights optimizer_name: Adam optimizer: diff --git a/tests/assets/minimal_instance/last.ckpt b/tests/assets/minimal_instance/last.ckpt index 2ed8a839..dfcdb48b 100644 Binary files a/tests/assets/minimal_instance/last.ckpt and b/tests/assets/minimal_instance/last.ckpt differ diff --git a/tests/assets/minimal_instance/training_config.yaml b/tests/assets/minimal_instance/training_config.yaml index 00655543..ca8438a7 100644 --- a/tests/assets/minimal_instance/training_config.yaml +++ b/tests/assets/minimal_instance/training_config.yaml @@ -1,102 +1,23 @@ data_config: provider: LabelsReader - pipeline: TopdownConfmaps - train: - labels_path: minimal_instance.pkg.slp + train_labels_path: minimal_instance.pkg.slp + val_labels_path: minimal_instance.pkg.slp + preprocessing: max_width: null max_height: null scale: 1.0 is_rgb: false - preprocessing: - crop_hw: - - 160 - - 160 - augmentation_config: - random_crop: - random_crop_p: 0 - random_crop_hw: - - 160 - - 160 - use_augmentations: true - augmentations: - intensity: - uniform_noise: - - 0.0 - - 0.04 - uniform_noise_p: 0 - gaussian_noise_mean: 0.02 - gaussian_noise_std: 0.004 - gaussian_noise_p: 0 - contrast: - - 0.5 - - 2.0 - contrast_p: 0 - brightness: 0.0 - brightness_p: 0 - geometric: - rotation: 180.0 - scale: 0 - translate: - - 0 - - 0 - affine_p: 0.5 - erase_scale: - - 0.0001 - - 0.01 - erase_ratio: - - 1 - - 1 - erase_p: 0 - mixup_lambda: null - mixup_p: 0 - val: - labels_path: minimal_instance.pkg.slp - max_width: null - max_height: null - is_rgb: false - scale: 1.0 - preprocessing: - crop_hw: - - 160 - - 160 - augmentation_config: - random_crop: - random_crop_p: 0 - random_crop_hw: - - 160 - - 160 - use_augmentations: false - augmentations: - intensity: - uniform_noise: - - 0.0 - - 0.04 - uniform_noise_p: 0 - gaussian_noise_mean: 0.02 - gaussian_noise_std: 0.004 - gaussian_noise_p: 0 - contrast: - - 0.5 - - 2.0 - contrast_p: 0 - brightness: 0.0 - brightness_p: 0 - geometric: - rotation: 180.0 - scale: 0 - translate: - - 0 - - 0 - affine_p: 0.5 - erase_scale: - - 0.0001 - - 0.01 - erase_ratio: - - 1 - - 1 - erase_p: 0 - mixup_lambda: null - mixup_p: 0 + crop_hw: + - 160 + - 160 + use_augmentations_train: true + augmentation_config: + geometric: + rotation: 180.0 + scale: null + translate_width: 0 + translate_height: 0 + affine_p: 0.5 skeletons: Skeleton-0: nodes: @@ -111,30 +32,30 @@ data_config: model_config: init_weights: xavier pre_trained_weights: null + backbone_type: unet backbone_config: - backbone_type: unet - backbone_config: - in_channels: 1 - kernel_size: 3 - filters: 16 - filters_rate: 2 - max_stride: 16 - convs_per_block: 2 - stacks: 1 - stem_stride: null - middle_block: true - up_interpolate: true + in_channels: 1 + kernel_size: 3 + filters: 16 + filters_rate: 2 + max_stride: 16 + convs_per_block: 2 + stacks: 1 + stem_stride: null + middle_block: true + up_interpolate: true head_configs: - confmaps: - head_type: CenteredInstanceConfmapsHead - head_config: + single_instance: null + bottomup: null + centroid: null + centered_instance: + confmaps: part_names: - A - B anchor_part: 0 sigma: 1.5 output_stride: 2 - loss_weight: 1.0 trainer_config: train_data_loader: batch_size: 4 @@ -146,12 +67,11 @@ trainer_config: model_ckpt: save_top_k: 1 save_last: true - device: cpu trainer_devices: 1 trainer_accelerator: cpu enable_progress_bar: false steps_per_epoch: null - max_epochs: 10 + max_epochs: 1 seed: 1000 use_wandb: false save_ckpt: true @@ -166,7 +86,7 @@ trainer_config: - trainer_config.optimizer_name - trainer_config.optimizer.amsgrad - trainer_config.optimizer.lr - - model_config.backbone_config.backbone_type + - model_config.backbone_type - model_config.init_weights optimizer_name: Adam optimizer: diff --git a/tests/assets/minimal_instance_bottomup/best.ckpt b/tests/assets/minimal_instance_bottomup/best.ckpt index d4525e06..a7f5134b 100644 Binary files a/tests/assets/minimal_instance_bottomup/best.ckpt and b/tests/assets/minimal_instance_bottomup/best.ckpt differ diff --git a/tests/assets/minimal_instance_bottomup/initial_config.yaml b/tests/assets/minimal_instance_bottomup/initial_config.yaml index 13d2ecb8..59f7a25e 100644 --- a/tests/assets/minimal_instance_bottomup/initial_config.yaml +++ b/tests/assets/minimal_instance_bottomup/initial_config.yaml @@ -1,129 +1,46 @@ data_config: provider: LabelsReader - pipeline: BottomUp - train: - labels_path: minimal_instance.pkg.slp + train_labels_path: minimal_instance.pkg.slp + val_labels_path: minimal_instance.pkg.slp + preprocessing: max_width: null max_height: null scale: 1.0 is_rgb: false - preprocessing: - crop_hw: - - 160 - - 160 - augmentation_config: - random_crop: - random_crop_p: 0 - random_crop_hw: - - 160 - - 160 - use_augmentations: true - augmentations: - intensity: - uniform_noise: - - 0.0 - - 0.04 - uniform_noise_p: 0 - gaussian_noise_mean: 0.02 - gaussian_noise_std: 0.004 - gaussian_noise_p: 0 - contrast: - - 0.5 - - 2.0 - contrast_p: 0 - brightness: 0.0 - brightness_p: 0 - geometric: - rotation: 180.0 - scale: 0 - translate: - - 0 - - 0 - affine_p: 0.5 - erase_scale: - - 0.0001 - - 0.01 - erase_ratio: - - 1 - - 1 - erase_p: 0 - mixup_lambda: null - mixup_p: 0 - val: - labels_path: minimal_instance.pkg.slp - max_width: null - max_height: null - is_rgb: false - scale: 1.0 - preprocessing: - crop_hw: - - 160 - - 160 - augmentation_config: - random_crop: - random_crop_p: 0 - random_crop_hw: - - 160 - - 160 - use_augmentations: false - augmentations: - intensity: - uniform_noise: - - 0.0 - - 0.04 - uniform_noise_p: 0 - gaussian_noise_mean: 0.02 - gaussian_noise_std: 0.004 - gaussian_noise_p: 0 - contrast: - - 0.5 - - 2.0 - contrast_p: 0 - brightness: 0.0 - brightness_p: 0 - geometric: - rotation: 180.0 - scale: 0 - translate: - - 0 - - 0 - affine_p: 0.5 - erase_scale: - - 0.0001 - - 0.01 - erase_ratio: - - 1 - - 1 - erase_p: 0 - mixup_lambda: null - mixup_p: 0 + use_augmentations_train: true + augmentation_config: + geometric: + rotation: 180.0 + scale: null + translate_width: 0 + translate_height: 0 + affine_p: 0.5 model_config: init_weights: xavier pre_trained_weights: null + backbone_type: unet backbone_config: - backbone_type: unet - backbone_config: - in_channels: 1 - kernel_size: 3 - filters: 16 - filters_rate: 2 - max_stride: 16 - convs_per_block: 2 - stacks: 1 - stem_stride: null - middle_block: true - up_interpolate: true + in_channels: 1 + kernel_size: 3 + filters: 16 + filters_rate: 2 + max_stride: 16 + convs_per_block: 2 + stacks: 1 + stem_stride: null + middle_block: true + up_interpolate: true head_configs: - confmaps: - head_type: MultiInstanceConfmapsHead - head_config: + single_instance: null + centered_instance: null + centroid: null + bottomup: + confmaps: part_names: null sigma: 1.5 output_stride: 2 loss_weight: 1.0 - pafs: - head_type: PartAffinityFieldsHead - head_config: + pafs: edges: null sigma: 50 output_stride: 4 @@ -139,7 +56,6 @@ trainer_config: model_ckpt: save_top_k: 1 save_last: true - device: cpu trainer_devices: 1 trainer_accelerator: cpu enable_progress_bar: false @@ -148,7 +64,7 @@ trainer_config: seed: 1000 use_wandb: false save_ckpt: true - save_ckpt_path: min_inst_bottomup1 + save_ckpt_path: min_inst_bottomup wandb: entity: team-ucsd project: test_centroid_centered @@ -159,7 +75,7 @@ trainer_config: - trainer_config.optimizer_name - trainer_config.optimizer.amsgrad - trainer_config.optimizer.lr - - model_config.backbone_config.backbone_type + - model_config.backbone_type - model_config.init_weights optimizer_name: Adam optimizer: diff --git a/tests/assets/minimal_instance_bottomup/last.ckpt b/tests/assets/minimal_instance_bottomup/last.ckpt index b18f96ad..6cce40c1 100644 Binary files a/tests/assets/minimal_instance_bottomup/last.ckpt and b/tests/assets/minimal_instance_bottomup/last.ckpt differ diff --git a/tests/assets/minimal_instance_bottomup/training_config.yaml b/tests/assets/minimal_instance_bottomup/training_config.yaml index 25c681a4..da08cd3c 100644 --- a/tests/assets/minimal_instance_bottomup/training_config.yaml +++ b/tests/assets/minimal_instance_bottomup/training_config.yaml @@ -1,102 +1,20 @@ data_config: provider: LabelsReader - pipeline: BottomUp - train: - labels_path: minimal_instance.pkg.slp + train_labels_path: minimal_instance.pkg.slp + val_labels_path: minimal_instance.pkg.slp + preprocessing: max_width: null max_height: null scale: 1.0 is_rgb: false - preprocessing: - crop_hw: - - 160 - - 160 - augmentation_config: - random_crop: - random_crop_p: 0 - random_crop_hw: - - 160 - - 160 - use_augmentations: true - augmentations: - intensity: - uniform_noise: - - 0.0 - - 0.04 - uniform_noise_p: 0 - gaussian_noise_mean: 0.02 - gaussian_noise_std: 0.004 - gaussian_noise_p: 0 - contrast: - - 0.5 - - 2.0 - contrast_p: 0 - brightness: 0.0 - brightness_p: 0 - geometric: - rotation: 180.0 - scale: 0 - translate: - - 0 - - 0 - affine_p: 0.5 - erase_scale: - - 0.0001 - - 0.01 - erase_ratio: - - 1 - - 1 - erase_p: 0 - mixup_lambda: null - mixup_p: 0 - val: - labels_path: minimal_instance.pkg.slp - max_width: null - max_height: null - is_rgb: false - scale: 1.0 - preprocessing: - crop_hw: - - 160 - - 160 - augmentation_config: - random_crop: - random_crop_p: 0 - random_crop_hw: - - 160 - - 160 - use_augmentations: false - augmentations: - intensity: - uniform_noise: - - 0.0 - - 0.04 - uniform_noise_p: 0 - gaussian_noise_mean: 0.02 - gaussian_noise_std: 0.004 - gaussian_noise_p: 0 - contrast: - - 0.5 - - 2.0 - contrast_p: 0 - brightness: 0.0 - brightness_p: 0 - geometric: - rotation: 180.0 - scale: 0 - translate: - - 0 - - 0 - affine_p: 0.5 - erase_scale: - - 0.0001 - - 0.01 - erase_ratio: - - 1 - - 1 - erase_p: 0 - mixup_lambda: null - mixup_p: 0 + use_augmentations_train: true + augmentation_config: + geometric: + rotation: 180.0 + scale: null + translate_width: 0 + translate_height: 0 + affine_p: 0.5 skeletons: Skeleton-0: nodes: @@ -111,32 +29,31 @@ data_config: model_config: init_weights: xavier pre_trained_weights: null + backbone_type: unet backbone_config: - backbone_type: unet - backbone_config: - in_channels: 1 - kernel_size: 3 - filters: 16 - filters_rate: 2 - max_stride: 16 - convs_per_block: 2 - stacks: 1 - stem_stride: null - middle_block: true - up_interpolate: true + in_channels: 1 + kernel_size: 3 + filters: 16 + filters_rate: 2 + max_stride: 16 + convs_per_block: 2 + stacks: 1 + stem_stride: null + middle_block: true + up_interpolate: true head_configs: - confmaps: - head_type: MultiInstanceConfmapsHead - head_config: + single_instance: null + centered_instance: null + centroid: null + bottomup: + confmaps: part_names: - A - B sigma: 1.5 output_stride: 2 loss_weight: 1.0 - pafs: - head_type: PartAffinityFieldsHead - head_config: + pafs: edges: - - A - B @@ -154,7 +71,6 @@ trainer_config: model_ckpt: save_top_k: 1 save_last: true - device: cpu trainer_devices: 1 trainer_accelerator: cpu enable_progress_bar: false @@ -163,7 +79,7 @@ trainer_config: seed: 1000 use_wandb: false save_ckpt: true - save_ckpt_path: min_inst_bottomup1 + save_ckpt_path: min_inst_bottomup wandb: entity: team-ucsd project: test_centroid_centered @@ -174,7 +90,7 @@ trainer_config: - trainer_config.optimizer_name - trainer_config.optimizer.amsgrad - trainer_config.optimizer.lr - - model_config.backbone_config.backbone_type + - model_config.backbone_type - model_config.init_weights optimizer_name: Adam optimizer: diff --git a/tests/assets/minimal_instance_centroid/best.ckpt b/tests/assets/minimal_instance_centroid/best.ckpt index 91684945..0fcf57d6 100644 Binary files a/tests/assets/minimal_instance_centroid/best.ckpt and b/tests/assets/minimal_instance_centroid/best.ckpt differ diff --git a/tests/assets/minimal_instance_centroid/initial_config.yaml b/tests/assets/minimal_instance_centroid/initial_config.yaml index d3f74a46..7bb9aaef 100644 --- a/tests/assets/minimal_instance_centroid/initial_config.yaml +++ b/tests/assets/minimal_instance_centroid/initial_config.yaml @@ -1,126 +1,48 @@ data_config: provider: LabelsReader - pipeline: CentroidConfmaps - train: - labels_path: minimal_instance.pkg.slp + train_labels_path: minimal_instance.pkg.slp + val_labels_path: minimal_instance.pkg.slp + preprocessing: max_width: null max_height: null scale: 1.0 is_rgb: false - preprocessing: - crop_hw: - - 160 - - 160 - augmentation_config: - random_crop: - random_crop_p: 0 - random_crop_hw: - - 160 - - 160 - use_augmentations: true - augmentations: - intensity: - uniform_noise: - - 0.0 - - 0.04 - uniform_noise_p: 0 - gaussian_noise_mean: 0.02 - gaussian_noise_std: 0.004 - gaussian_noise_p: 0 - contrast: - - 0.5 - - 2.0 - contrast_p: 0 - brightness: 0.0 - brightness_p: 0 - geometric: - rotation: 180.0 - scale: 0 - translate: - - 0 - - 0 - affine_p: 0.5 - erase_scale: - - 0.0001 - - 0.01 - erase_ratio: - - 1 - - 1 - erase_p: 0 - mixup_lambda: null - mixup_p: 0 - val: - labels_path: minimal_instance.pkg.slp - max_width: null - max_height: null - is_rgb: false - scale: 1.0 - preprocessing: - crop_hw: - - 160 - - 160 - augmentation_config: - random_crop: - random_crop_p: 0 - random_crop_hw: - - 160 - - 160 - use_augmentations: false - augmentations: - intensity: - uniform_noise: - - 0.0 - - 0.04 - uniform_noise_p: 0 - gaussian_noise_mean: 0.02 - gaussian_noise_std: 0.004 - gaussian_noise_p: 0 - contrast: - - 0.5 - - 2.0 - contrast_p: 0 - brightness: 0.0 - brightness_p: 0 - geometric: - rotation: 180.0 - scale: 0 - translate: - - 0 - - 0 - affine_p: 0.5 - erase_scale: - - 0.0001 - - 0.01 - erase_ratio: - - 1 - - 1 - erase_p: 0 - mixup_lambda: null - mixup_p: 0 + use_augmentations_train: true + augmentation_config: + random_crop: + random_crop_p: 0 + crop_height: 160 + crop_width: 160 + geometric: + rotation: 180.0 + scale: null + translate_width: 0 + translate_height: 0 + affine_p: 0.5 model_config: init_weights: xavier pre_trained_weights: null + backbone_type: unet backbone_config: - backbone_type: unet - backbone_config: - in_channels: 1 - kernel_size: 3 - filters: 16 - filters_rate: 2 - max_stride: 16 - convs_per_block: 2 - stacks: 1 - stem_stride: null - middle_block: true - up_interpolate: true + in_channels: 1 + kernel_size: 3 + filters: 16 + filters_rate: 2 + max_stride: 16 + convs_per_block: 2 + stacks: 1 + stem_stride: null + middle_block: true + up_interpolate: true head_configs: - confmaps: - head_type: CentroidConfmapsHead - head_config: + single_instance: null + centered_instance: null + bottomup: null + centroid: + confmaps: anchor_part: 0 sigma: 1.5 output_stride: 2 - loss_weight: 1.0 trainer_config: train_data_loader: batch_size: 4 @@ -132,7 +54,6 @@ trainer_config: model_ckpt: save_top_k: 1 save_last: true - device: cpu trainer_devices: 1 trainer_accelerator: cpu enable_progress_bar: false @@ -152,7 +73,7 @@ trainer_config: - trainer_config.optimizer_name - trainer_config.optimizer.amsgrad - trainer_config.optimizer.lr - - model_config.backbone_config.backbone_type + - model_config.backbone_type - model_config.init_weights optimizer_name: Adam optimizer: diff --git a/tests/assets/minimal_instance_centroid/last.ckpt b/tests/assets/minimal_instance_centroid/last.ckpt index 5a868e43..d6e65457 100644 Binary files a/tests/assets/minimal_instance_centroid/last.ckpt and b/tests/assets/minimal_instance_centroid/last.ckpt differ diff --git a/tests/assets/minimal_instance_centroid/training_config.yaml b/tests/assets/minimal_instance_centroid/training_config.yaml index 8eeafdd6..c625f1e3 100644 --- a/tests/assets/minimal_instance_centroid/training_config.yaml +++ b/tests/assets/minimal_instance_centroid/training_config.yaml @@ -1,102 +1,24 @@ data_config: provider: LabelsReader - pipeline: CentroidConfmaps - train: - labels_path: minimal_instance.pkg.slp + train_labels_path: minimal_instance.pkg.slp + val_labels_path: minimal_instance.pkg.slp + preprocessing: max_width: null max_height: null scale: 1.0 is_rgb: false - preprocessing: - crop_hw: - - 160 - - 160 - augmentation_config: - random_crop: - random_crop_p: 0 - random_crop_hw: - - 160 - - 160 - use_augmentations: true - augmentations: - intensity: - uniform_noise: - - 0.0 - - 0.04 - uniform_noise_p: 0 - gaussian_noise_mean: 0.02 - gaussian_noise_std: 0.004 - gaussian_noise_p: 0 - contrast: - - 0.5 - - 2.0 - contrast_p: 0 - brightness: 0.0 - brightness_p: 0 - geometric: - rotation: 180.0 - scale: 0 - translate: - - 0 - - 0 - affine_p: 0.5 - erase_scale: - - 0.0001 - - 0.01 - erase_ratio: - - 1 - - 1 - erase_p: 0 - mixup_lambda: null - mixup_p: 0 - val: - labels_path: minimal_instance.pkg.slp - max_width: null - max_height: null - is_rgb: false - scale: 1.0 - preprocessing: - crop_hw: - - 160 - - 160 - augmentation_config: - random_crop: - random_crop_p: 0 - random_crop_hw: - - 160 - - 160 - use_augmentations: false - augmentations: - intensity: - uniform_noise: - - 0.0 - - 0.04 - uniform_noise_p: 0 - gaussian_noise_mean: 0.02 - gaussian_noise_std: 0.004 - gaussian_noise_p: 0 - contrast: - - 0.5 - - 2.0 - contrast_p: 0 - brightness: 0.0 - brightness_p: 0 - geometric: - rotation: 180.0 - scale: 0 - translate: - - 0 - - 0 - affine_p: 0.5 - erase_scale: - - 0.0001 - - 0.01 - erase_ratio: - - 1 - - 1 - erase_p: 0 - mixup_lambda: null - mixup_p: 0 + use_augmentations_train: true + augmentation_config: + random_crop: + random_crop_p: 0 + crop_height: 160 + crop_width: 160 + geometric: + rotation: 180.0 + scale: null + translate_width: 0 + translate_height: 0 + affine_p: 0.5 skeletons: Skeleton-0: nodes: @@ -111,27 +33,27 @@ data_config: model_config: init_weights: xavier pre_trained_weights: null + backbone_type: unet backbone_config: - backbone_type: unet - backbone_config: - in_channels: 1 - kernel_size: 3 - filters: 16 - filters_rate: 2 - max_stride: 16 - convs_per_block: 2 - stacks: 1 - stem_stride: null - middle_block: true - up_interpolate: true + in_channels: 1 + kernel_size: 3 + filters: 16 + filters_rate: 2 + max_stride: 16 + convs_per_block: 2 + stacks: 1 + stem_stride: null + middle_block: true + up_interpolate: true head_configs: - confmaps: - head_type: CentroidConfmapsHead - head_config: + single_instance: null + centered_instance: null + bottomup: null + centroid: + confmaps: anchor_part: 0 sigma: 1.5 output_stride: 2 - loss_weight: 1.0 trainer_config: train_data_loader: batch_size: 4 @@ -143,7 +65,6 @@ trainer_config: model_ckpt: save_top_k: 1 save_last: true - device: cpu trainer_devices: 1 trainer_accelerator: cpu enable_progress_bar: false @@ -163,7 +84,7 @@ trainer_config: - trainer_config.optimizer_name - trainer_config.optimizer.amsgrad - trainer_config.optimizer.lr - - model_config.backbone_config.backbone_type + - model_config.backbone_type - model_config.init_weights optimizer_name: Adam optimizer: diff --git a/tests/data/test_augmentation.py b/tests/data/test_augmentation.py index 315dfb91..7c619c81 100644 --- a/tests/data/test_augmentation.py +++ b/tests/data/test_augmentation.py @@ -50,7 +50,8 @@ def test_kornia_augmentation(minimal_instance): erase_p=1.0, mixup_p=1.0, mixup_lambda=(0.0, 1.0), - random_crop_hw=(384, 384), + random_crop_height=384, + random_crop_width=384, random_crop_p=1.0, ) @@ -71,6 +72,7 @@ def test_kornia_augmentation(minimal_instance): ): p = KorniaAugmenter( p, - random_crop_hw=(0, 0), + random_crop_height=0, + random_crop_width=0, random_crop_p=1.0, ) diff --git a/tests/data/test_instance_cropping.py b/tests/data/test_instance_cropping.py index 91cb666d..77b94f97 100644 --- a/tests/data/test_instance_cropping.py +++ b/tests/data/test_instance_cropping.py @@ -35,7 +35,6 @@ def test_instance_cropper(minimal_instance): sample = next(iter(datapipe)) gt_sample_keys = [ - "image", "centroid", "instance", "instance_bbox", @@ -44,7 +43,6 @@ def test_instance_cropper(minimal_instance): "frame_idx", "num_instances", "orig_size", - "scale", ] # Test shapes. diff --git a/tests/data/test_pipelines.py b/tests/data/test_pipelines.py index b4f51c1d..db0abf6a 100644 --- a/tests/data/test_pipelines.py +++ b/tests/data/test_pipelines.py @@ -36,7 +36,6 @@ def test_key_filter(minimal_instance): datapipe = KeyFilter(datapipe, keep_keys=None) gt_sample_keys = [ - "image", "centroid", "instance", "instance_bbox", @@ -46,7 +45,6 @@ def test_key_filter(minimal_instance): "frame_idx", "num_instances", "orig_size", - "scale", ] sample = next(iter(datapipe)) @@ -76,7 +74,6 @@ def test_key_filter(minimal_instance): datapipe = KeyFilter(datapipe, keep_keys=None) gt_sample_keys = [ - "image", "centroid", "instance", "instance_bbox", @@ -86,7 +83,6 @@ def test_key_filter(minimal_instance): "frame_idx", "num_instances", "orig_size", - "scale", "original_image", ] @@ -98,42 +94,15 @@ def test_topdownconfmapspipeline(minimal_instance): """Test the TopdownConfmapsPipeline.""" base_topdown_data_config = OmegaConf.create( { - "max_height": None, - "max_width": None, - "scale": 1.0, - "is_rgb": False, "preprocessing": { + "max_height": None, + "max_width": None, + "scale": 1.0, + "is_rgb": False, "crop_hw": (160, 160), }, - "augmentation_config": { - "random_crop": {"random_crop_p": 1.0, "random_crop_hw": (160, 160)}, - "use_augmentations": False, - "augmentations": { - "intensity": { - "uniform_noise": (0.0, 0.04), - "uniform_noise_p": 0.5, - "gaussian_noise_mean": 0.02, - "gaussian_noise_std": 0.004, - "gaussian_noise_p": 0.5, - "contrast": (0.5, 2.0), - "contrast_p": 0.5, - "brightness": 0.0, - "brightness_p": 0.5, - }, - "geometric": { - "rotation": 15.0, - "scale": 0.05, - "translate": (0.02, 0.02), - "affine_p": 0.5, - "erase_scale": (0.0001, 0.01), - "erase_ratio": (1, 1), - "erase_p": 0.5, - "mixup_lambda": None, - "mixup_p": 0.5, - }, - }, - }, - } + "use_augmentations_train": False, + }, ) confmap_head = DictConfig({"sigma": 1.5, "output_stride": 2, "anchor_part": 0}) @@ -143,10 +112,12 @@ def test_topdownconfmapspipeline(minimal_instance): ) data_provider = LabelsReader(labels=sio.load_slp(minimal_instance)) - datapipe = pipeline.make_training_pipeline(data_provider=data_provider) + datapipe = pipeline.make_training_pipeline( + data_provider=data_provider, + use_augmentations=base_topdown_data_config.use_augmentations_train, + ) gt_sample_keys = [ - "image", "centroid", "instance", "instance_bbox", @@ -156,7 +127,6 @@ def test_topdownconfmapspipeline(minimal_instance): "video_idx", "orig_size", "num_instances", - "scale", ] sample = next(iter(datapipe)) assert len(sample.keys()) == len(gt_sample_keys) @@ -168,39 +138,46 @@ def test_topdownconfmapspipeline(minimal_instance): base_topdown_data_config = OmegaConf.create( { - "max_height": None, - "max_width": None, - "scale": 1.0, - "is_rgb": False, "preprocessing": { + "max_height": None, + "max_width": None, + "scale": 1.0, + "is_rgb": False, "crop_hw": (100, 100), }, + "use_augmentations_train": True, "augmentation_config": { - "random_crop": {"random_crop_p": 0.0, "random_crop_hw": (160, 160)}, - "use_augmentations": True, - "augmentations": { - "intensity": { - "uniform_noise": (0.0, 0.04), - "uniform_noise_p": 0.5, - "gaussian_noise_mean": 0.02, - "gaussian_noise_std": 0.004, - "gaussian_noise_p": 0.5, - "contrast": (0.5, 2.0), - "contrast_p": 0.5, - "brightness": 0.0, - "brightness_p": 0.5, - }, - "geometric": { - "rotation": 15.0, - "scale": 0.05, - "translate": (0.02, 0.02), - "affine_p": 0.5, - "erase_scale": (0.0001, 0.01), - "erase_ratio": (1, 1), - "erase_p": 0.5, - "mixup_lambda": None, - "mixup_p": 0.5, - }, + "random_crop": { + "random_crop_p": 0.0, + "crop_height": 160, + "crop_width": 160, + }, + "intensity": { + "uniform_noise_min": 0.0, + "uniform_noise_max": 0.04, + "uniform_noise_p": 0.5, + "gaussian_noise_mean": 0.02, + "gaussian_noise_std": 0.004, + "gaussian_noise_p": 0.5, + "contrast_min": 0.5, + "contrast_max": 2.0, + "contrast_p": 0.5, + "brightness": 0.0, + "brightness_p": 0.5, + }, + "geometric": { + "rotation": 15.0, + "scale": 0.05, + "translate_width": 0.02, + "translate_height": 0.02, + "affine_p": 0.5, + "erase_scale_min": 0.0001, + "erase_scale_max": 0.01, + "erase_ratio_min": 1, + "erase_ratio_max": 1, + "erase_p": 0.5, + "mixup_lambda": None, + "mixup_p": 0.5, }, }, } @@ -211,10 +188,12 @@ def test_topdownconfmapspipeline(minimal_instance): ) data_provider = LabelsReader(labels=sio.load_slp(minimal_instance)) - datapipe = pipeline.make_training_pipeline(data_provider=data_provider) + datapipe = pipeline.make_training_pipeline( + data_provider=data_provider, + use_augmentations=base_topdown_data_config.use_augmentations_train, + ) gt_sample_keys = [ - "image", "centroid", "instance", "instance_bbox", @@ -224,7 +203,6 @@ def test_topdownconfmapspipeline(minimal_instance): "video_idx", "orig_size", "num_instances", - "scale", ] sample = next(iter(datapipe)) @@ -238,39 +216,46 @@ def test_topdownconfmapspipeline(minimal_instance): # Test with resizing and padding base_topdown_data_config = OmegaConf.create( { - "max_height": None, - "max_width": None, - "scale": 2.0, - "is_rgb": False, "preprocessing": { + "max_height": None, + "max_width": None, + "scale": 2.0, + "is_rgb": False, "crop_hw": (100, 100), }, + "use_augmentations_train": True, "augmentation_config": { - "random_crop": {"random_crop_p": 0.0, "random_crop_hw": (160, 160)}, - "use_augmentations": True, - "augmentations": { - "intensity": { - "uniform_noise": (0.0, 0.04), - "uniform_noise_p": 0.5, - "gaussian_noise_mean": 0.02, - "gaussian_noise_std": 0.004, - "gaussian_noise_p": 0.5, - "contrast": (0.5, 2.0), - "contrast_p": 0.5, - "brightness": 0.0, - "brightness_p": 0.5, - }, - "geometric": { - "rotation": 15.0, - "scale": 0.05, - "translate": (0.02, 0.02), - "affine_p": 0.5, - "erase_scale": (0.0001, 0.01), - "erase_ratio": (1, 1), - "erase_p": 0.5, - "mixup_lambda": None, - "mixup_p": 0.5, - }, + "random_crop": { + "random_crop_p": 0.0, + "crop_height": 160, + "crop_width": 160, + }, + "intensity": { + "uniform_noise_min": 0.0, + "uniform_noise_max": 0.04, + "uniform_noise_p": 0.5, + "gaussian_noise_mean": 0.02, + "gaussian_noise_std": 0.004, + "gaussian_noise_p": 0.5, + "contrast_min": 0.5, + "contrast_max": 2.0, + "contrast_p": 0.5, + "brightness": 0.0, + "brightness_p": 0.5, + }, + "geometric": { + "rotation": 15.0, + "scale": 0.05, + "translate_width": 0.02, + "translate_height": 0.02, + "affine_p": 0.5, + "erase_scale_min": 0.0001, + "erase_scale_max": 0.01, + "erase_ratio_min": 1, + "erase_ratio_max": 1, + "erase_p": 0.5, + "mixup_lambda": None, + "mixup_p": 0.5, }, }, } @@ -281,10 +266,12 @@ def test_topdownconfmapspipeline(minimal_instance): ) data_provider = LabelsReader(labels=sio.load_slp(minimal_instance)) - datapipe = pipeline.make_training_pipeline(data_provider=data_provider) + datapipe = pipeline.make_training_pipeline( + data_provider=data_provider, + use_augmentations=base_topdown_data_config.use_augmentations_train, + ) gt_sample_keys = [ - "image", "centroid", "instance", "instance_bbox", @@ -294,7 +281,6 @@ def test_topdownconfmapspipeline(minimal_instance): "video_idx", "orig_size", "num_instances", - "scale", ] sample = next(iter(datapipe)) @@ -316,38 +302,13 @@ def test_singleinstanceconfmapspipeline(minimal_instance): base_singleinstance_data_config = OmegaConf.create( { - "max_height": None, - "max_width": None, - "scale": 2.0, - "is_rgb": False, - "augmentation_config": { - "random_crop": {"random_crop_p": 0.0, "random_crop_hw": (160, 160)}, - "use_augmentations": False, - "augmentations": { - "intensity": { - "uniform_noise": (0.0, 0.04), - "uniform_noise_p": 0.5, - "gaussian_noise_mean": 0.02, - "gaussian_noise_std": 0.004, - "gaussian_noise_p": 0.5, - "contrast": (0.5, 2.0), - "contrast_p": 0.5, - "brightness": 0.0, - "brightness_p": 0.5, - }, - "geometric": { - "rotation": 15.0, - "scale": 0.05, - "translate": (0.02, 0.02), - "affine_p": 0.5, - "erase_scale": (0.0001, 0.01), - "erase_ratio": (1, 1), - "erase_p": 0.5, - "mixup_lambda": None, - "mixup_p": 0.5, - }, - }, + "preprocessing": { + "max_height": None, + "max_width": None, + "scale": 2.0, + "is_rgb": False, }, + "use_augmentations_train": False, } ) @@ -360,7 +321,10 @@ def test_singleinstanceconfmapspipeline(minimal_instance): ) data_provider = LabelsReader(labels=labels) - datapipe = pipeline.make_training_pipeline(data_provider=data_provider) + datapipe = pipeline.make_training_pipeline( + data_provider=data_provider, + use_augmentations=base_singleinstance_data_config.use_augmentations_train, + ) sample = next(iter(datapipe)) @@ -371,7 +335,6 @@ def test_singleinstanceconfmapspipeline(minimal_instance): "instances", "confidence_maps", "orig_size", - "scale", ] for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())): @@ -381,36 +344,45 @@ def test_singleinstanceconfmapspipeline(minimal_instance): base_singleinstance_data_config = OmegaConf.create( { - "max_height": None, - "max_width": None, - "scale": 1.0, - "is_rgb": False, + "preprocessing": { + "max_height": None, + "max_width": None, + "scale": 1.0, + "is_rgb": False, + }, + "use_augmentations_train": True, "augmentation_config": { - "random_crop": {"random_crop_p": 1.0, "random_crop_hw": (160, 160)}, - "use_augmentations": True, - "augmentations": { - "intensity": { - "uniform_noise": (0.0, 0.04), - "uniform_noise_p": 0.5, - "gaussian_noise_mean": 0.02, - "gaussian_noise_std": 0.004, - "gaussian_noise_p": 0.5, - "contrast": (0.5, 2.0), - "contrast_p": 0.5, - "brightness": 0.0, - "brightness_p": 0.5, - }, - "geometric": { - "rotation": 15.0, - "scale": 0.05, - "translate": (0.02, 0.02), - "affine_p": 0.5, - "erase_scale": (0.0001, 0.01), - "erase_ratio": (1, 1), - "erase_p": 0.5, - "mixup_lambda": None, - "mixup_p": 0.5, - }, + "random_crop": { + "random_crop_p": 1.0, + "crop_height": 160, + "crop_width": 160, + }, + "intensity": { + "uniform_noise_min": 0.0, + "uniform_noise_max": 0.04, + "uniform_noise_p": 0.5, + "gaussian_noise_mean": 0.02, + "gaussian_noise_std": 0.004, + "gaussian_noise_p": 0.5, + "contrast_min": 0.5, + "contrast_max": 2.0, + "contrast_p": 0.5, + "brightness": 0.0, + "brightness_p": 0.5, + }, + "geometric": { + "rotation": 15.0, + "scale": 0.05, + "translate_width": 0.02, + "translate_height": 0.02, + "affine_p": 0.5, + "erase_scale_min": 0.0001, + "erase_scale_max": 0.01, + "erase_ratio_min": 1, + "erase_ratio_max": 1, + "erase_p": 0.5, + "mixup_lambda": None, + "mixup_p": 0.5, }, }, } @@ -423,7 +395,10 @@ def test_singleinstanceconfmapspipeline(minimal_instance): ) data_provider = LabelsReader(labels=labels) - datapipe = pipeline.make_training_pipeline(data_provider=data_provider) + datapipe = pipeline.make_training_pipeline( + data_provider=data_provider, + use_augmentations=base_singleinstance_data_config.use_augmentations_train, + ) sample = next(iter(datapipe)) @@ -434,7 +409,6 @@ def test_singleinstanceconfmapspipeline(minimal_instance): "instances", "confidence_maps", "orig_size", - "scale", ] for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())): @@ -448,39 +422,13 @@ def test_centroidconfmapspipeline(minimal_instance): """Test CentroidConfmapsPipeline class.""" base_centroid_data_config = OmegaConf.create( { - "max_height": None, - "max_width": None, - "scale": 1.0, - "is_rgb": False, - "preprocessing": {}, - "augmentation_config": { - "random_crop": {"random_crop_p": 0.0, "random_crop_hw": (160, 160)}, - "use_augmentations": False, - "augmentations": { - "intensity": { - "uniform_noise": (0.0, 0.04), - "uniform_noise_p": 0.5, - "gaussian_noise_mean": 0.02, - "gaussian_noise_std": 0.004, - "gaussian_noise_p": 0.5, - "contrast": (0.5, 2.0), - "contrast_p": 0.5, - "brightness": 0.0, - "brightness_p": 0.5, - }, - "geometric": { - "rotation": 15.0, - "scale": 0.05, - "translate": (0.02, 0.02), - "affine_p": 0.5, - "erase_scale": (0.0001, 0.01), - "erase_ratio": (1, 1), - "erase_p": 0.5, - "mixup_lambda": None, - "mixup_p": 0.5, - }, - }, + "preprocessing": { + "max_height": None, + "max_width": None, + "scale": 1.0, + "is_rgb": False, }, + "use_augmentations_train": False, } ) confmap_head = DictConfig({"sigma": 1.5, "output_stride": 2, "anchor_part": 0}) @@ -490,7 +438,10 @@ def test_centroidconfmapspipeline(minimal_instance): ) data_provider = LabelsReader(labels=sio.load_slp(minimal_instance)) - datapipe = pipeline.make_training_pipeline(data_provider=data_provider) + datapipe = pipeline.make_training_pipeline( + data_provider=data_provider, + use_augmentations=base_centroid_data_config.use_augmentations_train, + ) gt_sample_keys = [ "image", @@ -499,7 +450,6 @@ def test_centroidconfmapspipeline(minimal_instance): "centroids_confidence_maps", "orig_size", "num_instances", - "scale", ] sample = next(iter(datapipe)) assert len(sample.keys()) == len(gt_sample_keys) @@ -511,37 +461,45 @@ def test_centroidconfmapspipeline(minimal_instance): base_centroid_data_config = OmegaConf.create( { - "max_height": None, - "max_width": None, - "scale": 1.0, - "is_rgb": False, - "preprocessing": {}, + "preprocessing": { + "max_height": None, + "max_width": None, + "scale": 1.0, + "is_rgb": False, + }, + "use_augmentations_train": True, "augmentation_config": { - "random_crop": {"random_crop_p": 1.0, "random_crop_hw": (160, 160)}, - "use_augmentations": True, - "augmentations": { - "intensity": { - "uniform_noise": (0.0, 0.04), - "uniform_noise_p": 0.5, - "gaussian_noise_mean": 0.02, - "gaussian_noise_std": 0.004, - "gaussian_noise_p": 0.5, - "contrast": (0.5, 2.0), - "contrast_p": 0.5, - "brightness": 0.0, - "brightness_p": 0.5, - }, - "geometric": { - "rotation": 15.0, - "scale": 0.05, - "translate": (0.02, 0.02), - "affine_p": 0.5, - "erase_scale": (0.0001, 0.01), - "erase_ratio": (1, 1), - "erase_p": 0.5, - "mixup_lambda": None, - "mixup_p": 0.5, - }, + "random_crop": { + "random_crop_p": 1.0, + "crop_height": 160, + "crop_width": 160, + }, + "intensity": { + "uniform_noise_min": 0.0, + "uniform_noise_max": 0.04, + "uniform_noise_p": 0.5, + "gaussian_noise_mean": 0.02, + "gaussian_noise_std": 0.004, + "gaussian_noise_p": 0.5, + "contrast_min": 0.5, + "contrast_max": 2.0, + "contrast_p": 0.5, + "brightness": 0.0, + "brightness_p": 0.5, + }, + "geometric": { + "rotation": 15.0, + "scale": 0.05, + "translate_width": 0.02, + "translate_height": 0.02, + "affine_p": 0.5, + "erase_scale_min": 0.0001, + "erase_scale_max": 0.01, + "erase_ratio_min": 1, + "erase_ratio_max": 1, + "erase_p": 0.5, + "mixup_lambda": None, + "mixup_p": 0.5, }, }, } @@ -552,7 +510,10 @@ def test_centroidconfmapspipeline(minimal_instance): ) data_provider = LabelsReader(labels=sio.load_slp(minimal_instance)) - datapipe = pipeline.make_training_pipeline(data_provider=data_provider) + datapipe = pipeline.make_training_pipeline( + data_provider=data_provider, + use_augmentations=base_centroid_data_config.use_augmentations_train, + ) gt_sample_keys = [ "image", @@ -561,7 +522,6 @@ def test_centroidconfmapspipeline(minimal_instance): "centroids_confidence_maps", "orig_size", "num_instances", - "scale", ] sample = next(iter(datapipe)) @@ -577,39 +537,13 @@ def test_bottomuppipeline(minimal_instance): """Test BottomUpPipeline class.""" base_bottom_config = OmegaConf.create( { - "max_height": None, - "max_width": None, - "scale": 1.0, - "is_rgb": False, - "preprocessing": {}, - "augmentation_config": { - "random_crop": {"random_crop_p": 0.0, "random_crop_hw": (160, 160)}, - "use_augmentations": False, - "augmentations": { - "intensity": { - "uniform_noise": (0.0, 0.04), - "uniform_noise_p": 0.5, - "gaussian_noise_mean": 0.02, - "gaussian_noise_std": 0.004, - "gaussian_noise_p": 0.5, - "contrast": (0.5, 2.0), - "contrast_p": 0.5, - "brightness": 0.0, - "brightness_p": 0.5, - }, - "geometric": { - "rotation": 15.0, - "scale": 0.05, - "translate": (0.02, 0.02), - "affine_p": 0.5, - "erase_scale": (0.0001, 0.01), - "erase_ratio": (1, 1), - "erase_p": 0.5, - "mixup_lambda": None, - "mixup_p": 0.5, - }, - }, + "preprocessing": { + "max_height": None, + "max_width": None, + "scale": 1.0, + "is_rgb": False, }, + "use_augmentations_train": False, } ) @@ -624,7 +558,10 @@ def test_bottomuppipeline(minimal_instance): ) data_provider = LabelsReader(labels=sio.load_slp(minimal_instance)) - datapipe = pipeline.make_training_pipeline(data_provider=data_provider) + datapipe = pipeline.make_training_pipeline( + data_provider=data_provider, + use_augmentations=base_bottom_config.use_augmentations_train, + ) gt_sample_keys = [ "image", @@ -633,7 +570,6 @@ def test_bottomuppipeline(minimal_instance): "confidence_maps", "orig_size", "num_instances", - "scale", "part_affinity_fields", ] sample = next(iter(datapipe)) @@ -648,39 +584,13 @@ def test_bottomuppipeline(minimal_instance): # with scaling base_bottom_config = OmegaConf.create( { - "max_height": None, - "max_width": None, - "scale": 0.5, - "is_rgb": False, - "preprocessing": {}, - "augmentation_config": { - "random_crop": {"random_crop_p": 0.0, "random_crop_hw": (160, 160)}, - "use_augmentations": False, - "augmentations": { - "intensity": { - "uniform_noise": (0.0, 0.04), - "uniform_noise_p": 0.5, - "gaussian_noise_mean": 0.02, - "gaussian_noise_std": 0.004, - "gaussian_noise_p": 0.5, - "contrast": (0.5, 2.0), - "contrast_p": 0.5, - "brightness": 0.0, - "brightness_p": 0.5, - }, - "geometric": { - "rotation": 15.0, - "scale": 0.05, - "translate": (0.02, 0.02), - "affine_p": 0.5, - "erase_scale": (0.0001, 0.01), - "erase_ratio": (1, 1), - "erase_p": 0.5, - "mixup_lambda": None, - "mixup_p": 0.5, - }, - }, + "preprocessing": { + "max_height": None, + "max_width": None, + "scale": 0.5, + "is_rgb": False, }, + "use_augmentations_train": False, } ) @@ -692,7 +602,10 @@ def test_bottomuppipeline(minimal_instance): ) data_provider = LabelsReader(labels=sio.load_slp(minimal_instance)) - datapipe = pipeline.make_training_pipeline(data_provider=data_provider) + datapipe = pipeline.make_training_pipeline( + data_provider=data_provider, + use_augmentations=base_bottom_config.use_augmentations_train, + ) gt_sample_keys = [ "image", @@ -701,7 +614,6 @@ def test_bottomuppipeline(minimal_instance): "confidence_maps", "orig_size", "num_instances", - "scale", "part_affinity_fields", ] sample = next(iter(datapipe)) @@ -716,37 +628,45 @@ def test_bottomuppipeline(minimal_instance): # with padding base_bottom_config = OmegaConf.create( { - "max_height": None, - "max_width": None, - "scale": 1.0, - "is_rgb": False, - "preprocessing": {}, + "preprocessing": { + "max_height": None, + "max_width": None, + "scale": 1.0, + "is_rgb": False, + }, + "use_augmentations_train": True, "augmentation_config": { - "random_crop": {"random_crop_p": 1.0, "random_crop_hw": (100, 100)}, - "use_augmentations": False, - "augmentations": { - "intensity": { - "uniform_noise": (0.0, 0.04), - "uniform_noise_p": 0.5, - "gaussian_noise_mean": 0.02, - "gaussian_noise_std": 0.004, - "gaussian_noise_p": 0.5, - "contrast": (0.5, 2.0), - "contrast_p": 0.5, - "brightness": 0.0, - "brightness_p": 0.5, - }, - "geometric": { - "rotation": 15.0, - "scale": 0.05, - "translate": (0.02, 0.02), - "affine_p": 0.5, - "erase_scale": (0.0001, 0.01), - "erase_ratio": (1, 1), - "erase_p": 0.5, - "mixup_lambda": None, - "mixup_p": 0.5, - }, + "random_crop": { + "random_crop_p": 1.0, + "crop_height": 100, + "crop_width": 100, + }, + "intensity": { + "uniform_noise_min": 0.0, + "uniform_noise_max": 0.04, + "uniform_noise_p": 0.5, + "gaussian_noise_mean": 0.02, + "gaussian_noise_std": 0.004, + "gaussian_noise_p": 0.5, + "contrast_min": 0.5, + "contrast_max": 2.0, + "contrast_p": 0.5, + "brightness": 0.0, + "brightness_p": 0.5, + }, + "geometric": { + "rotation": 15.0, + "scale": 0.05, + "translate_width": 0.02, + "translate_height": 0.02, + "affine_p": 0.5, + "erase_scale_min": 0.0001, + "erase_scale_max": 0.01, + "erase_ratio_min": 1, + "erase_ratio_max": 1, + "erase_p": 0.5, + "mixup_lambda": None, + "mixup_p": 0.5, }, }, } @@ -760,7 +680,10 @@ def test_bottomuppipeline(minimal_instance): ) data_provider = LabelsReader(labels=sio.load_slp(minimal_instance)) - datapipe = pipeline.make_training_pipeline(data_provider=data_provider) + datapipe = pipeline.make_training_pipeline( + data_provider=data_provider, + use_augmentations=base_bottom_config.use_augmentations_train, + ) gt_sample_keys = [ "image", @@ -769,7 +692,6 @@ def test_bottomuppipeline(minimal_instance): "confidence_maps", "orig_size", "num_instances", - "scale", "part_affinity_fields", ] @@ -785,37 +707,45 @@ def test_bottomuppipeline(minimal_instance): # with random crop base_bottom_config = OmegaConf.create( { - "max_height": None, - "max_width": None, - "scale": 1.0, - "is_rgb": False, - "preprocessing": {}, + "preprocessing": { + "max_height": None, + "max_width": None, + "scale": 1.0, + "is_rgb": False, + }, + "use_augmentations_train": True, "augmentation_config": { - "random_crop": {"random_crop_p": 1.0, "random_crop_hw": (160, 160)}, - "use_augmentations": True, - "augmentations": { - "intensity": { - "uniform_noise": (0.0, 0.04), - "uniform_noise_p": 0.5, - "gaussian_noise_mean": 0.02, - "gaussian_noise_std": 0.004, - "gaussian_noise_p": 0.5, - "contrast": (0.5, 2.0), - "contrast_p": 0.5, - "brightness": 0.0, - "brightness_p": 0.5, - }, - "geometric": { - "rotation": 15.0, - "scale": 0.05, - "translate": (0.02, 0.02), - "affine_p": 0.5, - "erase_scale": (0.0001, 0.01), - "erase_ratio": (1, 1), - "erase_p": 0.5, - "mixup_lambda": None, - "mixup_p": 0.5, - }, + "random_crop": { + "random_crop_p": 1.0, + "crop_height": 160, + "crop_width": 160, + }, + "intensity": { + "uniform_noise_min": 0.0, + "uniform_noise_max": 0.04, + "uniform_noise_p": 0.5, + "gaussian_noise_mean": 0.02, + "gaussian_noise_std": 0.004, + "gaussian_noise_p": 0.5, + "contrast_min": 0.5, + "contrast_max": 2.0, + "contrast_p": 0.5, + "brightness": 0.0, + "brightness_p": 0.5, + }, + "geometric": { + "rotation": 15.0, + "scale": 0.05, + "translate_width": 0.02, + "translate_height": 0.02, + "affine_p": 0.5, + "erase_scale_min": 0.0001, + "erase_scale_max": 0.01, + "erase_ratio_min": 1, + "erase_ratio_max": 1, + "erase_p": 0.5, + "mixup_lambda": None, + "mixup_p": 0.5, }, }, } @@ -829,7 +759,10 @@ def test_bottomuppipeline(minimal_instance): ) data_provider = LabelsReader(labels=sio.load_slp(minimal_instance)) - datapipe = pipeline.make_training_pipeline(data_provider=data_provider) + datapipe = pipeline.make_training_pipeline( + data_provider=data_provider, + use_augmentations=base_bottom_config.use_augmentations_train, + ) gt_sample_keys = [ "image", @@ -838,7 +771,6 @@ def test_bottomuppipeline(minimal_instance): "confidence_maps", "orig_size", "num_instances", - "scale", "part_affinity_fields", ] diff --git a/tests/data/test_resizing.py b/tests/data/test_resizing.py index 555d8c34..f1c831d8 100644 --- a/tests/data/test_resizing.py +++ b/tests/data/test_resizing.py @@ -47,7 +47,6 @@ def test_resizer(minimal_instance): sample = next(iter(pipe)) image = sample["image"] assert image.shape == torch.Size([1, 1, 768, 768]) - assert sample["scale"] == 2 assert "original_image" not in sample.keys() pipe = Resizer(l, scale=2, keep_original=True) diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index 2d00cdb0..edd8e583 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -49,112 +49,48 @@ def config(sleap_data_dir): { "data_config": { "provider": "LabelsReader", - "pipeline": "TopdownConfmaps", - "train": { - "labels_path": f"{sleap_data_dir}/minimal_instance.pkg.slp", + "train_labels_path": f"{sleap_data_dir}/minimal_instance.pkg.slp", + "val_labels_path": f"{sleap_data_dir}/minimal_instance.pkg.slp", + "preprocessing": { "is_rgb": False, "max_width": None, "max_height": None, "scale": 1.0, - "preprocessing": { - "crop_hw": [160, 160], - }, - "augmentation_config": { - "random_crop": { - "random_crop_p": 0, - "random_crop_hw": [160, 160], - }, - "use_augmentations": False, - "augmentations": { - "intensity": { - "uniform_noise": [0.0, 0.04], - "uniform_noise_p": 0, - "gaussian_noise_mean": 0.02, - "gaussian_noise_std": 0.004, - "gaussian_noise_p": 0, - "contrast": [0.5, 2.0], - "contrast_p": 0, - "brightness": 0.0, - "brightness_p": 0, - }, - "geometric": { - "rotation": 180.0, - "scale": 0, - "translate": [0, 0], - "affine_p": 0.5, - "erase_scale": [0.0001, 0.01], - "erase_ratio": [1, 1], - "erase_p": 0, - "mixup_lambda": None, - "mixup_p": 0, - }, - }, - }, + "crop_hw": [160, 160], }, - "val": { - "labels_path": f"{sleap_data_dir}/minimal_instance.pkg.slp", - "is_rgb": False, - "max_width": None, - "max_height": None, - "scale": 1.0, - "preprocessing": { - "crop_hw": [160, 160], - }, - "augmentation_config": { - "random_crop": { - "random_crop_p": 0, - "random_crop_hw": [160, 160], - }, - "use_augmentations": False, - "augmentations": { - "intensity": { - "uniform_noise": [0.0, 0.04], - "uniform_noise_p": 0, - "gaussian_noise_mean": 0.02, - "gaussian_noise_std": 0.004, - "gaussian_noise_p": 0, - "contrast": [0.5, 2.0], - "contrast_p": 0, - "brightness": 0.0, - "brightness_p": 0, - }, - "geometric": { - "rotation": 180.0, - "scale": 0, - "translate": [0, 0], - "affine_p": 0.5, - "erase_scale": [0.0001, 0.01], - "erase_ratio": [1, 1], - "erase_p": 0, - "mixup_lambda": None, - "mixup_p": 0, - }, - }, + "use_augmentations_train": True, + "augmentation_config": { + "geometric": { + "rotation": 180.0, + "scale": None, + "translate_width": 0, + "translate_height": 0, + "affine_p": 0.5, }, }, }, "model_config": { "init_weights": "default", "pre_trained_weights": None, + "backbone_type": "unet", "backbone_config": { - "backbone_type": "unet", - "backbone_config": { - "in_channels": 1, - "kernel_size": 3, - "filters": 16, - "filters_rate": 2, - "max_stride": 16, - "convs_per_block": 2, - "stacks": 1, - "stem_stride": None, - "middle_block": True, - "up_interpolate": True, - }, + "in_channels": 1, + "kernel_size": 3, + "filters": 16, + "filters_rate": 2, + "max_stride": 16, + "convs_per_block": 2, + "stacks": 1, + "stem_stride": None, + "middle_block": True, + "up_interpolate": True, }, "head_configs": { - "confmaps": { - "head_type": "CenteredInstanceConfmapsHead", - "head_config": { + "single_instance": None, + "centroid": None, + "bottomup": None, + "centered_instance": { + "confmaps": { "part_names": [ "0", "1", @@ -162,9 +98,8 @@ def config(sleap_data_dir): "anchor_part": 1, "sigma": 1.5, "output_stride": 2, - "loss_weight": 1.0, - }, - } + } + }, }, }, "trainer_config": { @@ -186,7 +121,6 @@ def config(sleap_data_dir): "min_delta": 1e-08, "patience": 20, }, - "device": "cpu", "trainer_devices": 1, "trainer_accelerator": "cpu", "enable_progress_bar": False, @@ -206,7 +140,7 @@ def config(sleap_data_dir): "trainer_config.optimizer_name", "trainer_config.optimizer.amsgrad", "trainer_config.optimizer.lr", - "model_config.backbone_config.backbone_type", + "model_config.backbone_type", "model_config.init_weights", ], }, diff --git a/tests/inference/test_bottomup.py b/tests/inference/test_bottomup.py index 38367ab5..3659e5d3 100644 --- a/tests/inference/test_bottomup.py +++ b/tests/inference/test_bottomup.py @@ -14,12 +14,12 @@ def test_bottomup_inference_model(minimal_instance_bottomup_ckpt): ) OmegaConf.update( train_config, - "data_config.train.labels_path", + "data_config.train_labels_path", "./tests/assets/minimal_instance.pkg.slp", ) OmegaConf.update( train_config, - "data_config.val.labels_path", + "data_config.val_labels_path", "./tests/assets/minimal_instance.pkg.slp", ) # get dataloader @@ -28,7 +28,10 @@ def test_bottomup_inference_model(minimal_instance_bottomup_ckpt): loader = trainer.val_data_loader torch_model = BottomUpModel.load_from_checkpoint( - f"{minimal_instance_bottomup_ckpt}/best.ckpt", config=train_config + f"{minimal_instance_bottomup_ckpt}/best.ckpt", + config=train_config, + skeletons=None, + model_type="bottomup", ) inference_layer = BottomUpInferenceModel( @@ -36,10 +39,10 @@ def test_bottomup_inference_model(minimal_instance_bottomup_ckpt): paf_scorer=PAFScorer.from_config( config=OmegaConf.create( { - "confmaps": train_config.model_config.head_configs[ + "confmaps": train_config.model_config.head_configs.bottomup[ "confmaps" - ].head_config, - "pafs": train_config.model_config.head_configs["pafs"].head_config, + ], + "pafs": train_config.model_config.head_configs.bottomup["pafs"], } ) ), @@ -63,10 +66,10 @@ def test_bottomup_inference_model(minimal_instance_bottomup_ckpt): paf_scorer=PAFScorer.from_config( config=OmegaConf.create( { - "confmaps": train_config.model_config.head_configs[ + "confmaps": train_config.model_config.head_configs.bottomup[ "confmaps" - ].head_config, - "pafs": train_config.model_config.head_configs["pafs"].head_config, + ], + "pafs": train_config.model_config.head_configs.bottomup["pafs"], } ) ), diff --git a/tests/inference/test_predictors.py b/tests/inference/test_predictors.py index b7965c14..03f3ea0c 100644 --- a/tests/inference/test_predictors.py +++ b/tests/inference/test_predictors.py @@ -54,9 +54,11 @@ def test_topdown_predictor( # if model parameter is not set right with pytest.raises(ValueError): config = OmegaConf.load(f"{minimal_instance_ckpt}/training_config.yaml") - model_name = config.model_config.head_configs["confmaps"].head_type - config.model_config.head_configs["confmaps"].head_type = "instance" - OmegaConf.save(config, f"{minimal_instance_ckpt}/training_config.yaml") + config_copy = config.copy() + head_config = config_copy.model_config.head_configs.centered_instance + del config_copy.model_config.head_configs.centered_instance + OmegaConf.update(config_copy, "model_config.head_configs.topdown", head_config) + OmegaConf.save(config_copy, f"{minimal_instance_ckpt}/training_config.yaml") preds = main( model_paths=[minimal_instance_ckpt], data_path="./tests/assets/minimal_instance.pkg.slp", @@ -64,8 +66,6 @@ def test_topdown_predictor( make_labels=False, ) - config = OmegaConf.load(f"{minimal_instance_ckpt}/training_config.yaml") - config.model_config.head_configs["confmaps"].head_type = model_name OmegaConf.save(config, f"{minimal_instance_ckpt}/training_config.yaml") # centroid + centroid instance model @@ -75,7 +75,7 @@ def test_topdown_predictor( provider="LabelsReader", make_labels=True, max_instances=6, - peak_threshold=0.0, + peak_threshold=[0.0, 0.0], integral_refinement="integral", ) assert isinstance(pred_labels, sio.Labels) @@ -103,7 +103,7 @@ def test_topdown_predictor( provider="VideoReader", make_labels=True, max_instances=6, - peak_threshold=0.0, + peak_threshold=[0.0, 0.0], integral_refinement="integral", videoreader_start_idx=0, videoreader_end_idx=100, @@ -165,11 +165,14 @@ def test_single_instance_predictor(minimal_instance, minimal_instance_ckpt): config = _config.copy() try: - OmegaConf.update(config, "data_config.pipeline", "SingleInstanceConfmaps") - config.model_config.head_configs["confmaps"].head_type = ( - "SingleInstanceConfmapsHead" + head_config = config.model_config.head_configs.centered_instance + del config.model_config.head_configs.centered_instance + OmegaConf.update( + config, "model_config.head_configs.single_instance", head_config ) - del config.model_config.head_configs["confmaps"].head_config.anchor_part + del config.model_config.head_configs.single_instance.confmaps.anchor_part + OmegaConf.update(config, "data_config.preprocessing.scale", 0.9) + OmegaConf.save(config, f"{minimal_instance_ckpt}/training_config.yaml") # check if labels are created from ckpt @@ -182,7 +185,6 @@ def test_single_instance_predictor(minimal_instance, minimal_instance_ckpt): peak_threshold=0.3, max_height=500, max_width=500, - scale=0.9, ) assert isinstance(pred_labels, sio.Labels) assert len(pred_labels) == 1 @@ -205,7 +207,6 @@ def test_single_instance_predictor(minimal_instance, minimal_instance_ckpt): peak_threshold=0.3, max_height=500, max_width=500, - scale=0.9, ) assert isinstance(preds, list) assert len(preds) == 1 @@ -222,11 +223,14 @@ def test_single_instance_predictor(minimal_instance, minimal_instance_ckpt): config = _config.copy() try: - OmegaConf.update(config, "data_config.pipeline", "SingleInstanceConfmaps") - config.model_config.head_configs["confmaps"].head_type = ( - "SingleInstanceConfmapsHead" + head_config = config.model_config.head_configs.centered_instance + del config.model_config.head_configs.centered_instance + OmegaConf.update( + config, "model_config.head_configs.single_instance", head_config ) - del config.model_config.head_configs["confmaps"].head_config.anchor_part + del config.model_config.head_configs.single_instance.confmaps.anchor_part + OmegaConf.update(config, "data_config.preprocessing.scale", 0.9) + OmegaConf.save(config, f"{minimal_instance_ckpt}/training_config.yaml") # check if labels are created from ckpt @@ -236,7 +240,6 @@ def test_single_instance_predictor(minimal_instance, minimal_instance_ckpt): provider="VideoReader", make_labels=True, peak_threshold=0.3, - scale=0.9, ) assert isinstance(pred_labels, sio.Labels) assert len(pred_labels) == 100 @@ -255,7 +258,6 @@ def test_single_instance_predictor(minimal_instance, minimal_instance_ckpt): provider="VideoReader", make_labels=False, peak_threshold=0.3, - scale=0.9, ) assert isinstance(preds, list) assert len(preds) == 25 @@ -274,11 +276,12 @@ def test_single_instance_predictor(minimal_instance, minimal_instance_ckpt): config = _config.copy() try: - OmegaConf.update(config, "data_config.pipeline", "SingleInstanceConfmaps") - config.model_config.head_configs["confmaps"].head_type = ( - "SingleInstanceConfmapsHead" + head_config = config.model_config.head_configs.centered_instance + del config.model_config.head_configs.centered_instance + OmegaConf.update( + config, "model_config.head_configs.single_instance", head_config ) - del config.model_config.head_configs["confmaps"].head_config.anchor_part + del config.model_config.head_configs.single_instance.confmaps.anchor_part OmegaConf.save(config, f"{minimal_instance_ckpt}/training_config.yaml") # check if labels are created from ckpt @@ -292,7 +295,6 @@ def test_single_instance_predictor(minimal_instance, minimal_instance_ckpt): provider="Reader", make_labels=False, peak_threshold=0.3, - scale=0.9, ) finally: diff --git a/tests/inference/test_single_instance.py b/tests/inference/test_single_instance.py index f6d69ea1..d7977dcc 100644 --- a/tests/inference/test_single_instance.py +++ b/tests/inference/test_single_instance.py @@ -18,14 +18,16 @@ def test_single_instance_inference_model( config, minimal_instance, minimal_instance_ckpt ): """Test SingleInstanceInferenceModel.""" - OmegaConf.update(config, "data_config.pipeline", "SingleInstanceConfmaps") - config.model_config.head_configs["confmaps"].head_type = ( - "SingleInstanceConfmapsHead" - ) - del config.model_config.head_configs["confmaps"].head_config.anchor_part + head_config = config.model_config.head_configs.centered_instance + del config.model_config.head_configs.centered_instance + OmegaConf.update(config, "model_config.head_configs.single_instance", head_config) + del config.model_config.head_configs.single_instance.confmaps.anchor_part torch_model = SingleInstanceModel.load_from_checkpoint( - f"{minimal_instance_ckpt}/best.ckpt", config=config + f"{minimal_instance_ckpt}/best.ckpt", + config=config, + skeletons=None, + model_type="single_instance", ) labels = sio.load_slp(minimal_instance) diff --git a/tests/inference/test_topdown.py b/tests/inference/test_topdown.py index 8fe3b934..8a59808d 100644 --- a/tests/inference/test_topdown.py +++ b/tests/inference/test_topdown.py @@ -28,7 +28,10 @@ def initialize_model(config, minimal_instance, minimal_instance_ckpt): # for centered instance model config = OmegaConf.load(f"{minimal_instance_ckpt}/training_config.yaml") torch_model = TopDownCenteredInstanceModel.load_from_checkpoint( - f"{minimal_instance_ckpt}/best.ckpt", config=config + f"{minimal_instance_ckpt}/best.ckpt", + config=config, + skeletons=None, + model_type="centered_instance", ) data_provider = LabelsReader.from_filename(minimal_instance) @@ -68,9 +71,14 @@ def initialize_model(config, minimal_instance, minimal_instance_ckpt): def test_centroid_inference_model(config): """Test CentroidCrop class to run inference on centroid models.""" - OmegaConf.update(config, "data_config.pipeline", "CentroidConfmaps") - config.model_config.head_configs["confmaps"].head_type = "CentroidConfmapsHead" - del config.model_config.head_configs["confmaps"].head_config.part_names + + OmegaConf.update( + config, + "model_config.head_configs.centroid", + config.model_config.head_configs.centered_instance, + ) + del config.model_config.head_configs.centered_instance + del config.model_config.head_configs.centroid["confmaps"].part_names trainer = ModelTrainer(config) trainer._create_data_loaders() @@ -168,12 +176,19 @@ def test_find_instance_peaks_groundtruth( batch_size=4, ) - OmegaConf.update(config, "data_config.pipeline", "CentroidConfmaps") - config.model_config.head_configs["confmaps"].head_type = "CentroidConfmapsHead" - del config.model_config.head_configs["confmaps"].head_config.part_names + OmegaConf.update( + config, + "model_config.head_configs.centroid", + config.model_config.head_configs.centered_instance, + ) + del config.model_config.head_configs.centered_instance + del config.model_config.head_configs.centroid["confmaps"].part_names config = OmegaConf.load(f"{minimal_instance_centroid_ckpt}/training_config.yaml") model = CentroidModel.load_from_checkpoint( - f"{minimal_instance_centroid_ckpt}/best.ckpt", config=config + f"{minimal_instance_centroid_ckpt}/best.ckpt", + config=config, + skeletons=None, + model_type="centroid", ) layer = CentroidCrop( @@ -223,11 +238,8 @@ def test_find_instance_peaks(config, minimal_instance, minimal_instance_ckpt): outputs = [] for x in data_pipeline: outputs.append(find_peaks_layer(x)) - print(f"outputs len: {len(outputs)}") for i in outputs: instance = i["pred_instance_peaks"].numpy() - print(f"imgs: {i['instance_image'].shape}") - print(f"pred vals: {i['pred_peak_values']}") assert np.all(np.isnan(instance)) # check return confmaps @@ -240,7 +252,6 @@ def test_find_instance_peaks(config, minimal_instance, minimal_instance_ckpt): ) outputs = [] for x in data_pipeline: - x["image"] = resize_image(x["image"], 0.5) outputs.append(find_peaks_layer(x)) assert "pred_confmaps" in outputs[0].keys() assert outputs[0]["pred_confmaps"].shape[-2:] == (40, 40) @@ -300,7 +311,10 @@ def test_topdown_inference_model( # centroid layer and find peaks config = OmegaConf.load(f"{minimal_instance_centroid_ckpt}/training_config.yaml") torch_model = CentroidModel.load_from_checkpoint( - f"{minimal_instance_centroid_ckpt}/best.ckpt", config=config + f"{minimal_instance_centroid_ckpt}/best.ckpt", + config=config, + skeletons=None, + model_type="centroid", ) data_provider = LabelsReader.from_filename(minimal_instance, instances_key=True) diff --git a/tests/training/test_model_trainer.py b/tests/training/test_model_trainer.py index fb0ab84d..77483e0a 100644 --- a/tests/training/test_model_trainer.py +++ b/tests/training/test_model_trainer.py @@ -40,19 +40,31 @@ def test_create_data_loader(config, tmp_path: str): assert len(list(iter(model_trainer.train_data_loader))) == 2 assert len(list(iter(model_trainer.val_data_loader))) == 2 - OmegaConf.update(config, "data_config.pipeline", "TopDown") - model_trainer = ModelTrainer(config) - with pytest.raises(Exception, match="TopDown is not defined."): + # test exception + config_copy = config.copy() + head_config = config_copy.model_config.head_configs.centered_instance + del config_copy.model_config.head_configs.centered_instance + OmegaConf.update(config_copy, "model_config.head_configs.topdown", head_config) + model_trainer = ModelTrainer(config_copy) + with pytest.raises(Exception): model_trainer._create_data_loaders() - OmegaConf.update(config, "data_config.pipeline", "SingleInstanceConfmaps") - model_trainer = ModelTrainer(config) + # test single instance pipeline + config_copy = config.copy() + del config_copy.model_config.head_configs.centered_instance + OmegaConf.update( + config_copy, "model_config.head_configs.single_instance", head_config + ) + model_trainer = ModelTrainer(config_copy) model_trainer._create_data_loaders() assert len(list(iter(model_trainer.train_data_loader))) == 1 assert len(list(iter(model_trainer.val_data_loader))) == 1 - OmegaConf.update(config, "data_config.pipeline", "CentroidConfmaps") - model_trainer = ModelTrainer(config) + # test centroid pipeline + config_copy = config.copy() + del config_copy.model_config.head_configs.centered_instance + OmegaConf.update(config_copy, "model_config.head_configs.centroid", head_config) + model_trainer = ModelTrainer(config_copy) model_trainer._create_data_loaders() assert len(list(iter(model_trainer.train_data_loader))) == 1 assert len(list(iter(model_trainer.val_data_loader))) == 1 @@ -70,7 +82,7 @@ def test_wandb(): def test_trainer(config, tmp_path: str): - # for topdown centered instance model + # # for topdown centered instance model model_trainer = ModelTrainer(config) OmegaConf.update( config, "trainer_config.save_ckpt_path", f"{tmp_path}/test_model_trainer/" @@ -179,15 +191,14 @@ def test_trainer(config, tmp_path: str): # For Single instance model single_instance_config = config.copy() + head_config = single_instance_config.model_config.head_configs.centered_instance + del single_instance_config.model_config.head_configs.centered_instance OmegaConf.update( - single_instance_config, "data_config.pipeline", "SingleInstanceConfmaps" + single_instance_config, "model_config.head_configs.single_instance", head_config ) - single_instance_config.model_config.head_configs["confmaps"].head_type = ( - "SingleInstanceConfmapsHead" + del ( + single_instance_config.model_config.head_configs.single_instance.confmaps.anchor_part ) - del single_instance_config.model_config.head_configs[ - "confmaps" - ].head_config.anchor_part trainer = ModelTrainer(single_instance_config) trainer._create_data_loaders() @@ -196,15 +207,12 @@ def test_trainer(config, tmp_path: str): # Centroid model centroid_config = config.copy() - OmegaConf.update(centroid_config, "data_config.pipeline", "CentroidConfmaps") - centroid_config.model_config.head_configs["confmaps"].head_type = ( - "CentroidConfmapsHead" - ) - - del centroid_config.model_config.head_configs["confmaps"].head_config.part_names + OmegaConf.update(centroid_config, "model_config.head_configs.centroid", head_config) + del centroid_config.model_config.head_configs.centered_instance + del centroid_config.model_config.head_configs.centroid["confmaps"].part_names - if Path(config.trainer_config.save_ckpt_path).exists(): - shutil.rmtree(config.trainer_config.save_ckpt_path) + if Path(centroid_config.trainer_config.save_ckpt_path).exists(): + shutil.rmtree(centroid_config.trainer_config.save_ckpt_path) OmegaConf.update(centroid_config, "trainer_config.save_ckpt", True) OmegaConf.update(centroid_config, "trainer_config.use_wandb", False) @@ -227,21 +235,17 @@ def test_trainer(config, tmp_path: str): # bottom up model bottomup_config = config.copy() - OmegaConf.update(bottomup_config, "data_config.pipeline", "BottomUp") - bottomup_config.model_config.head_configs["confmaps"].head_type = ( - "MultiInstanceConfmapsHead" - ) + OmegaConf.update(bottomup_config, "model_config.head_configs.bottomup", head_config) paf = { - "head_type": "PartAffinityFieldsHead", - "head_config": { - "edges": [("part1", "part2")], - "sigma": 4, - "output_stride": 4, - "loss_weight": 1.0, - }, + "edges": [("part1", "part2")], + "sigma": 4, + "output_stride": 4, + "loss_weight": 1.0, } - del bottomup_config.model_config.head_configs["confmaps"].head_config.anchor_part - bottomup_config.model_config.head_configs["pafs"] = paf + del bottomup_config.model_config.head_configs.bottomup["confmaps"].anchor_part + del bottomup_config.model_config.head_configs.centered_instance + bottomup_config.model_config.head_configs.bottomup["pafs"] = paf + bottomup_config.model_config.head_configs.bottomup.confmaps.loss_weight = 1.0 if Path(bottomup_config.trainer_config.save_ckpt_path).exists(): shutil.rmtree(bottomup_config.trainer_config.save_ckpt_path) @@ -269,7 +273,7 @@ def test_trainer(config, tmp_path: str): def test_topdown_centered_instance_model(config, tmp_path: str): # unet - model = TopDownCenteredInstanceModel(config) + model = TopDownCenteredInstanceModel(config, None, "centered_instance") OmegaConf.update( config, "trainer_config.save_ckpt_path", f"{tmp_path}/test_model_trainer/" ) @@ -290,10 +294,10 @@ def test_topdown_centered_instance_model(config, tmp_path: str): OmegaConf.update( config, "model_config.pre_trained_weights", "ConvNeXt_Tiny_Weights" ) - OmegaConf.update(config, "model_config.backbone_config.backbone_type", "convnext") + OmegaConf.update(config, "model_config.backbone_type", "convnext") OmegaConf.update( config, - "model_config.backbone_config.backbone_config", + "model_config.backbone_config", { "in_channels": 1, "model_type": "tiny", @@ -306,7 +310,7 @@ def test_topdown_centered_instance_model(config, tmp_path: str): "stem_patch_stride": 2, }, ) - model = TopDownCenteredInstanceModel(config) + model = TopDownCenteredInstanceModel(config, None, "centered_instance") OmegaConf.update( config, "trainer_config.save_ckpt_path", f"{tmp_path}/test_model_trainer/" ) @@ -330,11 +334,15 @@ def test_topdown_centered_instance_model(config, tmp_path: str): def test_centroid_model(config, tmp_path: str): """Test CentroidModel training.""" - OmegaConf.update(config, "data_config.pipeline", "CentroidConfmaps") - config.model_config.head_configs["confmaps"].head_type = "CentroidConfmapsHead" - del config.model_config.head_configs["confmaps"].head_config.part_names + OmegaConf.update( + config, + "model_config.head_configs.centroid", + config.model_config.head_configs.centered_instance, + ) + del config.model_config.head_configs.centered_instance + del config.model_config.head_configs.centroid["confmaps"].part_names - model = CentroidModel(config) + model = CentroidModel(config, None, "centroid") OmegaConf.update( config, "trainer_config.save_ckpt_path", f"{tmp_path}/test_model_trainer/" @@ -355,12 +363,12 @@ def test_centroid_model(config, tmp_path: str): def test_single_instance_model(config, tmp_path: str): """Test the SingleInstanceModel training.""" - OmegaConf.update(config, "data_config.pipeline", "SingleInstanceConfmaps") + head_config = config.model_config.head_configs.centered_instance + del config.model_config.head_configs.centered_instance + OmegaConf.update(config, "model_config.head_configs.single_instance", head_config) + del config.model_config.head_configs.single_instance.confmaps.anchor_part + OmegaConf.update(config, "model_config.init_weights", "xavier") - config.model_config.head_configs["confmaps"].head_type = ( - "SingleInstanceConfmapsHead" - ) - del config.model_config.head_configs["confmaps"].head_config.anchor_part OmegaConf.update( config, "trainer_config.save_ckpt_path", f"{tmp_path}/test_model_trainer/" @@ -368,7 +376,7 @@ def test_single_instance_model(config, tmp_path: str): model_trainer = ModelTrainer(config) model_trainer._create_data_loaders() input_ = next(iter(model_trainer.train_data_loader)) - model = SingleInstanceModel(config) + model = SingleInstanceModel(config, None, "single_instance") img = input_["image"] img_shape = img.shape[-2:] @@ -380,11 +388,11 @@ def test_single_instance_model(config, tmp_path: str): 2, int( img_shape[0] - / config.model_config.head_configs.confmaps.head_config.output_stride + / config.model_config.head_configs.single_instance.confmaps.output_stride ), int( img_shape[1] - / config.model_config.head_configs.confmaps.head_config.output_stride + / config.model_config.head_configs.single_instance.confmaps.output_stride ), ) @@ -398,19 +406,18 @@ def test_bottomup_model(config, tmp_path: str): """Test BottomUp model training.""" config_copy = config.copy() - OmegaConf.update(config, "data_config.pipeline", "BottomUp") - config.model_config.head_configs["confmaps"].head_type = "MultiInstanceConfmapsHead" + head_config = config.model_config.head_configs.centered_instance + OmegaConf.update(config, "model_config.head_configs.bottomup", head_config) paf = { - "head_type": "PartAffinityFieldsHead", - "head_config": { - "edges": [("part1", "part2")], - "sigma": 4, - "output_stride": 4, - "loss_weight": 1.0, - }, + "edges": [("part1", "part2")], + "sigma": 4, + "output_stride": 4, + "loss_weight": 1.0, } - del config.model_config.head_configs["confmaps"].head_config.anchor_part - config.model_config.head_configs["pafs"] = paf + del config.model_config.head_configs.bottomup["confmaps"].anchor_part + del config.model_config.head_configs.centered_instance + config.model_config.head_configs.bottomup["pafs"] = paf + config.model_config.head_configs.bottomup.confmaps.loss_weight = 1.0 OmegaConf.update( config, "trainer_config.save_ckpt_path", f"{tmp_path}/test_model_trainer/" @@ -419,7 +426,7 @@ def test_bottomup_model(config, tmp_path: str): model_trainer._create_data_loaders() input_ = next(iter(model_trainer.train_data_loader)) - model = BottomUpModel(config) + model = BottomUpModel(config, None, "bottomup") preds = model(input_["image"]) @@ -430,20 +437,19 @@ def test_bottomup_model(config, tmp_path: str): # with edges as None config = config_copy - OmegaConf.update(config, "data_config.pipeline", "BottomUp") - config.model_config.head_configs["confmaps"].head_type = "MultiInstanceConfmapsHead" - config.model_config.head_configs["confmaps"].head_config.part_names = None + head_config = config.model_config.head_configs.centered_instance + OmegaConf.update(config, "model_config.head_configs.bottomup", head_config) paf = { - "head_type": "PartAffinityFieldsHead", - "head_config": { - "edges": None, - "sigma": 4, - "output_stride": 4, - "loss_weight": 1.0, - }, + "edges": None, + "sigma": 4, + "output_stride": 4, + "loss_weight": 1.0, } - del config.model_config.head_configs["confmaps"].head_config.anchor_part - config.model_config.head_configs["pafs"] = paf + del config.model_config.head_configs.bottomup["confmaps"].anchor_part + del config.model_config.head_configs.centered_instance + config.model_config.head_configs.bottomup["pafs"] = paf + config.model_config.head_configs.bottomup.confmaps.loss_weight = 1.0 + OmegaConf.update( config, "trainer_config.save_ckpt_path", f"{tmp_path}/test_model_trainer/" ) @@ -452,7 +458,7 @@ def test_bottomup_model(config, tmp_path: str): skeletons = model_trainer.skeletons input_ = next(iter(model_trainer.train_data_loader)) - model = BottomUpModel(config, skeletons) + model = BottomUpModel(config, skeletons, "bottomup") preds = model(input_["image"])