Skip to content

Commit

Permalink
feat,fix: add task manager
Browse files Browse the repository at this point in the history
Implement a task manager that can run different tasks given a unique ID.

This is needed to accommodate expensive tasks like importing a large
repository. The current behavior uses the connection's context (the SSH
connection) to import the repository. However, if the server has defined
an SSH `idle_timeout`, `max_timeout`, and/or the connection drops,
Soft Serve cancels the git clone process and aborts importing the
repository.

Instead, we add the import task to the "task manager" and wait on the
connection context. If a task already exists for the same repository,
return `Error: import already in progress`.

Fixes: #348
  • Loading branch information
aymanbagabas committed Aug 4, 2023
1 parent c7829a3 commit c4dde1c
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 76 deletions.
25 changes: 14 additions & 11 deletions server/backend/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,32 @@ import (
"github.com/charmbracelet/soft-serve/server/config"
"github.com/charmbracelet/soft-serve/server/db"
"github.com/charmbracelet/soft-serve/server/store"
"github.com/charmbracelet/soft-serve/server/task"
)

// Backend is the Soft Serve backend that handles users, repositories, and
// server settings management and operations.
type Backend struct {
ctx context.Context
cfg *config.Config
db *db.DB
store store.Store
logger *log.Logger
cache *cache
ctx context.Context
cfg *config.Config
db *db.DB
store store.Store
logger *log.Logger
cache *cache
manager *task.Manager
}

// New returns a new Soft Serve backend.
func New(ctx context.Context, cfg *config.Config, db *db.DB) *Backend {
dbstore := store.FromContext(ctx)
logger := log.FromContext(ctx).WithPrefix("backend")
b := &Backend{
ctx: ctx,
cfg: cfg,
db: db,
store: dbstore,
logger: logger,
ctx: ctx,
cfg: cfg,
db: db,
store: dbstore,
logger: logger,
manager: task.NewManager(ctx),
}

// TODO: implement a proper caching interface
Expand Down
151 changes: 86 additions & 65 deletions server/backend/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/charmbracelet/soft-serve/server/lfs"
"github.com/charmbracelet/soft-serve/server/proto"
"github.com/charmbracelet/soft-serve/server/storage"
"github.com/charmbracelet/soft-serve/server/task"
"github.com/charmbracelet/soft-serve/server/utils"
)

Expand Down Expand Up @@ -91,7 +92,8 @@ func (d *Backend) CreateRepository(ctx context.Context, name string, user proto.
}

