diff --git a/cgosqlite/cgosqlite.go b/cgosqlite/cgosqlite.go index 9e530e0..fbb8b8e 100644 --- a/cgosqlite/cgosqlite.go +++ b/cgosqlite/cgosqlite.go @@ -10,7 +10,6 @@ package cgosqlite #cgo CFLAGS: -DSQLITE_THREADSAFE=2 #cgo CFLAGS: -DSQLITE_DQS=0 -#cgo CFLAGS: -DSQLITE_DEFAULT_MEMSTATUS=0 #cgo CFLAGS: -DSQLITE_DEFAULT_WAL_SYNCHRONOUS=1 #cgo CFLAGS: -DSQLITE_LIKE_DOESNT_MATCH_BLOBS #cgo CFLAGS: -DSQLITE_MAX_EXPR_DEPTH=0 @@ -143,6 +142,10 @@ func (db *DB) TotalChanges() int { return int(C.sqlite3_total_changes(db.db)) } +func (db *DB) MemoryUsed() int64 { + return int64(C.sqlite3_memory_used()) +} + func (db *DB) ExtendedErrCode() sqliteh.Code { return sqliteh.Code(C.sqlite3_extended_errcode(db.db)) } diff --git a/go.mod b/go.mod index f3ecd7d..bda515b 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/tailscale/sqlite -go 1.21 +go 1.25 diff --git a/sqlite_test.go b/sqlite_test.go index a9f85d5..6a4b0f3 100644 --- a/sqlite_test.go +++ b/sqlite_test.go @@ -1540,15 +1540,31 @@ func TestConnLogger_read_tx(t *testing.T) { } } -func TestExpandedSQL(t *testing.T) { +func getTestConn(t testing.TB) *conn { ctx := context.Background() connector := Connector("file:"+t.TempDir()+"/test.db", nil, nil) sqlConn, err := connector.Connect(ctx) if err != nil { t.Fatalf("Connect: %v", err) } - conn := sqlConn.(*conn) + t.Cleanup(func() { sqlConn.Close() }) + return sqlConn.(*conn) +} + +// Verify we didn't build with -DSQLITE_DEFAULT_MEMSTATUS=0. +// We want memory stats. +func TestDBMemoryUsed(t *testing.T) { + conn := getTestConn(t) + mem0 := conn.db.MemoryUsed() + if mem0 == 0 { + t.Error("MemoryUsed=0; want non-zero; did you build with -DSQLITE_DEFAULT_MEMSTATUS=0?") + } +} +func TestExpandedSQL(t *testing.T) { + conn := getTestConn(t) + + ctx := t.Context() sqlStmt, err := conn.PrepareContext(ctx, "SELECT ? + ?") if err != nil { t.Fatalf("PrepareContext: %v", err) @@ -1567,4 +1583,14 @@ func TestExpandedSQL(t *testing.T) { if got, want := stmt.stmt.ExpandedSQL(), "SELECT 6 + 7"; got != want { t.Errorf("wrong sql: got %q, want %q", got, want) } + mem0 := conn.db.MemoryUsed() + for range 100 { + if got, want := stmt.stmt.ExpandedSQL(), "SELECT 6 + 7"; got != want { + t.Errorf("wrong sql: got %q, want %q", got, want) + } + } + mem1 := conn.db.MemoryUsed() + if mem1 > mem0 { + t.Errorf("memory leak detected: before=%v after=%v", mem0, mem1) + } } diff --git a/sqliteh/sqliteh.go b/sqliteh/sqliteh.go index 69b6a79..f4d5ed3 100644 --- a/sqliteh/sqliteh.go +++ b/sqliteh/sqliteh.go @@ -37,6 +37,9 @@ type DB interface { // TotalChanges is sqlite3_total_changes. // https://sqlite.org/c3ref/total_changes.html TotalChanges() int + // MemoryUsed is sqlite3_memory_used. + // https://sqlite.org/c3ref/memory_highwater.html + MemoryUsed() int64 // ExtendedErrCode is sqlite3_extended_errcode. // https://sqlite.org/c3ref/errcode.html ExtendedErrCode() Code