diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py
index 5b00bfe2d..3a52cda73 100644
--- a/coremltools/converters/mil/frontend/torch/ops.py
+++ b/coremltools/converters/mil/frontend/torch/ops.py
@@ -2598,6 +2598,10 @@ def upsample_nearest2d(context, node):
 def tupleunpack(context, node):
     inputs = _get_inputs(context, node, expected=1)
     values = inputs[0]
+
+    if len(node.outputs) == 1:
+        values = [values]
+
     # Node input could have been turned into constant array in @tupleconstruct
     if not isinstance(values, tuple) and not isinstance(values, list):
         values = values.val
@@ -3097,8 +3101,11 @@ def index(context, node):
     # For multiple index axes case, we now assume that all the index have equal shape
     for index in valid_indices:
         if not is_compatible_symbolic_vector(index.shape, valid_indices[0].shape):
-            raise NotImplementedError("Broadcasable tensor index not supported.")
-
+            broadcast_inputs = _broadcast_tensors([valid_indices[0], index])
+            index = broadcast_inputs[1]
+            valid_indices[0] = broadcast_inputs[0]
+            valid_indices.append(index)
+    
     # First stack the index together
     indices_rank = valid_indices[0].rank
     indices = mb.stack(values=valid_indices, axis=indices_rank)
@@ -3398,6 +3405,18 @@ def _slice(context, node):
     context.add(res)
 
 
+def _num_splits_and_sizes(split_sizes):
+    if split_sizes.sym_val is not None:
+        return len(split_sizes.sym_val), split_sizes.sym_val
+
+    if any_symbolic(split_sizes.shape):
+        raise ValueError("Unable to determine number of splits")
+
+    num_splits = len(split_sizes.shape)
+    sizes = [get_new_symbol() for _ in range(num_splits)]
+    return num_splits, sizes
+
+
 @register_torch_op(torch_alias=["split_with_sizes"])
 def split(context, node):
     inputs = _get_inputs(context, node, expected=3)
@@ -3425,6 +3444,14 @@ def split(context, node):
         else:
             partial_size = mb.mul(x=tmp, y=remainder)
             split_sizes = mb.concat(values=[whole_sizes, partial_size], axis=0)
+    
+
+    num_splits, sizes = _num_splits_and_sizes(split_sizes=split_sizes)
+    if num_splits == 1:
+        out = mb.identity(x=x, name=node.name)
+        context.add(out, node.name)
+        return
+
     res = mb.split(x=x, split_sizes=split_sizes, axis=dim, name=node.name)
     context.add(res, torch_name=node.name)
 
@@ -3482,6 +3509,13 @@ def to(context, node):
             "Received invalid arguments for PyTorch conversion of op {}".format(node)
         )
 
+    # We have to handle the case where the dtype is not set, this should be inferred from the Tensor dtype
+    # see, https://pytorch.org/docs/stable/generated/torch.Tensor.to.html?highlight=#torch.Tensor.to
+    if dtype is None:
+        out = mb.identity(x=_input, name=node.name)
+        context.add(out, node.name)
+        return
+
     torch_dtype = NUM_TO_TORCH_DTYPE[dtype]
     if isinstance(_input, Var) and _input.val is not None:
         _input = _input.val
@@ -3924,8 +3958,20 @@ def ceil(context, node):
 @register_torch_op
 def clamp(context, node):
     inputs = _get_inputs(context, node, expected=3)
-    min_val = inputs[1] if inputs[1] else _np.finfo(_np.float32).min
-    max_val = inputs[2] if inputs[2] else _np.finfo(_np.float32).max
+    if not inputs[1]:
+        min_val = _np.finfo(_np.float32).min
+    else:
+        min_val = inputs[1]
+        if types.builtin_to_string(min_val.dtype).startswith('int'):
+            min_val = mb.cast(x=min_val, dtype='fp32')
+
+    if not inputs[2]:
+        max_val = _np.finfo(_np.float32).max
+    else:
+        max_val = inputs[2]
+        if types.builtin_to_string(max_val.dtype).startswith('int'):
+            max_val = mb.cast(x=max_val, dtype='fp32')
+
     context.add(mb.clip(x=inputs[0], alpha=min_val, beta=max_val, name=node.name))
 
 @register_torch_op
