From 87ade6cd859ea9dbde6e80b3fcf717ed9a73b4a9 Mon Sep 17 00:00:00 2001 From: Yonghong Song Date: Fri, 15 Mar 2024 11:49:10 -0700 Subject: [PATCH] selftests/bpf: Add a cgroup prog bpf_get_ns_current_pid_tgid() test Add a cgroup bpf program test where the bpf program is running in a pid namespace. The test is successfully: #165/3 ns_current_pid_tgid/new_ns_cgrp:OK Signed-off-by: Yonghong Song Signed-off-by: Andrii Nakryiko Link: https://lore.kernel.org/bpf/20240315184910.2976522-1-yonghong.song@linux.dev --- .../bpf/prog_tests/ns_current_pid_tgid.c | 73 +++++++++++++++++++ .../bpf/progs/test_ns_current_pid_tgid.c | 7 ++ 2 files changed, 80 insertions(+) diff --git a/tools/testing/selftests/bpf/prog_tests/ns_current_pid_tgid.c b/tools/testing/selftests/bpf/prog_tests/ns_current_pid_tgid.c index 847d7b70e2902..fded38d24aae4 100644 --- a/tools/testing/selftests/bpf/prog_tests/ns_current_pid_tgid.c +++ b/tools/testing/selftests/bpf/prog_tests/ns_current_pid_tgid.c @@ -12,6 +12,7 @@ #include #include #include +#include "network_helpers.h" #define STACK_SIZE (1024 * 1024) static char child_stack[STACK_SIZE]; @@ -74,6 +75,50 @@ static int test_current_pid_tgid_tp(void *args) return ret; } +static int test_current_pid_tgid_cgrp(void *args) +{ + struct test_ns_current_pid_tgid__bss *bss; + struct test_ns_current_pid_tgid *skel; + int server_fd = -1, ret = -1, err; + int cgroup_fd = *(int *)args; + pid_t tgid, pid; + + skel = test_ns_current_pid_tgid__open(); + if (!ASSERT_OK_PTR(skel, "test_ns_current_pid_tgid__open")) + return ret; + + bpf_program__set_autoload(skel->progs.cgroup_bind4, true); + + err = test_ns_current_pid_tgid__load(skel); + if (!ASSERT_OK(err, "test_ns_current_pid_tgid__load")) + goto cleanup; + + bss = skel->bss; + if (get_pid_tgid(&pid, &tgid, bss)) + goto cleanup; + + skel->links.cgroup_bind4 = bpf_program__attach_cgroup( + skel->progs.cgroup_bind4, cgroup_fd); + if (!ASSERT_OK_PTR(skel->links.cgroup_bind4, "bpf_program__attach_cgroup")) + goto cleanup; + + server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0); + if (!ASSERT_GE(server_fd, 0, "start_server")) + goto cleanup; + + if (!ASSERT_EQ(bss->user_pid, pid, "pid")) + goto cleanup; + if (!ASSERT_EQ(bss->user_tgid, tgid, "tgid")) + goto cleanup; + ret = 0; + +cleanup: + if (server_fd >= 0) + close(server_fd); + test_ns_current_pid_tgid__destroy(skel); + return ret; +} + static void test_ns_current_pid_tgid_new_ns(int (*fn)(void *), void *arg) { int wstatus; @@ -95,6 +140,25 @@ static void test_ns_current_pid_tgid_new_ns(int (*fn)(void *), void *arg) return; } +static void test_in_netns(int (*fn)(void *), void *arg) +{ + struct nstoken *nstoken = NULL; + + SYS(cleanup, "ip netns add ns_current_pid_tgid"); + SYS(cleanup, "ip -net ns_current_pid_tgid link set dev lo up"); + + nstoken = open_netns("ns_current_pid_tgid"); + if (!ASSERT_OK_PTR(nstoken, "open_netns")) + goto cleanup; + + test_ns_current_pid_tgid_new_ns(fn, arg); + +cleanup: + if (nstoken) + close_netns(nstoken); + SYS_NOFAIL("ip netns del ns_current_pid_tgid"); +} + /* TODO: use a different tracepoint */ void serial_test_ns_current_pid_tgid(void) { @@ -102,4 +166,13 @@ void serial_test_ns_current_pid_tgid(void) test_current_pid_tgid_tp(NULL); if (test__start_subtest("new_ns_tp")) test_ns_current_pid_tgid_new_ns(test_current_pid_tgid_tp, NULL); + if (test__start_subtest("new_ns_cgrp")) { + int cgroup_fd = -1; + + cgroup_fd = test__join_cgroup("/sock_addr"); + if (ASSERT_GE(cgroup_fd, 0, "join_cgroup")) { + test_in_netns(test_current_pid_tgid_cgrp, &cgroup_fd); + close(cgroup_fd); + } + } } diff --git a/tools/testing/selftests/bpf/progs/test_ns_current_pid_tgid.c b/tools/testing/selftests/bpf/progs/test_ns_current_pid_tgid.c index aa3ec7ca16d9b..d0010e698f668 100644 --- a/tools/testing/selftests/bpf/progs/test_ns_current_pid_tgid.c +++ b/tools/testing/selftests/bpf/progs/test_ns_current_pid_tgid.c @@ -28,4 +28,11 @@ int tp_handler(const void *ctx) return 0; } +SEC("?cgroup/bind4") +int cgroup_bind4(struct bpf_sock_addr *ctx) +{ + get_pid_tgid(); + return 1; +} + char _license[] SEC("license") = "GPL";