Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 156 additions & 31 deletions node-graph/node-macro/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,33 +37,102 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
};
let struct_name = format_ident!("{}Node", struct_name);

let struct_generics: Vec<Ident> = fields.iter().enumerate().map(|(i, _)| format_ident!("Node{}", i)).collect();
// Separate data fields from regular fields
let (data_fields, regular_fields): (Vec<_>, Vec<_>) = fields.iter().partition(|f| f.is_data_field);

// Extract function generics used by data fields
let data_field_generics: Vec<_> = fn_generics
.iter()
.filter(|generic| {
let generic_ident = match generic {
syn::GenericParam::Type(type_param) => &type_param.ident,
_ => return false,
};

// Check if this generic is used in any data field type
data_fields.iter().any(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => type_contains_ident(ty, generic_ident),
_ => false,
})
})
.cloned()
.collect();

// Node generics for regular fields (Node0, Node1, ...)
let node_generics: Vec<Ident> = regular_fields.iter().enumerate().map(|(i, _)| format_ident!("Node{}", i)).collect();

// Extract just the idents from data_field_generics for struct type parameters
let data_field_generic_idents: Vec<Ident> = data_field_generics
.iter()
.filter_map(|gp| match gp {
syn::GenericParam::Type(tp) => Some(tp.ident.clone()),
_ => None,
})
.collect();

// Combined struct type parameters: data field generic idents (T, U, ...) + node generics (Node0, Node1, ...)
// For struct type instantiation: MemoNode<T, Node0>
let struct_type_params: Vec<Ident> = data_field_generic_idents.iter().cloned().chain(node_generics.iter().cloned()).collect();

// Combined struct generic parameters with bounds for struct definition
// struct MemoNode<T: Clone, Node0>
let struct_generic_params: Vec<TokenStream2> = data_field_generics.iter().map(|gp| quote!(#gp)).chain(node_generics.iter().map(|id| quote!(#id))).collect();
let input_ident = &input.pat_ident;

let context_features = &input.context_features;

let field_idents: Vec<_> = fields.iter().map(|f| &f.pat_ident).collect();
// Regular field idents and names (for function parameters)
let field_idents: Vec<_> = regular_fields.iter().map(|f| &f.pat_ident).collect();
let field_names: Vec<_> = field_idents.iter().map(|pat_ident| &pat_ident.ident).collect();
let regular_field_names: Vec<_> = regular_fields.iter().map(|f| &f.pat_ident.ident).collect();
let data_field_names: Vec<_> = data_fields.iter().map(|f| &f.pat_ident.ident).collect();

let input_names: Vec<_> = fields
// Only regular fields have input names/descriptions (for UI)
let input_names: Vec<_> = regular_fields
.iter()
.map(|f| &f.name)
.zip(field_names.iter())
.zip(regular_field_names.iter())
.map(|zipped| match zipped {
(Some(name), _) => name.value(),
(_, name) => name.to_string().to_case(Case::Title),
})
.collect();

let input_descriptions: Vec<_> = fields.iter().map(|f| &f.description).collect();
let input_descriptions: Vec<_> = regular_fields.iter().map(|f| &f.description).collect();

// Generate struct fields: data fields (concrete types) + regular fields (generic types)
let data_field_defs = data_fields.iter().map(|field| {
let name = &field.pat_ident.ident;
let ty = match &field.ty {
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => ty,
_ => unreachable!("Data fields must be Regular types, not Node types"),
};
quote! { pub(super) #name: #ty }
});

let struct_fields = field_names.iter().zip(struct_generics.iter()).map(|(name, r#gen)| {
let regular_field_defs = regular_field_names.iter().zip(node_generics.iter()).map(|(name, r#gen)| {
quote! { pub(super) #name: #r#gen }
});

let struct_fields = data_field_defs.chain(regular_field_defs);

let mut future_idents = Vec::new();

let field_types: Vec<_> = fields
// Data fields get passed as references to the underlying function
let data_field_idents: Vec<_> = data_fields.iter().map(|f| &f.pat_ident).collect();
let data_field_types: Vec<_> = data_fields
.iter()
.map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => {
let ty = ty.clone();
quote!(&#ty)
}
_ => unreachable!("Data fields must be Regular types, not Node types"),
})
.collect();

// Regular fields have types passed to the function
let field_types: Vec<_> = regular_fields
.iter()
.map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => ty.clone(),
Expand All @@ -74,7 +143,8 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
})
.collect();

let widget_override: Vec<_> = fields
// Only regular fields have UI metadata (data fields are internal state)
let widget_override: Vec<_> = regular_fields
.iter()
.map(|field| match &field.widget_override {
ParsedWidgetOverride::None => quote!(RegistryWidgetOverride::None),
Expand All @@ -84,7 +154,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
})
.collect();

let value_sources: Vec<_> = fields
let value_sources: Vec<_> = regular_fields
.iter()
.map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { value_source, .. }) => match value_source {
Expand All @@ -104,7 +174,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
})
.collect();

let default_types: Vec<_> = fields
let default_types: Vec<_> = regular_fields
.iter()
.map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { implementations, .. }) => match implementations.first() {
Expand All @@ -115,7 +185,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
})
.collect();

