Skip to content

Commit

Permalink
refactor: refactor mvp version
Browse files Browse the repository at this point in the history
  • Loading branch information
ICKelin committed Aug 15, 2024
1 parent 0ae6b4d commit 77ca8e0
Show file tree
Hide file tree
Showing 12 changed files with 242 additions and 53 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@ go.work.sum

# env file
.env

release

.idea
6 changes: 6 additions & 0 deletions build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
rm -r release
GOOS=linux GOARCH=amd64 go build -o release/zta-gw_linux_amd64 gateway/*.go
GOOS=linux GOARCH=amd64 go build -o release/zta-client_linux_amd64 client/*.go

GOOS=darwin GOARCH=amd64 go build -o release/zta-gw_darwin_amd64 gateway/*.go
GOOS=darwin GOARCH=amd64 go build -o release/zta-client_darwin_amd64 client/*.go
12 changes: 5 additions & 7 deletions common/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ import (
)

const (
version = 0
cmdPP = 0x0
cmdHandshake = 0x1
)

// 私有协议头部

type ClientInfo struct {
ClientID string
PublicProtocol string
Expand All @@ -32,14 +35,9 @@ type ProxyProtocol struct {
InternalPort uint16
}

// 1byte version
// 1byte cmd
// 2bytes length
// length body

func (pp *ProxyProtocol) Encode() ([]byte, error) {
hdr := make([]byte, 4)
hdr[0] = 0x0
hdr[0] = version
hdr[1] = cmdPP

body, err := json.Marshal(pp)
Expand Down Expand Up @@ -86,7 +84,7 @@ type HandshakeReq struct {

func (req *HandshakeReq) Encode() ([]byte, error) {
hdr := make([]byte, 4)
hdr[0] = 0x0
hdr[0] = version
hdr[1] = cmdHandshake

body, err := json.Marshal(req)
Expand Down
4 changes: 3 additions & 1 deletion etc/gateway.yaml
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
listener_file: ./proxy.json
gateway:
listen_addr: ":12359"
listener_file: ./listener.json
File renamed without changes.
7 changes: 6 additions & 1 deletion gateway/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ import (
)

type Config struct {
ListenerFile string `yaml:"listener_file"`
GatewayConfig *GatewayConfig `yaml:"gateway"`
ListenerFile string `yaml:"listener_file"`
}

type GatewayConfig struct {
ListenAddr string `yaml:"listen_addr"`
}

func ParseConfig(confFile string) (*Config, error) {
Expand Down
22 changes: 18 additions & 4 deletions gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,30 @@ import (
)

type Gateway struct {
ListenAddr string
conf *GatewayConfig
clientIDs map[string]struct{}
sessionMgr *SessionManager
}

func NewGateway(listenAddr string, sessionMgr *SessionManager) *Gateway {
func NewGateway(conf *GatewayConfig, sessionMgr *SessionManager) *Gateway {
gw := &Gateway{
ListenAddr: listenAddr,
conf: conf,
sessionMgr: sessionMgr,
}
go gw.checkOnlineInterval()
return gw
}

func (gw *Gateway) SetAvailableClientIDs(clientIDs []string) {
clientIDsMap := make(map[string]struct{})
for _, clientID := range clientIDs {
clientIDsMap[clientID] = struct{}{}
}
gw.clientIDs = clientIDsMap
}

func (gw *Gateway) ListenAndServe() error {
listener, err := net.Listen("tcp", gw.ListenAddr)
listener, err := net.Listen("tcp", gw.conf.ListenAddr)
if err != nil {
return err
}
Expand All @@ -46,6 +55,11 @@ func (gw *Gateway) handleConn(conn net.Conn) {
return
}

if _, ok := gw.clientIDs[handshakeReq.ClientID]; !ok {
logs.Warn("client %s is not configured", handshakeReq.ClientID)
return
}

logs.Debug("handshake from %s", handshakeReq.ClientID)

// 创建session
Expand Down
43 changes: 26 additions & 17 deletions gateway/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@ var (
)

type Listener struct {
pp *common.ProxyProtocol
sessionMgr *SessionManager
closeOnce sync.Once
close chan struct{}
tcpListener net.Listener
listenerConfig *ListenerConfig
sessionMgr *SessionManager
closeOnce sync.Once
close chan struct{}
tcpListener net.Listener
}

func NewListener(pp *common.ProxyProtocol, sessionMgr *SessionManager) *Listener {
func NewListener(listenerConfig *ListenerConfig, sessionMgr *SessionManager) *Listener {
return &Listener{
pp: pp,
close: make(chan struct{}),
sessionMgr: sessionMgr,
listenerConfig: listenerConfig,
close: make(chan struct{}),
sessionMgr: sessionMgr,
}
}

func (l *Listener) ListenAndServe() error {
switch l.pp.PublicProtocol {
switch l.listenerConfig.PublicProtocol {
case "tcp":
return l.listenAndServeTCP()
default:
Expand All @@ -40,7 +40,7 @@ func (l *Listener) ListenAndServe() error {
}

func (l *Listener) listenAndServeTCP() error {
listenAddr := fmt.Sprintf("%s:%d", l.pp.PublicIP, l.pp.PublicPort)
listenAddr := fmt.Sprintf("%s:%d", l.listenerConfig.PublicIP, l.listenerConfig.PublicPort)
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
return err
Expand All @@ -62,25 +62,34 @@ func (l *Listener) handleConn(conn net.Conn) {
defer conn.Close()

// 查询session
tunnelConn, err := l.sessionMgr.GetSessionByClientID(l.pp.ClientID)
tunnelConn, err := l.sessionMgr.GetSessionByClientID(l.listenerConfig.ClientID)
if err != nil {
logs.Warn("get session for client %s fail", l.pp.ClientID)
logs.Warn("get session for client %s fail", l.listenerConfig.ClientID)
return
}
defer tunnelConn.Close()

// 封装proxyprotocol
ppbody, err := l.pp.Encode()
pp := &common.ProxyProtocol{
ClientID: l.listenerConfig.ClientID,
PublicProtocol: l.listenerConfig.PublicProtocol,
PublicIP: l.listenerConfig.PublicIP,
PublicPort: l.listenerConfig.PublicPort,
InternalProtocol: l.listenerConfig.InternalProtocol,
InternalIP: l.listenerConfig.InternalIP,
InternalPort: l.listenerConfig.InternalPort,
}
ppBody, err := pp.Encode()
if err != nil {
logs.Warn("encode pp fail: %v ", err)
logs.Warn("encode listenerConfig fail: %v ", err)
return
}

tunnelConn.SetWriteDeadline(time.Now().Add(writeTimeout))
_, err = tunnelConn.Write(ppbody)
_, err = tunnelConn.Write(ppBody)
tunnelConn.SetWriteDeadline(time.Time{})
if err != nil {
logs.Warn("write pp body fail: %v", err)
logs.Warn("write listenerConfig body fail: %v", err)
return
}

Expand Down
16 changes: 5 additions & 11 deletions gateway/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"flag"
"github.com/ICKelin/zta/common"
)

func main() {
Expand All @@ -22,26 +21,21 @@ func main() {

sessionMgr := NewSessionManager()

clientIDs := make([]string, 0)
for _, listenerConfig := range listenerConfigs {
listener := NewListener(&common.ProxyProtocol{
ClientID: listenerConfig.ClientID,
PublicProtocol: listenerConfig.PublicProtocol,
PublicIP: listenerConfig.PublicIP,
PublicPort: listenerConfig.PublicPort,
InternalProtocol: listenerConfig.InternalProtocol,
InternalIP: listenerConfig.InternalIP,
InternalPort: listenerConfig.InternalPort,
}, sessionMgr)
listener := NewListener(listenerConfig, sessionMgr)
go func() {
defer listener.Close()
err := listener.ListenAndServe()
if err != nil {
panic(err)
}
}()
clientIDs = append(clientIDs, listenerConfig.ClientID)
}

gw := NewGateway(":12359", sessionMgr)
gw := NewGateway(conf.GatewayConfig, sessionMgr)
gw.SetAvailableClientIDs(clientIDs)
err = gw.ListenAndServe()
if err != nil {
panic(err)
Expand Down
12 changes: 0 additions & 12 deletions gateway/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,6 @@ func (mgr *SessionManager) CreateSession(clientID string, conn net.Conn) (*Sessi
return sess, nil
}

func (mgr *SessionManager) CloseSession(clientID string) {
mgr.sessionsMu.Lock()
defer mgr.sessionsMu.Unlock()
sess := mgr.sessions[clientID]
if sess == nil {
return
}

sess.Connection.Close()
delete(mgr.sessions, clientID)
}

func (mgr *SessionManager) Range(f func(k string, v *Session) bool) {
mgr.sessionsMu.Lock()
defer mgr.sessionsMu.Unlock()
Expand Down
11 changes: 11 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module github.com/ICKelin/zta

go 1.22.2

require (
github.com/alecthomas/gometalinter v3.0.0+incompatible
github.com/astaxie/beego v1.12.3
github.com/xtaci/smux v1.5.27
)

require github.com/shiena/ansicolor v0.0.0-20151119151921-a422bbe96644 // indirect
Loading

0 comments on commit 77ca8e0

Please sign in to comment.