From d244f56e1a918c38775a23883b5d8ed60cc05e29 Mon Sep 17 00:00:00 2001 From: kj455 Date: Mon, 30 Dec 2024 09:53:49 +0900 Subject: [PATCH] feat: driver --- README.md | 5 + go.mod | 1 + go.sum | 2 + hack/playground/main.go | 80 +++++++++++ pkg/driver/driver.go | 191 +++++++++++++++++++++++++ pkg/metadata/stat_mgr.go | 2 +- pkg/parse/lexer.go | 3 +- pkg/plan/basic_query_planner.go | 16 +-- pkg/plan/planner.go | 4 - pkg/plan/planner_test.go | 87 +++++++++++ pkg/query/product_scan.go | 6 - pkg/query/select_scan.go | 11 +- pkg/record/record.go | 4 +- pkg/record/table_scan.go | 11 +- pkg/record/table_scan_test.go | 2 +- pkg/testutil/util.go | 9 ++ pkg/tx/transaction/transaction_test.go | 33 ++--- 17 files changed, 411 insertions(+), 56 deletions(-) create mode 100644 hack/playground/main.go create mode 100644 pkg/driver/driver.go create mode 100644 pkg/plan/planner_test.go diff --git a/README.md b/README.md index 780a490..ef00518 100644 --- a/README.md +++ b/README.md @@ -1 +1,6 @@ # simple-db-go + +## TODO + +- [ ] Iterator pattern +- [ ] Migration to `SetupDir` diff --git a/go.mod b/go.mod index d24188e..a005fd4 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.22.1 require ( github.com/stretchr/testify v1.9.0 go.uber.org/mock v0.4.0 + golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67 ) require ( diff --git a/go.sum b/go.sum index 721144c..87a9c40 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= +golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67 h1:1UoZQm6f0P/ZO0w1Ri+f+ifG/gXhegadRdwBIXEFWDo= +golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/hack/playground/main.go b/hack/playground/main.go new file mode 100644 index 0000000..9df2963 --- /dev/null +++ b/hack/playground/main.go @@ -0,0 +1,80 @@ +package main + +import ( + "database/sql" + "fmt" + "log" + + _ "github.com/kj455/simple-db/pkg/driver" + "golang.org/x/exp/rand" +) + +func RandomString(n int) string { + const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + b := make([]byte, n) + for i := range b { + b[i] = letters[rand.Intn(len(letters))] + } + return string(b) +} + +func main() { + const ( + driverName = "simple" + ) + dataSourceName := RandomString(30) + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + log.Fatalln("Failed to open database:", err) + } + defer db.Close() + + query := "create table T1(A int, B varchar(9))" + if _, err := db.Exec(query); err != nil { + log.Fatalln("Failed to create table:", err) + } + + defer func() { + _, err := db.Exec("delete from T1") + if err != nil { + log.Fatalln("Failed to delete table:", err) + } + log.Println("Successfully deleted table.") + }() + + n := 200 + log.Print("Inserting", n, "random records.") + for i := 0; i < n; i++ { + a := i + b := "rec" + fmt.Sprint(a) + stmt := fmt.Sprintf("insert into T1(A,B) values(%d, '%s')", a, b) + _, err := db.Exec(stmt) + if err != nil { + log.Fatalln("Failed to execute insert:", err) + } + } + log.Println("Inserted", n, "records.") + + query = "select A, B from T1 where A = 100" + rows, err := db.Query(query) + if err != nil { + log.Fatalln("Failed to execute query:", err) + } + defer rows.Close() + + fields, err := rows.Columns() + if err != nil { + log.Fatalln("Failed to get columns:", err) + } + log.Println("Columns:", fields) + for rows.Next() { + var a int + var b string + err = rows.Scan(&a, &b) + if err != nil { + log.Fatalln("Failed to scan row:", err) + } + log.Printf("Matched: A=%d, B=%s\n", a, b) + } + log.Println("Done.") +} diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go new file mode 100644 index 0000000..5087f6e --- /dev/null +++ b/pkg/driver/driver.go @@ -0,0 +1,191 @@ +package driver + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + + "github.com/kj455/simple-db/pkg/buffer" + buffermgr "github.com/kj455/simple-db/pkg/buffer_mgr" + "github.com/kj455/simple-db/pkg/file" + "github.com/kj455/simple-db/pkg/log" + "github.com/kj455/simple-db/pkg/metadata" + "github.com/kj455/simple-db/pkg/plan" + "github.com/kj455/simple-db/pkg/query" + "github.com/kj455/simple-db/pkg/tx" + "github.com/kj455/simple-db/pkg/tx/transaction" +) + +func init() { + sql.Register("simple", NewSimpleDriver()) +} + +type SimpleDriver struct { +} + +func NewSimpleDriver() *SimpleDriver { + return &SimpleDriver{} +} + +func (d *SimpleDriver) Open(name string) (driver.Conn, error) { + return NewConn(name) +} + +type Conn struct { + fileMgr file.FileMgr + bufMgr buffermgr.BufferMgr + logMgr log.LogMgr + tx tx.Transaction + mdMgr metadata.MetadataMgr + planner *plan.Planner +} + +const dir = "./.tmp" + +func NewConn(name string) (*Conn, error) { + const ( + buffNum = 8 + blockSize = 4096 + logFileName = "simple-db-conn-log" + ) + fileMgr := file.NewFileMgr(dir, blockSize) + logMgr, err := log.NewLogMgr(fileMgr, logFileName) + if err != nil { + return nil, fmt.Errorf("driver: failed to create log manager: %v", err) + } + buffs := make([]buffer.Buffer, buffNum) + for i := 0; i < buffNum; i++ { + buffs[i] = buffer.NewBuffer(fileMgr, logMgr, blockSize) + } + bm := buffermgr.NewBufferMgr(buffs) + txNumGen := transaction.NewTxNumberGenerator() + tx, err := transaction.NewTransaction(fileMgr, logMgr, bm, txNumGen) + if err != nil { + return nil, fmt.Errorf("driver: failed to create transaction: %v", err) + } + isNew := fileMgr.IsNew() + if !isNew { + if err := tx.Recover(); err != nil { + return nil, fmt.Errorf("driver: failed to recover transaction: %v", err) + } + } + mdMgr, err := metadata.NewMetadataMgr(tx) + if err != nil { + return nil, fmt.Errorf("driver: failed to create metadata manager: %v", err) + } + qp := plan.NewBasicQueryPlanner(mdMgr) + up := plan.NewBasicUpdatePlanner(mdMgr) + planner := plan.NewPlanner(qp, up) + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("driver: failed to commit transaction: %v", err) + } + return &Conn{ + fileMgr: fileMgr, + bufMgr: bm, + logMgr: logMgr, + tx: tx, + mdMgr: mdMgr, + planner: planner, + }, nil +} + +func (c *Conn) Begin() (driver.Tx, error) { + return nil, errors.New("driver: not implemented") +} + +func (c *Conn) Close() error { + return nil +} + +func (c *Conn) Prepare(query string) (driver.Stmt, error) { + return NewSimpleStmt(query, c), nil +} + +type Stmt struct { + conn *Conn + query string +} + +func NewSimpleStmt(query string, conn *Conn) *Stmt { + return &Stmt{ + query: query, + conn: conn, + } +} + +func (s *Stmt) Close() error { + return s.conn.tx.Commit() +} + +func (s *Stmt) NumInput() int { + return -1 +} + +func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) { + n, err := s.conn.planner.ExecuteUpdate(s.query, s.conn.tx) + if err != nil { + return nil, err + } + return Result{n: n}, nil +} + +type Result struct { + n int +} + +func (r Result) LastInsertId() (int64, error) { + return 0, errors.New("driver: not implemented") +} + +func (r Result) RowsAffected() (int64, error) { + return int64(r.n), nil +} + +func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) { + p, err := s.conn.planner.CreateQueryPlan(s.query, s.conn.tx) + if err != nil { + return nil, err + } + scan, err := p.Open() + if err != nil { + return nil, fmt.Errorf("driver: failed to open plan: %v", err) + } + fields := p.Schema().Fields() + return NewSimpleRows(scan, fields), nil +} + +type Rows struct { + scan query.Scan + fields []string +} + +func NewSimpleRows(scan query.Scan, fields []string) *Rows { + return &Rows{ + scan: scan, + fields: fields, + } +} + +func (r *Rows) Columns() []string { + return r.fields +} + +func (r *Rows) Close() error { + r.scan.Close() + return nil +} + +func (r *Rows) Next(dest []driver.Value) error { + if !r.scan.Next() { + return driver.ErrSkip + } + for i, field := range r.fields { + val, err := r.scan.GetVal(field) + if err != nil { + return fmt.Errorf("driver: failed to get value: %v", err) + } + dest[i] = val.AnyValue() + } + return nil +} diff --git a/pkg/metadata/stat_mgr.go b/pkg/metadata/stat_mgr.go index 9a3f92e..248a6b8 100644 --- a/pkg/metadata/stat_mgr.go +++ b/pkg/metadata/stat_mgr.go @@ -99,7 +99,7 @@ func (sm *StatMgrImpl) calcTableStats(tableName string, layout record.Layout, tx var numRecs, numBlocks int for ts.Next() { numRecs++ - numBlocks = ts.GetRid().BlockNumber() + 1 + numBlocks = ts.GetRID().BlockNumber() + 1 } return NewStatInfo(numBlocks, numRecs), nil } diff --git a/pkg/parse/lexer.go b/pkg/parse/lexer.go index 361b3fb..4682929 100644 --- a/pkg/parse/lexer.go +++ b/pkg/parse/lexer.go @@ -125,8 +125,7 @@ func (l *Lexer) matchIntConstant() bool { // matchStringConstant returns true if the current token is a string. func (l *Lexer) MatchStringConstant() bool { - // return l.ttype == 'S' // Assuming 'S' represents a string - return rune(l.strVal[0]) == '\'' + return l.typ == TokenString } // matchKeyword returns true if the current token is the specified keyword. diff --git a/pkg/plan/basic_query_planner.go b/pkg/plan/basic_query_planner.go index 44ddf8c..0984e39 100644 --- a/pkg/plan/basic_query_planner.go +++ b/pkg/plan/basic_query_planner.go @@ -24,13 +24,13 @@ func (bp *BasicQueryPlanner) CreatePlan(data *parse.QueryData, tx tx.Transaction plans := make([]Plan, 0, len(data.Tables)) for _, table := range data.Tables { viewDef, err := bp.mdMgr.GetViewDef(table, tx) - if err != nil && errors.Is(err, metadata.ErrViewNotFound) { - return nil, fmt.Errorf("planner: failed to get view definition for %s: %v", table, err) + if err != nil && !errors.Is(err, metadata.ErrViewNotFound) { + return nil, fmt.Errorf("plan: failed to get view definition for %s: %v", table, err) } if isTable := errors.Is(err, metadata.ErrViewNotFound); isTable { plan, err := NewTablePlan(tx, table, bp.mdMgr) if err != nil { - return nil, fmt.Errorf("planner: failed to create table plan for %s: %v", table, err) + return nil, fmt.Errorf("plan: failed to create table plan for %s: %v", table, err) } plans = append(plans, plan) continue @@ -38,17 +38,17 @@ func (bp *BasicQueryPlanner) CreatePlan(data *parse.QueryData, tx tx.Transaction parser := parse.NewParser(viewDef) viewData, err := parser.Query() if err != nil { - return nil, fmt.Errorf("planner: failed to parse view definition for %s: %v", table, err) + return nil, fmt.Errorf("plan: failed to parse view definition for %s: %v", table, err) } plan, err := bp.CreatePlan(viewData, tx) if err != nil { - return nil, fmt.Errorf("planner: failed to create view plan for %s: %v", table, err) + return nil, fmt.Errorf("plan: failed to create view plan for %s: %v", table, err) } plans = append(plans, plan) } if len(plans) == 0 { - return nil, errors.New("planner: no tables or views in query") + return nil, errors.New("plan: no tables or views in query") } plan := plans[0] @@ -56,13 +56,13 @@ func (bp *BasicQueryPlanner) CreatePlan(data *parse.QueryData, tx tx.Transaction for i := 1; i < len(plans); i++ { plan, err = NewProductPlan(plan, plans[i]) if err != nil { - return nil, fmt.Errorf("planner: failed to create product plan: %v", err) + return nil, fmt.Errorf("plan: failed to create product plan: %v", err) } } plan = NewSelectPlan(plan, data.Pred) plan, err = NewProjectPlan(plan, data.Fields) if err != nil { - return nil, fmt.Errorf("planner: failed to create project plan: %v", err) + return nil, fmt.Errorf("plan: failed to create project plan: %v", err) } return plan, nil diff --git a/pkg/plan/planner.go b/pkg/plan/planner.go index 0971b2e..1d5927a 100644 --- a/pkg/plan/planner.go +++ b/pkg/plan/planner.go @@ -7,10 +7,6 @@ import ( "github.com/kj455/simple-db/pkg/tx" ) -type QueryPlanner interface { - CreatePlan(data parse.QueryData, tx tx.Transaction) (Plan, error) -} - type Planner struct { queryPlanner *BasicQueryPlanner updatePlanner *BasicUpdatePlanner diff --git a/pkg/plan/planner_test.go b/pkg/plan/planner_test.go new file mode 100644 index 0000000..1b8f243 --- /dev/null +++ b/pkg/plan/planner_test.go @@ -0,0 +1,87 @@ +package plan + +import ( + "fmt" + "math" + "math/rand" + "testing" + + "github.com/kj455/simple-db/pkg/buffer" + buffermgr "github.com/kj455/simple-db/pkg/buffer_mgr" + "github.com/kj455/simple-db/pkg/file" + "github.com/kj455/simple-db/pkg/log" + "github.com/kj455/simple-db/pkg/metadata" + "github.com/kj455/simple-db/pkg/testutil" + "github.com/kj455/simple-db/pkg/tx/transaction" + "github.com/stretchr/testify/require" +) + +func TestPlanner(t *testing.T) { + const ( + dirname = "studentdb" + logFileName = "logfile" + blockSize = 400 + ) + dir, cleanup := testutil.SetupDir(dirname) + t.Cleanup(cleanup) + fm := file.NewFileMgr(dir, blockSize) + lm, err := log.NewLogMgr(fm, logFileName) + require.NoError(t, err) + const buffNum = 8 + buffs := make([]buffer.Buffer, buffNum) + for i := 0; i < buffNum; i++ { + buffs[i] = buffer.NewBuffer(fm, lm, blockSize) + } + bm := buffermgr.NewBufferMgr(buffs) + txNumGen := transaction.NewTxNumberGenerator() + tx, err := transaction.NewTransaction(fm, lm, bm, txNumGen) + require.NoError(t, err) + mdm, err := metadata.NewMetadataMgr(tx) + require.NoError(t, err) + qp := NewBasicQueryPlanner(mdm) + up := NewBasicUpdatePlanner(mdm) + planner := NewPlanner(qp, up) + + cmd := "create table student(sname varchar(10), gradyear int, majorid int, studentid int)" + _, err = planner.ExecuteUpdate(cmd, tx) + require.NoError(t, err) + tx.Commit() + + const recordNum = 100 + for i := 0; i < recordNum; i++ { + name := fmt.Sprintf("student%d", i) + gradYear := int(math.Round(rand.Float64() * 50)) + cmd = fmt.Sprintf("insert into student(sname, gradyear, majorid, studentid) values('%s', %d, %d, %d)", name, gradYear, i, i) + _, err = planner.ExecuteUpdate(cmd, tx) + require.NoError(t, err) + } + + qry := "select sname, gradyear from student" + p, err := planner.CreateQueryPlan(qry, tx) + require.NoError(t, err) + s, err := p.Open() + require.NoError(t, err) + + for s.Next() { + sname, err := s.GetString("sname") + require.NoError(t, err) + gradyear, err := s.GetInt("gradyear") + require.NoError(t, err) + t.Logf("sname=%s, gradyear=%d\n", sname, gradyear) + } + s.Close() + + cmd = "delete from STUDENT where majorid = 30" + num, err := planner.ExecuteUpdate(cmd, tx) + require.NoError(t, err) + require.Equal(t, 1, num) + + cmd = "select sname from student where majorid = 30" + p, err = planner.CreateQueryPlan(cmd, tx) + require.NoError(t, err) + s, err = p.Open() + require.NoError(t, err) + require.False(t, s.Next()) + + tx.Commit() +} diff --git a/pkg/query/product_scan.go b/pkg/query/product_scan.go index 87b4a5a..583dbf4 100644 --- a/pkg/query/product_scan.go +++ b/pkg/query/product_scan.go @@ -1,8 +1,6 @@ package query import ( - "fmt" - "github.com/kj455/simple-db/pkg/constant" ) @@ -46,12 +44,8 @@ func (p *ProductScan) Next() bool { // GetInt returns the integer value of the specified field. The value is obtained from whichever scan contains the field. func (p *ProductScan) GetInt(field string) (int, error) { if p.s1.HasField(field) { - v, _ := p.s1.GetInt(field) - fmt.Println("get from s1:", v) return p.s1.GetInt(field) } - v, _ := p.s2.GetInt(field) - fmt.Println("get from s2:", v) return p.s2.GetInt(field) } diff --git a/pkg/query/select_scan.go b/pkg/query/select_scan.go index 71e24d8..7c4150f 100644 --- a/pkg/query/select_scan.go +++ b/pkg/query/select_scan.go @@ -95,15 +95,12 @@ func (s *SelectScan) Insert() error { return us.Insert() } -func (s *SelectScan) GetRid() (record.RID, error) { - us, ok := s.scan.(UpdatableScan) - if !ok { - return nil, fmt.Errorf("query: scan is not an UpdateScan") - } - return us.GetRID(), nil +func (s *SelectScan) GetRID() record.RID { + us := s.scan.(UpdatableScan) + return us.GetRID() } -func (s *SelectScan) MoveToRid(rid record.RID) error { +func (s *SelectScan) MoveToRID(rid record.RID) error { us, ok := s.scan.(UpdatableScan) if !ok { return fmt.Errorf("query: scan is not an UpdateScan") diff --git a/pkg/record/record.go b/pkg/record/record.go index cd7b748..117dc06 100644 --- a/pkg/record/record.go +++ b/pkg/record/record.go @@ -64,6 +64,6 @@ type TableScan interface { Close() Insert() error Delete() error - MoveToRid(rid RID) - GetRid() RID + MoveToRID(rid RID) error + GetRID() RID } diff --git a/pkg/record/table_scan.go b/pkg/record/table_scan.go index bbfccbd..e04aedf 100644 --- a/pkg/record/table_scan.go +++ b/pkg/record/table_scan.go @@ -51,7 +51,6 @@ func (ts *TableScanImpl) Next() bool { } err := ts.moveToBlock(ts.recordPage.Block().Number() + 1) if err != nil { - fmt.Println("record: table scan: next: ", err) return false } ts.curSlot = ts.recordPage.NextAfter(ts.curSlot) @@ -157,14 +156,18 @@ func (ts *TableScanImpl) Delete() error { return ts.recordPage.Delete(ts.curSlot) } -func (ts *TableScanImpl) MoveToRid(rid RID) { +func (ts *TableScanImpl) MoveToRID(rid RID) (err error) { ts.Close() blk := file.NewBlockId(ts.filename, rid.BlockNumber()) - ts.recordPage, _ = NewRecordPage(ts.tx, blk, ts.layout) + ts.recordPage, err = NewRecordPage(ts.tx, blk, ts.layout) + if err != nil { + return fmt.Errorf("record: failed to move to rid: %w", err) + } ts.curSlot = rid.Slot() + return nil } -func (ts *TableScanImpl) GetRid() RID { +func (ts *TableScanImpl) GetRID() RID { return NewRID(ts.recordPage.Block().Number(), ts.curSlot) } diff --git a/pkg/record/table_scan_test.go b/pkg/record/table_scan_test.go index 1624a21..24e4d0a 100644 --- a/pkg/record/table_scan_test.go +++ b/pkg/record/table_scan_test.go @@ -64,7 +64,7 @@ func TestTableScan(t *testing.T) { // Scan the records rid := NewRID(0, -1) - scan.MoveToRid(rid) + scan.MoveToRID(rid) count := 0 for scan.Next() { a, err := scan.GetInt("A") diff --git a/pkg/testutil/util.go b/pkg/testutil/util.go index ef1b14c..5957550 100644 --- a/pkg/testutil/util.go +++ b/pkg/testutil/util.go @@ -7,6 +7,15 @@ import ( const testDir = ".tmp" +func SetupDir(dirname string) (dir string, cleanup func()) { + path := filepath.Join(RootDir(), testDir, dirname) + _ = os.MkdirAll(path, os.ModePerm) + cleanup = func() { + _ = os.RemoveAll(path) + } + return path, cleanup +} + // SetupFile creates a file in the test directory and returns the directory, file, and cleanup function. func SetupFile(filename string) (dir string, f *os.File, cleanup func()) { path := filepath.Join(RootDir(), testDir, filename) diff --git a/pkg/tx/transaction/transaction_test.go b/pkg/tx/transaction/transaction_test.go index 31b2f80..14c2931 100644 --- a/pkg/tx/transaction/transaction_test.go +++ b/pkg/tx/transaction/transaction_test.go @@ -16,17 +16,11 @@ import ( func TestTransaction(t *testing.T) { t.Parallel() - const ( - filename = "test_transaction" - logFilename = "test_transaction_log" - blockSize = 400 - ) - dir, _, cleanup := testutil.SetupFile(filename) + const blockSize = 400 + dir, cleanup := testutil.SetupDir("test_transaction") t.Cleanup(cleanup) - _, _, cleanupLog := testutil.SetupFile(logFilename) - defer cleanupLog() fileMgr := file.NewFileMgr(dir, blockSize) - logMgr, err := log.NewLogMgr(fileMgr, logFilename) + logMgr, err := log.NewLogMgr(fileMgr, "test_transaction_log") assert.NoError(t, err) const buffNum = 2 buffs := make([]buffer.Buffer, buffNum) @@ -41,7 +35,7 @@ func TestTransaction(t *testing.T) { assert.NoError(t, err) assert.Equal(t, buffNum, tx1.AvailableBuffs()) - block := file.NewBlockId(filename, 0) + block := file.NewBlockId("test_transaction", 0) tx1.Pin(block) tx1.SetInt(block, 80, 1, false) tx1.SetString(block, 40, "one", false) @@ -84,14 +78,12 @@ func TestTransaction(t *testing.T) { func TestTransaction_Concurrency(t *testing.T) { t.Parallel() const ( - blockSize = 400 - testFileName = "test_transaction_concurrency" - testLogFileName = "test_transaction_concurrency_log" + blockSize = 400 + dirname = "test_transaction_concurrency" + testFileName = "concurrency" ) - dir, _, cleanup := testutil.SetupFile(testFileName) + dir, cleanup := testutil.SetupDir("test_transaction_concurrency") t.Cleanup(cleanup) - _, _, cleanupLog := testutil.SetupFile(testLogFileName) - defer cleanupLog() fm := file.NewFileMgr(dir, blockSize) lm, _ := log.NewLogMgr(fm, testFileName) buffs := make([]buffer.Buffer, 2) @@ -184,13 +176,12 @@ func TestTransaction_Size(t *testing.T) { t.Parallel() const ( blockSize = 400 - fileName = "test_transaction_size" - logFileName = "test_transaction_size_log" + dirname = "test_transaction_size" + fileName = "file" + logFileName = "log" ) - dir, _, cleanup := testutil.SetupFile(fileName) + dir, cleanup := testutil.SetupDir(dirname) t.Cleanup(cleanup) - _, _, cleanupLog := testutil.SetupFile(logFileName) - defer cleanupLog() fileMgr := file.NewFileMgr(dir, blockSize) logMgr, err := log.NewLogMgr(fileMgr, logFileName) assert.NoError(t, err)