package sanitize
import (
)
type Part any
type Query struct {
Parts []Part
}
const replacementcharacterwidth = 3
const maxBufSize = 16384
var bufPool = &pool[*bytes.Buffer]{
new: func() *bytes.Buffer {
return &bytes.Buffer{}
},
reset: func( *bytes.Buffer) bool {
:= .Len()
.Reset()
return < maxBufSize
},
}
var null = []byte("null")
func ( *Query) ( ...any) (string, error) {
:= make([]bool, len())
:= bufPool.get()
defer bufPool.put()
for , := range .Parts {
switch part := .(type) {
case string:
.WriteString()
case int:
:= - 1
var []byte
if < 0 {
return "", fmt.Errorf("first sql argument must be > 0")
}
if >= len() {
return "", fmt.Errorf("insufficient arguments")
}
.WriteByte(' ')
:= []
switch arg := .(type) {
case nil:
= null
case int64:
= strconv.AppendInt(.AvailableBuffer(), , 10)
case float64:
= strconv.AppendFloat(.AvailableBuffer(), , 'f', -1, 64)
case bool:
= strconv.AppendBool(.AvailableBuffer(), )
case []byte:
= QuoteBytes(.AvailableBuffer(), )
case string:
= QuoteString(.AvailableBuffer(), )
case time.Time:
= .Truncate(time.Microsecond).
AppendFormat(.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'")
default:
return "", fmt.Errorf("invalid arg type: %T", )
}
[] = true
.Write()
.WriteByte(' ')
default:
return "", fmt.Errorf("invalid Part type: %T", )
}
}
for , := range {
if ! {
return "", fmt.Errorf("unused argument: %d", )
}
}
return .String(), nil
}
func ( string) (*Query, error) {
:= &Query{}
.init()
return , nil
}
var sqlLexerPool = &pool[*sqlLexer]{
new: func() *sqlLexer {
return &sqlLexer{}
},
reset: func( *sqlLexer) bool {
* = sqlLexer{}
return true
},
}
func ( *Query) ( string) {
:= .Parts[:0]
if == nil {
:= strings.Count(, "$") + strings.Count(, "--") + 1
= make([]Part, 0, )
}
:= sqlLexerPool.get()
defer sqlLexerPool.put()
.src =
.stateFn = rawState
.parts =
for .stateFn != nil {
.stateFn = .stateFn()
}
.Parts = .parts
}
func ( []byte, string) []byte {
const = '\''
= slices.Grow(, len()*2+2)
= append(, )
for := 0; < len(); ++ {
if [] == {
= append(, , )
} else {
= append(, [])
}
}
= append(, )
return
}
func (, []byte) []byte {
if len() == 0 {
return append(, `'\x'`...)
}
:= 3 + hex.EncodedLen(len()) + 1
if cap()-len() < {
:= make([]byte, len(), len()+)
copy(, )
=
}
:= len()
= [:+]
[] = '\''
[+1] = '\\'
[+2] = 'x'
hex.Encode([+3:len()-1], )
[len()-1] = '\''
return
}
type sqlLexer struct {
src string
start int
pos int
nested int
stateFn stateFn
parts []Part
}
type stateFn func(*sqlLexer) stateFn
func ( *sqlLexer) stateFn {
for {
, := utf8.DecodeRuneInString(.src[.pos:])
.pos +=
switch {
case 'e', 'E':
, := utf8.DecodeRuneInString(.src[.pos:])
if == '\'' {
.pos +=
return escapeStringState
}
case '\'':
return singleQuoteState
case '"':
return doubleQuoteState
case '$':
, := utf8.DecodeRuneInString(.src[.pos:])
if '0' <= && <= '9' {
if .pos-.start > 0 {
.parts = append(.parts, .src[.start:.pos-])
}
.start = .pos
return placeholderState
}
case '-':
, := utf8.DecodeRuneInString(.src[.pos:])
if == '-' {
.pos +=
return oneLineCommentState
}
case '/':
, := utf8.DecodeRuneInString(.src[.pos:])
if == '*' {
.pos +=
return multilineCommentState
}
case utf8.RuneError:
if != replacementcharacterwidth {
if .pos-.start > 0 {
.parts = append(.parts, .src[.start:.pos])
.start = .pos
}
return nil
}
}
}
}
func ( *sqlLexer) stateFn {
for {
, := utf8.DecodeRuneInString(.src[.pos:])
.pos +=
switch {
case '\'':
, := utf8.DecodeRuneInString(.src[.pos:])
if != '\'' {
return rawState
}
.pos +=
case utf8.RuneError:
if != replacementcharacterwidth {
if .pos-.start > 0 {
.parts = append(.parts, .src[.start:.pos])
.start = .pos
}
return nil
}
}
}
}
func ( *sqlLexer) stateFn {
for {
, := utf8.DecodeRuneInString(.src[.pos:])
.pos +=
switch {
case '"':
, := utf8.DecodeRuneInString(.src[.pos:])
if != '"' {
return rawState
}
.pos +=
case utf8.RuneError:
if != replacementcharacterwidth {
if .pos-.start > 0 {
.parts = append(.parts, .src[.start:.pos])
.start = .pos
}
return nil
}
}
}
}
func ( *sqlLexer) stateFn {
:= 0
for {
, := utf8.DecodeRuneInString(.src[.pos:])
.pos +=
if '0' <= && <= '9' {
*= 10
+= int( - '0')
} else {
.parts = append(.parts, )
.pos -=
.start = .pos
return rawState
}
}
}
func ( *sqlLexer) stateFn {
for {
, := utf8.DecodeRuneInString(.src[.pos:])
.pos +=
switch {
case '\\':
_, = utf8.DecodeRuneInString(.src[.pos:])
.pos +=
case '\'':
, := utf8.DecodeRuneInString(.src[.pos:])
if != '\'' {
return rawState
}
.pos +=
case utf8.RuneError:
if != replacementcharacterwidth {
if .pos-.start > 0 {
.parts = append(.parts, .src[.start:.pos])
.start = .pos
}
return nil
}
}
}
}
func ( *sqlLexer) stateFn {
for {
, := utf8.DecodeRuneInString(.src[.pos:])
.pos +=
switch {
case '\\':
_, = utf8.DecodeRuneInString(.src[.pos:])
.pos +=
case '\n', '\r':
return rawState
case utf8.RuneError:
if != replacementcharacterwidth {
if .pos-.start > 0 {
.parts = append(.parts, .src[.start:.pos])
.start = .pos
}
return nil
}
}
}
}
func ( *sqlLexer) stateFn {
for {
, := utf8.DecodeRuneInString(.src[.pos:])
.pos +=
switch {
case '/':
, := utf8.DecodeRuneInString(.src[.pos:])
if == '*' {
.pos +=
.nested++
}
case '*':
, := utf8.DecodeRuneInString(.src[.pos:])
if != '/' {
continue
}
.pos +=
if .nested == 0 {
return rawState
}
.nested--
case utf8.RuneError:
if != replacementcharacterwidth {
if .pos-.start > 0 {
.parts = append(.parts, .src[.start:.pos])
.start = .pos
}
return nil
}
}
}
}
var queryPool = &pool[*Query]{
new: func() *Query {
return &Query{}
},
reset: func( *Query) bool {
:= len(.Parts)
.Parts = .Parts[:0]
return < 64
},
}
func ( string, ...any) (string, error) {
:= queryPool.get()
.init()
defer queryPool.put()
return .Sanitize(...)
}
type pool[ any] struct {
p sync.Pool
new func()
reset func() bool
}
func ( *pool[]) () {
, := .p.Get().()
if ! {
= .new()
}
return
}
func ( *pool[]) ( ) {
if .reset() {
.p.Put()
}
}