Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions cmd/gh-aw/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/github/gh-aw/pkg/cli"
"github.com/github/gh-aw/pkg/console"
"github.com/github/gh-aw/pkg/constants"
"github.com/github/gh-aw/pkg/parser"
"github.com/github/gh-aw/pkg/workflow"
"github.com/spf13/cobra"
)
Expand All @@ -24,8 +25,23 @@ var bannerFlag bool

// validateEngine validates the engine flag value
func validateEngine(engine string) error {
if engine != "" && engine != "claude" && engine != "codex" && engine != "copilot" && engine != "custom" {
return fmt.Errorf("invalid engine value '%s'. Must be 'claude', 'codex', 'copilot', or 'custom'", engine)
// Get the global engine registry
registry := workflow.GetGlobalEngineRegistry()
validEngines := registry.GetSupportedEngines()

if engine != "" && !registry.IsValidEngine(engine) {
// Try to find close matches for "did you mean" suggestion
suggestions := parser.FindClosestMatches(engine, validEngines, 1)

errMsg := fmt.Sprintf("invalid engine value '%s'. Must be '%s'",
engine, strings.Join(validEngines, "', '"))

if len(suggestions) > 0 {
errMsg = fmt.Sprintf("invalid engine value '%s'. Must be '%s'.\n\nDid you mean: %s?",
engine, strings.Join(validEngines, "', '"), suggestions[0])
}

return fmt.Errorf("%s", errMsg)
}
return nil
}
Expand Down
30 changes: 28 additions & 2 deletions pkg/workflow/engine_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ package workflow
import (
"encoding/json"
"fmt"
"strings"

"github.com/github/gh-aw/pkg/constants"
"github.com/github/gh-aw/pkg/logger"
"github.com/github/gh-aw/pkg/parser"
)

var engineValidationLog = logger.New("workflow:engine_validation")
Expand Down Expand Up @@ -66,8 +68,32 @@ func (c *Compiler) validateEngine(engineID string) error {
}

engineValidationLog.Printf("Engine ID %s not found: %v", engineID, err)
// Provide helpful error with valid options
return fmt.Errorf("invalid engine: %s. Valid engines are: copilot, claude, codex, custom.\n\nExample:\nengine: copilot\n\nSee: %s", engineID, constants.DocsEnginesURL)

// Get list of valid engine IDs from the engine registry
validEngines := c.engineRegistry.GetSupportedEngines()

// Try to find close matches for "did you mean" suggestion
suggestions := parser.FindClosestMatches(engineID, validEngines, 1)

// Build comma-separated list of valid engines for error message
enginesStr := strings.Join(validEngines, ", ")

// Build error message with helpful context
errMsg := fmt.Sprintf("invalid engine: %s. Valid engines are: %s.\n\nExample:\nengine: copilot\n\nSee: %s",
engineID,
enginesStr,
constants.DocsEnginesURL)

// Add "did you mean" suggestion if we found a close match
if len(suggestions) > 0 {
errMsg = fmt.Sprintf("invalid engine: %s. Valid engines are: %s.\n\nDid you mean: %s?\n\nExample:\nengine: copilot\n\nSee: %s",
engineID,
enginesStr,
suggestions[0],
constants.DocsEnginesURL)
}

return fmt.Errorf("%s", errMsg)
}

// validateSingleEngineSpecification validates that only one engine field exists across all files
Expand Down
94 changes: 94 additions & 0 deletions pkg/workflow/engine_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,97 @@ func TestValidateSingleEngineSpecificationErrorMessageQuality(t *testing.T) {
}
})
}

