package pgx
import (
)
type NamedArgs map[string]any
func ( NamedArgs) ( context.Context, *Conn, string, []any) ( string, []any, error) {
return rewriteQuery(, , false)
}
type StrictNamedArgs map[string]any
func ( StrictNamedArgs) ( context.Context, *Conn, string, []any) ( string, []any, error) {
return rewriteQuery(, , true)
}
type namedArg string
type sqlLexer struct {
src string
start int
pos int
nested int
stateFn stateFn
parts []any
nameToOrdinal map[namedArg]int
}
type stateFn func(*sqlLexer) stateFn
func ( map[string]any, string, bool) ( string, []any, error) {
:= &sqlLexer{
src: ,
stateFn: rawState,
nameToOrdinal: make(map[namedArg]int, len()),
}
for .stateFn != nil {
.stateFn = .stateFn()
}
:= strings.Builder{}
for , := range .parts {
switch p := .(type) {
case string:
.WriteString()
case namedArg:
.WriteRune('$')
.WriteString(strconv.Itoa(.nameToOrdinal[]))
}
}
= make([]any, len(.nameToOrdinal))
for , := range .nameToOrdinal {
var bool
[-1], = [string()]
if && ! {
return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", )
}
}
if {
for := range {
if , := .nameToOrdinal[namedArg()]; ! {
return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", )
}
}
}
return .String(), , nil
}
func ( *sqlLexer) stateFn {
for {
, := utf8.DecodeRuneInString(.src[.pos:])
.pos +=
switch {
case 'e', 'E':
, := utf8.DecodeRuneInString(.src[.pos:])
if == '\'' {
.pos +=
return escapeStringState
}
case '\'':
return singleQuoteState
case '"':
return doubleQuoteState
case '@':
, := utf8.DecodeRuneInString(.src[.pos:])
if isLetter() || == '_' {
if .pos-.start > 0 {
.parts = append(.parts, .src[.start:.pos-])
}
.start = .pos
return namedArgState
}
case '-':
, := utf8.DecodeRuneInString(.src[.pos:])
if == '-' {
.pos +=
return oneLineCommentState
}
case '/':
, := utf8.DecodeRuneInString(.src[.pos:])
if == '*' {
.pos +=
return multilineCommentState
}
case utf8.RuneError:
if .pos-.start > 0 {
.parts = append(.parts, .src[.start:.pos])
.start = .pos
}
return nil
}
}
}
func ( rune) bool {
return ( >= 'a' && <= 'z') || ( >= 'A' && <= 'Z')
}
func ( *sqlLexer) stateFn {
for {
, := utf8.DecodeRuneInString(.src[.pos:])
.pos +=
if == utf8.RuneError {
if .pos-.start > 0 {
:= namedArg(.src[.start:.pos])
if , := .nameToOrdinal[]; ! {
.nameToOrdinal[] = len(.nameToOrdinal) + 1
}
.parts = append(.parts, )
.start = .pos
}
return nil
} else if !(isLetter() || ( >= '0' && <= '9') || == '_') {
.pos -=
:= namedArg(.src[.start:.pos])
if , := .nameToOrdinal[]; ! {
.nameToOrdinal[] = len(.nameToOrdinal) + 1
}
.parts = append(.parts, namedArg())
.start = .pos
return rawState
}
}
}
func ( *sqlLexer) stateFn {
for {
, := utf8.DecodeRuneInString(.src[.pos:])
.pos +=
switch {
case '\'':
, := utf8.DecodeRuneInString(.src[.pos:])
if != '\'' {
return rawState
}
.pos +=
case utf8.RuneError:
if .pos-.start > 0 {
.parts = append(.parts, .src[.start:.pos])
.start = .pos
}
return nil
}
}
}
func ( *sqlLexer) stateFn {
for {
, := utf8.DecodeRuneInString(.src[.pos:])
.pos +=
switch {
case '"':
, := utf8.DecodeRuneInString(.src[.pos:])
if != '"' {
return rawState
}
.pos +=
case utf8.RuneError:
if .pos-.start > 0 {
.parts = append(.parts, .src[.start:.pos])
.start = .pos
}
return nil
}
}
}
func ( *sqlLexer) stateFn {
for {
, := utf8.DecodeRuneInString(.src[.pos:])
.pos +=
switch {
case '\\':
_, = utf8.DecodeRuneInString(.src[.pos:])
.pos +=
case '\'':
, := utf8.DecodeRuneInString(.src[.pos:])
if != '\'' {
return rawState
}
.pos +=
case utf8.RuneError:
if .pos-.start > 0 {
.parts = append(.parts, .src[.start:.pos])
.start = .pos
}
return nil
}
}
}
func ( *sqlLexer) stateFn {
for {
, := utf8.DecodeRuneInString(.src[.pos:])
.pos +=
switch {
case '\\':
_, = utf8.DecodeRuneInString(.src[.pos:])
.pos +=
case '\n', '\r':
return rawState
case utf8.RuneError:
if .pos-.start > 0 {
.parts = append(.parts, .src[.start:.pos])
.start = .pos
}
return nil
}
}
}
func ( *sqlLexer) stateFn {
for {
, := utf8.DecodeRuneInString(.src[.pos:])
.pos +=
switch {
case '/':
, := utf8.DecodeRuneInString(.src[.pos:])
if == '*' {
.pos +=
.nested++
}
case '*':
, := utf8.DecodeRuneInString(.src[.pos:])
if != '/' {
continue
}
.pos +=
if .nested == 0 {
return rawState
}
.nested--
case utf8.RuneError:
if .pos-.start > 0 {
.parts = append(.parts, .src[.start:.pos])
.start = .pos
}
return nil
}
}
}