diff --git a/README.md b/README.md index ae8f982..4b9722a 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ A type-safe validation framework for Go using generics. This package provides co - **Flexible**: Wrap go-playground/validator, ozzo-validation, or ANY validation library - **Zero dependencies**: Core package has no external dependencies - **Battle-tested**: Optional playground package provides 100+ validators from go-playground +- **Error Detection**: All errors are wrapped in `validation.Error` for easy identification - **Clean API**: Simple, readable validation code ## Philosophy @@ -97,6 +98,69 @@ validation.Validate(age, validation.MinLength(3)) validation.Validate(age, validation.Range("0", "120")) ``` +## Error Detection + +All validation errors in Protego are wrapped in a `validation.Error` type, making it easy to detect and handle Protego-specific errors: + +```go +import ( + "errors" + "github.com/quantumcycle/protego/validation" + "github.com/quantumcycle/protego/playground" +) + +func ProcessUser(input CreateUserInput) error { + err := input.Validate() + if err != nil { + // Check if this is a Protego validation error + if validation.IsValidationError(err) { + // Handle validation errors specifically + return fmt.Errorf("validation failed: %w", err) + } + // Handle other types of errors + return fmt.Errorf("unexpected error: %w", err) + } + // Process valid input + return nil +} +``` + +### Error Detection Features + +- **Type Detection**: Use `validation.IsValidationError(err)` to check if an error came from Protego +- **Error Wrapping**: All validators wrap errors using `validation.Error`, including playground validators +- **Error Unwrapping**: Supports Go's standard `errors.Unwrap()` and `errors.Is()` functions +- **Preserved Messages**: Original error messages remain unchanged for backward compatibility + +### Examples + +```go +// Detect validation errors from core validators +err := validation.Validate("", validation.Required[string]()) +if validation.IsValidationError(err) { + fmt.Println("Protego validation error:", err.Error()) // Output: required +} + +// Detect validation errors from playground validators +err = validation.Validate("invalid-email", playground.IsEmail) +if validation.IsValidationError(err) { + fmt.Println("Email validation failed:", err.Error()) +} + +// Use with errors.Join for multiple validations +err = errors.Join( + validation.Validate("", validation.Required[string]()), + validation.Validate("ab", validation.MinLength(3)), +) +// Check if any are validation errors +if validation.IsValidationError(err) { + fmt.Println("Contains validation errors") +} + +// Error unwrapping works +originalErr := errors.Unwrap(err) +``` + ## Available Validators ### Required Validators diff --git a/playground/error_test.go b/playground/error_test.go new file mode 100644 index 0000000..ef95de0 --- /dev/null +++ b/playground/error_test.go @@ -0,0 +1,77 @@ +package playground_test + +import ( + "testing" + + . "github.com/onsi/gomega" + + "github.com/quantumcycle/protego/playground" + "github.com/quantumcycle/protego/validation" +) + +func TestPlaygroundValidatorsReturnValidationError(t *testing.T) { + t.Run("IsEmail validator returns ValidationError", func(t *testing.T) { + g := NewWithT(t) + err := validation.Validate("not-an-email", playground.IsEmail) + g.Expect(err).ToNot(BeNil()) + g.Expect(validation.IsValidationError(err)).To(BeTrue()) + }) + + t.Run("IsUUID4 validator returns ValidationError", func(t *testing.T) { + g := NewWithT(t) + err := validation.Validate("not-a-uuid", playground.IsUUID4) + g.Expect(err).ToNot(BeNil()) + g.Expect(validation.IsValidationError(err)).To(BeTrue()) + }) + + t.Run("IsURL validator returns ValidationError", func(t *testing.T) { + g := NewWithT(t) + err := validation.Validate("not-a-url", playground.IsURL) + g.Expect(err).ToNot(BeNil()) + g.Expect(validation.IsValidationError(err)).To(BeTrue()) + }) + + t.Run("IsIPv4 validator returns ValidationError", func(t *testing.T) { + g := NewWithT(t) + err := validation.Validate("not-an-ip", playground.IsIPv4) + g.Expect(err).ToNot(BeNil()) + g.Expect(validation.IsValidationError(err)).To(BeTrue()) + }) + + t.Run("FromTag validator returns ValidationError", func(t *testing.T) { + g := NewWithT(t) + customValidator := playground.FromTag[string]("uuid") + err := validation.Validate("not-a-uuid", customValidator) + g.Expect(err).ToNot(BeNil()) + g.Expect(validation.IsValidationError(err)).To(BeTrue()) + }) + + t.Run("FromTagWithMessage validator returns ValidationError", func(t *testing.T) { + g := NewWithT(t) + customValidator := playground.FromTagWithMessage[string]("uuid", "custom error message") + err := validation.Validate("not-a-uuid", customValidator) + g.Expect(err).ToNot(BeNil()) + g.Expect(err.Error()).To(Equal("custom error message")) + g.Expect(validation.IsValidationError(err)).To(BeTrue()) + }) +} + +func TestPlaygroundValidatorsPass(t *testing.T) { + t.Run("IsEmail validator passes for valid email", func(t *testing.T) { + g := NewWithT(t) + err := validation.Validate("test@example.com", playground.IsEmail) + g.Expect(err).To(BeNil()) + }) + + t.Run("IsUUID4 validator passes for valid UUID", func(t *testing.T) { + g := NewWithT(t) + err := validation.Validate("f47ac10b-58cc-4372-a567-0e02b2c3d479", playground.IsUUID4) + g.Expect(err).To(BeNil()) + }) + + t.Run("IsURL validator passes for valid URL", func(t *testing.T) { + g := NewWithT(t) + err := validation.Validate("https://example.com", playground.IsURL) + g.Expect(err).To(BeNil()) + }) +} diff --git a/playground/playground.go b/playground/playground.go index 04e3d44..b7b563d 100644 --- a/playground/playground.go +++ b/playground/playground.go @@ -19,8 +19,6 @@ package playground import ( - "fmt" - "github.com/go-playground/validator/v10" "github.com/quantumcycle/protego/validation" @@ -43,7 +41,10 @@ var sharedValidator = validator.New() // See https://pkg.go.dev/github.com/go-playground/validator/v10 for all available tags. func FromTag[T any](tag string) validation.Validator[T] { return func(v T) error { - return sharedValidator.Var(v, tag) + if err := sharedValidator.Var(v, tag); err != nil { + return validation.WrapError(err) + } + return nil } } @@ -57,7 +58,7 @@ func FromTag[T any](tag string) validation.Validator[T] { func FromTagWithMessage[T any](tag, message string) validation.Validator[T] { return func(v T) error { if err := sharedValidator.Var(v, tag); err != nil { - return fmt.Errorf("%s", message) + return validation.NewValidationError(message) } return nil } diff --git a/validation/collection.go b/validation/collection.go index 54c9eb7..7a26139 100644 --- a/validation/collection.go +++ b/validation/collection.go @@ -27,7 +27,7 @@ func In[T comparable](caseInsensitive bool, allowed ...T) Validator[T] { } else if slices.Contains(allowed, v) { return nil } - return fmt.Errorf("must be one of: %v", allowed) + return NewValidationError(fmt.Sprintf("must be one of: %v", allowed)) } } @@ -53,11 +53,11 @@ func NotIn[T comparable](caseInsensitive bool, forbidden ...T) Validator[T] { vs := strings.ToLower(fmt.Sprint(v)) for _, f := range forbidden { if strings.ToLower(fmt.Sprint(f)) == vs { - return fmt.Errorf("cannot be one of: %v", forbidden) + return NewValidationError(fmt.Sprintf("cannot be one of: %v", forbidden)) } } } else if slices.Contains(forbidden, v) { - return fmt.Errorf("cannot be one of: %v", forbidden) + return NewValidationError(fmt.Sprintf("cannot be one of: %v", forbidden)) } return nil } @@ -75,7 +75,7 @@ func Each[T any](elementValidator Validator[T]) Validator[[]T] { var errs []error for i, v := range values { if err := elementValidator(v); err != nil { - errs = append(errs, fmt.Errorf("index %d: %w", i, err)) + errs = append(errs, WrapError(fmt.Errorf("index %d: %w", i, err))) } } return errors.Join(errs...) @@ -90,7 +90,7 @@ func Each[T any](elementValidator Validator[T]) Validator[[]T] { func NotEmpty[T any]() Validator[[]T] { return func(values []T) error { if len(values) == 0 { - return fmt.Errorf("cannot be empty") + return NewValidationError("cannot be empty") } return nil } @@ -104,7 +104,7 @@ func NotEmpty[T any]() Validator[[]T] { func MinItems[T any](minimum int) Validator[[]T] { return func(values []T) error { if len(values) < minimum { - return fmt.Errorf("must have at least %d items", minimum) + return NewValidationError(fmt.Sprintf("must have at least %d items", minimum)) } return nil } @@ -118,7 +118,7 @@ func MinItems[T any](minimum int) Validator[[]T] { func MaxItems[T any](maximum int) Validator[[]T] { return func(values []T) error { if len(values) > maximum { - return fmt.Errorf("must have at most %d items", maximum) + return NewValidationError(fmt.Sprintf("must have at most %d items", maximum)) } return nil } @@ -134,7 +134,7 @@ func UniqueItems[T comparable]() Validator[[]T] { seen := make(map[T]bool) for i, v := range values { if seen[v] { - return fmt.Errorf("duplicate item at index %d: %v", i, v) + return NewValidationError(fmt.Sprintf("duplicate item at index %d: %v", i, v)) } seen[v] = true } @@ -182,13 +182,13 @@ func ValidateStringMap(m map[string]string, allowExtra bool, rules ...MapKeyRule value, exists := m[rule.key] if !exists && rule.required { - return fmt.Errorf("key %q is required", rule.key) + return NewValidationError(fmt.Sprintf("key %q is required", rule.key)) } if exists { for _, validator := range rule.validators { if err := validator(value); err != nil { - return fmt.Errorf("key %q: %w", rule.key, err) + return WrapError(fmt.Errorf("key %q: %w", rule.key, err)) } } } @@ -198,7 +198,7 @@ func ValidateStringMap(m map[string]string, allowExtra bool, rules ...MapKeyRule if !allowExtra { for key := range m { if !validated[key] { - return fmt.Errorf("key %q not expected", key) + return NewValidationError(fmt.Sprintf("key %q not expected", key)) } } } @@ -229,13 +229,13 @@ func ValidateAnyMap(m map[string]any, allowExtra bool, rules ...MapKeyRule[any]) value, exists := m[rule.key] if !exists && rule.required { - return fmt.Errorf("key %q is required", rule.key) + return NewValidationError(fmt.Sprintf("key %q is required", rule.key)) } if exists { for _, validator := range rule.validators { if err := validator(value); err != nil { - return fmt.Errorf("key %q: %w", rule.key, err) + return WrapError(fmt.Errorf("key %q: %w", rule.key, err)) } } } @@ -245,7 +245,7 @@ func ValidateAnyMap(m map[string]any, allowExtra bool, rules ...MapKeyRule[any]) if !allowExtra { for key := range m { if !validated[key] { - return fmt.Errorf("key %q not expected", key) + return NewValidationError(fmt.Sprintf("key %q not expected", key)) } } } @@ -263,7 +263,7 @@ func StringValidator(validator Validator[string]) Validator[any] { return func(v any) error { str, ok := v.(string) if !ok { - return fmt.Errorf("must be a string") + return NewValidationError("must be a string") } return validator(str) } @@ -286,7 +286,7 @@ func IntValidator(validator Validator[int]) Validator[any] { case int64: return validator(int(val)) default: - return fmt.Errorf("must be a number") + return NewValidationError("must be a number") } } } @@ -310,7 +310,7 @@ func FloatValidator(validator Validator[float64]) Validator[any] { case int64: return validator(float64(val)) default: - return fmt.Errorf("must be a number") + return NewValidationError("must be a number") } } } @@ -328,7 +328,7 @@ func BoolValidator(validator Validator[bool]) Validator[any] { return func(v any) error { val, ok := v.(bool) if !ok { - return fmt.Errorf("must be a boolean") + return NewValidationError("must be a boolean") } return validator(val) } diff --git a/validation/date.go b/validation/date.go index 47f5d26..20d34d8 100644 --- a/validation/date.go +++ b/validation/date.go @@ -15,7 +15,7 @@ func IsRFC3339DateTime() Validator[string] { return func(v string) error { _, err := time.Parse(time.RFC3339, v) if err != nil { - return fmt.Errorf("must be a valid RFC3339 date-time") + return NewValidationError("must be a valid RFC3339 date-time") } return nil } @@ -30,7 +30,7 @@ func IsISO8601Date() Validator[string] { return func(v string) error { _, err := time.Parse("2006-01-02", v) if err != nil { - return fmt.Errorf("must be a valid ISO8601 date (YYYY-MM-DD)") + return NewValidationError("must be a valid ISO8601 date (YYYY-MM-DD)") } return nil } @@ -46,7 +46,7 @@ func IsDateFormat(layout string) Validator[string] { return func(v string) error { _, err := time.Parse(layout, v) if err != nil { - return fmt.Errorf("must match date format %q", layout) + return NewValidationError(fmt.Sprintf("must match date format %q", layout)) } return nil } @@ -62,10 +62,10 @@ func IsFutureDateFormat(layout string) Validator[string] { return func(v string) error { t, err := time.Parse(layout, v) if err != nil { - return fmt.Errorf("invalid date format") + return NewValidationError("invalid date format") } if !t.After(time.Now()) { - return fmt.Errorf("must be a future date") + return NewValidationError("must be a future date") } return nil } @@ -91,10 +91,10 @@ func IsPastDateFormat(layout string) Validator[string] { return func(v string) error { t, err := time.Parse(layout, v) if err != nil { - return fmt.Errorf("invalid date format") + return NewValidationError("invalid date format") } if !t.Before(time.Now()) { - return fmt.Errorf("must be a past date") + return NewValidationError("must be a past date") } return nil } @@ -120,14 +120,14 @@ func IsDateBeforeFormat(beforeDate, layout string) Validator[string] { return func(v string) error { t, err := time.Parse(layout, v) if err != nil { - return fmt.Errorf("invalid date format") + return NewValidationError("invalid date format") } before, err := time.Parse(layout, beforeDate) if err != nil { - return fmt.Errorf("invalid before date format") + return NewValidationError("invalid before date format") } if !t.Before(before) { - return fmt.Errorf("must be before %s", before.Format(layout)) + return NewValidationError(fmt.Sprintf("must be before %s", before.Format(layout))) } return nil } @@ -153,14 +153,14 @@ func IsDateAfterFormat(afterDate, layout string) Validator[string] { return func(v string) error { t, err := time.Parse(layout, v) if err != nil { - return fmt.Errorf("invalid date format") + return NewValidationError("invalid date format") } after, err := time.Parse(layout, afterDate) if err != nil { - return fmt.Errorf("invalid after date format") + return NewValidationError("invalid after date format") } if !t.After(after) { - return fmt.Errorf("must be after %s", after.Format(layout)) + return NewValidationError(fmt.Sprintf("must be after %s", after.Format(layout))) } return nil } @@ -184,7 +184,7 @@ func IsDateAfter(afterDate string) Validator[string] { func IsFutureTime() Validator[time.Time] { return func(v time.Time) error { if !v.After(time.Now()) { - return fmt.Errorf("must be a future time") + return NewValidationError("must be a future time") } return nil } @@ -198,7 +198,7 @@ func IsFutureTime() Validator[time.Time] { func IsPastTime() Validator[time.Time] { return func(v time.Time) error { if !v.Before(time.Now()) { - return fmt.Errorf("must be a past time") + return NewValidationError("must be a past time") } return nil } @@ -212,7 +212,7 @@ func IsPastTime() Validator[time.Time] { func IsTimeBefore(before time.Time) Validator[time.Time] { return func(v time.Time) error { if !v.Before(before) { - return fmt.Errorf("must be before %s", before.Format(time.RFC3339)) + return NewValidationError(fmt.Sprintf("must be before %s", before.Format(time.RFC3339))) } return nil } @@ -226,7 +226,7 @@ func IsTimeBefore(before time.Time) Validator[time.Time] { func IsTimeAfter(after time.Time) Validator[time.Time] { return func(v time.Time) error { if !v.After(after) { - return fmt.Errorf("must be after %s", after.Format(time.RFC3339)) + return NewValidationError(fmt.Sprintf("must be after %s", after.Format(time.RFC3339))) } return nil } diff --git a/validation/error_test.go b/validation/error_test.go new file mode 100644 index 0000000..dfbd893 --- /dev/null +++ b/validation/error_test.go @@ -0,0 +1,158 @@ +package validation_test + +import ( + "errors" + "testing" + + . "github.com/onsi/gomega" + + "github.com/quantumcycle/protego/validation" +) + +func TestValidationError(t *testing.T) { + t.Run("NewValidationError creates a ValidationError", func(t *testing.T) { + g := NewWithT(t) + err := validation.NewValidationError("test error") + g.Expect(err).ToNot(BeNil()) + g.Expect(err.Error()).To(Equal("test error")) + g.Expect(validation.IsValidationError(err)).To(BeTrue()) + }) + + t.Run("WrapError wraps an existing error", func(t *testing.T) { + g := NewWithT(t) + originalErr := errors.New("original error") + wrappedErr := validation.WrapError(originalErr) + g.Expect(wrappedErr).ToNot(BeNil()) + g.Expect(wrappedErr.Error()).To(Equal("original error")) + g.Expect(validation.IsValidationError(wrappedErr)).To(BeTrue()) + }) + + t.Run("WrapError preserves ValidationError", func(t *testing.T) { + g := NewWithT(t) + originalErr := validation.NewValidationError("validation error") + wrappedErr := validation.WrapError(originalErr) + g.Expect(wrappedErr).To(Equal(originalErr)) + }) + + t.Run("WrapError returns nil for nil error", func(t *testing.T) { + g := NewWithT(t) + wrappedErr := validation.WrapError(nil) + g.Expect(wrappedErr).To(BeNil()) + }) + + t.Run("IsValidationError returns false for non-ValidationError", func(t *testing.T) { + g := NewWithT(t) + err := errors.New("regular error") + g.Expect(validation.IsValidationError(err)).To(BeFalse()) + }) + + t.Run("IsValidationError returns false for nil", func(t *testing.T) { + g := NewWithT(t) + g.Expect(validation.IsValidationError(nil)).To(BeFalse()) + }) + + t.Run("Error unwrapping works", func(t *testing.T) { + g := NewWithT(t) + originalErr := errors.New("original error") + wrappedErr := validation.WrapError(originalErr) + g.Expect(errors.Unwrap(wrappedErr)).To(Equal(originalErr)) + }) + + t.Run("errors.Is works with validation.Error", func(t *testing.T) { + g := NewWithT(t) + err1 := validation.NewValidationError("test error") + err2 := validation.NewValidationError("another error") + g.Expect(errors.Is(err1, &validation.Error{})).To(BeTrue()) + g.Expect(errors.Is(err2, &validation.Error{})).To(BeTrue()) + }) +} + +func TestValidatorsReturnValidationError(t *testing.T) { + t.Run("Required validator returns ValidationError", func(t *testing.T) { + g := NewWithT(t) + err := validation.Validate("", validation.Required[string]()) + g.Expect(err).ToNot(BeNil()) + g.Expect(validation.IsValidationError(err)).To(BeTrue()) + }) + + t.Run("MinLength validator returns ValidationError", func(t *testing.T) { + g := NewWithT(t) + err := validation.Validate("ab", validation.MinLength(3)) + g.Expect(err).ToNot(BeNil()) + g.Expect(validation.IsValidationError(err)).To(BeTrue()) + }) + + t.Run("Range validator returns ValidationError", func(t *testing.T) { + g := NewWithT(t) + err := validation.Validate(150, validation.Range(0, 120)) + g.Expect(err).ToNot(BeNil()) + g.Expect(validation.IsValidationError(err)).To(BeTrue()) + }) + + t.Run("In validator returns ValidationError", func(t *testing.T) { + g := NewWithT(t) + err := validation.Validate("invalid", validation.In(false, "valid1", "valid2")) + g.Expect(err).ToNot(BeNil()) + g.Expect(validation.IsValidationError(err)).To(BeTrue()) + }) + + t.Run("NotEmpty validator returns ValidationError", func(t *testing.T) { + g := NewWithT(t) + err := validation.Validate([]string{}, validation.NotEmpty[string]()) + g.Expect(err).ToNot(BeNil()) + g.Expect(validation.IsValidationError(err)).To(BeTrue()) + }) + + t.Run("IsRFC3339DateTime validator returns ValidationError", func(t *testing.T) { + g := NewWithT(t) + err := validation.Validate("invalid-date", validation.IsRFC3339DateTime()) + g.Expect(err).ToNot(BeNil()) + g.Expect(validation.IsValidationError(err)).To(BeTrue()) + }) + + t.Run("WithMessage validator returns ValidationError", func(t *testing.T) { + g := NewWithT(t) + err := validation.Validate("", validation.WithMessage(validation.Required[string](), "custom message")) + g.Expect(err).ToNot(BeNil()) + g.Expect(err.Error()).To(Equal("custom message")) + g.Expect(validation.IsValidationError(err)).To(BeTrue()) + }) + + t.Run("Or validator returns ValidationError", func(t *testing.T) { + g := NewWithT(t) + err := validation.Validate("test", validation.Or( + validation.MinLength(10), + validation.MaxLength(2), + )) + g.Expect(err).ToNot(BeNil()) + g.Expect(validation.IsValidationError(err)).To(BeTrue()) + }) + + t.Run("NotNil validator returns ValidationError", func(t *testing.T) { + g := NewWithT(t) + var nilStr *string + err := validation.Validate(nilStr, validation.NotNil[string]()) + g.Expect(err).ToNot(BeNil()) + g.Expect(validation.IsValidationError(err)).To(BeTrue()) + }) +} + +func TestErrorMessagesPreserved(t *testing.T) { + t.Run("Required error message preserved", func(t *testing.T) { + g := NewWithT(t) + err := validation.Validate("", validation.Required[string]()) + g.Expect(err.Error()).To(Equal("required")) + }) + + t.Run("MinLength error message preserved", func(t *testing.T) { + g := NewWithT(t) + err := validation.Validate("ab", validation.MinLength(3)) + g.Expect(err.Error()).To(Equal("must be at least 3 characters")) + }) + + t.Run("Range error message preserved", func(t *testing.T) { + g := NewWithT(t) + err := validation.Validate(150, validation.Range(0, 120)) + g.Expect(err.Error()).To(Equal("must be between 0 and 120")) + }) +} diff --git a/validation/helpers.go b/validation/helpers.go index 026c96f..96f9f95 100644 --- a/validation/helpers.go +++ b/validation/helpers.go @@ -19,7 +19,7 @@ import ( func WithMessage[T any](validator Validator[T], message string) Validator[T] { return func(v T) error { if err := validator(v); err != nil { - return fmt.Errorf("%s", message) + return NewValidationError(message) } return nil } @@ -72,7 +72,7 @@ func Or[T any](validators ...Validator[T]) Validator[T] { if len(errs) == 1 { return errs[0] } - return fmt.Errorf("all validators failed: %w", errors.Join(errs...)) + return WrapError(fmt.Errorf("all validators failed: %w", errors.Join(errs...))) } } @@ -88,7 +88,7 @@ func Not[T any](validator Validator[T]) Validator[T] { if err := validator(v); err != nil { return nil // Validator failed, so Not passes } - return fmt.Errorf("validation should have failed but passed") + return NewValidationError("validation should have failed but passed") } } diff --git a/validation/numeric.go b/validation/numeric.go index ee534aa..bc243b0 100644 --- a/validation/numeric.go +++ b/validation/numeric.go @@ -15,7 +15,7 @@ import ( func Min[T constraints.Ordered](minimum T) Validator[T] { return func(v T) error { if v < minimum { - return fmt.Errorf("must be at least %v", minimum) + return NewValidationError(fmt.Sprintf("must be at least %v", minimum)) } return nil } @@ -30,7 +30,7 @@ func Min[T constraints.Ordered](minimum T) Validator[T] { func Max[T constraints.Ordered](maximum T) Validator[T] { return func(v T) error { if v > maximum { - return fmt.Errorf("must be at most %v", maximum) + return NewValidationError(fmt.Sprintf("must be at most %v", maximum)) } return nil } @@ -45,7 +45,7 @@ func Max[T constraints.Ordered](maximum T) Validator[T] { func Range[T constraints.Ordered](minimum, maximum T) Validator[T] { return func(v T) error { if v < minimum || v > maximum { - return fmt.Errorf("must be between %v and %v", minimum, maximum) + return NewValidationError(fmt.Sprintf("must be between %v and %v", minimum, maximum)) } return nil } @@ -59,7 +59,7 @@ func Range[T constraints.Ordered](minimum, maximum T) Validator[T] { func GreaterThan[T constraints.Ordered](threshold T) Validator[T] { return func(v T) error { if v <= threshold { - return fmt.Errorf("must be greater than %v", threshold) + return NewValidationError(fmt.Sprintf("must be greater than %v", threshold)) } return nil } @@ -73,7 +73,7 @@ func GreaterThan[T constraints.Ordered](threshold T) Validator[T] { func LessThan[T constraints.Ordered](threshold T) Validator[T] { return func(v T) error { if v >= threshold { - return fmt.Errorf("must be less than %v", threshold) + return NewValidationError(fmt.Sprintf("must be less than %v", threshold)) } return nil } @@ -89,7 +89,7 @@ func Positive[T constraints.Ordered]() Validator[T] { return func(v T) error { var zero T if v <= zero { - return fmt.Errorf("must be positive") + return NewValidationError("must be positive") } return nil } @@ -104,7 +104,7 @@ func NonNegative[T constraints.Ordered]() Validator[T] { return func(v T) error { var zero T if v < zero { - return fmt.Errorf("must be non-negative") + return NewValidationError("must be non-negative") } return nil } @@ -119,7 +119,7 @@ func Negative[T constraints.Ordered]() Validator[T] { return func(v T) error { var zero T if v >= zero { - return fmt.Errorf("must be negative") + return NewValidationError("must be negative") } return nil } @@ -133,7 +133,7 @@ func Negative[T constraints.Ordered]() Validator[T] { func MultipleOf[T constraints.Integer](divisor T) Validator[T] { return func(v T) error { if v%divisor != 0 { - return fmt.Errorf("must be a multiple of %v", divisor) + return NewValidationError(fmt.Sprintf("must be a multiple of %v", divisor)) } return nil } diff --git a/validation/optional.go b/validation/optional.go index e5ab824..05e44c7 100644 --- a/validation/optional.go +++ b/validation/optional.go @@ -1,7 +1,5 @@ package validation -import "fmt" - // NilOrNotEmpty validates that a pointer to string is either nil or not empty. // This is useful for optional fields that, if provided, must not be empty. // @@ -18,7 +16,7 @@ func NilOrNotEmpty() Validator[*string] { return nil // Nil is okay } if *v == "" { - return fmt.Errorf("cannot be empty string (must be nil or non-empty)") + return NewValidationError("cannot be empty string (must be nil or non-empty)") } return nil } @@ -65,7 +63,7 @@ func NilOr[T any](validator Validator[T]) Validator[*T] { func NotNil[T any]() Validator[*T] { return func(v *T) error { if v == nil { - return fmt.Errorf("cannot be nil") + return NewValidationError("cannot be nil") } return nil } diff --git a/validation/required.go b/validation/required.go index 8142477..ccb4473 100644 --- a/validation/required.go +++ b/validation/required.go @@ -1,7 +1,5 @@ package validation -import "fmt" - // Required validates that a value is not the zero value for its type. // For strings, this means not empty. For numbers, this means not zero. // For pointers, this means not nil. @@ -14,7 +12,7 @@ func Required[T comparable]() Validator[T] { return func(v T) error { var zero T if v == zero { - return fmt.Errorf("required") + return NewValidationError("required") } return nil } @@ -35,7 +33,7 @@ func RequiredIf[T comparable](condition bool) Validator[T] { } var zero T if v == zero { - return fmt.Errorf("required") + return NewValidationError("required") } return nil } diff --git a/validation/string.go b/validation/string.go index 7dbaa13..7bb2eed 100644 --- a/validation/string.go +++ b/validation/string.go @@ -15,7 +15,7 @@ import ( func MinLength(minimum int) Validator[string] { return func(v string) error { if len(v) < minimum { - return fmt.Errorf("must be at least %d characters", minimum) + return NewValidationError(fmt.Sprintf("must be at least %d characters", minimum)) } return nil } @@ -29,7 +29,7 @@ func MinLength(minimum int) Validator[string] { func MaxLength(maximum int) Validator[string] { return func(v string) error { if len(v) > maximum { - return fmt.Errorf("must be at most %d characters", maximum) + return NewValidationError(fmt.Sprintf("must be at most %d characters", maximum)) } return nil } @@ -44,7 +44,7 @@ func Length(minimum, maximum int) Validator[string] { return func(v string) error { length := len(v) if length < minimum || length > maximum { - return fmt.Errorf("must be between %d and %d characters", minimum, maximum) + return NewValidationError(fmt.Sprintf("must be between %d and %d characters", minimum, maximum)) } return nil } @@ -58,7 +58,7 @@ func Length(minimum, maximum int) Validator[string] { func IsInt() Validator[string] { return func(v string) error { if _, err := strconv.Atoi(v); err != nil { - return fmt.Errorf("must be a valid integer") + return NewValidationError("must be a valid integer") } return nil } @@ -73,7 +73,7 @@ func MatchesPattern(pattern string) Validator[string] { regex := regexp.MustCompile(pattern) return func(v string) error { if !regex.MatchString(v) { - return fmt.Errorf("must match pattern %q", pattern) + return NewValidationError(fmt.Sprintf("must match pattern %q", pattern)) } return nil } @@ -87,7 +87,7 @@ func MatchesPattern(pattern string) Validator[string] { func StartsWith(prefix string) Validator[string] { return func(v string) error { if !strings.HasPrefix(v, prefix) { - return fmt.Errorf("must start with %q", prefix) + return NewValidationError(fmt.Sprintf("must start with %q", prefix)) } return nil } @@ -101,7 +101,7 @@ func StartsWith(prefix string) Validator[string] { func EndsWith(suffix string) Validator[string] { return func(v string) error { if !strings.HasSuffix(v, suffix) { - return fmt.Errorf("must end with %q", suffix) + return NewValidationError(fmt.Sprintf("must end with %q", suffix)) } return nil } @@ -115,7 +115,7 @@ func EndsWith(suffix string) Validator[string] { func Contains(substring string) Validator[string] { return func(v string) error { if !strings.Contains(v, substring) { - return fmt.Errorf("must contain %q", substring) + return NewValidationError(fmt.Sprintf("must contain %q", substring)) } return nil } diff --git a/validation/validation.go b/validation/validation.go index cf2eb80..e87e842 100644 --- a/validation/validation.go +++ b/validation/validation.go @@ -20,6 +20,8 @@ // } package validation +import "errors" + // Validator is a generic validation function that validates a value of type T. // It returns an error if validation fails, or nil if the value is valid. type Validator[T any] func(T) error @@ -86,3 +88,70 @@ func Nested[T Validatable]() Validator[T] { return v.Validate() } } + +// Error represents an error that occurred during validation. +// It wraps validation errors to make them identifiable as Protego errors. +type Error struct { + msg string + err error +} + +// Error returns the error message. +func (e *Error) Error() string { + if e.err != nil { + return e.err.Error() + } + return e.msg +} + +// Unwrap returns the underlying error, if any. +func (e *Error) Unwrap() error { + return e.err +} + +// Is allows Error to work with errors.Is(). +func (e *Error) Is(target error) bool { + _, ok := target.(*Error) + return ok +} + +// NewValidationError creates a new validation Error with the given message. +// This should be used for creating new validation errors in validators. +// +// Example: +// +// return validation.NewValidationError("must be at least 3 characters") +func NewValidationError(msg string) error { + return &Error{msg: msg} +} + +// WrapError wraps an existing error as a validation Error. +// If the error is already a validation Error, it returns it as-is. +// This is useful for wrapping errors from external libraries (like go-playground/validator). +// +// Example: +// +// return validation.WrapError(externalLibraryError) +func WrapError(err error) error { + if err == nil { + return nil + } + var valErr *Error + if errors.As(err, &valErr) { + return err + } + return &Error{err: err} +} + +// IsValidationError checks if an error is a validation Error or wraps one. +// This allows users to detect if an error came from Protego validation. +// +// Example: +// +// if validation.IsValidationError(err) { +// // Handle validation error +// } +func IsValidationError(err error) bool { + var valErr *Error + return errors.As(err, &valErr) +}