Skip to content

Commit

Permalink
Merge pull request #311 from Jigsaw-Code/fortuna-split-1
Browse files Browse the repository at this point in the history
feat: advanced split options [BREAKING]
  • Loading branch information
fortuna authored Nov 7, 2024
2 parents ed096db + cae6dd4 commit dc0a0b0
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 41 deletions.
18 changes: 11 additions & 7 deletions transport/split/stream_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,24 @@ import (
"github.com/Jigsaw-Code/outline-sdk/transport"
)

// splitDialer is a [transport.StreamDialer] that implements the split strategy.
// Use [NewStreamDialer] to create new instances.
type splitDialer struct {
dialer transport.StreamDialer
splitPoint int64
dialer transport.StreamDialer
nextSplit SplitIterator
}

var _ transport.StreamDialer = (*splitDialer)(nil)

// NewStreamDialer creates a [transport.StreamDialer] that splits the outgoing stream after writing "prefixBytes" bytes
// using [SplitWriter].
func NewStreamDialer(dialer transport.StreamDialer, prefixBytes int64) (transport.StreamDialer, error) {
// NewStreamDialer creates a [transport.StreamDialer] that splits the outgoing stream according to nextSplit.
func NewStreamDialer(dialer transport.StreamDialer, nextSplit SplitIterator) (transport.StreamDialer, error) {
if dialer == nil {
return nil, errors.New("argument dialer must not be nil")
}
return &splitDialer{dialer: dialer, splitPoint: prefixBytes}, nil
if nextSplit == nil {
return nil, errors.New("argument nextSplit must not be nil")
}
return &splitDialer{dialer: dialer, nextSplit: nextSplit}, nil
}

// DialStream implements [transport.StreamDialer].DialStream.
Expand All @@ -43,5 +47,5 @@ func (d *splitDialer) DialStream(ctx context.Context, remoteAddr string) (transp
if err != nil {
return nil, err
}
return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.splitPoint)), nil
return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.nextSplit)), nil
}
108 changes: 92 additions & 16 deletions transport/split/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ import (
)

type splitWriter struct {
writer io.Writer
prefixBytes int64
writer io.Writer
// Bytes until the next split. This must always be > 0, unless splits are done.
nextSplitBytes int64
nextSegmentLength func() int64
}

