Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return full structured error info instead of just error message #76

Merged
merged 2 commits into from
Mar 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions normalize_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package pg_query_test

import (
"errors"
"reflect"
"testing"

pg_query "github.com/pganalyze/pg_query_go/v4"
"github.com/pganalyze/pg_query_go/v4/parser"
)

var normalizeTests = []struct {
Expand Down Expand Up @@ -36,7 +36,12 @@ var normalizeErrorTests = []struct {
}{
{
"SELECT $",
errors.New("syntax error at or near \"$\""),
&parser.Error{
Message: "syntax error at or near \"$\"",
Cursorpos: 8,
Filename: "scan.l",
Funcname: "scanner_yyerror",
},
},
}

Expand All @@ -46,8 +51,17 @@ func TestNormalizeError(t *testing.T) {

if actualErr == nil {
t.Errorf("Normalize(%s)\nexpected error but none returned\n\n", test.input)
} else if !reflect.DeepEqual(actualErr, test.expectedErr) {
t.Errorf("Normalize(%s)\nexpected error %s\nactual error %s\n\n", test.input, test.expectedErr, actualErr)
} else {
exp := test.expectedErr.(*parser.Error)
act := actualErr.(*parser.Error)
act.Lineno = 0 // Line number is architecture dependent, so we ignore it
if !reflect.DeepEqual(act, exp) {
t.Errorf(
"Normalize(%s)\nexpected error %s at %d (%s:%d), func: %s, context: %s\nactual error %+v at %d (%s:%d), func: %s, context: %s\n\n",
test.input,
exp.Message, exp.Cursorpos, exp.Filename, exp.Lineno, exp.Funcname, exp.Context,
act.Message, act.Cursorpos, act.Filename, act.Lineno, act.Funcname, act.Context)
}
}
}
}
29 changes: 24 additions & 5 deletions parse_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package pg_query_test

import (
"errors"
"fmt"
"reflect"
"sync"
"testing"

"github.com/google/go-cmp/cmp"
pg_query "github.com/pganalyze/pg_query_go/v4"
"github.com/pganalyze/pg_query_go/v4/parser"
"google.golang.org/protobuf/testing/protocmp"
)

Expand Down Expand Up @@ -599,11 +599,21 @@ var parseErrorTests = []struct {
}{
{
"SELECT $",
errors.New("syntax error at or near \"$\""),
&parser.Error{
Message: "syntax error at or near \"$\"",
Cursorpos: 8,
Filename: "scan.l",
Funcname: "scanner_yyerror",
},
},
{
"SELECT * FROM y WHERE x IN ($1, ",
errors.New("syntax error at end of input"),
&parser.Error{
Message: "syntax error at end of input",
Cursorpos: 33,
Filename: "scan.l",
Funcname: "scanner_yyerror",
},
},
}

