diff --git a/bin/main.go b/bin/main.go index f2644f9..2e02753 100644 --- a/bin/main.go +++ b/bin/main.go @@ -1,8 +1,11 @@ package main import ( + "fmt" "log" "os" + "os/signal" + "syscall" "github.com/jessevdk/go-flags" "github.com/nknorg/nconnect" @@ -16,13 +19,39 @@ func main() { } }() - var opts config.Opts - _, err := flags.Parse(&opts) + var opts = &config.Opts{} + _, err := flags.Parse(opts) if err != nil { if flagsErr, ok := err.(*flags.Error); ok && flagsErr.Type == flags.ErrHelp { os.Exit(0) } log.Fatal(err) } - nconnect.Run(&opts) + + if opts.Version { + fmt.Println(config.Version) + os.Exit(0) + } + + nc, err := nconnect.NewNconnect(opts) + if err != nil { + log.Fatal(err) + } + + if opts.Client { + err = nc.StartClient() + if err != nil { + log.Fatal(err) + } + } + if opts.Server { + err = nc.StartServer() + if err != nil { + log.Fatal(err) + } + } + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + <-sigs } diff --git a/config/config.go b/config/config.go index 4df98a2..c2964ec 100644 --- a/config/config.go +++ b/config/config.go @@ -15,7 +15,6 @@ import ( "github.com/nknorg/nconnect/util" "github.com/nknorg/nkn/v2/common" - "github.com/nknorg/tuna/types" ) const ( @@ -45,8 +44,6 @@ type Opts struct { Address bool `long:"address" description:"Print client address (client mode) or admin address (server mode)"` WalletAddress bool `long:"wallet-address" description:"Print wallet address (server only)"` Version bool `long:"version" description:"Print version"` - - TunaNode *types.Node } type Config struct { diff --git a/nconnect.go b/nconnect.go index cccb92e..6e46fdc 100644 --- a/nconnect.go +++ b/nconnect.go @@ -7,10 +7,8 @@ import ( "log" "net" "os" - "os/signal" "strconv" "strings" - "syscall" "time" "github.com/eycorsican/go-tun2socks/core" @@ -30,6 +28,7 @@ import ( "github.com/nknorg/nkngomobile" "github.com/nknorg/tuna/filter" "github.com/nknorg/tuna/geo" + "github.com/nknorg/tuna/types" "gopkg.in/natefinch/lumberjack.v2" ) @@ -37,15 +36,27 @@ const ( mtu = 1500 ) -func Run(opts *config.Opts) { +type nconnect struct { + opts *config.Opts + account *nkn.Account + + walletConfig *nkn.WalletConfig + clientConfig *nkn.ClientConfig + tunnelConfig *tunnel.Config + ssConfig *ss.Config + persistConf *config.Config + + adminClientCache *admin.Client + remoteInfoCache *admin.GetInfoJSON + + tunnel *tunnel.Tunnel + tunaNode *types.Node // It is used to connect specific tuna node, mainly is for testing. +} + +func NewNconnect(opts *config.Opts) (*nconnect, error) { err := (&opts.Config).SetPlatformSpecificDefaultValues() if err != nil { - log.Fatal(err) - } - - if opts.Version { - fmt.Println(config.Version) - os.Exit(0) + return nil, err } if opts.Client == opts.Server { @@ -54,12 +65,12 @@ func Run(opts *config.Opts) { persistConf, err := config.LoadOrNewConfig(opts.ConfigFile) if err != nil { - log.Fatal(err) + return nil, err } err = mergo.Merge(&opts.Config, persistConf) if err != nil { - log.Fatal(err) + return nil, err } if len(opts.LogFileName) > 0 { @@ -72,12 +83,12 @@ func Run(opts *config.Opts) { seed, err := hex.DecodeString(opts.Seed) if err != nil { - log.Fatal(err) + return nil, err } account, err := nkn.NewAccount(seed) if err != nil { - log.Fatal(err) + return nil, err } shouldSave := false @@ -96,7 +107,7 @@ func Run(opts *config.Opts) { if shouldSave { err = persistConf.Save() if err != nil { - log.Fatal(err) + return nil, err } } @@ -221,16 +232,8 @@ func Run(opts *config.Opts) { Verbose: opts.Verbose, UDP: opts.UDP, UDPIdleTime: opts.UDPIdleTime, - TunaNode: opts.TunaNode, - } - - port, err := util.GetFreePort() - if err != nil { - log.Fatal(err) } - ssAddr := "127.0.0.1:" + strconv.Itoa(port) - ssConfig := &ss.Config{ TCP: true, Cipher: opts.Cipher, @@ -245,219 +248,255 @@ func Run(opts *config.Opts) { ssConfig.UDPSocks = true } - var tun *tunnel.Tunnel + nc := &nconnect{ + opts: opts, + account: account, + clientConfig: clientConfig, + tunnelConfig: tunnelConfig, + ssConfig: ssConfig, + walletConfig: walletConfig, + persistConf: persistConf, + } - if opts.Client { - err = (&opts.Config).VerifyClient() - if err != nil { - log.Fatal(err) - } + return nc, nil +} - // Lazy create admin client to avoid unnecessary client creation. - var adminClientCache *admin.Client - getAdminClient := func() (*admin.Client, error) { - if adminClientCache != nil { - return adminClientCache, nil - } - c, err := admin.NewClient(account, clientConfig) - if err != nil { - return nil, err - } - // Wait for more sub-clients to connect - time.Sleep(time.Second) - adminClientCache = c - return adminClientCache, nil - } +// Lazy create admin client to avoid unnecessary client creation. +func (nc *nconnect) getAdminClient() (*admin.Client, error) { + if nc.adminClientCache != nil { + return nc.adminClientCache, nil + } + c, err := admin.NewClient(nc.account, nc.clientConfig) + if err != nil { + return nil, err + } + // Wait for more sub-clients to connect + time.Sleep(time.Second) + nc.adminClientCache = c - // Lazy get remote info to avoid unnecessary rpc call. - var remoteInfoCache *admin.GetInfoJSON - getRemoteInfo := func() (*admin.GetInfoJSON, error) { - if remoteInfoCache != nil { - return remoteInfoCache, nil - } - c, err := getAdminClient() - if err != nil { - return nil, err - } - remoteInfoCache, err = c.GetInfo(opts.RemoteAdminAddr) - if err != nil { - return nil, fmt.Errorf("get remote server info error: %v. Please make sure server is online and accepting connections from this client address", err) - } - return remoteInfoCache, nil + return nc.adminClientCache, nil +} + +// Lazy get remote info to avoid unnecessary rpc call. +func (nc *nconnect) getRemoteInfo() (*admin.GetInfoJSON, error) { + if nc.remoteInfoCache != nil { + return nc.remoteInfoCache, nil + } + c, err := nc.getAdminClient() + if err != nil { + return nil, err + } + nc.remoteInfoCache, err = c.GetInfo(nc.opts.RemoteAdminAddr) + if err != nil { + return nil, fmt.Errorf("get remote server info error: %v. make sure server is online and accepting connections from this client address", err) + } + + return nc.remoteInfoCache, nil +} + +func (nc *nconnect) StartClient() error { + err := nc.opts.VerifyClient() + if err != nil { + return err + } + + remoteTunnelAddr := nc.opts.RemoteTunnelAddr + if len(remoteTunnelAddr) == 0 { + remoteInfo, err := nc.getRemoteInfo() + if err != nil { + return err } + remoteTunnelAddr = remoteInfo.Addr + } - remoteTunnelAddr := opts.RemoteTunnelAddr - if len(remoteTunnelAddr) == 0 { - remoteInfo, err := getRemoteInfo() + var vpnCIDR []*net.IPNet + if nc.opts.VPN { + vpnRoutes := nc.opts.VPNRoute + if len(vpnRoutes) == 0 { + remoteInfo, err := nc.getRemoteInfo() if err != nil { - log.Fatal(err) + return err } - remoteTunnelAddr = remoteInfo.Addr - } - - var vpnCIDR []*net.IPNet - if opts.VPN { - vpnRoutes := opts.VPNRoute - if len(vpnRoutes) == 0 { - remoteInfo, err := getRemoteInfo() - if err != nil { - log.Fatal(err) - } - if len(remoteInfo.LocalIP.Ipv4) > 0 { - vpnRoutes = make([]string, 0, len(remoteInfo.LocalIP.Ipv4)) - for _, ip := range remoteInfo.LocalIP.Ipv4 { - if ip == opts.TunAddr || ip == opts.TunGateway { - log.Printf("Skipping server's local IP %s in routes", ip) - continue - } - vpnRoutes = append(vpnRoutes, fmt.Sprintf("%s/32", ip)) + if len(remoteInfo.LocalIP.Ipv4) > 0 { + vpnRoutes = make([]string, 0, len(remoteInfo.LocalIP.Ipv4)) + for _, ip := range remoteInfo.LocalIP.Ipv4 { + if ip == nc.opts.TunAddr || ip == nc.opts.TunGateway { + log.Printf("Skipping server's local IP %s in routes", ip) + continue } + vpnRoutes = append(vpnRoutes, fmt.Sprintf("%s/32", ip)) } } - if len(vpnRoutes) > 0 { - vpnCIDR = make([]*net.IPNet, len(vpnRoutes)) - for i, cidr := range vpnRoutes { - _, cidr, err := net.ParseCIDR(cidr) - if err != nil { - log.Fatalf("Parse CIDR %s error: %v", cidr, err) - } - vpnCIDR[i] = cidr + } + if len(vpnRoutes) > 0 { + vpnCIDR = make([]*net.IPNet, len(vpnRoutes)) + for i, cidr := range vpnRoutes { + _, cidr, err := net.ParseCIDR(cidr) + if err != nil { + return fmt.Errorf("parse CIDR %s error: %v", cidr, err) } + vpnCIDR[i] = cidr } } + } - proxyAddr, err := net.ResolveTCPAddr("tcp", opts.LocalSocksAddr) - if err != nil { - log.Fatalf("Invalid proxy server address: %v", err) - } - proxyHost := proxyAddr.IP.String() - proxyPort := uint16(proxyAddr.Port) + proxyAddr, err := net.ResolveTCPAddr("tcp", nc.opts.LocalSocksAddr) + if err != nil { + return fmt.Errorf("invalid proxy server address: %v", err) + } + proxyHost := proxyAddr.IP.String() + proxyPort := uint16(proxyAddr.Port) - ssConfig.Client = ssAddr - ssConfig.Socks = opts.LocalSocksAddr + port, err := util.GetFreePort() + if err != nil { + return err + } + + ssAddr := "127.0.0.1:" + strconv.Itoa(port) + nc.ssConfig.Client = ssAddr + nc.ssConfig.Socks = nc.opts.LocalSocksAddr + + tunnel, err := tunnel.NewTunnel(nc.account, nc.opts.Identifier, ssAddr, remoteTunnelAddr, nc.opts.Tuna, nc.tunnelConfig) + if err != nil { + return err + } + nc.tunnel = tunnel - tun, err = tunnel.NewTunnel(account, opts.Identifier, ssAddr, remoteTunnelAddr, opts.Tuna, tunnelConfig) + log.Println("Client NKN address:", tunnel.Addr().String()) + log.Println("Client socks proxy listen address:", nc.opts.LocalSocksAddr) + + if nc.opts.Tun || nc.opts.VPN { + tunDevice, err := arch.OpenTunDevice(nc.opts.TunName, nc.opts.TunAddr, nc.opts.TunGateway, nc.opts.TunMask, nc.opts.TunDNS, true) if err != nil { - log.Fatal(err) + return fmt.Errorf("failed to open TUN device: %v", err) } - log.Println("Client NKN address:", tun.Addr().String()) - log.Println("Client socks proxy listen address:", opts.LocalSocksAddr) + core.RegisterOutputFn(tunDevice.Write) - if opts.Tun || opts.VPN { - tunDevice, err := arch.OpenTunDevice(opts.TunName, opts.TunAddr, opts.TunGateway, opts.TunMask, opts.TunDNS, true) - if err != nil { - log.Fatalf("Failed to open TUN device: %v", err) - } + core.RegisterTCPConnHandler(socks.NewTCPHandler(proxyHost, proxyPort)) + core.RegisterUDPConnHandler(socks.NewUDPHandler(proxyHost, proxyPort, 30*time.Second)) - core.RegisterOutputFn(tunDevice.Write) + lwipWriter := core.NewLWIPStack() - core.RegisterTCPConnHandler(socks.NewTCPHandler(proxyHost, proxyPort)) - core.RegisterUDPConnHandler(socks.NewUDPHandler(proxyHost, proxyPort, 30*time.Second)) + go func() { + _, err := io.CopyBuffer(lwipWriter, tunDevice, make([]byte, mtu)) + if err != nil { + log.Fatalf("Failed to write data to network stack: %v", err) + } + }() - lwipWriter := core.NewLWIPStack() + log.Println("Started tun2socks") - go func() { - _, err := io.CopyBuffer(lwipWriter, tunDevice, make([]byte, mtu)) + if nc.opts.VPN { + for _, dest := range vpnCIDR { + log.Printf("Adding route %s", dest) + out, err := arch.AddRouteCmd(dest, nc.opts.TunGateway, nc.opts.TunName) + if len(out) > 0 { + os.Stdout.Write(out) + } if err != nil { - log.Fatalf("Failed to write data to network stack: %v", err) + os.Stdout.Write([]byte(util.ParseExecError(err))) + os.Exit(1) } - }() - - log.Println("Started tun2socks") - - if opts.VPN { - for _, dest := range vpnCIDR { - log.Printf("Adding route %s", dest) - out, err := arch.AddRouteCmd(dest, opts.TunGateway, opts.TunName) + defer func(dest *net.IPNet) { + log.Printf("Deleting route %s", dest) + out, err := arch.DeleteRouteCmd(dest, nc.opts.TunGateway, nc.opts.TunName) if len(out) > 0 { os.Stdout.Write(out) } if err != nil { os.Stdout.Write([]byte(util.ParseExecError(err))) - os.Exit(1) } - defer func(dest *net.IPNet) { - log.Printf("Deleting route %s", dest) - out, err := arch.DeleteRouteCmd(dest, opts.TunGateway, opts.TunName) - if len(out) > 0 { - os.Stdout.Write(out) - } - if err != nil { - os.Stdout.Write([]byte(util.ParseExecError(err))) - } - }(dest) - } + }(dest) } } } - if opts.Server { - err = (&opts.Config).VerifyServer() + nc.startSSAndTunnel() + + return nil +} + +func (nc *nconnect) StartServer() error { + err := nc.opts.VerifyServer() + if err != nil { + return err + } + + port, err := util.GetFreePort() + if err != nil { + return err + } + ssAddr := "127.0.0.1:" + strconv.Itoa(port) + nc.ssConfig.Server = ssAddr + + if nc.opts.Tuna { + minBalance, err := common.StringToFixed64(nc.opts.TunaMinBalance) if err != nil { - log.Fatal(err) + return err } - ssConfig.Server = ssAddr - - if opts.Tuna { - minBalance, err := common.StringToFixed64(opts.TunaMinBalance) + if minBalance > 0 { + w, err := nkn.NewWallet(nc.account, nc.walletConfig) if err != nil { - log.Fatal(err) + return err } - if minBalance > 0 { - w, err := nkn.NewWallet(account, walletConfig) - if err != nil { - log.Fatal(err) - } - - balance, err := w.Balance() - if err != nil { - log.Println("Fetch balance error:", err) - } else if balance.ToFixed64() < minBalance { - log.Printf("Wallet balance %s is less than minimal balance to enable tuna %s, tuna will not be enabled", balance.String(), opts.TunaMinBalance) - opts.Tuna = false - } + balance, err := w.Balance() + if err != nil { + log.Println("Fetch balance error:", err) + } else if balance.ToFixed64() < minBalance { + log.Printf("Wallet balance %s is less than minimal balance to enable tuna %s, tuna will not be enabled", + balance.String(), nc.opts.TunaMinBalance) + nc.opts.Tuna = false } } + } - tun, err = tunnel.NewTunnel(account, opts.Identifier, "", ssAddr, opts.Tuna, tunnelConfig) - if err != nil { - log.Fatal(err) - } + if nc.tunaNode != nil { + nc.tunnelConfig.TunaNode = nc.tunaNode + } + tunnel, err := tunnel.NewTunnel(nc.account, nc.opts.Identifier, "", ssAddr, nc.opts.Tuna, nc.tunnelConfig) + if err != nil { + return err + } + nc.tunnel = tunnel + log.Println("Tunnel listen address:", tunnel.FromAddr()) + + if len(nc.opts.AdminIdentifier) > 0 { + go func() { + identifier := nc.opts.AdminIdentifier + if len(nc.opts.Identifier) > 0 { + identifier += "." + nc.opts.Identifier + } + err := admin.StartNKNServer(nc.account, identifier, nc.clientConfig, tunnel, nc.persistConf, &nc.opts.Config) + if err != nil { + log.Fatal(err) + } + os.Exit(0) + }() + log.Println("Admin listening address:", nc.opts.AdminIdentifier+"."+tunnel.FromAddr()) + } - log.Println("Tunnel listen address:", tun.FromAddr()) + if len(nc.opts.AdminHTTPAddr) > 0 { + go func() { + err := admin.StartWebServer(nc.opts.AdminHTTPAddr, tunnel, nc.persistConf, &nc.opts.Config) + if err != nil { + log.Fatal(err) + } + os.Exit(0) + }() + log.Println("Admin web dashboard listening address:", nc.opts.AdminHTTPAddr) + } - if len(opts.AdminIdentifier) > 0 { - go func() { - identifier := opts.AdminIdentifier - if len(opts.Identifier) > 0 { - identifier += "." + opts.Identifier - } - err := admin.StartNKNServer(account, identifier, clientConfig, tun, persistConf, &opts.Config) - if err != nil { - log.Fatal(err) - } - os.Exit(0) - }() - log.Println("Admin listening address:", opts.AdminIdentifier+"."+tun.FromAddr()) - } + nc.startSSAndTunnel() - if len(opts.AdminHTTPAddr) > 0 { - go func() { - err := admin.StartWebServer(opts.AdminHTTPAddr, tun, persistConf, &opts.Config) - if err != nil { - log.Fatal(err) - } - os.Exit(0) - }() - log.Println("Admin web dashboard listening address:", opts.AdminHTTPAddr) - } - } + return nil +} +func (nc *nconnect) startSSAndTunnel() { go func() { - err := ss.Start(ssConfig) + err := ss.Start(nc.ssConfig) if err != nil { log.Fatal(err) } @@ -465,14 +504,14 @@ func Run(opts *config.Opts) { }() go func() { - err := tun.Start() + err := nc.tunnel.Start() if err != nil { log.Fatal(err) } os.Exit(0) }() +} - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) - <-sigs +func (nc *nconnect) SetTunaNode(node *types.Node) { + nc.tunaNode = node } diff --git a/tests/pub.go b/tests/pub.go index dae786c..9bd6ff9 100644 --- a/tests/pub.go +++ b/tests/pub.go @@ -31,15 +31,18 @@ func startNconnect(configFile string, n *types.Node) error { return err } - // fmt.Printf("opts: %+v\n", opts) - opts.TunaNode = n + nc, _ := nconnect.NewNconnect(opts) + if opts.Server { + nc.SetTunaNode(n) + nc.StartServer() + } else { + nc.StartClient() + } - nconnect.Run(opts) return nil } func startTunaNode() (*types.Node, error) { - // Set up tuna tunaSeed, _ := hex.DecodeString(seedHex) acc, err := nkn.NewAccount(tunaSeed) if err != nil { @@ -80,11 +83,13 @@ func runReverseEntry(seed []byte) error { if err != nil { return err } + entryConfig := new(tuna.EntryConfiguration) err = util.ReadJSON("config.reverse.entry.json", entryConfig) if err != nil { return err } + err = tuna.StartReverse(entryConfig, entryWallet) if err != nil { return err diff --git a/tests/socks5_proxy_test.go b/tests/socks5_proxy_test.go index 420db2d..e11ee18 100644 --- a/tests/socks5_proxy_test.go +++ b/tests/socks5_proxy_test.go @@ -85,11 +85,11 @@ func TestTCPSocks5Proxy(t *testing.T) { // go test -v -run=TestUDPSocks5Proxy func TestUDPSocks5Proxy(t *testing.T) { - for i := 0; i < 5; i++ { + for i := 1; i <= 5; i++ { err := brook.Socks5Test("127.0.0.1:1080", "", "", "http3.ooo", "137.184.237.95", "8.8.8.8:53") if err != nil { fmt.Printf("TestUDPSocks5Proxy try %v err: %v\n", i, err) - time.Sleep(time.Second) + time.Sleep(time.Duration(i) * time.Second) } else { break }