Skip to content

Commit

Permalink
add tls fallback proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
p4gefau1t committed Mar 23, 2020
1 parent 037b698 commit 76fd836
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 48 deletions.
2 changes: 1 addition & 1 deletion common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func HumanFriendlyTraffic(bytes int) string {
if bytes <= GiB {
return fmt.Sprintf("%.2f MiB", float32(bytes)/MiB)
}
return fmt.Sprintf("%.2f TiB", float32(bytes)/GiB)
return fmt.Sprintf("%.2f GiB", float32(bytes)/GiB)
}

func ConnectDatabase(driverName, username, password, ip string, port int, dbName string) (*sql.DB, error) {
Expand Down
6 changes: 5 additions & 1 deletion conf/conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@ type TLSConfig struct {
KeyPassword string `json:"key_password"`
Cipher string `json:"cipher"`
CipherTLS13 string `json:"cipher_tls13"`
HTTPFile string `json:"plain_http_response"`
PreferServerCipher bool `json:"prefer_server_cipher"`
SNI string `json:"sni"`

HTTPFile string `json:"plain_http_response"`
FallbackHost string `json:"fallback_addr"`
FallbackPort uint16 `json:"fallback_port"`
FallbackAddr net.Addr

CertPool *x509.CertPool
KeyPair []tls.Certificate
HTTPResponse []byte
Expand Down
66 changes: 27 additions & 39 deletions conf/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"io/ioutil"
"net"
"os"
Expand All @@ -16,19 +17,18 @@ import (

var logger = log.New(os.Stdout)

func ConvertToIP(s string) ([]net.IP, error) {
ip := net.ParseIP(s)
if ip == nil {
ips, err := net.LookupIP(s)
if err != nil {
return nil, err
}
if len(ips) == 0 {
return nil, common.NewError("cannot resolve host:" + s)
}
return ips, nil
func convertToAddr(preferV4 bool, host string, port uint16) (*net.TCPAddr, error) {
ip := net.ParseIP(host)
if ip != nil {
return &net.TCPAddr{
IP: ip,
Port: int(port),
}, nil
}
return []net.IP{ip}, nil
if preferV4 {
return net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", host, port))
}
return net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:%d", host, port))
}

func ParseJSON(data []byte) (*GlobalConfig, error) {
Expand Down Expand Up @@ -108,39 +108,27 @@ func ParseJSON(data []byte) (*GlobalConfig, error) {
default:
return nil, common.NewError("invalid run type")
}
localIPs, err := ConvertToIP(config.LocalHost)

localAddr, err := convertToAddr(config.TCP.PreferIPV4, config.LocalHost, config.LocalPort)
if err != nil {
return nil, err
return nil, common.NewError("invalid local address").Base(err)
}
remoteIPs, err := ConvertToIP(config.RemoteHost)
config.LocalAddr = localAddr
config.LocalIP = localAddr.IP

remoteAddr, err := convertToAddr(config.TCP.PreferIPV4, config.RemoteHost, config.RemotePort)
if err != nil {
return nil, err
return nil, common.NewError("invalid remote address").Base(err)
}
config.RemoteAddr = remoteAddr
config.RemoteIP = remoteAddr.IP

config.LocalIP = localIPs[0]
config.RemoteIP = remoteIPs[0]

if config.TCP.PreferIPV4 {
for _, ip := range localIPs {
if ip.To4() != nil {
config.LocalIP = ip
break
}
}
for _, ip := range remoteIPs {
if ip.To4() != nil {
config.RemoteIP = ip
break
}
if config.TLS.FallbackHost != "" {
fallbackAddr, err := convertToAddr(config.TCP.PreferIPV4, config.TLS.FallbackHost, config.TLS.FallbackPort)
if err != nil {
return nil, common.NewError("invalid tls fallback address").Base(err)
}
}
config.LocalAddr = &net.TCPAddr{
IP: config.LocalIP,
Port: int(config.LocalPort),
}
config.RemoteAddr = &net.TCPAddr{
IP: config.RemoteIP,
Port: int(config.RemotePort),
config.TLS.FallbackAddr = fallbackAddr
}

if config.TLS.Cipher != "" || config.TLS.CipherTLS13 != "" {
Expand Down
2 changes: 1 addition & 1 deletion protocol/trojan/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (i *TrojanInboundConnSession) parseRequest() error {
Port: i.config.RemotePort,
NetworkType: "tcp",
}
logger.Warn("invalid hash or other protocol:", string(userHash))
logger.Warn("remote", i.conn.RemoteAddr(), "invalid hash or other protocol:", string(userHash))
return nil
}
i.passwordHash = string(userHash)
Expand Down
30 changes: 26 additions & 4 deletions proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"database/sql"
"net"
"reflect"

"github.com/p4gefau1t/trojan-go/common"
"github.com/p4gefau1t/trojan-go/conf"
Expand All @@ -28,12 +29,12 @@ type Server struct {

func (s *Server) handleMuxConn(stream *smux.Stream, passwordHash string) {
inboundConn, err := mux.NewInboundMuxConnSession(stream, passwordHash)
inboundConn.(protocol.NeedMeter).SetMeter(s.meter)
if err != nil {
stream.Close()
logger.Error(common.NewError("cannot start inbound session").Base(err))
return
}
inboundConn.(protocol.NeedMeter).SetMeter(s.meter)
defer inboundConn.Close()
req := inboundConn.GetRequest()
if req.Command != protocol.Connect {
Expand All @@ -52,11 +53,11 @@ func (s *Server) handleMuxConn(stream *smux.Stream, passwordHash string) {

func (s *Server) handleConn(conn net.Conn) {
inboundConn, err := trojan.NewInboundConnSession(conn, s.config, s.auth)

if err != nil {
logger.Error(err)
logger.Error(common.NewError("failed to start inbound session, remote:" + conn.RemoteAddr().String()).Base(err))
return
}

req := inboundConn.GetRequest()
hash := inboundConn.(protocol.HasHash).GetHash()

Expand Down Expand Up @@ -191,10 +192,31 @@ func (s *Server) Run() error {
tlsConn := tls.Server(conn, tlsConfig)
err = tlsConn.Handshake()
if err != nil {
logger.Warn(common.NewError("failed to perform handshake, responsing http payload").Base(err))
logger.Warn(common.NewError("failed to perform tls handshake, remote:" + conn.RemoteAddr().String()).Base(err))

if len(s.config.TLS.HTTPResponse) > 0 {
logger.Warn("trying to response a plain http response")
conn.Write(s.config.TLS.HTTPResponse)
continue
}

if s.config.TLS.FallbackAddr != nil {
//HACK
//obtain the bytes buffered by the tls conn
v := reflect.ValueOf(*tlsConn)
buf := v.FieldByName("rawInput").FieldByName("buf").Bytes()
logger.Debug("payload:" + string(buf))

remote, err := net.Dial("tcp", s.config.TLS.FallbackAddr.String())
if err != nil {
logger.Warn(common.NewError("failed to dial to tls fallback server").Base(err))
}
logger.Warn("proxying this invalid tls conn to the tls fallback server")
remote.Write(buf)
go proxyConn(conn, remote)
continue
}

conn.Close()
continue
}
Expand Down
5 changes: 3 additions & 2 deletions proxy/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,10 @@ func TestServerTCPRedirecting(t *testing.T) {
}
config.TLS.KeyPair = []tls.Certificate{key}
config.TLS.SNI = "localhost"
payload, err := ioutil.ReadFile("http.txt")
addr, err := net.ResolveTCPAddr("tcp", "localhost:443")
common.Must(err)
config.TLS.HTTPResponse = payload
config.TLS.FallbackAddr = addr

server := Server{
config: config,
}
Expand Down

0 comments on commit 76fd836

Please sign in to comment.