Expand All @@ -613,8 +623,17 @@ func TestParseError(t *testing.T) {

if actualErr == nil {
t.Errorf("Parse(%s)\nexpected error but none returned\n\n", test.input)
} else if !reflect.DeepEqual(actualErr, test.expectedErr) {
t.Errorf("Parse(%s)\nexpected error %s\nactual error %s\n\n", test.input, test.expectedErr, actualErr)
} else {
exp := test.expectedErr.(*parser.Error)
act := actualErr.(*parser.Error)
act.Lineno = 0 // Line number is architecture dependent, so we ignore it
if !reflect.DeepEqual(act, exp) {
t.Errorf(
"Parse(%s)\nexpected error %s at %d (%s:%d), func: %s, context: %s\nactual error %+v at %d (%s:%d), func: %s, context: %s\n\n",
test.input,
exp.Message, exp.Cursorpos, exp.Filename, exp.Lineno, exp.Funcname, exp.Context,
act.Message, act.Cursorpos, act.Filename, act.Lineno, act.Funcname, act.Context)
}
}
}
}
Expand Down
56 changes: 39 additions & 17 deletions parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,44 @@ uint64_t pg_query_hash_xxh3_64(void *data, size_t len, size_t seed) {
import "C"

import (
"errors"
"unsafe"
)

func init() {
C.pg_query_init()
}

type Error struct {
Message string // exception message
Funcname string // source function of exception (e.g. SearchSysCache)
Filename string // source of exception (e.g. parse.l)
Lineno int // source of exception (e.g. 104)
Cursorpos int // char in query at which exception occurred
Context string // additional context (optional, can be NULL)
}

func (e *Error) Error() string {
return e.Message
}

func newPgQueryError(errC *C.PgQueryError) *Error {
err := &Error{
Message: C.GoString(errC.message),
Lineno: int(errC.lineno),
Cursorpos: int(errC.cursorpos),
}
if errC.funcname != nil {
err.Funcname = C.GoString(errC.funcname)
}
if errC.filename != nil {
err.Filename = C.GoString(errC.filename)
}
if errC.context != nil {
err.Context = C.GoString(errC.context)
}
return err
}

// ParseToJSON - Parses the given SQL statement into a parse tree (JSON format)
func ParseToJSON(input string) (result string, err error) {
inputC := C.CString(input)
Expand All @@ -41,8 +71,7 @@ func ParseToJSON(input string) (result string, err error) {
defer C.pg_query_free_parse_result(resultC)

if resultC.error != nil {
errMessage := C.GoString(resultC.error.message)
err = errors.New(errMessage)
err = newPgQueryError(resultC.error)
return
}

Expand All @@ -60,8 +89,7 @@ func ScanToProtobuf(input string) (result []byte, err error) {
defer C.pg_query_free_scan_result(resultC)

if resultC.error != nil {
errMessage := C.GoString(resultC.error.message)
err = errors.New(errMessage)
err = newPgQueryError(resultC.error)
return
}

Expand All @@ -80,8 +108,7 @@ func ParseToProtobuf(input string) (result []byte, err error) {
defer C.pg_query_free_protobuf_parse_result(resultC)

if resultC.error != nil {
errMessage := C.GoString(resultC.error.message)
err = errors.New(errMessage)
err = newPgQueryError(resultC.error)
return
}

Expand All @@ -100,8 +127,7 @@ func DeparseFromProtobuf(input []byte) (result string, err error) {
defer C.pg_query_free_deparse_result(resultC)

if resultC.error != nil {
errMessage := C.GoString(resultC.error.message)
err = errors.New(errMessage)
err = newPgQueryError(resultC.error)
return
}

Expand All @@ -120,8 +146,7 @@ func ParsePlPgSqlToJSON(input string) (result string, err error) {
defer C.pg_query_free_plpgsql_parse_result(resultC)

if resultC.error != nil {
errMessage := C.GoString(resultC.error.message)
err = errors.New(errMessage)
err = newPgQueryError(resultC.error)
return
}

Expand All @@ -139,8 +164,7 @@ func Normalize(input string) (result string, err error) {
defer C.pg_query_free_normalize_result(resultC)

if resultC.error != nil {
errMessage := C.GoString(resultC.error.message)
err = errors.New(errMessage)
err = newPgQueryError(resultC.error)
return
}

Expand All @@ -158,8 +182,7 @@ func FingerprintToUInt64(input string) (result uint64, err error) {
defer C.pg_query_free_fingerprint_result(resultC)

if resultC.error != nil {
errMessage := C.GoString(resultC.error.message)
err = errors.New(errMessage)
err = newPgQueryError(resultC.error)
return
}

Expand All @@ -178,8 +201,7 @@ func FingerprintToHexStr(input string) (result string, err error) {
defer C.pg_query_free_fingerprint_result(resultC)

if resultC.error != nil {
errMessage := C.GoString(resultC.error.message)
err = errors.New(errMessage)
err = newPgQueryError(resultC.error)
return
}

Expand Down