// SCRAM-SHA-256 and SCRAM-SHA-256-PLUS authentication
//
// Resources:
//   https://tools.ietf.org/html/rfc5802
//   https://tools.ietf.org/html/rfc5929
//   https://tools.ietf.org/html/rfc8265
//   https://www.postgresql.org/docs/current/sasl-authentication.html
//
// Inspiration drawn from other implementations:
//   https://github.com/lib/pq/pull/608
//   https://github.com/lib/pq/pull/788
//   https://github.com/lib/pq/pull/833

package pgconn

import (
	
	
	
	
	
	
	
	
	
	
	
	
	
	

	
	
)

const (
	clientNonceLen      = 18
	scramSHA256Name     = "SCRAM-SHA-256"
	scramSHA256PlusName = "SCRAM-SHA-256-PLUS"
)

// Perform SCRAM authentication.
func ( *PgConn) ( []string) error {
	,  := newScramClient(, .config.Password)
	if  != nil {
		return 
	}

	 := slices.Contains(.serverAuthMechanisms, scramSHA256PlusName)
	if .config.ChannelBinding == "require" && ! {
		return errors.New("channel binding required but server does not support SCRAM-SHA-256-PLUS")
	}

	// If we have a TLS connection and channel binding is not disabled, attempt to
	// extract the server certificate hash for tls-server-end-point channel binding.
	if ,  := .conn.(*tls.Conn);  && .config.ChannelBinding != "disable" {
		,  := getTLSCertificateHash()
		if  != nil && .config.ChannelBinding == "require" {
			return fmt.Errorf("channel binding required but failed to get server certificate hash: %w", )
		}

		// Upgrade to SCRAM-SHA-256-PLUS if we have binding data and the server supports it.
		if  != nil &&  {
			.authMechanism = scramSHA256PlusName
		}

		.channelBindingData = 
		.hasTLS = true
	}

	if .config.ChannelBinding == "require" && .channelBindingData == nil {
		return errors.New("channel binding required but channel binding data is not available")
	}

	// Send client-first-message in a SASLInitialResponse
	 := &pgproto3.SASLInitialResponse{
		AuthMechanism: .authMechanism,
		Data:          .clientFirstMessage(),
	}
	.frontend.Send()
	 = .flushWithPotentialWriteReadDeadlock()
	if  != nil {
		return 
	}

	// Receive server-first-message payload in an AuthenticationSASLContinue.
	,  := .rxSASLContinue()
	if  != nil {
		return 
	}
	 = .recvServerFirstMessage(.Data)
	if  != nil {
		return 
	}

	// Send client-final-message in a SASLResponse
	 := &pgproto3.SASLResponse{
		Data: []byte(.clientFinalMessage()),
	}
	.frontend.Send()
	 = .flushWithPotentialWriteReadDeadlock()
	if  != nil {
		return 
	}

	// Receive server-final-message payload in an AuthenticationSASLFinal.
	,  := .rxSASLFinal()
	if  != nil {
		return 
	}
	return .recvServerFinalMessage(.Data)
}

func ( *PgConn) () (*pgproto3.AuthenticationSASLContinue, error) {
	,  := .receiveMessage()
	if  != nil {
		return nil, 
	}
	switch m := .(type) {
	case *pgproto3.AuthenticationSASLContinue:
		return , nil
	case *pgproto3.ErrorResponse:
		return nil, ErrorResponseToPgError()
	}

	return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", )
}

func ( *PgConn) () (*pgproto3.AuthenticationSASLFinal, error) {
	,  := .receiveMessage()
	if  != nil {
		return nil, 
	}
	switch m := .(type) {
	case *pgproto3.AuthenticationSASLFinal:
		return , nil
	case *pgproto3.ErrorResponse:
		return nil, ErrorResponseToPgError()
	}

	return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", )
}

type scramClient struct {
	serverAuthMechanisms []string
	password             string
	clientNonce          []byte

	// authMechanism is the selected SASL mechanism for the client. Must be
	// either SCRAM-SHA-256 (default) or SCRAM-SHA-256-PLUS.
	//
	// Upgraded to SCRAM-SHA-256-PLUS during authentication when channel binding
	// is not disabled, channel binding data is available (TLS connection with
	// an obtainable server certificate hash) and the server advertises
	// SCRAM-SHA-256-PLUS.
	authMechanism string

	// hasTLS indicates whether the connection is using TLS. This is
	// needed because the GS2 header must distinguish between a client that
	// supports channel binding but the server does not ("y,,") versus one
	// that does not support it at all ("n,,").
	hasTLS bool

	// channelBindingData is the hash of the server's TLS certificate, computed
	// per the tls-server-end-point channel binding type (RFC 5929). Used as
	// the binding input in SCRAM-SHA-256-PLUS. nil when not in use.
	channelBindingData []byte

	clientFirstMessageBare []byte
	clientGS2Header        []byte

	serverFirstMessage   []byte
	clientAndServerNonce []byte
	salt                 []byte
	iterations           int

	saltedPassword []byte
	authMessage    []byte
}

func ( []string,  string) (*scramClient, error) {
	 := &scramClient{
		serverAuthMechanisms: ,
		authMechanism:        scramSHA256Name,
	}

	// Ensure the server supports SCRAM-SHA-256. SCRAM-SHA-256-PLUS is the
	// channel binding variant and is only advertised when the server supports
	// SSL. PostgreSQL always advertises the base SCRAM-SHA-256 mechanism
	// regardless of SSL.
	if !slices.Contains(.serverAuthMechanisms, scramSHA256Name) {
		return nil, errors.New("server does not support SCRAM-SHA-256")
	}

	// precis.OpaqueString is equivalent to SASLprep for password.
	var  error
	.password,  = precis.OpaqueString.String()
	if  != nil {
		// PostgreSQL allows passwords invalid according to SCRAM / SASLprep.
		.password = 
	}

	 := make([]byte, clientNonceLen)
	_,  = rand.Read()
	if  != nil {
		return nil, 
	}
	.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len()))
	base64.RawStdEncoding.Encode(.clientNonce, )

	return , nil
}

func ( *scramClient) () []byte {
	// The client-first-message is the GS2 header concatenated with the bare
	// message (username + client nonce). The GS2 header communicates the
	// client's channel binding capability to the server:
	//
	//   "n,,"                      - client is not using TLS (channel binding not possible)
	//   "y,,"                      - client is using TLS but channel binding is not
	//                                in use (e.g., server did not advertise SCRAM-SHA-256-PLUS
	//                                or the server certificate hash was not obtainable)
	//   "p=tls-server-end-point,," - channel binding is active via SCRAM-SHA-256-PLUS
	//
	// See:
	//   https://www.rfc-editor.org/rfc/rfc5802#section-6
	//   https://www.rfc-editor.org/rfc/rfc5929#section-4
	//   https://www.postgresql.org/docs/current/sasl-authentication.html#SASL-SCRAM-SHA-256

	.clientFirstMessageBare = fmt.Appendf(nil, "n=,r=%s", .clientNonce)

	if .authMechanism == scramSHA256PlusName {
		.clientGS2Header = []byte("p=tls-server-end-point,,")
	} else if .hasTLS {
		.clientGS2Header = []byte("y,,")
	} else {
		.clientGS2Header = []byte("n,,")
	}

	return append(.clientGS2Header, .clientFirstMessageBare...)
}

func ( *scramClient) ( []byte) error {
	.serverFirstMessage = 
	 := 
	if !bytes.HasPrefix(, []byte("r=")) {
		return errors.New("invalid SCRAM server-first-message received from server: did not include r=")
	}
	 = [2:]

	 := bytes.IndexByte(, ',')
	if  == -1 {
		return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
	}
	.clientAndServerNonce = [:]
	 = [+1:]

	if !bytes.HasPrefix(, []byte("s=")) {
		return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
	}
	 = [2:]

	 = bytes.IndexByte(, ',')
	if  == -1 {
		return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
	}
	 := [:]
	 = [+1:]

	if !bytes.HasPrefix(, []byte("i=")) {
		return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
	}
	 = [2:]
	 := 

	var  error
	.salt,  = base64.StdEncoding.DecodeString(string())
	if  != nil {
		return fmt.Errorf("invalid SCRAM salt received from server: %w", )
	}

	.iterations,  = strconv.Atoi(string())
	if  != nil || .iterations <= 0 {
		return fmt.Errorf("invalid SCRAM iteration count received from server: %w", )
	}

	if !bytes.HasPrefix(.clientAndServerNonce, .clientNonce) {
		return errors.New("invalid SCRAM nonce: did not start with client nonce")
	}

	if len(.clientAndServerNonce) <= len(.clientNonce) {
		return errors.New("invalid SCRAM nonce: did not include server nonce")
	}

	return nil
}

