// Copyright 2016 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package sasl

import (
	
	
	
	
	
	
	
	

	
)

const (
	exporterLen                = 32
	exporterLabel              = "EXPORTER-Channel-Binding"
	gs2HeaderCBSupportUnique   = "p=tls-unique,"
	gs2HeaderCBSupportExporter = "p=tls-exporter,"
	gs2HeaderNoServerCBSupport = "y,"
	gs2HeaderNoCBSupport       = "n,"
)

var (
	clientKeyInput = []byte("Client Key")
	serverKeyInput = []byte("Server Key")
)

// The number of random bytes to generate for a nonce.
const noncerandlen = 16

func ( string,  *Negotiator) ( []byte) {
	, ,  := .Credentials()
	 := .TLSState()
	switch {
	case  == nil || !strings.HasSuffix(, "-PLUS"):
		// We do not support channel binding
		 = []byte(gs2HeaderNoCBSupport)
	case .State()&RemoteCB == RemoteCB:
		// We support channel binding and the server does too
		if .Version >= tls.VersionTLS13 {
			 = []byte(gs2HeaderCBSupportExporter)
		} else {
			 = []byte(gs2HeaderCBSupportUnique)
		}
	case .State()&RemoteCB != RemoteCB:
		// We support channel binding but the server does not
		 = []byte(gs2HeaderNoServerCBSupport)
	}
	if len() > 0 {
		 = append(, []byte(`a=`)...)
		 = append(, ...)
	}
	 = append(, ',')
	return
}

func ( string,  func() hash.Hash) Mechanism {
	// BUG(ssw): We need a way to cache the SCRAM client and server key
	// calculations.
	return Mechanism{
		Name: ,
		Start: func( *Negotiator) (bool, []byte, interface{}, error) {
			, ,  := .Credentials()

			// Escape "=" and ",". This is mostly the same as bytes.Replace but
			// faster because we can do both replacements in a single pass.
			 := bytes.Count(, []byte{'='}) + bytes.Count(, []byte{','})
			 := make([]byte, len()+(*2))
			 := 0
			 := 0
			for  := 0;  < ; ++ {
				 := 
				 += bytes.IndexAny([:], "=,")
				 += copy([:], [:])
				switch [] {
				case '=':
					 += copy([:], "=3D")
				case ',':
					 += copy([:], "=2C")
				}
				 =  + 1
			}
			copy([:], [:])

			 := make([]byte, 5+len(.Nonce())+len())
			copy(, "n=")
			copy([2:], )
			copy([2+len():], ",r=")
			copy([5+len():], .Nonce())

			return true, append(getGS2Header(, ), ...), , nil
		},
		Next: func( *Negotiator,  []byte,  interface{}) ( bool,  []byte,  interface{},  error) {
			if len() == 0 {
				return , , , ErrInvalidChallenge
			}

			if .State()&Receiving == Receiving {
				panic("not yet implemented")
			}
			return scramClientNext(, , , , )
		},
	}
}

