Compare commits

..

3 commits

Author SHA1 Message Date
b521356cc9 it just works 2025-11-09 16:23:24 +05:00
a2f9c95c98 I ain't happy 2025-11-09 15:53:23 +05:00
35d186cf05 Random mutations 2025-11-09 15:43:03 +05:00
6 changed files with 66 additions and 7 deletions

View file

@ -22,11 +22,13 @@ impl Formula {
} }
outputs outputs
} }
pub fn mutate(&mut self) {}
pub fn modify_tree(&mut self) -> NodeModifier { pub fn modify_tree(&mut self) -> NodeModifier {
self.tree.modify_tree() self.tree.modify_tree()
} }
pub fn modify_random_node(&mut self) -> NodeModifier {
self.tree.modify_node()
}
pub fn display_tree(&self) { pub fn display_tree(&self) {
self.display_recursion(0, vec![&self.tree]); self.display_recursion(0, vec![&self.tree]);
} }

View file

@ -1,4 +1,10 @@
use crate::formula::Formula; use crate::formula::Formula;
use rand::random_range;
const ACTION_ADD: u8 = 0;
const ACTION_REMOVE: u8 = 1;
const ACTION_INSERT: u8 = 2;
const ACTION_MUTATE: u8 = 3;
pub struct Learner { pub struct Learner {
best_algorithm: Formula, best_algorithm: Formula,
@ -23,7 +29,13 @@ impl Learner {
iterations: iterations.unwrap_or(200), iterations: iterations.unwrap_or(200),
} }
} }
pub fn iterate(&self) -> Formula { pub fn calculate_formula(&mut self) -> Formula {
for _ in 0..self.iterations {
self.best_algorithm = self.iterate()
}
self.best_algorithm.clone()
}
fn iterate(&self) -> Formula {
let mut formulas: Vec<(Formula, f64)> = vec![]; let mut formulas: Vec<(Formula, f64)> = vec![];
for _ in 0..self.formulas_per_iteration { for _ in 0..self.formulas_per_iteration {
let mut formula = self.best_algorithm.clone(); let mut formula = self.best_algorithm.clone();
@ -49,8 +61,20 @@ impl Learner {
.0 .0
.clone() .clone()
} }
fn mutate_formula_randomly(formula: &mut Formula) {} fn mutate_formula_randomly(formula: &mut Formula) {
fn get_similarity(expected_output: &Vec<f64>, real_output: &Vec<f64>) -> Result<f64, ()> { let mut editor = formula.modify_random_node();
let decided_action = random_range(0..3);
if decided_action == ACTION_ADD {
editor.add_node(editor.get_random_node());
} else if decided_action == ACTION_REMOVE {
editor.remove_node(None);
} else if decided_action == ACTION_INSERT {
editor.insert_node(editor.get_random_node(), None);
} else if decided_action == ACTION_MUTATE {
editor.mutate_node();
}
}
pub fn get_similarity(expected_output: &Vec<f64>, real_output: &Vec<f64>) -> Result<f64, ()> {
if expected_output.len() != real_output.len() { if expected_output.len() != real_output.len() {
return Err(()); return Err(());
} }

View file

@ -1,8 +1,16 @@
use crate::{formula::Formula, node::Node}; use fapprox::learner::Learner;
mod formula; mod formula;
mod node; mod node;
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
fn main() {} fn main() {
let mut learner = Learner::new(
vec![0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.],
vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.],
None,
None,
);
println!("{:?}", learner.calculate_formula().as_text());
}

View file

@ -7,6 +7,7 @@ pub enum NodeHandler {
name: String, name: String,
function: fn(Vec<f64>) -> f64, function: fn(Vec<f64>) -> f64,
max_args: Option<usize>, max_args: Option<usize>,
min_args: usize,
}, },
Variable, Variable,
Empty, Empty,
@ -20,7 +21,14 @@ impl NodeHandler {
function, function,
name, name,
max_args, max_args,
} => function(inputs), min_args,
} => {
if &inputs.len() >= min_args {
function(inputs)
} else {
0.
}
}
Self::Variable => passed_x, Self::Variable => passed_x,
Self::Empty => inputs[0], Self::Empty => inputs[0],
} }

View file

