Skip to content

Commit 8379d94

Browse files
committed
dm: fix data race in cached WhereHandle
1 parent e613d3b commit 8379d94

4 files changed

Lines changed: 71 additions & 2 deletions

File tree

pkg/sqlmodel/causality.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ func (r *RowChange) getForeignKeyCausalityString(values []interface{}) []string
201201
}
202202

203203
func (r *RowChange) getCausalityString(values []interface{}) []string {
204-
pkAndUks := r.whereHandle.UniqueIdxs
204+
pkAndUks := r.whereHandle.getUniqueIdxs()
205205
if len(pkAndUks) == 0 {
206206
// the table has no PK/UK, all values of the row consists the causality key
207207
return []string{genKeyString(r.sourceTable.String(), r.sourceTableInfo.Columns, values)}

pkg/sqlmodel/reduce.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ func (r *RowChange) IsPrimaryOrUniqueKeyUpdated() bool {
130130
}
131131
}
132132

133-
for _, idx := range r.whereHandle.UniqueIdxs {
133+
for _, idx := range r.whereHandle.getUniqueIdxs() {
134134
if idx == nil || idx == r.whereHandle.UniqueNotNullIdx {
135135
continue
136136
}

pkg/sqlmodel/where_handle.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
package sqlmodel
1515

1616
import (
17+
"sync"
18+
1719
"github.com/pingcap/log"
1820
"github.com/pingcap/tidb/pkg/meta/model"
1921
pmodel "github.com/pingcap/tidb/pkg/parser/ast"
@@ -23,6 +25,8 @@ import (
2325

2426
// WhereHandle is used to generate a WHERE clause in SQL.
2527
type WhereHandle struct {
28+
mu sync.RWMutex
29+
2630
UniqueNotNullIdx *model.IndexInfo
2731
// If the index and columns have no NOT NULL constraint, but all data is NOT
2832
// NULL, we can still use it.
@@ -138,6 +142,9 @@ func (h *WhereHandle) getWhereIdxByData(data []interface{}) *model.IndexInfo {
138142
if h.UniqueNotNullIdx != nil {
139143
return h.UniqueNotNullIdx
140144
}
145+
146+
h.mu.Lock()
147+
defer h.mu.Unlock()
141148
for i, idx := range h.UniqueIdxs {
142149
ok := true
143150
for _, idxCol := range idx.Columns {
@@ -153,3 +160,13 @@ func (h *WhereHandle) getWhereIdxByData(data []interface{}) *model.IndexInfo {
153160
}
154161
return nil
155162
}
163+
164+
func (h *WhereHandle) getUniqueIdxs() []*model.IndexInfo {
165+
if h.UniqueNotNullIdx != nil {
166+
return h.UniqueIdxs
167+
}
168+
169+
h.mu.RLock()
170+
defer h.mu.RUnlock()
171+
return append([]*model.IndexInfo(nil), h.UniqueIdxs...)
172+
}

pkg/sqlmodel/where_handle_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
package sqlmodel
1515

1616
import (
17+
"fmt"
18+
"sync"
1719
"testing"
1820

1921
"github.com/pingcap/tidb/pkg/ddl"
@@ -215,3 +217,53 @@ CREATE TABLE t (
215217
idx = handle.getWhereIdxByData([]interface{}{1, nil, 3, nil})
216218
require.Nil(t, idx)
217219
}
220+
221+
func TestGetWhereIdxByDataNoRace(t *testing.T) {
222+
t.Parallel()
223+
224+
createSQL := `
225+
CREATE TABLE t (
226+
c1 INT,
227+
c2 INT,
228+
UNIQUE INDEX idx1 (c1),
229+
UNIQUE INDEX idx2 (c2)
230+
)`
231+
p := parser.New()
232+
node, err := p.ParseOneStmt(createSQL, "", "")
233+
require.NoError(t, err)
234+
ti, err := ddl.BuildTableInfoFromAST(metabuild.NewContext(), node.(*ast.CreateTableStmt))
235+
require.NoError(t, err)
236+
237+
handle := GetWhereHandle(ti, ti)
238+
checkIndex := func(data []interface{}, expected string) error {
239+
idx := handle.getWhereIdxByData(data)
240+
if idx == nil {
241+
return fmt.Errorf("expected %s, got nil", expected)
242+
}
243+
if idx.Name.L != expected {
244+
return fmt.Errorf("expected %s, got %s", expected, idx.Name.L)
245+
}
246+
return nil
247+
}
248+
249+
const concurrency = 100
250+
var wg sync.WaitGroup
251+
errCh := make(chan error, concurrency*2)
252+
for i := 0; i < concurrency; i++ {
253+
wg.Add(1)
254+
go func() {
255+
defer wg.Done()
256+
if err := checkIndex([]interface{}{1, nil}, "idx1"); err != nil {
257+
errCh <- err
258+
}
259+
if err := checkIndex([]interface{}{nil, 2}, "idx2"); err != nil {
260+
errCh <- err
261+
}
262+
}()
263+
}
264+
wg.Wait()
265+
close(errCh)
266+
for err := range errCh {
267+
require.NoError(t, err)
268+
}
269+
}

0 commit comments

Comments
 (0)