// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package ast

import (
	"bytes"
	"fmt"
	"strings"

	"github.com/juju/errors"
	"github.com/pingcap/tidb/model"
	"github.com/pingcap/tidb/util/distinct"
	"github.com/pingcap/tidb/util/types"
)

var (
	_ FuncNode = &AggregateFuncExpr{}
	_ FuncNode = &FuncCallExpr{}
	_ FuncNode = &FuncCastExpr{}
)

// UnquoteString is not quoted when printed.
type UnquoteString string

// FuncCallExpr is for function expression.
type FuncCallExpr struct {
	funcNode
	// FnName is the function name.
	FnName model.CIStr
	// Args is the function args.
	Args []ExprNode
}

// Accept implements Node interface.
func (n *FuncCallExpr) Accept(v Visitor) (Node, bool) {
	newNode, skipChildren := v.Enter(n)
	if skipChildren {
		return v.Leave(newNode)
	}
	n = newNode.(*FuncCallExpr)
	for i, val := range n.Args {
		node, ok := val.Accept(v)
		if !ok {
			return n, false
		}
		n.Args[i] = node.(ExprNode)
	}
	return v.Leave(n)
}

// CastFunctionType is the type for cast function.
type CastFunctionType int

// CastFunction types
const (
	CastFunction CastFunctionType = iota + 1
	CastConvertFunction
	CastBinaryOperator
)

// FuncCastExpr is the cast function converting value to another type, e.g, cast(expr AS signed).
// See https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html
type FuncCastExpr struct {
	funcNode
	// Expr is the expression to be converted.
	Expr ExprNode
	// Tp is the conversion type.
	Tp *types.FieldType
	// Cast, Convert and Binary share this struct.
	FunctionType CastFunctionType
}

// Accept implements Node Accept interface.
func (n *FuncCastExpr) Accept(v Visitor) (Node, bool) {
	newNode, skipChildren := v.Enter(n)
	if skipChildren {
		return v.Leave(newNode)
	}
	n = newNode.(*FuncCastExpr)
	node, ok := n.Expr.Accept(v)
	if !ok {
		return n, false
	}
	n.Expr = node.(ExprNode)
	return v.Leave(n)
}

// TrimDirectionType is the type for trim direction.
type TrimDirectionType int

const (
	// TrimBothDefault trims from both direction by default.
	TrimBothDefault TrimDirectionType = iota
	// TrimBoth trims from both direction with explicit notation.
	TrimBoth
	// TrimLeading trims from left.
	TrimLeading
	// TrimTrailing trims from right.
	TrimTrailing
)

// DateArithType is type for DateArith type.
type DateArithType byte

const (
	// DateAdd is to run adddate or date_add function option.
	// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_adddate
	// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add
	DateAdd DateArithType = iota + 1
	// DateSub is to run subdate or date_sub function option.
	// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subdate
	// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-sub
	DateSub
)

// DateArithInterval is the struct of DateArith interval part.
type DateArithInterval struct {
	Unit     string
	Interval ExprNode
}

const (
	// AggFuncCount is the name of Count function.
	AggFuncCount = "count"
	// AggFuncSum is the name of Sum function.
	AggFuncSum = "sum"
	// AggFuncAvg is the name of Avg function.
	AggFuncAvg = "avg"
	// AggFuncFirstRow is the name of FirstRowColumn function.
	AggFuncFirstRow = "firstrow"
	// AggFuncMax is the name of max function.
	AggFuncMax = "max"
	// AggFuncMin is the name of min function.
	AggFuncMin = "min"
	// AggFuncGroupConcat is the name of group_concat function.
	AggFuncGroupConcat = "group_concat"
)

// AggregateFuncExpr represents aggregate function expression.
type AggregateFuncExpr struct {
	funcNode
	// F is the function name.
	F string
	// Args is the function args.
	Args []ExprNode
	// If distinct is true, the function only aggregate distinct values.
	// For example, column c1 values are "1", "2", "2",  "sum(c1)" is "5",
	// but "sum(distinct c1)" is "3".
	Distinct bool

	CurrentGroup string
	// contextPerGroupMap is used to store aggregate evaluation context.
	// Each entry for a group.
	contextPerGroupMap map[string](*AggEvaluateContext)
}

// Accept implements Node Accept interface.
func (n *AggregateFuncExpr) Accept(v Visitor) (Node, bool) {
	newNode, skipChildren := v.Enter(n)
	if skipChildren {
		return v.Leave(newNode)
	}
	n = newNode.(*AggregateFuncExpr)
	for i, val := range n.Args {
		node, ok := val.Accept(v)
		if !ok {
			return n, false
		}
		n.Args[i] = node.(ExprNode)
	}
	return v.Leave(n)
}

// Clear clears aggregate computing context.
func (n *AggregateFuncExpr) Clear() {
	n.CurrentGroup = ""
	n.contextPerGroupMap = nil
}

// Update is used for update aggregate context.
func (n *AggregateFuncExpr) Update() error {
	name := strings.ToLower(n.F)
	switch name {
	case AggFuncCount:
		return n.updateCount()
	case AggFuncFirstRow:
		return n.updateFirstRow()
	case AggFuncGroupConcat:
		return n.updateGroupConcat()
	case AggFuncMax:
		return n.updateMaxMin(true)
	case AggFuncMin:
		return n.updateMaxMin(false)
	case AggFuncSum, AggFuncAvg:
		return n.updateSum()
	}
	return nil
}

