diff --git a/cmd/mdltool/cli_test.go b/cmd/mdltool/cli_test.go new file mode 100644 index 0000000..b54dfa8 --- /dev/null +++ b/cmd/mdltool/cli_test.go @@ -0,0 +1,52 @@ +package main + +import ( + "os" + "os/exec" + "strings" + "testing" +) + +func TestCLIDefaultRegistry(t *testing.T) { + // Build the mdltool binary for testing + buildCmd := exec.Command("go", "build", "-o", "test-mdltool", ".") + if err := buildCmd.Run(); err != nil { + t.Fatalf("Failed to build test binary: %v", err) + } + defer os.Remove("test-mdltool") + + tests := []struct { + name string + args []string + contains string + }{ + { + name: "help shows default-registry option", + args: []string{"--help"}, + contains: "-default-registry", + }, + { + name: "version works", + args: []string{"--version"}, + contains: "model-distribution-tool version", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := exec.Command("./test-mdltool", tt.args...) + output, err := cmd.CombinedOutput() + + // For help and version, these should exit successfully + if tt.name == "help shows default-registry option" || tt.name == "version works" { + if err != nil { + t.Errorf("Command failed: %v, output: %s", err, output) + } + } + + if !strings.Contains(string(output), tt.contains) { + t.Errorf("Expected output to contain %q, got: %s", tt.contains, output) + } + }) + } +} \ No newline at end of file diff --git a/cmd/mdltool/main.go b/cmd/mdltool/main.go index e95e88f..4597fab 100644 --- a/cmd/mdltool/main.go +++ b/cmd/mdltool/main.go @@ -32,13 +32,15 @@ const ( ) var ( - storePath string - showHelp bool - showVer bool + storePath string + defaultRegistry string + showHelp bool + showVer bool ) func init() { flag.StringVar(&storePath, "store-path", defaultStorePath, "Path to the model store") + flag.StringVar(&defaultRegistry, "default-registry", "", "Default registry for model references (e.g., registry.example.com)") flag.BoolVar(&showHelp, "help", false, "Show help") flag.BoolVar(&showVer, "version", false, "Show version") } @@ -69,6 +71,10 @@ func main() { distribution.WithUserAgent("model-distribution-tool/" + version), } + if defaultRegistry != "" { + clientOpts = append(clientOpts, distribution.WithDefaultRegistry(defaultRegistry)) + } + if username := os.Getenv("DOCKER_USERNAME"); username != "" { if password := os.Getenv("DOCKER_PASSWORD"); password != "" { clientOpts = append(clientOpts, distribution.WithRegistryAuth(username, password)) diff --git a/distribution/client.go b/distribution/client.go index ca1e022..e1796e5 100644 --- a/distribution/client.go +++ b/distribution/client.go @@ -9,6 +9,7 @@ import ( "github.com/sirupsen/logrus" + "github.com/docker/model-distribution/internal/naming" "github.com/docker/model-distribution/internal/progress" "github.com/docker/model-distribution/internal/store" "github.com/docker/model-distribution/registry" @@ -33,12 +34,13 @@ type Option func(*options) // options holds the configuration for a new Client type options struct { - storeRootPath string - logger *logrus.Entry - transport http.RoundTripper - userAgent string - username string - password string + storeRootPath string + logger *logrus.Entry + transport http.RoundTripper + userAgent string + username string + password string + defaultRegistry string } // WithStoreRootPath sets the store root path @@ -87,6 +89,15 @@ func WithRegistryAuth(username, password string) Option { } } +// WithDefaultRegistry sets the default registry namespace for model references +func WithDefaultRegistry(registry string) Option { + return func(o *options) { + if registry != "" { + o.defaultRegistry = registry + } + } +} + func defaultOptions() *options { return &options{ logger: logrus.NewEntry(logrus.StandardLogger()), @@ -124,6 +135,13 @@ func NewClient(opts ...Option) (*Client, error) { registryOpts = append(registryOpts, registry.WithAuthConfig(options.username, options.password)) } + // Add default registry namespace if provided + if options.defaultRegistry != "" { + registryOpts = append(registryOpts, registry.WithDefaultNamespace(options.defaultRegistry)) + // Also set it globally for store operations + naming.SetDefaultNamespace(options.defaultRegistry) + } + options.logger.Infoln("Successfully initialized store") return &Client{ store: s, diff --git a/distribution/integration_test.go b/distribution/integration_test.go new file mode 100644 index 0000000..6052ed2 --- /dev/null +++ b/distribution/integration_test.go @@ -0,0 +1,76 @@ +package distribution + +import ( + "testing" + "path/filepath" + "os" +) + +func TestDefaultRegistryIntegration(t *testing.T) { + // Create a temporary directory for the test store + tempDir, err := os.MkdirTemp("", "test-store-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + tests := []struct { + name string + defaultRegistry string + reference string + expectError bool + description string + }{ + { + name: "no default registry - standard behavior", + defaultRegistry: "", + reference: "library/alpine:latest", + expectError: true, // will fail because registry is not accessible, but parsing should work + description: "Should use Docker Hub as default", + }, + { + name: "custom default registry applied", + defaultRegistry: "registry.example.com", + reference: "mymodel:latest", + expectError: true, // will fail because registry is not accessible, but parsing should work + description: "Should apply custom default registry", + }, + { + name: "explicit registry preserved", + defaultRegistry: "registry.example.com", + reference: "other.registry.com/mymodel:latest", + expectError: true, // will fail because registry is not accessible, but parsing should work + description: "Should preserve explicit registry even when default is set", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + storeDir := filepath.Join(tempDir, tt.name) + + // Create client options + opts := []Option{ + WithStoreRootPath(storeDir), + } + if tt.defaultRegistry != "" { + opts = append(opts, WithDefaultRegistry(tt.defaultRegistry)) + } + + // Create distribution client + client, err := NewClient(opts...) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + // Test that the client can be created and the reference parsing works + // We expect these to fail with network errors since the registries don't exist, + // but we want to ensure the parsing is correct + _ = client + + // For this test, we're mainly validating that the client can be created + // with the default registry configuration. The actual network operations + // would require a running registry which is beyond the scope of this test. + t.Logf("Successfully created client with default registry: %s", tt.defaultRegistry) + }) + } +} \ No newline at end of file diff --git a/internal/naming/README.md b/internal/naming/README.md new file mode 100644 index 0000000..ff6f8d5 --- /dev/null +++ b/internal/naming/README.md @@ -0,0 +1,122 @@ +# Naming Package + +The naming package provides configurable default namespace support for container registry references in the model-distribution system. + +## Problem + +By default, the `github.com/google/go-containerregistry/pkg/name` package assumes Docker Hub (`index.docker.io`) as the default registry for image references that don't include an explicit registry hostname. For model distribution systems, it's often desirable to use a different default registry. + +## Solution + +This package provides wrapper functions around `name.ParseReference` and `name.ParseTag` that apply a configurable default namespace when no explicit registry is provided in the reference. + +## Usage + +### Client-specific Configuration + +```go +import ( + "github.com/docker/model-distribution/registry" + "github.com/docker/model-distribution/distribution" +) + +// Configure registry client with custom default namespace +registryClient := registry.NewClient( + registry.WithDefaultNamespace("registry.example.com"), +) + +// Configure distribution client with custom default namespace +distClient, err := distribution.NewClient( + distribution.WithStoreRootPath("/path/to/store"), + distribution.WithDefaultRegistry("registry.example.com"), +) +``` + +### Global Configuration + +```go +import "github.com/docker/model-distribution/internal/naming" + +// Set global default namespace +naming.SetDefaultNamespace("registry.example.com") + +// Use global convenience functions +ref, err := naming.ParseReference("mymodel:latest") +// Will resolve to registry.example.com/mymodel:latest + +tag, err := naming.ParseTag("mymodel:latest") +// Will resolve to registry.example.com/mymodel:latest +``` + +### Direct Usage + +```go +import "github.com/docker/model-distribution/internal/naming" + +// Create a namespace configuration +ns := &naming.DefaultNamespace{Registry: "registry.example.com"} + +// Parse references with custom default +ref, err := ns.ParseReference("mymodel:latest") +// Will resolve to registry.example.com/mymodel:latest + +// Explicit registries are preserved +ref, err := ns.ParseReference("other.registry.com/mymodel:latest") +// Will remain other.registry.com/mymodel:latest +``` + +## CLI Usage + +The `mdltool` command-line interface supports a `--default-registry` flag: + +```bash +# Use custom default registry +mdltool --default-registry registry.example.com pull mymodel:latest + +# This will pull from registry.example.com/mymodel:latest instead of +# the Docker Hub default index.docker.io/library/mymodel:latest +``` + +## Behavior + +- **No explicit registry**: Default namespace is applied + - Input: `mymodel:latest` → Output: `registry.example.com/mymodel:latest` + - Input: `user/mymodel:latest` → Output: `registry.example.com/user/mymodel:latest` + +- **Explicit registry**: Default namespace is ignored + - Input: `other.registry.com/mymodel:latest` → Output: `other.registry.com/mymodel:latest` + - Input: `localhost:5000/mymodel:latest` → Output: `localhost:5000/mymodel:latest` + +- **No default configured**: Falls back to standard go-containerregistry behavior + - Input: `mymodel:latest` → Output: `index.docker.io/library/mymodel:latest` + +## Registry Detection + +The package uses heuristics to detect whether a reference already contains an explicit registry: + +- Contains a dot (`.`) before the first slash → Likely a registry hostname +- Contains a colon (`:`) followed by digits before the first slash → Likely a registry with port +- Contains a colon followed by non-digits → Likely a tag, not a registry port + +Examples: +- `mymodel:latest` → No explicit registry (`:latest` is a tag) +- `registry.com/mymodel:latest` → Has explicit registry (contains dot) +- `localhost:5000/mymodel:latest` → Has explicit registry (port number) +- `user/mymodel:latest` → No explicit registry + +## Integration Points + +This functionality is integrated at the following levels: + +1. **Registry Client**: `registry.Client` can be configured with a default namespace +2. **Distribution Client**: `distribution.Client` configures both registry client and store operations +3. **Store Operations**: Uses global namespace configuration for tag/reference parsing +4. **CLI Tool**: Accepts `--default-registry` flag to configure the default + +## Backward Compatibility + +This change is fully backward compatible: + +- Existing code without default namespace configuration continues to work unchanged +- All explicit registry references continue to work as before +- The default behavior (Docker Hub) is preserved when no configuration is provided \ No newline at end of file diff --git a/internal/naming/naming.go b/internal/naming/naming.go new file mode 100644 index 0000000..42b3748 --- /dev/null +++ b/internal/naming/naming.go @@ -0,0 +1,122 @@ +package naming + +import ( + "strings" + + "github.com/google/go-containerregistry/pkg/name" +) + +// DefaultNamespace holds the default registry namespace configuration +type DefaultNamespace struct { + Registry string +} + +// ParseReference parses a reference string, applying the default namespace if needed +func (dn *DefaultNamespace) ParseReference(reference string) (name.Reference, error) { + // If no default namespace is configured, use standard parsing + if dn == nil || dn.Registry == "" { + return name.ParseReference(reference) + } + + // If the reference already contains a registry (has a domain with dot or port), use as-is + if hasExplicitRegistry(reference) { + return name.ParseReference(reference) + } + + // Apply default registry to the reference + qualified := dn.Registry + "/" + reference + return name.ParseReference(qualified) +} + +// ParseTag parses a tag string, applying the default namespace if needed +func (dn *DefaultNamespace) ParseTag(tag string) (name.Tag, error) { + // If no default namespace is configured, use standard parsing + if dn == nil || dn.Registry == "" { + return name.NewTag(tag) + } + + // If the tag already contains a registry (has a domain with dot or port), use as-is + if hasExplicitRegistry(tag) { + return name.NewTag(tag) + } + + // Apply default registry to the tag + qualified := dn.Registry + "/" + tag + return name.NewTag(qualified) +} + +// hasExplicitRegistry checks if a reference already contains an explicit registry +// This is a simple heuristic: if it contains a dot before the first slash or +// a colon followed by a port number before the first slash, it's likely a registry hostname +func hasExplicitRegistry(reference string) bool { + // Find the first slash + slashIndex := strings.Index(reference, "/") + + // If no slash, check if it looks like a registry (contains dot, not just tag colon) + if slashIndex == -1 { + // If it contains a dot, it's likely a registry + if strings.Contains(reference, ".") { + return true + } + // If it contains a colon, check if it's followed by a numeric port + colonIndex := strings.Index(reference, ":") + if colonIndex != -1 { + // Check if what comes after colon looks like a port number + afterColon := reference[colonIndex+1:] + // If it's all digits, it's a port; otherwise it's a tag + for _, r := range afterColon { + if r < '0' || r > '9' { + return false // It's a tag, not a port + } + } + return len(afterColon) > 0 // It's a port if non-empty and all digits + } + return false + } + + // Check the part before the first slash + beforeSlash := reference[:slashIndex] + + // If it contains a dot (domain), it's a registry + if strings.Contains(beforeSlash, ".") { + return true + } + + // If it contains a colon, check if it's followed by a numeric port + if strings.Contains(beforeSlash, ":") { + colonIndex := strings.Index(beforeSlash, ":") + afterColon := beforeSlash[colonIndex+1:] + // Check if what comes after colon looks like a port number + for _, r := range afterColon { + if r < '0' || r > '9' { + return false // Not a port + } + } + return len(afterColon) > 0 // It's a port if non-empty and all digits + } + + return false +} + +// Global default namespace instance +var globalDefaultNamespace *DefaultNamespace + +// SetDefaultNamespace sets the global default namespace +func SetDefaultNamespace(registry string) { + globalDefaultNamespace = &DefaultNamespace{Registry: registry} +} + +// GetDefaultNamespace returns the current global default namespace +func GetDefaultNamespace() *DefaultNamespace { + return globalDefaultNamespace +} + +// ParseReference is a convenience function that uses the global default namespace +func ParseReference(reference string) (name.Reference, error) { + return globalDefaultNamespace.ParseReference(reference) +} + +// ParseTag is a convenience function that uses the global default namespace +func ParseTag(tag string) (name.Tag, error) { + return globalDefaultNamespace.ParseTag(tag) +} \ No newline at end of file diff --git a/internal/naming/naming_test.go b/internal/naming/naming_test.go new file mode 100644 index 0000000..ec45356 --- /dev/null +++ b/internal/naming/naming_test.go @@ -0,0 +1,227 @@ +package naming + +import ( + "testing" +) + +func TestDefaultNamespace_ParseReference(t *testing.T) { + tests := []struct { + name string + defaultRegistry string + input string + expectedRegistry string + expectedRepository string + expectedError bool + }{ + { + name: "no default namespace - standard behavior", + defaultRegistry: "", + input: "mymodel:latest", + expectedRegistry: "index.docker.io", + expectedRepository: "library/mymodel", + expectedError: false, + }, + { + name: "default registry applied to simple reference", + defaultRegistry: "registry.example.com", + input: "mymodel:latest", + expectedRegistry: "registry.example.com", + expectedRepository: "mymodel", + expectedError: false, + }, + { + name: "explicit registry preserved", + defaultRegistry: "registry.example.com", + input: "other.registry.com/user/mymodel:latest", + expectedRegistry: "other.registry.com", + expectedRepository: "user/mymodel", + expectedError: false, + }, + { + name: "localhost registry preserved", + defaultRegistry: "registry.example.com", + input: "localhost:5000/mymodel:latest", + expectedRegistry: "localhost:5000", + expectedRepository: "mymodel", + expectedError: false, + }, + { + name: "docker hub user namespace preserved with default", + defaultRegistry: "registry.example.com", + input: "user/mymodel:latest", + expectedRegistry: "registry.example.com", + expectedRepository: "user/mymodel", + expectedError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dn := &DefaultNamespace{Registry: tt.defaultRegistry} + ref, err := dn.ParseReference(tt.input) + + if tt.expectedError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if ref.Context().Registry.RegistryStr() != tt.expectedRegistry { + t.Errorf("expected registry %s, got %s", tt.expectedRegistry, ref.Context().Registry.RegistryStr()) + } + + if ref.Context().RepositoryStr() != tt.expectedRepository { + t.Errorf("expected repository %s, got %s", tt.expectedRepository, ref.Context().RepositoryStr()) + } + }) + } +} + +func TestDefaultNamespace_ParseTag(t *testing.T) { + tests := []struct { + name string + defaultRegistry string + input string + expectedRegistry string + expectedRepository string + expectedError bool + }{ + { + name: "no default namespace - standard behavior", + defaultRegistry: "", + input: "mymodel:latest", + expectedRegistry: "index.docker.io", + expectedRepository: "library/mymodel", + expectedError: false, + }, + { + name: "default registry applied to simple tag", + defaultRegistry: "registry.example.com", + input: "mymodel:latest", + expectedRegistry: "registry.example.com", + expectedRepository: "mymodel", + expectedError: false, + }, + { + name: "explicit registry preserved in tag", + defaultRegistry: "registry.example.com", + input: "other.registry.com/user/mymodel:latest", + expectedRegistry: "other.registry.com", + expectedRepository: "user/mymodel", + expectedError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dn := &DefaultNamespace{Registry: tt.defaultRegistry} + tag, err := dn.ParseTag(tt.input) + + if tt.expectedError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if tag.Context().Registry.RegistryStr() != tt.expectedRegistry { + t.Errorf("expected registry %s, got %s", tt.expectedRegistry, tag.Context().Registry.RegistryStr()) + } + + if tag.Context().RepositoryStr() != tt.expectedRepository { + t.Errorf("expected repository %s, got %s", tt.expectedRepository, tag.Context().RepositoryStr()) + } + }) + } +} + +func TestHasExplicitRegistry(t *testing.T) { + tests := []struct { + input string + expected bool + }{ + {"mymodel:latest", false}, + {"user/mymodel:latest", false}, + {"registry.example.com/mymodel:latest", true}, + {"localhost:5000/mymodel:latest", true}, + {"example.com/user/mymodel:latest", true}, + {"sub.domain.com/ns/mymodel:latest", true}, + {"localhost/mymodel:latest", false}, // localhost without port/dot after is not a registry + {"model", false}, + {"registry.com", true}, + {"host:8080", true}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := hasExplicitRegistry(tt.input) + if result != tt.expected { + t.Errorf("hasExplicitRegistry(%q) = %v, expected %v", tt.input, result, tt.expected) + } + }) + } +} + +func TestGlobalDefaultNamespace(t *testing.T) { + // Save original state + original := globalDefaultNamespace + defer func() { + globalDefaultNamespace = original + }() + + // Test setting and getting + SetDefaultNamespace("test.registry.com") + dn := GetDefaultNamespace() + if dn == nil || dn.Registry != "test.registry.com" { + t.Errorf("expected registry test.registry.com, got %v", dn) + } + + // Test convenience functions + ref, err := ParseReference("mymodel:latest") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ref.Context().Registry.RegistryStr() != "test.registry.com" { + t.Errorf("expected registry test.registry.com, got %s", ref.Context().Registry.RegistryStr()) + } + + tag, err := ParseTag("mymodel:latest") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tag.Context().Registry.RegistryStr() != "test.registry.com" { + t.Errorf("expected registry test.registry.com, got %s", tag.Context().Registry.RegistryStr()) + } +} + +func TestNilDefaultNamespace(t *testing.T) { + var dn *DefaultNamespace + + // Should fall back to standard behavior + ref, err := dn.ParseReference("mymodel:latest") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ref.Context().Registry.RegistryStr() != "index.docker.io" { + t.Errorf("expected Docker Hub default, got %s", ref.Context().Registry.RegistryStr()) + } + + tag, err := dn.ParseTag("mymodel:latest") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tag.Context().Registry.RegistryStr() != "index.docker.io" { + t.Errorf("expected Docker Hub default, got %s", tag.Context().Registry.RegistryStr()) + } +} \ No newline at end of file diff --git a/internal/store/index.go b/internal/store/index.go index 000d7d7..c553e16 100644 --- a/internal/store/index.go +++ b/internal/store/index.go @@ -8,6 +8,7 @@ import ( "path/filepath" "github.com/google/go-containerregistry/pkg/name" + "github.com/docker/model-distribution/internal/naming" ) // Index represents the index of all models in the store @@ -16,7 +17,7 @@ type Index struct { } func (i Index) Tag(reference string, tag string) (Index, error) { - tagRef, err := name.NewTag(tag) + tagRef, err := naming.ParseTag(tag) if err != nil { return Index{}, fmt.Errorf("invalid tag: %w", err) } @@ -39,7 +40,7 @@ func (i Index) Tag(reference string, tag string) (Index, error) { } func (i Index) UnTag(tag string) (name.Tag, Index, error) { - tagRef, err := name.NewTag(tag) + tagRef, err := naming.ParseTag(tag) if err != nil { return name.Tag{}, Index{}, err } @@ -141,12 +142,12 @@ type IndexEntry struct { } func (e IndexEntry) HasTag(tag string) bool { - ref, err := name.NewTag(tag) + ref, err := naming.ParseTag(tag) if err != nil { return false } for _, t := range e.Tags { - tr, err := name.ParseReference(t) + tr, err := naming.ParseReference(t) if err != nil { continue } @@ -159,7 +160,7 @@ func (e IndexEntry) HasTag(tag string) bool { func (e IndexEntry) hasTag(tag name.Tag) bool { for _, t := range e.Tags { - tr, err := name.ParseReference(t) + tr, err := naming.ParseReference(t) if err != nil { continue } @@ -174,7 +175,7 @@ func (e IndexEntry) MatchesReference(reference string) bool { if e.ID == reference { return true } - ref, err := name.ParseReference(reference) + ref, err := naming.ParseReference(reference) if err != nil { return false } @@ -200,7 +201,7 @@ func (e IndexEntry) Tag(tag name.Tag) IndexEntry { func (e IndexEntry) UnTag(tag name.Tag) IndexEntry { var tags []string for i, t := range e.Tags { - tr, err := name.ParseReference(t) + tr, err := naming.ParseReference(t) if err != nil { continue } diff --git a/registry/client.go b/registry/client.go index 44a58a1..f40f515 100644 --- a/registry/client.go +++ b/registry/client.go @@ -13,6 +13,7 @@ import ( "github.com/google/go-containerregistry/pkg/v1/remote" "github.com/google/go-containerregistry/pkg/v1/remote/transport" + "github.com/docker/model-distribution/internal/naming" "github.com/docker/model-distribution/internal/progress" "github.com/docker/model-distribution/types" ) @@ -26,10 +27,11 @@ var ( ) type Client struct { - transport http.RoundTripper - userAgent string - keychain authn.Keychain - auth authn.Authenticator + transport http.RoundTripper + userAgent string + keychain authn.Keychain + auth authn.Authenticator + defaultNamespace *naming.DefaultNamespace } type ClientOption func(*Client) @@ -61,6 +63,14 @@ func WithAuthConfig(username, password string) ClientOption { } } +func WithDefaultNamespace(registry string) ClientOption { + return func(c *Client) { + if registry != "" { + c.defaultNamespace = &naming.DefaultNamespace{Registry: registry} + } + } +} + func NewClient(opts ...ClientOption) *Client { client := &Client{ transport: remote.DefaultTransport, @@ -75,7 +85,7 @@ func NewClient(opts ...ClientOption) *Client { func (c *Client) Model(ctx context.Context, reference string) (types.ModelArtifact, error) { // Parse the reference - ref, err := name.ParseReference(reference) + ref, err := c.defaultNamespace.ParseReference(reference) if err != nil { return nil, NewReferenceError(reference, err) } @@ -115,7 +125,7 @@ func (c *Client) Model(ctx context.Context, reference string) (types.ModelArtifa func (c *Client) BlobURL(reference string, digest v1.Hash) (string, error) { // Parse the reference - ref, err := name.ParseReference(reference) + ref, err := c.defaultNamespace.ParseReference(reference) if err != nil { return "", NewReferenceError(reference, err) } @@ -129,7 +139,7 @@ func (c *Client) BlobURL(reference string, digest v1.Hash) (string, error) { func (c *Client) BearerToken(ctx context.Context, reference string) (string, error) { // Parse the reference - ref, err := name.ParseReference(reference) + ref, err := c.defaultNamespace.ParseReference(reference) if err != nil { return "", NewReferenceError(reference, err) } @@ -157,24 +167,26 @@ func (c *Client) BearerToken(ctx context.Context, reference string) (string, err } type Target struct { - reference name.Reference - transport http.RoundTripper - userAgent string - keychain authn.Keychain - auth authn.Authenticator + reference name.Reference + transport http.RoundTripper + userAgent string + keychain authn.Keychain + auth authn.Authenticator + defaultNamespace *naming.DefaultNamespace } func (c *Client) NewTarget(tag string) (*Target, error) { - ref, err := name.NewTag(tag) + ref, err := c.defaultNamespace.ParseTag(tag) if err != nil { return nil, fmt.Errorf("invalid tag: %q: %w", tag, err) } return &Target{ - reference: ref, - transport: c.transport, - userAgent: c.userAgent, - keychain: c.keychain, - auth: c.auth, + reference: ref, + transport: c.transport, + userAgent: c.userAgent, + keychain: c.keychain, + auth: c.auth, + defaultNamespace: c.defaultNamespace, }, nil } diff --git a/registry/client_namespace_test.go b/registry/client_namespace_test.go new file mode 100644 index 0000000..c5b0a18 --- /dev/null +++ b/registry/client_namespace_test.go @@ -0,0 +1,77 @@ +package registry + +import ( + "testing" +) + +func TestWithDefaultNamespace(t *testing.T) { + // Test client without default namespace - should use standard behavior + client1 := NewClient() + if client1.defaultNamespace != nil { + t.Error("expected nil default namespace") + } + + // Test client with default namespace + client2 := NewClient(WithDefaultNamespace("registry.example.com")) + if client2.defaultNamespace == nil { + t.Fatal("expected non-nil default namespace") + } + if client2.defaultNamespace.Registry != "registry.example.com" { + t.Errorf("expected registry.example.com, got %s", client2.defaultNamespace.Registry) + } + + // Test empty namespace (should not set) + client3 := NewClient(WithDefaultNamespace("")) + if client3.defaultNamespace != nil { + t.Error("expected nil default namespace for empty string") + } +} + +func TestClientNamespaceBehavior(t *testing.T) { + tests := []struct { + name string + defaultRegistry string + input string + expectedRegistry string + }{ + { + name: "no default - uses Docker Hub", + defaultRegistry: "", + input: "mymodel:latest", + expectedRegistry: "index.docker.io", + }, + { + name: "custom default applied", + defaultRegistry: "registry.example.com", + input: "mymodel:latest", + expectedRegistry: "registry.example.com", + }, + { + name: "explicit registry preserved", + defaultRegistry: "registry.example.com", + input: "other.registry.com/mymodel:latest", + expectedRegistry: "other.registry.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var client *Client + if tt.defaultRegistry == "" { + client = NewClient() + } else { + client = NewClient(WithDefaultNamespace(tt.defaultRegistry)) + } + + // Test Model method parsing + ref, err := client.defaultNamespace.ParseReference(tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if ref.Context().Registry.RegistryStr() != tt.expectedRegistry { + t.Errorf("expected registry %s, got %s", tt.expectedRegistry, ref.Context().Registry.RegistryStr()) + } + }) + } +} \ No newline at end of file