Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow multiple CreateContainer operations at the same time. #1355

Merged
merged 7 commits into from
Apr 22, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions internal/guest/bridge/bridge_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func (b *Bridge) execProcessV2(r *Request) (_ RequestResponse, err error) {
var c *hcsv2.Container
if params.IsExternal || request.ContainerID == hcsv2.UVMContainerID {
pid, err = b.hostState.RunExternalProcess(ctx, params, conSettings)
} else if c, err = b.hostState.GetContainer(request.ContainerID); err == nil {
} else if c, err = b.hostState.GetCreatedContainer(request.ContainerID); err == nil {
// We found a V2 container. Treat this as a V2 process.
if params.OCIProcess == nil {
pid, err = c.Start(ctx, conSettings)
Expand Down Expand Up @@ -267,7 +267,7 @@ func (b *Bridge) signalContainerV2(ctx context.Context, span *trace.Span, r *Req
b.quitChan <- true
b.hostState.Shutdown()
} else {
c, err := b.hostState.GetContainer(request.ContainerID)
c, err := b.hostState.GetCreatedContainer(request.ContainerID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -296,7 +296,7 @@ func (b *Bridge) signalProcessV2(r *Request) (_ RequestResponse, err error) {
trace.Int64Attribute("pid", int64(request.ProcessID)),
trace.Int64Attribute("signal", int64(request.Options.Signal)))

c, err := b.hostState.GetContainer(request.ContainerID)
c, err := b.hostState.GetCreatedContainer(request.ContainerID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -344,7 +344,7 @@ func (b *Bridge) getPropertiesV2(r *Request) (_ RequestResponse, err error) {
return nil, errors.New("getPropertiesV2 is not supported against the UVM")
}

c, err := b.hostState.GetContainer(request.ContainerID)
c, err := b.hostState.GetCreatedContainer(request.ContainerID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -407,7 +407,7 @@ func (b *Bridge) waitOnProcessV2(r *Request) (_ RequestResponse, err error) {
}
exitCodeChan, doneChan = p.Wait()
} else {
c, err := b.hostState.GetContainer(request.ContainerID)
c, err := b.hostState.GetCreatedContainer(request.ContainerID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -453,7 +453,7 @@ func (b *Bridge) resizeConsoleV2(r *Request) (_ RequestResponse, err error) {
trace.Int64Attribute("height", int64(request.Height)),
trace.Int64Attribute("width", int64(request.Width)))

c, err := b.hostState.GetContainer(request.ContainerID)
c, err := b.hostState.GetCreatedContainer(request.ContainerID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -514,7 +514,7 @@ func (b *Bridge) deleteContainerStateV2(r *Request) (_ RequestResponse, err erro
return nil, errors.Wrapf(err, "failed to unmarshal JSON in message \"%s\"", r.Message)
}

c, err := b.hostState.GetContainer(request.ContainerID)
c, err := b.hostState.GetCreatedContainer(request.ContainerID)
if err != nil {
return nil, err
}
Expand Down
32 changes: 32 additions & 0 deletions internal/guest/runtime/hcsv2/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ package hcsv2

import (
"context"
"fmt"
"sync"
"sync/atomic"
"syscall"

"github.com/containerd/cgroups"
Expand All @@ -28,6 +30,18 @@ import (
"github.com/Microsoft/hcsshim/internal/protocol/guestresource"
)

// containerStatus has been introduced to enable parallel container creation
type containerStatus uint32

const (
// containerCreating is the default status set on a Container object, when
// no underlying runtime container or init process has been assigned
containerCreating containerStatus = iota
// containerCreated is the status when a runtime container and init process
// have been assigned, but runtime start command has not been issued yet
containerCreated
)

type Container struct {
id string
vsock transport.Transport
Expand All @@ -43,6 +57,8 @@ type Container struct {

processesMutex sync.Mutex
processes map[uint32]*containerProcess

status containerStatus
anmaxvl marked this conversation as resolved.
Show resolved Hide resolved
}

func (c *Container) Start(ctx context.Context, conSettings stdio.ConnectionSettings) (int, error) {
Expand Down Expand Up @@ -220,3 +236,19 @@ func (c *Container) GetStats(ctx context.Context) (*v1.Metrics, error) {
func (c *Container) modifyContainerConstraints(ctx context.Context, rt guestrequest.RequestType, cc *guestresource.LCOWContainerConstraints) (err error) {
return c.Update(ctx, cc.Linux)
}

func (c *Container) getStatus() containerStatus {
val := atomic.LoadUint32((*uint32)(&c.status))
return containerStatus(val)
}

func (c *Container) setStatus(st containerStatus) error {
switch st {
anmaxvl marked this conversation as resolved.
Show resolved Hide resolved
case containerCreating, containerCreated:
break
default:
return fmt.Errorf("unknown status: %d", st)
}
atomic.StoreUint32((*uint32)(&c.status), uint32(st))
return nil
}
65 changes: 40 additions & 25 deletions internal/guest/runtime/hcsv2/uvm.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@ import (
"syscall"
"time"

"github.com/Microsoft/hcsshim/internal/guest/policy"
"github.com/mattn/go-shellwords"
"github.com/pkg/errors"

"github.com/Microsoft/hcsshim/internal/guest/gcserr"
"github.com/Microsoft/hcsshim/internal/guest/policy"
"github.com/Microsoft/hcsshim/internal/guest/prot"
"github.com/Microsoft/hcsshim/internal/guest/runtime"
"github.com/Microsoft/hcsshim/internal/guest/spec"
Expand All @@ -36,6 +33,8 @@ import (
"github.com/Microsoft/hcsshim/internal/protocol/guestresource"
"github.com/Microsoft/hcsshim/pkg/annotations"
"github.com/Microsoft/hcsshim/pkg/securitypolicy"
"github.com/mattn/go-shellwords"
"github.com/pkg/errors"
)

// UVMContainerID is the ContainerID that will be sent on any prot.MessageBase
Expand Down Expand Up @@ -123,19 +122,30 @@ func (h *Host) RemoveContainer(id string) {
delete(h.containers, id)
}

func (h *Host) getContainerLocked(id string) (*Container, error) {
func (h *Host) GetCreatedContainer(id string) (*Container, error) {
h.containersMutex.Lock()
defer h.containersMutex.Unlock()

if c, ok := h.containers[id]; !ok {
return nil, gcserr.NewHresultError(gcserr.HrVmcomputeSystemNotFound)
} else {
if c.getStatus() != containerCreated {
return nil, fmt.Errorf("container is not in state \"created\": %w",
gcserr.NewHresultError(gcserr.HrVmcomputeInvalidState))
}
return c, nil
}
}

func (h *Host) GetContainer(id string) (*Container, error) {
func (h *Host) AddContainer(id string, c *Container) error {
h.containersMutex.Lock()
defer h.containersMutex.Unlock()

return h.getContainerLocked(id)
if _, ok := h.containers[id]; ok {
return gcserr.NewHresultError(gcserr.HrVmcomputeSystemAlreadyExists)
}
h.containers[id] = c
return nil
}

func setupSandboxMountsPath(id string) (err error) {
Expand All @@ -162,26 +172,37 @@ func setupSandboxHugePageMountsPath(id string) error {
}

func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VMHostedContainerSettingsV2) (_ *Container, err error) {
h.containersMutex.Lock()
defer h.containersMutex.Unlock()
criType, isCRI := settings.OCISpecification.Annotations[annotations.KubernetesContainerType]
c := &Container{
id: id,
vsock: h.vsock,
spec: settings.OCISpecification,
isSandbox: criType == "sandbox",
exitType: prot.NtUnexpectedExit,
processes: make(map[uint32]*containerProcess),
status: containerCreating,
}

if _, ok := h.containers[id]; ok {
return nil, gcserr.NewHresultError(gcserr.HrVmcomputeSystemAlreadyExists)
if err := h.AddContainer(id, c); err != nil {
return nil, err
}
defer func() {
if err != nil {
h.RemoveContainer(id)
}
}()

err = h.securityPolicyEnforcer.EnforceCreateContainerPolicy(
id,
settings.OCISpecification.Process.Args,
settings.OCISpecification.Process.Env,
settings.OCISpecification.Process.Cwd,
)

if err != nil {
return nil, errors.Wrapf(err, "container creation denied due to policy")
}

var namespaceID string
criType, isCRI := settings.OCISpecification.Annotations[annotations.KubernetesContainerType]
// for sandbox container sandboxID is same as container id
sandboxID := id
if isCRI {
Expand Down Expand Up @@ -290,15 +311,7 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM
return nil, errors.Wrapf(err, "failed to get container init process")
}

c := &Container{
id: id,
vsock: h.vsock,
spec: settings.OCISpecification,
isSandbox: criType == "sandbox",
container: con,
exitType: prot.NtUnexpectedExit,
processes: make(map[uint32]*containerProcess),
}
c.container = con
c.initProcess = newProcess(c, settings.OCISpecification.Process, init, uint32(c.container.Pid()), true)

// Sandbox or standalone, move the networks to the container namespace
Expand All @@ -318,7 +331,9 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM
}
}

h.containers[id] = c
if err := c.setStatus(containerCreated); err != nil {
return nil, err
}
return c, nil
}

Expand All @@ -337,7 +352,7 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req *
case guestresource.ResourceTypeVPCIDevice:
return modifyMappedVPCIDevice(ctx, req.RequestType, req.Settings.(*guestresource.LCOWMappedVPCIDevice))
case guestresource.ResourceTypeContainerConstraints:
c, err := h.GetContainer(containerID)
c, err := h.GetCreatedContainer(containerID)
if err != nil {
return err
}
Expand All @@ -355,7 +370,7 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req *
}

func (h *Host) modifyContainerSettings(ctx context.Context, containerID string, req *guestrequest.ModificationRequest) error {
c, err := h.GetContainer(containerID)
c, err := h.GetCreatedContainer(containerID)
if err != nil {
return err
}
Expand Down
32 changes: 25 additions & 7 deletions test/cri-containerd/policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"strings"
"testing"
"time"

runtime "k8s.io/cri-api/pkg/apis/runtime/v1alpha2"

Expand Down Expand Up @@ -232,13 +233,30 @@ func Test_RunContainers_WithSyncHooks_ValidWaitPath(t *testing.T) {
cidWriter := createContainer(t, client, ctx, writerReq)
cidWaiter := createContainer(t, client, ctx, waiterReq)

startContainer(t, client, ctx, cidWriter)
defer removeContainer(t, client, ctx, cidWriter)
defer stopContainer(t, client, ctx, cidWriter)

startContainer(t, client, ctx, cidWaiter)
defer removeContainer(t, client, ctx, cidWaiter)
defer stopContainer(t, client, ctx, cidWaiter)
errChan := make(chan error)
go func() {
_, err := client.StartContainer(ctx, &runtime.StartContainerRequest{ContainerId: cidWaiter})
errChan <- err
defer removeContainer(t, client, ctx, cidWaiter)
defer stopContainer(t, client, ctx, cidWaiter)
}()

// give some time for the first go routine to kick in.
time.Sleep(time.Second)

go func() {
_, err := client.StartContainer(ctx, &runtime.StartContainerRequest{ContainerId: cidWriter})
errChan <- err
defer removeContainer(t, client, ctx, cidWriter)
defer stopContainer(t, client, ctx, cidWriter)
}()

for i := 0; i < 2; i++ {
if err := <-errChan; err != nil {
close(errChan)
t.Fatalf("failed to start container: %s", err)
}
}
}

func Test_RunContainers_WithSyncHooks_InvalidWaitPath(t *testing.T) {
Expand Down