GCC Code Coverage Report


Directory: ./
File: include/crocoddyl/core/optctrl/shooting.hxx
Date: 2025-01-30 11:01:55
Exec Total Coverage
Lines: 151 263 57.4%
Branches: 85 740 11.5%

Line Branch Exec Source
1 ///////////////////////////////////////////////////////////////////////////////
2 // BSD 3-Clause License
3 //
4 // Copyright (C) 2019-2022, LAAS-CNRS, University of Edinburgh,
5 // University of Oxford, Heriot-Watt University
6 // Copyright note valid unless otherwise stated in individual files.
7 // All rights reserved.
8 ///////////////////////////////////////////////////////////////////////////////
9
10 #include <iostream>
11 #ifdef CROCODDYL_WITH_MULTITHREADING
12 #include <omp.h>
13 #endif // CROCODDYL_WITH_MULTITHREADING
14 #include "crocoddyl/core/utils/stop-watch.hpp"
15
16 namespace crocoddyl {
17
18 template <typename Scalar>
19 525 ShootingProblemTpl<Scalar>::ShootingProblemTpl(
20 const VectorXs& x0,
21 const std::vector<std::shared_ptr<ActionModelAbstract> >& running_models,
22 std::shared_ptr<ActionModelAbstract> terminal_model)
23 525 : cost_(Scalar(0.)),
24 525 T_(running_models.size()),
25 525 x0_(x0),
26 525 terminal_model_(terminal_model),
27
1/2
✓ Branch 1 taken 525 times.
✗ Branch 2 not taken.
525 running_models_(running_models),
28 525 nx_(running_models[0]->get_state()->get_nx()),
29 525 ndx_(running_models[0]->get_state()->get_ndx()),
30 525 nu_max_(running_models[0]->get_nu()),
31 525 nthreads_(1),
32 1050 is_updated_(false) {
33
2/2
✓ Branch 0 taken 9507 times.
✓ Branch 1 taken 525 times.
10032 for (std::size_t i = 1; i < T_; ++i) {
34 9507 const std::shared_ptr<ActionModelAbstract>& model = running_models_[i];
35 9507 const std::size_t nu = model->get_nu();
36
2/2
✓ Branch 0 taken 48 times.
✓ Branch 1 taken 9459 times.
9507 if (nu_max_ < nu) {
37 48 nu_max_ = nu;
38 }
39 }
40
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 525 times.
525 if (static_cast<std::size_t>(x0.size()) != nx_) {
41 throw_pretty(
42 "Invalid argument: " << "x0 has wrong dimension (it should be " +
43 std::to_string(nx_) + ")");
44 }
45
2/2
✓ Branch 0 taken 9507 times.
✓ Branch 1 taken 525 times.
10032 for (std::size_t i = 1; i < T_; ++i) {
46 9507 const std::shared_ptr<ActionModelAbstract>& model = running_models_[i];
47
1/2
✗ Branch 4 not taken.
✓ Branch 5 taken 9507 times.
9507 if (model->get_state()->get_nx() != nx_) {
48 throw_pretty("Invalid argument: "
49 << "nx in " << i
50 << " node is not consistent with the other nodes")
51 }
52
1/2
✗ Branch 4 not taken.
✓ Branch 5 taken 9507 times.
9507 if (model->get_state()->get_ndx() != ndx_) {
53 throw_pretty("Invalid argument: "
54 << "ndx in " << i
55 << " node is not consistent with the other nodes")
56 }
57 }
58
1/2
✗ Branch 4 not taken.
✓ Branch 5 taken 525 times.
525 if (terminal_model_->get_state()->get_nx() != nx_) {
59 throw_pretty(
60 "Invalid argument: "
61 << "nx in terminal node is not consistent with the other nodes")
62 }
63
1/2
✗ Branch 4 not taken.
✓ Branch 5 taken 525 times.
525 if (terminal_model_->get_state()->get_ndx() != ndx_) {
64 throw_pretty(
65 "Invalid argument: "
66 << "ndx in terminal node is not consistent with the other nodes")
67 }
68
1/2
✓ Branch 1 taken 525 times.
✗ Branch 2 not taken.
525 allocateData();
69
70 #ifdef CROCODDYL_WITH_MULTITHREADING
71 if (enableMultithreading()) {
72 nthreads_ = CROCODDYL_WITH_NTHREADS;
73 }
74 #endif
75 525 }
76
77 template <typename Scalar>
78 297 ShootingProblemTpl<Scalar>::ShootingProblemTpl(
79 const VectorXs& x0,
80 const std::vector<std::shared_ptr<ActionModelAbstract> >& running_models,
81 std::shared_ptr<ActionModelAbstract> terminal_model,
82 const std::vector<std::shared_ptr<ActionDataAbstract> >& running_datas,
83 std::shared_ptr<ActionDataAbstract> terminal_data)
84 297 : cost_(Scalar(0.)),
85 297 T_(running_models.size()),
86 297 x0_(x0),
87 297 terminal_model_(terminal_model),
88 297 terminal_data_(terminal_data),
89
1/2
✓ Branch 1 taken 297 times.
✗ Branch 2 not taken.
297 running_models_(running_models),
90
1/2
✓ Branch 1 taken 297 times.
✗ Branch 2 not taken.
297 running_datas_(running_datas),
91 297 nx_(running_models[0]->get_state()->get_nx()),
92 297 ndx_(running_models[0]->get_state()->get_ndx()),
93 297 nu_max_(running_models[0]->get_nu()),
94 297 nthreads_(1) {
95
2/2
✓ Branch 0 taken 5643 times.
✓ Branch 1 taken 297 times.
5940 for (std::size_t i = 1; i < T_; ++i) {
96 5643 const std::shared_ptr<ActionModelAbstract>& model = running_models_[i];
97 5643 const std::size_t nu = model->get_nu();
98
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 5643 times.
5643 if (nu_max_ < nu) {
99 nu_max_ = nu;
100 }
101 }
102
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 297 times.
297 if (static_cast<std::size_t>(x0.size()) != nx_) {
103 throw_pretty(
104 "Invalid argument: " << "x0 has wrong dimension (it should be " +
105 std::to_string(nx_) + ")");
106 }
107 297 const std::size_t Td = running_datas.size();
108
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 297 times.
297 if (Td != T_) {
109 throw_pretty(
110 "Invalid argument: "
111 << "the number of running models and datas are not the same (" +
112 std::to_string(T_) + " != " + std::to_string(Td) + ")")
113 }
114
2/2
✓ Branch 0 taken 5940 times.
✓ Branch 1 taken 297 times.
6237 for (std::size_t i = 0; i < T_; ++i) {
115 5940 const std::shared_ptr<ActionModelAbstract>& model = running_models_[i];
116 5940 const std::shared_ptr<ActionDataAbstract>& data = running_datas_[i];
117
1/2
✗ Branch 4 not taken.
✓ Branch 5 taken 5940 times.
5940 if (model->get_state()->get_nx() != nx_) {
118 throw_pretty("Invalid argument: "
119 << "nx in " << i
120 << " node is not consistent with the other nodes")
121 }
122
1/2
✗ Branch 4 not taken.
✓ Branch 5 taken 5940 times.
5940 if (model->get_state()->get_ndx() != ndx_) {
123 throw_pretty("Invalid argument: "
124 << "ndx in " << i
125 << " node is not consistent with the other nodes")
126 }
127
2/4
✓ Branch 2 taken 5940 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 5940 times.
5940 if (!model->checkData(data)) {
128 throw_pretty("Invalid argument: "
129 << "action data in " << i
130 << " node is not consistent with the action model")
131 }
132 }
133
2/4
✓ Branch 2 taken 297 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 297 times.
297 if (!terminal_model->checkData(terminal_data)) {
134 throw_pretty("Invalid argument: "
135 << "terminal action data is not consistent with the terminal "
136 "action model")
137 }
138
139 #ifdef CROCODDYL_WITH_MULTITHREADING
140 if (enableMultithreading()) {
141 nthreads_ = CROCODDYL_WITH_NTHREADS;
142 }
143 #endif
144 297 }
145
146 template <typename Scalar>
147 2 ShootingProblemTpl<Scalar>::ShootingProblemTpl(
148 const ShootingProblemTpl<Scalar>& problem)
149 2 : cost_(Scalar(0.)),
150 2 T_(problem.get_T()),
151 2 x0_(problem.get_x0()),
152 2 terminal_model_(problem.get_terminalModel()),
153 2 terminal_data_(problem.get_terminalData()),
154
1/2
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
2 running_models_(problem.get_runningModels()),
155
1/2
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
2 running_datas_(problem.get_runningDatas()),
156 2 nx_(problem.get_nx()),
157 2 ndx_(problem.get_ndx()),
158 2 nu_max_(problem.get_nu_max()) {}
159
160 template <typename Scalar>
161 824 ShootingProblemTpl<Scalar>::~ShootingProblemTpl() {}
162
163 template <typename Scalar>
164 619 Scalar ShootingProblemTpl<Scalar>::calc(const std::vector<VectorXs>& xs,
165 const std::vector<VectorXs>& us) {
166
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 619 times.
619 if (xs.size() != T_ + 1) {
167 throw_pretty(
168 "Invalid argument: " << "xs has wrong dimension (it should be " +
169 std::to_string(T_ + 1) + ")");
170 }
171
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 619 times.
619 if (us.size() != T_) {
172 throw_pretty(
173 "Invalid argument: " << "us has wrong dimension (it should be " +
174 std::to_string(T_) + ")");
175 }
176
3/16
✗ Branch 2 not taken.
✓ Branch 3 taken 619 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 619 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 619 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
619 START_PROFILER("ShootingProblem::calc");
177
178 #ifdef CROCODDYL_WITH_MULTITHREADING
179 #pragma omp parallel for num_threads(nthreads_)
180 #endif
181
2/2
✓ Branch 0 taken 10615 times.
✓ Branch 1 taken 619 times.
11234 for (std::size_t i = 0; i < T_; ++i) {
182
2/4
✓ Branch 6 taken 10615 times.
✗ Branch 7 not taken.
✓ Branch 10 taken 10615 times.
✗ Branch 11 not taken.
10615 running_models_[i]->calc(running_datas_[i], xs[i], us[i]);
183 }
184
1/2
✓ Branch 4 taken 619 times.
✗ Branch 5 not taken.
619 terminal_model_->calc(terminal_data_, xs.back());
185
186 619 cost_ = Scalar(0.);
187 #ifdef CROCODDYL_WITH_MULTITHREADING
188 #pragma omp simd reduction(+ : cost_)
189 #endif
190
2/2
✓ Branch 0 taken 10615 times.
✓ Branch 1 taken 619 times.
11234 for (std::size_t i = 0; i < T_; ++i) {
191 10615 cost_ += running_datas_[i]->cost;
192 }
193 619 cost_ += terminal_data_->cost;
194
3/16
✗ Branch 2 not taken.
✓ Branch 3 taken 619 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 619 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 619 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
619 STOP_PROFILER("ShootingProblem::calc");
195 619 return cost_;
196 }
197
198 template <typename Scalar>
199 441 Scalar ShootingProblemTpl<Scalar>::calcDiff(const std::vector<VectorXs>& xs,
200 const std::vector<VectorXs>& us) {
201
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 441 times.
441 if (xs.size() != T_ + 1) {
202 throw_pretty(
203 "Invalid argument: " << "xs has wrong dimension (it should be " +
204 std::to_string(T_ + 1) + ")");
205 }
206
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 441 times.
441 if (us.size() != T_) {
207 throw_pretty(
208 "Invalid argument: " << "us has wrong dimension (it should be " +
209 std::to_string(T_) + ")");
210 }
211
3/16
✗ Branch 2 not taken.
✓ Branch 3 taken 441 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 441 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 441 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
441 START_PROFILER("ShootingProblem::calcDiff");
212
213 #ifdef CROCODDYL_WITH_MULTITHREADING
214 #pragma omp parallel for num_threads(nthreads_)
215 #endif
216
2/2
✓ Branch 0 taken 7064 times.
✓ Branch 1 taken 441 times.
7505 for (std::size_t i = 0; i < T_; ++i) {
217
2/4
✓ Branch 6 taken 7064 times.
✗ Branch 7 not taken.
✓ Branch 10 taken 7064 times.
✗ Branch 11 not taken.
7064 running_models_[i]->calcDiff(running_datas_[i], xs[i], us[i]);
218 }
219
1/2
✓ Branch 4 taken 441 times.
✗ Branch 5 not taken.
441 terminal_model_->calcDiff(terminal_data_, xs.back());
220
221 441 cost_ = Scalar(0.);
222 #ifdef CROCODDYL_WITH_MULTITHREADING
223 #pragma omp simd reduction(+ : cost_)
224 #endif
225
2/2
✓ Branch 0 taken 7064 times.
✓ Branch 1 taken 441 times.
7505 for (std::size_t i = 0; i < T_; ++i) {
226 7064 cost_ += running_datas_[i]->cost;
227 }
228 441 cost_ += terminal_data_->cost;
229
230
3/16
✗ Branch 2 not taken.
✓ Branch 3 taken 441 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 441 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 441 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
441 STOP_PROFILER("ShootingProblem::calcDiff");
231 441 return cost_;
232 }
233
234 template <typename Scalar>
235 103 void ShootingProblemTpl<Scalar>::rollout(const std::vector<VectorXs>& us,
236 std::vector<VectorXs>& xs) {
237
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 103 times.
103 if (xs.size() != T_ + 1) {
238 throw_pretty(
239 "Invalid argument: " << "xs has wrong dimension (it should be " +
240 std::to_string(T_ + 1) + ")");
241 }
242
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 103 times.
103 if (us.size() != T_) {
243 throw_pretty(
244 "Invalid argument: " << "us has wrong dimension (it should be " +
245 std::to_string(T_) + ")");
246 }
247
3/16
✗ Branch 2 not taken.
✓ Branch 3 taken 103 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 103 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 103 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
103 START_PROFILER("ShootingProblem::rollout");
248
249 103 xs[0] = x0_;
250
2/2
✓ Branch 0 taken 2336 times.
✓ Branch 1 taken 103 times.
2439 for (std::size_t i = 0; i < T_; ++i) {
251 2336 const std::shared_ptr<ActionDataAbstract>& data = running_datas_[i];
252
2/4
✓ Branch 6 taken 2336 times.
✗ Branch 7 not taken.
✓ Branch 9 taken 2336 times.
✗ Branch 10 not taken.
2336 running_models_[i]->calc(data, xs[i], us[i]);
253 2336 xs[i + 1] = data->xnext;
254 }
255
1/2
✓ Branch 4 taken 103 times.
✗ Branch 5 not taken.
103 terminal_model_->calc(terminal_data_, xs.back());
256
3/16
✗ Branch 2 not taken.
✓ Branch 3 taken 103 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 103 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 103 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
103 STOP_PROFILER("ShootingProblem::rollout");
257 103 }
258
259 template <typename Scalar>
260 std::vector<typename MathBaseTpl<Scalar>::VectorXs>
261 4 ShootingProblemTpl<Scalar>::rollout_us(const std::vector<VectorXs>& us) {
262 4 std::vector<VectorXs> xs;
263
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 xs.resize(T_ + 1);
264
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 rollout(us, xs);
265 4 return xs;
266 }
267
268 template <typename Scalar>
269 198 void ShootingProblemTpl<Scalar>::quasiStatic(std::vector<VectorXs>& us,
270 const std::vector<VectorXs>& xs) {
271
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 198 times.
198 if (xs.size() != T_) {
272 throw_pretty(
273 "Invalid argument: " << "xs has wrong dimension (it should be " +
274 std::to_string(T_) + ")");
275 }
276
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 198 times.
198 if (us.size() != T_) {
277 throw_pretty(
278 "Invalid argument: " << "us has wrong dimension (it should be " +
279 std::to_string(T_) + ")");
280 }
281
282 #ifdef CROCODDYL_WITH_MULTITHREADING
283 #pragma omp parallel for num_threads(nthreads_)
284 #endif
285
2/2
✓ Branch 0 taken 3960 times.
✓ Branch 1 taken 198 times.
4158 for (std::size_t i = 0; i < T_; ++i) {
286
2/4
✓ Branch 6 taken 3960 times.
✗ Branch 7 not taken.
✓ Branch 10 taken 3960 times.
✗ Branch 11 not taken.
3960 running_models_[i]->quasiStatic(running_datas_[i], us[i], xs[i]);
287 }
288 198 }
289
290 template <typename Scalar>
291 std::vector<typename MathBaseTpl<Scalar>::VectorXs>
292 ShootingProblemTpl<Scalar>::quasiStatic_xs(const std::vector<VectorXs>& xs) {
293 std::vector<VectorXs> us;
294 us.resize(T_);
295 for (std::size_t i = 0; i < T_; ++i) {
296 us[i] = VectorXs::Zero(running_models_[i]->get_nu());
297 }
298 quasiStatic(us, xs);
299 return us;
300 }
301
302 template <typename Scalar>
303 void ShootingProblemTpl<Scalar>::circularAppend(
304 std::shared_ptr<ActionModelAbstract> model,
305 std::shared_ptr<ActionDataAbstract> data) {
306 if (!model->checkData(data)) {
307 throw_pretty("Invalid argument: "
308 << "action data is not consistent with the action model")
309 }
310 if (model->get_state()->get_nx() != nx_) {
311 throw_pretty(
312 "Invalid argument: " << "nx is not consistent with the other nodes")
313 }
314 if (model->get_state()->get_ndx() != ndx_) {
315 throw_pretty("Invalid argument: "
316 << "ndx node is not consistent with the other nodes")
317 }
318 is_updated_ = true;
319 for (std::size_t i = 0; i < T_ - 1; ++i) {
320 running_models_[i] = running_models_[i + 1];
321 running_datas_[i] = running_datas_[i + 1];
322 }
323 running_models_.back() = model;
324 running_datas_.back() = data;
325 }
326
327 template <typename Scalar>
328 void ShootingProblemTpl<Scalar>::circularAppend(
329 std::shared_ptr<ActionModelAbstract> model) {
330 if (model->get_state()->get_nx() != nx_) {
331 throw_pretty(
332 "Invalid argument: " << "nx is not consistent with the other nodes")
333 }
334 if (model->get_state()->get_ndx() != ndx_) {
335 throw_pretty("Invalid argument: "
336 << "ndx node is not consistent with the other nodes")
337 }
338 is_updated_ = true;
339 for (std::size_t i = 0; i < T_ - 1; ++i) {
340 running_models_[i] = running_models_[i + 1];
341 running_datas_[i] = running_datas_[i + 1];
342 }
343 running_models_.back() = model;
344 running_datas_.back() = model->createData();
345 }
346
347 template <typename Scalar>
348 void ShootingProblemTpl<Scalar>::updateNode(
349 const std::size_t i, std::shared_ptr<ActionModelAbstract> model,
350 std::shared_ptr<ActionDataAbstract> data) {
351 if (i >= T_ + 1) {
352 throw_pretty("Invalid argument: "
353 << "i is bigger than the allocated horizon (it should be less "
354 "than or equal to " +
355 std::to_string(T_ + 1) + ")");
356 }
357 if (!model->checkData(data)) {
358 throw_pretty("Invalid argument: "
359 << "action data is not consistent with the action model")
360 }
361 if (model->get_state()->get_nx() != nx_) {
362 throw_pretty(
363 "Invalid argument: " << "nx is not consistent with the other nodes")
364 }
365 if (model->get_state()->get_ndx() != ndx_) {
366 throw_pretty("Invalid argument: "
367 << "ndx node is not consistent with the other nodes")
368 }
369 is_updated_ = true;
370 if (i == T_) {
371 terminal_model_ = model;
372 terminal_data_ = data;
373 } else {
374 running_models_[i] = model;
375 running_datas_[i] = data;
376 }
377 }
378
379 template <typename Scalar>
380 void ShootingProblemTpl<Scalar>::updateModel(
381 const std::size_t i, std::shared_ptr<ActionModelAbstract> model) {
382 if (i >= T_ + 1) {
383 throw_pretty(
384 "Invalid argument: "
385 << "i is bigger than the allocated horizon (it should be lower than " +
386 std::to_string(T_ + 1) + ")");
387 }
388 if (model->get_state()->get_nx() != nx_) {
389 throw_pretty(
390 "Invalid argument: " << "nx is not consistent with the other nodes")
391 }
392 if (model->get_state()->get_ndx() != ndx_) {
393 throw_pretty(
394 "Invalid argument: " << "ndx is not consistent with the other nodes")
395 }
396 is_updated_ = true;
397 if (i == T_) {
398 terminal_model_ = model;
399 terminal_data_ = terminal_model_->createData();
400 } else {
401 running_models_[i] = model;
402 running_datas_[i] = model->createData();
403 }
404 }
405
406 template <typename Scalar>
407 6078 std::size_t ShootingProblemTpl<Scalar>::get_T() const {
408 6078 return T_;
409 }
410
411 template <typename Scalar>
412 const typename MathBaseTpl<Scalar>::VectorXs&
413 538 ShootingProblemTpl<Scalar>::get_x0() const {
414 538 return x0_;
415 }
416
417 template <typename Scalar>
418 525 void ShootingProblemTpl<Scalar>::allocateData() {
419 525 running_datas_.resize(T_);
420
2/2
✓ Branch 0 taken 10032 times.
✓ Branch 1 taken 525 times.
10557 for (std::size_t i = 0; i < T_; ++i) {
421 10032 const std::shared_ptr<ActionModelAbstract>& model = running_models_[i];
422 10032 running_datas_[i] = model->createData();
423 }
424 525 terminal_data_ = terminal_model_->createData();
425 525 }
426
427 template <typename Scalar>
428 const std::vector<std::shared_ptr<crocoddyl::ActionModelAbstractTpl<Scalar> > >&
429 5147 ShootingProblemTpl<Scalar>::get_runningModels() const {
430 5147 return running_models_;
431 }
432
433 template <typename Scalar>
434 const std::shared_ptr<crocoddyl::ActionModelAbstractTpl<Scalar> >&
435 1536 ShootingProblemTpl<Scalar>::get_terminalModel() const {
436 1536 return terminal_model_;
437 }
438
439 template <typename Scalar>
440 const std::vector<std::shared_ptr<crocoddyl::ActionDataAbstractTpl<Scalar> > >&
441 37135 ShootingProblemTpl<Scalar>::get_runningDatas() const {
442 37135 return running_datas_;
443 }
444
445 template <typename Scalar>
446 const std::shared_ptr<crocoddyl::ActionDataAbstractTpl<Scalar> >&
447 1578 ShootingProblemTpl<Scalar>::get_terminalData() const {
448 1578 return terminal_data_;
449 }
450
451 template <typename Scalar>
452 void ShootingProblemTpl<Scalar>::set_x0(const VectorXs& x0_in) {
453 if (x0_in.size() != x0_.size()) {
454 throw_pretty("Invalid argument: "
455 << "invalid size of x0 provided: Expected " << x0_.size()
456 << ", received " << x0_in.size());
457 }
458 x0_ = x0_in;
459 }
460
461 template <typename Scalar>
462 void ShootingProblemTpl<Scalar>::set_runningModels(
463 const std::vector<std::shared_ptr<ActionModelAbstract> >& models) {
464 for (std::size_t i = 0; i < T_; ++i) {
465 const std::shared_ptr<ActionModelAbstract>& model = running_models_[i];
466 if (model->get_state()->get_nx() != nx_) {
467 throw_pretty("Invalid argument: "
468 << "nx in " << i
469 << " node is not consistent with the other nodes")
470 }
471 if (model->get_state()->get_ndx() != ndx_) {
472 throw_pretty("Invalid argument: "
473 << "ndx in " << i
474 << " node is not consistent with the other nodes")
475 }
476 }
477 is_updated_ = true;
478 T_ = models.size();
479 running_models_.clear();
480 running_datas_.clear();
481 for (std::size_t i = 0; i < T_; ++i) {
482 const std::shared_ptr<ActionModelAbstract>& model = running_models_[i];
483 running_datas_.push_back(model->createData());
484 }
485 }
486
487 template <typename Scalar>
488 void ShootingProblemTpl<Scalar>::set_terminalModel(
489 std::shared_ptr<ActionModelAbstract> model) {
490 if (model->get_state()->get_nx() != nx_) {
491 throw_pretty(
492 "Invalid argument: " << "nx is not consistent with the other nodes")
493 }
494 if (model->get_state()->get_ndx() != ndx_) {
495 throw_pretty(
496 "Invalid argument: " << "ndx is not consistent with the other nodes")
497 }
498 is_updated_ = true;
499 terminal_model_ = model;
500 terminal_data_ = terminal_model_->createData();
501 }
502
503 template <typename Scalar>
504 void ShootingProblemTpl<Scalar>::set_nthreads(const int nthreads) {
505 #ifndef CROCODDYL_WITH_MULTITHREADING
506 (void)nthreads;
507 std::cerr << "Warning: the number of threads won't affect the computational "
508 "performance as multithreading "
509 "support is not enabled."
510 << std::endl;
511 #else
512 if (nthreads < 1) {
513 nthreads_ = CROCODDYL_WITH_NTHREADS;
514 } else {
515 nthreads_ = static_cast<std::size_t>(nthreads);
516 }
517 if (!enableMultithreading()) {
518 std::cerr << "Warning: the number of threads won't affect the "
519 "computational performance as multithreading "
520 "support is not enabled."
521 << std::endl;
522 nthreads_ = 1;
523 }
524 #endif
525 }
526
527 template <typename Scalar>
528 49 std::size_t ShootingProblemTpl<Scalar>::get_nx() const {
529 49 return nx_;
530 }
531
532 template <typename Scalar>
533 230 std::size_t ShootingProblemTpl<Scalar>::get_ndx() const {
534 230 return ndx_;
535 }
536
537 template <typename Scalar>
538 2 std::size_t ShootingProblemTpl<Scalar>::get_nu_max() const {
539 2 return nu_max_;
540 }
541
542 template <typename Scalar>
543 std::size_t ShootingProblemTpl<Scalar>::get_nthreads() const {
544 #ifndef CROCODDYL_WITH_MULTITHREADING
545 std::cerr << "Warning: the number of threads won't affect the computational "
546 "performance as multithreading "
547 "support is not enabled."
548 << std::endl;
549 #endif
550 return nthreads_;
551 }
552
553 template <typename Scalar>
554 32 bool ShootingProblemTpl<Scalar>::is_updated() {
555 32 const bool status = is_updated_;
556 32 is_updated_ = false;
557 32 return status;
558 }
559
560 template <typename Scalar>
561 7 std::ostream& operator<<(std::ostream& os,
562 const ShootingProblemTpl<Scalar>& problem) {
563 7 os << "ShootingProblem (T=" << problem.get_T() << ", nx=" << problem.get_nx()
564 7 << ", ndx=" << problem.get_ndx() << ") " << std::endl
565 7 << " Models:" << std::endl;
566 const std::vector<
567 std::shared_ptr<crocoddyl::ActionModelAbstractTpl<Scalar> > >&
568 7 runningModels = problem.get_runningModels();
569
2/2
✓ Branch 1 taken 140 times.
✓ Branch 2 taken 7 times.
147 for (std::size_t t = 0; t < problem.get_T(); ++t) {
570 140 os << " " << t << ": " << *runningModels[t] << std::endl;
571 }
572 7 os << " " << problem.get_T() << ": " << *problem.get_terminalModel();
573 7 return os;
574 }
575
576 } // namespace crocoddyl
577