diff --git a/node-graph/node-macro/src/codegen.rs b/node-graph/node-macro/src/codegen.rs index cd22f38ba2..5758c7c0d8 100644 --- a/node-graph/node-macro/src/codegen.rs +++ b/node-graph/node-macro/src/codegen.rs @@ -37,33 +37,102 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn }; let struct_name = format_ident!("{}Node", struct_name); - let struct_generics: Vec = fields.iter().enumerate().map(|(i, _)| format_ident!("Node{}", i)).collect(); + // Separate data fields from regular fields + let (data_fields, regular_fields): (Vec<_>, Vec<_>) = fields.iter().partition(|f| f.is_data_field); + + // Extract function generics used by data fields + let data_field_generics: Vec<_> = fn_generics + .iter() + .filter(|generic| { + let generic_ident = match generic { + syn::GenericParam::Type(type_param) => &type_param.ident, + _ => return false, + }; + + // Check if this generic is used in any data field type + data_fields.iter().any(|field| match &field.ty { + ParsedFieldType::Regular(RegularParsedField { ty, .. }) => type_contains_ident(ty, generic_ident), + _ => false, + }) + }) + .cloned() + .collect(); + + // Node generics for regular fields (Node0, Node1, ...) + let node_generics: Vec = regular_fields.iter().enumerate().map(|(i, _)| format_ident!("Node{}", i)).collect(); + + // Extract just the idents from data_field_generics for struct type parameters + let data_field_generic_idents: Vec = data_field_generics + .iter() + .filter_map(|gp| match gp { + syn::GenericParam::Type(tp) => Some(tp.ident.clone()), + _ => None, + }) + .collect(); + + // Combined struct type parameters: data field generic idents (T, U, ...) + node generics (Node0, Node1, ...) + // For struct type instantiation: MemoNode + let struct_type_params: Vec = data_field_generic_idents.iter().cloned().chain(node_generics.iter().cloned()).collect(); + + // Combined struct generic parameters with bounds for struct definition + // struct MemoNode + let struct_generic_params: Vec = data_field_generics.iter().map(|gp| quote!(#gp)).chain(node_generics.iter().map(|id| quote!(#id))).collect(); let input_ident = &input.pat_ident; let context_features = &input.context_features; - let field_idents: Vec<_> = fields.iter().map(|f| &f.pat_ident).collect(); + // Regular field idents and names (for function parameters) + let field_idents: Vec<_> = regular_fields.iter().map(|f| &f.pat_ident).collect(); let field_names: Vec<_> = field_idents.iter().map(|pat_ident| &pat_ident.ident).collect(); + let regular_field_names: Vec<_> = regular_fields.iter().map(|f| &f.pat_ident.ident).collect(); + let data_field_names: Vec<_> = data_fields.iter().map(|f| &f.pat_ident.ident).collect(); - let input_names: Vec<_> = fields + // Only regular fields have input names/descriptions (for UI) + let input_names: Vec<_> = regular_fields .iter() .map(|f| &f.name) - .zip(field_names.iter()) + .zip(regular_field_names.iter()) .map(|zipped| match zipped { (Some(name), _) => name.value(), (_, name) => name.to_string().to_case(Case::Title), }) .collect(); - let input_descriptions: Vec<_> = fields.iter().map(|f| &f.description).collect(); + let input_descriptions: Vec<_> = regular_fields.iter().map(|f| &f.description).collect(); + + // Generate struct fields: data fields (concrete types) + regular fields (generic types) + let data_field_defs = data_fields.iter().map(|field| { + let name = &field.pat_ident.ident; + let ty = match &field.ty { + ParsedFieldType::Regular(RegularParsedField { ty, .. }) => ty, + _ => unreachable!("Data fields must be Regular types, not Node types"), + }; + quote! { pub(super) #name: #ty } + }); - let struct_fields = field_names.iter().zip(struct_generics.iter()).map(|(name, r#gen)| { + let regular_field_defs = regular_field_names.iter().zip(node_generics.iter()).map(|(name, r#gen)| { quote! { pub(super) #name: #r#gen } }); + let struct_fields = data_field_defs.chain(regular_field_defs); + let mut future_idents = Vec::new(); - let field_types: Vec<_> = fields + // Data fields get passed as references to the underlying function + let data_field_idents: Vec<_> = data_fields.iter().map(|f| &f.pat_ident).collect(); + let data_field_types: Vec<_> = data_fields + .iter() + .map(|field| match &field.ty { + ParsedFieldType::Regular(RegularParsedField { ty, .. }) => { + let ty = ty.clone(); + quote!(&#ty) + } + _ => unreachable!("Data fields must be Regular types, not Node types"), + }) + .collect(); + + // Regular fields have types passed to the function + let field_types: Vec<_> = regular_fields .iter() .map(|field| match &field.ty { ParsedFieldType::Regular(RegularParsedField { ty, .. }) => ty.clone(), @@ -74,7 +143,8 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn }) .collect(); - let widget_override: Vec<_> = fields + // Only regular fields have UI metadata (data fields are internal state) + let widget_override: Vec<_> = regular_fields .iter() .map(|field| match &field.widget_override { ParsedWidgetOverride::None => quote!(RegistryWidgetOverride::None), @@ -84,7 +154,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn }) .collect(); - let value_sources: Vec<_> = fields + let value_sources: Vec<_> = regular_fields .iter() .map(|field| match &field.ty { ParsedFieldType::Regular(RegularParsedField { value_source, .. }) => match value_source { @@ -104,7 +174,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn }) .collect(); - let default_types: Vec<_> = fields + let default_types: Vec<_> = regular_fields .iter() .map(|field| match &field.ty { ParsedFieldType::Regular(RegularParsedField { implementations, .. }) => match implementations.first() { @@ -115,7 +185,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn }) .collect(); - let number_min_values: Vec<_> = fields + let number_min_values: Vec<_> = regular_fields .iter() .map(|field| match &field.ty { ParsedFieldType::Regular(RegularParsedField { number_soft_min, number_hard_min, .. }) => match (number_soft_min, number_hard_min) { @@ -126,7 +196,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn _ => quote!(None), }) .collect(); - let number_max_values: Vec<_> = fields + let number_max_values: Vec<_> = regular_fields .iter() .map(|field| match &field.ty { ParsedFieldType::Regular(RegularParsedField { number_soft_max, number_hard_max, .. }) => match (number_soft_max, number_hard_max) { @@ -137,7 +207,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn _ => quote!(None), }) .collect(); - let number_mode_range_values: Vec<_> = fields + let number_mode_range_values: Vec<_> = regular_fields .iter() .map(|field| match &field.ty { ParsedFieldType::Regular(RegularParsedField { @@ -147,15 +217,15 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn _ => quote!(None), }) .collect(); - let number_display_decimal_places: Vec<_> = fields + let number_display_decimal_places: Vec<_> = regular_fields .iter() .map(|field| field.number_display_decimal_places.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))) .collect(); - let number_step: Vec<_> = fields.iter().map(|field| field.number_step.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect(); + let number_step: Vec<_> = regular_fields.iter().map(|field| field.number_step.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect(); - let unit_suffix: Vec<_> = fields.iter().map(|field| field.unit.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect(); + let unit_suffix: Vec<_> = regular_fields.iter().map(|field| field.unit.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect(); - let exposed: Vec<_> = fields + let exposed: Vec<_> = regular_fields .iter() .map(|field| match &field.ty { ParsedFieldType::Regular(RegularParsedField { exposed, .. }) => quote!(#exposed), @@ -163,7 +233,8 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn }) .collect(); - let eval_args = fields.iter().map(|field| { + // Only eval regular fields (data fields are accessed directly as self.field_name) + let eval_args = regular_fields.iter().map(|field| { let name = &field.pat_ident.ident; match &field.ty { ParsedFieldType::Regular { .. } => { @@ -175,7 +246,8 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn } }); - let min_max_args = fields.iter().map(|field| match &field.ty { + // Only regular fields can have min/max constraints + let min_max_args = regular_fields.iter().map(|field| match &field.ty { ParsedFieldType::Regular(RegularParsedField { number_hard_min, number_hard_max, .. }) => { let name = &field.pat_ident.ident; let mut tokens = quote!(); @@ -208,7 +280,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn let mut clauses = Vec::new(); let mut clampable_clauses = Vec::new(); - for (field, name) in fields.iter().zip(struct_generics.iter()) { + for (field, name) in regular_fields.iter().zip(node_generics.iter()) { clauses.push(match (&field.ty, *is_async) { ( ParsedFieldType::Regular(RegularParsedField { @@ -259,13 +331,42 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn ); struct_where_clause.predicates.extend(extra_where); - let new_args = struct_generics.iter().zip(field_names.iter()).map(|(r#gen, name)| { + // Only regular fields are parameters to new() + let new_args = node_generics.iter().zip(regular_field_names.iter()).map(|(r#gen, name)| { quote! { #name: #r#gen } }); + // Initialize data fields with Default, regular fields with parameters + let data_inits = data_field_names.iter().map(|name| { + quote! { #name: Default::default() } + }); + let regular_inits = regular_field_names.iter().map(|name| { + quote! { #name } + }); + let all_field_inits = data_inits.chain(regular_inits); + let async_keyword = is_async.then(|| quote!(async)); let await_keyword = is_async.then(|| quote!(.await)); + // Data fields may not implement Copy, PartialEq, etc., so only derive Debug and Clone + let struct_derives = if data_fields.is_empty() { + quote!(#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]) + } else { + quote!(#[derive(Debug, Clone)]) + }; + + // Generate serialize method if serialize attribute is specified + let serialize_impl = if let Some(serialize_fn) = &parsed.attributes.serialize { + let data_field_refs = data_field_names.iter().map(|name| quote!(&self.#name)); + quote! { + fn serialize(&self) -> Option> { + #serialize_fn(#(#data_field_refs),*) + } + } + } else { + quote!() + }; + let eval_impl = quote! { type Output = #core_types::registry::DynFuture<'n, #output_type>; #[inline] @@ -275,9 +376,11 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn #(#eval_args)* #(#min_max_args)* - self::#fn_name(__input #(, #field_names)*) #await_keyword + self::#fn_name(__input #(, &self.#data_field_names)* #(, #regular_field_names)*) #await_keyword }) } + + #serialize_impl }; let identifier = format_ident!("{}_proto_ident", fn_name); @@ -302,11 +405,11 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn /// Underlying implementation for [#struct_name] #[inline] #[allow(clippy::too_many_arguments)] - #vis #async_keyword fn #fn_name <'n, #(#fn_generics,)*> (#input_ident: #input_type #(, #field_idents: #field_types)*) -> #output_type #where_clause #body + #vis #async_keyword fn #fn_name <'n, #(#fn_generics,)*> (#input_ident: #input_type #(, #data_field_idents: #data_field_types)* #(, #field_idents: #field_types)*) -> #output_type #where_clause #body #cfg #[automatically_derived] - impl<'n, #(#fn_generics,)* #(#struct_generics,)* #(#future_idents,)*> #core_types::Node<'n, #input_type> for #mod_name::#struct_name<#(#struct_generics,)*> + impl<'n, #(#fn_generics,)* #(#node_generics,)* #(#future_idents,)*> #core_types::Node<'n, #input_type> for #mod_name::#struct_name<#(#struct_type_params,)*> #struct_where_clause { #eval_impl @@ -340,18 +443,18 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn static #import_name: core::marker::PhantomData<(#(#all_implementation_types,)*)> = core::marker::PhantomData; - #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] - pub struct #struct_name<#(#struct_generics,)*> { + #struct_derives + pub struct #struct_name<#(#struct_generic_params,)*> { #(#struct_fields,)* } #[automatically_derived] - impl<'n, #(#struct_generics,)*> #struct_name<#(#struct_generics,)*> + impl<'n, #(#struct_generic_params,)*> #struct_name<#(#struct_type_params,)*> { #[allow(clippy::too_many_arguments)] pub fn new(#(#new_args,)*) -> Self { Self { - #(#field_names,)* + #(#all_field_inits,)* } } } @@ -493,8 +596,10 @@ fn generate_register_node_impl(parsed: &ParsedNodeFn, field_names: &[&Ident], st let mut constructors = Vec::new(); let unit = parse_quote!(gcore::Context); - let parameter_types: Vec<_> = parsed - .fields + + let regular_fields: Vec<_> = parsed.fields.iter().filter(|f| !f.is_data_field).collect(); + + let parameter_types: Vec<_> = regular_fields .iter() .map(|field| { match &field.ty { @@ -535,7 +640,7 @@ fn generate_register_node_impl(parsed: &ParsedNodeFn, field_names: &[&Ident], st let field_name = field_names[j]; let (input_type, output_type) = &types[i.min(types.len() - 1)]; - let node = matches!(parsed.fields[j].ty, ParsedFieldType::Node { .. }); + let node = matches!(regular_fields[j].ty, ParsedFieldType::Node { .. }); let downcast_node = quote!( let #field_name: DowncastBothNode<#input_type, #output_type> = DowncastBothNode::new(args[#j].clone()); @@ -712,3 +817,23 @@ impl FilterUsedGenerics { self.used(&*modified).cloned().collect() } } + +/// Check if a type contains a reference to a specific identifier (e.g., a generic type parameter) +fn type_contains_ident(ty: &Type, ident: &Ident) -> bool { + struct IdentChecker<'a> { + target: &'a Ident, + found: bool, + } + + impl<'a, 'ast> syn::visit::Visit<'ast> for IdentChecker<'a> { + fn visit_ident(&mut self, i: &'ast Ident) { + if i == self.target { + self.found = true; + } + } + } + + let mut checker = IdentChecker { target: ident, found: false }; + syn::visit::visit_type(&mut checker, ty); + checker.found +} diff --git a/node-graph/node-macro/src/parsing.rs b/node-graph/node-macro/src/parsing.rs index 74f547b95a..ca58b14988 100644 --- a/node-graph/node-macro/src/parsing.rs +++ b/node-graph/node-macro/src/parsing.rs @@ -50,6 +50,8 @@ pub(crate) struct NodeFnAttributes { pub(crate) cfg: Option, /// if this node should get a gpu implementation, defaults to None pub(crate) shader_node: Option, + /// Custom serialization function path (e.g., "my_module::custom_serialize") + pub(crate) serialize: Option, // Add more attributes as needed } @@ -112,6 +114,7 @@ pub struct ParsedField { pub number_display_decimal_places: Option, pub number_step: Option, pub unit: Option, + pub is_data_field: bool, } #[derive(Clone, Debug)] @@ -201,6 +204,7 @@ impl Parse for NodeFnAttributes { let mut properties_string = None; let mut cfg = None; let mut shader_node = None; + let mut serialize = None; let content = input; // let content; @@ -270,13 +274,23 @@ impl Parse for NodeFnAttributes { let meta = meta.require_list()?; shader_node = Some(syn::parse2(meta.tokens.to_token_stream())?); } + "serialize" => { + let meta = meta.require_list()?; + if serialize.is_some() { + return Err(Error::new_spanned(meta, "Multiple 'serialize' attributes are not allowed")); + } + let parsed_path: Path = meta + .parse_args() + .map_err(|_| Error::new_spanned(meta, "Expected a valid path for 'serialize', e.g., serialize(my_module::custom_serialize)"))?; + serialize = Some(parsed_path); + } _ => { return Err(Error::new_spanned( meta, indoc!( r#" Unsupported attribute in `node`. - Supported attributes are 'category', 'path' 'name', 'skip_impl', 'cfg' and 'properties'. + Supported attributes are 'category', 'path', 'name', 'skip_impl', 'cfg', 'properties', 'serialize', and 'shader_node'. Example usage: #[node_macro::node(category("Value"), name("Test Node"))] @@ -295,6 +309,7 @@ impl Parse for NodeFnAttributes { properties_string, cfg, shader_node, + serialize, }) } } @@ -467,6 +482,9 @@ fn parse_node_implementations(attr: &Attribute, name: &Ident) -> syn:: fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Result { let ident = &pat_ident.ident; + // Check if this is a data field (struct field, not a parameter) + let is_data_field = extract_attribute(attrs, "data").is_some(); + let default_value = extract_attribute(attrs, "default") .map(|attr| attr.parse_args().map_err(|e| Error::new_spanned(attr, format!("Invalid `default` value for argument '{ident}': {e}")))) .transpose()?; @@ -489,6 +507,25 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul let exposed = extract_attribute(attrs, "expose").is_some(); + // Validate data field attributes + if is_data_field { + if default_value.is_some() { + return Err(Error::new_spanned( + &pat_ident, + "Data fields (#[data]) cannot have #[default] attribute. They are automatically initialized with Default::default()", + )); + } + if scope.is_some() { + return Err(Error::new_spanned(&pat_ident, "Data fields (#[data]) cannot have #[scope] attribute")); + } + if exposed { + return Err(Error::new_spanned( + &pat_ident, + "Data fields (#[data]) cannot be exposed (#[expose]). They are internal state, not node parameters", + )); + } + } + let value_source = match (default_value, scope) { (Some(_), Some(_)) => return Err(Error::new_spanned(&pat_ident, "Cannot have both `default` and `scope` attributes")), (Some(default_value), _) => ParsedValueSource::Default(default_value), @@ -586,6 +623,14 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul .fold(String::new(), |acc, b| acc + &b + "\n"); if is_node { + // Data fields cannot be impl Node types + if is_data_field { + return Err(Error::new_spanned( + &ty, + "Data fields (#[data]) cannot be of type `impl Node`. Data fields must be concrete types that implement Default", + )); + } + let (input_type, output_type) = node_input_type .zip(node_output_type) .ok_or_else(|| Error::new_spanned(&ty, "Invalid Node type. Expected `impl Node`"))?; @@ -610,6 +655,7 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul number_display_decimal_places, number_step, unit, + is_data_field, }) } else { let implementations = extract_attribute(attrs, "implementations") @@ -636,6 +682,7 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul number_display_decimal_places, number_step, unit, + is_data_field, }) } } @@ -826,6 +873,7 @@ mod tests { properties_string: None, cfg: None, shader_node: None, + serialize: None, }, fn_name: Ident::new("add", Span::call_site()), struct_name: Ident::new("Add", Span::call_site()), @@ -860,6 +908,7 @@ mod tests { number_display_decimal_places: None, number_step: None, unit: None, + is_data_field: false, }], body: TokenStream2::new(), description: String::from("Multi\nLine\n"), @@ -892,6 +941,7 @@ mod tests { properties_string: None, cfg: None, shader_node: None, + serialize: None, }, fn_name: Ident::new("transform", Span::call_site()), struct_name: Ident::new("Transform", Span::call_site()), @@ -920,6 +970,7 @@ mod tests { number_display_decimal_places: None, number_step: None, unit: None, + is_data_field: false, }, ParsedField { pat_ident: pat_ident("translate"), @@ -941,6 +992,7 @@ mod tests { number_display_decimal_places: None, number_step: None, unit: None, + is_data_field: false, }, ], body: TokenStream2::new(), @@ -971,6 +1023,7 @@ mod tests { properties_string: None, cfg: None, shader_node: None, + serialize: None, }, fn_name: Ident::new("circle", Span::call_site()), struct_name: Ident::new("Circle", Span::call_site()), @@ -1005,6 +1058,7 @@ mod tests { number_display_decimal_places: None, number_step: None, unit: None, + is_data_field: false, }], body: TokenStream2::new(), description: "Test\n".into(), @@ -1033,6 +1087,7 @@ mod tests { properties_string: None, cfg: None, shader_node: None, + serialize: None, }, fn_name: Ident::new("levels", Span::call_site()), struct_name: Ident::new("Levels", Span::call_site()), @@ -1072,6 +1127,7 @@ mod tests { number_display_decimal_places: None, number_step: None, unit: None, + is_data_field: false, }], body: TokenStream2::new(), description: String::new(), @@ -1107,6 +1163,7 @@ mod tests { properties_string: None, cfg: None, shader_node: None, + serialize: None, }, fn_name: Ident::new("add", Span::call_site()), struct_name: Ident::new("Add", Span::call_site()), @@ -1141,6 +1198,7 @@ mod tests { number_display_decimal_places: None, number_step: None, unit: None, + is_data_field: false, }], body: TokenStream2::new(), description: String::new(), @@ -1169,6 +1227,7 @@ mod tests { properties_string: None, cfg: None, shader_node: None, + serialize: None, }, fn_name: Ident::new("load_image", Span::call_site()), struct_name: Ident::new("LoadImage", Span::call_site()), @@ -1203,6 +1262,7 @@ mod tests { number_display_decimal_places: None, number_step: None, unit: None, + is_data_field: false, }], body: TokenStream2::new(), description: String::new(), @@ -1231,6 +1291,7 @@ mod tests { properties_string: None, cfg: None, shader_node: None, + serialize: None, }, fn_name: Ident::new("custom_node", Span::call_site()), struct_name: Ident::new("CustomNode", Span::call_site()), diff --git a/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs b/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs index 2704bdecaf..35fff34d87 100644 --- a/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs +++ b/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs @@ -245,6 +245,7 @@ impl PerPixelAdjustCodegen<'_> { number_display_decimal_places: None, number_step: None, unit: None, + is_data_field: false, }); // find exactly one gpu_image field, runtime doesn't support more than 1 atm diff --git a/node-graph/node-macro/src/validation.rs b/node-graph/node-macro/src/validation.rs index ae60663095..67d02c5e92 100644 --- a/node-graph/node-macro/src/validation.rs +++ b/node-graph/node-macro/src/validation.rs @@ -102,6 +102,11 @@ fn validate_implementations_for_generics(parsed: &ParsedNodeFn) { if !has_skip_impl && !parsed.fn_generics.is_empty() { for field in &parsed.fields { + // Skip validation for data fields - they're internal state and can be generic + if field.is_data_field { + continue; + } + let pat_ident = &field.pat_ident; match &field.ty { ParsedFieldType::Regular(RegularParsedField { ty, implementations, .. }) => { diff --git a/node-graph/nodes/gcore/src/memo.rs b/node-graph/nodes/gcore/src/memo.rs index fcb4b637d3..e57d7755b9 100644 --- a/node-graph/nodes/gcore/src/memo.rs +++ b/node-graph/nodes/gcore/src/memo.rs @@ -1,7 +1,5 @@ +use core_types::WasmNotSend; use core_types::memo::*; -use core_types::{Node, WasmNotSend}; -use dyn_any::DynFuture; -use std::future::Future; use std::hash::DefaultHasher; use std::hash::{Hash, Hasher}; use std::sync::Arc; @@ -14,94 +12,38 @@ use std::sync::Mutex; /// A cache hit occurs when the Option is Some and has a stored hash matching the hash of the call argument. In this case, the node returns the cached value without re-evaluating the inner node. /// /// Currently, only one input-output pair is cached. Subsequent calls with different inputs will overwrite the previous cache. -#[derive(Default)] -pub struct MemoNode { - cache: Arc>>, - node: CachedNode, -} -impl<'i, I: Hash + 'i, T: 'i + Clone + WasmNotSend, CachedNode: 'i> Node<'i, I> for MemoNode -where - CachedNode: for<'any_input> Node<'any_input, I>, - for<'a> >::Output: Future + WasmNotSend, -{ - // TODO: This should return a reference to the cached cached_value - // but that requires a lot of lifetime magic <- This was suggested by copilot but is pretty accurate xD - type Output = DynFuture<'i, T>; - fn eval(&'i self, input: I) -> Self::Output { - let mut hasher = DefaultHasher::new(); - input.hash(&mut hasher); - let hash = hasher.finish(); - - if let Some(data) = self.cache.lock().as_ref().unwrap().as_ref().and_then(|data| (data.0 == hash).then_some(data.1.clone())) { - Box::pin(async move { data }) - } else { - let fut = self.node.eval(input); - let cache = self.cache.clone(); - Box::pin(async move { - let value = fut.await; - *cache.lock().unwrap() = Some((hash, value.clone())); - value - }) - } +#[node_macro::node(category(""), path(graphene_core::memo), skip_impl)] +async fn memo(input: I, #[data] cache: Arc>>, node: impl Node) -> T { + let mut hasher = DefaultHasher::new(); + input.hash(&mut hasher); + let hash = hasher.finish(); + + if let Some(data) = cache.lock().as_ref().unwrap().as_ref().and_then(|data| (data.0 == hash).then_some(data.1.clone())) { + return data; } - fn reset(&self) { - self.cache.lock().unwrap().take(); - } -} - -impl MemoNode { - pub fn new(node: CachedNode) -> MemoNode { - MemoNode { cache: Default::default(), node } - } + let value = node.eval(input).await; + *cache.lock().unwrap() = Some((hash, value.clone())); + value } -#[allow(clippy::module_inception)] -pub mod memo { - use core_types::ProtoNodeIdentifier; - - pub const IDENTIFIER: ProtoNodeIdentifier = ProtoNodeIdentifier::new("graphene_core::memo::MemoNode"); -} +type MonitorValue = Arc>>>>; /// Caches the output of the last graph evaluation for introspection. -#[derive(Default)] -pub struct MonitorNode { +#[node_macro::node(category(""), path(graphene_core::memo), serialize(serialize_monitor), skip_impl)] +async fn monitor( + input: I, #[allow(clippy::type_complexity)] - io: Arc>>>>, - node: N, -} - -impl<'i, T, I, N> Node<'i, I> for MonitorNode -where - I: Clone + 'static + Send + Sync, - T: Clone + 'static + Send + Sync, - for<'a> N: Node<'a, I, Output: Future + WasmNotSend> + 'i, -{ - type Output = DynFuture<'i, T>; - fn eval(&'i self, input: I) -> Self::Output { - let io = self.io.clone(); - let output_fut = self.node.eval(input.clone()); - Box::pin(async move { - let output = output_fut.await; - *io.lock().unwrap() = Some(Arc::new(IORecord { input, output: output.clone() })); - output - }) - } - - fn serialize(&self) -> Option> { - let io = self.io.lock().unwrap(); - (io).as_ref().map(|output| output.clone() as Arc) - } -} - -impl MonitorNode { - pub fn new(node: N) -> MonitorNode { - MonitorNode { io: Arc::new(Mutex::new(None)), node } - } -} - -pub mod monitor { - use core_types::ProtoNodeIdentifier; - - pub const IDENTIFIER: ProtoNodeIdentifier = ProtoNodeIdentifier::new("graphene_core::memo::MonitorNode"); + #[data] + io: MonitorValue, + node: impl Node, +) -> T { + let output = node.eval(input.clone()).await; + *io.lock().unwrap() = Some(Arc::new(IORecord { input, output: output.clone() })); + output +} + +fn serialize_monitor(io: &MonitorValue) -> Option> { + let io = io.lock().unwrap(); + io.as_ref().map(|output| output.clone() as Arc) }