// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package impl

import (
	
	
	
	
	

	
	
	
	
	
	
	
	
)

// ValidationStatus is the result of validating the wire-format encoding of a message.
type ValidationStatus int

const (
	// ValidationUnknown indicates that unmarshaling the message might succeed or fail.
	// The validator was unable to render a judgement.
	//
	// The only causes of this status are an aberrant message type appearing somewhere
	// in the message or a failure in the extension resolver.
	ValidationUnknown ValidationStatus = iota + 1

	// ValidationInvalid indicates that unmarshaling the message will fail.
	ValidationInvalid

	// ValidationValid indicates that unmarshaling the message will succeed.
	ValidationValid
)

func ( ValidationStatus) () string {
	switch  {
	case ValidationUnknown:
		return "ValidationUnknown"
	case ValidationInvalid:
		return "ValidationInvalid"
	case ValidationValid:
		return "ValidationValid"
	default:
		return fmt.Sprintf("ValidationStatus(%d)", int())
	}
}

// Validate determines whether the contents of the buffer are a valid wire encoding
// of the message type.
//
// This function is exposed for testing.
func ( protoreflect.MessageType,  protoiface.UnmarshalInput) ( protoiface.UnmarshalOutput,  ValidationStatus) {
	,  := .(*MessageInfo)
	if ! {
		return , ValidationUnknown
	}
	if .Resolver == nil {
		.Resolver = protoregistry.GlobalTypes
	}
	,  := .validate(.Buf, 0, unmarshalOptions{
		flags:    .Flags,
		resolver: .Resolver,
	})
	if .initialized {
		.Flags |= protoiface.UnmarshalInitialized
	}
	return , 
}

type validationInfo struct {
	mi               *MessageInfo
	typ              validationType
	keyType, valType validationType

	// For non-required fields, requiredBit is 0.
	//
	// For required fields, requiredBit's nth bit is set, where n is a
	// unique index in the range [0, MessageInfo.numRequiredFields).
	//
	// If there are more than 64 required fields, requiredBit is 0.
	requiredBit uint64
}

type validationType uint8

const (
	validationTypeOther validationType = iota
	validationTypeMessage
	validationTypeGroup
	validationTypeMap
	validationTypeRepeatedVarint
	validationTypeRepeatedFixed32
	validationTypeRepeatedFixed64
	validationTypeVarint
	validationTypeFixed32
	validationTypeFixed64
	validationTypeBytes
	validationTypeUTF8String
	validationTypeMessageSetItem
)

func ( *MessageInfo,  structInfo,  protoreflect.FieldDescriptor,  reflect.Type) validationInfo {
	var  validationInfo
	switch {
	case .ContainingOneof() != nil && !.ContainingOneof().IsSynthetic():
		switch .Kind() {
		case protoreflect.MessageKind:
			.typ = validationTypeMessage
			if ,  := .oneofWrappersByNumber[.Number()];  {
				.mi = getMessageInfo(.Field(0).Type)
			}
		case protoreflect.GroupKind:
			.typ = validationTypeGroup
			if ,  := .oneofWrappersByNumber[.Number()];  {
				.mi = getMessageInfo(.Field(0).Type)
			}
		case protoreflect.StringKind:
			if strs.EnforceUTF8() {
				.typ = validationTypeUTF8String
			}
		}
	default:
		 = newValidationInfo(, )
	}
	if .Cardinality() == protoreflect.Required {
		// Avoid overflow. The required field check is done with a 64-bit mask, with
		// any message containing more than 64 required fields always reported as
		// potentially uninitialized, so it is not important to get a precise count
		// of the required fields past 64.
		if .numRequiredFields < math.MaxUint8 {
			.numRequiredFields++
			.requiredBit = 1 << (.numRequiredFields - 1)
		}
	}
	return 
}

func ( protoreflect.FieldDescriptor,  reflect.Type) validationInfo {
	var  validationInfo
	switch {
	case .IsList():
		switch .Kind() {
		case protoreflect.MessageKind:
			.typ = validationTypeMessage
			if .Kind() == reflect.Slice {
				.mi = getMessageInfo(.Elem())
			}
		case protoreflect.GroupKind:
			.typ = validationTypeGroup
			if .Kind() == reflect.Slice {
				.mi = getMessageInfo(.Elem())
			}
		case protoreflect.StringKind:
			.typ = validationTypeBytes
			if strs.EnforceUTF8() {
				.typ = validationTypeUTF8String
			}
		default:
			switch wireTypes[.Kind()] {
			case protowire.VarintType:
				.typ = validationTypeRepeatedVarint
			case protowire.Fixed32Type:
				.typ = validationTypeRepeatedFixed32
			case protowire.Fixed64Type:
				.typ = validationTypeRepeatedFixed64
			}
		}
	case .IsMap():
		.typ = validationTypeMap
		switch .MapKey().Kind() {
		case protoreflect.StringKind:
			if strs.EnforceUTF8() {
				.keyType = validationTypeUTF8String
			}
		}
		switch .MapValue().Kind() {
		case protoreflect.MessageKind:
			.valType = validationTypeMessage
			if .Kind() == reflect.Map {
				.mi = getMessageInfo(.Elem())
			}
		case protoreflect.StringKind:
			if strs.EnforceUTF8() {
				.valType = validationTypeUTF8String
			}
		}
	default:
		switch .Kind() {
		case protoreflect.MessageKind:
			.typ = validationTypeMessage
			if !.IsWeak() {
				.mi = getMessageInfo()
			}
		case protoreflect.GroupKind:
			.typ = validationTypeGroup
			.mi = getMessageInfo()
		case protoreflect.StringKind:
			.typ = validationTypeBytes
			if strs.EnforceUTF8() {
				.typ = validationTypeUTF8String
			}
		default:
			switch wireTypes[.Kind()] {
			case protowire.VarintType:
				.typ = validationTypeVarint
			case protowire.Fixed32Type:
				.typ = validationTypeFixed32
			case protowire.Fixed64Type:
				.typ = validationTypeFixed64
			case protowire.BytesType:
				.typ = validationTypeBytes
			}
		}
	}
	return 
}

