Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add float cast op for JIT backend #2511

Merged
merged 20 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions burn-book/src/advanced/backend-extension/custom-cubecl-kernel.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,14 @@ impl<R: JitRuntime, F: FloatElement, I: IntElement> Backend for JitBackend<R, F,
.empty(shape_out.num_elements() * core::mem::size_of::<F>());

// Create the output tensor primitive.
let output =
JitTensor::new_contiguous(lhs.client.clone(), lhs.device.clone(), shape_out, buffer);
// Create the output tensor primitive.
let output = JitTensor::new_contiguous(
lhs.client.clone(),
lhs.device.clone(),
shape_out,
buffer,
F::dtype(),
);

// Declare the wgsl workgroup with the number of cubes in x, y and z.
let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;
Expand All @@ -186,10 +192,10 @@ impl<R: JitRuntime, F: FloatElement, I: IntElement> Backend for JitBackend<R, F,
&lhs.client,
cube_count,
cube_dim,
lhs.as_tensor_arg(1),
rhs.as_tensor_arg(1),
bias.as_tensor_arg(1),
output.as_tensor_arg(1),
lhs.as_tensor_arg::<F>(1),
rhs.as_tensor_arg::<F>(1),
bias.as_tensor_arg::<F>(1),
output.as_tensor_arg::<F>(1),
);

// Return the output tensor.
Expand Down Expand Up @@ -251,12 +257,12 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {

// Set our state.
let (lhs_state, rhs_state, output, shape_bias) = ops.state;
let lhs = checkpointer.retrieve_node_output(lhs_state);
let rhs = checkpointer.retrieve_node_output(rhs_state);
let lhs: FloatTensor<B> = checkpointer.retrieve_node_output(lhs_state);
let rhs: FloatTensor<B> = checkpointer.retrieve_node_output(rhs_state);

// Fetch shapes of our tensor to support broadcasting.
let shape_lhs = B::float_shape(&lhs);
let shape_rhs = B::float_shape(&rhs);
let shape_lhs = lhs.shape();
let shape_rhs = rhs.shape();

// Compute the gradient of the output using the already existing `relu_backward`
// function in the basic Burn backend trait.
Expand Down Expand Up @@ -314,7 +320,7 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
// compute bound operation.
let lhs_state = prep.checkpoint(&lhs);
let rhs_state = prep.checkpoint(&rhs);
let bias_shape = B::float_shape(&bias.primitive);
let bias_shape = bias.primitive.shape();

let output = B::fused_matmul_add_relu(
lhs.primitive.clone(),
Expand Down
21 changes: 13 additions & 8 deletions burn-book/src/advanced/backend-extension/custom-wgpu-kernel.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,14 +239,19 @@ impl<F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime, F, I> {
.empty(shape_out.num_elements() * core::mem::size_of::<F>());

// Create the output tensor primitive.
let output =
JitTensor::new_contiguous(lhs.client.clone(), lhs.device.clone(), shape_out, buffer);
let output = JitTensor::new_contiguous(
lhs.client.clone(),
lhs.device.clone(),
shape_out,
buffer,
F::dtype(),
);

// Create the kernel.
let kernel = FusedMatmulAddRelu::<F>::new(cube_dim);

// Build info buffer with tensor information needed by the kernel, such as shapes and strides.
let info = build_info(&[&lhs, &rhs, &output]);
let info = build_info::<_, F>(&[&lhs, &rhs, &output]);
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));

// Declare the wgsl workgroup with the number of cubes in x, y and z.
Expand Down Expand Up @@ -331,12 +336,12 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {

// Set our state.
let (lhs_state, rhs_state, output, shape_bias) = ops.state;
let lhs = checkpointer.retrieve_node_output(lhs_state);
let rhs = checkpointer.retrieve_node_output(rhs_state);
let lhs: FloatTensor<B> = checkpointer.retrieve_node_output(lhs_state);
let rhs: FloatTensor<B> = checkpointer.retrieve_node_output(rhs_state);

// Fetch shapes of our tensor to support broadcasting.
let shape_lhs = B::float_shape(&lhs);
let shape_rhs = B::float_shape(&rhs);
let shape_lhs = lhs.shape();
let shape_rhs = rhs.shape();

// Compute the gradient of the output using the already existing `relu_backward`
// function in the basic Burn backend trait.
Expand Down Expand Up @@ -392,7 +397,7 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
// during the backward pass. Here we choose to save it in the state because it's a compute bound operation.
let lhs_state = prep.checkpoint(&lhs);
let rhs_state = prep.checkpoint(&rhs);
let bias_shape = B::float_shape(&bias.primitive);
let bias_shape = bias.primitive.shape();

let output = B::fused_matmul_add_relu(
lhs.primitive.clone(),
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-autodiff/src/grads.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use burn_tensor::{backend::Backend, container::TensorContainer, ops::FloatTensor};
use burn_tensor::{backend::Backend, container::TensorContainer, ops::FloatTensor, TensorMetadata};

use crate::{
graph::{NodeRef, Requirement},
Expand All @@ -22,7 +22,7 @@ impl Gradients {
};
gradients.register::<B>(
root_node.id,
B::float_ones(B::float_shape(&root_tensor), &B::float_device(&root_tensor)),
B::float_ones(root_tensor.shape(), &B::float_device(&root_tensor)),
);
gradients
}
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-autodiff/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
graph::{ComputingProperty, NodeID, NodeRef, Requirement, Step},
tensor::AutodiffTensor,
};
use burn_tensor::{backend::Backend, ops::FloatTensor, Shape};
use burn_tensor::{backend::Backend, ops::FloatTensor, Shape, TensorMetadata};
use std::marker::PhantomData;

/// Operation in preparation.
Expand Down Expand Up @@ -292,7 +292,7 @@ impl<const N: usize> Step for UntrackedOpsStep<N> {
/// If broadcasting happened during the forward pass, the gradients will be sum along the
/// broadcasted dimension.
pub fn broadcast_shape<B: Backend>(mut grad: FloatTensor<B>, shape: &Shape) -> FloatTensor<B> {
let shape_grad = B::float_shape(&grad);
let shape_grad = grad.shape();
let ndims = shape_grad.num_dims();

for i in 0..ndims {
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-autodiff/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
B::bool_from_data(data, device)
}

fn bool_shape(tensor: &BoolTensor<B>) -> Shape {
B::bool_shape(tensor)
}

async fn bool_into_data(tensor: BoolTensor<B>) -> TensorData {
B::bool_into_data(tensor).await
}
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-autodiff/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_from_data(data, device)
}

fn int_shape(tensor: &IntTensor<B>) -> Shape {
B::int_shape(tensor)
}

async fn int_into_data(tensor: IntTensor<B>) -> TensorData {
B::int_into_data(tensor).await
}
Expand Down
Loading