Vita
dss.cc
Go to the documentation of this file.
1
13#include "kernel/gp/src/dss.h"
14#include "kernel/random.h"
15
16namespace vita
17{
18
19namespace
20{
21
22auto weight(const dataframe::example &v)
23{
24 return static_cast<std::uintmax_t>(v.difficulty)
25 + static_cast<std::uintmax_t>(v.age) * v.age * v.age;
26}
27
28} // unnamed namespace
29
42 : training_(prob.data(dataset_t::training)),
43 validation_(prob.data(dataset_t::validation)),
44 eva_t_(eva_t), eva_v_(eva_v),
45 env_(prob.env)
46{
47 // Here `!env_.dss.has_value()` could be true. Validation strategy is set
48 // before parameters are tuned.
49}
50
51void dss::reset_age_difficulty(dataframe &d)
52{
53 std::for_each(d.begin(), d.end(),
54 [](auto &example)
55 {
56 example.difficulty = 0;
57 example.age = 1;
58 });
59}
60
61std::pair<std::uintmax_t, std::uintmax_t> dss::average_age_difficulty(
62 dataframe &d) const
63{
64 constexpr std::pair<std::uintmax_t, std::uintmax_t> zero(0, 0);
65
66 const auto s(d.size());
67 if (!s)
68 return zero;
69
70 auto avg(std::accumulate(d.begin(), d.end(), zero,
71 [](const auto &p, const dataframe::example &e)
72 {
73 return std::pair<std::uintmax_t, std::uintmax_t>(
74 p.first + e.age, p.second + e.difficulty);
75 }));
76
77 avg.first /= s;
78 avg.second /= s;
79
80 return avg;
81}
82
83void dss::clear_evaluators()
84{
85 eva_t_.clear();
86 eva_v_.clear();
87}
88
89void dss::move_to_validation()
90{
91 std::move(training_.begin(), training_.end(),
92 std::back_inserter(validation_));
93 training_.clear();
94
95 Ensures(training_.empty());
96}
97
104void dss::init(unsigned)
105{
106 Expects(env_.dss.value_or(0) > 0);
107
108 reset_age_difficulty(training_);
109 reset_age_difficulty(validation_);
110
111 shake_impl();
112 clear_evaluators();
113}
114
115void dss::shake_impl()
116{
117 Expects(training_.size() + validation_.size() >= 2);
118
119 move_to_validation();
120
121 const auto avg_v(average_age_difficulty(validation_));
122 vitaDEBUG << "DSS average validation difficulty " << avg_v.second
123 << ", age " << avg_v.first;
124
125 const auto weight_sum(
126 std::accumulate(validation_.begin(), validation_.end(), std::uintmax_t(0),
127 [](const std::uintmax_t &s, const dataframe::example &e)
128 {
129 return s + weight(e);
130 }));
131
132 assert(weight_sum);
133
134 // Move a subset of the available examples (initially contained in the
135 // validation set) into the training set.
136 // Note that the actual size of the selected subset is not fixed and, in
137 // fact, it averages slightly above `target_size` (Gathercole and Ross felt
138 // it might improve performance).
139 const auto s(static_cast<double>(validation_.size()));
140 const double ratio(std::min(0.6, 0.2 + 100.0 / (s + 100.0)));
141 assert(0.2 <= ratio && ratio <= 0.6);
142 const double target_size(std::max(1.0, s * ratio));
143 assert(1.0 <= target_size && target_size <= s);
144 const double k(target_size / static_cast<double>(weight_sum));
145
146 auto pivot(
147 std::partition(validation_.begin(), validation_.end(),
148 [k](const auto &e)
149 {
150 const auto p1(static_cast<double>(weight(e)) * k);
151 const auto prob(std::min(p1, 1.0));
152
153 return random::boolean(prob) == false;
154 }));
155
156 if (pivot == validation_.begin() || pivot == validation_.end())
157 pivot = std::next(validation_.begin(),
158 static_cast<std::ptrdiff_t>(target_size));
159
160 assert(validation_.size() == static_cast<size_t>(s));
161 std::move(pivot, validation_.end(), std::back_inserter(training_));
162 validation_.erase(pivot, validation_.end());
163
164 vitaDEBUG << "DSS SHAKE (weight sum: " << weight_sum << ", training with: "
165 << training_.size() << ')';
166 assert(static_cast<size_t>(s) == training_.size() + validation_.size());
167
168 reset_age_difficulty(training_);
169
170 Ensures(!training_.empty());
171 Ensures(!validation_.empty());
172}
173
174bool dss::shake(unsigned generation)
175{
176 Expects(env_.dss.value_or(0) > 0);
177
178 const auto gap(*env_.dss);
179
180 if (generation == 0 // already handled by init()
181 || generation % gap)
182 {
183 assert(!training_.empty());
184 assert(!validation_.empty());
185 return false;
186 }
187
188 vitaDEBUG << "DSS shaking generation " << generation;
189
190 const auto avg_t(average_age_difficulty(training_));
191 vitaDEBUG << "DSS average training difficulty " << avg_t.second;
192 assert(avg_t.first == 1);
193
194 const auto inc_age([](dataframe::example &e) { ++e.age; });
195 std::for_each(training_.begin(), training_.end(), inc_age);
196 std::for_each(validation_.begin(), validation_.end(), inc_age);
197
198 shake_impl();
199 clear_evaluators();
200
201 return true;
202}
203
207void dss::close(unsigned)
208{
209 move_to_validation();
210 clear_evaluators();
211}
212
213} // namespace vita
virtual void clear()
Clear possible cached values.
Definition: evaluator.h:28
A 2-dimensional labeled data structure with columns of potentially different types.
Definition: dataframe.h:48
iterator erase(iterator, iterator)
Removes specified elements from the dataframe.
Definition: dataframe.cc:772
std::size_t size() const
Definition: dataframe.cc:291
void clear()
Removes all elements from the container.
Definition: dataframe.cc:227
iterator begin()
Definition: dataframe.cc:235
bool empty() const
Definition: dataframe.cc:299
iterator end()
Definition: dataframe.cc:251
void close(unsigned) override
Moves all the example in the validation set.
Definition: dss.cc:207
bool shake(unsigned) override
Changes the training environment.
Definition: dss.cc:174
void init(unsigned) override
Available examples are randomly partitioned into two independent sets according to a given percentage...
Definition: dss.cc:104
dss(src_problem &, cached_evaluator &, cached_evaluator &)
Sets up a DSS validator.
Definition: dss.cc:41
facultative< unsigned > dss
Enables Dynamic Subset Selection every dss generations.
Definition: environment.h:182
Provides a GP-specific interface to the generic problem class.
The main namespace for the project.
dataset_t
Data/simulations are categorised in three sets:
Stores a single element (row) of the dataset.
Definition: dataframe.h:194