From ebc8013edfc009b1190c656e738b15fe9729cc89 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 17 Jan 2012 10:44:35 -0800 Subject: [PATCH] exp/sql: copy when scanning into []byte by default Fixes #2698 R=rsc CC=golang-dev https://golang.org/cl/5539060 --- src/pkg/exp/sql/sql.go | 30 ++++++++++++++++++++++++--- src/pkg/exp/sql/sql_test.go | 41 +++++++++++++++++++++++++++++++++++-- 2 files changed, 66 insertions(+), 5 deletions(-) diff --git a/src/pkg/exp/sql/sql.go b/src/pkg/exp/sql/sql.go index 4e68c3ee09..cba7e9ebe5 100644 --- a/src/pkg/exp/sql/sql.go +++ b/src/pkg/exp/sql/sql.go @@ -30,6 +30,11 @@ func Register(name string, driver driver.Driver) { drivers[name] = driver } +// RawBytes is a byte slice that holds a reference to memory owned by +// the database itself. After a Scan into a RawBytes, the slice is only +// valid until the next call to Next, Scan, or Close. +type RawBytes []byte + // NullableString represents a string that may be null. // NullableString implements the ScannerInto interface so // it can be used as a scan destination: @@ -760,9 +765,13 @@ func (rs *Rows) Columns() ([]string, error) { } // Scan copies the columns in the current row into the values pointed -// at by dest. If dest contains pointers to []byte, the slices should -// not be modified and should only be considered valid until the next -// call to Next or Scan. +// at by dest. +// +// If an argument has type *[]byte, Scan saves in that argument a copy +// of the corresponding data. The copy is owned by the caller and can +// be modified and held indefinitely. The copy can be avoided by using +// an argument of type *RawBytes instead; see the documentation for +// RawBytes for restrictions on its use. func (rs *Rows) Scan(dest ...interface{}) error { if rs.closed { return errors.New("sql: Rows closed") @@ -782,6 +791,18 @@ func (rs *Rows) Scan(dest ...interface{}) error { return fmt.Errorf("sql: Scan error on column index %d: %v", i, err) } } + for _, dp := range dest { + b, ok := dp.(*[]byte) + if !ok { + continue + } + if _, ok = dp.(*RawBytes); ok { + continue + } + clone := make([]byte, len(*b)) + copy(clone, *b) + *b = clone + } return nil } @@ -838,6 +859,9 @@ func (r *Row) Scan(dest ...interface{}) error { // they were obtained from the network anyway) But for now we // don't care. for _, dp := range dest { + if _, ok := dp.(*RawBytes); ok { + return errors.New("sql: RawBytes isn't allowed on Row.Scan") + } b, ok := dp.(*[]byte) if !ok { continue diff --git a/src/pkg/exp/sql/sql_test.go b/src/pkg/exp/sql/sql_test.go index 3f98a8cd9f..30cd97d176 100644 --- a/src/pkg/exp/sql/sql_test.go +++ b/src/pkg/exp/sql/sql_test.go @@ -76,7 +76,7 @@ func TestQuery(t *testing.T) { {age: 3, name: "Chris"}, } if !reflect.DeepEqual(got, want) { - t.Logf(" got: %#v\nwant: %#v", got, want) + t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) } // And verify that the final rows.Next() call, which hit EOF, @@ -86,6 +86,43 @@ func TestQuery(t *testing.T) { } } +func TestByteOwnership(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + rows, err := db.Query("SELECT|people|name,photo|") + if err != nil { + t.Fatalf("Query: %v", err) + } + type row struct { + name []byte + photo RawBytes + } + got := []row{} + for rows.Next() { + var r row + err = rows.Scan(&r.name, &r.photo) + if err != nil { + t.Fatalf("Scan: %v", err) + } + got = append(got, r) + } + corruptMemory := []byte("\xffPHOTO") + want := []row{ + {name: []byte("Alice"), photo: corruptMemory}, + {name: []byte("Bob"), photo: corruptMemory}, + {name: []byte("Chris"), photo: corruptMemory}, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) + } + + var photo RawBytes + err = db.QueryRow("SELECT|people|photo|name=?", "Alice").Scan(&photo) + if err == nil { + t.Error("want error scanning into RawBytes from QueryRow") + } +} + func TestRowsColumns(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db) @@ -300,6 +337,6 @@ func TestQueryRowClosingStmt(t *testing.T) { } fakeConn := db.freeConn[0].(*fakeConn) if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed { - t.Logf("statement close mismatch: made %d, closed %d", made, closed) + t.Errorf("statement close mismatch: made %d, closed %d", made, closed) } }