Compare commits
2 commits
ad4375240e
...
2ebfe38bac
| Author | SHA1 | Date | |
|---|---|---|---|
| 2ebfe38bac | |||
| 6655499305 |
2 changed files with 79 additions and 0 deletions
78
src/learner.rs
Normal file
78
src/learner.rs
Normal file
|
|
@ -0,0 +1,78 @@
|
||||||
|
use crate::formula::Formula;
|
||||||
|
|
||||||
|
pub struct Learner {
|
||||||
|
best_algorithm: Formula,
|
||||||
|
inputs: Vec<f64>,
|
||||||
|
expected_outputs: Vec<f64>,
|
||||||
|
formulas_per_iteration: usize,
|
||||||
|
iterations: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Learner {
|
||||||
|
pub fn new(
|
||||||
|
inputs: Vec<f64>,
|
||||||
|
expected_outputs: Vec<f64>,
|
||||||
|
formulas_per_iteration: Option<usize>,
|
||||||
|
iterations: Option<usize>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
best_algorithm: Formula::new(),
|
||||||
|
inputs,
|
||||||
|
expected_outputs,
|
||||||
|
formulas_per_iteration: formulas_per_iteration.unwrap_or(200),
|
||||||
|
iterations: iterations.unwrap_or(200),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn iterate(&self) -> Formula {
|
||||||
|
let mut formulas: Vec<(Formula, f64)> = vec![];
|
||||||
|
for _ in 0..self.formulas_per_iteration {
|
||||||
|
let mut formula = self.best_algorithm.clone();
|
||||||
|
Learner::mutate_formula_randomly(&mut formula);
|
||||||
|
let outputs = formula.run(self.inputs.clone());
|
||||||
|
formulas.push((
|
||||||
|
formula,
|
||||||
|
Learner::get_similarity(&self.expected_outputs, &outputs).unwrap(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
formulas
|
||||||
|
.iter()
|
||||||
|
.max_by(|x, y| {
|
||||||
|
if x.1 > y.1 {
|
||||||
|
std::cmp::Ordering::Greater
|
||||||
|
} else if x.1 < y.1 {
|
||||||
|
std::cmp::Ordering::Less
|
||||||
|
} else {
|
||||||
|
std::cmp::Ordering::Equal
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
.0
|
||||||
|
.clone()
|
||||||
|
}
|
||||||
|
fn mutate_formula_randomly(formula: &mut Formula) {}
|
||||||
|
fn get_similarity(expected_output: &Vec<f64>, real_output: &Vec<f64>) -> Result<f64, ()> {
|
||||||
|
if expected_output.len() != real_output.len() {
|
||||||
|
return Err(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut scalar = 0f64;
|
||||||
|
let mut expected_len = 0f64;
|
||||||
|
let mut real_len = 0f64;
|
||||||
|
|
||||||
|
for i in 0..expected_output.len() {
|
||||||
|
expected_len += expected_output[i] * expected_output[i];
|
||||||
|
real_len += real_output[i] * real_output[i];
|
||||||
|
scalar += expected_output[i] * real_output[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
expected_len = expected_len.sqrt();
|
||||||
|
real_len = real_len.sqrt();
|
||||||
|
|
||||||
|
let cos: f64 = scalar / (expected_len * real_len);
|
||||||
|
let len_proportion: f64 = real_len / expected_len;
|
||||||
|
let similarity: f64 =
|
||||||
|
cos * (1f64 - ((len_proportion - 1f64).abs() / (len_proportion + 1f64)));
|
||||||
|
|
||||||
|
Ok(similarity)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
pub mod formula;
|
pub mod formula;
|
||||||
|
pub mod learner;
|
||||||
pub mod node;
|
pub mod node;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue