Skip to content

Commit

Permalink
feat(cmd/rofl): Add support for building TDX ROFL apps
Browse files Browse the repository at this point in the history
  • Loading branch information
kostko committed Oct 25, 2024
1 parent ca765b9 commit 01e0f4e
Show file tree
Hide file tree
Showing 7 changed files with 614 additions and 4 deletions.
304 changes: 304 additions & 0 deletions cmd/rofl/build/artifacts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
package build

import (
"archive/tar"
"compress/bzip2"
"crypto/sha256"
"errors"
"fmt"
"io"
"io/fs"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"strings"
"time"

"github.com/adrg/xdg"
"github.com/spf13/cobra"

"github.com/oasisprotocol/oasis-core/go/common/crypto/hash"
)

const artifactCacheDir = "build_cache"

// maybeDownloadArtifact downloads the given artifact and optionally verifies its integrity against
// the provided hash.
func maybeDownloadArtifact(kind, uri, knownHash string) string {
fmt.Printf("Downloading %s artifact...\n", kind)
fmt.Printf(" URI: %s\n", uri)
if knownHash != "" {
fmt.Printf(" Hash: %s\n", knownHash)
}

url, err := url.Parse(uri)
if err != nil {
cobra.CheckErr(fmt.Errorf("failed to parse %s artifact URL: %w", kind, err))
}

// In case the URI represents a local file, just return it.
if url.Host == "" {
return url.Path
}

// TODO: Prune cache.
cacheHash := hash.NewFromBytes([]byte(uri)).Hex()
cacheFn, err := xdg.CacheFile(filepath.Join("oasis", artifactCacheDir, cacheHash))
if err != nil {
cobra.CheckErr(fmt.Errorf("failed to create cache directory for %s artifact: %w", kind, err))
}

f, err := os.Create(cacheFn)
if err != nil {
cobra.CheckErr(fmt.Errorf("failed to create file for %s artifact: %w", kind, err))
}
defer f.Close()

// Download the remote artifact.
res, err := http.Get(uri) //nolint:gosec,noctx
if err != nil {
cobra.CheckErr(fmt.Errorf("failed to download %s artifact: %w", kind, err))
}
defer res.Body.Close()

// Compute the SHA256 hash while downloading the artifact.
h := sha256.New()
rd := io.TeeReader(res.Body, h)

if _, err = io.Copy(f, rd); err != nil {
cobra.CheckErr(fmt.Errorf("failed to download %s artifact: %w", kind, err))
}

// Verify integrity if available.
if knownHash != "" {
artifactHash := fmt.Sprintf("%x", h.Sum(nil))
if artifactHash != knownHash {
cobra.CheckErr(fmt.Errorf("hash mismatch for %s artifact (expected: %s got: %s)", kind, knownHash, artifactHash))
}
}

return cacheFn
}

// extractArchive extracts the given tar.bz2 archive into the target output directory.
func extractArchive(fn, outputDir string) error {
f, err := os.Open(fn)
if err != nil {
return fmt.Errorf("failed to open archive: %w", err)
}
defer f.Close()

rd := tar.NewReader(bzip2.NewReader(f))

existingPaths := make(map[string]struct{})
cleanupPath := func(path string) (string, error) {
// Sanitize path to ensure it doesn't escape to any parent directories.
path = filepath.Clean(filepath.Join(outputDir, path))
if !strings.HasPrefix(path, outputDir) {
return "", fmt.Errorf("malformed path in archive")
}
return path, nil
}

modTimes := make(map[string]time.Time)

FILES:
for {
var header *tar.Header
header, err = rd.Next()
switch {
case errors.Is(err, io.EOF):
// We are done.
break FILES
case err != nil:
// Failed to read archive.
return fmt.Errorf("error reading archive: %w", err)
case header == nil:
// Bad archive.
return fmt.Errorf("malformed archive")
}

var path string
path, err = cleanupPath(header.Name)
if err != nil {
return err
}
if _, ok := existingPaths[path]; ok {
continue // Make sure we never handle a path twice.
}
existingPaths[path] = struct{}{}
modTimes[path] = header.ModTime

switch header.Typeflag {
case tar.TypeDir:
// Directory.
if err = os.MkdirAll(path, header.FileInfo().Mode()); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
case tar.TypeLink:
// Hard link.
var linkPath string
linkPath, err = cleanupPath(header.Linkname)
if err != nil {
return err
}

if err = os.Link(linkPath, path); err != nil {
return fmt.Errorf("failed to create hard link: %w", err)
}
case tar.TypeSymlink:
// Symbolic link.
if err = os.Symlink(header.Linkname, path); err != nil {
return fmt.Errorf("failed to create soft link: %w", err)
}
case tar.TypeChar, tar.TypeBlock, tar.TypeFifo:
// Device or FIFO node.
if err = extractHandleSpecialNode(path, header); err != nil {
return err
}
case tar.TypeReg:
// Regular file.
if err = os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
return fmt.Errorf("failed to create parent directory: %w", err)
}

var fh *os.File
fh, err = os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, header.FileInfo().Mode())
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
if _, err = io.Copy(fh, rd); err != nil { //nolint:gosec
fh.Close()
return fmt.Errorf("failed to copy data: %w", err)
}
fh.Close()
default:
// Skip unsupported types.
continue
}
}

