diff --git a/backend/api/controller/asset.go b/backend/api/controller/asset.go index 8a37343..ae0afe8 100644 --- a/backend/api/controller/asset.go +++ b/backend/api/controller/asset.go @@ -109,13 +109,17 @@ func (c *Controller) GetAssets(ctx *gin.Context) { db = db.Where("parent_id IN ?", parentIds) } - if info && !acl.IsAdmin(currentUser) { - ids, err := GetAssetIdsByAuthorization(ctx) - if err != nil { - ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}}) - return + if info { + db = db.Select("id", "parent_id", "name", "ip", "protocols", "connectable", "authorization") + + if !acl.IsAdmin(currentUser) { + ids, err := GetAssetIdsByAuthorization(ctx) + if err != nil { + ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}}) + return + } + db = db.Where("id IN ?", ids) } - db = db.Where("id IN ?", ids) } db = db.Order("name") diff --git a/backend/api/controller/connect.go b/backend/api/controller/connect.go index 5e2b38c..fdf029a 100644 --- a/backend/api/controller/connect.go +++ b/backend/api/controller/connect.go @@ -346,7 +346,7 @@ func connectSsh(ctx *gin.Context, sess *gsession.Session, asset *model.Asset, ac } }() - ip, port, err := util.Proxy(sess.SessionId, "ssh", asset, gateway) + ip, port, err := util.Proxy(false, sess.SessionId, "ssh", asset, gateway) if err != nil { return } diff --git a/backend/api/controller/controller.go b/backend/api/controller/controller.go index e54fa33..12b5ea4 100644 --- a/backend/api/controller/controller.go +++ b/backend/api/controller/controller.go @@ -266,8 +266,8 @@ func doUpdate[T model.Model](ctx *gin.Context, needAcl bool, md T, resourceType } if needAcl { md.SetResourceId(old.GetResourceId()) - fmt.Printf("%+v\n", old) - fmt.Printf("%+v\n", md) + // fmt.Printf("%+v\n", old) + // fmt.Printf("%+v\n", md) if !hasPerm(ctx, md, resourceType, acl.WRITE) { ctx.AbortWithError(http.StatusForbidden, &ApiError{Code: ErrNoPerm, Data: map[string]any{"perm": acl.WRITE}}) return diff --git a/backend/api/controller/node.go b/backend/api/controller/node.go index 5599162..32d1e77 100644 --- a/backend/api/controller/node.go +++ b/backend/api/controller/node.go @@ -98,7 +98,7 @@ func (c *Controller) GetNodes(ctx *gin.Context) { if err != nil { return } - db = db.Where("id NOT IN ?", ids) + db = db.Where("id IN ?", ids) } if id, ok := ctx.GetQuery("self_parent"); ok { @@ -109,15 +109,18 @@ func (c *Controller) GetNodes(ctx *gin.Context) { db = db.Where("id IN ?", ids) } - if info && !acl.IsAdmin(currentUser) { - ids, err := GetNodeIdsByAuthorization(ctx) - if err != nil { - return - } - if ids, err = handleSelfParent(ctx, ids...); err != nil { - return + if info { + db = db.Select("id", "parent_id", "name") + if !acl.IsAdmin(currentUser) { + ids, err := GetNodeIdsByAuthorization(ctx) + if err != nil { + return + } + if ids, err = handleSelfParent(ctx, ids...); err != nil { + return + } + db = db.Where("id IN ?", ids) } - db = db.Where("id IN ?", ids) } doGet(ctx, !info, db, conf.RESOURCE_NODE, nodePostHooks...) @@ -253,7 +256,7 @@ func nodeDelHook(ctx *gin.Context, id int) { ctx.AbortWithError(http.StatusBadRequest, err) } -func handleNoSelfChild(ctx context.Context, ids ...int) (res []int, err error) { +func handleSelfChild(ctx context.Context, ids ...int) (res []int, err error) { nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode) if err != nil { return @@ -274,10 +277,12 @@ func handleNoSelfChild(ctx context.Context, ids ...int) (res []int, err error) { } dfs(0, false) + res = lo.Uniq(append(res, ids...)) + return } -func handleNoSelfParent(ctx context.Context, ids ...int) (res []int, err error) { +func handleSelfParent(ctx context.Context, ids ...int) (res []int, err error) { nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode) if err != nil { return @@ -290,10 +295,10 @@ func handleNoSelfParent(ctx context.Context, ids ...int) (res []int, err error) t := make([]int, 0) var dfs func(int) dfs = func(x int) { + t = append(t, x) if lo.Contains(ids, x) { res = append(res, t...) } - t = append(t, x) for _, y := range g[x] { dfs(y) } @@ -301,27 +306,39 @@ func handleNoSelfParent(ctx context.Context, ids ...int) (res []int, err error) } dfs(0) - res = lo.Uniq(res) + res = lo.Uniq(append(res, ids...)) return } -func handleSelfParent(ctx context.Context, ids ...int) (res []int, err error) { - res, err = handleNoSelfParent(ctx, ids...) +func handleNoSelfChild(ctx context.Context, ids ...int) (res []int, err error) { + nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode) if err != nil { return } - res = lo.Uniq(append(res, ids...)) + allids := lo.Map(nodes, func(n *model.Node, _ int) int { return n.Id }) + + res, err = handleSelfChild(ctx, ids...) + if err != nil { + return + } + res = lo.Uniq(lo.Without(allids, res...)) return } -func handleSelfChild(ctx context.Context, ids ...int) (res []int, err error) { - res, err = handleNoSelfChild(ctx, ids...) +func handleNoSelfParent(ctx context.Context, ids ...int) (res []int, err error) { + nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode) if err != nil { return } - res = lo.Uniq(append(res, ids...)) + allids := lo.Map(nodes, func(n *model.Node, _ int) int { return n.Id }) + + res, err = handleSelfParent(ctx, ids...) + if err != nil { + return + } + res = lo.Uniq(lo.Without(allids, res...)) return } diff --git a/backend/api/file/file.go b/backend/api/file/file.go index eb7d205..f09faef 100644 --- a/backend/api/file/file.go +++ b/backend/api/file/file.go @@ -75,7 +75,7 @@ func (fm *FileManager) GetFileClient(assetId, accountId int) (cli *sftp.Client, return } - ip, port, err := util.Proxy(uuid.New().String(), "sftp,ssh", asset, gateway) + ip, port, err := util.Proxy(false, uuid.New().String(), "sftp,ssh", asset, gateway) if err != nil { return } diff --git a/backend/api/guacd/conn.go b/backend/api/guacd/conn.go index 3ad6db5..16e3bcf 100644 --- a/backend/api/guacd/conn.go +++ b/backend/api/guacd/conn.go @@ -92,7 +92,7 @@ func NewTunnel(connectionId, sessionId string, w, h, dpi int, protocol string, a t.Config.Parameters["recording-name"] = t.SessionId } if gateway != nil && gateway.Id != 0 && t.ConnectionId == "" { - t.gw, err = ggateway.GetGatewayManager().Open(t.SessionId, asset.Ip, cast.ToInt(port), gateway) + t.gw, err = ggateway.GetGatewayManager().Open(false, t.SessionId, asset.Ip, cast.ToInt(port), gateway) if err != nil { return t, err } diff --git a/backend/gateway/gateway.go b/backend/gateway/gateway.go index 88af75b..ea7e461 100644 --- a/backend/gateway/gateway.go +++ b/backend/gateway/gateway.go @@ -17,7 +17,7 @@ import ( var ( manager = &GateWayManager{ - gateways: map[string]*GatewayTunnel{}, + gatewayTunnels: map[string]*GatewayTunnel{}, sshClients: map[int]*ssh.Client{}, sshClientsCount: map[int]int{}, mtx: sync.Mutex{}, @@ -28,8 +28,8 @@ func GetGatewayManager() *GateWayManager { return manager } -func GetGatewayBySessionId(sessionId string) *GatewayTunnel { - return manager.gateways[sessionId] +func GetGatewayTunnelBySessionId(sessionId string) *GatewayTunnel { + return manager.gatewayTunnels[sessionId] } type GatewayTunnel struct { @@ -45,13 +45,14 @@ type GatewayTunnel struct { Opened chan error } -func (gt *GatewayTunnel) Open() (err error) { +func (gt *GatewayTunnel) Open(isConnectable bool) (err error) { go func() { <-time.After(time.Second * 3) logger.L().Debug("timeout 3 second close listener", zap.String("sessionId", gt.SessionId)) gt.listener.Close() }() defer func() { + logger.L().Debug("close listener", zap.String("sessionId", gt.SessionId), zap.Error(err)) gt.Opened <- err }() gt.Opened <- nil @@ -65,11 +66,20 @@ func (gt *GatewayTunnel) Open() (err error) { defer cancel() gt.RemoteConn, err = manager.sshClients[gt.GatewayId].DialContext(ctx, "tcp", remoteAddr) if err != nil { - defer gt.LocalConn.Close() - defer gt.RemoteConn.Close() + defer func() { + if gt.LocalConn != nil { + defer gt.LocalConn.Close() + } + if gt.RemoteConn != nil { + defer gt.RemoteConn.Close() + } + }() logger.L().Error("dial remote failed", zap.String("sessionId", gt.SessionId), zap.Error(err)) return } + if isConnectable { + return + } go io.Copy(gt.LocalConn, gt.RemoteConn) go io.Copy(gt.RemoteConn, gt.LocalConn) @@ -77,13 +87,13 @@ func (gt *GatewayTunnel) Open() (err error) { } type GateWayManager struct { - gateways map[string]*GatewayTunnel + gatewayTunnels map[string]*GatewayTunnel sshClients map[int]*ssh.Client sshClientsCount map[int]int mtx sync.Mutex } -func (gm *GateWayManager) Open(sessionId, remoteIp string, remotePort int, gateway *model.Gateway) (g *GatewayTunnel, err error) { +func (gm *GateWayManager) Open(isConnectable bool, sessionId, remoteIp string, remotePort int, gateway *model.Gateway) (g *GatewayTunnel, err error) { if gateway == nil { err = fmt.Errorf("gateway is nil") return @@ -109,7 +119,7 @@ func (gm *GateWayManager) Open(sessionId, remoteIp string, remotePort int, gatew return } go func() { - logger.L().Debug("ssh client closed", zap.Int("gatewayId", gateway.Id), zap.Error(sshCli.Wait())) + logger.L().Debug("ssh proxy wait closed", zap.Int("gatewayId", gateway.Id), zap.Error(sshCli.Wait())) delete(gm.sshClients, gateway.Id) }() } @@ -133,8 +143,8 @@ func (gm *GateWayManager) Open(sessionId, remoteIp string, remotePort int, gatew RemotePort: remotePort, Opened: make(chan error), } - gm.gateways[sessionId] = g - go g.Open() + gm.gatewayTunnels[sessionId] = g + go g.Open(isConnectable) logger.L().Debug("opening gateway", zap.Any("sessionId", sessionId)) <-g.Opened @@ -147,13 +157,15 @@ func (gm *GateWayManager) Close(sessionIds ...string) { gm.mtx.Lock() defer gm.mtx.Unlock() for _, sid := range sessionIds { - gt, ok := gm.gateways[sid] + gt, ok := gm.gatewayTunnels[sid] if !ok { return } gm.sshClientsCount[gt.GatewayId] -= 1 if gm.sshClientsCount[gt.GatewayId] <= 0 { - gm.sshClients[gt.GatewayId].Close() + if g := gm.sshClients[gt.GatewayId]; g != nil { + g.Close() + } delete(gm.sshClients, gt.GatewayId) delete(gm.sshClientsCount, gt.GatewayId) } diff --git a/backend/schedule/connectable.go b/backend/schedule/connectable.go index c09c7a4..51764c1 100644 --- a/backend/schedule/connectable.go +++ b/backend/schedule/connectable.go @@ -8,7 +8,6 @@ import ( "github.com/google/uuid" "github.com/samber/lo" - "github.com/spf13/cast" "go.uber.org/zap" mysql "github.com/veops/oneterm/db" @@ -80,31 +79,28 @@ func UpdateConnectables(ids ...int) (err error) { func updateConnectable(asset *model.Asset, gateway *model.Gateway) (sid string, ok bool) { sid = uuid.New().String() - for _, p := range asset.Protocols { - ip, port := asset.Ip, cast.ToInt(strings.Split(p, ":")[1]) - var ( - gt *ggateway.GatewayTunnel - err error - ) - if asset.GatewayId != 0 { - gt, err = ggateway.GetGatewayManager().Open(sid, ip, port, gateway) - if err != nil { - logger.L().Debug("open gateway failed", zap.Error(err)) - continue - } - ip, port = gt.LocalIp, gt.LocalPort - <-gt.Opened + ps := strings.Join(lo.Map(asset.Protocols, func(p string, _ int) string { return strings.Split(p, ":")[0] }), ",") + ip, port, err := util.Proxy(true, sid, ps, asset, gateway) + if err != nil { + logger.L().Debug("connectable proxy failed", zap.String("protocol", ps), zap.Error(err)) + return + } + addr := fmt.Sprintf("%s:%d", ip, port) + conn, err := net.DialTimeout("tcp", addr, time.Second) + if err != nil { + logger.L().Debug("dail failed", zap.String("addr", addr), zap.Error(err)) + return + } + defer conn.Close() + if asset.GatewayId != 0 { + t := ggateway.GetGatewayTunnelBySessionId(sid) + if t == nil { + return } - addr := fmt.Sprintf("%s:%d", ip, port) - net, err := net.DialTimeout("tcp", addr, time.Second*3) - if err != nil { - logger.L().Debug("dail failed", zap.String("addr", addr), zap.Error(err)) - continue + if err = <-t.Opened; err != nil { + return } - defer net.Close() - - ok = true - return } + ok = true return } diff --git a/backend/util/ssh.go b/backend/util/ssh.go index 171e280..b390ac5 100644 --- a/backend/util/ssh.go +++ b/backend/util/ssh.go @@ -58,7 +58,7 @@ func GetAuth(account *model.Account) (ssh.AuthMethod, error) { } } -func Proxy(sessionId string, protocol string, asset *model.Asset, gateway *model.Gateway) (ip string, port int, err error) { +func Proxy(isConnectable bool, sessionId string, protocol string, asset *model.Asset, gateway *model.Gateway) (ip string, port int, err error) { ip, port = asset.Ip, 0 for _, tp := range strings.Split(protocol, ",") { for _, p := range asset.Protocols { @@ -74,8 +74,7 @@ func Proxy(sessionId string, protocol string, asset *model.Asset, gateway *model return } - - g, err := ggateway.GetGatewayManager().Open(sessionId, ip, port, gateway) + g, err := ggateway.GetGatewayManager().Open(isConnectable, sessionId, ip, port, gateway) if err != nil { return }