diff --git a/internal/guest/bridge/bridge_v2.go b/internal/guest/bridge/bridge_v2.go index f7924a2576..c0b9b55546 100644 --- a/internal/guest/bridge/bridge_v2.go +++ b/internal/guest/bridge/bridge_v2.go @@ -199,7 +199,7 @@ func (b *Bridge) execProcessV2(r *Request) (_ RequestResponse, err error) { var c *hcsv2.Container if params.IsExternal || request.ContainerID == hcsv2.UVMContainerID { pid, err = b.hostState.RunExternalProcess(ctx, params, conSettings) - } else if c, err = b.hostState.GetContainer(request.ContainerID); err == nil { + } else if c, err = b.hostState.GetRunningContainer(request.ContainerID); err == nil { // We found a V2 container. Treat this as a V2 process. if params.OCIProcess == nil { pid, err = c.Start(ctx, conSettings) @@ -267,7 +267,7 @@ func (b *Bridge) signalContainerV2(ctx context.Context, span *trace.Span, r *Req b.quitChan <- true b.hostState.Shutdown() } else { - c, err := b.hostState.GetContainer(request.ContainerID) + c, err := b.hostState.GetRunningContainer(request.ContainerID) if err != nil { return nil, err } @@ -296,7 +296,7 @@ func (b *Bridge) signalProcessV2(r *Request) (_ RequestResponse, err error) { trace.Int64Attribute("pid", int64(request.ProcessID)), trace.Int64Attribute("signal", int64(request.Options.Signal))) - c, err := b.hostState.GetContainer(request.ContainerID) + c, err := b.hostState.GetRunningContainer(request.ContainerID) if err != nil { return nil, err } @@ -344,7 +344,7 @@ func (b *Bridge) getPropertiesV2(r *Request) (_ RequestResponse, err error) { return nil, errors.New("getPropertiesV2 is not supported against the UVM") } - c, err := b.hostState.GetContainer(request.ContainerID) + c, err := b.hostState.GetRunningContainer(request.ContainerID) if err != nil { return nil, err } @@ -407,7 +407,7 @@ func (b *Bridge) waitOnProcessV2(r *Request) (_ RequestResponse, err error) { } exitCodeChan, doneChan = p.Wait() } else { - c, err := b.hostState.GetContainer(request.ContainerID) + c, err := b.hostState.GetRunningContainer(request.ContainerID) if err != nil { return nil, err } @@ -453,7 +453,7 @@ func (b *Bridge) resizeConsoleV2(r *Request) (_ RequestResponse, err error) { trace.Int64Attribute("height", int64(request.Height)), trace.Int64Attribute("width", int64(request.Width))) - c, err := b.hostState.GetContainer(request.ContainerID) + c, err := b.hostState.GetRunningContainer(request.ContainerID) if err != nil { return nil, err } @@ -514,7 +514,7 @@ func (b *Bridge) deleteContainerStateV2(r *Request) (_ RequestResponse, err erro return nil, errors.Wrapf(err, "failed to unmarshal JSON in message \"%s\"", r.Message) } - c, err := b.hostState.GetContainer(request.ContainerID) + c, err := b.hostState.GetRunningContainer(request.ContainerID) if err != nil { return nil, err } diff --git a/internal/guest/gcserr/errors.go b/internal/guest/gcserr/errors.go index 0e13e21ef1..b042776800 100644 --- a/internal/guest/gcserr/errors.go +++ b/internal/guest/gcserr/errors.go @@ -87,14 +87,14 @@ func BaseStackTrace(e error) errors.StackTrace { return tracer.StackTrace() } -type baseHresultError struct { +type BaseHresultError struct { hresult Hresult } -func (e *baseHresultError) Error() string { +func (e *BaseHresultError) Error() string { return fmt.Sprintf("HRESULT: 0x%x", uint32(e.Hresult())) } -func (e *baseHresultError) Hresult() Hresult { +func (e *BaseHresultError) Hresult() Hresult { return e.hresult } @@ -139,7 +139,7 @@ func (e *wrappingHresultError) StackTrace() errors.StackTrace { // NewHresultError produces a new error with the given HRESULT. func NewHresultError(hresult Hresult) error { - return &baseHresultError{hresult: hresult} + return &BaseHresultError{hresult: hresult} } // WrapHresult produces a new error with the given HRESULT and wrapping the diff --git a/internal/guest/runtime/hcsv2/container.go b/internal/guest/runtime/hcsv2/container.go index d890d0c2b0..5004f14313 100644 --- a/internal/guest/runtime/hcsv2/container.go +++ b/internal/guest/runtime/hcsv2/container.go @@ -28,6 +28,13 @@ import ( "github.com/Microsoft/hcsshim/internal/protocol/guestresource" ) +type containerStatus string + +const ( + containerCreating containerStatus = "Creating" + containerRunning containerStatus = "Running" +) + type Container struct { id string vsock transport.Transport @@ -43,6 +50,8 @@ type Container struct { processesMutex sync.Mutex processes map[uint32]*containerProcess + + Status containerStatus } func (c *Container) Start(ctx context.Context, conSettings stdio.ConnectionSettings) (int, error) { diff --git a/internal/guest/runtime/hcsv2/uvm.go b/internal/guest/runtime/hcsv2/uvm.go index 568f0ed448..6a50456651 100644 --- a/internal/guest/runtime/hcsv2/uvm.go +++ b/internal/guest/runtime/hcsv2/uvm.go @@ -123,19 +123,29 @@ func (h *Host) RemoveContainer(id string) { delete(h.containers, id) } -func (h *Host) getContainerLocked(id string) (*Container, error) { +func (h *Host) GetRunningContainer(id string) (*Container, error) { + h.containersMutex.Lock() + defer h.containersMutex.Unlock() + if c, ok := h.containers[id]; !ok { return nil, gcserr.NewHresultError(gcserr.HrVmcomputeSystemNotFound) } else { + if c.Status != containerRunning { + return nil, gcserr.NewHresultError(gcserr.HrVmcomputeInvalidState) + } return c, nil } } -func (h *Host) GetContainer(id string) (*Container, error) { +func (h *Host) AddContainer(id string, c *Container) error { h.containersMutex.Lock() defer h.containersMutex.Unlock() - return h.getContainerLocked(id) + if _, ok := h.containers[id]; ok { + return gcserr.NewHresultError(gcserr.HrVmcomputeSystemAlreadyExists) + } + h.containers[id] = c + return nil } func setupSandboxMountsPath(id string) (err error) { @@ -162,11 +172,13 @@ func setupSandboxHugePageMountsPath(id string) error { } func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VMHostedContainerSettingsV2) (_ *Container, err error) { - h.containersMutex.Lock() - defer h.containersMutex.Unlock() - - if _, ok := h.containers[id]; ok { + if _, err := h.GetRunningContainer(id); err == nil { return nil, gcserr.NewHresultError(gcserr.HrVmcomputeSystemAlreadyExists) + } else { + herr := err.(*gcserr.BaseHresultError) + if herr.Hresult() == gcserr.HrVmcomputeInvalidState { + return nil, gcserr.NewHresultError(gcserr.HrVmcomputeSystemAlreadyExists) + } } err = h.securityPolicyEnforcer.EnforceCreateContainerPolicy( @@ -180,8 +192,26 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM return nil, errors.Wrapf(err, "container creation denied due to policy") } - var namespaceID string criType, isCRI := settings.OCISpecification.Annotations[annotations.KubernetesContainerType] + c := &Container{ + id: id, + vsock: h.vsock, + spec: settings.OCISpecification, + isSandbox: criType == "sandbox", + exitType: prot.NtUnexpectedExit, + processes: make(map[uint32]*containerProcess), + Status: containerCreating, + } + if err := h.AddContainer(id, c); err != nil { + return nil, err + } + defer func() { + if err != nil { + h.RemoveContainer(id) + } + }() + + var namespaceID string // for sandbox container sandboxID is same as container id sandboxID := id if isCRI { @@ -290,15 +320,7 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM return nil, errors.Wrapf(err, "failed to get container init process") } - c := &Container{ - id: id, - vsock: h.vsock, - spec: settings.OCISpecification, - isSandbox: criType == "sandbox", - container: con, - exitType: prot.NtUnexpectedExit, - processes: make(map[uint32]*containerProcess), - } + c.container = con c.initProcess = newProcess(c, settings.OCISpecification.Process, init, uint32(c.container.Pid()), true) // Sandbox or standalone, move the networks to the container namespace @@ -318,7 +340,7 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM } } - h.containers[id] = c + c.Status = containerRunning return c, nil } @@ -337,7 +359,7 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req * case guestresource.ResourceTypeVPCIDevice: return modifyMappedVPCIDevice(ctx, req.RequestType, req.Settings.(*guestresource.LCOWMappedVPCIDevice)) case guestresource.ResourceTypeContainerConstraints: - c, err := h.GetContainer(containerID) + c, err := h.GetRunningContainer(containerID) if err != nil { return err } @@ -355,7 +377,7 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req * } func (h *Host) modifyContainerSettings(ctx context.Context, containerID string, req *guestrequest.ModificationRequest) error { - c, err := h.GetContainer(containerID) + c, err := h.GetRunningContainer(containerID) if err != nil { return err } diff --git a/test/cri-containerd/policy_test.go b/test/cri-containerd/policy_test.go index 8ab5a6e6e3..ead1e844bb 100644 --- a/test/cri-containerd/policy_test.go +++ b/test/cri-containerd/policy_test.go @@ -232,13 +232,13 @@ func Test_RunContainers_WithSyncHooks_ValidWaitPath(t *testing.T) { cidWriter := createContainer(t, client, ctx, writerReq) cidWaiter := createContainer(t, client, ctx, waiterReq) - startContainer(t, client, ctx, cidWriter) - defer removeContainer(t, client, ctx, cidWriter) - defer stopContainer(t, client, ctx, cidWriter) - startContainer(t, client, ctx, cidWaiter) defer removeContainer(t, client, ctx, cidWaiter) defer stopContainer(t, client, ctx, cidWaiter) + + startContainer(t, client, ctx, cidWriter) + defer removeContainer(t, client, ctx, cidWriter) + defer stopContainer(t, client, ctx, cidWriter) } func Test_RunContainers_WithSyncHooks_InvalidWaitPath(t *testing.T) {