diff --git a/docs/detection/double_detection_filter.md b/docs/detection/double_detection_filter.md
index b02663715..6cb2393fd 100644
--- a/docs/detection/double_detection_filter.md
+++ b/docs/detection/double_detection_filter.md
@@ -16,12 +16,25 @@ comments: true
 
 :::supervision.detection.overlap_filter.box_non_max_suppression
 
+<div class="md-typeset">
+  <h2><a href="#supervision.detection.overlap_filter.box_soft_non_max_suppression">box_soft_non_max_suppression</a></h2>
+</div>
+
+:::supervision.detection.overlap_filter.box_soft_non_max_suppression
+
+
 <div class="md-typeset">
   <h2><a href="#supervision.detection.overlap_filter.mask_non_max_suppression">mask_non_max_suppression</a></h2>
 </div>
 
 :::supervision.detection.overlap_filter.mask_non_max_suppression
 
+<div class="md-typeset">
+  <h2><a href="#supervision.detection.overlap_filter.mask_soft_non_max_suppression">mask_soft_non_max_suppression</a></h2>
+</div>
+
+:::supervision.detection.overlap_filter.mask_soft_non_max_suppression
+
 <div class="md-typeset">
   <h2><a href="#supervision.detection.overlap_filter.box_non_max_merge">box_non_max_merge</a></h2>
 </div>
