From fb706b397132bb9c5a44fcde461c00beb0af5cb7 Mon Sep 17 00:00:00 2001 From: Maksim An Date: Wed, 6 Apr 2022 18:28:15 -0700 Subject: [PATCH 1/2] Port grantvmgroupaccess code from go-winio repo for further extension Signed-off-by: Maksim An --- computestorage/helpers.go | 3 +- .../security/grantvmgroupaccess.go | 15 +- internal/security/grantvmgroupaccess_test.go | 120 +++ internal/security/mksyscall_windows.go | 958 ++++++++++++++++++ .../security/syscall_windows.go | 6 +- .../security/zsyscall_windows.go | 24 +- internal/tools/grantvmgroupaccess/main.go | 2 +- internal/uvm/scsi.go | 2 +- .../hcsshim/computestorage/helpers.go | 3 +- .../internal}/security/grantvmgroupaccess.go | 15 +- .../internal}/security/syscall_windows.go | 6 +- .../internal}/security/zsyscall_windows.go | 24 +- .../Microsoft/hcsshim/internal/uvm/scsi.go | 2 +- test/vendor/modules.txt | 2 +- vendor/golang.org/x/sys/execabs/execabs.go | 102 ++ vendor/modules.txt | 2 +- 16 files changed, 1233 insertions(+), 53 deletions(-) rename {test/vendor/github.com/Microsoft/go-winio/pkg => internal}/security/grantvmgroupaccess.go (90%) create mode 100644 internal/security/grantvmgroupaccess_test.go create mode 100644 internal/security/mksyscall_windows.go rename {vendor/github.com/Microsoft/go-winio/pkg => internal}/security/syscall_windows.go (61%) rename {vendor/github.com/Microsoft/go-winio/pkg => internal}/security/zsyscall_windows.go (59%) rename {vendor/github.com/Microsoft/go-winio/pkg => test/vendor/github.com/Microsoft/hcsshim/internal}/security/grantvmgroupaccess.go (90%) rename test/vendor/github.com/Microsoft/{go-winio/pkg => hcsshim/internal}/security/syscall_windows.go (61%) rename test/vendor/github.com/Microsoft/{go-winio/pkg => hcsshim/internal}/security/zsyscall_windows.go (59%) create mode 100644 vendor/golang.org/x/sys/execabs/execabs.go diff --git a/computestorage/helpers.go b/computestorage/helpers.go index a0e329edec..c3608dcec8 100644 --- a/computestorage/helpers.go +++ b/computestorage/helpers.go @@ -8,11 +8,12 @@ import ( "path/filepath" "syscall" - "github.com/Microsoft/go-winio/pkg/security" "github.com/Microsoft/go-winio/vhd" "github.com/Microsoft/hcsshim/internal/memory" "github.com/pkg/errors" "golang.org/x/sys/windows" + + "github.com/Microsoft/hcsshim/internal/security" ) const defaultVHDXBlockSizeInMB = 1 diff --git a/test/vendor/github.com/Microsoft/go-winio/pkg/security/grantvmgroupaccess.go b/internal/security/grantvmgroupaccess.go similarity index 90% rename from test/vendor/github.com/Microsoft/go-winio/pkg/security/grantvmgroupaccess.go rename to internal/security/grantvmgroupaccess.go index fca241590c..602920786c 100644 --- a/test/vendor/github.com/Microsoft/go-winio/pkg/security/grantvmgroupaccess.go +++ b/internal/security/grantvmgroupaccess.go @@ -3,11 +3,10 @@ package security import ( + "fmt" "os" "syscall" "unsafe" - - "github.com/pkg/errors" ) type ( @@ -72,7 +71,7 @@ func GrantVmGroupAccess(name string) error { // Stat (to determine if `name` is a directory). s, err := os.Stat(name) if err != nil { - return errors.Wrapf(err, "%s os.Stat %s", gvmga, name) + return fmt.Errorf("%s os.Stat %s: %w", gvmga, name, err) } // Get a handle to the file/directory. Must defer Close on success. @@ -88,7 +87,7 @@ func GrantVmGroupAccess(name string) error { sd := uintptr(0) origDACL := uintptr(0) if err := getSecurityInfo(fd, uint32(ot), uint32(si), nil, nil, &origDACL, nil, &sd); err != nil { - return errors.Wrapf(err, "%s GetSecurityInfo %s", gvmga, name) + return fmt.Errorf("%s GetSecurityInfo %s: %w", gvmga, name, err) } defer syscall.LocalFree((syscall.Handle)(unsafe.Pointer(sd))) @@ -102,7 +101,7 @@ func GrantVmGroupAccess(name string) error { // And finally use SetSecurityInfo to apply the updated DACL. if err := setSecurityInfo(fd, uint32(ot), uint32(si), uintptr(0), uintptr(0), newDACL, uintptr(0)); err != nil { - return errors.Wrapf(err, "%s SetSecurityInfo %s", gvmga, name) + return fmt.Errorf("%s SetSecurityInfo %s: %w", gvmga, name, err) } return nil @@ -120,7 +119,7 @@ func createFile(name string, isDir bool) (syscall.Handle, error) { } fd, err := syscall.CreateFile(&namep[0], da, sm, nil, syscall.OPEN_EXISTING, fa, 0) if err != nil { - return 0, errors.Wrapf(err, "%s syscall.CreateFile %s", gvmga, name) + return 0, fmt.Errorf("%s syscall.CreateFile %s: %w", gvmga, name, err) } return fd, nil } @@ -131,7 +130,7 @@ func generateDACLWithAcesAdded(name string, isDir bool, origDACL uintptr) (uintp // Generate pointers to the SIDs based on the string SIDs sid, err := syscall.StringToSid(sidVmGroup) if err != nil { - return 0, errors.Wrapf(err, "%s syscall.StringToSid %s %s", gvmga, name, sidVmGroup) + return 0, fmt.Errorf("%s syscall.StringToSid %s %s: %w", gvmga, name, sidVmGroup, err) } inheritance := inheritModeNoInheritance @@ -154,7 +153,7 @@ func generateDACLWithAcesAdded(name string, isDir bool, origDACL uintptr) (uintp modifiedDACL := uintptr(0) if err := setEntriesInAcl(uintptr(uint32(1)), uintptr(unsafe.Pointer(&eaArray[0])), origDACL, &modifiedDACL); err != nil { - return 0, errors.Wrapf(err, "%s SetEntriesInAcl %s", gvmga, name) + return 0, fmt.Errorf("%s SetEntriesInAcl %s: %w", gvmga, name, err) } return modifiedDACL, nil diff --git a/internal/security/grantvmgroupaccess_test.go b/internal/security/grantvmgroupaccess_test.go new file mode 100644 index 0000000000..2dffaa1284 --- /dev/null +++ b/internal/security/grantvmgroupaccess_test.go @@ -0,0 +1,120 @@ +//+build windows + +package security + +import ( + "io/ioutil" + "os" + "path/filepath" + "regexp" + "strings" + "testing" + + exec "golang.org/x/sys/execabs" +) + +const ( + vmAccountName = `NT VIRTUAL MACHINE\\Virtual Machines` + vmAccountSID = "S-1-5-83-0" +) + +// TestGrantVmGroupAccess verifies for the three case of a file, a directory, +// and a file in a directory that the appropriate ACEs are set, including +// inheritance in the second two examples. These are the expected ACES. Is +// verified by running icacls and comparing output. +// +// File: +// S-1-15-3-1024-2268835264-3721307629-241982045-173645152-1490879176-104643441-2915960892-1612460704:(R,W) +// S-1-5-83-1-3166535780-1122986932-343720105-43916321:(R,W) +// +// Directory: +// S-1-15-3-1024-2268835264-3721307629-241982045-173645152-1490879176-104643441-2915960892-1612460704:(OI)(CI)(R,W) +// S-1-5-83-1-3166535780-1122986932-343720105-43916321:(OI)(CI)(R,W) +// +// File in directory (inherited): +// S-1-15-3-1024-2268835264-3721307629-241982045-173645152-1490879176-104643441-2915960892-1612460704:(I)(R,W) +// S-1-5-83-1-3166535780-1122986932-343720105-43916321:(I)(R,W) + +func TestGrantVmGroupAccess(t *testing.T) { + f, err := ioutil.TempFile("", "gvmgafile") + if err != nil { + t.Fatal(err) + } + defer func() { + f.Close() + os.Remove(f.Name()) + }() + + d, err := ioutil.TempDir("", "gvmgadir") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(d) + + find, err := os.Create(filepath.Join(d, "find.txt")) + if err != nil { + t.Fatal(err) + } + + if err := GrantVmGroupAccess(f.Name()); err != nil { + t.Fatal(err) + } + + if err := GrantVmGroupAccess(d); err != nil { + t.Fatal(err) + } + + verifyVMAccountDACLs(t, + f.Name(), + []string{`(R)`}, + ) + + // Two items here: + // - One explicit read only. + // - Other applies to this folder, subfolders and files + // (OI): object inherit + // (CI): container inherit + // (IO): inherit only + // (GR): generic read + // + // In properties for the directory, advanced security settings, this will + // show as a single line "Allow/Virtual Machines/Read/Inherited from none/This folder, subfolder and files + verifyVMAccountDACLs(t, + d, + []string{`(R)`, `(OI)(CI)(IO)(GR)`}, + ) + + verifyVMAccountDACLs(t, + find.Name(), + []string{`(I)(R)`}, + ) + +} + +func verifyVMAccountDACLs(t *testing.T, name string, permissions []string) { + cmd := exec.Command("icacls", name) + outb, err := cmd.CombinedOutput() + if err != nil { + t.Fatal(err) + } + out := string(outb) + + for _, p := range permissions { + // Avoid '(' and ')' being part of match groups + p = strings.Replace(p, "(", "\\(", -1) + p = strings.Replace(p, ")", "\\)", -1) + + nameToCheck := vmAccountName + ":" + p + sidToCheck := vmAccountSID + ":" + p + + rxName := regexp.MustCompile(nameToCheck) + rxSID := regexp.MustCompile(sidToCheck) + + matchesName := rxName.FindAllStringIndex(out, -1) + matchesSID := rxSID.FindAllStringIndex(out, -1) + + if len(matchesName) != 1 && len(matchesSID) != 1 { + t.Fatalf("expected one match for %s or %s\n%s", nameToCheck, sidToCheck, out) + } + } +} diff --git a/internal/security/mksyscall_windows.go b/internal/security/mksyscall_windows.go new file mode 100644 index 0000000000..d52b6137c7 --- /dev/null +++ b/internal/security/mksyscall_windows.go @@ -0,0 +1,958 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Hard-coding unicode mode for VHD library. + +// +build ignore + +/* +mksyscall_windows generates windows system call bodies + +It parses all files specified on command line containing function +prototypes (like syscall_windows.go) and prints system call bodies +to standard output. + +The prototypes are marked by lines beginning with "//sys" and read +like func declarations if //sys is replaced by func, but: + +* The parameter lists must give a name for each argument. This + includes return parameters. + +* The parameter lists must give a type for each argument: + the (x, y, z int) shorthand is not allowed. + +* If the return parameter is an error number, it must be named err. + +* If go func name needs to be different from its winapi dll name, + the winapi name could be specified at the end, after "=" sign, like + //sys LoadLibrary(libname string) (handle uint32, err error) = LoadLibraryA + +* Each function that returns err needs to supply a condition, that + return value of winapi will be tested against to detect failure. + This would set err to windows "last-error", otherwise it will be nil. + The value can be provided at end of //sys declaration, like + //sys LoadLibrary(libname string) (handle uint32, err error) [failretval==-1] = LoadLibraryA + and is [failretval==0] by default. + +* If the function name ends in a "?", then the function not existing is non- + fatal, and an error will be returned instead of panicking. + +Usage: + mksyscall_windows [flags] [path ...] + +The flags are: + -output + Specify output file name (outputs to console if blank). + -trace + Generate print statement after every syscall. +*/ +package main + +import ( + "bufio" + "bytes" + "errors" + "flag" + "fmt" + "go/format" + "go/parser" + "go/token" + "io" + "io/ioutil" + "log" + "os" + "path/filepath" + "runtime" + "sort" + "strconv" + "strings" + "text/template" +) + +var ( + filename = flag.String("output", "", "output file name (standard output if omitted)") + printTraceFlag = flag.Bool("trace", false, "generate print statement after every syscall") + systemDLL = flag.Bool("systemdll", true, "whether all DLLs should be loaded from the Windows system directory") +) + +func trim(s string) string { + return strings.Trim(s, " \t") +} + +var packageName string + +func packagename() string { + return packageName +} + +func syscalldot() string { + if packageName == "syscall" { + return "" + } + return "syscall." +} + +// Param is function parameter +type Param struct { + Name string + Type string + fn *Fn + tmpVarIdx int +} + +// tmpVar returns temp variable name that will be used to represent p during syscall. +func (p *Param) tmpVar() string { + if p.tmpVarIdx < 0 { + p.tmpVarIdx = p.fn.curTmpVarIdx + p.fn.curTmpVarIdx++ + } + return fmt.Sprintf("_p%d", p.tmpVarIdx) +} + +// BoolTmpVarCode returns source code for bool temp variable. +func (p *Param) BoolTmpVarCode() string { + const code = `var %[1]s uint32 + if %[2]s { + %[1]s = 1 + }` + return fmt.Sprintf(code, p.tmpVar(), p.Name) +} + +// BoolPointerTmpVarCode returns source code for bool temp variable. +func (p *Param) BoolPointerTmpVarCode() string { + const code = `var %[1]s uint32 + if *%[2]s { + %[1]s = 1 + }` + return fmt.Sprintf(code, p.tmpVar(), p.Name) +} + +// SliceTmpVarCode returns source code for slice temp variable. +func (p *Param) SliceTmpVarCode() string { + const code = `var %s *%s + if len(%s) > 0 { + %s = &%s[0] + }` + tmp := p.tmpVar() + return fmt.Sprintf(code, tmp, p.Type[2:], p.Name, tmp, p.Name) +} + +// StringTmpVarCode returns source code for string temp variable. +func (p *Param) StringTmpVarCode() string { + errvar := p.fn.Rets.ErrorVarName() + if errvar == "" { + errvar = "_" + } + tmp := p.tmpVar() + const code = `var %s %s + %s, %s = %s(%s)` + s := fmt.Sprintf(code, tmp, p.fn.StrconvType(), tmp, errvar, p.fn.StrconvFunc(), p.Name) + if errvar == "-" { + return s + } + const morecode = ` + if %s != nil { + return + }` + return s + fmt.Sprintf(morecode, errvar) +} + +// TmpVarCode returns source code for temp variable. +func (p *Param) TmpVarCode() string { + switch { + case p.Type == "bool": + return p.BoolTmpVarCode() + case p.Type == "*bool": + return p.BoolPointerTmpVarCode() + case strings.HasPrefix(p.Type, "[]"): + return p.SliceTmpVarCode() + default: + return "" + } +} + +// TmpVarReadbackCode returns source code for reading back the temp variable into the original variable. +func (p *Param) TmpVarReadbackCode() string { + switch { + case p.Type == "*bool": + return fmt.Sprintf("*%s = %s != 0", p.Name, p.tmpVar()) + default: + return "" + } +} + +// TmpVarHelperCode returns source code for helper's temp variable. +func (p *Param) TmpVarHelperCode() string { + if p.Type != "string" { + return "" + } + return p.StringTmpVarCode() +} + +// SyscallArgList returns source code fragments representing p parameter +// in syscall. Slices are translated into 2 syscall parameters: pointer to +// the first element and length. +func (p *Param) SyscallArgList() []string { + t := p.HelperType() + var s string + switch { + case t == "*bool": + s = fmt.Sprintf("unsafe.Pointer(&%s)", p.tmpVar()) + case t[0] == '*': + s = fmt.Sprintf("unsafe.Pointer(%s)", p.Name) + case t == "bool": + s = p.tmpVar() + case strings.HasPrefix(t, "[]"): + return []string{ + fmt.Sprintf("uintptr(unsafe.Pointer(%s))", p.tmpVar()), + fmt.Sprintf("uintptr(len(%s))", p.Name), + } + default: + s = p.Name + } + return []string{fmt.Sprintf("uintptr(%s)", s)} +} + +// IsError determines if p parameter is used to return error. +func (p *Param) IsError() bool { + return p.Name == "err" && p.Type == "error" +} + +// HelperType returns type of parameter p used in helper function. +func (p *Param) HelperType() string { + if p.Type == "string" { + return p.fn.StrconvType() + } + return p.Type +} + +// join concatenates parameters ps into a string with sep separator. +// Each parameter is converted into string by applying fn to it +// before conversion. +func join(ps []*Param, fn func(*Param) string, sep string) string { + if len(ps) == 0 { + return "" + } + a := make([]string, 0) + for _, p := range ps { + a = append(a, fn(p)) + } + return strings.Join(a, sep) +} + +// Rets describes function return parameters. +type Rets struct { + Name string + Type string + ReturnsError bool + FailCond string + fnMaybeAbsent bool +} + +// ErrorVarName returns error variable name for r. +func (r *Rets) ErrorVarName() string { + if r.ReturnsError { + return "err" + } + if r.Type == "error" { + return r.Name + } + return "" +} + +// ToParams converts r into slice of *Param. +func (r *Rets) ToParams() []*Param { + ps := make([]*Param, 0) + if len(r.Name) > 0 { + ps = append(ps, &Param{Name: r.Name, Type: r.Type}) + } + if r.ReturnsError { + ps = append(ps, &Param{Name: "err", Type: "error"}) + } + return ps +} + +// List returns source code of syscall return parameters. +func (r *Rets) List() string { + s := join(r.ToParams(), func(p *Param) string { return p.Name + " " + p.Type }, ", ") + if len(s) > 0 { + s = "(" + s + ")" + } else if r.fnMaybeAbsent { + s = "(err error)" + } + return s +} + +// PrintList returns source code of trace printing part correspondent +// to syscall return values. +func (r *Rets) PrintList() string { + return join(r.ToParams(), func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `) +} + +// SetReturnValuesCode returns source code that accepts syscall return values. +func (r *Rets) SetReturnValuesCode() string { + if r.Name == "" && !r.ReturnsError { + return "" + } + retvar := "r0" + if r.Name == "" { + retvar = "r1" + } + errvar := "_" + if r.ReturnsError { + errvar = "e1" + } + return fmt.Sprintf("%s, _, %s := ", retvar, errvar) +} + +func (r *Rets) useLongHandleErrorCode(retvar string) string { + const code = `if %s { + err = errnoErr(e1) + }` + cond := retvar + " == 0" + if r.FailCond != "" { + cond = strings.Replace(r.FailCond, "failretval", retvar, 1) + } + return fmt.Sprintf(code, cond) +} + +// SetErrorCode returns source code that sets return parameters. +func (r *Rets) SetErrorCode() string { + const code = `if r0 != 0 { + %s = %sErrno(r0) + }` + if r.Name == "" && !r.ReturnsError { + return "" + } + if r.Name == "" { + return r.useLongHandleErrorCode("r1") + } + if r.Type == "error" { + return fmt.Sprintf(code, r.Name, syscalldot()) + } + s := "" + switch { + case r.Type[0] == '*': + s = fmt.Sprintf("%s = (%s)(unsafe.Pointer(r0))", r.Name, r.Type) + case r.Type == "bool": + s = fmt.Sprintf("%s = r0 != 0", r.Name) + default: + s = fmt.Sprintf("%s = %s(r0)", r.Name, r.Type) + } + if !r.ReturnsError { + return s + } + return s + "\n\t" + r.useLongHandleErrorCode(r.Name) +} + +// Fn describes syscall function. +type Fn struct { + Name string + Params []*Param + Rets *Rets + PrintTrace bool + dllname string + dllfuncname string + src string + // TODO: get rid of this field and just use parameter index instead + curTmpVarIdx int // insure tmp variables have uniq names +} + +// extractParams parses s to extract function parameters. +func extractParams(s string, f *Fn) ([]*Param, error) { + s = trim(s) + if s == "" { + return nil, nil + } + a := strings.Split(s, ",") + ps := make([]*Param, len(a)) + for i := range ps { + s2 := trim(a[i]) + b := strings.Split(s2, " ") + if len(b) != 2 { + b = strings.Split(s2, "\t") + if len(b) != 2 { + return nil, errors.New("Could not extract function parameter from \"" + s2 + "\"") + } + } + ps[i] = &Param{ + Name: trim(b[0]), + Type: trim(b[1]), + fn: f, + tmpVarIdx: -1, + } + } + return ps, nil +} + +// extractSection extracts text out of string s starting after start +// and ending just before end. found return value will indicate success, +// and prefix, body and suffix will contain correspondent parts of string s. +func extractSection(s string, start, end rune) (prefix, body, suffix string, found bool) { + s = trim(s) + if strings.HasPrefix(s, string(start)) { + // no prefix + body = s[1:] + } else { + a := strings.SplitN(s, string(start), 2) + if len(a) != 2 { + return "", "", s, false + } + prefix = a[0] + body = a[1] + } + a := strings.SplitN(body, string(end), 2) + if len(a) != 2 { + return "", "", "", false + } + return prefix, a[0], a[1], true +} + +// newFn parses string s and return created function Fn. +func newFn(s string) (*Fn, error) { + s = trim(s) + f := &Fn{ + Rets: &Rets{}, + src: s, + PrintTrace: *printTraceFlag, + } + // function name and args + prefix, body, s, found := extractSection(s, '(', ')') + if !found || prefix == "" { + return nil, errors.New("Could not extract function name and parameters from \"" + f.src + "\"") + } + f.Name = prefix + var err error + f.Params, err = extractParams(body, f) + if err != nil { + return nil, err + } + // return values + _, body, s, found = extractSection(s, '(', ')') + if found { + r, err := extractParams(body, f) + if err != nil { + return nil, err + } + switch len(r) { + case 0: + case 1: + if r[0].IsError() { + f.Rets.ReturnsError = true + } else { + f.Rets.Name = r[0].Name + f.Rets.Type = r[0].Type + } + case 2: + if !r[1].IsError() { + return nil, errors.New("Only last windows error is allowed as second return value in \"" + f.src + "\"") + } + f.Rets.ReturnsError = true + f.Rets.Name = r[0].Name + f.Rets.Type = r[0].Type + default: + return nil, errors.New("Too many return values in \"" + f.src + "\"") + } + } + // fail condition + _, body, s, found = extractSection(s, '[', ']') + if found { + f.Rets.FailCond = body + } + // dll and dll function names + s = trim(s) + if s == "" { + return f, nil + } + if !strings.HasPrefix(s, "=") { + return nil, errors.New("Could not extract dll name from \"" + f.src + "\"") + } + s = trim(s[1:]) + a := strings.Split(s, ".") + switch len(a) { + case 1: + f.dllfuncname = a[0] + case 2: + f.dllname = a[0] + f.dllfuncname = a[1] + default: + return nil, errors.New("Could not extract dll name from \"" + f.src + "\"") + } + if n := f.dllfuncname; strings.HasSuffix(n, "?") { + f.dllfuncname = n[:len(n)-1] + f.Rets.fnMaybeAbsent = true + } + return f, nil +} + +// DLLName returns DLL name for function f. +func (f *Fn) DLLName() string { + if f.dllname == "" { + return "kernel32" + } + return f.dllname +} + +// DLLName returns DLL function name for function f. +func (f *Fn) DLLFuncName() string { + if f.dllfuncname == "" { + return f.Name + } + return f.dllfuncname +} + +// ParamList returns source code for function f parameters. +func (f *Fn) ParamList() string { + return join(f.Params, func(p *Param) string { return p.Name + " " + p.Type }, ", ") +} + +// HelperParamList returns source code for helper function f parameters. +func (f *Fn) HelperParamList() string { + return join(f.Params, func(p *Param) string { return p.Name + " " + p.HelperType() }, ", ") +} + +// ParamPrintList returns source code of trace printing part correspondent +// to syscall input parameters. +func (f *Fn) ParamPrintList() string { + return join(f.Params, func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `) +} + +// ParamCount return number of syscall parameters for function f. +func (f *Fn) ParamCount() int { + n := 0 + for _, p := range f.Params { + n += len(p.SyscallArgList()) + } + return n +} + +// SyscallParamCount determines which version of Syscall/Syscall6/Syscall9/... +// to use. It returns parameter count for correspondent SyscallX function. +func (f *Fn) SyscallParamCount() int { + n := f.ParamCount() + switch { + case n <= 3: + return 3 + case n <= 6: + return 6 + case n <= 9: + return 9 + case n <= 12: + return 12 + case n <= 15: + return 15 + default: + panic("too many arguments to system call") + } +} + +// Syscall determines which SyscallX function to use for function f. +func (f *Fn) Syscall() string { + c := f.SyscallParamCount() + if c == 3 { + return syscalldot() + "Syscall" + } + return syscalldot() + "Syscall" + strconv.Itoa(c) +} + +// SyscallParamList returns source code for SyscallX parameters for function f. +func (f *Fn) SyscallParamList() string { + a := make([]string, 0) + for _, p := range f.Params { + a = append(a, p.SyscallArgList()...) + } + for len(a) < f.SyscallParamCount() { + a = append(a, "0") + } + return strings.Join(a, ", ") +} + +// HelperCallParamList returns source code of call into function f helper. +func (f *Fn) HelperCallParamList() string { + a := make([]string, 0, len(f.Params)) + for _, p := range f.Params { + s := p.Name + if p.Type == "string" { + s = p.tmpVar() + } + a = append(a, s) + } + return strings.Join(a, ", ") +} + +// MaybeAbsent returns source code for handling functions that are possibly unavailable. +func (p *Fn) MaybeAbsent() string { + if !p.Rets.fnMaybeAbsent { + return "" + } + const code = `%[1]s = proc%[2]s.Find() + if %[1]s != nil { + return + }` + errorVar := p.Rets.ErrorVarName() + if errorVar == "" { + errorVar = "err" + } + return fmt.Sprintf(code, errorVar, p.DLLFuncName()) +} + +// IsUTF16 is true, if f is W (utf16) function. It is false +// for all A (ascii) functions. +func (_ *Fn) IsUTF16() bool { + return true +} + +// StrconvFunc returns name of Go string to OS string function for f. +func (f *Fn) StrconvFunc() string { + if f.IsUTF16() { + return syscalldot() + "UTF16PtrFromString" + } + return syscalldot() + "BytePtrFromString" +} + +// StrconvType returns Go type name used for OS string for f. +func (f *Fn) StrconvType() string { + if f.IsUTF16() { + return "*uint16" + } + return "*byte" +} + +// HasStringParam is true, if f has at least one string parameter. +// Otherwise it is false. +func (f *Fn) HasStringParam() bool { + for _, p := range f.Params { + if p.Type == "string" { + return true + } + } + return false +} + +// HelperName returns name of function f helper. +func (f *Fn) HelperName() string { + if !f.HasStringParam() { + return f.Name + } + return "_" + f.Name +} + +// Source files and functions. +type Source struct { + Funcs []*Fn + Files []string + StdLibImports []string + ExternalImports []string +} + +func (src *Source) Import(pkg string) { + src.StdLibImports = append(src.StdLibImports, pkg) + sort.Strings(src.StdLibImports) +} + +func (src *Source) ExternalImport(pkg string) { + src.ExternalImports = append(src.ExternalImports, pkg) + sort.Strings(src.ExternalImports) +} + +// ParseFiles parses files listed in fs and extracts all syscall +// functions listed in sys comments. It returns source files +// and functions collection *Source if successful. +func ParseFiles(fs []string) (*Source, error) { + src := &Source{ + Funcs: make([]*Fn, 0), + Files: make([]string, 0), + StdLibImports: []string{ + "unsafe", + }, + ExternalImports: make([]string, 0), + } + for _, file := range fs { + if err := src.ParseFile(file); err != nil { + return nil, err + } + } + return src, nil +} + +// DLLs return dll names for a source set src. +func (src *Source) DLLs() []string { + uniq := make(map[string]bool) + r := make([]string, 0) + for _, f := range src.Funcs { + name := f.DLLName() + if _, found := uniq[name]; !found { + uniq[name] = true + r = append(r, name) + } + } + sort.Strings(r) + return r +} + +// ParseFile adds additional file path to a source set src. +func (src *Source) ParseFile(path string) error { + file, err := os.Open(path) + if err != nil { + return err + } + defer file.Close() + + s := bufio.NewScanner(file) + for s.Scan() { + t := trim(s.Text()) + if len(t) < 7 { + continue + } + if !strings.HasPrefix(t, "//sys") { + continue + } + t = t[5:] + if !(t[0] == ' ' || t[0] == '\t') { + continue + } + f, err := newFn(t[1:]) + if err != nil { + return err + } + src.Funcs = append(src.Funcs, f) + } + if err := s.Err(); err != nil { + return err + } + src.Files = append(src.Files, path) + sort.Slice(src.Funcs, func(i, j int) bool { + fi, fj := src.Funcs[i], src.Funcs[j] + if fi.DLLName() == fj.DLLName() { + return fi.DLLFuncName() < fj.DLLFuncName() + } + return fi.DLLName() < fj.DLLName() + }) + + // get package name + fset := token.NewFileSet() + _, err = file.Seek(0, 0) + if err != nil { + return err + } + pkg, err := parser.ParseFile(fset, "", file, parser.PackageClauseOnly) + if err != nil { + return err + } + packageName = pkg.Name.Name + + return nil +} + +// IsStdRepo reports whether src is part of standard library. +func (src *Source) IsStdRepo() (bool, error) { + if len(src.Files) == 0 { + return false, errors.New("no input files provided") + } + abspath, err := filepath.Abs(src.Files[0]) + if err != nil { + return false, err + } + goroot := runtime.GOROOT() + if runtime.GOOS == "windows" { + abspath = strings.ToLower(abspath) + goroot = strings.ToLower(goroot) + } + sep := string(os.PathSeparator) + if !strings.HasSuffix(goroot, sep) { + goroot += sep + } + return strings.HasPrefix(abspath, goroot), nil +} + +// Generate output source file from a source set src. +func (src *Source) Generate(w io.Writer) error { + const ( + pkgStd = iota // any package in std library + pkgXSysWindows // x/sys/windows package + pkgOther + ) + isStdRepo, err := src.IsStdRepo() + if err != nil { + return err + } + var pkgtype int + switch { + case isStdRepo: + pkgtype = pkgStd + case packageName == "windows": + // TODO: this needs better logic than just using package name + pkgtype = pkgXSysWindows + default: + pkgtype = pkgOther + } + if *systemDLL { + switch pkgtype { + case pkgStd: + src.Import("internal/syscall/windows/sysdll") + case pkgXSysWindows: + default: + src.ExternalImport("golang.org/x/sys/windows") + } + } + if packageName != "syscall" { + src.Import("syscall") + } + funcMap := template.FuncMap{ + "packagename": packagename, + "syscalldot": syscalldot, + "newlazydll": func(dll string) string { + arg := "\"" + dll + ".dll\"" + if !*systemDLL { + return syscalldot() + "NewLazyDLL(" + arg + ")" + } + switch pkgtype { + case pkgStd: + return syscalldot() + "NewLazyDLL(sysdll.Add(" + arg + "))" + case pkgXSysWindows: + return "NewLazySystemDLL(" + arg + ")" + default: + return "windows.NewLazySystemDLL(" + arg + ")" + } + }, + } + t := template.Must(template.New("main").Funcs(funcMap).Parse(srcTemplate)) + err = t.Execute(w, src) + if err != nil { + return errors.New("Failed to execute template: " + err.Error()) + } + return nil +} + +func usage() { + fmt.Fprintf(os.Stderr, "usage: mksyscall_windows [flags] [path ...]\n") + flag.PrintDefaults() + os.Exit(1) +} + +func main() { + flag.Usage = usage + flag.Parse() + if len(flag.Args()) <= 0 { + fmt.Fprintf(os.Stderr, "no files to parse provided\n") + usage() + } + + src, err := ParseFiles(flag.Args()) + if err != nil { + log.Fatal(err) + } + + var buf bytes.Buffer + if err := src.Generate(&buf); err != nil { + log.Fatal(err) + } + + data, err := format.Source(buf.Bytes()) + if err != nil { + log.Fatal(err) + } + if *filename == "" { + _, err = os.Stdout.Write(data) + } else { + err = ioutil.WriteFile(*filename, data, 0644) + } + if err != nil { + log.Fatal(err) + } +} + +// TODO: use println instead to print in the following template +const srcTemplate = ` + +{{define "main"}}// Code generated by 'go generate'; DO NOT EDIT. + +package {{packagename}} + +import ( +{{range .StdLibImports}}"{{.}}" +{{end}} + +{{range .ExternalImports}}"{{.}}" +{{end}} +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = {{syscalldot}}Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = {{syscalldot}}EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e {{syscalldot}}Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( +{{template "dlls" .}} +{{template "funcnames" .}}) +{{range .Funcs}}{{if .HasStringParam}}{{template "helperbody" .}}{{end}}{{template "funcbody" .}}{{end}} +{{end}} + +{{/* help functions */}} + +{{define "dlls"}}{{range .DLLs}} mod{{.}} = {{newlazydll .}} +{{end}}{{end}} + +{{define "funcnames"}}{{range .Funcs}} proc{{.DLLFuncName}} = mod{{.DLLName}}.NewProc("{{.DLLFuncName}}") +{{end}}{{end}} + +{{define "helperbody"}} +func {{.Name}}({{.ParamList}}) {{template "results" .}}{ +{{template "helpertmpvars" .}} return {{.HelperName}}({{.HelperCallParamList}}) +} +{{end}} + +{{define "funcbody"}} +func {{.HelperName}}({{.HelperParamList}}) {{template "results" .}}{ +{{template "maybeabsent" .}} {{template "tmpvars" .}} {{template "syscall" .}} {{template "tmpvarsreadback" .}} +{{template "seterror" .}}{{template "printtrace" .}} return +} +{{end}} + +{{define "helpertmpvars"}}{{range .Params}}{{if .TmpVarHelperCode}} {{.TmpVarHelperCode}} +{{end}}{{end}}{{end}} + +{{define "maybeabsent"}}{{if .MaybeAbsent}}{{.MaybeAbsent}} +{{end}}{{end}} + +{{define "tmpvars"}}{{range .Params}}{{if .TmpVarCode}} {{.TmpVarCode}} +{{end}}{{end}}{{end}} + +{{define "results"}}{{if .Rets.List}}{{.Rets.List}} {{end}}{{end}} + +{{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}} + +{{define "tmpvarsreadback"}}{{range .Params}}{{if .TmpVarReadbackCode}} +{{.TmpVarReadbackCode}}{{end}}{{end}}{{end}} + +{{define "seterror"}}{{if .Rets.SetErrorCode}} {{.Rets.SetErrorCode}} +{{end}}{{end}} + +{{define "printtrace"}}{{if .PrintTrace}} print("SYSCALL: {{.Name}}(", {{.ParamPrintList}}") (", {{.Rets.PrintList}}")\n") +{{end}}{{end}} + +` diff --git a/vendor/github.com/Microsoft/go-winio/pkg/security/syscall_windows.go b/internal/security/syscall_windows.go similarity index 61% rename from vendor/github.com/Microsoft/go-winio/pkg/security/syscall_windows.go rename to internal/security/syscall_windows.go index c40c2739b7..d7096716ce 100644 --- a/vendor/github.com/Microsoft/go-winio/pkg/security/syscall_windows.go +++ b/internal/security/syscall_windows.go @@ -2,6 +2,6 @@ package security //go:generate go run mksyscall_windows.go -output zsyscall_windows.go syscall_windows.go -//sys getSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, ppsidOwner **uintptr, ppsidGroup **uintptr, ppDacl *uintptr, ppSacl *uintptr, ppSecurityDescriptor *uintptr) (err error) [failretval!=0] = advapi32.GetSecurityInfo -//sys setSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, psidOwner uintptr, psidGroup uintptr, pDacl uintptr, pSacl uintptr) (err error) [failretval!=0] = advapi32.SetSecurityInfo -//sys setEntriesInAcl(count uintptr, pListOfEEs uintptr, oldAcl uintptr, newAcl *uintptr) (err error) [failretval!=0] = advapi32.SetEntriesInAclW +//sys getSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, ppsidOwner **uintptr, ppsidGroup **uintptr, ppDacl *uintptr, ppSacl *uintptr, ppSecurityDescriptor *uintptr) (win32err error) = advapi32.GetSecurityInfo +//sys setSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, psidOwner uintptr, psidGroup uintptr, pDacl uintptr, pSacl uintptr) (win32err error) = advapi32.SetSecurityInfo +//sys setEntriesInAcl(count uintptr, pListOfEEs uintptr, oldAcl uintptr, newAcl *uintptr) (win32err error) = advapi32.SetEntriesInAclW diff --git a/vendor/github.com/Microsoft/go-winio/pkg/security/zsyscall_windows.go b/internal/security/zsyscall_windows.go similarity index 59% rename from vendor/github.com/Microsoft/go-winio/pkg/security/zsyscall_windows.go rename to internal/security/zsyscall_windows.go index 4a90cb3cc8..4084680e0f 100644 --- a/vendor/github.com/Microsoft/go-winio/pkg/security/zsyscall_windows.go +++ b/internal/security/zsyscall_windows.go @@ -45,26 +45,26 @@ var ( procSetSecurityInfo = modadvapi32.NewProc("SetSecurityInfo") ) -func getSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, ppsidOwner **uintptr, ppsidGroup **uintptr, ppDacl *uintptr, ppSacl *uintptr, ppSecurityDescriptor *uintptr) (err error) { - r1, _, e1 := syscall.Syscall9(procGetSecurityInfo.Addr(), 8, uintptr(handle), uintptr(objectType), uintptr(si), uintptr(unsafe.Pointer(ppsidOwner)), uintptr(unsafe.Pointer(ppsidGroup)), uintptr(unsafe.Pointer(ppDacl)), uintptr(unsafe.Pointer(ppSacl)), uintptr(unsafe.Pointer(ppSecurityDescriptor)), 0) - if r1 != 0 { - err = errnoErr(e1) +func getSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, ppsidOwner **uintptr, ppsidGroup **uintptr, ppDacl *uintptr, ppSacl *uintptr, ppSecurityDescriptor *uintptr) (win32err error) { + r0, _, _ := syscall.Syscall9(procGetSecurityInfo.Addr(), 8, uintptr(handle), uintptr(objectType), uintptr(si), uintptr(unsafe.Pointer(ppsidOwner)), uintptr(unsafe.Pointer(ppsidGroup)), uintptr(unsafe.Pointer(ppDacl)), uintptr(unsafe.Pointer(ppSacl)), uintptr(unsafe.Pointer(ppSecurityDescriptor)), 0) + if r0 != 0 { + win32err = syscall.Errno(r0) } return } -func setEntriesInAcl(count uintptr, pListOfEEs uintptr, oldAcl uintptr, newAcl *uintptr) (err error) { - r1, _, e1 := syscall.Syscall6(procSetEntriesInAclW.Addr(), 4, uintptr(count), uintptr(pListOfEEs), uintptr(oldAcl), uintptr(unsafe.Pointer(newAcl)), 0, 0) - if r1 != 0 { - err = errnoErr(e1) +func setEntriesInAcl(count uintptr, pListOfEEs uintptr, oldAcl uintptr, newAcl *uintptr) (win32err error) { + r0, _, _ := syscall.Syscall6(procSetEntriesInAclW.Addr(), 4, uintptr(count), uintptr(pListOfEEs), uintptr(oldAcl), uintptr(unsafe.Pointer(newAcl)), 0, 0) + if r0 != 0 { + win32err = syscall.Errno(r0) } return } -func setSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, psidOwner uintptr, psidGroup uintptr, pDacl uintptr, pSacl uintptr) (err error) { - r1, _, e1 := syscall.Syscall9(procSetSecurityInfo.Addr(), 7, uintptr(handle), uintptr(objectType), uintptr(si), uintptr(psidOwner), uintptr(psidGroup), uintptr(pDacl), uintptr(pSacl), 0, 0) - if r1 != 0 { - err = errnoErr(e1) +func setSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, psidOwner uintptr, psidGroup uintptr, pDacl uintptr, pSacl uintptr) (win32err error) { + r0, _, _ := syscall.Syscall9(procSetSecurityInfo.Addr(), 7, uintptr(handle), uintptr(objectType), uintptr(si), uintptr(psidOwner), uintptr(psidGroup), uintptr(pDacl), uintptr(pSacl), 0, 0) + if r0 != 0 { + win32err = syscall.Errno(r0) } return } diff --git a/internal/tools/grantvmgroupaccess/main.go b/internal/tools/grantvmgroupaccess/main.go index 5fa6644d77..887db694d0 100644 --- a/internal/tools/grantvmgroupaccess/main.go +++ b/internal/tools/grantvmgroupaccess/main.go @@ -6,7 +6,7 @@ import ( "fmt" "os" - "github.com/Microsoft/go-winio/pkg/security" + "github.com/Microsoft/hcsshim/internal/security" ) func main() { diff --git a/internal/uvm/scsi.go b/internal/uvm/scsi.go index cc884c3bcd..e21d0ddfe2 100644 --- a/internal/uvm/scsi.go +++ b/internal/uvm/scsi.go @@ -12,7 +12,6 @@ import ( "path/filepath" "strings" - "github.com/Microsoft/go-winio/pkg/security" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -22,6 +21,7 @@ import ( "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/protocol/guestrequest" "github.com/Microsoft/hcsshim/internal/protocol/guestresource" + "github.com/Microsoft/hcsshim/internal/security" "github.com/Microsoft/hcsshim/internal/wclayer" ) diff --git a/test/vendor/github.com/Microsoft/hcsshim/computestorage/helpers.go b/test/vendor/github.com/Microsoft/hcsshim/computestorage/helpers.go index a0e329edec..c3608dcec8 100644 --- a/test/vendor/github.com/Microsoft/hcsshim/computestorage/helpers.go +++ b/test/vendor/github.com/Microsoft/hcsshim/computestorage/helpers.go @@ -8,11 +8,12 @@ import ( "path/filepath" "syscall" - "github.com/Microsoft/go-winio/pkg/security" "github.com/Microsoft/go-winio/vhd" "github.com/Microsoft/hcsshim/internal/memory" "github.com/pkg/errors" "golang.org/x/sys/windows" + + "github.com/Microsoft/hcsshim/internal/security" ) const defaultVHDXBlockSizeInMB = 1 diff --git a/vendor/github.com/Microsoft/go-winio/pkg/security/grantvmgroupaccess.go b/test/vendor/github.com/Microsoft/hcsshim/internal/security/grantvmgroupaccess.go similarity index 90% rename from vendor/github.com/Microsoft/go-winio/pkg/security/grantvmgroupaccess.go rename to test/vendor/github.com/Microsoft/hcsshim/internal/security/grantvmgroupaccess.go index fca241590c..602920786c 100644 --- a/vendor/github.com/Microsoft/go-winio/pkg/security/grantvmgroupaccess.go +++ b/test/vendor/github.com/Microsoft/hcsshim/internal/security/grantvmgroupaccess.go @@ -3,11 +3,10 @@ package security import ( + "fmt" "os" "syscall" "unsafe" - - "github.com/pkg/errors" ) type ( @@ -72,7 +71,7 @@ func GrantVmGroupAccess(name string) error { // Stat (to determine if `name` is a directory). s, err := os.Stat(name) if err != nil { - return errors.Wrapf(err, "%s os.Stat %s", gvmga, name) + return fmt.Errorf("%s os.Stat %s: %w", gvmga, name, err) } // Get a handle to the file/directory. Must defer Close on success. @@ -88,7 +87,7 @@ func GrantVmGroupAccess(name string) error { sd := uintptr(0) origDACL := uintptr(0) if err := getSecurityInfo(fd, uint32(ot), uint32(si), nil, nil, &origDACL, nil, &sd); err != nil { - return errors.Wrapf(err, "%s GetSecurityInfo %s", gvmga, name) + return fmt.Errorf("%s GetSecurityInfo %s: %w", gvmga, name, err) } defer syscall.LocalFree((syscall.Handle)(unsafe.Pointer(sd))) @@ -102,7 +101,7 @@ func GrantVmGroupAccess(name string) error { // And finally use SetSecurityInfo to apply the updated DACL. if err := setSecurityInfo(fd, uint32(ot), uint32(si), uintptr(0), uintptr(0), newDACL, uintptr(0)); err != nil { - return errors.Wrapf(err, "%s SetSecurityInfo %s", gvmga, name) + return fmt.Errorf("%s SetSecurityInfo %s: %w", gvmga, name, err) } return nil @@ -120,7 +119,7 @@ func createFile(name string, isDir bool) (syscall.Handle, error) { } fd, err := syscall.CreateFile(&namep[0], da, sm, nil, syscall.OPEN_EXISTING, fa, 0) if err != nil { - return 0, errors.Wrapf(err, "%s syscall.CreateFile %s", gvmga, name) + return 0, fmt.Errorf("%s syscall.CreateFile %s: %w", gvmga, name, err) } return fd, nil } @@ -131,7 +130,7 @@ func generateDACLWithAcesAdded(name string, isDir bool, origDACL uintptr) (uintp // Generate pointers to the SIDs based on the string SIDs sid, err := syscall.StringToSid(sidVmGroup) if err != nil { - return 0, errors.Wrapf(err, "%s syscall.StringToSid %s %s", gvmga, name, sidVmGroup) + return 0, fmt.Errorf("%s syscall.StringToSid %s %s: %w", gvmga, name, sidVmGroup, err) } inheritance := inheritModeNoInheritance @@ -154,7 +153,7 @@ func generateDACLWithAcesAdded(name string, isDir bool, origDACL uintptr) (uintp modifiedDACL := uintptr(0) if err := setEntriesInAcl(uintptr(uint32(1)), uintptr(unsafe.Pointer(&eaArray[0])), origDACL, &modifiedDACL); err != nil { - return 0, errors.Wrapf(err, "%s SetEntriesInAcl %s", gvmga, name) + return 0, fmt.Errorf("%s SetEntriesInAcl %s: %w", gvmga, name, err) } return modifiedDACL, nil diff --git a/test/vendor/github.com/Microsoft/go-winio/pkg/security/syscall_windows.go b/test/vendor/github.com/Microsoft/hcsshim/internal/security/syscall_windows.go similarity index 61% rename from test/vendor/github.com/Microsoft/go-winio/pkg/security/syscall_windows.go rename to test/vendor/github.com/Microsoft/hcsshim/internal/security/syscall_windows.go index c40c2739b7..d7096716ce 100644 --- a/test/vendor/github.com/Microsoft/go-winio/pkg/security/syscall_windows.go +++ b/test/vendor/github.com/Microsoft/hcsshim/internal/security/syscall_windows.go @@ -2,6 +2,6 @@ package security //go:generate go run mksyscall_windows.go -output zsyscall_windows.go syscall_windows.go -//sys getSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, ppsidOwner **uintptr, ppsidGroup **uintptr, ppDacl *uintptr, ppSacl *uintptr, ppSecurityDescriptor *uintptr) (err error) [failretval!=0] = advapi32.GetSecurityInfo -//sys setSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, psidOwner uintptr, psidGroup uintptr, pDacl uintptr, pSacl uintptr) (err error) [failretval!=0] = advapi32.SetSecurityInfo -//sys setEntriesInAcl(count uintptr, pListOfEEs uintptr, oldAcl uintptr, newAcl *uintptr) (err error) [failretval!=0] = advapi32.SetEntriesInAclW +//sys getSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, ppsidOwner **uintptr, ppsidGroup **uintptr, ppDacl *uintptr, ppSacl *uintptr, ppSecurityDescriptor *uintptr) (win32err error) = advapi32.GetSecurityInfo +//sys setSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, psidOwner uintptr, psidGroup uintptr, pDacl uintptr, pSacl uintptr) (win32err error) = advapi32.SetSecurityInfo +//sys setEntriesInAcl(count uintptr, pListOfEEs uintptr, oldAcl uintptr, newAcl *uintptr) (win32err error) = advapi32.SetEntriesInAclW diff --git a/test/vendor/github.com/Microsoft/go-winio/pkg/security/zsyscall_windows.go b/test/vendor/github.com/Microsoft/hcsshim/internal/security/zsyscall_windows.go similarity index 59% rename from test/vendor/github.com/Microsoft/go-winio/pkg/security/zsyscall_windows.go rename to test/vendor/github.com/Microsoft/hcsshim/internal/security/zsyscall_windows.go index 4a90cb3cc8..4084680e0f 100644 --- a/test/vendor/github.com/Microsoft/go-winio/pkg/security/zsyscall_windows.go +++ b/test/vendor/github.com/Microsoft/hcsshim/internal/security/zsyscall_windows.go @@ -45,26 +45,26 @@ var ( procSetSecurityInfo = modadvapi32.NewProc("SetSecurityInfo") ) -func getSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, ppsidOwner **uintptr, ppsidGroup **uintptr, ppDacl *uintptr, ppSacl *uintptr, ppSecurityDescriptor *uintptr) (err error) { - r1, _, e1 := syscall.Syscall9(procGetSecurityInfo.Addr(), 8, uintptr(handle), uintptr(objectType), uintptr(si), uintptr(unsafe.Pointer(ppsidOwner)), uintptr(unsafe.Pointer(ppsidGroup)), uintptr(unsafe.Pointer(ppDacl)), uintptr(unsafe.Pointer(ppSacl)), uintptr(unsafe.Pointer(ppSecurityDescriptor)), 0) - if r1 != 0 { - err = errnoErr(e1) +func getSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, ppsidOwner **uintptr, ppsidGroup **uintptr, ppDacl *uintptr, ppSacl *uintptr, ppSecurityDescriptor *uintptr) (win32err error) { + r0, _, _ := syscall.Syscall9(procGetSecurityInfo.Addr(), 8, uintptr(handle), uintptr(objectType), uintptr(si), uintptr(unsafe.Pointer(ppsidOwner)), uintptr(unsafe.Pointer(ppsidGroup)), uintptr(unsafe.Pointer(ppDacl)), uintptr(unsafe.Pointer(ppSacl)), uintptr(unsafe.Pointer(ppSecurityDescriptor)), 0) + if r0 != 0 { + win32err = syscall.Errno(r0) } return } -func setEntriesInAcl(count uintptr, pListOfEEs uintptr, oldAcl uintptr, newAcl *uintptr) (err error) { - r1, _, e1 := syscall.Syscall6(procSetEntriesInAclW.Addr(), 4, uintptr(count), uintptr(pListOfEEs), uintptr(oldAcl), uintptr(unsafe.Pointer(newAcl)), 0, 0) - if r1 != 0 { - err = errnoErr(e1) +func setEntriesInAcl(count uintptr, pListOfEEs uintptr, oldAcl uintptr, newAcl *uintptr) (win32err error) { + r0, _, _ := syscall.Syscall6(procSetEntriesInAclW.Addr(), 4, uintptr(count), uintptr(pListOfEEs), uintptr(oldAcl), uintptr(unsafe.Pointer(newAcl)), 0, 0) + if r0 != 0 { + win32err = syscall.Errno(r0) } return } -func setSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, psidOwner uintptr, psidGroup uintptr, pDacl uintptr, pSacl uintptr) (err error) { - r1, _, e1 := syscall.Syscall9(procSetSecurityInfo.Addr(), 7, uintptr(handle), uintptr(objectType), uintptr(si), uintptr(psidOwner), uintptr(psidGroup), uintptr(pDacl), uintptr(pSacl), 0, 0) - if r1 != 0 { - err = errnoErr(e1) +func setSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, psidOwner uintptr, psidGroup uintptr, pDacl uintptr, pSacl uintptr) (win32err error) { + r0, _, _ := syscall.Syscall9(procSetSecurityInfo.Addr(), 7, uintptr(handle), uintptr(objectType), uintptr(si), uintptr(psidOwner), uintptr(psidGroup), uintptr(pDacl), uintptr(pSacl), 0, 0) + if r0 != 0 { + win32err = syscall.Errno(r0) } return } diff --git a/test/vendor/github.com/Microsoft/hcsshim/internal/uvm/scsi.go b/test/vendor/github.com/Microsoft/hcsshim/internal/uvm/scsi.go index cc884c3bcd..e21d0ddfe2 100644 --- a/test/vendor/github.com/Microsoft/hcsshim/internal/uvm/scsi.go +++ b/test/vendor/github.com/Microsoft/hcsshim/internal/uvm/scsi.go @@ -12,7 +12,6 @@ import ( "path/filepath" "strings" - "github.com/Microsoft/go-winio/pkg/security" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -22,6 +21,7 @@ import ( "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/protocol/guestrequest" "github.com/Microsoft/hcsshim/internal/protocol/guestresource" + "github.com/Microsoft/hcsshim/internal/security" "github.com/Microsoft/hcsshim/internal/wclayer" ) diff --git a/test/vendor/modules.txt b/test/vendor/modules.txt index b0d29873d5..e91c03c2c1 100644 --- a/test/vendor/modules.txt +++ b/test/vendor/modules.txt @@ -4,7 +4,6 @@ github.com/Microsoft/go-winio github.com/Microsoft/go-winio/backuptar github.com/Microsoft/go-winio/pkg/guid github.com/Microsoft/go-winio/pkg/process -github.com/Microsoft/go-winio/pkg/security github.com/Microsoft/go-winio/vhd # github.com/Microsoft/hcsshim v0.8.23 => ../ ## explicit; go 1.17 @@ -59,6 +58,7 @@ github.com/Microsoft/hcsshim/internal/resources github.com/Microsoft/hcsshim/internal/runhcs github.com/Microsoft/hcsshim/internal/safefile github.com/Microsoft/hcsshim/internal/schemaversion +github.com/Microsoft/hcsshim/internal/security github.com/Microsoft/hcsshim/internal/shimdiag github.com/Microsoft/hcsshim/internal/timeout github.com/Microsoft/hcsshim/internal/tools/securitypolicy/helpers diff --git a/vendor/golang.org/x/sys/execabs/execabs.go b/vendor/golang.org/x/sys/execabs/execabs.go new file mode 100644 index 0000000000..78192498db --- /dev/null +++ b/vendor/golang.org/x/sys/execabs/execabs.go @@ -0,0 +1,102 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package execabs is a drop-in replacement for os/exec +// that requires PATH lookups to find absolute paths. +// That is, execabs.Command("cmd") runs the same PATH lookup +// as exec.Command("cmd"), but if the result is a path +// which is relative, the Run and Start methods will report +// an error instead of running the executable. +// +// See https://blog.golang.org/path-security for more information +// about when it may be necessary or appropriate to use this package. +package execabs + +import ( + "context" + "fmt" + "os/exec" + "path/filepath" + "reflect" + "unsafe" +) + +// ErrNotFound is the error resulting if a path search failed to find an executable file. +// It is an alias for exec.ErrNotFound. +var ErrNotFound = exec.ErrNotFound + +// Cmd represents an external command being prepared or run. +// It is an alias for exec.Cmd. +type Cmd = exec.Cmd + +// Error is returned by LookPath when it fails to classify a file as an executable. +// It is an alias for exec.Error. +type Error = exec.Error + +// An ExitError reports an unsuccessful exit by a command. +// It is an alias for exec.ExitError. +type ExitError = exec.ExitError + +func relError(file, path string) error { + return fmt.Errorf("%s resolves to executable in current directory (.%c%s)", file, filepath.Separator, path) +} + +// LookPath searches for an executable named file in the directories +// named by the PATH environment variable. If file contains a slash, +// it is tried directly and the PATH is not consulted. The result will be +// an absolute path. +// +// LookPath differs from exec.LookPath in its handling of PATH lookups, +// which are used for file names without slashes. If exec.LookPath's +// PATH lookup would have returned an executable from the current directory, +// LookPath instead returns an error. +func LookPath(file string) (string, error) { + path, err := exec.LookPath(file) + if err != nil { + return "", err + } + if filepath.Base(file) == file && !filepath.IsAbs(path) { + return "", relError(file, path) + } + return path, nil +} + +func fixCmd(name string, cmd *exec.Cmd) { + if filepath.Base(name) == name && !filepath.IsAbs(cmd.Path) { + // exec.Command was called with a bare binary name and + // exec.LookPath returned a path which is not absolute. + // Set cmd.lookPathErr and clear cmd.Path so that it + // cannot be run. + lookPathErr := (*error)(unsafe.Pointer(reflect.ValueOf(cmd).Elem().FieldByName("lookPathErr").Addr().Pointer())) + if *lookPathErr == nil { + *lookPathErr = relError(name, cmd.Path) + } + cmd.Path = "" + } +} + +// CommandContext is like Command but includes a context. +// +// The provided context is used to kill the process (by calling os.Process.Kill) +// if the context becomes done before the command completes on its own. +func CommandContext(ctx context.Context, name string, arg ...string) *exec.Cmd { + cmd := exec.CommandContext(ctx, name, arg...) + fixCmd(name, cmd) + return cmd + +} + +// Command returns the Cmd struct to execute the named program with the given arguments. +// See exec.Command for most details. +// +// Command differs from exec.Command in its handling of PATH lookups, +// which are used when the program name contains no slashes. +// If exec.Command would have returned an exec.Cmd configured to run an +// executable from the current directory, Command instead +// returns an exec.Cmd that will return an error from Start or Run. +func Command(name string, arg ...string) *exec.Cmd { + cmd := exec.Command(name, arg...) + fixCmd(name, cmd) + return cmd +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 6462b39e28..be16488c02 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -9,7 +9,6 @@ github.com/Microsoft/go-winio/pkg/etw github.com/Microsoft/go-winio/pkg/etwlogrus github.com/Microsoft/go-winio/pkg/guid github.com/Microsoft/go-winio/pkg/process -github.com/Microsoft/go-winio/pkg/security github.com/Microsoft/go-winio/vhd # github.com/cenkalti/backoff/v4 v4.1.1 ## explicit; go 1.13 @@ -208,6 +207,7 @@ golang.org/x/net/trace golang.org/x/sync/errgroup # golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e ## explicit; go 1.17 +golang.org/x/sys/execabs golang.org/x/sys/internal/unsafeheader golang.org/x/sys/unix golang.org/x/sys/windows From 6b920444509e2cd25ff44f5f918cf7fbc8ae438d Mon Sep 17 00:00:00 2001 From: Maksim An Date: Wed, 6 Apr 2022 18:59:51 -0700 Subject: [PATCH 2/2] Extend GrantVmGroupAccess to support write/execute/all permissions Add masks for GENERIC_WRITE, GENERIC_EXECUTE and GENERIC_ALL and update function signatures accordingly. Update grantvmgroupaccess tool to support granting permissions from above. Ignore various linter errors resurfaced after code copy-paste. Remove mksyscall_windows.go and update old unit tests and cleanup helper functions and packages used. Add unit test coverage for new functionality. Signed-off-by: Maksim An --- internal/security/grantvmgroupaccess.go | 82 +- internal/security/grantvmgroupaccess_test.go | 216 +++- internal/security/mksyscall_windows.go | 958 ------------------ internal/security/syscall_windows.go | 2 +- internal/tools/grantvmgroupaccess/main.go | 43 +- .../internal/security/grantvmgroupaccess.go | 82 +- .../internal/security/syscall_windows.go | 2 +- vendor/golang.org/x/sys/execabs/execabs.go | 102 -- vendor/modules.txt | 1 - 9 files changed, 348 insertions(+), 1140 deletions(-) delete mode 100644 internal/security/mksyscall_windows.go delete mode 100644 vendor/golang.org/x/sys/execabs/execabs.go diff --git a/internal/security/grantvmgroupaccess.go b/internal/security/grantvmgroupaccess.go index 602920786c..bfcc157699 100644 --- a/internal/security/grantvmgroupaccess.go +++ b/internal/security/grantvmgroupaccess.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package security @@ -19,25 +20,37 @@ type ( securityInformation uint32 trusteeForm uint32 trusteeType uint32 +) - explicitAccess struct { - accessPermissions accessMask - accessMode accessMode - inheritance inheritMode - trustee trustee - } +type explicitAccess struct { + //nolint:structcheck + accessPermissions accessMask + //nolint:structcheck + accessMode accessMode + //nolint:structcheck + inheritance inheritMode + //nolint:structcheck + trustee trustee +} - trustee struct { - multipleTrustee *trustee - multipleTrusteeOperation int32 - trusteeForm trusteeForm - trusteeType trusteeType - name uintptr - } -) +type trustee struct { + //nolint:unused,structcheck + multipleTrustee *trustee + //nolint:unused,structcheck + multipleTrusteeOperation int32 + trusteeForm trusteeForm + trusteeType trusteeType + name uintptr +} const ( - accessMaskDesiredPermission accessMask = 1 << 31 // GENERIC_READ + AccessMaskNone accessMask = 0 + AccessMaskRead accessMask = 1 << 31 // GENERIC_READ + AccessMaskWrite accessMask = 1 << 30 // GENERIC_WRITE + AccessMaskExecute accessMask = 1 << 29 // GENERIC_EXECUTE + AccessMaskAll accessMask = 1 << 28 // GENERIC_ALL + + accessMaskDesiredPermission = AccessMaskRead accessModeGrant accessMode = 1 @@ -56,6 +69,7 @@ const ( shareModeRead shareMode = 0x1 shareModeWrite shareMode = 0x2 + //nolint:stylecheck // ST1003 sidVmGroup = "S-1-5-83-0" trusteeFormIsSid trusteeForm = 0 @@ -63,11 +77,20 @@ const ( trusteeTypeWellKnownGroup trusteeType = 5 ) -// GrantVMGroupAccess sets the DACL for a specified file or directory to +// GrantVmGroupAccess sets the DACL for a specified file or directory to // include Grant ACE entries for the VM Group SID. This is a golang re- // implementation of the same function in vmcompute, just not exported in // RS5. Which kind of sucks. Sucks a lot :/ -func GrantVmGroupAccess(name string) error { +func GrantVmGroupAccess(name string) error { //nolint:stylecheck // ST1003 + return GrantVmGroupAccessWithMask(name, accessMaskDesiredPermission) +} + +// GrantVmGroupAccessWithMask sets the desired DACL for a specified file or +// directory. +func GrantVmGroupAccessWithMask(name string, access accessMask) error { //nolint:stylecheck // ST1003 + if access == 0 || access<<4 != 0 { + return fmt.Errorf("invalid access mask: 0x%08x", access) + } // Stat (to determine if `name` is a directory). s, err := os.Stat(name) if err != nil { @@ -79,7 +102,9 @@ func GrantVmGroupAccess(name string) error { if err != nil { return err // Already wrapped } - defer syscall.CloseHandle(fd) + defer func() { + _ = syscall.CloseHandle(fd) + }() // Get the current DACL and Security Descriptor. Must defer LocalFree on success. ot := objectTypeFileObject @@ -89,15 +114,19 @@ func GrantVmGroupAccess(name string) error { if err := getSecurityInfo(fd, uint32(ot), uint32(si), nil, nil, &origDACL, nil, &sd); err != nil { return fmt.Errorf("%s GetSecurityInfo %s: %w", gvmga, name, err) } - defer syscall.LocalFree((syscall.Handle)(unsafe.Pointer(sd))) + defer func() { + _, _ = syscall.LocalFree((syscall.Handle)(unsafe.Pointer(sd))) + }() // Generate a new DACL which is the current DACL with the required ACEs added. // Must defer LocalFree on success. - newDACL, err := generateDACLWithAcesAdded(name, s.IsDir(), origDACL) + newDACL, err := generateDACLWithAcesAdded(name, s.IsDir(), access, origDACL) if err != nil { return err // Already wrapped } - defer syscall.LocalFree((syscall.Handle)(unsafe.Pointer(newDACL))) + defer func() { + _, _ = syscall.LocalFree((syscall.Handle)(unsafe.Pointer(newDACL))) + }() // And finally use SetSecurityInfo to apply the updated DACL. if err := setSecurityInfo(fd, uint32(ot), uint32(si), uintptr(0), uintptr(0), newDACL, uintptr(0)); err != nil { @@ -110,7 +139,10 @@ func GrantVmGroupAccess(name string) error { // createFile is a helper function to call [Nt]CreateFile to get a handle to // the file or directory. func createFile(name string, isDir bool) (syscall.Handle, error) { - namep := syscall.StringToUTF16(name) + namep, err := syscall.UTF16FromString(name) + if err != nil { + return 0, fmt.Errorf("syscall.UTF16FromString %s: %w", name, err) + } da := uint32(desiredAccessReadControl | desiredAccessWriteDac) sm := uint32(shareModeRead | shareModeWrite) fa := uint32(syscall.FILE_ATTRIBUTE_NORMAL) @@ -126,7 +158,7 @@ func createFile(name string, isDir bool) (syscall.Handle, error) { // generateDACLWithAcesAdded generates a new DACL with the two needed ACEs added. // The caller is responsible for LocalFree of the returned DACL on success. -func generateDACLWithAcesAdded(name string, isDir bool, origDACL uintptr) (uintptr, error) { +func generateDACLWithAcesAdded(name string, isDir bool, desiredAccess accessMask, origDACL uintptr) (uintptr, error) { // Generate pointers to the SIDs based on the string SIDs sid, err := syscall.StringToSid(sidVmGroup) if err != nil { @@ -139,8 +171,8 @@ func generateDACLWithAcesAdded(name string, isDir bool, origDACL uintptr) (uintp } eaArray := []explicitAccess{ - explicitAccess{ - accessPermissions: accessMaskDesiredPermission, + { + accessPermissions: desiredAccess, accessMode: accessModeGrant, inheritance: inheritance, trustee: trustee{ diff --git a/internal/security/grantvmgroupaccess_test.go b/internal/security/grantvmgroupaccess_test.go index 2dffaa1284..469fdf076f 100644 --- a/internal/security/grantvmgroupaccess_test.go +++ b/internal/security/grantvmgroupaccess_test.go @@ -1,16 +1,15 @@ -//+build windows +//go:build windows +// +build windows package security import ( - "io/ioutil" + "fmt" "os" + "os/exec" "path/filepath" "regexp" - "strings" "testing" - - exec "golang.org/x/sys/execabs" ) const ( @@ -35,37 +34,38 @@ const ( // S-1-15-3-1024-2268835264-3721307629-241982045-173645152-1490879176-104643441-2915960892-1612460704:(I)(R,W) // S-1-5-83-1-3166535780-1122986932-343720105-43916321:(I)(R,W) -func TestGrantVmGroupAccess(t *testing.T) { - f, err := ioutil.TempFile("", "gvmgafile") +func TestGrantVmGroupAccessDefault(t *testing.T) { + f1Path := filepath.Join(t.TempDir(), "gvmgafile") + f, err := os.Create(f1Path) if err != nil { t.Fatal(err) } defer func() { - f.Close() - os.Remove(f.Name()) + _ = f.Close() + _ = os.Remove(f1Path) }() - d, err := ioutil.TempDir("", "gvmgadir") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(d) - - find, err := os.Create(filepath.Join(d, "find.txt")) + dir2 := t.TempDir() + f2Path := filepath.Join(dir2, "find.txt") + find, err := os.Create(f2Path) if err != nil { t.Fatal(err) } + defer func() { + _ = find.Close() + _ = os.Remove(f2Path) + }() - if err := GrantVmGroupAccess(f.Name()); err != nil { + if err := GrantVmGroupAccess(f1Path); err != nil { t.Fatal(err) } - if err := GrantVmGroupAccess(d); err != nil { + if err := GrantVmGroupAccess(dir2); err != nil { t.Fatal(err) } verifyVMAccountDACLs(t, - f.Name(), + f1Path, []string{`(R)`}, ) @@ -80,17 +80,186 @@ func TestGrantVmGroupAccess(t *testing.T) { // In properties for the directory, advanced security settings, this will // show as a single line "Allow/Virtual Machines/Read/Inherited from none/This folder, subfolder and files verifyVMAccountDACLs(t, - d, + dir2, []string{`(R)`, `(OI)(CI)(IO)(GR)`}, ) verifyVMAccountDACLs(t, - find.Name(), + f2Path, []string{`(I)(R)`}, ) } +func TestGrantVMGroupAccess_File_DesiredPermissions(t *testing.T) { + type config struct { + name string + desiredAccess accessMask + expectedPermissions []string + } + + for _, cfg := range []config{ + { + name: "Read", + desiredAccess: AccessMaskRead, + expectedPermissions: []string{`(R)`}, + }, + { + name: "Write", + desiredAccess: AccessMaskWrite, + expectedPermissions: []string{`(W,Rc)`}, + }, + { + name: "Execute", + desiredAccess: AccessMaskExecute, + expectedPermissions: []string{`(Rc,S,X,RA)`}, + }, + { + name: "ReadWrite", + desiredAccess: AccessMaskRead | AccessMaskWrite, + expectedPermissions: []string{`(R,W)`}, + }, + { + name: "ReadExecute", + desiredAccess: AccessMaskRead | AccessMaskExecute, + expectedPermissions: []string{`(RX)`}, + }, + { + name: "WriteExecute", + desiredAccess: AccessMaskWrite | AccessMaskExecute, + expectedPermissions: []string{`(W,Rc,X,RA)`}, + }, + { + name: "ReadWriteExecute", + desiredAccess: AccessMaskRead | AccessMaskWrite | AccessMaskExecute, + expectedPermissions: []string{`(RX,W)`}, + }, + { + name: "All", + desiredAccess: AccessMaskAll, + expectedPermissions: []string{`(F)`}, + }, + } { + t.Run(cfg.name, func(t *testing.T) { + dir := t.TempDir() + fd, err := os.Create(filepath.Join(dir, "test.txt")) + if err != nil { + t.Fatalf("failed to create temporary file: %s", err) + } + defer func() { + _ = fd.Close() + _ = os.Remove(fd.Name()) + }() + + if err := GrantVmGroupAccessWithMask(fd.Name(), cfg.desiredAccess); err != nil { + t.Fatal(err) + } + verifyVMAccountDACLs(t, fd.Name(), cfg.expectedPermissions) + }) + } +} + +func TestGrantVMGroupAccess_Directory_Permissions(t *testing.T) { + type config struct { + name string + access accessMask + filePermissions []string + dirPermissions []string + } + + for _, cfg := range []config{ + { + name: "Read", + access: AccessMaskRead, + filePermissions: []string{`(I)(R)`}, + dirPermissions: []string{`(R)`, `(OI)(CI)(IO)(GR)`}, + }, + { + name: "Write", + access: AccessMaskWrite, + filePermissions: []string{`(I)(W,Rc)`}, + dirPermissions: []string{`(W,Rc)`, `(OI)(CI)(IO)(GW)`}, + }, + { + name: "Execute", + access: AccessMaskExecute, + filePermissions: []string{`(I)(Rc,S,X,RA)`}, + dirPermissions: []string{`(Rc,S,X,RA)`, `(OI)(CI)(IO)(GE)`}, + }, + { + name: "ReadWrite", + access: AccessMaskRead | AccessMaskWrite, + filePermissions: []string{`(I)(R,W)`}, + dirPermissions: []string{`(R,W)`, `(OI)(CI)(IO)(GR,GW)`}, + }, + { + name: "ReadExecute", + access: AccessMaskRead | AccessMaskExecute, + filePermissions: []string{`(I)(RX)`}, + dirPermissions: []string{`(RX)`, `(OI)(CI)(IO)(GR,GE)`}, + }, + { + name: "WriteExecute", + access: AccessMaskWrite | AccessMaskExecute, + filePermissions: []string{`(I)(W,Rc,X,RA)`}, + dirPermissions: []string{`(W,Rc,X,RA)`, `(OI)(CI)(IO)(GW,GE)`}, + }, + { + name: "ReadWriteExecute", + access: AccessMaskRead | AccessMaskWrite | AccessMaskExecute, + filePermissions: []string{`(I)(RX,W)`}, + dirPermissions: []string{`(RX,W)`, `(OI)(CI)(IO)(GR,GW,GE)`}, + }, + { + name: "All", + access: AccessMaskAll, + filePermissions: []string{`(I)(F)`}, + dirPermissions: []string{`(F)`, `(OI)(CI)(IO)(F)`}, + }} { + t.Run(cfg.name, func(t *testing.T) { + dir := t.TempDir() + fd, err := os.Create(filepath.Join(dir, "test.txt")) + if err != nil { + t.Fatalf("failed to create temporary file: %s", err) + } + defer func() { + _ = fd.Close() + _ = os.Remove(fd.Name()) + }() + + if err := GrantVmGroupAccessWithMask(dir, cfg.access); err != nil { + t.Fatal(err) + } + verifyVMAccountDACLs(t, dir, cfg.dirPermissions) + verifyVMAccountDACLs(t, fd.Name(), cfg.filePermissions) + }) + } +} + +func TestGrantVmGroupAccess_Invalid_AccessMask(t *testing.T) { + for _, access := range []accessMask{ + 0, // no bits set + 1, // invalid bit set + 0x02000001, // invalid extra bit set + } { + t.Run(fmt.Sprintf("AccessMask_0x%x", access), func(t *testing.T) { + dir := t.TempDir() + fd, err := os.Create(filepath.Join(dir, "test.txt")) + if err != nil { + t.Fatalf("failed to create temporary file: %s", err) + } + defer func() { + _ = fd.Close() + _ = os.Remove(fd.Name()) + }() + + if err := GrantVmGroupAccessWithMask(fd.Name(), access); err == nil { + t.Fatalf("expected an error for mask: %x", access) + } + }) + } +} + func verifyVMAccountDACLs(t *testing.T, name string, permissions []string) { cmd := exec.Command("icacls", name) outb, err := cmd.CombinedOutput() @@ -101,8 +270,7 @@ func verifyVMAccountDACLs(t *testing.T, name string, permissions []string) { for _, p := range permissions { // Avoid '(' and ')' being part of match groups - p = strings.Replace(p, "(", "\\(", -1) - p = strings.Replace(p, ")", "\\)", -1) + p = regexp.QuoteMeta(p) nameToCheck := vmAccountName + ":" + p sidToCheck := vmAccountSID + ":" + p @@ -114,7 +282,7 @@ func verifyVMAccountDACLs(t *testing.T, name string, permissions []string) { matchesSID := rxSID.FindAllStringIndex(out, -1) if len(matchesName) != 1 && len(matchesSID) != 1 { - t.Fatalf("expected one match for %s or %s\n%s", nameToCheck, sidToCheck, out) + t.Fatalf("expected one match for %s or %s\n%s\n", nameToCheck, sidToCheck, out) } } } diff --git a/internal/security/mksyscall_windows.go b/internal/security/mksyscall_windows.go deleted file mode 100644 index d52b6137c7..0000000000 --- a/internal/security/mksyscall_windows.go +++ /dev/null @@ -1,958 +0,0 @@ -// Copyright 2013 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Hard-coding unicode mode for VHD library. - -// +build ignore - -/* -mksyscall_windows generates windows system call bodies - -It parses all files specified on command line containing function -prototypes (like syscall_windows.go) and prints system call bodies -to standard output. - -The prototypes are marked by lines beginning with "//sys" and read -like func declarations if //sys is replaced by func, but: - -* The parameter lists must give a name for each argument. This - includes return parameters. - -* The parameter lists must give a type for each argument: - the (x, y, z int) shorthand is not allowed. - -* If the return parameter is an error number, it must be named err. - -* If go func name needs to be different from its winapi dll name, - the winapi name could be specified at the end, after "=" sign, like - //sys LoadLibrary(libname string) (handle uint32, err error) = LoadLibraryA - -* Each function that returns err needs to supply a condition, that - return value of winapi will be tested against to detect failure. - This would set err to windows "last-error", otherwise it will be nil. - The value can be provided at end of //sys declaration, like - //sys LoadLibrary(libname string) (handle uint32, err error) [failretval==-1] = LoadLibraryA - and is [failretval==0] by default. - -* If the function name ends in a "?", then the function not existing is non- - fatal, and an error will be returned instead of panicking. - -Usage: - mksyscall_windows [flags] [path ...] - -The flags are: - -output - Specify output file name (outputs to console if blank). - -trace - Generate print statement after every syscall. -*/ -package main - -import ( - "bufio" - "bytes" - "errors" - "flag" - "fmt" - "go/format" - "go/parser" - "go/token" - "io" - "io/ioutil" - "log" - "os" - "path/filepath" - "runtime" - "sort" - "strconv" - "strings" - "text/template" -) - -var ( - filename = flag.String("output", "", "output file name (standard output if omitted)") - printTraceFlag = flag.Bool("trace", false, "generate print statement after every syscall") - systemDLL = flag.Bool("systemdll", true, "whether all DLLs should be loaded from the Windows system directory") -) - -func trim(s string) string { - return strings.Trim(s, " \t") -} - -var packageName string - -func packagename() string { - return packageName -} - -func syscalldot() string { - if packageName == "syscall" { - return "" - } - return "syscall." -} - -// Param is function parameter -type Param struct { - Name string - Type string - fn *Fn - tmpVarIdx int -} - -// tmpVar returns temp variable name that will be used to represent p during syscall. -func (p *Param) tmpVar() string { - if p.tmpVarIdx < 0 { - p.tmpVarIdx = p.fn.curTmpVarIdx - p.fn.curTmpVarIdx++ - } - return fmt.Sprintf("_p%d", p.tmpVarIdx) -} - -// BoolTmpVarCode returns source code for bool temp variable. -func (p *Param) BoolTmpVarCode() string { - const code = `var %[1]s uint32 - if %[2]s { - %[1]s = 1 - }` - return fmt.Sprintf(code, p.tmpVar(), p.Name) -} - -// BoolPointerTmpVarCode returns source code for bool temp variable. -func (p *Param) BoolPointerTmpVarCode() string { - const code = `var %[1]s uint32 - if *%[2]s { - %[1]s = 1 - }` - return fmt.Sprintf(code, p.tmpVar(), p.Name) -} - -// SliceTmpVarCode returns source code for slice temp variable. -func (p *Param) SliceTmpVarCode() string { - const code = `var %s *%s - if len(%s) > 0 { - %s = &%s[0] - }` - tmp := p.tmpVar() - return fmt.Sprintf(code, tmp, p.Type[2:], p.Name, tmp, p.Name) -} - -// StringTmpVarCode returns source code for string temp variable. -func (p *Param) StringTmpVarCode() string { - errvar := p.fn.Rets.ErrorVarName() - if errvar == "" { - errvar = "_" - } - tmp := p.tmpVar() - const code = `var %s %s - %s, %s = %s(%s)` - s := fmt.Sprintf(code, tmp, p.fn.StrconvType(), tmp, errvar, p.fn.StrconvFunc(), p.Name) - if errvar == "-" { - return s - } - const morecode = ` - if %s != nil { - return - }` - return s + fmt.Sprintf(morecode, errvar) -} - -// TmpVarCode returns source code for temp variable. -func (p *Param) TmpVarCode() string { - switch { - case p.Type == "bool": - return p.BoolTmpVarCode() - case p.Type == "*bool": - return p.BoolPointerTmpVarCode() - case strings.HasPrefix(p.Type, "[]"): - return p.SliceTmpVarCode() - default: - return "" - } -} - -// TmpVarReadbackCode returns source code for reading back the temp variable into the original variable. -func (p *Param) TmpVarReadbackCode() string { - switch { - case p.Type == "*bool": - return fmt.Sprintf("*%s = %s != 0", p.Name, p.tmpVar()) - default: - return "" - } -} - -// TmpVarHelperCode returns source code for helper's temp variable. -func (p *Param) TmpVarHelperCode() string { - if p.Type != "string" { - return "" - } - return p.StringTmpVarCode() -} - -// SyscallArgList returns source code fragments representing p parameter -// in syscall. Slices are translated into 2 syscall parameters: pointer to -// the first element and length. -func (p *Param) SyscallArgList() []string { - t := p.HelperType() - var s string - switch { - case t == "*bool": - s = fmt.Sprintf("unsafe.Pointer(&%s)", p.tmpVar()) - case t[0] == '*': - s = fmt.Sprintf("unsafe.Pointer(%s)", p.Name) - case t == "bool": - s = p.tmpVar() - case strings.HasPrefix(t, "[]"): - return []string{ - fmt.Sprintf("uintptr(unsafe.Pointer(%s))", p.tmpVar()), - fmt.Sprintf("uintptr(len(%s))", p.Name), - } - default: - s = p.Name - } - return []string{fmt.Sprintf("uintptr(%s)", s)} -} - -// IsError determines if p parameter is used to return error. -func (p *Param) IsError() bool { - return p.Name == "err" && p.Type == "error" -} - -// HelperType returns type of parameter p used in helper function. -func (p *Param) HelperType() string { - if p.Type == "string" { - return p.fn.StrconvType() - } - return p.Type -} - -// join concatenates parameters ps into a string with sep separator. -// Each parameter is converted into string by applying fn to it -// before conversion. -func join(ps []*Param, fn func(*Param) string, sep string) string { - if len(ps) == 0 { - return "" - } - a := make([]string, 0) - for _, p := range ps { - a = append(a, fn(p)) - } - return strings.Join(a, sep) -} - -// Rets describes function return parameters. -type Rets struct { - Name string - Type string - ReturnsError bool - FailCond string - fnMaybeAbsent bool -} - -// ErrorVarName returns error variable name for r. -func (r *Rets) ErrorVarName() string { - if r.ReturnsError { - return "err" - } - if r.Type == "error" { - return r.Name - } - return "" -} - -// ToParams converts r into slice of *Param. -func (r *Rets) ToParams() []*Param { - ps := make([]*Param, 0) - if len(r.Name) > 0 { - ps = append(ps, &Param{Name: r.Name, Type: r.Type}) - } - if r.ReturnsError { - ps = append(ps, &Param{Name: "err", Type: "error"}) - } - return ps -} - -// List returns source code of syscall return parameters. -func (r *Rets) List() string { - s := join(r.ToParams(), func(p *Param) string { return p.Name + " " + p.Type }, ", ") - if len(s) > 0 { - s = "(" + s + ")" - } else if r.fnMaybeAbsent { - s = "(err error)" - } - return s -} - -// PrintList returns source code of trace printing part correspondent -// to syscall return values. -func (r *Rets) PrintList() string { - return join(r.ToParams(), func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `) -} - -// SetReturnValuesCode returns source code that accepts syscall return values. -func (r *Rets) SetReturnValuesCode() string { - if r.Name == "" && !r.ReturnsError { - return "" - } - retvar := "r0" - if r.Name == "" { - retvar = "r1" - } - errvar := "_" - if r.ReturnsError { - errvar = "e1" - } - return fmt.Sprintf("%s, _, %s := ", retvar, errvar) -} - -func (r *Rets) useLongHandleErrorCode(retvar string) string { - const code = `if %s { - err = errnoErr(e1) - }` - cond := retvar + " == 0" - if r.FailCond != "" { - cond = strings.Replace(r.FailCond, "failretval", retvar, 1) - } - return fmt.Sprintf(code, cond) -} - -// SetErrorCode returns source code that sets return parameters. -func (r *Rets) SetErrorCode() string { - const code = `if r0 != 0 { - %s = %sErrno(r0) - }` - if r.Name == "" && !r.ReturnsError { - return "" - } - if r.Name == "" { - return r.useLongHandleErrorCode("r1") - } - if r.Type == "error" { - return fmt.Sprintf(code, r.Name, syscalldot()) - } - s := "" - switch { - case r.Type[0] == '*': - s = fmt.Sprintf("%s = (%s)(unsafe.Pointer(r0))", r.Name, r.Type) - case r.Type == "bool": - s = fmt.Sprintf("%s = r0 != 0", r.Name) - default: - s = fmt.Sprintf("%s = %s(r0)", r.Name, r.Type) - } - if !r.ReturnsError { - return s - } - return s + "\n\t" + r.useLongHandleErrorCode(r.Name) -} - -// Fn describes syscall function. -type Fn struct { - Name string - Params []*Param - Rets *Rets - PrintTrace bool - dllname string - dllfuncname string - src string - // TODO: get rid of this field and just use parameter index instead - curTmpVarIdx int // insure tmp variables have uniq names -} - -// extractParams parses s to extract function parameters. -func extractParams(s string, f *Fn) ([]*Param, error) { - s = trim(s) - if s == "" { - return nil, nil - } - a := strings.Split(s, ",") - ps := make([]*Param, len(a)) - for i := range ps { - s2 := trim(a[i]) - b := strings.Split(s2, " ") - if len(b) != 2 { - b = strings.Split(s2, "\t") - if len(b) != 2 { - return nil, errors.New("Could not extract function parameter from \"" + s2 + "\"") - } - } - ps[i] = &Param{ - Name: trim(b[0]), - Type: trim(b[1]), - fn: f, - tmpVarIdx: -1, - } - } - return ps, nil -} - -// extractSection extracts text out of string s starting after start -// and ending just before end. found return value will indicate success, -// and prefix, body and suffix will contain correspondent parts of string s. -func extractSection(s string, start, end rune) (prefix, body, suffix string, found bool) { - s = trim(s) - if strings.HasPrefix(s, string(start)) { - // no prefix - body = s[1:] - } else { - a := strings.SplitN(s, string(start), 2) - if len(a) != 2 { - return "", "", s, false - } - prefix = a[0] - body = a[1] - } - a := strings.SplitN(body, string(end), 2) - if len(a) != 2 { - return "", "", "", false - } - return prefix, a[0], a[1], true -} - -// newFn parses string s and return created function Fn. -func newFn(s string) (*Fn, error) { - s = trim(s) - f := &Fn{ - Rets: &Rets{}, - src: s, - PrintTrace: *printTraceFlag, - } - // function name and args - prefix, body, s, found := extractSection(s, '(', ')') - if !found || prefix == "" { - return nil, errors.New("Could not extract function name and parameters from \"" + f.src + "\"") - } - f.Name = prefix - var err error - f.Params, err = extractParams(body, f) - if err != nil { - return nil, err - } - // return values - _, body, s, found = extractSection(s, '(', ')') - if found { - r, err := extractParams(body, f) - if err != nil { - return nil, err - } - switch len(r) { - case 0: - case 1: - if r[0].IsError() { - f.Rets.ReturnsError = true - } else { - f.Rets.Name = r[0].Name - f.Rets.Type = r[0].Type - } - case 2: - if !r[1].IsError() { - return nil, errors.New("Only last windows error is allowed as second return value in \"" + f.src + "\"") - } - f.Rets.ReturnsError = true - f.Rets.Name = r[0].Name - f.Rets.Type = r[0].Type - default: - return nil, errors.New("Too many return values in \"" + f.src + "\"") - } - } - // fail condition - _, body, s, found = extractSection(s, '[', ']') - if found { - f.Rets.FailCond = body - } - // dll and dll function names - s = trim(s) - if s == "" { - return f, nil - } - if !strings.HasPrefix(s, "=") { - return nil, errors.New("Could not extract dll name from \"" + f.src + "\"") - } - s = trim(s[1:]) - a := strings.Split(s, ".") - switch len(a) { - case 1: - f.dllfuncname = a[0] - case 2: - f.dllname = a[0] - f.dllfuncname = a[1] - default: - return nil, errors.New("Could not extract dll name from \"" + f.src + "\"") - } - if n := f.dllfuncname; strings.HasSuffix(n, "?") { - f.dllfuncname = n[:len(n)-1] - f.Rets.fnMaybeAbsent = true - } - return f, nil -} - -// DLLName returns DLL name for function f. -func (f *Fn) DLLName() string { - if f.dllname == "" { - return "kernel32" - } - return f.dllname -} - -// DLLName returns DLL function name for function f. -func (f *Fn) DLLFuncName() string { - if f.dllfuncname == "" { - return f.Name - } - return f.dllfuncname -} - -// ParamList returns source code for function f parameters. -func (f *Fn) ParamList() string { - return join(f.Params, func(p *Param) string { return p.Name + " " + p.Type }, ", ") -} - -// HelperParamList returns source code for helper function f parameters. -func (f *Fn) HelperParamList() string { - return join(f.Params, func(p *Param) string { return p.Name + " " + p.HelperType() }, ", ") -} - -// ParamPrintList returns source code of trace printing part correspondent -// to syscall input parameters. -func (f *Fn) ParamPrintList() string { - return join(f.Params, func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `) -} - -// ParamCount return number of syscall parameters for function f. -func (f *Fn) ParamCount() int { - n := 0 - for _, p := range f.Params { - n += len(p.SyscallArgList()) - } - return n -} - -// SyscallParamCount determines which version of Syscall/Syscall6/Syscall9/... -// to use. It returns parameter count for correspondent SyscallX function. -func (f *Fn) SyscallParamCount() int { - n := f.ParamCount() - switch { - case n <= 3: - return 3 - case n <= 6: - return 6 - case n <= 9: - return 9 - case n <= 12: - return 12 - case n <= 15: - return 15 - default: - panic("too many arguments to system call") - } -} - -// Syscall determines which SyscallX function to use for function f. -func (f *Fn) Syscall() string { - c := f.SyscallParamCount() - if c == 3 { - return syscalldot() + "Syscall" - } - return syscalldot() + "Syscall" + strconv.Itoa(c) -} - -// SyscallParamList returns source code for SyscallX parameters for function f. -func (f *Fn) SyscallParamList() string { - a := make([]string, 0) - for _, p := range f.Params { - a = append(a, p.SyscallArgList()...) - } - for len(a) < f.SyscallParamCount() { - a = append(a, "0") - } - return strings.Join(a, ", ") -} - -// HelperCallParamList returns source code of call into function f helper. -func (f *Fn) HelperCallParamList() string { - a := make([]string, 0, len(f.Params)) - for _, p := range f.Params { - s := p.Name - if p.Type == "string" { - s = p.tmpVar() - } - a = append(a, s) - } - return strings.Join(a, ", ") -} - -// MaybeAbsent returns source code for handling functions that are possibly unavailable. -func (p *Fn) MaybeAbsent() string { - if !p.Rets.fnMaybeAbsent { - return "" - } - const code = `%[1]s = proc%[2]s.Find() - if %[1]s != nil { - return - }` - errorVar := p.Rets.ErrorVarName() - if errorVar == "" { - errorVar = "err" - } - return fmt.Sprintf(code, errorVar, p.DLLFuncName()) -} - -// IsUTF16 is true, if f is W (utf16) function. It is false -// for all A (ascii) functions. -func (_ *Fn) IsUTF16() bool { - return true -} - -// StrconvFunc returns name of Go string to OS string function for f. -func (f *Fn) StrconvFunc() string { - if f.IsUTF16() { - return syscalldot() + "UTF16PtrFromString" - } - return syscalldot() + "BytePtrFromString" -} - -// StrconvType returns Go type name used for OS string for f. -func (f *Fn) StrconvType() string { - if f.IsUTF16() { - return "*uint16" - } - return "*byte" -} - -// HasStringParam is true, if f has at least one string parameter. -// Otherwise it is false. -func (f *Fn) HasStringParam() bool { - for _, p := range f.Params { - if p.Type == "string" { - return true - } - } - return false -} - -// HelperName returns name of function f helper. -func (f *Fn) HelperName() string { - if !f.HasStringParam() { - return f.Name - } - return "_" + f.Name -} - -// Source files and functions. -type Source struct { - Funcs []*Fn - Files []string - StdLibImports []string - ExternalImports []string -} - -func (src *Source) Import(pkg string) { - src.StdLibImports = append(src.StdLibImports, pkg) - sort.Strings(src.StdLibImports) -} - -func (src *Source) ExternalImport(pkg string) { - src.ExternalImports = append(src.ExternalImports, pkg) - sort.Strings(src.ExternalImports) -} - -// ParseFiles parses files listed in fs and extracts all syscall -// functions listed in sys comments. It returns source files -// and functions collection *Source if successful. -func ParseFiles(fs []string) (*Source, error) { - src := &Source{ - Funcs: make([]*Fn, 0), - Files: make([]string, 0), - StdLibImports: []string{ - "unsafe", - }, - ExternalImports: make([]string, 0), - } - for _, file := range fs { - if err := src.ParseFile(file); err != nil { - return nil, err - } - } - return src, nil -} - -// DLLs return dll names for a source set src. -func (src *Source) DLLs() []string { - uniq := make(map[string]bool) - r := make([]string, 0) - for _, f := range src.Funcs { - name := f.DLLName() - if _, found := uniq[name]; !found { - uniq[name] = true - r = append(r, name) - } - } - sort.Strings(r) - return r -} - -// ParseFile adds additional file path to a source set src. -func (src *Source) ParseFile(path string) error { - file, err := os.Open(path) - if err != nil { - return err - } - defer file.Close() - - s := bufio.NewScanner(file) - for s.Scan() { - t := trim(s.Text()) - if len(t) < 7 { - continue - } - if !strings.HasPrefix(t, "//sys") { - continue - } - t = t[5:] - if !(t[0] == ' ' || t[0] == '\t') { - continue - } - f, err := newFn(t[1:]) - if err != nil { - return err - } - src.Funcs = append(src.Funcs, f) - } - if err := s.Err(); err != nil { - return err - } - src.Files = append(src.Files, path) - sort.Slice(src.Funcs, func(i, j int) bool { - fi, fj := src.Funcs[i], src.Funcs[j] - if fi.DLLName() == fj.DLLName() { - return fi.DLLFuncName() < fj.DLLFuncName() - } - return fi.DLLName() < fj.DLLName() - }) - - // get package name - fset := token.NewFileSet() - _, err = file.Seek(0, 0) - if err != nil { - return err - } - pkg, err := parser.ParseFile(fset, "", file, parser.PackageClauseOnly) - if err != nil { - return err - } - packageName = pkg.Name.Name - - return nil -} - -// IsStdRepo reports whether src is part of standard library. -func (src *Source) IsStdRepo() (bool, error) { - if len(src.Files) == 0 { - return false, errors.New("no input files provided") - } - abspath, err := filepath.Abs(src.Files[0]) - if err != nil { - return false, err - } - goroot := runtime.GOROOT() - if runtime.GOOS == "windows" { - abspath = strings.ToLower(abspath) - goroot = strings.ToLower(goroot) - } - sep := string(os.PathSeparator) - if !strings.HasSuffix(goroot, sep) { - goroot += sep - } - return strings.HasPrefix(abspath, goroot), nil -} - -// Generate output source file from a source set src. -func (src *Source) Generate(w io.Writer) error { - const ( - pkgStd = iota // any package in std library - pkgXSysWindows // x/sys/windows package - pkgOther - ) - isStdRepo, err := src.IsStdRepo() - if err != nil { - return err - } - var pkgtype int - switch { - case isStdRepo: - pkgtype = pkgStd - case packageName == "windows": - // TODO: this needs better logic than just using package name - pkgtype = pkgXSysWindows - default: - pkgtype = pkgOther - } - if *systemDLL { - switch pkgtype { - case pkgStd: - src.Import("internal/syscall/windows/sysdll") - case pkgXSysWindows: - default: - src.ExternalImport("golang.org/x/sys/windows") - } - } - if packageName != "syscall" { - src.Import("syscall") - } - funcMap := template.FuncMap{ - "packagename": packagename, - "syscalldot": syscalldot, - "newlazydll": func(dll string) string { - arg := "\"" + dll + ".dll\"" - if !*systemDLL { - return syscalldot() + "NewLazyDLL(" + arg + ")" - } - switch pkgtype { - case pkgStd: - return syscalldot() + "NewLazyDLL(sysdll.Add(" + arg + "))" - case pkgXSysWindows: - return "NewLazySystemDLL(" + arg + ")" - default: - return "windows.NewLazySystemDLL(" + arg + ")" - } - }, - } - t := template.Must(template.New("main").Funcs(funcMap).Parse(srcTemplate)) - err = t.Execute(w, src) - if err != nil { - return errors.New("Failed to execute template: " + err.Error()) - } - return nil -} - -func usage() { - fmt.Fprintf(os.Stderr, "usage: mksyscall_windows [flags] [path ...]\n") - flag.PrintDefaults() - os.Exit(1) -} - -func main() { - flag.Usage = usage - flag.Parse() - if len(flag.Args()) <= 0 { - fmt.Fprintf(os.Stderr, "no files to parse provided\n") - usage() - } - - src, err := ParseFiles(flag.Args()) - if err != nil { - log.Fatal(err) - } - - var buf bytes.Buffer - if err := src.Generate(&buf); err != nil { - log.Fatal(err) - } - - data, err := format.Source(buf.Bytes()) - if err != nil { - log.Fatal(err) - } - if *filename == "" { - _, err = os.Stdout.Write(data) - } else { - err = ioutil.WriteFile(*filename, data, 0644) - } - if err != nil { - log.Fatal(err) - } -} - -// TODO: use println instead to print in the following template -const srcTemplate = ` - -{{define "main"}}// Code generated by 'go generate'; DO NOT EDIT. - -package {{packagename}} - -import ( -{{range .StdLibImports}}"{{.}}" -{{end}} - -{{range .ExternalImports}}"{{.}}" -{{end}} -) - -var _ unsafe.Pointer - -// Do the interface allocations only once for common -// Errno values. -const ( - errnoERROR_IO_PENDING = 997 -) - -var ( - errERROR_IO_PENDING error = {{syscalldot}}Errno(errnoERROR_IO_PENDING) - errERROR_EINVAL error = {{syscalldot}}EINVAL -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e {{syscalldot}}Errno) error { - switch e { - case 0: - return errERROR_EINVAL - case errnoERROR_IO_PENDING: - return errERROR_IO_PENDING - } - // TODO: add more here, after collecting data on the common - // error values see on Windows. (perhaps when running - // all.bat?) - return e -} - -var ( -{{template "dlls" .}} -{{template "funcnames" .}}) -{{range .Funcs}}{{if .HasStringParam}}{{template "helperbody" .}}{{end}}{{template "funcbody" .}}{{end}} -{{end}} - -{{/* help functions */}} - -{{define "dlls"}}{{range .DLLs}} mod{{.}} = {{newlazydll .}} -{{end}}{{end}} - -{{define "funcnames"}}{{range .Funcs}} proc{{.DLLFuncName}} = mod{{.DLLName}}.NewProc("{{.DLLFuncName}}") -{{end}}{{end}} - -{{define "helperbody"}} -func {{.Name}}({{.ParamList}}) {{template "results" .}}{ -{{template "helpertmpvars" .}} return {{.HelperName}}({{.HelperCallParamList}}) -} -{{end}} - -{{define "funcbody"}} -func {{.HelperName}}({{.HelperParamList}}) {{template "results" .}}{ -{{template "maybeabsent" .}} {{template "tmpvars" .}} {{template "syscall" .}} {{template "tmpvarsreadback" .}} -{{template "seterror" .}}{{template "printtrace" .}} return -} -{{end}} - -{{define "helpertmpvars"}}{{range .Params}}{{if .TmpVarHelperCode}} {{.TmpVarHelperCode}} -{{end}}{{end}}{{end}} - -{{define "maybeabsent"}}{{if .MaybeAbsent}}{{.MaybeAbsent}} -{{end}}{{end}} - -{{define "tmpvars"}}{{range .Params}}{{if .TmpVarCode}} {{.TmpVarCode}} -{{end}}{{end}}{{end}} - -{{define "results"}}{{if .Rets.List}}{{.Rets.List}} {{end}}{{end}} - -{{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}} - -{{define "tmpvarsreadback"}}{{range .Params}}{{if .TmpVarReadbackCode}} -{{.TmpVarReadbackCode}}{{end}}{{end}}{{end}} - -{{define "seterror"}}{{if .Rets.SetErrorCode}} {{.Rets.SetErrorCode}} -{{end}}{{end}} - -{{define "printtrace"}}{{if .PrintTrace}} print("SYSCALL: {{.Name}}(", {{.ParamPrintList}}") (", {{.Rets.PrintList}}")\n") -{{end}}{{end}} - -` diff --git a/internal/security/syscall_windows.go b/internal/security/syscall_windows.go index d7096716ce..f0cdd7d20c 100644 --- a/internal/security/syscall_windows.go +++ b/internal/security/syscall_windows.go @@ -1,6 +1,6 @@ package security -//go:generate go run mksyscall_windows.go -output zsyscall_windows.go syscall_windows.go +//go:generate go run $GOPATH/src/golang.org/x/sys/windows/mkwinsyscall/mkwinsyscall.go -output zsyscall_windows.go syscall_windows.go //sys getSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, ppsidOwner **uintptr, ppsidGroup **uintptr, ppDacl *uintptr, ppSacl *uintptr, ppSecurityDescriptor *uintptr) (win32err error) = advapi32.GetSecurityInfo //sys setSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, psidOwner uintptr, psidGroup uintptr, pDacl uintptr, pSacl uintptr) (win32err error) = advapi32.SetSecurityInfo diff --git a/internal/tools/grantvmgroupaccess/main.go b/internal/tools/grantvmgroupaccess/main.go index 887db694d0..c18a8f4134 100644 --- a/internal/tools/grantvmgroupaccess/main.go +++ b/internal/tools/grantvmgroupaccess/main.go @@ -3,6 +3,7 @@ package main import ( + "flag" "fmt" "os" @@ -10,11 +11,47 @@ import ( ) func main() { - if len(os.Args) != 2 { - fmt.Fprintln(os.Stderr, "Usage: grantvmgroupaccess.exe file") + readFlag := flag.Bool("read", false, "Grant GENERIC_READ permission (DEFAULT)") + writeFlag := flag.Bool("write", false, "Grant GENERIC_WRITE permission") + executeFlag := flag.Bool("execute", false, "Grant GENERIC_EXECUTE permission") + allFlag := flag.Bool("all", false, "Grant GENERIC_ALL permission") + + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "\nUsage of %s:\n\n", os.Args[0]) + fmt.Fprintf(os.Stderr, " %s [flags] file\n\n", os.Args[0]) + fmt.Fprintf(os.Stderr, "Flags:\n") + flag.PrintDefaults() + fmt.Fprintf(os.Stderr, "Examples:\n") + fmt.Fprintf(os.Stderr, " %s --read myfile.txt\n", os.Args[0]) + fmt.Fprintf(os.Stderr, " %s --read --execute myfile.txt\n", os.Args[0]) + } + + flag.Parse() + + if flag.NArg() == 0 { + flag.Usage() os.Exit(-1) } - if err := security.GrantVmGroupAccess(os.Args[1]); err != nil { + + desiredAccess := security.AccessMaskNone + if flag.NFlag() == 0 { + desiredAccess = security.AccessMaskRead + } + + if *readFlag { + desiredAccess |= security.AccessMaskRead + } + if *writeFlag { + desiredAccess |= security.AccessMaskWrite + } + if *executeFlag { + desiredAccess |= security.AccessMaskExecute + } + if *allFlag { + desiredAccess |= security.AccessMaskAll + } + + if err := security.GrantVmGroupAccessWithMask(flag.Arg(0), desiredAccess); err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(-1) } diff --git a/test/vendor/github.com/Microsoft/hcsshim/internal/security/grantvmgroupaccess.go b/test/vendor/github.com/Microsoft/hcsshim/internal/security/grantvmgroupaccess.go index 602920786c..bfcc157699 100644 --- a/test/vendor/github.com/Microsoft/hcsshim/internal/security/grantvmgroupaccess.go +++ b/test/vendor/github.com/Microsoft/hcsshim/internal/security/grantvmgroupaccess.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package security @@ -19,25 +20,37 @@ type ( securityInformation uint32 trusteeForm uint32 trusteeType uint32 +) - explicitAccess struct { - accessPermissions accessMask - accessMode accessMode - inheritance inheritMode - trustee trustee - } +type explicitAccess struct { + //nolint:structcheck + accessPermissions accessMask + //nolint:structcheck + accessMode accessMode + //nolint:structcheck + inheritance inheritMode + //nolint:structcheck + trustee trustee +} - trustee struct { - multipleTrustee *trustee - multipleTrusteeOperation int32 - trusteeForm trusteeForm - trusteeType trusteeType - name uintptr - } -) +type trustee struct { + //nolint:unused,structcheck + multipleTrustee *trustee + //nolint:unused,structcheck + multipleTrusteeOperation int32 + trusteeForm trusteeForm + trusteeType trusteeType + name uintptr +} const ( - accessMaskDesiredPermission accessMask = 1 << 31 // GENERIC_READ + AccessMaskNone accessMask = 0 + AccessMaskRead accessMask = 1 << 31 // GENERIC_READ + AccessMaskWrite accessMask = 1 << 30 // GENERIC_WRITE + AccessMaskExecute accessMask = 1 << 29 // GENERIC_EXECUTE + AccessMaskAll accessMask = 1 << 28 // GENERIC_ALL + + accessMaskDesiredPermission = AccessMaskRead accessModeGrant accessMode = 1 @@ -56,6 +69,7 @@ const ( shareModeRead shareMode = 0x1 shareModeWrite shareMode = 0x2 + //nolint:stylecheck // ST1003 sidVmGroup = "S-1-5-83-0" trusteeFormIsSid trusteeForm = 0 @@ -63,11 +77,20 @@ const ( trusteeTypeWellKnownGroup trusteeType = 5 ) -// GrantVMGroupAccess sets the DACL for a specified file or directory to +// GrantVmGroupAccess sets the DACL for a specified file or directory to // include Grant ACE entries for the VM Group SID. This is a golang re- // implementation of the same function in vmcompute, just not exported in // RS5. Which kind of sucks. Sucks a lot :/ -func GrantVmGroupAccess(name string) error { +func GrantVmGroupAccess(name string) error { //nolint:stylecheck // ST1003 + return GrantVmGroupAccessWithMask(name, accessMaskDesiredPermission) +} + +// GrantVmGroupAccessWithMask sets the desired DACL for a specified file or +// directory. +func GrantVmGroupAccessWithMask(name string, access accessMask) error { //nolint:stylecheck // ST1003 + if access == 0 || access<<4 != 0 { + return fmt.Errorf("invalid access mask: 0x%08x", access) + } // Stat (to determine if `name` is a directory). s, err := os.Stat(name) if err != nil { @@ -79,7 +102,9 @@ func GrantVmGroupAccess(name string) error { if err != nil { return err // Already wrapped } - defer syscall.CloseHandle(fd) + defer func() { + _ = syscall.CloseHandle(fd) + }() // Get the current DACL and Security Descriptor. Must defer LocalFree on success. ot := objectTypeFileObject @@ -89,15 +114,19 @@ func GrantVmGroupAccess(name string) error { if err := getSecurityInfo(fd, uint32(ot), uint32(si), nil, nil, &origDACL, nil, &sd); err != nil { return fmt.Errorf("%s GetSecurityInfo %s: %w", gvmga, name, err) } - defer syscall.LocalFree((syscall.Handle)(unsafe.Pointer(sd))) + defer func() { + _, _ = syscall.LocalFree((syscall.Handle)(unsafe.Pointer(sd))) + }() // Generate a new DACL which is the current DACL with the required ACEs added. // Must defer LocalFree on success. - newDACL, err := generateDACLWithAcesAdded(name, s.IsDir(), origDACL) + newDACL, err := generateDACLWithAcesAdded(name, s.IsDir(), access, origDACL) if err != nil { return err // Already wrapped } - defer syscall.LocalFree((syscall.Handle)(unsafe.Pointer(newDACL))) + defer func() { + _, _ = syscall.LocalFree((syscall.Handle)(unsafe.Pointer(newDACL))) + }() // And finally use SetSecurityInfo to apply the updated DACL. if err := setSecurityInfo(fd, uint32(ot), uint32(si), uintptr(0), uintptr(0), newDACL, uintptr(0)); err != nil { @@ -110,7 +139,10 @@ func GrantVmGroupAccess(name string) error { // createFile is a helper function to call [Nt]CreateFile to get a handle to // the file or directory. func createFile(name string, isDir bool) (syscall.Handle, error) { - namep := syscall.StringToUTF16(name) + namep, err := syscall.UTF16FromString(name) + if err != nil { + return 0, fmt.Errorf("syscall.UTF16FromString %s: %w", name, err) + } da := uint32(desiredAccessReadControl | desiredAccessWriteDac) sm := uint32(shareModeRead | shareModeWrite) fa := uint32(syscall.FILE_ATTRIBUTE_NORMAL) @@ -126,7 +158,7 @@ func createFile(name string, isDir bool) (syscall.Handle, error) { // generateDACLWithAcesAdded generates a new DACL with the two needed ACEs added. // The caller is responsible for LocalFree of the returned DACL on success. -func generateDACLWithAcesAdded(name string, isDir bool, origDACL uintptr) (uintptr, error) { +func generateDACLWithAcesAdded(name string, isDir bool, desiredAccess accessMask, origDACL uintptr) (uintptr, error) { // Generate pointers to the SIDs based on the string SIDs sid, err := syscall.StringToSid(sidVmGroup) if err != nil { @@ -139,8 +171,8 @@ func generateDACLWithAcesAdded(name string, isDir bool, origDACL uintptr) (uintp } eaArray := []explicitAccess{ - explicitAccess{ - accessPermissions: accessMaskDesiredPermission, + { + accessPermissions: desiredAccess, accessMode: accessModeGrant, inheritance: inheritance, trustee: trustee{ diff --git a/test/vendor/github.com/Microsoft/hcsshim/internal/security/syscall_windows.go b/test/vendor/github.com/Microsoft/hcsshim/internal/security/syscall_windows.go index d7096716ce..f0cdd7d20c 100644 --- a/test/vendor/github.com/Microsoft/hcsshim/internal/security/syscall_windows.go +++ b/test/vendor/github.com/Microsoft/hcsshim/internal/security/syscall_windows.go @@ -1,6 +1,6 @@ package security -//go:generate go run mksyscall_windows.go -output zsyscall_windows.go syscall_windows.go +//go:generate go run $GOPATH/src/golang.org/x/sys/windows/mkwinsyscall/mkwinsyscall.go -output zsyscall_windows.go syscall_windows.go //sys getSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, ppsidOwner **uintptr, ppsidGroup **uintptr, ppDacl *uintptr, ppSacl *uintptr, ppSecurityDescriptor *uintptr) (win32err error) = advapi32.GetSecurityInfo //sys setSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, psidOwner uintptr, psidGroup uintptr, pDacl uintptr, pSacl uintptr) (win32err error) = advapi32.SetSecurityInfo diff --git a/vendor/golang.org/x/sys/execabs/execabs.go b/vendor/golang.org/x/sys/execabs/execabs.go deleted file mode 100644 index 78192498db..0000000000 --- a/vendor/golang.org/x/sys/execabs/execabs.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2020 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package execabs is a drop-in replacement for os/exec -// that requires PATH lookups to find absolute paths. -// That is, execabs.Command("cmd") runs the same PATH lookup -// as exec.Command("cmd"), but if the result is a path -// which is relative, the Run and Start methods will report -// an error instead of running the executable. -// -// See https://blog.golang.org/path-security for more information -// about when it may be necessary or appropriate to use this package. -package execabs - -import ( - "context" - "fmt" - "os/exec" - "path/filepath" - "reflect" - "unsafe" -) - -// ErrNotFound is the error resulting if a path search failed to find an executable file. -// It is an alias for exec.ErrNotFound. -var ErrNotFound = exec.ErrNotFound - -// Cmd represents an external command being prepared or run. -// It is an alias for exec.Cmd. -type Cmd = exec.Cmd - -// Error is returned by LookPath when it fails to classify a file as an executable. -// It is an alias for exec.Error. -type Error = exec.Error - -// An ExitError reports an unsuccessful exit by a command. -// It is an alias for exec.ExitError. -type ExitError = exec.ExitError - -func relError(file, path string) error { - return fmt.Errorf("%s resolves to executable in current directory (.%c%s)", file, filepath.Separator, path) -} - -// LookPath searches for an executable named file in the directories -// named by the PATH environment variable. If file contains a slash, -// it is tried directly and the PATH is not consulted. The result will be -// an absolute path. -// -// LookPath differs from exec.LookPath in its handling of PATH lookups, -// which are used for file names without slashes. If exec.LookPath's -// PATH lookup would have returned an executable from the current directory, -// LookPath instead returns an error. -func LookPath(file string) (string, error) { - path, err := exec.LookPath(file) - if err != nil { - return "", err - } - if filepath.Base(file) == file && !filepath.IsAbs(path) { - return "", relError(file, path) - } - return path, nil -} - -func fixCmd(name string, cmd *exec.Cmd) { - if filepath.Base(name) == name && !filepath.IsAbs(cmd.Path) { - // exec.Command was called with a bare binary name and - // exec.LookPath returned a path which is not absolute. - // Set cmd.lookPathErr and clear cmd.Path so that it - // cannot be run. - lookPathErr := (*error)(unsafe.Pointer(reflect.ValueOf(cmd).Elem().FieldByName("lookPathErr").Addr().Pointer())) - if *lookPathErr == nil { - *lookPathErr = relError(name, cmd.Path) - } - cmd.Path = "" - } -} - -// CommandContext is like Command but includes a context. -// -// The provided context is used to kill the process (by calling os.Process.Kill) -// if the context becomes done before the command completes on its own. -func CommandContext(ctx context.Context, name string, arg ...string) *exec.Cmd { - cmd := exec.CommandContext(ctx, name, arg...) - fixCmd(name, cmd) - return cmd - -} - -// Command returns the Cmd struct to execute the named program with the given arguments. -// See exec.Command for most details. -// -// Command differs from exec.Command in its handling of PATH lookups, -// which are used when the program name contains no slashes. -// If exec.Command would have returned an exec.Cmd configured to run an -// executable from the current directory, Command instead -// returns an exec.Cmd that will return an error from Start or Run. -func Command(name string, arg ...string) *exec.Cmd { - cmd := exec.Command(name, arg...) - fixCmd(name, cmd) - return cmd -} diff --git a/vendor/modules.txt b/vendor/modules.txt index be16488c02..a90aef84a6 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -207,7 +207,6 @@ golang.org/x/net/trace golang.org/x/sync/errgroup # golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e ## explicit; go 1.17 -golang.org/x/sys/execabs golang.org/x/sys/internal/unsafeheader golang.org/x/sys/unix golang.org/x/sys/windows