package sanitize

import (
	
	
	
	
	
	
	
	
	
)

// Part is either a string or an int. A string is raw SQL. An int is a
// argument placeholder.
type Part any

type Query struct {
	Parts []Part
}

// utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement
// character. utf8.RuneError is not an error if it is also width 3.
//
// https://github.com/jackc/pgx/issues/1380
const replacementcharacterwidth = 3

const maxBufSize = 16384 // 16 Ki

var bufPool = &pool[*bytes.Buffer]{
	new: func() *bytes.Buffer {
		return &bytes.Buffer{}
	},
	reset: func( *bytes.Buffer) bool {
		 := .Len()
		.Reset()
		return  < maxBufSize
	},
}

var null = []byte("null")

func ( *Query) ( ...any) (string, error) {
	 := make([]bool, len())
	 := bufPool.get()
	defer bufPool.put()

	for ,  := range .Parts {
		switch part := .(type) {
		case string:
			.WriteString()
		case int:
			 :=  - 1
			var  []byte
			if  < 0 {
				return "", fmt.Errorf("first sql argument must be > 0")
			}

			if  >= len() {
				return "", fmt.Errorf("insufficient arguments")
			}

			// Prevent SQL injection via Line Comment Creation
			// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
			.WriteByte(' ')

			 := []
			switch arg := .(type) {
			case nil:
				 = null
			case int64:
				 = strconv.AppendInt(.AvailableBuffer(), , 10)
			case float64:
				 = strconv.AppendFloat(.AvailableBuffer(), , 'f', -1, 64)
			case bool:
				 = strconv.AppendBool(.AvailableBuffer(), )
			case []byte:
				 = QuoteBytes(.AvailableBuffer(), )
			case string:
				 = QuoteString(.AvailableBuffer(), )
			case time.Time:
				 = .Truncate(time.Microsecond).
					AppendFormat(.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'")
			default:
				return "", fmt.Errorf("invalid arg type: %T", )
			}
			[] = true

			.Write()

			// Prevent SQL injection via Line Comment Creation
			// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
			.WriteByte(' ')
		default:
			return "", fmt.Errorf("invalid Part type: %T", )
		}
	}

	for ,  := range  {
		if ! {
			return "", fmt.Errorf("unused argument: %d", )
		}
	}
	return .String(), nil
}

func ( string) (*Query, error) {
	 := &Query{}
	.init()

	return , nil
}

var sqlLexerPool = &pool[*sqlLexer]{
	new: func() *sqlLexer {
		return &sqlLexer{}
	},
	reset: func( *sqlLexer) bool {
		* = sqlLexer{}
		return true
	},
}

func ( *Query) ( string) {
	 := .Parts[:0]
	if  == nil {
		// dirty, but fast heuristic to preallocate for ~90% usecases
		 := strings.Count(, "$") + strings.Count(, "--") + 1
		 = make([]Part, 0, )
	}

	 := sqlLexerPool.get()
	defer sqlLexerPool.put()

	.src = 
	.stateFn = rawState
	.parts = 

	for .stateFn != nil {
		.stateFn = .stateFn()
	}

	.Parts = .parts
}

func ( []byte,  string) []byte {
	const  = '\''

	// Preallocate space for the worst case scenario
	 = slices.Grow(, len()*2+2)

	// Add opening quote
	 = append(, )

	// Iterate through the string without allocating
	for  := 0;  < len(); ++ {
		if [] ==  {
			 = append(, , )
		} else {
			 = append(, [])
		}
	}

	// Add closing quote
	 = append(, )

	return 
}

func (,  []byte) []byte {
	if len() == 0 {
		return append(, `'\x'`...)
	}

	// Calculate required length
	 := 3 + hex.EncodedLen(len()) + 1

	// Ensure dst has enough capacity
	if cap()-len() <  {
		 := make([]byte, len(), len()+)
		copy(, )
		 = 
	}

	// Record original length and extend slice
	 := len()
	 = [:+]

	// Add prefix
	[] = '\''
	[+1] = '\\'
	[+2] = 'x'

	// Encode bytes directly into dst
	hex.Encode([+3:len()-1], )

	// Add suffix
	[len()-1] = '\''

	return 
}

type sqlLexer struct {
	src     string
	start   int
	pos     int
	nested  int // multiline comment nesting level.
	stateFn stateFn
	parts   []Part
}

type stateFn func(*sqlLexer) stateFn

func ( *sqlLexer) stateFn {
	for {
		,  := utf8.DecodeRuneInString(.src[.pos:])
		.pos += 

		switch  {
		case 'e', 'E':
			,  := utf8.DecodeRuneInString(.src[.pos:])
			if  == '\'' {
				.pos += 
				return escapeStringState
			}
		case '\'':
			return singleQuoteState
		case '"':
			return doubleQuoteState
		case '$':
			,  := utf8.DecodeRuneInString(.src[.pos:])
			if '0' <=  &&  <= '9' {
				if .pos-.start > 0 {
					.parts = append(.parts, .src[.start:.pos-])
				}
				.start = .pos
				return placeholderState
			}
		case '-':
			,  := utf8.DecodeRuneInString(.src[.pos:])
			if  == '-' {
				.pos += 
				return oneLineCommentState
			}
		case '/':
			,  := utf8.DecodeRuneInString(.src[.pos:])
			if  == '*' {
				.pos += 
				return multilineCommentState
			}
		case utf8.RuneError:
			if  != replacementcharacterwidth {
				if .pos-.start > 0 {
					.parts = append(.parts, .src[.start:.pos])
					.start = .pos
				}
				return nil
			}
		}
	}
}

func ( *sqlLexer) stateFn {
	for {
		,  := utf8.DecodeRuneInString(.src[.pos:])
		.pos += 

		switch  {
		case '\'':
			,  := utf8.DecodeRuneInString(.src[.pos:])
			if  != '\'' {
				return rawState
			}
			.pos += 
		case utf8.RuneError:
			if  != replacementcharacterwidth {
				if .pos-.start > 0 {
					.parts = append(.parts, .src[.start:.pos])
					.start = .pos
				}
				return nil
			}
		}
	}
}

func ( *sqlLexer) stateFn {
	for {
		,  := utf8.DecodeRuneInString(.src[.pos:])
		.pos += 

		switch  {
		case '"':
			,  := utf8.DecodeRuneInString(.src[.pos:])
			if  != '"' {
				return rawState
			}
			.pos += 
		case utf8.RuneError:
			if  != replacementcharacterwidth {
				if .pos-.start > 0 {
					.parts = append(.parts, .src[.start:.pos])
					.start = .pos
				}
				return nil
			}
		}
	}
}

// placeholderState consumes a placeholder value. The $ must have already has
// already been consumed. The first rune must be a digit.
func ( *sqlLexer) stateFn {
	 := 0

	for {
		,  := utf8.DecodeRuneInString(.src[.pos:])
		.pos += 

		if '0' <=  &&  <= '9' {
			 *= 10
			 += int( - '0')
		} else {
			.parts = append(.parts, )
			.pos -= 
			.start = .pos
			return rawState
		}
	}
}

func ( *sqlLexer) stateFn {
	for {
		,  := utf8.DecodeRuneInString(.src[.pos:])
		.pos += 

		switch  {
		case '\\':
			_,  = utf8.DecodeRuneInString(.src[.pos:])
			.pos += 
		case '\'':
			,  := utf8.DecodeRuneInString(.src[.pos:])
			if  != '\'' {
				return rawState
			}
			.pos += 
		case utf8.RuneError:
			if  != replacementcharacterwidth {
				if .pos-.start > 0 {
					.parts = append(.parts, .src[.start:.pos])
					.start = .pos
				}
				return nil
			}
		}
	}
}

func ( *sqlLexer) stateFn {
	for {
		,  := utf8.DecodeRuneInString(.src[.pos:])
		.pos += 

		switch  {
		case '\\':
			_,  = utf8.DecodeRuneInString(.src[.pos:])
			.pos += 
		case '\n', '\r':
			return rawState
		case utf8.RuneError:
			if  != replacementcharacterwidth {
				if .pos-.start > 0 {
					.parts = append(.parts, .src[.start:.pos])
					.start = .pos
				}
				return nil
			}
		}
	}
}

func ( *sqlLexer) stateFn {
	for {
		,  := utf8.DecodeRuneInString(.src[.pos:])
		.pos += 

		switch  {
		case '/':
			,  := utf8.DecodeRuneInString(.src[.pos:])
			if  == '*' {
				.pos += 
				.nested++
			}
		case '*':
			,  := utf8.DecodeRuneInString(.src[.pos:])
			if  != '/' {
				continue
			}

			.pos += 
			if .nested == 0 {
				return rawState
			}
			.nested--

		case utf8.RuneError:
			if  != replacementcharacterwidth {
				if .pos-.start > 0 {
					.parts = append(.parts, .src[.start:.pos])
					.start = .pos
				}
				return nil
			}
		}
	}
}

var queryPool = &pool[*Query]{
	new: func() *Query {
		return &Query{}
	},
	reset: func( *Query) bool {
		 := len(.Parts)
		.Parts = .Parts[:0]
		return  < 64 // drop too large queries
	},
}

// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
// as necessary. This function is only safe when standard_conforming_strings is
// on.
func ( string,  ...any) (string, error) {
	 := queryPool.get()
	.init()
	defer queryPool.put()

	return .Sanitize(...)
}

type pool[ any] struct {
	p     sync.Pool
	new   func() 
	reset func() bool
}

func ( *pool[]) ()  {
	,  := .p.Get().()
	if ! {
		 = .new()
	}

	return 
}

func ( *pool[]) ( ) {
	if .reset() {
		.p.Put()
	}
}