package pgconn
import (
)
const (
connStatusUninitialized = iota
connStatusConnecting
connStatusClosed
connStatusIdle
connStatusBusy
)
type Notice PgError
type Notification struct {
PID uint32
Channel string
Payload string
}
type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
type LookupFunc func(ctx context.Context, host string) (addrs []string, err error)
type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend
type PgErrorHandler func(*PgConn, *PgError) bool
type NoticeHandler func(*PgConn, *Notice)
type NotificationHandler func(*PgConn, *Notification)
type PgConn struct {
conn net.Conn
pid uint32
secretKey []byte
parameterStatuses map[string]string
txStatus byte
frontend *pgproto3.Frontend
bgReader *bgreader.BGReader
slowWriteTimer *time.Timer
bgReaderStarted chan struct{}
customData map[string]any
config *Config
status byte
bufferingReceive bool
bufferingReceiveMux sync.Mutex
bufferingReceiveMsg pgproto3.BackendMessage
bufferingReceiveErr error
peekedMsg pgproto3.BackendMessage
resultReader ResultReader
multiResultReader MultiResultReader
pipeline Pipeline
contextWatcher *ctxwatch.ContextWatcher
fieldDescriptions [16]FieldDescription
cleanupDone chan struct{}
}
func ( context.Context, string) (*PgConn, error) {
, := ParseConfig()
if != nil {
return nil,
}
return ConnectConfig(, )
}
func ( context.Context, string, ParseConfigOptions) (*PgConn, error) {
, := ParseConfigWithOptions(, )
if != nil {
return nil,
}
return ConnectConfig(, )
}
func ( context.Context, *Config) (*PgConn, error) {
if !.createdByParseConfig {
panic("config must be created by ParseConfig")
}
var []error
, := buildConnectOneConfigs(, )
if len() > 0 {
= append(, ...)
}
if len() == 0 {
return nil, &ConnectError{Config: , err: fmt.Errorf("hostname resolving error: %w", errors.Join(...))}
}
, := connectPreferred(, , )
if len() > 0 {
= append(, ...)
return nil, &ConnectError{Config: , err: errors.Join(...)}
}
if .AfterConnect != nil {
:= .AfterConnect(, )
if != nil {
.conn.Close()
return nil, &ConnectError{Config: , err: fmt.Errorf("AfterConnect error: %w", )}
}
}
return , nil
}
func ( context.Context, *Config) ([]*connectOneConfig, []error) {
:= []*FallbackConfig{
{
Host: .Host,
Port: .Port,
TLSConfig: .TLSConfig,
},
}
= append(, .Fallbacks...)
var []*connectOneConfig
var []error
for , := range {
if isAbsolutePath(.Host) {
, := NetworkAddress(.Host, .Port)
= append(, &connectOneConfig{
network: ,
address: ,
originalHostname: .Host,
tlsConfig: .TLSConfig,
})
continue
}
, := .LookupFunc(, .Host)
if != nil {
= append(, )
continue
}
for , := range {
, , := net.SplitHostPort()
if == nil {
, := strconv.ParseUint(, 10, 16)
if != nil {
return nil, []error{fmt.Errorf("error parsing port (%s) from lookup: %w", , )}
}
, := NetworkAddress(, uint16())
= append(, &connectOneConfig{
network: ,
address: ,
originalHostname: .Host,
tlsConfig: .TLSConfig,
})
} else {
, := NetworkAddress(, .Port)
= append(, &connectOneConfig{
network: ,
address: ,
originalHostname: .Host,
tlsConfig: .TLSConfig,
})
}
}
}
return ,
}
func ( context.Context, *Config, []*connectOneConfig) (*PgConn, []error) {
:=
var []error
var *connectOneConfig
for , := range {
if .ConnectTimeout != 0 {
if == 0 || ([].address != [-1].address) {
var context.CancelFunc
, = context.WithTimeout(, .ConnectTimeout)
defer ()
}
} else {
=
}
, := connectOne(, , , false)
if != nil {
return , nil
}
= append(, )
var *PgError
if errors.As(, &) {
const = "28P01"
const = "3D000"
const = "42501"
if .Code == ||
.Code == ||
.Code == {
return nil,
}
}
var *NotPreferredError
if errors.As(, &) {
=
}
}
if != nil {
, := connectOne(, , , true)
if == nil {
return , nil
}
= append(, )
}
return nil,
}
func ( context.Context, *Config, *connectOneConfig,
bool,
) (*PgConn, error) {
:= new(PgConn)
.config =
.cleanupDone = make(chan struct{})
.customData = make(map[string]any)
var error
:= func( string, error) *perDialConnectError {
= normalizeTimeoutError(, )
:= &perDialConnectError{address: .address, originalHostname: .originalHostname, err: fmt.Errorf("%s: %w", , )}
return
}
, := parseProtocolVersion(.MaxProtocolVersion)
if != nil {
return nil, ("invalid max_protocol_version", )
}
, := parseProtocolVersion(.MinProtocolVersion)
if != nil {
return nil, ("invalid min_protocol_version", )
}
.conn, = .DialFunc(, .network, .address)
if != nil {
return nil, ("dial error", )
}
if .tlsConfig != nil {
.contextWatcher = ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: .conn})
.contextWatcher.Watch()
var (
net.Conn
error
)
if .SSLNegotiation == "direct" {
= tls.Client(.conn, .tlsConfig)
} else {
, = startTLS(.conn, .tlsConfig)
}
.contextWatcher.Unwatch()
if != nil {
.conn.Close()
return nil, ("tls error", )
}
.conn =
}
if .AfterNetConnect != nil {
.conn, = .AfterNetConnect(, , .conn)
if != nil {
.conn.Close()
return nil, ("AfterNetConnect failed", )
}
}
.contextWatcher = ctxwatch.NewContextWatcher(.BuildContextWatcherHandler())
.contextWatcher.Watch()
defer .contextWatcher.Unwatch()
.parameterStatuses = make(map[string]string)
.status = connStatusConnecting
.bgReader = bgreader.New(.conn)
.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64),
func() {
.bgReader.Start()
.bgReaderStarted <- struct{}{}
},
)
.slowWriteTimer.Stop()
.bgReaderStarted = make(chan struct{})
.frontend = .BuildFrontend(.bgReader, .conn)
:= pgproto3.StartupMessage{
ProtocolVersion: ,
Parameters: make(map[string]string),
}
maps.Copy(.Parameters, .RuntimeParams)
.Parameters["user"] = .User
if .Database != "" {
.Parameters["database"] = .Database
}
.frontend.Send(&)
if := .flushWithPotentialWriteReadDeadlock(); != nil {
.conn.Close()
return nil, ("failed to write startup message", )
}
for {
, := .receiveMessage()
if != nil {
.conn.Close()
if , := .(*PgError); {
return nil, ("server error", )
}
return nil, ("failed to receive message", )
}
switch msg := .(type) {
case *pgproto3.BackendKeyData:
.pid = .ProcessID
.secretKey = .SecretKey
case *pgproto3.AuthenticationOk:
case *pgproto3.AuthenticationCleartextPassword:
= .txPasswordMessage(.config.Password)
if != nil {
.conn.Close()
return nil, ("failed to write password message", )
}
case *pgproto3.AuthenticationMD5Password:
:= "md5" + hexMD5(hexMD5(.config.Password+.config.User)+string(.Salt[:]))
= .txPasswordMessage()
if != nil {
.conn.Close()
return nil, ("failed to write password message", )
}
case *pgproto3.AuthenticationSASL:
:= false
for , := range .AuthMechanisms {
if == "OAUTHBEARER" {
= true
break
}
}
if && .config.OAuthTokenProvider != nil {
= .oauthAuth()
} else {
= .scramAuth(.AuthMechanisms)
}
if != nil {
.conn.Close()
return nil, ("failed SASL auth", )
}
case *pgproto3.AuthenticationGSS:
= .gssAuth()
if != nil {
.conn.Close()
return nil, ("failed GSS auth", )
}
case *pgproto3.ReadyForQuery:
.status = connStatusIdle
if .ValidateConnect != nil {
.contextWatcher.Unwatch()
:= .ValidateConnect(, )
if != nil {
if , := .(*NotPreferredError); && {
return , nil
}
.conn.Close()
return nil, ("ValidateConnect failed", )
}
}
return , nil
case *pgproto3.ParameterStatus, *pgproto3.NoticeResponse:
case *pgproto3.NegotiateProtocolVersion:
:= pgproto3.ProtocolVersion30&0xFFFF0000 | uint32(.NewestMinorProtocol)
if < {
.conn.Close()
return nil, ("server protocol version too low", nil)
}
case *pgproto3.ErrorResponse:
.conn.Close()
return nil, ("server error", ErrorResponseToPgError())
default:
.conn.Close()
return nil, ("received unexpected message", )
}
}
}
func ( net.Conn, *tls.Config) (net.Conn, error) {
:= binary.Write(, binary.BigEndian, []int32{8, 80877103})
if != nil {
return nil,
}
:= make([]byte, 1)
if _, = io.ReadFull(, ); != nil {
return nil,
}
if [0] != 'S' {
return nil, errors.New("server refused TLS connection")
}
return tls.Client(, ), nil
}
func ( *PgConn) ( string) ( error) {
.frontend.Send(&pgproto3.PasswordMessage{Password: })
return .flushWithPotentialWriteReadDeadlock()
}
func ( string) string {
:= md5.New()
io.WriteString(, )
return hex.EncodeToString(.Sum(nil))
}
func ( *PgConn) () chan struct{} {
if .bufferingReceive {
panic("BUG: signalMessage when already in progress")
}
.bufferingReceive = true
.bufferingReceiveMux.Lock()
:= make(chan struct{})
go func() {
.bufferingReceiveMsg, .bufferingReceiveErr = .frontend.Receive()
.bufferingReceiveMux.Unlock()
close()
}()
return
}
func ( *PgConn) ( context.Context) (pgproto3.BackendMessage, error) {
if := .lock(); != nil {
return nil,
}
defer .unlock()
if != context.Background() {
select {
case <-.Done():
return nil, newContextAlreadyDoneError()
default:
}
.contextWatcher.Watch()
defer .contextWatcher.Unwatch()
}
, := .receiveMessage()
if != nil {
= &pgconnError{
msg: "receive message failed",
err: normalizeTimeoutError(, ),
safeToRetry: true,
}
}
return ,
}
func ( *PgConn) () (pgproto3.BackendMessage, error) {
if .peekedMsg != nil {
return .peekedMsg, nil
}
var pgproto3.BackendMessage
var error
if .bufferingReceive {
.bufferingReceiveMux.Lock()
= .bufferingReceiveMsg
= .bufferingReceiveErr
.bufferingReceiveMux.Unlock()
.bufferingReceive = false
var net.Error
if errors.As(, &) && .Timeout() {
, = .frontend.Receive()
}
} else {
, = .frontend.Receive()
}
if != nil {
var net.Error
:= errors.As(, &)
if !( && .Timeout()) {
.asyncClose()
}
return nil,
}
.peekedMsg =
return , nil
}
func ( *PgConn) () (pgproto3.BackendMessage, error) {
if .status == connStatusClosed {
return nil, &connLockError{status: "conn closed"}
}
, := .peekMessage()
if != nil {
return nil,
}
.peekedMsg = nil
switch msg := .(type) {
case *pgproto3.ReadyForQuery:
.txStatus = .TxStatus
case *pgproto3.ParameterStatus:
.parameterStatuses[.Name] = .Value
case *pgproto3.ErrorResponse:
:= ErrorResponseToPgError()
if .config.OnPgError != nil && !.config.OnPgError(, ) {
.status = connStatusClosed
.conn.Close()
close(.cleanupDone)
return nil,
}
case *pgproto3.NoticeResponse:
if .config.OnNotice != nil {
.config.OnNotice(, noticeResponseToNotice())
}
case *pgproto3.NotificationResponse:
if .config.OnNotification != nil {
.config.OnNotification(, &Notification{PID: .PID, Channel: .Channel, Payload: .Payload})
}
}
return , nil
}
func ( *PgConn) () net.Conn {
return .conn
}
func ( *PgConn) () uint32 {
return .pid
}
func ( *PgConn) () byte {
return .txStatus
}
func ( *PgConn) () []byte {
return .secretKey
}
func ( *PgConn) () *pgproto3.Frontend {
return .frontend
}
func ( *PgConn) ( context.Context) error {
if .status == connStatusClosed {
return nil
}
.status = connStatusClosed
defer close(.cleanupDone)
defer .conn.Close()
if != context.Background() {
.contextWatcher.Unwatch()
.contextWatcher.Watch()
defer .contextWatcher.Unwatch()
}
.frontend.Send(&pgproto3.Terminate{})
.flushWithPotentialWriteReadDeadlock()
return .conn.Close()
}
func ( *PgConn) () {
if .status == connStatusClosed {
return
}
.status = connStatusClosed
go func() {
defer close(.cleanupDone)
defer .conn.Close()
:= time.Now().Add(time.Second * 15)
, := context.WithDeadline(context.Background(), )
defer ()
.CancelRequest()
.conn.SetDeadline()
.frontend.Send(&pgproto3.Terminate{})
.flushWithPotentialWriteReadDeadlock()
}()
}
func ( *PgConn) () chan (struct{}) {
return .cleanupDone
}
func ( *PgConn) () bool {
return .status < connStatusIdle
}
func ( *PgConn) () bool {
return .status == connStatusBusy
}
func ( *PgConn) () error {
switch .status {
case connStatusBusy:
return &connLockError{status: "conn busy"}
case connStatusClosed:
return &connLockError{status: "conn closed"}
case connStatusUninitialized:
return &connLockError{status: "conn uninitialized"}
}
.status = connStatusBusy
return nil
}
func ( *PgConn) () {
switch .status {
case connStatusBusy:
.status = connStatusIdle
case connStatusClosed:
default:
panic("BUG: cannot unlock unlocked connection")
}
}
func ( *PgConn) ( string) string {
return .parameterStatuses[]
}
type CommandTag struct {
s string
}
func ( string) CommandTag {
return CommandTag{s: }
}
func ( CommandTag) () int64 {
var int64
var int64 = 1
for := len(.s) - 1; >= 0; -- {
:= .s[]
if >= '0' && <= '9' {
+= int64(-'0') *
*= 10
} else {
break
}
}
return
}
func ( CommandTag) () string {
return .s
}
func ( CommandTag) () bool {
return strings.HasPrefix(.s, "INSERT")
}
func ( CommandTag) () bool {
return strings.HasPrefix(.s, "UPDATE")
}
func ( CommandTag) () bool {
return strings.HasPrefix(.s, "DELETE")
}
func ( CommandTag) () bool {
return strings.HasPrefix(.s, "SELECT")
}
type FieldDescription struct {
Name string
TableOID uint32
TableAttributeNumber uint16
DataTypeOID uint32
DataTypeSize int16
TypeModifier int32
Format int16
}
func ( *PgConn) ( int) []FieldDescription {
if cap(.fieldDescriptions) >= {
return .fieldDescriptions[::]
} else {
return make([]FieldDescription, )
}
}
func ( []FieldDescription, *pgproto3.RowDescription) {
for := range .Fields {
[].Name = string(.Fields[].Name)
[].TableOID = .Fields[].TableOID
[].TableAttributeNumber = .Fields[].TableAttributeNumber
[].DataTypeOID = .Fields[].DataTypeOID
[].DataTypeSize = .Fields[].DataTypeSize
[].TypeModifier = .Fields[].TypeModifier
[].Format = .Fields[].Format
}
}
type StatementDescription struct {
Name string
SQL string
ParamOIDs []uint32
Fields []FieldDescription
}
func ( *PgConn) ( context.Context, , string, []uint32) (*StatementDescription, error) {
if := .lock(); != nil {
return nil,
}
defer .unlock()
if != context.Background() {
select {
case <-.Done():
return nil, newContextAlreadyDoneError()
default:
}
.contextWatcher.Watch()
defer .contextWatcher.Unwatch()
}
.frontend.SendParse(&pgproto3.Parse{Name: , Query: , ParameterOIDs: })
.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: })
.frontend.SendSync(&pgproto3.Sync{})
:= .flushWithPotentialWriteReadDeadlock()
if != nil {
.asyncClose()
return nil,
}
:= &StatementDescription{Name: , SQL: }
var bool
var *PgError
:
for {
, := .receiveMessage()
if != nil {
.asyncClose()
return nil, normalizeTimeoutError(, )
}
switch msg := .(type) {
case *pgproto3.ParseComplete:
= true
case *pgproto3.ParameterDescription:
.ParamOIDs = make([]uint32, len(.ParameterOIDs))
copy(.ParamOIDs, .ParameterOIDs)
case *pgproto3.RowDescription:
.Fields = make([]FieldDescription, len(.Fields))
convertRowDescription(.Fields, )
case *pgproto3.ErrorResponse:
= ErrorResponseToPgError()
case *pgproto3.ReadyForQuery:
break
}
}
if != nil {
return nil, &PrepareError{err: , ParseComplete: }
}
return , nil
}
func ( *PgConn) ( context.Context, string) error {
if := .lock(); != nil {
return
}
defer .unlock()
if != context.Background() {
select {
case <-.Done():
return newContextAlreadyDoneError()
default:
}
.contextWatcher.Watch()
defer .contextWatcher.Unwatch()
}
.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: })
.frontend.SendSync(&pgproto3.Sync{})
:= .flushWithPotentialWriteReadDeadlock()
if != nil {
.asyncClose()
return
}
for {
, := .receiveMessage()
if != nil {
.asyncClose()
return normalizeTimeoutError(, )
}
switch msg := .(type) {
case *pgproto3.ErrorResponse:
return ErrorResponseToPgError()
case *pgproto3.ReadyForQuery:
return nil
}
}
}
func ( *pgproto3.ErrorResponse) *PgError {
return &PgError{
Severity: .Severity,
SeverityUnlocalized: .SeverityUnlocalized,
Code: string(.Code),
Message: string(.Message),
Detail: string(.Detail),
Hint: .Hint,
Position: .Position,
InternalPosition: .InternalPosition,
InternalQuery: string(.InternalQuery),
Where: string(.Where),
SchemaName: string(.SchemaName),
TableName: string(.TableName),
ColumnName: string(.ColumnName),
DataTypeName: string(.DataTypeName),
ConstraintName: .ConstraintName,
File: string(.File),
Line: .Line,
Routine: string(.Routine),
}
}
func ( *pgproto3.NoticeResponse) *Notice {
:= ErrorResponseToPgError((*pgproto3.ErrorResponse)())
return (*Notice)()
}
func ( *PgConn) ( context.Context) error {
:= .conn.RemoteAddr()
var string
var string
if .Network() == "unix" {
, = NetworkAddress(.config.Host, .config.Port)
} else {
, = .Network(), .String()
}
, := .config.DialFunc(, , )
if != nil {
if .Network() != "unix" {
return
}
, := NetworkAddress(.config.Host, .config.Port)
, = .config.DialFunc(, , )
if != nil {
return
}
}
defer .Close()
if != context.Background() {
:= ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: })
.Watch()
defer .Unwatch()
}
:= make([]byte, 12+len(.secretKey))
binary.BigEndian.PutUint32([0:4], uint32(len()))
binary.BigEndian.PutUint32([4:8], 80877102)
binary.BigEndian.PutUint32([8:12], .pid)
copy([12:], .secretKey)
if , := .Write(); != nil {
return fmt.Errorf("write to connection for cancellation: %w", )
}
_, _ = .Read()
return nil
}
func ( *PgConn) ( context.Context) error {
if := .lock(); != nil {
return
}
defer .unlock()
if != context.Background() {
select {
case <-.Done():
return newContextAlreadyDoneError()
default:
}
.contextWatcher.Watch()
defer .contextWatcher.Unwatch()
}
for {
, := .receiveMessage()
if != nil {
return normalizeTimeoutError(, )
}
switch .(type) {
case *pgproto3.NotificationResponse:
return nil
}
}
}
func ( *PgConn) ( context.Context, string) *MultiResultReader {
if := .lock(); != nil {
return &MultiResultReader{
closed: true,
err: ,
}
}
.multiResultReader = MultiResultReader{
pgConn: ,
ctx: ,
}
:= &.multiResultReader
if != context.Background() {
select {
case <-.Done():
.closed = true
.err = newContextAlreadyDoneError()
.unlock()
return
default:
}
.contextWatcher.Watch()
}
.frontend.SendQuery(&pgproto3.Query{String: })
:= .flushWithPotentialWriteReadDeadlock()
if != nil {
.asyncClose()
.contextWatcher.Unwatch()
.closed = true
.err =
.unlock()
return
}
return
}
func ( *PgConn) ( context.Context, string, [][]byte, []uint32, , []int16) *ResultReader {
:= .execExtendedPrefix(, )
if .closed {
return
}
.frontend.SendParse(&pgproto3.Parse{Query: , ParameterOIDs: })
.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: , Parameters: , ResultFormatCodes: })
.execExtendedSuffix(, nil, nil)
return
}
func ( *PgConn) ( context.Context, string, [][]byte, , []int16) *ResultReader {
:= .execExtendedPrefix(, )
if .closed {
return
}
.frontend.SendBind(&pgproto3.Bind{PreparedStatement: , ParameterFormatCodes: , Parameters: , ResultFormatCodes: })
.execExtendedSuffix(, nil, nil)
return
}
func ( *PgConn) ( context.Context, *StatementDescription, [][]byte, , []int16) *ResultReader {
:= .execExtendedPrefix(, )
if .closed {
return
}
.frontend.SendBind(&pgproto3.Bind{PreparedStatement: .Name, ParameterFormatCodes: , Parameters: , ResultFormatCodes: })
.execExtendedSuffix(, , )
return
}
func ( *PgConn) ( context.Context, [][]byte) *ResultReader {
.resultReader = ResultReader{
pgConn: ,
ctx: ,
}
:= &.resultReader
if := .lock(); != nil {
.concludeCommand(CommandTag{}, )
.closed = true
return
}
if len() > math.MaxUint16 {
.concludeCommand(CommandTag{}, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16))
.closed = true
.unlock()
return
}
if != context.Background() {
select {
case <-.Done():
.concludeCommand(CommandTag{}, newContextAlreadyDoneError())
.closed = true
.unlock()
return
default:
}
.contextWatcher.Watch()
}
return
}
func ( *PgConn) ( *ResultReader, *StatementDescription, []int16) {
if == nil {
.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
}
.frontend.SendExecute(&pgproto3.Execute{})
.frontend.SendSync(&pgproto3.Sync{})
:= .flushWithPotentialWriteReadDeadlock()
if != nil {
.asyncClose()
.concludeCommand(CommandTag{}, )
.contextWatcher.Unwatch()
.closed = true
.unlock()
return
}
.readUntilRowDescription(, )
}
func ( *PgConn) ( context.Context, io.Writer, string) (CommandTag, error) {
if := .lock(); != nil {
return CommandTag{},
}
if != context.Background() {
select {
case <-.Done():
.unlock()
return CommandTag{}, newContextAlreadyDoneError()
default:
}
.contextWatcher.Watch()
defer .contextWatcher.Unwatch()
}
.frontend.SendQuery(&pgproto3.Query{String: })
:= .flushWithPotentialWriteReadDeadlock()
if != nil {
.asyncClose()
.unlock()
return CommandTag{},
}
var CommandTag
var error
for {
, := .receiveMessage()
if != nil {
.asyncClose()
return CommandTag{}, normalizeTimeoutError(, )
}
switch msg := .(type) {
case *pgproto3.CopyDone:
case *pgproto3.CopyData:
, := .Write(.Data)
if != nil {
.asyncClose()
return CommandTag{},
}
case *pgproto3.ReadyForQuery:
.unlock()
return ,
case *pgproto3.CommandComplete:
= .makeCommandTag(.CommandTag)
case *pgproto3.ErrorResponse:
= ErrorResponseToPgError()
}
}
}
func ( *PgConn) ( context.Context, io.Reader, string) (CommandTag, error) {
if := .lock(); != nil {
return CommandTag{},
}
defer .unlock()
if != context.Background() {
select {
case <-.Done():
return CommandTag{}, newContextAlreadyDoneError()
default:
}
.contextWatcher.Watch()
defer .contextWatcher.Unwatch()
}
.frontend.SendQuery(&pgproto3.Query{String: })
:= .flushWithPotentialWriteReadDeadlock()
if != nil {
.asyncClose()
return CommandTag{},
}
:= make(chan struct{})
:= make(chan error, 1)
:= .signalMessage()
var sync.WaitGroup
.Go(func() {
:= iobufpool.Get(65536)
defer iobufpool.Put()
(*)[0] = 'd'
for {
, := .Read((*)[5:cap(*)])
if > 0 {
* = (*)[0 : +5]
pgio.SetInt32((*)[1:], int32(+4))
:= .frontend.SendUnbufferedEncodedCopyData(*)
if != nil {
.conn.Close()
<-
return
}
}
if != nil {
<-
return
}
select {
case <-:
return
default:
}
}
})
var error
var error
for == nil && == nil {
select {
case = <-:
case <-:
if := .bufferingReceiveErr; != nil {
.status = connStatusClosed
.conn.Close()
close(.cleanupDone)
return CommandTag{}, normalizeTimeoutError(, )
}
, := .receiveMessage()
if != nil {
close()
return CommandTag{},
}
switch msg := .(type) {
case *pgproto3.ErrorResponse:
= ErrorResponseToPgError()
default:
= .signalMessage()
}
}
}
close()
.Wait()
if == io.EOF || != nil {
.frontend.Send(&pgproto3.CopyDone{})
} else {
.frontend.Send(&pgproto3.CopyFail{Message: .Error()})
}
= .flushWithPotentialWriteReadDeadlock()
if != nil {
.asyncClose()
return CommandTag{},
}
var CommandTag
for {
, := .receiveMessage()
if != nil {
.asyncClose()
return CommandTag{}, normalizeTimeoutError(, )
}
switch msg := .(type) {
case *pgproto3.ReadyForQuery:
return ,
case *pgproto3.CommandComplete:
= .makeCommandTag(.CommandTag)
case *pgproto3.ErrorResponse:
= ErrorResponseToPgError()
}
}
}
type MultiResultReader struct {
pgConn *PgConn
ctx context.Context
rr *ResultReader
statementDescriptions []*StatementDescription
resultFormats [][]int16
closed bool
err error
}
func ( *MultiResultReader) () ([]*Result, error) {
var []*Result
for .NextResult() {
= append(, .ResultReader().Read())
}
:= .Close()
return ,
}
func ( *MultiResultReader) () (pgproto3.BackendMessage, error) {
, := .pgConn.receiveMessage()
if != nil {
.pgConn.contextWatcher.Unwatch()
.err = normalizeTimeoutError(.ctx, )
.closed = true
.pgConn.asyncClose()
return nil, .err
}
switch msg := .(type) {
case *pgproto3.ReadyForQuery:
.closed = true
.pgConn.contextWatcher.Unwatch()
.pgConn.unlock()
case *pgproto3.ErrorResponse:
.err = ErrorResponseToPgError()
}
return , nil
}
func ( *MultiResultReader) () bool {
for !.closed && .err == nil {
, := .pgConn.peekMessage()
if , := .(*pgproto3.DataRow); {
if len(.statementDescriptions) > 0 {
:= ResultReader{
pgConn: .pgConn,
multiResultReader: ,
ctx: .ctx,
}
:= .statementDescriptions[0]
.statementDescriptions = .statementDescriptions[1:]
:= .resultFormats[0]
.resultFormats = .resultFormats[1:]
:= .Fields
.fieldDescriptions = .pgConn.getFieldDescriptionSlice(len())
:= combineFieldDescriptionsAndResultFormats(.fieldDescriptions, , )
if != nil {
.concludeCommand(CommandTag{}, )
}
.pgConn.resultReader =
.rr = &.pgConn.resultReader
return true
}
.err = fmt.Errorf("unexpected DataRow message without preceding RowDescription")
return false
}
, := .receiveMessage()
if != nil {
return false
}
switch msg := .(type) {
case *pgproto3.RowDescription:
.pgConn.resultReader = ResultReader{
pgConn: .pgConn,
multiResultReader: ,
ctx: .ctx,
fieldDescriptions: .pgConn.getFieldDescriptionSlice(len(.Fields)),
}
convertRowDescription(.pgConn.resultReader.fieldDescriptions, )
.rr = &.pgConn.resultReader
return true
case *pgproto3.CommandComplete:
.pgConn.resultReader = ResultReader{
commandTag: .pgConn.makeCommandTag(.CommandTag),
commandConcluded: true,
closed: true,
}
.rr = &.pgConn.resultReader
return true
case *pgproto3.EmptyQueryResponse:
.pgConn.resultReader = ResultReader{
commandConcluded: true,
closed: true,
}
.rr = &.pgConn.resultReader
return true
}
}
return false
}
func ( *MultiResultReader) () *ResultReader {
return .rr
}
func ( *MultiResultReader) () error {
for !.closed {
, := .receiveMessage()
if != nil {
return .err
}
}
return .err
}
type ResultReader struct {
pgConn *PgConn
multiResultReader *MultiResultReader
pipeline *Pipeline
ctx context.Context
fieldDescriptions []FieldDescription
rowValues [][]byte
commandTag CommandTag
preloaded bool
commandConcluded bool
closed bool
err error
}
type Result struct {
FieldDescriptions []FieldDescription
Rows [][][]byte
CommandTag CommandTag
Err error
}
func ( *ResultReader) () *Result {
:= &Result{}
for .NextRow() {
if .FieldDescriptions == nil {
.FieldDescriptions = make([]FieldDescription, len(.FieldDescriptions()))
copy(.FieldDescriptions, .FieldDescriptions())
}
:= .Values()
:= make([][]byte, len())
for := range {
if [] != nil {
[] = make([]byte, len([]))
copy([], [])
}
}
.Rows = append(.Rows, )
}
.CommandTag, .Err = .Close()
return
}
func ( *ResultReader) () bool {
if .preloaded {
.preloaded = false
return true
}
for !.commandConcluded {
, := .receiveMessage()
if != nil {
return false
}
switch msg := .(type) {
case *pgproto3.DataRow:
.rowValues = .Values
return true
}
}
return false
}
func ( *ResultReader) ( [][]byte) {
.rowValues =
.preloaded = true
}
func ( *ResultReader) () []FieldDescription {
return .fieldDescriptions
}
func ( *ResultReader) () [][]byte {
return .rowValues
}
func ( *ResultReader) () (CommandTag, error) {
if .closed {
return .commandTag, .err
}
.closed = true
for !.commandConcluded {
, := .receiveMessage()
if != nil {
return CommandTag{}, .err
}
}
if .multiResultReader == nil && .pipeline == nil {
for {
, := .receiveMessage()
if != nil {
return CommandTag{}, .err
}
switch msg := .(type) {
case *pgproto3.ErrorResponse:
.err = ErrorResponseToPgError()
case *pgproto3.ReadyForQuery:
.pgConn.contextWatcher.Unwatch()
.pgConn.unlock()
return .commandTag, .err
}
}
}
return .commandTag, .err
}
func ( *ResultReader) ( *StatementDescription, []int16) {
for !.commandConcluded {
, := .receiveMessage()
switch msg := .(type) {
case *pgproto3.RowDescription:
return
case *pgproto3.DataRow:
.preloadRowValues(.Values)
if != nil {
:= .Fields
.fieldDescriptions = .pgConn.getFieldDescriptionSlice(len())
:= combineFieldDescriptionsAndResultFormats(.fieldDescriptions, , )
if != nil {
.concludeCommand(CommandTag{}, )
}
}
return
case *pgproto3.CommandComplete:
if != nil {
:= .Fields
.fieldDescriptions = .pgConn.getFieldDescriptionSlice(len())
:= combineFieldDescriptionsAndResultFormats(.fieldDescriptions, , )
if != nil {
.concludeCommand(CommandTag{}, )
}
}
return
}
}
}
func ( *ResultReader) () ( pgproto3.BackendMessage, error) {
if .multiResultReader == nil {
, = .pgConn.receiveMessage()
} else {
, = .multiResultReader.receiveMessage()
}
if != nil {
= normalizeTimeoutError(.ctx, )
.concludeCommand(CommandTag{}, )
.pgConn.contextWatcher.Unwatch()
.closed = true
if .multiResultReader == nil {
.pgConn.asyncClose()
}
return nil, .err
}
switch msg := .(type) {
case *pgproto3.RowDescription:
.fieldDescriptions = .pgConn.getFieldDescriptionSlice(len(.Fields))
convertRowDescription(.fieldDescriptions, )
case *pgproto3.CommandComplete:
.concludeCommand(.pgConn.makeCommandTag(.CommandTag), nil)
case *pgproto3.EmptyQueryResponse:
.concludeCommand(CommandTag{}, nil)
case *pgproto3.ErrorResponse:
:= ErrorResponseToPgError()
if .pipeline != nil {
.pipeline.state.HandleError()
}
.concludeCommand(CommandTag{}, )
}
return , nil
}
func ( *ResultReader) ( CommandTag, error) {
if != nil && .err == nil {
.err =
}
if .commandConcluded {
return
}
.commandTag =
.rowValues = nil
.commandConcluded = true
}
type Batch struct {
buf []byte
statementDescriptions []*StatementDescription
resultFormats [][]int16
err error
}
func ( *Batch) ( string, [][]byte, []uint32, , []int16) {
if .err != nil {
return
}
.buf, .err = (&pgproto3.Parse{Query: , ParameterOIDs: }).Encode(.buf)
if .err != nil {
return
}
.ExecPrepared("", , , )
}
func ( *Batch) ( string, [][]byte, , []int16) {
if .err != nil {
return
}
.buf, .err = (&pgproto3.Bind{PreparedStatement: , ParameterFormatCodes: , Parameters: , ResultFormatCodes: }).Encode(.buf)
if .err != nil {
return
}
.buf, .err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(.buf)
if .err != nil {
return
}
.buf, .err = (&pgproto3.Execute{}).Encode(.buf)
if .err != nil {
return
}
}
func ( *Batch) ( *StatementDescription, [][]byte, , []int16) {
if .err != nil {
return
}
.buf, .err = (&pgproto3.Bind{PreparedStatement: .Name, ParameterFormatCodes: , Parameters: , ResultFormatCodes: }).Encode(.buf)
if .err != nil {
return
}
.statementDescriptions = append(.statementDescriptions, )
.resultFormats = append(.resultFormats, )
.buf, .err = (&pgproto3.Execute{}).Encode(.buf)
if .err != nil {
return
}
}
func ( *PgConn) ( context.Context, *Batch) *MultiResultReader {
if .err != nil {
return &MultiResultReader{
closed: true,
err: .err,
}
}
if := .lock(); != nil {
return &MultiResultReader{
closed: true,
err: ,
}
}
.multiResultReader = MultiResultReader{
pgConn: ,
ctx: ,
statementDescriptions: .statementDescriptions,
resultFormats: .resultFormats,
}
:= &.multiResultReader
if != context.Background() {
select {
case <-.Done():
.closed = true
.err = newContextAlreadyDoneError()
.unlock()
return
default:
}
.contextWatcher.Watch()
}
.buf, .err = (&pgproto3.Sync{}).Encode(.buf)
if .err != nil {
.contextWatcher.Unwatch()
.err = normalizeTimeoutError(.ctx, .err)
.closed = true
.asyncClose()
return
}
, := func( []byte) (int, error) {
.enterPotentialWriteReadDeadlock()
defer .exitPotentialWriteReadDeadlock()
return .conn.Write()
}(.buf)
if != nil {
.contextWatcher.Unwatch()
.err = normalizeTimeoutError(.ctx, )
.closed = true
.asyncClose()
return
}
return
}
func ( *PgConn) ( string) (string, error) {
if .ParameterStatus("standard_conforming_strings") != "on" {
return "", errors.New("EscapeString must be run with standard_conforming_strings=on")
}
if .ParameterStatus("client_encoding") != "UTF8" {
return "", errors.New("EscapeString must be run with client_encoding=UTF8")
}
return strings.Replace(, "'", "''", -1), nil
}
func ( *PgConn) () error {
, := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer ()
, := .ReceiveMessage()
if != nil {
if !Timeout() {
return
}
}
return nil
}
func ( *PgConn) ( context.Context) error {
return .Exec(, "-- ping").Close()
}
func ( *PgConn) ( []byte) CommandTag {
return CommandTag{s: string()}
}
func ( *PgConn) () {
if .slowWriteTimer.Reset(15 * time.Millisecond) {
panic("BUG: slow write timer already active")
}
}
func ( *PgConn) () {
if !.slowWriteTimer.Stop() {
<-.bgReaderStarted
.bgReader.Stop()
}
}
func ( *PgConn) () error {
.enterPotentialWriteReadDeadlock()
defer .exitPotentialWriteReadDeadlock()
:= .frontend.Flush()
return
}
func ( *PgConn) ( context.Context) error {
for range 10 {
if .bgReader.Status() == bgreader.StatusStopped && .frontend.ReadBufferLen() == 0 {
return nil
}
:= .Ping()
if != nil {
return fmt.Errorf("SyncConn: Ping failed while syncing conn: %w", )
}
}
return errors.New("SyncConn: conn never synchronized")
}
func ( *PgConn) () map[string]any {
return .customData
}
type HijackedConn struct {
Conn net.Conn
PID uint32
SecretKey []byte
ParameterStatuses map[string]string
TxStatus byte
Frontend *pgproto3.Frontend
Config *Config
CustomData map[string]any
}
func ( *PgConn) () (*HijackedConn, error) {
if := .lock(); != nil {
return nil,
}
.status = connStatusClosed
return &HijackedConn{
Conn: .conn,
PID: .pid,
SecretKey: .secretKey,
ParameterStatuses: .parameterStatuses,
TxStatus: .txStatus,
Frontend: .frontend,
Config: .config,
CustomData: .customData,
}, nil
}
func ( *HijackedConn) (*PgConn, error) {
:= &PgConn{
conn: .Conn,
pid: .PID,
secretKey: .SecretKey,
parameterStatuses: .ParameterStatuses,
txStatus: .TxStatus,
frontend: .Frontend,
config: .Config,
customData: .CustomData,
status: connStatusIdle,
cleanupDone: make(chan struct{}),
}
.contextWatcher = ctxwatch.NewContextWatcher(.Config.BuildContextWatcherHandler())
.bgReader = bgreader.New(.conn)
.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64),
func() {
.bgReader.Start()
.bgReaderStarted <- struct{}{}
},
)
.slowWriteTimer.Stop()
.bgReaderStarted = make(chan struct{})
.frontend = .Config.BuildFrontend(.bgReader, .conn)
return , nil
}
type Pipeline struct {
conn *PgConn
ctx context.Context
state pipelineState
err error
closed bool
}
type PipelineSync struct{}
type CloseComplete struct{}
type pipelineRequestType int
const (
pipelineNil pipelineRequestType = iota
pipelinePrepare
pipelineQueryParams
pipelineQueryPrepared
pipelineQueryStatement
pipelineDeallocate
pipelineSyncRequest
pipelineFlushRequest
)
type pipelineRequestEvent struct {
RequestType pipelineRequestType
WasSentToServer bool
BeforeFlushOrSync bool
}
type pipelineState struct {
requestEventQueue list.List
statementDescriptionsQueue list.List
resultFormatsQueue list.List
lastRequestType pipelineRequestType
pgErr *PgError
expectedReadyForQueryCount int
}
func ( *pipelineState) () {
.requestEventQueue.Init()
.statementDescriptionsQueue.Init()
.resultFormatsQueue.Init()
.lastRequestType = pipelineNil
}
func ( *pipelineState) () {
for := .requestEventQueue.Back(); != nil; = .Prev() {
:= .Value.(pipelineRequestEvent)
if .WasSentToServer {
return
}
.WasSentToServer = true
.Value =
}
}
func ( *pipelineState) () {
for := .requestEventQueue.Back(); != nil; = .Prev() {
:= .Value.(pipelineRequestEvent)
if .BeforeFlushOrSync {
return
}
.BeforeFlushOrSync = true
.Value =
}
}
func ( *pipelineState) ( pipelineRequestType) {
if == pipelineNil {
return
}
if != pipelineFlushRequest {
.requestEventQueue.PushBack(pipelineRequestEvent{RequestType: })
}
if == pipelineFlushRequest || == pipelineSyncRequest {
.registerFlushingBufferOnServer()
}
.lastRequestType =
if == pipelineSyncRequest {
.expectedReadyForQueryCount++
}
}
func ( *pipelineState) () pipelineRequestType {
for {
:= .requestEventQueue.Front()
if == nil {
return pipelineNil
}
:= .Value.(pipelineRequestEvent)
if !(.WasSentToServer && .BeforeFlushOrSync) {
return pipelineNil
}
.requestEventQueue.Remove()
if .RequestType == pipelineSyncRequest {
.pgErr = nil
}
if .pgErr == nil {
return .RequestType
}
}
}
func ( *pipelineState) ( *StatementDescription, []int16) {
.statementDescriptionsQueue.PushBack()
.resultFormatsQueue.PushBack()
}
func ( *pipelineState) () (*StatementDescription, []int16) {
:= .statementDescriptionsQueue.Front()
var *StatementDescription
if != nil {
.statementDescriptionsQueue.Remove()
= .Value.(*StatementDescription)
}
:= .resultFormatsQueue.Front()
var []int16
if != nil {
.resultFormatsQueue.Remove()
= .Value.([]int16)
}
return ,
}
func ( *pipelineState) ( *PgError) {
.pgErr =
}
func ( *pipelineState) () {
.expectedReadyForQueryCount--
}
func ( *pipelineState) () bool {
var bool
if := .requestEventQueue.Back(); != nil {
:= .Value.(pipelineRequestEvent)
= (.RequestType == pipelineSyncRequest) && .WasSentToServer
} else {
= (.lastRequestType == pipelineSyncRequest) || (.lastRequestType == pipelineNil)
}
return !
}
func ( *pipelineState) () int {
return .expectedReadyForQueryCount
}
func ( *PgConn) ( context.Context) *Pipeline {
if := .lock(); != nil {
:= &Pipeline{
closed: true,
err: ,
}
.state.Init()
return
}
.resultReader = ResultReader{closed: true}
.pipeline = Pipeline{
conn: ,
ctx: ,
}
.pipeline.state.Init()
:= &.pipeline
if != context.Background() {
select {
case <-.Done():
.closed = true
.err = newContextAlreadyDoneError()
.unlock()
return
default:
}
.contextWatcher.Watch()
}
return
}
func ( *Pipeline) (, string, []uint32) {
if .closed {
return
}
.conn.frontend.SendParse(&pgproto3.Parse{Name: , Query: , ParameterOIDs: })
.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: })
.state.PushBackRequestType(pipelinePrepare)
}
func ( *Pipeline) ( string) {
if .closed {
return
}
.conn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: })
.state.PushBackRequestType(pipelineDeallocate)
}
func ( *Pipeline) ( string, [][]byte, []uint32, , []int16) {
if .closed {
return
}
.conn.frontend.SendParse(&pgproto3.Parse{Query: , ParameterOIDs: })
.conn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: , Parameters: , ResultFormatCodes: })
.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
.conn.frontend.SendExecute(&pgproto3.Execute{})
.state.PushBackRequestType(pipelineQueryParams)
}
func ( *Pipeline) ( string, [][]byte, , []int16) {
if .closed {
return
}
.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: , ParameterFormatCodes: , Parameters: , ResultFormatCodes: })
.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
.conn.frontend.SendExecute(&pgproto3.Execute{})
.state.PushBackRequestType(pipelineQueryPrepared)
}
func ( *Pipeline) ( *StatementDescription, [][]byte, , []int16) {
if .closed {
return
}
.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: .Name, ParameterFormatCodes: , Parameters: , ResultFormatCodes: })
.conn.frontend.SendExecute(&pgproto3.Execute{})
.state.PushBackRequestType(pipelineQueryStatement)
.state.PushBackStatementData(, )
}
func ( *Pipeline) () {
if .closed {
return
}
.conn.frontend.Send(&pgproto3.Flush{})
.state.PushBackRequestType(pipelineFlushRequest)
}
func ( *Pipeline) () {
if .closed {
return
}
.conn.frontend.SendSync(&pgproto3.Sync{})
.state.PushBackRequestType(pipelineSyncRequest)
}
func ( *Pipeline) () error {
if .closed {
if .err != nil {
return .err
}
return errors.New("pipeline closed")
}
:= .conn.flushWithPotentialWriteReadDeadlock()
if != nil {
= normalizeTimeoutError(.ctx, )
.conn.asyncClose()
.conn.contextWatcher.Unwatch()
.conn.unlock()
.closed = true
.err =
return
}
.state.RegisterSendingToServer()
return nil
}
func ( *Pipeline) () error {
.SendPipelineSync()
return .Flush()
}
func ( *Pipeline) () ( any, error) {
if .closed {
if .err != nil {
return nil, .err
}
return nil, errors.New("pipeline closed")
}
return .getResults()
}
func ( *Pipeline) () ( any, error) {
if !.conn.resultReader.closed {
, := .conn.resultReader.Close()
if != nil {
return nil,
}
}
:= .state.ExtractFrontRequestType()
switch {
case pipelineNil:
return nil, nil
case pipelinePrepare:
return .getResultsPrepare()
case pipelineQueryParams:
return .getResultsQueryParams()
case pipelineQueryPrepared:
return .getResultsQueryPrepared()
case pipelineQueryStatement:
return .getResultsQueryStatement()
case pipelineDeallocate:
return .getResultsDeallocate()
case pipelineSyncRequest:
return .getResultsSync()
case pipelineFlushRequest:
return nil, errors.New("BUG: pipelineFlushRequest should not be in request queue")
default:
return nil, errors.New("BUG: unknown pipeline request type")
}
}
func ( *Pipeline) () (*StatementDescription, error) {
:= .receiveParseComplete("Prepare")
if != nil {
return nil,
}
:= &StatementDescription{}
, := .receiveMessage()
if != nil {
return nil,
}
switch msg := .(type) {
case *pgproto3.ParameterDescription:
.ParamOIDs = make([]uint32, len(.ParameterOIDs))
copy(.ParamOIDs, .ParameterOIDs)
case *pgproto3.ErrorResponse:
:= ErrorResponseToPgError()
.state.HandleError()
return nil,
default:
return nil, .handleUnexpectedMessage("Prepare ParameterDescription", )
}
, = .receiveMessage()
if != nil {
return nil,
}
switch msg := .(type) {
case *pgproto3.RowDescription:
.Fields = make([]FieldDescription, len(.Fields))
convertRowDescription(.Fields, )
return , nil
case *pgproto3.NoData:
return , nil
case *pgproto3.ErrorResponse:
:= ErrorResponseToPgError()
.state.HandleError()
return nil,
default:
return nil, .handleUnexpectedMessage("Prepare RowDescription", )
}
}
func ( *Pipeline) () (*ResultReader, error) {
:= .receiveParseComplete("QueryParams")
if != nil {
return nil,
}
= .receiveBindComplete("QueryParams")
if != nil {
return nil,
}
return .receiveDescribedResultReader("QueryParams")
}
func ( *Pipeline) () (*ResultReader, error) {
:= .receiveBindComplete("QueryPrepared")
if != nil {
return nil,
}
return .receiveDescribedResultReader("QueryPrepared")
}
func ( *Pipeline) () (*ResultReader, error) {
:= .receiveBindComplete("QueryStatement")
if != nil {
return nil,
}
, := .receiveMessage()
if != nil {
return nil,
}
, := .state.ExtractFrontStatementData()
if == nil {
return nil, errors.New("BUG: missing statement description or result formats for QueryStatement")
}
:= .Fields
:= .conn.getFieldDescriptionSlice(len())
= combineFieldDescriptionsAndResultFormats(, , )
if != nil {
return nil,
}
switch msg := .(type) {
case *pgproto3.DataRow:
:= ResultReader{
pgConn: .conn,
pipeline: ,
ctx: .ctx,
fieldDescriptions: ,
}
.preloadRowValues(.Values)
.conn.resultReader =
return &.conn.resultReader, nil
case *pgproto3.CommandComplete:
.conn.resultReader = ResultReader{
commandTag: .conn.makeCommandTag(.CommandTag),
commandConcluded: true,
closed: true,
fieldDescriptions: ,
}
return &.conn.resultReader, nil
case *pgproto3.ErrorResponse:
:= ErrorResponseToPgError()
.state.HandleError()
.conn.resultReader.closed = true
return nil,
default:
return nil, .handleUnexpectedMessage("QueryStatement", )
}
}
func ( *Pipeline) () (*CloseComplete, error) {
, := .receiveMessage()
if != nil {
return nil,
}
switch msg := .(type) {
case *pgproto3.CloseComplete:
return &CloseComplete{}, nil
case *pgproto3.ErrorResponse:
:= ErrorResponseToPgError()
.state.HandleError()
.conn.resultReader.closed = true
return nil,
default:
return nil, .handleUnexpectedMessage("Deallocate", )
}
}
func ( *Pipeline) () (*PipelineSync, error) {
, := .receiveMessage()
if != nil {
return nil,
}
switch msg := .(type) {
case *pgproto3.ReadyForQuery:
.state.HandleReadyForQuery()
return &PipelineSync{}, nil
case *pgproto3.ErrorResponse:
.state.requestEventQueue.PushFront(pipelineRequestEvent{RequestType: pipelineSyncRequest, WasSentToServer: true, BeforeFlushOrSync: true})
:= ErrorResponseToPgError()
.state.HandleError()
.conn.resultReader.closed = true
return nil,
default:
return nil, .handleUnexpectedMessage("Sync", )
}
}
func ( *Pipeline) ( string) error {
, := .receiveMessage()
if != nil {
return
}
switch msg := .(type) {
case *pgproto3.ParseComplete:
return nil
case *pgproto3.ErrorResponse:
:= ErrorResponseToPgError()
.state.HandleError()
return
default:
return .handleUnexpectedMessage(fmt.Sprintf("%s Parse", ), )
}
}
func ( *Pipeline) ( string) error {
, := .receiveMessage()
if != nil {
return
}
switch msg := .(type) {
case *pgproto3.BindComplete:
return nil
case *pgproto3.ErrorResponse:
:= ErrorResponseToPgError()
.state.HandleError()
return
default:
return .handleUnexpectedMessage(fmt.Sprintf("%s Bind", ), )
}
}
func ( *Pipeline) ( string) (*ResultReader, error) {
, := .receiveMessage()
if != nil {
return nil,
}
switch msg := .(type) {
case *pgproto3.RowDescription:
.conn.resultReader = ResultReader{
pgConn: .conn,
pipeline: ,
ctx: .ctx,
fieldDescriptions: .conn.getFieldDescriptionSlice(len(.Fields)),
}
convertRowDescription(.conn.resultReader.fieldDescriptions, )
return &.conn.resultReader, nil
case *pgproto3.NoData:
case *pgproto3.ErrorResponse:
:= ErrorResponseToPgError()
.state.HandleError()
.conn.resultReader.closed = true
return nil,
default:
return nil, .handleUnexpectedMessage(fmt.Sprintf("%s RowDescription or NoData", ), )
}
, = .receiveMessage()
if != nil {
return nil,
}
switch msg := .(type) {
case *pgproto3.CommandComplete:
.conn.resultReader = ResultReader{
commandTag: .conn.makeCommandTag(.CommandTag),
commandConcluded: true,
closed: true,
}
return &.conn.resultReader, nil
case *pgproto3.ErrorResponse:
:= ErrorResponseToPgError()
.state.HandleError()
.conn.resultReader.closed = true
return nil,
default:
return nil, .handleUnexpectedMessage(fmt.Sprintf("%s CommandComplete", ), )
}
}
func ( *Pipeline) () (pgproto3.BackendMessage, error) {
for {
, := .conn.receiveMessage()
if != nil {
.err =
.conn.asyncClose()
return nil, normalizeTimeoutError(.ctx, )
}
switch msg := .(type) {
case *pgproto3.ParameterStatus, *pgproto3.NoticeResponse, *pgproto3.NotificationResponse:
default:
return , nil
}
}
}
func ( *Pipeline) ( string, pgproto3.BackendMessage) error {
.err = fmt.Errorf("pipeline: %s: received unexpected message type %T", , )
.conn.asyncClose()
return .err
}
func ( *Pipeline) () error {
if .closed {
return .err
}
.closed = true
if .state.PendingSync() {
.conn.asyncClose()
.err = errors.New("pipeline has unsynced requests")
.conn.contextWatcher.Unwatch()
.conn.unlock()
return .err
}
for .state.ExpectedReadyForQuery() > 0 {
, := .getResults()
if != nil {
.err =
var *PgError
if !errors.As(, &) {
.conn.asyncClose()
break
}
} else if == nil {
.conn.asyncClose()
if .err == nil {
.err = errors.New("pipeline: no more results but expected ReadyForQuery")
}
break
}
}
.conn.contextWatcher.Unwatch()
.conn.unlock()
return .err
}
type DeadlineContextWatcherHandler struct {
Conn net.Conn
DeadlineDelay time.Duration
}
func ( *DeadlineContextWatcherHandler) ( context.Context) {
.Conn.SetDeadline(time.Now().Add(.DeadlineDelay))
}
func ( *DeadlineContextWatcherHandler) () {
.Conn.SetDeadline(time.Time{})
}
type CancelRequestContextWatcherHandler struct {
Conn *PgConn
CancelRequestDelay time.Duration
DeadlineDelay time.Duration
cancelFinishedChan chan struct{}
handleUnwatchAfterCancelCalled func()
}
func ( *CancelRequestContextWatcherHandler) (context.Context) {
.cancelFinishedChan = make(chan struct{})
var context.Context
, .handleUnwatchAfterCancelCalled = context.WithCancel(context.Background())
:= time.Now().Add(.DeadlineDelay)
.Conn.conn.SetDeadline()
go func() {
defer close(.cancelFinishedChan)
select {
case <-.Done():
return
case <-time.After(.CancelRequestDelay):
}
, := context.WithDeadline(, )
defer ()
.Conn.CancelRequest()
time.Sleep(100 * time.Millisecond)
}()
}
func ( *CancelRequestContextWatcherHandler) () {
.handleUnwatchAfterCancelCalled()
<-.cancelFinishedChan
.Conn.conn.SetDeadline(time.Time{})
}
func (, []FieldDescription, []int16) error {
switch {
case len() == 0:
for := range {
[] = []
[].Format = pgtype.TextFormatCode
}
case len() == 1:
:= [0]
for := range {
[] = []
[].Format =
}
case len() == len():
for := range {
[] = []
[].Format = []
}
default:
return fmt.Errorf("result format codes length %d does not match field count %d", len(), len())
}
return nil
}