diff --git a/include/linux/bpf.h b/include/linux/bpf.h index c87c608a36892..53798ab4f19c4 100644 --- a/include/linux/bpf.h +++ b/include/linux/bpf.h @@ -216,6 +216,7 @@ struct btf_field_graph_root { u32 value_btf_id; u32 node_offset; struct btf_record *value_rec; + bool has_btf_ref; }; struct btf_field { diff --git a/include/linux/btf.h b/include/linux/btf.h index 59d404e22814e..f7d4f65943082 100644 --- a/include/linux/btf.h +++ b/include/linux/btf.h @@ -217,7 +217,7 @@ bool btf_member_is_reg_int(const struct btf *btf, const struct btf_type *s, const struct btf_member *m, u32 expected_offset, u32 expected_size); struct btf_record *btf_parse_fields(const struct btf *btf, const struct btf_type *t, - u32 field_mask, u32 value_size); + u32 field_mask, u32 value_size, bool from_map_check); int btf_check_and_fixup_fields(const struct btf *btf, struct btf_record *rec); bool btf_type_is_void(const struct btf_type *t); s32 btf_find_by_name_kind(const struct btf *btf, const char *name, u8 kind); diff --git a/kernel/bpf/btf.c b/kernel/bpf/btf.c index d56433bf8aba1..ea8fb700f823d 100644 --- a/kernel/bpf/btf.c +++ b/kernel/bpf/btf.c @@ -3665,7 +3665,8 @@ static int btf_parse_graph_root(const struct btf *btf, struct btf_field *field, struct btf_field_info *info, const char *node_type_name, - size_t node_type_align) + size_t node_type_align, + bool from_map_check) { const struct btf_type *t, *n = NULL; const struct btf_member *member; @@ -3696,6 +3697,9 @@ static int btf_parse_graph_root(const struct btf *btf, if (offset % node_type_align) return -EINVAL; + if (from_map_check) + btf_get((struct btf *)btf); + field->graph_root.has_btf_ref = from_map_check; field->graph_root.btf = (struct btf *)btf; field->graph_root.value_btf_id = info->graph_root.value_btf_id; field->graph_root.node_offset = offset; @@ -3706,17 +3710,19 @@ static int btf_parse_graph_root(const struct btf *btf, } static int btf_parse_list_head(const struct btf *btf, struct btf_field *field, - struct btf_field_info *info) + struct btf_field_info *info, bool from_map_check) { return btf_parse_graph_root(btf, field, info, "bpf_list_node", - __alignof__(struct bpf_list_node)); + __alignof__(struct bpf_list_node), + from_map_check); } static int btf_parse_rb_root(const struct btf *btf, struct btf_field *field, - struct btf_field_info *info) + struct btf_field_info *info, bool from_map_check) { return btf_parse_graph_root(btf, field, info, "bpf_rb_node", - __alignof__(struct bpf_rb_node)); + __alignof__(struct bpf_rb_node), + from_map_check); } static int btf_field_cmp(const void *_a, const void *_b, const void *priv) @@ -3732,7 +3738,7 @@ static int btf_field_cmp(const void *_a, const void *_b, const void *priv) } struct btf_record *btf_parse_fields(const struct btf *btf, const struct btf_type *t, - u32 field_mask, u32 value_size) + u32 field_mask, u32 value_size, bool from_map_check) { struct btf_field_info info_arr[BTF_FIELDS_MAX]; u32 next_off = 0, field_type_size; @@ -3798,12 +3804,14 @@ struct btf_record *btf_parse_fields(const struct btf *btf, const struct btf_type goto end; break; case BPF_LIST_HEAD: - ret = btf_parse_list_head(btf, &rec->fields[i], &info_arr[i]); + ret = btf_parse_list_head(btf, &rec->fields[i], &info_arr[i], + from_map_check); if (ret < 0) goto end; break; case BPF_RB_ROOT: - ret = btf_parse_rb_root(btf, &rec->fields[i], &info_arr[i]); + ret = btf_parse_rb_root(btf, &rec->fields[i], &info_arr[i], + from_map_check); if (ret < 0) goto end; break; @@ -5390,7 +5398,8 @@ btf_parse_struct_metas(struct bpf_verifier_log *log, struct btf *btf) type = &tab->types[tab->cnt]; type->btf_id = i; record = btf_parse_fields(btf, t, BPF_SPIN_LOCK | BPF_LIST_HEAD | BPF_LIST_NODE | - BPF_RB_ROOT | BPF_RB_NODE | BPF_REFCOUNT, t->size); + BPF_RB_ROOT | BPF_RB_NODE | BPF_REFCOUNT, t->size, + false); /* The record cannot be unset, treat it as an error if so */ if (IS_ERR_OR_NULL(record)) { ret = PTR_ERR_OR_ZERO(record) ?: -EFAULT; diff --git a/kernel/bpf/syscall.c b/kernel/bpf/syscall.c index 06320d9abf339..8eab5ae1871f5 100644 --- a/kernel/bpf/syscall.c +++ b/kernel/bpf/syscall.c @@ -519,8 +519,11 @@ void btf_record_free(struct btf_record *rec) btf_put(rec->fields[i].kptr.btf); break; case BPF_LIST_HEAD: - case BPF_LIST_NODE: case BPF_RB_ROOT: + if (rec->fields[i].graph_root.has_btf_ref) + btf_put(rec->fields[i].graph_root.btf); + break; + case BPF_LIST_NODE: case BPF_RB_NODE: case BPF_SPIN_LOCK: case BPF_TIMER: @@ -568,8 +571,11 @@ struct btf_record *btf_record_dup(const struct btf_record *rec) } break; case BPF_LIST_HEAD: - case BPF_LIST_NODE: case BPF_RB_ROOT: + if (fields[i].graph_root.has_btf_ref) + btf_get(fields[i].graph_root.btf); + break; + case BPF_LIST_NODE: case BPF_RB_NODE: case BPF_SPIN_LOCK: case BPF_TIMER: @@ -1032,7 +1038,7 @@ static int map_check_btf(struct bpf_map *map, struct bpf_token *token, map->record = btf_parse_fields(btf, value_type, BPF_SPIN_LOCK | BPF_TIMER | BPF_KPTR | BPF_LIST_HEAD | BPF_RB_ROOT | BPF_REFCOUNT, - map->value_size); + map->value_size, true); if (!IS_ERR_OR_NULL(map->record)) { int i;