diff --git a/src/node/handler.rs b/src/node/handler.rs index 1c86bc4..fbb8802 100644 --- a/src/node/handler.rs +++ b/src/node/handler.rs @@ -7,6 +7,7 @@ pub enum NodeHandler { name: String, function: fn(Vec) -> f64, max_args: Option, + min_args: usize, }, Variable, Empty, @@ -20,7 +21,14 @@ impl NodeHandler { function, name, max_args, - } => function(inputs), + min_args, + } => { + if &inputs.len() >= min_args { + function(inputs) + } else { + 0. + } + } Self::Variable => passed_x, Self::Empty => inputs[0], } diff --git a/src/node/mod.rs b/src/node/mod.rs index c3bbfef..046369b 100644 --- a/src/node/mod.rs +++ b/src/node/mod.rs @@ -43,6 +43,7 @@ impl Node { func_name: String, func: fn(Vec) -> f64, max_children_count: Option, + min_arguments: usize, ) -> Self { Self { children: vec![], @@ -50,6 +51,7 @@ impl Node { name: func_name, function: func, max_args: max_children_count, + min_args: min_arguments, }, max_children_count, } @@ -93,6 +95,7 @@ impl Node { name, function, max_args, + min_args, } => name.clone() + "(" + children_text.as_str() + ")", NodeHandler::Empty => children_text, _ => self.to_string(), @@ -111,6 +114,7 @@ impl fmt::Display for Node { function, name, max_args, + min_args, } => name.clone(), NodeHandler::Variable => "X".to_string(), NodeHandler::Empty => "".to_string(), diff --git a/src/node/node_modifier/mod.rs b/src/node/node_modifier/mod.rs index f63442e..9a0007a 100644 --- a/src/node/node_modifier/mod.rs +++ b/src/node/node_modifier/mod.rs @@ -22,36 +22,43 @@ impl<'a> NodeModifier<'a> { name: "sin".to_string(), function: |inputs| inputs[0].sin(), max_args: Some(1), + min_args: 1, }, NodeHandler::Function { name: "cos".to_string(), function: |inputs| inputs[0].cos(), max_args: Some(1), + min_args: 1, }, NodeHandler::Function { name: "sum".to_string(), function: |inputs| inputs.iter().sum(), max_args: None, + min_args: 0, }, NodeHandler::Function { name: "-".to_string(), function: |inputs| -inputs[0], max_args: Some(1), + min_args: 1, }, NodeHandler::Function { name: "product".to_string(), function: |inputs| inputs.iter().product(), max_args: None, + min_args: 0, }, NodeHandler::Function { name: "exp".to_string(), function: |inputs| inputs[0].powf(inputs[1]), max_args: Some(2), + min_args: 2, }, NodeHandler::Function { name: "1/".to_string(), function: |inputs| 1f64 / inputs[0], max_args: Some(1), + min_args: 1, }, ]; let standard_number_mutation: Vec f64> = vec![ @@ -142,6 +149,9 @@ impl<'a> NodeModifier<'a> { } } let children_count = self.picked_node.children.len(); + if children_count <= 0 { + return Err(NodeManipulationError::NotEnoughChildren); + } let operated_index = between.unwrap_or(random_range(0..children_count)); let moved = self.picked_node.children.remove(operated_index); node.children.push(moved); @@ -167,6 +177,7 @@ impl<'a> NodeModifier<'a> { function, name, max_args, + min_args, } => { let selected_mutation = &self.function_mutation_pool [random_range(0..self.function_mutation_pool.len())]; @@ -174,6 +185,7 @@ impl<'a> NodeModifier<'a> { name, function, max_args, + min_args, } = &selected_mutation { self.picked_node.max_children_count = max_args.clone(); @@ -211,6 +223,7 @@ impl<'a> NodeModifier<'a> { name, function, max_args, + min_args, } => max_args, NodeHandler::Variable => Some(0), NodeHandler::Empty => Some(1),