diff --git a/cmd/shimdiag/lm.go b/cmd/shimdiag/lm.go index 7628b4ab35..f1d411c57a 100644 --- a/cmd/shimdiag/lm.go +++ b/cmd/shimdiag/lm.go @@ -4,279 +4,23 @@ package main import ( "context" - "encoding/json" "fmt" - "io" - "net" "os" - "path/filepath" "strconv" - "sync" - "time" "github.com/Microsoft/go-winio" runhcsopts "github.com/Microsoft/hcsshim/cmd/containerd-shim-runhcs-v1/options" "github.com/Microsoft/hcsshim/internal/appargs" lmproto "github.com/Microsoft/hcsshim/internal/lm/proto" statepkg "github.com/Microsoft/hcsshim/internal/state" - eventtypes "github.com/containerd/containerd/api/events" - "github.com/containerd/containerd/api/runtime/task/v2" - "github.com/containerd/containerd/api/services/ttrpc/events/v1" - "github.com/containerd/containerd/api/types" "github.com/containerd/ttrpc" - "github.com/containerd/typeurl/v2" "github.com/urfave/cli" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/known/anypb" - "google.golang.org/protobuf/types/known/emptypb" ) -var createCommand = cli.Command{ - Name: "create", - Usage: "Creates a task", - ArgsUsage: "[flags]
", - Flags: []cli.Flag{ - cli.StringFlag{ - Name: "stdin", - Usage: "Named pipe path", - }, - cli.StringFlag{ - Name: "stdout", - Usage: "Named pipe path", - }, - cli.StringFlag{ - Name: "stderr", - Usage: "Named pipe path", - }, - cli.BoolFlag{ - Name: "tty", - Usage: "Enable terminal mode for task IO", - }, - cli.StringFlag{ - Name: "rootfs", - Usage: "JSON file to read rootfs from", - }, - cli.StringFlag{ - Name: "options", - Usage: "jsonpb file to read shim options from", - }, - }, - SkipArgReorder: true, - Before: appargs.Validate(appargs.String, appargs.String, appargs.String), - Action: func(clictx *cli.Context) error { - args := clictx.Args() - address := args[0] - id := args[1] - - bundle, err := filepath.Abs(args[2]) - if err != nil { - return err - } - - var rootfs []*types.Mount - if clictx.IsSet("rootfs") { - data, err := os.ReadFile(clictx.String("rootfs")) - if err != nil { - return err - } - if err := json.Unmarshal(data, &rootfs); err != nil { - return err - } - } - - var options *anypb.Any - if clictx.IsSet("options") { - data, err := os.ReadFile(clictx.String("options")) - if err != nil { - return err - } - var opts runhcsopts.Options - if err := (protojson.UnmarshalOptions{}).Unmarshal(data, &opts); err != nil { - return err - } - any, err := typeurl.MarshalAny(&opts) - if err != nil { - return err - } - options = &anypb.Any{TypeUrl: any.GetTypeUrl(), Value: any.GetValue()} - } - - conn, err := winio.DialPipe(address, nil) - if err != nil { - return fmt.Errorf("dial %s: %w", address, err) - } - - client := ttrpc.NewClient(conn) - svc := task.NewTaskClient(client) - - ctx := context.Background() - - { - resp, err := svc.Create(ctx, &task.CreateTaskRequest{ - ID: id, - Bundle: bundle, - Rootfs: rootfs, - Terminal: clictx.Bool("tty"), - Stdin: clictx.String("stdin"), - Stdout: clictx.String("stdout"), - Stderr: clictx.String("stderr"), - Options: options, - }) - if err != nil { - return fmt.Errorf("task.Create: %w", err) - } - fmt.Printf("task pid is %d\n", resp.Pid) - } - return nil - }, -} - -var startCommand = cli.Command{ - Name: "start", - Usage: "", - ArgsUsage: "[flags]
", - Flags: []cli.Flag{}, - SkipArgReorder: true, - Before: appargs.Validate(appargs.String, appargs.String), - Action: func(clictx *cli.Context) error { - args := clictx.Args() - address := args[0] - id := args[1] - - conn, err := winio.DialPipe(address, nil) - if err != nil { - return fmt.Errorf("dial %s: %w", address, err) - } - - client := ttrpc.NewClient(conn) - svc := task.NewTaskClient(client) - - ctx := context.Background() - - { - resp, err := svc.Start(ctx, &task.StartRequest{ - ID: id, - }) - if err != nil { - return fmt.Errorf("task.Start: %w", err) - } - fmt.Printf("task pid is %d\n", resp.Pid) - } - return nil - }, -} - -var pipeCommand = cli.Command{ - Name: "pipe", - Usage: "", - ArgsUsage: "[flags] ", - Flags: []cli.Flag{}, - SkipArgReorder: true, - Action: func(clictx *cli.Context) error { - args := clictx.Args() - - f := func(name, pipe string, wg *sync.WaitGroup, copy func(c net.Conn) (int64, error)) { - defer wg.Done() - l, err := winio.ListenPipe(pipe, nil) - if err != nil { - panic(err) - } - fmt.Printf("%s: listening on %s\n", name, pipe) - c, err := l.Accept() - if err != nil { - panic(err) - } - fmt.Printf("%s: received connection\n", name) - n, err := copy(c) - fmt.Printf("%s: copy completed after %d bytes", name, n) - if err != nil { - fmt.Printf(" and error: %s", err) - } - fmt.Printf("\n") - } - - var wg sync.WaitGroup - if len(args) > 0 { - wg.Add(1) - go f("stdin", args[0], &wg, func(c net.Conn) (int64, error) { return io.Copy(c, os.Stdin) }) - } - if len(args) > 1 { - wg.Add(1) - go f("stdout", args[1], &wg, func(c net.Conn) (int64, error) { return io.Copy(os.Stdout, c) }) - } - if len(args) > 2 { - wg.Add(1) - go f("stderr", args[2], &wg, func(c net.Conn) (int64, error) { return io.Copy(os.Stderr, c) }) - } - wg.Wait() - - return nil - }, -} - -type eventsSvc struct { - m sync.Mutex -} - -func (e *eventsSvc) Forward(ctx context.Context, req *events.ForwardRequest) (*emptypb.Empty, error) { - e.m.Lock() - defer e.m.Unlock() - - fmt.Printf("[%s][%s]: %s\n", req.Envelope.Timestamp.AsTime().Format(time.RFC3339), req.Envelope.Namespace, req.Envelope.Topic) - v, err := typeurl.UnmarshalAny(req.Envelope.Event) - if err != nil { - fmt.Printf("\tunmarshal failed: %s\n", err) - } - switch v := v.(type) { - case *eventtypes.TaskCreate: - fmt.Printf("\tContainerID: %s\n", v.ContainerID) - fmt.Printf("\tBundle: %s\n", v.Bundle) - fmt.Printf("\tPID: %d\n", v.Pid) - case *eventtypes.TaskStart: - fmt.Printf("\tContainerID: %s\n", v.ContainerID) - fmt.Printf("\tPID: %d\n", v.Pid) - case *eventtypes.TaskExit: - fmt.Printf("\tID: %s\n", v.ID) - fmt.Printf("\tContainerID: %s\n", v.ContainerID) - fmt.Printf("\tPID: %d\n", v.Pid) - fmt.Printf("\tExitStatus: %d\n", v.ExitStatus) - fmt.Printf("\tExitedAt: %v\n", v.ExitedAt.AsTime().Format(time.RFC3339)) - default: - fmt.Printf("\tunrecognized event type: %T\n", v) - } - return &emptypb.Empty{}, nil -} - -var eventsCommand = cli.Command{ - Name: "events", - Usage: "", - ArgsUsage: "[flags]
", - Flags: []cli.Flag{}, - SkipArgReorder: true, - Before: appargs.Validate(appargs.String), - Action: func(clictx *cli.Context) error { - args := clictx.Args() - address := args[0] - - l, err := winio.ListenPipe(address, nil) - if err != nil { - return err - } - - server, err := ttrpc.NewServer() - if err != nil { - return err - } - events.RegisterEventsService(server, &eventsSvc{}) - if err := server.Serve(context.Background(), l); err != nil { - return err - } - return nil - }, -} - var lmPrepareCommand = cli.Command{ Name: "lmprepare", Usage: "Prepares the sandbox for migration", @@ -572,57 +316,3 @@ var pb2jsonCommand = cli.Command{ return nil }, } - -var deleteCommand = cli.Command{ - Name: "delete", - ArgsUsage: " ", - SkipArgReorder: true, - Before: appargs.Validate(appargs.String, appargs.String), - Action: func(clictx *cli.Context) error { - args := clictx.Args() - address := args[0] - id := args[1] - - conn, err := winio.DialPipe(address, nil) - if err != nil { - return fmt.Errorf("dial %s: %w", address, err) - } - - client := ttrpc.NewClient(conn) - svc := task.NewTaskClient(client) - - ctx := context.Background() - - if _, err := svc.Delete(ctx, &task.DeleteRequest{ID: id}); err != nil { - return err - } - return nil - }, -} - -var shutdownCommand = cli.Command{ - Name: "shutdown", - ArgsUsage: " ", - SkipArgReorder: true, - Before: appargs.Validate(appargs.String, appargs.String), - Action: func(clictx *cli.Context) error { - args := clictx.Args() - address := args[0] - id := args[1] - - conn, err := winio.DialPipe(address, nil) - if err != nil { - return fmt.Errorf("dial %s: %w", address, err) - } - - client := ttrpc.NewClient(conn) - svc := task.NewTaskClient(client) - - ctx := context.Background() - - if _, err := svc.Shutdown(ctx, &task.ShutdownRequest{ID: id}); err != nil { - return err - } - return nil - }, -} diff --git a/cmd/shimdiag/shimdiag.go b/cmd/shimdiag/shimdiag.go index a7097523bc..ea3720aed4 100644 --- a/cmd/shimdiag/shimdiag.go +++ b/cmd/shimdiag/shimdiag.go @@ -27,14 +27,8 @@ func main() { lmDialCommand, lmTransferCommand, lmFinalizeCommand, - createCommand, - pipeCommand, - startCommand, - eventsCommand, json2pbCommand, pb2jsonCommand, - deleteCommand, - shutdownCommand, } if err := app.Run(os.Args); err != nil { fmt.Fprintln(os.Stderr, err) diff --git a/cmd/task/core.go b/cmd/task/core.go index b70b4e8929..0b20c3c8eb 100644 --- a/cmd/task/core.go +++ b/cmd/task/core.go @@ -200,8 +200,13 @@ var startCommand = cli.Command{ } var deleteCommand = cli.Command{ - Name: "delete", - ArgsUsage: " ", + Name: "delete", + ArgsUsage: " ", + Flags: []cli.Flag{ + cli.StringFlag{ + Name: "execid", + }, + }, SkipArgReorder: true, Before: appargs.Validate(appargs.String, appargs.String), Action: func(clictx *cli.Context) error { @@ -209,6 +214,8 @@ var deleteCommand = cli.Command{ address := args[0] id := args[1] + execID := clictx.String("execid") + conn, err := winio.DialPipe(address, nil) if err != nil { return fmt.Errorf("dial %s: %w", address, err) @@ -219,7 +226,7 @@ var deleteCommand = cli.Command{ ctx := context.Background() - if _, err := svc.Delete(ctx, &task.DeleteRequest{ID: id}); err != nil { + if _, err := svc.Delete(ctx, &task.DeleteRequest{ID: id, ExecID: execID}); err != nil { return err } return nil diff --git a/cmd/task/extra.go b/cmd/task/extra.go index f11d36683c..52652fc865 100644 --- a/cmd/task/extra.go +++ b/cmd/task/extra.go @@ -4,13 +4,15 @@ import ( "context" "fmt" "io" - "net" "os" + "path/filepath" "sync" + "syscall" "time" "github.com/Microsoft/go-winio" "github.com/Microsoft/hcsshim/internal/appargs" + "github.com/containerd/console" eventtypes "github.com/containerd/containerd/api/events" "github.com/containerd/containerd/api/services/ttrpc/events/v1" "github.com/containerd/ttrpc" @@ -19,48 +21,97 @@ import ( "google.golang.org/protobuf/types/known/emptypb" ) -var pipeCommand = cli.Command{ - Name: "pipe", - Usage: "", - ArgsUsage: "[flags] ", - Flags: []cli.Flag{}, +type rawConReader struct { + f *os.File +} + +func (r rawConReader) Read(b []byte) (int, error) { + n, err := syscall.Read(syscall.Handle(r.f.Fd()), b) + if n == 0 && len(b) != 0 && err == nil { + // A zero-byte read on a console indicates that the user wrote Ctrl-Z. + b[0] = 26 + return 1, nil + } + return n, err +} + +func pipeIO(name string, path string, f interface{}, in bool, wg *sync.WaitGroup) error { + l, err := winio.ListenPipe(path, nil) + if err != nil { + return err + } + wg.Add(1) + go func() { + defer wg.Done() + fmt.Printf("%s: listening on %s\n", name, path) + c, err := l.Accept() + if err != nil { + fmt.Printf("%s: connection failed: %s\n", name, err) + return + } + fmt.Printf("%s: received connection\n", name) + var copy func() (int64, error) + if in { + copy = func() (int64, error) { return io.Copy(c, f.(io.Reader)) } + defer c.Close() + } else { + copy = func() (int64, error) { return io.Copy(f.(io.Writer), c) } + } + n, err := copy() + fmt.Printf("%s: copy completed after %d bytes", name, n) + if err != nil { + fmt.Printf(" with error: %s", err) + } + fmt.Printf("\n") + }() + return nil +} + +var ioCommand = cli.Command{ + Name: "io", + Usage: "", + ArgsUsage: "[flags] ", + Flags: []cli.Flag{ + cli.BoolFlag{ + Name: "tty", + }, + }, SkipArgReorder: true, + Before: appargs.Validate(appargs.String), Action: func(clictx *cli.Context) error { - args := clictx.Args() + pipeBase := clictx.Args()[0] - f := func(name, pipe string, wg *sync.WaitGroup, copy func(c net.Conn) (int64, error)) { - defer wg.Done() - l, err := winio.ListenPipe(pipe, nil) + var stdin io.Reader = os.Stdin + if clictx.Bool("tty") { + con, err := console.ConsoleFromFile(os.Stdin) if err != nil { - panic(err) + return err } - fmt.Printf("%s: listening on %s\n", name, pipe) - c, err := l.Accept() - if err != nil { - panic(err) + if err := con.SetRaw(); err != nil { + return err } - fmt.Printf("%s: received connection\n", name) - n, err := copy(c) - fmt.Printf("%s: copy completed after %d bytes", name, n) - if err != nil { - fmt.Printf(" and error: %s", err) - } - fmt.Printf("\n") + defer con.Reset() + stdin = rawConReader{os.Stdin} } var wg sync.WaitGroup - if len(args) > 0 { - wg.Add(1) - go f("stdin", args[0], &wg, func(c net.Conn) (int64, error) { return io.Copy(c, os.Stdin) }) + + stdinPath := filepath.Join(pipeBase, "stdin") + if err := pipeIO("stdin", stdinPath, stdin, true, &wg); err != nil { + return err } - if len(args) > 1 { - wg.Add(1) - go f("stdout", args[1], &wg, func(c net.Conn) (int64, error) { return io.Copy(os.Stdout, c) }) + stdoutPath := filepath.Join(pipeBase, "stdout") + if err := pipeIO("stdout", stdoutPath, os.Stdout, false, &wg); err != nil { + return err } - if len(args) > 2 { - wg.Add(1) - go f("stderr", args[2], &wg, func(c net.Conn) (int64, error) { return io.Copy(os.Stderr, c) }) + var stderrPath string + if !clictx.Bool("tty") { + stderrPath = filepath.Join(pipeBase, "stderr") + if err := pipeIO("stderr", stderrPath, os.Stderr, false, &wg); err != nil { + return err + } } + wg.Wait() return nil diff --git a/cmd/task/main.go b/cmd/task/main.go index 0f0ee213c9..c9b3bae247 100644 --- a/cmd/task/main.go +++ b/cmd/task/main.go @@ -26,7 +26,7 @@ func main() { connectCommand, shutdownCommand, // Extra - pipeCommand, + ioCommand, eventsCommand, } if err := app.Run(os.Args); err != nil { diff --git a/internal/core/linuxvm/migrator.go b/internal/core/linuxvm/migrator.go index 4eb551fde4..d432b22bc9 100644 --- a/internal/core/linuxvm/migrator.go +++ b/internal/core/linuxvm/migrator.go @@ -3,10 +3,12 @@ package linuxvm import ( "context" "fmt" + "io" "net" "github.com/Microsoft/go-winio" "github.com/Microsoft/go-winio/pkg/guid" + "github.com/Microsoft/hcsshim/internal/cmd" "github.com/Microsoft/hcsshim/internal/core" "github.com/Microsoft/hcsshim/internal/guestmanager" "github.com/Microsoft/hcsshim/internal/hns" @@ -65,14 +67,21 @@ func (s *Sandbox) LMPrepare(ctx context.Context) (_ *statepkg.SandboxState, _ *c s.isLMSrc = true vmConfig := s.vm.Config() vmConfig.NICs = nil + containers := make(map[string]*statepkg.Container) + for id, ctr := range s.ctrs { + containers[id] = &statepkg.Container{ + InitPid: uint32(ctr.Pid()), + } + } s.state = &statepkg.SandboxState{ Vm: &statepkg.VMState{ Config: statepkg.VMConfigFromInternal(vmConfig), CompatInfo: compatInfo, Resources: intResources, }, - Agent: s.gm.State(), - Ifaces: s.ifaces, + Agent: s.gm.State(), + Ifaces: s.ifaces, + Containers: containers, } return s.state, resources, nil } @@ -89,17 +98,18 @@ func (s *Sandbox) LMTransfer(ctx context.Context, socket uintptr) (core.Migrated } type migrated struct { - vm *vm.VM - agentConfig *statepkg.GCState - newNetNS string - oldIfaces []*statepkg.GuestInterface + vm *vm.VM + sandboxContainer *statepkg.Container + agentConfig *statepkg.GCState + newNetNS string + oldIfaces []*statepkg.GuestInterface } func (m *migrated) LMComplete(ctx context.Context) (core.Sandbox, error) { if err := m.vm.LMFinalize(ctx, true); err != nil { return nil, err } - return newSandbox(ctx, m.vm, m.agentConfig, m.newNetNS, m.oldIfaces) + return newSandbox(ctx, m.vm, m.sandboxContainer, m.agentConfig, m.newNetNS, m.oldIfaces) } func (m *migrated) LMKill(ctx context.Context) error { @@ -145,14 +155,15 @@ func (m *migrator) LMTransfer(ctx context.Context, socket uintptr) (core.Migrate return nil, err } return &migrated{ - vm: m.vm, - agentConfig: m.sandboxState.Agent, - newNetNS: m.netns, - oldIfaces: m.sandboxState.Ifaces, + vm: m.vm, + sandboxContainer: m.sandboxState.Containers["SANDBOX"], + agentConfig: m.sandboxState.Agent, + newNetNS: m.netns, + oldIfaces: m.sandboxState.Ifaces, }, nil } -func newSandbox(ctx context.Context, vm *vm.VM, agentConfig *statepkg.GCState, newNetNS string, oldIFaces []*statepkg.GuestInterface) (core.Sandbox, error) { +func newSandbox(ctx context.Context, vm *vm.VM, sandboxContainer *statepkg.Container, agentConfig *statepkg.GCState, newNetNS string, oldIFaces []*statepkg.GuestInterface) (core.Sandbox, error) { gm, err := guestmanager.NewLinuxManagerFromState( func(port uint32) (net.Listener, error) { return vm.ListenHVSocket(winio.VsockServiceID(port)) }, agentConfig) @@ -203,10 +214,16 @@ func newSandbox(ctx context.Context, vm *vm.VM, agentConfig *statepkg.GCState, n } waitCtx, waitCancel := context.WithCancel(context.Background()) - return &Sandbox{ + gt := newGuestThing(gm) + pauseCtr, err := restoreContainer(ctx, gt, waitCtx, "SANDBOX", sandboxContainer.InitPid, nil) + if err != nil { + return nil, err + } + + sandbox := &Sandbox{ vm: vm, gm: gm, - gt: newGuestThing(gm), + gt: gt, translator: &translator{ vm: vm, scsiAttacher: nil, @@ -217,5 +234,47 @@ func newSandbox(ctx context.Context, vm *vm.VM, agentConfig *statepkg.GCState, n ifaces: ifaces, waitCtx: waitCtx, waitCancel: waitCancel, - }, nil + pauseCtr: pauseCtr, + } + go sandbox.waitBackground() + return sandbox, nil +} + +func restoreContainer(ctx context.Context, gt *guestThing, waitCtx context.Context, cid string, pid uint32, myIO cmd.UpstreamIO) (*ctr, error) { + innerCtr, err := gt.OpenContainer(ctx, cid) + if err != nil { + return nil, err + } + var ( + stdin io.Reader + stdout, stderr io.Writer + ) + if myIO != nil { + stdin = myIO.Stdin() + stdout = myIO.Stdout() + stderr = myIO.Stderr() + } + cmd, err := cmd.Open(ctx, innerCtr, pid, stdin, stdout, stderr) + if err != nil { + return nil, err + } + p := newProcess(cmd, myIO) + c := &ctr{ + innerCtr: innerCtr, + init: p, + io: myIO, + waitCh: make(chan struct{}), + waitCtx: waitCtx, + } + go p.waitBackground() + go c.waitBackground() + return c, nil +} + +func (s *Sandbox) RestoreLinuxContainer(ctx context.Context, cid string, pid uint32, myIO cmd.UpstreamIO) (core.Ctr, error) { + c, err := restoreContainer(ctx, s.gt, s.waitCtx, cid, pid, myIO) + if err != nil { + return nil, err + } + return c, nil } diff --git a/internal/core/linuxvm/sandbox.go b/internal/core/linuxvm/sandbox.go index 8367216b82..1aa322ab00 100644 --- a/internal/core/linuxvm/sandbox.go +++ b/internal/core/linuxvm/sandbox.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "net" "slices" "strconv" @@ -210,7 +209,7 @@ func NewSandbox(ctx context.Context, id string, l *layers.LCOWLayers2, spec *spe return nil, err } ctrConfig := &core.LinuxCtrConfig{ - ID: id, + ID: "SANDBOX", Layers: l, Spec: newSpec, } @@ -236,7 +235,7 @@ func NewSandbox(ctx context.Context, id string, l *layers.LCOWLayers2, spec *spe translator: translator, pauseCtr: pauseCtr, ctrs: map[string]*ctr{ - id: pauseCtr, + "SANDBOX": pauseCtr, }, allowMigration: allowMigration, waitCh: make(chan struct{}), @@ -298,37 +297,6 @@ func (s *Sandbox) CreateLinuxContainer(ctx context.Context, c *core.LinuxCtrConf return ctr, nil } -func (s *Sandbox) RestoreLinuxContainer(ctx context.Context, cid string, pid uint32, myIO cmd.UpstreamIO) (core.Ctr, error) { - innerCtr, err := s.gt.OpenContainer(ctx, cid) - if err != nil { - return nil, err - } - var ( - stdin io.Reader - stdout, stderr io.Writer - ) - if myIO != nil { - stdin = myIO.Stdin() - stdout = myIO.Stdout() - stderr = myIO.Stderr() - } - cmd, err := cmd.Open(ctx, innerCtr, pid, stdin, stdout, stderr) - if err != nil { - return nil, err - } - p := newProcess(cmd, myIO) - c := &ctr{ - innerCtr: innerCtr, - init: p, - io: myIO, - waitCh: make(chan struct{}), - waitCtx: s.waitCtx, - } - go p.waitBackground() - go c.waitBackground() - return c, nil -} - type cleanupSet []resources.ResourceCloser func (cs cleanupSet) Release(ctx context.Context) (retErr error) { diff --git a/internal/guest/bridge/bridge.go b/internal/guest/bridge/bridge.go index f560ea6964..8fdc4efb44 100644 --- a/internal/guest/bridge/bridge.go +++ b/internal/guest/bridge/bridge.go @@ -251,12 +251,12 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser responseErrChan := make(chan error, 1) b.quitChan = make(chan bool) - defer close(b.quitChan) defer bridgeOut.Close() // defer close(responseErrChan) defer close(b.responseChan) // defer close(requestChan) // defer close(requestErrChan) + defer close(b.quitChan) defer bridgeIn.Close() // Receive bridge requests and schedule them to be processed. @@ -387,6 +387,11 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser } br.response = resp select { + case <-b.quitChan: + return + default: + } + select { case b.responseChan <- br: case <-b.quitChan: return diff --git a/internal/taskserver/internal.go b/internal/taskserver/internal.go index db24b637dd..0827969e61 100644 --- a/internal/taskserver/internal.go +++ b/internal/taskserver/internal.go @@ -13,12 +13,18 @@ import ( "github.com/Microsoft/hcsshim/internal/cmd" "github.com/Microsoft/hcsshim/internal/core" "github.com/Microsoft/hcsshim/internal/core/linuxvm" + "github.com/Microsoft/hcsshim/internal/ctrdpub" "github.com/Microsoft/hcsshim/internal/layers" + "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/oci" + "github.com/containerd/containerd/api/events" "github.com/containerd/containerd/api/runtime/task/v2" containerd_v1_types "github.com/containerd/containerd/api/types/task" taskapi "github.com/containerd/containerd/api/types/task" + "github.com/containerd/containerd/protobuf" + "github.com/containerd/containerd/runtime" "github.com/opencontainers/runtime-spec/specs-go" + "github.com/sirupsen/logrus" ) type State struct { @@ -219,3 +225,35 @@ func getOCISpec(ctx context.Context, bundle string, shimOpts *runhcsopts.Options } return &spec, nil } + +func waitContainer(ctx context.Context, c core.GenericCompute, state *State, publisher *ctrdpub.Publisher) { + waitCh := make(chan error) + go func() { + waitCh <- c.Wait(ctx) + }() + select { + case err := <-waitCh: + logrus.WithFields(logrus.Fields{ + "taskID": state.TaskID, + "execID": state.ExecID, + logrus.ErrorKey: err, + }).Error("failed waiting for task exit") + case <-ctx.Done(): + logrus.WithFields(logrus.Fields{ + "taskID": state.TaskID, + "execID": state.ExecID, + }).Info("aborted task wait") + return + } + state.setExited(uint32(c.Status().ExitCode())) + if err := publisher.PublishEvent(ctx, runtime.TaskExitEventTopic, &events.TaskExit{ + ContainerID: state.TaskID, + ID: state.ExecID, + Pid: state.Pid, + ExitStatus: state.ExitStatus, + ExitedAt: protobuf.ToTimestamp(state.ExitedAt), + }); err != nil { + log.G(ctx).WithError(err).Info("PublishEvent failed") + } + close(state.waitCh) +} diff --git a/internal/taskserver/migration.go b/internal/taskserver/migration.go index 33faa0d990..6c1201d0a3 100644 --- a/internal/taskserver/migration.go +++ b/internal/taskserver/migration.go @@ -13,8 +13,11 @@ import ( "github.com/Microsoft/hcsshim/internal/core" "github.com/Microsoft/hcsshim/internal/core/linuxvm" lmproto "github.com/Microsoft/hcsshim/internal/lm/proto" + "github.com/Microsoft/hcsshim/internal/log" statepkg "github.com/Microsoft/hcsshim/internal/state" + "github.com/containerd/containerd/api/events" "github.com/containerd/containerd/api/runtime/task/v2" + "github.com/containerd/containerd/runtime" "github.com/containerd/typeurl/v2" "github.com/sirupsen/logrus" "golang.org/x/sys/windows" @@ -213,7 +216,7 @@ func (s *service) TransferSandbox(ctx context.Context, req *lmproto.TransferSand func (s *service) FinalizeSandbox(ctx context.Context, req *lmproto.FinalizeSandboxRequest) (*lmproto.FinalizeSandboxResponse, error) { if s.migState.migrated == nil { - return nil, fmt.Errorf("No migrated sandbox is present") + return nil, fmt.Errorf("no migrated sandbox is present") } switch req.Action { case lmproto.FinalizeSandboxRequest_ACTION_RESUME: @@ -221,17 +224,47 @@ func (s *service) FinalizeSandbox(ctx context.Context, req *lmproto.FinalizeSand if err != nil { return nil, err } + waitCtx, waitCancel := context.WithCancel(context.Background()) s.sandbox = &Sandbox{ State: &State{ TaskID: s.migState.newID, + waitCh: make(chan struct{}), }, - Sandbox: sandbox, - Tasks: make(map[string]*Task), + Sandbox: sandbox, + Tasks: make(map[string]*Task), + waitCtx: waitCtx, + waitCancel: waitCancel, } + go waitContainer(s.sandbox.waitCtx, s.sandbox.Sandbox, s.sandbox.State, s.publisher) case lmproto.FinalizeSandboxRequest_ACTION_STOP: if err := s.migState.migrated.LMKill(ctx); err != nil { return nil, err } + for _, t := range s.sandbox.Tasks { + t.setExited(255) + if err := s.publisher.PublishEvent(ctx, runtime.TaskExitEventTopic, &events.TaskExit{ + ContainerID: t.TaskID, + ID: t.ExecID, + Pid: t.Pid, + ExitStatus: t.ExitStatus, + ExitedAt: timestamppb.New(t.ExitedAt), + }); err != nil { + log.G(ctx).WithError(err).Info("PublishEvent failed") + } + } + s.sandbox.setExited(255) + if err := s.publisher.PublishEvent(ctx, runtime.TaskExitEventTopic, &events.TaskExit{ + ContainerID: s.sandbox.TaskID, + Pid: s.sandbox.Pid, + ExitStatus: s.sandbox.ExitStatus, + ExitedAt: timestamppb.New(s.sandbox.ExitedAt), + }); err != nil { + log.G(ctx).WithError(err).Info("PublishEvent failed") + } + s.sandbox = nil + // We should do this for resume at some point as well, but can't do it right away, + // since we need the info in migState for container restore. + s.migState = nil default: return nil, fmt.Errorf("unsupported action: %v", req.Action) } @@ -339,12 +372,14 @@ func (s *service) newRestoreContainer(ctx context.Context, shimOpts *runhcsopts. if err != nil { return err } - - s.sandbox.Tasks[req.ID] = &Task{ + t := &Task{ State: newTaskState(req), Ctr: ctr, Execs: make(map[string]*Exec), } + s.sandbox.Tasks[req.ID] = t + + go waitContainer(s.sandbox.waitCtx, ctr, t.State, s.publisher) return nil } diff --git a/internal/taskserver/service.go b/internal/taskserver/service.go index 249e365a61..99a7de2f5e 100644 --- a/internal/taskserver/service.go +++ b/internal/taskserver/service.go @@ -276,37 +276,7 @@ func (s *service) Start(ctx context.Context, req *task.StartRequest) (*task.Star log.G(ctx).WithError(err).Info("PublishEvent failed") } } - go func() { - waitCh := make(chan error) - go func() { - waitCh <- c.Wait(context.Background()) - }() - select { - case err := <-waitCh: - logrus.WithFields(logrus.Fields{ - "taskID": req.ID, - "execID": req.ExecID, - logrus.ErrorKey: err, - }).Error("failed waiting for task exit") - case <-s.sandbox.waitCtx.Done(): - logrus.WithFields(logrus.Fields{ - "taskID": req.ID, - "execID": req.ExecID, - }).Info("aborted task wait") - return - } - state.setExited(uint32(c.Status().ExitCode())) - if err := s.publisher.PublishEvent(ctx, runtime.TaskExitEventTopic, &events.TaskExit{ - ContainerID: state.TaskID, - ID: req.ExecID, - Pid: state.Pid, - ExitStatus: state.ExitStatus, - ExitedAt: protobuf.ToTimestamp(state.ExitedAt), - }); err != nil { - log.G(ctx).WithError(err).Info("PublishEvent failed") - } - close(state.waitCh) - }() + go waitContainer(s.sandbox.waitCtx, c, state, s.publisher) return &task.StartResponse{Pid: pid}, nil }