Vita
symbol_set.cc
Go to the documentation of this file.
1
13#include <set>
14
15#include "kernel/symbol_set.h"
16#include "kernel/log.h"
17#include "kernel/random.h"
18#include "kernel/gp/gene.h"
19
20namespace vita
21{
27symbol_set::symbol_set() : symbols_(), views_()
28{
29 Ensures(is_valid());
30}
31
36{
37 //*this = {};
38
39 symbols_.clear();
40 views_.clear();
41}
42
55symbol *symbol_set::insert(std::unique_ptr<symbol> s, double wr)
56{
57 Expects(s);
58 Expects(wr >= 0.0);
59
60 const auto w(static_cast<weight_t>(wr * w_symbol::base_weight));
61 const w_symbol ws(s.get(), w);
62
63 category_t category(s->category());
64 if (category == undefined_category)
65 {
66 category = views_.size();
67 s->category(category);
68 }
69
70 for (category_t i(views_.size()); i <= category; ++i)
71 views_.emplace_back("Collection " + std::to_string(i));
72 assert(category < views_.size());
73
74 views_[category].all.insert(ws);
75
76 if (s->terminal())
77 views_[category].terminals.insert(ws);
78 else // function
79 views_[category].functions.insert(ws);
80
81 symbols_.push_back(std::move(s));
82 return ws.sym;
83}
84
85template<class F>
86void symbol_set::collection::sum_container::scale_weights(double ratio, F f)
87{
88 for (auto &s : elems_)
89 if (f(s))
90 {
91 sum_ -= s.weight;
92 s.weight = static_cast<weight_t>(s.weight * ratio);
93 sum_ += s.weight;
94 }
95}
96
102{
103 Expects(c < categories());
104 Expects(views_[c].functions.size());
105
106 return static_cast<const function &>(views_[c].functions.roulette());
107}
108
114{
115 Expects(c < categories());
116 Expects(views_[c].terminals.size());
117
118 return static_cast<const terminal &>(views_[c].terminals.roulette());
119}
120
144{
145 Expects(c < categories());
146 Expects(views_[c].terminals.size());
147
148 if (random::boolean() && views_[c].functions.size())
149 return views_[c].functions.roulette();
150
151 return views_[c].terminals.roulette();
152}
153
169{
170 Expects(c < categories());
171 return views_[c].all.roulette();
172}
173
180{
181 for (const auto &s : symbols_)
182 if (s->opcode() == opcode)
183 return s.get();
184
185 return nullptr;
186}
187
198symbol *symbol_set::decode(const std::string &dex) const
199{
200 Expects(!dex.empty());
201
202 for (const auto &s : symbols_)
203 if (s->name() == dex)
204 return s.get();
205
206 return nullptr;
207}
208
215{
216 return static_cast<category_t>(views_.size());
217}
218
224{
225 Expects(c < categories());
226 return views_[c].terminals.size();
227}
228
236{
237 if (views_.empty())
238 return true;
239
240 std::set<category_t> need;
241
242 for (const auto &s : symbols_)
243 {
244 const auto arity(s->arity());
245 for (auto i(decltype(arity){0}); i < arity; ++i)
246 need.insert(function::cast(s.get())->arg_category(i));
247 }
248
249 for (const auto &i : need)
250 if (i >= categories() || !views_[i].terminals.size())
251 return false;
252
253 return true;
254}
255
260symbol_set::weight_t symbol_set::weight(const symbol &s) const
261{
262 for (const auto &ws : views_[s.category()].all)
263 if (ws.sym == &s)
264 return ws.weight;
265
266 return 0;
267}
268
278std::ostream &operator<<(std::ostream &o, const symbol_set &ss)
279{
280 for (const auto &s : ss.symbols_)
281 {
282 o << s->name();
283
284 auto arity(s->arity());
285 if (arity)
286 {
287 o << '(';
288 for (decltype(arity) j(0); j < arity; ++j)
289 o << function::cast(s.get())->arg_category(j)
290 << (j + 1 == arity ? "" : ", ");
291 o << ')';
292 }
293
294 o << " -> " << s->category() << " (opcode " << s->opcode()
295 << ", parametric "
296 << (s->terminal() && terminal::cast(s.get())->parametric())
297 << ", weight "
298 << ss.weight(*s) << ")\n";
299 }
300
301 return o;
302}
303
308{
309 if (!enough_terminals())
310 {
311 vitaERROR << "Symbol set doesn't contain enough symbols";
312 return false;
313 }
314
315 return true;
316}
317
320//
323symbol_set::collection::collection(std::string n)
324 : all("all"), functions("functions"), terminals("terminals"),
325 name_(std::move(n))
326{
327}
328
332bool symbol_set::collection::is_valid() const
333{
334 if (!all.is_valid() || !functions.is_valid() || !terminals.is_valid())
335 {
336 vitaERROR << "(inside " << name_ << ")";
337 return false;
338 }
339
340 if (std::any_of(functions.begin(), functions.end(),
341 [](const auto &s) { return s.sym->terminal(); }))
342 return false;
343
344 if (std::any_of(terminals.begin(), terminals.end(),
345 [](const auto &s) { return !s.sym->terminal(); }))
346 return false;
347
348 for (const auto &s : all)
349 {
350 if (s.sym->terminal())
351 {
352 if (std::find(terminals.begin(), terminals.end(), s) == terminals.end())
353 {
354 vitaERROR << name_ << ": terminal " << s.sym->name()
355 << " badly stored";
356 return false;
357 }
358 }
359 else // function
360 {
361 if (std::find(functions.begin(), functions.end(), s) == functions.end())
362 {
363 vitaERROR << name_ << ": function " << s.sym->name()
364 << " badly stored";
365 return false;
366 }
367 }
368 }
369
370 const auto ssize(all.size());
371
372 // The following condition should be met at the end of the symbol_set
373 // specification.
374 // Since we don't want to enforce a particular insertion order (i.e.
375 // terminals before functions), we cannot perform the check here.
376 //
377 // if (ssize && !terminals.size())
378 // {
379 // vitaERROR << name_ << ": no terminal in the symbol set";
380 // return false;
381 // }
382
383 if (ssize < functions.size())
384 {
385 vitaERROR << name_ << ": wrong function set size (more than symbol set)";
386 return false;
387 }
388
389 if (ssize < terminals.size())
390 {
391 vitaERROR << name_ << ": wrong terminal set size (more than symbol set)";
392 return false;
393 }
394
395 return ssize == functions.size() + terminals.size();
396}
397
407{
408 elems_.push_back(ws);
409 sum_ += ws.weight;
410
411 std::sort(begin(), end(),
412 [](auto s1, auto s2) { return s1.weight > s2.weight; });
413}
414
431{
432 Expects(sum());
433
434 const auto slot(random::sup(sum()));
435
436 std::size_t i(0);
437 for (auto wedge(elems_[i].weight);
438 wedge <= slot;
439 wedge += elems_[++i].weight)
440 {}
441
442 assert(i < elems_.size());
443 return *elems_[i].sym;
444
445 // The so called roulette-wheel selection via stochastic acceptance:
446 //
447 // for (;;)
448 // {
449 // const symbol *s(random::element(elems));
450 //
451 // if (random::sup(max) < s->weight)
452 // return *s;
453 // }
454 //
455 // Internal tests have proved this is slower for Vita.
456
457 // This is a different approach from Eli Bendersky
458 // (<http://eli.thegreenplace.net/>):
459 //
460 // weight_t total(0);
461 // std::size_t winner(0);
462 //
463 // for (std::size_t i(0); i < size(); ++i)
464 // {
465 // total += elems[i].weight;
466 // if (random::sup(total + 1) < elems[i].weight)
467 // winner = i;
468 // }
469 // return *elems[winner].sym;
470 //
471 // The interesting property of this algorithm is that you don't need to
472 // know the sum of weights in advance in order to use it. The method is
473 // cool, but slower than the standard roulette.
474}
475
480{
481 weight_t check_sum(0);
482
483 for (const auto &e : elems_)
484 {
485 check_sum += e.weight;
486
487 if (e.weight == 0 && !e.sym->terminal())
488 {
489 vitaERROR << name_ << ": null weight for symbol " << e.sym->name();
490 return false;
491 }
492 }
493
494 if (check_sum != sum())
495 {
496 vitaERROR << name_ << ": incorrect cached sum of weights (stored: "
497 << sum() << ", correct: " << check_sum << ')';
498 return false;
499 }
500
501 return true;
502}
503
504} // namespace vita
A symbol with arity() > 0.
Definition: function.h:35
static const function * cast(const symbol *)
This is a short cut function.
Definition: function.h:101
category_t arg_category(std::size_t) const
Definition: function.h:89
void insert(const w_symbol &)
Inserts a weighted symbol in the container.
Definition: symbol_set.cc:406
const symbol & roulette() const
Extracts a random symbol from the collection.
Definition: symbol_set.cc:430
A container for the symbols used by the GP engine.
Definition: symbol_set.h:37
symbol * decode(opcode_t) const
Definition: symbol_set.cc:179
symbol_set()
Sets up the object.
Definition: symbol_set.cc:27
void clear()
Clears the current symbol set.
Definition: symbol_set.cc:35
const function & roulette_function(category_t) const
Definition: symbol_set.cc:101
bool is_valid() const
Definition: symbol_set.cc:307
const symbol & roulette(category_t) const
Extracts a random symbol from the symbol set without bias between terminals and functions .
Definition: symbol_set.cc:143
bool enough_terminals() const
We want at least one terminal for every used category.
Definition: symbol_set.cc:235
weight_t weight(const symbol &) const
Definition: symbol_set.cc:260
symbol * insert(std::unique_ptr< symbol >, double=1.0)
Adds a new symbol to the set.
Definition: symbol_set.cc:55
friend std::ostream & operator<<(std::ostream &, const symbol_set &)
Prints the symbol set to an output stream.
Definition: symbol_set.cc:278
const symbol & roulette_free(category_t) const
Extracts a random symbol from the symbol set.
Definition: symbol_set.cc:168
const terminal & roulette_terminal(category_t) const
Definition: symbol_set.cc:113
category_t categories() const
Definition: symbol_set.cc:214
std::size_t terminals(category_t) const
Definition: symbol_set.cc:223
Together functions and terminals are referred to as symbols.
Definition: symbol.h:36
category_t category() const
The type (a.k.a.
Definition: symbol.h:99
A symbol with zero-arity.
Definition: terminal.h:27
virtual bool parametric() const
A parametric terminal needs an additional parameter to be evaluated.
Definition: terminal.h:59
static const terminal * cast(const symbol *)
This is a short cut function.
Definition: terminal.h:83
The main namespace for the project.
std::size_t category_t
A category provide operations which supplement or supersede those of the domain but which are restric...
Definition: common.h:44
unsigned opcode_t
This is the type used as key for symbol identification.
Definition: symbol.h:26