// Copyright 2019 Tim Shannon. All rights reserved.
// Use of this source code is governed by the MIT license
// that can be found in the LICENSE file.

package badgerhold

import (
	"fmt"
	"reflect"
	"sort"

	"github.com/dgraph-io/badger/v3"
)

// AggregateResult allows you to access the results of an aggregate query
type AggregateResult struct {
	reduction []reflect.Value // always pointers
	group     []reflect.Value
	sortby    string
}

// Group returns the field grouped by in the query
func (a *AggregateResult) Group(result ...interface{}) {
	for i := range result {
		resultVal := reflect.ValueOf(result[i])
		if resultVal.Kind() != reflect.Ptr {
			panic("result argument must be an address")
		}

		if i >= len(a.group) {
			panic(fmt.Sprintf("There is not %d elements in the grouping", i))
		}

		resultVal.Elem().Set(a.group[i])
	}
}

// Reduction is the collection of records that are part of the AggregateResult Group
func (a *AggregateResult) Reduction(result interface{}) {
	resultVal := reflect.ValueOf(result)

	if resultVal.Kind() != reflect.Ptr || resultVal.Elem().Kind() != reflect.Slice {
		panic("result argument must be a slice address")
	}

	sliceVal := resultVal.Elem()

	elType := sliceVal.Type().Elem()

	for i := range a.reduction {
		if elType.Kind() == reflect.Ptr {
			sliceVal = reflect.Append(sliceVal, a.reduction[i])
		} else {
			sliceVal = reflect.Append(sliceVal, a.reduction[i].Elem())
		}
	}

	resultVal.Elem().Set(sliceVal.Slice(0, sliceVal.Len()))
}

type aggregateResultSort AggregateResult

func (a *aggregateResultSort) Len() int { return len(a.reduction) }
func (a *aggregateResultSort) Swap(i, j int) {
	a.reduction[i], a.reduction[j] = a.reduction[j], a.reduction[i]
}
func (a *aggregateResultSort) Less(i, j int) bool {
	//reduction values are always pointers
	iVal := a.reduction[i].Elem().FieldByName(a.sortby)
	if !iVal.IsValid() {
		panic(fmt.Sprintf("The field %s does not exist in the type %s", a.sortby, a.reduction[i].Type()))
	}

	jVal := a.reduction[j].Elem().FieldByName(a.sortby)
	if !jVal.IsValid() {
		panic(fmt.Sprintf("The field %s does not exist in the type %s", a.sortby, a.reduction[j].Type()))
	}

	c, err := compare(iVal.Interface(), jVal.Interface())
	if err != nil {
		panic(err)
	}

	return c == -1
}

// Sort sorts the aggregate reduction by the passed in field in ascending order
// Sort is called automatically by calls to Min / Max to get the min and max values
func (a *AggregateResult) Sort(field string) {
	if !startsUpper(field) {
		panic("The first letter of a field must be upper-case")
	}
	if a.sortby == field {
		// already sorted
		return
	}

	a.sortby = field
	sort.Sort((*aggregateResultSort)(a))
}

// Max Returns the maxiumum value of the Aggregate Grouping, uses the Comparer interface
func (a *AggregateResult) Max(field string, result interface{}) {
	a.Sort(field)

	resultVal := reflect.ValueOf(result)
	if resultVal.Kind() != reflect.Ptr {
		panic("result argument must be an address")
	}

	if resultVal.IsNil() {
		panic("result argument must not be nil")
	}

	resultVal.Elem().Set(a.reduction[len(a.reduction)-1].Elem())
}

// Min returns the minimum value of the Aggregate Grouping, uses the Comparer interface
func (a *AggregateResult) Min(field string, result interface{}) {
	a.Sort(field)

	resultVal := reflect.ValueOf(result)
	if resultVal.Kind() != reflect.Ptr {
		panic("result argument must be an address")
	}

	if resultVal.IsNil() {
		panic("result argument must not be nil")
	}

	resultVal.Elem().Set(a.reduction[0].Elem())
}

// Avg returns the average float value of the aggregate grouping
// panics if the field cannot be converted to an float64
func (a *AggregateResult) Avg(field string) float64 {
	sum := a.Sum(field)
	return sum / float64(len(a.reduction))
}

// Sum returns the sum value of the aggregate grouping
// panics if the field cannot be converted to an float64
func (a *AggregateResult) Sum(field string) float64 {
	var sum float64

	for i := range a.reduction {
		fVal := a.reduction[i].Elem().FieldByName(field)
		if !fVal.IsValid() {
			panic(fmt.Sprintf("The field %s does not exist in the type %s", field, a.reduction[i].Type()))
		}

		sum += tryFloat(fVal)
	}

	return sum
}

// Count returns the number of records in the aggregate grouping
func (a *AggregateResult) Count() uint64 {
	return uint64(len(a.reduction))
}

// FindAggregate returns an aggregate grouping for the passed in query
// groupBy is optional
func (s *Store) FindAggregate(dataType interface{}, query *Query, groupBy ...string) ([]*AggregateResult, error) {
	var result []*AggregateResult
	var err error
	err = s.Badger().View(func(tx *badger.Txn) error {
		result, err = s.TxFindAggregate(tx, dataType, query, groupBy...)
		return err
	})

	if err != nil {
		return nil, err
	}

	return result, nil
}

// TxFindAggregate is the same as FindAggregate, but you specify your own transaction
// groupBy is optional
func (s *Store) TxFindAggregate(tx *badger.Txn, dataType interface{}, query *Query,
	groupBy ...string) ([]*AggregateResult, error) {
	return s.aggregateQuery(tx, dataType, query, groupBy...)
}

func tryFloat(val reflect.Value) float64 {
	switch val.Kind() {
	case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int8:
		return float64(val.Int())
	case reflect.Uint, reflect.Uint16,
		reflect.Uint32, reflect.Uint64, reflect.Uint8:
		return float64(val.Uint())
	case reflect.Float32, reflect.Float64:
		return val.Float()
	default:
		panic(fmt.Sprintf("The field is of Kind %s and cannot be converted to a float64", val.Kind()))
	}
}