Skip to content

Commit

Permalink
feat(compiler): enhance afterStreaming execution flow
Browse files Browse the repository at this point in the history
Store and return result of afterStreaming execution, update tests to reflect changes.
  • Loading branch information
phodal committed Jul 22, 2024
1 parent ec1d02b commit c2b7227
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class PostCodeHandleContext(
*/
val editor: Editor? = null,

val lastTaskOutput: String? = null,
var lastTaskOutput: String? = null,

var compiledVariables: Map<String, Any> = mapOf(),
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,10 @@ open class HobbitHole(
myProject: Project,
console: ConsoleView?,
context: PostCodeHandleContext,
) {
afterStreaming?.execute(myProject, context, this)
): Any? {
val result = afterStreaming?.execute(myProject, context, this)
context.lastTaskOutput = result as? String
return result
}

companion object {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package com.phodal.shirelang.compiler.hobbit.ast

import com.intellij.openapi.diagnostic.logger
import com.intellij.openapi.project.Project
import com.phodal.shirecore.middleware.PostCodeHandleContext
import com.phodal.shirelang.compiler.hobbit.HobbitHole
import com.phodal.shirelang.compiler.hobbit.execute.FunctionStatementProcessor
import com.phodal.shirelang.compiler.patternaction.PatternActionFunc


data class TaskRoutesContext(
Expand Down Expand Up @@ -40,7 +42,7 @@ data class TaskRoutes(
myProject: Project,
context: PostCodeHandleContext,
hobbitHole: HobbitHole,
): List<Case> {
): Any? {
val conditionResult = mutableMapOf<String, Any?>()
val variableTable = mutableMapOf<String, Any?>()

Expand Down Expand Up @@ -71,20 +73,22 @@ data class TaskRoutes(
}
}

var result: Any? = null
if (matchedCase.isEmpty()) {
((defaultTask as? Task.Default)?.expression?.value as? Statement)?.let {
val result = ((defaultTask as? Task.Default)?.expression?.value as? Statement)?.let {
processor.execute(it, variableTable)
}

return emptyList()
logger<TaskRoutes>().info("no matched case, execute default task: $result")
return result
}

matchedCase.forEach {
val statement = (it.valueExpression as Task.CustomTask).expression?.value as Statement
processor.execute(statement, variableTable)
result = processor.execute(statement, variableTable)
}

return matchedCase
return result
}

companion object {
Expand All @@ -111,7 +115,7 @@ data class TaskRoutes(
* @param conditionCase The [ConditionCase] object to transform. This object contains conditions and cases that determine routing logic.
* @return A [TaskRoutes] object that encapsulates the transformed conditions and cases, along with a default task if specified.
*/
fun transformConditionCasesToRoutes(conditionCase: ConditionCase): TaskRoutes {
private fun transformConditionCasesToRoutes(conditionCase: ConditionCase): TaskRoutes {
val conditions: List<Condition> = conditionCase.conditions.map {
val caseKeyValue = it.value as CaseKeyValue

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ open class PatternFuncProcessor(open val myProject: Project, open val hole: Hobb
*
* @param action This is an instance of `PatternActionFunc` which is a sealed class. The function behavior changes based on the type of `PatternActionFunc`.
* @param input This is a generic parameter which can be of any type. It is used in the `PatternActionFunc.Cat` case.
* @param lastResult This is a generic parameter which can be of any type. It is used in all cases except `PatternActionFunc.Prompt`, `PatternActionFunc.Cat`, `PatternActionFunc.Print` and `PatternActionFunc.Xargs`.
* @param lastResult This is a generic parameter which can be of any type. It is used in all cases except `PatternActionFunc.Prompt`, `PatternActionFunc.Cat`, `PatternActionFunc.`Print`` and `PatternActionFunc.Xargs`.
*
* @return The return type is `Any`. The actual return type depends on the type of `PatternActionFunc`. For example, if `PatternActionFunc` is `Prompt`, it returns a `String`. If `PatternActionFunc` is `Grep`, it returns a `String` joined by "\n" from an `Array` or `String` that contains the specified patterns. If `PatternActionFunc` is `Sed`, it returns a `String` joined by "\n" from an `Array` or `String` where the specified pattern has been replaced. If `PatternActionFunc` is `Sort`, it returns a sorted `String` joined by "\n" from an `Array` or `String`. If `PatternActionFunc` is `Uniq`, it returns a `String` joined by "\n" from an `Array` or `String` with distinct elements. If `PatternActionFunc` is `Head`, it returns a `String` joined by "\n" from the first 'n' elements of an `Array` or `String`. If `PatternActionFunc` is `Tail`, it returns a `String` joined by "\n" from the last 'n' elements of an `Array` or `String`. If `PatternActionFunc` is `Cat`, it executes the `executeCatFunc` function. If `PatternActionFunc` is `Print`, it returns a `String` joined by "\n" from the `texts` property of `Print`. If `PatternActionFunc` is `Xargs`, it returns the `variables` property of `Xargs`. If `PatternActionFunc` is `UserCustom`, it logs an error message. If `PatternActionFunc` is of an unknown type, it logs an error message and returns an empty `String`.
*/
Expand Down Expand Up @@ -129,7 +129,16 @@ open class PatternFuncProcessor(open val myProject: Project, open val hole: Hobb
}

is PatternActionFunc.Print -> {
action.texts.joinToString("\n")
action.texts.map {
// load from variable table
// if (it.startsWith("\$")) {
// val variable = it.substring(1)
// hole.variables[variable]?.toString() ?: ""
// } else {
// it
// }
it
}.joinToString("\n")
}

is PatternActionFunc.Xargs -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,7 @@ class ShireLifecycleTest : BasePlatformTestCase() {
editor = null
)

val matchedCase = hole.afterStreaming?.execute(myFixture.project, handleContext, hole)
assertEquals(matchedCase?.size, 2)

assertEquals(matchedCase?.get(0)?.caseKey, "\"success\"")
assertEquals(matchedCase?.get(1)?.caseKey, "\"json-result\"")
val result = hole.afterStreaming?.execute(myFixture.project, handleContext, hole)
TestCase.assertEquals(result, "File not found")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,51 @@ class ShirePatternPipelineTest : BasePlatformTestCase() {
"Summary webpage: \$fileName\n" +
"when: \$fileName.matches(\"/.*.java/\")", context.genText)
}

fun testShouldSupportAfterStreamingPattern() {
@Language("Shire")
val code = """
---
name: Summary
description: "Generate Summary"
interaction: AppendCursor
variables:
"var2": "sample"
afterStreaming: {
case condition {
default { print(${'$'}output) }
}
}
---
Summary webpage: ${'$'}fileName
""".trimIndent()

val file = myFixture.addFileToProject("sample.shire", code)

myFixture.openFileInEditor(file.virtualFile)

val compile = ShireSyntaxAnalyzer(project, file as ShireFile, myFixture.editor).parse()
val hole = compile.config!!

val context = PostCodeHandleContext(
genText = "User prompt:\n\n",
)

runBlocking {
val compiledVariables =
ShireTemplateCompiler(project, hole, compile.variableTable, code).compileVariable(myFixture.editor)

context.compiledVariables = compiledVariables

hole.variables.mapValues {
PatternActionProcessor(project, hole).execute(it.value)
}

hole.setupStreamingEndProcessor(project, context = context)
hole.executeAfterStreamingProcessor(project, null, context = context)
}

assertEquals("\$output", context.lastTaskOutput)
}
}

0 comments on commit c2b7227

Please sign in to comment.