package checksum
import (
)
type Algorithm string
const (
AlgorithmCRC32C Algorithm = "CRC32C"
AlgorithmCRC32 Algorithm = "CRC32"
AlgorithmSHA1 Algorithm = "SHA1"
AlgorithmSHA256 Algorithm = "SHA256"
)
var supportedAlgorithms = []Algorithm{
AlgorithmCRC32C,
AlgorithmCRC32,
AlgorithmSHA1,
AlgorithmSHA256,
}
func ( Algorithm) () string { return string() }
func ( string) (Algorithm, error) {
for , := range supportedAlgorithms {
if strings.EqualFold(string(), ) {
return , nil
}
}
return "", fmt.Errorf("unknown checksum algorithm, %v", )
}
func ( []string) []Algorithm {
:= map[Algorithm]struct{}{}
:= make([]Algorithm, 0, len(supportedAlgorithms))
for , := range {
for , := range supportedAlgorithms {
if !strings.EqualFold(, string()) {
continue
}
if , := []; {
continue
}
= append(, )
[] = struct{}{}
}
}
return
}
func ( Algorithm) (hash.Hash, error) {
switch {
case AlgorithmSHA1:
return sha1.New(), nil
case AlgorithmSHA256:
return sha256.New(), nil
case AlgorithmCRC32:
return crc32.NewIEEE(), nil
case AlgorithmCRC32C:
return crc32.New(crc32.MakeTable(crc32.Castagnoli)), nil
default:
return nil, fmt.Errorf("unknown checksum algorithm, %v", )
}
}
func ( Algorithm) (int, error) {
switch {
case AlgorithmSHA1:
return sha1.Size, nil
case AlgorithmSHA256:
return sha256.Size, nil
case AlgorithmCRC32:
return crc32.Size, nil
case AlgorithmCRC32C:
return crc32.Size, nil
default:
return 0, fmt.Errorf("unknown checksum algorithm, %v", )
}
}
const awsChecksumHeaderPrefix = "x-amz-checksum-"
func ( Algorithm) string {
return awsChecksumHeaderPrefix + strings.ToLower(string())
}
func ( hash.Hash) []byte {
:= .Sum(nil)
:= make([]byte, base64.StdEncoding.EncodedLen(len()))
base64.StdEncoding.Encode(, )
return
}
func ( hash.Hash) []byte {
:= .Sum(nil)
:= make([]byte, hex.EncodedLen(len()))
hex.Encode(, )
return
}
func ( io.Reader) ([]byte, error) {
:= md5.New()
if , := io.Copy(, ); != nil {
return nil, fmt.Errorf("failed compute MD5 hash of reader, %w", )
}
return base64EncodeHashSum(), nil
}
type computeChecksumReader struct {
stream io.Reader
algorithm Algorithm
hasher hash.Hash
base64ChecksumLen int
mux sync.RWMutex
lockedChecksum string
lockedErr error
}
func ( io.Reader, Algorithm) (*computeChecksumReader, error) {
, := NewAlgorithmHash()
if != nil {
return nil,
}
, := AlgorithmChecksumLength()
if != nil {
return nil,
}
return &computeChecksumReader{
stream: io.TeeReader(, ),
algorithm: ,
hasher: ,
base64ChecksumLen: base64.StdEncoding.EncodedLen(),
}, nil
}
func ( *computeChecksumReader) ( []byte) (int, error) {
, := .stream.Read()
if == nil {
return , nil
} else if != io.EOF {
.mux.Lock()
defer .mux.Unlock()
.lockedErr =
return ,
}
:= base64EncodeHashSum(.hasher)
.mux.Lock()
defer .mux.Unlock()
.lockedChecksum = string()
return ,
}
func ( *computeChecksumReader) () Algorithm {
return .algorithm
}
func ( *computeChecksumReader) () int {
return .base64ChecksumLen
}
func ( *computeChecksumReader) () (string, error) {
.mux.RLock()
defer .mux.RUnlock()
if .lockedErr != nil {
return "", .lockedErr
}
if .lockedChecksum == "" {
return "", fmt.Errorf(
"checksum not available yet, called before reader returns EOF",
)
}
return .lockedChecksum, nil
}
type validateChecksumReader struct {
originalBody io.ReadCloser
body io.Reader
hasher hash.Hash
algorithm Algorithm
expectChecksum string
}
func (
io.ReadCloser,
Algorithm,
string,
) (*validateChecksumReader, error) {
, := NewAlgorithmHash()
if != nil {
return nil,
}
return &validateChecksumReader{
originalBody: ,
body: io.TeeReader(, ),
hasher: ,
algorithm: ,
expectChecksum: ,
}, nil
}
func ( *validateChecksumReader) ( []byte) ( int, error) {
, = .body.Read()
if == io.EOF {
if := .validateChecksum(); != nil {
return ,
}
}
return ,
}
func ( *validateChecksumReader) () ( error) {
return .originalBody.Close()
}
func ( *validateChecksumReader) () error {
:= base64EncodeHashSum(.hasher)
if , := .expectChecksum, string(); !strings.EqualFold(, ) {
return validationError{
Algorithm: .algorithm, Expect: , Actual: ,
}
}
return nil
}
type validationError struct {
Algorithm Algorithm
Expect string
Actual string
}
func ( validationError) () string {
return fmt.Sprintf("checksum did not match: algorithm %v, expect %v, actual %v",
.Algorithm, .Expect, .Actual)
}