// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package rsa

import (
	
	
	
	
	
)

// GenerateKey generates a new RSA key pair of the given bit size.
// bits must be at least 32.
func ( io.Reader,  int) (*PrivateKey, error) {
	if  < 32 {
		return nil, errors.New("rsa: key too small")
	}
	fips140.RecordApproved()
	if  < 2048 || %2 == 1 {
		fips140.RecordNonApproved()
	}

	for {
		,  := randomPrime(, (+1)/2)
		if  != nil {
			return nil, 
		}
		,  := randomPrime(, /2)
		if  != nil {
			return nil, 
		}

		,  := bigmod.NewModulus()
		if  != nil {
			return nil, 
		}
		,  := bigmod.NewModulus()
		if  != nil {
			return nil, 
		}

		if .Nat().ExpandFor().Equal(.Nat()) == 1 {
			return nil, errors.New("rsa: generated p == q, random source is broken")
		}

		,  := bigmod.NewModulusProduct(, )
		if  != nil {
			return nil, 
		}
		if .BitLen() !=  {
			return nil, errors.New("rsa: internal error: modulus size incorrect")
		}

		// d can be safely computed as e⁻¹ mod φ(N) where φ(N) = (p-1)(q-1), and
		// indeed that's what both the original RSA paper and the pre-FIPS
		// crypto/rsa implementation did.
		//
		// However, FIPS 186-5, A.1.1(3) requires computing it as e⁻¹ mod λ(N)
		// where λ(N) = lcm(p-1, q-1).
		//
		// This makes d smaller by 1.5 bits on average, which is irrelevant both
		// because we exclusively use the CRT for private operations and because
		// we use constant time windowed exponentiation. On the other hand, it
		// requires computing a GCD of two values that are not coprime, and then
		// a division, both complex variable-time operations.
		,  := totient(, )
		if  == errDivisorTooLarge {
			// The divisor is too large, try again with different primes.
			continue
		}
		if  != nil {
			return nil, 
		}

		 := bigmod.NewNat().SetUint(65537)
		,  := bigmod.NewNat().InverseVarTime(, )
		if ! {
			// This checks that GCD(e, lcm(p-1, q-1)) = 1, which is equivalent
			// to checking GCD(e, p-1) = 1 and GCD(e, q-1) = 1 separately in
			// FIPS 186-5, Appendix A.1.3, steps 4.5 and 5.6.
			//
			// We waste a prime by retrying the whole process, since 65537 is
			// probably only a factor of one of p-1 or q-1, but the probability
			// of this check failing is only 1/65537, so it doesn't matter.
			continue
		}

		if .ExpandFor().Mul(, ).IsOne() == 0 {
			return nil, errors.New("rsa: internal error: e*d != 1 mod λ(N)")
		}

		// FIPS 186-5, A.1.1(3) requires checking that d > 2^(nlen / 2).
		//
		// The probability of this check failing when d is derived from
		// (e, p, q) is roughly
		//
		//   2^(nlen/2) / 2^nlen = 2^(-nlen/2)
		//
		// so less than 2⁻¹²⁸ for keys larger than 256 bits.
		//
		// We still need to check to comply with FIPS 186-5, but knowing it has
		// negligible chance of failure we can defer the check to the end of key
		// generation and return an error if it fails. See [checkPrivateKey].

		return newPrivateKey(, 65537, , , )
	}
}

// errDivisorTooLarge is returned by [totient] when gcd(p-1, q-1) is too large.
var errDivisorTooLarge = errors.New("divisor too large")

