Skip to content

Commit 883c78c

Browse files
committed
Enhance interpolateParams to correctly handle placeholders in queries with comments, strings, and backticks.
* Add `findParamPositions` to identify real parameter positions * Update and expand related tests.
1 parent 76c00e3 commit 883c78c

File tree

2 files changed

+154
-30
lines changed

2 files changed

+154
-30
lines changed

‎connection.go‎

Lines changed: 106 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ func (mc *mysqlConn) close(){
172172
}
173173

174174
// Closes the network connection and unsets internal variables. Do not call this
175-
// function after successfully authentication, call Close instead. This function
175+
// function after successful authentication, call Close instead. This function
176176
// is called before auth or on auth failure because MySQL will have already
177177
// closed the network connection.
178178
func (mc*mysqlConn) cleanup(){
@@ -245,9 +245,106 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error){
245245
returnstmt, err
246246
}
247247

248+
// findParamPositions returns the positions of real parameter holders ('?') in the query, ignoring those in comments, strings, or backticks.
249+
funcfindParamPositions(querystring) []int{
250+
const (
251+
stateNormal=iota
252+
stateString
253+
stateEscape
254+
stateEOLComment
255+
stateSlashStarComment
256+
stateBacktick
257+
)
258+
259+
var (
260+
QUOTE_BYTE=byte('\'')
261+
DBL_QUOTE_BYTE=byte('"')
262+
BACKSLASH_BYTE=byte('\\')
263+
QUESTION_MARK_BYTE=byte('?')
264+
SLASH_BYTE=byte('/')
265+
STAR_BYTE=byte('*')
266+
HASH_BYTE=byte('#')
267+
MINUS_BYTE=byte('-')
268+
LINE_FEED_BYTE=byte('\n')
269+
RADICAL_BYTE=byte('`')
270+
)
271+
272+
paramPositions:=make([]int, 0)
273+
state:=stateNormal
274+
singleQuotes:=false
275+
lastChar:=byte(0)
276+
lenq:=len(query)
277+
fori:=0; i<lenq; i++{
278+
currentChar:=query[i]
279+
ifstate==stateEscape&&!((currentChar==QUOTE_BYTE&&singleQuotes) || (currentChar==DBL_QUOTE_BYTE&&!singleQuotes)){
280+
state=stateString
281+
lastChar=currentChar
282+
continue
283+
}
284+
switchcurrentChar{
285+
caseSTAR_BYTE:
286+
ifstate==stateNormal&&lastChar==SLASH_BYTE{
287+
state=stateSlashStarComment
288+
}
289+
caseSLASH_BYTE:
290+
ifstate==stateSlashStarComment&&lastChar==STAR_BYTE{
291+
state=stateNormal
292+
} elseifstate==stateNormal&&lastChar==SLASH_BYTE{
293+
state=stateEOLComment
294+
}
295+
caseHASH_BYTE:
296+
ifstate==stateNormal{
297+
state=stateEOLComment
298+
}
299+
caseMINUS_BYTE:
300+
ifstate==stateNormal&&lastChar==MINUS_BYTE{
301+
state=stateEOLComment
302+
}
303+
caseLINE_FEED_BYTE:
304+
ifstate==stateEOLComment{
305+
state=stateNormal
306+
}
307+
caseDBL_QUOTE_BYTE:
308+
ifstate==stateNormal{
309+
state=stateString
310+
singleQuotes=false
311+
} elseifstate==stateString&&!singleQuotes{
312+
state=stateNormal
313+
} elseifstate==stateEscape{
314+
state=stateString
315+
}
316+
caseQUOTE_BYTE:
317+
ifstate==stateNormal{
318+
state=stateString
319+
singleQuotes=true
320+
} elseifstate==stateString&&singleQuotes{
321+
state=stateNormal
322+
} elseifstate==stateEscape{
323+
state=stateString
324+
}
325+
caseBACKSLASH_BYTE:
326+
ifstate==stateString{
327+
state=stateEscape
328+
}
329+
caseQUESTION_MARK_BYTE:
330+
ifstate==stateNormal{
331+
paramPositions=append(paramPositions, i)
332+
}
333+
caseRADICAL_BYTE:
334+
ifstate==stateBacktick{
335+
state=stateNormal
336+
} elseifstate==stateNormal{
337+
state=stateBacktick
338+
}
339+
}
340+
lastChar=currentChar
341+
}
342+
returnparamPositions
343+
}
344+
248345
func (mc*mysqlConn) interpolateParams(querystring, args []driver.Value) (string, error){
249-
// Number of ? should be same to len(args)
250-
ifstrings.Count(query, "?") !=len(args){
346+
paramPositions:=findParamPositions(query)
347+
iflen(paramPositions) !=len(args){
251348
return"", driver.ErrSkip
252349
}
253350

@@ -261,21 +358,16 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
261358
}
262359
buf=buf[:0]
263360
argPos:=0
361+
lastIdx:=0
264362

265-
fori:=0; i<len(query); i++{
266-
q:=strings.IndexByte(query[i:], '?')
267-
ifq==-1{
268-
buf=append(buf, query[i:]...)
269-
break
270-
}
271-
buf=append(buf, query[i:i+q]...)
272-
i+=q
273-
363+
for_, qmIdx:=rangeparamPositions{
364+
buf=append(buf, query[lastIdx:qmIdx]...)
274365
arg:=args[argPos]
275366
argPos++
276367

277368
ifarg==nil{
278369
buf=append(buf, "NULL"...)
370+
lastIdx=qmIdx+1
279371
continue
280372
}
281373

@@ -339,7 +431,9 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
339431
iflen(buf)+4>mc.maxAllowedPacket{
340432
return"", driver.ErrSkip
341433
}
434+
lastIdx=qmIdx+1
342435
}
436+
buf=append(buf, query[lastIdx:]...)
343437
ifargPos!=len(args){
344438
return"", driver.ErrSkip
345439
}

‎connection_test.go‎

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,24 +79,6 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T){
7979
}
8080
}
8181

