3 * \remark This file is part of VITA.
5 * \copyright Copyright (C) 2013-2023 EOS di Manlio Morini.
8 * This Source Code Form is subject to the terms of the Mozilla Public
9 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
10 * You can obtain one at http://mozilla.org/MPL/2.0/
13#if !defined(VITA_SRC_SEARCH_H)
14# error "Don't include this file directly, include the specific .h instead"
17#if !defined(VITA_SRC_SEARCH_TCC)
18#define VITA_SRC_SEARCH_TCC
20constexpr std::underlying_type_t<metric_flags> operator&(metric_flags f1,
23 return as_integer(f1) & as_integer(f2);
27/// \param[in] p the problem we're working on. The lifetime of `p` must exceed
28/// the lifetime of `this` class
29/// \param[in] m a bit field used to specify matrics we have to calculate while
32template<class T, template<class> class ES>
33src_search<T, ES>::src_search(src_problem &p, metric_flags m)
35 p_symre(evaluator_id::rmae), p_class(evaluator_id::gaussian), metrics(m)
37 evaluator(p.classification() ? p_class : p_symre);
39 Ensures(this->is_valid());
43/// \return a reference to the training set
45template<class T, template<class> class ES>
46dataframe &src_search<T, ES>::training_data() const
48 return prob().data(dataset_t::training);
52/// \return a reference to the test set
54template<class T, template<class> class ES>
55dataframe &src_search<T, ES>::test_data() const
57 return prob().data(dataset_t::test);
61/// \return a reference to the validation set
63template<class T, template<class> class ES>
64dataframe &src_search<T, ES>::validation_data() const
66 return prob().data(dataset_t::validation);
70/// \return a reference to the current problem
72template<class T, template<class> class ES>
73src_problem &src_search<T, ES>::prob() const
75 return static_cast<src_problem &>(this->prob_);
79/// Creates a lambda function associated with an individual.
81/// \param[in] ind individual to be transformed in a lambda function
82/// \return the lambda function (`nullptr` in case of errors)
84/// The lambda function depends on the active training evaluator.
86template<class T, template<class> class ES>
87std::unique_ptr<basic_src_lambda_f> src_search<T, ES>::lambdify(
90 auto l(this->eva1_->lambdify(ind));
91 auto p(static_cast<basic_src_lambda_f *>(l.release()));
93 return std::unique_ptr<basic_src_lambda_f>(p);
96template<class T, template<class> class ES>
97bool src_search<T, ES>::can_validate() const
99 return search<T, ES>::can_validate() && validation_data().size();
103/// Calculates various performance metrics.
105/// \param[out] s update summary of the evolution run just finished
106/// (metrics regarding `s.best.solution`)
108/// Accuracy calculation is performed if AT LEAST ONE of the following
109/// conditions is satisfied:
111/// * the accuracy threshold is defined (`env.threshold.accuracy > 0.0`);
112/// * we explicitly asked for accuracy calculation in the `src_search`
115/// Otherwise the function skips accuracy calculation.
117/// \warning Can be very time consuming.
119template<class T, template<class> class ES>
120void src_search<T, ES>::calculate_metrics(summary<T> *s) const
122 if ((metrics & metric_flags::accuracy)
123 || prob().env.threshold.accuracy > 0.0)
125 const auto model(lambdify(s->best.solution));
126 const auto &d(can_validate() ? validation_data() : training_data());
127 s->best.score.accuracy = model->measure(accuracy_metric(), d);
130 search<T, ES>::calculate_metrics(s);
134/// Tries to tune search parameters for the current problem.
136/// Parameter tuning is a typical approach to algorithm design. Such tuning
137/// is done by experimenting with different values and selecting the ones
138/// that give the best results on the test problems at hand.
140/// However, the number of possible parameters and their different values
141/// means that this is a very complex and time-consuming task; it is
142/// something we do not want users to worry about (power users can force many
143/// parameters, but our idea is "simple by default").
145/// So if user sets an environment parameter he will force the search class
146/// to use it as is. Otherwise this function will try to guess a good
147/// starting point and changes its hint after every run. The code is a mix of
148/// black magic, experience, common logic and randomness but it seems
152/// It has been formally proven, in the No-Free-Lunch theorem, that it is
153/// impossible to tune a search algorithm such that it will have optimal
154/// settings for all possible problems, but parameters can be properly
155/// set for a given problem.
158/// * "Parameter Setting in Evolutionary Algorithms" (F.G. Lobo, C.F. Lima,
159/// Z. Michalewicz) - Springer;
160/// - https://github.com/morinim/vita/wiki/bibliography#9
162template<class T, template<class> class ES>
163void src_search<T, ES>::tune_parameters()
165 // The `shape` function modifies the default parameters with
166 // strategy-specific values.
167 const environment dflt(ES<T>::shape(environment().init()));
169 environment &env(prob().env);
171 // Contains user-specified parameters that will be partly changed by the
172 // `search::tune_parameters` call.
173 const environment constrained(env);
175 search<T, ES>::tune_parameters();
177 const auto d_size(training_data().size());
180 if (!constrained.layers)
182 if (dflt.layers > 1 && d_size > 8)
183 env.layers = static_cast<decltype(dflt.layers)>(std::log(d_size));
185 env.layers = dflt.layers;
187 vitaINFO << "Number of layers set to " << env.layers;
190 // A larger number of training cases requires an increase in the population
191 // size (e.g. https://github.com/morinim/vita/wiki/bibliography#9 suggests
192 // 10 - 1000 individuals for smaller problems; between 1000 and 10000
193 // individuals for complex problem (more than 200 fitness cases).
195 // We chose a strictly increasing function to link training set size and
197 if (!constrained.individuals)
202 * static_cast<decltype(dflt.individuals)>(
203 std::pow(std::log2(d_size), 3))
207 env.individuals = dflt.individuals;
209 if (env.individuals < 4)
212 vitaINFO << "Population size set to " << env.individuals;
215 if (!constrained.dss.has_value() && typeid(this->vs_.get()) == typeid(dss))
218 if (!constrained.validation_percentage.has_value()
219 && typeid(this->vs_.get()) == typeid(holdout_validation))
220 env.validation_percentage = dflt.validation_percentage;
222 Ensures(env.is_valid(true));
225template<class T, template<class> class ES>
226void src_search<T, ES>::after_evolution(const summary<T> &s)
228 search<T, ES>::after_evolution(s);
232/// \param[in] m metrics relative to the current run
234template<class T, template<class> class ES>
235void src_search<T, ES>::print_resume(const model_measurements &m) const
237 if (0.0 <= m.accuracy && m.accuracy <= 1.0)
239 const std::string s(can_validate() ? "Validation " : "Training ");
240 vitaINFO << s << "accuracy: " << 100.0 * m.accuracy << '%';
243 search<T, ES>::print_resume(m);
247/// Writes end-of-run logs (run summary, results for test...).
249/// \param[in] s summary information regarding the search
250/// \param[out] d output xml document
252template<class T, template<class> class ES>
253void src_search<T, ES>::log_stats(const search_stats<T> &s,
254 tinyxml2::XMLDocument *d) const
258 const auto &stat(prob().env.stat);
260 search<T, ES>::log_stats(s, d);
262 if (!stat.summary_file.empty())
264 assert(d->FirstChild());
265 assert(d->FirstChild()->FirstChildElement("summary"));
267 auto *e_best(d->FirstChild()->FirstChildElement("summary")
268 ->FirstChildElement("best"));
270 set_text(e_best, "accuracy", s.overall.best.score.accuracy);
273 // Test set results logging.
274 if (!stat.test_file.empty() && test_data().size())
276 const auto lambda(lambdify(s.overall.best.solution));
278 std::ofstream tf(stat.dir / stat.test_file);
279 for (const auto &example : test_data())
280 tf << lambda->name((*lambda)(example)) << '\n';
285/// Sets the active validation strategy.
287/// \param[in] id numerical id of the validator to be activated
288/// \return a reference to the search class (used for method chaining)
290/// \exception std::invalid_argument unknown validation strategy
292template<class T, template<class> class ES>
293src_search<T, ES> &src_search<T, ES>::validation_strategy(validator_id id)
297 case validator_id::as_is:
298 search<T, ES>::template validation_strategy<as_is_validation>();
301 case validator_id::dss:
304 search<T, ES>::template validation_strategy<dss>(prob(),
305 *this->eva1_, *this->eva2_);
308 case validator_id::holdout:
309 search<T, ES>::template validation_strategy<holdout_validation>(prob());
313 throw std::invalid_argument("Unknown validation strategy");
319template<class T, template<class> class ES>
320template<class E, class... Args>
321void src_search<T, ES>::set_evaluator(Args && ...args)
323 search<T, ES>::template training_evaluator<E>(
324 training_data(), std::forward<Args>(args)...);
326 search<T, ES>::template validation_evaluator<E>(
327 validation_data(), std::forward<Args>(args)...);
331/// \param[in] id numerical id of the evaluator to be activated
332/// \param[in] msg input parameters for the evaluator constructor
333/// \return a reference to the search class (used for method chaining)
335/// \exception std::invalid_argument unknown evaluator
338/// If the evaluator `id` is not compatible with the problem type the
339/// function returns `false` and the active evaluator stays the same.
341template<class T, template<class> class ES>
342src_search<T, ES> &src_search<T, ES>::evaluator(evaluator_id id,
343 const std::string &msg)
345 if (training_data().classes() > 1)
349 case evaluator_id::bin:
350 set_evaluator<binary_evaluator<T>>();
353 case evaluator_id::dyn_slot:
355 auto x_slot(static_cast<unsigned>(msg.empty() ? 10ul
357 set_evaluator<dyn_slot_evaluator<T>>(x_slot);
361 case evaluator_id::gaussian:
362 set_evaluator<gaussian_evaluator<T>>();
366 throw std::invalid_argument("Unknown evaluator");
369 else // symbolic regression
373 case evaluator_id::count:
374 set_evaluator<count_evaluator<T>>();
377 case evaluator_id::mae:
378 set_evaluator<mae_evaluator<T>>();
381 case evaluator_id::rmae:
382 set_evaluator<rmae_evaluator<T>>();
385 case evaluator_id::mse:
386 set_evaluator<mse_evaluator<T>>();
390 throw std::invalid_argument("Unknown evaluator");
398/// \return `true` if the object passes the internal consistency check
400template<class T, template<class> class ES>
401bool src_search<T, ES>::is_valid() const
403 if (p_symre == evaluator_id::undefined)
405 vitaERROR << "Undefined ID for preferred sym.reg. evaluator";
409 if (p_class == evaluator_id::undefined)
411 vitaERROR << "Undefined ID for preferred classification evaluator";
415 return search<T, ES>::is_valid();
418#endif // include guard