var _ io.Writer = (*splitWriter)(nil)
Expand All @@ -32,36 +34,110 @@ type splitWriterReaderFrom struct {

var _ io.ReaderFrom = (*splitWriterReaderFrom)(nil)

// NewWriter creates a [io.Writer] that ensures the byte sequence is split at prefixBytes.
// A write will end right after byte index prefixBytes - 1, before a write starting at byte index prefixBytes.
// For example, if you have a write of [0123456789] and prefixBytes = 3, you will get writes [012] and [3456789].
// If the input writer is a [io.ReaderFrom], the output writer will be too.
func NewWriter(writer io.Writer, prefixBytes int64) io.Writer {
sw := &splitWriter{writer, prefixBytes}
// SplitIterator is a function that returns how many bytes until the next split point, or zero if there are no more splits to do.
type SplitIterator func() int64

// NewFixedSplitIterator is a helper function that returns a [SplitIterator] that returns the input number once, followed by zero.
// This is helpful for when you want to split the stream once in a fixed position.
func NewFixedSplitIterator(n int64) SplitIterator {
return func() int64 {
next := n
n = 0
return next
}
}

// RepeatedSplit represents a split sequence of count segments with bytes length.
type RepeatedSplit struct {
Count int
Bytes int64
}

// NewRepeatedSplitIterator is a helper function that returns a [SplitIterator] that returns split points according to splits.
// The splits input represents pairs of (count, bytes), meaning a sequence of count splits with bytes length.
// This is helpful for when you want to split the stream repeatedly at different positions and lengths.
func NewRepeatedSplitIterator(splits ...RepeatedSplit) SplitIterator {
// Make sure we don't edit the original slice.
cleanSplits := make([]RepeatedSplit, 0, len(splits))
// Remove no-op splits.
for _, split := range splits {
if split.Count > 0 && split.Bytes > 0 {
cleanSplits = append(cleanSplits, split)
}
}
return func() int64 {
if len(cleanSplits) == 0 {
return 0
}
next := cleanSplits[0].Bytes
cleanSplits[0].Count -= 1
if cleanSplits[0].Count == 0 {
cleanSplits = cleanSplits[1:]
}
return next
}
}

// NewWriter creates a split Writer that calls the nextSegmentLength [SplitIterator] to determine the number bytes until the next split
// point until it returns zero.
func NewWriter(writer io.Writer, nextSegmentLength SplitIterator) io.Writer {
sw := &splitWriter{writer: writer, nextSegmentLength: nextSegmentLength}
sw.nextSplitBytes = nextSegmentLength()
if rf, ok := writer.(io.ReaderFrom); ok {
return &splitWriterReaderFrom{sw, rf}
}
return sw
}

// ReadFrom implements io.ReaderFrom.
func (w *splitWriterReaderFrom) ReadFrom(source io.Reader) (int64, error) {
reader := io.MultiReader(io.LimitReader(source, w.prefixBytes), source)
written, err := w.rf.ReadFrom(reader)
w.prefixBytes -= written
var written int64
for w.nextSplitBytes > 0 {
expectedBytes := w.nextSplitBytes
n, err := w.rf.ReadFrom(io.LimitReader(source, expectedBytes))
written += n
w.advance(n)
if err != nil {
return written, err
}
if n < expectedBytes {
// Source is done before the split happened. Return.
return written, err
}
}
n, err := w.rf.ReadFrom(source)
written += n
w.advance(n)
return written, err
}

func (w *splitWriter) advance(n int64) {
if w.nextSplitBytes == 0 {
// Done with splits: return.
return
}
w.nextSplitBytes -= int64(n)
if w.nextSplitBytes > 0 {
return
}
// Split done, set up the next split.
w.nextSplitBytes = w.nextSegmentLength()
}

// Write implements io.Writer.
func (w *splitWriter) Write(data []byte) (written int, err error) {
if 0 < w.prefixBytes && w.prefixBytes < int64(len(data)) {
written, err = w.writer.Write(data[:w.prefixBytes])
w.prefixBytes -= int64(written)
for 0 < w.nextSplitBytes && w.nextSplitBytes < int64(len(data)) {
dataToSend := data[:w.nextSplitBytes]
n, err := w.writer.Write(dataToSend)
written += n
w.advance(int64(n))
if err != nil {
return written, err
}
data = data[written:]
data = data[n:]
}
n, err := w.writer.Write(data)
written += n
w.prefixBytes -= int64(n)
w.advance(int64(n))
return written, err
}
88 changes: 80 additions & 8 deletions transport/split/writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,52 @@ func (w *collectWrites) Write(data []byte) (int, error) {

func TestWrite_Split(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 3)
splitWriter := NewWriter(&innerWriter, NewFixedSplitIterator(3))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
require.Equal(t, [][]byte{[]byte("Req"), []byte("uest")}, innerWriter.writes)
}

func TestWrite_SplitZero(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 0}, RepeatedSplit{0, 1}, RepeatedSplit{10, 0}, RepeatedSplit{0, 2}))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
require.Equal(t, [][]byte{[]byte("Request")}, innerWriter.writes)
}

func TestWrite_SplitZeroLong(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 0}, RepeatedSplit{1_000_000_000_000_000_000, 0}))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
require.Equal(t, [][]byte{[]byte("Request")}, innerWriter.writes)
}

func TestWrite_SplitZeroPrefix(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 0}, RepeatedSplit{3, 2}))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
require.Equal(t, [][]byte{[]byte("Re"), []byte("qu"), []byte("es"), []byte("t")}, innerWriter.writes)
}

