Skip to content

Commit 95f9aa3

Browse files
committed
feat(semantic): add ability to lookup if AST contains any node kinds
1 parent 700b412 commit 95f9aa3

File tree

5 files changed

+180
-1
lines changed

5 files changed

+180
-1
lines changed

crates/oxc_ast/src/generated/ast_kind.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ use oxc_span::{GetSpan, Span};
1111

1212
use crate::ast::*;
1313

14+
/// The largest integer value that can be mapped to an `AstType`/`AstKind` enum variant.
15+
pub const AST_TYPE_MAX: u8 = 186;
16+
1417
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1518
#[repr(u8)]
1619
pub enum AstType {
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
use oxc_ast::{AstType, ast_kind::AST_TYPE_MAX};
2+
3+
const USIZE_BITS: usize = usize::BITS as usize;
4+
5+
/// Number of bytes required for bit set which can represent all [`AstType`]s.
6+
// Need to add plus one here because 0 is a possible value, but requires at least one bit to represent it.
7+
const NUM_USIZES: usize = (AST_TYPE_MAX as usize + 1).div_ceil(USIZE_BITS);
8+
9+
/// Bit set with a bit for each [`AstType`].
10+
#[derive(Debug, Clone)]
11+
pub struct AstTypesBitset([usize; NUM_USIZES]);
12+
13+
impl AstTypesBitset {
14+
/// Create empty [`AstTypesBitset`] with no bits set.
15+
pub const fn new() -> Self {
16+
Self([0; NUM_USIZES])
17+
}
18+
19+
/// Create a new [`AstTypesBitset`] from a slice of [`AstType`].
20+
pub fn from_types(types: &[AstType]) -> Self {
21+
let mut bitset = Self::new();
22+
for &ty in types {
23+
bitset.set(ty);
24+
}
25+
bitset
26+
}
27+
28+
/// Returns `true` if bit is set for provided [`AstType`].
29+
pub const fn has(&self, ty: AstType) -> bool {
30+
let (index, mask) = Self::index_and_mask(ty);
31+
(self.0[index] & mask) != 0
32+
}
33+
34+
/// Set bit for provided [`AstType`].
35+
pub const fn set(&mut self, ty: AstType) {
36+
let (index, mask) = Self::index_and_mask(ty);
37+
self.0[index] |= mask;
38+
}
39+
40+
/// Returns `true` if any bit is set in both `self` and `other`.
41+
pub fn intersects(&self, other: &Self) -> bool {
42+
let mut intersection = 0;
43+
for (&a, &b) in self.0.iter().zip(other.0.iter()) {
44+
intersection |= a & b;
45+
}
46+
intersection != 0
47+
}
48+
49+
/// Returns `true` if all bits in `other` are set in `self`.
50+
pub fn contains(&self, other: &Self) -> bool {
51+
let mut mismatches = 0;
52+
for (&a, &b) in self.0.iter().zip(other.0.iter()) {
53+
let set_in_both = a & b;
54+
// 0 if `set_in_both == b`
55+
let mismatch = set_in_both ^ b;
56+
mismatches |= mismatch;
57+
}
58+
mismatches == 0
59+
}
60+
61+
/// Get index and mask for an [`AstType`].
62+
/// Returned `index` is guaranteed not to be out of bounds of the array.
63+
const fn index_and_mask(ty: AstType) -> (usize, usize) {
64+
let n = ty as usize;
65+
let index = n / USIZE_BITS;
66+
let mask = 1usize << (n % USIZE_BITS);
67+
(index, mask)
68+
}
69+
}
70+
71+
impl Default for AstTypesBitset {
72+
fn default() -> Self {
73+
Self::new()
74+
}
75+
}
76+
77+
#[cfg(test)]
78+
mod tests {
79+
use super::*;
80+
use oxc_ast::AstType;
81+
82+
#[test]
83+
fn empty_bitset_has_no_bits_and_contains_empty() {
84+
let bs = AstTypesBitset::new();
85+
assert!(!bs.has(AstType::Program));
86+
let other = AstTypesBitset::new();
87+
assert!(bs.contains(&other), "Empty bitset should contain empty bitset");
88+
assert!(!bs.intersects(&other));
89+
}
90+
91+
#[test]
92+
fn intersects_and_contains() {
93+
let mut a = AstTypesBitset::from_types(&[AstType::Program, AstType::AssignmentPattern]);
94+
let b = AstTypesBitset::from_types(&[AstType::TSTupleType]);
95+
assert!(!a.intersects(&b));
96+
a.set(AstType::TSTupleType);
97+
assert!(a.intersects(&b));
98+
99+
let c = AstTypesBitset::from_types(&[AstType::Program]);
100+
assert!(a.contains(&c));
101+
assert!(!c.contains(&a));
102+
103+
// a should contain union of subset bits
104+
let subset = AstTypesBitset::from_types(&[AstType::Program, AstType::AssignmentPattern]);
105+
assert!(a.contains(&subset));
106+
// subset does not contain a (missing TSTupleType)
107+
assert!(!subset.contains(&a));
108+
}
109+
110+
#[test]
111+
fn contains_empty_is_true() {
112+
let non_empty = AstTypesBitset::from_types(&[AstType::IdentifierName]);
113+
let empty = AstTypesBitset::new();
114+
assert!(non_empty.contains(&empty));
115+
}
116+
}

crates/oxc_semantic/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pub use oxc_syntax::{
2222

2323
pub mod dot;
2424

25+
mod ast_types_bitset;
2526
mod binder;
2627
mod builder;
2728
mod checker;

crates/oxc_semantic/src/node.rs

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::iter::FusedIterator;
22

33
use oxc_allocator::{Address, GetAddress};
4-
use oxc_ast::{AstKind, ast::Program};
4+
use oxc_ast::{AstKind, AstType, ast::Program};
55
use oxc_cfg::BlockNodeId;
66
use oxc_index::{IndexSlice, IndexVec};
77
use oxc_span::{GetSpan, Span};
@@ -10,6 +10,8 @@ use oxc_syntax::{
1010
scope::ScopeId,
1111
};
1212

13+
use crate::ast_types_bitset::AstTypesBitset;
14+
1315
/// Semantic node contains all the semantic information about an ast node.
1416
#[derive(Debug, Clone, Copy)]
1517
pub struct AstNode<'a> {
@@ -101,6 +103,10 @@ pub struct AstNodes<'a> {
101103
nodes: IndexVec<NodeId, AstNode<'a>>,
102104
/// `node` -> `parent`
103105
parent_ids: IndexVec<NodeId, NodeId>,
106+
/// Stores a set of bits of a fixed size, where each bit represents a single [`AstKind`]. If the bit is set (1),
107+
/// then the AST contains at least one node of that kind. If the bit is not set (0), then the AST does not contain
108+
/// any nodes of that kind.
109+
pub node_kinds_set: AstTypesBitset,
104110
}
105111

106112
impl<'a> AstNodes<'a> {
@@ -212,6 +218,7 @@ impl<'a> AstNodes<'a> {
212218
let node_id = self.parent_ids.push(parent_node_id);
213219
let node = AstNode::new(kind, scope_id, cfg_id, flags, node_id);
214220
self.nodes.push(node);
221+
self.node_kinds_set.set(kind.ty());
215222
node_id
216223
}
217224

@@ -234,6 +241,7 @@ impl<'a> AstNodes<'a> {
234241
);
235242
self.parent_ids.push(NodeId::ROOT);
236243
self.nodes.push(AstNode::new(kind, scope_id, cfg_id, flags, NodeId::ROOT));
244+
self.node_kinds_set.set(AstType::Program);
237245
NodeId::ROOT
238246
}
239247

@@ -242,6 +250,51 @@ impl<'a> AstNodes<'a> {
242250
self.nodes.reserve(additional);
243251
self.parent_ids.reserve(additional);
244252
}
253+
254+
/// Checks if the AST contains any nodes of the given types.
255+
///
256+
/// Example:
257+
/// ```ignore
258+
/// let for_stmt = AstTypesBitset::from_types(&[AstType::ForStatement]);
259+
/// let import_export_decl = AstTypesBitset::from_types(&[AstType::ImportDeclaration, AstType::ExportDeclaration]);
260+
///
261+
/// // returns true if there is a `for` loop anywhere in the AST.
262+
/// nodes.contains_any(&for_stmt)
263+
/// // returns true if there is at least one import OR one export in the AST.
264+
/// nodes.contains_any(&import_export_decl)
265+
/// ```
266+
pub fn contains_any(&self, bitset: &AstTypesBitset) -> bool {
267+
self.node_kinds_set.intersects(bitset)
268+
}
269+
270+
/// Checks if the AST contains all of the given types.
271+
///
272+
/// Example:
273+
/// ```ignore
274+
/// let for_stmt = AstTypesBitset::from_types(&[AstType::ForStatement]);
275+
/// let import_export_decl = AstTypesBitset::from_types(&[AstType::ImportDeclaration, AstType::ExportDeclaration]);
276+
///
277+
/// // returns true if there is a `for` loop anywhere in the AST.
278+
/// nodes.contains_all(&for_stmt)
279+
/// // returns true only if there is at least one import AND one export in the AST.
280+
/// nodes.contains_all(&import_export_decl)
281+
/// ```
282+
pub fn contains_all(&self, bitset: &AstTypesBitset) -> bool {
283+
self.node_kinds_set.contains(bitset)
284+
}
285+
286+
/// Checks if the AST contains a node of the given type.
287+
///
288+
/// Example:
289+
/// ```ignore
290+
/// // returns true if there is a `for` loop anywhere in the AST.
291+
/// nodes.contains(AstType::ForStatement)
292+
/// // returns true if there is an `ImportDeclaration` anywhere in the AST.
293+
/// nodes.contains(AstType::ImportDeclaration)
294+
/// ```
295+
pub fn contains(&self, ty: AstType) -> bool {
296+
self.node_kinds_set.has(ty)
297+
}
245298
}
246299

247300
impl<'a, 'n> IntoIterator for &'n AstNodes<'a> {

tasks/ast_tools/src/generators/ast_kind.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ impl Generator for AstKindGenerator {
149149
next_index += 1;
150150
}
151151

152+
let ast_type_max = number_lit(next_index - 1);
153+
152154
let output = quote! {
153155
#![expect(missing_docs)] ///@ FIXME (in ast_tools/src/generators/ast_kind.rs)
154156
@@ -162,6 +164,10 @@ impl Generator for AstKindGenerator {
162164
///@@line_break
163165
use crate::ast::*;
164166

167+
///@@line_break
168+
/// The largest integer value that can be mapped to an `AstType`/`AstKind` enum variant.
169+
pub const AST_TYPE_MAX: u8 = #ast_type_max;
170+
165171
///@@line_break
166172
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
167173
#[repr(u8)]

0 commit comments

Comments
 (0)