@@ -21,13 +21,14 @@ import (
2121
2222// UpStream creates upstream handler struct
2323type UpStream struct {
24- Name string
25- proxy http.Handler
24+ Name string
25+ proxy http.Handler
2626 // TODO: Kick out separat config options and use more generic one
27- allowed []* regexp.Regexp
28- bindMounts []string
29- devMappings []string
27+ allowed []* regexp.Regexp
28+ bindMounts []string
29+ devMappings []string
3030 gpu bool
31+ pinUser string
3132}
3233
3334// UnixSocket just provides the path, so that I can test it
@@ -64,21 +65,23 @@ func newReverseProxy(dial func(network, addr string) (net.Conn, error)) *httputi
6465}
6566
6667// NewUpstream returns a new socket (magic)
67- func NewUpstream (socket string , regs []string , binds []string , devs []string , gpu bool ) * UpStream {
68+ func NewUpstream (socket string , regs []string , binds []string , devs []string , gpu bool , pinUser string ) * UpStream {
6869 us := NewUnixSocket (socket )
6970 a := []* regexp.Regexp {}
7071 for _ , r := range regs {
7172 p , _ := regexp .Compile (r )
7273 a = append (a , p )
7374 }
74- return & UpStream {
75+ upstream := & UpStream {
7576 Name : socket ,
7677 proxy : newReverseProxy (us .connectSocket ),
7778 allowed : a ,
7879 bindMounts : binds ,
7980 devMappings : devs ,
8081 gpu : gpu ,
82+ pinUser : pinUser ,
8183 }
84+ return upstream
8285}
8386
8487
@@ -97,11 +100,24 @@ func (u *UpStream) ServeHTTP(w http.ResponseWriter, req *http.Request) {
97100 http.Error(w, fmt.Sprintf("Only GET requests are allowed, req.Method: %s", req.Method), 400)
98101 return
99102 }*/
103+ /*
104+ // Hijack the connection to inspect who called it
105+ hj, ok := w.(http.Hijacker)
106+ if !ok {
107+ http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError)
108+ return
109+ }
110+ conn, _, err := hj.Hijack()
111+ if err != nil {
112+ http.Error(w, err.Error(), http.StatusInternalServerError)
113+ return
114+ }*/
100115 // Read the body
101116 body , err := ioutil .ReadAll (req .Body )
102117 if err != nil {
103118 fmt .Println (err .Error ())
104119 }
120+ //syscall.GetsockoptUcred(int(fd), syscall.SOL_SOCKET, syscall.SO_PEERCRED)
105121 //fmt.Printf("%v\n", hostConfig.Mounts)
106122 // And now set a new body, which will simulate the same data we read:
107123 req .Body = ioutil .NopCloser (bytes .NewBuffer (body ))
@@ -125,6 +141,15 @@ func (u *UpStream) ServeHTTP(w http.ResponseWriter, req *http.Request) {
125141 config .Env = append (config .Env , "PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" )
126142 config .Env = append (config .Env , "LD_LIBRARY_PATH=/usr/local/nvidia/" )
127143 }
144+ if u .pinUser != "" {
145+ // TODO: Should depend on calling user from syscall.GetsockoptUcred()
146+ if config .User != "" {
147+ fmt .Printf ("Overwrite User with '%s', was '%s'\n " , u .pinUser , config .User )
148+ } else {
149+ fmt .Printf ("Overwrite User with '%s'\n " , u .pinUser )
150+ }
151+ config .User = u .pinUser
152+ }
128153 for _ , bMount := range u .bindMounts {
129154 if bMount == "" {
130155 continue
0 commit comments