It doesn't work

This commit is contained in:
Rendo 2025-11-09 18:52:59 +05:00
commit c71ed10e76
5 changed files with 81 additions and 25 deletions

View file

@ -35,8 +35,21 @@ impl Learner {
} }
self.best_algorithm.clone() self.best_algorithm.clone()
} }
pub fn calculate_formula_debug(&mut self) -> Formula {
for _ in 0..self.iterations {
self.best_algorithm = self.iterate();
self.best_algorithm.display_tree();
}
self.best_algorithm.clone()
}
fn iterate(&self) -> Formula { fn iterate(&self) -> Formula {
let mut formulas: Vec<(Formula, f64)> = vec![]; let best_similarity = Learner::get_similarity(
&self.expected_outputs,
&self.best_algorithm.run(self.inputs.clone()),
)
.unwrap();
let mut formulas: Vec<(Formula, f64)> =
vec![(self.best_algorithm.clone(), best_similarity)];
for _ in 0..self.formulas_per_iteration { for _ in 0..self.formulas_per_iteration {
let mut formula = self.best_algorithm.clone(); let mut formula = self.best_algorithm.clone();
Learner::mutate_formula_randomly(&mut formula); Learner::mutate_formula_randomly(&mut formula);
@ -63,7 +76,7 @@ impl Learner {
} }
fn mutate_formula_randomly(formula: &mut Formula) { fn mutate_formula_randomly(formula: &mut Formula) {
let mut editor = formula.modify_random_node(); let mut editor = formula.modify_random_node();
let decided_action = random_range(0..3); let decided_action = random_range(0..4);
if decided_action == ACTION_ADD { if decided_action == ACTION_ADD {
editor.add_node(editor.get_random_node()); editor.add_node(editor.get_random_node());
} else if decided_action == ACTION_REMOVE { } else if decided_action == ACTION_REMOVE {

View file

@ -8,9 +8,11 @@ mod tests;
fn main() { fn main() {
let mut learner = Learner::new( let mut learner = Learner::new(
vec![0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.], vec![0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.],
vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], vec![0., 2., 4., 6., 8., 10., 12., 14., 16., 19., 20.],
None,
None, None,
Some(10000),
); );
println!("{:?}", learner.calculate_formula().as_text()); let formula = learner.calculate_formula();
println!("{:?}", formula.as_text());
formula.display_tree();
} }

View file

@ -29,7 +29,7 @@ impl Node {
Self { Self {
children: vec![], children: vec![],
handler: NodeHandler::Empty, handler: NodeHandler::Empty,
max_children_count: None, max_children_count: Some(1),
} }
} }
pub fn number(n: f64) -> Self { pub fn number(n: f64) -> Self {
@ -74,6 +74,33 @@ impl Node {
return self.handler.run(inputs, passed_x); return self.handler.run(inputs, passed_x);
} }
pub fn set_handler(&mut self, to: NodeHandler) {
match &to {
NodeHandler::Number { number } => {
self.max_children_count = Some(0);
self.children.clear();
}
NodeHandler::Function {
name,
function,
max_args,
min_args,
} => {
self.max_children_count = max_args.clone();
if let Some(x) = max_args {
self.children.truncate(x.clone());
}
}
NodeHandler::Variable => {
self.max_children_count = Some(0);
self.children.clear();
}
NodeHandler::Empty => {
self.max_children_count = Some(1);
self.children.truncate(1);
}
}
}
pub fn modify_node(&mut self) -> NodeModifier { pub fn modify_node(&mut self) -> NodeModifier {
NodeModifier::from_random(self, None) NodeModifier::from_random(self, None)
} }

View file

@ -6,7 +6,7 @@ use rand::random_range;
pub mod errors; pub mod errors;
const PICK_STOP_PROBABILITY: u8 = 2; const PICK_STOP_PROBABILITY: u8 = 2;
const TYPE_CHANGE_PROBABILITY: u8 = 10; const TYPE_CHANGE_PROBABILITY: u8 = 2;
pub struct NodeModifier<'a> { pub struct NodeModifier<'a> {
picked_node: &'a mut Node, picked_node: &'a mut Node,
@ -120,6 +120,9 @@ impl<'a> NodeModifier<'a> {
&mut self, &mut self,
specified_index: Option<usize>, specified_index: Option<usize>,
) -> Result<(), NodeManipulationError> { ) -> Result<(), NodeManipulationError> {
if let NodeHandler::Empty = &self.picked_node.handler {
return Err(NodeManipulationError::ProtectedEmpty);
}
if self.picked_node.children.len() == 0 { if self.picked_node.children.len() == 0 {
return Err(NodeManipulationError::NotEnoughChildren); return Err(NodeManipulationError::NotEnoughChildren);
} }
@ -149,6 +152,11 @@ impl<'a> NodeModifier<'a> {
} }
} }
let children_count = self.picked_node.children.len(); let children_count = self.picked_node.children.len();
if let NodeHandler::Empty = &self.picked_node.handler
&& children_count - 1 > 0
{
return Err(NodeManipulationError::TooMuchChildren(node));
}
if children_count <= 0 { if children_count <= 0 {
return Err(NodeManipulationError::NotEnoughChildren); return Err(NodeManipulationError::NotEnoughChildren);
} }
@ -166,12 +174,12 @@ impl<'a> NodeModifier<'a> {
} else { } else {
match &self.picked_node.handler { match &self.picked_node.handler {
NodeHandler::Number { number } => { NodeHandler::Number { number } => {
self.picked_node.handler = NodeHandler::Number { self.picked_node.set_handler(NodeHandler::Number {
number: (self.number_mutation_pool number: (self.number_mutation_pool
[random_range(0..self.number_mutation_pool.len())])( [random_range(0..self.number_mutation_pool.len())])(
number.clone() number.clone()
), ),
}; });
} }
NodeHandler::Function { NodeHandler::Function {
function, function,
@ -181,23 +189,18 @@ impl<'a> NodeModifier<'a> {
} => { } => {
let selected_mutation = &self.function_mutation_pool let selected_mutation = &self.function_mutation_pool
[random_range(0..self.function_mutation_pool.len())]; [random_range(0..self.function_mutation_pool.len())];
if let NodeHandler::Function { self.picked_node.set_handler(selected_mutation.clone());
name,
function,
max_args,
min_args,
} = &selected_mutation
{
self.picked_node.max_children_count = max_args.clone();
self.picked_node.handler = selected_mutation.clone();
}
} }
_ => {} _ => {}
} }
} }
} }
pub fn change_node_type(&mut self, to: Option<NodeHandler>) { pub fn change_node_type(&mut self, to: Option<NodeHandler>) {
self.picked_node.handler = to.unwrap_or(self.get_random_handler()); if let NodeHandler::Empty = &self.picked_node.handler {
return;
}
let next_type = to.unwrap_or(self.get_random_handler());
self.picked_node.set_handler(next_type);
} }
pub fn get_random_handler(&self) -> NodeHandler { pub fn get_random_handler(&self) -> NodeHandler {
let picked = random_range(0..3); let picked = random_range(0..3);

View file

@ -1,4 +1,4 @@
use crate::{formula::Formula, node::Node}; use crate::{formula::Formula, learner::Learner, node::Node};
#[test] #[test]
fn test_node_variable() { fn test_node_variable() {
@ -17,7 +17,8 @@ fn test_plus_one() {
Node::function( Node::function(
"+1".to_string(), "+1".to_string(),
|inputs: Vec<f64>| inputs[0] + 1f64, |inputs: Vec<f64>| inputs[0] + 1f64,
Some(1) Some(1),
1
), ),
None None
) )
@ -35,7 +36,7 @@ fn test_branch_sum() {
formula formula
.modify_tree() .modify_tree()
.insert_node( .insert_node(
Node::function("Sum".to_string(), |inputs| inputs.iter().sum(), Some(2)), Node::function("Sum".to_string(), |inputs| inputs.iter().sum(), Some(2), 0),
None None
) )
.is_err() .is_err()
@ -60,7 +61,7 @@ fn test_display_as_text() {
formula formula
.modify_tree() .modify_tree()
.insert_node( .insert_node(
Node::function("sum".to_string(), |inputs| inputs.iter().sum(), None), Node::function("sum".to_string(), |inputs| inputs.iter().sum(), None, 0),
None None
) )
.is_err() .is_err()
@ -73,7 +74,8 @@ fn test_display_as_text() {
.add_node(Node::function( .add_node(Node::function(
"sin".to_string(), "sin".to_string(),
|inputs| inputs[0].sin(), |inputs| inputs[0].sin(),
Some(1) Some(1),
0
)) ))
.is_err() .is_err()
== false == false
@ -89,3 +91,12 @@ fn test_display_as_text() {
); );
assert_eq!(formula.as_text(), "sum(X,sin(X))".to_string()); assert_eq!(formula.as_text(), "sum(X,sin(X))".to_string());
} }
#[test]
fn test_2x() {
let inputs = vec![0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.];
let outputs = vec![0., 2., 4., 6., 8., 10., 12., 14., 16., 18., 20., 22., 24.];
let mut learner = Learner::new(inputs.clone(), outputs.clone(), None, None);
let formula = learner.calculate_formula();
assert_eq!(formula.run(inputs), outputs);
}