Vita
search.tcc
1/**
2 * \file
3 * \remark This file is part of VITA.
4 *
5 * \copyright Copyright (C) 2013-2023 EOS di Manlio Morini.
6 *
7 * \license
8 * This Source Code Form is subject to the terms of the Mozilla Public
9 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
10 * You can obtain one at http://mozilla.org/MPL/2.0/
11 */
12
13#if !defined(VITA_SEARCH_H)
14# error "Don't include this file directly, include the specific .h instead"
15#endif
16
17#if !defined(VITA_SEARCH_TCC)
18#define VITA_SEARCH_TCC
19
20///
21/// \param[in] p the problem we're working on. The lifetime of `p` must exceed
22/// the lifetime of `this` class
23///
24template<class T, template<class> class ES>
25search<T, ES>::search(problem &p) : eva1_(nullptr), eva2_(nullptr),
26 vs_(std::make_unique<as_is_validation>()),
27 prob_(p), after_generation_callback_()
28{
29 Ensures(is_valid());
30}
31
32template<class T, template<class> class ES>
33search<T, ES> &search<T, ES>::after_generation(
34 typename evolution<T, ES>::after_generation_callback_t f)
35{
36 after_generation_callback_ = std::move(f);
37 return *this;
38}
39
40template<class T, template<class> class ES>
41bool search<T, ES>::can_validate() const
42{
43 return eva2_.get();
44}
45
46///
47/// Calculates and stores the fitness of the best individual so far.
48///
49/// \param[in] s summary of the evolution run just finished
50///
51/// Specializations of this method can calculate further / distinct
52/// problem-specific metrics regarding the candidate solution.
53///
54/// If a validation set / simulation is available, it's used for the
55/// calculations.
56///
57template<class T, template<class> class ES>
58void search<T, ES>::calculate_metrics(summary<T> *s) const
59{
60 auto &best(s->best);
61
62 if (can_validate())
63 {
64 assert(eva2_);
65 best.score.fitness = (*eva2_)(best.solution);
66 }
67 else
68 {
69 assert(eva1_);
70 best.score.fitness = (*eva1_)(best.solution);
71 }
72
73 // We use accuracy or fitness (or both) to identify successful runs.
74 best.score.is_solution = (best.score >= this->prob_.env.threshold);
75}
76
77///
78/// Tries to tune search parameters for the current problem.
79///
80template<class T, template<class> class ES>
81void search<T, ES>::tune_parameters()
82{
83 // The `shape` function modifies the default parameters with
84 // strategy-specific values.
85 const environment dflt(ES<T>::shape(environment().init()));
86 const environment constrained(prob_.env);
87
88 if (!constrained.mep.code_length)
89 prob_.env.mep.code_length = dflt.mep.code_length;
90
91 if (!constrained.mep.patch_length)
92 prob_.env.mep.patch_length = 1 + prob_.sset.terminals(0) / 2;
93
94 if (constrained.elitism == trilean::unknown)
95 prob_.env.elitism = dflt.elitism;
96
97 if (constrained.p_mutation < 0.0)
98 prob_.env.p_mutation = dflt.p_mutation;
99
100 if (constrained.p_cross < 0.0)
101 prob_.env.p_cross = dflt.p_cross;
102
103 if (!constrained.brood_recombination)
104 prob_.env.brood_recombination = dflt.brood_recombination;
105
106 if (!constrained.layers)
107 prob_.env.layers = dflt.layers;
108
109 if (!constrained.individuals)
110 prob_.env.individuals = dflt.individuals;
111
112 if (!constrained.min_individuals)
113 prob_.env.min_individuals = dflt.min_individuals;
114
115 if (!constrained.tournament_size)
116 prob_.env.tournament_size = dflt.tournament_size;
117
118 if (!constrained.mate_zone)
119 prob_.env.mate_zone = dflt.mate_zone;
120
121 if (!constrained.generations)
122 prob_.env.generations = dflt.generations;
123
124 if (!constrained.max_stuck_time.has_value())
125 prob_.env.max_stuck_time = dflt.max_stuck_time;
126
127 Ensures(prob_.env.is_valid(true));
128}
129
130///
131/// Performs basic initialization before the search.
132///
133/// The default behaviour involve:
134/// - tuning of the search parameters;
135/// - possibly loading cached value for the training evaluator.
136///
137/// \remark
138/// Called at the beginning of the first run (i.e. only one time even for a
139/// multiple-run search).
140///
141template<class T, template<class> class ES>
142void search<T, ES>::init()
143{
144 tune_parameters();
145
146 load();
147}
148
149///
150/// Performs closing actions at the end of the search.
151///
152/// The default behaviour involve (possibly) caching values of the training
153/// evaluator.
154///
155/// \remark
156/// Called at the beginning of the first run (i.e. only one time even for a
157/// multiple-run search).
158///
159template<class T, template<class> class ES>
160void search<T, ES>::close()
161{
162 save();
163}
164
165///
166/// Performs after evolution tasks.
167///
168/// The default act is to print the result of the evolutionary run. Derived
169/// classes can change / integrate the base behaviour.
170///
171/// \remark Called at the end of each run.
172///
173template<class T, template<class> class ES>
174void search<T, ES>::after_evolution(const summary<T> &s)
175{
176 print_resume(s.best.score);
177}
178
179///
180/// \param[in] n number of runs
181/// \return a summary of the search
182///
183template<class T, template<class> class ES>
184summary<T> search<T, ES>::run(unsigned n)
185{
186 init();
187
188 auto shake([this](unsigned g) { return vs_->shake(g); });
189 search_stats<T> stats;
190
191 for (unsigned r(0); r < n; ++r)
192 {
193 vs_->init(r);
194 auto run_summary(evolution<T, ES>(prob_, *eva1_)
195 .after_generation(after_generation_callback_)
196 .run(r, shake));
197 vs_->close(r);
198
199 // Possibly calculates additional metrics.
200 calculate_metrics(&run_summary);
201
202 after_evolution(run_summary);
203
204 stats.update(run_summary);
205 log_stats(stats);
206 }
207
208 close();
209
210 return stats.overall;
211}
212
213template<class T>
214void search_stats<T>::update(const summary<T> &r)
215{
216 if (runs == 0 || r.best.score.fitness > overall.best.score.fitness)
217 {
218 overall.best = r.best;
219 best_run = runs;
220 }
221
222 if (r.best.score.is_solution)
223 {
224 overall.last_imp += r.last_imp;
225 good_runs.insert(good_runs.end(), runs);
226 }
227
228 if (isfinite(r.best.score.fitness))
229 fd.add(r.best.score.fitness);
230
231 overall.elapsed += r.elapsed;
232 overall.gen += r.gen;
233
234 ++runs;
235
236 Ensures(good_runs.empty() || good_runs.count(best_run));
237}
238
239///
240/// Loads the saved evaluation cache from a file (if available).
241///
242/// \return `true` if the object is correctly loaded
243///
244template<class T, template<class> class ES>
245bool search<T, ES>::load()
246{
247 if (prob_.env.misc.serialization_file.empty())
248 return true;
249
250 std::ifstream in(prob_.env.misc.serialization_file);
251 if (!in)
252 return false;
253
254 if (prob_.env.cache_size)
255 {
256 if (!eva1_->load(in))
257 return false;
258 vitaINFO << "Loading cache";
259 }
260
261 return true;
262}
263
264///
265/// \return `true` if the object was saved correctly
266///
267template<class T, template<class> class ES>
268bool search<T, ES>::save() const
269{
270 if (prob_.env.misc.serialization_file.empty())
271 return true;
272
273 std::ofstream out(prob_.env.misc.serialization_file);
274 if (!out)
275 return false;
276
277 if (prob_.env.cache_size)
278 {
279 if (!eva1_->save(out))
280 return false;
281 vitaINFO << "Saving cache";
282 }
283
284 return true;
285}
286
287///
288/// Prints a resume of the evolutionary run.
289///
290/// \param[in] m metrics relative to the current run
291///
292/// Derived classes can add further specific information.
293///
294template<class T, template<class> class ES>
295void search<T, ES>::print_resume(const model_measurements &m) const
296{
297 const std::string s(can_validate() ? "Validation" : "Training");
298
299 vitaINFO << s << " fitness: " << m.fitness;
300}
301
302///
303/// Sets the main evaluator (used for training).
304///
305/// \tparam E an evaluator
306///
307/// \param[in] args arguments used to build the `E` evaluator
308/// \return a reference to the search class (used for method chaining)
309///
310/// \warning
311/// We assume that the training evaluator could have a cache. This means that
312/// changes in the training simulation / set should invalidate fitness values
313/// stored in that cache.
314///
315template<class T, template<class> class ES>
316template<class E, class... Args>
317search<T, ES> &search<T, ES>::training_evaluator(Args && ...args)
318{
319 if (prob_.env.cache_size)
320 eva1_ = std::make_unique<evaluator_proxy<T, E>>(
321 E(std::forward<Args>(args)...), prob_.env.cache_size);
322 else
323 eva1_ = std::make_unique<E>(std::forward<Args>(args)...);
324
325 return *this;
326}
327
328///
329/// Sets the validation evaluator (used for validation).
330///
331/// \tparam E an evaluator
332///
333/// \param[in] args arguments used to build the `E` evaluator
334/// \return a reference to the search class (used for method chaining)
335///
336/// \warning
337/// The validation evaluator cannot have a cache.
338///
339template<class T, template<class> class ES>
340template<class E, class... Args>
341search<T, ES> &search<T, ES>::validation_evaluator(Args && ...args)
342{
343 eva2_ = std::make_unique<E>(std::forward<Args>(args)...);
344 return *this;
345}
346
347///
348/// Sets the active validation strategy.
349///
350/// \param[in] args parameters for the validation strategy
351/// \return a reference to the search class (used for method chaining)
352///
353template<class T, template<class> class ES>
354template<class V, class... Args>
355search<T, ES> &search<T, ES>::validation_strategy(Args && ...args)
356{
357 vs_ = std::make_unique<V>(std::forward<Args>(args)...);
358 return *this;
359}
360
361///
362/// \return `true` if the object passes the internal consistency check
363///
364template<class T, template<class> class ES>
365bool search<T, ES>::is_valid() const
366{
367 return true;
368}
369
370///
371/// Writes end-of-run logs (run summary, results for test...).
372///
373/// \param[in] stats mixed statistics about the search performed so far
374/// \param[out] d output file (XML)
375///
376template<class T, template<class> class ES>
377void search<T, ES>::log_stats(const search_stats<T> &stats,
378 tinyxml2::XMLDocument *d) const
379{
380 Expects(d);
381
382 if (prob_.env.stat.summary_file.empty())
383 return;
384
385 auto *root(d->NewElement("vita"));
386 d->InsertFirstChild(root);
387
388 auto *e_summary(d->NewElement("summary"));
389 root->InsertEndChild(e_summary);
390
391 const auto solutions(stats.good_runs.size());
392 const auto success_rate(
393 stats.runs ? static_cast<double>(solutions)
394 / static_cast<double>(stats.runs)
395 : 0);
396
397 set_text(e_summary, "success_rate", success_rate);
398 set_text(e_summary, "elapsed_time", stats.overall.elapsed.count());
399 set_text(e_summary, "mean_fitness", stats.fd.mean());
400 set_text(e_summary, "standard_deviation", stats.fd.standard_deviation());
401
402 auto *e_best(d->NewElement("best"));
403 e_summary->InsertEndChild(e_best);
404 set_text(e_best, "fitness", stats.overall.best.score.fitness);
405 set_text(e_best, "run", stats.best_run);
406
407 std::ostringstream ss;
408 ss << out::print_format(prob_.env.stat.ind_format)
409 << stats.overall.best.solution;
410 set_text(e_best, "code", ss.str());
411
412 auto *e_solutions(d->NewElement("solutions"));
413 e_summary->InsertEndChild(e_solutions);
414
415 auto *e_runs(d->NewElement("runs"));
416 e_solutions->InsertEndChild(e_runs);
417 for (const auto &gr : stats.good_runs)
418 set_text(e_runs, "run", gr);
419 set_text(e_solutions, "found", solutions);
420
421 const auto avg_depth(solutions ? stats.overall.last_imp / solutions
422 : 0);
423 set_text(e_solutions, "avg_depth", avg_depth);
424
425 auto *e_other(d->NewElement("other"));
426 e_summary->InsertEndChild(e_other);
427
428 prob_.env.xml(d);
429}
430
431template<class T, template<class> class ES>
432void search<T, ES>::log_stats(const search_stats<T> &stats) const
433{
434 tinyxml2::XMLDocument d(false);
435
436 log_stats(stats, &d);
437
438 const auto f_sum(prob_.env.stat.dir / prob_.env.stat.summary_file);
439 d.SaveFile(f_sum.string().c_str());
440}
441
442#endif // include guard