Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ ENV PYTHON_VERSION=${PYTHON_VERSION}
ENV DEBIAN_FRONTEND=noninteractive

# Install basic dependencies
RUN apt-get update
RUN apt install -y build-essential manpages-dev wget zlib1g software-properties-common git libssl-dev zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev wget ca-certificates curl llvm libncurses5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev mecab-ipadic-utf8
RUN apt-get update && apt-get install -y build-essential manpages-dev wget zlib1g software-properties-common git libssl-dev zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev wget ca-certificates curl llvm libncurses5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev mecab-ipadic-utf8

# Install PyEnv and desired Python version
ENV HOME="/root"
Expand All @@ -34,8 +33,7 @@ RUN pyenv global ${PYTHON_VERSION}
# Install TensorRT + dependencies
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub
RUN add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/ /"
RUN apt-get update
RUN TENSORRT_MAJOR_VERSION=`echo ${TENSORRT_VERSION} | cut -d '.' -f 1` && \
RUN apt-get update && TENSORRT_MAJOR_VERSION=`echo ${TENSORRT_VERSION} | cut -d '.' -f 1` && \
apt-get install -y libnvinfer${TENSORRT_MAJOR_VERSION}=${TENSORRT_VERSION}.* \
libnvinfer-plugin${TENSORRT_MAJOR_VERSION}=${TENSORRT_VERSION}.* \
libnvinfer-dev=${TENSORRT_VERSION}.* \
Expand All @@ -55,9 +53,9 @@ FROM base as torch-tensorrt-builder-base
ARG ARCH="x86_64"
ARG TARGETARCH="amd64"

RUN apt-get update
RUN apt-get install -y python3-setuptools
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub
RUN apt-get update && \
apt-get install -y python3-setuptools && \
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub

