diff --git a/normalize_test.go b/normalize_test.go index 10adc60e..16001244 100644 --- a/normalize_test.go +++ b/normalize_test.go @@ -1,11 +1,11 @@ package pg_query_test import ( - "errors" "reflect" "testing" pg_query "github.com/pganalyze/pg_query_go/v4" + "github.com/pganalyze/pg_query_go/v4/parser" ) var normalizeTests = []struct { @@ -36,7 +36,12 @@ var normalizeErrorTests = []struct { }{ { "SELECT $", - errors.New("syntax error at or near \"$\""), + &parser.Error{ + Message: "syntax error at or near \"$\"", + Cursorpos: 8, + Filename: "scan.l", + Funcname: "scanner_yyerror", + }, }, } @@ -46,8 +51,17 @@ func TestNormalizeError(t *testing.T) { if actualErr == nil { t.Errorf("Normalize(%s)\nexpected error but none returned\n\n", test.input) - } else if !reflect.DeepEqual(actualErr, test.expectedErr) { - t.Errorf("Normalize(%s)\nexpected error %s\nactual error %s\n\n", test.input, test.expectedErr, actualErr) + } else { + exp := test.expectedErr.(*parser.Error) + act := actualErr.(*parser.Error) + act.Lineno = 0 // Line number is architecture dependent, so we ignore it + if !reflect.DeepEqual(act, exp) { + t.Errorf( + "Normalize(%s)\nexpected error %s at %d (%s:%d), func: %s, context: %s\nactual error %+v at %d (%s:%d), func: %s, context: %s\n\n", + test.input, + exp.Message, exp.Cursorpos, exp.Filename, exp.Lineno, exp.Funcname, exp.Context, + act.Message, act.Cursorpos, act.Filename, act.Lineno, act.Funcname, act.Context) + } } } } diff --git a/parse_test.go b/parse_test.go index 29d85b65..1f639776 100644 --- a/parse_test.go +++ b/parse_test.go @@ -1,7 +1,6 @@ package pg_query_test import ( - "errors" "fmt" "reflect" "sync" @@ -9,6 +8,7 @@ import ( "github.com/google/go-cmp/cmp" pg_query "github.com/pganalyze/pg_query_go/v4" + "github.com/pganalyze/pg_query_go/v4/parser" "google.golang.org/protobuf/testing/protocmp" ) @@ -599,11 +599,21 @@ var parseErrorTests = []struct { }{ { "SELECT $", - errors.New("syntax error at or near \"$\""), + &parser.Error{ + Message: "syntax error at or near \"$\"", + Cursorpos: 8, + Filename: "scan.l", + Funcname: "scanner_yyerror", + }, }, { "SELECT * FROM y WHERE x IN ($1, ", - errors.New("syntax error at end of input"), + &parser.Error{ + Message: "syntax error at end of input", + Cursorpos: 33, + Filename: "scan.l", + Funcname: "scanner_yyerror", + }, }, } @@ -613,8 +623,17 @@ func TestParseError(t *testing.T) { if actualErr == nil { t.Errorf("Parse(%s)\nexpected error but none returned\n\n", test.input) - } else if !reflect.DeepEqual(actualErr, test.expectedErr) { - t.Errorf("Parse(%s)\nexpected error %s\nactual error %s\n\n", test.input, test.expectedErr, actualErr) + } else { + exp := test.expectedErr.(*parser.Error) + act := actualErr.(*parser.Error) + act.Lineno = 0 // Line number is architecture dependent, so we ignore it + if !reflect.DeepEqual(act, exp) { + t.Errorf( + "Parse(%s)\nexpected error %s at %d (%s:%d), func: %s, context: %s\nactual error %+v at %d (%s:%d), func: %s, context: %s\n\n", + test.input, + exp.Message, exp.Cursorpos, exp.Filename, exp.Lineno, exp.Funcname, exp.Context, + act.Message, act.Cursorpos, act.Filename, act.Lineno, act.Funcname, act.Context) + } } } } diff --git a/parser/parser.go b/parser/parser.go index 1e855784..cdb8dafc 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -23,7 +23,6 @@ uint64_t pg_query_hash_xxh3_64(void *data, size_t len, size_t seed) { import "C" import ( - "errors" "unsafe" ) @@ -31,6 +30,37 @@ func init() { C.pg_query_init() } +type Error struct { + Message string // exception message + Funcname string // source function of exception (e.g. SearchSysCache) + Filename string // source of exception (e.g. parse.l) + Lineno int // source of exception (e.g. 104) + Cursorpos int // char in query at which exception occurred + Context string // additional context (optional, can be NULL) +} + +func (e *Error) Error() string { + return e.Message +} + +func newPgQueryError(errC *C.PgQueryError) *Error { + err := &Error{ + Message: C.GoString(errC.message), + Lineno: int(errC.lineno), + Cursorpos: int(errC.cursorpos), + } + if errC.funcname != nil { + err.Funcname = C.GoString(errC.funcname) + } + if errC.filename != nil { + err.Filename = C.GoString(errC.filename) + } + if errC.context != nil { + err.Context = C.GoString(errC.context) + } + return err +} + // ParseToJSON - Parses the given SQL statement into a parse tree (JSON format) func ParseToJSON(input string) (result string, err error) { inputC := C.CString(input) @@ -41,8 +71,7 @@ func ParseToJSON(input string) (result string, err error) { defer C.pg_query_free_parse_result(resultC) if resultC.error != nil { - errMessage := C.GoString(resultC.error.message) - err = errors.New(errMessage) + err = newPgQueryError(resultC.error) return } @@ -60,8 +89,7 @@ func ScanToProtobuf(input string) (result []byte, err error) { defer C.pg_query_free_scan_result(resultC) if resultC.error != nil { - errMessage := C.GoString(resultC.error.message) - err = errors.New(errMessage) + err = newPgQueryError(resultC.error) return } @@ -80,8 +108,7 @@ func ParseToProtobuf(input string) (result []byte, err error) { defer C.pg_query_free_protobuf_parse_result(resultC) if resultC.error != nil { - errMessage := C.GoString(resultC.error.message) - err = errors.New(errMessage) + err = newPgQueryError(resultC.error) return } @@ -100,8 +127,7 @@ func DeparseFromProtobuf(input []byte) (result string, err error) { defer C.pg_query_free_deparse_result(resultC) if resultC.error != nil { - errMessage := C.GoString(resultC.error.message) - err = errors.New(errMessage) + err = newPgQueryError(resultC.error) return } @@ -120,8 +146,7 @@ func ParsePlPgSqlToJSON(input string) (result string, err error) { defer C.pg_query_free_plpgsql_parse_result(resultC) if resultC.error != nil { - errMessage := C.GoString(resultC.error.message) - err = errors.New(errMessage) + err = newPgQueryError(resultC.error) return } @@ -139,8 +164,7 @@ func Normalize(input string) (result string, err error) { defer C.pg_query_free_normalize_result(resultC) if resultC.error != nil { - errMessage := C.GoString(resultC.error.message) - err = errors.New(errMessage) + err = newPgQueryError(resultC.error) return } @@ -158,8 +182,7 @@ func FingerprintToUInt64(input string) (result uint64, err error) { defer C.pg_query_free_fingerprint_result(resultC) if resultC.error != nil { - errMessage := C.GoString(resultC.error.message) - err = errors.New(errMessage) + err = newPgQueryError(resultC.error) return } @@ -178,8 +201,7 @@ func FingerprintToHexStr(input string) (result string, err error) { defer C.pg_query_free_fingerprint_result(resultC) if resultC.error != nil { - errMessage := C.GoString(resultC.error.message) - err = errors.New(errMessage) + err = newPgQueryError(resultC.error) return }