8000 Use go-cmp.Equal instead of reflect.DeepEqual by posener · Pull Request #579 · stretchr/testify · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Use go-cmp.Equal instead of reflect.DeepEqual #579

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 9 additions & 74 deletions assert/assertions.go
8000
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"unicode/utf8"

"github.com/davecgh/go-spew/spew"
"github.com/pmezard/go-difflib/difflib"
yaml "gopkg.in/yaml.v2"
)

Expand Down Expand Up @@ -59,20 +58,7 @@ func ObjectsAreEqual(expected, actual interface{}) bool {
if expected == nil || actual == nil {
return expected == actual
}

exp, ok := expected.([]byte)
if !ok {
return reflect.DeepEqual(expected, actual)
}

act, ok := actual.([]byte)
if !ok {
return false
}
if exp == nil || act == nil {
return exp == nil && act == nil
}
return bytes.Equal(exp, act)
return cmpEqual(expected, actual)
}

// ObjectsAreEqualValues gets whether two objects are equal, or if their
Expand All @@ -89,7 +75,7 @@ func ObjectsAreEqualValues(expected, actual interface{}) bool {
expectedValue := reflect.ValueOf(expected)
if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) {
// Attempt comparison after type conversion
return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual)
return cmpEqual(expectedValue.Convert(actualType).Interface(), actual)
}

