diff options
Diffstat (limited to 'libgo/go/database/sql/fakedb_test.go')
-rw-r--r-- | libgo/go/database/sql/fakedb_test.go | 37 |
1 files changed, 27 insertions, 10 deletions
diff --git a/libgo/go/database/sql/fakedb_test.go b/libgo/go/database/sql/fakedb_test.go index 889e2a25232..184e7756c51 100644 --- a/libgo/go/database/sql/fakedb_test.go +++ b/libgo/go/database/sql/fakedb_test.go @@ -82,6 +82,7 @@ type fakeConn struct { mu sync.Mutex stmtsMade int stmtsClosed int + numPrepare int } func (c *fakeConn) incrStat(v *int) { @@ -208,16 +209,19 @@ func (c *fakeConn) Begin() (driver.Tx, error) { func (c *fakeConn) Close() error { if c.currTx != nil { - return errors.New("can't close; in a Transaction") + return errors.New("can't close fakeConn; in a Transaction") } if c.db == nil { - return errors.New("can't close; already closed") + return errors.New("can't close fakeConn; already closed") + } + if c.stmtsMade > c.stmtsClosed { + return errors.New("can't close; dangling statement(s)") } c.db = nil return nil } -func checkSubsetTypes(args []interface{}) error { +func checkSubsetTypes(args []driver.Value) error { for n, arg := range args { switch arg.(type) { case int64, float64, bool, nil, []byte, string, time.Time: @@ -228,7 +232,7 @@ func checkSubsetTypes(args []interface{}) error { return nil } -func (c *fakeConn) Exec(query string, args []interface{}) (driver.Result, error) { +func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { // This is an optional interface, but it's implemented here // just to check that all the args of of the proper types. // ErrSkip is returned so the caller acts as if we didn't @@ -249,6 +253,7 @@ func errf(msg string, args ...interface{}) error { // just a limitation for fakedb) func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) { if len(parts) != 3 { + stmt.Close() return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) } stmt.table = parts[0] @@ -259,14 +264,17 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e } nameVal := strings.Split(colspec, "=") if len(nameVal) != 2 { + stmt.Close() return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) } column, value := nameVal[0], nameVal[1] _, ok := c.db.columnType(stmt.table, column) if !ok { + stmt.Close() return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) } if value != "?" { + stmt.Close() return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", stmt.table, column) } @@ -279,12 +287,14 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e // parts are table|col=type,col2=type2 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) { if len(parts) != 2 { + stmt.Close() return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) } stmt.table = parts[0] for n, colspec := range strings.Split(parts[1], ",") { nameType := strings.Split(colspec, "=") if len(nameType) != 2 { + stmt.Close() return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) } stmt.colName = append(stmt.colName, nameType[0]) @@ -296,17 +306,20 @@ func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, e // parts are table|col=?,col2=val func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) { if len(parts) != 2 { + stmt.Close() return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) } stmt.table = parts[0] for n, colspec := range strings.Split(parts[1], ",") { nameVal := strings.Split(colspec, "=") if len(nameVal) != 2 { + stmt.Close() return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) } column, value := nameVal[0], nameVal[1] ctype, ok := c.db.columnType(stmt.table, column) if !ok { + stmt.Close() return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) } stmt.colName = append(stmt.colName, column) @@ -322,10 +335,12 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e case "int32": i, err := strconv.Atoi(value) if err != nil { + stmt.Close() return nil, errf("invalid conversion to int32 from %q", value) } subsetVal = int64(i) // int64 is a subset type, but not int32 default: + stmt.Close() return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) } stmt.colValue = append(stmt.colValue, subsetVal) @@ -339,6 +354,7 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e } func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { + c.numPrepare++ if c.db == nil { panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) } @@ -360,6 +376,7 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { case "INSERT": return c.prepareInsert(stmt, parts) default: + stmt.Close() return nil, errf("unsupported command type %q", cmd) } return stmt, nil @@ -379,7 +396,7 @@ func (s *fakeStmt) Close() error { var errClosed = errors.New("fakedb: statement has been closed") -func (s *fakeStmt) Exec(args []interface{}) (driver.Result, error) { +func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { if s.closed { return nil, errClosed } @@ -392,12 +409,12 @@ func (s *fakeStmt) Exec(args []interface{}) (driver.Result, error) { switch s.cmd { case "WIPE": db.wipe() - return driver.DDLSuccess, nil + return driver.ResultNoRows, nil case "CREATE": if err := db.createTable(s.table, s.colName, s.colType); err != nil { return nil, err } - return driver.DDLSuccess, nil + return driver.ResultNoRows, nil case "INSERT": return s.execInsert(args) } @@ -405,7 +422,7 @@ func (s *fakeStmt) Exec(args []interface{}) (driver.Result, error) { return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd) } -func (s *fakeStmt) execInsert(args []interface{}) (driver.Result, error) { +func (s *fakeStmt) execInsert(args []driver.Value) (driver.Result, error) { db := s.c.db if len(args) != s.placeholders { panic("error in pkg db; should only get here if size is correct") @@ -441,7 +458,7 @@ func (s *fakeStmt) execInsert(args []interface{}) (driver.Result, error) { return driver.RowsAffected(1), nil } -func (s *fakeStmt) Query(args []interface{}) (driver.Rows, error) { +func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { if s.closed { return nil, errClosed } @@ -548,7 +565,7 @@ func (rc *rowsCursor) Columns() []string { return rc.cols } -func (rc *rowsCursor) Next(dest []interface{}) error { +func (rc *rowsCursor) Next(dest []driver.Value) error { if rc.closed { return errors.New("fakedb: cursor is closed") } |