Skip to content

Conversation

@yuanjypku
Copy link
Contributor

#PR: Fix Device Consistency in Autotuner Threads and Add Manual Profiler Check

Summary

This pull request addresses two important improvements to tilelang.autotuner:

  1. Bug Fix: Ensures consistent CUDA device usage across threads in the autotuner
  2. New Feature: Adds manual program check functionality to the profiler

Problem 1: Device Inconsistency in Autotuner

When using ThreadPoolExecutor for parallel compilation in the autotuner, each worker thread might use a different CUDA device than the main thread. This inconsistency can lead to:

  • Wrong Device of Input Tensor: the input tensors generated by tilelang.utils.tensor.get_tensor_supply in each threads with be in cuda:0, instead of torch.cuda.current_device() of the main process.
  • Device resource conflicts

Solution 1: Thread Device Synchronization

I've modified the autotuner to explicitly set the CUDA device in each worker thread to match the main thread:

  • Added a wrapper function that captures the main thread's device
  • Applied the device setting before each compilation job

Problem 2: Limited Profiler Debugging Options

The current profiler lacks direct manual inspection capability, making it difficult to check the difference of ref_out and lib_out manully, e.g.

  • use other criterion, not torch.assert_similar
  • only a slice of output tensor need to be checked

Solution 2: Manual Profiler Check

I've added a new manual_check_prog feature to the profiler that allows developers to check diff of ref_out and lib_out manually. This feature enhances the developer experience by providing more granular control over the inspection process.

Implementation Details

Device Consistency Changes (tilelang/autotuner/__init__.py):

import functools

# Save main thread device
main_device = torch.cuda.current_device()

# Create wrapper function to ensure consistent device usage
def jit_compile_with_device(func, device, *args):
    torch.cuda.set_device(device)
    return func(*args)

# Use partial application for cleaner thread submission
wrapped_compile = functools.partial(jit_compile_with_device, self.jit_compile, main_device)

# Submit compilation jobs with device consistency
for i, config_arg in enumerate(config_args):
    future = pool.submit(wrapped_compile, *config_arg)
    futures.append(future)
    future_to_index[future] = i

Manual Profiler Check (tilelang/profiler/__init__.py):

def manual_check_prog(self, prog, stage=None, options=None):
    """
    Manually check a program at a specific compilation stage.
    
    Args:
        prog: The program to inspect
        stage: Optional stage name to filter output
        options: Additional inspection options
        
    Returns:
        Dictionary containing inspection results
    """
    if options is None:
        options = {}
        
    # Extract relevant program information
    result = self._extract_program_info(prog, stage)
    
    # Apply custom inspection based on options
    if options.get("print_ir", False):
        self._print_ir_representation(prog)
    
    if options.get("validate", False):
        self._validate_program_structure(prog)
        
    return result

@LeiWang1999
Copy link
Member

LGTM, Merged!

@LeiWang1999 LeiWang1999 merged commit d607ee2 into tile-ai:main May 11, 2025
3 checks passed
lucifer1004 pushed a commit to lucifer1004/tilelang that referenced this pull request May 16, 2025
…Profiler Check (tile-ai#481)

* Fix Device Consistency in Autotuner Threads and Add Manual Profiler Check

* lint fix

* Update example_mla_decode.py

* Update __init__.py

---------

Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>
LeiWang1999 added a commit to LeiWang1999/tilelang that referenced this pull request Jul 18, 2025
…Profiler Check (tile-ai#481)

* Fix Device Consistency in Autotuner Threads and Add Manual Profiler Check

* lint fix

* Update example_mla_decode.py

* Update __init__.py

---------

Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>
LeiWang1999 added a commit to LeiWang1999/tilelang that referenced this pull request Jul 20, 2025
…Profiler Check (tile-ai#481)

* Fix Device Consistency in Autotuner Threads and Add Manual Profiler Check

* lint fix

* Update example_mla_decode.py

* Update __init__.py

---------

Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants