Skip to content

Commit

Permalink
fix shape bucket
Browse files Browse the repository at this point in the history
  • Loading branch information
zhen8838 committed Jan 24, 2025
1 parent 13b8373 commit 9d0bb7b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
8 changes: 4 additions & 4 deletions src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -335,10 +335,10 @@ public static void ArgsChecker(Expr[] newArgs)
throw new InvalidOperationException("Args has tuple");
}

if (newArgs.ToHashSet().Count != newArgs.Length)
{
throw new InvalidOperationException("Has Repeat args");
}
// if (newArgs.ToHashSet().Count != newArgs.Length)
// {
// throw new InvalidOperationException("Has Repeat args");
// }
}

// clone origin Expr and Do replace for var
Expand Down
23 changes: 20 additions & 3 deletions src/Nncase.Passes/Rules/ShapeBucket/SplitLLMStage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,26 @@ protected override async Task<IRModule> RunCoreAsync(IRModule input, RunPassCont
input.Add(decode);
Expr newBody;
{
var kvShape = IR.F.Tensors.ShapeOf(entry.Parameters[3]); // %past_key_values: f32[24,2,1,?,2,64]
var kvLen = IR.F.Tensors.GetItem(kvShape, 3);
var cond = IR.F.Math.Equal(kvLen, 0L);
Expr? history_len = null;
for (int j = 0; j < entry.Parameters.Length; j++)
{
var paramVar = entry.Parameters[j];
var dimExprs = CompileOptions.ShapeBucketOptions.VarMap[paramVar];
for (int i = 0; i < dimExprs.Length; i++)
{
if (dimExprs[i] is Var { Name: "history_len" } && history_len is null)
{
history_len = IR.F.Tensors.GetItem(IR.F.Tensors.ShapeOf(paramVar), i);
}
}
}

if (history_len is null)
{
throw new NotSupportedException("Can't get the history len from the function inputs!");
}

var cond = IR.F.Math.Equal(history_len, 0L);
newBody = new IR.If(cond, prefill, decode, entry.Parameters.ToArray());
}

Expand Down

0 comments on commit 9d0bb7b

Please sign in to comment.