return false
Expand All @@ -110,7 +96,7 @@ func CallerInfo() []string {
var line int
var name string

callers := []string{}
var callers []string
for i := 0; ; i++ {
pc, file, line, ok = runtime.Caller(i)
if !ok {
Expand Down Expand Up @@ -341,7 +327,7 @@ func Equal(t TestingT, expected, actual interface{}, msgAndArgs ...interface{})
}

if !ObjectsAreEqual(expected, actual) {
diff := diff(expected, actual)
diff := cmpDiff(expected, actual)
expected, actual = formatUnequalValues(expected, actual)
return Fail(t, fmt.Sprintf("Not equal: \n"+
"expected: %s\n"+
Expand Down Expand Up @@ -462,7 +448,7 @@ func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interfa
}

if !ObjectsAreEqualValues(expected, actual) {
diff := diff(expected, actual)
diff := cmpDiff(expected, actual)
expected, actual = formatUnequalValues(expected, actual)
return Fail(t, fmt.Sprintf("Not equal: \n"+
"expected: %s\n"+
Expand Down Expand Up @@ -575,7 +561,7 @@ func isEmpty(object interface{}) bool {
// for all other types, compare against the zero value
default:
zero := reflect.Zero(objValue.Type())
return reflect.DeepEqual(object, zero.Interface())
return cmpEqual(object, zero.Interface())
}
}

Expand Down Expand Up @@ -1369,7 +1355,7 @@ func matchRegexp(rx interface{}, str interface{}) bool {
r = regexp.MustCompile(fmt.Sprint(rx))
}

return (r.FindStringIndex(fmt.Sprint(str)) != nil)
return r.FindStringIndex(fmt.Sprint(str)) != nil

}

Expand Down Expand Up @@ -1414,7 +1400,7 @@ func Zero(t TestingT, i interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if i != nil && !reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) {
if i != nil && !cmpEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) {
return Fail(t, fmt.Sprintf("Should be zero, but was %v", i), msgAndArgs...)
}
return true
Expand All @@ -1425,7 +1411,7 @@ func NotZero(t TestingT, i interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if i == nil || reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) {
if i == nil || cmpEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) {
return Fail(t, fmt.Sprintf("Should not be zero, but was %v", i), msgAndArgs...)
}
return true
Expand Down Expand Up @@ -1542,57 +1528,6 @@ func YAMLEq(t TestingT, expected string, actual string, msgAndArgs ...interface{
return Equal(t, expectedYAMLAsInterface, actualYAMLAsInterface, msgAndArgs...)
}

func typeAndKind(v interface{}) (reflect.Type, reflect.Kind) {
t := reflect.TypeOf(v)
k := t.Kind()

if k == reflect.Ptr {
t = t.Elem()
k = t.Kind()
}
return t, k
}

// diff returns a diff of both values as long as both are of the same type and
// are a struct, map, slice, array or string. Otherwise it returns an empty string.
func diff(expected interface{}, actual interface{}) string {
if expected == nil || actual == nil {
return ""
}

et, ek := typeAndKind(expected)
at, _ := typeAndKind(actual)

if et != at {
return ""
}

if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array && ek != reflect.String {
return ""
}

var e, a string
if et != reflect.TypeOf("") {
e = spewConfig.Sdump(expected)
a = spewConfig.Sdump(actual)
} else {
e = reflect.ValueOf(expected).String()
a = reflect.ValueOf(actual).String()
}

diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{
A: difflib.SplitLines(e),
B: difflib.SplitLines(a),
FromFile: "Expected",
FromDate: "",
ToFile: "Actual",
ToDate: "",
Context: 1,
})

return "\n\nDiff:\n" + diff
}

func isFunction(arg interface{}) bool {
if arg == nil {
return false
Expand Down
81 changes: 36 additions & 45 deletions assert/assertions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1419,7 +1419,7 @@ func TestInDeltaMapValues(t *testing.T) {
f: False,
},
} {
tc.f(t, InDeltaMapValues(mockT, tc.expect, tc.actual, tc.delta), tc.title+"\n"+diff(tc.expect, tc.actual))
tc.f(t, InDeltaMapValues(mockT, tc.expect, tc.actual, tc.delta), tc.title+"\n"+cmpDiff(tc.expect, tc.actual))
}
}

Expand Down Expand Up @@ -1855,13 +1855,11 @@ func TestDiff(t *testing.T) {
Diff:
--- Expected
+++ Actual
@@ -1,3 +1,3 @@
(struct { foo string }) {
- foo: (string) (len=5) "hello"
+ foo: (string) (len=3) "bar"
}
root.foo:
-: "hello"
+: "bar"
`
actual := diff(
actual := cmpDiff(
struct{ foo string }{"hello"},
struct{ foo string }{"bar"},
)
Expand All @@ -1872,16 +1870,11 @@ Diff:
Diff:
--- Expected
+++ Actual
@@ -2,5 +2,5 @@
(int) 1,
- (int) 2,
(int) 3,
- (int) 4
+ (int) 5,
+ (int) 7
}
{[]int}:
-: []int{1, 2, 3, 4}
+: []int{1, 3, 5, 7}
`
actual = diff(
actual = cmpDiff(
[]int{1, 2, 3, 4},
[]int{1, 3, 5, 7},
)
Expand All @@ -1892,15 +1885,11 @@ Diff:
Diff:
--- Expected
+++ Actual
@@ -2,4 +2,4 @@
(int) 1,
- (int) 2,
- (int) 3
+ (int) 3,
+ (int) 5
}
{[]int}:
-: []int{1, 2, 3}
+: []int{1, 3, 5}
`
actual = diff(
actual = cmpDiff(
[]int{1, 2, 3, 4}[0:3],
[]int{1, 3, 5, 7}[0:3],
)
Expand All @@ -1911,19 +1900,21 @@ Diff:
Diff:
--- Expected
+++ Actual
@@ -1,6 +1,6 @@
(map[string]int) (len=4) {
- (string) (len=4) "four": (int) 4,
+ (string) (len=4) "five": (int) 5,
(string) (len=3) "one": (int) 1,
- (string) (len=5) "three": (int) 3,
- (string) (len=3) "two": (int) 2
+ (string) (len=5) "seven": (int) 7,
+ (string) (len=5) "three": (int) 3
}
{map[string]int}["five"]:
-: <non-existent>
+: 5
{map[string]int}["four"]:
-: 4
+: <non-existent>
{map[string]int}["seven"]:
-: <non-existent>
+: 7
{map[string]int}["two"]:
-: 2
+: <non-existent>
`

actual = diff(
actual = cmpDiff(
map[string]int{"one": 1, "two": 2, "three": 3, "four": 4},
map[string]int{"one": 1, "three": 3, "five": 5, "seven": 7},
)
Expand All @@ -1941,7 +1932,7 @@ Diff:
})
`

actual = diff(
actual = cmpDiff(
errors.New("some expected error"),
errors.New("actual error"),
)
Expand All @@ -1959,7 +1950,7 @@ Diff:
}
`

actual = diff(
actual = cmpDiff(
diffTestingStruct{A: "some string", B: 10},
diffTestingStruct{A: "some string", B: 15},
)
Expand All @@ -1976,12 +1967,12 @@ func TestTimeEqualityErrorFormatting(t *testing.T) {
}

func TestDiffEmptyCases(t *testing.T) {
Equal(t, "", diff(nil, nil))
Equal(t, "", diff(struct{ foo string }{}, nil))
Equal(t, "", diff(nil, struct{ foo string }{}))
Equal(t, "", diff(1, 2))
Equal(t, "", diff(1, 2))
Equal(t, "", diff([]int{1}, []bool{true}))
Equal(t, "", cmpDiff(nil, nil))
Equal(t, "", cmpDiff(struct{ foo string }{}, nil))
Equal(t, "", cmpDiff(nil, struct{ foo string }{}))
Equal(t, "", cmpDiff(1, 2))
Equal(t, "", cmpDiff(1, 2))
Equal(t, "", cmpDiff([]int{1}, []bool{true}))
}

// Ensure there are no data races
Expand All @@ -2007,7 +1998,7 @@ func TestDiffRace(t *testing.T) {
rChans[idx] = make(chan string)
go func(ch chan string) {
defer close(ch)
ch <- diff(expected, actual)
ch <- cmpDiff(expected, actual)
}(rChans[idx])
}

Expand Down Expand Up @@ -2064,7 +2055,7 @@ func TestBytesEqual(t *testing.T) {
{nil, make([]byte, 0)},
}
for i, c := range cases {
Equal(t, reflect.DeepEqual(c.a, c.b), ObjectsAreEqual(c.a, c.b), "case %d failed", i+1)
Equal(t, cmpEqual(c.a, c.b), ObjectsAreEqual(c.a, c.b), "case %d failed", i+1)
}
}

Expand Down
Loading
0