// Update all modification times at the end to ensure they are correct.
for path, mtime := range modTimes {
if err = extractChtimes(path, mtime, mtime); err != nil {
return fmt.Errorf("failed to change file '%s' timestamps: %w", path, err)
}
}

return nil
}

// copyFile copies the file at path src to a file at path dst using the given mode.
func copyFile(src, dst string, mode os.FileMode) error {
sf, err := os.Open(src)
if err != nil {
return fmt.Errorf("failed to open '%s': %w", src, err)
}
defer sf.Close()

df, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode)
if err != nil {
return fmt.Errorf("failed to create '%s': %w", dst, err)
}
defer df.Close()

_, err = io.Copy(df, sf)
return err
}

// computeDirSize computes the size of the given directory.
func computeDirSize(path string) (int64, error) {
var size int64
err := filepath.WalkDir(path, func(path string, d fs.DirEntry, derr error) error {
if derr != nil {
return derr
}
fi, err := d.Info()
if err != nil {
return err
}
size += fi.Size()
return nil
})
if err != nil {
return 0, err
}
return size, nil
}

// createExt4Fs creates an ext4 filesystem in the given file using directory dir to populate it.
//
// Returns the size of the created filesystem image in bytes.
func createExt4Fs(fn, dir string) (int64, error) {
// Compute filesystem size in bytes.
fsSize, err := computeDirSize(dir)
if err != nil {
return 0, err
}
fsSize /= 1024 // Convert to kilobytes.
fsSize = (fsSize * 150) / 100 // Scale by overhead factor of 1.5.

// Execute mkfs.ext4.
cmd := exec.Command( //nolint:gosec
"mkfs.ext4",
"-E", "root_owner=0:0",
"-d", dir,
fn,
fmt.Sprintf("%dK", fsSize),
)
var out strings.Builder
cmd.Stderr = &out
if err = cmd.Run(); err != nil {
return 0, fmt.Errorf("%w\n%s", err, out.String())
}

// Measure the size of the resulting image.
fi, err := os.Stat(fn)
if err != nil {
return 0, err
}
return fi.Size(), nil
}

// createVerityHashTree creates the verity Merkle hash tree and returns the root hash.
func createVerityHashTree(fsFn, hashFn string) (string, error) {
rootHashFn := hashFn + ".roothash"

cmd := exec.Command( //nolint:gosec
"veritysetup", "format",
"--data-block-size=4096",
"--hash-block-size=4096",
"--root-hash-file="+rootHashFn,
fsFn,
hashFn,
)
if err := cmd.Run(); err != nil {
return "", err
}

data, err := os.ReadFile(rootHashFn)
if err != nil {
return "", fmt.Errorf("")
}
return string(data), nil
}

// concatFiles appends the contents of file b to a.
func concatFiles(a, b string) error {
df, err := os.OpenFile(a, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return err
}
defer df.Close()

sf, err := os.Open(b)
if err != nil {
return err
}
defer sf.Close()

_, err = io.Copy(df, sf)
return err
}
17 changes: 17 additions & 0 deletions cmd/rofl/build/artifacts_other.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
//go:build !unix

package build

import (
"archive/tar"
"os"
"time"
)

func extractHandleSpecialNode(path string, header *tar.Header) error {
return nil
}

func extractChtimes(path string, atime, mtime time.Time) error {
return os.Chtimes(path, atime, mtime)
}
30 changes: 30 additions & 0 deletions cmd/rofl/build/artifacts_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//go:build unix

package build

import (
"archive/tar"
"time"

"golang.org/x/sys/unix"
)

func extractHandleSpecialNode(path string, header *tar.Header) error {
mode := uint32(header.Mode & 0o7777)
switch header.Typeflag {
case tar.TypeBlock:
mode |= unix.S_IFBLK
case tar.TypeChar:
mode |= unix.S_IFCHR
case tar.TypeFifo:
mode |= unix.S_IFIFO
}

return unix.Mknod(path, mode, int(unix.Mkdev(uint32(header.Devmajor), uint32(header.Devminor))))
}

func extractChtimes(path string, atime, mtime time.Time) error {
atv := unix.NsecToTimeval(atime.UnixNano())
mtv := unix.NsecToTimeval(mtime.UnixNano())
return unix.Lutimes(path, []unix.Timeval{atv, mtv})
}
1 change: 1 addition & 0 deletions cmd/rofl/build/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,5 @@ func init() {

Cmd.PersistentFlags().AddFlagSet(globalFlags)
Cmd.AddCommand(sgxCmd)
Cmd.AddCommand(tdxCmd)
}
Loading

0 comments on commit 01e0f4e

Please sign in to comment.