GCC Code Coverage Report


Directory: ./
File: include/crocoddyl/core/optctrl/shooting.hxx
Date: 2025-03-26 19:23:43
Exec Total Coverage
Lines: 137 254 53.9%
Branches: 104 796 13.1%

Line Branch Exec Source
1 ///////////////////////////////////////////////////////////////////////////////
2 // BSD 3-Clause License
3 //
4 // Copyright (C) 2019-2025, LAAS-CNRS, University of Edinburgh,
5 // University of Oxford, Heriot-Watt University
6 // Copyright note valid unless otherwise stated in individual files.
7 // All rights reserved.
8 ///////////////////////////////////////////////////////////////////////////////
9
10 #include <iostream>
11 #ifdef CROCODDYL_WITH_MULTITHREADING
12 #include <omp.h>
13 #endif // CROCODDYL_WITH_MULTITHREADING
14 #include "crocoddyl/core/utils/stop-watch.hpp"
15
16 namespace crocoddyl {
17
18 template <typename Scalar>
19 525 ShootingProblemTpl<Scalar>::ShootingProblemTpl(
20 const VectorXs& x0,
21 const std::vector<std::shared_ptr<ActionModelAbstract> >& running_models,
22 std::shared_ptr<ActionModelAbstract> terminal_model)
23 525 : cost_(Scalar(0.)),
24 525 T_(running_models.size()),
25 525 x0_(x0),
26 525 terminal_model_(terminal_model),
27
1/2
✓ Branch 1 taken 525 times.
✗ Branch 2 not taken.
525 running_models_(running_models),
28
2/4
✓ Branch 3 taken 525 times.
✗ Branch 4 not taken.
✓ Branch 7 taken 525 times.
✗ Branch 8 not taken.
525 nx_(running_models[0]->get_state()->get_nx()),
29
2/4
✓ Branch 3 taken 525 times.
✗ Branch 4 not taken.
✓ Branch 7 taken 525 times.
✗ Branch 8 not taken.
525 ndx_(running_models[0]->get_state()->get_ndx()),
30 525 nthreads_(1),
31 1050 is_updated_(false) {
32
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 525 times.
525 if (static_cast<std::size_t>(x0.size()) != nx_) {
33 throw_pretty(
34 "Invalid argument: " << "x0 has wrong dimension (it should be " +
35 std::to_string(nx_) + ")");
36 }
37
2/2
✓ Branch 0 taken 9135 times.
✓ Branch 1 taken 525 times.
9660 for (std::size_t i = 1; i < T_; ++i) {
38 9135 const std::shared_ptr<ActionModelAbstract>& model = running_models_[i];
39
3/6
✓ Branch 2 taken 9135 times.
✗ Branch 3 not taken.
✓ Branch 6 taken 9135 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 9135 times.
9135 if (model->get_state()->get_nx() != nx_) {
40 throw_pretty("Invalid argument: "
41 << "nx in " << i
42 << " node is not consistent with the other nodes")
43 }
44
3/6
✓ Branch 2 taken 9135 times.
✗ Branch 3 not taken.
✓ Branch 6 taken 9135 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 9135 times.
9135 if (model->get_state()->get_ndx() != ndx_) {
45 throw_pretty("Invalid argument: "
46 << "ndx in " << i
47 << " node is not consistent with the other nodes")
48 }
49 }
50
3/6
✓ Branch 2 taken 525 times.
✗ Branch 3 not taken.
✓ Branch 6 taken 525 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 525 times.
525 if (terminal_model_->get_state()->get_nx() != nx_) {
51 throw_pretty(
52 "Invalid argument: "
53 << "nx in terminal node is not consistent with the other nodes")
54 }
55
3/6
✓ Branch 2 taken 525 times.
✗ Branch 3 not taken.
✓ Branch 6 taken 525 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 525 times.
525 if (terminal_model_->get_state()->get_ndx() != ndx_) {
56 throw_pretty(
57 "Invalid argument: "
58 << "ndx in terminal node is not consistent with the other nodes")
59 }
60
1/2
✓ Branch 1 taken 525 times.
✗ Branch 2 not taken.
525 allocateData();
61
62 #ifdef CROCODDYL_WITH_MULTITHREADING
63 if (enableMultithreading()) {
64 nthreads_ = CROCODDYL_WITH_NTHREADS;
65 }
66 #endif
67 525 }
68
69 template <typename Scalar>
70 297 ShootingProblemTpl<Scalar>::ShootingProblemTpl(
71 const VectorXs& x0,
72 const std::vector<std::shared_ptr<ActionModelAbstract> >& running_models,
73 std::shared_ptr<ActionModelAbstract> terminal_model,
74 const std::vector<std::shared_ptr<ActionDataAbstract> >& running_datas,
75 std::shared_ptr<ActionDataAbstract> terminal_data)
76 297 : cost_(Scalar(0.)),
77 297 T_(running_models.size()),
78 297 x0_(x0),
79 297 terminal_model_(terminal_model),
80 297 terminal_data_(terminal_data),
81
1/2
✓ Branch 1 taken 297 times.
✗ Branch 2 not taken.
297 running_models_(running_models),
82
1/2
✓ Branch 1 taken 297 times.
✗ Branch 2 not taken.
297 running_datas_(running_datas),
83
2/4
✓ Branch 3 taken 297 times.
✗ Branch 4 not taken.
✓ Branch 7 taken 297 times.
✗ Branch 8 not taken.
297 nx_(running_models[0]->get_state()->get_nx()),
84
2/4
✓ Branch 3 taken 297 times.
✗ Branch 4 not taken.
✓ Branch 7 taken 297 times.
✗ Branch 8 not taken.
297 ndx_(running_models[0]->get_state()->get_ndx()),
85 297 nthreads_(1) {
86
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 297 times.
297 if (static_cast<std::size_t>(x0.size()) != nx_) {
87 throw_pretty(
88 "Invalid argument: " << "x0 has wrong dimension (it should be " +
89 std::to_string(nx_) + ")");
90 }
91 297 const std::size_t Td = running_datas.size();
92
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 297 times.
297 if (Td != T_) {
93 throw_pretty(
94 "Invalid argument: "
95 << "the number of running models and datas are not the same (" +
96 std::to_string(T_) + " != " + std::to_string(Td) + ")")
97 }
98
2/2
✓ Branch 0 taken 5940 times.
✓ Branch 1 taken 297 times.
6237 for (std::size_t i = 0; i < T_; ++i) {
99 5940 const std::shared_ptr<ActionModelAbstract>& model = running_models_[i];
100 5940 const std::shared_ptr<ActionDataAbstract>& data = running_datas_[i];
101
3/6
✓ Branch 2 taken 5940 times.
✗ Branch 3 not taken.
✓ Branch 6 taken 5940 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 5940 times.
5940 if (model->get_state()->get_nx() != nx_) {
102 throw_pretty("Invalid argument: "
103 << "nx in " << i
104 << " node is not consistent with the other nodes")
105 }
106
3/6
✓ Branch 2 taken 5940 times.
✗ Branch 3 not taken.
✓ Branch 6 taken 5940 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 5940 times.
5940 if (model->get_state()->get_ndx() != ndx_) {
107 throw_pretty("Invalid argument: "
108 << "ndx in " << i
109 << " node is not consistent with the other nodes")
110 }
111
2/4
✓ Branch 2 taken 5940 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 5940 times.
5940 if (!model->checkData(data)) {
112 throw_pretty("Invalid argument: "
113 << "action data in " << i
114 << " node is not consistent with the action model")
115 }
116 }
117
2/4
✓ Branch 2 taken 297 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 297 times.
297 if (!terminal_model->checkData(terminal_data)) {
118 throw_pretty("Invalid argument: "
119 << "terminal action data is not consistent with the terminal "
120 "action model")
121 }
122
123 #ifdef CROCODDYL_WITH_MULTITHREADING
124 if (enableMultithreading()) {
125 nthreads_ = CROCODDYL_WITH_NTHREADS;
126 }
127 #endif
128 297 }
129
130 template <typename Scalar>
131 2 ShootingProblemTpl<Scalar>::ShootingProblemTpl(
132 const ShootingProblemTpl<Scalar>& problem)
133 2 : cost_(Scalar(0.)),
134 2 T_(problem.get_T()),
135 2 x0_(problem.get_x0()),
136
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 terminal_model_(problem.get_terminalModel()),
137
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 terminal_data_(problem.get_terminalData()),
138
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
2 running_models_(problem.get_runningModels()),
139
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
2 running_datas_(problem.get_runningDatas()),
140
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 nx_(problem.get_nx()),
141
1/2
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
2 ndx_(problem.get_ndx()) {}
142
143 template <typename Scalar>
144 824 ShootingProblemTpl<Scalar>::~ShootingProblemTpl() {}
145
146 template <typename Scalar>
147 619 Scalar ShootingProblemTpl<Scalar>::calc(const std::vector<VectorXs>& xs,
148 const std::vector<VectorXs>& us) {
149
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 619 times.
619 if (xs.size() != T_ + 1) {
150 throw_pretty(
151 "Invalid argument: " << "xs has wrong dimension (it should be " +
152 std::to_string(T_ + 1) + ")");
153 }
154
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 619 times.
619 if (us.size() != T_) {
155 throw_pretty(
156 "Invalid argument: " << "us has wrong dimension (it should be " +
157 std::to_string(T_) + ")");
158 }
159
3/16
✗ Branch 2 not taken.
✓ Branch 3 taken 619 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 619 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 619 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
619 START_PROFILER("ShootingProblem::calc");
160
161 #ifdef CROCODDYL_WITH_MULTITHREADING
162 #pragma omp parallel for num_threads(nthreads_)
163 #endif
164
2/2
✓ Branch 0 taken 10325 times.
✓ Branch 1 taken 619 times.
10944 for (std::size_t i = 0; i < T_; ++i) {
165
2/4
✓ Branch 6 taken 10325 times.
✗ Branch 7 not taken.
✓ Branch 10 taken 10325 times.
✗ Branch 11 not taken.
10325 running_models_[i]->calc(running_datas_[i], xs[i], us[i]);
166 }
167
1/2
✓ Branch 4 taken 619 times.
✗ Branch 5 not taken.
619 terminal_model_->calc(terminal_data_, xs.back());
168
169 619 cost_ = Scalar(0.);
170 #ifdef CROCODDYL_WITH_MULTITHREADING
171 #pragma omp simd reduction(+ : cost_)
172 #endif
173
2/2
✓ Branch 0 taken 10325 times.
✓ Branch 1 taken 619 times.
10944 for (std::size_t i = 0; i < T_; ++i) {
174 10325 cost_ += running_datas_[i]->cost;
175 }
176 619 cost_ += terminal_data_->cost;
177
3/16
✗ Branch 2 not taken.
✓ Branch 3 taken 619 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 619 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 619 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
619 STOP_PROFILER("ShootingProblem::calc");
178 619 return cost_;
179 }
180
181 template <typename Scalar>
182 443 Scalar ShootingProblemTpl<Scalar>::calcDiff(const std::vector<VectorXs>& xs,
183 const std::vector<VectorXs>& us) {
184
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 443 times.
443 if (xs.size() != T_ + 1) {
185 throw_pretty(
186 "Invalid argument: " << "xs has wrong dimension (it should be " +
187 std::to_string(T_ + 1) + ")");
188 }
189
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 443 times.
443 if (us.size() != T_) {
190 throw_pretty(
191 "Invalid argument: " << "us has wrong dimension (it should be " +
192 std::to_string(T_) + ")");
193 }
194
3/16
✗ Branch 2 not taken.
✓ Branch 3 taken 443 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 443 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 443 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
443 START_PROFILER("ShootingProblem::calcDiff");
195
196 #ifdef CROCODDYL_WITH_MULTITHREADING
197 #pragma omp parallel for num_threads(nthreads_)
198 #endif
199
2/2
✓ Branch 0 taken 6582 times.
✓ Branch 1 taken 443 times.
7025 for (std::size_t i = 0; i < T_; ++i) {
200
2/4
✓ Branch 6 taken 6582 times.
✗ Branch 7 not taken.
✓ Branch 10 taken 6582 times.
✗ Branch 11 not taken.
6582 running_models_[i]->calcDiff(running_datas_[i], xs[i], us[i]);
201 }
202
1/2
✓ Branch 4 taken 443 times.
✗ Branch 5 not taken.
443 terminal_model_->calcDiff(terminal_data_, xs.back());
203
204 443 cost_ = Scalar(0.);
205 #ifdef CROCODDYL_WITH_MULTITHREADING
206 #pragma omp simd reduction(+ : cost_)
207 #endif
208
2/2
✓ Branch 0 taken 6582 times.
✓ Branch 1 taken 443 times.
7025 for (std::size_t i = 0; i < T_; ++i) {
209 6582 cost_ += running_datas_[i]->cost;
210 }
211 443 cost_ += terminal_data_->cost;
212
213
3/16
✗ Branch 2 not taken.
✓ Branch 3 taken 443 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 443 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 443 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
443 STOP_PROFILER("ShootingProblem::calcDiff");
214 443 return cost_;
215 }
216
217 template <typename Scalar>
218 103 void ShootingProblemTpl<Scalar>::rollout(const std::vector<VectorXs>& us,
219 std::vector<VectorXs>& xs) {
220
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 103 times.
103 if (xs.size() != T_ + 1) {
221 throw_pretty(
222 "Invalid argument: " << "xs has wrong dimension (it should be " +
223 std::to_string(T_ + 1) + ")");
224 }
225
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 103 times.
103 if (us.size() != T_) {
226 throw_pretty(
227 "Invalid argument: " << "us has wrong dimension (it should be " +
228 std::to_string(T_) + ")");
229 }
230
3/16
✗ Branch 2 not taken.
✓ Branch 3 taken 103 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 103 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 103 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
103 START_PROFILER("ShootingProblem::rollout");
231
232 103 xs[0] = x0_;
233
2/2
✓ Branch 0 taken 2196 times.
✓ Branch 1 taken 103 times.
2299 for (std::size_t i = 0; i < T_; ++i) {
234 2196 const std::shared_ptr<ActionDataAbstract>& data = running_datas_[i];
235
2/4
✓ Branch 6 taken 2196 times.
✗ Branch 7 not taken.
✓ Branch 9 taken 2196 times.
✗ Branch 10 not taken.
2196 running_models_[i]->calc(data, xs[i], us[i]);
236 2196 xs[i + 1] = data->xnext;
237 }
238
1/2
✓ Branch 4 taken 103 times.
✗ Branch 5 not taken.
103 terminal_model_->calc(terminal_data_, xs.back());
239
3/16
✗ Branch 2 not taken.
✓ Branch 3 taken 103 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 103 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 103 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
103 STOP_PROFILER("ShootingProblem::rollout");
240 103 }
241
242 template <typename Scalar>
243 std::vector<typename MathBaseTpl<Scalar>::VectorXs>
244 4 ShootingProblemTpl<Scalar>::rollout_us(const std::vector<VectorXs>& us) {
245 4 std::vector<VectorXs> xs;
246
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 xs.resize(T_ + 1);
247
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 rollout(us, xs);
248 4 return xs;
249 }
250
251 template <typename Scalar>
252 198 void ShootingProblemTpl<Scalar>::quasiStatic(std::vector<VectorXs>& us,
253 const std::vector<VectorXs>& xs) {
254
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 198 times.
198 if (xs.size() != T_) {
255 throw_pretty(
256 "Invalid argument: " << "xs has wrong dimension (it should be " +
257 std::to_string(T_) + ")");
258 }
259
1/2
✗ Branch 1 not taken.
✓ Branch 2 taken 198 times.
198 if (us.size() != T_) {
260 throw_pretty(
261 "Invalid argument: " << "us has wrong dimension (it should be " +
262 std::to_string(T_) + ")");
263 }
264
265 #ifdef CROCODDYL_WITH_MULTITHREADING
266 #pragma omp parallel for num_threads(nthreads_)
267 #endif
268
2/2
✓ Branch 0 taken 3960 times.
✓ Branch 1 taken 198 times.
4158 for (std::size_t i = 0; i < T_; ++i) {
269
2/4
✓ Branch 6 taken 3960 times.
✗ Branch 7 not taken.
✓ Branch 10 taken 3960 times.
✗ Branch 11 not taken.
3960 running_models_[i]->quasiStatic(running_datas_[i], us[i], xs[i]);
270 }
271 198 }
272
273 template <typename Scalar>
274 std::vector<typename MathBaseTpl<Scalar>::VectorXs>
275 ShootingProblemTpl<Scalar>::quasiStatic_xs(const std::vector<VectorXs>& xs) {
276 std::vector<VectorXs> us;
277 us.resize(T_);
278 for (std::size_t i = 0; i < T_; ++i) {
279 us[i] = VectorXs::Zero(running_models_[i]->get_nu());
280 }
281 quasiStatic(us, xs);
282 return us;
283 }
284
285 template <typename Scalar>
286 void ShootingProblemTpl<Scalar>::circularAppend(
287 std::shared_ptr<ActionModelAbstract> model,
288 std::shared_ptr<ActionDataAbstract> data) {
289 if (!model->checkData(data)) {
290 throw_pretty("Invalid argument: "
291 << "action data is not consistent with the action model")
292 }
293 if (model->get_state()->get_nx() != nx_) {
294 throw_pretty(
295 "Invalid argument: " << "nx is not consistent with the other nodes")
296 }
297 if (model->get_state()->get_ndx() != ndx_) {
298 throw_pretty("Invalid argument: "
299 << "ndx node is not consistent with the other nodes")
300 }
301 is_updated_ = true;
302 for (std::size_t i = 0; i < T_ - 1; ++i) {
303 running_models_[i] = running_models_[i + 1];
304 running_datas_[i] = running_datas_[i + 1];
305 }
306 running_models_.back() = model;
307 running_datas_.back() = data;
308 }
309
310 template <typename Scalar>
311 void ShootingProblemTpl<Scalar>::circularAppend(
312 std::shared_ptr<ActionModelAbstract> model) {
313 if (model->get_state()->get_nx() != nx_) {
314 throw_pretty(
315 "Invalid argument: " << "nx is not consistent with the other nodes")
316 }
317 if (model->get_state()->get_ndx() != ndx_) {
318 throw_pretty("Invalid argument: "
319 << "ndx node is not consistent with the other nodes")
320 }
321 is_updated_ = true;
322 for (std::size_t i = 0; i < T_ - 1; ++i) {
323 running_models_[i] = running_models_[i + 1];
324 running_datas_[i] = running_datas_[i + 1];
325 }
326 running_models_.back() = model;
327 running_datas_.back() = model->createData();
328 }
329
330 template <typename Scalar>
331 void ShootingProblemTpl<Scalar>::updateNode(
332 const std::size_t i, std::shared_ptr<ActionModelAbstract> model,
333 std::shared_ptr<ActionDataAbstract> data) {
334 if (i >= T_ + 1) {
335 throw_pretty("Invalid argument: "
336 << "i is bigger than the allocated horizon (it should be less "
337 "than or equal to " +
338 std::to_string(T_ + 1) + ")");
339 }
340 if (!model->checkData(data)) {
341 throw_pretty("Invalid argument: "
342 << "action data is not consistent with the action model")
343 }
344 if (model->get_state()->get_nx() != nx_) {
345 throw_pretty(
346 "Invalid argument: " << "nx is not consistent with the other nodes")
347 }
348 if (model->get_state()->get_ndx() != ndx_) {
349 throw_pretty("Invalid argument: "
350 << "ndx node is not consistent with the other nodes")
351 }
352 is_updated_ = true;
353 if (i == T_) {
354 terminal_model_ = model;
355 terminal_data_ = data;
356 } else {
357 running_models_[i] = model;
358 running_datas_[i] = data;
359 }
360 }
361
362 template <typename Scalar>
363 void ShootingProblemTpl<Scalar>::updateModel(
364 const std::size_t i, std::shared_ptr<ActionModelAbstract> model) {
365 if (i >= T_ + 1) {
366 throw_pretty(
367 "Invalid argument: "
368 << "i is bigger than the allocated horizon (it should be lower than " +
369 std::to_string(T_ + 1) + ")");
370 }
371 if (model->get_state()->get_nx() != nx_) {
372 throw_pretty(
373 "Invalid argument: " << "nx is not consistent with the other nodes")
374 }
375 if (model->get_state()->get_ndx() != ndx_) {
376 throw_pretty(
377 "Invalid argument: " << "ndx is not consistent with the other nodes")
378 }
379 is_updated_ = true;
380 if (i == T_) {
381 terminal_model_ = model;
382 terminal_data_ = terminal_model_->createData();
383 } else {
384 running_models_[i] = model;
385 running_datas_[i] = model->createData();
386 }
387 }
388
389 template <typename Scalar>
390 template <typename NewScalar>
391 ShootingProblemTpl<NewScalar> ShootingProblemTpl<Scalar>::cast() const {
392 typedef ShootingProblemTpl<NewScalar> ReturnType;
393 ReturnType ret(x0_.template cast<NewScalar>(),
394 vector_cast<NewScalar>(running_models_),
395 terminal_model_->template cast<NewScalar>());
396 ret.set_nthreads((int)nthreads_);
397 return ret;
398 }
399
400 template <typename Scalar>
401 5652 std::size_t ShootingProblemTpl<Scalar>::get_T() const {
402 5652 return T_;
403 }
404
405 template <typename Scalar>
406 const typename MathBaseTpl<Scalar>::VectorXs&
407 545 ShootingProblemTpl<Scalar>::get_x0() const {
408 545 return x0_;
409 }
410
411 template <typename Scalar>
412 525 void ShootingProblemTpl<Scalar>::allocateData() {
413 525 running_datas_.resize(T_);
414
2/2
✓ Branch 0 taken 9660 times.
✓ Branch 1 taken 525 times.
10185 for (std::size_t i = 0; i < T_; ++i) {
415 9660 const std::shared_ptr<ActionModelAbstract>& model = running_models_[i];
416 9660 running_datas_[i] = model->createData();
417 }
418 525 terminal_data_ = terminal_model_->createData();
419 525 }
420
421 template <typename Scalar>
422 const std::vector<std::shared_ptr<crocoddyl::ActionModelAbstractTpl<Scalar> > >&
423 5079 ShootingProblemTpl<Scalar>::get_runningModels() const {
424 5079 return running_models_;
425 }
426
427 template <typename Scalar>
428 const std::shared_ptr<crocoddyl::ActionModelAbstractTpl<Scalar> >&
429 1552 ShootingProblemTpl<Scalar>::get_terminalModel() const {
430 1552 return terminal_model_;
431 }
432
433 template <typename Scalar>
434 const std::vector<std::shared_ptr<crocoddyl::ActionDataAbstractTpl<Scalar> > >&
435 37149 ShootingProblemTpl<Scalar>::get_runningDatas() const {
436 37149 return running_datas_;
437 }
438
439 template <typename Scalar>
440 const std::shared_ptr<crocoddyl::ActionDataAbstractTpl<Scalar> >&
441 1585 ShootingProblemTpl<Scalar>::get_terminalData() const {
442 1585 return terminal_data_;
443 }
444
445 template <typename Scalar>
446 void ShootingProblemTpl<Scalar>::set_x0(const VectorXs& x0_in) {
447 if (x0_in.size() != x0_.size()) {
448 throw_pretty("Invalid argument: "
449 << "invalid size of x0 provided: Expected " << x0_.size()
450 << ", received " << x0_in.size());
451 }
452 x0_ = x0_in;
453 }
454
455 template <typename Scalar>
456 void ShootingProblemTpl<Scalar>::set_runningModels(
457 const std::vector<std::shared_ptr<ActionModelAbstract> >& models) {
458 for (std::size_t i = 0; i < T_; ++i) {
459 const std::shared_ptr<ActionModelAbstract>& model = running_models_[i];
460 if (model->get_state()->get_nx() != nx_) {
461 throw_pretty("Invalid argument: "
462 << "nx in " << i
463 << " node is not consistent with the other nodes")
464 }
465 if (model->get_state()->get_ndx() != ndx_) {
466 throw_pretty("Invalid argument: "
467 << "ndx in " << i
468 << " node is not consistent with the other nodes")
469 }
470 }
471 is_updated_ = true;
472 T_ = models.size();
473 running_models_.clear();
474 running_datas_.clear();
475 for (std::size_t i = 0; i < T_; ++i) {
476 const std::shared_ptr<ActionModelAbstract>& model = running_models_[i];
477 running_datas_.push_back(model->createData());
478 }
479 }
480
481 template <typename Scalar>
482 void ShootingProblemTpl<Scalar>::set_terminalModel(
483 std::shared_ptr<ActionModelAbstract> model) {
484 if (model->get_state()->get_nx() != nx_) {
485 throw_pretty(
486 "Invalid argument: " << "nx is not consistent with the other nodes")
487 }
488 if (model->get_state()->get_ndx() != ndx_) {
489 throw_pretty(
490 "Invalid argument: " << "ndx is not consistent with the other nodes")
491 }
492 is_updated_ = true;
493 terminal_model_ = model;
494 terminal_data_ = terminal_model_->createData();
495 }
496
497 template <typename Scalar>
498 void ShootingProblemTpl<Scalar>::set_nthreads(const int nthreads) {
499 #ifndef CROCODDYL_WITH_MULTITHREADING
500 (void)nthreads;
501 std::cerr << "Warning: the number of threads won't affect the computational "
502 "performance as multithreading "
503 "support is not enabled."
504 << std::endl;
505 #else
506 if (nthreads < 1) {
507 nthreads_ = CROCODDYL_WITH_NTHREADS;
508 } else {
509 nthreads_ = static_cast<std::size_t>(nthreads);
510 }
511 if (!enableMultithreading()) {
512 std::cerr << "Warning: the number of threads won't affect the "
513 "computational performance as multithreading "
514 "support is not enabled."
515 << std::endl;
516 nthreads_ = 1;
517 }
518 #endif
519 }
520
521 template <typename Scalar>
522 49 std::size_t ShootingProblemTpl<Scalar>::get_nx() const {
523 49 return nx_;
524 }
525
526 template <typename Scalar>
527 230 std::size_t ShootingProblemTpl<Scalar>::get_ndx() const {
528 230 return ndx_;
529 }
530
531 template <typename Scalar>
532 std::size_t ShootingProblemTpl<Scalar>::get_nthreads() const {
533 #ifndef CROCODDYL_WITH_MULTITHREADING
534 std::cerr << "Warning: the number of threads won't affect the computational "
535 "performance as multithreading "
536 "support is not enabled."
537 << std::endl;
538 #endif
539 return nthreads_;
540 }
541
542 template <typename Scalar>
543 32 bool ShootingProblemTpl<Scalar>::is_updated() {
544 32 const bool status = is_updated_;
545 32 is_updated_ = false;
546 32 return status;
547 }
548
549 template <typename Scalar>
550 7 std::ostream& operator<<(std::ostream& os,
551 const ShootingProblemTpl<Scalar>& problem) {
552 7 os << "ShootingProblem (T=" << problem.get_T() << ", nx=" << problem.get_nx()
553 7 << ", ndx=" << problem.get_ndx() << ") " << std::endl
554 7 << " Models:" << std::endl;
555 const std::vector<
556 std::shared_ptr<crocoddyl::ActionModelAbstractTpl<Scalar> > >&
557 7 runningModels = problem.get_runningModels();
558
2/2
✓ Branch 1 taken 140 times.
✓ Branch 2 taken 7 times.
147 for (std::size_t t = 0; t < problem.get_T(); ++t) {
559 140 os << " " << t << ": " << *runningModels[t] << std::endl;
560 }
561 7 os << " " << problem.get_T() << ": " << *problem.get_terminalModel();
562 7 return os;
563 }
564
565 } // namespace crocoddyl
566