package pgconn

import (
	
	
	
	

	
)

func ( *PgConn) ( context.Context) error {
	if .config.OAuthTokenProvider == nil {
		return errors.New("OAuth authentication required but no token provider configured")
	}

	,  := .config.OAuthTokenProvider()
	if  != nil {
		return fmt.Errorf("failed to obtain OAuth token: %w", )
	}

	// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.1
	 := []byte("n,,\x01auth=Bearer " +  + "\x01\x01")

	 := &pgproto3.SASLInitialResponse{
		AuthMechanism: "OAUTHBEARER",
		Data:          ,
	}
	.frontend.Send()
	 = .flushWithPotentialWriteReadDeadlock()
	if  != nil {
		return 
	}

	,  := .receiveMessage()
	if  != nil {
		return 
	}

	switch m := .(type) {
	case *pgproto3.AuthenticationOk:
		return nil
	case *pgproto3.AuthenticationSASLContinue:
		// Server sent error response in SASL continue
		// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.2
		// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.3
		 := struct {
			              string `json:"status"`
			               string `json:"scope"`
			 string `json:"openid-configuration"`
		}{}
		 := json.Unmarshal(.Data, &)
		if  != nil {
			return fmt.Errorf("invalid OAuth error response from server: %w", )
		}

		// Per RFC 7628 section 3.2.3, we should send a SASLResponse which only contains \x01.
		// However, since the connection will be closed anyway, we can skip this
		return fmt.Errorf("OAuth authentication failed: %s", .)

	case *pgproto3.ErrorResponse:
		return ErrorResponseToPgError()

	default:
		return fmt.Errorf("unexpected message type during OAuth auth: %T", )
	}
}