package pgx
import (
)
type ConnConfig struct {
pgconn.Config
Tracer QueryTracer
connString string
StatementCacheCapacity int
DescriptionCacheCapacity int
DefaultQueryExecMode QueryExecMode
createdByParseConfig bool
}
type ParseConfigOptions struct {
pgconn.ParseConfigOptions
}
func ( *ConnConfig) () *ConnConfig {
:= new(ConnConfig)
* = *
.Config = *.Config.Copy()
return
}
func ( *ConnConfig) () string { return .connString }
type Conn struct {
pgConn *pgconn.PgConn
config *ConnConfig
preparedStatements map[string]*pgconn.StatementDescription
failedDescribeStatement string
statementCache stmtcache.Cache
descriptionCache stmtcache.Cache
queryTracer QueryTracer
batchTracer BatchTracer
copyFromTracer CopyFromTracer
prepareTracer PrepareTracer
notifications []*pgconn.Notification
doneChan chan struct{}
closedChan chan error
typeMap *pgtype.Map
wbuf []byte
eqb ExtendedQueryBuilder
}
type Identifier []string
func ( Identifier) () string {
:= make([]string, len())
for := range {
:= strings.ReplaceAll([], string([]byte{0}), "")
[] = `"` + strings.ReplaceAll(, `"`, `""`) + `"`
}
return strings.Join(, ".")
}
var (
ErrNoRows = newProxyErr(sql.ErrNoRows, "no rows in result set")
ErrTooManyRows = errors.New("too many rows in result set")
)
func ( error, string) error {
return &proxyError{
msg: ,
background: ,
}
}
type proxyError struct {
msg string
background error
}
func ( *proxyError) () string { return .msg }
func ( *proxyError) () error { return .background }
var (
errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
)
func ( context.Context, string) (*Conn, error) {
, := ParseConfig()
if != nil {
return nil,
}
return connect(, )
}
func ( context.Context, string, ParseConfigOptions) (*Conn, error) {
, := ParseConfigWithOptions(, )
if != nil {
return nil,
}
return connect(, )
}
func ( context.Context, *ConnConfig) (*Conn, error) {
= .Copy()
return connect(, )
}
func ( string, ParseConfigOptions) (*ConnConfig, error) {
, := pgconn.ParseConfigWithOptions(, .ParseConfigOptions)
if != nil {
return nil,
}
:= 512
if , := .RuntimeParams["statement_cache_capacity"]; {
delete(.RuntimeParams, "statement_cache_capacity")
, := strconv.ParseInt(, 10, 32)
if != nil {
return nil, pgconn.NewParseConfigError(, "cannot parse statement_cache_capacity", )
}
= int()
}
:= 512
if , := .RuntimeParams["description_cache_capacity"]; {
delete(.RuntimeParams, "description_cache_capacity")
, := strconv.ParseInt(, 10, 32)
if != nil {
return nil, pgconn.NewParseConfigError(, "cannot parse description_cache_capacity", )
}
= int()
}
:= QueryExecModeCacheStatement
if , := .RuntimeParams["default_query_exec_mode"]; {
delete(.RuntimeParams, "default_query_exec_mode")
switch {
case "cache_statement":
= QueryExecModeCacheStatement
case "cache_describe":
= QueryExecModeCacheDescribe
case "describe_exec":
= QueryExecModeDescribeExec
case "exec":
= QueryExecModeExec
case "simple_protocol":
= QueryExecModeSimpleProtocol
default:
return nil, pgconn.NewParseConfigError(
, "invalid default_query_exec_mode", fmt.Errorf("unknown value %q", ),
)
}
}
:= &ConnConfig{
Config: *,
createdByParseConfig: true,
StatementCacheCapacity: ,
DescriptionCacheCapacity: ,
DefaultQueryExecMode: ,
connString: ,
}
return , nil
}
func ( string) (*ConnConfig, error) {
return ParseConfigWithOptions(, ParseConfigOptions{})
}
func ( context.Context, *ConnConfig) ( *Conn, error) {
if , := .Tracer.(ConnectTracer); {
= .TraceConnectStart(, TraceConnectStartData{ConnConfig: })
defer func() {
.TraceConnectEnd(, TraceConnectEndData{Conn: , Err: })
}()
}
if !.createdByParseConfig {
panic("config must be created by ParseConfig")
}
= &Conn{
config: ,
typeMap: pgtype.NewMap(),
queryTracer: .Tracer,
}
if , := .queryTracer.(BatchTracer); {
.batchTracer =
}
if , := .queryTracer.(CopyFromTracer); {
.copyFromTracer =
}
if , := .queryTracer.(PrepareTracer); {
.prepareTracer =
}
if .Config.OnNotification == nil {
.Config.OnNotification = .bufferNotifications
}
.pgConn, = pgconn.ConnectConfig(, &.Config)
if != nil {
return nil,
}
.preparedStatements = make(map[string]*pgconn.StatementDescription)
.doneChan = make(chan struct{})
.closedChan = make(chan error)
.wbuf = make([]byte, 0, 1024)
if .config.StatementCacheCapacity > 0 {
.statementCache = stmtcache.NewLRUCache(.config.StatementCacheCapacity)
}
if .config.DescriptionCacheCapacity > 0 {
.descriptionCache = stmtcache.NewLRUCache(.config.DescriptionCacheCapacity)
}
return , nil
}
func ( *Conn) ( context.Context) error {
if .IsClosed() {
return nil
}
:= .pgConn.Close()
return
}
func ( *Conn) ( context.Context, , string) ( *pgconn.StatementDescription, error) {
if .failedDescribeStatement != "" {
= .Deallocate(, .failedDescribeStatement)
if != nil {
return nil, fmt.Errorf("failed to deallocate previously failed statement %q: %w", .failedDescribeStatement, )
}
.failedDescribeStatement = ""
}
if .prepareTracer != nil {
= .prepareTracer.TracePrepareStart(, , TracePrepareStartData{Name: , SQL: })
}
if != "" {
var bool
if , = .preparedStatements[]; && .SQL == {
if .prepareTracer != nil {
.prepareTracer.TracePrepareEnd(, , TracePrepareEndData{AlreadyPrepared: true})
}
return , nil
}
}
if .prepareTracer != nil {
defer func() {
.prepareTracer.TracePrepareEnd(, , TracePrepareEndData{Err: })
}()
}
var , string
if == {
:= sha256.Sum256([]byte())
= "stmt_" + hex.EncodeToString([0:24])
=
} else {
=
=
}
, = .pgConn.Prepare(, , , nil)
if != nil {
var *pgconn.PrepareError
if errors.As(, &) {
.failedDescribeStatement =
}
return nil,
}
if != "" {
.preparedStatements[] =
}
return , nil
}
func ( *Conn) ( context.Context, string) error {
var string
:= .preparedStatements[]
if != nil {
= .Name
} else {
=
}
:= .pgConn.Deallocate(, )
if != nil {
return
}
if != nil {
delete(.preparedStatements, )
}
return nil
}
func ( *Conn) ( context.Context) error {
.preparedStatements = map[string]*pgconn.StatementDescription{}
if .config.StatementCacheCapacity > 0 {
.statementCache = stmtcache.NewLRUCache(.config.StatementCacheCapacity)
}
if .config.DescriptionCacheCapacity > 0 {
.descriptionCache = stmtcache.NewLRUCache(.config.DescriptionCacheCapacity)
}
, := .pgConn.Exec(, "deallocate all").ReadAll()
return
}
func ( *Conn) ( *pgconn.PgConn, *pgconn.Notification) {
.notifications = append(.notifications, )
}
func ( *Conn) ( context.Context) (*pgconn.Notification, error) {
var *pgconn.Notification
if len(.notifications) > 0 {
= .notifications[0]
.notifications = .notifications[1:]
return , nil
}
:= .pgConn.WaitForNotification()
if len(.notifications) > 0 {
= .notifications[0]
.notifications = .notifications[1:]
}
return ,
}
func ( *Conn) () bool {
return .pgConn.IsClosed()
}
func ( *Conn) () {
if .IsClosed() {
return
}
, := context.WithCancel(context.Background())
()
.pgConn.Close()
}
func ( string) string {
return `"` + strings.ReplaceAll(, `"`, `""`) + `"`
}
func ( *Conn) ( context.Context) error {
return .pgConn.Ping()
}
func ( *Conn) () *pgconn.PgConn { return .pgConn }
func ( *Conn) () *pgtype.Map { return .typeMap }
func ( *Conn) () *ConnConfig { return .config.Copy() }
func ( *Conn) ( context.Context, string, ...any) (pgconn.CommandTag, error) {
if .queryTracer != nil {
= .queryTracer.TraceQueryStart(, , TraceQueryStartData{SQL: , Args: })
}
if := .deallocateInvalidatedCachedStatements(); != nil {
return pgconn.CommandTag{},
}
, := .exec(, , ...)
if .queryTracer != nil {
.queryTracer.TraceQueryEnd(, , TraceQueryEndData{CommandTag: , Err: })
}
return ,
}
func ( *Conn) ( context.Context, string, ...any) ( pgconn.CommandTag, error) {
:= .config.DefaultQueryExecMode
var QueryRewriter
:
for len() > 0 {
switch arg := [0].(type) {
case QueryExecMode:
=
= [1:]
case QueryRewriter:
=
= [1:]
default:
break
}
}
if != nil {
, , = .RewriteQuery(, , , )
if != nil {
return pgconn.CommandTag{}, fmt.Errorf("rewrite query failed: %w", )
}
}
if len() == 0 {
= QueryExecModeSimpleProtocol
}
defer func() {
if != nil {
if := .statementCache; != nil {
.Invalidate()
}
if := .descriptionCache; != nil {
.Invalidate()
}
}
}()
if , := .preparedStatements[]; {
return .execPrepared(, , )
}
switch {
case QueryExecModeCacheStatement:
if .statementCache == nil {
return pgconn.CommandTag{}, errDisabledStatementCache
}
:= .statementCache.Get()
if == nil {
, = .Prepare(, stmtcache.StatementName(), )
if != nil {
return pgconn.CommandTag{},
}
.statementCache.Put()
}
return .execPrepared(, , )
case QueryExecModeCacheDescribe:
if .descriptionCache == nil {
return pgconn.CommandTag{}, errDisabledDescriptionCache
}
:= .descriptionCache.Get()
if == nil {
, = .Prepare(, "", )
if != nil {
return pgconn.CommandTag{},
}
.descriptionCache.Put()
}
return .execParams(, , )
case QueryExecModeDescribeExec:
, := .Prepare(, "", )
if != nil {
return pgconn.CommandTag{},
}
return .execPrepared(, , )
case QueryExecModeExec:
return .execSQLParams(, , )
case QueryExecModeSimpleProtocol:
return .execSimpleProtocol(, , )
default:
return pgconn.CommandTag{}, fmt.Errorf("unknown QueryExecMode: %v", )
}
}
func ( *Conn) ( context.Context, string, []any) ( pgconn.CommandTag, error) {
if len() > 0 {
, = .sanitizeForSimpleQuery(, ...)
if != nil {
return pgconn.CommandTag{},
}
}
:= .pgConn.Exec(, )
for .NextResult() {
, _ = .ResultReader().Close()
}
= .Close()
return ,
}
func ( *Conn) ( context.Context, *pgconn.StatementDescription, []any) (pgconn.CommandTag, error) {
:= .eqb.Build(.typeMap, , )
if != nil {
return pgconn.CommandTag{},
}
:= .pgConn.ExecParams(, .SQL, .eqb.ParamValues, .ParamOIDs, .eqb.ParamFormats, .eqb.ResultFormats).Read()
.eqb.reset()
return .CommandTag, .Err
}
func ( *Conn) ( context.Context, *pgconn.StatementDescription, []any) (pgconn.CommandTag, error) {
:= .eqb.Build(.typeMap, , )
if != nil {
return pgconn.CommandTag{},
}
:= .pgConn.ExecStatement(, , .eqb.ParamValues, .eqb.ParamFormats, .eqb.ResultFormats).Read()
.eqb.reset()
return .CommandTag, .Err
}
func ( *Conn) ( context.Context, string, []any) (pgconn.CommandTag, error) {
:= .eqb.Build(.typeMap, nil, )
if != nil {
return pgconn.CommandTag{},
}
:= .pgConn.ExecParams(, , .eqb.ParamValues, nil, .eqb.ParamFormats, .eqb.ResultFormats).Read()
.eqb.reset()
return .CommandTag, .Err
}
func ( *Conn) ( context.Context, string, []any) *baseRows {
:= &baseRows{}
.ctx =
.queryTracer = .queryTracer
.typeMap = .typeMap
.startTime = time.Now()
.sql =
.args =
.conn =
return
}
type QueryExecMode int32
const (
_ QueryExecMode = iota
QueryExecModeCacheStatement
QueryExecModeCacheDescribe
QueryExecModeDescribeExec
QueryExecModeExec
QueryExecModeSimpleProtocol
)
func ( QueryExecMode) () string {
switch {
case QueryExecModeCacheStatement:
return "cache statement"
case QueryExecModeCacheDescribe:
return "cache describe"
case QueryExecModeDescribeExec:
return "describe exec"
case QueryExecModeExec:
return "exec"
case QueryExecModeSimpleProtocol:
return "simple protocol"
default:
return "invalid"
}
}
type QueryResultFormats []int16
type QueryResultFormatsByOID map[uint32]int16
type QueryRewriter interface {
RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error)
}
func ( *Conn) ( context.Context, string, ...any) (Rows, error) {
if .queryTracer != nil {
= .queryTracer.TraceQueryStart(, , TraceQueryStartData{SQL: , Args: })
}
if := .deallocateInvalidatedCachedStatements(); != nil {
if .queryTracer != nil {
.queryTracer.TraceQueryEnd(, , TraceQueryEndData{Err: })
}
return &baseRows{err: , closed: true},
}
var QueryResultFormats
var QueryResultFormatsByOID
:= .config.DefaultQueryExecMode
var QueryRewriter
:
for len() > 0 {
switch arg := [0].(type) {
case QueryResultFormats:
=
= [1:]
case QueryResultFormatsByOID:
=
= [1:]
case QueryExecMode:
=
= [1:]
case QueryRewriter:
=
= [1:]
default:
break
}
}
if != nil {
var error
:=
:=
, , = .RewriteQuery(, , , )
if != nil {
:= .getRows(, , )
= fmt.Errorf("rewrite query failed: %w", )
.fatal()
return ,
}
}
if == "" {
= QueryExecModeSimpleProtocol
}
.eqb.reset()
:= .getRows(, , )
var error
, := .preparedStatements[]
if != nil || == QueryExecModeCacheStatement || == QueryExecModeCacheDescribe || == QueryExecModeDescribeExec {
if == nil {
, = .getStatementDescription(, , )
if != nil {
.fatal()
return ,
}
}
if len(.ParamOIDs) != len() {
.fatal(fmt.Errorf("expected %d arguments, got %d", len(.ParamOIDs), len()))
return , .err
}
.sql = .SQL
= .eqb.Build(.typeMap, , )
if != nil {
.fatal()
return , .err
}
if != nil {
= make([]int16, len(.Fields))
for := range {
[] = [uint32(.Fields[].DataTypeOID)]
}
}
if == nil {
= .eqb.ResultFormats
}
if ! && == QueryExecModeCacheDescribe {
.resultReader = .pgConn.ExecParams(, , .eqb.ParamValues, .ParamOIDs, .eqb.ParamFormats, )
} else {
.resultReader = .pgConn.ExecStatement(, , .eqb.ParamValues, .eqb.ParamFormats, )
}
} else if == QueryExecModeExec {
:= .eqb.Build(.typeMap, nil, )
if != nil {
.fatal()
return , .err
}
.resultReader = .pgConn.ExecParams(, , .eqb.ParamValues, nil, .eqb.ParamFormats, .eqb.ResultFormats)
} else if == QueryExecModeSimpleProtocol {
, = .sanitizeForSimpleQuery(, ...)
if != nil {
.fatal()
return ,
}
:= .pgConn.Exec(, )
if .NextResult() {
.resultReader = .ResultReader()
.multiResultReader =
} else {
= .Close()
.fatal()
return ,
}
return , nil
} else {
= fmt.Errorf("unknown QueryExecMode: %v", )
.fatal()
return , .err
}
.eqb.reset()
return , .err
}
func ( *Conn) (
context.Context,
QueryExecMode,
string,
) ( *pgconn.StatementDescription, error) {
switch {
case QueryExecModeCacheStatement:
if .statementCache == nil {
return nil, errDisabledStatementCache
}
= .statementCache.Get()
if == nil {
, = .Prepare(, stmtcache.StatementName(), )
if != nil {
return nil,
}
.statementCache.Put()
}
case QueryExecModeCacheDescribe:
if .descriptionCache == nil {
return nil, errDisabledDescriptionCache
}
= .descriptionCache.Get()
if == nil {
, = .Prepare(, "", )
if != nil {
return nil,
}
.descriptionCache.Put()
}
case QueryExecModeDescribeExec:
return .Prepare(, "", )
}
return ,
}
func ( *Conn) ( context.Context, string, ...any) Row {
, := .Query(, , ...)
return (*connRow)(.(*baseRows))
}
func ( *Conn) ( context.Context, *Batch) ( BatchResults) {
if len(.QueuedQueries) == 0 {
return &emptyBatchResults{conn: }
}
if .batchTracer != nil {
= .batchTracer.TraceBatchStart(, , TraceBatchStartData{Batch: })
defer func() {
:= .(interface{ () error }).()
if != nil {
.batchTracer.TraceBatchEnd(, , TraceBatchEndData{Err: })
}
}()
}
if := .deallocateInvalidatedCachedStatements(); != nil {
return &batchResults{ctx: , conn: , err: }
}
for , := range .QueuedQueries {
var QueryRewriter
:= .SQL
:= .Arguments
:
for len() > 0 {
switch arg := [0].(type) {
case QueryRewriter:
=
= [1:]
default:
break
}
}
if != nil {
var error
, , = .RewriteQuery(, , , )
if != nil {
return &batchResults{ctx: , conn: , err: fmt.Errorf("rewrite query failed: %w", )}
}
}
.SQL =
.Arguments =
}
:= .config.DefaultQueryExecMode
if == QueryExecModeSimpleProtocol {
return .sendBatchQueryExecModeSimpleProtocol(, )
}
for , := range .QueuedQueries {
if , := .preparedStatements[.SQL]; {
.sd =
}
}
switch {
case QueryExecModeExec:
return .sendBatchQueryExecModeExec(, )
case QueryExecModeCacheStatement:
return .sendBatchQueryExecModeCacheStatement(, )
case QueryExecModeCacheDescribe:
return .sendBatchQueryExecModeCacheDescribe(, )
case QueryExecModeDescribeExec:
return .sendBatchQueryExecModeDescribeExec(, )
default:
panic("unknown QueryExecMode")
}
}
func ( *Conn) ( context.Context, *Batch) *batchResults {
var strings.Builder
for , := range .QueuedQueries {
if > 0 {
.WriteByte(';')
}
, := .sanitizeForSimpleQuery(.SQL, .Arguments...)
if != nil {
return &batchResults{ctx: , conn: , err: }
}
.WriteString()
}
:= .pgConn.Exec(, .String())
return &batchResults{
ctx: ,
conn: ,
mrr: ,
b: ,
qqIdx: 0,
}
}
func ( *Conn) ( context.Context, *Batch) *batchResults {
:= &pgconn.Batch{}
for , := range .QueuedQueries {
:= .sd
if != nil {
:= .eqb.Build(.typeMap, , .Arguments)
if != nil {
return &batchResults{ctx: , conn: , err: }
}
.ExecPrepared(.Name, .eqb.ParamValues, .eqb.ParamFormats, .eqb.ResultFormats)
} else {
:= .eqb.Build(.typeMap, nil, .Arguments)
if != nil {
return &batchResults{ctx: , conn: , err: }
}
.ExecParams(.SQL, .eqb.ParamValues, nil, .eqb.ParamFormats, .eqb.ResultFormats)
}
}
.eqb.reset()
:= .pgConn.ExecBatch(, )
return &batchResults{
ctx: ,
conn: ,
mrr: ,
b: ,
qqIdx: 0,
}
}
func ( *Conn) ( context.Context, *Batch) ( *pipelineBatchResults) {
if .statementCache == nil {
return &pipelineBatchResults{ctx: , conn: , err: errDisabledStatementCache, closed: true}
}
:= []*pgconn.StatementDescription{}
:= make(map[string]int)
for , := range .QueuedQueries {
if .sd == nil {
:= .statementCache.Get(.SQL)
if != nil {
.sd =
} else {
if , := [.SQL]; {
.sd = []
} else {
= &pgconn.StatementDescription{
Name: stmtcache.StatementName(.SQL),
SQL: .SQL,
}
[.SQL] = len()
= append(, )
.sd =
}
}
}
}
return .sendBatchExtendedWithDescription(, , , .statementCache)
}
func ( *Conn) ( context.Context, *Batch) ( *pipelineBatchResults) {
if .descriptionCache == nil {
return &pipelineBatchResults{ctx: , conn: , err: errDisabledDescriptionCache, closed: true}
}
:= []*pgconn.StatementDescription{}
:= make(map[string]int)
for , := range .QueuedQueries {
if .sd == nil {
:= .descriptionCache.Get(.SQL)
if != nil {
.sd =
} else {
if , := [.SQL]; {
.sd = []
} else {
= &pgconn.StatementDescription{
SQL: .SQL,
}
[.SQL] = len()
= append(, )
.sd =
}
}
}
}
return .sendBatchExtendedWithDescription(, , , .descriptionCache)
}
func ( *Conn) ( context.Context, *Batch) ( *pipelineBatchResults) {
:= []*pgconn.StatementDescription{}
:= make(map[string]int)
for , := range .QueuedQueries {
if .sd == nil {
if , := [.SQL]; {
.sd = []
} else {
:= &pgconn.StatementDescription{
SQL: .SQL,
}
[.SQL] = len()
= append(, )
.sd =
}
}
}
return .sendBatchExtendedWithDescription(, , , nil)
}
func ( *Conn) ( context.Context, *Batch, []*pgconn.StatementDescription, stmtcache.Cache) ( *pipelineBatchResults) {
:= .pgConn.StartPipeline()
defer func() {
if != nil && .err != nil {
.Close()
}
}()
if len() > 0 {
:= func() ( error) {
for , := range {
.SendPrepare(.Name, .SQL, nil)
}
if != nil {
for , := range {
.Put()
}
}
defer func() {
if != nil && != nil {
for , := range {
.Invalidate(.SQL)
}
}
}()
= .Sync()
if != nil {
return
}
for , := range {
, := .GetResults()
if != nil {
return newErrPreprocessingBatch("prepare", .SQL, )
}
, := .(*pgconn.StatementDescription)
if ! {
return fmt.Errorf("expected statement description, got %T", )
}
.ParamOIDs = .ParamOIDs
.Fields = .Fields
}
, := .GetResults()
if != nil {
return
}
, := .(*pgconn.PipelineSync)
if ! {
return fmt.Errorf("expected sync, got %T", )
}
return nil
}()
if != nil {
return &pipelineBatchResults{ctx: , conn: , err: , closed: true}
}
}
for , := range .QueuedQueries {
:= .eqb.Build(.typeMap, .sd, .Arguments)
if != nil {
= newErrPreprocessingBatch("build", .SQL, )
return &pipelineBatchResults{ctx: , conn: , err: , closed: true}
}
if .sd.Name == "" {
.SendQueryParams(.sd.SQL, .eqb.ParamValues, .sd.ParamOIDs, .eqb.ParamFormats, .eqb.ResultFormats)
} else {
:= make([]int16, len(.eqb.ResultFormats))
copy(, .eqb.ResultFormats)
.SendQueryStatement(.sd, .eqb.ParamValues, .eqb.ParamFormats, )
}
}
:= .Sync()
if != nil {
return &pipelineBatchResults{ctx: , conn: , err: , closed: true}
}
return &pipelineBatchResults{
ctx: ,
conn: ,
pipeline: ,
b: ,
}
}
func ( *Conn) ( string, ...any) (string, error) {
if .pgConn.ParameterStatus("standard_conforming_strings") != "on" {
return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on")
}
if .pgConn.ParameterStatus("client_encoding") != "UTF8" {
return "", errors.New("simple protocol queries must be run with client_encoding=UTF8")
}
var error
:= make([]any, len())
for , := range {
[], = convertSimpleArgument(.typeMap, )
if != nil {
return "",
}
}
return sanitize.SanitizeSQL(, ...)
}
func ( *Conn) ( context.Context, string) (*pgtype.Type, error) {
var uint32
:= .QueryRow(, "select $1::text::regtype::oid;", ).Scan(&)
if != nil {
return nil,
}
var string
var uint32
= .QueryRow(, "select typtype::text, typbasetype from pg_type where oid=$1", ).Scan(&, &)
if != nil {
return nil,
}
switch {
case "b":
, := .getArrayElementOID(, )
if != nil {
return nil,
}
, := .TypeMap().TypeForOID()
if ! {
return nil, errors.New("array element OID not registered")
}
return &pgtype.Type{Name: , OID: , Codec: &pgtype.ArrayCodec{ElementType: }}, nil
case "c":
, := .getCompositeFields(, )
if != nil {
return nil,
}
return &pgtype.Type{Name: , OID: , Codec: &pgtype.CompositeCodec{Fields: }}, nil
case "d":
, := .TypeMap().TypeForOID()
if ! {
return nil, errors.New("domain base type OID not registered")
}
return &pgtype.Type{Name: , OID: , Codec: .Codec}, nil
case "e":
return &pgtype.Type{Name: , OID: , Codec: &pgtype.EnumCodec{}}, nil
case "r":
, := .getRangeElementOID(, )
if != nil {
return nil,
}
, := .TypeMap().TypeForOID()
if ! {
return nil, errors.New("range element OID not registered")
}
return &pgtype.Type{Name: , OID: , Codec: &pgtype.RangeCodec{ElementType: }}, nil
case "m":
, := .getMultiRangeElementOID(, )
if != nil {
return nil,
}
, := .TypeMap().TypeForOID()
if ! {
return nil, errors.New("multirange element OID not registered")
}
return &pgtype.Type{Name: , OID: , Codec: &pgtype.MultirangeCodec{ElementType: }}, nil
default:
return &pgtype.Type{}, errors.New("unknown typtype")
}
}
func ( *Conn) ( context.Context, uint32) (uint32, error) {
var uint32
:= .QueryRow(, "select typelem from pg_type where oid=$1", ).Scan(&)
if != nil {
return 0,
}
return , nil
}
func ( *Conn) ( context.Context, uint32) (uint32, error) {
var uint32
:= .QueryRow(, "select rngsubtype from pg_range where rngtypid=$1", ).Scan(&)
if != nil {
return 0,
}
return , nil
}
func ( *Conn) ( context.Context, uint32) (uint32, error) {
var uint32
:= .QueryRow(, "select rngtypid from pg_range where rngmultitypid=$1", ).Scan(&)
if != nil {
return 0,
}
return , nil
}
func ( *Conn) ( context.Context, uint32) ([]pgtype.CompositeCodecField, error) {
var uint32
:= .QueryRow(, "select typrelid from pg_type where oid=$1", ).Scan(&)
if != nil {
return nil,
}
var []pgtype.CompositeCodecField
var string
var uint32
, := .Query(, `select attname, atttypid
from pg_attribute
where attrelid=$1
and not attisdropped
and attnum > 0
order by attnum`,
,
)
_, = ForEachRow(, []any{&, &}, func() error {
, := .TypeMap().TypeForOID()
if ! {
return fmt.Errorf("unknown composite type field OID: %v", )
}
= append(, pgtype.CompositeCodecField{Name: , Type: })
return nil
})
if != nil {
return nil,
}
return , nil
}
func ( *Conn) ( context.Context) error {
if := .pgConn.TxStatus(); != 'I' && != 'T' {
return nil
}
if .descriptionCache != nil {
.descriptionCache.RemoveInvalidated()
}
var []*pgconn.StatementDescription
if .statementCache != nil {
= .statementCache.GetInvalidated()
}
if len() == 0 {
return nil
}
:= .pgConn.StartPipeline()
defer .Close()
for , := range {
.SendDeallocate(.Name)
}
:= .Sync()
if != nil {
return fmt.Errorf("failed to deallocate cached statement(s): %w", )
}
= .Close()
if != nil {
return fmt.Errorf("failed to deallocate cached statement(s): %w", )
}
.statementCache.RemoveInvalidated()
for , := range {
delete(.preparedStatements, .Name)
}
return nil
}