diff --git a/src/cmd/compile/internal/compare/compare.go b/src/cmd/compile/internal/compare/compare.go index 0e78013cf3..1674065556 100644 --- a/src/cmd/compile/internal/compare/compare.go +++ b/src/cmd/compile/internal/compare/compare.go @@ -166,7 +166,10 @@ func calculateCostForType(t *types.Type) int64 { // It works by building a list of boolean conditions to satisfy. // Conditions must be evaluated in the returned order and // properly short-circuited by the caller. -func EqStruct(t *types.Type, np, nq ir.Node) []ir.Node { +// The first return value is the flattened list of conditions, +// the second value is a boolean indicating whether any of the +// comparisons could panic. +func EqStruct(t *types.Type, np, nq ir.Node) ([]ir.Node, bool) { // The conditions are a list-of-lists. Conditions are reorderable // within each inner list. The outer lists must be evaluated in order. var conds [][]ir.Node @@ -187,9 +190,11 @@ func EqStruct(t *types.Type, np, nq ir.Node) []ir.Node { continue } + typeCanPanic := EqCanPanic(f.Type) + // Compare non-memory fields with field equality. if !IsRegularMemory(f.Type) { - if EqCanPanic(f.Type) { + if typeCanPanic { // Enforce ordering by starting a new set of reorderable conditions. conds = append(conds, []ir.Node{}) } @@ -203,7 +208,7 @@ func EqStruct(t *types.Type, np, nq ir.Node) []ir.Node { default: and(ir.NewBinaryExpr(base.Pos, ir.OEQ, p, q)) } - if EqCanPanic(f.Type) { + if typeCanPanic { // Also enforce ordering after something that can panic. conds = append(conds, []ir.Node{}) } @@ -238,7 +243,7 @@ func EqStruct(t *types.Type, np, nq ir.Node) []ir.Node { }) flatConds = append(flatConds, c...) } - return flatConds + return flatConds, len(conds) > 1 } // EqString returns the nodes diff --git a/src/cmd/compile/internal/reflectdata/alg.go b/src/cmd/compile/internal/reflectdata/alg.go index 10240b2f1f..69de685ca0 100644 --- a/src/cmd/compile/internal/reflectdata/alg.go +++ b/src/cmd/compile/internal/reflectdata/alg.go @@ -14,6 +14,7 @@ import ( "cmd/compile/internal/typecheck" "cmd/compile/internal/types" "cmd/internal/obj" + "cmd/internal/src" ) // AlgType returns the fixed-width AMEMxx variants instead of the general @@ -507,7 +508,66 @@ func eqFunc(t *types.Type) *ir.Func { // p[i] == q[i] return ir.NewBinaryExpr(base.Pos, ir.OEQ, pi, qi) }) - // TODO: pick apart structs, do them piecemeal too + case types.TSTRUCT: + isCall := func(n ir.Node) bool { + return n.Op() == ir.OCALL || n.Op() == ir.OCALLFUNC + } + var expr ir.Node + var hasCallExprs bool + allCallExprs := true + and := func(cond ir.Node) { + if expr == nil { + expr = cond + } else { + expr = ir.NewLogicalExpr(base.Pos, ir.OANDAND, expr, cond) + } + } + + var tmpPos src.XPos + pi := ir.NewIndexExpr(tmpPos, np, ir.NewInt(tmpPos, 0)) + pi.SetBounded(true) + pi.SetType(t.Elem()) + qi := ir.NewIndexExpr(tmpPos, nq, ir.NewInt(tmpPos, 0)) + qi.SetBounded(true) + qi.SetType(t.Elem()) + flatConds, canPanic := compare.EqStruct(t.Elem(), pi, qi) + for _, c := range flatConds { + if isCall(c) { + hasCallExprs = true + } else { + allCallExprs = false + } + } + if !hasCallExprs || allCallExprs || canPanic { + checkAll(1, true, func(pi, qi ir.Node) ir.Node { + // p[i] == q[i] + return ir.NewBinaryExpr(base.Pos, ir.OEQ, pi, qi) + }) + } else { + checkAll(4, false, func(pi, qi ir.Node) ir.Node { + expr = nil + flatConds, _ := compare.EqStruct(t.Elem(), pi, qi) + if len(flatConds) == 0 { + return ir.NewBool(base.Pos, true) + } + for _, c := range flatConds { + if !isCall(c) { + and(c) + } + } + return expr + }) + checkAll(2, true, func(pi, qi ir.Node) ir.Node { + expr = nil + flatConds, _ := compare.EqStruct(t.Elem(), pi, qi) + for _, c := range flatConds { + if isCall(c) { + and(c) + } + } + return expr + }) + } default: checkAll(1, true, func(pi, qi ir.Node) ir.Node { // p[i] == q[i] @@ -516,7 +576,7 @@ func eqFunc(t *types.Type) *ir.Func { } case types.TSTRUCT: - flatConds := compare.EqStruct(t, np, nq) + flatConds, _ := compare.EqStruct(t, np, nq) if len(flatConds) == 0 { fn.Body.Append(ir.NewAssignStmt(base.Pos, nr, ir.NewBool(base.Pos, true))) } else { diff --git a/src/cmd/compile/internal/reflectdata/alg_test.go b/src/cmd/compile/internal/reflectdata/alg_test.go index a1fc8c590c..38fb974f61 100644 --- a/src/cmd/compile/internal/reflectdata/alg_test.go +++ b/src/cmd/compile/internal/reflectdata/alg_test.go @@ -4,7 +4,9 @@ package reflectdata_test -import "testing" +import ( + "testing" +) func BenchmarkEqArrayOfStrings5(b *testing.B) { var a [5]string @@ -75,6 +77,56 @@ func BenchmarkEqArrayOfFloats1024(b *testing.B) { } } +func BenchmarkEqArrayOfStructsEq(b *testing.B) { + type T2 struct { + a string + b int + } + const size = 1024 + var ( + str1 = "foobar" + + a [size]T2 + c [size]T2 + ) + + for i := 0; i < size; i++ { + a[i].a = str1 + c[i].a = str1 + } + + b.ResetTimer() + for j := 0; j < b.N; j++ { + _ = a == c + } +} + +func BenchmarkEqArrayOfStructsNotEq(b *testing.B) { + type T2 struct { + a string + b int + } + const size = 1024 + var ( + str1 = "foobar" + str2 = "foobarz" + + a [size]T2 + c [size]T2 + ) + + for i := 0; i < size; i++ { + a[i].a = str1 + c[i].a = str1 + } + c[len(c)-1].a = str2 + + b.ResetTimer() + for j := 0; j < b.N; j++ { + _ = a == c + } +} + const size = 16 type T1 struct { diff --git a/src/cmd/compile/internal/walk/compare.go b/src/cmd/compile/internal/walk/compare.go index 58d6b57496..625cfecee0 100644 --- a/src/cmd/compile/internal/walk/compare.go +++ b/src/cmd/compile/internal/walk/compare.go @@ -228,7 +228,7 @@ func walkCompare(n *ir.BinaryExpr, init *ir.Nodes) ir.Node { cmpl = safeExpr(cmpl, init) cmpr = safeExpr(cmpr, init) if t.IsStruct() { - conds := compare.EqStruct(t, cmpl, cmpr) + conds, _ := compare.EqStruct(t, cmpl, cmpr) if n.Op() == ir.OEQ { for _, cond := range conds { and(cond) diff --git a/test/fixedbugs/issue8606.go b/test/fixedbugs/issue8606.go index 8c85069695..6bac02a1da 100644 --- a/test/fixedbugs/issue8606.go +++ b/test/fixedbugs/issue8606.go @@ -30,7 +30,17 @@ func main() { s string j interface{} } + type S3 struct { + f any + i int + } + type S4 struct { + a [1000]byte + b any + } b := []byte{1} + s1 := S3{func() {}, 0} + s2 := S3{func() {}, 1} for _, test := range []struct { panic bool @@ -64,6 +74,9 @@ func main() { {false, T3{s: "foo", j: b}, T3{s: "bar", j: b}}, {true, T3{i: b, s: "fooz"}, T3{i: b, s: "bar"}}, {false, T3{s: "fooz", j: b}, T3{s: "bar", j: b}}, + {true, A{s1, s2}, A{s2, s1}}, + {true, s1, s2}, + {false, S4{[1000]byte{0}, func() {}}, S4{[1000]byte{1}, func() {}}}, } { f := func() { defer func() {