diff --git a/src/bytes/bytes.go b/src/bytes/bytes.go index 4bc375df19..8198415c3e 100644 --- a/src/bytes/bytes.go +++ b/src/bytes/bytes.go @@ -1192,19 +1192,22 @@ func Replace(s, old, new []byte, n int) []byte { t := make([]byte, len(s)+n*(len(new)-len(old))) w := 0 start := 0 - for i := 0; i < n; i++ { - j := start - if len(old) == 0 { - if i > 0 { - _, wid := utf8.DecodeRune(s[start:]) - j += wid - } - } else { - j += Index(s[start:], old) + if len(old) > 0 { + for range n { + j := start + Index(s[start:], old) + w += copy(t[w:], s[start:j]) + w += copy(t[w:], new) + start = j + len(old) } - w += copy(t[w:], s[start:j]) + } else { // len(old) == 0 w += copy(t[w:], new) - start = j + len(old) + for range n - 1 { + _, wid := utf8.DecodeRune(s[start:]) + j := start + wid + w += copy(t[w:], s[start:j]) + w += copy(t[w:], new) + start = j + } } w += copy(t[w:], s[start:]) return t[0:w] diff --git a/src/bytes/bytes_test.go b/src/bytes/bytes_test.go index ead581718a..14b52a8035 100644 --- a/src/bytes/bytes_test.go +++ b/src/bytes/bytes_test.go @@ -7,6 +7,7 @@ package bytes_test import ( . "bytes" "fmt" + "internal/asan" "internal/testenv" "iter" "math" @@ -1786,9 +1787,20 @@ var ReplaceTests = []ReplaceTest{ func TestReplace(t *testing.T) { for _, tt := range ReplaceTests { - in := append([]byte(tt.in), ""...) + var ( + in = []byte(tt.in) + old = []byte(tt.old) + new = []byte(tt.new) + ) + if !asan.Enabled { + allocs := testing.AllocsPerRun(10, func() { Replace(in, old, new, tt.n) }) + if allocs > 1 { + t.Errorf("Replace(%q, %q, %q, %d) allocates %.2f objects", tt.in, tt.old, tt.new, tt.n, allocs) + } + } + in = append(in, ""...) in = in[:len(tt.in)] - out := Replace(in, []byte(tt.old), []byte(tt.new), tt.n) + out := Replace(in, old, new, tt.n) if s := string(out); s != tt.out { t.Errorf("Replace(%q, %q, %q, %d) = %q, want %q", tt.in, tt.old, tt.new, tt.n, s, tt.out) } @@ -1796,7 +1808,7 @@ func TestReplace(t *testing.T) { t.Errorf("Replace(%q, %q, %q, %d) didn't copy", tt.in, tt.old, tt.new, tt.n) } if tt.n == -1 { - out := ReplaceAll(in, []byte(tt.old), []byte(tt.new)) + out := ReplaceAll(in, old, new) if s := string(out); s != tt.out { t.Errorf("ReplaceAll(%q, %q, %q) = %q, want %q", tt.in, tt.old, tt.new, s, tt.out) } @@ -1804,6 +1816,69 @@ func TestReplace(t *testing.T) { } } +func FuzzReplace(f *testing.F) { + for _, tt := range ReplaceTests { + f.Add([]byte(tt.in), []byte(tt.old), []byte(tt.new), tt.n) + } + f.Fuzz(func(t *testing.T, in, old, new []byte, n int) { + differentImpl := func(in, old, new []byte, n int) []byte { + var out Buffer + if n < 0 { + n = math.MaxInt + } + for i := 0; i < len(in); { + if n == 0 { + out.Write(in[i:]) + break + } + if HasPrefix(in[i:], old) { + out.Write(new) + i += len(old) + n-- + if len(old) != 0 { + continue + } + if i == len(in) { + break + } + } + if len(old) == 0 { + _, length := utf8.DecodeRune(in[i:]) + out.Write(in[i : i+length]) + i += length + } else { + out.WriteByte(in[i]) + i++ + } + } + if len(old) == 0 && n != 0 { + out.Write(new) + } + return out.Bytes() + } + if simple, replace := differentImpl(in, old, new, n), Replace(in, old, new, n); !slices.Equal(simple, replace) { + t.Errorf("The two implementations do not match %q != %q for Replace(%q, %q, %q, %d)", simple, replace, in, old, new, n) + } + }) +} + +func BenchmarkReplace(b *testing.B) { + for _, tt := range ReplaceTests { + desc := fmt.Sprintf("%q %q %q %d", tt.in, tt.old, tt.new, tt.n) + var ( + in = []byte(tt.in) + old = []byte(tt.old) + new = []byte(tt.new) + ) + b.Run(desc, func(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + Replace(in, old, new, tt.n) + } + }) + } +} + type TitleTest struct { in, out string }