diff --git a/x/configurl/fake.go b/x/configurl/fake.go new file mode 100644 index 00000000..66077a87 --- /dev/null +++ b/x/configurl/fake.go @@ -0,0 +1,80 @@ +// Copyright 2024 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package configurl + +import ( + "context" + "fmt" + "github.com/Jigsaw-Code/outline-sdk/x/fake" + "strconv" + + "github.com/Jigsaw-Code/outline-sdk/transport" +) + +// tls wikipedia request +var tlsData = [517]byte{ + 0x16, 0x03, 0x01, 0x02, 0x00, 0x01, 0x00, 0x01, 0xfc, 0x03, 0x03, 0x03, 0x5f, + 0x6f, 0x2c, 0xed, 0x13, 0x22, 0xf8, 0xdc, 0xb2, 0xf2, 0x60, 0x48, 0x2d, 0x72, + 0x66, 0x6f, 0x57, 0xdd, 0x13, 0x9d, 0x1b, 0x37, 0xdc, 0xfa, 0x36, 0x2e, 0xba, + 0xf9, 0x92, 0x99, 0x3a, 0x20, 0xf9, 0xdf, 0x0c, 0x2e, 0x8a, 0x55, 0x89, 0x82, + 0x31, 0x63, 0x1a, 0xef, 0xa8, 0xbe, 0x08, 0x58, 0xa7, 0xa3, 0x5a, 0x18, 0xd3, + 0x96, 0x5f, 0x04, 0x5c, 0xb4, 0x62, 0xaf, 0x89, 0xd7, 0x0f, 0x8b, 0x00, 0x3e, + 0x13, 0x02, 0x13, 0x03, 0x13, 0x01, 0xc0, 0x2c, 0xc0, 0x30, 0x00, 0x9f, 0xcc, + 0xa9, 0xcc, 0xa8, 0xcc, 0xaa, 0xc0, 0x2b, 0xc0, 0x2f, 0x00, 0x9e, 0xc0, 0x24, + 0xc0, 0x28, 0x00, 0x6b, 0xc0, 0x23, 0xc0, 0x27, 0x00, 0x67, 0xc0, 0x0a, 0xc0, + 0x14, 0x00, 0x39, 0xc0, 0x09, 0xc0, 0x13, 0x00, 0x33, 0x00, 0x9d, 0x00, 0x9c, + 0x00, 0x3d, 0x00, 0x3c, 0x00, 0x35, 0x00, 0x2f, 0x00, 0xff, 0x01, 0x00, 0x01, + 0x75, 0x00, 0x00, 0x00, 0x16, 0x00, 0x14, 0x00, 0x00, 0x11, 0x77, 0x77, 0x77, + 0x2e, 0x77, 0x69, 0x6b, 0x69, 0x70, 0x65, 0x64, 0x69, 0x61, 0x2e, 0x6f, 0x72, + 0x67, 0x00, 0x0b, 0x00, 0x04, 0x03, 0x00, 0x01, 0x02, 0x00, 0x0a, 0x00, 0x16, + 0x00, 0x14, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x1e, 0x00, 0x19, 0x00, 0x18, 0x01, + 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x10, 0x00, 0x0e, + 0x00, 0x0c, 0x02, 0x68, 0x32, 0x08, 0x68, 0x74, 0x74, 0x70, 0x2f, 0x31, 0x2e, + 0x31, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x31, 0x00, 0x00, + 0x00, 0x0d, 0x00, 0x2a, 0x00, 0x28, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, + 0x07, 0x08, 0x08, 0x08, 0x09, 0x08, 0x0a, 0x08, 0x0b, 0x08, 0x04, 0x08, 0x05, + 0x08, 0x06, 0x04, 0x01, 0x05, 0x01, 0x06, 0x01, 0x03, 0x03, 0x03, 0x01, 0x03, + 0x02, 0x04, 0x02, 0x05, 0x02, 0x06, 0x02, 0x00, 0x2b, 0x00, 0x09, 0x08, 0x03, + 0x04, 0x03, 0x03, 0x03, 0x02, 0x03, 0x01, 0x00, 0x2d, 0x00, 0x02, 0x01, 0x01, + 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x11, 0x8c, 0xb8, + 0x8c, 0xe8, 0x8a, 0x08, 0x90, 0x1e, 0xee, 0x19, 0xd9, 0xdd, 0xe8, 0xd4, 0x06, + 0xb1, 0xd1, 0xe2, 0xab, 0xe0, 0x16, 0x63, 0xd6, 0xdc, 0xda, 0x84, 0xa4, 0xb8, + 0x4b, 0xfb, 0x0e, 0x00, 0x15, 0x00, 0xac, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +} +var httpData = []byte("GET / HTTP/1.1\r\nHost: www.wikipedia.org\r\n\r\n") +var udpData [64]byte + +func registerFakeStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + sd, err := newSD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + prefixBytesStr := config.URL.Opaque + prefixBytes, err := strconv.Atoi(prefixBytesStr) + if err != nil { + return nil, fmt.Errorf("prefixBytes is not a number: %v. Fake config should be in fake: format", prefixBytesStr) + } + // TODO: Read fake data from the CLI or use a default value (depending on the protocol). + var fakeData []byte + // TODO: Read fake offset from the CLI + var fakeOffset int64 = 0 + // TODO: Read fake TTL from the CLI or use a default value (8) + var fakeTtl int = 8 + // TODO: Read md5 signature from the CLI or use a default value (false). + var md5Sig bool = false + return fake.NewStreamDialer(sd, int64(prefixBytes), fakeData, fakeOffset, fakeTtl, md5Sig) + }) +} diff --git a/x/configurl/module.go b/x/configurl/module.go index 83e14b89..2ebac2b5 100644 --- a/x/configurl/module.go +++ b/x/configurl/module.go @@ -53,6 +53,7 @@ func RegisterDefaultProviders(c *ProviderContainer) *ProviderContainer { registerSOCKS5PacketListener(&c.PacketListeners, "socks5", c.StreamDialers.NewInstance, c.PacketDialers.NewInstance) registerSplitStreamDialer(&c.StreamDialers, "split", c.StreamDialers.NewInstance) + registerFakeStreamDialer(&c.StreamDialers, "fake", c.StreamDialers.NewInstance) registerShadowsocksStreamDialer(&c.StreamDialers, "ss", c.StreamDialers.NewInstance) registerShadowsocksPacketDialer(&c.PacketDialers, "ss", c.PacketDialers.NewInstance) diff --git a/x/fake/signature/signature.go b/x/fake/signature/signature.go new file mode 100644 index 00000000..0abc3166 --- /dev/null +++ b/x/fake/signature/signature.go @@ -0,0 +1,79 @@ +package signature + +import ( + "fmt" + "golang.org/x/sys/unix" + "net" + "unsafe" +) + +const socketFlag = 14 + +type signature struct { + Addr [16]byte + Len uint16 + Flags uint16 + Key [80]byte +} + +func Add(conn *net.TCPConn, remoteAddr string, data string) error { + ip := net.ParseIP(remoteAddr) + if ip == nil { + return fmt.Errorf("invalid remote IP address: %s", remoteAddr) + } + + address, err := ip.To16().MarshalText() + if err != nil { + return fmt.Errorf("failed to marshal IP address: %w", err) + } + + key := []byte(data) + + sig := signature{ + Addr: [16]byte(address), + Len: uint16(len(data)), + Key: [80]byte(key), + } + + if err := setOption(conn, sig); err != nil { + return fmt.Errorf("failed to set socket option: %w", err) + } + + return nil +} + +func setOption(conn *net.TCPConn, md5sig signature) error { + file, err := conn.File() + if err != nil { + return fmt.Errorf("failed to get file descriptor: %w", err) + } + defer file.Close() + + size := unsafe.Sizeof(md5sig) + buffer := (*[unsafe.Sizeof(md5sig)]byte)(unsafe.Pointer(&md5sig))[:size] + fd := int(file.Fd()) + + err = unix.SetsockoptString(fd, unix.IPPROTO_TCP, socketFlag, string(buffer)) + if err != nil { + return fmt.Errorf("failed to set TCP_MD5SIG: %w", err) + } + + return nil +} + +func Remove(conn *net.TCPConn) error { + file, err := conn.File() + if err != nil { + return fmt.Errorf("failed to get underlying file descriptor: %w", err) + } + defer file.Close() + + fd := int(file.Fd()) + + err = unix.SetsockoptString(fd, unix.IPPROTO_TCP, socketFlag, "") + if err != nil { + return fmt.Errorf("failed to clear TCP_MD5SIG: %w", err) + } + + return nil +} diff --git a/x/fake/stream_dialer.go b/x/fake/stream_dialer.go new file mode 100644 index 00000000..e3ceb81c --- /dev/null +++ b/x/fake/stream_dialer.go @@ -0,0 +1,74 @@ +// Copyright 2023 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fake + +import ( + "context" + "errors" + "fmt" + "github.com/Jigsaw-Code/outline-sdk/x/fake/signature" + "net" + + "github.com/Jigsaw-Code/outline-sdk/transport" +) + +type fakeDialer struct { + dialer transport.StreamDialer + splitPoint int64 + fakeData []byte + fakeOffset int64 + fakeTtl int + md5Sig bool +} + +var _ transport.StreamDialer = (*fakeDialer)(nil) + +// NewStreamDialer creates a [transport.StreamDialer] that writes "fakeData" in the beginning of the stream and +// then splits the outgoing stream after writing "fakeBytes" bytes using [FakeWriter]. +func NewStreamDialer( + dialer transport.StreamDialer, + prefixBytes int64, + fakeData []byte, + fakeOffset int64, + fakeTtl int, + md5Sig bool, +) (transport.StreamDialer, error) { + if dialer == nil { + return nil, errors.New("argument dialer must not be nil") + } + return &fakeDialer{ + dialer: dialer, + splitPoint: prefixBytes, + fakeData: fakeData, + fakeOffset: fakeOffset, + fakeTtl: fakeTtl, + md5Sig: md5Sig, + }, nil +} + +// DialStream implements [transport.StreamDialer].DialStream. +func (d *fakeDialer) DialStream(ctx context.Context, remoteAddr string) (transport.StreamConn, error) { + innerConn, err := d.dialer.DialStream(ctx, remoteAddr) + if err != nil { + return nil, err + } + if tcpInnerConn, isTcp := innerConn.(*net.TCPConn); isTcp && d.md5Sig { + err := signature.Add(tcpInnerConn, remoteAddr, tcpInnerConn.RemoteAddr().String()) + if err != nil { + return nil, fmt.Errorf("failed to add MD5 signature: %w", err) + } + } + return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.splitPoint, d.fakeData, d.fakeOffset, d.fakeTtl)), nil +} diff --git a/x/fake/writer.go b/x/fake/writer.go new file mode 100644 index 00000000..d72539bb --- /dev/null +++ b/x/fake/writer.go @@ -0,0 +1,105 @@ +package fake + +import ( + "bytes" + "fmt" + "github.com/Jigsaw-Code/outline-sdk/x/ttl" + "io" + "net" +) + +type fakeWriter struct { + writer io.Writer + fakeBytes int64 + fakeData []byte + fakeOffset int64 + ttl int +} + +var _ io.Writer = (*fakeWriter)(nil) + +type fakeWriterReaderFrom struct { + *fakeWriter + rf io.ReaderFrom +} + +var _ io.ReaderFrom = (*fakeWriterReaderFrom)(nil) + +// NewWriter creates a [io.Writer] that ensures the fake data is written before the real data. +// A write will end right after byte index fakeBytes - 1, before a write starting at byte index fakeBytes. +// For example, if you have a write of [0123456789], fakeData = [abc], fakeOffset = 1, and fakeBytes = 3, +// you will get writes [bc] and [0123456789]. If the input writer is a [io.ReaderFrom], the output writer will be too. +func NewWriter(writer io.Writer, fakeBytes int64, fakeData []byte, fakeOffset int64, fakeTtl int) io.Writer { + sw := &fakeWriter{writer, fakeBytes, fakeData, fakeOffset, fakeTtl} + if rf, ok := writer.(io.ReaderFrom); ok { + return &fakeWriterReaderFrom{sw, rf} + } + return sw +} + +func (w *fakeWriterReaderFrom) ReadFrom(source io.Reader) (written int64, err error) { + conn, isNetConn := w.writer.(net.Conn) + fakeData := w.getFakeData() + if fakeData != nil { + if isNetConn { + oldTtl, err := ttl.Set(conn, w.ttl) + if err != nil { + return written, fmt.Errorf("failed to set TTL before writing fake data: %w", err) + } + defer func() { + if _, err = ttl.Set(conn, oldTtl); err != nil { + err = fmt.Errorf("failed to restore TTL after writing fake data: %w", err) + } + }() + } + fakeN, err := w.rf.ReadFrom(bytes.NewReader(fakeData)) + written += fakeN + if err != nil { + return written, err + } + } + reader := io.MultiReader(io.LimitReader(source, w.fakeBytes), source) + n, err := w.rf.ReadFrom(reader) + written += n + return written, err +} + +func (w *fakeWriter) Write(data []byte) (written int, err error) { + conn, isNetConn := w.writer.(net.Conn) + fakeData := w.getFakeData() + if fakeData != nil { + if isNetConn { + oldTtl, err := ttl.Set(conn, w.ttl) + if err != nil { + return written, fmt.Errorf("failed to set TTL before writing fake data: %w", err) + } + defer func() { + if _, err = ttl.Set(conn, oldTtl); err != nil { + err = fmt.Errorf("failed to restore TTL after writing fake data: %w", err) + } + }() + } + fakeN, err := w.writer.Write(fakeData) + written += fakeN + if err != nil { + return written, err + } + } + n, err := w.writer.Write(data) + written += n + return written, err +} + +func (w *fakeWriter) getFakeData() []byte { + if w.fakeOffset >= int64(len(w.fakeData)) { + return nil + } + data := w.fakeData[w.fakeOffset:] + if w.fakeBytes < int64(len(data)) { + data = data[:w.fakeBytes] + } + if len(data) == 0 { + return nil + } + return data +} diff --git a/x/fake/writer_test.go b/x/fake/writer_test.go new file mode 100644 index 00000000..8d048e11 --- /dev/null +++ b/x/fake/writer_test.go @@ -0,0 +1,268 @@ +package fake + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/require" +) + +// collectWrites is an [io.Writer] that appends each write to the writes slice. +type collectWrites struct { + writes [][]byte +} + +var _ io.Writer = (*collectWrites)(nil) + +// Write appends a copy of the data to the writes slice. +func (w *collectWrites) Write(data []byte) (int, error) { + dataCopy := make([]byte, len(data)) + copy(dataCopy, data) + w.writes = append(w.writes, dataCopy) + return len(data), nil +} + +// collectReader is an [io.Reader] that appends each Read from the Reader to the reads slice. +type collectReader struct { + io.Reader + reads [][]byte +} + +func (r *collectReader) Read(buf []byte) (int, error) { + n, err := r.Reader.Read(buf) + if n > 0 { + read := make([]byte, n) + copy(read, buf[:n]) + r.reads = append(r.reads, read) + } + return n, err +} + +func TestWrite_FullFake(t *testing.T) { + var innerWriter collectWrites + fakeData := []byte("Fake data") // 9 bytes + fakeBytes := int64(len(fakeData)) // 9 + fakeOffset := int64(0) + fakeWriter := NewWriter(&innerWriter, fakeBytes, fakeData, fakeOffset, 0) + n, err := fakeWriter.Write([]byte("Request")) // 7 bytes + require.NoError(t, err) + require.Equal(t, 16, n) // 9 fake + 7 real + require.Equal(t, [][]byte{[]byte("Fake data"), []byte("Request")}, innerWriter.writes) +} + +func TestWrite_PartialFake(t *testing.T) { + var innerWriter collectWrites + fakeData := []byte("Fake data") // 9 bytes + fakeBytes := int64(5) // Inject first 5 bytes: "Fake " + fakeOffset := int64(0) + fakeWriter := NewWriter(&innerWriter, fakeBytes, fakeData, fakeOffset, 0) + n, err := fakeWriter.Write([]byte("Request")) // 7 bytes + require.NoError(t, err) + require.Equal(t, 12, n) // 5 fake + 7 real + require.Equal(t, [][]byte{[]byte("Fake "), []byte("Request")}, innerWriter.writes) +} + +func TestWrite_NoFake(t *testing.T) { + var innerWriter collectWrites + fakeData := []byte("Fake data") // 9 bytes + fakeBytes := int64(0) // No fake data + fakeOffset := int64(0) + fakeWriter := NewWriter(&innerWriter, fakeBytes, fakeData, fakeOffset, 0) + n, err := fakeWriter.Write([]byte("Request")) // 7 bytes + require.NoError(t, err) + require.Equal(t, 7, n) // 0 fake + 7 real + require.Equal(t, [][]byte{[]byte("Request")}, innerWriter.writes) +} + +func TestWrite_WithOffset(t *testing.T) { + var innerWriter collectWrites + fakeData := []byte("Fake data") // 9 bytes + fakeBytes := int64(4) // Inject 4 bytes starting from offset + fakeOffset := int64(5) // fakeData[5:] = "data" + fakeWriter := NewWriter(&innerWriter, fakeBytes, fakeData, fakeOffset, 0) + n, err := fakeWriter.Write([]byte("Request")) // 7 bytes + require.NoError(t, err) + require.Equal(t, 11, n) // 4 fake + 7 real + require.Equal(t, [][]byte{[]byte("data"), []byte("Request")}, innerWriter.writes) +} + +func TestWrite_NeedsTwoWrites(t *testing.T) { + var innerWriter collectWrites + fakeData := []byte("Fake data") // 9 bytes + fakeBytes := int64(6) // Inject first 6 bytes: "Fake d" + fakeOffset := int64(0) + fakeWriter := NewWriter(&innerWriter, fakeBytes, fakeData, fakeOffset, 0) + n, err := fakeWriter.Write([]byte("Request")) // 7 bytes + require.NoError(t, err) + require.Equal(t, 13, n) // 6 fake + 7 real + require.Equal(t, [][]byte{[]byte("Fake d"), []byte("Request")}, innerWriter.writes) +} + +func TestWrite_Compound(t *testing.T) { + var innerWriter collectWrites + // First fakeWriter: fakeBytes=1, fakeData="F" + fakeData1 := []byte("F") + fakeBytes1 := int64(1) + fakeOffset1 := int64(0) + fakeTtl1 := 0 + writer1 := NewWriter(&innerWriter, fakeBytes1, fakeData1, fakeOffset1, fakeTtl1) + + // Second fakeWriter: fakeBytes=3, fakeData="ake d", fakeOffset=0 + fakeData2 := []byte("ake") // Total fakeData now: "Fake d" + fakeBytes2 := int64(3) + fakeOffset2 := int64(0) + fakeTtl2 := 0 + fakeWriter := NewWriter(writer1, fakeBytes2, fakeData2, fakeOffset2, fakeTtl2) + + // Write "Request" + n, err := fakeWriter.Write([]byte("Request")) // 7 bytes + require.NoError(t, err) + require.Equal(t, 12, n) // 1 fake + 3 fake + 1 fake + 7 real (Note: total fake data is 5, real data is 7) + require.Equal(t, [][]byte{[]byte("F"), []byte("ake"), []byte("F"), []byte("Request")}, innerWriter.writes) +} + +func TestReadFrom_FullFake(t *testing.T) { + fakeData := []byte("Fake data") // 9 bytes + fakeBytes := int64(9) // Inject all fake data + fakeOffset := int64(0) + fakeWriter := NewWriter(&bytes.Buffer{}, fakeBytes, fakeData, fakeOffset, 0) + rf, ok := fakeWriter.(io.ReaderFrom) + require.True(t, ok) + + cr := &collectReader{Reader: bytes.NewReader([]byte("Request"))} // 7 bytes + n, err := rf.ReadFrom(cr) + require.NoError(t, err) + require.Equal(t, int64(16), n) // 9 fake + 7 real +} + +func TestReadFrom_PartialFake(t *testing.T) { + fakeData := []byte("Fake data") // 9 bytes + fakeBytes := int64(5) // Inject first 5 bytes: "Fake " + fakeOffset := int64(0) + fakeWriter := NewWriter(&bytes.Buffer{}, fakeBytes, fakeData, fakeOffset, 0) + rf, ok := fakeWriter.(io.ReaderFrom) + require.True(t, ok) + + cr := &collectReader{Reader: bytes.NewReader([]byte("Request"))} // 7 bytes + n, err := rf.ReadFrom(cr) + require.NoError(t, err) + require.Equal(t, int64(12), n) // 5 fake + 7 real +} + +func TestReadFrom_NoFake(t *testing.T) { + fakeData := []byte("Fake data") // 9 bytes + fakeBytes := int64(0) // No fake data + fakeOffset := int64(0) + fakeWriter := NewWriter(&bytes.Buffer{}, fakeBytes, fakeData, fakeOffset, 0) + rf, ok := fakeWriter.(io.ReaderFrom) + require.True(t, ok) + + cr := &collectReader{Reader: bytes.NewReader([]byte("Request"))} // 7 bytes + n, err := rf.ReadFrom(cr) + require.NoError(t, err) + require.Equal(t, int64(7), n) // 0 fake + 7 real +} + +func TestReadFrom_WithOffset(t *testing.T) { + fakeData := []byte("Fake data") // 9 bytes + fakeBytes := int64(4) // Inject 4 bytes starting from offset + fakeOffset := int64(5) // fakeData[5:] = "data" + fakeWriter := NewWriter(&bytes.Buffer{}, fakeBytes, fakeData, fakeOffset, 0) + rf, ok := fakeWriter.(io.ReaderFrom) + require.True(t, ok) + + cr := &collectReader{Reader: bytes.NewReader([]byte("Request"))} // 7 bytes + n, err := rf.ReadFrom(cr) + require.NoError(t, err) + require.Equal(t, int64(11), n) // 4 fake + 7 real +} + +func TestReadFrom_NeedsTwoReads(t *testing.T) { + fakeData := []byte("Fake data") // 9 bytes + fakeBytes := int64(6) // Inject first 6 bytes: "Fake d" + fakeOffset := int64(0) + fakeWriter := NewWriter(&bytes.Buffer{}, fakeBytes, fakeData, fakeOffset, 0) + rf, ok := fakeWriter.(io.ReaderFrom) + require.True(t, ok) + + // First ReadFrom with "Request1" (8 bytes) + cr1 := &collectReader{Reader: bytes.NewReader([]byte("Request1"))} // 8 bytes + n1, err1 := rf.ReadFrom(cr1) + require.NoError(t, err1) + require.Equal(t, int64(6+8), n1) // 6 fake + 8 real + + // Second ReadFrom with "Request2" (8 bytes) + cr2 := &collectReader{Reader: bytes.NewReader([]byte("Request2"))} // 8 bytes + n2, err2 := rf.ReadFrom(cr2) + require.NoError(t, err2) + require.Equal(t, int64(6+8), n2) // 6 fake + 8 real +} + +func TestReadFrom_Compound(t *testing.T) { + var innerWriter collectWrites + // First fakeWriter: fakeBytes=3, fakeData="Fake " + fakeData1 := []byte("Fake ") + fakeBytes1 := int64(3) + fakeOffset1 := int64(0) + fakeTtl1 := 0 + writer1 := NewWriter(&innerWriter, fakeBytes1, fakeData1, fakeOffset1, fakeTtl1) + + // Second fakeWriter: fakeBytes=5, fakeData="data", fakeOffset=0 + fakeData2 := []byte("data") + fakeBytes2 := int64(5) + fakeOffset2 := int64(0) + fakeTtl2 := 0 + writer2 := NewWriter(writer1, fakeBytes2, fakeData2, fakeOffset2, fakeTtl2) + + n, err := writer2.Write([]byte("Request")) + require.NoError(t, err) + require.Equal(t, 17, n) // 3 fake + 4 fake + 3 fake + 7 real + require.Equal(t, [][]byte{[]byte("Fak"), []byte("data"), []byte("Fak"), []byte("Request")}, innerWriter.writes) +} + +func TestWrite_WithOffsetBeyondFakeData(t *testing.T) { + var innerWriter collectWrites + fakeData := []byte("Fake data") // 9 bytes + fakeBytes := int64(4) // Attempt to inject 4 bytes + fakeOffset := int64(10) // Offset beyond fakeData length + fakeWriter := NewWriter(&innerWriter, fakeBytes, fakeData, fakeOffset, 0) + n, err := fakeWriter.Write([]byte("Request")) // 7 bytes + require.NoError(t, err) + require.Equal(t, 7, n) // 0 fake + 7 real + require.Equal(t, [][]byte{[]byte("Request")}, innerWriter.writes) +} + +func TestReadFrom_WithOffsetBeyondFakeData(t *testing.T) { + fakeData := []byte("Fake data") // 9 bytes + fakeBytes := int64(5) // Attempt to inject 5 bytes + fakeOffset := int64(10) // Offset beyond fakeData length + var buffer bytes.Buffer + fakeWriter := NewWriter(&buffer, fakeBytes, fakeData, fakeOffset, 0) + rf, ok := fakeWriter.(io.ReaderFrom) + require.True(t, ok) + + cr := &collectReader{Reader: bytes.NewReader([]byte("Request"))} // 7 bytes + n, err := rf.ReadFrom(cr) + require.NoError(t, err) + require.Equal(t, int64(7), n) // 0 fake + 7 real +} + +func BenchmarkReadFrom(b *testing.B) { + fakeData := []byte("Fake data") // 9 bytes + fakeBytes := int64(5) // Inject first 5 bytes: "Fake " + fakeOffset := int64(0) + for n := 0; n < b.N; n++ { + reader := bytes.NewReader([]byte("Request")) + var buffer bytes.Buffer + fakeWriter := NewWriter(&buffer, fakeBytes, fakeData, fakeOffset, 0) + rf, ok := fakeWriter.(io.ReaderFrom) + if !ok { + b.Fatalf("Writer does not implement io.ReaderFrom") + } + _, err := rf.ReadFrom(reader) + if err != nil && err != io.EOF { + b.Fatalf("ReadFrom failed: %v", err) + } + } +} diff --git a/x/ttl/ttl.go b/x/ttl/ttl.go new file mode 100644 index 00000000..60817220 --- /dev/null +++ b/x/ttl/ttl.go @@ -0,0 +1,33 @@ +package ttl + +import ( + "fmt" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "net" + "net/netip" +) + +func Set(conn net.Conn, ttl int) (old int, err error) { + addr, err := netip.ParseAddrPort(conn.RemoteAddr().String()) + if err != nil { + return 0, err + } + + switch { + case addr.Addr().Is4(): + conn := ipv4.NewConn(conn) + old, _ = conn.TTL() + if err := conn.SetTTL(ttl); err != nil { + return 0, fmt.Errorf("failed to set TTL: %w", err) + } + case addr.Addr().Is6(): + conn := ipv6.NewConn(conn) + old, _ = conn.HopLimit() + if err := conn.SetHopLimit(ttl); err != nil { + return 0, fmt.Errorf("failed to set hop limit: %w", err) + } + } + + return +}