Compare commits

...

4 commits

5 changed files with 56 additions and 49 deletions

View file

@ -31,25 +31,40 @@ impl Learner {
} }
pub fn calculate_formula(&mut self) -> Formula { pub fn calculate_formula(&mut self) -> Formula {
for _ in 0..self.iterations { for _ in 0..self.iterations {
self.best_algorithm = self.iterate() let current_best = Learner::get_similarity(
&self.expected_outputs,
&self.best_algorithm.run(self.inputs.clone()),
);
let found_best = self.iterate();
if found_best.1 > current_best.unwrap_or(0.) {
self.best_algorithm = found_best.0;
}
} }
self.best_algorithm.clone() self.best_algorithm.clone()
} }
pub fn calculate_formula_debug(&mut self) -> Formula { pub fn calculate_formula_debug(&mut self, tree: bool, sim: bool) -> Formula {
for _ in 0..self.iterations { for _ in 0..self.iterations {
self.best_algorithm = self.iterate(); let current_best = Learner::get_similarity(
self.best_algorithm.display_tree(); &self.expected_outputs,
&self.best_algorithm.run(self.inputs.clone()),
);
let found_best = self.iterate();
if sim {
println!("{:?}", &found_best.1);
}
if tree {
self.best_algorithm.display_tree();
}
if found_best.1 > current_best.unwrap_or(0.) {
self.best_algorithm = found_best.0;
}
} }
self.best_algorithm.clone() self.best_algorithm.clone()
} }
fn iterate(&self) -> Formula { fn iterate(&self) -> (Formula, f64) {
let best_similarity = Learner::get_similarity( let mut formulas: Vec<(Formula, f64)> = vec![];
&self.expected_outputs,
&self.best_algorithm.run(self.inputs.clone()),
)
.unwrap();
let mut formulas: Vec<(Formula, f64)> =
vec![(self.best_algorithm.clone(), best_similarity)];
for _ in 0..self.formulas_per_iteration { for _ in 0..self.formulas_per_iteration {
let mut formula = self.best_algorithm.clone(); let mut formula = self.best_algorithm.clone();
Learner::mutate_formula_randomly(&mut formula); Learner::mutate_formula_randomly(&mut formula);
@ -71,20 +86,23 @@ impl Learner {
} }
}) })
.unwrap() .unwrap()
.0
.clone() .clone()
} }
fn mutate_formula_randomly(formula: &mut Formula) { fn mutate_formula_randomly(formula: &mut Formula) {
let mut editor = formula.modify_random_node(); let amount_of_mutations = random_range(1..4);
let decided_action = random_range(0..4); for _ in 0..amount_of_mutations {
if decided_action == ACTION_ADD { let mut editor = formula.modify_random_node();
editor.add_node(editor.get_random_node()); let decided_action = random_range(0..4);
} else if decided_action == ACTION_REMOVE {
editor.remove_node(None); if decided_action == ACTION_ADD {
} else if decided_action == ACTION_INSERT { editor.add_node(editor.get_random_node());
editor.insert_node(editor.get_random_node(), None); } else if decided_action == ACTION_REMOVE {
} else if decided_action == ACTION_MUTATE { editor.remove_node(None);
editor.mutate_node(); } else if decided_action == ACTION_INSERT {
editor.insert_node(editor.get_random_node(), None);
} else if decided_action == ACTION_MUTATE {
editor.mutate_node();
}
} }
} }
pub fn get_similarity(expected_output: &Vec<f64>, real_output: &Vec<f64>) -> Result<f64, ()> { pub fn get_similarity(expected_output: &Vec<f64>, real_output: &Vec<f64>) -> Result<f64, ()> {

View file

@ -9,11 +9,11 @@ mod tests;
fn main() { fn main() {
let mut learner = Learner::new( let mut learner = Learner::new(
vec![0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.], vec![0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.],
vec![0., 2., 4., 6., 8., 10., 12., 14., 16., 19., 20.], vec![0., 1., 4., 9., 16., 25., 36., 49., 64., 81., 100.],
None,
None, None,
Some(10000),
); );
let formula = learner.calculate_formula_debug(); let formula = learner.calculate_formula();
println!("{:?}", formula.as_text()); println!("{:?}", formula.as_text());
formula.display_tree(); formula.display_tree();
} }

View file

@ -88,7 +88,7 @@ impl Node {
} => { } => {
self.max_children_count = max_args.clone(); self.max_children_count = max_args.clone();
if let Some(x) = max_args { if let Some(x) = max_args {
self.children.truncate(x.clone()); self.children.resize(x.clone(), Node::number(0.));
} }
} }
NodeHandler::Variable => { NodeHandler::Variable => {

View file

@ -146,17 +146,7 @@ impl<'a> NodeModifier<'a> {
mut node: Node, mut node: Node,
between: Option<usize>, between: Option<usize>,
) -> Result<(), NodeManipulationError> { ) -> Result<(), NodeManipulationError> {
if let Some(x) = self.picked_node.max_children_count {
if self.picked_node.children.len() + 1 > x {
return Err(NodeManipulationError::TooMuchChildren(node));
}
}
let children_count = self.picked_node.children.len(); let children_count = self.picked_node.children.len();
if let NodeHandler::Empty = &self.picked_node.handler
&& children_count - 1 > 0
{
return Err(NodeManipulationError::TooMuchChildren(node));
}
if children_count <= 0 { if children_count <= 0 {
return Err(NodeManipulationError::NotEnoughChildren); return Err(NodeManipulationError::NotEnoughChildren);
} }

View file

@ -22,8 +22,8 @@ fn test_plus_one() {
), ),
None None
) )
.is_err() .inspect_err(|x| println!("{x}"))
== false .is_ok()
); );
let results = formula.run(vec![0f64, 1f64, 2f64, 3f64, 4f64, 5f64]); let results = formula.run(vec![0f64, 1f64, 2f64, 3f64, 4f64, 5f64]);
assert_eq!(results, vec![1f64, 2f64, 3f64, 4f64, 5f64, 6f64]) assert_eq!(results, vec![1f64, 2f64, 3f64, 4f64, 5f64, 6f64])
@ -39,16 +39,15 @@ fn test_branch_sum() {
Node::function("Sum".to_string(), |inputs| inputs.iter().sum(), Some(2), 0), Node::function("Sum".to_string(), |inputs| inputs.iter().sum(), Some(2), 0),
None None
) )
.is_err() .inspect_err(|x| println!("{x}"))
== false .is_ok()
); );
assert!( assert!(
formula formula
.modify_tree() .modify_tree()
.go_down(0) .go_down(0)
.add_node(Node::number(1f64)) .add_node(Node::number(1f64))
.is_err() .is_ok()
== false
); );
let results = formula.run(vec![0f64, 1f64, 2f64, 3f64, 4f64, 5f64]); let results = formula.run(vec![0f64, 1f64, 2f64, 3f64, 4f64, 5f64]);
assert_eq!(results, vec![1f64, 2f64, 3f64, 4f64, 5f64, 6f64]) assert_eq!(results, vec![1f64, 2f64, 3f64, 4f64, 5f64, 6f64])
@ -64,8 +63,8 @@ fn test_display_as_text() {
Node::function("sum".to_string(), |inputs| inputs.iter().sum(), None, 0), Node::function("sum".to_string(), |inputs| inputs.iter().sum(), None, 0),
None None
) )
.is_err() .inspect_err(|x| println!("{x}"))
== false .is_ok()
); );
assert!( assert!(
formula formula
@ -77,8 +76,8 @@ fn test_display_as_text() {
Some(1), Some(1),
0 0
)) ))
.is_err() .inspect_err(|x| println!("{x}"))
== false .is_ok()
); );
assert!( assert!(
formula formula
@ -86,8 +85,8 @@ fn test_display_as_text() {
.go_down(0) .go_down(0)
.go_down(1) .go_down(1)
.add_node(Node::variable()) .add_node(Node::variable())
.is_err() .inspect_err(|x| println!("{x}"))
== false .is_ok()
); );
assert_eq!(formula.as_text(), "sum(X,sin(X))".to_string()); assert_eq!(formula.as_text(), "sum(X,sin(X))".to_string());
} }