package pgproto3
import (
)
type Frontend struct {
cr *chunkReader
w io.Writer
tracer *tracer
wbuf []byte
encodeError error
authenticationOk AuthenticationOk
authenticationCleartextPassword AuthenticationCleartextPassword
authenticationMD5Password AuthenticationMD5Password
authenticationGSS AuthenticationGSS
authenticationGSSContinue AuthenticationGSSContinue
authenticationSASL AuthenticationSASL
authenticationSASLContinue AuthenticationSASLContinue
authenticationSASLFinal AuthenticationSASLFinal
backendKeyData BackendKeyData
bindComplete BindComplete
closeComplete CloseComplete
commandComplete CommandComplete
copyBothResponse CopyBothResponse
copyData CopyData
copyInResponse CopyInResponse
copyOutResponse CopyOutResponse
copyDone CopyDone
dataRow DataRow
emptyQueryResponse EmptyQueryResponse
errorResponse ErrorResponse
functionCallResponse FunctionCallResponse
noData NoData
noticeResponse NoticeResponse
notificationResponse NotificationResponse
parameterDescription ParameterDescription
parameterStatus ParameterStatus
parseComplete ParseComplete
readyForQuery ReadyForQuery
rowDescription RowDescription
portalSuspended PortalSuspended
negotiateProtocolVersion NegotiateProtocolVersion
bodyLen int
maxBodyLen int
msgType byte
partialMsg bool
authType uint32
}
func ( io.Reader, io.Writer) *Frontend {
:= newChunkReader(, 0)
return &Frontend{cr: , w: }
}
func ( *Frontend) ( FrontendMessage) {
if .encodeError != nil {
return
}
:= len(.wbuf)
, := .Encode(.wbuf)
if != nil {
.encodeError =
return
}
.wbuf =
if .tracer != nil {
.tracer.traceMessage('F', int32(len(.wbuf)-), )
}
}
func ( *Frontend) () error {
if := .encodeError; != nil {
.encodeError = nil
.wbuf = .wbuf[:0]
return &writeError{err: , safeToRetry: true}
}
if len(.wbuf) == 0 {
return nil
}
, := .w.Write(.wbuf)
const = 1024
if len(.wbuf) > {
.wbuf = make([]byte, 0, )
} else {
.wbuf = .wbuf[:0]
}
if != nil {
return &writeError{err: , safeToRetry: == 0}
}
return nil
}
func ( *Frontend) ( io.Writer, TracerOptions) {
.tracer = &tracer{
w: ,
buf: &bytes.Buffer{},
TracerOptions: ,
}
}
func ( *Frontend) () {
.tracer = nil
}
func ( *Frontend) ( *Bind) {
if .encodeError != nil {
return
}
:= len(.wbuf)
, := .Encode(.wbuf)
if != nil {
.encodeError =
return
}
.wbuf =
if .tracer != nil {
.tracer.traceBind('F', int32(len(.wbuf)-), )
}
}
func ( *Frontend) ( *Parse) {
if .encodeError != nil {
return
}
:= len(.wbuf)
, := .Encode(.wbuf)
if != nil {
.encodeError =
return
}
.wbuf =
if .tracer != nil {
.tracer.traceParse('F', int32(len(.wbuf)-), )
}
}
func ( *Frontend) ( *Close) {
if .encodeError != nil {
return
}
:= len(.wbuf)
, := .Encode(.wbuf)
if != nil {
.encodeError =
return
}
.wbuf =
if .tracer != nil {
.tracer.traceClose('F', int32(len(.wbuf)-), )
}
}
func ( *Frontend) ( *Describe) {
if .encodeError != nil {
return
}
:= len(.wbuf)
, := .Encode(.wbuf)
if != nil {
.encodeError =
return
}
.wbuf =
if .tracer != nil {
.tracer.traceDescribe('F', int32(len(.wbuf)-), )
}
}
func ( *Frontend) ( *Execute) {
if .encodeError != nil {
return
}
:= len(.wbuf)
, := .Encode(.wbuf)
if != nil {
.encodeError =
return
}
.wbuf =
if .tracer != nil {
.tracer.traceExecute('F', int32(len(.wbuf)-), )
}
}
func ( *Frontend) ( *Sync) {
if .encodeError != nil {
return
}
:= len(.wbuf)
, := .Encode(.wbuf)
if != nil {
.encodeError =
return
}
.wbuf =
if .tracer != nil {
.tracer.traceSync('F', int32(len(.wbuf)-), )
}
}
func ( *Frontend) ( *Query) {
if .encodeError != nil {
return
}
:= len(.wbuf)
, := .Encode(.wbuf)
if != nil {
.encodeError =
return
}
.wbuf =
if .tracer != nil {
.tracer.traceQuery('F', int32(len(.wbuf)-), )
}
}
func ( *Frontend) ( []byte) error {
:= .Flush()
if != nil {
return
}
, := .w.Write()
if != nil {
return &writeError{err: , safeToRetry: == 0}
}
if .tracer != nil {
.tracer.traceCopyData('F', int32(len()-1), &CopyData{})
}
return nil
}
func ( error) error {
if == io.EOF {
return io.ErrUnexpectedEOF
}
return
}
func ( *Frontend) () (BackendMessage, error) {
if !.partialMsg {
, := .cr.Next(5)
if != nil {
return nil, translateEOFtoErrUnexpectedEOF()
}
.msgType = [0]
:= int(int32(binary.BigEndian.Uint32([1:])))
if < 4 {
return nil, fmt.Errorf("invalid message length: %d", )
}
.bodyLen = - 4
if .maxBodyLen > 0 && .bodyLen > .maxBodyLen {
return nil, &ExceededMaxBodyLenErr{.maxBodyLen, .bodyLen}
}
.partialMsg = true
}
, := .cr.Next(.bodyLen)
if != nil {
return nil, translateEOFtoErrUnexpectedEOF()
}
.partialMsg = false
var BackendMessage
switch .msgType {
case '1':
= &.parseComplete
case '2':
= &.bindComplete
case '3':
= &.closeComplete
case 'A':
= &.notificationResponse
case 'c':
= &.copyDone
case 'C':
= &.commandComplete
case 'd':
= &.copyData
case 'D':
= &.dataRow
case 'E':
= &.errorResponse
case 'G':
= &.copyInResponse
case 'H':
= &.copyOutResponse
case 'I':
= &.emptyQueryResponse
case 'K':
= &.backendKeyData
case 'n':
= &.noData
case 'N':
= &.noticeResponse
case 'R':
var error
, = .findAuthenticationMessageType()
if != nil {
return nil,
}
case 's':
= &.portalSuspended
case 'S':
= &.parameterStatus
case 't':
= &.parameterDescription
case 'T':
= &.rowDescription
case 'V':
= &.functionCallResponse
case 'W':
= &.copyBothResponse
case 'Z':
= &.readyForQuery
case 'v':
= &.negotiateProtocolVersion
default:
return nil, fmt.Errorf("unknown message type: %c", .msgType)
}
= .Decode()
if != nil {
return nil,
}
if .tracer != nil {
.tracer.traceMessage('B', int32(5+len()), )
}
return , nil
}
const (
AuthTypeOk = 0
AuthTypeCleartextPassword = 3
AuthTypeMD5Password = 5
AuthTypeSCMCreds = 6
AuthTypeGSS = 7
AuthTypeGSSCont = 8
AuthTypeSSPI = 9
AuthTypeSASL = 10
AuthTypeSASLContinue = 11
AuthTypeSASLFinal = 12
)
func ( *Frontend) ( []byte) (BackendMessage, error) {
if len() < 4 {
return nil, errors.New("authentication message too short")
}
.authType = binary.BigEndian.Uint32([:4])
switch .authType {
case AuthTypeOk:
return &.authenticationOk, nil
case AuthTypeCleartextPassword:
return &.authenticationCleartextPassword, nil
case AuthTypeMD5Password:
return &.authenticationMD5Password, nil
case AuthTypeSCMCreds:
return nil, errors.New("AuthTypeSCMCreds is unimplemented")
case AuthTypeGSS:
return &.authenticationGSS, nil
case AuthTypeGSSCont:
return &.authenticationGSSContinue, nil
case AuthTypeSSPI:
return nil, errors.New("AuthTypeSSPI is unimplemented")
case AuthTypeSASL:
return &.authenticationSASL, nil
case AuthTypeSASLContinue:
return &.authenticationSASLContinue, nil
case AuthTypeSASLFinal:
return &.authenticationSASLFinal, nil
default:
return nil, fmt.Errorf("unknown authentication type: %d", .authType)
}
}
func ( *Frontend) () uint32 {
return .authType
}
func ( *Frontend) () int {
return .cr.wp - .cr.rp
}
func ( *Frontend) ( int) {
.maxBodyLen =
}