diff --git a/container/intsets/sparse.go b/container/intsets/sparse.go index 8847febf1d..abe172d34f 100644 --- a/container/intsets/sparse.go +++ b/container/intsets/sparse.go @@ -21,10 +21,6 @@ package intsets // import "golang.org/x/tools/container/intsets" // The space usage would be proportional to Max(), not Len(), and the // implementation would be based upon big.Int. // -// TODO(adonovan): experiment with making the root block indirect (nil -// iff IsEmpty). This would reduce the memory usage when empty and -// might simplify the aliasing invariants. -// // TODO(adonovan): opt: make UnionWith and Difference faster. // These are the hot-spots for go/pointer. @@ -45,9 +41,10 @@ type Sparse struct { // An uninitialized Sparse represents an empty set. // An empty set may also be represented by // root.next == root.prev == &root. - // In a non-empty set, root.next points to the first block and - // root.prev to the last. - // root.offset and root.bits are unused. + // + // The root is always the block with the smallest offset. + // It can be empty, but only if it is the only block; in that case, offset is + // MaxInt (which is not a valid offset). root block } @@ -144,7 +141,6 @@ func (b *block) len() int { // max returns the maximum element of the block. // The block must not be empty. -// func (b *block) max() int { bi := b.offset + bitsPerBlock // Decrement bi by number of high zeros in last.bits. @@ -161,7 +157,6 @@ func (b *block) max() int { // and also removes it if take is set. // The block must not be initially empty. // NB: may leave the block empty. -// func (b *block) min(take bool) int { for i, w := range b.bits { if w != 0 { @@ -204,14 +199,20 @@ func offsetAndBitIndex(x int) (int, uint) { // -- Sparse -------------------------------------------------------------- -// start returns the root's next block, which is the root block -// (if s.IsEmpty()) or the first true block otherwise. -// start has the side effect of ensuring that s is properly -// initialized. -// -func (s *Sparse) start() *block { +// none is a shared, empty, sentinel block that indicates the end of a block +// list. +var none block + +// Dummy type used to generate an implicit panic. This must be defined at the +// package level; if it is defined inside a function, it prevents the inlining +// of that function. +type to_copy_a_sparse_you_must_call_its_Copy_method struct{} + +// init ensures s is properly initialized. +func (s *Sparse) init() { root := &s.root if root.next == nil { + root.offset = MaxInt root.next = root root.prev = root } else if root.next.prev != root { @@ -219,21 +220,45 @@ func (s *Sparse) start() *block { // new Sparse y shares the old linked list, but iteration // on y will never encounter &y.root so it goes into a // loop. Fail fast before this occurs. - panic("A Sparse has been copied without (*Sparse).Copy()") + // We don't want to call panic here because it prevents the + // inlining of this function. + _ = (interface{}(nil)).(to_copy_a_sparse_you_must_call_its_Copy_method) } +} - return root.next +func (s *Sparse) first() *block { + s.init() + if s.root.offset == MaxInt { + return &none + } + return &s.root +} + +// next returns the next block in the list, or end if b is the last block. +func (s *Sparse) next(b *block) *block { + if b.next == &s.root { + return &none + } + return b.next +} + +// prev returns the previous block in the list, or end if b is the first block. +func (s *Sparse) prev(b *block) *block { + if b.prev == &s.root { + return &none + } + return b.prev } // IsEmpty reports whether the set s is empty. func (s *Sparse) IsEmpty() bool { - return s.start() == &s.root + return s.root.next == nil || s.root.offset == MaxInt } // Len returns the number of elements in the set s. func (s *Sparse) Len() int { var l int - for b := s.start(); b != &s.root; b = b.next { + for b := s.first(); b != &none; b = s.next(b) { l += b.len() } return l @@ -252,19 +277,16 @@ func (s *Sparse) Min() int { if s.IsEmpty() { return MaxInt } - return s.root.next.min(false) + return s.root.min(false) } // block returns the block that would contain offset, // or nil if s contains no such block. -// func (s *Sparse) block(offset int) *block { - b := s.start() - for b != &s.root && b.offset <= offset { + for b := s.first(); b != &none && b.offset <= offset; b = s.next(b) { if b.offset == offset { return b } - b = b.next } return nil } @@ -272,26 +294,49 @@ func (s *Sparse) block(offset int) *block { // Insert adds x to the set s, and reports whether the set grew. func (s *Sparse) Insert(x int) bool { offset, i := offsetAndBitIndex(x) - b := s.start() - for b != &s.root && b.offset <= offset { + + b := s.first() + for ; b != &none && b.offset <= offset; b = s.next(b) { if b.offset == offset { return b.insert(i) } - b = b.next } // Insert new block before b. - new := &block{offset: offset} - new.next = b - new.prev = b.prev - new.prev.next = new - new.next.prev = new + new := s.insertBlockBefore(b) + new.offset = offset return new.insert(i) } -func (s *Sparse) removeBlock(b *block) { - b.prev.next = b.next - b.next.prev = b.prev +// removeBlock removes a block and returns the block that followed it (or end if +// it was the last block). +func (s *Sparse) removeBlock(b *block) *block { + if b != &s.root { + b.prev.next = b.next + b.next.prev = b.prev + if b.next == &s.root { + return &none + } + return b.next + } + + first := s.root.next + if first == &s.root { + // This was the only block. + s.Clear() + return &none + } + s.root.offset = first.offset + s.root.bits = first.bits + if first.next == &s.root { + // Single block remaining. + s.root.next = &s.root + s.root.prev = &s.root + } else { + s.root.next = first.next + first.next.prev = &s.root + } + return &s.root } // Remove removes x from the set s, and reports whether the set shrank. @@ -311,8 +356,11 @@ func (s *Sparse) Remove(x int) bool { // Clear removes all elements from the set s. func (s *Sparse) Clear() { - s.root.next = &s.root - s.root.prev = &s.root + s.root = block{ + offset: MaxInt, + next: &s.root, + prev: &s.root, + } } // If set s is non-empty, TakeMin sets *p to the minimum element of @@ -325,13 +373,12 @@ func (s *Sparse) Clear() { // for worklist.TakeMin(&x) { use(x) } // func (s *Sparse) TakeMin(p *int) bool { - head := s.start() - if head == &s.root { + if s.IsEmpty() { return false } - *p = head.min(true) - if head.empty() { - s.removeBlock(head) + *p = s.root.min(true) + if s.root.empty() { + s.removeBlock(&s.root) } return true } @@ -352,7 +399,7 @@ func (s *Sparse) Has(x int) bool { // natural control flow with continue/break/return. // func (s *Sparse) forEach(f func(int)) { - for b := s.start(); b != &s.root; b = b.next { + for b := s.first(); b != &none; b = s.next(b) { b.forEach(f) } } @@ -363,22 +410,51 @@ func (s *Sparse) Copy(x *Sparse) { return } - xb := x.start() - sb := s.start() - for xb != &x.root { - if sb == &s.root { + xb := x.first() + sb := s.first() + for xb != &none { + if sb == &none { sb = s.insertBlockBefore(sb) } sb.offset = xb.offset sb.bits = xb.bits - xb = xb.next - sb = sb.next + xb = x.next(xb) + sb = s.next(sb) } s.discardTail(sb) } // insertBlockBefore returns a new block, inserting it before next. +// If next is the root, the root is replaced. If next is end, the block is +// inserted at the end. func (s *Sparse) insertBlockBefore(next *block) *block { + if s.IsEmpty() { + if next != &none { + panic("BUG: passed block with empty set") + } + return &s.root + } + + if next == &s.root { + // Special case: we need to create a new block that will become the root + // block.The old root block becomes the second block. + second := s.root + s.root = block{ + next: &second, + } + if second.next == &s.root { + s.root.prev = &second + } else { + s.root.prev = second.prev + second.next.prev = &second + second.prev = &s.root + } + return &s.root + } + if next == &none { + // Insert before root. + next = &s.root + } b := new(block) b.next = next b.prev = next.prev @@ -389,9 +465,13 @@ func (s *Sparse) insertBlockBefore(next *block) *block { // discardTail removes block b and all its successors from s. func (s *Sparse) discardTail(b *block) { - if b != &s.root { - b.prev.next = &s.root - s.root.prev = b.prev + if b != &none { + if b == &s.root { + s.Clear() + } else { + b.prev.next = &s.root + s.root.prev = b.prev + } } } @@ -401,16 +481,15 @@ func (s *Sparse) IntersectionWith(x *Sparse) { return } - xb := x.start() - sb := s.start() - for xb != &x.root && sb != &s.root { + xb := x.first() + sb := s.first() + for xb != &none && sb != &none { switch { case xb.offset < sb.offset: - xb = xb.next + xb = x.next(xb) case xb.offset > sb.offset: - sb = sb.next - s.removeBlock(sb.prev) + sb = s.removeBlock(sb) default: var sum word @@ -420,12 +499,12 @@ func (s *Sparse) IntersectionWith(x *Sparse) { sum |= r } if sum != 0 { - sb = sb.next + sb = s.next(sb) } else { // sb will be overwritten or removed } - xb = xb.next + xb = x.next(xb) } } @@ -446,20 +525,20 @@ func (s *Sparse) Intersection(x, y *Sparse) { return } - xb := x.start() - yb := y.start() - sb := s.start() - for xb != &x.root && yb != &y.root { + xb := x.first() + yb := y.first() + sb := s.first() + for xb != &none && yb != &none { switch { case xb.offset < yb.offset: - xb = xb.next + xb = x.next(xb) continue case xb.offset > yb.offset: - yb = yb.next + yb = y.next(yb) continue } - if sb == &s.root { + if sb == &none { sb = s.insertBlockBefore(sb) } sb.offset = xb.offset @@ -471,13 +550,13 @@ func (s *Sparse) Intersection(x, y *Sparse) { sum |= r } if sum != 0 { - sb = sb.next + sb = s.next(sb) } else { // sb will be overwritten or removed } - xb = xb.next - yb = yb.next + xb = x.next(xb) + yb = y.next(yb) } s.discardTail(sb) @@ -485,22 +564,22 @@ func (s *Sparse) Intersection(x, y *Sparse) { // Intersects reports whether s ∩ x ≠ ∅. func (s *Sparse) Intersects(x *Sparse) bool { - sb := s.start() - xb := x.start() - for sb != &s.root && xb != &x.root { + sb := s.first() + xb := x.first() + for sb != &none && xb != &none { switch { case xb.offset < sb.offset: - xb = xb.next + xb = x.next(xb) case xb.offset > sb.offset: - sb = sb.next + sb = s.next(sb) default: for i := range sb.bits { if sb.bits[i]&xb.bits[i] != 0 { return true } } - sb = sb.next - xb = xb.next + sb = s.next(sb) + xb = x.next(xb) } } return false @@ -513,26 +592,26 @@ func (s *Sparse) UnionWith(x *Sparse) bool { } var changed bool - xb := x.start() - sb := s.start() - for xb != &x.root { - if sb != &s.root && sb.offset == xb.offset { + xb := x.first() + sb := s.first() + for xb != &none { + if sb != &none && sb.offset == xb.offset { for i := range xb.bits { if sb.bits[i] != xb.bits[i] { sb.bits[i] |= xb.bits[i] changed = true } } - xb = xb.next - } else if sb == &s.root || sb.offset > xb.offset { + xb = x.next(xb) + } else if sb == &none || sb.offset > xb.offset { sb = s.insertBlockBefore(sb) sb.offset = xb.offset sb.bits = xb.bits changed = true - xb = xb.next + xb = x.next(xb) } - sb = sb.next + sb = s.next(sb) } return changed } @@ -551,33 +630,33 @@ func (s *Sparse) Union(x, y *Sparse) { return } - xb := x.start() - yb := y.start() - sb := s.start() - for xb != &x.root || yb != &y.root { - if sb == &s.root { + xb := x.first() + yb := y.first() + sb := s.first() + for xb != &none || yb != &none { + if sb == &none { sb = s.insertBlockBefore(sb) } switch { - case yb == &y.root || (xb != &x.root && xb.offset < yb.offset): + case yb == &none || (xb != &none && xb.offset < yb.offset): sb.offset = xb.offset sb.bits = xb.bits - xb = xb.next + xb = x.next(xb) - case xb == &x.root || (yb != &y.root && yb.offset < xb.offset): + case xb == &none || (yb != &none && yb.offset < xb.offset): sb.offset = yb.offset sb.bits = yb.bits - yb = yb.next + yb = y.next(yb) default: sb.offset = xb.offset for i := range xb.bits { sb.bits[i] = xb.bits[i] | yb.bits[i] } - xb = xb.next - yb = yb.next + xb = x.next(xb) + yb = y.next(yb) } - sb = sb.next + sb = s.next(sb) } s.discardTail(sb) @@ -590,15 +669,15 @@ func (s *Sparse) DifferenceWith(x *Sparse) { return } - xb := x.start() - sb := s.start() - for xb != &x.root && sb != &s.root { + xb := x.first() + sb := s.first() + for xb != &none && sb != &none { switch { case xb.offset > sb.offset: - sb = sb.next + sb = s.next(sb) case xb.offset < sb.offset: - xb = xb.next + xb = x.next(xb) default: var sum word @@ -607,12 +686,12 @@ func (s *Sparse) DifferenceWith(x *Sparse) { sb.bits[i] = r sum |= r } - sb = sb.next - xb = xb.next - if sum == 0 { - s.removeBlock(sb.prev) + sb = s.removeBlock(sb) + } else { + sb = s.next(sb) } + xb = x.next(xb) } } } @@ -633,27 +712,27 @@ func (s *Sparse) Difference(x, y *Sparse) { return } - xb := x.start() - yb := y.start() - sb := s.start() - for xb != &x.root && yb != &y.root { + xb := x.first() + yb := y.first() + sb := s.first() + for xb != &none && yb != &none { if xb.offset > yb.offset { - // y has block, x has none - yb = yb.next + // y has block, x has &none + yb = y.next(yb) continue } - if sb == &s.root { + if sb == &none { sb = s.insertBlockBefore(sb) } sb.offset = xb.offset switch { case xb.offset < yb.offset: - // x has block, y has none + // x has block, y has &none sb.bits = xb.bits - sb = sb.next + sb = s.next(sb) default: // x and y have corresponding blocks @@ -664,25 +743,25 @@ func (s *Sparse) Difference(x, y *Sparse) { sum |= r } if sum != 0 { - sb = sb.next + sb = s.next(sb) } else { // sb will be overwritten or removed } - yb = yb.next + yb = y.next(yb) } - xb = xb.next + xb = x.next(xb) } - for xb != &x.root { - if sb == &s.root { + for xb != &none { + if sb == &none { sb = s.insertBlockBefore(sb) } sb.offset = xb.offset sb.bits = xb.bits - sb = sb.next + sb = s.next(sb) - xb = xb.next + xb = x.next(xb) } s.discardTail(sb) @@ -695,17 +774,17 @@ func (s *Sparse) SymmetricDifferenceWith(x *Sparse) { return } - sb := s.start() - xb := x.start() - for xb != &x.root && sb != &s.root { + sb := s.first() + xb := x.first() + for xb != &none && sb != &none { switch { case sb.offset < xb.offset: - sb = sb.next + sb = s.next(sb) case xb.offset < sb.offset: nb := s.insertBlockBefore(sb) nb.offset = xb.offset nb.bits = xb.bits - xb = xb.next + xb = x.next(xb) default: var sum word for i := range sb.bits { @@ -713,20 +792,21 @@ func (s *Sparse) SymmetricDifferenceWith(x *Sparse) { sb.bits[i] = r sum |= r } - sb = sb.next - xb = xb.next if sum == 0 { - s.removeBlock(sb.prev) + sb = s.removeBlock(sb) + } else { + sb = s.next(sb) } + xb = x.next(xb) } } - for xb != &x.root { // append the tail of x to s + for xb != &none { // append the tail of x to s sb = s.insertBlockBefore(sb) sb.offset = xb.offset sb.bits = xb.bits - sb = sb.next - xb = xb.next + sb = s.next(sb) + xb = x.next(xb) } } @@ -744,24 +824,24 @@ func (s *Sparse) SymmetricDifference(x, y *Sparse) { return } - sb := s.start() - xb := x.start() - yb := y.start() - for xb != &x.root && yb != &y.root { - if sb == &s.root { + sb := s.first() + xb := x.first() + yb := y.first() + for xb != &none && yb != &none { + if sb == &none { sb = s.insertBlockBefore(sb) } switch { case yb.offset < xb.offset: sb.offset = yb.offset sb.bits = yb.bits - sb = sb.next - yb = yb.next + sb = s.next(sb) + yb = y.next(yb) case xb.offset < yb.offset: sb.offset = xb.offset sb.bits = xb.bits - sb = sb.next - xb = xb.next + sb = s.next(sb) + xb = x.next(xb) default: var sum word for i := range sb.bits { @@ -771,31 +851,31 @@ func (s *Sparse) SymmetricDifference(x, y *Sparse) { } if sum != 0 { sb.offset = xb.offset - sb = sb.next + sb = s.next(sb) } - xb = xb.next - yb = yb.next + xb = x.next(xb) + yb = y.next(yb) } } - for xb != &x.root { // append the tail of x to s - if sb == &s.root { + for xb != &none { // append the tail of x to s + if sb == &none { sb = s.insertBlockBefore(sb) } sb.offset = xb.offset sb.bits = xb.bits - sb = sb.next - xb = xb.next + sb = s.next(sb) + xb = x.next(xb) } - for yb != &y.root { // append the tail of y to s - if sb == &s.root { + for yb != &none { // append the tail of y to s + if sb == &none { sb = s.insertBlockBefore(sb) } sb.offset = yb.offset sb.bits = yb.bits - sb = sb.next - yb = yb.next + sb = s.next(sb) + yb = y.next(yb) } s.discardTail(sb) @@ -807,22 +887,22 @@ func (s *Sparse) SubsetOf(x *Sparse) bool { return true } - sb := s.start() - xb := x.start() - for sb != &s.root { + sb := s.first() + xb := x.first() + for sb != &none { switch { - case xb == &x.root || xb.offset > sb.offset: + case xb == &none || xb.offset > sb.offset: return false case xb.offset < sb.offset: - xb = xb.next + xb = x.next(xb) default: for i := range sb.bits { if sb.bits[i]&^xb.bits[i] != 0 { return false } } - sb = sb.next - xb = xb.next + sb = s.next(sb) + xb = x.next(xb) } } return true @@ -833,13 +913,13 @@ func (s *Sparse) Equals(t *Sparse) bool { if s == t { return true } - sb := s.start() - tb := t.start() + sb := s.first() + tb := t.first() for { switch { - case sb == &s.root && tb == &t.root: + case sb == &none && tb == &none: return true - case sb == &s.root || tb == &t.root: + case sb == &none || tb == &none: return false case sb.offset != tb.offset: return false @@ -847,8 +927,8 @@ func (s *Sparse) Equals(t *Sparse) bool { return false } - sb = sb.next - tb = tb.next + sb = s.next(sb) + tb = t.next(tb) } } @@ -913,7 +993,7 @@ func (s *Sparse) BitString() string { // func (s *Sparse) GoString() string { var buf bytes.Buffer - for b := s.start(); b != &s.root; b = b.next { + for b := s.first(); b != &none; b = s.next(b) { fmt.Fprintf(&buf, "block %p {offset=%d next=%p prev=%p", b, b.offset, b.next, b.prev) for _, w := range b.bits { @@ -937,13 +1017,18 @@ func (s *Sparse) AppendTo(slice []int) []int { // check returns an error if the representation invariants of s are violated. func (s *Sparse) check() error { - if !s.root.empty() { - return fmt.Errorf("non-empty root block") + s.init() + if s.root.empty() { + // An empty set must have only the root block with offset MaxInt. + if s.root.next != &s.root { + return fmt.Errorf("multiple blocks with empty root block") + } + if s.root.offset != MaxInt { + return fmt.Errorf("empty set has offset %d, should be MaxInt", s.root.offset) + } + return nil } - if s.root.offset != 0 { - return fmt.Errorf("root block has non-zero offset %d", s.root.offset) - } - for b := s.start(); b != &s.root; b = b.next { + for b := s.first(); ; b = s.next(b) { if b.offset%bitsPerBlock != 0 { return fmt.Errorf("bad offset modulo: %d", b.offset) } @@ -956,11 +1041,12 @@ func (s *Sparse) check() error { if b.next.prev != b { return fmt.Errorf("bad next.prev link") } - if b.prev != &s.root { - if b.offset <= b.prev.offset { - return fmt.Errorf("bad offset order: b.offset=%d, prev.offset=%d", - b.offset, b.prev.offset) - } + if b.next == &s.root { + break + } + if b.offset >= b.next.offset { + return fmt.Errorf("bad offset order: b.offset=%d, b.next.offset=%d", + b.offset, b.next.offset) } } return nil diff --git a/container/intsets/sparse_test.go b/container/intsets/sparse_test.go index 34b9a4e7f7..d9d4036aee 100644 --- a/container/intsets/sparse_test.go +++ b/container/intsets/sparse_test.go @@ -471,7 +471,7 @@ func TestIntersects(t *testing.T) { z.IntersectionWith(y) if got, want := x.Intersects(y), !z.IsEmpty(); got != want { - t.Errorf("Intersects: got %v, want %v", got, want) + t.Errorf("Intersects(%s, %s): got %v, want %v (%s)", x, y, got, want, &z) } // make it false @@ -563,7 +563,7 @@ func TestFailFastOnShallowCopy(t *testing.T) { y := x // shallow copy (breaks representation invariants) defer func() { got := fmt.Sprint(recover()) - want := "A Sparse has been copied without (*Sparse).Copy()" + want := "interface conversion: interface {} is nil, not intsets.to_copy_a_sparse_you_must_call_its_Copy_method" if got != want { t.Errorf("shallow copy: recover() = %q, want %q", got, want) } @@ -579,7 +579,60 @@ func TestFailFastOnShallowCopy(t *testing.T) { // - Gather set distributions from pointer analysis. // - Measure memory usage. -func BenchmarkSparseBitVector(b *testing.B) { +func benchmarkInsertProbeSparse(b *testing.B, size, spread int) { + prng := rand.New(rand.NewSource(0)) + // Generate our insertions and probes beforehand (we don't want to benchmark + // the prng). + insert := make([]int, size) + probe := make([]int, size*2) + for i := range insert { + insert[i] = prng.Int() % spread + } + for i := range probe { + probe[i] = prng.Int() % spread + } + + b.ResetTimer() + var x intsets.Sparse + for tries := 0; tries < b.N; tries++ { + x.Clear() + for _, n := range insert { + x.Insert(n) + } + hits := 0 + for _, n := range probe { + if x.Has(n) { + hits++ + } + } + // Use the variable so it doesn't get optimized away. + if hits > len(probe) { + b.Fatalf("%d hits, only %d probes", hits, len(probe)) + } + } +} + +func BenchmarkInsertProbeSparse_2_10(b *testing.B) { + benchmarkInsertProbeSparse(b, 2, 10) +} + +func BenchmarkInsertProbeSparse_10_10(b *testing.B) { + benchmarkInsertProbeSparse(b, 10, 10) +} + +func BenchmarkInsertProbeSparse_10_1000(b *testing.B) { + benchmarkInsertProbeSparse(b, 10, 1000) +} + +func BenchmarkInsertProbeSparse_100_100(b *testing.B) { + benchmarkInsertProbeSparse(b, 100, 100) +} + +func BenchmarkInsertProbeSparse_100_10000(b *testing.B) { + benchmarkInsertProbeSparse(b, 100, 1000) +} + +func BenchmarkUnionDifferenceSparse(b *testing.B) { prng := rand.New(rand.NewSource(0)) for tries := 0; tries < b.N; tries++ { var x, y, z intsets.Sparse @@ -596,7 +649,7 @@ func BenchmarkSparseBitVector(b *testing.B) { } } -func BenchmarkHashTable(b *testing.B) { +func BenchmarkUnionDifferenceHashTable(b *testing.B) { prng := rand.New(rand.NewSource(0)) for tries := 0; tries < b.N; tries++ { x, y, z := make(map[int]bool), make(map[int]bool), make(map[int]bool)