func ( *scramClient) () string {
	// The c= attribute carries the base64-encoded channel binding input.
	//
	// Without channel binding this is just the GS2 header alone ("biws" for
	// "n,," or "eSws" for "y,,").
	//
	// With channel binding, this is the GS2 header with the channel binding data
	// (certificate hash) appended.
	 := .clientGS2Header
	if .authMechanism == scramSHA256PlusName {
		 = slices.Concat(.clientGS2Header, .channelBindingData)
	}
	 := base64.StdEncoding.EncodeToString()
	 := fmt.Appendf(nil, "c=%s,r=%s", , .clientAndServerNonce)

	var  error
	.saltedPassword,  = pbkdf2.Key(sha256.New, .password, .salt, .iterations, 32)
	if  != nil {
		panic() // This should never happen.
	}
	.authMessage = bytes.Join([][]byte{.clientFirstMessageBare, .serverFirstMessage, }, []byte(","))

	 := computeClientProof(.saltedPassword, .authMessage)

	return fmt.Sprintf("%s,p=%s", , )
}

func ( *scramClient) ( []byte) error {
	if !bytes.HasPrefix(, []byte("v=")) {
		return errors.New("invalid SCRAM server-final-message received from server")
	}

	 := [2:]

	if !hmac.Equal(, computeServerSignature(.saltedPassword, .authMessage)) {
		return errors.New("invalid SCRAM ServerSignature received from server")
	}

	return nil
}

func (,  []byte) []byte {
	 := hmac.New(sha256.New, )
	.Write()
	return .Sum(nil)
}

func (,  []byte) []byte {
	 := computeHMAC(, []byte("Client Key"))
	 := sha256.Sum256()
	 := computeHMAC([:], )

	 := make([]byte, len())
	for  := range  {
		[] = [] ^ []
	}

	 := make([]byte, base64.StdEncoding.EncodedLen(len()))
	base64.StdEncoding.Encode(, )
	return 
}

func (,  []byte) []byte {
	 := computeHMAC(, []byte("Server Key"))
	 := computeHMAC(, )
	 := make([]byte, base64.StdEncoding.EncodedLen(len()))
	base64.StdEncoding.Encode(, )
	return 
}

// Get the server certificate hash for SCRAM channel binding type
// tls-server-end-point.
func ( *tls.Conn) ([]byte, error) {
	 := .ConnectionState()
	if len(.PeerCertificates) == 0 {
		return nil, errors.New("no peer certificates for channel binding")
	}

	 := .PeerCertificates[0]

	// Per RFC 5929 section 4.1: If the certificate's signatureAlgorithm uses
	// MD5 or SHA-1, use SHA-256. Otherwise use the hash from the signature
	// algorithm.
	//
	// See: https://www.rfc-editor.org/rfc/rfc5929.html#section-4.1
	var  hash.Hash
	switch .SignatureAlgorithm {
	case x509.MD5WithRSA, x509.SHA1WithRSA, x509.ECDSAWithSHA1:
		 = sha256.New()
	case x509.SHA256WithRSA, x509.SHA256WithRSAPSS, x509.ECDSAWithSHA256:
		 = sha256.New()
	case x509.SHA384WithRSA, x509.SHA384WithRSAPSS, x509.ECDSAWithSHA384:
		 = sha512.New384()
	case x509.SHA512WithRSA, x509.SHA512WithRSAPSS, x509.ECDSAWithSHA512:
		 = sha512.New()
	default:
		return nil, fmt.Errorf("tls-server-end-point channel binding is undefined for certificate signature algorithm %v", .SignatureAlgorithm)
	}

	.Write(.Raw)
	return .Sum(nil), nil
}