diff --git a/internal/app/machined/pkg/system/services/registry/registry.go b/internal/app/machined/pkg/system/services/registry/registry.go index 84b06c101a..d07674864f 100644 --- a/internal/app/machined/pkg/system/services/registry/registry.go +++ b/internal/app/machined/pkg/system/services/registry/registry.go @@ -123,7 +123,9 @@ func (svc *Service) handler(w http.ResponseWriter, req *http.Request) error { if p.isBlob { s = &singleFileStore{root: svc.root, path: "blob"} } else { - s = &singleFileStore{root: svc.root, path: filepath.Join("manifests", ref.Name(), "digest")} + refName := handleRegistryWithPort(ref, p) + + s = &singleFileStore{root: svc.root, path: filepath.Join("manifests", refName, "digest")} } info, err := s.Info(req.Context(), ref.Digest()) @@ -181,7 +183,9 @@ func (svc *Service) resolveCanonicalRef(p params) (reference.Canonical, error) { return nil, xerrors.NewTaggedf[internalErrorTag]("incorrect reference type: %T", ref) } - taggedFile := filepath.Join("manifests", namedTagged.Name(), "reference", namedTagged.Tag()) + namedTaggedName := handleRegistryWithPort(namedTagged, p) + + taggedFile := filepath.Join("manifests", namedTaggedName, "reference", namedTagged.Tag()) ntSum, err := hashFile(taggedFile, svc.root) if err != nil { @@ -194,7 +198,7 @@ func (svc *Service) resolveCanonicalRef(p params) (reference.Canonical, error) { digestString := strings.ReplaceAll(digest.NewDigestFromBytes(digest.SHA256, ntSum).String(), "sha256:", "sha256-") - sha256file := filepath.Join("manifests", namedTagged.Name(), "digest", digestString) + sha256file := filepath.Join("manifests", namedTaggedName, "digest", digestString) sSum, err := hashFile(sha256file, svc.root) if err != nil { @@ -211,6 +215,19 @@ func (svc *Service) resolveCanonicalRef(p params) (reference.Canonical, error) { }, nil } +func handleRegistryWithPort(namedTagged reference.Named, p params) string { + namedTaggedName := namedTagged.Name() + + idx := strings.LastIndex(p.registry, ":") + if idx > 0 { + path := strings.TrimPrefix(namedTaggedName, p.registry) + + namedTaggedName = p.registry[:idx] + "_" + p.registry[idx+1:] + "_" + path + } + + return namedTaggedName +} + func hashFile(f string, where fs.FS) (_ []byte, returnErr error) { data, err := where.Open(f) if err != nil { diff --git a/internal/app/machined/pkg/system/services/registry/registry_test.go b/internal/app/machined/pkg/system/services/registry/registry_test.go index d3e3a57620..e1a39a14a4 100644 --- a/internal/app/machined/pkg/system/services/registry/registry_test.go +++ b/internal/app/machined/pkg/system/services/registry/registry_test.go @@ -5,23 +5,31 @@ package registry_test import ( + "archive/tar" "cmp" "context" "errors" + "fmt" "io" "net/http" + "os" + "path/filepath" "strings" "sync" "testing" - "time" - "github.com/siderolabs/gen/xiter" + "github.com/google/go-containerregistry/pkg/crane" + "github.com/google/go-containerregistry/pkg/name" + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/layout" + "github.com/google/go-containerregistry/pkg/v1/remote" "github.com/siderolabs/gen/xtesting/check" "github.com/siderolabs/go-pointer" - "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" "go.uber.org/zap/zaptest" "github.com/siderolabs/talos/internal/app/machined/pkg/system/services/registry" + "github.com/siderolabs/talos/pkg/imager/cache" "github.com/siderolabs/talos/pkg/machinery/constants" ) @@ -30,8 +38,40 @@ func TestRegistry(t *testing.T) { t.Skip("skipping test in short mode.") } + cacheDir := t.TempDir() + + images := []string{ + fmt.Sprintf("%s:%s", constants.CoreDNSImage, constants.DefaultCoreDNSVersion), + fmt.Sprintf("%s:%s", strings.ReplaceAll(constants.CoreDNSImage, "registry.k8s.io", "registry.k8s.io:443"), constants.DefaultCoreDNSVersion), + } + + platform, err := v1.ParsePlatform("linux/amd64") + assert.NoError(t, err) + + assert.NoError(t, cache.Generate(images, platform.String(), false, "", cacheDir)) + + l, err := layout.ImageIndexFromPath(cacheDir) + assert.NoError(t, err) + + m, err := l.IndexManifest() + assert.NoError(t, err) + + image, err := l.Image(m.Manifests[0].Digest) + assert.NoError(t, err) + + registryRoot := t.TempDir() + + tarExtract(t, image, registryRoot) + logger := zaptest.NewLogger(t) - svc := registry.NewService(registry.NewMultiPathFS(xiter.Single("test")), logger) + + it := func(yield func(string) bool) { + if !yield(registryRoot) { + return + } + } + + svc := registry.NewService(registry.NewMultiPathFS(it), logger) var wg sync.WaitGroup @@ -48,13 +88,51 @@ func TestRegistry(t *testing.T) { defer wg.Wait() - time.Sleep(100 * time.Millisecond) + for _, image := range images { + t.Run(image, func(t *testing.T) { + ref, err := name.ParseReference(image) + assert.NoError(t, err) + + manifest, err := crane.Manifest(ref.String()) + assert.NoError(t, err) + + rmt, err := remote.Get(ref, remote.WithPlatform(*platform)) + assert.NoError(t, err) + + for _, path := range []string{ + fmt.Sprintf("/v2/%s/manifests/%s?ns=%s", ref.Context().RepositoryStr(), constants.DefaultCoreDNSVersion, ref.Context().RegistryStr()), + fmt.Sprintf("/v2/%s/manifests/%s?ns=%s", ref.Context().RepositoryStr(), rmt.Digest.String(), ref.Context().RegistryStr()), + } { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://"+constants.RegistrydListenAddress+path, nil) + assert.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + + if resp != nil { + t.Cleanup(func() { + assert.NoError(t, resp.Body.Close()) + }) + } + + assert.Equal(t, http.StatusOK, pointer.SafeDeref(resp).StatusCode, "unexpected status code") + assert.Equal(t, string(manifest), readAll(t, resp)) + } + + img, err := rmt.Image() + assert.NoError(t, err) + + layers, err := img.Layers() + assert.NoError(t, err) + + handleLayers(ctx, t, layers, ref) + }) + } tests := []struct { name string path string method string - body io.Reader check check.Check expectedCode int }{ @@ -104,29 +182,86 @@ func TestRegistry(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - req, err := http.NewRequestWithContext(ctx, test.method, "http://"+constants.RegistrydListenAddress+test.path, test.body) - require.NoError(t, err) + req, err := http.NewRequestWithContext(ctx, test.method, "http://"+constants.RegistrydListenAddress+test.path, nil) + assert.NoError(t, err) resp, err := http.DefaultClient.Do(req) test.check(t, err) if resp != nil { - defer resp.Body.Close() //nolint:errcheck + t.Cleanup(func() { + assert.NoError(t, resp.Body.Close()) + }) } - require.Equal(t, test.expectedCode, pointer.SafeDeref(resp).StatusCode, "unexpected status code, body is %s", readAll(t, resp)) + assert.Equal(t, test.expectedCode, pointer.SafeDeref(resp).StatusCode, "unexpected status code, body is %s", readAll(t, resp)) }) } cancel(nil) wg.Wait() - err := ctx.Err() + err = ctx.Err() if err == context.Canceled || err == context.DeadlineExceeded { err = nil } - require.NoError(t, err) + assert.NoError(t, err) +} + +func handleLayers(ctx context.Context, t *testing.T, layers []v1.Layer, ref name.Reference) { + for _, layer := range layers { + dig, err := layer.Digest() + assert.NoError(t, err) + + path := fmt.Sprintf("/v2/%s/blobs/%s?ns=%s", ref.Context().RepositoryStr(), dig, ref.Context().RegistryStr()) + + req, err := http.NewRequestWithContext(ctx, http.MethodHead, "http://"+constants.RegistrydListenAddress+path, nil) + assert.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + + if resp != nil { + t.Cleanup(func() { + assert.NoError(t, resp.Body.Close()) + }) + } + + assert.Equal(t, http.StatusOK, pointer.SafeDeref(resp).StatusCode, "unexpected status code") + } +} + +func tarExtract(t *testing.T, img v1.Image, dest string) { + pipeReader, pipeWriter := io.Pipe() + + go func() { + pipeWriter.CloseWithError(crane.Export(img, pipeWriter)) + }() + + tr := tar.NewReader(pipeReader) + + for { + header, err := tr.Next() + if err == io.EOF { + break + } + + assert.NoError(t, err) + + switch header.Typeflag { + case tar.TypeDir: + assert.NoError(t, os.MkdirAll(filepath.Join(dest, header.Name), 0o755)) + case tar.TypeReg: + f, err := os.Create(filepath.Join(dest, header.Name)) + assert.NoError(t, err) + + _, err = io.Copy(f, tr) + assert.NoError(t, err) + default: + assert.Failf(t, "unexpected tar entry type", "type: %v", header.Typeflag) + } + } } func readAll(t *testing.T, resp *http.Response) string { @@ -137,7 +272,7 @@ func readAll(t *testing.T, resp *http.Response) string { var builder strings.Builder _, err := io.Copy(&builder, resp.Body) - require.NoError(t, err) + assert.NoError(t, err) if builder.String() == "" { return "" diff --git a/pkg/imager/cache/cache.go b/pkg/imager/cache/cache.go index 88b108e76a..c359e6f81e 100644 --- a/pkg/imager/cache/cache.go +++ b/pkg/imager/cache/cache.go @@ -43,6 +43,12 @@ func rewriteRegistry(registryName, origRef string) string { return "docker.io" } + // convert :port to _port_ to support copying image-cache to vfat filesystems + idx := strings.LastIndex(registryName, ":") + if idx > 0 { + return registryName[:idx] + "_" + registryName[idx+1:] + "_" + } + return registryName }