aboutsummaryrefslogtreecommitdiff
path: root/libgo/go/database/sql/fakedb_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/database/sql/fakedb_test.go')
-rw-r--r--libgo/go/database/sql/fakedb_test.go37
1 files changed, 27 insertions, 10 deletions
diff --git a/libgo/go/database/sql/fakedb_test.go b/libgo/go/database/sql/fakedb_test.go
index 889e2a25232..184e7756c51 100644
--- a/libgo/go/database/sql/fakedb_test.go
+++ b/libgo/go/database/sql/fakedb_test.go
@@ -82,6 +82,7 @@ type fakeConn struct {
mu sync.Mutex
stmtsMade int
stmtsClosed int
+ numPrepare int
}
func (c *fakeConn) incrStat(v *int) {
@@ -208,16 +209,19 @@ func (c *fakeConn) Begin() (driver.Tx, error) {
func (c *fakeConn) Close() error {
if c.currTx != nil {
- return errors.New("can't close; in a Transaction")
+ return errors.New("can't close fakeConn; in a Transaction")
}
if c.db == nil {
- return errors.New("can't close; already closed")
+ return errors.New("can't close fakeConn; already closed")
+ }
+ if c.stmtsMade > c.stmtsClosed {
+ return errors.New("can't close; dangling statement(s)")
}
c.db = nil
return nil
}
-func checkSubsetTypes(args []interface{}) error {
+func checkSubsetTypes(args []driver.Value) error {
for n, arg := range args {
switch arg.(type) {
case int64, float64, bool, nil, []byte, string, time.Time:
@@ -228,7 +232,7 @@ func checkSubsetTypes(args []interface{}) error {
return nil
}
-func (c *fakeConn) Exec(query string, args []interface{}) (driver.Result, error) {
+func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
// This is an optional interface, but it's implemented here
// just to check that all the args of of the proper types.
// ErrSkip is returned so the caller acts as if we didn't
@@ -249,6 +253,7 @@ func errf(msg string, args ...interface{}) error {
// just a limitation for fakedb)
func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
if len(parts) != 3 {
+ stmt.Close()
return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
}
stmt.table = parts[0]
@@ -259,14 +264,17 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e
}
nameVal := strings.Split(colspec, "=")
if len(nameVal) != 2 {
+ stmt.Close()
return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
}
column, value := nameVal[0], nameVal[1]
_, ok := c.db.columnType(stmt.table, column)
if !ok {
+ stmt.Close()
return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
}
if value != "?" {
+ stmt.Close()
return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
stmt.table, column)
}
@@ -279,12 +287,14 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e
// parts are table|col=type,col2=type2
func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
if len(parts) != 2 {
+ stmt.Close()
return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
}
stmt.table = parts[0]
for n, colspec := range strings.Split(parts[1], ",") {
nameType := strings.Split(colspec, "=")
if len(nameType) != 2 {
+ stmt.Close()
return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
}
stmt.colName = append(stmt.colName, nameType[0])
@@ -296,17 +306,20 @@ func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, e
// parts are table|col=?,col2=val
func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
if len(parts) != 2 {
+ stmt.Close()
return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
}
stmt.table = parts[0]
for n, colspec := range strings.Split(parts[1], ",") {
nameVal := strings.Split(colspec, "=")
if len(nameVal) != 2 {
+ stmt.Close()
return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
}
column, value := nameVal[0], nameVal[1]
ctype, ok := c.db.columnType(stmt.table, column)
if !ok {
+ stmt.Close()
return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
}
stmt.colName = append(stmt.colName, column)
@@ -322,10 +335,12 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
case "int32":
i, err := strconv.Atoi(value)
if err != nil {
+ stmt.Close()
return nil, errf("invalid conversion to int32 from %q", value)
}
subsetVal = int64(i) // int64 is a subset type, but not int32
default:
+ stmt.Close()
return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
}
stmt.colValue = append(stmt.colValue, subsetVal)
@@ -339,6 +354,7 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
}
func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
+ c.numPrepare++
if c.db == nil {
panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
}
@@ -360,6 +376,7 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
case "INSERT":
return c.prepareInsert(stmt, parts)
default:
+ stmt.Close()
return nil, errf("unsupported command type %q", cmd)
}
return stmt, nil
@@ -379,7 +396,7 @@ func (s *fakeStmt) Close() error {
var errClosed = errors.New("fakedb: statement has been closed")
-func (s *fakeStmt) Exec(args []interface{}) (driver.Result, error) {
+func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
if s.closed {
return nil, errClosed
}
@@ -392,12 +409,12 @@ func (s *fakeStmt) Exec(args []interface{}) (driver.Result, error) {
switch s.cmd {
case "WIPE":
db.wipe()
- return driver.DDLSuccess, nil
+ return driver.ResultNoRows, nil
case "CREATE":
if err := db.createTable(s.table, s.colName, s.colType); err != nil {
return nil, err
}
- return driver.DDLSuccess, nil
+ return driver.ResultNoRows, nil
case "INSERT":
return s.execInsert(args)
}
@@ -405,7 +422,7 @@ func (s *fakeStmt) Exec(args []interface{}) (driver.Result, error) {
return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd)
}
-func (s *fakeStmt) execInsert(args []interface{}) (driver.Result, error) {
+func (s *fakeStmt) execInsert(args []driver.Value) (driver.Result, error) {
db := s.c.db
if len(args) != s.placeholders {
panic("error in pkg db; should only get here if size is correct")
@@ -441,7 +458,7 @@ func (s *fakeStmt) execInsert(args []interface{}) (driver.Result, error) {
return driver.RowsAffected(1), nil
}
-func (s *fakeStmt) Query(args []interface{}) (driver.Rows, error) {
+func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
if s.closed {
return nil, errClosed
}
@@ -548,7 +565,7 @@ func (rc *rowsCursor) Columns() []string {
return rc.cols
}
-func (rc *rowsCursor) Next(dest []interface{}) error {
+func (rc *rowsCursor) Next(dest []driver.Value) error {
if rc.closed {
return errors.New("fakedb: cursor is closed")
}