diff --git a/src/io/io.go b/src/io/io.go index 28dab08e46..86710ed6f3 100644 --- a/src/io/io.go +++ b/src/io/io.go @@ -335,7 +335,7 @@ func ReadFull(r Reader, buf []byte) (n int, err error) { // If dst implements the ReaderFrom interface, // the copy is implemented using it. func CopyN(dst Writer, src Reader, n int64) (written int64, err error) { - written, err = Copy(dst, LimitReader(src, n)) + written, err = copyN(dst, src, n) if written == n { return n, nil } @@ -346,6 +346,55 @@ func CopyN(dst Writer, src Reader, n int64) (written int64, err error) { return } +// copyN copies n bytes (or until an error) from src to dst. +// It returns the number of bytes copied and the earliest +// error encountered while copying. +// +// If dst implements the ReaderFrom interface, +// the copy is implemented using it. +func copyN(dst Writer, src Reader, n int64) (int64, error) { + // If the writer has a ReadFrom method, use it to do the copy. + if rt, ok := dst.(ReaderFrom); ok { + return rt.ReadFrom(LimitReader(src, n)) + } + + l := 32 * 1024 // same size as in copyBuffer + if n < int64(l) { + l = int(n) + } + buf := make([]byte, l) + + var written int64 + for n > 0 { + if n < int64(len(buf)) { + buf = buf[:n] + } + + nr, errR := src.Read(buf) + if nr > 0 { + n -= int64(nr) + nw, errW := dst.Write(buf[:nr]) + if nw > 0 { + written += int64(nw) + } + if errW != nil { + return written, errW + } + if nr != nw { + return written, ErrShortWrite + } + } + + if errR != nil { + if errR != EOF { + return written, errR + } + return written, nil + } + } + return written, nil +} + // Copy copies from src to dst until either EOF is reached // on src or an error occurs. It returns the number of bytes // copied and the first error encountered while copying, if any.