package cmp
import (
)
type Comparison func() Result
func (, interface{}, ...cmp.Option) Comparison {
return func() ( Result) {
defer func() {
if , := handleCmpPanic(recover()); {
= ResultFailure()
}
}()
:= cmp.Diff(, , ...)
if == "" {
return ResultSuccess
}
return multiLineDiffResult(, , )
}
}
func ( interface{}) (string, bool) {
if == nil {
return "", false
}
, := .(string)
if ! {
panic()
}
switch {
case strings.HasPrefix(, "cannot handle unexported field"):
return , true
}
panic()
}
func ( bool, string) Result {
if {
return ResultSuccess
}
return ResultFailure()
}
type RegexOrPattern interface{}
func ( RegexOrPattern, string) Comparison {
:= func( *regexp.Regexp) Result {
return toResult(
.MatchString(),
fmt.Sprintf("value %q does not match regexp %q", , .String()))
}
return func() Result {
switch regex := .(type) {
case *regexp.Regexp:
return ()
case string:
, := regexp.Compile()
if != nil {
return ResultFailure(.Error())
}
return ()
default:
return ResultFailure(fmt.Sprintf("invalid type %T for regex pattern", ))
}
}
}
func (, interface{}) Comparison {
return func() Result {
switch {
case == :
return ResultSuccess
case isMultiLineStringCompare(, ):
:= format.UnifiedDiff(format.DiffConfig{A: .(string), B: .(string)})
return multiLineDiffResult(, , )
}
return ResultFailureTemplate(`
{{- printf "%v" .Data.x}} (
{{- with callArg 0 }}{{ formatNode . }} {{end -}}
{{- printf "%T" .Data.x -}}
) != {{ printf "%v" .Data.y}} (
{{- with callArg 1 }}{{ formatNode . }} {{end -}}
{{- printf "%T" .Data.y -}}
)`,
map[string]interface{}{"x": , "y": })
}
}
func (, interface{}) bool {
, := .(string)
if ! {
return false
}
, := .(string)
if ! {
return false
}
return strings.Contains(, "\n") || strings.Contains(, "\n")
}
func ( string, , interface{}) Result {
return ResultFailureTemplate(`
--- {{ with callArg 0 }}{{ formatNode . }}{{else}}←{{end}}
+++ {{ with callArg 1 }}{{ formatNode . }}{{else}}→{{end}}
{{ .Data.diff }}`,
map[string]interface{}{"diff": , "x": , "y": })
}
func ( interface{}, int) Comparison {
return func() ( Result) {
defer func() {
if := recover(); != nil {
= ResultFailure(fmt.Sprintf("type %T does not have a length", ))
}
}()
:= reflect.ValueOf()
:= .Len()
if == {
return ResultSuccess
}
:= fmt.Sprintf("expected %s (length %d) to have length %d", , , )
return ResultFailure()
}
}
func ( interface{}, interface{}) Comparison {
return func() Result {
:= reflect.ValueOf()
if !.IsValid() {
return ResultFailure("nil does not contain items")
}
:= fmt.Sprintf("%v does not contain %v", , )
:= reflect.ValueOf()
switch .Type().Kind() {
case reflect.String:
if .Type().Kind() != reflect.String {
return ResultFailure("string may only contain strings")
}
return toResult(
strings.Contains(.String(), .String()),
fmt.Sprintf("string %q does not contain %q", , ))
case reflect.Map:
if .Type() != .Type().Key() {
return ResultFailure(fmt.Sprintf(
"%v can not contain a %v key", .Type(), .Type()))
}
return toResult(.MapIndex().IsValid(), )
case reflect.Slice, reflect.Array:
for := 0; < .Len(); ++ {
if reflect.DeepEqual(.Index().Interface(), ) {
return ResultSuccess
}
}
return ResultFailure()
default:
return ResultFailure(fmt.Sprintf("type %T does not contain items", ))
}
}
}
func ( func()) Comparison {
return func() ( Result) {
defer func() {
if := recover(); != nil {
= ResultSuccess
}
}()
()
return ResultFailure("did not panic")
}
}
func ( error, string) Comparison {
return func() Result {
switch {
case == nil:
return ResultFailure("expected an error, got nil")
case .Error() != :
return ResultFailure(fmt.Sprintf(
"expected error %q, got %s", , formatErrorMessage()))
}
return ResultSuccess
}
}
func ( error, string) Comparison {
return func() Result {
switch {
case == nil:
return ResultFailure("expected an error, got nil")
case !strings.Contains(.Error(), ):
return ResultFailure(fmt.Sprintf(
"expected error to contain %q, got %s", , formatErrorMessage()))
}
return ResultSuccess
}
}
type causer interface {
Cause() error
}
func ( error) string {
if , := .(causer); {
return fmt.Sprintf("%q\n%+v", , )
}
return fmt.Sprintf("%q", )
}
func ( interface{}) Comparison {
:= func( reflect.Value) string {
return fmt.Sprintf("%v (type %s) is not nil", reflect.Indirect(), .Type())
}
return isNil(, )
}
func ( interface{}, func(reflect.Value) string) Comparison {
return func() Result {
if == nil {
return ResultSuccess
}
:= reflect.ValueOf()
:= .Type().Kind()
if >= reflect.Chan && <= reflect.Slice {
if .IsNil() {
return ResultSuccess
}
return ResultFailure(())
}
return ResultFailure(fmt.Sprintf("%v (type %s) can not be nil", , .Type()))
}
}
func ( error, interface{}) Comparison {
return func() Result {
switch expectedType := .(type) {
case func(error) bool:
return cmpErrorTypeFunc(, )
case reflect.Type:
if .Kind() == reflect.Interface {
return cmpErrorTypeImplementsType(, )
}
return cmpErrorTypeEqualType(, )
case nil:
return ResultFailure("invalid type for expected: nil")
}
:= reflect.TypeOf()
switch {
case .Kind() == reflect.Struct, isPtrToStruct():
return cmpErrorTypeEqualType(, )
case isPtrToInterface():
return cmpErrorTypeImplementsType(, .Elem())
}
return ResultFailure(fmt.Sprintf("invalid type for expected: %T", ))
}
}
func ( error, func(error) bool) Result {
if () {
return ResultSuccess
}
:= "nil"
if != nil {
= fmt.Sprintf("%s (%T)", , )
}
return ResultFailureTemplate(`error is {{ .Data.actual }}
{{- with callArg 1 }}, not {{ formatNode . }}{{end -}}`,
map[string]interface{}{"actual": })
}
func ( error, reflect.Type) Result {
if == nil {
return ResultFailure(fmt.Sprintf("error is nil, not %s", ))
}
:= reflect.ValueOf()
if .Type() == {
return ResultSuccess
}
return ResultFailure(fmt.Sprintf("error is %s (%T), not %s", , , ))
}
func ( error, reflect.Type) Result {
if == nil {
return ResultFailure(fmt.Sprintf("error is nil, not %s", ))
}
:= reflect.ValueOf()
if .Type().Implements() {
return ResultSuccess
}
return ResultFailure(fmt.Sprintf("error is %s (%T), not %s", , , ))
}
func ( reflect.Type) bool {
return .Kind() == reflect.Ptr && .Elem().Kind() == reflect.Interface
}
func ( reflect.Type) bool {
return .Kind() == reflect.Ptr && .Elem().Kind() == reflect.Struct
}
var (
stdlibErrorNewType = reflect.TypeOf(errors.New(""))
stdlibFmtErrorType = reflect.TypeOf(fmt.Errorf("%w", fmt.Errorf("")))
)
func ( error, error) Comparison {
return func() Result {
if errors.Is(, ) {
return ResultSuccess
}
return ResultFailureTemplate(`error is
{{- if not .Data.a }} nil,{{ else }}
{{- printf " \"%v\"" .Data.a }}
{{- if notStdlibErrorType .Data.a }} ({{ printf "%T" .Data.a }}){{ end }},
{{- end }} not {{ printf "\"%v\"" .Data.x }} (
{{- with callArg 1 }}{{ formatNode . }}{{ end }}
{{- if notStdlibErrorType .Data.x }}{{ printf " %T" .Data.x }}{{ end }})`,
map[string]interface{}{"a": , "x": })
}
}