package pgconn
import (
)
const (
clientNonceLen = 18
scramSHA256Name = "SCRAM-SHA-256"
scramSHA256PlusName = "SCRAM-SHA-256-PLUS"
)
func ( *PgConn) ( []string) error {
, := newScramClient(, .config.Password)
if != nil {
return
}
:= slices.Contains(.serverAuthMechanisms, scramSHA256PlusName)
if .config.ChannelBinding == "require" && ! {
return errors.New("channel binding required but server does not support SCRAM-SHA-256-PLUS")
}
if , := .conn.(*tls.Conn); && .config.ChannelBinding != "disable" {
, := getTLSCertificateHash()
if != nil && .config.ChannelBinding == "require" {
return fmt.Errorf("channel binding required but failed to get server certificate hash: %w", )
}
if != nil && {
.authMechanism = scramSHA256PlusName
}
.channelBindingData =
.hasTLS = true
}
if .config.ChannelBinding == "require" && .channelBindingData == nil {
return errors.New("channel binding required but channel binding data is not available")
}
:= &pgproto3.SASLInitialResponse{
AuthMechanism: .authMechanism,
Data: .clientFirstMessage(),
}
.frontend.Send()
= .flushWithPotentialWriteReadDeadlock()
if != nil {
return
}
, := .rxSASLContinue()
if != nil {
return
}
= .recvServerFirstMessage(.Data)
if != nil {
return
}
:= &pgproto3.SASLResponse{
Data: []byte(.clientFinalMessage()),
}
.frontend.Send()
= .flushWithPotentialWriteReadDeadlock()
if != nil {
return
}
, := .rxSASLFinal()
if != nil {
return
}
return .recvServerFinalMessage(.Data)
}
func ( *PgConn) () (*pgproto3.AuthenticationSASLContinue, error) {
, := .receiveMessage()
if != nil {
return nil,
}
switch m := .(type) {
case *pgproto3.AuthenticationSASLContinue:
return , nil
case *pgproto3.ErrorResponse:
return nil, ErrorResponseToPgError()
}
return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", )
}
func ( *PgConn) () (*pgproto3.AuthenticationSASLFinal, error) {
, := .receiveMessage()
if != nil {
return nil,
}
switch m := .(type) {
case *pgproto3.AuthenticationSASLFinal:
return , nil
case *pgproto3.ErrorResponse:
return nil, ErrorResponseToPgError()
}
return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", )
}
type scramClient struct {
serverAuthMechanisms []string
password string
clientNonce []byte
authMechanism string
hasTLS bool
channelBindingData []byte
clientFirstMessageBare []byte
clientGS2Header []byte
serverFirstMessage []byte
clientAndServerNonce []byte
salt []byte
iterations int
saltedPassword []byte
authMessage []byte
}
func ( []string, string) (*scramClient, error) {
:= &scramClient{
serverAuthMechanisms: ,
authMechanism: scramSHA256Name,
}
if !slices.Contains(.serverAuthMechanisms, scramSHA256Name) {
return nil, errors.New("server does not support SCRAM-SHA-256")
}
var error
.password, = precis.OpaqueString.String()
if != nil {
.password =
}
:= make([]byte, clientNonceLen)
_, = rand.Read()
if != nil {
return nil,
}
.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len()))
base64.RawStdEncoding.Encode(.clientNonce, )
return , nil
}
func ( *scramClient) () []byte {
.clientFirstMessageBare = fmt.Appendf(nil, "n=,r=%s", .clientNonce)
if .authMechanism == scramSHA256PlusName {
.clientGS2Header = []byte("p=tls-server-end-point,,")
} else if .hasTLS {
.clientGS2Header = []byte("y,,")
} else {
.clientGS2Header = []byte("n,,")
}
return append(.clientGS2Header, .clientFirstMessageBare...)
}
func ( *scramClient) ( []byte) error {
.serverFirstMessage =
:=
if !bytes.HasPrefix(, []byte("r=")) {
return errors.New("invalid SCRAM server-first-message received from server: did not include r=")
}
= [2:]
:= bytes.IndexByte(, ',')
if == -1 {
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
}
.clientAndServerNonce = [:]
= [+1:]
if !bytes.HasPrefix(, []byte("s=")) {
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
}
= [2:]
= bytes.IndexByte(, ',')
if == -1 {
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
}
:= [:]
= [+1:]
if !bytes.HasPrefix(, []byte("i=")) {
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
}
= [2:]
:=
var error
.salt, = base64.StdEncoding.DecodeString(string())
if != nil {
return fmt.Errorf("invalid SCRAM salt received from server: %w", )
}
.iterations, = strconv.Atoi(string())
if != nil || .iterations <= 0 {
return fmt.Errorf("invalid SCRAM iteration count received from server: %w", )
}
if !bytes.HasPrefix(.clientAndServerNonce, .clientNonce) {
return errors.New("invalid SCRAM nonce: did not start with client nonce")
}
if len(.clientAndServerNonce) <= len(.clientNonce) {
return errors.New("invalid SCRAM nonce: did not include server nonce")
}
return nil
}
func ( *scramClient) () string {
:= .clientGS2Header
if .authMechanism == scramSHA256PlusName {
= slices.Concat(.clientGS2Header, .channelBindingData)
}
:= base64.StdEncoding.EncodeToString()
:= fmt.Appendf(nil, "c=%s,r=%s", , .clientAndServerNonce)
var error
.saltedPassword, = pbkdf2.Key(sha256.New, .password, .salt, .iterations, 32)
if != nil {
panic()
}
.authMessage = bytes.Join([][]byte{.clientFirstMessageBare, .serverFirstMessage, }, []byte(","))
:= computeClientProof(.saltedPassword, .authMessage)
return fmt.Sprintf("%s,p=%s", , )
}
func ( *scramClient) ( []byte) error {
if !bytes.HasPrefix(, []byte("v=")) {
return errors.New("invalid SCRAM server-final-message received from server")
}
:= [2:]
if !hmac.Equal(, computeServerSignature(.saltedPassword, .authMessage)) {
return errors.New("invalid SCRAM ServerSignature received from server")
}
return nil
}
func (, []byte) []byte {
:= hmac.New(sha256.New, )
.Write()
return .Sum(nil)
}
func (, []byte) []byte {
:= computeHMAC(, []byte("Client Key"))
:= sha256.Sum256()
:= computeHMAC([:], )
:= make([]byte, len())
for := range {
[] = [] ^ []
}
:= make([]byte, base64.StdEncoding.EncodedLen(len()))
base64.StdEncoding.Encode(, )
return
}
func (, []byte) []byte {
:= computeHMAC(, []byte("Server Key"))
:= computeHMAC(, )
:= make([]byte, base64.StdEncoding.EncodedLen(len()))
base64.StdEncoding.Encode(, )
return
}
func ( *tls.Conn) ([]byte, error) {
:= .ConnectionState()
if len(.PeerCertificates) == 0 {
return nil, errors.New("no peer certificates for channel binding")
}
:= .PeerCertificates[0]
var hash.Hash
switch .SignatureAlgorithm {
case x509.MD5WithRSA, x509.SHA1WithRSA, x509.ECDSAWithSHA1:
= sha256.New()
case x509.SHA256WithRSA, x509.SHA256WithRSAPSS, x509.ECDSAWithSHA256:
= sha256.New()
case x509.SHA384WithRSA, x509.SHA384WithRSAPSS, x509.ECDSAWithSHA384:
= sha512.New384()
case x509.SHA512WithRSA, x509.SHA512WithRSAPSS, x509.ECDSAWithSHA512:
= sha512.New()
default:
return nil, fmt.Errorf("tls-server-end-point channel binding is undefined for certificate signature algorithm %v", .SignatureAlgorithm)
}
.Write(.Raw)
return .Sum(nil), nil
}