diff --git a/types/project.go b/types/project.go index afd787ef..9675840e 100644 --- a/types/project.go +++ b/types/project.go @@ -18,6 +18,7 @@ package types import ( "bytes" + "context" "encoding/json" "fmt" "os" @@ -536,39 +537,29 @@ func (p *Project) WithServicesDisabled(names ...string) *Project { // WithImagesResolved updates services images to include digest computed by a resolver function // It returns a new Project instance with the changes and keep the original Project unchanged func (p *Project) WithImagesResolved(resolver func(named reference.Named) (godigest.Digest, error)) (*Project, error) { - newProject := p.deepCopy() - eg := errgroup.Group{} - for i, s := range newProject.Services { - idx := i - service := s - + return p.WithServicesTransform(func(name string, service ServiceConfig) (ServiceConfig, error) { if service.Image == "" { - continue + return service, nil } - eg.Go(func() error { - named, err := reference.ParseDockerRef(service.Image) + named, err := reference.ParseDockerRef(service.Image) + if err != nil { + return service, err + } + + if _, ok := named.(reference.Canonical); !ok { + // image is named but not digested reference + digest, err := resolver(named) if err != nil { - return err + return service, err } - - if _, ok := named.(reference.Canonical); !ok { - // image is named but not digested reference - digest, err := resolver(named) - if err != nil { - return err - } - named, err = reference.WithDigest(named, digest) - if err != nil { - return err - } + named, err = reference.WithDigest(named, digest) + if err != nil { + return service, err } - - service.Image = named.String() - newProject.Services[idx] = service - return nil - }) - } - return newProject, eg.Wait() + } + service.Image = named.String() + return service, nil + }) } // MarshalYAML marshal Project into a yaml tree @@ -662,3 +653,47 @@ func (p *Project) deepCopy() *Project { } return instance.(*Project) } + +// WithServicesTransform applies a transformation to project services and return a new project with transformation results +func (p *Project) WithServicesTransform(fn func(name string, s ServiceConfig) (ServiceConfig, error)) (*Project, error) { + type result struct { + name string + service ServiceConfig + } + resultCh := make(chan result) + newProject := p.deepCopy() + + eg, ctx := errgroup.WithContext(context.Background()) + eg.Go(func() error { + expect := len(newProject.Services) + s := Services{} + for expect > 0 { + select { + case <-ctx.Done(): + // interrupted as some goroutine returned an error + return nil + case r := <-resultCh: + s[r.name] = r.service + expect-- + } + } + newProject.Services = s + return nil + }) + for n, s := range newProject.Services { + name := n + service := s + eg.Go(func() error { + updated, err := fn(name, service) + if err != nil { + return err + } + resultCh <- result{ + name: name, + service: updated, + } + return nil + }) + } + return newProject, eg.Wait() +} diff --git a/types/project_test.go b/types/project_test.go index 910ec577..c2c0fbfe 100644 --- a/types/project_test.go +++ b/types/project_test.go @@ -18,6 +18,8 @@ package types import ( _ "crypto/sha256" + "errors" + "fmt" "testing" "github.com/compose-spec/compose-go/v2/utils" @@ -206,6 +208,43 @@ func Test_ResolveImages(t *testing.T) { } } +func Test_ResolveImages_concurrent(t *testing.T) { + const garfield = "sha256:1234567890123456789012345678901234567890123456789012345678901234" + resolver := func(named reference.Named) (digest.Digest, error) { + return garfield, nil + } + p := &Project{ + Services: Services{}, + } + for i := 0; i < 1000; i++ { + p.Services[fmt.Sprintf("service_%d", i)] = ServiceConfig{ + Image: fmt.Sprintf("image_%d", i), + } + } + p, err := p.WithImagesResolved(resolver) + assert.NilError(t, err) + for i := 0; i < 1000; i++ { + assert.Equal(t, p.Services[fmt.Sprintf("service_%d", i)].Image, + fmt.Sprintf("docker.io/library/image_%d:latest@%s", i, garfield)) + } +} + +func Test_ResolveImages_concurrent_interrupted(t *testing.T) { + resolver := func(named reference.Named) (digest.Digest, error) { + return "", errors.New("something went wrong") + } + p := Project{ + Services: Services{}, + } + for i := 0; i < 10; i++ { + p.Services[fmt.Sprintf("service_%d", i)] = ServiceConfig{ + Image: fmt.Sprintf("image_%d", i), + } + } + _, err := p.WithImagesResolved(resolver) + assert.Error(t, err, "something went wrong") +} + func TestWithServices(t *testing.T) { p := makeProject() var seen []string