// TestValidateEngineDidYouMean tests the "did you mean" suggestion feature
func TestValidateEngineDidYouMean(t *testing.T) {
tests := []struct {
name string
invalidEngine string
expectedSuggestion string
shouldHaveSuggestion bool
}{
{
name: "typo copiilot suggests copilot",
invalidEngine: "copiilot",
expectedSuggestion: "copilot",
shouldHaveSuggestion: true,
},
{
name: "typo claud suggests claude",
invalidEngine: "claud",
expectedSuggestion: "claude",
shouldHaveSuggestion: true,
},
{
name: "typo codec suggests codex",
invalidEngine: "codec",
expectedSuggestion: "codex",
shouldHaveSuggestion: true,
},
{
name: "typo custon suggests custom",
invalidEngine: "custon",
expectedSuggestion: "custom",
shouldHaveSuggestion: true,
},
{
name: "case difference no suggestion (case-insensitive match)",
invalidEngine: "Copilot",
expectedSuggestion: "",
shouldHaveSuggestion: false,
},
{
name: "completely wrong gets no suggestion",
invalidEngine: "gpt4",
expectedSuggestion: "",
shouldHaveSuggestion: false,
},
{
name: "totally different gets no suggestion",
invalidEngine: "xyz",
expectedSuggestion: "",
shouldHaveSuggestion: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
compiler := NewCompiler()
err := compiler.validateEngine(tt.invalidEngine)

if err == nil {
t.Fatal("Expected validation to fail for invalid engine")
}

errorMsg := err.Error()

if tt.shouldHaveSuggestion {
// Should have "Did you mean: X?" suggestion
if !strings.Contains(errorMsg, "Did you mean:") {
t.Errorf("Expected 'Did you mean:' in error message, got: %s", errorMsg)
}

if !strings.Contains(errorMsg, tt.expectedSuggestion) {
t.Errorf("Expected suggestion '%s' in error message, got: %s",
tt.expectedSuggestion, errorMsg)
}
} else {
// Should NOT have "Did you mean:" suggestion
if strings.Contains(errorMsg, "Did you mean:") {
t.Errorf("Should not suggest anything for '%s', but got: %s",
tt.invalidEngine, errorMsg)
}
}

// All errors should still list valid engines
if !strings.Contains(errorMsg, "copilot") {
t.Errorf("Error should always list valid engines, got: %s", errorMsg)
}

// All errors should still include an example
if !strings.Contains(errorMsg, "Example:") {
t.Errorf("Error should always include an example, got: %s", errorMsg)
}
})
}
}
76 changes: 75 additions & 1 deletion pkg/workflow/github_tool_to_toolset.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ package workflow
import (
_ "embed"
"encoding/json"
"fmt"
"sort"

"github.com/github/gh-aw/pkg/logger"
"github.com/github/gh-aw/pkg/parser"
)

