Skip to content

Commit

Permalink
Refactor database initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
MHSanaei committed Jul 13, 2024
1 parent 60cb328 commit dfe0bbd
Showing 1 changed file with 48 additions and 38 deletions.
86 changes: 48 additions & 38 deletions database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"io"
"io/fs"
"log"
"os"
"path"

Expand All @@ -18,54 +19,51 @@ import (

var db *gorm.DB

var initializers = []func() error{
initUser,
initInbound,
initOutbound,
initSetting,
initInboundClientIps,
initClientTraffic,
const (
defaultUsername = "admin"
defaultPassword = "admin"
defaultSecret = ""
)

func initModels() error {
models := []interface{}{
&model.User{},
&model.Inbound{},
&model.OutboundTraffics{},
&model.Setting{},
&model.InboundClientIps{},
&xray.ClientTraffic{},
}
for _, model := range models {
if err := db.AutoMigrate(model); err != nil {
log.Printf("Error auto migrating model: %v", err)
return err
}
}
return nil
}

func initUser() error {
err := db.AutoMigrate(&model.User{})
empty, err := isTableEmpty("users")
if err != nil {
log.Printf("Error checking if users table is empty: %v", err)
return err
}
var count int64
err = db.Model(&model.User{}).Count(&count).Error
if err != nil {
return err
}
if count == 0 {
if empty {
user := &model.User{
Username: "admin",
Password: "admin",
LoginSecret: "",
Username: defaultUsername,
Password: defaultPassword,
LoginSecret: defaultSecret,
}
return db.Create(user).Error
}
return nil
}

func initInbound() error {
return db.AutoMigrate(&model.Inbound{})
}

func initOutbound() error {
return db.AutoMigrate(&model.OutboundTraffics{})
}

func initSetting() error {
return db.AutoMigrate(&model.Setting{})
}

func initInboundClientIps() error {
return db.AutoMigrate(&model.InboundClientIps{})
}

func initClientTraffic() error {
return db.AutoMigrate(&xray.ClientTraffic{})
func isTableEmpty(tableName string) (bool, error) {
var count int64
err := db.Table(tableName).Count(&count).Error
return count == 0, err
}

func InitDB(dbPath string) error {
Expand All @@ -91,12 +89,24 @@ func InitDB(dbPath string) error {
return err
}

for _, initialize := range initializers {
if err := initialize(); err != nil {
if err := initModels(); err != nil {
return err
}
if err := initUser(); err != nil {
return err
}

return nil
}

func CloseDB() error {
if db != nil {
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}

return nil
}

Expand Down

0 comments on commit dfe0bbd

Please sign in to comment.