Skip to content

Conversation

@ParamThakkar123
Copy link
Contributor

@ParamThakkar123 ParamThakkar123 commented Sep 11, 2025

On completion, currently WIP fixes: #42

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 11, 2025
@ParamThakkar123 ParamThakkar123 marked this pull request as draft September 11, 2025 12:35
@ParamThakkar123
Copy link
Contributor Author

ParamThakkar123 commented Oct 23, 2025

This PR adds training support to BackendBench moving from inference only to supporting training as well. This PR adds the following:

  1. prompts for LLMs to generate backward passes of kernel by adding a create_backward_prompt() function to KernelTemplate, TritonKernelTemplate, PyTorchKernelTemplate, CuTeDSLKernelTemplate,
  2. Added a register_kernel() method in the OpRegistry class to register forward and backward pass kernels
  3. Added a new file train.py with the following:
    BackendBench.train.TrainingTestCase — container for inputs, target, params, and optional loss_fn.
    BackendBench.train.TrainingTestSuite — collection type for multiple test cases.
    BackendBench.train._mse_loss — default MSE loss used when none provided.
    BackendBench.train._compute_numerical_grads — finite-difference numerical gradient calculator used as fallback.
    BackendBench.train.train_one_op — main training loop that:
  • Prepares inputs/params (device/grad handling).
  • Runs forward via provided kernel implementation.
  • Synthesizes a target via a reference operator if needed (resolves via op registry).
  • Attempts autograd gradients on kernel; falls back to numerical grads if unavailable.
  • Computes reference gradients via autograd on a reference op (if available).
  • Compares gradients (relative error threshold) and applies simple SGD updates to params or inputs.
  • Returns metrics: grad_correct, grad_rel_error, step_time_ms, final_loss.
  1. Tests for each of the new functions implemented with dummy kernel implementations.

@ParamThakkar123 ParamThakkar123 marked this pull request as ready for review October 23, 2025 08:19
@ParamThakkar123 ParamThakkar123 changed the title [WIP] Adding training support Adding training support Oct 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Training support

1 participant