Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ast_macros): raise compile error on invalid generate_derive input. #4766

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/oxc_ast/src/ast/js.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

use std::cell::Cell;

use oxc_allocator::{Box, Vec};
use oxc_allocator::{Box, CloneIn, Vec};
use oxc_ast_macros::ast;
use oxc_span::{Atom, SourceType, Span};
use oxc_span::{Atom, GetSpan, GetSpanMut, SourceType, Span};
use oxc_syntax::{
operator::{
AssignmentOperator, BinaryOperator, LogicalOperator, UnaryOperator, UpdateOperator,
Expand Down
4 changes: 2 additions & 2 deletions crates/oxc_ast/src/ast/jsx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
// Silence erroneous warnings from Rust Analyser for `#[derive(Tsify)]`
#![allow(non_snake_case)]

use oxc_allocator::{Box, Vec};
use oxc_allocator::{Box, CloneIn, Vec};
use oxc_ast_macros::ast;
use oxc_span::{Atom, Span};
use oxc_span::{Atom, GetSpan, GetSpanMut, Span};
#[cfg(feature = "serialize")]
use serde::Serialize;
#[cfg(feature = "serialize")]
Expand Down
3 changes: 2 additions & 1 deletion crates/oxc_ast/src/ast/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
use std::hash::Hash;

use bitflags::bitflags;
use oxc_allocator::CloneIn;
use oxc_ast_macros::{ast, CloneIn};
use oxc_span::{Atom, Span};
use oxc_span::{Atom, GetSpan, GetSpanMut, Span};
use oxc_syntax::number::{BigintBase, NumberBase};
#[cfg(feature = "serialize")]
use serde::Serialize;
Expand Down
4 changes: 2 additions & 2 deletions crates/oxc_ast/src/ast/ts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

use std::{cell::Cell, hash::Hash};

use oxc_allocator::{Box, Vec};
use oxc_allocator::{Box, CloneIn, Vec};
use oxc_ast_macros::ast;
use oxc_span::{Atom, Span};
use oxc_span::{Atom, GetSpan, GetSpanMut, Span};
use oxc_syntax::scope::ScopeId;
#[cfg(feature = "serialize")]
use serde::Serialize;
Expand Down
81 changes: 69 additions & 12 deletions crates/oxc_ast_macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,91 @@ fn enum_repr(enum_: &syn::ItemEnum) -> TokenStream2 {
}
}

/// This attribute serves two purposes,
/// First, it is a marker for our codegen to detect AST types. Furthermore.
/// It is also a lightweight macro; All of its computation is cached and
/// it only applies the following changes without any complex operation:
/// Generate assertions that traits used in `#[generate_derive]` are in scope.
///
/// e.g. for `#[generate_derive(GetSpan)]`, it generates:
///
/// ```rs
/// const _: () = {
/// {
/// trait AssertionTrait: ::oxc_span::GetSpan {}
/// impl<T: GetSpan> AssertionTrait for T {}
/// }
/// };
/// ```
///
/// If `GetSpan` is not in scope, or it is not the correct `oxc_span::GetSpan`,
/// this will raise a compilation error.
fn assert_generated_derives(attrs: &[syn::Attribute]) -> TokenStream2 {
#[inline]
fn parse(attr: &syn::Attribute) -> impl Iterator<Item = syn::Ident> {
attr.parse_args_with(
syn::punctuated::Punctuated::<syn::Ident, syn::token::Comma>::parse_terminated,
)
.expect("`generate_derive` only accepts traits as single segment paths, Found an invalid argument")
.into_iter()
}

// TODO: benchmark this to see if a lazy static cell would perform better.
#[inline]
fn abs_trait(
ident: &syn::Ident,
) -> (/* absolute type path */ TokenStream2, /* possible generics */ TokenStream2) {
if ident == "CloneIn" {
(quote!(::oxc_allocator::CloneIn), quote!(<'static>))
} else if ident == "GetSpan" {
(quote!(::oxc_span::GetSpan), TokenStream2::default())
} else if ident == "GetSpanMut" {
(quote!(::oxc_span::GetSpanMut), TokenStream2::default())
} else {
panic!("Invalid derive trait(generate_derive): {ident}");
}
}

// NOTE: At this level we don't care if a trait is derived multiple times, It is the
// responsibility of the codegen to raise errors for those.
let assertion =
attrs.iter().filter(|attr| attr.path().is_ident("generate_derive")).flat_map(parse).map(
|derive| {
let (abs_derive, generics) = abs_trait(&derive);
quote! {{
// NOTE: these are wrapped in a scope to avoid the need for unique identifiers.
trait AssertionTrait: #abs_derive #generics {}
impl<T: #derive #generics> AssertionTrait for T {}
}}
},
);
quote!(const _: () = { #(#assertion)* };)
}

/// This attribute serves two purposes.
/// First, it is a marker for our codegen to detect AST types.
/// Secondly, it generates the following code:
///
/// * Prepend `#[repr(C)]` to structs
/// * Prepend `#[repr(C, u8)]` to fieldful enums e.g. `enum E { X: u32, Y: u8 }`
/// * Prepend `#[repr(u8)]` to unit (fieldless) enums e.g. `enum E { X, Y, Z, }`
/// * Prepend `#[derive(oxc_ast_macros::Ast)]` to all structs and enums
///
/// * Add assertions that traits used in `#[generate_derive(...)]` are in scope.
#[proc_macro_attribute]
#[allow(clippy::missing_panics_doc)]
pub fn ast(_args: TokenStream, input: TokenStream) -> TokenStream {
let input = syn::parse_macro_input!(input as syn::Item);

let repr = match input {
syn::Item::Enum(ref enum_) => enum_repr(enum_),
syn::Item::Struct(_) => quote!(#[repr(C)]),

_ => {
unreachable!()
let (head, tail) = match &input {
syn::Item::Enum(enum_) => (enum_repr(enum_), assert_generated_derives(&enum_.attrs)),
syn::Item::Struct(struct_) => {
(quote!(#[repr(C)]), assert_generated_derives(&struct_.attrs))
}

_ => unreachable!(),
};

let expanded = quote! {
#[derive(::oxc_ast_macros::Ast)]
#repr
#head
#input
#tail
};
TokenStream::from(expanded)
}
Expand Down