// ImportRepository imports a repository from remote.
func (d *Backend) ImportRepository(ctx context.Context, name string, user proto.User, remote string, opts proto.RepositoryOptions) (proto.Repository, error) {
// XXX: This a expensive operation and should be run in a goroutine.
func (d *Backend) ImportRepository(_ context.Context, name string, user proto.User, remote string, opts proto.RepositoryOptions) (proto.Repository, error) {
name = utils.SanitizeRepo(name)
if err := utils.ValidateRepo(name); err != nil {
return nil, err
Expand All @@ -100,91 +102,110 @@ func (d *Backend) ImportRepository(ctx context.Context, name string, user proto.
repo := name + ".git"
rp := filepath.Join(d.reposPath(), repo)

tid := "import:" + name
if d.manager.Exists(tid) {
return nil, task.ErrAlreadyStarted
}

if _, err := os.Stat(rp); err == nil || os.IsExist(err) {
return nil, proto.ErrRepoExist
}

copts := git.CloneOptions{
Bare: true,
Mirror: opts.Mirror,
Quiet: true,
CommandOptions: git.CommandOptions{
Timeout: -1,
Context: ctx,
Envs: []string{
fmt.Sprintf(`GIT_SSH_COMMAND=ssh -o UserKnownHostsFile="%s" -o StrictHostKeyChecking=no -i "%s"`,
filepath.Join(d.cfg.DataPath, "ssh", "known_hosts"),
d.cfg.SSH.ClientKeyPath,
),
done := make(chan error, 1)
repoc := make(chan proto.Repository, 1)
d.logger.Info("importing repository", "name", name, "remote", remote, "path", rp)
d.manager.Add(tid, func(ctx context.Context) (err error) {
copts := git.CloneOptions{
Bare: true,
Mirror: opts.Mirror,
Quiet: true,
CommandOptions: git.CommandOptions{
Timeout: -1,
Context: ctx,
Envs: []string{
fmt.Sprintf(`GIT_SSH_COMMAND=ssh -o UserKnownHostsFile="%s" -o StrictHostKeyChecking=no -i "%s"`,
filepath.Join(d.cfg.DataPath, "ssh", "known_hosts"),
d.cfg.SSH.ClientKeyPath,
),
},
},
},
}

if err := git.Clone(remote, rp, copts); err != nil {
d.logger.Error("failed to clone repository", "err", err, "mirror", opts.Mirror, "remote", remote, "path", rp)
// Cleanup the mess!
if rerr := os.RemoveAll(rp); rerr != nil {
err = errors.Join(err, rerr)
}

return nil, err
}
if err := git.Clone(remote, rp, copts); err != nil {
d.logger.Error("failed to clone repository", "err", err, "mirror", opts.Mirror, "remote", remote, "path", rp)
// Cleanup the mess!
if rerr := os.RemoveAll(rp); rerr != nil {
err = errors.Join(err, rerr)
}

r, err := d.CreateRepository(ctx, name, user, opts)
if err != nil {
d.logger.Error("failed to create repository", "err", err, "name", name)
return nil, err
}
return err
}

defer func() {
r, err := d.CreateRepository(ctx, name, user, opts)
if err != nil {
if rerr := d.DeleteRepository(ctx, name); rerr != nil {
d.logger.Error("failed to delete repository", "err", rerr, "name", name)
d.logger.Error("failed to create repository", "err", err, "name", name)
return err
}

defer func() {
if err != nil {
if rerr := d.DeleteRepository(ctx, name); rerr != nil {
d.logger.Error("failed to delete repository", "err", rerr, "name", name)
}
}
}()

rr, err := r.Open()
if err != nil {
d.logger.Error("failed to open repository", "err", err, "path", rp)
return err
}
}()

rr, err := r.Open()
if err != nil {
d.logger.Error("failed to open repository", "err", err, "path", rp)
return nil, err
}
repoc <- r

rcfg, err := rr.Config()
if err != nil {
d.logger.Error("failed to get repository config", "err", err, "path", rp)
return nil, err
}
rcfg, err := rr.Config()
if err != nil {
d.logger.Error("failed to get repository config", "err", err, "path", rp)
return err
}

endpoint := remote
if opts.LFSEndpoint != "" {
endpoint = opts.LFSEndpoint
}
endpoint := remote
if opts.LFSEndpoint != "" {
endpoint = opts.LFSEndpoint
}

rcfg.Section("lfs").SetOption("url", endpoint)
rcfg.Section("lfs").SetOption("url", endpoint)

if err := rr.SetConfig(rcfg); err != nil {
d.logger.Error("failed to set repository config", "err", err, "path", rp)
return nil, err
}
if err := rr.SetConfig(rcfg); err != nil {
d.logger.Error("failed to set repository config", "err", err, "path", rp)
return err
}

ep, err := lfs.NewEndpoint(endpoint)
if err != nil {
d.logger.Error("failed to create lfs endpoint", "err", err, "path", rp)
return nil, err
}
ep, err := lfs.NewEndpoint(endpoint)
if err != nil {
d.logger.Error("failed to create lfs endpoint", "err", err, "path", rp)
return err
}

client := lfs.NewClient(ep)
if client == nil {
return nil, fmt.Errorf("failed to create lfs client: unsupported endpoint %s", endpoint)
}
client := lfs.NewClient(ep)
if client == nil {
return fmt.Errorf("failed to create lfs client: unsupported endpoint %s", endpoint)
}

if err := StoreRepoMissingLFSObjects(ctx, r, d.db, d.store, client); err != nil {
d.logger.Error("failed to store missing lfs objects", "err", err, "path", rp)
return nil, err
}
if err := StoreRepoMissingLFSObjects(ctx, r, d.db, d.store, client); err != nil {
d.logger.Error("failed to store missing lfs objects", "err", err, "path", rp)
return err
}

return r, nil
return nil
})

go func() {
d.logger.Info("running import", "name", name)
d.manager.Run(tid, done)
}()

return <-repoc, <-done
}

// DeleteRepository deletes a repository.
Expand Down
8 changes: 8 additions & 0 deletions server/ssh/cmd/import.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package cmd

import (
"errors"

"github.com/charmbracelet/soft-serve/server/backend"
"github.com/charmbracelet/soft-serve/server/proto"
"github.com/charmbracelet/soft-serve/server/task"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -36,8 +39,13 @@ func importCommand() *cobra.Command {
LFS: lfs,
LFSEndpoint: lfsEndpoint,
}); err != nil {
if errors.Is(err, task.ErrAlreadyStarted) {
return errors.New("import already in progress")
}

return err
}

return nil
},
}
Expand Down
116 changes: 116 additions & 0 deletions server/task/manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package task

import (
"context"
"errors"
"sync"
"sync/atomic"
)

var (
// ErrNotFound is returned when a process is not found.
ErrNotFound = errors.New("task not found")

// ErrAlreadyStarted is returned when a process is already started.
ErrAlreadyStarted = errors.New("task already started")
)

// Task is a task that can be started and stopped.
type Task struct {
id string
fn func(context.Context) error
started atomic.Bool
ctx context.Context
cancel context.CancelFunc
err error
}

// Manager manages tasks.
type Manager struct {
m sync.Map
ctx context.Context
}

// NewManager returns a new task manager.
func NewManager(ctx context.Context) *Manager {
return &Manager{
m: sync.Map{},
ctx: ctx,
}
}

// Add adds a task to the manager.
// If the process already exists, it is a no-op.
func (m *Manager) Add(id string, fn func(context.Context) error) {
if m.Exists(id) {
return
}

ctx, cancel := context.WithCancel(m.ctx)
m.m.Store(id, &Task{
id: id,
fn: fn,
ctx: ctx,
cancel: cancel,
})
}

// Stop stops the task and removes it from the manager.
func (m *Manager) Stop(id string) error {
v, ok := m.m.Load(id)
if !ok {
return ErrNotFound
}

p := v.(*Task)
p.cancel()

m.m.Delete(id)
return nil
}

// Exists checks if a task exists.
func (m *Manager) Exists(id string) bool {
_, ok := m.m.Load(id)
return ok
}

// Run starts the task if it exists.
// Otherwise, it waits for the process to finish.
func (m *Manager) Run(id string, done chan<- error) {
v, ok := m.m.Load(id)
if !ok {
done <- ErrNotFound
return
}

p := v.(*Task)
if p.started.Load() {
<-p.ctx.Done()
if p.err != nil {
done <- p.err
return
}

done <- p.ctx.Err()
}

p.started.Store(true)
m.m.Store(id, p)
defer p.cancel()
defer m.m.Delete(id)

errc := make(chan error, 1)
go func(ctx context.Context) {
errc <- p.fn(ctx)
}(p.ctx)

select {
case <-m.ctx.Done():
done <- m.ctx.Err()
case err := <-errc:
p.err = err
m.m.Store(id, p)
done <- err
}
}

0 comments on commit c4dde1c

Please sign in to comment.