Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: revamp config #230

Merged
merged 9 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 146 additions & 133 deletions x/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,187 +22,194 @@ import (
"strings"

"github.com/Jigsaw-Code/outline-sdk/transport"
"github.com/Jigsaw-Code/outline-sdk/transport/split"
"github.com/Jigsaw-Code/outline-sdk/transport/tlsfrag"
)

// ConfigParser enables the creation of stream and packet dialers based on a config. The config is
// ConfigToDialer enables the creation of stream and packet dialers based on a config. The config is
// extensible by registering wrappers for config subtypes.
type ConfigParser struct {
sdWrapers map[string]WrapStreamDialerFunc
pdWrappers map[string]WrapPacketDialerFunc
type ConfigToDialer struct {
// Base StreamDialer to create direct stream connections. If you need direct stream connections, this must not be nil.
BaseStreamDialer transport.StreamDialer
// Base PacketDialer to create direct packet connections. If you need direct packet connections, this must not be nil.
BasePacketDialer transport.PacketDialer
sdBuilders map[string]NewStreamDialerFunc
pdBuilders map[string]NewPacketDialerFunc
}

// NewDefaultConfigParser creates a [ConfigParser] with a set of default wrappers already registered.
func NewDefaultConfigParser() *ConfigParser {
p := new(ConfigParser)
// NewStreamDialerFunc wraps a Dialer based on the wrapConfig. The innerSD and innerPD functions can provide a base Stream and Packet Dialers if needed.
type NewStreamDialerFunc func(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), wrapConfig *url.URL) (transport.StreamDialer, error)

// NewPacketDialerFunc wraps a Dialer based on the wrapConfig. The innerSD and innerPD functions can provide a base Stream and Packet Dialers if needed.
type NewPacketDialerFunc func(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), wrapConfig *url.URL) (transport.PacketDialer, error)