// totient computes the Carmichael totient function λ(N) = lcm(p-1, q-1).
func (,  *bigmod.Modulus) (*bigmod.Modulus, error) {
	,  := .Nat().SubOne(), .Nat().SubOne()

	// lcm(a, b) = a×b / gcd(a, b) = a × (b / gcd(a, b))

	// Our GCD requires at least one of the numbers to be odd. For LCM we only
	// need to preserve the larger prime power of each prime factor, so we can
	// right-shift the number with the fewest trailing zeros until it's odd.
	// For odd a, b and m >= n, lcm(a×2ᵐ, b×2ⁿ) = lcm(a×2ᵐ, b).
	,  := .TrailingZeroBitsVarTime(), .TrailingZeroBitsVarTime()
	if  <  {
		 = .ShiftRightVarTime()
	} else {
		 = .ShiftRightVarTime()
	}

	,  := bigmod.NewNat().GCDVarTime(, )
	if  != nil {
		return nil, 
	}
	if .IsOdd() == 0 {
		return nil, errors.New("rsa: internal error: gcd(a, b) is even")
	}

	// To avoid implementing multiple-precision division, we just try again if
	// the divisor doesn't fit in a single word. This would have a chance of
	// 2⁻⁶⁴ on 64-bit platforms, and 2⁻³² on 32-bit platforms, but testing 2⁻⁶⁴
	// edge cases is impractical, and we'd rather not behave differently on
	// different platforms, so we reject divisors above 2³²-1.
	if .BitLenVarTime() > 32 {
		return nil, errDivisorTooLarge
	}
	if .IsZero() == 1 || .Bits()[0] == 0 {
		return nil, errors.New("rsa: internal error: gcd(a, b) is zero")
	}
	if  := .DivShortVarTime(.Bits()[0]);  != 0 {
		return nil, errors.New("rsa: internal error: b is not divisible by gcd(a, b)")
	}

	return bigmod.NewModulusProduct(.Bytes(), .Bytes())
}

// randomPrime returns a random prime number of the given bit size following
// the process in FIPS 186-5, Appendix A.1.3.
func ( io.Reader,  int) ([]byte, error) {
	if  < 16 {
		return nil, errors.New("rsa: prime size must be at least 16 bits")
	}

	 := make([]byte, (+7)/8)
	for {
		if  := drbg.ReadWithReader(, );  != nil {
			return nil, 
		}
		if  := len()*8 - ;  != 0 {
			[0] >>= 
		}

		// Don't let the value be too small: set the most significant two bits.
		// Setting the top two bits, rather than just the top bit, means that
		// when two of these values are multiplied together, the result isn't
		// ever one bit short.
		if  := len()*8 - ;  < 7 {
			[0] |= 0b1100_0000 >> 
		} else {
			[0] |= 0b0000_0001
			[1] |= 0b1000_0000
		}

		// Make the value odd since an even number certainly isn't prime.
		[len()-1] |= 1

		// We don't need to check for p >= √2 × 2^(bits-1) (steps 4.4 and 5.4)
		// because we set the top two bits above, so
		//
		//   p > 2^(bits-1) + 2^(bits-2) = 3⁄2 × 2^(bits-1) > √2 × 2^(bits-1)
		//

		// Step 5.5 requires checking that |p - q| > 2^(nlen/2 - 100).
		//
		// The probability of |p - q| ≤ k where p and q are uniformly random in
		// the range (a, b) is 1 - (b-a-k)^2 / (b-a)^2, so the probability of
		// this check failing during key generation is 2⁻⁹⁷.
		//
		// We still need to check to comply with FIPS 186-5, but knowing it has
		// negligible chance of failure we can defer the check to the end of key
		// generation and return an error if it fails. See [checkPrivateKey].

		if isPrime() {
			return , nil
		}
	}
}

// isPrime runs the Miller-Rabin Probabilistic Primality Test from
// FIPS 186-5, Appendix B.3.1.
//
// w must be a random odd integer greater than three in big-endian order.
// isPrime might return false positives for adversarially chosen values.
//
// isPrime is not constant-time.
func ( []byte) bool {
	,  := millerRabinSetup()
	if  != nil {
		// w is zero, one, or even.
		return false
	}

	,  := bigmod.NewNat().SetBytes(productOfPrimes, .w)
	// If w is too small for productOfPrimes, key generation is
	// going to be fast enough anyway.
	if  == nil {
		,  := .InverseVarTime(, .w)
		if ! {
			// productOfPrimes doesn't have an inverse mod w,
			// so w is divisible by at least one of the primes.
			return false
		}
	}

	// iterations is the number of Miller-Rabin rounds, each with a
	// randomly-selected base.
	//
	// The worst case false positive rate for a single iteration is 1/4 per
	// https://eprint.iacr.org/2018/749, so if w were selected adversarially, we
	// would need up to 64 iterations to get to a negligible (2⁻¹²⁸) chance of
	// false positive.
	//
	// However, since this function is only used for randomly-selected w in the
	// context of RSA key generation, we can use a smaller number of iterations.
	// The exact number depends on the size of the prime (and the implied
	// security level). See BoringSSL for the full formula.
	// https://cs.opensource.google/boringssl/boringssl/+/master:crypto/fipsmodule/bn/prime.c.inc;l=208-283;drc=3a138e43
	 := .w.BitLen()
	var  int
	switch {
	case  >= 3747:
		 = 3
	case  >= 1345:
		 = 4
	case  >= 476:
		 = 5
	case  >= 400:
		 = 6
	case  >= 347:
		 = 7
	case  >= 308:
		 = 8
	case  >= 55:
		 = 27
	default:
		 = 34
	}

	 := make([]byte, (+7)/8)
	for {
		drbg.Read()
		if  := len()*8 - ;  != 0 {
			[0] >>= 
		}
		,  := millerRabinIteration(, )
		if  != nil {
			// b was rejected.
			continue
		}
		if  == millerRabinCOMPOSITE {
			return false
		}
		--
		if  == 0 {
			return true
		}
	}
}

