diff --git a/src/cmd/compile/internal/ssa/prove.go b/src/cmd/compile/internal/ssa/prove.go index 8a17302a01..172d210216 100644 --- a/src/cmd/compile/internal/ssa/prove.go +++ b/src/cmd/compile/internal/ssa/prove.go @@ -12,7 +12,7 @@ import ( type branch int const ( - unknown = iota + unknown branch = iota positive negative ) @@ -149,6 +149,14 @@ type limitFact struct { // by the facts table is effective for real code while remaining very // efficient. type factsTable struct { + // unsat is true if facts contains a contradiction. + // + // Note that the factsTable logic is incomplete, so if unsat + // is false, the assertions in factsTable could be satisfiable + // *or* unsatisfiable. + unsat bool // true if facts contains a contradiction + unsatDepth int // number of unsat checkpoints + facts map[pair]relation // current known set of relation stack []fact // previous sets of relations @@ -177,89 +185,6 @@ func newFactsTable() *factsTable { return ft } -// get returns the known possible relations between v and w. -// If v and w are not in the map it returns lt|eq|gt, i.e. any order. -func (ft *factsTable) get(v, w *Value, d domain) relation { - if v.isGenericIntConst() || w.isGenericIntConst() { - reversed := false - if v.isGenericIntConst() { - v, w = w, v - reversed = true - } - r := lt | eq | gt - lim, ok := ft.limits[v.ID] - if !ok { - return r - } - c := w.AuxInt - switch d { - case signed: - switch { - case c < lim.min: - r = gt - case c > lim.max: - r = lt - case c == lim.min && c == lim.max: - r = eq - case c == lim.min: - r = gt | eq - case c == lim.max: - r = lt | eq - } - case unsigned: - // TODO: also use signed data if lim.min >= 0? - var uc uint64 - switch w.Op { - case OpConst64: - uc = uint64(c) - case OpConst32: - uc = uint64(uint32(c)) - case OpConst16: - uc = uint64(uint16(c)) - case OpConst8: - uc = uint64(uint8(c)) - } - switch { - case uc < lim.umin: - r = gt - case uc > lim.umax: - r = lt - case uc == lim.umin && uc == lim.umax: - r = eq - case uc == lim.umin: - r = gt | eq - case uc == lim.umax: - r = lt | eq - } - } - if reversed { - return reverseBits[r] - } - return r - } - - reversed := false - if lessByID(w, v) { - v, w = w, v - reversed = !reversed - } - - p := pair{v, w, d} - r, ok := ft.facts[p] - if !ok { - if p.v == p.w { - r = eq - } else { - r = lt | eq | gt - } - } - - if reversed { - return reverseBits[r] - } - return r -} - // update updates the set of relations between v and w in domain d // restricting it to r. func (ft *factsTable) update(parent *Block, v, w *Value, d domain, r relation) { @@ -269,9 +194,19 @@ func (ft *factsTable) update(parent *Block, v, w *Value, d domain, r relation) { } p := pair{v, w, d} - oldR := ft.get(v, w, d) + oldR, ok := ft.facts[p] + if !ok { + if v == w { + oldR = eq + } else { + oldR = lt | eq | gt + } + } ft.stack = append(ft.stack, fact{p, oldR}) ft.facts[p] = oldR & r + if oldR&r == 0 { + ft.unsat = true + } // Extract bounds when comparing against constants if v.isGenericIntConst() { @@ -350,6 +285,9 @@ func (ft *factsTable) update(parent *Block, v, w *Value, d domain, r relation) { ft.limitStack = append(ft.limitStack, limitFact{v.ID, old}) lim = old.intersect(lim) ft.limits[v.ID] = lim + if lim.min > lim.max || lim.umin > lim.umax { + ft.unsat = true + } if v.Block.Func.pass.debug > 2 { v.Block.Func.Warnl(parent.Pos, "parent=%s, new limits %s %s %s", parent, v, w, lim.String()) } @@ -368,6 +306,9 @@ func (ft *factsTable) isNonNegative(v *Value) bool { // checkpoint saves the current state of known relations. // Called when descending on a branch. func (ft *factsTable) checkpoint() { + if ft.unsat { + ft.unsatDepth++ + } ft.stack = append(ft.stack, checkpointFact) ft.limitStack = append(ft.limitStack, checkpointBound) } @@ -376,6 +317,11 @@ func (ft *factsTable) checkpoint() { // before the previous checkpoint. // Called when backing up on a branch. func (ft *factsTable) restore() { + if ft.unsatDepth > 0 { + ft.unsatDepth-- + } else { + ft.unsat = false + } for { old := ft.stack[len(ft.stack)-1] ft.stack = ft.stack[:len(ft.stack)-1] @@ -505,6 +451,14 @@ var ( // The second comparison i >= len(a) is clearly redundant because if the // else branch of the first comparison is executed, we already know that i < len(a). // The code for the second panic can be removed. +// +// prove works by finding contradictions and trimming branches whose +// conditions are unsatisfiable given the branches leading up to them. +// It tracks a "fact table" of branch conditions. For each branching +// block, it asserts the branch conditions that uniquely dominate that +// block, and then separately asserts the block's branch condition and +// its negation. If either leads to a contradiction, it can trim that +// successor. func prove(f *Func) { ft := newFactsTable() @@ -552,6 +506,15 @@ func prove(f *Func) { sdom := f.sdom() // DFS on the dominator tree. + // + // For efficiency, we consider only the dominator tree rather + // than the entire flow graph. On the way down, we consider + // incoming branches and accumulate conditions that uniquely + // dominate the current block. If we discover a contradiction, + // we can eliminate the entire block and all of its children. + // On the way back up, we consider outgoing branches that + // haven't already been considered. This way we consider each + // branch condition only once. for len(work) > 0 { node := work[len(work)-1] work = work[:len(work)-1] @@ -561,14 +524,16 @@ func prove(f *Func) { switch node.state { case descend: if branch != unknown { - ft.checkpoint() - c := parent.Control - updateRestrictions(parent, ft, boolean, nil, c, lt|gt, branch) - if tr, has := domainRelationTable[parent.Control.Op]; has { - // When we branched from parent we learned a new set of - // restrictions. Update the factsTable accordingly. - updateRestrictions(parent, ft, tr.d, c.Args[0], c.Args[1], tr.r, branch) + if !tryPushBranch(ft, parent, branch) { + // node.block is unreachable. + // Remove it and don't visit + // its children. + removeBranch(parent, branch) + break } + // Otherwise, we can now commit to + // taking this branch. We'll restore + // ft when we unwind. } work = append(work, bp{ @@ -583,18 +548,10 @@ func prove(f *Func) { } case simplify: - succ := simplifyBlock(ft, node.block) - if succ != unknown { - b := node.block - b.Kind = BlockFirst - b.SetControl(nil) - if succ == negative { - b.swapSuccessors() - } - } + simplifyBlock(sdom, ft, node.block) if branch != unknown { - ft.restore() + popBranch(ft) } } } @@ -621,6 +578,38 @@ func getBranch(sdom SparseTree, p *Block, b *Block) branch { return unknown } +// tryPushBranch tests whether it is possible to branch from Block b +// in direction br and, if so, pushes the branch conditions in the +// factsTable and returns true. A successful tryPushBranch must be +// paired with a popBranch. +func tryPushBranch(ft *factsTable, b *Block, br branch) bool { + ft.checkpoint() + c := b.Control + updateRestrictions(b, ft, boolean, nil, c, lt|gt, br) + if tr, has := domainRelationTable[b.Control.Op]; has { + // When we branched from parent we learned a new set of + // restrictions. Update the factsTable accordingly. + updateRestrictions(b, ft, tr.d, c.Args[0], c.Args[1], tr.r, br) + } + if ft.unsat { + // This branch's conditions contradict some known + // fact, so it cannot be taken. Unwind the facts. + // + // (Since we never checkpoint an unsat factsTable, we + // don't really need factsTable.unsatDepth, but + // there's no cost to keeping checkpoint/restore more + // general.) + ft.restore() + return false + } + return true +} + +// popBranch undoes the effects of a successful tryPushBranch. +func popBranch(ft *factsTable) { + ft.restore() +} + // updateRestrictions updates restrictions from the immediate // dominating block (p) using r. r is adjusted according to the branch taken. func updateRestrictions(parent *Block, ft *factsTable, t domain, v, w *Value, r relation, branch branch) { @@ -639,6 +628,31 @@ func updateRestrictions(parent *Block, ft *factsTable, t domain, v, w *Value, r } ft.update(parent, v, w, i, r) + if i == boolean && v == nil && w != nil && (w.Op == OpIsInBounds || w.Op == OpIsSliceInBounds) { + // 0 <= a0 < a1 (or 0 <= a0 <= a1) + // + // domainRelationTable handles the a0 / a1 + // relation, but not the 0 / a0 relation. + // + // On the positive branch we learn 0 <= a0, + // but this turns out never to be useful. + // + // On the negative branch we learn (0 > a0 || + // a0 >= a1) (or (0 > a0 || a0 > a1)). We + // can't express an || condition, but we learn + // something if we can disprove the LHS. + if r == eq && ft.isNonNegative(w.Args[0]) { + // false == w, so we're on the + // negative branch. a0 >= 0, so the + // LHS is false. Thus, the RHS holds. + opr := eq | gt + if w.Op == OpIsSliceInBounds { + opr = gt + } + ft.update(parent, w.Args[0], w.Args[1], signed, opr) + } + } + // Additional facts we know given the relationship between len and cap. if i != signed && i != unsigned { continue @@ -666,9 +680,9 @@ func updateRestrictions(parent *Block, ft *factsTable, t domain, v, w *Value, r } } -// simplifyBlock simplifies block known the restrictions in ft. -// Returns which branch must always be taken. -func simplifyBlock(ft *factsTable, b *Block) branch { +// simplifyBlock simplifies some constant values in b and evaluates +// branches to non-uniquely dominated successors of b. +func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) { // Replace OpSlicemask operations in b with constants where possible. for _, v := range b.Values { if v.Op != OpSlicemask { @@ -709,94 +723,53 @@ func simplifyBlock(ft *factsTable, b *Block) branch { } if b.Kind != BlockIf { - return unknown + return } - // First, checks if the condition itself is redundant. - m := ft.get(nil, b.Control, boolean) - if m == lt|gt { - if b.Func.pass.debug > 0 { - if b.Func.pass.debug > 1 { - b.Func.Warnl(b.Pos, "Proved boolean %s (%s)", b.Control.Op, b.Control) - } else { - b.Func.Warnl(b.Pos, "Proved boolean %s", b.Control.Op) - } - } - return positive - } - if m == eq { - if b.Func.pass.debug > 0 { - if b.Func.pass.debug > 1 { - b.Func.Warnl(b.Pos, "Disproved boolean %s (%s)", b.Control.Op, b.Control) - } else { - b.Func.Warnl(b.Pos, "Disproved boolean %s", b.Control.Op) - } - } - return negative - } - - // Next look check equalities. - c := b.Control - tr, has := domainRelationTable[c.Op] - if !has { - return unknown - } - - a0, a1 := c.Args[0], c.Args[1] - for d := domain(1); d <= tr.d; d <<= 1 { - if d&tr.d == 0 { + // Consider outgoing edges from this block. + parent := b + for i, branch := range [...]branch{positive, negative} { + child := parent.Succs[i].b + if getBranch(sdom, parent, child) != unknown { + // For edges to uniquely dominated blocks, we + // already did this when we visited the child. continue } - - // tr.r represents in which case the positive branch is taken. - // m represents which cases are possible because of previous relations. - // If the set of possible relations m is included in the set of relations - // need to take the positive branch (or negative) then that branch will - // always be taken. - // For shortcut, if m == 0 then this block is dead code. - m := ft.get(a0, a1, d) - if m != 0 && tr.r&m == m { - if b.Func.pass.debug > 0 { - if b.Func.pass.debug > 1 { - b.Func.Warnl(b.Pos, "Proved %s (%s)", c.Op, c) - } else { - b.Func.Warnl(b.Pos, "Proved %s", c.Op) - } - } - return positive + // For edges to other blocks, this can trim a branch + // even if we couldn't get rid of the child itself. + if !tryPushBranch(ft, parent, branch) { + // This branch is impossible, so remove it + // from the block. + removeBranch(parent, branch) + // No point in considering the other branch. + // (It *is* possible for both to be + // unsatisfiable since the fact table is + // incomplete. We could turn this into a + // BlockExit, but it doesn't seem worth it.) + break } - if m != 0 && ((lt|eq|gt)^tr.r)&m == m { - if b.Func.pass.debug > 0 { - if b.Func.pass.debug > 1 { - b.Func.Warnl(b.Pos, "Disproved %s (%s)", c.Op, c) - } else { - b.Func.Warnl(b.Pos, "Disproved %s", c.Op) - } - } - return negative + popBranch(ft) + } +} + +func removeBranch(b *Block, branch branch) { + if b.Func.pass.debug > 0 { + verb := "Proved" + if branch == positive { + verb = "Disproved" + } + c := b.Control + if b.Func.pass.debug > 1 { + b.Func.Warnl(b.Pos, "%s %s (%s)", verb, c.Op, c) + } else { + b.Func.Warnl(b.Pos, "%s %s", verb, c.Op) } } - - // HACK: If the first argument of IsInBounds or IsSliceInBounds - // is a constant and we already know that constant is smaller (or equal) - // to the upper bound than this is proven. Most useful in cases such as: - // if len(a) <= 1 { return } - // do something with a[1] - if (c.Op == OpIsInBounds || c.Op == OpIsSliceInBounds) && ft.isNonNegative(c.Args[0]) { - m := ft.get(a0, a1, signed) - if m != 0 && tr.r&m == m { - if b.Func.pass.debug > 0 { - if b.Func.pass.debug > 1 { - b.Func.Warnl(b.Pos, "Proved non-negative bounds %s (%s)", c.Op, c) - } else { - b.Func.Warnl(b.Pos, "Proved non-negative bounds %s", c.Op) - } - } - return positive - } + b.Kind = BlockFirst + b.SetControl(nil) + if branch == positive { + b.swapSuccessors() } - - return unknown } // isNonNegative returns true is v is known to be greater or equal to zero. diff --git a/test/prove.go b/test/prove.go index e89ab3f8d8..2f4fa5d308 100644 --- a/test/prove.go +++ b/test/prove.go @@ -11,11 +11,11 @@ import "math" func f0(a []int) int { a[0] = 1 - a[0] = 1 // ERROR "Proved boolean IsInBounds$" + a[0] = 1 // ERROR "Proved IsInBounds$" a[6] = 1 - a[6] = 1 // ERROR "Proved boolean IsInBounds$" + a[6] = 1 // ERROR "Proved IsInBounds$" + a[5] = 1 // ERROR "Proved IsInBounds$" a[5] = 1 // ERROR "Proved IsInBounds$" - a[5] = 1 // ERROR "Proved boolean IsInBounds$" return 13 } @@ -23,24 +23,24 @@ func f1(a []int) int { if len(a) <= 5 { return 18 } - a[0] = 1 // ERROR "Proved non-negative bounds IsInBounds$" - a[0] = 1 // ERROR "Proved boolean IsInBounds$" + a[0] = 1 // ERROR "Proved IsInBounds$" + a[0] = 1 // ERROR "Proved IsInBounds$" a[6] = 1 - a[6] = 1 // ERROR "Proved boolean IsInBounds$" + a[6] = 1 // ERROR "Proved IsInBounds$" + a[5] = 1 // ERROR "Proved IsInBounds$" a[5] = 1 // ERROR "Proved IsInBounds$" - a[5] = 1 // ERROR "Proved boolean IsInBounds$" return 26 } func f1b(a []int, i int, j uint) int { if i >= 0 && i < len(a) { - return a[i] // ERROR "Proved non-negative bounds IsInBounds$" + return a[i] // ERROR "Proved IsInBounds$" } if i >= 10 && i < len(a) { - return a[i] // ERROR "Proved non-negative bounds IsInBounds$" + return a[i] // ERROR "Proved IsInBounds$" } if i >= 10 && i < len(a) { - return a[i] // ERROR "Proved non-negative bounds IsInBounds$" + return a[i] // ERROR "Proved IsInBounds$" } if i >= 10 && i < len(a) { // todo: handle this case return a[i-10] @@ -64,7 +64,7 @@ func f1c(a []int, i int64) int { func f2(a []int) int { for i := range a { a[i+1] = i - a[i+1] = i // ERROR "Proved boolean IsInBounds$" + a[i+1] = i // ERROR "Proved IsInBounds$" } return 34 } @@ -84,15 +84,14 @@ func f4a(a, b, c int) int { if a > b { // ERROR "Disproved Greater64$" return 50 } - if a < b { // ERROR "Proved boolean Less64$" + if a < b { // ERROR "Proved Less64$" return 53 } - if a == b { // ERROR "Disproved boolean Eq64$" + // We can't get to this point and prove knows that, so + // there's no message for the next (obvious) branch. + if a != a { return 56 } - if a > b { // ERROR "Disproved boolean Greater64$" - return 59 - } return 61 } return 63 @@ -127,8 +126,8 @@ func f4c(a, b, c int) int { func f4d(a, b, c int) int { if a < b { if a < c { - if a < b { // ERROR "Proved boolean Less64$" - if a < c { // ERROR "Proved boolean Less64$" + if a < b { // ERROR "Proved Less64$" + if a < c { // ERROR "Proved Less64$" return 87 } return 89 @@ -218,8 +217,8 @@ func f6e(a uint8) int { func f7(a []int, b int) int { if b < len(a) { a[b] = 3 - if b < len(a) { // ERROR "Proved boolean Less64$" - a[b] = 5 // ERROR "Proved boolean IsInBounds$" + if b < len(a) { // ERROR "Proved Less64$" + a[b] = 5 // ERROR "Proved IsInBounds$" } } return 161 @@ -242,7 +241,7 @@ func f9(a, b bool) int { if a { return 1 } - if a || b { // ERROR "Disproved boolean Arg$" + if a || b { // ERROR "Disproved Arg$" return 2 } return 3 @@ -260,22 +259,22 @@ func f10(a string) int { func f11a(a []int, i int) { useInt(a[i]) - useInt(a[i]) // ERROR "Proved boolean IsInBounds$" + useInt(a[i]) // ERROR "Proved IsInBounds$" } func f11b(a []int, i int) { useSlice(a[i:]) - useSlice(a[i:]) // ERROR "Proved boolean IsSliceInBounds$" + useSlice(a[i:]) // ERROR "Proved IsSliceInBounds$" } func f11c(a []int, i int) { useSlice(a[:i]) - useSlice(a[:i]) // ERROR "Proved boolean IsSliceInBounds$" + useSlice(a[:i]) // ERROR "Proved IsSliceInBounds$" } func f11d(a []int, i int) { useInt(a[2*i+7]) - useInt(a[2*i+7]) // ERROR "Proved boolean IsInBounds$" + useInt(a[2*i+7]) // ERROR "Proved IsInBounds$" } func f12(a []int, b int) { @@ -305,7 +304,7 @@ func f13a(a, b, c int, x bool) int { } } if x { - if a > 12 { // ERROR "Proved boolean Greater64$" + if a > 12 { // ERROR "Proved Greater64$" return 5 } } @@ -327,7 +326,7 @@ func f13b(a int, x bool) int { } } if x { - if a == -9 { // ERROR "Proved boolean Eq64$" + if a == -9 { // ERROR "Proved Eq64$" return 9 } } @@ -349,7 +348,7 @@ func f13b(a int, x bool) int { func f13c(a int, x bool) int { if a < 90 { if x { - if a < 90 { // ERROR "Proved boolean Less64$" + if a < 90 { // ERROR "Proved Less64$" return 13 } } @@ -450,7 +449,7 @@ func f14(p, q *int, a []int) { j := *q i2 := *p useInt(a[i1+j]) - useInt(a[i2+j]) // ERROR "Proved boolean IsInBounds$" + useInt(a[i2+j]) // ERROR "Proved IsInBounds$" } func f15(s []int, x int) { @@ -460,11 +459,32 @@ func f15(s []int, x int) { func f16(s []int) []int { if len(s) >= 10 { - return s[:10] // ERROR "Proved non-negative bounds IsSliceInBounds$" + return s[:10] // ERROR "Proved IsSliceInBounds$" } return nil } +func f17(b []int) { + for i := 0; i < len(b); i++ { + useSlice(b[i:]) // Learns i <= len + // This tests for i <= cap, which we can only prove + // using the derived relation between len and cap. + // This depends on finding the contradiction, since we + // don't query this condition directly. + useSlice(b[:i]) // ERROR "Proved IsSliceInBounds$" + } +} + +func sm1(b []int, x int) { + // Test constant argument to slicemask. + useSlice(b[2:8]) // ERROR "Proved slicemask not needed$" + // Test non-constant argument with known limits. + // Right now prove only uses the unsigned limit. + if uint(cap(b)) > 10 { + useSlice(b[2:]) // ERROR "Proved slicemask not needed$" + } +} + //go:noinline func useInt(a int) { } diff --git a/test/sliceopt.go b/test/sliceopt.go index eb24701f31..b8b947229c 100644 --- a/test/sliceopt.go +++ b/test/sliceopt.go @@ -1,4 +1,4 @@ -// errorcheck -0 -d=append,slice,ssa/prove/debug=1 +// errorcheck -0 -d=append,slice // Copyright 2015 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style @@ -21,51 +21,12 @@ func a3(x *[]int, y int) { *x = append(*x, y) // ERROR "append: len-only update$" } -// s1_if_false_then_anything -func s1_if_false_then_anything(x **[]int, xs **string, i, j int) { - z := (**x)[0:i] - z = z[i : i+1] - println(z) // if we get here, then we have proven that i==i+1 (this cannot happen, but the program is still being analyzed...) - - zs := (**xs)[0:i] // since i=i+1 is proven, i+1 is "in bounds", ha-ha - zs = zs[i : i+1] // ERROR "Proved boolean IsSliceInBounds$" - println(zs) -} - func s1(x **[]int, xs **string, i, j int) { var z []int - z = (**x)[2:] - z = (**x)[2:len(**x)] // ERROR "Proved boolean IsSliceInBounds$" - z = (**x)[2:cap(**x)] // ERROR "Proved IsSliceInBounds$" - z = (**x)[i:i] // -ERROR "Proved IsSliceInBounds" - z = (**x)[1:i:i] // ERROR "Proved boolean IsSliceInBounds$" - z = (**x)[i:j:0] - z = (**x)[i:0:j] // ERROR "Disproved IsSliceInBounds$" - z = (**x)[0:i:j] // ERROR "Proved boolean IsSliceInBounds$" - z = (**x)[0:] // ERROR "slice: omit slice operation$" - z = (**x)[2:8] // ERROR "Proved slicemask not needed$" - println(z) - z = (**x)[2:2] - z = (**x)[0:i] - z = (**x)[2:i:8] // ERROR "Disproved IsSliceInBounds$" "Proved IsSliceInBounds$" - z = (**x)[i:2:i] // ERROR "Proved IsSliceInBounds$" "Proved boolean IsSliceInBounds$" - - z = z[0:i] // ERROR "Proved boolean IsSliceInBounds" - z = z[0:i : i+1] - z = z[i : i+1] // ERROR "Proved boolean IsSliceInBounds$" - + z = (**x)[0:] // ERROR "slice: omit slice operation$" println(z) var zs string - zs = (**xs)[2:] - zs = (**xs)[2:len(**xs)] // ERROR "Proved IsSliceInBounds$" "Proved boolean IsSliceInBounds$" - zs = (**xs)[i:i] // -ERROR "Proved boolean IsSliceInBounds" - zs = (**xs)[0:] // ERROR "slice: omit slice operation$" - zs = (**xs)[2:8] - zs = (**xs)[2:2] // ERROR "Proved boolean IsSliceInBounds$" - zs = (**xs)[0:i] // ERROR "Proved boolean IsSliceInBounds$" - - zs = zs[0:i] // See s1_if_false_then_anything above to explain the counterfactual bounds check result below - zs = zs[i : i+1] // ERROR "Proved boolean IsSliceInBounds$" + zs = (**xs)[0:] // ERROR "slice: omit slice operation$" println(zs) }