Skip to content

Conversation

@LeiWang1999
Copy link
Member

This pull request introduces a new debugging feature for CUDA code generation and includes several related changes across multiple files. The most important changes include the addition of a new header file for debugging, the implementation of device-side debug printing templates, and the creation of tests for these new debugging features.

Debugging Feature Implementation:

  • src/target/codegen_cuda.cc: Added an include statement for the new debug.h header file.
  • src/tl_templates/cuda/debug.h: Implemented template functions for device-side debug printing, including specializations for various data types such as int, float, half, half_t, bfloat16_t, and double.

Testing:

Language Enhancements:

Utility Functions:

  • tilelang/language/print.py: Created a new module for debug printing macros and utilities, including functions for printing variables, buffers, and conditionally executing debug prints.

@LeiWang1999 LeiWang1999 changed the title [Debug] Introduce T.print to help us print buffer and variables on frontend [Debug] Introduce T.print to help print buffer and variables on frontend Jan 24, 2025
@LeiWang1999 LeiWang1999 changed the title [Debug] Introduce T.print to help print buffer and variables on frontend [Debug] Introduce T.print for buffer and variables logging on frontend Jan 24, 2025
@LeiWang1999
Copy link
Member Author

print Implementation and Usage Guide

Overview

The print function in the provided code is a utility designed to help with debugging in TileLang programs. It supports printing both TIR primitive expressions (tir.PrimExpr) and TIR buffers (tir.Buffer). This functionality is particularly useful in GPU kernels for inspecting variables and intermediate results during execution.

The print function ensures controlled output by restricting printing to specific threads, preventing excessive or redundant debug information in parallel execution contexts.


print Implementation

Function Signature

def print(obj: Any) -> tir.PrimExpr:

Supported Input Types

  1. tir.Buffer: A buffer whose elements can be printed.

    • Buffers must be flattened into a 1D array before printing.
    • Printing is restricted to the first thread (tx=0, ty=0, tz=0) to avoid redundant output from multiple threads.
  2. tir.PrimExpr: A scalar expression to print its value directly.

Behavior

  • Buffer Printing:

    • Uses the print_flat_buffer_with_condition macro to print elements of a flattened buffer when certain thread conditions are met.
  • Expression Printing:

    • Uses the print_var macro to print the value of a scalar expression.
  • Unsupported Types:

    • Raises a ValueError if the input object is neither a tir.Buffer nor a tir.PrimExpr.

Macros Used by print

  1. print_var(var: tir.PrimExpr)

    • Prints the value of a scalar primitive expression.
    • Example: T.print(42) prints the scalar value 42.
  2. print_flat_buffer_with_condition(condition, buffer, elems)

    • Prints the elements of a buffer when the specified condition is True.
    • Iterates over the buffer elements and calls an external print function for each element.

Example Usage

1. Printing a Buffer

The debug_print_buffer function demonstrates how to print a shared buffer.

def debug_print_buffer(M=16, N=16):
    dtype = "float16"

    @T.prim_func
    def program(Q: T.Buffer((M, N), dtype)):
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
            # Allocate a shared buffer
            shared_buf = T.alloc_shared([M, N], dtype)
            
            # Print the buffer
            T.print(shared_buf)

    # Compile and execute the kernel
    jit_kernel = tilelang.JITKernel(program, target="cuda")
    profiler = jit_kernel.get_profiler()
    profiler.run_once()

2. Conditional Printing of a Buffer

To restrict printing to a specific thread, use conditions. In this case, only the thread with (bx=0, by=0, bz=0) will print the buffer.

def debug_print_buffer_conditional(M=16, N=16):
    dtype = "float16"

    @T.prim_func
    def program(Q: T.Buffer((M, N), dtype)):
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
            shared_buf = T.alloc_shared([M, N], dtype)

            # Print the buffer only if bx, by, bz are 0
            if bx == 0 and by == 0 and bz == 0:
                T.print(shared_buf)

    # Compile and execute the kernel
    jit_kernel = tilelang.JITKernel(program, target="cuda")
    profiler = jit_kernel.get_profiler()
    profiler.run_once()

3. Printing a Scalar Expression Conditionally

To print a scalar expression, use the T.print function directly. For example, the following function prints the sum of bx, by, and bz but only for the thread with tid=0.

def debug_print_value_conditional(M=16, N=16):
    dtype = "float16"

    @T.prim_func
    def program(Q: T.Buffer((M, N), dtype)):
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
            # Get the thread ID
            tid = T.get_thread_binding()
            
            # Print bx + by + bz only for thread tid=0
            if tid == 0:
                T.print(bx + by + bz)

    # Compile and execute the kernel
    jit_kernel = tilelang.JITKernel(program, target="cuda")
    profiler = jit_kernel.get_profiler()
    profiler.run_once()

Important Notes

  1. Thread Safety:

    • Printing within GPU kernels should be restricted to specific threads to avoid excessive and redundant output, which can overwhelm the console.
    • Use conditions like if bx == 0 and by == 0 and bz == 0 to limit output.
  2. Flattened Buffers:

    • Buffers must be flattened to 1D before printing to ensure consistent output.
    • Ensure the buffer shape matches the expectation of T.print.
  3. Debugging with Profiler:

    • The profiler (profiler.run_once()) is used to execute and debug the kernel.
  4. Type Restrictions:

    • The print function only supports tir.Buffer and tir.PrimExpr. Passing unsupported types will raise an error.

How to Run Tests

The provided test cases can be executed directly to verify the functionality of print:

if __name__ == "__main__":
    tilelang.testing.main()

Test Functions

  • test_debug_print_buffer(): Tests printing of an entire shared buffer.
  • test_debug_print_buffer_conditional(): Tests conditional printing of a shared buffer.
  • test_debug_print_value_conditional(): Tests conditional printing of scalar expressions.

@LeiWang1999
Copy link
Member Author

LeiWang1999 commented Jan 24, 2025

We should also need to implement test case for register files.

  • Implement T.print for register files.

It's not simple for us to print fragment data.

@LeiWang1999 LeiWang1999 merged commit 5b48f8d into tile-ai:main Jan 24, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant