Skip to content

Commit

Permalink
Merge pull request #2872 from kevinGC:ipt-skip-prerouting
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 315041419
  • Loading branch information
gvisor-bot committed Jun 6, 2020
2 parents 21b6bc7 + 74a7d76 commit 427d208
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 56 deletions.
49 changes: 20 additions & 29 deletions pkg/tcpip/network/ipv4/ipv4.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,38 +258,24 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return nil
}

// If the packet is manipulated as per NAT Ouput rules, handle packet
// based on destination address and do not send the packet to link layer.
// TODO(gvisor.dev/issue/170): We should do this for every packet, rather than
// only NATted packets, but removing this check short circuits broadcasts
// before they are sent out to other hosts.
if pkt.NatDone {
// If the packet is manipulated as per NAT Ouput rules, handle packet
// based on destination address and do not send the packet to link layer.
netHeader := header.IPv4(pkt.NetworkHeader)
ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
if err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)

views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
views[0] = pkt.Header.View()
views = append(views, pkt.Data.Views()...)
ep.HandlePacket(&route, &stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
})
route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
handleLoopback(&route, pkt, ep)
return nil
}
}

if r.Loop&stack.PacketLoop != 0 {
// The inbound path expects the network header to still be in
// the PacketBuffer's Data field.
views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
views[0] = pkt.Header.View()
views = append(views, pkt.Data.Views()...)
loopedR := r.MakeLoopedRoute()

e.HandlePacket(&loopedR, &stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
})

handleLoopback(&loopedR, pkt, e)
loopedR.Release()
}
if r.Loop&stack.PacketOut == 0 {
Expand All @@ -305,6 +291,17 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return nil
}

func handleLoopback(route *stack.Route, pkt *stack.PacketBuffer, ep stack.NetworkEndpoint) {
// The inbound path expects the network header to still be in
// the PacketBuffer's Data field.
views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
views[0] = pkt.Header.View()
views = append(views, pkt.Data.Views()...)
ep.HandlePacket(route, &stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
})
}

// WritePackets implements stack.NetworkEndpoint.WritePackets.
func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
if r.Loop&stack.PacketLoop != 0 {
Expand Down Expand Up @@ -347,13 +344,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)

views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
views[0] = pkt.Header.View()
views = append(views, pkt.Data.Views()...)
ep.HandlePacket(&route, &stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
})
handleLoopback(&route, pkt, ep)
n++
continue
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/tcpip/stack/nic.go
Original file line number Diff line number Diff line change
Expand Up @@ -1229,7 +1229,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}

// TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet.
if protocol == header.IPv4ProtocolNumber {
// Loopback traffic skips the prerouting chain.
if protocol == header.IPv4ProtocolNumber && !n.isLoopback() {
// iptables filtering.
ipt := n.stack.IPTables()
address := n.primaryAddress(protocol)
Expand Down
4 changes: 4 additions & 0 deletions test/iptables/iptables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,10 @@ func TestNATRedirectRequiresProtocol(t *testing.T) {
singleTest(t, NATRedirectRequiresProtocol{})
}

func TestNATLoopbackSkipsPrerouting(t *testing.T) {
singleTest(t, NATLoopbackSkipsPrerouting{})
}

func TestInputSource(t *testing.T) {
singleTest(t, FilterInputSource{})
}
Expand Down
89 changes: 63 additions & 26 deletions test/iptables/nat.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ func init() {
RegisterTestCase(NATOutDontRedirectIP{})
RegisterTestCase(NATOutRedirectInvert{})
RegisterTestCase(NATRedirectRequiresProtocol{})
RegisterTestCase(NATLoopbackSkipsPrerouting{})
}

// NATPreRedirectUDPPort tests that packets are redirected to different port.
Expand Down Expand Up @@ -326,32 +327,6 @@ func (NATRedirectRequiresProtocol) LocalAction(ip net.IP) error {
return nil
}

// loopbackTests runs an iptables rule and ensures that packets sent to
// dest:dropPort are received by localhost:acceptPort.
func loopbackTest(dest net.IP, args ...string) error {
if err := natTable(args...); err != nil {
return err
}
sendCh := make(chan error)
listenCh := make(chan error)
go func() {
sendCh <- sendUDPLoop(dest, dropPort, sendloopDuration)
}()
go func() {
listenCh <- listenUDP(acceptPort, sendloopDuration)
}()
select {
case err := <-listenCh:
if err != nil {
return err
}
case <-time.After(sendloopDuration):
return errors.New("timed out")
}
// sendCh will always take the full sendloop time.
return <-sendCh
}

// NATOutRedirectTCPPort tests that connections are redirected on specified ports.
type NATOutRedirectTCPPort struct{}

Expand Down Expand Up @@ -400,3 +375,65 @@ func (NATOutRedirectTCPPort) ContainerAction(ip net.IP) error {
func (NATOutRedirectTCPPort) LocalAction(ip net.IP) error {
return nil
}

// NATLoopbackSkipsPrerouting tests that packets sent via loopback aren't
// affected by PREROUTING rules.
type NATLoopbackSkipsPrerouting struct{}

// Name implements TestCase.Name.
func (NATLoopbackSkipsPrerouting) Name() string {
return "NATLoopbackSkipsPrerouting"
}

// ContainerAction implements TestCase.ContainerAction.
func (NATLoopbackSkipsPrerouting) ContainerAction(ip net.IP) error {
// Redirect anything sent to localhost to an unused port.
dest := []byte{127, 0, 0, 1}
if err := natTable("-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil {
return err
}

// Establish a connection via localhost. If the PREROUTING rule did apply to
// loopback traffic, the connection would fail.
sendCh := make(chan error)
go func() {
sendCh <- connectTCP(dest, acceptPort, sendloopDuration)
}()

if err := listenTCP(acceptPort, sendloopDuration); err != nil {
return err
}
return <-sendCh
}

// LocalAction implements TestCase.LocalAction.
func (NATLoopbackSkipsPrerouting) LocalAction(ip net.IP) error {
// No-op.
return nil
}

// loopbackTests runs an iptables rule and ensures that packets sent to
// dest:dropPort are received by localhost:acceptPort.
func loopbackTest(dest net.IP, args ...string) error {
if err := natTable(args...); err != nil {
return err
}
sendCh := make(chan error)
listenCh := make(chan error)
go func() {
sendCh <- sendUDPLoop(dest, dropPort, sendloopDuration)
}()
go func() {
listenCh <- listenUDP(acceptPort, sendloopDuration)
}()
select {
case err := <-listenCh:
if err != nil {
return err
}
case <-time.After(sendloopDuration):
return errors.New("timed out")
}
// sendCh will always take the full sendloop time.
return <-sendCh
}

0 comments on commit 427d208

Please sign in to comment.