package pgproto3
import (
)
type Backend struct {
cr *chunkReader
w io.Writer
tracer *tracer
wbuf []byte
encodeError error
bind Bind
cancelRequest CancelRequest
_close Close
copyFail CopyFail
copyData CopyData
copyDone CopyDone
describe Describe
execute Execute
flush Flush
functionCall FunctionCall
gssEncRequest GSSEncRequest
parse Parse
query Query
sslRequest SSLRequest
startupMessage StartupMessage
sync Sync
terminate Terminate
bodyLen int
maxBodyLen int
msgType byte
partialMsg bool
authType uint32
}
const (
minStartupPacketLen = 4
maxStartupPacketLen = 10_000
)
func ( io.Reader, io.Writer) *Backend {
:= newChunkReader(, 0)
return &Backend{cr: , w: }
}
func ( *Backend) ( BackendMessage) {
if .encodeError != nil {
return
}
:= len(.wbuf)
, := .Encode(.wbuf)
if != nil {
.encodeError =
return
}
.wbuf =
if .tracer != nil {
.tracer.traceMessage('B', int32(len(.wbuf)-), )
}
}
func ( *Backend) () error {
if := .encodeError; != nil {
.encodeError = nil
.wbuf = .wbuf[:0]
return &writeError{err: , safeToRetry: true}
}
, := .w.Write(.wbuf)
const = 1024
if len(.wbuf) > {
.wbuf = make([]byte, 0, )
} else {
.wbuf = .wbuf[:0]
}
if != nil {
return &writeError{err: , safeToRetry: == 0}
}
return nil
}
func ( *Backend) ( io.Writer, TracerOptions) {
.tracer = &tracer{
w: ,
buf: &bytes.Buffer{},
TracerOptions: ,
}
}
func ( *Backend) () {
.tracer = nil
}
func ( *Backend) () (FrontendMessage, error) {
, := .cr.Next(4)
if != nil {
return nil,
}
:= int(int32(binary.BigEndian.Uint32()) - 4)
if < minStartupPacketLen || > maxStartupPacketLen {
return nil, fmt.Errorf("invalid length of startup packet: %d", )
}
, = .cr.Next()
if != nil {
return nil, translateEOFtoErrUnexpectedEOF()
}
:= binary.BigEndian.Uint32()
switch {
case ProtocolVersion30, ProtocolVersion32:
= .startupMessage.Decode()
if != nil {
return nil,
}
return &.startupMessage, nil
case sslRequestNumber:
= .sslRequest.Decode()
if != nil {
return nil,
}
return &.sslRequest, nil
case cancelRequestCode:
= .cancelRequest.Decode()
if != nil {
return nil,
}
return &.cancelRequest, nil
case gssEncReqNumber:
= .gssEncRequest.Decode()
if != nil {
return nil,
}
return &.gssEncRequest, nil
default:
return nil, fmt.Errorf("unknown startup message code: %d", )
}
}
func ( *Backend) () (FrontendMessage, error) {
if !.partialMsg {
, := .cr.Next(5)
if != nil {
return nil, translateEOFtoErrUnexpectedEOF()
}
.msgType = [0]
:= int(int32(binary.BigEndian.Uint32([1:])))
if < 4 {
return nil, fmt.Errorf("invalid message length: %d", )
}
.bodyLen = - 4
if .maxBodyLen > 0 && .bodyLen > .maxBodyLen {
return nil, &ExceededMaxBodyLenErr{.maxBodyLen, .bodyLen}
}
.partialMsg = true
}
var FrontendMessage
switch .msgType {
case 'B':
= &.bind
case 'C':
= &._close
case 'D':
= &.describe
case 'E':
= &.execute
case 'F':
= &.functionCall
case 'f':
= &.copyFail
case 'd':
= &.copyData
case 'c':
= &.copyDone
case 'H':
= &.flush
case 'P':
= &.parse
case 'p':
switch .authType {
case AuthTypeSASL:
= &SASLInitialResponse{}
case AuthTypeSASLContinue:
= &SASLResponse{}
case AuthTypeSASLFinal:
= &SASLResponse{}
case AuthTypeGSS, AuthTypeGSSCont:
= &GSSResponse{}
case AuthTypeCleartextPassword, AuthTypeMD5Password:
fallthrough
default:
= &PasswordMessage{}
}
case 'Q':
= &.query
case 'S':
= &.sync
case 'X':
= &.terminate
default:
return nil, fmt.Errorf("unknown message type: %c", .msgType)
}
, := .cr.Next(.bodyLen)
if != nil {
return nil, translateEOFtoErrUnexpectedEOF()
}
.partialMsg = false
= .Decode()
if != nil {
return nil,
}
if .tracer != nil {
.tracer.traceMessage('F', int32(5+len()), )
}
return , nil
}
func ( *Backend) ( uint32) error {
switch {
case AuthTypeOk,
AuthTypeCleartextPassword,
AuthTypeMD5Password,
AuthTypeSCMCreds,
AuthTypeGSS,
AuthTypeGSSCont,
AuthTypeSSPI,
AuthTypeSASL,
AuthTypeSASLContinue,
AuthTypeSASLFinal:
.authType =
default:
return fmt.Errorf("authType not recognized: %d", )
}
return nil
}
func ( *Backend) ( int) {
.maxBodyLen =
}