3 * \remark This file is part of VITA.
5 * \copyright Copyright (C) 2013-2023 EOS di Manlio Morini.
8 * This Source Code Form is subject to the terms of the Mozilla Public
9 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
10 * You can obtain one at http://mozilla.org/MPL/2.0/
13#if !defined(VITA_SEARCH_H)
14# error "Don't include this file directly, include the specific .h instead"
17#if !defined(VITA_SEARCH_TCC)
18#define VITA_SEARCH_TCC
21/// \param[in] p the problem we're working on. The lifetime of `p` must exceed
22/// the lifetime of `this` class
24template<class T, template<class> class ES>
25search<T, ES>::search(problem &p) : eva1_(nullptr), eva2_(nullptr),
26 vs_(std::make_unique<as_is_validation>()),
27 prob_(p), after_generation_callback_()
32template<class T, template<class> class ES>
33search<T, ES> &search<T, ES>::after_generation(
34 typename evolution<T, ES>::after_generation_callback_t f)
36 after_generation_callback_ = std::move(f);
40template<class T, template<class> class ES>
41bool search<T, ES>::can_validate() const
47/// Calculates and stores the fitness of the best individual so far.
49/// \param[in] s summary of the evolution run just finished
51/// Specializations of this method can calculate further / distinct
52/// problem-specific metrics regarding the candidate solution.
54/// If a validation set / simulation is available, it's used for the
57template<class T, template<class> class ES>
58void search<T, ES>::calculate_metrics(summary<T> *s) const
65 best.score.fitness = (*eva2_)(best.solution);
70 best.score.fitness = (*eva1_)(best.solution);
73 // We use accuracy or fitness (or both) to identify successful runs.
74 best.score.is_solution = (best.score >= this->prob_.env.threshold);
78/// Tries to tune search parameters for the current problem.
80template<class T, template<class> class ES>
81void search<T, ES>::tune_parameters()
83 // The `shape` function modifies the default parameters with
84 // strategy-specific values.
85 const environment dflt(ES<T>::shape(environment().init()));
86 const environment constrained(prob_.env);
88 if (!constrained.mep.code_length)
89 prob_.env.mep.code_length = dflt.mep.code_length;
91 if (!constrained.mep.patch_length)
92 prob_.env.mep.patch_length = 1 + prob_.sset.terminals(0) / 2;
94 if (constrained.elitism == trilean::unknown)
95 prob_.env.elitism = dflt.elitism;
97 if (constrained.p_mutation < 0.0)
98 prob_.env.p_mutation = dflt.p_mutation;
100 if (constrained.p_cross < 0.0)
101 prob_.env.p_cross = dflt.p_cross;
103 if (!constrained.brood_recombination)
104 prob_.env.brood_recombination = dflt.brood_recombination;
106 if (!constrained.layers)
107 prob_.env.layers = dflt.layers;
109 if (!constrained.individuals)
110 prob_.env.individuals = dflt.individuals;
112 if (!constrained.min_individuals)
113 prob_.env.min_individuals = dflt.min_individuals;
115 if (!constrained.tournament_size)
116 prob_.env.tournament_size = dflt.tournament_size;
118 if (!constrained.mate_zone)
119 prob_.env.mate_zone = dflt.mate_zone;
121 if (!constrained.generations)
122 prob_.env.generations = dflt.generations;
124 if (!constrained.max_stuck_time.has_value())
125 prob_.env.max_stuck_time = dflt.max_stuck_time;
127 Ensures(prob_.env.is_valid(true));
131/// Performs basic initialization before the search.
133/// The default behaviour involve:
134/// - tuning of the search parameters;
135/// - possibly loading cached value for the training evaluator.
138/// Called at the beginning of the first run (i.e. only one time even for a
139/// multiple-run search).
141template<class T, template<class> class ES>
142void search<T, ES>::init()
150/// Performs closing actions at the end of the search.
152/// The default behaviour involve (possibly) caching values of the training
156/// Called at the beginning of the first run (i.e. only one time even for a
157/// multiple-run search).
159template<class T, template<class> class ES>
160void search<T, ES>::close()
166/// Performs after evolution tasks.
168/// The default act is to print the result of the evolutionary run. Derived
169/// classes can change / integrate the base behaviour.
171/// \remark Called at the end of each run.
173template<class T, template<class> class ES>
174void search<T, ES>::after_evolution(const summary<T> &s)
176 print_resume(s.best.score);
180/// \param[in] n number of runs
181/// \return a summary of the search
183template<class T, template<class> class ES>
184summary<T> search<T, ES>::run(unsigned n)
188 auto shake([this](unsigned g) { return vs_->shake(g); });
189 search_stats<T> stats;
191 for (unsigned r(0); r < n; ++r)
194 auto run_summary(evolution<T, ES>(prob_, *eva1_)
195 .after_generation(after_generation_callback_)
199 // Possibly calculates additional metrics.
200 calculate_metrics(&run_summary);
202 after_evolution(run_summary);
204 stats.update(run_summary);
210 return stats.overall;
214void search_stats<T>::update(const summary<T> &r)
216 if (runs == 0 || r.best.score.fitness > overall.best.score.fitness)
218 overall.best = r.best;
222 if (r.best.score.is_solution)
224 overall.last_imp += r.last_imp;
225 good_runs.insert(good_runs.end(), runs);
228 if (isfinite(r.best.score.fitness))
229 fd.add(r.best.score.fitness);
231 overall.elapsed += r.elapsed;
232 overall.gen += r.gen;
236 Ensures(good_runs.empty() || good_runs.count(best_run));
240/// Loads the saved evaluation cache from a file (if available).
242/// \return `true` if the object is correctly loaded
244template<class T, template<class> class ES>
245bool search<T, ES>::load()
247 if (prob_.env.misc.serialization_file.empty())
250 std::ifstream in(prob_.env.misc.serialization_file);
254 if (prob_.env.cache_size)
256 if (!eva1_->load(in))
258 vitaINFO << "Loading cache";
265/// \return `true` if the object was saved correctly
267template<class T, template<class> class ES>
268bool search<T, ES>::save() const
270 if (prob_.env.misc.serialization_file.empty())
273 std::ofstream out(prob_.env.misc.serialization_file);
277 if (prob_.env.cache_size)
279 if (!eva1_->save(out))
281 vitaINFO << "Saving cache";
288/// Prints a resume of the evolutionary run.
290/// \param[in] m metrics relative to the current run
292/// Derived classes can add further specific information.
294template<class T, template<class> class ES>
295void search<T, ES>::print_resume(const model_measurements &m) const
297 const std::string s(can_validate() ? "Validation" : "Training");
299 vitaINFO << s << " fitness: " << m.fitness;
303/// Sets the main evaluator (used for training).
305/// \tparam E an evaluator
307/// \param[in] args arguments used to build the `E` evaluator
308/// \return a reference to the search class (used for method chaining)
311/// We assume that the training evaluator could have a cache. This means that
312/// changes in the training simulation / set should invalidate fitness values
313/// stored in that cache.
315template<class T, template<class> class ES>
316template<class E, class... Args>
317search<T, ES> &search<T, ES>::training_evaluator(Args && ...args)
319 if (prob_.env.cache_size)
320 eva1_ = std::make_unique<evaluator_proxy<T, E>>(
321 E(std::forward<Args>(args)...), prob_.env.cache_size);
323 eva1_ = std::make_unique<E>(std::forward<Args>(args)...);
329/// Sets the validation evaluator (used for validation).
331/// \tparam E an evaluator
333/// \param[in] args arguments used to build the `E` evaluator
334/// \return a reference to the search class (used for method chaining)
337/// The validation evaluator cannot have a cache.
339template<class T, template<class> class ES>
340template<class E, class... Args>
341search<T, ES> &search<T, ES>::validation_evaluator(Args && ...args)
343 eva2_ = std::make_unique<E>(std::forward<Args>(args)...);
348/// Sets the active validation strategy.
350/// \param[in] args parameters for the validation strategy
351/// \return a reference to the search class (used for method chaining)
353template<class T, template<class> class ES>
354template<class V, class... Args>
355search<T, ES> &search<T, ES>::validation_strategy(Args && ...args)
357 vs_ = std::make_unique<V>(std::forward<Args>(args)...);
362/// \return `true` if the object passes the internal consistency check
364template<class T, template<class> class ES>
365bool search<T, ES>::is_valid() const
371/// Writes end-of-run logs (run summary, results for test...).
373/// \param[in] stats mixed statistics about the search performed so far
374/// \param[out] d output file (XML)
376template<class T, template<class> class ES>
377void search<T, ES>::log_stats(const search_stats<T> &stats,
378 tinyxml2::XMLDocument *d) const
382 if (prob_.env.stat.summary_file.empty())
385 auto *root(d->NewElement("vita"));
386 d->InsertFirstChild(root);
388 auto *e_summary(d->NewElement("summary"));
389 root->InsertEndChild(e_summary);
391 const auto solutions(stats.good_runs.size());
392 const auto success_rate(
393 stats.runs ? static_cast<double>(solutions)
394 / static_cast<double>(stats.runs)
397 set_text(e_summary, "success_rate", success_rate);
398 set_text(e_summary, "elapsed_time", stats.overall.elapsed.count());
399 set_text(e_summary, "mean_fitness", stats.fd.mean());
400 set_text(e_summary, "standard_deviation", stats.fd.standard_deviation());
402 auto *e_best(d->NewElement("best"));
403 e_summary->InsertEndChild(e_best);
404 set_text(e_best, "fitness", stats.overall.best.score.fitness);
405 set_text(e_best, "run", stats.best_run);
407 std::ostringstream ss;
408 ss << out::print_format(prob_.env.stat.ind_format)
409 << stats.overall.best.solution;
410 set_text(e_best, "code", ss.str());
412 auto *e_solutions(d->NewElement("solutions"));
413 e_summary->InsertEndChild(e_solutions);
415 auto *e_runs(d->NewElement("runs"));
416 e_solutions->InsertEndChild(e_runs);
417 for (const auto &gr : stats.good_runs)
418 set_text(e_runs, "run", gr);
419 set_text(e_solutions, "found", solutions);
421 const auto avg_depth(solutions ? stats.overall.last_imp / solutions
423 set_text(e_solutions, "avg_depth", avg_depth);
425 auto *e_other(d->NewElement("other"));
426 e_summary->InsertEndChild(e_other);
431template<class T, template<class> class ES>
432void search<T, ES>::log_stats(const search_stats<T> &stats) const
434 tinyxml2::XMLDocument d(false);
436 log_stats(stats, &d);
438 const auto f_sum(prob_.env.stat.dir / prob_.env.stat.summary_file);
439 d.SaveFile(f_sum.string().c_str());
442#endif // include guard