Skip to content

Commit

Permalink
Support mask for XfrmState's output mark
Browse files Browse the repository at this point in the history
XfrmState currently doesn't allow setting the mask for the output mark.
As a result, setting an output mark always clears all bits. This commit
adds support for the mask value.

Signed-off-by: Paul Chaignon <paul@cilium.io>
  • Loading branch information
pchaigno committed Dec 14, 2020
1 parent 88079d9 commit b600385
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 9 deletions.
4 changes: 2 additions & 2 deletions xfrm_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ type XfrmState struct {
Limits XfrmStateLimits
Statistics XfrmStateStats
Mark *XfrmMark
OutputMark int
OutputMark *XfrmMark
Ifid int
Auth *XfrmStateAlgo
Crypt *XfrmStateAlgo
Expand All @@ -104,7 +104,7 @@ type XfrmState struct {
}

func (sa XfrmState) String() string {
return fmt.Sprintf("Dst: %v, Src: %v, Proto: %s, Mode: %s, SPI: 0x%x, ReqID: 0x%x, ReplayWindow: %d, Mark: %v, OutputMark: %d, Ifid: %d, Auth: %v, Crypt: %v, Aead: %v, Encap: %v, ESN: %t",
return fmt.Sprintf("Dst: %v, Src: %v, Proto: %s, Mode: %s, SPI: 0x%x, ReqID: 0x%x, ReplayWindow: %d, Mark: %v, OutputMark: %v, Ifid: %d, Auth: %v, Crypt: %v, Aead: %v, Encap: %v, ESN: %t",
sa.Dst, sa.Src, sa.Proto, sa.Mode, sa.Spi, sa.Reqid, sa.ReplayWindow, sa.Mark, sa.OutputMark, sa.Ifid, sa.Auth, sa.Crypt, sa.Aead, sa.Encap, sa.ESN)
}
func (sa XfrmState) Print(stats bool) string {
Expand Down
23 changes: 19 additions & 4 deletions xfrm_state_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,13 @@ func (h *Handle) xfrmStateAddOrUpdate(state *XfrmState, nlProto int) error {
out := nl.NewRtAttr(nl.XFRMA_REPLAY_ESN_VAL, writeReplayEsn(state.ReplayWindow))
req.AddData(out)
}
if state.OutputMark != 0 {
out := nl.NewRtAttr(nl.XFRMA_OUTPUT_MARK, nl.Uint32Attr(uint32(state.OutputMark)))
if state.OutputMark != nil {
out := nl.NewRtAttr(nl.XFRMA_SET_MARK, nl.Uint32Attr(uint32(state.OutputMark.Value)))
req.AddData(out)
if state.OutputMark.Mask != 0 {
out = nl.NewRtAttr(nl.XFRMA_SET_MARK_MASK, nl.Uint32Attr(uint32(state.OutputMark.Mask)))
req.AddData(out)
}
}

ifId := nl.NewRtAttr(nl.XFRMA_IF_ID, nl.Uint32Attr(uint32(state.Ifid)))
Expand Down Expand Up @@ -377,8 +381,19 @@ func parseXfrmState(m []byte, family int) (*XfrmState, error) {
state.Mark = new(XfrmMark)
state.Mark.Value = mark.Value
state.Mark.Mask = mark.Mask
case nl.XFRMA_OUTPUT_MARK:
state.OutputMark = int(native.Uint32(attr.Value))
case nl.XFRMA_SET_MARK:
if state.OutputMark == nil {
state.OutputMark = new(XfrmMark)
}
state.OutputMark.Value = native.Uint32(attr.Value)
case nl.XFRMA_SET_MARK_MASK:
if state.OutputMark == nil {
state.OutputMark = new(XfrmMark)
}
state.OutputMark.Mask = native.Uint32(attr.Value)
if state.OutputMark.Mask == 0xffffffff {
state.OutputMark.Mask = 0
}
case nl.XFRMA_IF_ID:
state.Ifid = int(native.Uint32(attr.Value))
}
Expand Down
32 changes: 29 additions & 3 deletions xfrm_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,33 @@ func TestXfrmStateWithOutputMark(t *testing.T) {
defer setUpNetlinkTest(t)()

state := getBaseState()
state.OutputMark = 10
state.OutputMark = &XfrmMark{
Value: 0x0000000a,
}
if err := XfrmStateAdd(state); err != nil {
t.Fatal(err)
}
s, err := XfrmStateGet(state)
if err != nil {
t.Fatal(err)
}
if !compareStates(state, s) {
t.Fatalf("unexpected state returned.\nExpected: %v.\nGot %v", state, s)
}
if err = XfrmStateDel(s); err != nil {
t.Fatal(err)
}
}

func TestXfrmStateWithOutputMarkAndMask(t *testing.T) {
minKernelRequired(t, 4, 19)
defer setUpNetlinkTest(t)()

state := getBaseState()
state.OutputMark = &XfrmMark{
Value: 0x0000000a,
Mask: 0x0000000f,
}
if err := XfrmStateAdd(state); err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -294,11 +320,11 @@ func compareStates(a, b *XfrmState) bool {
return a.Src.Equal(b.Src) && a.Dst.Equal(b.Dst) &&
a.Mode == b.Mode && a.Spi == b.Spi && a.Proto == b.Proto &&
a.Ifid == b.Ifid &&
a.OutputMark == b.OutputMark &&
compareAlgo(a.Auth, b.Auth) &&
compareAlgo(a.Crypt, b.Crypt) &&
compareAlgo(a.Aead, b.Aead) &&
compareMarks(a.Mark, b.Mark)
compareMarks(a.Mark, b.Mark) &&
compareMarks(a.OutputMark, b.OutputMark)
}

func compareLimits(a, b *XfrmState) bool {
Expand Down

0 comments on commit b600385

Please sign in to comment.