Skip to content

Commit

Permalink
Merge pull request #19 from dipdup-io/feature/new-orm
Browse files Browse the repository at this point in the history
Feature: add bun as ORM and test
  • Loading branch information
aopoltorzhicky committed Aug 14, 2023
2 parents e9794d1 + 29ec6a2 commit f789138
Show file tree
Hide file tree
Showing 17 changed files with 1,342 additions and 364 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
steps:
- uses: actions/setup-go@v3
with:
go-version: '1.19'
go-version: '1.20'
- uses: actions/checkout@v3
- name: golangci-lint
uses: golangci/golangci-lint-action@v3
Expand All @@ -21,7 +21,7 @@ jobs:
- name: install Go
uses: actions/setup-go@v2
with:
go-version: 1.18.x
go-version: 1.20.x
- name: checkout code
uses: actions/checkout@v2
- uses: actions/cache@v2
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
lint:
golangci-lint run --go=1.18
golangci-lint run

test:
go test ./...
164 changes: 164 additions & 0 deletions database/bun.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package database

import (
"context"
"database/sql"
"fmt"
"runtime"

"github.com/dipdup-net/go-lib/config"
"github.com/pkg/errors"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
)

// Bun -
type Bun struct {
sqldb *sql.DB
conn *bun.DB
}

// NewBun -
func NewBun() *Bun {
return new(Bun)
}

// DB -
func (db *Bun) DB() *bun.DB {
return db.conn
}

// Connect -
func (db *Bun) Connect(ctx context.Context, cfg config.Database) error {
if cfg.Kind != config.DBKindPostgres {
return errors.Wrap(ErrUnsupportedDatabaseType, cfg.Kind)
}
if cfg.Path != "" {
conn, err := sql.Open("postgres", cfg.Path)
if err != nil {
return err
}
db.sqldb = conn
db.conn = bun.NewDB(db.sqldb, pgdialect.New())
} else {
connStr := fmt.Sprintf(
"postgres://%s:%s@%s:%d/%s?sslmode=disable",
cfg.User,
cfg.Password,
cfg.Host,
cfg.Port,
cfg.Database,
)
conn, err := sql.Open("postgres", connStr)
if err != nil {
return err
}
db.sqldb = conn
db.conn = bun.NewDB(db.sqldb, pgdialect.New())
}
maxOpenConns := 4 * runtime.GOMAXPROCS(0)
db.sqldb.SetMaxOpenConns(maxOpenConns)
db.sqldb.SetMaxIdleConns(maxOpenConns)
return nil
}

// Close -
func (db *Bun) Close() error {
if err := db.conn.Close(); err != nil {
return err
}
return db.sqldb.Close()
}

// Exec -
func (db *Bun) Exec(ctx context.Context, query string, args ...any) (int64, error) {
result, err := db.conn.ExecContext(ctx, query, args...)
if err != nil {
return 0, err
}
return result.RowsAffected()
}

// Ping -
func (db *Bun) Ping(ctx context.Context) error {
if db.conn == nil {
return ErrConnectionIsNotInitialized
}
return db.conn.PingContext(ctx)
}

// State -
func (db *Bun) State(ctx context.Context, indexName string) (*State, error) {
var s State
err := db.conn.NewSelect().Model(&s).Where("index_name = ?", indexName).Limit(1).Scan(ctx)
return &s, err
}

// CreateState -
func (db *Bun) CreateState(ctx context.Context, s *State) error {
_, err := db.conn.NewInsert().Model(s).Exec(ctx)
return err
}

// UpdateState -
func (db *Bun) UpdateState(ctx context.Context, s *State) error {
_, err := db.conn.NewUpdate().Model(s).Where("index_name = ?", s.IndexName).Exec(ctx)
return err
}

// DeleteState -
func (db *Bun) DeleteState(ctx context.Context, s *State) error {
_, err := db.conn.NewDelete().Model(s).Where("index_name = ?", s.IndexName).Exec(ctx)
return err
}

