diff --git a/src/learner.rs b/src/learner.rs new file mode 100644 index 0000000..fb14647 --- /dev/null +++ b/src/learner.rs @@ -0,0 +1,78 @@ +use crate::formula::Formula; + +pub struct Learner { + best_algorithm: Formula, + inputs: Vec, + expected_outputs: Vec, + formulas_per_iteration: usize, + iterations: usize, +} + +impl Learner { + pub fn new( + inputs: Vec, + expected_outputs: Vec, + formulas_per_iteration: Option, + iterations: Option, + ) -> Self { + Self { + best_algorithm: Formula::new(), + inputs, + expected_outputs, + formulas_per_iteration: formulas_per_iteration.unwrap_or(200), + iterations: iterations.unwrap_or(200), + } + } + pub fn iterate(&self) -> Formula { + let mut formulas: Vec<(Formula, f64)> = vec![]; + for _ in 0..self.formulas_per_iteration { + let mut formula = self.best_algorithm.clone(); + Learner::mutate_formula_randomly(&mut formula); + let outputs = formula.run(self.inputs.clone()); + formulas.push(( + formula, + Learner::get_similarity(&self.expected_outputs, &outputs).unwrap(), + )); + } + formulas + .iter() + .max_by(|x, y| { + if x.1 > y.1 { + std::cmp::Ordering::Greater + } else if x.1 < y.1 { + std::cmp::Ordering::Less + } else { + std::cmp::Ordering::Equal + } + }) + .unwrap() + .0 + .clone() + } + fn mutate_formula_randomly(formula: &mut Formula) {} + fn get_similarity(expected_output: &Vec, real_output: &Vec) -> Result { + if expected_output.len() != real_output.len() { + return Err(()); + } + + let mut scalar = 0f64; + let mut expected_len = 0f64; + let mut real_len = 0f64; + + for i in 0..expected_output.len() { + expected_len += expected_output[i] * expected_output[i]; + real_len += real_output[i] * real_output[i]; + scalar += expected_output[i] * real_output[i]; + } + + expected_len = expected_len.sqrt(); + real_len = real_len.sqrt(); + + let cos: f64 = scalar / (expected_len * real_len); + let len_proportion: f64 = real_len / expected_len; + let similarity: f64 = + cos * (1f64 - ((len_proportion - 1f64).abs() / (len_proportion + 1f64))); + + Ok(similarity) + } +} diff --git a/src/lib.rs b/src/lib.rs index 97ffda3..76516c8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod formula; +pub mod learner; pub mod node; #[cfg(test)]