Skip to content
This repository was archived by the owner on Sep 29, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Dockerfile.gguf
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ RUN --mount=type=secret,id=DOCKER_USERNAME,env=DOCKER_USERNAME \
model-distribution-tool package \
--licenses /licenses/LICENSE \
--mmproj /model/model.mmproj \
/model/model.gguf \
$HUB_REPOSITORY:$TAG; \
--tag $HUB_REPOSITORY:$TAG \
/model/model.gguf; \
else \
echo "Packaging without multimodal projector file"; \
model-distribution-tool package \
--licenses /licenses/LICENSE \
/model/model.gguf \
$HUB_REPOSITORY:$TAG; \
--tag $HUB_REPOSITORY:$TAG \
/model/model.gguf; \
fi'
4 changes: 2 additions & 2 deletions Dockerfile.safetensors
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,5 @@ RUN --mount=type=secret,id=DOCKER_USERNAME,env=DOCKER_USERNAME \
--mount=type=secret,id=DOCKER_PASSWORD,env=DOCKER_PASSWORD \
model-distribution-tool package \
--licenses /licenses/LICENSE \
/model/model.gguf \
$HUB_REPOSITORY:$WEIGHTS-$QUANTIZATION
--tag $HUB_REPOSITORY:$WEIGHTS-$QUANTIZATION \
/model/model.gguf
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,22 @@ make build
./bin/model-distribution-tool pull registry.example.com/models/llama:v1.0

# Package a model and push to a registry
./bin/model-distribution-tool package ./model.gguf registry.example.com/models/llama:v1.0
./bin/model-distribution-tool package --tag registry.example.com/models/llama:v1.0 ./model.gguf

# Package a model with license files and push to a registry
./bin/model-distribution-tool package --licenses license1.txt --licenses license2.txt ./model.gguf registry.example.com/models/llama:v1.0
./bin/model-distribution-tool package --licenses license1.txt --licenses license2.txt --tag registry.example.com/models/llama:v1.0 ./model.gguf

# Package a model with a default context size and push to a registry
./bin/model-distribution-tool package ./model.gguf --context-size 2048 registry.example.com/models/llama:v1.0
./bin/model-distribution-tool package --context-size 2048 --tag registry.example.com/models/llama:v1.0 ./model.gguf

# Package a model with a multimodal projector file and push to a registry
./bin/model-distribution-tool package ./model.gguf --mmproj ./model.mmproj registry.example.com/models/llama:v1.0
./bin/model-distribution-tool package --mmproj ./model.mmproj --tag registry.example.com/models/llama:v1.0 ./model.gguf

# Package a model and output the result to a file
./bin/model-distribution-tool package --file ./model.tar ./model.gguf

# Load a model from an archive into the local store
./bin/model-distribution-tool load ./model.tar

# Push a model from the content store to the registry
./bin/model-distribution-tool push registry.example.com/models/llama:v1.0
Expand Down
96 changes: 84 additions & 12 deletions cmd/mdltool/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/docker/model-distribution/builder"
"github.com/docker/model-distribution/distribution"
"github.com/docker/model-distribution/registry"
"github.com/docker/model-distribution/tarball"
)

