// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_STRUCTURAL_SEQUENCE_LABELING_TRAiNER_Hh_
#define DLIB_STRUCTURAL_SEQUENCE_LABELING_TRAiNER_Hh_
#include "structural_sequence_labeling_trainer_abstract.h"
#include "../algs.h"
#include "../optimization.h"
#include "structural_svm_sequence_labeling_problem.h"
#include "num_nonnegative_weights.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
class structural_sequence_labeling_trainer
{
public:
typedef typename feature_extractor::sequence_type sample_sequence_type;
typedef std::vector<unsigned long> labeled_sequence_type;
typedef sequence_labeler<feature_extractor> trained_function_type;
explicit structural_sequence_labeling_trainer (
const feature_extractor& fe_
) : fe(fe_)
{
set_defaults();
}
structural_sequence_labeling_trainer (
)
{
set_defaults();
}
const feature_extractor& get_feature_extractor (
) const { return fe; }
unsigned long num_labels (
) const { return fe.num_labels(); }
void set_num_threads (
unsigned long num
)
{
num_threads = num;
}
unsigned long get_num_threads (
) const
{
return num_threads;
}
void set_epsilon (
double eps_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(eps_ > 0,
"\t void structural_sequence_labeling_trainer::set_epsilon()"
<< "\n\t eps_ must be greater than 0"
<< "\n\t eps_: " << eps_
<< "\n\t this: " << this
);
eps = eps_;
}
double get_epsilon (
) const { return eps; }
unsigned long get_max_iterations (
) const { return max_iterations; }
void set_max_iterations (
unsigned long max_iter
)
{
max_iterations = max_iter;
}
void set_max_cache_size (
unsigned long max_size
)
{
max_cache_size = max_size;
}
unsigned long get_max_cache_size (
) const
{
return max_cache_size;
}
void be_verbose (
)
{
verbose = true;
}
void be_quiet (
)
{
verbose = false;
}
void set_oca (
const oca& item
)
{
solver = item;
}
const oca get_oca (
) const
{
return solver;
}
void set_c (
double C_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(C_ > 0,
"\t void structural_sequence_labeling_trainer::set_c()"
<< "\n\t C_ must be greater than 0"
<< "\n\t C_: " << C_
<< "\n\t this: " << this
);
C = C_;
}
double get_c (
) const
{
return C;
}
double get_loss (
unsigned long label
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(label < num_labels(),
"\t void structural_sequence_labeling_trainer::get_loss()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t label: " << label
<< "\n\t num_labels(): " << num_labels()
<< "\n\t this: " << this
);
return loss_values[label];
}
void set_loss (
unsigned long label,
double value
)
{
// make sure requires clause is not broken
DLIB_ASSERT(label < num_labels() && value >= 0,
"\t void structural_sequence_labeling_trainer::set_loss()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t label: " << label
<< "\n\t num_labels(): " << num_labels()
<< "\n\t value: " << value
<< "\n\t this: " << this
);
loss_values[label] = value;
}
const sequence_labeler<feature_extractor> train(
const std::vector<sample_sequence_type>& x,
const std::vector<labeled_sequence_type>& y
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(is_sequence_labeling_problem(x,y) == true &&
contains_invalid_labeling(get_feature_extractor(), x, y) == false,
"\t sequence_labeler structural_sequence_labeling_trainer::train(x,y)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t x.size(): " << x.size()
<< "\n\t is_sequence_labeling_problem(x,y): " << is_sequence_labeling_problem(x,y)
<< "\n\t contains_invalid_labeling(get_feature_extractor(),x,y): " << contains_invalid_labeling(get_feature_extractor(),x,y)
<< "\n\t this: " << this
);
#ifdef ENABLE_ASSERTS
for (unsigned long i = 0; i < y.size(); ++i)
{
for (unsigned long j = 0; j < y[i].size(); ++j)
{
// make sure requires clause is not broken
DLIB_ASSERT(y[i][j] < num_labels(),
"\t sequence_labeler structural_sequence_labeling_trainer::train(x,y)"
<< "\n\t The given labels in y are invalid."
<< "\n\t y[i][j]: " << y[i][j]
<< "\n\t num_labels(): " << num_labels()
<< "\n\t i: " << i
<< "\n\t j: " << j
<< "\n\t this: " << this
);
}
}
#endif
structural_svm_sequence_labeling_problem<feature_extractor> prob(x, y, fe, num_threads);
matrix<double,0,1> weights;
if (verbose)
prob.be_verbose();
prob.set_epsilon(eps);
prob.set_max_iterations(max_iterations);
prob.set_c(C);
prob.set_max_cache_size(max_cache_size);
for (unsigned long i = 0; i < loss_values.size(); ++i)
prob.set_loss(i,loss_values[i]);
solver(prob, weights, num_nonnegative_weights(fe));
return sequence_labeler<feature_extractor>(weights,fe);
}
private:
double C;
oca solver;
double eps;
unsigned long max_iterations;
bool verbose;
unsigned long num_threads;
unsigned long max_cache_size;
std::vector<double> loss_values;
void set_defaults ()
{
C = 100;
verbose = false;
eps = 0.1;
max_iterations = 10000;
num_threads = 2;
max_cache_size = 5;
loss_values.assign(num_labels(), 1);
}
feature_extractor fe;
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_STRUCTURAL_SEQUENCE_LABELING_TRAiNER_Hh_