Skip to content

Commit 397e2f5

Browse files
Exec() now provides access to status of multiple statements. (#1309)
It now reports the last inserted ID and affected row count for all statements, not just the last one. This is useful to execute batches of statements such as UPDATE with minimal roundtrips. Co-authored-by: Inada Naoki <[email protected]>
1 parent f43effa commit 397e2f5

File tree

9 files changed

+259
-50
lines changed

9 files changed

+259
-50
lines changed

‎README.md‎

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,22 @@ Allow multiple statements in one query. This can be used to bach multiple querie
305305

306306
When `multiStatements` is used, `?` parameters must only be used in the first statement. [interpolateParams](#interpolateparams) can be used to avoid this limitation unless prepared statement is used explicitly.
307307

308+
It's possible to access the last inserted ID and number of affected rows for multiple statements by using `sql.Conn.Raw()` and the `mysql.Result`. For example:
309+
310+
```go
311+
conn, _:= db.Conn(ctx)
312+
conn.Raw(func(conn interface{}) error{
313+
ex:= conn.(driver.Execer)
314+
res, err:= ex.Exec(`
315+
UPDATE point SET x = 1 WHERE y = 2;
316+
UPDATE point SET x = 2 WHERE y = 3;
317+
`, nil)
318+
// Both slices have 2 elements.
319+
log.Print(res.(mysql.Result).AllRowsAffected())
320+
log.Print(res.(mysql.Result).AllLastInsertIds())
321+
})
322+
```
323+
308324
##### `parseTime`
309325

310326
```

‎auth.go‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error{
346346
case1:
347347
switchauthData[0]{
348348
casecachingSha2PasswordFastAuthSuccess:
349-
iferr=mc.readResultOK(); err==nil{
349+
iferr=mc.resultUnchanged().readResultOK(); err==nil{
350350
returnnil// auth successful
351351
}
352352

@@ -397,7 +397,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error{
397397
returnerr
398398
}
399399
}
400-
returnmc.readResultOK()
400+
returnmc.resultUnchanged().readResultOK()
401401

402402
default:
403403
returnErrMalformPkt
@@ -426,7 +426,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error{
426426
iferr!=nil{
427427
returnerr
428428
}
429-
returnmc.readResultOK()
429+
returnmc.resultUnchanged().readResultOK()
430430
}
431431

432432
default:

‎connection.go‎

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ import (
2323
typemysqlConnstruct{
2424
bufbuffer
2525
netConn net.Conn
26-
rawConn net.Conn// underlying connection when netConn is TLS connection.
27-
affectedRowsuint64
28-
insertIduint64
26+
rawConn net.Conn// underlying connection when netConn is TLS connection.
27+
resultmysqlResult// managed by clearResult() and handleOkPacket().
2928
cfg*Config
3029
connector*connector
3130
maxAllowedPacketint
@@ -155,6 +154,7 @@ func (mc *mysqlConn) cleanup(){
155154
iferr:=mc.netConn.Close(); err!=nil{
156155
mc.cfg.Logger.Print(err)
157156
}
157+
mc.clearResult()
158158
}
159159

160160
func (mc*mysqlConn) error() error{
@@ -316,28 +316,25 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
316316
}
317317
query=prepared
318318
}
319-
mc.affectedRows=0
320-
mc.insertId=0
321319

322320
err:=mc.exec(query)
323321
iferr==nil{
324-
return&mysqlResult{
325-
affectedRows: int64(mc.affectedRows),
326-
insertId: int64(mc.insertId),
327-
}, err
322+
copied:=mc.result
323+
return&copied, err
328324
}
329325
returnnil, mc.markBadConn(err)
330326
}
331327

332328
// Internal function to execute commands
333329
func (mc*mysqlConn) exec(querystring) error{
330+
handleOk:=mc.clearResult()
334331
// Send command
335332
iferr:=mc.writeCommandPacketStr(comQuery, query); err!=nil{
336333
returnmc.markBadConn(err)
337334
}
338335

339336
// Read Result
340-
resLen, err:=mc.readResultSetHeaderPacket()
337+
resLen, err:=handleOk.readResultSetHeaderPacket()
341338
iferr!=nil{
342339
returnerr
343340
}
@@ -354,14 +351,16 @@ func (mc *mysqlConn) exec(query string) error{
354351
}
355352
}
356353

357-
returnmc.discardResults()
354+
returnhandleOk.discardResults()
358355
}
359356

360357
func (mc*mysqlConn) Query(querystring, args []driver.Value) (driver.Rows, error){
361358
returnmc.query(query, args)
362359
}
363360

364361
func (mc*mysqlConn) query(querystring, args []driver.Value) (*textRows, error){
362+
handleOk:=mc.clearResult()
363+
365364
ifmc.closed.Load(){
366365
mc.cfg.Logger.Print(ErrInvalidConn)
367366
returnnil, driver.ErrBadConn
@@ -382,7 +381,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
382381
iferr==nil{
383382
// Read Result
384383
varresLenint
385-
resLen, err=mc.readResultSetHeaderPacket()
384+
resLen, err=handleOk.readResultSetHeaderPacket()
386385
iferr==nil{
387386
rows:=new(textRows)
388387
rows.mc=mc
@@ -410,12 +409,13 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
410409
// The returned byte slice is only valid until the next read
411410
func (mc*mysqlConn) getSystemVar(namestring) ([]byte, error){
412411
// Send command
412+
handleOk:=mc.clearResult()
413413
iferr:=mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err!=nil{
414414
returnnil, err
415415
}
416416

417417
// Read Result
418-
resLen, err:=mc.readResultSetHeaderPacket()
418+
resLen, err:=handleOk.readResultSetHeaderPacket()
419419
iferr==nil{
420420
rows:=new(textRows)
421421
rows.mc=mc
@@ -466,11 +466,12 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error){
466466
}
467467
defermc.finish()
468468

469+
handleOk:=mc.clearResult()
469470
iferr=mc.writeCommandPacket(comPing); err!=nil{
470471
returnmc.markBadConn(err)
471472
}
472473

473-
returnmc.readResultOK()
474+
returnhandleOk.readResultOK()
474475
}
475476

476477
// BeginTx implements driver.ConnBeginTx interface

‎driver_test.go‎

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2154,11 +2154,51 @@ func TestRejectReadOnly(t *testing.T){
21542154
}
21552155

21562156
funcTestPing(t*testing.T){
2157+
ctx:=context.Background()
21572158
runTests(t, dsn, func(dbt*DBTest){
21582159
iferr:=dbt.db.Ping(); err!=nil{
21592160
dbt.fail("Ping", "Ping", err)
21602161
}
21612162
})
2163+
2164+
runTests(t, dsn, func(dbt*DBTest){
2165+
conn, err:=dbt.db.Conn(ctx)
2166+
iferr!=nil{
2167+
dbt.fail("db", "Conn", err)
2168+
}
2169+
2170+
// Check that affectedRows and insertIds are cleared after each call.
2171+
conn.Raw(func(conninterface{}) error{
2172+
c:=conn.(*mysqlConn)
2173+
2174+
// Issue a query that sets affectedRows and insertIds.
2175+
q, err:=c.Query(`SELECT 1`, nil)
2176+
iferr!=nil{
2177+
dbt.fail("Conn", "Query", err)
2178+
}
2179+
ifgot, want:=c.result.affectedRows, []int64{0}; !reflect.DeepEqual(got, want){
2180+
dbt.Fatalf("bad affectedRows: got %v, want=%v", got, want)
2181+
}
2182+
ifgot, want:=c.result.insertIds, []int64{0}; !reflect.DeepEqual(got, want){
2183+
dbt.Fatalf("bad insertIds: got %v, want=%v", got, want)
2184+
}
2185+
q.Close()
2186+
2187+
// Verify that Ping() clears both fields.
2188+
fori:=0; i<2; i++{
2189+
iferr:=c.Ping(ctx); err!=nil{
2190+
dbt.fail("Pinger", "Ping", err)
2191+
}
2192+
ifgot, want:=c.result.affectedRows, []int64(nil); !reflect.DeepEqual(got, want){
2193+
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
2194+
}
2195+
ifgot, want:=c.result.insertIds, []int64(nil); !reflect.DeepEqual(got, want){
2196+
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
2197+
}
2198+
}
2199+
returnnil
2200+
})
2201+
})
21622202
}
21632203

21642204
// See Issue #799
@@ -2378,6 +2418,42 @@ func TestMultiResultSetNoSelect(t *testing.T){
23782418
})
23792419
}
23802420

2421+
funcTestExecMultipleResults(t*testing.T){
2422+
ctx:=context.Background()
2423+
runTestsWithMultiStatement(t, dsn, func(dbt*DBTest){
2424+
dbt.mustExec(`
2425+
CREATE TABLE test (
2426+
id INT NOT NULL AUTO_INCREMENT,
2427+
value VARCHAR(255),
2428+
PRIMARY KEY (id)
2429+
)`)
2430+
conn, err:=dbt.db.Conn(ctx)
2431+
iferr!=nil{
2432+
t.Fatalf("failed to connect: %v", err)
2433+
}
2434+
conn.Raw(func(conninterface{}) error{
2435+
ex:=conn.(driver.Execer)
2436+
res, err:=ex.Exec(`
2437+
INSERT INTO test (value) VALUES ('a'), ('b');
2438+
INSERT INTO test (value) VALUES ('c'), ('d'), ('e');
2439+
`, nil)
2440+
iferr!=nil{
2441+
t.Fatalf("insert statements failed: %v", err)
2442+
}
2443+
mres:=res.(Result)
2444+
ifgot, want:=mres.AllRowsAffected(), []int64{2, 3}; !reflect.DeepEqual(got, want){
2445+
t.Errorf("bad AllRowsAffected: got %v, want=%v", got, want)
2446+
}
2447+
// For INSERTs containing multiple rows, LAST_INSERT_ID() returns the
2448+
// first inserted ID, not the last.
2449+
ifgot, want:=mres.AllLastInsertIds(), []int64{1, 3}; !reflect.DeepEqual(got, want){
2450+
t.Errorf("bad AllLastInsertIds: got %v, want %v", got, want)
2451+
}
2452+
returnnil
2453+
})
2454+
})
2455+
}
2456+
23812457
// tests if rows are set in a proper state if some results were ignored before
23822458
// calling rows.NextResultSet.
23832459
funcTestSkipResults(t*testing.T){
@@ -2399,6 +2475,42 @@ func TestSkipResults(t *testing.T){
23992475
})
24002476
}
24012477

2478+
funcTestQueryMultipleResults(t*testing.T){
2479+
ctx:=context.Background()
2480+
runTestsWithMultiStatement(t, dsn, func(dbt*DBTest){
2481+
dbt.mustExec(`
2482+
CREATE TABLE test (
2483+
id INT NOT NULL AUTO_INCREMENT,
2484+
value VARCHAR(255),
2485+
PRIMARY KEY (id)
2486+
)`)
2487+
conn, err:=dbt.db.Conn(ctx)
2488+
iferr!=nil{
2489+
t.Fatalf("failed to connect: %v", err)
2490+
}
2491+
conn.Raw(func(conninterface{}) error{
2492+
qr:=conn.(driver.Queryer)
2493+
2494+
c:=conn.(*mysqlConn)
2495+
2496+
// Demonstrate that repeated queries reset the affectedRows
2497+
fori:=0; i<2; i++{
2498+
_, err:=qr.Query(`
2499+
INSERT INTO test (value) VALUES ('a'), ('b');
2500+
INSERT INTO test (value) VALUES ('c'), ('d'), ('e');
2501+
`, nil)
2502+
iferr!=nil{
2503+
t.Fatalf("insert statements failed: %v", err)
2504+
}
2505+
ifgot, want:=c.result.affectedRows, []int64{2, 3}; !reflect.DeepEqual(got, want){
2506+
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
2507+
}
2508+
}
2509+
returnnil
2510+
})
2511+
})
2512+
}
2513+
24022514
funcTestPingContext(t*testing.T){
24032515
runTests(t, dsn, func(dbt*DBTest){
24042516
ctx, cancel:=context.WithCancel(context.Background())

‎infile.go‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func deferredClose(err *error, closer io.Closer){
9393

9494
constdefaultPacketSize=16*1024// 16KB is small enough for disk readahead and large enough for TCP
9595

96-
func (mc*mysqlConn) handleInFileRequest(namestring) (errerror){
96+
func (mc*okHandler) handleInFileRequest(namestring) (errerror){
9797
varrdr io.Reader
9898
vardata []byte
9999
packetSize:=defaultPacketSize
@@ -154,7 +154,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error){
154154
forerr==nil{
155155
n, err=rdr.Read(data[4:])
156156
ifn>0{
157-
ifioErr:=mc.writePacket(data[:4+n]); ioErr!=nil{
157+
ifioErr:=mc.conn().writePacket(data[:4+n]); ioErr!=nil{
158158
returnioErr
159159
}
160160
}
@@ -168,7 +168,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error){
168168
ifdata==nil{
169169
data=make([]byte, 4)
170170
}
171-
ifioErr:=mc.writePacket(data[:4]); ioErr!=nil{
171+
ifioErr:=mc.conn().writePacket(data[:4]); ioErr!=nil{
172172
returnioErr
173173
}
174174

@@ -177,6 +177,6 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error){
177177
returnmc.readResultOK()
178178
}
179179

180-
mc.readPacket()
180+
mc.conn().readPacket()
181181
returnerr
182182
}

0 commit comments

Comments
(0)