diff --git a/Makefile b/Makefile index c0986be..c04c6cc 100644 --- a/Makefile +++ b/Makefile @@ -3,9 +3,9 @@ PKG := $(shell go list ./... | grep -v '/mock$$') print-pkg: @echo $(PKG) test: - go test -v $(PKG) + gotestsum --format testname $(PKG) testw: - gow test -timeout 5s $(PKG) + gotestsum --watch --format testname $(PKG) lint: golangci-lint run clean: diff --git a/pkg/constant/constant.go b/pkg/constant/constant.go new file mode 100644 index 0000000..2dc8225 --- /dev/null +++ b/pkg/constant/constant.go @@ -0,0 +1,104 @@ +package constant + +import "fmt" + +type Kind string + +const ( + KIND_INT Kind = "int" + KIND_STR Kind = "string" +) + +// Const denotes values stored in the database. +type Const struct { + val any + kind Kind +} + +func NewConstant(kind Kind, val any) (*Const, error) { + switch kind { + case KIND_INT: + if _, ok := val.(int); !ok { + return nil, fmt.Errorf("constant: value is not an integer") + } + case KIND_STR: + if _, ok := val.(string); !ok { + return nil, fmt.Errorf("constant: value is not a string") + } + default: + } + return &Const{ + val: val, + kind: kind, + }, nil +} + +// AsInt returns the integer value of the constant. +func (c *Const) AsInt() int { + if c.kind == KIND_INT { + return c.val.(int) + } + return 0 // or panic/error if you want to handle it strictly +} + +// AsString returns the string value of the constant. +func (c *Const) AsString() string { + if c.kind == KIND_STR { + return c.val.(string) + } + return "" // or panic/error if you want to handle it strictly +} + +// Equals checks if two constants are equal. +func (c *Const) Equals(other *Const) bool { + if c.kind != other.kind { + return false + } + return c.val == other.val +} + +// CompareTo returns 0 if two constants are equal, -1 if the receiver is less than the other, and 1 if the receiver is greater than the other. +func (c *Const) CompareTo(other *Const) int { + if c.kind != other.kind { + return 0 // or panic/error if you want to handle it strictly + } + if c.val == other.val { + return 0 + } + // TODO: Implement comparison for other types + if c.val.(int) < other.val.(int) { + return -1 + } + return 1 +} + +// HashCode returns the hash code of the constant. +func (c *Const) HashCode() int { + // TODO: Implement more valid hash code + switch c.kind { + case KIND_INT: + return c.AsInt() + case KIND_STR: + return len(c.AsString()) + } + return 0 +} + +// ToString returns the string representation of the constant. +func (c *Const) ToString() string { + switch c.kind { + case KIND_INT: + return fmt.Sprint(c.AsInt()) + case KIND_STR: + return c.AsString() + } + return "" +} + +func (c *Const) Kind() Kind { + return c.kind +} + +func (c *Const) AnyValue() any { + return c.val +} diff --git a/pkg/constant/constant_test.go b/pkg/constant/constant_test.go new file mode 100644 index 0000000..552e9fb --- /dev/null +++ b/pkg/constant/constant_test.go @@ -0,0 +1,37 @@ +package constant + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConstant(t *testing.T) { + t.Run("NewConstant", func(t *testing.T) { + c, _ := NewConstant(KIND_INT, 42) + assert.Equal(t, 42, c.AsInt()) + c, _ = NewConstant(KIND_STR, "hello") + assert.Equal(t, "hello", c.AsString()) + _, err := NewConstant(KIND_INT, "hello") + assert.Error(t, err) + }) + t.Run("Equals", func(t *testing.T) { + c1, _ := NewConstant(KIND_INT, 42) + c2, _ := NewConstant(KIND_INT, 42) + c3, _ := NewConstant(KIND_INT, 43) + c4, _ := NewConstant(KIND_STR, "hello") + assert.True(t, c1.Equals(c2)) + assert.False(t, c1.Equals(c3)) + assert.False(t, c1.Equals(c4)) + }) + t.Run("CompareTo", func(t *testing.T) { + c1, _ := NewConstant(KIND_INT, 42) + c2, _ := NewConstant(KIND_INT, 42) + c3, _ := NewConstant(KIND_INT, 43) + c4, _ := NewConstant(KIND_STR, "hello") + assert.Equal(t, 0, c1.CompareTo(c2)) + assert.Equal(t, -1, c1.CompareTo(c3)) + assert.Equal(t, 1, c3.CompareTo(c1)) + assert.Equal(t, 0, c4.CompareTo(c4)) + }) +} diff --git a/pkg/query/expression.go b/pkg/query/expression.go new file mode 100644 index 0000000..0ac7184 --- /dev/null +++ b/pkg/query/expression.go @@ -0,0 +1,61 @@ +package query + +import ( + "github.com/kj455/db/pkg/constant" + "github.com/kj455/db/pkg/record" +) + +// Expression corresponds to SQL expressions. +type ExpressionImpl struct { + val *constant.Const + field string +} + +// NewConstantExpression creates a new expression with a constant value. +func NewConstantExpression(val *constant.Const) Expression { + return &ExpressionImpl{val: val} +} + +// NewFieldExpression creates a new expression with a field name. +func NewFieldExpression(field string) Expression { + return &ExpressionImpl{field: field} +} + +// Evaluate evaluates the expression with respect to the current constant of the specified scan. +func (e *ExpressionImpl) Evaluate(s Scan) (*constant.Const, error) { + if e.val != nil { + return e.val, nil + } + return s.GetVal(e.field) +} + +// IsFieldName returns true if the expression is a field reference. +func (e *ExpressionImpl) IsFieldName() bool { + return e.field != "" +} + +// AsConstant returns the constant corresponding to a constant expression, or nil if the expression does not denote a constant. +func (e *ExpressionImpl) AsConstant() *constant.Const { + return e.val +} + +// AsFieldName returns the field name corresponding to a constant expression, or an empty string if the expression does not denote a field. +func (e *ExpressionImpl) AsFieldName() string { + return e.field +} + +// AppliesTo determines if all of the fields mentioned in this expression are contained in the specified schema. +func (e *ExpressionImpl) AppliesTo(sch record.Schema) bool { + if e.val != nil { + return true + } + return sch.HasField(e.field) +} + +// ToString returns the string representation of the expression. +func (e *ExpressionImpl) ToString() string { + if e.val != nil { + return e.val.ToString() + } + return e.field +} diff --git a/pkg/query/interface.go b/pkg/query/interface.go new file mode 100644 index 0000000..26097b2 --- /dev/null +++ b/pkg/query/interface.go @@ -0,0 +1,62 @@ +package query + +import ( + "github.com/kj455/db/pkg/constant" + "github.com/kj455/db/pkg/record" +) + +type Scan interface { + // BeforeFirst positions the scan before its first record. + // A subsequent call to Next will return the first record. + BeforeFirst() error + + // Next moves the scan to the next record. + // Returns false if there is no next record. + Next() bool + + // GetInt returns the value of the specified integer field in the current record. + GetInt(field string) (int, error) + + // GetString returns the value of the specified string field in the current record. + GetString(field string) (string, error) + + // GetVal returns the value of the specified field in the current record, expressed as a Constant. + GetVal(field string) (*constant.Const, error) + + // HasField checks if the scan has the specified field. + // The field parameter represents the name of the field. + // Returns true if the scan has that field. + HasField(field string) bool + + // Close closes the scan and its subscans, if any. + Close() +} + +type UpdateScan interface { + Scan + SetInt(field string, val int) error + SetString(field string, val string) error + SetVal(field string, val *constant.Const) error + Insert() error + Delete() error + + GetRID() record.RID + MoveToRID(rid record.RID) error +} + +type Predicate interface { + IsSatisfied(s Scan) (bool, error) +} + +type Expression interface { + Evaluate(s Scan) (*constant.Const, error) + IsFieldName() bool + AsConstant() *constant.Const + AsFieldName() string + AppliesTo(sch record.Schema) bool + ToString() string +} + +type PlanInfo interface { + DistinctValues(field string) int +} diff --git a/pkg/query/predicate.go b/pkg/query/predicate.go new file mode 100644 index 0000000..e896f2b --- /dev/null +++ b/pkg/query/predicate.go @@ -0,0 +1,121 @@ +package query + +import ( + "fmt" + "strings" + + "github.com/kj455/db/pkg/constant" + "github.com/kj455/db/pkg/record" +) + +// PredicateImpl is a Boolean combination of terms. +type PredicateImpl struct { + terms []*Term +} + +const MAX_TERMS = 10 + +// NewPredicate creates an empty predicate, corresponding to "true". +func NewPredicate(term *Term) *PredicateImpl { + pr := &PredicateImpl{ + terms: make([]*Term, 0, MAX_TERMS), + } + if term != nil { + // pr.terms[0] = term + pr.terms = append(pr.terms, term) + } + return pr +} + +// NewPredicateWithTerm creates a predicate containing a single term. +func NewPredicateWithTerm(t *Term) *PredicateImpl { + return &PredicateImpl{terms: []*Term{t}} +} + +// ConjoinWith modifies the predicate to be the conjunction of itself and the specified predicate. +func (p *PredicateImpl) ConjoinWith(pred *PredicateImpl) { + p.terms = append(p.terms, pred.terms...) +} + +// IsSatisfied returns true if the predicate evaluates to true with respect to the specified scan. +func (p *PredicateImpl) IsSatisfied(s Scan) (bool, error) { + for _, t := range p.terms { + if ok, err := t.IsSatisfied(s); !ok || err != nil { + return false, err + } + } + return true, nil +} + +// ReductionFactor calculates the extent to which selecting on the predicate reduces the number of records output by a query. +func (p *PredicateImpl) ReductionFactor(plan PlanInfo) int { + factor := 1 + for _, t := range p.terms { + factor *= t.ReductionFactor(plan) + } + return factor +} + +// SelectSubPred returns the subpredicate that applies to the specified schema. +func (p *PredicateImpl) SelectSubPred(sch record.Schema) *PredicateImpl { + result := NewPredicate(nil) + for _, t := range p.terms { + if t.AppliesTo(sch) { + result.terms = append(result.terms, t) + } + } + if len(result.terms) == 0 { + return nil + } + return result +} + +// JoinSubPred returns the subpredicate consisting of terms that apply to the union of the two specified schemas, but not to either schema separately. +func (p *PredicateImpl) JoinSubPred(sch1, sch2 record.Schema) (*PredicateImpl, error) { + result := NewPredicate(nil) + newsch := record.NewSchema() + if err := newsch.AddAll(sch1); err != nil { + return nil, fmt.Errorf("error adding schema1: %v", err) + } + if err := newsch.AddAll(sch2); err != nil { + return nil, fmt.Errorf("error adding schema2: %v", err) + } + for _, t := range p.terms { + if !t.AppliesTo(sch1) && !t.AppliesTo(sch2) && t.AppliesTo(newsch) { + result.terms = append(result.terms, t) + } + } + // TODO: return nil or empty predicate? + if len(result.terms) == 0 { + return nil, fmt.Errorf("no join subpredicate") + } + return result, nil +} + +// EquatesWithConstant determines if there is a term of the form "F=c" where F is the specified field and c is some constant. +func (p *PredicateImpl) EquatesWithConstant(field string) (*constant.Const, bool) { + for _, t := range p.terms { + if c, ok := t.EquatesWithConstant(field); ok { + return c, true + } + } + return nil, false +} + +// EquatesWithField determines if there is a term of the form "F1=F2" where F1 is the specified field and F2 is another field. +func (p *PredicateImpl) EquatesWithField(field string) (string, bool) { + for _, t := range p.terms { + if s, ok := t.EquatesWithField(field); ok { + return s, true + } + } + return "", false +} + +func (p *PredicateImpl) String() string { + var terms []string + for _, t := range p.terms { + terms = append(terms, t.String()) + } + return strings.Join(terms, " and ") +} diff --git a/pkg/query/product_scan.go b/pkg/query/product_scan.go new file mode 100644 index 0000000..3fec44e --- /dev/null +++ b/pkg/query/product_scan.go @@ -0,0 +1,75 @@ +package query + +import "github.com/kj455/db/pkg/constant" + +// ProductScan corresponds to the product relational algebra operator. +type ProductScan struct { + s1, s2 Scan +} + +// NewProductScan creates a product scan having the two underlying scans. +func NewProductScan(s1, s2 Scan) (*ProductScan, error) { + p := &ProductScan{ + s1: s1, + s2: s2, + } + if err := p.BeforeFirst(); err != nil { + return nil, err + } + return p, nil +} + +// BeforeFirst positions the scan before its first record. In particular, the LHS scan is positioned at its first record, and the RHS scan is positioned before its first record. +func (p *ProductScan) BeforeFirst() error { + if err := p.s1.BeforeFirst(); err != nil { + return err + } + p.s1.Next() + return p.s2.BeforeFirst() +} + +// Next moves the scan to the next record. The method moves to the next RHS record, if possible. Otherwise, it moves to the next LHS record and the first RHS record. If there are no more LHS records, the method returns false. +func (p *ProductScan) Next() bool { + if p.s2.Next() { + return true + } + if err := p.s2.BeforeFirst(); err != nil { + return false + } + return p.s2.Next() && p.s1.Next() +} + +// GetInt returns the integer value of the specified field. The value is obtained from whichever scan contains the field. +func (p *ProductScan) GetInt(field string) (int, error) { + if p.s1.HasField(field) { + return p.s1.GetInt(field) + } + return p.s2.GetInt(field) +} + +// GetString returns the string value of the specified field. The value is obtained from whichever scan contains the field. +func (p *ProductScan) GetString(field string) (string, error) { + if p.s1.HasField(field) { + return p.s1.GetString(field) + } + return p.s2.GetString(field) +} + +// GetVal returns the value of the specified field. The value is obtained from whichever scan contains the field. +func (p *ProductScan) GetVal(field string) (*constant.Const, error) { + if p.s1.HasField(field) { + return p.s1.GetVal(field) + } + return p.s2.GetVal(field) +} + +// HasField returns true if the specified field is in either of the underlying scans. +func (p *ProductScan) HasField(field string) bool { + return p.s1.HasField(field) || p.s2.HasField(field) +} + +// Close closes both underlying scans. +func (p *ProductScan) Close() { + p.s1.Close() + p.s2.Close() +} diff --git a/pkg/query/project_scan.go b/pkg/query/project_scan.go new file mode 100644 index 0000000..2c313d2 --- /dev/null +++ b/pkg/query/project_scan.go @@ -0,0 +1,61 @@ +package query + +import ( + "fmt" + + "github.com/kj455/db/pkg/constant" +) + +type ProjectScan struct { + s Scan + fields []string +} + +func NewProjectScan(s Scan, fields []string) *ProjectScan { + return &ProjectScan{ + s: s, + fields: fields, + } +} + +func (ps *ProjectScan) BeforeFirst() error { + return ps.s.BeforeFirst() +} + +func (ps *ProjectScan) Next() bool { + return ps.s.Next() +} + +func (ps *ProjectScan) GetInt(fldname string) (int, error) { + if ps.HasField(fldname) { + return ps.s.GetInt(fldname) + } + return 0, fmt.Errorf("query: field %s not found", fldname) +} + +func (ps *ProjectScan) GetString(fldname string) (string, error) { + if ps.HasField(fldname) { + return ps.s.GetString(fldname) + } + return "", fmt.Errorf("query: field %s not found", fldname) +} + +func (ps *ProjectScan) GetVal(fldname string) (*constant.Const, error) { + if ps.HasField(fldname) { + return ps.s.GetVal(fldname) + } + return nil, fmt.Errorf("query: field %s not found", fldname) +} + +func (ps *ProjectScan) HasField(fldname string) bool { + for _, field := range ps.fields { + if field == fldname { + return true + } + } + return false +} + +func (ps *ProjectScan) Close() { + ps.s.Close() +} diff --git a/pkg/query/scan_test.go b/pkg/query/scan_test.go new file mode 100644 index 0000000..52ce3ee --- /dev/null +++ b/pkg/query/scan_test.go @@ -0,0 +1,150 @@ +package query + +import ( + "fmt" + "math/rand/v2" + "testing" + + "github.com/kj455/db/pkg/buffer" + buffermgr "github.com/kj455/db/pkg/buffer_mgr" + "github.com/kj455/db/pkg/constant" + "github.com/kj455/db/pkg/file" + "github.com/kj455/db/pkg/log" + "github.com/kj455/db/pkg/metadata" + "github.com/kj455/db/pkg/record" + "github.com/kj455/db/pkg/testutil" + "github.com/kj455/db/pkg/tx/transaction" + "github.com/stretchr/testify/assert" +) + +func TestScan1(t *testing.T) { + const blockSize = 400 + rootDir := testutil.ProjectRootDir() + dir := rootDir + "/.tmp/scan1" + defer testutil.CleanupDir(dir) + fm := file.NewFileMgr(dir, blockSize) + lm, _ := log.NewLogMgr(fm, "testlogfile") + buffs := make([]buffer.Buffer, 2) + for i := range buffs { + buffs[i] = buffer.NewBuffer(fm, lm, blockSize) + } + bm := buffermgr.NewBufferMgr(buffs) + txNumGen := transaction.NewTxNumberGenerator() + tx, _ := transaction.NewTransaction(fm, lm, bm, txNumGen) + _, _ = metadata.NewMetadataMgr(true, tx) + tx.Commit() + sch1 := record.NewSchema() + sch1.AddIntField("A") + sch1.AddStringField("B", 9) + layout, _ := record.NewLayoutFromSchema(sch1) + tableSc1, _ := record.NewTableScan(tx, "T", layout) + tableSc1.BeforeFirst() + const ( + length = 100 + targetIdx = 50 + ) + ints := newShuffledInts(length) + for i := 0; i < length; i++ { + err := tableSc1.Insert() + assert.NoError(t, err) + tableSc1.SetInt("A", ints[i]) + tableSc1.SetString("B", "rec"+fmt.Sprint(ints[i])) + } + tableSc1.Close() + + tableSc2, _ := record.NewTableScan(tx, "T", layout) + c, _ := constant.NewConstant(constant.KIND_INT, targetIdx) + term := NewTerm(NewFieldExpression("A"), NewConstantExpression(c)) + pred := NewPredicate(term) + t.Logf("The predicate is %v", pred) + selectScan := NewSelectScan(tableSc2, pred) + s4 := NewProjectScan(selectScan, []string{"A", "B"}) + for s4.Next() { + iVal, err := s4.GetInt("A") + assert.NoError(t, err) + assert.Equal(t, targetIdx, iVal) + t.Logf("A: %d", iVal) + sVal, err := s4.GetString("B") + assert.NoError(t, err) + t.Logf("B: %s", sVal) + } + s4.Close() + tx.Commit() +} + +func TestScan2(t *testing.T) { + const blockSize = 400 + rootDir := testutil.ProjectRootDir() + dir := rootDir + "/.tmp/scan2" + defer testutil.CleanupDir(dir) + fm := file.NewFileMgr(dir, blockSize) + lm, _ := log.NewLogMgr(fm, "testlogfile") + buffs := make([]buffer.Buffer, 2) + for i := range buffs { + buffs[i] = buffer.NewBuffer(fm, lm, blockSize) + } + bm := buffermgr.NewBufferMgr(buffs) + txNumGen := transaction.NewTxNumberGenerator() + tx, _ := transaction.NewTransaction(fm, lm, bm, txNumGen) + _, _ = metadata.NewMetadataMgr(true, tx) + + sch1 := record.NewSchema() + sch1.AddIntField("A") + sch1.AddStringField("B", 9) + layout1, _ := record.NewLayoutFromSchema(sch1) + us1, _ := record.NewTableScan(tx, "T1", layout1) + us1.BeforeFirst() + n := 200 + t.Logf("Inserting %d records into T1.\n", n) + for i := 0; i < n; i++ { + us1.Insert() + us1.SetInt("A", i) + us1.SetString("B", "bbb"+fmt.Sprint(i)) + } + us1.Close() + + sch2 := record.NewSchema() + sch2.AddIntField("C") + sch2.AddStringField("D", 9) + layout2, _ := record.NewLayoutFromSchema(sch2) + us2, _ := record.NewTableScan(tx, "T2", layout2) + us2.BeforeFirst() + t.Logf("Inserting %d records into T2.\n", n) + for i := 0; i < n; i++ { + us2.Insert() + us2.SetInt("C", n-i-1) + us2.SetString("D", "ddd"+fmt.Sprint(n-i-1)) + } + us2.Close() + + s1, _ := record.NewTableScan(tx, "T1", layout1) + s2, _ := record.NewTableScan(tx, "T2", layout2) + s3, _ := NewProductScan(s1, s2) + term := NewTerm(NewFieldExpression("A"), NewFieldExpression("C")) + pred := NewPredicate(term) + t.Log("The predicate is", pred) + s4 := NewSelectScan(s3, pred) + + fields := []string{"B", "D"} + s5 := NewProjectScan(s4, fields) + for s5.Next() { + bVal, _ := s5.GetString("B") + dVal, _ := s5.GetString("D") + + t.Logf("%s %s\n", bVal, dVal) + assert.Equal(t, bVal[3:], dVal[3:]) + } + s5.Close() + tx.Commit() +} + +func newShuffledInts(length int) []int { + ints := make([]int, length) + for i := 0; i < length; i++ { + ints[i] = i + } + rand.Shuffle(length, func(i, j int) { + ints[i], ints[j] = ints[j], ints[i] + }) + return ints +} diff --git a/pkg/query/select_scan.go b/pkg/query/select_scan.go new file mode 100644 index 0000000..3554096 --- /dev/null +++ b/pkg/query/select_scan.go @@ -0,0 +1,94 @@ +package query + +import ( + "fmt" + + "github.com/kj455/db/pkg/constant" + "github.com/kj455/db/pkg/record" +) + +type SelectScan struct { + scan Scan + pred Predicate +} + +// NewSelectScan creates a new SelectScan instance +func NewSelectScan(s Scan, pred Predicate) *SelectScan { + return &SelectScan{ + scan: s, + pred: pred, + } +} + +// BeforeFirst positions the scan before its first record +func (s *SelectScan) BeforeFirst() error { + return s.scan.BeforeFirst() +} + +// Next moves the scan to the next record and returns true if there is such a record +func (s *SelectScan) Next() bool { + for s.scan.Next() { + if ok, err := s.pred.IsSatisfied(s.scan); ok && err == nil { + return true + } + } + return false +} + +func (s *SelectScan) GetInt(fldname string) (int, error) { + return s.scan.GetInt(fldname) +} + +func (s *SelectScan) GetString(fldname string) (string, error) { + return s.scan.GetString(fldname) +} + +func (s *SelectScan) GetVal(fldname string) (*constant.Const, error) { + return s.scan.GetVal(fldname) +} + +func (s *SelectScan) HasField(fldname string) bool { + return s.scan.HasField(fldname) +} + +func (s *SelectScan) Close() { + s.scan.Close() +} + +func (s *SelectScan) SetInt(fldname string, val int) error { + us := s.scan.(UpdateScan) + return us.SetInt(fldname, val) +} + +func (s *SelectScan) SetString(fldname string, val string) error { + us := s.scan.(UpdateScan) + return us.SetString(fldname, val) +} + +func (s *SelectScan) SetVal(fldname string, val *constant.Const) error { + us := s.scan.(UpdateScan) + return us.SetVal(fldname, val) +} + +func (s *SelectScan) Delete() error { + us := s.scan.(UpdateScan) + return us.Delete() +} + +func (s *SelectScan) Insert() error { + us := s.scan.(UpdateScan) + return us.Insert() +} + +func (s *SelectScan) GetRid() (record.RID, error) { + us, ok := s.scan.(UpdateScan) + if !ok { + return nil, fmt.Errorf("query: scan is not an UpdateScan") + } + return us.GetRID(), nil +} + +func (s *SelectScan) MoveToRid(rid record.RID) error { + us := s.scan.(UpdateScan) + return us.MoveToRID(rid) +} diff --git a/pkg/query/term.go b/pkg/query/term.go new file mode 100644 index 0000000..44cc27d --- /dev/null +++ b/pkg/query/term.go @@ -0,0 +1,79 @@ +package query + +import ( + "github.com/kj455/db/pkg/constant" + "github.com/kj455/db/pkg/record" +) + +type Term struct { + lhs, rhs Expression +} + +// NewTerm creates a new Term instance with two expressions +func NewTerm(lhs, rhs Expression) *Term { + return &Term{ + lhs: lhs, + rhs: rhs, + } +} + +func (t *Term) IsSatisfied(s Scan) (bool, error) { + lhsVal, err := t.lhs.Evaluate(s) + if err != nil { + return false, err + } + rhsVal, err := t.rhs.Evaluate(s) + if err != nil { + return false, err + } + return lhsVal.Equals(rhsVal), nil +} + +func (t *Term) ReductionFactor(p PlanInfo) int { + var lhsName, rhsName string + if t.lhs.IsFieldName() && t.rhs.IsFieldName() { + lhsName = t.lhs.AsFieldName() + rhsName = t.rhs.AsFieldName() + return max(p.DistinctValues(lhsName), p.DistinctValues(rhsName)) + } + if t.lhs.IsFieldName() { + lhsName = t.lhs.AsFieldName() + return p.DistinctValues(lhsName) + } + if t.rhs.IsFieldName() { + rhsName = t.rhs.AsFieldName() + return p.DistinctValues(rhsName) + } + if t.lhs.AsConstant().Equals(t.rhs.AsConstant()) { + return 1 + } + return int(^uint(0) >> 1) // Max int value +} + +func (t *Term) EquatesWithConstant(field string) (*constant.Const, bool) { + if t.lhs.IsFieldName() && t.lhs.AsFieldName() == field && !t.rhs.IsFieldName() { + return t.rhs.AsConstant(), true + } + if t.rhs.IsFieldName() && t.rhs.AsFieldName() == field && !t.lhs.IsFieldName() { + return t.lhs.AsConstant(), true + } + return nil, false +} + +func (t *Term) EquatesWithField(field string) (string, bool) { + if t.lhs.IsFieldName() && t.lhs.AsFieldName() == field && t.rhs.IsFieldName() { + return t.rhs.AsFieldName(), true + } + if t.rhs.IsFieldName() && t.rhs.AsFieldName() == field && t.lhs.IsFieldName() { + return t.lhs.AsFieldName(), true + } + return "", false +} + +func (t *Term) AppliesTo(sch record.Schema) bool { + return t.lhs.AppliesTo(sch) && t.rhs.AppliesTo(sch) +} + +func (t *Term) String() string { + return t.lhs.ToString() + "=" + t.rhs.ToString() +} diff --git a/pkg/record/table_scan.go b/pkg/record/table_scan.go index 5cec9a1..a73bdb0 100644 --- a/pkg/record/table_scan.go +++ b/pkg/record/table_scan.go @@ -3,6 +3,7 @@ package record import ( "fmt" + "github.com/kj455/db/pkg/constant" "github.com/kj455/db/pkg/file" "github.com/kj455/db/pkg/tx" ) @@ -17,7 +18,7 @@ type TableScanImpl struct { const TABLE_SUFFIX = ".tbl" -func NewTableScan(tx tx.Transaction, table string, layout Layout) (TableScan, error) { +func NewTableScan(tx tx.Transaction, table string, layout Layout) (*TableScanImpl, error) { ts := &TableScanImpl{ tx: tx, layout: layout, @@ -64,16 +65,24 @@ func (ts *TableScanImpl) GetString(field string) (string, error) { return ts.rp.GetString(ts.curSlot, field) } -func (ts *TableScanImpl) GetVal(field string) (any, error) { +func (ts *TableScanImpl) GetVal(field string) (*constant.Const, error) { schemaType, err := ts.layout.Schema().Type(field) if err != nil { return nil, err } switch schemaType { case SCHEMA_TYPE_INTEGER: - return ts.GetInt(field) + v, err := ts.GetInt(field) + if err != nil { + return nil, err + } + return constant.NewConstant(constant.KIND_INT, v) case SCHEMA_TYPE_VARCHAR: - return ts.GetString(field) + v, err := ts.GetString(field) + if err != nil { + return nil, err + } + return constant.NewConstant(constant.KIND_STR, v) default: return nil, fmt.Errorf("record: table scan: get val: unknown type %v", schemaType) } @@ -97,24 +106,22 @@ func (ts *TableScanImpl) SetString(field string, val string) error { return ts.rp.SetString(ts.curSlot, field, val) } -func (ts *TableScanImpl) SetVal(field string, val any) error { +func (ts *TableScanImpl) SetVal(field string, val *constant.Const) error { schemaType, err := ts.layout.Schema().Type(field) if err != nil { return err } switch schemaType { case SCHEMA_TYPE_INTEGER: - v, ok := val.(int) - if !ok { + if val.Kind() != constant.KIND_INT { return fmt.Errorf("record: table scan: set val: expected int, got %v", val) } - return ts.SetInt(field, v) + return ts.SetInt(field, val.AsInt()) case SCHEMA_TYPE_VARCHAR: - v, ok := val.(string) - if !ok { + if val.Kind() != constant.KIND_STR { return fmt.Errorf("record: table scan: set val: expected string, got %v", val) } - return ts.SetString(field, v) + return ts.SetString(field, val.AsString()) } return nil } diff --git a/pkg/testutil/util.go b/pkg/testutil/util.go index b83b81f..5aa06df 100644 --- a/pkg/testutil/util.go +++ b/pkg/testutil/util.go @@ -26,3 +26,8 @@ func ProjectRootDir() string { } return currentDir } + +func CleanupDir(dir string) { + _ = os.RemoveAll(dir) + _ = os.MkdirAll(dir, os.ModePerm) +}