let number_min_values: Vec<_> = fields
let number_min_values: Vec<_> = regular_fields
.iter()
.map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { number_soft_min, number_hard_min, .. }) => match (number_soft_min, number_hard_min) {
Expand All @@ -126,7 +196,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
_ => quote!(None),
})
.collect();
let number_max_values: Vec<_> = fields
let number_max_values: Vec<_> = regular_fields
.iter()
.map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { number_soft_max, number_hard_max, .. }) => match (number_soft_max, number_hard_max) {
Expand All @@ -137,7 +207,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
_ => quote!(None),
})
.collect();
let number_mode_range_values: Vec<_> = fields
let number_mode_range_values: Vec<_> = regular_fields
.iter()
.map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField {
Expand All @@ -147,23 +217,24 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
_ => quote!(None),
})
.collect();
let number_display_decimal_places: Vec<_> = fields
let number_display_decimal_places: Vec<_> = regular_fields
.iter()
.map(|field| field.number_display_decimal_places.as_ref().map_or(quote!(None), |i| quote!(Some(#i))))
.collect();
let number_step: Vec<_> = fields.iter().map(|field| field.number_step.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect();
let number_step: Vec<_> = regular_fields.iter().map(|field| field.number_step.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect();

let unit_suffix: Vec<_> = fields.iter().map(|field| field.unit.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect();
let unit_suffix: Vec<_> = regular_fields.iter().map(|field| field.unit.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect();

let exposed: Vec<_> = fields
let exposed: Vec<_> = regular_fields
.iter()
.map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { exposed, .. }) => quote!(#exposed),
_ => quote!(true),
})
.collect();

let eval_args = fields.iter().map(|field| {
// Only eval regular fields (data fields are accessed directly as self.field_name)
let eval_args = regular_fields.iter().map(|field| {
let name = &field.pat_ident.ident;
match &field.ty {
ParsedFieldType::Regular { .. } => {
Expand All @@ -175,7 +246,8 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
}
});

let min_max_args = fields.iter().map(|field| match &field.ty {
// Only regular fields can have min/max constraints
let min_max_args = regular_fields.iter().map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { number_hard_min, number_hard_max, .. }) => {
let name = &field.pat_ident.ident;
let mut tokens = quote!();
Expand Down Expand Up @@ -208,7 +280,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
let mut clauses = Vec::new();
let mut clampable_clauses = Vec::new();

for (field, name) in fields.iter().zip(struct_generics.iter()) {
for (field, name) in regular_fields.iter().zip(node_generics.iter()) {
clauses.push(match (&field.ty, *is_async) {
(
ParsedFieldType::Regular(RegularParsedField {
Expand Down Expand Up @@ -259,13 +331,42 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
);
struct_where_clause.predicates.extend(extra_where);

let new_args = struct_generics.iter().zip(field_names.iter()).map(|(r#gen, name)| {
// Only regular fields are parameters to new()
let new_args = node_generics.iter().zip(regular_field_names.iter()).map(|(r#gen, name)| {
quote! { #name: #r#gen }
});

// Initialize data fields with Default, regular fields with parameters
let data_inits = data_field_names.iter().map(|name| {
quote! { #name: Default::default() }
});
let regular_inits = regular_field_names.iter().map(|name| {
quote! { #name }
});
let all_field_inits = data_inits.chain(regular_inits);

let async_keyword = is_async.then(|| quote!(async));
let await_keyword = is_async.then(|| quote!(.await));

// Data fields may not implement Copy, PartialEq, etc., so only derive Debug and Clone
let struct_derives = if data_fields.is_empty() {
quote!(#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)])
} else {
quote!(#[derive(Debug, Clone)])
};

// Generate serialize method if serialize attribute is specified
let serialize_impl = if let Some(serialize_fn) = &parsed.attributes.serialize {
let data_field_refs = data_field_names.iter().map(|name| quote!(&self.#name));
quote! {
fn serialize(&self) -> Option<std::sync::Arc<dyn std::any::Any + Send + Sync>> {
#serialize_fn(#(#data_field_refs),*)
}
}
} else {
quote!()
};

let eval_impl = quote! {
type Output = #core_types::registry::DynFuture<'n, #output_type>;
#[inline]
Expand All @@ -275,9 +376,11 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn

#(#eval_args)*
#(#min_max_args)*
self::#fn_name(__input #(, #field_names)*) #await_keyword
self::#fn_name(__input #(, &self.#data_field_names)* #(, #regular_field_names)*) #await_keyword
})
}

#serialize_impl
};

let identifier = format_ident!("{}_proto_ident", fn_name);
Expand All @@ -302,11 +405,11 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
/// Underlying implementation for [#struct_name]
#[inline]
#[allow(clippy::too_many_arguments)]
#vis #async_keyword fn #fn_name <'n, #(#fn_generics,)*> (#input_ident: #input_type #(, #field_idents: #field_types)*) -> #output_type #where_clause #body
#vis #async_keyword fn #fn_name <'n, #(#fn_generics,)*> (#input_ident: #input_type #(, #data_field_idents: #data_field_types)* #(, #field_idents: #field_types)*) -> #output_type #where_clause #body

#cfg
#[automatically_derived]
impl<'n, #(#fn_generics,)* #(#struct_generics,)* #(#future_idents,)*> #core_types::Node<'n, #input_type> for #mod_name::#struct_name<#(#struct_generics,)*>
impl<'n, #(#fn_generics,)* #(#node_generics,)* #(#future_idents,)*> #core_types::Node<'n, #input_type> for #mod_name::#struct_name<#(#struct_type_params,)*>
#struct_where_clause
{
#eval_impl
Expand Down Expand Up @@ -340,18 +443,18 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn

static #import_name: core::marker::PhantomData<(#(#all_implementation_types,)*)> = core::marker::PhantomData;

#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct #struct_name<#(#struct_generics,)*> {
#struct_derives
pub struct #struct_name<#(#struct_generic_params,)*> {
#(#struct_fields,)*
}

#[automatically_derived]
impl<'n, #(#struct_generics,)*> #struct_name<#(#struct_generics,)*>
impl<'n, #(#struct_generic_params,)*> #struct_name<#(#struct_type_params,)*>
{
#[allow(clippy::too_many_arguments)]
pub fn new(#(#new_args,)*) -> Self {
Self {
#(#field_names,)*
#(#all_field_inits,)*
}
}
}
Expand Down Expand Up @@ -493,8 +596,10 @@ fn generate_register_node_impl(parsed: &ParsedNodeFn, field_names: &[&Ident], st

let mut constructors = Vec::new();
let unit = parse_quote!(gcore::Context);
let parameter_types: Vec<_> = parsed
.fields

let regular_fields: Vec<_> = parsed.fields.iter().filter(|f| !f.is_data_field).collect();

let parameter_types: Vec<_> = regular_fields
.iter()
.map(|field| {
match &field.ty {
Expand Down Expand Up @@ -535,7 +640,7 @@ fn generate_register_node_impl(parsed: &ParsedNodeFn, field_names: &[&Ident], st
let field_name = field_names[j];
let (input_type, output_type) = &types[i.min(types.len() - 1)];

let node = matches!(parsed.fields[j].ty, ParsedFieldType::Node { .. });
let node = matches!(regular_fields[j].ty, ParsedFieldType::Node { .. });

let downcast_node = quote!(
let #field_name: DowncastBothNode<#input_type, #output_type> = DowncastBothNode::new(args[#j].clone());
Expand Down Expand Up @@ -712,3 +817,23 @@ impl FilterUsedGenerics {
self.used(&*modified).cloned().collect()
}
}

/// Check if a type contains a reference to a specific identifier (e.g., a generic type parameter)
fn type_contains_ident(ty: &Type, ident: &Ident) -> bool {
struct IdentChecker<'a> {
target: &'a Ident,
found: bool,
}

impl<'a, 'ast> syn::visit::Visit<'ast> for IdentChecker<'a> {
fn visit_ident(&mut self, i: &'ast Ident) {
if i == self.target {
self.found = true;
}
}
}

let mut checker = IdentChecker { target: ident, found: false };
syn::visit::visit_type(&mut checker, ty);
checker.found
}
Loading