82-
// We don't support placeholder in string literal for now.
83-
// https://github.com/go-sql-driver/mysql/pull/490
84-
funcTestInterpolateParamsPlaceholderInString(t*testing.T){
85-
mc:=&mysqlConn{
86-
buf: newBuffer(),
87-
maxAllowedPacket: maxPacketSize,
88-
cfg: &Config{
89-
InterpolateParams: true,
90-
},
91-
}
92-
93-
q, err:=mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
94-
// When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`
95-
iferr!=driver.ErrSkip{
96-
t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
97-
}
98-
}
99-
10082
funcTestInterpolateParamsUint64(t*testing.T){
10183
mc:=&mysqlConn{
10284
buf: newBuffer(),
@@ -204,3 +186,51 @@ func (bc badConnection) Write(b []byte) (n int, err error){
204186
func (bcbadConnection) Close() error{
205187
returnnil
206188
}
189+
190+
funcTestInterpolateParamsWithComments(t*testing.T){
191+
mc:=&mysqlConn{
192+
buf: newBuffer(),
193+
maxAllowedPacket: maxPacketSize,
194+
cfg: &Config{
195+
InterpolateParams: true,
196+
},
197+
}
198+
199+
tests:= []struct{
200+
querystring
201+
args []driver.Value
202+
expectedstring
203+
shouldSkipbool
204+
}{
205+
// ? in single-line comment (--) should not be replaced
206+
{"SELECT 1 -- ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 -- ?\n, 42", false},
207+
// ? in single-line comment (#) should not be replaced
208+
{"SELECT 1 # ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 # ?\n, 42", false},
209+
// ? in multi-line comment should not be replaced
210+
{"SELECT /* ? */ ?", []driver.Value{int64(42)}, "SELECT /* ? */ 42", false},
211+
// ? in string literal should not be replaced
212+
{"SELECT '?', ?", []driver.Value{int64(42)}, "SELECT '?', 42", false},
213+
// ? in backtick identifier should not be replaced
214+
{"SELECT `?`, ?", []driver.Value{int64(42)}, "SELECT `?`, 42", false},
215+
// Multiple comments and real placeholders
216+
{"SELECT ? -- comment ?\n, ? /* ? */ , ? # ?\n, ?", []driver.Value{int64(1), int64(2), int64(3)}, "SELECT 1 -- comment ?\n, 2 /* ? */ , 3 # ?\n, ?", true},
217+
}
218+
219+
fori, test:=rangetests{
220+
221+
q, err:=mc.interpolateParams(test.query, test.args)
222+
iftest.shouldSkip{
223+
iferr!=driver.ErrSkip{
224+
t.Errorf("Test %d: Expected driver.ErrSkip, got err=%#v, q=%#v", i, err, q)
225+
}
226+
continue
227+
}
228+
iferr!=nil{
229+
t.Errorf("Test %d: Expected err=nil, got %#v", i, err)
230+
continue
231+
}
232+
ifq!=test.expected{
233+
t.Errorf("Test %d: Expected: %q\nGot: %q", i, test.expected, q)
234+
}
235+
}
236+
}

0 commit comments

Comments
(0)