// Copyright 2018 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package errcode

import (
	"fmt"

	"github.com/pingcap/errors"
)

// ErrorCodes return all errors (from an ErrorGroup) that are of interface ErrorCode.
// It first calls the Errors function.
func ErrorCodes(err error) []ErrorCode {
	errors := errors.Errors(err)
	errorCodes := make([]ErrorCode, len(errors))
	for i, errItem := range errors {
		if errcode, ok := errItem.(ErrorCode); ok {
			errorCodes[i] = errcode
		}
	}
	return errorCodes
}

// A MultiErrCode contains at least one ErrorCode and uses that to satisfy the ErrorCode and related interfaces
// The Error method will produce a string of all the errors with a semi-colon separation.
// Later code (such as a JSON response) needs to look for the ErrorGroup interface.
type MultiErrCode struct {
	ErrCode ErrorCode
	rest    []error
}

// Combine constructs a MultiErrCode.
// It will combine any other MultiErrCode into just one MultiErrCode.
// This is "horizontal" composition.
// If you want normal "vertical" composition use BuildChain.
func Combine(initial ErrorCode, others ...ErrorCode) MultiErrCode {
	var rest []error
	if group, ok := initial.(errors.ErrorGroup); ok {
		rest = group.Errors()
	}
	for _, other := range others {
		rest = append(rest, errors.Errors(other)...)
	}
	return MultiErrCode{
		ErrCode: initial,
		rest:    rest,
	}
}

var _ ErrorCode = (*MultiErrCode)(nil)         // assert implements interface
var _ HasClientData = (*MultiErrCode)(nil)     // assert implements interface
var _ Causer = (*MultiErrCode)(nil)            // assert implements interface
var _ errors.ErrorGroup = (*MultiErrCode)(nil) // assert implements interface
var _ fmt.Formatter = (*MultiErrCode)(nil)     // assert implements interface

func (e MultiErrCode) Error() string {
	output := e.ErrCode.Error()
	for _, item := range e.rest {
		output += "; " + item.Error()
	}
	return output
}

// Errors fullfills the ErrorGroup inteface
func (e MultiErrCode) Errors() []error {
	return append([]error{e.ErrCode.(error)}, e.rest...)
}

// Code fullfills the ErrorCode inteface
func (e MultiErrCode) Code() Code {
	return e.ErrCode.Code()
}

// Cause fullfills the Causer inteface
func (e MultiErrCode) Cause() error {
	return e.ErrCode
}

// GetClientData fullfills the HasClientData inteface
func (e MultiErrCode) GetClientData() interface{} {
	return ClientData(e.ErrCode)
}

// CodeChain resolves an error chain down to a chain of just error codes
// Any ErrorGroups found are converted to a MultiErrCode.
// Passed over error inforation is retained using ChainContext.
// If a code was overidden in the chain, it will show up as a MultiErrCode.
func CodeChain(err error) ErrorCode {
	var code ErrorCode
	currentErr := err
	chainErrCode := func(errcode ErrorCode) {
		if errcode.(error) != currentErr {
			if chained, ok := errcode.(ChainContext); ok {
				// Perhaps this is a hack because we should be passing the context to recursive CodeChain calls
				chained.Top = currentErr
				errcode = chained
			} else {
				errcode = ChainContext{currentErr, errcode}
			}
		}
		if code == nil {
			code = errcode
		} else {
			code = MultiErrCode{code, []error{code.(error), errcode.(error)}}
		}
		currentErr = errcode.(error)
	}

	for err != nil {
		if errcode, ok := err.(ErrorCode); ok {
			if code == nil || code.Code() != errcode.Code() {
				chainErrCode(errcode)
			}
		} else if eg, ok := err.(errors.ErrorGroup); ok {
			group := []ErrorCode{}
			for _, errItem := range eg.Errors() {
				if itemCode := CodeChain(errItem); itemCode != nil {
					group = append(group, itemCode)
				}
			}
			if len(group) > 0 {
				var codeGroup ErrorCode
				if len(group) == 1 {
					codeGroup = group[0]
				} else {
					codeGroup = Combine(group[0], group[1:]...)
				}
				chainErrCode(codeGroup)
			}
		}
		err = errors.Unwrap(err)
	}

	return code
}

// ChainContext is returned by ErrorCodeChain
// to retain the full wrapped error message of the error chain.
// If you annotated an ErrorCode with additional information, it is retained in the Top field.
// The Top field is used for the Error() and Cause() methods.
type ChainContext struct {
	Top     error
	ErrCode ErrorCode
}

// Code satisfies the ErrorCode interface
func (err ChainContext) Code() Code {
	return err.ErrCode.Code()
}

// Error satisfies the Error interface
func (err ChainContext) Error() string {
	return err.Top.Error()
}

// Cause satisfies the Causer interface
func (err ChainContext) Cause() error {
	if wrapped := errors.Unwrap(err.Top); wrapped != nil {
		return wrapped
	}
	return err.ErrCode
}

// GetClientData satisfies the HasClientData interface
func (err ChainContext) GetClientData() interface{} {
	return ClientData(err.ErrCode)
}

var _ ErrorCode = (*ChainContext)(nil)
var _ HasClientData = (*ChainContext)(nil)
var _ Causer = (*ChainContext)(nil)

// Format implements the Formatter interface
func (err ChainContext) Format(s fmt.State, verb rune) {
	switch verb {
	case 'v':
		if s.Flag('+') {
			fmt.Fprintf(s, "%+v\n", err.ErrCode)
			if errors.HasStack(err.ErrCode) {
				fmt.Fprintf(s, "%v", err.Top)
			} else {
				fmt.Fprintf(s, "%+v", err.Top)
			}
			return
		}
		if s.Flag('#') {
			fmt.Fprintf(s, "ChainContext{Code: %#v, Top: %#v}", err.ErrCode, err.Top)
			return
		}
		fallthrough
	case 's':
		fmt.Fprintf(s, "Code: %s. Top Error: %s", err.ErrCode.Code().CodeStr(), err.Top)
	case 'q':
		fmt.Fprintf(s, "Code: %q. Top Error: %q", err.ErrCode.Code().CodeStr(), err.Top)
	}
}

// Format implements the Formatter interface
func (e MultiErrCode) Format(s fmt.State, verb rune) {
	switch verb {
	case 'v':
		if s.Flag('+') {
			fmt.Fprintf(s, "%+v\n", e.ErrCode)
			if errors.HasStack(e.ErrCode) {
				for _, nextErr := range e.rest {
					fmt.Fprintf(s, "%v", nextErr)
				}
			} else {
				for _, nextErr := range e.rest {
					fmt.Fprintf(s, "%+v", nextErr)
				}
			}
			return
		}
		fallthrough
	case 's':
		fmt.Fprintf(s, "%s\n", e.ErrCode)
		fmt.Fprintf(s, "%s", e.rest)
	case 'q':
		fmt.Fprintf(s, "%q\n", e.ErrCode)
		fmt.Fprintf(s, "%q\n", e.rest)
	}
}