package impl
import (
)
type ValidationStatus int
const (
ValidationUnknown ValidationStatus = iota + 1
ValidationInvalid
ValidationValid
)
func ( ValidationStatus) () string {
switch {
case ValidationUnknown:
return "ValidationUnknown"
case ValidationInvalid:
return "ValidationInvalid"
case ValidationValid:
return "ValidationValid"
default:
return fmt.Sprintf("ValidationStatus(%d)", int())
}
}
func ( protoreflect.MessageType, protoiface.UnmarshalInput) ( protoiface.UnmarshalOutput, ValidationStatus) {
, := .(*MessageInfo)
if ! {
return , ValidationUnknown
}
if .Resolver == nil {
.Resolver = protoregistry.GlobalTypes
}
, := .validate(.Buf, 0, unmarshalOptions{
flags: .Flags,
resolver: .Resolver,
})
if .initialized {
.Flags |= protoiface.UnmarshalInitialized
}
return ,
}
type validationInfo struct {
mi *MessageInfo
typ validationType
keyType, valType validationType
requiredBit uint64
}
type validationType uint8
const (
validationTypeOther validationType = iota
validationTypeMessage
validationTypeGroup
validationTypeMap
validationTypeRepeatedVarint
validationTypeRepeatedFixed32
validationTypeRepeatedFixed64
validationTypeVarint
validationTypeFixed32
validationTypeFixed64
validationTypeBytes
validationTypeUTF8String
validationTypeMessageSetItem
)
func ( *MessageInfo, structInfo, protoreflect.FieldDescriptor, reflect.Type) validationInfo {
var validationInfo
switch {
case .ContainingOneof() != nil && !.ContainingOneof().IsSynthetic():
switch .Kind() {
case protoreflect.MessageKind:
.typ = validationTypeMessage
if , := .oneofWrappersByNumber[.Number()]; {
.mi = getMessageInfo(.Field(0).Type)
}
case protoreflect.GroupKind:
.typ = validationTypeGroup
if , := .oneofWrappersByNumber[.Number()]; {
.mi = getMessageInfo(.Field(0).Type)
}
case protoreflect.StringKind:
if strs.EnforceUTF8() {
.typ = validationTypeUTF8String
}
}
default:
= newValidationInfo(, )
}
if .Cardinality() == protoreflect.Required {
if .numRequiredFields < math.MaxUint8 {
.numRequiredFields++
.requiredBit = 1 << (.numRequiredFields - 1)
}
}
return
}
func ( protoreflect.FieldDescriptor, reflect.Type) validationInfo {
var validationInfo
switch {
case .IsList():
switch .Kind() {
case protoreflect.MessageKind:
.typ = validationTypeMessage
if .Kind() == reflect.Slice {
.mi = getMessageInfo(.Elem())
}
case protoreflect.GroupKind:
.typ = validationTypeGroup
if .Kind() == reflect.Slice {
.mi = getMessageInfo(.Elem())
}
case protoreflect.StringKind:
.typ = validationTypeBytes
if strs.EnforceUTF8() {
.typ = validationTypeUTF8String
}
default:
switch wireTypes[.Kind()] {
case protowire.VarintType:
.typ = validationTypeRepeatedVarint
case protowire.Fixed32Type:
.typ = validationTypeRepeatedFixed32
case protowire.Fixed64Type:
.typ = validationTypeRepeatedFixed64
}
}
case .IsMap():
.typ = validationTypeMap
switch .MapKey().Kind() {
case protoreflect.StringKind:
if strs.EnforceUTF8() {
.keyType = validationTypeUTF8String
}
}
switch .MapValue().Kind() {
case protoreflect.MessageKind:
.valType = validationTypeMessage
if .Kind() == reflect.Map {
.mi = getMessageInfo(.Elem())
}
case protoreflect.StringKind:
if strs.EnforceUTF8() {
.valType = validationTypeUTF8String
}
}
default:
switch .Kind() {
case protoreflect.MessageKind:
.typ = validationTypeMessage
if !.IsWeak() {
.mi = getMessageInfo()
}
case protoreflect.GroupKind:
.typ = validationTypeGroup
.mi = getMessageInfo()
case protoreflect.StringKind:
.typ = validationTypeBytes
if strs.EnforceUTF8() {
.typ = validationTypeUTF8String
}
default:
switch wireTypes[.Kind()] {
case protowire.VarintType:
.typ = validationTypeVarint
case protowire.Fixed32Type:
.typ = validationTypeFixed32
case protowire.Fixed64Type:
.typ = validationTypeFixed64
case protowire.BytesType:
.typ = validationTypeBytes
}
}
}
return
}
func ( *MessageInfo) ( []byte, protowire.Number, unmarshalOptions) ( unmarshalOutput, ValidationStatus) {
.init()
type struct {
validationType
, validationType
protowire.Number
*MessageInfo
[]byte
uint64
}
:= make([], 0, 16)
= append(, {
: validationTypeMessage,
: ,
})
if > 0 {
[0]. = validationTypeGroup
[0]. =
}
:= true
:= len()
:
for len() > 0 {
:= &[len()-1]
for len() > 0 {
var uint64
if [0] < 0x80 {
= uint64([0])
= [1:]
} else if len() >= 2 && [1] < 128 {
= uint64([0]&0x7f) + uint64([1])<<7
= [2:]
} else {
var int
, = protowire.ConsumeVarint()
if < 0 {
return , ValidationInvalid
}
= [:]
}
var protowire.Number
if := >> 3; < uint64(protowire.MinValidNumber) || > uint64(protowire.MaxValidNumber) {
return , ValidationInvalid
} else {
= protowire.Number()
}
:= protowire.Type( & 7)
if == protowire.EndGroupType {
if . == {
goto
}
return , ValidationInvalid
}
var validationInfo
switch {
case . == validationTypeMap:
switch {
case genid.MapEntry_Key_field_number:
.typ = .
case genid.MapEntry_Value_field_number:
.typ = .
.mi = .
.requiredBit = 1
}
case flags.ProtoLegacy && ..isMessageSet:
switch {
case messageset.FieldItem:
.typ = validationTypeMessageSetItem
}
default:
var *coderFieldInfo
if int() < len(..denseCoderFields) {
= ..denseCoderFields[]
} else {
= ..coderFields[]
}
if != nil {
= .validation
if .typ == validationTypeMessage && .mi == nil {
:= ..Desc.Fields().ByNumber()
if == nil || !.IsWeak() {
break
}
:= .Message().FullName()
, := protoregistry.GlobalTypes.FindMessageByName()
switch {
case nil:
.mi, _ = .(*MessageInfo)
case protoregistry.NotFound:
.typ = validationTypeBytes
default:
return , ValidationUnknown
}
}
break
}
, := .resolver.FindExtensionByNumber(..Desc.FullName(), )
if != nil && != protoregistry.NotFound {
return , ValidationUnknown
}
if == nil {
= getExtensionFieldInfo().validation
}
}
if .requiredBit != 0 {
:= false
switch .typ {
case validationTypeVarint:
= == protowire.VarintType
case validationTypeFixed32:
= == protowire.Fixed32Type
case validationTypeFixed64:
= == protowire.Fixed64Type
case validationTypeBytes, validationTypeUTF8String, validationTypeMessage:
= == protowire.BytesType
case validationTypeGroup:
= == protowire.StartGroupType
}
if {
. |= .requiredBit
}
}
switch {
case protowire.VarintType:
if len() >= 10 {
switch {
case [0] < 0x80:
= [1:]
case [1] < 0x80:
= [2:]
case [2] < 0x80:
= [3:]
case [3] < 0x80:
= [4:]
case [4] < 0x80:
= [5:]
case [5] < 0x80:
= [6:]
case [6] < 0x80:
= [7:]
case [7] < 0x80:
= [8:]
case [8] < 0x80:
= [9:]
case [9] < 0x80 && [9] < 2:
= [10:]
default:
return , ValidationInvalid
}
} else {
switch {
case len() > 0 && [0] < 0x80:
= [1:]
case len() > 1 && [1] < 0x80:
= [2:]
case len() > 2 && [2] < 0x80:
= [3:]
case len() > 3 && [3] < 0x80:
= [4:]
case len() > 4 && [4] < 0x80:
= [5:]
case len() > 5 && [5] < 0x80:
= [6:]
case len() > 6 && [6] < 0x80:
= [7:]
case len() > 7 && [7] < 0x80:
= [8:]
case len() > 8 && [8] < 0x80:
= [9:]
case len() > 9 && [9] < 2:
= [10:]
default:
return , ValidationInvalid
}
}
continue
case protowire.BytesType:
var uint64
if len() >= 1 && [0] < 0x80 {
= uint64([0])
= [1:]
} else if len() >= 2 && [1] < 128 {
= uint64([0]&0x7f) + uint64([1])<<7
= [2:]
} else {
var int
, = protowire.ConsumeVarint()
if < 0 {
return , ValidationInvalid
}
= [:]
}
if > uint64(len()) {
return , ValidationInvalid
}
:= [:]
= [:]
switch .typ {
case validationTypeMessage:
if .mi == nil {
return , ValidationUnknown
}
.mi.init()
fallthrough
case validationTypeMap:
if .mi != nil {
.mi.init()
}
= append(, {
: .typ,
: .keyType,
: .valType,
: .mi,
: ,
})
=
continue
case validationTypeRepeatedVarint:
for len() > 0 {
, := protowire.ConsumeVarint()
if < 0 {
return , ValidationInvalid
}
= [:]
}
case validationTypeRepeatedFixed32:
if len()%4 != 0 {
return , ValidationInvalid
}
case validationTypeRepeatedFixed64:
if len()%8 != 0 {
return , ValidationInvalid
}
case validationTypeUTF8String:
if !utf8.Valid() {
return , ValidationInvalid
}
}
case protowire.Fixed32Type:
if len() < 4 {
return , ValidationInvalid
}
= [4:]
case protowire.Fixed64Type:
if len() < 8 {
return , ValidationInvalid
}
= [8:]
case protowire.StartGroupType:
switch {
case .typ == validationTypeGroup:
if .mi == nil {
return , ValidationUnknown
}
.mi.init()
= append(, {
: validationTypeGroup,
: .mi,
: ,
})
continue
case flags.ProtoLegacy && .typ == validationTypeMessageSetItem:
, , , := messageset.ConsumeFieldValue(, false)
if != nil {
return , ValidationInvalid
}
, := .resolver.FindExtensionByNumber(..Desc.FullName(), )
switch {
case == protoregistry.NotFound:
= [:]
case != nil:
return , ValidationUnknown
default:
:= getExtensionFieldInfo().validation
if .mi != nil {
.mi.init()
}
= append(, {
: .typ,
: .mi,
: [:],
})
=
continue
}
default:
:= protowire.ConsumeFieldValue(, , )
if < 0 {
return , ValidationInvalid
}
= [:]
}
default:
return , ValidationInvalid
}
}
if . != 0 {
return , ValidationInvalid
}
if len() != 0 {
return , ValidationInvalid
}
= .
:
:= 0
switch . {
case validationTypeMessage, validationTypeGroup:
= int(..numRequiredFields)
case validationTypeMap:
if . != nil && ..numRequiredFields > 0 {
= 1
}
}
if > 0 && bits.OnesCount64(.) != {
= false
}
= [:len()-1]
}
.n = - len()
if {
.initialized = true
}
return , ValidationValid
}