func ( string,  func() hash.Hash,  *Negotiator,  []byte,  interface{}) ( bool,  []byte,  interface{},  error) {
	, ,  := .Credentials()
	 := .State()

	switch  & StepMask {
	case AuthTextSent:
		 := -1
		var ,  []byte
		 := 
		for {
			var  []byte
			,  = nextParam()
			if len() < 3 || (len() >= 2 && [1] != '=') {
				continue
			}
			switch [0] {
			case 'i':
				 := string(bytes.TrimRight([2:], "\x00"))

				if ,  = strconv.Atoi();  != nil {
					return
				}
			case 's':
				 = make([]byte, base64.StdEncoding.DecodedLen(len()-2))
				var  int
				,  = base64.StdEncoding.Decode(, [2:])
				 = [:]
				if  != nil {
					return
				}
			case 'r':
				 = [2:]
			case 'm':
				// RFC 5802:
				// m: This attribute is reserved for future extensibility.  In this
				// version of SCRAM, its presence in a client or a server message
				// MUST cause authentication failure when the attribute is parsed by
				// the other end.
				 = errors.New("server sent reserved attribute `m'")
				return
			}
			if  == nil {
				break
			}
		}

		switch {
		case  < 0:
			 = errors.New("iteration count is invalid")
			return
		case  == nil || !bytes.HasPrefix(, .Nonce()):
			 = errors.New("server nonce does not match client nonce")
			return
		case  == nil:
			 = errors.New("server sent empty salt")
			return
		}

		 := getGS2Header(, )
		 := .TLSState()
		var  []byte
		switch  := strings.HasSuffix(, "-PLUS"); {
		case  &&  == nil:
			 = errors.New("sasl: SCRAM with channel binding requires a TLS connection")
			return
		case bytes.Contains(, []byte(gs2HeaderCBSupportExporter)):
			,  := .ExportKeyingMaterial(exporterLabel, nil, exporterLen)
			if  != nil {
				return false, nil, nil, 
			}
			if len() == 0 {
				 = errors.New("sasl: SCRAM with channel binding requires valid TLS keying material")
				return false, nil, nil, 
			}
			 = make([]byte, 2+base64.StdEncoding.EncodedLen(len()+len()))
			[0] = 'c'
			[1] = '='
			base64.StdEncoding.Encode([2:], append(, ...))
		case bytes.Contains(, []byte(gs2HeaderCBSupportUnique)):
			//lint:ignore SA1019 TLS unique must be supported by SCRAM
			if len(.TLSUnique) == 0 {
				 = errors.New("sasl: SCRAM with channel binding requires valid tls-unique data")
				return false, nil, nil, 
			}
			 = make(
				[]byte,
				//lint:ignore SA1019 TLS unique must be supported by SCRAM
				2+base64.StdEncoding.EncodedLen(len()+len(.TLSUnique)),
			)
			[0] = 'c'
			[1] = '='
			//lint:ignore SA1019 TLS unique must be supported by SCRAM
			base64.StdEncoding.Encode([2:], append(, .TLSUnique...))
		default:
			 = make(
				[]byte,
				2+base64.StdEncoding.EncodedLen(len()),
			)
			[0] = 'c'
			[1] = '='
			base64.StdEncoding.Encode([2:], )
		}
		 := append(, []byte(",r=")...)
		 = append(, ...)

		 := .([]byte)
		 := append(, ',')
		 = append(, ...)
		 = append(, ',')
		 = append(, ...)

		 := pbkdf2.Key(, , , ().Size(), )

		 := hmac.New(, )
		_,  = .Write(serverKeyInput)
		if  != nil {
			return
		}
		 := .Sum(nil)
		.Reset()

		_,  = .Write(clientKeyInput)
		if  != nil {
			return
		}
		 := .Sum(nil)

		 = hmac.New(, )
		_,  = .Write()
		if  != nil {
			return
		}
		 := .Sum(nil)

		 = ()
		_,  = .Write()
		if  != nil {
			return
		}
		 := .Sum(nil)
		 = hmac.New(, )
		_,  = .Write()
		if  != nil {
			return
		}
		 := .Sum(nil)
		 := make([]byte, len())
		goXORBytes(, , )

		 := make([]byte, base64.StdEncoding.EncodedLen(len()))
		base64.StdEncoding.Encode(, )
		 := append(, []byte(",p=")...)
		 = append(, ...)

		return true, , , nil
	case ResponseSent:
		 := "v=" + base64.StdEncoding.EncodeToString(.([]byte))
		if  != string() {
			 = ErrAuthn
			return
		}
		// Success!
		return false, nil, nil, nil
	}
	 = ErrInvalidState
	return
}

func ( []byte) ([]byte, []byte) {
	 := bytes.IndexByte(, ',')
	if  == -1 {
		return , nil
	}
	return [:], [+1:]
}