diff --git a/internal/app/machined/pkg/system/services/registry/registry.go b/internal/app/machined/pkg/system/services/registry/registry.go index 84b06c101a..d07674864f 100644 --- a/internal/app/machined/pkg/system/services/registry/registry.go +++ b/internal/app/machined/pkg/system/services/registry/registry.go @@ -123,7 +123,9 @@ func (svc *Service) handler(w http.ResponseWriter, req *http.Request) error { if p.isBlob { s = &singleFileStore{root: svc.root, path: "blob"} } else { - s = &singleFileStore{root: svc.root, path: filepath.Join("manifests", ref.Name(), "digest")} + refName := handleRegistryWithPort(ref, p) + + s = &singleFileStore{root: svc.root, path: filepath.Join("manifests", refName, "digest")} } info, err := s.Info(req.Context(), ref.Digest()) @@ -181,7 +183,9 @@ func (svc *Service) resolveCanonicalRef(p params) (reference.Canonical, error) { return nil, xerrors.NewTaggedf[internalErrorTag]("incorrect reference type: %T", ref) } - taggedFile := filepath.Join("manifests", namedTagged.Name(), "reference", namedTagged.Tag()) + namedTaggedName := handleRegistryWithPort(namedTagged, p) + + taggedFile := filepath.Join("manifests", namedTaggedName, "reference", namedTagged.Tag()) ntSum, err := hashFile(taggedFile, svc.root) if err != nil { @@ -194,7 +198,7 @@ func (svc *Service) resolveCanonicalRef(p params) (reference.Canonical, error) { digestString := strings.ReplaceAll(digest.NewDigestFromBytes(digest.SHA256, ntSum).String(), "sha256:", "sha256-") - sha256file := filepath.Join("manifests", namedTagged.Name(), "digest", digestString) + sha256file := filepath.Join("manifests", namedTaggedName, "digest", digestString) sSum, err := hashFile(sha256file, svc.root) if err != nil { @@ -211,6 +215,19 @@ func (svc *Service) resolveCanonicalRef(p params) (reference.Canonical, error) { }, nil } +func handleRegistryWithPort(namedTagged reference.Named, p params) string { + namedTaggedName := namedTagged.Name() + + idx := strings.LastIndex(p.registry, ":") + if idx > 0 { + path := strings.TrimPrefix(namedTaggedName, p.registry) + + namedTaggedName = p.registry[:idx] + "_" + p.registry[idx+1:] + "_" + path + } + + return namedTaggedName +} + func hashFile(f string, where fs.FS) (_ []byte, returnErr error) { data, err := where.Open(f) if err != nil { diff --git a/internal/app/machined/pkg/system/services/registry/registry_test.go b/internal/app/machined/pkg/system/services/registry/registry_test.go index d3e3a57620..0dca7c841f 100644 --- a/internal/app/machined/pkg/system/services/registry/registry_test.go +++ b/internal/app/machined/pkg/system/services/registry/registry_test.go @@ -8,13 +8,19 @@ import ( "cmp" "context" "errors" + "fmt" "io" "net/http" + "os" + "path/filepath" "strings" "sync" "testing" "time" + "github.com/google/go-containerregistry/pkg/v1/empty" + "github.com/google/go-containerregistry/pkg/v1/mutate" + "github.com/google/go-containerregistry/pkg/v1/types" "github.com/siderolabs/gen/xiter" "github.com/siderolabs/gen/xtesting/check" "github.com/siderolabs/go-pointer" @@ -54,7 +60,6 @@ func TestRegistry(t *testing.T) { name string path string method string - body io.Reader check check.Check expectedCode int }{ @@ -104,7 +109,7 @@ func TestRegistry(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - req, err := http.NewRequestWithContext(ctx, test.method, "http://"+constants.RegistrydListenAddress+test.path, test.body) + req, err := http.NewRequestWithContext(ctx, test.method, "http://"+constants.RegistrydListenAddress+test.path, nil) require.NoError(t, err) resp, err := http.DefaultClient.Do(req) @@ -145,3 +150,151 @@ func readAll(t *testing.T, resp *http.Response) string { return builder.String() } + +//nolint:gocyclo +func TestRegistryWithFakeFS(t *testing.T) { + cacheDir := t.TempDir() + + logger := zaptest.NewLogger(t) + + it := func(yield func(string) bool) { + if !yield(cacheDir) { + return + } + } + + svc := registry.NewService(registry.NewMultiPathFS(it), logger) + + var wg sync.WaitGroup + + wg.Add(1) + + ctx, cancel := context.WithCancelCause(context.Background()) + defer cancel(nil) + + go func() { + defer wg.Done() + + cancel(cmp.Or(svc.Run(ctx), errors.New("service exited"))) + }() + + defer wg.Wait() + + type fakeRegistryData struct { + testName string + registryName string + imageName string + blobsData map[string]string + tag string + } + + blobDir := filepath.Join(cacheDir, "blob") + manifestsDir := filepath.Join(cacheDir, "manifests") + + require.NoError(t, os.MkdirAll(blobDir, 0o755)) + require.NoError(t, os.MkdirAll(manifestsDir, 0o755)) + + for _, data := range []fakeRegistryData{ + { + testName: "normalRegistry", + registryName: "ghcr.io", + imageName: "siderolabs/installer", + blobsData: map[string]string{ + "blob-1": "7bab0675f055470848da29f4455f04cd16900bf3ab7b5d1e94585ca82252b6dd", + "blob-2": "0a9202444208bf863ed60cfc4ab1dea7278f1b8d4033acd57cc239db7fc95543", + }, + tag: "v0.1.0", + }, + { + testName: "registryWithPort", + registryName: "127.0.0.1:5010", + imageName: "custom/installer", + blobsData: map[string]string{ + "blob-1": "7bab0675f055470848da29f4455f04cd16900bf3ab7b5d1e94585ca82252b6dd", + "blob-2": "0a9202444208bf863ed60cfc4ab1dea7278f1b8d4033acd57cc239db7fc95543", + }, + tag: "v0.1.5", + }, + } { + for blobData, blobName := range data.blobsData { + blobPath := filepath.Join(blobDir, fmt.Sprintf("sha256-%s", blobName)) + require.NoError(t, os.WriteFile(blobPath, []byte(blobData), 0o644)) + } + + registry := data.registryName + + // convert :port to _port_ to support copying image-cache to vfat filesystems + idx := strings.LastIndex(registry, ":") + if idx > 0 { + registry = registry[:idx] + "_" + registry[idx+1:] + "_" + } + + digestDir := filepath.Join(manifestsDir, registry, data.imageName, "digest") + require.NoError(t, os.MkdirAll(digestDir, 0o755)) + + referenceDir := filepath.Join(manifestsDir, registry, data.imageName, "reference") + require.NoError(t, os.MkdirAll(referenceDir, 0o755)) + + newImg := mutate.MediaType(empty.Image, types.OCIManifestSchema1) + newImg = mutate.ConfigMediaType(newImg, types.OCIConfigJSON) + + manifest, err := newImg.RawManifest() + require.NoError(t, err) + + taggedFile := filepath.Join(referenceDir, data.tag) + require.NoError(t, os.WriteFile(taggedFile, manifest, 0o644)) + + digest, err := newImg.Digest() + require.NoError(t, err) + + digestFile := filepath.Join(digestDir, strings.ReplaceAll(digest.String(), "sha256:", "sha256-")) + require.NoError(t, os.WriteFile(digestFile, manifest, 0o644)) + + t.Run(data.testName, func(t *testing.T) { + for _, path := range []string{ + fmt.Sprintf("/v2/%s/manifests/%s?ns=%s", data.imageName, data.tag, data.registryName), + fmt.Sprintf("/v2/%s/manifests/%s?ns=%s", data.imageName, digest.String(), data.registryName), + } { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://"+constants.RegistrydListenAddress+path, nil) + require.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + + if resp != nil { + defer resp.Body.Close() //nolint:errcheck + } + + require.Equal(t, http.StatusOK, pointer.SafeDeref(resp).StatusCode, "unexpected status code") + require.Equal(t, string(manifest), readAll(t, resp)) + } + + for blobData, blobName := range data.blobsData { + path := fmt.Sprintf("/v2/%s/blobs/sha256:%s?ns=%s", data.imageName, blobName, data.registryName) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://"+constants.RegistrydListenAddress+path, nil) + require.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + + if resp != nil { + defer resp.Body.Close() //nolint:errcheck + } + + require.Equal(t, http.StatusOK, pointer.SafeDeref(resp).StatusCode, "unexpected status code") + require.Equal(t, blobData, readAll(t, resp)) + } + }) + } + + cancel(nil) + wg.Wait() + + err := ctx.Err() + if err == context.Canceled || err == context.DeadlineExceeded { + err = nil + } + + require.NoError(t, err) +} diff --git a/pkg/imager/cache/cache.go b/pkg/imager/cache/cache.go index 88b108e76a..c359e6f81e 100644 --- a/pkg/imager/cache/cache.go +++ b/pkg/imager/cache/cache.go @@ -43,6 +43,12 @@ func rewriteRegistry(registryName, origRef string) string { return "docker.io" } + // convert :port to _port_ to support copying image-cache to vfat filesystems + idx := strings.LastIndex(registryName, ":") + if idx > 0 { + return registryName[:idx] + "_" + registryName[idx+1:] + "_" + } + return registryName }