var githubToolToToolsetLog = logger.New("workflow:github_tool_to_toolset")
Expand Down Expand Up @@ -45,10 +48,37 @@ func ValidateGitHubToolsAgainstToolsets(allowedTools []string, enabledToolsets [
// Track missing toolsets and which tools need them
missingToolsets := make(map[string][]string) // toolset -> list of tools that need it

// Track unknown tools for suggestions
var unknownTools []string
var suggestions []string

for _, tool := range allowedTools {
// Skip wildcard - it means "allow all tools"
if tool == "*" {
continue
}

requiredToolset, exists := GitHubToolToToolsetMap[tool]
if !exists {
githubToolToToolsetLog.Printf("Tool %s not found in mapping, skipping validation", tool)
githubToolToToolsetLog.Printf("Tool %s not found in mapping, checking for typo", tool)

// Get all valid tool names for suggestion
validTools := make([]string, 0, len(GitHubToolToToolsetMap))
for validTool := range GitHubToolToToolsetMap {
validTools = append(validTools, validTool)
}
sort.Strings(validTools)
Comment on lines +65 to +70
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validTools is rebuilt and sorted for every unknown tool encountered, and then rebuilt again later for the final error message. Consider computing the sorted validTools slice once outside the loop and reusing it for both suggestion lookups and the “Valid GitHub tools include …” section to avoid unnecessary work.

Copilot uses AI. Check for mistakes.

// Try to find close matches
matches := parser.FindClosestMatches(tool, validTools, 1)
if len(matches) > 0 {
githubToolToToolsetLog.Printf("Found suggestion for unknown tool %s: %s", tool, matches[0])
unknownTools = append(unknownTools, tool)
suggestions = append(suggestions, fmt.Sprintf("%s → %s", tool, matches[0]))
} else {
githubToolToToolsetLog.Printf("No suggestion found for unknown tool: %s", tool)
unknownTools = append(unknownTools, tool)
}
// Tool not in our mapping - this could be a new tool or a typo
// We'll skip validation for unknown tools to avoid false positives
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says unknown tools are skipped “to avoid false positives”, but the function now returns an error when unknownTools is non-empty. Update or remove this comment so it matches the new behavior (unknown tools are now treated as validation failures).

Suggested change
// We'll skip validation for unknown tools to avoid false positives
// Record the unknown tool and skip further checks for this entry; unknown tools will cause validation to fail later.

Copilot uses AI. Check for mistakes.
continue
Expand All @@ -60,6 +90,36 @@ func ValidateGitHubToolsAgainstToolsets(allowedTools []string, enabledToolsets [
}
}

// Report unknown tools with suggestions if any were found
if len(unknownTools) > 0 {
githubToolToToolsetLog.Printf("Found %d unknown tools", len(unknownTools))
errMsg := fmt.Sprintf("Unknown GitHub tool(s): %s\n\n", formatList(unknownTools))

if len(suggestions) > 0 {
errMsg += "Did you mean:\n"
for _, s := range suggestions {
errMsg += fmt.Sprintf(" %s\n", s)
}
errMsg += "\n"
}

// Show a few examples of valid tools
validTools := make([]string, 0, len(GitHubToolToToolsetMap))
for tool := range GitHubToolToToolsetMap {
validTools = append(validTools, tool)
}
sort.Strings(validTools)

exampleCount := 10
if len(validTools) < exampleCount {
exampleCount = len(validTools)
}
errMsg += fmt.Sprintf("Valid GitHub tools include: %s\n\n", formatList(validTools[:exampleCount]))
errMsg += "See all tools: https://github.com/github/gh-aw/blob/main/pkg/workflow/data/github_tool_to_toolset.json"

return fmt.Errorf("%s", errMsg)
}

if len(missingToolsets) > 0 {
githubToolToToolsetLog.Printf("Validation failed: missing %d toolsets", len(missingToolsets))
return NewGitHubToolsetValidationError(missingToolsets)
Expand All @@ -68,3 +128,17 @@ func ValidateGitHubToolsAgainstToolsets(allowedTools []string, enabledToolsets [
githubToolToToolsetLog.Print("Validation successful: all tools have required toolsets")
return nil
}

// formatList formats a list of strings as a comma-separated list
func formatList(items []string) string {
if len(items) == 0 {
return ""
}
if len(items) == 1 {
return items[0]
}
if len(items) == 2 {
return items[0] + " and " + items[1]
}
return fmt.Sprintf("%s, and %s", formatList(items[:len(items)-1]), items[len(items)-1])
Comment on lines +134 to +143
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

formatList produces awkward/incorrect grammar for 3+ items (e.g., "a and b, and c" instead of "a, b, and c") due to the recursive construction. This will make the new error messages harder to read when listing multiple tools. Consider implementing proper comma-separated formatting (e.g., join all but last with ", ", then add ", and ").

Suggested change
if len(items) == 0 {
return ""
}
if len(items) == 1 {
return items[0]
}
if len(items) == 2 {
return items[0] + " and " + items[1]
}
return fmt.Sprintf("%s, and %s", formatList(items[:len(items)-1]), items[len(items)-1])
switch len(items) {
case 0:
return ""
case 1:
return items[0]
case 2:
return items[0] + " and " + items[1]
default:
result := items[0]
for i := 1; i < len(items); i++ {
if i == len(items)-1 {
result += ", and " + items[i]
} else {
result += ", " + items[i]
}
}
return result
}

Copilot uses AI. Check for mistakes.
}
Loading
Loading