Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(x): make config support PacketListeners and make dependencies explicit and decoupled #304

Merged
merged 31 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 54 additions & 214 deletions x/configurl/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,108 +15,84 @@
package configurl

import (
"context"
"errors"
"fmt"
"net/url"
"strconv"
"strings"

"github.com/Jigsaw-Code/outline-sdk/transport"
"github.com/Jigsaw-Code/outline-sdk/transport/tlsfrag"
)

// ConfigToDialer enables the creation of stream and packet dialers based on a config. The config is
// extensible by registering wrappers for config subtypes.
type ConfigToDialer struct {
// Base StreamDialer to create direct stream connections. If you need direct stream connections, this must not be nil.
BaseStreamDialer transport.StreamDialer
// Base PacketDialer to create direct packet connections. If you need direct packet connections, this must not be nil.
BasePacketDialer transport.PacketDialer
sdBuilders map[string]NewStreamDialerFunc
pdBuilders map[string]NewPacketDialerFunc
// Config is a pre-parsed generic config created from pipe-separated URLs.
type Config struct {
URL url.URL
BaseConfig *Config
}

// NewStreamDialerFunc wraps a Dialer based on the wrapConfig. The innerSD and innerPD functions can provide a base Stream and Packet Dialers if needed.
type NewStreamDialerFunc func(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), wrapConfig *url.URL) (transport.StreamDialer, error)

// NewPacketDialerFunc wraps a Dialer based on the wrapConfig. The innerSD and innerPD functions can provide a base Stream and Packet Dialers if needed.
type NewPacketDialerFunc func(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), wrapConfig *url.URL) (transport.PacketDialer, error)

// NewDefaultConfigToDialer creates a [ConfigToDialer] with a set of default wrappers already registered.
func NewDefaultConfigToDialer() *ConfigToDialer {
p := new(ConfigToDialer)
p.BaseStreamDialer = &transport.TCPDialer{}
p.BasePacketDialer = &transport.UDPDialer{}

// Please keep the list in alphabetical order.
p.RegisterStreamDialerType("do53", wrapStreamDialerWithDO53)

p.RegisterStreamDialerType("doh", wrapStreamDialerWithDOH)

p.RegisterStreamDialerType("override", wrapStreamDialerWithOverride)
p.RegisterPacketDialerType("override", wrapPacketDialerWithOverride)

p.RegisterStreamDialerType("socks5", wrapStreamDialerWithSOCKS5)
p.RegisterPacketDialerType("socks5", wrapPacketDialerWithSOCKS5)

p.RegisterStreamDialerType("split", wrapStreamDialerWithSplit)
// BuildFunc is a function that creates an instance of ObjectType given a [Config].
type BuildFunc[ObjectType any] func(ctx context.Context, config *Config) (ObjectType, error)

p.RegisterStreamDialerType("ss", wrapStreamDialerWithShadowsocks)
p.RegisterPacketDialerType("ss", wrapPacketDialerWithShadowsocks)

p.RegisterStreamDialerType("tls", wrapStreamDialerWithTLS)
// TypeRegistry registers config types.
type TypeRegistry[ObjectType any] interface {
RegisterType(subtype string, newInstance BuildFunc[ObjectType])
}

p.RegisterStreamDialerType("tlsfrag", func(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), wrapConfig *url.URL) (transport.StreamDialer, error) {
sd, err := innerSD()
if err != nil {
return nil, err
}
lenStr := wrapConfig.Opaque
fixedLen, err := strconv.Atoi(lenStr)
if err != nil {
return nil, fmt.Errorf("invalid tlsfrag option: %v. It should be in tlsfrag:<number> format", lenStr)
}
return tlsfrag.NewFixedLenStreamDialer(sd, fixedLen)
})
// ExtensibleProvider creates instances of ObjectType in a way that can be extended via its [TypeRegistry] interface.
type ExtensibleProvider[ObjectType comparable] struct {
// Instance to return when config is nil.
BaseInstance ObjectType
builders map[string]BuildFunc[ObjectType]
}

p.RegisterStreamDialerType("ws", wrapStreamDialerWithWebSocket)
p.RegisterPacketDialerType("ws", wrapPacketDialerWithWebSocket)
var (
_ BuildFunc[any] = (*ExtensibleProvider[any])(nil).NewInstance
_ TypeRegistry[any] = (*ExtensibleProvider[any])(nil)
)

return p
// NewExtensibleProvider creates an [ExtensibleProvider] with the given base instance.
func NewExtensibleProvider[ObjectType comparable](baseInstance ObjectType) ExtensibleProvider[ObjectType] {
return ExtensibleProvider[ObjectType]{
BaseInstance: baseInstance,
builders: make(map[string]BuildFunc[ObjectType]),
}
}

