package mlkem
import (
)
type DecapsulationKey1024 struct {
d [32]byte
z [32]byte
ρ [32]byte
h [32]byte
encryptionKey1024
decryptionKey1024
}
func ( *DecapsulationKey1024) () []byte {
var [SeedSize]byte
copy([:], .d[:])
copy([32:], .z[:])
return [:]
}
func ( *DecapsulationKey1024) []byte {
:= make([]byte, 0, decapsulationKeySize1024)
for := range .s {
= polyByteEncode(, .s[])
}
for := range .t {
= polyByteEncode(, .t[])
}
= append(, .ρ[:]...)
= append(, .h[:]...)
= append(, .z[:]...)
return
}
func ( *DecapsulationKey1024) () *EncapsulationKey1024 {
return &EncapsulationKey1024{
ρ: .ρ,
h: .h,
encryptionKey1024: .encryptionKey1024,
}
}
type EncapsulationKey1024 struct {
ρ [32]byte
h [32]byte
encryptionKey1024
}
func ( *EncapsulationKey1024) () []byte {
:= make([]byte, 0, EncapsulationKeySize1024)
return .bytes()
}
func ( *EncapsulationKey1024) ( []byte) []byte {
for := range .t {
= polyByteEncode(, .t[])
}
= append(, .ρ[:]...)
return
}
type encryptionKey1024 struct {
t [k1024]nttElement
a [k1024 * k1024]nttElement
}
type decryptionKey1024 struct {
s [k1024]nttElement
}
func () (*DecapsulationKey1024, error) {
:= &DecapsulationKey1024{}
return generateKey1024()
}
func ( *DecapsulationKey1024) (*DecapsulationKey1024, error) {
var [32]byte
drbg.Read([:])
var [32]byte
drbg.Read([:])
kemKeyGen1024(, &, &)
if := fips140.PCT("ML-KEM PCT", func() error { return kemPCT1024() }); != nil {
panic()
}
fips140.RecordApproved()
return , nil
}
func (, *[32]byte) *DecapsulationKey1024 {
:= &DecapsulationKey1024{}
kemKeyGen1024(, , )
return
}
func ( []byte) (*DecapsulationKey1024, error) {
:= &DecapsulationKey1024{}
return newKeyFromSeed1024(, )
}
func ( *DecapsulationKey1024, []byte) (*DecapsulationKey1024, error) {
if len() != SeedSize {
return nil, errors.New("mlkem: invalid seed length")
}
:= (*[32]byte)([:32])
:= (*[32]byte)([32:])
kemKeyGen1024(, , )
if := fips140.PCT("ML-KEM PCT", func() error { return kemPCT1024() }); != nil {
panic()
}
fips140.RecordApproved()
return , nil
}
func ( []byte) (*DecapsulationKey1024, error) {
if len() != decapsulationKeySize1024 {
return nil, errors.New("mlkem: invalid NIST decapsulation key length")
}
:= &DecapsulationKey1024{}
for := range .s {
var error
.s[], = polyByteDecode[nttElement]([:encodingSize12])
if != nil {
return nil, errors.New("mlkem: invalid secret key encoding")
}
= [encodingSize12:]
}
, := NewEncapsulationKey1024([:EncapsulationKeySize1024])
if != nil {
return nil,
}
.ρ = .ρ
.h = .h
.encryptionKey1024 = .encryptionKey1024
= [EncapsulationKeySize1024:]
if !bytes.Equal(.h[:], [:32]) {
return nil, errors.New("mlkem: inconsistent H(ek) in encoded bytes")
}
= [32:]
copy(.z[:], )
drbg.Read(.d[:])
return , nil
}
func ( *DecapsulationKey1024, , *[32]byte) {
.d = *
.z = *
:= sha3.New512()
.Write([:])
.Write([]byte{k1024})
:= .Sum(make([]byte, 0, 64))
, := [:32], [32:]
.ρ = [32]byte()
:= &.a
for := byte(0); < k1024; ++ {
for := byte(0); < k1024; ++ {
[*k1024+] = sampleNTT(, , )
}
}
var byte
:= &.s
for := range {
[] = ntt(samplePolyCBD(, ))
++
}
:= make([]nttElement, k1024)
for := range {
[] = ntt(samplePolyCBD(, ))
++
}
:= &.t
for := range {
[] = []
for := range {
[] = polyAdd([], nttMul([*k1024+], []))
}
}
:= sha3.New256()
:= .EncapsulationKey().Bytes()
.Write()
.Sum(.h[:0])
}
func ( *DecapsulationKey1024) error {
:= .EncapsulationKey()
, := .Encapsulate()
, := .Decapsulate()
if != nil {
return
}
if subtle.ConstantTimeCompare(, ) != 1 {
return errors.New("mlkem: PCT failed")
}
return nil
}
func ( *EncapsulationKey1024) () (, []byte) {
var [CiphertextSize1024]byte
return .encapsulate(&)
}
func ( *EncapsulationKey1024) ( *[CiphertextSize1024]byte) (, []byte) {
var [messageSize]byte
drbg.Read([:])
fips140.RecordApproved()
return kemEncaps1024(, , &)
}
func ( *EncapsulationKey1024) ( *[32]byte) (, []byte) {
:= &[CiphertextSize1024]byte{}
return kemEncaps1024(, , )
}
func ( *[CiphertextSize1024]byte, *EncapsulationKey1024, *[messageSize]byte) (, []byte) {
:= sha3.New512()
.Write([:])
.Write(.h[:])
:= .Sum(nil)
, := [:SharedKeySize], [SharedKeySize:]
= pkeEncrypt1024(, &.encryptionKey1024, , )
return ,
}
func ( []byte) (*EncapsulationKey1024, error) {
:= &EncapsulationKey1024{}
return parseEK1024(, )
}
func ( *EncapsulationKey1024, []byte) (*EncapsulationKey1024, error) {
if len() != EncapsulationKeySize1024 {
return nil, errors.New("mlkem: invalid encapsulation key length")
}
:= sha3.New256()
.Write()
.Sum(.h[:0])
for := range .t {
var error
.t[], = polyByteDecode[nttElement]([:encodingSize12])
if != nil {
return nil,
}
= [encodingSize12:]
}
copy(.ρ[:], )
for := byte(0); < k1024; ++ {
for := byte(0); < k1024; ++ {
.a[*k1024+] = sampleNTT(.ρ[:], , )
}
}
return , nil
}
func ( *[CiphertextSize1024]byte, *encryptionKey1024, *[messageSize]byte, []byte) []byte {
var byte
, := make([]nttElement, k1024), make([]ringElement, k1024)
for := range {
[] = ntt(samplePolyCBD(, ))
++
}
for := range {
[] = samplePolyCBD(, )
++
}
:= samplePolyCBD(, )
:= make([]ringElement, k1024)
for := range {
[] = []
for := range {
[] = polyAdd([], inverseNTT(nttMul(.a[*k1024+], [])))
}
}
:= ringDecodeAndDecompress1()
var nttElement
for := range .t {
= polyAdd(, nttMul(.t[], []))
}
:= polyAdd(polyAdd(inverseNTT(), ), )
:= [:0]
for , := range {
= ringCompressAndEncode11(, )
}
= ringCompressAndEncode5(, )
return
}
func ( *DecapsulationKey1024) ( []byte) ( []byte, error) {
if len() != CiphertextSize1024 {
return nil, errors.New("mlkem: invalid ciphertext length")
}
:= (*[CiphertextSize1024]byte)()
return kemDecaps1024(, ), nil
}
func ( *DecapsulationKey1024, *[CiphertextSize1024]byte) ( []byte) {
fips140.RecordApproved()
:= pkeDecrypt1024(&.decryptionKey1024, )
:= sha3.New512()
.Write([:])
.Write(.h[:])
:= .Sum(make([]byte, 0, 64))
, := [:SharedKeySize], [SharedKeySize:]
:= sha3.NewShake256()
.Write(.z[:])
.Write([:])
:= make([]byte, SharedKeySize)
.Read()
var [CiphertextSize1024]byte
:= pkeEncrypt1024(&, &.encryptionKey1024, (*[32]byte)(), )
subtle.ConstantTimeCopy(subtle.ConstantTimeCompare([:], ), , )
return
}
func ( *decryptionKey1024, *[CiphertextSize1024]byte) []byte {
:= make([]ringElement, k1024)
for := range {
:= (*[encodingSize11]byte)([encodingSize11* : encodingSize11*(+1)])
[] = ringDecodeAndDecompress11()
}
:= (*[encodingSize5]byte)([encodingSize11*k1024:])
:= ringDecodeAndDecompress5()
var nttElement
for := range .s {
= polyAdd(, nttMul(.s[], ntt([])))
}
:= polySub(, inverseNTT())
return ringCompressAndEncode1(nil, )
}