diff --git a/supervision/detection/core.py b/supervision/detection/core.py
index 113948fc9..ab3ab348d 100644
--- a/supervision/detection/core.py
+++ b/supervision/detection/core.py
@@ -19,7 +19,9 @@
 from supervision.detection.overlap_filter import (
     box_non_max_merge,
     box_non_max_suppression,
+    box_soft_non_max_suppression,
     mask_non_max_suppression,
+    mask_soft_non_max_suppression,
 )
 from supervision.detection.tools.transformers import (
     process_transformers_detection_result,
@@ -1320,6 +1322,63 @@ def with_nms(
 
         return self[indices]
 
+    def with_soft_nms(
+        self, sigma: float = 0.5, class_agnostic: bool = False
+    ) -> Detections:
+        """
+        Perform soft non-maximum suppression on the current set of object detections.
+
+        Args:
+            sigma (float): The sigma value to use for the soft non-maximum suppression
+                algorithm. Defaults to 0.5.
+            class_agnostic (bool): Whether to perform class-agnostic
+                non-maximum suppression. If True, the class_id of each detection
+                will be ignored. Defaults to False.
+
+        Returns:
+            Detections: A new Detections object containing the subset of detections
+                after non-maximum suppression.
+
+        Raises:
+            AssertionError: If `confidence` is None and class_agnostic is False.
+        """
+        if len(self) == 0:
+            return self
+
+        assert (
+            self.confidence is not None
+        ), "Detections confidence must be given for NMS to be executed."
+
+        if class_agnostic:
+            predictions = np.hstack((self.xyxy, self.confidence.reshape(-1, 1)))
+        else:
+            assert self.class_id is not None, (
+                "Detections class_id must be given for NMS to be executed. If you"
+                " intended to perform class agnostic NMS set class_agnostic=True."
+            )
+            predictions = np.hstack(
+                (
+                    self.xyxy,
+                    self.confidence.reshape(-1, 1),
+                    self.class_id.reshape(-1, 1),
+                )
+            )
+
+        if self.mask is not None:
+            soft_confidences = mask_soft_non_max_suppression(
+                predictions=predictions,
+                masks=self.mask,
+                sigma=sigma,
+            )
+            self.confidence = soft_confidences
+        else:
+            soft_confidences = box_soft_non_max_suppression(
+                predictions=predictions, sigma=sigma
+            )
+            self.confidence = soft_confidences
+
+        return self
+
     def with_nmm(
         self, threshold: float = 0.5, class_agnostic: bool = False
     ) -> Detections:
diff --git a/supervision/detection/overlap_filter.py b/supervision/detection/overlap_filter.py
index 4c59295f6..a7ef40c19 100644
--- a/supervision/detection/overlap_filter.py
+++ b/supervision/detection/overlap_filter.py
@@ -1,7 +1,7 @@
 from __future__ import annotations
 
 from enum import Enum
-from typing import List, Union
+from typing import List, Tuple, Union
 
 import numpy as np
 import numpy.typing as npt
@@ -38,6 +38,48 @@ def resize_masks(masks: np.ndarray, max_dimension: int = 640) -> np.ndarray:
     return resized_masks
 
 
+def __prepare_data_for_mask_nms(
+    mask_dimension: int,
+    masks: np.ndarray,
+    predictions: np.ndarray,
+) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int, np.ndarray]:
+    """
+    Get IOUs from mask. Prepare the data for non-max suppression.
+
+    Args:
+        mask_dimension (int): The dimension to which the masks should be
+            resized before computing IOU values.
+        masks (np.ndarray): A 3D array of binary masks corresponding to the predictions.
+            Shape: `(N, H, W)`, where N is the number of predictions, and H, W are the
+            dimensions of each
+        predictions (np.ndarray): An array of object detection predictions in the format
+            of `(x_min, y_min, x_max, y_max, score)` or
+            `(x_min, y_min, x_max, y_max, score, class)`. Shape: `(N, 5)` or `(N, 6)`,
+            where N is the number of predictions.
+
+    Returns:
+        Tuple[np.ndarray, np.ndarray, int, np.ndarray]: A tuple containing the
+            predictions, categories, IOUs, number of rows, and the sorted indices.
+
+    Raises:
+        AssertionError: If `iou_threshold` is not within the closed range from
+            `0` to `1`.
+    """
+    rows, columns = predictions.shape
+
+    if columns == 5:
+        predictions = np.c_[predictions, np.zeros(rows)]
+
+    sort_index = predictions[:, 4].argsort()[::-1]
+    predictions = predictions[sort_index]
+    masks = masks[sort_index]
+    masks_resized = resize_masks(masks, mask_dimension)
+    ious = mask_iou_batch(masks_resized, masks_resized)
+    categories = predictions[:, 5]
+
+    return predictions, categories, ious, rows, sort_index
+
+
 def mask_non_max_suppression(
     predictions: np.ndarray,
     masks: np.ndarray,
@@ -72,17 +114,9 @@ def mask_non_max_suppression(
         "Value of `iou_threshold` must be in the closed range from 0 to 1, "
         f"{iou_threshold} given."
     )
-    rows, columns = predictions.shape
-
-    if columns == 5:
-        predictions = np.c_[predictions, np.zeros(rows)]
-
-    sort_index = predictions[:, 4].argsort()[::-1]
-    predictions = predictions[sort_index]
-    masks = masks[sort_index]
-    masks_resized = resize_masks(masks, mask_dimension)
-    ious = mask_iou_batch(masks_resized, masks_resized)
-    categories = predictions[:, 5]
+    _, categories, ious, rows, sort_index = __prepare_data_for_mask_nms(
+        mask_dimension, masks, predictions
+    )
 
     keep = np.ones(rows, dtype=bool)
     for i in range(rows):
@@ -93,31 +127,71 @@ def mask_non_max_suppression(
     return keep[sort_index.argsort()]
 
 
-def box_non_max_suppression(
-    predictions: np.ndarray, iou_threshold: float = 0.5
+def mask_soft_non_max_suppression(
+    predictions: np.ndarray,
+    masks: np.ndarray,
+    mask_dimension: int = 640,
+    sigma: float = 0.5,
 ) -> np.ndarray:
     """
-    Perform Non-Maximum Suppression (NMS) on object detection predictions.
+    Perform Soft Non-Maximum Suppression (Soft-NMS) on segmentation predictions.
 
-    Args:
+     Args:
         predictions (np.ndarray): An array of object detection predictions in
             the format of `(x_min, y_min, x_max, y_max, score)`
             or `(x_min, y_min, x_max, y_max, score, class)`.
         iou_threshold (float): The intersection-over-union threshold
             to use for non-maximum suppression.
+        sigma (float): The sigma value to use for soft non-maximum suppression.
 
     Returns:
-        np.ndarray: A boolean array indicating which predictions to keep after n
-            on-maximum suppression.
+        np.ndarray: An array containing the updated confidence scores.
 
     Raises:
         AssertionError: If `iou_threshold` is not within the
             closed range from `0` to `1`.
+        AssertionError: If `sigma` is not within the open range from `0` to `1`.
     """
-    assert 0 <= iou_threshold <= 1, (
-        "Value of `iou_threshold` must be in the closed range from 0 to 1, "
-        f"{iou_threshold} given."
+    assert (
+        0 < sigma < 1
+    ), f"Value of `sigma` must be greater than 0 and less than 1, {sigma} given."
+    predictions, categories, ious, rows, sort_index = __prepare_data_for_mask_nms(
+        mask_dimension, masks, predictions
     )
+
+    not_this_row = np.ones(rows)
+    for i in range(rows):
+        not_this_row[i] = 0
+        condition = (categories[i] == categories) * not_this_row
+        predictions[:, 4] = predictions[:, 4] * np.exp(
+            -(ious[i] ** 2) / sigma * condition
+        )
+
+    return predictions[sort_index.argsort(), 4]
+
+
+def __prepare_data_for_box_nsm(
+    predictions: np.ndarray,
+) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int, np.ndarray]:
+    """
+    Prepare the data for non-max suppression.
+
+    Args:
+        predictions (np.ndarray): An array of object detection predictions in the
+            format of `(x_min, y_min, x_max, y_max, score)` or
+            `(x_min, y_min, x_max, y_max, score, class)`. Shape: `(N, 5)` or `(N, 6)`,
+            where N is the number of predictions.
+
+    Returns:
+        Tuple[np.ndarray, np.ndarray, np.ndarray, int, np.ndarray]: A tuple containing
+            the predictions, categories, IOUs, number of rows, and the sorted indices
+
+    Raises:
+        AssertionError: If `iou_threshold` is not within the closed range from `0`
+            to `1`.
+
+
+    """
     rows, columns = predictions.shape
 
     # add column #5 - category filled with zeros for agnostic nms
@@ -127,14 +201,42 @@ def box_non_max_suppression(
     # sort predictions column #4 - score
     sort_index = np.flip(predictions[:, 4].argsort())
     predictions = predictions[sort_index]
-
     boxes = predictions[:, :4]
     categories = predictions[:, 5]
     ious = box_iou_batch(boxes, boxes)
     ious = ious - np.eye(rows)
 
-    keep = np.ones(rows, dtype=bool)
+    return predictions, categories, ious, rows, sort_index
+
+
+def box_non_max_suppression(
+    predictions: np.ndarray, iou_threshold: float = 0.5
+) -> np.ndarray:
+    """
+    Perform Non-Maximum Suppression (NMS) on object detection predictions.
+
+    Args:
+        predictions (np.ndarray): An array of object detection predictions in
+            the format of `(x_min, y_min, x_max, y_max, score)`
+            or `(x_min, y_min, x_max, y_max, score, class)`.
+        iou_threshold (float): The intersection-over-union threshold
+            to use for non-maximum suppression.
+
+    Returns:
+        np.ndarray: A boolean array indicating which predictions to keep after n
+            on-maximum suppression.
+
+    Raises:
+        AssertionError: If `iou_threshold` is not within the
+            closed range from `0` to `1`.
+    """
+    assert 0 <= iou_threshold <= 1, (
+        "Value of `iou_threshold` must be in the closed range from 0 to 1, "
+        f"{iou_threshold} given."
+    )
+    _, categories, ious, rows, sort_index = __prepare_data_for_box_nsm(predictions)
 
+    keep = np.ones(rows, dtype=bool)
     for index, (iou, category) in enumerate(zip(ious, categories)):
         if not keep[index]:
             continue
@@ -147,6 +249,46 @@ def box_non_max_suppression(
     return keep[sort_index.argsort()]
 
 
+def box_soft_non_max_suppression(
+    predictions: np.ndarray, sigma: float = 0.5
+) -> np.ndarray:
+    """
+    Perform Soft Non-Maximum Suppression (Soft-NMS) on object detection predictions.
+
+    Args:
+        predictions (np.ndarray): An array of object detection predictions in
+            the format of `(x_min, y_min, x_max, y_max, score)`
+            or `(x_min, y_min, x_max, y_max, score, class)`.
+        iou_threshold (float): The intersection-over-union threshold
+            to use for soft non-maximum suppression.
+        sigma (float): The sigma value to use for soft non-maximum suppression.
+
+    Returns:
+        np.ndarray: An array containing the updated confidence scores.
+    Raises:
+        AssertionError: If `iou_threshold` is not within the
+            closed range from `0` to `1`.
+        AssertionError: If `sigma` is not within the opened range from `0` to `1`.
+    """
+
+    assert (
+        0 < sigma < 1
+    ), f"Value of `sigma` must be greater than 0 and less than 1, {sigma} given."
+    predictions, categories, ious, rows, sort_index = __prepare_data_for_box_nsm(
+        predictions
+    )
+
+    not_this_row = np.ones(rows)
+    for i in range(rows):
+        not_this_row[i] = 0
+        condition = (categories[i] == categories) * not_this_row
+        predictions[:, 4] = predictions[:, 4] * np.exp(
+            -(ious[i] ** 2) / sigma * condition
+        )
+
+    return predictions[sort_index.argsort(), 4]
+
+
 def group_overlapping_boxes(
     predictions: npt.NDArray[np.float64], iou_threshold: float = 0.5
 ) -> List[List[int]]:
diff --git a/test/detection/test_overlap_filter.py b/test/detection/test_overlap_filter.py
index f628c30f9..6b0df77a4 100644
--- a/test/detection/test_overlap_filter.py
+++ b/test/detection/test_overlap_filter.py
@@ -6,8 +6,10 @@
 
 from supervision.detection.overlap_filter import (
     box_non_max_suppression,
+    box_soft_non_max_suppression,
     group_overlapping_boxes,
     mask_non_max_suppression,
+    mask_soft_non_max_suppression,
 )
 
 
@@ -243,6 +245,109 @@ def test_box_non_max_suppression(
         assert np.array_equal(result, expected_result)
 
 
+@pytest.mark.parametrize(
+    "predictions, sigma, expected_result, exception",
+    [
+        (
+            np.empty(shape=(0, 5)),
+            0.1,
+            np.array([]),
+            DoesNotRaise(),
+        ),  # single box with no category
+        (
+            np.array([[10.0, 10.0, 40.0, 40.0, 0.8]]),
+            0.8,
+            np.array([0.8]),
+            DoesNotRaise(),
+        ),  # single box with no category
+        (
+            np.array([[10.0, 10.0, 40.0, 40.0, 0.8, 0]]),
+            0.9,
+            np.array([0.8]),
+            DoesNotRaise(),
+        ),  # single box with category
+        (
+            np.array(
+                [
+                    [10.0, 10.0, 40.0, 40.0, 0.8],
+                    [15.0, 15.0, 40.0, 40.0, 0.9],
+                ]
+            ),
+            0.2,
+            np.array([0.07176137, 0.9]),
+            DoesNotRaise(),
+        ),  # two boxes with no category
+        (
+            np.array(
+                [
+                    [10.0, 10.0, 40.0, 40.0, 0.8, 0],
+                    [15.0, 15.0, 40.0, 40.0, 0.9, 1],
+                ]
+            ),
+            0.3,
+            np.array([0.8, 0.9]),
+            DoesNotRaise(),
+        ),  # two boxes with different category
+        (
+            np.array(
+                [
+                    [10.0, 10.0, 40.0, 40.0, 0.8, 0],
+                    [15.0, 15.0, 40.0, 40.0, 0.9, 0],
+                ]
+            ),
+            0.9,
+            np.array([0.46814354, 0.9]),
+            DoesNotRaise(),
+        ),  # two boxes with same category
+        (
+            np.array(
+                [
+                    [0.0, 0.0, 30.0, 40.0, 0.8],
+                    [5.0, 5.0, 35.0, 45.0, 0.9],
+                    [10.0, 10.0, 40.0, 50.0, 0.85],
+                ]
+            ),
+            0.7,
+            np.array([0.42648529, 0.9, 0.53109062]),
+            DoesNotRaise(),
+        ),  # three boxes with no category
+        (
+            np.array(
+                [
+                    [0.0, 0.0, 30.0, 40.0, 0.8, 0],
+                    [5.0, 5.0, 35.0, 45.0, 0.9, 1],
+                    [10.0, 10.0, 40.0, 50.0, 0.85, 2],
+                ]
+            ),
+            0.5,
+            np.array([0.8, 0.9, 0.85]),
+            DoesNotRaise(),
+        ),  # three boxes with same category
+        (
+            np.array(
+                [
+                    [0.0, 0.0, 30.0, 40.0, 0.8, 0],
+                    [5.0, 5.0, 35.0, 45.0, 0.9, 0],
+                    [10.0, 10.0, 40.0, 50.0, 0.85, 1],
+                ]
+            ),
+            0.9,
+            np.array([0.55491779, 0.9, 0.85]),
+            DoesNotRaise(),
+        ),  # three boxes with different category
+    ],
+)
+def test_box_soft_non_max_suppression(
+    predictions: np.ndarray,
+    sigma: float,
+    expected_result: Optional[np.ndarray],
+    exception: Exception,
+) -> None:
+    with exception:
+        result = box_soft_non_max_suppression(predictions=predictions, sigma=sigma)
+        np.testing.assert_almost_equal(result, expected_result, decimal=5)
+
+
 @pytest.mark.parametrize(
     "predictions, masks, iou_threshold, expected_result, exception",
     [
@@ -447,3 +552,211 @@ def test_mask_non_max_suppression(
             predictions=predictions, masks=masks, iou_threshold=iou_threshold
         )
         assert np.array_equal(result, expected_result)
+
+
+@pytest.mark.parametrize(
+    "predictions, masks, sigma, expected_result, exception",
+    [
+        (
+            np.empty((0, 6)),
+            np.empty((0, 5, 5)),
+            0.1,
+            np.array([]),
+            DoesNotRaise(),
+        ),  # empty predictions and masks
+        (
+            np.array([[0, 0, 0, 0, 0.8]]),
+            np.array(
+                [
+                    [
+                        [False, False, False, False, False],
+                        [False, True, True, True, False],
+                        [False, True, True, True, False],
+                        [False, True, True, True, False],
+                        [False, False, False, False, False],
+                    ]
+                ]
+            ),
+            0.2,
+            np.array([0.8]),
+            DoesNotRaise(),
+        ),  # single mask with no category
+        (
+            np.array([[0, 0, 0, 0, 0.8, 0]]),
+            np.array(
+                [
+                    [
+                        [False, False, False, False, False],
+                        [False, True, True, True, False],
+                        [False, True, True, True, False],
+                        [False, True, True, True, False],
+                        [False, False, False, False, False],
+                    ]
+                ]
+            ),
+            0.99,
+            np.array([0.8]),
+            DoesNotRaise(),
+        ),  # single mask with category
+        (
+            np.array([[0, 0, 0, 0, 0.8], [0, 0, 0, 0, 0.9]]),
+            np.array(
+                [
+                    [
+                        [False, False, False, False, False],
+                        [False, True, True, False, False],
+                        [False, True, True, False, False],
+                        [False, False, False, False, False],
+                        [False, False, False, False, False],
+                    ],
+                    [
+                        [False, False, False, False, False],
+                        [False, False, False, False, False],
+                        [False, False, False, True, True],
+                        [False, False, False, True, True],
+                        [False, False, False, False, False],
+                    ],
+                ]
+            ),
+            0.8,
+            np.array([0.8, 0.9]),
+            DoesNotRaise(),
+        ),  # two masks non-overlapping with no category
+        (
+            np.array([[0, 0, 0, 0, 0.8], [0, 0, 0, 0, 0.9]]),
+            np.array(
+                [
+                    [
+                        [False, False, False, False, False],
+                        [False, True, True, True, False],
+                        [False, True, True, True, False],
+                        [False, True, True, True, False],
+                        [False, False, False, False, False],
+                    ],
+                    [
+                        [False, False, False, False, False],
+                        [False, False, True, True, True],
+                        [False, False, True, True, True],
+                        [False, False, True, True, True],
+                        [False, False, False, False, False],
+                    ],
+                ]
+            ),
+            0.6,
+            np.array([0.3831756, 0.9]),
+            DoesNotRaise(),
+        ),  # two masks partially overlapping with no category
+        (
+            np.array([[0, 0, 0, 0, 0.8, 0], [0, 0, 0, 0, 0.9, 1]]),
+            np.array(
+                [
+                    [
+                        [False, False, False, False, False],
+                        [False, True, True, True, False],
+                        [False, True, True, True, False],
+                        [False, True, True, True, False],
+                        [False, False, False, False, False],
+                    ],
+                    [
+                        [False, False, False, False, False],
+                        [False, False, True, True, True],
+                        [False, False, True, True, True],
+                        [False, False, True, True, True],
+                        [False, False, False, False, False],
+                    ],
+                ]
+            ),
+            0.9,
+            np.array([0.8, 0.9]),
+            DoesNotRaise(),
+        ),  # two masks partially overlapping with different category
+        (
+            np.array(
+                [
+                    [0, 0, 0, 0, 0.8],
+                    [0, 0, 0, 0, 0.85],
+                    [0, 0, 0, 0, 0.9],
+                ]
+            ),
+            np.array(
+                [
+                    [
+                        [False, False, False, False, False],
+                        [False, True, True, False, False],
+                        [False, True, True, False, False],
+                        [False, False, False, False, False],
+                        [False, False, False, False, False],
+                    ],
+                    [
+                        [False, False, False, False, False],
+                        [False, True, True, False, False],
+                        [False, True, True, False, False],
+                        [False, False, False, False, False],
+                        [False, False, False, False, False],
+                    ],
+                    [
+                        [False, False, False, False, False],
+                        [False, False, False, True, True],
+                        [False, False, False, True, True],
+                        [False, False, False, False, False],
+                        [False, False, False, False, False],
+                    ],
+                ]
+            ),
+            0.3,
+            np.array([0.02853919, 0.85, 0.9]),
+            DoesNotRaise(),
+        ),  # three masks with no category
+        (
+            np.array(
+                [
+                    [0, 0, 0, 0, 0.8, 0],
+                    [0, 0, 0, 0, 0.85, 1],
+                    [0, 0, 0, 0, 0.9, 2],
+                ]
+            ),
+            np.array(
+                [
+                    [
+                        [False, False, False, False, False],
+                        [False, True, True, False, False],
+                        [False, True, True, False, False],
+                        [False, False, False, False, False],
+                        [False, False, False, False, False],
+                    ],
+                    [
+                        [False, False, False, False, False],
+                        [False, True, True, False, False],
+                        [False, True, True, False, False],
+                        [False, True, True, False, False],
+                        [False, False, False, False, False],
+                    ],
+                    [
+                        [False, False, False, False, False],
+                        [False, True, True, False, False],
+                        [False, True, True, False, False],
+                        [False, False, False, False, False],
+                        [False, False, False, False, False],
+                    ],
+                ]
+            ),
+            0.1,
+            np.array([0.8, 0.85, 0.9]),
+            DoesNotRaise(),
+        ),  # three masks with different category
+    ],
+)
+def test_mask_soft_non_max_suppression(
+    predictions: np.ndarray,
+    masks: np.ndarray,
+    sigma: float,
+    expected_result: Optional[np.ndarray],
+    exception: Exception,
+) -> None:
+    with exception:
+        result = mask_soft_non_max_suppression(
+            predictions=predictions,
+            masks=masks,
+            sigma=sigma,
+        )
+        np.testing.assert_almost_equal(result, expected_result, decimal=6)