// RegisterStreamDialerType will register a wrapper for stream dialers under the given subtype.
func (p *ConfigToDialer) RegisterStreamDialerType(subtype string, newDialer NewStreamDialerFunc) error {
if p.sdBuilders == nil {
p.sdBuilders = make(map[string]NewStreamDialerFunc)
func (p *ExtensibleProvider[ObjectType]) ensureBuildersMap() map[string]BuildFunc[ObjectType] {
if p.builders == nil {
p.builders = make(map[string]BuildFunc[ObjectType])
}
return p.builders
}

if _, found := p.sdBuilders[subtype]; found {
return fmt.Errorf("config parser %v for StreamDialer added twice", subtype)
}
p.sdBuilders[subtype] = newDialer
return nil
// RegisterType will register a factory for the given subtype.
func (p *ExtensibleProvider[ObjectType]) RegisterType(subtype string, newInstance BuildFunc[ObjectType]) {
p.ensureBuildersMap()[subtype] = newInstance
}

// RegisterPacketDialerType will register a wrapper for packet dialers under the given subtype.
func (p *ConfigToDialer) RegisterPacketDialerType(subtype string, newDialer NewPacketDialerFunc) error {
if p.pdBuilders == nil {
p.pdBuilders = make(map[string]NewPacketDialerFunc)
// NewInstance creates a new instance of ObjectType according to the config.
func (p *ExtensibleProvider[ObjectType]) NewInstance(ctx context.Context, config *Config) (ObjectType, error) {
var zero ObjectType
if config == nil {
if p.BaseInstance == zero {
return zero, errors.New("base instance is not configured")
}
return p.BaseInstance, nil
}

if _, found := p.pdBuilders[subtype]; found {
return fmt.Errorf("config parser %v for StreamDialer added twice", subtype)
newInstance, ok := p.ensureBuildersMap()[config.URL.Scheme]
if !ok {
return zero, fmt.Errorf("config type '%v' is not registered", config.URL.Scheme)
}
p.pdBuilders[subtype] = newDialer
return nil
return newInstance(ctx, config)
}

func parseConfig(configText string) ([]*url.URL, error) {
// ParseConfig will parse a config given as a string and return the structured [Config].
func ParseConfig(configText string) (*Config, error) {
parts := strings.Split(strings.TrimSpace(configText), "|")
if len(parts) == 1 && parts[0] == "" {
return []*url.URL{}, nil
return nil, nil
}
urls := make([]*url.URL, 0, len(parts))

var config *Config = nil
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
Expand All @@ -130,143 +106,7 @@ func parseConfig(configText string) ([]*url.URL, error) {
if err != nil {
return nil, fmt.Errorf("part is not a valid URL: %w", err)
}
urls = append(urls, url)
}
return urls, nil
}

// NewStreamDialer creates a [Dialer] according to transportConfig, using dialer as the
// base [Dialer]. The given dialer must not be nil.
func (p *ConfigToDialer) NewStreamDialer(transportConfig string) (transport.StreamDialer, error) {
parts, err := parseConfig(transportConfig)
if err != nil {
return nil, err
}
return p.newStreamDialer(parts)
}

// NewPacketDialer creates a [Dialer] according to transportConfig, using dialer as the
// base [Dialer]. The given dialer must not be nil.
func (p *ConfigToDialer) NewPacketDialer(transportConfig string) (transport.PacketDialer, error) {
parts, err := parseConfig(transportConfig)
if err != nil {
return nil, err
}
return p.newPacketDialer(parts)
}

func (p *ConfigToDialer) newStreamDialer(configParts []*url.URL) (transport.StreamDialer, error) {
if len(configParts) == 0 {
if p.BaseStreamDialer == nil {
return nil, fmt.Errorf("base StreamDialer must not be nil")
}
return p.BaseStreamDialer, nil
}
thisURL := configParts[len(configParts)-1]
innerConfig := configParts[:len(configParts)-1]
newDialer, ok := p.sdBuilders[thisURL.Scheme]
if !ok {
return nil, fmt.Errorf("config scheme '%v' is not supported for Stream Dialers", thisURL.Scheme)
}
newSD := func() (transport.StreamDialer, error) {
return p.newStreamDialer(innerConfig)
}
newPD := func() (transport.PacketDialer, error) {
return p.newPacketDialer(innerConfig)
}
return newDialer(newSD, newPD, thisURL)
}

func (p *ConfigToDialer) newPacketDialer(configParts []*url.URL) (transport.PacketDialer, error) {
if len(configParts) == 0 {
if p.BasePacketDialer == nil {
return nil, fmt.Errorf("base PacketDialer must not be nil")
}
return p.BasePacketDialer, nil
}
thisURL := configParts[len(configParts)-1]
innerConfig := configParts[:len(configParts)-1]
newDialer, ok := p.pdBuilders[thisURL.Scheme]
if !ok {
return nil, fmt.Errorf("config scheme '%v' is not supported for Packet Dialers", thisURL.Scheme)
}
newSD := func() (transport.StreamDialer, error) {
return p.newStreamDialer(innerConfig)
}
newPD := func() (transport.PacketDialer, error) {
return p.newPacketDialer(innerConfig)
}
return newDialer(newSD, newPD, thisURL)
}

// NewpacketListener creates a new [transport.PacketListener] according to the given config,
// the config must contain only one "ss://" segment.
// TODO: make NewPacketListener configurable.
func NewPacketListener(transportConfig string) (transport.PacketListener, error) {
parts, err := parseConfig(transportConfig)
if err != nil {
return nil, err
}
if len(parts) == 0 {
return nil, errors.New("config is required")
}
if len(parts) > 1 {
return nil, errors.New("multi-part config is not supported")
}

url := parts[0]
// Please keep scheme list sorted.
switch strings.ToLower(url.Scheme) {
case "ss":
// TODO: support nested dialer, the last part must be "ss://"
return newShadowsocksPacketListenerFromURL(url)
default:
return nil, fmt.Errorf("config scheme '%v' is not supported", url.Scheme)
}
}

func SanitizeConfig(transportConfig string) (string, error) {
parts, err := parseConfig(transportConfig)
if err != nil {
return "", err
}

// Do nothing if the config is empty
if len(parts) == 0 {
return "", nil
}

// Iterate through each part
textParts := make([]string, len(parts))
for i, u := range parts {
scheme := strings.ToLower(u.Scheme)
switch scheme {
case "ss":
textParts[i], err = sanitizeShadowsocksURL(u)
if err != nil {
return "", err
}
case "socks5":
textParts[i], err = sanitizeSocks5URL(u)
if err != nil {
return "", err
}
case "override", "split", "tls", "tlsfrag":
// No sanitization needed
textParts[i] = u.String()
default:
textParts[i] = scheme + "://UNKNOWN"
}
}
// Join the parts back into a string
return strings.Join(textParts, "|"), nil
}

func sanitizeSocks5URL(u *url.URL) (string, error) {
const redactedPlaceholder = "REDACTED"
if u.User != nil {
u.User = url.User(redactedPlaceholder)
return u.String(), nil
config = &Config{URL: *url, BaseConfig: config}
}
return u.String(), nil
return config, nil
}
63 changes: 44 additions & 19 deletions x/configurl/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,46 @@ import (
"golang.org/x/net/dns/dnsmessage"
)

func wrapStreamDialerWithDO53(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), configURL *url.URL) (transport.StreamDialer, error) {
sd, err := innerSD()
if err != nil {
return nil, err
}
pd, err := innerPD()
if err != nil {
return nil, err
}
query := configURL.Opaque
func registerDO53StreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer], newPD BuildFunc[transport.PacketDialer]) {
r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) {
if config == nil {
return nil, fmt.Errorf("emtpy do53 config")
}
sd, err := newSD(ctx, config.BaseConfig)
if err != nil {
return nil, err
}
pd, err := newPD(ctx, config.BaseConfig)
if err != nil {
return nil, err
}
resolver, err := newDO53Resolver(config.URL, sd, pd)
if err != nil {
return nil, err
}
return dns.NewStreamDialer(resolver, sd)
})
}

func registerDOHStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) {
r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) {
if config == nil {
return nil, fmt.Errorf("emtpy doh config")
}
sd, err := newSD(ctx, config.BaseConfig)
if err != nil {
return nil, err
}
resolver, err := newDOHResolver(config.URL, sd)
if err != nil {
return nil, err
}
return dns.NewStreamDialer(resolver, sd)
})
}

