Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorRT-LLM LayerNorm layer with no weights error (Hackathon 2023) #92

Open
col-in-coding opened this issue Sep 19, 2023 · 1 comment

Comments

@col-in-coding
Copy link

  • Environment
    • TensorRT 9.0.0.1
    • Versions of CUDA==12.0
    • Container used: registry.cn-hangzhou.aliyuncs.com/trt-hackathon/trt-hackathon:final_v1
    • NVIDIA driver version
  • Reproduction Steps
    • 当 LayerNorm 中的elementwise_affine为False时,会报错
    • 相关代码
from tensorrt_llm.layers import LayerNorm

class TestModel(Module):
    def __init__(self):
        super().__init__()
        dtype = str_dtype_to_trt('float32')
        self.dtype = dtype
        self.layernorm = LayerNorm(1280, dtype=dtype, elementwise_affine=False)

    def forward(self, inp):
        out = self.layernorm(inp)
        out.mark_output("output", self.dtype)
        return out

    def prepare_inputs(self):
        inp = Tensor(name="input", dtype=self.dtype, shape=[1, 64, 64, 1280])
        return (inp, )

llm_model = TestModel()
  • Expected Behavior
    • LayerNorm 不带参数时,不应该报错
  • Actual Behavior
Traceback (most recent call last):
  File "build.py", line 217, in <module>
    main(args)
  File "build.py", line 201, in main
    trt_llm_model(*inputs)
  File "/usr/local/lib/python3.8/dist-packages/tensorrt_llm/module.py", line 26, in __call__
    return self.forward(*args, **kwargs)
  File "/workspace/examples/sam/models/segment_anything/model.py", line 21, in forward
    out = self.layernorm(inp)
  File "/usr/local/lib/python3.8/dist-packages/tensorrt_llm/module.py", line 26, in __call__
    return self.forward(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/tensorrt_llm/layers/normalization.py", line 30, in forward
    return layer_norm(x, self.normalized_shape, weight, bias, self.eps)
  File "/usr/local/lib/python3.8/dist-packages/tensorrt_llm/functional.py", line 3152, in layer_norm
    input, weight = broadcast_helper(input, weight)
  File "/usr/local/lib/python3.8/dist-packages/tensorrt_llm/functional.py", line 1740, in broadcast_helper
    if left.rank() == right.rank():
AttributeError: 'NoneType' object has no attribute 'rank'
@col-in-coding
Copy link
Author

col-in-coding commented Sep 20, 2023

import time
from pathlib import Path
from tensorrt_llm.network import net_guard
from tensorrt_llm.builder import Builder
from tensorrt_llm.logger import logger
from tensorrt_llm.layers import LayerNorm
from tensorrt_llm.module import Module
from tensorrt_llm.functional import Tensor
from tensorrt_llm._utils import str_dtype_to_trt

logger.set_level("info")


class TestModel(Module):
    def __init__(self):
        super().__init__()
        dtype = str_dtype_to_trt('float32')
        self.dtype = dtype
        self.layernorm = LayerNorm(1280, dtype=dtype, elementwise_affine=False)

    def forward(self, inp):
        out = self.layernorm(inp)
        out.mark_output("output", self.dtype)
        return out

    def prepare_inputs(self):
        inp = Tensor(name="input", dtype=self.dtype, shape=[1, 64, 64, 1280])
        return (inp, )


def serialize_engine(engine, path):
    logger.info(f'Serializing engine to {path}...')
    tik = time.time()
    with open(path, 'wb') as f:
        f.write(bytearray(engine))
    tok = time.time()
    t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
    logger.info(f'Engine serialized. Total time: {t}')


if __name__ == "__main__":

    engine_dir = ""
    engine_name = "test.engine"
    dtype = "float32"
    engine_dir = Path(engine_dir)
    engine_path = engine_dir / engine_name
    # Build TRT network
    trt_llm_model = TestModel()

    # Module -> Network
    builder = Builder()
    builder_config = builder.create_builder_config(
        name="test",
        precision="float32",
        timing_cache=None,
        tensor_parallel=1,
        parallel_build=False,
    )
    network = builder.create_network()
    network.trt_network.name = engine_name

    with net_guard(network):
        # Prepare
        network.set_named_parameters(trt_llm_model.named_parameters())
        # Forward
        inputs = trt_llm_model.prepare_inputs()
        trt_llm_model(*inputs)

    # Network -> Engine
    # engine = None
    engine = builder.build_engine(network, builder_config)
    serialize_engine(engine, engine_path)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant