package checksum
import (
v4
smithyhttp
)
const (
contentMD5Header = "Content-Md5"
streamingUnsignedPayloadTrailerPayloadHash = "STREAMING-UNSIGNED-PAYLOAD-TRAILER"
)
type computedInputChecksumsKey struct{}
func ( middleware.Metadata) (map[string]string, bool) {
, := .Get(computedInputChecksumsKey{}).(map[string]string)
return ,
}
func ( *middleware.Metadata, map[string]string) {
.Set(computedInputChecksumsKey{}, )
}
type computeInputPayloadChecksum struct {
EnableTrailingChecksum bool
RequireChecksum bool
EnableComputePayloadHash bool
EnableDecodedContentLengthHeader bool
buildHandlerRun bool
deferToFinalizeHandler bool
}
func ( *computeInputPayloadChecksum) () string {
return "AWSChecksum:ComputeInputPayloadChecksum"
}
type computeInputHeaderChecksumError struct {
Msg string
Err error
}
func ( computeInputHeaderChecksumError) () string {
const = "compute input header checksum failed"
if .Err != nil {
return fmt.Sprintf("%s, %s, %v", , .Msg, .Err)
}
return fmt.Sprintf("%s, %s", , .Msg)
}
func ( computeInputHeaderChecksumError) () error { return .Err }
func ( *computeInputPayloadChecksum) (
context.Context, middleware.BuildInput, middleware.BuildHandler,
) (
middleware.BuildOutput, middleware.Metadata, error,
) {
.buildHandlerRun = true
, := .Request.(*smithyhttp.Request)
if ! {
return , , computeInputHeaderChecksumError{
Msg: fmt.Sprintf("unknown request type %T", ),
}
}
var Algorithm
var string
defer func() {
if == "" || == "" || != nil {
return
}
SetComputedInputChecksums(&, map[string]string{
string(): ,
})
}()
, , = getInputAlgorithm()
if != nil {
return , ,
} else if ! {
if .RequireChecksum {
, = setMD5Checksum(, )
if != nil {
return , , computeInputHeaderChecksumError{
Msg: "failed to compute stream's MD5 checksum",
Err: ,
}
}
= Algorithm("MD5")
}
return .HandleBuild(, )
}
:= AlgorithmHTTPHeader()
if = .Header.Get(); != "" {
return .HandleBuild(, )
}
:= .EnableComputePayloadHash
if := v4.GetPayloadHash(); != "" {
= false
}
:= .GetStream()
, := getRequestStreamLength()
if != nil {
return , , computeInputHeaderChecksumError{
Msg: "failed to determine stream length",
Err: ,
}
}
if .IsHTTPS() {
if != nil && != 0 && .EnableTrailingChecksum {
if .EnableComputePayloadHash {
= v4.SetPayloadHash(, streamingUnsignedPayloadTrailerPayloadHash)
}
.deferToFinalizeHandler = true
return .HandleBuild(, )
}
= false
}
if != nil && !.IsStreamSeekable() {
return , , computeInputHeaderChecksumError{
Msg: "unseekable stream is not supported without TLS and trailing checksum",
}
}
var string
, , = computeStreamChecksum(
, , )
if != nil {
return , , computeInputHeaderChecksumError{
Msg: "failed to compute stream checksum",
Err: ,
}
}
if := .RewindStream(); != nil {
return , , computeInputHeaderChecksumError{
Msg: "failed to rewind stream",
Err: ,
}
}
.Header.Set(, )
if {
= v4.SetPayloadHash(, )
}
return .HandleBuild(, )
}
type computeInputTrailingChecksumError struct {
Msg string
Err error
}
func ( computeInputTrailingChecksumError) () string {
const = "compute input trailing checksum failed"
if .Err != nil {
return fmt.Sprintf("%s, %s, %v", , .Msg, .Err)
}
return fmt.Sprintf("%s, %s", , .Msg)
}
func ( computeInputTrailingChecksumError) () error { return .Err }
func ( *computeInputPayloadChecksum) (
context.Context, middleware.FinalizeInput, middleware.FinalizeHandler,
) (
middleware.FinalizeOutput, middleware.Metadata, error,
) {
if !.deferToFinalizeHandler {
if !.buildHandlerRun {
return , , computeInputTrailingChecksumError{
Msg: "build handler was removed without also removing finalize handler",
}
}
return .HandleFinalize(, )
}
, := .Request.(*smithyhttp.Request)
if ! {
return , , computeInputTrailingChecksumError{
Msg: fmt.Sprintf("unknown request type %T", ),
}
}
if !.IsHTTPS() {
return , , computeInputTrailingChecksumError{
Msg: "HTTPS required",
}
}
, , := getInputAlgorithm()
if != nil {
return , , computeInputTrailingChecksumError{
Msg: "failed to get algorithm",
Err: ,
}
} else if ! {
return , , computeInputTrailingChecksumError{
Msg: "no algorithm specified",
}
}
:= AlgorithmHTTPHeader()
if .Header.Get() != "" {
return .HandleFinalize(, )
}
:= .GetStream()
, := getRequestStreamLength()
if != nil {
return , , computeInputTrailingChecksumError{
Msg: "failed to determine stream length",
Err: ,
}
}
if == nil || == 0 {
return , , computeInputTrailingChecksumError{
Msg: "nil or empty streams are not supported",
}
}
, := newComputeChecksumReader(, )
if != nil {
return , , computeInputTrailingChecksumError{
Msg: "failed to created checksum reader",
Err: ,
}
}
:= newUnsignedAWSChunkedEncoding(,
func( *awsChunkedEncodingOptions) {
.Trailers[AlgorithmHTTPHeader(.Algorithm())] = awsChunkedTrailerValue{
Get: .Base64Checksum,
Length: .Base64ChecksumLength(),
}
.StreamLength =
})
for , := range .HTTPHeaders() {
for , := range {
.Header.Add(, )
}
}
, = .SetStream()
if != nil {
return , , computeInputTrailingChecksumError{
Msg: "failed updating request to trailing checksum wrapped stream",
Err: ,
}
}
.ContentLength = .EncodedLength()
.Request =
if != -1 && .EnableDecodedContentLengthHeader {
.Header.Set(decodedContentLengthHeaderName, strconv.FormatInt(, 10))
}
, , = .HandleFinalize(, )
if == nil {
, := .Base64Checksum()
if != nil {
return , , fmt.Errorf("failed to get computed checksum, %w", )
}
SetComputedInputChecksums(&, map[string]string{
string(): ,
})
}
return , ,
}
func ( context.Context) (Algorithm, bool, error) {
:= getContextInputAlgorithm()
if == "" {
return "", false, nil
}
, := ParseAlgorithm()
if != nil {
return "", false, fmt.Errorf(
"failed to parse algorithm, %w", )
}
return , true, nil
}
func ( Algorithm, io.Reader, bool) (
string, string, error,
) {
, := NewAlgorithmHash()
if != nil {
return "", "", fmt.Errorf(
"failed to get hasher for checksum algorithm, %w", )
}
var hash.Hash
var io.Writer =
if && != AlgorithmSHA256 {
= sha256.New()
= io.MultiWriter(, )
}
if != nil {
if _, = io.Copy(, ); != nil {
return "", "", fmt.Errorf(
"failed to read stream to compute hash, %w", )
}
}
= string(base64EncodeHashSum())
if {
if != AlgorithmSHA256 {
= string(hexEncodeHashSum())
} else {
= string(hexEncodeHashSum())
}
}
return , , nil
}
func ( *smithyhttp.Request) (int64, error) {
if := .ContentLength; > 0 {
return , nil
}
if , , := .StreamLength(); != nil {
return 0, fmt.Errorf("failed getting request stream's length, %w", )
} else if {
return , nil
}
return -1, nil
}
func ( context.Context, *smithyhttp.Request) (string, error) {
if := .Header.Get(contentMD5Header); len() != 0 {
return , nil
}
:= .GetStream()
if == nil {
return "", nil
}
if !.IsStreamSeekable() {
return "", fmt.Errorf(
"unseekable stream is not supported for computing md5 checksum")
}
, := computeMD5Checksum()
if != nil {
return "",
}
if := .RewindStream(); != nil {
return "", fmt.Errorf("failed to rewind stream after computing MD5 checksum, %w", )
}
.Header.Set(contentMD5Header, string())
return string(), nil
}