diff --git a/src/learner.rs b/src/learner.rs index 1a9f002..67754c7 100644 --- a/src/learner.rs +++ b/src/learner.rs @@ -35,8 +35,21 @@ impl Learner { } self.best_algorithm.clone() } + pub fn calculate_formula_debug(&mut self) -> Formula { + for _ in 0..self.iterations { + self.best_algorithm = self.iterate(); + self.best_algorithm.display_tree(); + } + self.best_algorithm.clone() + } fn iterate(&self) -> Formula { - let mut formulas: Vec<(Formula, f64)> = vec![]; + let best_similarity = Learner::get_similarity( + &self.expected_outputs, + &self.best_algorithm.run(self.inputs.clone()), + ) + .unwrap(); + let mut formulas: Vec<(Formula, f64)> = + vec![(self.best_algorithm.clone(), best_similarity)]; for _ in 0..self.formulas_per_iteration { let mut formula = self.best_algorithm.clone(); Learner::mutate_formula_randomly(&mut formula); @@ -63,7 +76,7 @@ impl Learner { } fn mutate_formula_randomly(formula: &mut Formula) { let mut editor = formula.modify_random_node(); - let decided_action = random_range(0..3); + let decided_action = random_range(0..4); if decided_action == ACTION_ADD { editor.add_node(editor.get_random_node()); } else if decided_action == ACTION_REMOVE { diff --git a/src/main.rs b/src/main.rs index 797f633..aa03f81 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,9 +8,11 @@ mod tests; fn main() { let mut learner = Learner::new( vec![0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.], - vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], - None, + vec![0., 2., 4., 6., 8., 10., 12., 14., 16., 19., 20.], None, + Some(10000), ); - println!("{:?}", learner.calculate_formula().as_text()); + let formula = learner.calculate_formula(); + println!("{:?}", formula.as_text()); + formula.display_tree(); } diff --git a/src/node/mod.rs b/src/node/mod.rs index 046369b..7763d21 100644 --- a/src/node/mod.rs +++ b/src/node/mod.rs @@ -29,7 +29,7 @@ impl Node { Self { children: vec![], handler: NodeHandler::Empty, - max_children_count: None, + max_children_count: Some(1), } } pub fn number(n: f64) -> Self { @@ -74,6 +74,33 @@ impl Node { return self.handler.run(inputs, passed_x); } + pub fn set_handler(&mut self, to: NodeHandler) { + match &to { + NodeHandler::Number { number } => { + self.max_children_count = Some(0); + self.children.clear(); + } + NodeHandler::Function { + name, + function, + max_args, + min_args, + } => { + self.max_children_count = max_args.clone(); + if let Some(x) = max_args { + self.children.truncate(x.clone()); + } + } + NodeHandler::Variable => { + self.max_children_count = Some(0); + self.children.clear(); + } + NodeHandler::Empty => { + self.max_children_count = Some(1); + self.children.truncate(1); + } + } + } pub fn modify_node(&mut self) -> NodeModifier { NodeModifier::from_random(self, None) } diff --git a/src/node/node_modifier/mod.rs b/src/node/node_modifier/mod.rs index 9a0007a..dba5d2f 100644 --- a/src/node/node_modifier/mod.rs +++ b/src/node/node_modifier/mod.rs @@ -6,7 +6,7 @@ use rand::random_range; pub mod errors; const PICK_STOP_PROBABILITY: u8 = 2; -const TYPE_CHANGE_PROBABILITY: u8 = 10; +const TYPE_CHANGE_PROBABILITY: u8 = 2; pub struct NodeModifier<'a> { picked_node: &'a mut Node, @@ -120,6 +120,9 @@ impl<'a> NodeModifier<'a> { &mut self, specified_index: Option, ) -> Result<(), NodeManipulationError> { + if let NodeHandler::Empty = &self.picked_node.handler { + return Err(NodeManipulationError::ProtectedEmpty); + } if self.picked_node.children.len() == 0 { return Err(NodeManipulationError::NotEnoughChildren); } @@ -149,6 +152,11 @@ impl<'a> NodeModifier<'a> { } } let children_count = self.picked_node.children.len(); + if let NodeHandler::Empty = &self.picked_node.handler + && children_count - 1 > 0 + { + return Err(NodeManipulationError::TooMuchChildren(node)); + } if children_count <= 0 { return Err(NodeManipulationError::NotEnoughChildren); } @@ -166,12 +174,12 @@ impl<'a> NodeModifier<'a> { } else { match &self.picked_node.handler { NodeHandler::Number { number } => { - self.picked_node.handler = NodeHandler::Number { + self.picked_node.set_handler(NodeHandler::Number { number: (self.number_mutation_pool [random_range(0..self.number_mutation_pool.len())])( number.clone() ), - }; + }); } NodeHandler::Function { function, @@ -181,23 +189,18 @@ impl<'a> NodeModifier<'a> { } => { let selected_mutation = &self.function_mutation_pool [random_range(0..self.function_mutation_pool.len())]; - if let NodeHandler::Function { - name, - function, - max_args, - min_args, - } = &selected_mutation - { - self.picked_node.max_children_count = max_args.clone(); - self.picked_node.handler = selected_mutation.clone(); - } + self.picked_node.set_handler(selected_mutation.clone()); } _ => {} } } } pub fn change_node_type(&mut self, to: Option) { - self.picked_node.handler = to.unwrap_or(self.get_random_handler()); + if let NodeHandler::Empty = &self.picked_node.handler { + return; + } + let next_type = to.unwrap_or(self.get_random_handler()); + self.picked_node.set_handler(next_type); } pub fn get_random_handler(&self) -> NodeHandler { let picked = random_range(0..3); diff --git a/src/tests.rs b/src/tests.rs index bfe4de7..82ca988 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,4 +1,4 @@ -use crate::{formula::Formula, node::Node}; +use crate::{formula::Formula, learner::Learner, node::Node}; #[test] fn test_node_variable() { @@ -17,7 +17,8 @@ fn test_plus_one() { Node::function( "+1".to_string(), |inputs: Vec| inputs[0] + 1f64, - Some(1) + Some(1), + 1 ), None ) @@ -35,7 +36,7 @@ fn test_branch_sum() { formula .modify_tree() .insert_node( - Node::function("Sum".to_string(), |inputs| inputs.iter().sum(), Some(2)), + Node::function("Sum".to_string(), |inputs| inputs.iter().sum(), Some(2), 0), None ) .is_err() @@ -60,7 +61,7 @@ fn test_display_as_text() { formula .modify_tree() .insert_node( - Node::function("sum".to_string(), |inputs| inputs.iter().sum(), None), + Node::function("sum".to_string(), |inputs| inputs.iter().sum(), None, 0), None ) .is_err() @@ -73,7 +74,8 @@ fn test_display_as_text() { .add_node(Node::function( "sin".to_string(), |inputs| inputs[0].sin(), - Some(1) + Some(1), + 0 )) .is_err() == false @@ -89,3 +91,12 @@ fn test_display_as_text() { ); assert_eq!(formula.as_text(), "sum(X,sin(X))".to_string()); } + +#[test] +fn test_2x() { + let inputs = vec![0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.]; + let outputs = vec![0., 2., 4., 6., 8., 10., 12., 14., 16., 18., 20., 22., 24.]; + let mut learner = Learner::new(inputs.clone(), outputs.clone(), None, None); + let formula = learner.calculate_formula(); + assert_eq!(formula.run(inputs), outputs); +}