func newDO53Resolver(config url.URL, sd transport.StreamDialer, pd transport.PacketDialer) (dns.Resolver, error) {
query := config.Opaque
values, err := url.ParseQuery(query)
if err != nil {
return nil, err
Expand Down Expand Up @@ -75,19 +105,15 @@ func wrapStreamDialerWithDO53(innerSD func() (transport.StreamDialer, error), in
// See https://datatracker.ietf.org/doc/html/rfc1123#page-75.
return tcpResolver.Query(ctx, q)
})
return dns.NewStreamDialer(resolver, sd)
return resolver, nil
}

func wrapStreamDialerWithDOH(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), configURL *url.URL) (transport.StreamDialer, error) {
query := configURL.Opaque
func newDOHResolver(config url.URL, sd transport.StreamDialer) (dns.Resolver, error) {
query := config.Opaque
values, err := url.ParseQuery(query)
if err != nil {
return nil, err
}
sd, err := innerSD()
if err != nil {
return nil, err
}

var name, address string
for key, values := range values {
Expand Down Expand Up @@ -119,6 +145,5 @@ func wrapStreamDialerWithDOH(innerSD func() (transport.StreamDialer, error), inn
port = "443"
}
dohURL := url.URL{Scheme: "https", Host: net.JoinHostPort(name, port), Path: "/dns-query"}
resolver := dns.NewHTTPSResolver(sd, address, dohURL.String())
return dns.NewStreamDialer(resolver, sd)
return dns.NewHTTPSResolver(sd, address, dohURL.String()), nil
}
Loading
Loading