// MakeTableComment -
func (db *Bun) MakeTableComment(ctx context.Context, name string, comment string) error {
_, err := db.conn.ExecContext(ctx,
`COMMENT ON TABLE ? IS ?`,
bun.Ident(name),
comment)

return err
}

// MakeColumnComment -
func (db *Bun) MakeColumnComment(ctx context.Context, tableName string, columnName string, comment string) error {
_, err := db.conn.ExecContext(ctx,
`COMMENT ON COLUMN ?.? IS ?`,
bun.Ident(tableName),
bun.Ident(columnName),
comment)

return err
}

// CreateTable -
func (db *Bun) CreateTable(ctx context.Context, model any, opts ...CreateTableOption) error {
if model == nil {
return nil
}
var options CreateTableOptions
for i := range opts {
opts[i](&options)
}

query := db.DB().
NewCreateTable().
Model(model)

if options.ifNotExists {
query = query.IfNotExists()
}

if options.partitionBy != "" {
query = query.PartitionBy(options.partitionBy)
}

if options.temporary {
query = query.Temp()
}

_, err := query.Exec(ctx)
return err
}
43 changes: 31 additions & 12 deletions database/comment.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@ package database

import (
"context"
"github.com/dipdup-net/go-lib/hasura"
"github.com/pkg/errors"
"reflect"
"strings"

"github.com/dipdup-net/go-lib/hasura"
"github.com/pkg/errors"
)

const (
fieldTableName = "tableName"
fieldBaseModel = "BaseModel"
)

// MakeComments -
func MakeComments(ctx context.Context, sc SchemeCommenter, models ...interface{}) error {
if models == nil {
if len(models) == 0 {
return nil
}

Expand All @@ -29,9 +36,9 @@ func MakeComments(ctx context.Context, sc SchemeCommenter, models ...interface{}
for i := 0; i < modelType.NumField(); i++ {
fieldType := modelType.Field(i)

if fieldType.Name == "tableName" {
if fieldType.Name == fieldTableName || fieldType.Name == fieldBaseModel {
var ok bool
tableName, ok = getPgName(fieldType)
tableName, ok = getDatabaseTagName(fieldType)
if !ok {
tableName = hasura.ToSnakeCase(modelType.Name())
}
Expand Down Expand Up @@ -75,7 +82,7 @@ func makeEmbeddedComments(ctx context.Context, sc SchemeCommenter, tableName str
continue
}

if fieldType.Name == "tableName" {
if fieldType.Name == fieldTableName {
return errors.New("Embedded type must not have tableName field.")
}

Expand All @@ -93,7 +100,7 @@ func makeFieldComment(ctx context.Context, sc SchemeCommenter, tableName string,
return nil
}

columnName, ok := getPgName(fieldType)
columnName, ok := getDatabaseTagName(fieldType)

if columnName == "-" {
return nil
Expand All @@ -110,13 +117,26 @@ func makeFieldComment(ctx context.Context, sc SchemeCommenter, tableName string,
return nil
}

func getPgName(fieldType reflect.StructField) (name string, ok bool) {
pgTag, ok := fieldType.Tag.Lookup("pg")
if !ok {
func getDatabaseTagName(fieldType reflect.StructField) (name string, ok bool) {
pgTag, pgOk := fieldType.Tag.Lookup("pg")
bunTag, bunOk := fieldType.Tag.Lookup("bun")
ok = pgOk || bunOk

var tag string
switch {
case !pgOk && !bunOk:
return "", false
case pgOk && pgTag != "-":
tag = pgTag
case bunOk && bunTag != "-":
tag = strings.TrimPrefix(bunTag, "table:")
case pgOk:
tag = pgTag
case bunOk:
tag = bunTag
}

tags := strings.Split(pgTag, ",")
tags := strings.Split(tag, ",")

if tags[0] != "" {
name = tags[0]
Expand All @@ -128,7 +148,6 @@ func getPgName(fieldType reflect.StructField) (name string, ok bool) {

func getComment(fieldType reflect.StructField) (string, bool) {
commentTag, ok := fieldType.Tag.Lookup("comment")

if ok {
return commentTag, ok
}
Expand Down
Loading

0 comments on commit f789138

Please sign in to comment.