21#include "tinyxml2/tinyxml2.h"
39 case d_int:
return std::stoi(s);
40 case d_double:
return std::stod(s);
41 case d_string:
return s;
55 static const std::map<const std::string, domain_t> map(
60 {
"numeric", d_double},
64 {
"nominal", d_string},
73 const auto &i(map.find(n));
74 return i == map.end() ? d_void : i->second;
101 cols_.insert(
begin(), v);
124 const auto set_domain(
127 const std::string &value(trim(r[idx]));
131 const bool number(is_number(value));
132 const bool classification(idx == 0 && !
number);
135 if (cols_[idx].domain == d_void)
138 cols_[idx].domain =
number || classification ? d_double : d_string;
141 const auto fields(r.size());
145 cols_.reserve(fields);
149 std::transform(r.begin(), r.end(),
150 std::back_inserter(cols_),
153 return column_info{trim(name), d_void, {}};
159 std::fill_n(std::back_inserter(cols_), fields,
column_info());
162 assert(
size() == r.size());
164 for (std::size_t field(0); field < fields; ++field)
173 return std::none_of(begin(), end(),
175 {
return c.domain == d_void && !c.states.empty(); });
213 Expects(!fn.empty());
237 return dataset_.begin();
245 return dataset_.begin();
253 return dataset_.end();
261 return dataset_.end();
273 return dataset_.front();
285 return dataset_.front();
293 return dataset_.size();
310 return static_cast<class_t>(classes_map_.size());
321 const auto n(
empty() ? 0u :
static_cast<unsigned>(
begin()->input.size()));
323 Ensures(
empty() || n + 1 == columns.size());
334 dataset_.push_back(e);
343 if (classes_map_.find(
label) == classes_map_.end())
346 classes_map_[
label] = n;
349 return classes_map_[
label];
363dataframe::example dataframe::to_example(
const record_t &v,
bool add_instance)
366 Expects(v.size() == columns.size());
370 for (std::size_t i(0); i < v.size(); ++i)
371 if (
const auto domain = columns[i].domain; domain != d_void)
373 const auto feature(trim(v[i]));
377 const bool classification(!is_number(v.front()));
382 ret.output =
static_cast<D_INT
>(encode(feature));
384 ret.output = convert(feature, domain);
387 ret.input.push_back(convert(feature, domain));
389 if (add_instance && domain == d_string)
390 columns[i].states.insert(feature);
402bool dataframe::read_record(
const record_t &r,
bool add_instance)
406 if (r.size() != columns.size())
408 vitaWARNING <<
"Malformed exampled " <<
size() <<
" skipped";
412 const auto instance(to_example(r, add_instance));
425 for (
const auto &p : classes_map_)
443std::size_t dataframe::read_xrff(
const std::filesystem::path &fn,
446 tinyxml2::XMLDocument doc;
447 if (doc.LoadFile(fn.string().c_str()) != tinyxml2::XML_SUCCESS)
450 return read_xrff(doc, p);
464std::size_t dataframe::read_xrff(std::istream &in,
const params &p)
466 std::ostringstream ss;
469 tinyxml2::XMLDocument doc;
470 if (doc.Parse(ss.str().c_str()) != tinyxml2::XML_SUCCESS)
473 return read_xrff(doc, p);
475std::size_t dataframe::read_xrff(std::istream &in)
477 return read_xrff(in, {});
497std::size_t dataframe::read_xrff(tinyxml2::XMLDocument &doc,
const params &p)
501 tinyxml2::XMLHandle handle(&doc);
502 auto *attributes(handle.FirstChildElement(
"dataset")
503 .FirstChildElement(
"header")
504 .FirstChildElement(
"attributes").ToElement());
506 throw exception::data_format(
"Missing `attributes` element in XRFF file");
510 unsigned n_output(0), output_index(0), index(0);
512 for (
auto *attribute(attributes->FirstChildElement(
"attribute"));
514 attribute = attribute->NextSiblingElement(
"attribute"), ++index)
516 columns_info::column_info a;
518 const char *s(attribute->Attribute(
"name"));
524 const bool output(attribute->Attribute(
"class",
"yes"));
526 s = attribute->Attribute(
"type");
527 std::string xml_type(s ? s :
"");
533 output_index = index;
537 throw exception::data_format(
"Multiple output columns in XRFF file");
541 if (xml_type ==
"nominal" || xml_type ==
"string")
542 xml_type =
"numeric";
548 if (xml_type ==
"nominal")
549 for (
auto *l(attribute->FirstChildElement(
"label"));
551 l = l->NextSiblingElement(
"label"))
553 const std::string
label(l->GetText() ? l->GetText() :
"");
554 a.states.insert(
label);
566 throw exception::data_format(
"Missing column information in XRFF file");
574 output_index = index - 1;
577 if (
auto *instances = handle.FirstChildElement(
"dataset")
578 .FirstChildElement(
"body")
579 .FirstChildElement(
"instances").ToElement())
581 for (
auto *i(instances->FirstChildElement(
"instance"));
583 i = i->NextSiblingElement(
"instance"))
587 for (
auto *v(i->FirstChildElement(
"value"));
589 v = v->NextSiblingElement(
"value"))
590 record.push_back(v->GetText() ? v->GetText() :
"");
592 if (p.filter && p.filter(record) ==
false)
595 std::rotate(record.begin(),
596 std::next(record.begin(), output_index),
597 std::next(record.begin(), output_index + 1));
599 read_record(record,
false);
603 throw exception::data_format(
"Missing `instances` element in XRFF file");
605 return is_valid() ?
size() : static_cast<std::size_t>(0);
619std::size_t dataframe::read_csv(
const std::filesystem::path &fn,
622 std::ifstream
in(fn);
624 throw std::runtime_error(
"Cannot read CSV data file");
626 return read_csv(in, p);
677std::size_t dataframe::read_csv(std::istream &from,
params p)
681 if (p.
dialect.has_header == pocket_csv::dialect::GUESS_HEADER
684 const auto sniff(pocket_csv::sniffer(from));
686 if (p.
dialect.has_header == pocket_csv::dialect::GUESS_HEADER)
687 p.
dialect.has_header = sniff.has_header;
689 p.
dialect.delimiter = sniff.delimiter;
692 std::size_t count(0);
693 for (
auto record : pocket_csv::parser(from, p.
dialect).filter_hook(p.
filter))
700 std::rotate(record.begin(),
708 record.insert(record.begin(),
"");
711 const bool has_header(p.
dialect.has_header
712 == pocket_csv::dialect::HAS_HEADER);
714 columns.
build(record, has_header);
715 if (has_header ==
false || count)
716 read_record(record,
true);
726std::size_t dataframe::read_csv(std::istream &from)
728 return read_csv(from, {});
742std::size_t dataframe::read(
const std::filesystem::path &fn,
const params &p)
745 throw std::invalid_argument(
"Missing dataset filename");
747 const auto ext(fn.extension().string());
748 const bool xrff(iequals(ext,
".xrff") || iequals(ext,
".xml"));
750 return xrff ? read_xrff(fn, p) : read_csv(fn, p);
752std::size_t dataframe::read(
const std::filesystem::path &fn)
774 return dataset_.erase(first, last);
791 const auto in_size(
front().input.size());
793 for (
const auto &e : *
this)
795 if (e.input.size() != in_size)
798 if (cl_size &&
label(e) >= cl_size)
columns_info()
Constructs a new empty columns_info object.
void push_front(const column_info &)
Adds a new column at the front of the column list.
void build(const record_t &, bool)
Given an example compiles information about the columns of the dataframe.
void push_back(const column_info &)
Adds a new column at the end of the column list.
std::optional< std::size_t > output_index
Index of the column containing the output value (label).
filter_hook_t filter
A filter and transform function applied when reading data.
pocket_csv::dialect dialect
A 2-dimensional labeled data structure with columns of potentially different types.
void push_back(const example &)
Appends the given element to the end of the active dataset.
std::string class_name(class_t) const
unsigned variables() const
std::vector< std::string > record_t
Raw input record.
iterator erase(iterator, iterator)
Removes specified elements from the dataframe.
value_type front() const
Returns a constant reference to the first element in the dataframe.
void clear()
Removes all elements from the container.
dataframe()
New empty data instance.
The main namespace for the project.
class_t label(const dataframe::example &e)
Gets the class_t ID (aka label) for a given example.
domain_t from_weka(const std::string &n)
std::size_t class_t
The type used as class ID in classification tasks.
domain_t
In an environment where a symbol such as '+' may have many different meanings, it's useful to specify...
D_DOUBLE number
This is the return type of the src_interpreter::run method.
std::variant< D_VOID, D_INT, D_DOUBLE, D_STRING > value_t
A variant containing the data types used by the interpreter for internal calculations / output value ...
T in(range_t< T > r)
Uniformly extracts a random value in a range.
Information about a single column of the dataset.
Stores a single element (row) of the dataset.