@@ -4074,7 +4120,7 @@ def is_floating_point(context, node):
     is_float = types.is_float(inputs[0].dtype)
     context.add(mb.const(val=is_float, name=node.name))
 
-@register_torch_op()
+@register_torch_op(torch_alias=["__and_", "__and__"])
 def logical_and(context, node):
     inputs = _get_inputs(context, node, expected=2)
     x, y = inputs
@@ -4253,6 +4299,11 @@ def _make_tensor(list_of_tensor, name, rank):
         context.add(mb.identity(x=val, name=node.name))
         return
 
+    if inputs[2] is None:
+        res = mb.const(val=[val.val], name=node.name)
+        context.add(res, torch_name=node.name)
+        return
+
     # Case 2: Create a tensor filled with a single value
     val = val.val # element val to fill
     msg_prefix = 'torch::tensor {} '.format(node.name)
@@ -4483,7 +4534,6 @@ def _scatter(context, inputs, mode, name):
                                    axis=axis, mode=mode, name=name)
     context.add(result)
 
-
 @register_torch_op
 def scatter(context, node):
     inputs = _get_inputs(context, node)
@@ -4501,8 +4551,106 @@ def scatter(context, node):
 
     _scatter(context, inputs, mode, node.name)
 
-
 @register_torch_op
 def scatter_add(context, node):
     inputs = _get_inputs(context, node)
     _scatter(context, inputs, 'add', node.name)
+
+@register_torch_op
+def roi_align(context, node):
+    inputs = _get_inputs(context, node)
+
+    x = context[node.inputs[0]]
+    input_shape = x.shape  # (B, h_in, w_in, C)
+    if len(input_shape) != 4:
+        raise ValueError(
+            '"CropResize" op: expected input rank 4, got {}'.format(x.rank)
+        )
+
+    const_box_info = True
+    if context[node.inputs[1]].val is None or context[node.inputs[2]].val is None:
+        const_box_info = False
+
+    extrapolation_value = context[node.inputs[2]].val
+
+    # CoreML index information along with boxes
+    if const_box_info:
+        boxes = context[node.inputs[1]].val
+        # CoreML expects boxes/ROI in
+        # [N, 1, 5, 1, 1] format
+        boxes = boxes.reshape(boxes.shape[0], 1, boxes.shape[1], 1, 1)
+    else:
+        boxes = inputs[1]
+        boxes = mb.reshape(x=boxes, shape=[boxes.shape[0], 1, boxes.shape[1], 1, 1])
+    # Get Height and Width of crop
+    h_out = inputs[3]
+    w_out = inputs[4]
+
+    # Torch input format: [B, C, h_in, w_in]
+    # CoreML input format: [B, C, h_in, w_in]
+
+    # Crop Resize
+    x = mb.crop_resize(
+        x=x,
+        roi=boxes,
+        target_height=h_out.val,
+        target_width=w_out.val,
+        normalized_coordinates=True,
+        spatial_scale=extrapolation_value,
+        box_coordinate_mode="CORNERS_HEIGHT_FIRST",
+        sampling_mode='OFFSET_CORNERS',
+    )
+    
+    # CoreML output format: [N, 1, C, h_out, w_out]
+    # Torch output format: [N, C, h_out, w_out]
+    x = mb.squeeze(x=x, axes=[1])
+
+    context.add(x, torch_name=node.outputs[0])
+
+@register_torch_op
+def numel(context, node):
+    inputs = _get_inputs(context, node, expected=1)
+    context.add(mb.reduce_prod(x=inputs[0], name=node.name), torch_name=node.outputs[0])
+
+@register_torch_op
+def nms(context, node):
+    inputs = _get_inputs(context, node)
+    boxes = inputs[0]
+
+    num_boxes = boxes.shape[0]
+    max_boxes = num_boxes  # we set the max_boxes just to be # input boxes
+
+    scores = inputs[1]
+    iou_threshold = inputs[2]
+    boxes = mb.expand_dims(x=boxes, axes=[0])
+    scores = mb.expand_dims(x=scores, axes=[0, -1])
+
+    # Follow tensorflow op example: TensorFlow's default value for score_threshold, Core ML does not
+    # have float('-inf') support, converted to minimum float32 instead
+    score_threshold = -3.4e38
+
+    _, _, x, _ = mb.non_maximum_suppression(
+        boxes=boxes,
+        scores=scores,
+        iou_threshold=iou_threshold,
+        score_threshold=score_threshold,
+        max_boxes=max_boxes
+    )
+
+    if not is_symbolic(num_boxes):
+        x = mb.squeeze(x=x, axes=[0])
+        x = mb.slice_by_index(x=x, begin=[0], end=[max_boxes], name=node.name)
+    else:
+        x = mb.squeeze(x=x, axes=[0], name=node.name)
+    context.add(x, torch_name=node.name)
+
+@register_torch_op
+def narrow(context, node):
+    data, dim, start, length = _get_inputs(context, node, expected=4)
+    data_shape = mb.shape(x=data).val
+    begin = [0]*len(data_shape)
+    end = [x for x in data_shape]
+    begin[dim.val] = start.val
+    end[dim.val] = start.val+length.val
+    out = mb.slice_by_index(x=data, begin=begin, end=end)
+    context.add(out, torch_name=node.name)
diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
index 95e6690b6..8ebb78982 100644
--- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
+++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
@@ -10,6 +10,8 @@
 import pytest
 import torch.nn as nn
 
+import torchvision
+
 from .testing_utils import (
     contains_op,
     generate_input_data,
@@ -4564,3 +4566,76 @@ def forward(self, x):
             backend=backend,
             converter_input_type=converter_input_type,
         )
+
+class TestNumel(TorchBaseTest):
+    @pytest.mark.parametrize(
+        "shapes, backend",
+        itertools.product(
+            [
+                [(2, 1)],
+                [(5, 1, 4, 1)],
+                [(1,)],
+            ],
+            backends
+        ),
+    )
+    def test_numel(self, shapes, backend):
+        class Model(nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, x):
+                v = torch.numel(x)
+                return torch.tensor(v)
+
+        model = Model()
+        self.run_compare_torch(shapes, model, backend=backend)
+
+
+class TestNarrow(TorchBaseTest):
+    @pytest.mark.parametrize(
+        "shapes, dim_start_length, backend",
+        itertools.product(
+            [
+                [(3, 3)],
+            ],
+            [
+                (0, 0, 2)
+            ]
+            ,
+            backends
+        ),
+    )
+    def test_narrow(self, shapes, dim_start_length, backend):
+        dim, start, length = dim_start_length
+        class Model(nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, x):
+                return torch.narrow(x, dim, start, length)
+
+        model = Model()
+        self.run_compare_torch(shapes, model, backend=backend)
+
+
+class TestNonMaximalSuppression(TorchBaseTest):
+    @pytest.mark.parametrize(
+        "shapes, scores, backend",
+        itertools.product(
+            [[(2, 4)]],
+            [(2,)],
+            backends
+        ),
+    )
+    def test_non_maximal_supression(self, shapes, scores, backend):
+        scores = torch.rand(scores)
+        class Model(nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, x):
+                return torchvision.ops.nms(x, scores, iou_threshold=0.7)
+
+        model = Model()
+        self.run_compare_torch(shapes, model, backend=backend)