// productOfPrimes is the product of the first 74 primes higher than 2.
//
// The number of primes was selected to be the highest such that the product fit
// in 512 bits, so to be usable for 1024 bit RSA keys.
//
// Higher values cause fewer Miller-Rabin tests of composites (nothing can help
// with the final test on the actual prime) but make InverseVarTime take longer.
var productOfPrimes = []byte{
	0x10, 0x6a, 0xa9, 0xfb, 0x76, 0x46, 0xfa, 0x6e, 0xb0, 0x81, 0x3c, 0x28, 0xc5, 0xd5, 0xf0, 0x9f,
	0x07, 0x7e, 0xc3, 0xba, 0x23, 0x8b, 0xfb, 0x99, 0xc1, 0xb6, 0x31, 0xa2, 0x03, 0xe8, 0x11, 0x87,
	0x23, 0x3d, 0xb1, 0x17, 0xcb, 0xc3, 0x84, 0x05, 0x6e, 0xf0, 0x46, 0x59, 0xa4, 0xa1, 0x1d, 0xe4,
	0x9f, 0x7e, 0xcb, 0x29, 0xba, 0xda, 0x8f, 0x98, 0x0d, 0xec, 0xec, 0xe9, 0x2e, 0x30, 0xc4, 0x8f,
}

type millerRabin struct {
	w *bigmod.Modulus
	a uint
	m []byte
}

// millerRabinSetup prepares state that's reused across multiple iterations of
// the Miller-Rabin test.
func ( []byte) (*millerRabin, error) {
	 := &millerRabin{}

	// Check that w is odd, and precompute Montgomery parameters.
	,  := bigmod.NewModulus()
	if  != nil {
		return nil, 
	}
	if .Nat().IsOdd() == 0 {
		return nil, errors.New("candidate is even")
	}
	.w = 

	// Compute m = (w-1)/2^a, where m is odd.
	 := .w.Nat().SubOne(.w)
	if .IsZero() == 1 {
		return nil, errors.New("candidate is one")
	}
	.a = .TrailingZeroBitsVarTime()

	// Store mr.m as a big-endian byte slice with leading zero bytes removed,
	// for use with [bigmod.Nat.Exp].
	 := .ShiftRightVarTime(.a)
	.m = .Bytes(.w)
	for .m[0] == 0 {
		.m = .m[1:]
	}

	return , nil
}

const millerRabinCOMPOSITE = false
const millerRabinPOSSIBLYPRIME = true

func ( *millerRabin,  []byte) (bool, error) {
	// Reject b ≤ 1 or b ≥ w − 1.
	if len() != (.w.BitLen()+7)/8 {
		return false, errors.New("incorrect length")
	}
	 := bigmod.NewNat()
	if ,  := .SetBytes(, .w);  != nil {
		return false, 
	}
	if .IsZero() == 1 || .IsOne() == 1 || .IsMinusOne(.w) == 1 {
		return false, errors.New("out-of-range candidate")
	}

	// Compute b^(m*2^i) mod w for successive i.
	// If b^m mod w = 1, b is a possible prime.
	// If b^(m*2^i) mod w = -1 for some 0 <= i < a, b is a possible prime.
	// Otherwise b is composite.

	// Start by computing and checking b^m mod w (also the i = 0 case).
	 := bigmod.NewNat().Exp(, .m, .w)
	if .IsOne() == 1 || .IsMinusOne(.w) == 1 {
		return millerRabinPOSSIBLYPRIME, nil
	}

	// Check b^(m*2^i) mod w = -1 for 0 < i < a.
	for range .a - 1 {
		.Mul(, .w)
		if .IsMinusOne(.w) == 1 {
			return millerRabinPOSSIBLYPRIME, nil
		}
		if .IsOne() == 1 {
			// Future squaring will not turn z == 1 into -1.
			break
		}
	}

	return millerRabinCOMPOSITE, nil
}