Directory: | ./ |
---|---|
File: | include/crocoddyl/core/optctrl/shooting.hxx |
Date: | 2025-01-30 11:01:55 |
Exec | Total | Coverage | |
---|---|---|---|
Lines: | 151 | 263 | 57.4% |
Branches: | 85 | 740 | 11.5% |
Line | Branch | Exec | Source |
---|---|---|---|
1 | /////////////////////////////////////////////////////////////////////////////// | ||
2 | // BSD 3-Clause License | ||
3 | // | ||
4 | // Copyright (C) 2019-2022, LAAS-CNRS, University of Edinburgh, | ||
5 | // University of Oxford, Heriot-Watt University | ||
6 | // Copyright note valid unless otherwise stated in individual files. | ||
7 | // All rights reserved. | ||
8 | /////////////////////////////////////////////////////////////////////////////// | ||
9 | |||
10 | #include <iostream> | ||
11 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
12 | #include <omp.h> | ||
13 | #endif // CROCODDYL_WITH_MULTITHREADING | ||
14 | #include "crocoddyl/core/utils/stop-watch.hpp" | ||
15 | |||
16 | namespace crocoddyl { | ||
17 | |||
18 | template <typename Scalar> | ||
19 | 525 | ShootingProblemTpl<Scalar>::ShootingProblemTpl( | |
20 | const VectorXs& x0, | ||
21 | const std::vector<std::shared_ptr<ActionModelAbstract> >& running_models, | ||
22 | std::shared_ptr<ActionModelAbstract> terminal_model) | ||
23 | 525 | : cost_(Scalar(0.)), | |
24 | 525 | T_(running_models.size()), | |
25 | 525 | x0_(x0), | |
26 | 525 | terminal_model_(terminal_model), | |
27 |
1/2✓ Branch 1 taken 525 times.
✗ Branch 2 not taken.
|
525 | running_models_(running_models), |
28 | 525 | nx_(running_models[0]->get_state()->get_nx()), | |
29 | 525 | ndx_(running_models[0]->get_state()->get_ndx()), | |
30 | 525 | nu_max_(running_models[0]->get_nu()), | |
31 | 525 | nthreads_(1), | |
32 | 1050 | is_updated_(false) { | |
33 |
2/2✓ Branch 0 taken 9507 times.
✓ Branch 1 taken 525 times.
|
10032 | for (std::size_t i = 1; i < T_; ++i) { |
34 | 9507 | const std::shared_ptr<ActionModelAbstract>& model = running_models_[i]; | |
35 | 9507 | const std::size_t nu = model->get_nu(); | |
36 |
2/2✓ Branch 0 taken 48 times.
✓ Branch 1 taken 9459 times.
|
9507 | if (nu_max_ < nu) { |
37 | 48 | nu_max_ = nu; | |
38 | } | ||
39 | } | ||
40 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 525 times.
|
525 | if (static_cast<std::size_t>(x0.size()) != nx_) { |
41 | ✗ | throw_pretty( | |
42 | "Invalid argument: " << "x0 has wrong dimension (it should be " + | ||
43 | std::to_string(nx_) + ")"); | ||
44 | } | ||
45 |
2/2✓ Branch 0 taken 9507 times.
✓ Branch 1 taken 525 times.
|
10032 | for (std::size_t i = 1; i < T_; ++i) { |
46 | 9507 | const std::shared_ptr<ActionModelAbstract>& model = running_models_[i]; | |
47 |
1/2✗ Branch 4 not taken.
✓ Branch 5 taken 9507 times.
|
9507 | if (model->get_state()->get_nx() != nx_) { |
48 | ✗ | throw_pretty("Invalid argument: " | |
49 | << "nx in " << i | ||
50 | << " node is not consistent with the other nodes") | ||
51 | } | ||
52 |
1/2✗ Branch 4 not taken.
✓ Branch 5 taken 9507 times.
|
9507 | if (model->get_state()->get_ndx() != ndx_) { |
53 | ✗ | throw_pretty("Invalid argument: " | |
54 | << "ndx in " << i | ||
55 | << " node is not consistent with the other nodes") | ||
56 | } | ||
57 | } | ||
58 |
1/2✗ Branch 4 not taken.
✓ Branch 5 taken 525 times.
|
525 | if (terminal_model_->get_state()->get_nx() != nx_) { |
59 | ✗ | throw_pretty( | |
60 | "Invalid argument: " | ||
61 | << "nx in terminal node is not consistent with the other nodes") | ||
62 | } | ||
63 |
1/2✗ Branch 4 not taken.
✓ Branch 5 taken 525 times.
|
525 | if (terminal_model_->get_state()->get_ndx() != ndx_) { |
64 | ✗ | throw_pretty( | |
65 | "Invalid argument: " | ||
66 | << "ndx in terminal node is not consistent with the other nodes") | ||
67 | } | ||
68 |
1/2✓ Branch 1 taken 525 times.
✗ Branch 2 not taken.
|
525 | allocateData(); |
69 | |||
70 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
71 | if (enableMultithreading()) { | ||
72 | nthreads_ = CROCODDYL_WITH_NTHREADS; | ||
73 | } | ||
74 | #endif | ||
75 | 525 | } | |
76 | |||
77 | template <typename Scalar> | ||
78 | 297 | ShootingProblemTpl<Scalar>::ShootingProblemTpl( | |
79 | const VectorXs& x0, | ||
80 | const std::vector<std::shared_ptr<ActionModelAbstract> >& running_models, | ||
81 | std::shared_ptr<ActionModelAbstract> terminal_model, | ||
82 | const std::vector<std::shared_ptr<ActionDataAbstract> >& running_datas, | ||
83 | std::shared_ptr<ActionDataAbstract> terminal_data) | ||
84 | 297 | : cost_(Scalar(0.)), | |
85 | 297 | T_(running_models.size()), | |
86 | 297 | x0_(x0), | |
87 | 297 | terminal_model_(terminal_model), | |
88 | 297 | terminal_data_(terminal_data), | |
89 |
1/2✓ Branch 1 taken 297 times.
✗ Branch 2 not taken.
|
297 | running_models_(running_models), |
90 |
1/2✓ Branch 1 taken 297 times.
✗ Branch 2 not taken.
|
297 | running_datas_(running_datas), |
91 | 297 | nx_(running_models[0]->get_state()->get_nx()), | |
92 | 297 | ndx_(running_models[0]->get_state()->get_ndx()), | |
93 | 297 | nu_max_(running_models[0]->get_nu()), | |
94 | 297 | nthreads_(1) { | |
95 |
2/2✓ Branch 0 taken 5643 times.
✓ Branch 1 taken 297 times.
|
5940 | for (std::size_t i = 1; i < T_; ++i) { |
96 | 5643 | const std::shared_ptr<ActionModelAbstract>& model = running_models_[i]; | |
97 | 5643 | const std::size_t nu = model->get_nu(); | |
98 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 5643 times.
|
5643 | if (nu_max_ < nu) { |
99 | ✗ | nu_max_ = nu; | |
100 | } | ||
101 | } | ||
102 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 297 times.
|
297 | if (static_cast<std::size_t>(x0.size()) != nx_) { |
103 | ✗ | throw_pretty( | |
104 | "Invalid argument: " << "x0 has wrong dimension (it should be " + | ||
105 | std::to_string(nx_) + ")"); | ||
106 | } | ||
107 | 297 | const std::size_t Td = running_datas.size(); | |
108 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 297 times.
|
297 | if (Td != T_) { |
109 | ✗ | throw_pretty( | |
110 | "Invalid argument: " | ||
111 | << "the number of running models and datas are not the same (" + | ||
112 | std::to_string(T_) + " != " + std::to_string(Td) + ")") | ||
113 | } | ||
114 |
2/2✓ Branch 0 taken 5940 times.
✓ Branch 1 taken 297 times.
|
6237 | for (std::size_t i = 0; i < T_; ++i) { |
115 | 5940 | const std::shared_ptr<ActionModelAbstract>& model = running_models_[i]; | |
116 | 5940 | const std::shared_ptr<ActionDataAbstract>& data = running_datas_[i]; | |
117 |
1/2✗ Branch 4 not taken.
✓ Branch 5 taken 5940 times.
|
5940 | if (model->get_state()->get_nx() != nx_) { |
118 | ✗ | throw_pretty("Invalid argument: " | |
119 | << "nx in " << i | ||
120 | << " node is not consistent with the other nodes") | ||
121 | } | ||
122 |
1/2✗ Branch 4 not taken.
✓ Branch 5 taken 5940 times.
|
5940 | if (model->get_state()->get_ndx() != ndx_) { |
123 | ✗ | throw_pretty("Invalid argument: " | |
124 | << "ndx in " << i | ||
125 | << " node is not consistent with the other nodes") | ||
126 | } | ||
127 |
2/4✓ Branch 2 taken 5940 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 5940 times.
|
5940 | if (!model->checkData(data)) { |
128 | ✗ | throw_pretty("Invalid argument: " | |
129 | << "action data in " << i | ||
130 | << " node is not consistent with the action model") | ||
131 | } | ||
132 | } | ||
133 |
2/4✓ Branch 2 taken 297 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 297 times.
|
297 | if (!terminal_model->checkData(terminal_data)) { |
134 | ✗ | throw_pretty("Invalid argument: " | |
135 | << "terminal action data is not consistent with the terminal " | ||
136 | "action model") | ||
137 | } | ||
138 | |||
139 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
140 | if (enableMultithreading()) { | ||
141 | nthreads_ = CROCODDYL_WITH_NTHREADS; | ||
142 | } | ||
143 | #endif | ||
144 | 297 | } | |
145 | |||
146 | template <typename Scalar> | ||
147 | 2 | ShootingProblemTpl<Scalar>::ShootingProblemTpl( | |
148 | const ShootingProblemTpl<Scalar>& problem) | ||
149 | 2 | : cost_(Scalar(0.)), | |
150 | 2 | T_(problem.get_T()), | |
151 | 2 | x0_(problem.get_x0()), | |
152 | 2 | terminal_model_(problem.get_terminalModel()), | |
153 | 2 | terminal_data_(problem.get_terminalData()), | |
154 |
1/2✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
2 | running_models_(problem.get_runningModels()), |
155 |
1/2✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
2 | running_datas_(problem.get_runningDatas()), |
156 | 2 | nx_(problem.get_nx()), | |
157 | 2 | ndx_(problem.get_ndx()), | |
158 | 2 | nu_max_(problem.get_nu_max()) {} | |
159 | |||
160 | template <typename Scalar> | ||
161 | 824 | ShootingProblemTpl<Scalar>::~ShootingProblemTpl() {} | |
162 | |||
163 | template <typename Scalar> | ||
164 | 619 | Scalar ShootingProblemTpl<Scalar>::calc(const std::vector<VectorXs>& xs, | |
165 | const std::vector<VectorXs>& us) { | ||
166 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 619 times.
|
619 | if (xs.size() != T_ + 1) { |
167 | ✗ | throw_pretty( | |
168 | "Invalid argument: " << "xs has wrong dimension (it should be " + | ||
169 | std::to_string(T_ + 1) + ")"); | ||
170 | } | ||
171 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 619 times.
|
619 | if (us.size() != T_) { |
172 | ✗ | throw_pretty( | |
173 | "Invalid argument: " << "us has wrong dimension (it should be " + | ||
174 | std::to_string(T_) + ")"); | ||
175 | } | ||
176 |
3/16✗ Branch 2 not taken.
✓ Branch 3 taken 619 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 619 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 619 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
|
619 | START_PROFILER("ShootingProblem::calc"); |
177 | |||
178 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
179 | #pragma omp parallel for num_threads(nthreads_) | ||
180 | #endif | ||
181 |
2/2✓ Branch 0 taken 10615 times.
✓ Branch 1 taken 619 times.
|
11234 | for (std::size_t i = 0; i < T_; ++i) { |
182 |
2/4✓ Branch 6 taken 10615 times.
✗ Branch 7 not taken.
✓ Branch 10 taken 10615 times.
✗ Branch 11 not taken.
|
10615 | running_models_[i]->calc(running_datas_[i], xs[i], us[i]); |
183 | } | ||
184 |
1/2✓ Branch 4 taken 619 times.
✗ Branch 5 not taken.
|
619 | terminal_model_->calc(terminal_data_, xs.back()); |
185 | |||
186 | 619 | cost_ = Scalar(0.); | |
187 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
188 | #pragma omp simd reduction(+ : cost_) | ||
189 | #endif | ||
190 |
2/2✓ Branch 0 taken 10615 times.
✓ Branch 1 taken 619 times.
|
11234 | for (std::size_t i = 0; i < T_; ++i) { |
191 | 10615 | cost_ += running_datas_[i]->cost; | |
192 | } | ||
193 | 619 | cost_ += terminal_data_->cost; | |
194 |
3/16✗ Branch 2 not taken.
✓ Branch 3 taken 619 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 619 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 619 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
|
619 | STOP_PROFILER("ShootingProblem::calc"); |
195 | 619 | return cost_; | |
196 | } | ||
197 | |||
198 | template <typename Scalar> | ||
199 | 441 | Scalar ShootingProblemTpl<Scalar>::calcDiff(const std::vector<VectorXs>& xs, | |
200 | const std::vector<VectorXs>& us) { | ||
201 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 441 times.
|
441 | if (xs.size() != T_ + 1) { |
202 | ✗ | throw_pretty( | |
203 | "Invalid argument: " << "xs has wrong dimension (it should be " + | ||
204 | std::to_string(T_ + 1) + ")"); | ||
205 | } | ||
206 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 441 times.
|
441 | if (us.size() != T_) { |
207 | ✗ | throw_pretty( | |
208 | "Invalid argument: " << "us has wrong dimension (it should be " + | ||
209 | std::to_string(T_) + ")"); | ||
210 | } | ||
211 |
3/16✗ Branch 2 not taken.
✓ Branch 3 taken 441 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 441 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 441 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
|
441 | START_PROFILER("ShootingProblem::calcDiff"); |
212 | |||
213 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
214 | #pragma omp parallel for num_threads(nthreads_) | ||
215 | #endif | ||
216 |
2/2✓ Branch 0 taken 7064 times.
✓ Branch 1 taken 441 times.
|
7505 | for (std::size_t i = 0; i < T_; ++i) { |
217 |
2/4✓ Branch 6 taken 7064 times.
✗ Branch 7 not taken.
✓ Branch 10 taken 7064 times.
✗ Branch 11 not taken.
|
7064 | running_models_[i]->calcDiff(running_datas_[i], xs[i], us[i]); |
218 | } | ||
219 |
1/2✓ Branch 4 taken 441 times.
✗ Branch 5 not taken.
|
441 | terminal_model_->calcDiff(terminal_data_, xs.back()); |
220 | |||
221 | 441 | cost_ = Scalar(0.); | |
222 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
223 | #pragma omp simd reduction(+ : cost_) | ||
224 | #endif | ||
225 |
2/2✓ Branch 0 taken 7064 times.
✓ Branch 1 taken 441 times.
|
7505 | for (std::size_t i = 0; i < T_; ++i) { |
226 | 7064 | cost_ += running_datas_[i]->cost; | |
227 | } | ||
228 | 441 | cost_ += terminal_data_->cost; | |
229 | |||
230 |
3/16✗ Branch 2 not taken.
✓ Branch 3 taken 441 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 441 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 441 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
|
441 | STOP_PROFILER("ShootingProblem::calcDiff"); |
231 | 441 | return cost_; | |
232 | } | ||
233 | |||
234 | template <typename Scalar> | ||
235 | 103 | void ShootingProblemTpl<Scalar>::rollout(const std::vector<VectorXs>& us, | |
236 | std::vector<VectorXs>& xs) { | ||
237 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 103 times.
|
103 | if (xs.size() != T_ + 1) { |
238 | ✗ | throw_pretty( | |
239 | "Invalid argument: " << "xs has wrong dimension (it should be " + | ||
240 | std::to_string(T_ + 1) + ")"); | ||
241 | } | ||
242 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 103 times.
|
103 | if (us.size() != T_) { |
243 | ✗ | throw_pretty( | |
244 | "Invalid argument: " << "us has wrong dimension (it should be " + | ||
245 | std::to_string(T_) + ")"); | ||
246 | } | ||
247 |
3/16✗ Branch 2 not taken.
✓ Branch 3 taken 103 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 103 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 103 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
|
103 | START_PROFILER("ShootingProblem::rollout"); |
248 | |||
249 | 103 | xs[0] = x0_; | |
250 |
2/2✓ Branch 0 taken 2336 times.
✓ Branch 1 taken 103 times.
|
2439 | for (std::size_t i = 0; i < T_; ++i) { |
251 | 2336 | const std::shared_ptr<ActionDataAbstract>& data = running_datas_[i]; | |
252 |
2/4✓ Branch 6 taken 2336 times.
✗ Branch 7 not taken.
✓ Branch 9 taken 2336 times.
✗ Branch 10 not taken.
|
2336 | running_models_[i]->calc(data, xs[i], us[i]); |
253 | 2336 | xs[i + 1] = data->xnext; | |
254 | } | ||
255 |
1/2✓ Branch 4 taken 103 times.
✗ Branch 5 not taken.
|
103 | terminal_model_->calc(terminal_data_, xs.back()); |
256 |
3/16✗ Branch 2 not taken.
✓ Branch 3 taken 103 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 103 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 103 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
|
103 | STOP_PROFILER("ShootingProblem::rollout"); |
257 | 103 | } | |
258 | |||
259 | template <typename Scalar> | ||
260 | std::vector<typename MathBaseTpl<Scalar>::VectorXs> | ||
261 | 4 | ShootingProblemTpl<Scalar>::rollout_us(const std::vector<VectorXs>& us) { | |
262 | 4 | std::vector<VectorXs> xs; | |
263 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | xs.resize(T_ + 1); |
264 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | rollout(us, xs); |
265 | 4 | return xs; | |
266 | } | ||
267 | |||
268 | template <typename Scalar> | ||
269 | 198 | void ShootingProblemTpl<Scalar>::quasiStatic(std::vector<VectorXs>& us, | |
270 | const std::vector<VectorXs>& xs) { | ||
271 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 198 times.
|
198 | if (xs.size() != T_) { |
272 | ✗ | throw_pretty( | |
273 | "Invalid argument: " << "xs has wrong dimension (it should be " + | ||
274 | std::to_string(T_) + ")"); | ||
275 | } | ||
276 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 198 times.
|
198 | if (us.size() != T_) { |
277 | ✗ | throw_pretty( | |
278 | "Invalid argument: " << "us has wrong dimension (it should be " + | ||
279 | std::to_string(T_) + ")"); | ||
280 | } | ||
281 | |||
282 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
283 | #pragma omp parallel for num_threads(nthreads_) | ||
284 | #endif | ||
285 |
2/2✓ Branch 0 taken 3960 times.
✓ Branch 1 taken 198 times.
|
4158 | for (std::size_t i = 0; i < T_; ++i) { |
286 |
2/4✓ Branch 6 taken 3960 times.
✗ Branch 7 not taken.
✓ Branch 10 taken 3960 times.
✗ Branch 11 not taken.
|
3960 | running_models_[i]->quasiStatic(running_datas_[i], us[i], xs[i]); |
287 | } | ||
288 | 198 | } | |
289 | |||
290 | template <typename Scalar> | ||
291 | std::vector<typename MathBaseTpl<Scalar>::VectorXs> | ||
292 | ✗ | ShootingProblemTpl<Scalar>::quasiStatic_xs(const std::vector<VectorXs>& xs) { | |
293 | ✗ | std::vector<VectorXs> us; | |
294 | ✗ | us.resize(T_); | |
295 | ✗ | for (std::size_t i = 0; i < T_; ++i) { | |
296 | ✗ | us[i] = VectorXs::Zero(running_models_[i]->get_nu()); | |
297 | } | ||
298 | ✗ | quasiStatic(us, xs); | |
299 | ✗ | return us; | |
300 | } | ||
301 | |||
302 | template <typename Scalar> | ||
303 | ✗ | void ShootingProblemTpl<Scalar>::circularAppend( | |
304 | std::shared_ptr<ActionModelAbstract> model, | ||
305 | std::shared_ptr<ActionDataAbstract> data) { | ||
306 | ✗ | if (!model->checkData(data)) { | |
307 | ✗ | throw_pretty("Invalid argument: " | |
308 | << "action data is not consistent with the action model") | ||
309 | } | ||
310 | ✗ | if (model->get_state()->get_nx() != nx_) { | |
311 | ✗ | throw_pretty( | |
312 | "Invalid argument: " << "nx is not consistent with the other nodes") | ||
313 | } | ||
314 | ✗ | if (model->get_state()->get_ndx() != ndx_) { | |
315 | ✗ | throw_pretty("Invalid argument: " | |
316 | << "ndx node is not consistent with the other nodes") | ||
317 | } | ||
318 | ✗ | is_updated_ = true; | |
319 | ✗ | for (std::size_t i = 0; i < T_ - 1; ++i) { | |
320 | ✗ | running_models_[i] = running_models_[i + 1]; | |
321 | ✗ | running_datas_[i] = running_datas_[i + 1]; | |
322 | } | ||
323 | ✗ | running_models_.back() = model; | |
324 | ✗ | running_datas_.back() = data; | |
325 | } | ||
326 | |||
327 | template <typename Scalar> | ||
328 | ✗ | void ShootingProblemTpl<Scalar>::circularAppend( | |
329 | std::shared_ptr<ActionModelAbstract> model) { | ||
330 | ✗ | if (model->get_state()->get_nx() != nx_) { | |
331 | ✗ | throw_pretty( | |
332 | "Invalid argument: " << "nx is not consistent with the other nodes") | ||
333 | } | ||
334 | ✗ | if (model->get_state()->get_ndx() != ndx_) { | |
335 | ✗ | throw_pretty("Invalid argument: " | |
336 | << "ndx node is not consistent with the other nodes") | ||
337 | } | ||
338 | ✗ | is_updated_ = true; | |
339 | ✗ | for (std::size_t i = 0; i < T_ - 1; ++i) { | |
340 | ✗ | running_models_[i] = running_models_[i + 1]; | |
341 | ✗ | running_datas_[i] = running_datas_[i + 1]; | |
342 | } | ||
343 | ✗ | running_models_.back() = model; | |
344 | ✗ | running_datas_.back() = model->createData(); | |
345 | } | ||
346 | |||
347 | template <typename Scalar> | ||
348 | ✗ | void ShootingProblemTpl<Scalar>::updateNode( | |
349 | const std::size_t i, std::shared_ptr<ActionModelAbstract> model, | ||
350 | std::shared_ptr<ActionDataAbstract> data) { | ||
351 | ✗ | if (i >= T_ + 1) { | |
352 | ✗ | throw_pretty("Invalid argument: " | |
353 | << "i is bigger than the allocated horizon (it should be less " | ||
354 | "than or equal to " + | ||
355 | std::to_string(T_ + 1) + ")"); | ||
356 | } | ||
357 | ✗ | if (!model->checkData(data)) { | |
358 | ✗ | throw_pretty("Invalid argument: " | |
359 | << "action data is not consistent with the action model") | ||
360 | } | ||
361 | ✗ | if (model->get_state()->get_nx() != nx_) { | |
362 | ✗ | throw_pretty( | |
363 | "Invalid argument: " << "nx is not consistent with the other nodes") | ||
364 | } | ||
365 | ✗ | if (model->get_state()->get_ndx() != ndx_) { | |
366 | ✗ | throw_pretty("Invalid argument: " | |
367 | << "ndx node is not consistent with the other nodes") | ||
368 | } | ||
369 | ✗ | is_updated_ = true; | |
370 | ✗ | if (i == T_) { | |
371 | ✗ | terminal_model_ = model; | |
372 | ✗ | terminal_data_ = data; | |
373 | } else { | ||
374 | ✗ | running_models_[i] = model; | |
375 | ✗ | running_datas_[i] = data; | |
376 | } | ||
377 | } | ||
378 | |||
379 | template <typename Scalar> | ||
380 | ✗ | void ShootingProblemTpl<Scalar>::updateModel( | |
381 | const std::size_t i, std::shared_ptr<ActionModelAbstract> model) { | ||
382 | ✗ | if (i >= T_ + 1) { | |
383 | ✗ | throw_pretty( | |
384 | "Invalid argument: " | ||
385 | << "i is bigger than the allocated horizon (it should be lower than " + | ||
386 | std::to_string(T_ + 1) + ")"); | ||
387 | } | ||
388 | ✗ | if (model->get_state()->get_nx() != nx_) { | |
389 | ✗ | throw_pretty( | |
390 | "Invalid argument: " << "nx is not consistent with the other nodes") | ||
391 | } | ||
392 | ✗ | if (model->get_state()->get_ndx() != ndx_) { | |
393 | ✗ | throw_pretty( | |
394 | "Invalid argument: " << "ndx is not consistent with the other nodes") | ||
395 | } | ||
396 | ✗ | is_updated_ = true; | |
397 | ✗ | if (i == T_) { | |
398 | ✗ | terminal_model_ = model; | |
399 | ✗ | terminal_data_ = terminal_model_->createData(); | |
400 | } else { | ||
401 | ✗ | running_models_[i] = model; | |
402 | ✗ | running_datas_[i] = model->createData(); | |
403 | } | ||
404 | } | ||
405 | |||
406 | template <typename Scalar> | ||
407 | 6078 | std::size_t ShootingProblemTpl<Scalar>::get_T() const { | |
408 | 6078 | return T_; | |
409 | } | ||
410 | |||
411 | template <typename Scalar> | ||
412 | const typename MathBaseTpl<Scalar>::VectorXs& | ||
413 | 538 | ShootingProblemTpl<Scalar>::get_x0() const { | |
414 | 538 | return x0_; | |
415 | } | ||
416 | |||
417 | template <typename Scalar> | ||
418 | 525 | void ShootingProblemTpl<Scalar>::allocateData() { | |
419 | 525 | running_datas_.resize(T_); | |
420 |
2/2✓ Branch 0 taken 10032 times.
✓ Branch 1 taken 525 times.
|
10557 | for (std::size_t i = 0; i < T_; ++i) { |
421 | 10032 | const std::shared_ptr<ActionModelAbstract>& model = running_models_[i]; | |
422 | 10032 | running_datas_[i] = model->createData(); | |
423 | } | ||
424 | 525 | terminal_data_ = terminal_model_->createData(); | |
425 | 525 | } | |
426 | |||
427 | template <typename Scalar> | ||
428 | const std::vector<std::shared_ptr<crocoddyl::ActionModelAbstractTpl<Scalar> > >& | ||
429 | 5147 | ShootingProblemTpl<Scalar>::get_runningModels() const { | |
430 | 5147 | return running_models_; | |
431 | } | ||
432 | |||
433 | template <typename Scalar> | ||
434 | const std::shared_ptr<crocoddyl::ActionModelAbstractTpl<Scalar> >& | ||
435 | 1536 | ShootingProblemTpl<Scalar>::get_terminalModel() const { | |
436 | 1536 | return terminal_model_; | |
437 | } | ||
438 | |||
439 | template <typename Scalar> | ||
440 | const std::vector<std::shared_ptr<crocoddyl::ActionDataAbstractTpl<Scalar> > >& | ||
441 | 37135 | ShootingProblemTpl<Scalar>::get_runningDatas() const { | |
442 | 37135 | return running_datas_; | |
443 | } | ||
444 | |||
445 | template <typename Scalar> | ||
446 | const std::shared_ptr<crocoddyl::ActionDataAbstractTpl<Scalar> >& | ||
447 | 1578 | ShootingProblemTpl<Scalar>::get_terminalData() const { | |
448 | 1578 | return terminal_data_; | |
449 | } | ||
450 | |||
451 | template <typename Scalar> | ||
452 | ✗ | void ShootingProblemTpl<Scalar>::set_x0(const VectorXs& x0_in) { | |
453 | ✗ | if (x0_in.size() != x0_.size()) { | |
454 | ✗ | throw_pretty("Invalid argument: " | |
455 | << "invalid size of x0 provided: Expected " << x0_.size() | ||
456 | << ", received " << x0_in.size()); | ||
457 | } | ||
458 | ✗ | x0_ = x0_in; | |
459 | } | ||
460 | |||
461 | template <typename Scalar> | ||
462 | ✗ | void ShootingProblemTpl<Scalar>::set_runningModels( | |
463 | const std::vector<std::shared_ptr<ActionModelAbstract> >& models) { | ||
464 | ✗ | for (std::size_t i = 0; i < T_; ++i) { | |
465 | ✗ | const std::shared_ptr<ActionModelAbstract>& model = running_models_[i]; | |
466 | ✗ | if (model->get_state()->get_nx() != nx_) { | |
467 | ✗ | throw_pretty("Invalid argument: " | |
468 | << "nx in " << i | ||
469 | << " node is not consistent with the other nodes") | ||
470 | } | ||
471 | ✗ | if (model->get_state()->get_ndx() != ndx_) { | |
472 | ✗ | throw_pretty("Invalid argument: " | |
473 | << "ndx in " << i | ||
474 | << " node is not consistent with the other nodes") | ||
475 | } | ||
476 | } | ||
477 | ✗ | is_updated_ = true; | |
478 | ✗ | T_ = models.size(); | |
479 | ✗ | running_models_.clear(); | |
480 | ✗ | running_datas_.clear(); | |
481 | ✗ | for (std::size_t i = 0; i < T_; ++i) { | |
482 | ✗ | const std::shared_ptr<ActionModelAbstract>& model = running_models_[i]; | |
483 | ✗ | running_datas_.push_back(model->createData()); | |
484 | } | ||
485 | } | ||
486 | |||
487 | template <typename Scalar> | ||
488 | ✗ | void ShootingProblemTpl<Scalar>::set_terminalModel( | |
489 | std::shared_ptr<ActionModelAbstract> model) { | ||
490 | ✗ | if (model->get_state()->get_nx() != nx_) { | |
491 | ✗ | throw_pretty( | |
492 | "Invalid argument: " << "nx is not consistent with the other nodes") | ||
493 | } | ||
494 | ✗ | if (model->get_state()->get_ndx() != ndx_) { | |
495 | ✗ | throw_pretty( | |
496 | "Invalid argument: " << "ndx is not consistent with the other nodes") | ||
497 | } | ||
498 | ✗ | is_updated_ = true; | |
499 | ✗ | terminal_model_ = model; | |
500 | ✗ | terminal_data_ = terminal_model_->createData(); | |
501 | } | ||
502 | |||
503 | template <typename Scalar> | ||
504 | ✗ | void ShootingProblemTpl<Scalar>::set_nthreads(const int nthreads) { | |
505 | #ifndef CROCODDYL_WITH_MULTITHREADING | ||
506 | (void)nthreads; | ||
507 | ✗ | std::cerr << "Warning: the number of threads won't affect the computational " | |
508 | "performance as multithreading " | ||
509 | "support is not enabled." | ||
510 | ✗ | << std::endl; | |
511 | #else | ||
512 | if (nthreads < 1) { | ||
513 | nthreads_ = CROCODDYL_WITH_NTHREADS; | ||
514 | } else { | ||
515 | nthreads_ = static_cast<std::size_t>(nthreads); | ||
516 | } | ||
517 | if (!enableMultithreading()) { | ||
518 | std::cerr << "Warning: the number of threads won't affect the " | ||
519 | "computational performance as multithreading " | ||
520 | "support is not enabled." | ||
521 | << std::endl; | ||
522 | nthreads_ = 1; | ||
523 | } | ||
524 | #endif | ||
525 | } | ||
526 | |||
527 | template <typename Scalar> | ||
528 | 49 | std::size_t ShootingProblemTpl<Scalar>::get_nx() const { | |
529 | 49 | return nx_; | |
530 | } | ||
531 | |||
532 | template <typename Scalar> | ||
533 | 230 | std::size_t ShootingProblemTpl<Scalar>::get_ndx() const { | |
534 | 230 | return ndx_; | |
535 | } | ||
536 | |||
537 | template <typename Scalar> | ||
538 | 2 | std::size_t ShootingProblemTpl<Scalar>::get_nu_max() const { | |
539 | 2 | return nu_max_; | |
540 | } | ||
541 | |||
542 | template <typename Scalar> | ||
543 | ✗ | std::size_t ShootingProblemTpl<Scalar>::get_nthreads() const { | |
544 | #ifndef CROCODDYL_WITH_MULTITHREADING | ||
545 | ✗ | std::cerr << "Warning: the number of threads won't affect the computational " | |
546 | "performance as multithreading " | ||
547 | "support is not enabled." | ||
548 | ✗ | << std::endl; | |
549 | #endif | ||
550 | ✗ | return nthreads_; | |
551 | } | ||
552 | |||
553 | template <typename Scalar> | ||
554 | 32 | bool ShootingProblemTpl<Scalar>::is_updated() { | |
555 | 32 | const bool status = is_updated_; | |
556 | 32 | is_updated_ = false; | |
557 | 32 | return status; | |
558 | } | ||
559 | |||
560 | template <typename Scalar> | ||
561 | 7 | std::ostream& operator<<(std::ostream& os, | |
562 | const ShootingProblemTpl<Scalar>& problem) { | ||
563 | 7 | os << "ShootingProblem (T=" << problem.get_T() << ", nx=" << problem.get_nx() | |
564 | 7 | << ", ndx=" << problem.get_ndx() << ") " << std::endl | |
565 | 7 | << " Models:" << std::endl; | |
566 | const std::vector< | ||
567 | std::shared_ptr<crocoddyl::ActionModelAbstractTpl<Scalar> > >& | ||
568 | 7 | runningModels = problem.get_runningModels(); | |
569 |
2/2✓ Branch 1 taken 140 times.
✓ Branch 2 taken 7 times.
|
147 | for (std::size_t t = 0; t < problem.get_T(); ++t) { |
570 | 140 | os << " " << t << ": " << *runningModels[t] << std::endl; | |
571 | } | ||
572 | 7 | os << " " << problem.get_T() << ": " << *problem.get_terminalModel(); | |
573 | 7 | return os; | |
574 | } | ||
575 | |||
576 | } // namespace crocoddyl | ||
577 |