RUN apt-get update &&\
apt-get install -y --no-install-recommends locales ninja-build &&\
Expand Down
80 changes: 37 additions & 43 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,53 +446,47 @@ def create_constant(
else:
shape = list(torch_value.shape)

if torch_value is not None:

if torch_value.dtype == torch.uint8:
if is_tensorrt_version_supported("10.8.0"):
if (
target_quantized_type is None
or target_quantized_type != trt.DataType.FP4
):
# Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8
raise ValueError(
"Currently supported target_quantized_type for uint8 is FP4, got {target_quantized_type=}"
)
shape[-1] = shape[-1] * 2
weights = to_trt_weights(
ctx,
torch_value,
name,
"CONSTANT",
"CONSTANT",
dtype=trt.DataType.FP4,
count=torch_value.numel() * 2,
)
constant = ctx.net.add_constant(
shape,
weights,
)
constant.name = name
return constant.get_output(0)
else:
if torch_value.dtype == torch.uint8:
if is_tensorrt_version_supported("10.8.0"):
if (
target_quantized_type is None
or target_quantized_type != trt.DataType.FP4
):
# Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8
raise ValueError(
"Currently FP4 is only supported in TensorRT 10.8.0 and above"
"Currently supported target_quantized_type for uint8 is FP4, got {target_quantized_type=}"
)
# Record the weight in ctx for refit and cpu memory reference
shape[-1] = shape[-1] * 2
weights = to_trt_weights(
ctx,
torch_value,
name,
"CONSTANT",
"CONSTANT",
dtype=trt.DataType.FP4,
count=torch_value.numel() * 2,
)
constant = ctx.net.add_constant(
shape,
weights,
)
constant.name = name
return constant.get_output(0)
else:
raise ValueError(
"Currently FP4 is only supported in TensorRT 10.8.0 and above"
)
# Record the weight in ctx for refit and cpu memory reference

# Convert the torch.Tensor to a trt.Weights object
trt_weights = to_trt_weights(ctx, torch_value, name, "CONSTANT", "CONSTANT")
constant = ctx.net.add_constant(
shape,
trt_weights,
)
constant.name = name
# Convert the torch.Tensor to a trt.Weights object
trt_weights = to_trt_weights(ctx, torch_value, name, "CONSTANT", "CONSTANT")
constant = ctx.net.add_constant(
shape,
trt_weights,
)
constant.name = name

return constant.get_output(0)
else:
raise ValueError(
f"Cannot convert tensor '{name}' to a TensorRT constant because its value is None."
)
return constant.get_output(0)


def get_trt_tensor(
Expand Down
27 changes: 15 additions & 12 deletions py/torch_tensorrt/dynamo/conversion/impl/deconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,16 @@ def deconvNd(
assert len(kernel_shape) > 0, "Deconvolution kernel shape must be non-empty"

# add deconv layer
if groups is not None:
num_output_maps = num_output_maps * groups
deconv_layer = ctx.net.add_deconvolution_nd(
input=input,
num_output_maps=num_output_maps * groups,
num_output_maps=num_output_maps,
kernel_shape=kernel_shape,
kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight,
bias=trt.Weights() if isinstance(bias, TRTTensor) else bias,
)
assert deconv_layer is not None, "Deconvolution layer is None"
set_layer_name(deconv_layer, target, name, source_ir)

# If the weight is a TRTTensor, set it as an input of the layer
Expand Down Expand Up @@ -145,7 +148,6 @@ def deconvNd(
if output_padding is not None
else output_padding
)

# Set relevant attributes of deconvolution layer
if padding is not None:
deconv_layer.padding_nd = padding
Expand All @@ -156,19 +158,20 @@ def deconvNd(
if groups is not None:
deconv_layer.num_groups = groups

ndims = len(padding)
pre_padding_values = []
post_padding_values = []
if padding is not None:
ndims = len(padding)
pre_padding_values = []
post_padding_values = []

for dim in range(ndims):
pre_padding = padding[dim]
post_padding = padding[dim] - output_padding[dim]
for dim in range(ndims):
pre_padding = padding[dim]
post_padding = padding[dim] - output_padding[dim]

pre_padding_values.append(pre_padding)
post_padding_values.append(post_padding)
pre_padding_values.append(pre_padding)
post_padding_values.append(post_padding)

deconv_layer.pre_padding = tuple(pre_padding_values)
deconv_layer.post_padding = tuple(post_padding_values)
deconv_layer.pre_padding = tuple(pre_padding_values)
deconv_layer.post_padding = tuple(post_padding_values)

result = deconv_layer.get_output(0)

Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def reduce_operation_with_scatter(
scatter_tensor = initial_tensor
else:
# This case would not be encountered from torch itself
print("Invalid Operation for Reduce op!!")
raise ValueError(f"Invalid Operation for Reduce op: {self}")

operation_rhs = torch.scatter(scatter_tensor, dim, index_tensor, src_tensor)
device = to_torch_device(scatter_tensor.device)
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,13 +826,13 @@ def get_output_metadata(
return [node.meta for node in nodes]


def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]:
def get_output_dtypes(output: Any, truncate_double: bool = False) -> List[dtype]:
output_dtypes = []
if isinstance(output, torch.fx.node.Node):
if "val" in output.meta:
output_meta = output.meta["val"]
if isinstance(output_meta, (FakeTensor, torch.Tensor)):
if truncate_doulbe and output_meta.dtype == torch.float64:
if truncate_double and output_meta.dtype == torch.float64:
output_dtypes.append(dtype.float32)
else:
output_dtypes.append(dtype._from(output_meta.dtype))
Expand Down
5 changes: 4 additions & 1 deletion py/torch_tensorrt/fx/converters/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,10 @@ def to_numpy(
"""
output = None

if value is None or isinstance(value, np.ndarray):
if value is None:
return None

elif isinstance(value, np.ndarray):
output = value

elif isinstance(value, torch.Tensor):
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/fx/tools/timing_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ def get_timing_cache_trt(self, timing_cache_file: str) -> bytearray:
return None

def update_timing_cache(
self, timing_cache_file: str, serilized_cache: bytearray
self, timing_cache_file: str, serialized_cache: bytearray
) -> None:
if not self.save_timing_cache:
return
timing_cache_file = self.get_file_full_name(timing_cache_file)
with open(timing_cache_file, "wb") as local_cache:
local_cache.seek(0)
local_cache.write(serilized_cache)
local_cache.write(serialized_cache)
local_cache.truncate()
Loading