func ( *MessageInfo) ( []byte,  protowire.Number,  unmarshalOptions) ( unmarshalOutput,  ValidationStatus) {
	.init()
	type  struct {
		              validationType
		,  validationType
		         protowire.Number
		               *MessageInfo
		             []byte
		     uint64
	}

	// Pre-allocate some slots to avoid repeated slice reallocation.
	 := make([], 0, 16)
	 = append(, {
		: validationTypeMessage,
		:  ,
	})
	if  > 0 {
		[0]. = validationTypeGroup
		[0]. = 
	}
	 := true
	 := len()
:
	for len() > 0 {
		 := &[len()-1]
		for len() > 0 {
			// Parse the tag (field number and wire type).
			var  uint64
			if [0] < 0x80 {
				 = uint64([0])
				 = [1:]
			} else if len() >= 2 && [1] < 128 {
				 = uint64([0]&0x7f) + uint64([1])<<7
				 = [2:]
			} else {
				var  int
				,  = protowire.ConsumeVarint()
				if  < 0 {
					return , ValidationInvalid
				}
				 = [:]
			}
			var  protowire.Number
			if  :=  >> 3;  < uint64(protowire.MinValidNumber) ||  > uint64(protowire.MaxValidNumber) {
				return , ValidationInvalid
			} else {
				 = protowire.Number()
			}
			 := protowire.Type( & 7)

			if  == protowire.EndGroupType {
				if . ==  {
					goto 
				}
				return , ValidationInvalid
			}
			var  validationInfo
			switch {
			case . == validationTypeMap:
				switch  {
				case genid.MapEntry_Key_field_number:
					.typ = .
				case genid.MapEntry_Value_field_number:
					.typ = .
					.mi = .
					.requiredBit = 1
				}
			case flags.ProtoLegacy && ..isMessageSet:
				switch  {
				case messageset.FieldItem:
					.typ = validationTypeMessageSetItem
				}
			default:
				var  *coderFieldInfo
				if int() < len(..denseCoderFields) {
					 = ..denseCoderFields[]
				} else {
					 = ..coderFields[]
				}
				if  != nil {
					 = .validation
					if .typ == validationTypeMessage && .mi == nil {
						// Probable weak field.
						//
						// TODO: Consider storing the results of this lookup somewhere
						// rather than recomputing it on every validation.
						 := ..Desc.Fields().ByNumber()
						if  == nil || !.IsWeak() {
							break
						}
						 := .Message().FullName()
						,  := protoregistry.GlobalTypes.FindMessageByName()
						switch  {
						case nil:
							.mi, _ = .(*MessageInfo)
						case protoregistry.NotFound:
							.typ = validationTypeBytes
						default:
							return , ValidationUnknown
						}
					}
					break
				}
				// Possible extension field.
				//
				// TODO: We should return ValidationUnknown when:
				//   1. The resolver is not frozen. (More extensions may be added to it.)
				//   2. The resolver returns preg.NotFound.
				// In this case, a type added to the resolver in the future could cause
				// unmarshaling to begin failing. Supporting this requires some way to
				// determine if the resolver is frozen.
				,  := .resolver.FindExtensionByNumber(..Desc.FullName(), )
				if  != nil &&  != protoregistry.NotFound {
					return , ValidationUnknown
				}
				if  == nil {
					 = getExtensionFieldInfo().validation
				}
			}
			if .requiredBit != 0 {
				// Check that the field has a compatible wire type.
				// We only need to consider non-repeated field types,
				// since repeated fields (and maps) can never be required.
				 := false
				switch .typ {
				case validationTypeVarint:
					 =  == protowire.VarintType
				case validationTypeFixed32:
					 =  == protowire.Fixed32Type
				case validationTypeFixed64:
					 =  == protowire.Fixed64Type
				case validationTypeBytes, validationTypeUTF8String, validationTypeMessage:
					 =  == protowire.BytesType
				case validationTypeGroup:
					 =  == protowire.StartGroupType
				}
				if  {
					. |= .requiredBit
				}
			}

			switch  {
			case protowire.VarintType:
				if len() >= 10 {
					switch {
					case [0] < 0x80:
						 = [1:]
					case [1] < 0x80:
						 = [2:]
					case [2] < 0x80:
						 = [3:]
					case [3] < 0x80:
						 = [4:]
					case [4] < 0x80:
						 = [5:]
					case [5] < 0x80:
						 = [6:]
					case [6] < 0x80:
						 = [7:]
					case [7] < 0x80:
						 = [8:]
					case [8] < 0x80:
						 = [9:]
					case [9] < 0x80 && [9] < 2:
						 = [10:]
					default:
						return , ValidationInvalid
					}
				} else {
					switch {
					case len() > 0 && [0] < 0x80:
						 = [1:]
					case len() > 1 && [1] < 0x80:
						 = [2:]
					case len() > 2 && [2] < 0x80:
						 = [3:]
					case len() > 3 && [3] < 0x80:
						 = [4:]
					case len() > 4 && [4] < 0x80:
						 = [5:]
					case len() > 5 && [5] < 0x80:
						 = [6:]
					case len() > 6 && [6] < 0x80:
						 = [7:]
					case len() > 7 && [7] < 0x80:
						 = [8:]
					case len() > 8 && [8] < 0x80:
						 = [9:]
					case len() > 9 && [9] < 2:
						 = [10:]
					default:
						return , ValidationInvalid
					}
				}
				continue 
			case protowire.BytesType:
				var  uint64
				if len() >= 1 && [0] < 0x80 {
					 = uint64([0])
					 = [1:]
				} else if len() >= 2 && [1] < 128 {
					 = uint64([0]&0x7f) + uint64([1])<<7
					 = [2:]
				} else {
					var  int
					,  = protowire.ConsumeVarint()
					if  < 0 {
						return , ValidationInvalid
					}
					 = [:]
				}
				if  > uint64(len()) {
					return , ValidationInvalid
				}
				 := [:]
				 = [:]
				switch .typ {
				case validationTypeMessage:
					if .mi == nil {
						return , ValidationUnknown
					}
					.mi.init()
					fallthrough
				case validationTypeMap:
					if .mi != nil {
						.mi.init()
					}
					 = append(, {
						:     .typ,
						: .keyType,
						: .valType,
						:      .mi,
						:    ,
					})
					 = 
					continue 
				case validationTypeRepeatedVarint:
					// Packed field.
					for len() > 0 {
						,  := protowire.ConsumeVarint()
						if  < 0 {
							return , ValidationInvalid
						}
						 = [:]
					}
				case validationTypeRepeatedFixed32:
					// Packed field.
					if len()%4 != 0 {
						return , ValidationInvalid
					}
				case validationTypeRepeatedFixed64:
					// Packed field.
					if len()%8 != 0 {
						return , ValidationInvalid
					}
				case validationTypeUTF8String:
					if !utf8.Valid() {
						return , ValidationInvalid
					}
				}
			case protowire.Fixed32Type:
				if len() < 4 {
					return , ValidationInvalid
				}
				 = [4:]
			case protowire.Fixed64Type:
				if len() < 8 {
					return , ValidationInvalid
				}
				 = [8:]
			case protowire.StartGroupType:
				switch {
				case .typ == validationTypeGroup:
					if .mi == nil {
						return , ValidationUnknown
					}
					.mi.init()
					 = append(, {
						:      validationTypeGroup,
						:       .mi,
						: ,
					})
					continue 
				case flags.ProtoLegacy && .typ == validationTypeMessageSetItem:
					, , ,  := messageset.ConsumeFieldValue(, false)
					if  != nil {
						return , ValidationInvalid
					}
					,  := .resolver.FindExtensionByNumber(..Desc.FullName(), )
					switch {
					case  == protoregistry.NotFound:
						 = [:]
					case  != nil:
						return , ValidationUnknown
					default:
						 := getExtensionFieldInfo().validation
						if .mi != nil {
							.mi.init()
						}
						 = append(, {
							:  .typ,
							:   .mi,
							: [:],
						})
						 = 
						continue 
					}
				default:
					 := protowire.ConsumeFieldValue(, , )
					if  < 0 {
						return , ValidationInvalid
					}
					 = [:]
				}
			default:
				return , ValidationInvalid
			}
		}
		if . != 0 {
			return , ValidationInvalid
		}
		if len() != 0 {
			return , ValidationInvalid
		}
		 = .
	:
		 := 0
		switch . {
		case validationTypeMessage, validationTypeGroup:
			 = int(..numRequiredFields)
		case validationTypeMap:
			// If this is a map field with a message value that contains
			// required fields, require that the value be present.
			if . != nil && ..numRequiredFields > 0 {
				 = 1
			}
		}
		// If there are more than 64 required fields, this check will
		// always fail and we will report that the message is potentially
		// uninitialized.
		if  > 0 && bits.OnesCount64(.) !=  {
			 = false
		}
		 = [:len()-1]
	}
	.n =  - len()
	if  {
		.initialized = true
	}
	return , ValidationValid
}