diff --git a/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs b/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs index 8bfe31acb..005c7850f 100644 --- a/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs +++ b/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs @@ -335,10 +335,10 @@ public static void ArgsChecker(Expr[] newArgs) throw new InvalidOperationException("Args has tuple"); } - if (newArgs.ToHashSet().Count != newArgs.Length) - { - throw new InvalidOperationException("Has Repeat args"); - } + // if (newArgs.ToHashSet().Count != newArgs.Length) + // { + // throw new InvalidOperationException("Has Repeat args"); + // } } // clone origin Expr and Do replace for var diff --git a/src/Nncase.Passes/Rules/ShapeBucket/SplitLLMStage.cs b/src/Nncase.Passes/Rules/ShapeBucket/SplitLLMStage.cs index 2d2bce5b2..723915c69 100644 --- a/src/Nncase.Passes/Rules/ShapeBucket/SplitLLMStage.cs +++ b/src/Nncase.Passes/Rules/ShapeBucket/SplitLLMStage.cs @@ -41,9 +41,26 @@ protected override async Task RunCoreAsync(IRModule input, RunPassCont input.Add(decode); Expr newBody; { - var kvShape = IR.F.Tensors.ShapeOf(entry.Parameters[3]); // %past_key_values: f32[24,2,1,?,2,64] - var kvLen = IR.F.Tensors.GetItem(kvShape, 3); - var cond = IR.F.Math.Equal(kvLen, 0L); + Expr? history_len = null; + for (int j = 0; j < entry.Parameters.Length; j++) + { + var paramVar = entry.Parameters[j]; + var dimExprs = CompileOptions.ShapeBucketOptions.VarMap[paramVar]; + for (int i = 0; i < dimExprs.Length; i++) + { + if (dimExprs[i] is Var { Name: "history_len" } && history_len is null) + { + history_len = IR.F.Tensors.GetItem(IR.F.Tensors.ShapeOf(paramVar), i); + } + } + } + + if (history_len is null) + { + throw new NotSupportedException("Can't get the history len from the function inputs!"); + } + + var cond = IR.F.Math.Equal(history_len, 0L); newBody = new IR.If(cond, prefill, decode, entry.Parameters.ToArray()); }