@ -43,6 +43,7 @@ impl Node {
func_name: String, func_name: String,
func: fn(Vec<f64>) -> f64, func: fn(Vec<f64>) -> f64,
max_children_count: Option<usize>, max_children_count: Option<usize>,
min_arguments: usize,
) -> Self { ) -> Self {
Self { Self {
children: vec![], children: vec![],
@ -50,6 +51,7 @@ impl Node {
name: func_name, name: func_name,
function: func, function: func,
max_args: max_children_count, max_args: max_children_count,
min_args: min_arguments,
}, },
max_children_count, max_children_count,
} }
@ -93,6 +95,7 @@ impl Node {
name, name,
function, function,
max_args, max_args,
min_args,
} => name.clone() + "(" + children_text.as_str() + ")", } => name.clone() + "(" + children_text.as_str() + ")",
NodeHandler::Empty => children_text, NodeHandler::Empty => children_text,
_ => self.to_string(), _ => self.to_string(),
@ -111,6 +114,7 @@ impl fmt::Display for Node {
function, function,
name, name,
max_args, max_args,
min_args,
} => name.clone(), } => name.clone(),
NodeHandler::Variable => "X".to_string(), NodeHandler::Variable => "X".to_string(),
NodeHandler::Empty => "".to_string(), NodeHandler::Empty => "".to_string(),

View file

@ -22,36 +22,43 @@ impl<'a> NodeModifier<'a> {
name: "sin".to_string(), name: "sin".to_string(),
function: |inputs| inputs[0].sin(), function: |inputs| inputs[0].sin(),
max_args: Some(1), max_args: Some(1),
min_args: 1,
}, },
NodeHandler::Function { NodeHandler::Function {
name: "cos".to_string(), name: "cos".to_string(),
function: |inputs| inputs[0].cos(), function: |inputs| inputs[0].cos(),
max_args: Some(1), max_args: Some(1),
min_args: 1,
}, },
NodeHandler::Function { NodeHandler::Function {
name: "sum".to_string(), name: "sum".to_string(),
function: |inputs| inputs.iter().sum(), function: |inputs| inputs.iter().sum(),
max_args: None, max_args: None,
min_args: 0,
}, },
NodeHandler::Function { NodeHandler::Function {
name: "-".to_string(), name: "-".to_string(),
function: |inputs| -inputs[0], function: |inputs| -inputs[0],
max_args: Some(1), max_args: Some(1),
min_args: 1,
}, },
NodeHandler::Function { NodeHandler::Function {
name: "product".to_string(), name: "product".to_string(),
function: |inputs| inputs.iter().product(), function: |inputs| inputs.iter().product(),
max_args: None, max_args: None,
min_args: 0,
}, },
NodeHandler::Function { NodeHandler::Function {
name: "exp".to_string(), name: "exp".to_string(),
function: |inputs| inputs[0].powf(inputs[1]), function: |inputs| inputs[0].powf(inputs[1]),
max_args: Some(2), max_args: Some(2),
min_args: 2,
}, },
NodeHandler::Function { NodeHandler::Function {
name: "1/".to_string(), name: "1/".to_string(),
function: |inputs| 1f64 / inputs[0], function: |inputs| 1f64 / inputs[0],
max_args: Some(1), max_args: Some(1),
min_args: 1,
}, },
]; ];
let standard_number_mutation: Vec<fn(f64) -> f64> = vec![ let standard_number_mutation: Vec<fn(f64) -> f64> = vec![
@ -142,6 +149,9 @@ impl<'a> NodeModifier<'a> {
} }
} }
let children_count = self.picked_node.children.len(); let children_count = self.picked_node.children.len();
if children_count <= 0 {
return Err(NodeManipulationError::NotEnoughChildren);
}
let operated_index = between.unwrap_or(random_range(0..children_count)); let operated_index = between.unwrap_or(random_range(0..children_count));
let moved = self.picked_node.children.remove(operated_index); let moved = self.picked_node.children.remove(operated_index);
node.children.push(moved); node.children.push(moved);
@ -167,6 +177,7 @@ impl<'a> NodeModifier<'a> {
function, function,
name, name,
max_args, max_args,
min_args,
} => { } => {
let selected_mutation = &self.function_mutation_pool let selected_mutation = &self.function_mutation_pool
[random_range(0..self.function_mutation_pool.len())]; [random_range(0..self.function_mutation_pool.len())];
@ -174,6 +185,7 @@ impl<'a> NodeModifier<'a> {
name, name,
function, function,
max_args, max_args,
min_args,
} = &selected_mutation } = &selected_mutation
{ {
self.picked_node.max_children_count = max_args.clone(); self.picked_node.max_children_count = max_args.clone();
@ -211,6 +223,7 @@ impl<'a> NodeModifier<'a> {
name, name,
function, function,
max_args, max_args,
min_args,
} => max_args, } => max_args,
NodeHandler::Variable => Some(0), NodeHandler::Variable => Some(0),
NodeHandler::Empty => Some(1), NodeHandler::Empty => Some(1),