Line |
Branch |
Exec |
Source |
1 |
|
|
/////////////////////////////////////////////////////////////////////////////// |
2 |
|
|
// BSD 3-Clause License |
3 |
|
|
// |
4 |
|
|
// Copyright (C) 2019-2025, LAAS-CNRS, University of Edinburgh, |
5 |
|
|
// University of Oxford, Heriot-Watt University |
6 |
|
|
// Copyright note valid unless otherwise stated in individual files. |
7 |
|
|
// All rights reserved. |
8 |
|
|
/////////////////////////////////////////////////////////////////////////////// |
9 |
|
|
|
10 |
|
|
#include "crocoddyl/core/utils/stop-watch.hpp" |
11 |
|
|
|
12 |
|
|
namespace crocoddyl { |
13 |
|
|
|
14 |
|
|
template <typename Scalar> |
15 |
|
✗ |
ShootingProblemTpl<Scalar>::ShootingProblemTpl( |
16 |
|
|
const VectorXs& x0, |
17 |
|
|
const std::vector<std::shared_ptr<ActionModelAbstract> >& running_models, |
18 |
|
|
std::shared_ptr<ActionModelAbstract> terminal_model) |
19 |
|
✗ |
: cost_(Scalar(0.)), |
20 |
|
✗ |
T_(running_models.size()), |
21 |
|
✗ |
x0_(x0), |
22 |
|
✗ |
terminal_model_(terminal_model), |
23 |
|
✗ |
running_models_(running_models), |
24 |
|
✗ |
nx_(running_models[0]->get_state()->get_nx()), |
25 |
|
✗ |
ndx_(running_models[0]->get_state()->get_ndx()), |
26 |
|
✗ |
nthreads_(1), |
27 |
|
✗ |
is_updated_(false) { |
28 |
|
✗ |
if (static_cast<std::size_t>(x0.size()) != nx_) { |
29 |
|
✗ |
throw_pretty( |
30 |
|
|
"Invalid argument: " << "x0 has wrong dimension (it should be " + |
31 |
|
|
std::to_string(nx_) + ")"); |
32 |
|
|
} |
33 |
|
✗ |
for (std::size_t i = 1; i < T_; ++i) { |
34 |
|
✗ |
const std::shared_ptr<ActionModelAbstract>& model = running_models_[i]; |
35 |
|
✗ |
if (model->get_state()->get_nx() != nx_) { |
36 |
|
✗ |
throw_pretty("Invalid argument: " |
37 |
|
|
<< "nx in " << i |
38 |
|
|
<< " node is not consistent with the other nodes") |
39 |
|
|
} |
40 |
|
✗ |
if (model->get_state()->get_ndx() != ndx_) { |
41 |
|
✗ |
throw_pretty("Invalid argument: " |
42 |
|
|
<< "ndx in " << i |
43 |
|
|
<< " node is not consistent with the other nodes") |
44 |
|
|
} |
45 |
|
|
} |
46 |
|
✗ |
if (terminal_model_->get_state()->get_nx() != nx_) { |
47 |
|
✗ |
throw_pretty( |
48 |
|
|
"Invalid argument: " |
49 |
|
|
<< "nx in terminal node is not consistent with the other nodes") |
50 |
|
|
} |
51 |
|
✗ |
if (terminal_model_->get_state()->get_ndx() != ndx_) { |
52 |
|
✗ |
throw_pretty( |
53 |
|
|
"Invalid argument: " |
54 |
|
|
<< "ndx in terminal node is not consistent with the other nodes") |
55 |
|
|
} |
56 |
|
✗ |
allocateData(); |
57 |
|
|
|
58 |
|
|
#ifdef CROCODDYL_WITH_MULTITHREADING |
59 |
|
|
if (enableMultithreading()) { |
60 |
|
|
nthreads_ = CROCODDYL_WITH_NTHREADS; |
61 |
|
|
} |
62 |
|
|
#endif |
63 |
|
|
} |
64 |
|
|
|
65 |
|
|
template <typename Scalar> |
66 |
|
✗ |
ShootingProblemTpl<Scalar>::ShootingProblemTpl( |
67 |
|
|
const VectorXs& x0, |
68 |
|
|
const std::vector<std::shared_ptr<ActionModelAbstract> >& running_models, |
69 |
|
|
std::shared_ptr<ActionModelAbstract> terminal_model, |
70 |
|
|
const std::vector<std::shared_ptr<ActionDataAbstract> >& running_datas, |
71 |
|
|
std::shared_ptr<ActionDataAbstract> terminal_data) |
72 |
|
✗ |
: cost_(Scalar(0.)), |
73 |
|
✗ |
T_(running_models.size()), |
74 |
|
✗ |
x0_(x0), |
75 |
|
✗ |
terminal_model_(terminal_model), |
76 |
|
✗ |
terminal_data_(terminal_data), |
77 |
|
✗ |
running_models_(running_models), |
78 |
|
✗ |
running_datas_(running_datas), |
79 |
|
✗ |
nx_(running_models[0]->get_state()->get_nx()), |
80 |
|
✗ |
ndx_(running_models[0]->get_state()->get_ndx()), |
81 |
|
✗ |
nthreads_(1) { |
82 |
|
✗ |
if (static_cast<std::size_t>(x0.size()) != nx_) { |
83 |
|
✗ |
throw_pretty( |
84 |
|
|
"Invalid argument: " << "x0 has wrong dimension (it should be " + |
85 |
|
|
std::to_string(nx_) + ")"); |
86 |
|
|
} |
87 |
|
✗ |
const std::size_t Td = running_datas.size(); |
88 |
|
✗ |
if (Td != T_) { |
89 |
|
✗ |
throw_pretty( |
90 |
|
|
"Invalid argument: " |
91 |
|
|
<< "the number of running models and datas are not the same (" + |
92 |
|
|
std::to_string(T_) + " != " + std::to_string(Td) + ")") |
93 |
|
|
} |
94 |
|
✗ |
for (std::size_t i = 0; i < T_; ++i) { |
95 |
|
✗ |
const std::shared_ptr<ActionModelAbstract>& model = running_models_[i]; |
96 |
|
✗ |
const std::shared_ptr<ActionDataAbstract>& data = running_datas_[i]; |
97 |
|
✗ |
if (model->get_state()->get_nx() != nx_) { |
98 |
|
✗ |
throw_pretty("Invalid argument: " |
99 |
|
|
<< "nx in " << i |
100 |
|
|
<< " node is not consistent with the other nodes") |
101 |
|
|
} |
102 |
|
✗ |
if (model->get_state()->get_ndx() != ndx_) { |
103 |
|
✗ |
throw_pretty("Invalid argument: " |
104 |
|
|
<< "ndx in " << i |
105 |
|
|
<< " node is not consistent with the other nodes") |
106 |
|
|
} |
107 |
|
✗ |
if (!model->checkData(data)) { |
108 |
|
✗ |
throw_pretty("Invalid argument: " |
109 |
|
|
<< "action data in " << i |
110 |
|
|
<< " node is not consistent with the action model") |
111 |
|
|
} |
112 |
|
|
} |
113 |
|
✗ |
if (!terminal_model->checkData(terminal_data)) { |
114 |
|
✗ |
throw_pretty("Invalid argument: " |
115 |
|
|
<< "terminal action data is not consistent with the terminal " |
116 |
|
|
"action model") |
117 |
|
|
} |
118 |
|
|
|
119 |
|
|
#ifdef CROCODDYL_WITH_MULTITHREADING |
120 |
|
|
if (enableMultithreading()) { |
121 |
|
|
nthreads_ = CROCODDYL_WITH_NTHREADS; |
122 |
|
|
} |
123 |
|
|
#endif |
124 |
|
|
} |
125 |
|
|
|
126 |
|
|
template <typename Scalar> |
127 |
|
✗ |
ShootingProblemTpl<Scalar>::ShootingProblemTpl( |
128 |
|
|
const ShootingProblemTpl<Scalar>& problem) |
129 |
|
✗ |
: cost_(Scalar(0.)), |
130 |
|
✗ |
T_(problem.get_T()), |
131 |
|
✗ |
x0_(problem.get_x0()), |
132 |
|
✗ |
terminal_model_(problem.get_terminalModel()), |
133 |
|
✗ |
terminal_data_(problem.get_terminalData()), |
134 |
|
✗ |
running_models_(problem.get_runningModels()), |
135 |
|
✗ |
running_datas_(problem.get_runningDatas()), |
136 |
|
✗ |
nx_(problem.get_nx()), |
137 |
|
✗ |
ndx_(problem.get_ndx()) {} |
138 |
|
|
|
139 |
|
|
template <typename Scalar> |
140 |
|
✗ |
ShootingProblemTpl<Scalar>::~ShootingProblemTpl() {} |
141 |
|
|
|
142 |
|
|
template <typename Scalar> |
143 |
|
✗ |
Scalar ShootingProblemTpl<Scalar>::calc(const std::vector<VectorXs>& xs, |
144 |
|
|
const std::vector<VectorXs>& us) { |
145 |
|
✗ |
if (xs.size() != T_ + 1) { |
146 |
|
✗ |
throw_pretty( |
147 |
|
|
"Invalid argument: " << "xs has wrong dimension (it should be " + |
148 |
|
|
std::to_string(T_ + 1) + ")"); |
149 |
|
|
} |
150 |
|
✗ |
if (us.size() != T_) { |
151 |
|
✗ |
throw_pretty( |
152 |
|
|
"Invalid argument: " << "us has wrong dimension (it should be " + |
153 |
|
|
std::to_string(T_) + ")"); |
154 |
|
|
} |
155 |
|
✗ |
START_PROFILER("ShootingProblem::calc"); |
156 |
|
|
|
157 |
|
|
#ifdef CROCODDYL_WITH_MULTITHREADING |
158 |
|
|
#pragma omp parallel for num_threads(nthreads_) |
159 |
|
|
#endif |
160 |
|
✗ |
for (std::size_t i = 0; i < T_; ++i) { |
161 |
|
✗ |
running_models_[i]->calc(running_datas_[i], xs[i], us[i]); |
162 |
|
|
} |
163 |
|
✗ |
terminal_model_->calc(terminal_data_, xs.back()); |
164 |
|
|
|
165 |
|
✗ |
cost_ = Scalar(0.); |
166 |
|
|
#ifdef CROCODDYL_WITH_MULTITHREADING |
167 |
|
|
#pragma omp simd reduction(+ : cost_) |
168 |
|
|
#endif |
169 |
|
✗ |
for (std::size_t i = 0; i < T_; ++i) { |
170 |
|
✗ |
cost_ += running_datas_[i]->cost; |
171 |
|
|
} |
172 |
|
✗ |
cost_ += terminal_data_->cost; |
173 |
|
✗ |
STOP_PROFILER("ShootingProblem::calc"); |
174 |
|
✗ |
return cost_; |
175 |
|
|
} |
176 |
|
|
|
177 |
|
|
template <typename Scalar> |
178 |
|
✗ |
Scalar ShootingProblemTpl<Scalar>::calcDiff(const std::vector<VectorXs>& xs, |
179 |
|
|
const std::vector<VectorXs>& us) { |
180 |
|
✗ |
if (xs.size() != T_ + 1) { |
181 |
|
✗ |
throw_pretty( |
182 |
|
|
"Invalid argument: " << "xs has wrong dimension (it should be " + |
183 |
|
|
std::to_string(T_ + 1) + ")"); |
184 |
|
|
} |
185 |
|
✗ |
if (us.size() != T_) { |
186 |
|
✗ |
throw_pretty( |
187 |
|
|
"Invalid argument: " << "us has wrong dimension (it should be " + |
188 |
|
|
std::to_string(T_) + ")"); |
189 |
|
|
} |
190 |
|
✗ |
START_PROFILER("ShootingProblem::calcDiff"); |
191 |
|
|
|
192 |
|
|
#ifdef CROCODDYL_WITH_MULTITHREADING |
193 |
|
|
#pragma omp parallel for num_threads(nthreads_) |
194 |
|
|
#endif |
195 |
|
✗ |
for (std::size_t i = 0; i < T_; ++i) { |
196 |
|
✗ |
running_models_[i]->calcDiff(running_datas_[i], xs[i], us[i]); |
197 |
|
|
} |
198 |
|
✗ |
terminal_model_->calcDiff(terminal_data_, xs.back()); |
199 |
|
|
|
200 |
|
✗ |
cost_ = Scalar(0.); |
201 |
|
|
// Apply SIMD only for floating-point types |
202 |
|
|
if (std::is_floating_point<Scalar>::value) { |
203 |
|
|
#ifdef CROCODDYL_WITH_MULTITHREADING |
204 |
|
|
#pragma omp simd reduction(+ : cost_) |
205 |
|
|
#endif |
206 |
|
✗ |
for (std::size_t i = 0; i < T_; ++i) { |
207 |
|
✗ |
cost_ += running_datas_[i]->cost; |
208 |
|
|
} |
209 |
|
✗ |
cost_ += terminal_data_->cost; |
210 |
|
|
} else { // For non-floating-point types (e.g., CppAD types), use the normal |
211 |
|
|
// loop without SIMD |
212 |
|
|
for (std::size_t i = 0; i < T_; ++i) { |
213 |
|
|
cost_ += running_datas_[i]->cost; |
214 |
|
|
} |
215 |
|
|
cost_ += terminal_data_->cost; |
216 |
|
|
} |
217 |
|
✗ |
STOP_PROFILER("ShootingProblem::calcDiff"); |
218 |
|
✗ |
return cost_; |
219 |
|
|
} |
220 |
|
|
|
221 |
|
|
template <typename Scalar> |
222 |
|
✗ |
void ShootingProblemTpl<Scalar>::rollout(const std::vector<VectorXs>& us, |
223 |
|
|
std::vector<VectorXs>& xs) { |
224 |
|
✗ |
if (xs.size() != T_ + 1) { |
225 |
|
✗ |
throw_pretty( |
226 |
|
|
"Invalid argument: " << "xs has wrong dimension (it should be " + |
227 |
|
|
std::to_string(T_ + 1) + ")"); |
228 |
|
|
} |
229 |
|
✗ |
if (us.size() != T_) { |
230 |
|
✗ |
throw_pretty( |
231 |
|
|
"Invalid argument: " << "us has wrong dimension (it should be " + |
232 |
|
|
std::to_string(T_) + ")"); |
233 |
|
|
} |
234 |
|
✗ |
START_PROFILER("ShootingProblem::rollout"); |
235 |
|
|
|
236 |
|
✗ |
xs[0] = x0_; |
237 |
|
✗ |
for (std::size_t i = 0; i < T_; ++i) { |
238 |
|
✗ |
const std::shared_ptr<ActionDataAbstract>& data = running_datas_[i]; |
239 |
|
✗ |
running_models_[i]->calc(data, xs[i], us[i]); |
240 |
|
✗ |
xs[i + 1] = data->xnext; |
241 |
|
|
} |
242 |
|
✗ |
terminal_model_->calc(terminal_data_, xs.back()); |
243 |
|
✗ |
STOP_PROFILER("ShootingProblem::rollout"); |
244 |
|
|
} |
245 |
|
|
|
246 |
|
|
template <typename Scalar> |
247 |
|
|
std::vector<typename MathBaseTpl<Scalar>::VectorXs> |
248 |
|
✗ |
ShootingProblemTpl<Scalar>::rollout_us(const std::vector<VectorXs>& us) { |
249 |
|
✗ |
std::vector<VectorXs> xs; |
250 |
|
✗ |
xs.resize(T_ + 1); |
251 |
|
✗ |
rollout(us, xs); |
252 |
|
✗ |
return xs; |
253 |
|
|
} |
254 |
|
|
|
255 |
|
|
template <typename Scalar> |
256 |
|
✗ |
void ShootingProblemTpl<Scalar>::quasiStatic(std::vector<VectorXs>& us, |
257 |
|
|
const std::vector<VectorXs>& xs) { |
258 |
|
✗ |
if (xs.size() != T_) { |
259 |
|
✗ |
throw_pretty( |
260 |
|
|
"Invalid argument: " << "xs has wrong dimension (it should be " + |
261 |
|
|
std::to_string(T_) + ")"); |
262 |
|
|
} |
263 |
|
✗ |
if (us.size() != T_) { |
264 |
|
✗ |
throw_pretty( |
265 |
|
|
"Invalid argument: " << "us has wrong dimension (it should be " + |
266 |
|
|
std::to_string(T_) + ")"); |
267 |
|
|
} |
268 |
|
|
|
269 |
|
|
#ifdef CROCODDYL_WITH_MULTITHREADING |
270 |
|
|
#pragma omp parallel for num_threads(nthreads_) |
271 |
|
|
#endif |
272 |
|
✗ |
for (std::size_t i = 0; i < T_; ++i) { |
273 |
|
✗ |
running_models_[i]->quasiStatic(running_datas_[i], us[i], xs[i]); |
274 |
|
|
} |
275 |
|
|
} |
276 |
|
|
|
277 |
|
|
template <typename Scalar> |
278 |
|
|
std::vector<typename MathBaseTpl<Scalar>::VectorXs> |
279 |
|
✗ |
ShootingProblemTpl<Scalar>::quasiStatic_xs(const std::vector<VectorXs>& xs) { |
280 |
|
✗ |
std::vector<VectorXs> us; |
281 |
|
✗ |
us.resize(T_); |
282 |
|
✗ |
for (std::size_t i = 0; i < T_; ++i) { |
283 |
|
✗ |
us[i] = VectorXs::Zero(running_models_[i]->get_nu()); |
284 |
|
|
} |
285 |
|
✗ |
quasiStatic(us, xs); |
286 |
|
✗ |
return us; |
287 |
|
|
} |
288 |
|
|
|
289 |
|
|
template <typename Scalar> |
290 |
|
✗ |
void ShootingProblemTpl<Scalar>::circularAppend( |
291 |
|
|
std::shared_ptr<ActionModelAbstract> model, |
292 |
|
|
std::shared_ptr<ActionDataAbstract> data) { |
293 |
|
✗ |
if (!model->checkData(data)) { |
294 |
|
✗ |
throw_pretty("Invalid argument: " |
295 |
|
|
<< "action data is not consistent with the action model") |
296 |
|
|
} |
297 |
|
✗ |
if (model->get_state()->get_nx() != nx_) { |
298 |
|
✗ |
throw_pretty( |
299 |
|
|
"Invalid argument: " << "nx is not consistent with the other nodes") |
300 |
|
|
} |
301 |
|
✗ |
if (model->get_state()->get_ndx() != ndx_) { |
302 |
|
✗ |
throw_pretty("Invalid argument: " |
303 |
|
|
<< "ndx node is not consistent with the other nodes") |
304 |
|
|
} |
305 |
|
✗ |
is_updated_ = true; |
306 |
|
✗ |
for (std::size_t i = 0; i < T_ - 1; ++i) { |
307 |
|
✗ |
running_models_[i] = running_models_[i + 1]; |
308 |
|
✗ |
running_datas_[i] = running_datas_[i + 1]; |
309 |
|
|
} |
310 |
|
✗ |
running_models_.back() = model; |
311 |
|
✗ |
running_datas_.back() = data; |
312 |
|
|
} |
313 |
|
|
|
314 |
|
|
template <typename Scalar> |
315 |
|
✗ |
void ShootingProblemTpl<Scalar>::circularAppend( |
316 |
|
|
std::shared_ptr<ActionModelAbstract> model) { |
317 |
|
✗ |
if (model->get_state()->get_nx() != nx_) { |
318 |
|
✗ |
throw_pretty( |
319 |
|
|
"Invalid argument: " << "nx is not consistent with the other nodes") |
320 |
|
|
} |
321 |
|
✗ |
if (model->get_state()->get_ndx() != ndx_) { |
322 |
|
✗ |
throw_pretty("Invalid argument: " |
323 |
|
|
<< "ndx node is not consistent with the other nodes") |
324 |
|
|
} |
325 |
|
✗ |
is_updated_ = true; |
326 |
|
✗ |
for (std::size_t i = 0; i < T_ - 1; ++i) { |
327 |
|
✗ |
running_models_[i] = running_models_[i + 1]; |
328 |
|
✗ |
running_datas_[i] = running_datas_[i + 1]; |
329 |
|
|
} |
330 |
|
✗ |
running_models_.back() = model; |
331 |
|
✗ |
running_datas_.back() = model->createData(); |
332 |
|
|
} |
333 |
|
|
|
334 |
|
|
template <typename Scalar> |
335 |
|
✗ |
void ShootingProblemTpl<Scalar>::updateNode( |
336 |
|
|
const std::size_t i, std::shared_ptr<ActionModelAbstract> model, |
337 |
|
|
std::shared_ptr<ActionDataAbstract> data) { |
338 |
|
✗ |
if (i >= T_ + 1) { |
339 |
|
✗ |
throw_pretty("Invalid argument: " |
340 |
|
|
<< "i is bigger than the allocated horizon (it should be less " |
341 |
|
|
"than or equal to " + |
342 |
|
|
std::to_string(T_ + 1) + ")"); |
343 |
|
|
} |
344 |
|
✗ |
if (!model->checkData(data)) { |
345 |
|
✗ |
throw_pretty("Invalid argument: " |
346 |
|
|
<< "action data is not consistent with the action model") |
347 |
|
|
} |
348 |
|
✗ |
if (model->get_state()->get_nx() != nx_) { |
349 |
|
✗ |
throw_pretty( |
350 |
|
|
"Invalid argument: " << "nx is not consistent with the other nodes") |
351 |
|
|
} |
352 |
|
✗ |
if (model->get_state()->get_ndx() != ndx_) { |
353 |
|
✗ |
throw_pretty("Invalid argument: " |
354 |
|
|
<< "ndx node is not consistent with the other nodes") |
355 |
|
|
} |
356 |
|
✗ |
is_updated_ = true; |
357 |
|
✗ |
if (i == T_) { |
358 |
|
✗ |
terminal_model_ = model; |
359 |
|
✗ |
terminal_data_ = data; |
360 |
|
|
} else { |
361 |
|
✗ |
running_models_[i] = model; |
362 |
|
✗ |
running_datas_[i] = data; |
363 |
|
|
} |
364 |
|
|
} |
365 |
|
|
|
366 |
|
|
template <typename Scalar> |
367 |
|
✗ |
void ShootingProblemTpl<Scalar>::updateModel( |
368 |
|
|
const std::size_t i, std::shared_ptr<ActionModelAbstract> model) { |
369 |
|
✗ |
if (i >= T_ + 1) { |
370 |
|
✗ |
throw_pretty( |
371 |
|
|
"Invalid argument: " |
372 |
|
|
<< "i is bigger than the allocated horizon (it should be lower than " + |
373 |
|
|
std::to_string(T_ + 1) + ")"); |
374 |
|
|
} |
375 |
|
✗ |
if (model->get_state()->get_nx() != nx_) { |
376 |
|
✗ |
throw_pretty( |
377 |
|
|
"Invalid argument: " << "nx is not consistent with the other nodes") |
378 |
|
|
} |
379 |
|
✗ |
if (model->get_state()->get_ndx() != ndx_) { |
380 |
|
✗ |
throw_pretty( |
381 |
|
|
"Invalid argument: " << "ndx is not consistent with the other nodes") |
382 |
|
|
} |
383 |
|
✗ |
is_updated_ = true; |
384 |
|
✗ |
if (i == T_) { |
385 |
|
✗ |
terminal_model_ = model; |
386 |
|
✗ |
terminal_data_ = terminal_model_->createData(); |
387 |
|
|
} else { |
388 |
|
✗ |
running_models_[i] = model; |
389 |
|
✗ |
running_datas_[i] = model->createData(); |
390 |
|
|
} |
391 |
|
|
} |
392 |
|
|
|
393 |
|
|
template <typename Scalar> |
394 |
|
|
template <typename NewScalar> |
395 |
|
✗ |
ShootingProblemTpl<NewScalar> ShootingProblemTpl<Scalar>::cast() const { |
396 |
|
|
typedef ShootingProblemTpl<NewScalar> ReturnType; |
397 |
|
✗ |
ReturnType ret(x0_.template cast<NewScalar>(), |
398 |
|
✗ |
vector_cast<NewScalar>(running_models_), |
399 |
|
✗ |
terminal_model_->template cast<NewScalar>()); |
400 |
|
✗ |
ret.set_nthreads((int)nthreads_); |
401 |
|
✗ |
return ret; |
402 |
|
|
} |
403 |
|
|
|
404 |
|
|
template <typename Scalar> |
405 |
|
✗ |
std::size_t ShootingProblemTpl<Scalar>::get_T() const { |
406 |
|
✗ |
return T_; |
407 |
|
|
} |
408 |
|
|
|
409 |
|
|
template <typename Scalar> |
410 |
|
|
const typename MathBaseTpl<Scalar>::VectorXs& |
411 |
|
✗ |
ShootingProblemTpl<Scalar>::get_x0() const { |
412 |
|
✗ |
return x0_; |
413 |
|
|
} |
414 |
|
|
|
415 |
|
|
template <typename Scalar> |
416 |
|
✗ |
void ShootingProblemTpl<Scalar>::allocateData() { |
417 |
|
✗ |
running_datas_.resize(T_); |
418 |
|
✗ |
for (std::size_t i = 0; i < T_; ++i) { |
419 |
|
✗ |
const std::shared_ptr<ActionModelAbstract>& model = running_models_[i]; |
420 |
|
✗ |
running_datas_[i] = model->createData(); |
421 |
|
|
} |
422 |
|
✗ |
terminal_data_ = terminal_model_->createData(); |
423 |
|
|
} |
424 |
|
|
|
425 |
|
|
template <typename Scalar> |
426 |
|
|
const std::vector<std::shared_ptr<crocoddyl::ActionModelAbstractTpl<Scalar> > >& |
427 |
|
✗ |
ShootingProblemTpl<Scalar>::get_runningModels() const { |
428 |
|
✗ |
return running_models_; |
429 |
|
|
} |
430 |
|
|
|
431 |
|
|
template <typename Scalar> |
432 |
|
|
const std::shared_ptr<crocoddyl::ActionModelAbstractTpl<Scalar> >& |
433 |
|
✗ |
ShootingProblemTpl<Scalar>::get_terminalModel() const { |
434 |
|
✗ |
return terminal_model_; |
435 |
|
|
} |
436 |
|
|
|
437 |
|
|
template <typename Scalar> |
438 |
|
|
const std::vector<std::shared_ptr<crocoddyl::ActionDataAbstractTpl<Scalar> > >& |
439 |
|
✗ |
ShootingProblemTpl<Scalar>::get_runningDatas() const { |
440 |
|
✗ |
return running_datas_; |
441 |
|
|
} |
442 |
|
|
|
443 |
|
|
template <typename Scalar> |
444 |
|
|
const std::shared_ptr<crocoddyl::ActionDataAbstractTpl<Scalar> >& |
445 |
|
✗ |
ShootingProblemTpl<Scalar>::get_terminalData() const { |
446 |
|
✗ |
return terminal_data_; |
447 |
|
|
} |
448 |
|
|
|
449 |
|
|
template <typename Scalar> |
450 |
|
✗ |
void ShootingProblemTpl<Scalar>::set_x0(const VectorXs& x0_in) { |
451 |
|
✗ |
if (x0_in.size() != x0_.size()) { |
452 |
|
✗ |
throw_pretty("Invalid argument: " |
453 |
|
|
<< "invalid size of x0 provided: Expected " << x0_.size() |
454 |
|
|
<< ", received " << x0_in.size()); |
455 |
|
|
} |
456 |
|
✗ |
x0_ = x0_in; |
457 |
|
|
} |
458 |
|
|
|
459 |
|
|
template <typename Scalar> |
460 |
|
✗ |
void ShootingProblemTpl<Scalar>::set_runningModels( |
461 |
|
|
const std::vector<std::shared_ptr<ActionModelAbstract> >& models) { |
462 |
|
✗ |
for (std::size_t i = 0; i < T_; ++i) { |
463 |
|
✗ |
const std::shared_ptr<ActionModelAbstract>& model = running_models_[i]; |
464 |
|
✗ |
if (model->get_state()->get_nx() != nx_) { |
465 |
|
✗ |
throw_pretty("Invalid argument: " |
466 |
|
|
<< "nx in " << i |
467 |
|
|
<< " node is not consistent with the other nodes") |
468 |
|
|
} |
469 |
|
✗ |
if (model->get_state()->get_ndx() != ndx_) { |
470 |
|
✗ |
throw_pretty("Invalid argument: " |
471 |
|
|
<< "ndx in " << i |
472 |
|
|
<< " node is not consistent with the other nodes") |
473 |
|
|
} |
474 |
|
|
} |
475 |
|
✗ |
is_updated_ = true; |
476 |
|
✗ |
T_ = models.size(); |
477 |
|
✗ |
running_models_.clear(); |
478 |
|
✗ |
running_datas_.clear(); |
479 |
|
✗ |
for (std::size_t i = 0; i < T_; ++i) { |
480 |
|
✗ |
const std::shared_ptr<ActionModelAbstract>& model = running_models_[i]; |
481 |
|
✗ |
running_datas_.push_back(model->createData()); |
482 |
|
|
} |
483 |
|
|
} |
484 |
|
|
|
485 |
|
|
template <typename Scalar> |
486 |
|
✗ |
void ShootingProblemTpl<Scalar>::set_terminalModel( |
487 |
|
|
std::shared_ptr<ActionModelAbstract> model) { |
488 |
|
✗ |
if (model->get_state()->get_nx() != nx_) { |
489 |
|
✗ |
throw_pretty( |
490 |
|
|
"Invalid argument: " << "nx is not consistent with the other nodes") |
491 |
|
|
} |
492 |
|
✗ |
if (model->get_state()->get_ndx() != ndx_) { |
493 |
|
✗ |
throw_pretty( |
494 |
|
|
"Invalid argument: " << "ndx is not consistent with the other nodes") |
495 |
|
|
} |
496 |
|
✗ |
is_updated_ = true; |
497 |
|
✗ |
terminal_model_ = model; |
498 |
|
✗ |
terminal_data_ = terminal_model_->createData(); |
499 |
|
|
} |
500 |
|
|
|
501 |
|
|
template <typename Scalar> |
502 |
|
✗ |
void ShootingProblemTpl<Scalar>::set_nthreads(const int nthreads) { |
503 |
|
|
#ifndef CROCODDYL_WITH_MULTITHREADING |
504 |
|
|
(void)nthreads; |
505 |
|
✗ |
std::cerr << "Warning: the number of threads won't affect the computational " |
506 |
|
|
"performance as multithreading " |
507 |
|
|
"support is not enabled." |
508 |
|
✗ |
<< std::endl; |
509 |
|
|
#else |
510 |
|
|
if (nthreads < 1) { |
511 |
|
|
nthreads_ = CROCODDYL_WITH_NTHREADS; |
512 |
|
|
} else { |
513 |
|
|
nthreads_ = static_cast<std::size_t>(nthreads); |
514 |
|
|
} |
515 |
|
|
if (!enableMultithreading()) { |
516 |
|
|
std::cerr << "Warning: the number of threads won't affect the " |
517 |
|
|
"computational performance as multithreading " |
518 |
|
|
"support is not enabled." |
519 |
|
|
<< std::endl; |
520 |
|
|
nthreads_ = 1; |
521 |
|
|
} |
522 |
|
|
#endif |
523 |
|
|
} |
524 |
|
|
|
525 |
|
|
template <typename Scalar> |
526 |
|
✗ |
std::size_t ShootingProblemTpl<Scalar>::get_nx() const { |
527 |
|
✗ |
return nx_; |
528 |
|
|
} |
529 |
|
|
|
530 |
|
|
template <typename Scalar> |
531 |
|
✗ |
std::size_t ShootingProblemTpl<Scalar>::get_ndx() const { |
532 |
|
✗ |
return ndx_; |
533 |
|
|
} |
534 |
|
|
|
535 |
|
|
template <typename Scalar> |
536 |
|
✗ |
std::size_t ShootingProblemTpl<Scalar>::get_nthreads() const { |
537 |
|
|
#ifndef CROCODDYL_WITH_MULTITHREADING |
538 |
|
✗ |
std::cerr << "Warning: the number of threads won't affect the computational " |
539 |
|
|
"performance as multithreading " |
540 |
|
|
"support is not enabled." |
541 |
|
✗ |
<< std::endl; |
542 |
|
|
#endif |
543 |
|
✗ |
return nthreads_; |
544 |
|
|
} |
545 |
|
|
|
546 |
|
|
template <typename Scalar> |
547 |
|
✗ |
bool ShootingProblemTpl<Scalar>::is_updated() { |
548 |
|
✗ |
const bool status = is_updated_; |
549 |
|
✗ |
is_updated_ = false; |
550 |
|
✗ |
return status; |
551 |
|
|
} |
552 |
|
|
|
553 |
|
|
template <typename Scalar> |
554 |
|
✗ |
std::ostream& operator<<(std::ostream& os, |
555 |
|
|
const ShootingProblemTpl<Scalar>& problem) { |
556 |
|
✗ |
os << "ShootingProblem (T=" << problem.get_T() << ", nx=" << problem.get_nx() |
557 |
|
✗ |
<< ", ndx=" << problem.get_ndx() << ") " << std::endl |
558 |
|
✗ |
<< " Models:" << std::endl; |
559 |
|
|
const std::vector< |
560 |
|
|
std::shared_ptr<crocoddyl::ActionModelAbstractTpl<Scalar> > >& |
561 |
|
✗ |
runningModels = problem.get_runningModels(); |
562 |
|
✗ |
for (std::size_t t = 0; t < problem.get_T(); ++t) { |
563 |
|
✗ |
os << " " << t << ": " << *runningModels[t] << std::endl; |
564 |
|
|
} |
565 |
|
✗ |
os << " " << problem.get_T() << ": " << *problem.get_terminalModel(); |
566 |
|
✗ |
return os; |
567 |
|
|
} |
568 |
|
|
|
569 |
|
|
} // namespace crocoddyl |
570 |
|
|
|