// GetContext gets aggregate evaluation context for the current group.
// If it is nil, add a new context into contextPerGroupMap.
func (n *AggregateFuncExpr) GetContext() *AggEvaluateContext {
	if n.contextPerGroupMap == nil {
		n.contextPerGroupMap = make(map[string](*AggEvaluateContext))
	}
	if _, ok := n.contextPerGroupMap[n.CurrentGroup]; !ok {
		c := &AggEvaluateContext{}
		if n.Distinct {
			c.distinctChecker = distinct.CreateDistinctChecker()
		}
		n.contextPerGroupMap[n.CurrentGroup] = c
	}
	return n.contextPerGroupMap[n.CurrentGroup]
}

func (n *AggregateFuncExpr) updateCount() error {
	ctx := n.GetContext()
	vals := make([]interface{}, 0, len(n.Args))
	for _, a := range n.Args {
		value := a.GetValue()
		if value == nil {
			return nil
		}
		vals = append(vals, value)
	}
	if n.Distinct {
		d, err := ctx.distinctChecker.Check(vals)
		if err != nil {
			return errors.Trace(err)
		}
		if !d {
			return nil
		}
	}
	ctx.Count++
	return nil
}

func (n *AggregateFuncExpr) updateFirstRow() error {
	ctx := n.GetContext()
	if ctx.evaluated {
		return nil
	}
	if len(n.Args) != 1 {
		return errors.New("Wrong number of args for AggFuncFirstRow")
	}
	ctx.Value = n.Args[0].GetValue()
	ctx.evaluated = true
	return nil
}

func (n *AggregateFuncExpr) updateMaxMin(max bool) error {
	ctx := n.GetContext()
	if len(n.Args) != 1 {
		return errors.New("Wrong number of args for AggFuncFirstRow")
	}
	v := n.Args[0].GetValue()
	if !ctx.evaluated {
		ctx.Value = v
		ctx.evaluated = true
		return nil
	}
	c, err := types.Compare(ctx.Value, v)
	if err != nil {
		return errors.Trace(err)
	}
	if max {
		if c == -1 {
			ctx.Value = v
		}
	} else {
		if c == 1 {
			ctx.Value = v
		}

	}
	return nil
}

func (n *AggregateFuncExpr) updateSum() error {
	ctx := n.GetContext()
	a := n.Args[0]
	value := a.GetValue()
	if value == nil {
		return nil
	}
	if n.Distinct {
		d, err := ctx.distinctChecker.Check([]interface{}{value})
		if err != nil {
			return errors.Trace(err)
		}
		if !d {
			return nil
		}
	}
	var err error
	ctx.Value, err = types.CalculateSum(ctx.Value, value)
	if err != nil {
		return errors.Trace(err)
	}
	ctx.Count++
	return nil
}

func (n *AggregateFuncExpr) updateGroupConcat() error {
	ctx := n.GetContext()
	vals := make([]interface{}, 0, len(n.Args))
	for _, a := range n.Args {
		value := a.GetValue()
		if value == nil {
			return nil
		}
		vals = append(vals, value)
	}
	if n.Distinct {
		d, err := ctx.distinctChecker.Check(vals)
		if err != nil {
			return errors.Trace(err)
		}
		if !d {
			return nil
		}
	}
	if ctx.Buffer == nil {
		ctx.Buffer = &bytes.Buffer{}
	} else {
		// now use comma separator
		ctx.Buffer.WriteString(",")
	}
	for _, val := range vals {
		ctx.Buffer.WriteString(fmt.Sprintf("%v", val))
	}
	// TODO: if total length is greater than global var group_concat_max_len, truncate it.
	return nil
}

// AggregateFuncExtractor visits Expr tree.
// It converts ColunmNameExpr to AggregateFuncExpr and collects AggregateFuncExpr.
type AggregateFuncExtractor struct {
	inAggregateFuncExpr bool
	// AggFuncs is the collected AggregateFuncExprs.
	AggFuncs   []*AggregateFuncExpr
	extracting bool
}

// Enter implements Visitor interface.
func (a *AggregateFuncExtractor) Enter(n Node) (node Node, skipChildren bool) {
	switch n.(type) {
	case *AggregateFuncExpr:
		a.inAggregateFuncExpr = true
	case *SelectStmt, *InsertStmt, *DeleteStmt, *UpdateStmt:
		// Enter a new context, skip it.
		// For example: select sum(c) + c + exists(select c from t) from t;
		if a.extracting {
			return n, true
		}
	}
	a.extracting = true
	return n, false
}

// Leave implements Visitor interface.
func (a *AggregateFuncExtractor) Leave(n Node) (node Node, ok bool) {
	switch v := n.(type) {
	case *AggregateFuncExpr:
		a.inAggregateFuncExpr = false
		a.AggFuncs = append(a.AggFuncs, v)
	case *ColumnNameExpr:
		// compose new AggregateFuncExpr
		if !a.inAggregateFuncExpr {
			// For example: select sum(c) + c from t;
			// The c in sum() should be evaluated for each row.
			// The c after plus should be evaluated only once.
			agg := &AggregateFuncExpr{
				F:    AggFuncFirstRow,
				Args: []ExprNode{v},
			}
			a.AggFuncs = append(a.AggFuncs, agg)
			return agg, true
		}
	}
	return n, true
}

// AggEvaluateContext is used to store intermediate result when caculation aggregate functions.
type AggEvaluateContext struct {
	distinctChecker *distinct.Checker
	Count           int64
	Value           interface{}
	Buffer          *bytes.Buffer // Buffer is used for group_concat.
	evaluated       bool
}