// stringSliceFlag is a flag that can be specified multiple times to collect multiple string values
Expand Down Expand Up @@ -103,6 +104,8 @@ func main() {
exitCode = cmdRm(client, args)
case "tag":
exitCode = cmdTag(client, args)
case "load":
exitCode = cmdLoad(client, args)
default:
fmt.Fprintf(os.Stderr, "Unknown command: %s\n", command)
printUsage()
Expand Down Expand Up @@ -154,15 +157,23 @@ func cmdPull(client *distribution.Client, args []string) int {

func cmdPackage(args []string) int {
fs := flag.NewFlagSet("package", flag.ExitOnError)
var licensePaths stringSliceFlag
var contextSize uint64
var mmproj string
var (
licensePaths stringSliceFlag
contextSize uint64
file string
tag string
mmproj string
)

fs.Var(&licensePaths, "licenses", "Paths to license files (can be specified multiple times)")
fs.Uint64Var(&contextSize, "context-size", 0, "Context size in tokens")
fs.StringVar(&mmproj, "mmproj", "", "Path to Multimodal Projector file")
fs.StringVar(&file, "file", "", "Write archived model to the given file")
fs.StringVar(&tag, "tag", "", "Push model to the given registry tag")
fs.Parse(args)

fs.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool package [OPTIONS] <path-to-gguf> <reference>\n\n")
fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool package [OPTIONS] <path-to-gguf>\n\n")
fmt.Fprintf(os.Stderr, "Options:\n")
fs.PrintDefaults()
}
Expand All @@ -173,14 +184,18 @@ func cmdPackage(args []string) int {
}
args = fs.Args()

if len(args) < 2 {
if len(args) < 1 {
fmt.Fprintf(os.Stderr, "Error: missing arguments\n")
fs.Usage()
return 1
}
if file == "" && tag == "" {
fmt.Fprintf(os.Stderr, "Error: one of --file or --tag is required\n")
fs.Usage()
return 1
}

source := args[0]
reference := args[1]
ctx := context.Background()

// Check if source file exists
Expand Down Expand Up @@ -210,11 +225,18 @@ func cmdPackage(args []string) int {
// Create registry client once with all options
registryClient := registry.NewClient(registryClientOpts...)

// Parse the reference
target, err := registryClient.NewTarget(reference)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing reference: %v\n", err)
return 1
var (
target builder.Target
err error
)
if file != "" {
target = tarball.NewFileTarget(file)
} else {
target, err = registryClient.NewTarget(tag)
if err != nil {
fmt.Fprintf(os.Stderr, "Create packaging target: %v\n", err)
return 1
}
}

// Create image with layer
Expand Down Expand Up @@ -250,9 +272,59 @@ func cmdPackage(args []string) int {

// Push the image
if err := builder.Build(ctx, target, os.Stdout); err != nil {
fmt.Fprintf(os.Stderr, "Error writing model %q to registry: %v\n", reference, err)
fmt.Fprintf(os.Stderr, "Error writing model to registry: %v\n", err)
return 1
}
if tag != "" {
fmt.Printf("Successfully packaged and pushed model: %s\n", tag)
} else {
fmt.Printf("Successfully packaged model to file: %s\n", file)
}
return 0
}

func cmdLoad(client *distribution.Client, args []string) int {
fs := flag.NewFlagSet("load", flag.ExitOnError)
var (
tag string
)
fs.StringVar(&tag, "tag", "", "Apply tag to the loaded model")
fs.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool load [OPTIONS] <path-to-archive>\n\n")
fmt.Fprintf(os.Stderr, "Options:\n")
fs.PrintDefaults()
}

if err := fs.Parse(args); err != nil {
fmt.Fprintf(os.Stderr, "Error parsing flags: %v\n", err)
return 1
}
args = fs.Args()

if len(args) < 1 {
fmt.Fprintf(os.Stderr, "Error: missing required argument\n")
fs.Usage()
return 1
}
path := args[0]

f, err := os.Open(path)
if err != nil {
fmt.Fprintf(os.Stderr, "Error opening model file: %v\n", err)
return 1
}
defer f.Close()

id, err := client.LoadModel(f, os.Stdout)
if err != nil {
fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err)
return 1
}
fmt.Fprintln(os.Stdout, "Loaded model:", id)
if err := client.Tag(id, tag); err != nil {
fmt.Fprintf(os.Stderr, "Error tagging model: %v\n", err)
}
fmt.Fprintln(os.Stdout, "Tagged model:", tag)
return 0
}

Expand Down
47 changes: 46 additions & 1 deletion distribution/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package distribution

import (
"context"
"errors"
"fmt"
"io"
"net/http"
Expand All @@ -12,6 +13,7 @@ import (
"github.com/docker/model-distribution/internal/progress"
"github.com/docker/model-distribution/internal/store"
"github.com/docker/model-distribution/registry"
"github.com/docker/model-distribution/tarball"
"github.com/docker/model-distribution/types"
)

Expand Down Expand Up @@ -179,7 +181,7 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter

// Ensure model has the correct tag
if err := c.store.AddTags(remoteDigest.String(), []string{reference}); err != nil {
return fmt.Errorf("tagging modle: %w", err)
return fmt.Errorf("tagging model: %w", err)
}
return nil
} else {
Expand All @@ -206,6 +208,49 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter
return nil
}

// LoadModel loads the model from the reader to the store
func (c *Client) LoadModel(r io.Reader, progressWriter io.Writer) (string, error) {
c.log.Infoln("Starting model load")

tr := tarball.NewReader(r)
for {
diffID, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
if errors.Is(err, io.ErrUnexpectedEOF) {
c.log.Infof("Model load interrupted (likely cancelled): %v", err)
return "", fmt.Errorf("model load interrupted: %w", err)
}
return "", fmt.Errorf("reading blob from stream: %w", err)
}
c.log.Infoln("Loading blob:", diffID)
if err := c.store.WriteBlob(diffID, tr); err != nil {
return "", fmt.Errorf("writing blob: %w", err)
}
c.log.Infoln("Loaded blob:", diffID)
}

manifest, digest, err := tr.Manifest()
if err != nil {
return "", fmt.Errorf("read manifest: %w", err)
}
c.log.Infoln("Loading manifest:", digest.String())
if err := c.store.WriteManifest(digest, manifest); err != nil {
return "", fmt.Errorf("write manifest: %w", err)
}
c.log.Infoln("Loaded model with ID:", digest.String())

if err := progress.WriteSuccess(progressWriter, "Model loaded successfully"); err != nil {
c.log.Warnf("Failed to write success message: %v", err)
// If we fail to write success message, don't try again
progressWriter = nil
}

return digest.String(), nil
}

// ListModels returns all available models
func (c *Client) ListModels() ([]types.Model, error) {
c.log.Infoln("Listing available models")
Expand Down
55 changes: 55 additions & 0 deletions distribution/load_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package distribution

import (
"io"
"os"
"testing"

"github.com/docker/model-distribution/builder"
"github.com/docker/model-distribution/tarball"
)

func TestLoadModel(t *testing.T) {
// Create temp directory for store
tempDir, err := os.MkdirTemp("", "model-distribution-test-*")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)

// Create client
client, err := NewClient(WithStoreRootPath(tempDir))
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}

// Load model
pr, pw := io.Pipe()
target, err := tarball.NewTarget(pw)
if err != nil {
t.Fatalf("Failed to create target: %v", err)
}
done := make(chan error)
var id string
go func() {
var err error
id, err = client.LoadModel(pr, nil)
done <- err
}()
bldr, err := builder.FromGGUF(testGGUFFile)
if err != nil {
t.Fatalf("Failed to create builder: %v", err)
}
err = bldr.Build(t.Context(), target, nil)
if err != nil {
t.Fatalf("Failed to build model: %v", err)
}
if err := <-done; err != nil {
t.Fatalf("LoadModel exited with error: %v", err)
}

// Ensure model was loaded
if _, err := client.GetModel(id); err != nil {
t.Fatalf("Failed to get model: %v", err)
}
}
39 changes: 39 additions & 0 deletions internal/progress/reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package progress

import (
"io"

"github.com/google/go-containerregistry/pkg/v1"
)

// Reader wraps an io.Reader to track reading progress
type Reader struct {
Reader io.Reader
ProgressChan chan<- v1.Update
Total int64
}

// NewReader returns a reader that reports progress to the given channel while reading.
func NewReader(r io.Reader, updates chan<- v1.Update) io.Reader {
if updates == nil {
return r
}
return &Reader{
Reader: r,
ProgressChan: updates,
}
}

func (pr *Reader) Read(p []byte) (int, error) {
n, err := pr.Reader.Read(p)
pr.Total += int64(n)
if err == io.EOF {
pr.ProgressChan <- v1.Update{Complete: pr.Total}
} else if n > 0 {
select {
case pr.ProgressChan <- v1.Update{Complete: pr.Total}:
default: // if the progress channel is full, it skips sending rather than blocking the Read() call.
}
}
return n, err
}
Loading
Loading