Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RPE metal #2049

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Implement batch dimension transposition.
almaudoh committed Dec 4, 2024
commit d0bf5965c0f67c1e43bfcfa302aa40f9ca2d8533
83 changes: 41 additions & 42 deletions src/neural/metal/mps/NetworkGraph.mm
Original file line number Diff line number Diff line change
@@ -784,75 +784,74 @@ -(nonnull MPSGraphTensor *) relativePositionEncodingWithTensor:(MPSGraphTensor *
name:[NSString stringWithFormat:@"%@/reshape", label]];

// Permutations to implement einsum.
// First permute rpeTensor to get D to dimension 3, then expand.
// First permute rpeTensor to get D to dimension 3.
if (type == 0) {
// RPE-Q
// rpe: [D, H, Q, K] -> [H, Q, D, K]
rpeTensor = [self transposeTensor:rpeTensor dimension:0 withDimension:1 name:[NSString stringWithFormat:@"%@/transpose_1", label]];
rpeTensor = [self transposeTensor:rpeTensor dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_2", label]];
// Reshape rpe for the matmul.
rpeTensor = [self reshapeTensor:rpeTensor
withShape:@[@(heads * queries), @(depth), @(keys)]
name:[NSString stringWithFormat:@"%@/reshape_1", label]];
} else if (type == 1) {
// RPE-K
// rpe: [D, H, Q, K] -> [H, K, D, Q]
rpeTensor = [self transposeTensor:rpeTensor dimension:2 withDimension:3 name:[NSString stringWithFormat:@"%@/transpose_1", label]];
rpeTensor = [self transposeTensor:rpeTensor dimension:0 withDimension:1 name:[NSString stringWithFormat:@"%@/transpose_2", label]];
rpeTensor = [self transposeTensor:rpeTensor dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_3", label]];
// Reshape rpe for the matmul.
rpeTensor = [self reshapeTensor:rpeTensor
withShape:@[@(heads * keys), @(depth), @(queries)]
name:[NSString stringWithFormat:@"%@/reshape_1", label]];
} else if (type == 2) {
// RPE-V
// rpe: [D, H, Q, K] -> [H, Q, K, D]
rpeTensor = [self transposeTensor:rpeTensor dimension:0 withDimension:1 name:[NSString stringWithFormat:@"%@/transpose_1", label]];
rpeTensor = [self transposeTensor:rpeTensor dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_2", label]];
rpeTensor = [self transposeTensor:rpeTensor dimension:2 withDimension:3 name:[NSString stringWithFormat:@"%@/transpose_3", label]];
// Reshape rpe for the matmul.
rpeTensor = [self reshapeTensor:rpeTensor
withShape:@[@(heads * queries), @(keys), @(depth)]
name:[NSString stringWithFormat:@"%@/reshape_1", label]];
}

// Second transpose Nabc -> abNc to allow abNc × abcd -> abNd, where N is the batch dimension.
// x: [B, H, Q, D] -> [H, Q, B, D] # RPE-Q
// x: [B, H, K, D] -> [H, K, B, D] # RPE-K
// x: [B, H, Q, K] -> [H, Q, B, K] # RPE-V
tensor = [self transposeTensor:tensor dimension:0 withDimension:1 name:[NSString stringWithFormat:@"%@/a_transpose_1", label]];
tensor = [self transposeTensor:tensor dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/a_transpose_2", label]];
if (type == 2) {
tensor = [self reshapeTensor:tensor
withShape:@[@(heads * queries), @(-1), @(keys)]
name:[NSString stringWithFormat:@"%@/reshape_2", label]];
} else {
tensor = [self reshapeTensor:tensor
withShape:@[@(heads * queries), @(-1), @(depth)]
name:[NSString stringWithFormat:@"%@/reshape_2", label]];
}
// Expand dimension of RPE tensor.
// rpe: [H, Q, D, K] -> [1, H, Q, D, K] # RPE-Q
// rpe: [H, K, D, Q] -> [1, H, K, D, Q] # RPE-K
// rpe: [H, Q, K, D] -> [1, H, Q, K, D] # RPE-V
// rpeTensor = [self expandDimsOfTensor:rpeTensor axis:0 name:[NSString stringWithFormat:@"%@/expand_dims", label]];

// Second broadcast rpeTensor to match the batch size. Fuse with above expand dims.
// rpe_permuted_Q: [1, H, Q, D, K] -> [B, H, Q, D, K] -> [BH, Q, D, K]
// rpe_permuted_K: [1, H, K, D, Q] -> [B, H, K, D, Q] -> [BH, K, D, Q]
// rpe_permuted_V: [1, H, Q, K, D] -> [B, H, Q, K, D] -> [BH, Q, K, D]
// Rather than broadcast, reshape to allow proper auto-broadcast.
// rpeTensor = [self reshapeTensor:rpeTensor
// withShape:@[@(-1), rpeTensor.shape[2], rpeTensor.shape[3], rpeTensor.shape[4]]
// name:[NSString stringWithFormat:@"%@/rpe/reshape_1", label]];
// NSArray * toShape = type == 0 ? @[tensor.shape[0], @(heads), @(queries), @(depth), @(keys)]
// : (type == 1 ? @[tensor.shape[0], @(heads), @(keys), @(depth), @(queries)]
// : @[tensor.shape[0], @(heads), @(queries), @(keys), @(depth)]);
//
// rpeTensor = [self broadcastTensor:rpeTensor
// toShape:toShape
// name:[NSString stringWithFormat:@"%@/rpe/broadcast", label]];

// Third add a singleton dimension to input tensor for K.
// Rather than a dimension add, a reshape is needed to allow proper broadcasting.
// x: [B, H, Q, D] -> [B, H, Q, 1, D] -> [BH, Q, 1, D] # RPE-Q
// x: [B, H, K, D] -> [B, H, K, 1, D] -> [BH, K, 1, D] # RPE-K
// x: [B, H, Q, K] -> [B, H, Q, 1, K] -> [BH, Q, 1, K] # RPE-V
// tensor = [self expandDimsOfTensor:tensor axis:3 name:[NSString stringWithFormat:@"%@/rpe/expand_dims_2", label]];
tensor = [self reshapeTensor:tensor
withShape:@[@(-1), tensor.shape[2], @(1), tensor.shape[3]]
name:[NSString stringWithFormat:@"%@/rpe/reshape_2", label]];

// Finally matrix multiplication and squeeze.
// x: [H, Q, B, D] x [H, Q, D, K] -> [H, Q, B, K] # RPE-Q
// x: [H, K, B, D] x [H, K, D, Q] -> [H, K, B, Q] # RPE-K
// x: [H, Q, B, K] x [H, Q, K, D] -> [H, Q, B, D] # RPE-V
tensor = [self matrixMultiplicationWithPrimaryTensor:tensor
secondaryTensor:rpeTensor
name:[NSString stringWithFormat:@"%@/rpe/matmul", label]];

// x: [B, H, Q, 1, K] -> [B, H, Q, K] # RPE-Q
// x: [B, H, K, 1, Q] -> [B, H, K, Q] # RPE-K
// x: [B, H, Q, 1, D] -> [B, H, Q, D] # RPE-V
// tensor = [self squeezeTensor:tensor
// axis:3
// name:[NSString stringWithFormat:@"%@/rpe/squeeze", label]];
tensor = [self reshapeTensor:tensor withShape:@[@(-1), @(heads), tensor.shape[1], tensor.shape[3]] name:[NSString stringWithFormat:@"%@/rpe/reshape_3", label]];


// Reverse the last reshape and transposition.
NSUInteger dim = type == 2 ? depth : keys;
tensor = [self reshapeTensor:tensor withShape:@[@(heads), @(queries), @(-1), @(dim)] name:[NSString stringWithFormat:@"%@/reshape_3", label]];
tensor = [self transposeTensor:tensor dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/a_transpose_4", label]];
tensor = [self transposeTensor:tensor dimension:0 withDimension:1 name:[NSString stringWithFormat:@"%@/a_transpose_5", label]];


if (type == 1) {
// RPE-K needs to transpose back.
// RPE-K needs another transposition back to BHQK.
// x: [B, H, K, Q] -> [B, H, Q, K] # RPE-K
return [self transposeTensor:tensor dimension:2 withDimension:3 name:[NSString stringWithFormat:@"%@/rpe/transpose_3", label]];
return [self transposeTensor:tensor dimension:2 withDimension:3 name:[NSString stringWithFormat:@"%@/rpe/transpose_6", label]];
}

// x: [B, H, Q, K] # RPE-Q or RPE-K
8 changes: 4 additions & 4 deletions src/neural/metal/network_metal.cc
Original file line number Diff line number Diff line change
@@ -203,10 +203,10 @@ void MetalNetwork::forwardEval(InputsOutputs* io, int batchSize) {
// The next thread can start using the GPU now.
lock_.unlock();

int start = 0;
for (auto i = 0; i < 16 * 4096; i++) {
CERR << i + start << ";" << io->op_policy_raw_mem_[i + start];
}
// int start = 0;
// for (auto i = 0; i < 16 * 4096; i++) {
// CERR << i + start << ";" << io->op_policy_raw_mem_[i + start];
// }

if (attn_policy_) {
// Promotion offset calculation.