func TestWrite_SplitMulti(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 1}, RepeatedSplit{3, 2}, RepeatedSplit{2, 3}))
n, err := splitWriter.Write([]byte("RequestRequestRequest"))
require.NoError(t, err)
require.Equal(t, 21, n)
require.Equal(t, [][]byte{[]byte("R"), []byte("eq"), []byte("ue"), []byte("st"), []byte("Req"), []byte("ues"), []byte("tRequest")}, innerWriter.writes)
}

func TestWrite_ShortWrite(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 10)
splitWriter := NewWriter(&innerWriter, NewFixedSplitIterator(10))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
Expand All @@ -56,7 +92,7 @@ func TestWrite_ShortWrite(t *testing.T) {

func TestWrite_Zero(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 0)
splitWriter := NewWriter(&innerWriter, NewFixedSplitIterator(0))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
Expand All @@ -65,7 +101,7 @@ func TestWrite_Zero(t *testing.T) {

func TestWrite_NeedsTwoWrites(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 5)
splitWriter := NewWriter(&innerWriter, NewFixedSplitIterator(5))
n, err := splitWriter.Write([]byte("Re"))
require.NoError(t, err)
require.Equal(t, 2, n)
Expand All @@ -77,13 +113,37 @@ func TestWrite_NeedsTwoWrites(t *testing.T) {

func TestWrite_Compound(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(NewWriter(&innerWriter, 4), 1)
splitWriter := NewWriter(NewWriter(&innerWriter, NewFixedSplitIterator(4)), NewFixedSplitIterator(1))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
require.Equal(t, [][]byte{[]byte("R"), []byte("equ"), []byte("est")}, innerWriter.writes)
}

func TestWrite_RepeatNumber3_SkipBytes5(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 1}, RepeatedSplit{3, 5}))
n, err := splitWriter.Write([]byte("RequestRequestRequest."))
require.NoError(t, err)
require.Equal(t, 7*3+1, n)
require.Equal(t, [][]byte{
[]byte("R"), // prefix
[]byte("eques"), // split 1
[]byte("tRequ"), // split 2
[]byte("estRe"), // split 3
[]byte("quest."), // tail
}, innerWriter.writes)
}

func TestWrite_RepeatNumber3_SkipBytes0(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 1}, RepeatedSplit{0, 3}))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
require.Equal(t, [][]byte{[]byte("R"), []byte("equest")}, innerWriter.writes)
}

// collectReader is a [io.Reader] that appends each Read from the Reader to the reads slice.
type collectReader struct {
io.Reader
Expand All @@ -101,7 +161,7 @@ func (r *collectReader) Read(buf []byte) (int, error) {
}

func TestReadFrom(t *testing.T) {
splitWriter := NewWriter(&bytes.Buffer{}, 3)
splitWriter := NewWriter(&bytes.Buffer{}, NewFixedSplitIterator(3))
rf, ok := splitWriter.(io.ReaderFrom)
require.True(t, ok)

Expand All @@ -118,8 +178,20 @@ func TestReadFrom(t *testing.T) {
require.Equal(t, [][]byte{[]byte("Request2")}, cr.reads)
}

func TestReadFrom_Multi(t *testing.T) {
splitWriter := NewWriter(&bytes.Buffer{}, NewRepeatedSplitIterator(RepeatedSplit{1, 1}, RepeatedSplit{3, 2}, RepeatedSplit{2, 3}))
rf, ok := splitWriter.(io.ReaderFrom)
require.True(t, ok)

cr := &collectReader{Reader: bytes.NewReader([]byte("RequestRequestRequest"))}
n, err := rf.ReadFrom(cr)
require.NoError(t, err)
require.Equal(t, int64(21), n)
require.Equal(t, [][]byte{[]byte("R"), []byte("eq"), []byte("ue"), []byte("st"), []byte("Req"), []byte("ues"), []byte("tRequest")}, cr.reads)
}

func TestReadFrom_ShortRead(t *testing.T) {
splitWriter := NewWriter(&bytes.Buffer{}, 10)
splitWriter := NewWriter(&bytes.Buffer{}, NewFixedSplitIterator(10))
rf, ok := splitWriter.(io.ReaderFrom)
require.True(t, ok)
cr := &collectReader{Reader: bytes.NewReader([]byte("Request1"))}
Expand All @@ -138,7 +210,7 @@ func TestReadFrom_ShortRead(t *testing.T) {
func BenchmarkReadFrom(b *testing.B) {
for n := 0; n < b.N; n++ {
reader := bytes.NewReader(make([]byte, n))
writer := NewWriter(io.Discard, 10)
writer := NewWriter(io.Discard, NewFixedSplitIterator(10))
rf, ok := writer.(io.ReaderFrom)
require.True(b, ok)
_, err := rf.ReadFrom(reader)
Expand Down
4 changes: 2 additions & 2 deletions x/configurl/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ These strategies manipulate packets to bypass SNI-based blocking.
Stream split transport (streams only, package [github.com/Jigsaw-Code/outline-sdk/transport/split])
It takes the length of the prefix. The stream will be split when PREFIX_LENGTH bytes are first written.
It takes a list of count*length pairs meaning splitting the sequence in count segments of the given length. If you omit "[COUNT]*", it's assumed to be 1.
split:[PREFIX_LENGTH]
split:[COUNT1]*[LENGTH1],[COUNT2]*[LENGTH2],...
TLS fragmentation (streams only, package [github.com/Jigsaw-Code/outline-sdk/transport/tlsfrag]).
Expand Down
33 changes: 28 additions & 5 deletions x/configurl/split.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"fmt"
"strconv"
"strings"

"github.com/Jigsaw-Code/outline-sdk/transport"
"github.com/Jigsaw-Code/outline-sdk/transport/split"
Expand All @@ -29,11 +30,33 @@ func registerSplitStreamDialer(r TypeRegistry[transport.StreamDialer], typeID st
if err != nil {
return nil, err
}
prefixBytesStr := config.URL.Opaque
prefixBytes, err := strconv.Atoi(prefixBytesStr)
if err != nil {
return nil, fmt.Errorf("prefixBytes is not a number: %v. Split config should be in split:<number> format", prefixBytesStr)
configText := config.URL.Opaque
splits := make([]split.RepeatedSplit, 0)
for _, part := range strings.Split(configText, ",") {
var count int
var bytes int64
subparts := strings.Split(strings.TrimSpace(part), "*")
switch len(subparts) {
case 1:
count = 1
bytes, err = strconv.ParseInt(subparts[0], 10, 64)
if err != nil {
return nil, fmt.Errorf("bytes is not a number: %v", subparts[0])
}
case 2:
count, err = strconv.Atoi(subparts[0])
if err != nil {
return nil, fmt.Errorf("count is not a number: %v", subparts[0])
}
bytes, err = strconv.ParseInt(subparts[1], 10, 64)
if err != nil {
return nil, fmt.Errorf("bytes is not a number: %v", subparts[1])
}
default:
return nil, fmt.Errorf("split format must be a comma-separated list of '[$COUNT*]$BYTES' (e.g. '100,5*2'). Got %v", part)
}
splits = append(splits, split.RepeatedSplit{Count: count, Bytes: bytes})
}
return split.NewStreamDialer(sd, int64(prefixBytes))
return split.NewStreamDialer(sd, split.NewRepeatedSplitIterator(splits...))
})
}
2 changes: 1 addition & 1 deletion x/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module github.com/Jigsaw-Code/outline-sdk/x
go 1.22

require (
github.com/Jigsaw-Code/outline-sdk v0.0.17
github.com/Jigsaw-Code/outline-sdk v0.0.18-0.20241106233708-faffebb12629
// Use github.com/Psiphon-Labs/psiphon-tunnel-core@staging-client as per
// https://github.com/Psiphon-Labs/psiphon-tunnel-core/?tab=readme-ov-file#using-psiphon-with-go-modules
github.com/Psiphon-Labs/psiphon-tunnel-core v1.0.11-0.20240619172145-03cade11f647
Expand Down
Loading

0 comments on commit dc0a0b0

Please sign in to comment.