// NewDefaultConfigToDialer creates a [ConfigToDialer] with a set of default wrappers already registered.
func NewDefaultConfigToDialer() *ConfigToDialer {
p := new(ConfigToDialer)
p.BaseStreamDialer = &transport.TCPDialer{}
p.BasePacketDialer = &transport.UDPDialer{}

// Please keep the list in alphabetical order.
p.RegisterStreamDialerWrapper("doh", wrapStreamDialerWithDOH)
p.RegisterPacketDialerWrapper("doh", func(baseDialer transport.PacketDialer, wrapConfig *url.URL) (transport.PacketDialer, error) {
return nil, errors.New("doh is not supported for PacketDialers")
})
p.RegisterStreamDialerType("do53", wrapStreamDialerWithDO53)

p.RegisterStreamDialerWrapper("override", wrapStreamDialerWithOverride)
p.RegisterPacketDialerWrapper("override", wrapPacketDialerWithOverride)
p.RegisterStreamDialerType("doh", wrapStreamDialerWithDOH)

p.RegisterStreamDialerWrapper("socks5", wrapStreamDialerWithSOCKS5)
p.RegisterPacketDialerWrapper("socks5", func(baseDialer transport.PacketDialer, wrapConfig *url.URL) (transport.PacketDialer, error) {
return nil, errors.New("socks5 is not supported for PacketDialers")
})
p.RegisterStreamDialerType("override", wrapStreamDialerWithOverride)
p.RegisterPacketDialerType("override", wrapPacketDialerWithOverride)

p.RegisterStreamDialerWrapper("split", func(baseDialer transport.StreamDialer, wrapConfig *url.URL) (transport.StreamDialer, error) {
prefixBytesStr := wrapConfig.Opaque
prefixBytes, err := strconv.Atoi(prefixBytesStr)
if err != nil {
return nil, fmt.Errorf("prefixBytes is not a number: %v. Split config should be in split:<number> format", prefixBytesStr)
}
return split.NewStreamDialer(baseDialer, int64(prefixBytes))
})
p.RegisterPacketDialerWrapper("split", func(baseDialer transport.PacketDialer, wrapConfig *url.URL) (transport.PacketDialer, error) {
return nil, errors.New("split is not supported for PacketDialers")
})
p.RegisterStreamDialerType("socks5", wrapStreamDialerWithSOCKS5)

p.RegisterStreamDialerWrapper("ss", wrapStreamDialerWithShadowsocks)
p.RegisterPacketDialerWrapper("ss", wrapPacketDialerWithShadowsocks)
p.RegisterStreamDialerType("split", wrapStreamDialerWithSplit)

p.RegisterStreamDialerWrapper("tls", wrapStreamDialerWithTLS)
p.RegisterPacketDialerWrapper("tls", func(baseDialer transport.PacketDialer, wrapConfig *url.URL) (transport.PacketDialer, error) {
return nil, errors.New("tls is not supported for PacketDialers")
})
p.RegisterStreamDialerType("ss", wrapStreamDialerWithShadowsocks)
p.RegisterPacketDialerType("ss", wrapPacketDialerWithShadowsocks)

p.RegisterStreamDialerWrapper("tlsfrag", func(baseDialer transport.StreamDialer, wrapConfig *url.URL) (transport.StreamDialer, error) {
p.RegisterStreamDialerType("tls", wrapStreamDialerWithTLS)

p.RegisterStreamDialerType("tlsfrag", func(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), wrapConfig *url.URL) (transport.StreamDialer, error) {
sd, err := innerSD()
if err != nil {
return nil, err
}
lenStr := wrapConfig.Opaque
fixedLen, err := strconv.Atoi(lenStr)
if err != nil {
return nil, fmt.Errorf("invalid tlsfrag option: %v. It should be in tlsfrag:<number> format", lenStr)
}
return tlsfrag.NewFixedLenStreamDialer(baseDialer, fixedLen)
})
p.RegisterPacketDialerWrapper("tlsfrag", func(baseDialer transport.PacketDialer, wrapConfig *url.URL) (transport.PacketDialer, error) {
return nil, errors.New("tlsfrag is not supported for PacketDialers")
return tlsfrag.NewFixedLenStreamDialer(sd, fixedLen)
})

return p
}

// WrapStreamDialerFunc wraps a [transport.StreamDialer] based on the wrapConfig.
type WrapStreamDialerFunc func(dialer transport.StreamDialer, wrapConfig *url.URL) (transport.StreamDialer, error)

// RegisterStreamDialerWrapper will register a wrapper for stream dialers under the given subtype.
func (p *ConfigParser) RegisterStreamDialerWrapper(subtype string, wrapper WrapStreamDialerFunc) error {
if p.sdWrapers == nil {
p.sdWrapers = make(map[string]WrapStreamDialerFunc)
// RegisterStreamDialerType will register a wrapper for stream dialers under the given subtype.
func (p *ConfigToDialer) RegisterStreamDialerType(subtype string, newDialer NewStreamDialerFunc) error {
if p.sdBuilders == nil {
p.sdBuilders = make(map[string]NewStreamDialerFunc)
}

if _, found := p.sdWrapers[subtype]; found {
if _, found := p.sdBuilders[subtype]; found {
return fmt.Errorf("config parser %v for StreamDialer added twice", subtype)
}
p.sdWrapers[subtype] = wrapper
p.sdBuilders[subtype] = newDialer
return nil
}

// WrapPacketDialerFunc wraps a [transport.PacketDialer] based on the wrapConfig.
type WrapPacketDialerFunc func(dialer transport.PacketDialer, wrapConfig *url.URL) (transport.PacketDialer, error)

// RegisterPacketDialerWrapper will register a wrapper for packet dialers under the given subtype.
func (p *ConfigParser) RegisterPacketDialerWrapper(subtype string, wrapper WrapPacketDialerFunc) error {
if p.pdWrappers == nil {
p.pdWrappers = make(map[string]WrapPacketDialerFunc)
// RegisterPacketDialerType will register a wrapper for packet dialers under the given subtype.
func (p *ConfigToDialer) RegisterPacketDialerType(subtype string, newDialer NewPacketDialerFunc) error {
if p.pdBuilders == nil {
p.pdBuilders = make(map[string]NewPacketDialerFunc)
}

if _, found := p.pdWrappers[subtype]; found {
return fmt.Errorf("config parser %v for PacketDialer added twice", subtype)
if _, found := p.pdBuilders[subtype]; found {
return fmt.Errorf("config parser %v for StreamDialer added twice", subtype)
}
p.pdWrappers[subtype] = wrapper
p.pdBuilders[subtype] = newDialer
return nil
}

func parseConfigPart(oneDialerConfig string) (*url.URL, error) {
oneDialerConfig = strings.TrimSpace(oneDialerConfig)
if oneDialerConfig == "" {
return nil, errors.New("empty config part")
func parseConfig(configText string) ([]*url.URL, error) {
parts := strings.Split(strings.TrimSpace(configText), "|")
if len(parts) == 1 && parts[0] == "" {
return []*url.URL{}, nil
}
// Make it "<scheme>:" if it's only "<scheme>" to parse as a URL.
if !strings.Contains(oneDialerConfig, ":") {
oneDialerConfig += ":"
urls := make([]*url.URL, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
return nil, errors.New("empty config part")
}
// Make it "<scheme>:" if it's only "<scheme>" to parse as a URL.
if !strings.Contains(part, ":") {
part += ":"
}
url, err := url.Parse(part)
if err != nil {
return nil, fmt.Errorf("part is not a valid URL: %w", err)
}
urls = append(urls, url)
}
url, err := url.Parse(oneDialerConfig)
return urls, nil
}

// NewStreamDialer creates a [Dialer] according to transportConfig, using dialer as the
// base [Dialer]. The given dialer must not be nil.
func (p *ConfigToDialer) NewStreamDialer(transportConfig string) (transport.StreamDialer, error) {
parts, err := parseConfig(transportConfig)
if err != nil {
return nil, fmt.Errorf("part is not a valid URL: %w", err)
return nil, err
}
return url, nil
return p.newStreamDialer(parts)
}

// WrapStreamDialer creates a [transport.StreamDialer] according to transportConfig, using dialer as the
// base [transport.StreamDialer]. The given dialer must not be nil.
func (p *ConfigParser) WrapStreamDialer(dialer transport.StreamDialer, transportConfig string) (transport.StreamDialer, error) {
if dialer == nil {
return nil, errors.New("base dialer must not be nil")
}
transportConfig = strings.TrimSpace(transportConfig)
if transportConfig == "" {
return dialer, nil
// NewPacketDialer creates a [Dialer] according to transportConfig, using dialer as the
// base [Dialer]. The given dialer must not be nil.
func (p *ConfigToDialer) NewPacketDialer(transportConfig string) (transport.PacketDialer, error) {
parts, err := parseConfig(transportConfig)
if err != nil {
return nil, err
}
for _, part := range strings.Split(transportConfig, "|") {
url, err := parseConfigPart(part)
if err != nil {
return nil, err
}
w, ok := p.sdWrapers[url.Scheme]
if !ok {
return nil, fmt.Errorf("config scheme '%v' is not supported", url.Scheme)
}
dialer, err = w(dialer, url)
if err != nil {
return nil, err
return p.newPacketDialer(parts)
}

func (p *ConfigToDialer) newStreamDialer(configParts []*url.URL) (transport.StreamDialer, error) {
if len(configParts) == 0 {
if p.BaseStreamDialer == nil {
return nil, fmt.Errorf("base StreamDialer must not be nil")
}
return p.BaseStreamDialer, nil
}
thisURL := configParts[len(configParts)-1]
innerConfig := configParts[:len(configParts)-1]
newDialer, ok := p.sdBuilders[thisURL.Scheme]
if !ok {
return nil, fmt.Errorf("config scheme '%v' is not supported for Stream Dialers", thisURL.Scheme)
}
return dialer, nil
newSD := func() (transport.StreamDialer, error) {
return p.newStreamDialer(innerConfig)
}
newPD := func() (transport.PacketDialer, error) {
return p.newPacketDialer(innerConfig)
}
return newDialer(newSD, newPD, thisURL)
}

// WrapPacketDialer creates a [transport.PacketDialer] according to transportConfig, using dialer as the
// base [transport.PacketDialer]. The given dialer must not be nil.
func (p *ConfigParser) WrapPacketDialer(dialer transport.PacketDialer, transportConfig string) (transport.PacketDialer, error) {
if dialer == nil {
return nil, errors.New("base dialer must not be nil")
func (p *ConfigToDialer) newPacketDialer(configParts []*url.URL) (transport.PacketDialer, error) {
if len(configParts) == 0 {
if p.BasePacketDialer == nil {
return nil, fmt.Errorf("base PacketDialer must not be nil")
}
return p.BasePacketDialer, nil
}
transportConfig = strings.TrimSpace(transportConfig)
if transportConfig == "" {
return dialer, nil
thisURL := configParts[len(configParts)-1]
innerConfig := configParts[:len(configParts)-1]
newDialer, ok := p.pdBuilders[thisURL.Scheme]
if !ok {
return nil, fmt.Errorf("config scheme '%v' is not supported for Packet Dialers", thisURL.Scheme)
}
for _, part := range strings.Split(transportConfig, "|") {
url, err := parseConfigPart(part)
if err != nil {
return nil, err
}
w, ok := p.pdWrappers[url.Scheme]
if !ok {
return nil, fmt.Errorf("config scheme '%v' is not supported", url.Scheme)
}
dialer, err = w(dialer, url)
if err != nil {
return nil, err
}
newSD := func() (transport.StreamDialer, error) {
return p.newStreamDialer(innerConfig)
}
newPD := func() (transport.PacketDialer, error) {
return p.newPacketDialer(innerConfig)
}
return dialer, nil
return newDialer(newSD, newPD, thisURL)
}

// NewpacketListener creates a new [transport.PacketListener] according to the given config,
// the config must contain only one "ss://" segment.
// TODO: make NewPacketListener configurable.
func NewPacketListener(transportConfig string) (transport.PacketListener, error) {
if transportConfig = strings.TrimSpace(transportConfig); transportConfig == "" {
parts, err := parseConfig(transportConfig)
if err != nil {
return nil, err
}
if len(parts) == 0 {
return nil, errors.New("config is required")
}
if strings.Contains(transportConfig, "|") {
if len(parts) > 1 {
return nil, errors.New("multi-part config is not supported")
}

url, err := parseConfigPart(transportConfig)
if err != nil {
return nil, fmt.Errorf("failed to parse config: %w", err)
}
url := parts[0]
// Please keep scheme list sorted.
switch strings.ToLower(url.Scheme) {
case "ss":
Expand All @@ -214,34 +221,40 @@ func NewPacketListener(transportConfig string) (transport.PacketListener, error)
}

func SanitizeConfig(transportConfig string) (string, error) {
parts, err := parseConfig(transportConfig)
if err != nil {
return "", err
}

// Do nothing if the config is empty
if transportConfig == "" {
if len(parts) == 0 {
return "", nil
}
// Split the string into parts
parts := strings.Split(transportConfig, "|")

// Iterate through each part
for i, part := range parts {
u, err := parseConfigPart(part)
if err != nil {
return "", fmt.Errorf("failed to parse config part: %w", err)
}
textParts := make([]string, len(parts))
for i, u := range parts {
scheme := strings.ToLower(u.Scheme)
switch scheme {
case "ss":
parts[i], _ = sanitizeShadowsocksURL(u)
textParts[i], err = sanitizeShadowsocksURL(u)
if err != nil {
return "", err
}
case "socks5":
parts[i], _ = sanitizeSocks5URL(u)
textParts[i], err = sanitizeSocks5URL(u)
if err != nil {
return "", err
}
case "override", "split", "tls", "tlsfrag":
// No sanitization needed
parts[i] = u.String()
textParts[i] = u.String()
default:
parts[i] = scheme + "://UNKNOWN"
textParts[i] = scheme + "://UNKNOWN"
}
}
// Join the parts back into a string
return strings.Join(parts, "|"), nil
return strings.Join(textParts, "|"), nil
}

func sanitizeSocks5URL(u *url.URL) (string, error) {
Expand Down
Loading
Loading