diff --git a/cmd/shimdiag/lm.go b/cmd/shimdiag/lm.go
index 7628b4ab35..f1d411c57a 100644
--- a/cmd/shimdiag/lm.go
+++ b/cmd/shimdiag/lm.go
@@ -4,279 +4,23 @@ package main
import (
"context"
- "encoding/json"
"fmt"
- "io"
- "net"
"os"
- "path/filepath"
"strconv"
- "sync"
- "time"
"github.com/Microsoft/go-winio"
runhcsopts "github.com/Microsoft/hcsshim/cmd/containerd-shim-runhcs-v1/options"
"github.com/Microsoft/hcsshim/internal/appargs"
lmproto "github.com/Microsoft/hcsshim/internal/lm/proto"
statepkg "github.com/Microsoft/hcsshim/internal/state"
- eventtypes "github.com/containerd/containerd/api/events"
- "github.com/containerd/containerd/api/runtime/task/v2"
- "github.com/containerd/containerd/api/services/ttrpc/events/v1"
- "github.com/containerd/containerd/api/types"
"github.com/containerd/ttrpc"
- "github.com/containerd/typeurl/v2"
"github.com/urfave/cli"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/anypb"
- "google.golang.org/protobuf/types/known/emptypb"
)
-var createCommand = cli.Command{
- Name: "create",
- Usage: "Creates a task",
- ArgsUsage: "[flags]
",
- Flags: []cli.Flag{
- cli.StringFlag{
- Name: "stdin",
- Usage: "Named pipe path",
- },
- cli.StringFlag{
- Name: "stdout",
- Usage: "Named pipe path",
- },
- cli.StringFlag{
- Name: "stderr",
- Usage: "Named pipe path",
- },
- cli.BoolFlag{
- Name: "tty",
- Usage: "Enable terminal mode for task IO",
- },
- cli.StringFlag{
- Name: "rootfs",
- Usage: "JSON file to read rootfs from",
- },
- cli.StringFlag{
- Name: "options",
- Usage: "jsonpb file to read shim options from",
- },
- },
- SkipArgReorder: true,
- Before: appargs.Validate(appargs.String, appargs.String, appargs.String),
- Action: func(clictx *cli.Context) error {
- args := clictx.Args()
- address := args[0]
- id := args[1]
-
- bundle, err := filepath.Abs(args[2])
- if err != nil {
- return err
- }
-
- var rootfs []*types.Mount
- if clictx.IsSet("rootfs") {
- data, err := os.ReadFile(clictx.String("rootfs"))
- if err != nil {
- return err
- }
- if err := json.Unmarshal(data, &rootfs); err != nil {
- return err
- }
- }
-
- var options *anypb.Any
- if clictx.IsSet("options") {
- data, err := os.ReadFile(clictx.String("options"))
- if err != nil {
- return err
- }
- var opts runhcsopts.Options
- if err := (protojson.UnmarshalOptions{}).Unmarshal(data, &opts); err != nil {
- return err
- }
- any, err := typeurl.MarshalAny(&opts)
- if err != nil {
- return err
- }
- options = &anypb.Any{TypeUrl: any.GetTypeUrl(), Value: any.GetValue()}
- }
-
- conn, err := winio.DialPipe(address, nil)
- if err != nil {
- return fmt.Errorf("dial %s: %w", address, err)
- }
-
- client := ttrpc.NewClient(conn)
- svc := task.NewTaskClient(client)
-
- ctx := context.Background()
-
- {
- resp, err := svc.Create(ctx, &task.CreateTaskRequest{
- ID: id,
- Bundle: bundle,
- Rootfs: rootfs,
- Terminal: clictx.Bool("tty"),
- Stdin: clictx.String("stdin"),
- Stdout: clictx.String("stdout"),
- Stderr: clictx.String("stderr"),
- Options: options,
- })
- if err != nil {
- return fmt.Errorf("task.Create: %w", err)
- }
- fmt.Printf("task pid is %d\n", resp.Pid)
- }
- return nil
- },
-}
-
-var startCommand = cli.Command{
- Name: "start",
- Usage: "",
- ArgsUsage: "[flags] ",
- Flags: []cli.Flag{},
- SkipArgReorder: true,
- Before: appargs.Validate(appargs.String, appargs.String),
- Action: func(clictx *cli.Context) error {
- args := clictx.Args()
- address := args[0]
- id := args[1]
-
- conn, err := winio.DialPipe(address, nil)
- if err != nil {
- return fmt.Errorf("dial %s: %w", address, err)
- }
-
- client := ttrpc.NewClient(conn)
- svc := task.NewTaskClient(client)
-
- ctx := context.Background()
-
- {
- resp, err := svc.Start(ctx, &task.StartRequest{
- ID: id,
- })
- if err != nil {
- return fmt.Errorf("task.Start: %w", err)
- }
- fmt.Printf("task pid is %d\n", resp.Pid)
- }
- return nil
- },
-}
-
-var pipeCommand = cli.Command{
- Name: "pipe",
- Usage: "",
- ArgsUsage: "[flags] ",
- Flags: []cli.Flag{},
- SkipArgReorder: true,
- Action: func(clictx *cli.Context) error {
- args := clictx.Args()
-
- f := func(name, pipe string, wg *sync.WaitGroup, copy func(c net.Conn) (int64, error)) {
- defer wg.Done()
- l, err := winio.ListenPipe(pipe, nil)
- if err != nil {
- panic(err)
- }
- fmt.Printf("%s: listening on %s\n", name, pipe)
- c, err := l.Accept()
- if err != nil {
- panic(err)
- }
- fmt.Printf("%s: received connection\n", name)
- n, err := copy(c)
- fmt.Printf("%s: copy completed after %d bytes", name, n)
- if err != nil {
- fmt.Printf(" and error: %s", err)
- }
- fmt.Printf("\n")
- }
-
- var wg sync.WaitGroup
- if len(args) > 0 {
- wg.Add(1)
- go f("stdin", args[0], &wg, func(c net.Conn) (int64, error) { return io.Copy(c, os.Stdin) })
- }
- if len(args) > 1 {
- wg.Add(1)
- go f("stdout", args[1], &wg, func(c net.Conn) (int64, error) { return io.Copy(os.Stdout, c) })
- }
- if len(args) > 2 {
- wg.Add(1)
- go f("stderr", args[2], &wg, func(c net.Conn) (int64, error) { return io.Copy(os.Stderr, c) })
- }
- wg.Wait()
-
- return nil
- },
-}
-
-type eventsSvc struct {
- m sync.Mutex
-}
-
-func (e *eventsSvc) Forward(ctx context.Context, req *events.ForwardRequest) (*emptypb.Empty, error) {
- e.m.Lock()
- defer e.m.Unlock()
-
- fmt.Printf("[%s][%s]: %s\n", req.Envelope.Timestamp.AsTime().Format(time.RFC3339), req.Envelope.Namespace, req.Envelope.Topic)
- v, err := typeurl.UnmarshalAny(req.Envelope.Event)
- if err != nil {
- fmt.Printf("\tunmarshal failed: %s\n", err)
- }
- switch v := v.(type) {
- case *eventtypes.TaskCreate:
- fmt.Printf("\tContainerID: %s\n", v.ContainerID)
- fmt.Printf("\tBundle: %s\n", v.Bundle)
- fmt.Printf("\tPID: %d\n", v.Pid)
- case *eventtypes.TaskStart:
- fmt.Printf("\tContainerID: %s\n", v.ContainerID)
- fmt.Printf("\tPID: %d\n", v.Pid)
- case *eventtypes.TaskExit:
- fmt.Printf("\tID: %s\n", v.ID)
- fmt.Printf("\tContainerID: %s\n", v.ContainerID)
- fmt.Printf("\tPID: %d\n", v.Pid)
- fmt.Printf("\tExitStatus: %d\n", v.ExitStatus)
- fmt.Printf("\tExitedAt: %v\n", v.ExitedAt.AsTime().Format(time.RFC3339))
- default:
- fmt.Printf("\tunrecognized event type: %T\n", v)
- }
- return &emptypb.Empty{}, nil
-}
-
-var eventsCommand = cli.Command{
- Name: "events",
- Usage: "",
- ArgsUsage: "[flags] ",
- Flags: []cli.Flag{},
- SkipArgReorder: true,
- Before: appargs.Validate(appargs.String),
- Action: func(clictx *cli.Context) error {
- args := clictx.Args()
- address := args[0]
-
- l, err := winio.ListenPipe(address, nil)
- if err != nil {
- return err
- }
-
- server, err := ttrpc.NewServer()
- if err != nil {
- return err
- }
- events.RegisterEventsService(server, &eventsSvc{})
- if err := server.Serve(context.Background(), l); err != nil {
- return err
- }
- return nil
- },
-}
-
var lmPrepareCommand = cli.Command{
Name: "lmprepare",
Usage: "Prepares the sandbox for migration",
@@ -572,57 +316,3 @@ var pb2jsonCommand = cli.Command{
return nil
},
}
-
-var deleteCommand = cli.Command{
- Name: "delete",
- ArgsUsage: " ",
- SkipArgReorder: true,
- Before: appargs.Validate(appargs.String, appargs.String),
- Action: func(clictx *cli.Context) error {
- args := clictx.Args()
- address := args[0]
- id := args[1]
-
- conn, err := winio.DialPipe(address, nil)
- if err != nil {
- return fmt.Errorf("dial %s: %w", address, err)
- }
-
- client := ttrpc.NewClient(conn)
- svc := task.NewTaskClient(client)
-
- ctx := context.Background()
-
- if _, err := svc.Delete(ctx, &task.DeleteRequest{ID: id}); err != nil {
- return err
- }
- return nil
- },
-}
-
-var shutdownCommand = cli.Command{
- Name: "shutdown",
- ArgsUsage: " ",
- SkipArgReorder: true,
- Before: appargs.Validate(appargs.String, appargs.String),
- Action: func(clictx *cli.Context) error {
- args := clictx.Args()
- address := args[0]
- id := args[1]
-
- conn, err := winio.DialPipe(address, nil)
- if err != nil {
- return fmt.Errorf("dial %s: %w", address, err)
- }
-
- client := ttrpc.NewClient(conn)
- svc := task.NewTaskClient(client)
-
- ctx := context.Background()
-
- if _, err := svc.Shutdown(ctx, &task.ShutdownRequest{ID: id}); err != nil {
- return err
- }
- return nil
- },
-}
diff --git a/cmd/shimdiag/shimdiag.go b/cmd/shimdiag/shimdiag.go
index a7097523bc..ea3720aed4 100644
--- a/cmd/shimdiag/shimdiag.go
+++ b/cmd/shimdiag/shimdiag.go
@@ -27,14 +27,8 @@ func main() {
lmDialCommand,
lmTransferCommand,
lmFinalizeCommand,
- createCommand,
- pipeCommand,
- startCommand,
- eventsCommand,
json2pbCommand,
pb2jsonCommand,
- deleteCommand,
- shutdownCommand,
}
if err := app.Run(os.Args); err != nil {
fmt.Fprintln(os.Stderr, err)
diff --git a/cmd/task/core.go b/cmd/task/core.go
index b70b4e8929..0b20c3c8eb 100644
--- a/cmd/task/core.go
+++ b/cmd/task/core.go
@@ -200,8 +200,13 @@ var startCommand = cli.Command{
}
var deleteCommand = cli.Command{
- Name: "delete",
- ArgsUsage: " ",
+ Name: "delete",
+ ArgsUsage: " ",
+ Flags: []cli.Flag{
+ cli.StringFlag{
+ Name: "execid",
+ },
+ },
SkipArgReorder: true,
Before: appargs.Validate(appargs.String, appargs.String),
Action: func(clictx *cli.Context) error {
@@ -209,6 +214,8 @@ var deleteCommand = cli.Command{
address := args[0]
id := args[1]
+ execID := clictx.String("execid")
+
conn, err := winio.DialPipe(address, nil)
if err != nil {
return fmt.Errorf("dial %s: %w", address, err)
@@ -219,7 +226,7 @@ var deleteCommand = cli.Command{
ctx := context.Background()
- if _, err := svc.Delete(ctx, &task.DeleteRequest{ID: id}); err != nil {
+ if _, err := svc.Delete(ctx, &task.DeleteRequest{ID: id, ExecID: execID}); err != nil {
return err
}
return nil
diff --git a/cmd/task/extra.go b/cmd/task/extra.go
index f11d36683c..52652fc865 100644
--- a/cmd/task/extra.go
+++ b/cmd/task/extra.go
@@ -4,13 +4,15 @@ import (
"context"
"fmt"
"io"
- "net"
"os"
+ "path/filepath"
"sync"
+ "syscall"
"time"
"github.com/Microsoft/go-winio"
"github.com/Microsoft/hcsshim/internal/appargs"
+ "github.com/containerd/console"
eventtypes "github.com/containerd/containerd/api/events"
"github.com/containerd/containerd/api/services/ttrpc/events/v1"
"github.com/containerd/ttrpc"
@@ -19,48 +21,97 @@ import (
"google.golang.org/protobuf/types/known/emptypb"
)
-var pipeCommand = cli.Command{
- Name: "pipe",
- Usage: "",
- ArgsUsage: "[flags] ",
- Flags: []cli.Flag{},
+type rawConReader struct {
+ f *os.File
+}
+
+func (r rawConReader) Read(b []byte) (int, error) {
+ n, err := syscall.Read(syscall.Handle(r.f.Fd()), b)
+ if n == 0 && len(b) != 0 && err == nil {
+ // A zero-byte read on a console indicates that the user wrote Ctrl-Z.
+ b[0] = 26
+ return 1, nil
+ }
+ return n, err
+}
+
+func pipeIO(name string, path string, f interface{}, in bool, wg *sync.WaitGroup) error {
+ l, err := winio.ListenPipe(path, nil)
+ if err != nil {
+ return err
+ }
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ fmt.Printf("%s: listening on %s\n", name, path)
+ c, err := l.Accept()
+ if err != nil {
+ fmt.Printf("%s: connection failed: %s\n", name, err)
+ return
+ }
+ fmt.Printf("%s: received connection\n", name)
+ var copy func() (int64, error)
+ if in {
+ copy = func() (int64, error) { return io.Copy(c, f.(io.Reader)) }
+ defer c.Close()
+ } else {
+ copy = func() (int64, error) { return io.Copy(f.(io.Writer), c) }
+ }
+ n, err := copy()
+ fmt.Printf("%s: copy completed after %d bytes", name, n)
+ if err != nil {
+ fmt.Printf(" with error: %s", err)
+ }
+ fmt.Printf("\n")
+ }()
+ return nil
+}
+
+var ioCommand = cli.Command{
+ Name: "io",
+ Usage: "",
+ ArgsUsage: "[flags] ",
+ Flags: []cli.Flag{
+ cli.BoolFlag{
+ Name: "tty",
+ },
+ },
SkipArgReorder: true,
+ Before: appargs.Validate(appargs.String),
Action: func(clictx *cli.Context) error {
- args := clictx.Args()
+ pipeBase := clictx.Args()[0]
- f := func(name, pipe string, wg *sync.WaitGroup, copy func(c net.Conn) (int64, error)) {
- defer wg.Done()
- l, err := winio.ListenPipe(pipe, nil)
+ var stdin io.Reader = os.Stdin
+ if clictx.Bool("tty") {
+ con, err := console.ConsoleFromFile(os.Stdin)
if err != nil {
- panic(err)
+ return err
}
- fmt.Printf("%s: listening on %s\n", name, pipe)
- c, err := l.Accept()
- if err != nil {
- panic(err)
+ if err := con.SetRaw(); err != nil {
+ return err
}
- fmt.Printf("%s: received connection\n", name)
- n, err := copy(c)
- fmt.Printf("%s: copy completed after %d bytes", name, n)
- if err != nil {
- fmt.Printf(" and error: %s", err)
- }
- fmt.Printf("\n")
+ defer con.Reset()
+ stdin = rawConReader{os.Stdin}
}
var wg sync.WaitGroup
- if len(args) > 0 {
- wg.Add(1)
- go f("stdin", args[0], &wg, func(c net.Conn) (int64, error) { return io.Copy(c, os.Stdin) })
+
+ stdinPath := filepath.Join(pipeBase, "stdin")
+ if err := pipeIO("stdin", stdinPath, stdin, true, &wg); err != nil {
+ return err
}
- if len(args) > 1 {
- wg.Add(1)
- go f("stdout", args[1], &wg, func(c net.Conn) (int64, error) { return io.Copy(os.Stdout, c) })
+ stdoutPath := filepath.Join(pipeBase, "stdout")
+ if err := pipeIO("stdout", stdoutPath, os.Stdout, false, &wg); err != nil {
+ return err
}
- if len(args) > 2 {
- wg.Add(1)
- go f("stderr", args[2], &wg, func(c net.Conn) (int64, error) { return io.Copy(os.Stderr, c) })
+ var stderrPath string
+ if !clictx.Bool("tty") {
+ stderrPath = filepath.Join(pipeBase, "stderr")
+ if err := pipeIO("stderr", stderrPath, os.Stderr, false, &wg); err != nil {
+ return err
+ }
}
+
wg.Wait()
return nil
diff --git a/cmd/task/main.go b/cmd/task/main.go
index 0f0ee213c9..c9b3bae247 100644
--- a/cmd/task/main.go
+++ b/cmd/task/main.go
@@ -26,7 +26,7 @@ func main() {
connectCommand,
shutdownCommand,
// Extra
- pipeCommand,
+ ioCommand,
eventsCommand,
}
if err := app.Run(os.Args); err != nil {
diff --git a/internal/core/linuxvm/migrator.go b/internal/core/linuxvm/migrator.go
index 4eb551fde4..d432b22bc9 100644
--- a/internal/core/linuxvm/migrator.go
+++ b/internal/core/linuxvm/migrator.go
@@ -3,10 +3,12 @@ package linuxvm
import (
"context"
"fmt"
+ "io"
"net"
"github.com/Microsoft/go-winio"
"github.com/Microsoft/go-winio/pkg/guid"
+ "github.com/Microsoft/hcsshim/internal/cmd"
"github.com/Microsoft/hcsshim/internal/core"
"github.com/Microsoft/hcsshim/internal/guestmanager"
"github.com/Microsoft/hcsshim/internal/hns"
@@ -65,14 +67,21 @@ func (s *Sandbox) LMPrepare(ctx context.Context) (_ *statepkg.SandboxState, _ *c
s.isLMSrc = true
vmConfig := s.vm.Config()
vmConfig.NICs = nil
+ containers := make(map[string]*statepkg.Container)
+ for id, ctr := range s.ctrs {
+ containers[id] = &statepkg.Container{
+ InitPid: uint32(ctr.Pid()),
+ }
+ }
s.state = &statepkg.SandboxState{
Vm: &statepkg.VMState{
Config: statepkg.VMConfigFromInternal(vmConfig),
CompatInfo: compatInfo,
Resources: intResources,
},
- Agent: s.gm.State(),
- Ifaces: s.ifaces,
+ Agent: s.gm.State(),
+ Ifaces: s.ifaces,
+ Containers: containers,
}
return s.state, resources, nil
}
@@ -89,17 +98,18 @@ func (s *Sandbox) LMTransfer(ctx context.Context, socket uintptr) (core.Migrated
}
type migrated struct {
- vm *vm.VM
- agentConfig *statepkg.GCState
- newNetNS string
- oldIfaces []*statepkg.GuestInterface
+ vm *vm.VM
+ sandboxContainer *statepkg.Container
+ agentConfig *statepkg.GCState
+ newNetNS string
+ oldIfaces []*statepkg.GuestInterface
}
func (m *migrated) LMComplete(ctx context.Context) (core.Sandbox, error) {
if err := m.vm.LMFinalize(ctx, true); err != nil {
return nil, err
}
- return newSandbox(ctx, m.vm, m.agentConfig, m.newNetNS, m.oldIfaces)
+ return newSandbox(ctx, m.vm, m.sandboxContainer, m.agentConfig, m.newNetNS, m.oldIfaces)
}
func (m *migrated) LMKill(ctx context.Context) error {
@@ -145,14 +155,15 @@ func (m *migrator) LMTransfer(ctx context.Context, socket uintptr) (core.Migrate
return nil, err
}
return &migrated{
- vm: m.vm,
- agentConfig: m.sandboxState.Agent,
- newNetNS: m.netns,
- oldIfaces: m.sandboxState.Ifaces,
+ vm: m.vm,
+ sandboxContainer: m.sandboxState.Containers["SANDBOX"],
+ agentConfig: m.sandboxState.Agent,
+ newNetNS: m.netns,
+ oldIfaces: m.sandboxState.Ifaces,
}, nil
}
-func newSandbox(ctx context.Context, vm *vm.VM, agentConfig *statepkg.GCState, newNetNS string, oldIFaces []*statepkg.GuestInterface) (core.Sandbox, error) {
+func newSandbox(ctx context.Context, vm *vm.VM, sandboxContainer *statepkg.Container, agentConfig *statepkg.GCState, newNetNS string, oldIFaces []*statepkg.GuestInterface) (core.Sandbox, error) {
gm, err := guestmanager.NewLinuxManagerFromState(
func(port uint32) (net.Listener, error) { return vm.ListenHVSocket(winio.VsockServiceID(port)) },
agentConfig)
@@ -203,10 +214,16 @@ func newSandbox(ctx context.Context, vm *vm.VM, agentConfig *statepkg.GCState, n
}
waitCtx, waitCancel := context.WithCancel(context.Background())
- return &Sandbox{
+ gt := newGuestThing(gm)
+ pauseCtr, err := restoreContainer(ctx, gt, waitCtx, "SANDBOX", sandboxContainer.InitPid, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ sandbox := &Sandbox{
vm: vm,
gm: gm,
- gt: newGuestThing(gm),
+ gt: gt,
translator: &translator{
vm: vm,
scsiAttacher: nil,
@@ -217,5 +234,47 @@ func newSandbox(ctx context.Context, vm *vm.VM, agentConfig *statepkg.GCState, n
ifaces: ifaces,
waitCtx: waitCtx,
waitCancel: waitCancel,
- }, nil
+ pauseCtr: pauseCtr,
+ }
+ go sandbox.waitBackground()
+ return sandbox, nil
+}
+
+func restoreContainer(ctx context.Context, gt *guestThing, waitCtx context.Context, cid string, pid uint32, myIO cmd.UpstreamIO) (*ctr, error) {
+ innerCtr, err := gt.OpenContainer(ctx, cid)
+ if err != nil {
+ return nil, err
+ }
+ var (
+ stdin io.Reader
+ stdout, stderr io.Writer
+ )
+ if myIO != nil {
+ stdin = myIO.Stdin()
+ stdout = myIO.Stdout()
+ stderr = myIO.Stderr()
+ }
+ cmd, err := cmd.Open(ctx, innerCtr, pid, stdin, stdout, stderr)
+ if err != nil {
+ return nil, err
+ }
+ p := newProcess(cmd, myIO)
+ c := &ctr{
+ innerCtr: innerCtr,
+ init: p,
+ io: myIO,
+ waitCh: make(chan struct{}),
+ waitCtx: waitCtx,
+ }
+ go p.waitBackground()
+ go c.waitBackground()
+ return c, nil
+}
+
+func (s *Sandbox) RestoreLinuxContainer(ctx context.Context, cid string, pid uint32, myIO cmd.UpstreamIO) (core.Ctr, error) {
+ c, err := restoreContainer(ctx, s.gt, s.waitCtx, cid, pid, myIO)
+ if err != nil {
+ return nil, err
+ }
+ return c, nil
}
diff --git a/internal/core/linuxvm/sandbox.go b/internal/core/linuxvm/sandbox.go
index 8367216b82..1aa322ab00 100644
--- a/internal/core/linuxvm/sandbox.go
+++ b/internal/core/linuxvm/sandbox.go
@@ -5,7 +5,6 @@ import (
"encoding/json"
"errors"
"fmt"
- "io"
"net"
"slices"
"strconv"
@@ -210,7 +209,7 @@ func NewSandbox(ctx context.Context, id string, l *layers.LCOWLayers2, spec *spe
return nil, err
}
ctrConfig := &core.LinuxCtrConfig{
- ID: id,
+ ID: "SANDBOX",
Layers: l,
Spec: newSpec,
}
@@ -236,7 +235,7 @@ func NewSandbox(ctx context.Context, id string, l *layers.LCOWLayers2, spec *spe
translator: translator,
pauseCtr: pauseCtr,
ctrs: map[string]*ctr{
- id: pauseCtr,
+ "SANDBOX": pauseCtr,
},
allowMigration: allowMigration,
waitCh: make(chan struct{}),
@@ -298,37 +297,6 @@ func (s *Sandbox) CreateLinuxContainer(ctx context.Context, c *core.LinuxCtrConf
return ctr, nil
}
-func (s *Sandbox) RestoreLinuxContainer(ctx context.Context, cid string, pid uint32, myIO cmd.UpstreamIO) (core.Ctr, error) {
- innerCtr, err := s.gt.OpenContainer(ctx, cid)
- if err != nil {
- return nil, err
- }
- var (
- stdin io.Reader
- stdout, stderr io.Writer
- )
- if myIO != nil {
- stdin = myIO.Stdin()
- stdout = myIO.Stdout()
- stderr = myIO.Stderr()
- }
- cmd, err := cmd.Open(ctx, innerCtr, pid, stdin, stdout, stderr)
- if err != nil {
- return nil, err
- }
- p := newProcess(cmd, myIO)
- c := &ctr{
- innerCtr: innerCtr,
- init: p,
- io: myIO,
- waitCh: make(chan struct{}),
- waitCtx: s.waitCtx,
- }
- go p.waitBackground()
- go c.waitBackground()
- return c, nil
-}
-
type cleanupSet []resources.ResourceCloser
func (cs cleanupSet) Release(ctx context.Context) (retErr error) {
diff --git a/internal/guest/bridge/bridge.go b/internal/guest/bridge/bridge.go
index f560ea6964..8fdc4efb44 100644
--- a/internal/guest/bridge/bridge.go
+++ b/internal/guest/bridge/bridge.go
@@ -251,12 +251,12 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser
responseErrChan := make(chan error, 1)
b.quitChan = make(chan bool)
- defer close(b.quitChan)
defer bridgeOut.Close()
// defer close(responseErrChan)
defer close(b.responseChan)
// defer close(requestChan)
// defer close(requestErrChan)
+ defer close(b.quitChan)
defer bridgeIn.Close()
// Receive bridge requests and schedule them to be processed.
@@ -387,6 +387,11 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser
}
br.response = resp
select {
+ case <-b.quitChan:
+ return
+ default:
+ }
+ select {
case b.responseChan <- br:
case <-b.quitChan:
return
diff --git a/internal/taskserver/internal.go b/internal/taskserver/internal.go
index db24b637dd..0827969e61 100644
--- a/internal/taskserver/internal.go
+++ b/internal/taskserver/internal.go
@@ -13,12 +13,18 @@ import (
"github.com/Microsoft/hcsshim/internal/cmd"
"github.com/Microsoft/hcsshim/internal/core"
"github.com/Microsoft/hcsshim/internal/core/linuxvm"
+ "github.com/Microsoft/hcsshim/internal/ctrdpub"
"github.com/Microsoft/hcsshim/internal/layers"
+ "github.com/Microsoft/hcsshim/internal/log"
"github.com/Microsoft/hcsshim/internal/oci"
+ "github.com/containerd/containerd/api/events"
"github.com/containerd/containerd/api/runtime/task/v2"
containerd_v1_types "github.com/containerd/containerd/api/types/task"
taskapi "github.com/containerd/containerd/api/types/task"
+ "github.com/containerd/containerd/protobuf"
+ "github.com/containerd/containerd/runtime"
"github.com/opencontainers/runtime-spec/specs-go"
+ "github.com/sirupsen/logrus"
)
type State struct {
@@ -219,3 +225,35 @@ func getOCISpec(ctx context.Context, bundle string, shimOpts *runhcsopts.Options
}
return &spec, nil
}
+
+func waitContainer(ctx context.Context, c core.GenericCompute, state *State, publisher *ctrdpub.Publisher) {
+ waitCh := make(chan error)
+ go func() {
+ waitCh <- c.Wait(ctx)
+ }()
+ select {
+ case err := <-waitCh:
+ logrus.WithFields(logrus.Fields{
+ "taskID": state.TaskID,
+ "execID": state.ExecID,
+ logrus.ErrorKey: err,
+ }).Error("failed waiting for task exit")
+ case <-ctx.Done():
+ logrus.WithFields(logrus.Fields{
+ "taskID": state.TaskID,
+ "execID": state.ExecID,
+ }).Info("aborted task wait")
+ return
+ }
+ state.setExited(uint32(c.Status().ExitCode()))
+ if err := publisher.PublishEvent(ctx, runtime.TaskExitEventTopic, &events.TaskExit{
+ ContainerID: state.TaskID,
+ ID: state.ExecID,
+ Pid: state.Pid,
+ ExitStatus: state.ExitStatus,
+ ExitedAt: protobuf.ToTimestamp(state.ExitedAt),
+ }); err != nil {
+ log.G(ctx).WithError(err).Info("PublishEvent failed")
+ }
+ close(state.waitCh)
+}
diff --git a/internal/taskserver/migration.go b/internal/taskserver/migration.go
index 33faa0d990..6c1201d0a3 100644
--- a/internal/taskserver/migration.go
+++ b/internal/taskserver/migration.go
@@ -13,8 +13,11 @@ import (
"github.com/Microsoft/hcsshim/internal/core"
"github.com/Microsoft/hcsshim/internal/core/linuxvm"
lmproto "github.com/Microsoft/hcsshim/internal/lm/proto"
+ "github.com/Microsoft/hcsshim/internal/log"
statepkg "github.com/Microsoft/hcsshim/internal/state"
+ "github.com/containerd/containerd/api/events"
"github.com/containerd/containerd/api/runtime/task/v2"
+ "github.com/containerd/containerd/runtime"
"github.com/containerd/typeurl/v2"
"github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
@@ -213,7 +216,7 @@ func (s *service) TransferSandbox(ctx context.Context, req *lmproto.TransferSand
func (s *service) FinalizeSandbox(ctx context.Context, req *lmproto.FinalizeSandboxRequest) (*lmproto.FinalizeSandboxResponse, error) {
if s.migState.migrated == nil {
- return nil, fmt.Errorf("No migrated sandbox is present")
+ return nil, fmt.Errorf("no migrated sandbox is present")
}
switch req.Action {
case lmproto.FinalizeSandboxRequest_ACTION_RESUME:
@@ -221,17 +224,47 @@ func (s *service) FinalizeSandbox(ctx context.Context, req *lmproto.FinalizeSand
if err != nil {
return nil, err
}
+ waitCtx, waitCancel := context.WithCancel(context.Background())
s.sandbox = &Sandbox{
State: &State{
TaskID: s.migState.newID,
+ waitCh: make(chan struct{}),
},
- Sandbox: sandbox,
- Tasks: make(map[string]*Task),
+ Sandbox: sandbox,
+ Tasks: make(map[string]*Task),
+ waitCtx: waitCtx,
+ waitCancel: waitCancel,
}
+ go waitContainer(s.sandbox.waitCtx, s.sandbox.Sandbox, s.sandbox.State, s.publisher)
case lmproto.FinalizeSandboxRequest_ACTION_STOP:
if err := s.migState.migrated.LMKill(ctx); err != nil {
return nil, err
}
+ for _, t := range s.sandbox.Tasks {
+ t.setExited(255)
+ if err := s.publisher.PublishEvent(ctx, runtime.TaskExitEventTopic, &events.TaskExit{
+ ContainerID: t.TaskID,
+ ID: t.ExecID,
+ Pid: t.Pid,
+ ExitStatus: t.ExitStatus,
+ ExitedAt: timestamppb.New(t.ExitedAt),
+ }); err != nil {
+ log.G(ctx).WithError(err).Info("PublishEvent failed")
+ }
+ }
+ s.sandbox.setExited(255)
+ if err := s.publisher.PublishEvent(ctx, runtime.TaskExitEventTopic, &events.TaskExit{
+ ContainerID: s.sandbox.TaskID,
+ Pid: s.sandbox.Pid,
+ ExitStatus: s.sandbox.ExitStatus,
+ ExitedAt: timestamppb.New(s.sandbox.ExitedAt),
+ }); err != nil {
+ log.G(ctx).WithError(err).Info("PublishEvent failed")
+ }
+ s.sandbox = nil
+ // We should do this for resume at some point as well, but can't do it right away,
+ // since we need the info in migState for container restore.
+ s.migState = nil
default:
return nil, fmt.Errorf("unsupported action: %v", req.Action)
}
@@ -339,12 +372,14 @@ func (s *service) newRestoreContainer(ctx context.Context, shimOpts *runhcsopts.
if err != nil {
return err
}
-
- s.sandbox.Tasks[req.ID] = &Task{
+ t := &Task{
State: newTaskState(req),
Ctr: ctr,
Execs: make(map[string]*Exec),
}
+ s.sandbox.Tasks[req.ID] = t
+
+ go waitContainer(s.sandbox.waitCtx, ctr, t.State, s.publisher)
return nil
}
diff --git a/internal/taskserver/service.go b/internal/taskserver/service.go
index 249e365a61..99a7de2f5e 100644
--- a/internal/taskserver/service.go
+++ b/internal/taskserver/service.go
@@ -276,37 +276,7 @@ func (s *service) Start(ctx context.Context, req *task.StartRequest) (*task.Star
log.G(ctx).WithError(err).Info("PublishEvent failed")
}
}
- go func() {
- waitCh := make(chan error)
- go func() {
- waitCh <- c.Wait(context.Background())
- }()
- select {
- case err := <-waitCh:
- logrus.WithFields(logrus.Fields{
- "taskID": req.ID,
- "execID": req.ExecID,
- logrus.ErrorKey: err,
- }).Error("failed waiting for task exit")
- case <-s.sandbox.waitCtx.Done():
- logrus.WithFields(logrus.Fields{
- "taskID": req.ID,
- "execID": req.ExecID,
- }).Info("aborted task wait")
- return
- }
- state.setExited(uint32(c.Status().ExitCode()))
- if err := s.publisher.PublishEvent(ctx, runtime.TaskExitEventTopic, &events.TaskExit{
- ContainerID: state.TaskID,
- ID: req.ExecID,
- Pid: state.Pid,
- ExitStatus: state.ExitStatus,
- ExitedAt: protobuf.ToTimestamp(state.ExitedAt),
- }); err != nil {
- log.G(ctx).WithError(err).Info("PublishEvent failed")
- }
- close(state.waitCh)
- }()
+ go waitContainer(s.sandbox.waitCtx, c, state, s.publisher)
return &task.StartResponse{Pid: pid}, nil
}