Vita
gp/src/search.tcc
1/**
2 * \file
3 * \remark This file is part of VITA.
4 *
5 * \copyright Copyright (C) 2013-2023 EOS di Manlio Morini.
6 *
7 * \license
8 * This Source Code Form is subject to the terms of the Mozilla Public
9 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
10 * You can obtain one at http://mozilla.org/MPL/2.0/
11 */
12
13#if !defined(VITA_SRC_SEARCH_H)
14# error "Don't include this file directly, include the specific .h instead"
15#endif
16
17#if !defined(VITA_SRC_SEARCH_TCC)
18#define VITA_SRC_SEARCH_TCC
19
20constexpr std::underlying_type_t<metric_flags> operator&(metric_flags f1,
21 metric_flags f2)
22{
23 return as_integer(f1) & as_integer(f2);
24}
25
26///
27/// \param[in] p the problem we're working on. The lifetime of `p` must exceed
28/// the lifetime of `this` class
29/// \param[in] m a bit field used to specify matrics we have to calculate while
30/// searching
31///
32template<class T, template<class> class ES>
33src_search<T, ES>::src_search(src_problem &p, metric_flags m)
34 : search<T, ES>(p),
35 p_symre(evaluator_id::rmae), p_class(evaluator_id::gaussian), metrics(m)
36{
37 evaluator(p.classification() ? p_class : p_symre);
38
39 Ensures(this->is_valid());
40}
41
42///
43/// \return a reference to the training set
44///
45template<class T, template<class> class ES>
46dataframe &src_search<T, ES>::training_data() const
47{
48 return prob().data(dataset_t::training);
49}
50
51///
52/// \return a reference to the test set
53///
54template<class T, template<class> class ES>
55dataframe &src_search<T, ES>::test_data() const
56{
57 return prob().data(dataset_t::test);
58}
59
60///
61/// \return a reference to the validation set
62///
63template<class T, template<class> class ES>
64dataframe &src_search<T, ES>::validation_data() const
65{
66 return prob().data(dataset_t::validation);
67}
68
69///
70/// \return a reference to the current problem
71///
72template<class T, template<class> class ES>
73src_problem &src_search<T, ES>::prob() const
74{
75 return static_cast<src_problem &>(this->prob_);
76}
77
78///
79/// Creates a lambda function associated with an individual.
80///
81/// \param[in] ind individual to be transformed in a lambda function
82/// \return the lambda function (`nullptr` in case of errors)
83///
84/// The lambda function depends on the active training evaluator.
85///
86template<class T, template<class> class ES>
87std::unique_ptr<basic_src_lambda_f> src_search<T, ES>::lambdify(
88 const T &ind) const
89{
90 auto l(this->eva1_->lambdify(ind));
91 auto p(static_cast<basic_src_lambda_f *>(l.release()));
92
93 return std::unique_ptr<basic_src_lambda_f>(p);
94}
95
96template<class T, template<class> class ES>
97bool src_search<T, ES>::can_validate() const
98{
99 return search<T, ES>::can_validate() && validation_data().size();
100}
101
102///
103/// Calculates various performance metrics.
104///
105/// \param[out] s update summary of the evolution run just finished
106/// (metrics regarding `s.best.solution`)
107///
108/// Accuracy calculation is performed if AT LEAST ONE of the following
109/// conditions is satisfied:
110///
111/// * the accuracy threshold is defined (`env.threshold.accuracy > 0.0`);
112/// * we explicitly asked for accuracy calculation in the `src_search`
113/// constructor.
114///
115/// Otherwise the function skips accuracy calculation.
116///
117/// \warning Can be very time consuming.
118///
119template<class T, template<class> class ES>
120void src_search<T, ES>::calculate_metrics(summary<T> *s) const
121{
122 if ((metrics & metric_flags::accuracy)
123 || prob().env.threshold.accuracy > 0.0)
124 {
125 const auto model(lambdify(s->best.solution));
126 const auto &d(can_validate() ? validation_data() : training_data());
127 s->best.score.accuracy = model->measure(accuracy_metric(), d);
128 }
129
130 search<T, ES>::calculate_metrics(s);
131}
132
133///
134/// Tries to tune search parameters for the current problem.
135///
136/// Parameter tuning is a typical approach to algorithm design. Such tuning
137/// is done by experimenting with different values and selecting the ones
138/// that give the best results on the test problems at hand.
139///
140/// However, the number of possible parameters and their different values
141/// means that this is a very complex and time-consuming task; it is
142/// something we do not want users to worry about (power users can force many
143/// parameters, but our idea is "simple by default").
144///
145/// So if user sets an environment parameter he will force the search class
146/// to use it as is. Otherwise this function will try to guess a good
147/// starting point and changes its hint after every run. The code is a mix of
148/// black magic, experience, common logic and randomness but it seems
149/// reasonable.
150///
151/// \note
152/// It has been formally proven, in the No-Free-Lunch theorem, that it is
153/// impossible to tune a search algorithm such that it will have optimal
154/// settings for all possible problems, but parameters can be properly
155/// set for a given problem.
156///
157/// \see
158/// * "Parameter Setting in Evolutionary Algorithms" (F.G. Lobo, C.F. Lima,
159/// Z. Michalewicz) - Springer;
160/// - https://github.com/morinim/vita/wiki/bibliography#9
161///
162template<class T, template<class> class ES>
163void src_search<T, ES>::tune_parameters()
164{
165 // The `shape` function modifies the default parameters with
166 // strategy-specific values.
167 const environment dflt(ES<T>::shape(environment().init()));
168
169 environment &env(prob().env);
170
171 // Contains user-specified parameters that will be partly changed by the
172 // `search::tune_parameters` call.
173 const environment constrained(env);
174
175 search<T, ES>::tune_parameters();
176
177 const auto d_size(training_data().size());
178 Expects(d_size);
179
180 if (!constrained.layers)
181 {
182 if (dflt.layers > 1 && d_size > 8)
183 env.layers = static_cast<decltype(dflt.layers)>(std::log(d_size));
184 else
185 env.layers = dflt.layers;
186
187 vitaINFO << "Number of layers set to " << env.layers;
188 }
189
190 // A larger number of training cases requires an increase in the population
191 // size (e.g. https://github.com/morinim/vita/wiki/bibliography#9 suggests
192 // 10 - 1000 individuals for smaller problems; between 1000 and 10000
193 // individuals for complex problem (more than 200 fitness cases).
194 //
195 // We chose a strictly increasing function to link training set size and
196 // population size.
197 if (!constrained.individuals)
198 {
199 if (d_size > 8)
200 {
201 env.individuals = 2
202 * static_cast<decltype(dflt.individuals)>(
203 std::pow(std::log2(d_size), 3))
204 / env.layers;
205 }
206 else
207 env.individuals = dflt.individuals;
208
209 if (env.individuals < 4)
210 env.individuals = 4;
211
212 vitaINFO << "Population size set to " << env.individuals;
213 }
214
215 if (!constrained.dss.has_value() && typeid(this->vs_.get()) == typeid(dss))
216 env.dss = dflt.dss;
217
218 if (!constrained.validation_percentage.has_value()
219 && typeid(this->vs_.get()) == typeid(holdout_validation))
220 env.validation_percentage = dflt.validation_percentage;
221
222 Ensures(env.is_valid(true));
223}
224
225template<class T, template<class> class ES>
226void src_search<T, ES>::after_evolution(const summary<T> &s)
227{
228 search<T, ES>::after_evolution(s);
229}
230
231///
232/// \param[in] m metrics relative to the current run
233///
234template<class T, template<class> class ES>
235void src_search<T, ES>::print_resume(const model_measurements &m) const
236{
237 if (0.0 <= m.accuracy && m.accuracy <= 1.0)
238 {
239 const std::string s(can_validate() ? "Validation " : "Training ");
240 vitaINFO << s << "accuracy: " << 100.0 * m.accuracy << '%';
241 }
242
243 search<T, ES>::print_resume(m);
244}
245
246///
247/// Writes end-of-run logs (run summary, results for test...).
248///
249/// \param[in] s summary information regarding the search
250/// \param[out] d output xml document
251///
252template<class T, template<class> class ES>
253void src_search<T, ES>::log_stats(const search_stats<T> &s,
254 tinyxml2::XMLDocument *d) const
255{
256 Expects(d);
257
258 const auto &stat(prob().env.stat);
259
260 search<T, ES>::log_stats(s, d);
261
262 if (!stat.summary_file.empty())
263 {
264 assert(d->FirstChild());
265 assert(d->FirstChild()->FirstChildElement("summary"));
266
267 auto *e_best(d->FirstChild()->FirstChildElement("summary")
268 ->FirstChildElement("best"));
269 assert(e_best);
270 set_text(e_best, "accuracy", s.overall.best.score.accuracy);
271 }
272
273 // Test set results logging.
274 if (!stat.test_file.empty() && test_data().size())
275 {
276 const auto lambda(lambdify(s.overall.best.solution));
277
278 std::ofstream tf(stat.dir / stat.test_file);
279 for (const auto &example : test_data())
280 tf << lambda->name((*lambda)(example)) << '\n';
281 }
282}
283
284///
285/// Sets the active validation strategy.
286///
287/// \param[in] id numerical id of the validator to be activated
288/// \return a reference to the search class (used for method chaining)
289///
290/// \exception std::invalid_argument unknown validation strategy
291///
292template<class T, template<class> class ES>
293src_search<T, ES> &src_search<T, ES>::validation_strategy(validator_id id)
294{
295 switch (id)
296 {
297 case validator_id::as_is:
298 search<T, ES>::template validation_strategy<as_is_validation>();
299 break;
300
301 case validator_id::dss:
302 assert(this->eva1_);
303 assert(this->eva2_);
304 search<T, ES>::template validation_strategy<dss>(prob(),
305 *this->eva1_, *this->eva2_);
306 break;
307
308 case validator_id::holdout:
309 search<T, ES>::template validation_strategy<holdout_validation>(prob());
310 break;
311
312 default:
313 throw std::invalid_argument("Unknown validation strategy");
314 }
315
316 return *this;
317}
318
319template<class T, template<class> class ES>
320template<class E, class... Args>
321void src_search<T, ES>::set_evaluator(Args && ...args)
322{
323 search<T, ES>::template training_evaluator<E>(
324 training_data(), std::forward<Args>(args)...);
325
326 search<T, ES>::template validation_evaluator<E>(
327 validation_data(), std::forward<Args>(args)...);
328}
329
330///
331/// \param[in] id numerical id of the evaluator to be activated
332/// \param[in] msg input parameters for the evaluator constructor
333/// \return a reference to the search class (used for method chaining)
334///
335/// \exception std::invalid_argument unknown evaluator
336///
337/// \note
338/// If the evaluator `id` is not compatible with the problem type the
339/// function returns `false` and the active evaluator stays the same.
340///
341template<class T, template<class> class ES>
342src_search<T, ES> &src_search<T, ES>::evaluator(evaluator_id id,
343 const std::string &msg)
344{
345 if (training_data().classes() > 1)
346 {
347 switch (id)
348 {
349 case evaluator_id::bin:
350 set_evaluator<binary_evaluator<T>>();
351 break;
352
353 case evaluator_id::dyn_slot:
354 {
355 auto x_slot(static_cast<unsigned>(msg.empty() ? 10ul
356 : std::stoul(msg)));
357 set_evaluator<dyn_slot_evaluator<T>>(x_slot);
358 }
359 break;
360
361 case evaluator_id::gaussian:
362 set_evaluator<gaussian_evaluator<T>>();
363 break;
364
365 default:
366 throw std::invalid_argument("Unknown evaluator");
367 }
368 }
369 else // symbolic regression
370 {
371 switch (id)
372 {
373 case evaluator_id::count:
374 set_evaluator<count_evaluator<T>>();
375 break;
376
377 case evaluator_id::mae:
378 set_evaluator<mae_evaluator<T>>();
379 break;
380
381 case evaluator_id::rmae:
382 set_evaluator<rmae_evaluator<T>>();
383 break;
384
385 case evaluator_id::mse:
386 set_evaluator<mse_evaluator<T>>();
387 break;
388
389 default:
390 throw std::invalid_argument("Unknown evaluator");
391 }
392 }
393
394 return *this;
395}
396
397///
398/// \return `true` if the object passes the internal consistency check
399///
400template<class T, template<class> class ES>
401bool src_search<T, ES>::is_valid() const
402{
403 if (p_symre == evaluator_id::undefined)
404 {
405 vitaERROR << "Undefined ID for preferred sym.reg. evaluator";
406 return false;
407 }
408
409 if (p_class == evaluator_id::undefined)
410 {
411 vitaERROR << "Undefined ID for preferred classification evaluator";
412 return false;
413 }
414
415 return search<T, ES>::is_valid();
416}
417
418#endif // include guard