Line |
Branch |
Exec |
Source |
1 |
|
|
/* |
2 |
|
|
* Copyright 2018, |
3 |
|
|
* Julian Viereck |
4 |
|
|
* |
5 |
|
|
* CNRS/AIST |
6 |
|
|
* |
7 |
|
|
*/ |
8 |
|
|
|
9 |
|
|
#include <dynamic-graph/all-commands.h> |
10 |
|
|
#include <dynamic-graph/factory.h> |
11 |
|
|
|
12 |
|
|
#include <boost/function.hpp> |
13 |
|
|
#include <sot/core/factory.hh> |
14 |
|
|
#include <sot/core/gradient-ascent.hh> |
15 |
|
|
|
16 |
|
|
namespace dg = ::dynamicgraph; |
17 |
|
|
|
18 |
|
|
/* ---------------------------------------------------------------------------*/ |
19 |
|
|
/* ------- GENERIC HELPERS -------------------------------------------------- */ |
20 |
|
|
/* ---------------------------------------------------------------------------*/ |
21 |
|
|
|
22 |
|
|
namespace dynamicgraph { |
23 |
|
|
namespace sot { |
24 |
|
|
|
25 |
|
✗ |
DYNAMICGRAPH_FACTORY_ENTITY_PLUGIN(GradientAscent, "GradientAscent"); |
26 |
|
|
|
27 |
|
|
/* --------------------------------------------------------------------- */ |
28 |
|
|
/* --- CLASS ----------------------------------------------------------- */ |
29 |
|
|
/* --------------------------------------------------------------------- */ |
30 |
|
|
|
31 |
|
✗ |
GradientAscent::GradientAscent(const std::string &n) |
32 |
|
|
: Entity(n), |
33 |
|
✗ |
gradientSIN(NULL, "GradientAscent(" + n + ")::input(vector)::gradient"), |
34 |
|
✗ |
learningRateSIN(NULL, |
35 |
|
✗ |
"GradientAscent(" + n + ")::input(double)::learningRate"), |
36 |
|
✗ |
refresherSINTERN("GradientAscent(" + n + ")::intern(dummy)::refresher"), |
37 |
|
✗ |
valueSOUT(boost::bind(&GradientAscent::update, this, _1, _2), |
38 |
|
✗ |
gradientSIN << refresherSINTERN, |
39 |
|
✗ |
"GradientAscent(" + n + ")::output(vector)::value"), |
40 |
|
✗ |
init(false) { |
41 |
|
|
// Register signals into the entity. |
42 |
|
✗ |
signalRegistration(gradientSIN << learningRateSIN << valueSOUT); |
43 |
|
✗ |
refresherSINTERN.setDependencyType(TimeDependency<int>::ALWAYS_READY); |
44 |
|
|
} |
45 |
|
|
|
46 |
|
✗ |
GradientAscent::~GradientAscent() {} |
47 |
|
|
|
48 |
|
|
/* --- COMPUTE ----------------------------------------------------------- */ |
49 |
|
|
/* --- COMPUTE ----------------------------------------------------------- */ |
50 |
|
|
/* --- COMPUTE ----------------------------------------------------------- */ |
51 |
|
|
|
52 |
|
✗ |
dynamicgraph::Vector &GradientAscent::update(dynamicgraph::Vector &res, |
53 |
|
|
const int &inTime) { |
54 |
|
✗ |
const dynamicgraph::Vector &gradient = gradientSIN(inTime); |
55 |
|
✗ |
const double &learningRate = learningRateSIN(inTime); |
56 |
|
|
|
57 |
|
✗ |
if (init == false) { |
58 |
|
✗ |
init = true; |
59 |
|
✗ |
value = gradient; |
60 |
|
✗ |
value.setZero(); |
61 |
|
✗ |
res.resize(value.size()); |
62 |
|
|
} |
63 |
|
|
|
64 |
|
✗ |
value += learningRate * gradient; |
65 |
|
✗ |
res = value; |
66 |
|
✗ |
return res; |
67 |
|
|
} |
68 |
|
|
|
69 |
|
|
} /* namespace sot */ |
70 |
|
|
} /* namespace dynamicgraph */ |
71 |
|
|
|