Directory: | ./ |
---|---|
File: | include/crocoddyl/core/optctrl/shooting.hxx |
Date: | 2025-03-26 19:23:43 |
Exec | Total | Coverage | |
---|---|---|---|
Lines: | 137 | 254 | 53.9% |
Branches: | 104 | 796 | 13.1% |
Line | Branch | Exec | Source |
---|---|---|---|
1 | /////////////////////////////////////////////////////////////////////////////// | ||
2 | // BSD 3-Clause License | ||
3 | // | ||
4 | // Copyright (C) 2019-2025, LAAS-CNRS, University of Edinburgh, | ||
5 | // University of Oxford, Heriot-Watt University | ||
6 | // Copyright note valid unless otherwise stated in individual files. | ||
7 | // All rights reserved. | ||
8 | /////////////////////////////////////////////////////////////////////////////// | ||
9 | |||
10 | #include <iostream> | ||
11 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
12 | #include <omp.h> | ||
13 | #endif // CROCODDYL_WITH_MULTITHREADING | ||
14 | #include "crocoddyl/core/utils/stop-watch.hpp" | ||
15 | |||
16 | namespace crocoddyl { | ||
17 | |||
18 | template <typename Scalar> | ||
19 | 525 | ShootingProblemTpl<Scalar>::ShootingProblemTpl( | |
20 | const VectorXs& x0, | ||
21 | const std::vector<std::shared_ptr<ActionModelAbstract> >& running_models, | ||
22 | std::shared_ptr<ActionModelAbstract> terminal_model) | ||
23 | 525 | : cost_(Scalar(0.)), | |
24 | 525 | T_(running_models.size()), | |
25 | 525 | x0_(x0), | |
26 | 525 | terminal_model_(terminal_model), | |
27 |
1/2✓ Branch 1 taken 525 times.
✗ Branch 2 not taken.
|
525 | running_models_(running_models), |
28 |
2/4✓ Branch 3 taken 525 times.
✗ Branch 4 not taken.
✓ Branch 7 taken 525 times.
✗ Branch 8 not taken.
|
525 | nx_(running_models[0]->get_state()->get_nx()), |
29 |
2/4✓ Branch 3 taken 525 times.
✗ Branch 4 not taken.
✓ Branch 7 taken 525 times.
✗ Branch 8 not taken.
|
525 | ndx_(running_models[0]->get_state()->get_ndx()), |
30 | 525 | nthreads_(1), | |
31 | 1050 | is_updated_(false) { | |
32 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 525 times.
|
525 | if (static_cast<std::size_t>(x0.size()) != nx_) { |
33 | ✗ | throw_pretty( | |
34 | "Invalid argument: " << "x0 has wrong dimension (it should be " + | ||
35 | std::to_string(nx_) + ")"); | ||
36 | } | ||
37 |
2/2✓ Branch 0 taken 9135 times.
✓ Branch 1 taken 525 times.
|
9660 | for (std::size_t i = 1; i < T_; ++i) { |
38 | 9135 | const std::shared_ptr<ActionModelAbstract>& model = running_models_[i]; | |
39 |
3/6✓ Branch 2 taken 9135 times.
✗ Branch 3 not taken.
✓ Branch 6 taken 9135 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 9135 times.
|
9135 | if (model->get_state()->get_nx() != nx_) { |
40 | ✗ | throw_pretty("Invalid argument: " | |
41 | << "nx in " << i | ||
42 | << " node is not consistent with the other nodes") | ||
43 | } | ||
44 |
3/6✓ Branch 2 taken 9135 times.
✗ Branch 3 not taken.
✓ Branch 6 taken 9135 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 9135 times.
|
9135 | if (model->get_state()->get_ndx() != ndx_) { |
45 | ✗ | throw_pretty("Invalid argument: " | |
46 | << "ndx in " << i | ||
47 | << " node is not consistent with the other nodes") | ||
48 | } | ||
49 | } | ||
50 |
3/6✓ Branch 2 taken 525 times.
✗ Branch 3 not taken.
✓ Branch 6 taken 525 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 525 times.
|
525 | if (terminal_model_->get_state()->get_nx() != nx_) { |
51 | ✗ | throw_pretty( | |
52 | "Invalid argument: " | ||
53 | << "nx in terminal node is not consistent with the other nodes") | ||
54 | } | ||
55 |
3/6✓ Branch 2 taken 525 times.
✗ Branch 3 not taken.
✓ Branch 6 taken 525 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 525 times.
|
525 | if (terminal_model_->get_state()->get_ndx() != ndx_) { |
56 | ✗ | throw_pretty( | |
57 | "Invalid argument: " | ||
58 | << "ndx in terminal node is not consistent with the other nodes") | ||
59 | } | ||
60 |
1/2✓ Branch 1 taken 525 times.
✗ Branch 2 not taken.
|
525 | allocateData(); |
61 | |||
62 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
63 | if (enableMultithreading()) { | ||
64 | nthreads_ = CROCODDYL_WITH_NTHREADS; | ||
65 | } | ||
66 | #endif | ||
67 | 525 | } | |
68 | |||
69 | template <typename Scalar> | ||
70 | 297 | ShootingProblemTpl<Scalar>::ShootingProblemTpl( | |
71 | const VectorXs& x0, | ||
72 | const std::vector<std::shared_ptr<ActionModelAbstract> >& running_models, | ||
73 | std::shared_ptr<ActionModelAbstract> terminal_model, | ||
74 | const std::vector<std::shared_ptr<ActionDataAbstract> >& running_datas, | ||
75 | std::shared_ptr<ActionDataAbstract> terminal_data) | ||
76 | 297 | : cost_(Scalar(0.)), | |
77 | 297 | T_(running_models.size()), | |
78 | 297 | x0_(x0), | |
79 | 297 | terminal_model_(terminal_model), | |
80 | 297 | terminal_data_(terminal_data), | |
81 |
1/2✓ Branch 1 taken 297 times.
✗ Branch 2 not taken.
|
297 | running_models_(running_models), |
82 |
1/2✓ Branch 1 taken 297 times.
✗ Branch 2 not taken.
|
297 | running_datas_(running_datas), |
83 |
2/4✓ Branch 3 taken 297 times.
✗ Branch 4 not taken.
✓ Branch 7 taken 297 times.
✗ Branch 8 not taken.
|
297 | nx_(running_models[0]->get_state()->get_nx()), |
84 |
2/4✓ Branch 3 taken 297 times.
✗ Branch 4 not taken.
✓ Branch 7 taken 297 times.
✗ Branch 8 not taken.
|
297 | ndx_(running_models[0]->get_state()->get_ndx()), |
85 | 297 | nthreads_(1) { | |
86 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 297 times.
|
297 | if (static_cast<std::size_t>(x0.size()) != nx_) { |
87 | ✗ | throw_pretty( | |
88 | "Invalid argument: " << "x0 has wrong dimension (it should be " + | ||
89 | std::to_string(nx_) + ")"); | ||
90 | } | ||
91 | 297 | const std::size_t Td = running_datas.size(); | |
92 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 297 times.
|
297 | if (Td != T_) { |
93 | ✗ | throw_pretty( | |
94 | "Invalid argument: " | ||
95 | << "the number of running models and datas are not the same (" + | ||
96 | std::to_string(T_) + " != " + std::to_string(Td) + ")") | ||
97 | } | ||
98 |
2/2✓ Branch 0 taken 5940 times.
✓ Branch 1 taken 297 times.
|
6237 | for (std::size_t i = 0; i < T_; ++i) { |
99 | 5940 | const std::shared_ptr<ActionModelAbstract>& model = running_models_[i]; | |
100 | 5940 | const std::shared_ptr<ActionDataAbstract>& data = running_datas_[i]; | |
101 |
3/6✓ Branch 2 taken 5940 times.
✗ Branch 3 not taken.
✓ Branch 6 taken 5940 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 5940 times.
|
5940 | if (model->get_state()->get_nx() != nx_) { |
102 | ✗ | throw_pretty("Invalid argument: " | |
103 | << "nx in " << i | ||
104 | << " node is not consistent with the other nodes") | ||
105 | } | ||
106 |
3/6✓ Branch 2 taken 5940 times.
✗ Branch 3 not taken.
✓ Branch 6 taken 5940 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 5940 times.
|
5940 | if (model->get_state()->get_ndx() != ndx_) { |
107 | ✗ | throw_pretty("Invalid argument: " | |
108 | << "ndx in " << i | ||
109 | << " node is not consistent with the other nodes") | ||
110 | } | ||
111 |
2/4✓ Branch 2 taken 5940 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 5940 times.
|
5940 | if (!model->checkData(data)) { |
112 | ✗ | throw_pretty("Invalid argument: " | |
113 | << "action data in " << i | ||
114 | << " node is not consistent with the action model") | ||
115 | } | ||
116 | } | ||
117 |
2/4✓ Branch 2 taken 297 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 297 times.
|
297 | if (!terminal_model->checkData(terminal_data)) { |
118 | ✗ | throw_pretty("Invalid argument: " | |
119 | << "terminal action data is not consistent with the terminal " | ||
120 | "action model") | ||
121 | } | ||
122 | |||
123 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
124 | if (enableMultithreading()) { | ||
125 | nthreads_ = CROCODDYL_WITH_NTHREADS; | ||
126 | } | ||
127 | #endif | ||
128 | 297 | } | |
129 | |||
130 | template <typename Scalar> | ||
131 | 2 | ShootingProblemTpl<Scalar>::ShootingProblemTpl( | |
132 | const ShootingProblemTpl<Scalar>& problem) | ||
133 | 2 | : cost_(Scalar(0.)), | |
134 | 2 | T_(problem.get_T()), | |
135 | 2 | x0_(problem.get_x0()), | |
136 |
1/2✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
|
2 | terminal_model_(problem.get_terminalModel()), |
137 |
1/2✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
|
2 | terminal_data_(problem.get_terminalData()), |
138 |
2/4✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
|
2 | running_models_(problem.get_runningModels()), |
139 |
2/4✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
|
2 | running_datas_(problem.get_runningDatas()), |
140 |
1/2✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
|
2 | nx_(problem.get_nx()), |
141 |
1/2✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
|
2 | ndx_(problem.get_ndx()) {} |
142 | |||
143 | template <typename Scalar> | ||
144 | 824 | ShootingProblemTpl<Scalar>::~ShootingProblemTpl() {} | |
145 | |||
146 | template <typename Scalar> | ||
147 | 619 | Scalar ShootingProblemTpl<Scalar>::calc(const std::vector<VectorXs>& xs, | |
148 | const std::vector<VectorXs>& us) { | ||
149 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 619 times.
|
619 | if (xs.size() != T_ + 1) { |
150 | ✗ | throw_pretty( | |
151 | "Invalid argument: " << "xs has wrong dimension (it should be " + | ||
152 | std::to_string(T_ + 1) + ")"); | ||
153 | } | ||
154 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 619 times.
|
619 | if (us.size() != T_) { |
155 | ✗ | throw_pretty( | |
156 | "Invalid argument: " << "us has wrong dimension (it should be " + | ||
157 | std::to_string(T_) + ")"); | ||
158 | } | ||
159 |
3/16✗ Branch 2 not taken.
✓ Branch 3 taken 619 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 619 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 619 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
|
619 | START_PROFILER("ShootingProblem::calc"); |
160 | |||
161 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
162 | #pragma omp parallel for num_threads(nthreads_) | ||
163 | #endif | ||
164 |
2/2✓ Branch 0 taken 10325 times.
✓ Branch 1 taken 619 times.
|
10944 | for (std::size_t i = 0; i < T_; ++i) { |
165 |
2/4✓ Branch 6 taken 10325 times.
✗ Branch 7 not taken.
✓ Branch 10 taken 10325 times.
✗ Branch 11 not taken.
|
10325 | running_models_[i]->calc(running_datas_[i], xs[i], us[i]); |
166 | } | ||
167 |
1/2✓ Branch 4 taken 619 times.
✗ Branch 5 not taken.
|
619 | terminal_model_->calc(terminal_data_, xs.back()); |
168 | |||
169 | 619 | cost_ = Scalar(0.); | |
170 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
171 | #pragma omp simd reduction(+ : cost_) | ||
172 | #endif | ||
173 |
2/2✓ Branch 0 taken 10325 times.
✓ Branch 1 taken 619 times.
|
10944 | for (std::size_t i = 0; i < T_; ++i) { |
174 | 10325 | cost_ += running_datas_[i]->cost; | |
175 | } | ||
176 | 619 | cost_ += terminal_data_->cost; | |
177 |
3/16✗ Branch 2 not taken.
✓ Branch 3 taken 619 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 619 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 619 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
|
619 | STOP_PROFILER("ShootingProblem::calc"); |
178 | 619 | return cost_; | |
179 | } | ||
180 | |||
181 | template <typename Scalar> | ||
182 | 443 | Scalar ShootingProblemTpl<Scalar>::calcDiff(const std::vector<VectorXs>& xs, | |
183 | const std::vector<VectorXs>& us) { | ||
184 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 443 times.
|
443 | if (xs.size() != T_ + 1) { |
185 | ✗ | throw_pretty( | |
186 | "Invalid argument: " << "xs has wrong dimension (it should be " + | ||
187 | std::to_string(T_ + 1) + ")"); | ||
188 | } | ||
189 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 443 times.
|
443 | if (us.size() != T_) { |
190 | ✗ | throw_pretty( | |
191 | "Invalid argument: " << "us has wrong dimension (it should be " + | ||
192 | std::to_string(T_) + ")"); | ||
193 | } | ||
194 |
3/16✗ Branch 2 not taken.
✓ Branch 3 taken 443 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 443 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 443 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
|
443 | START_PROFILER("ShootingProblem::calcDiff"); |
195 | |||
196 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
197 | #pragma omp parallel for num_threads(nthreads_) | ||
198 | #endif | ||
199 |
2/2✓ Branch 0 taken 6582 times.
✓ Branch 1 taken 443 times.
|
7025 | for (std::size_t i = 0; i < T_; ++i) { |
200 |
2/4✓ Branch 6 taken 6582 times.
✗ Branch 7 not taken.
✓ Branch 10 taken 6582 times.
✗ Branch 11 not taken.
|
6582 | running_models_[i]->calcDiff(running_datas_[i], xs[i], us[i]); |
201 | } | ||
202 |
1/2✓ Branch 4 taken 443 times.
✗ Branch 5 not taken.
|
443 | terminal_model_->calcDiff(terminal_data_, xs.back()); |
203 | |||
204 | 443 | cost_ = Scalar(0.); | |
205 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
206 | #pragma omp simd reduction(+ : cost_) | ||
207 | #endif | ||
208 |
2/2✓ Branch 0 taken 6582 times.
✓ Branch 1 taken 443 times.
|
7025 | for (std::size_t i = 0; i < T_; ++i) { |
209 | 6582 | cost_ += running_datas_[i]->cost; | |
210 | } | ||
211 | 443 | cost_ += terminal_data_->cost; | |
212 | |||
213 |
3/16✗ Branch 2 not taken.
✓ Branch 3 taken 443 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 443 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 443 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
|
443 | STOP_PROFILER("ShootingProblem::calcDiff"); |
214 | 443 | return cost_; | |
215 | } | ||
216 | |||
217 | template <typename Scalar> | ||
218 | 103 | void ShootingProblemTpl<Scalar>::rollout(const std::vector<VectorXs>& us, | |
219 | std::vector<VectorXs>& xs) { | ||
220 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 103 times.
|
103 | if (xs.size() != T_ + 1) { |
221 | ✗ | throw_pretty( | |
222 | "Invalid argument: " << "xs has wrong dimension (it should be " + | ||
223 | std::to_string(T_ + 1) + ")"); | ||
224 | } | ||
225 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 103 times.
|
103 | if (us.size() != T_) { |
226 | ✗ | throw_pretty( | |
227 | "Invalid argument: " << "us has wrong dimension (it should be " + | ||
228 | std::to_string(T_) + ")"); | ||
229 | } | ||
230 |
3/16✗ Branch 2 not taken.
✓ Branch 3 taken 103 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 103 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 103 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
|
103 | START_PROFILER("ShootingProblem::rollout"); |
231 | |||
232 | 103 | xs[0] = x0_; | |
233 |
2/2✓ Branch 0 taken 2196 times.
✓ Branch 1 taken 103 times.
|
2299 | for (std::size_t i = 0; i < T_; ++i) { |
234 | 2196 | const std::shared_ptr<ActionDataAbstract>& data = running_datas_[i]; | |
235 |
2/4✓ Branch 6 taken 2196 times.
✗ Branch 7 not taken.
✓ Branch 9 taken 2196 times.
✗ Branch 10 not taken.
|
2196 | running_models_[i]->calc(data, xs[i], us[i]); |
236 | 2196 | xs[i + 1] = data->xnext; | |
237 | } | ||
238 |
1/2✓ Branch 4 taken 103 times.
✗ Branch 5 not taken.
|
103 | terminal_model_->calc(terminal_data_, xs.back()); |
239 |
3/16✗ Branch 2 not taken.
✓ Branch 3 taken 103 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 103 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 103 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
|
103 | STOP_PROFILER("ShootingProblem::rollout"); |
240 | 103 | } | |
241 | |||
242 | template <typename Scalar> | ||
243 | std::vector<typename MathBaseTpl<Scalar>::VectorXs> | ||
244 | 4 | ShootingProblemTpl<Scalar>::rollout_us(const std::vector<VectorXs>& us) { | |
245 | 4 | std::vector<VectorXs> xs; | |
246 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | xs.resize(T_ + 1); |
247 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | rollout(us, xs); |
248 | 4 | return xs; | |
249 | } | ||
250 | |||
251 | template <typename Scalar> | ||
252 | 198 | void ShootingProblemTpl<Scalar>::quasiStatic(std::vector<VectorXs>& us, | |
253 | const std::vector<VectorXs>& xs) { | ||
254 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 198 times.
|
198 | if (xs.size() != T_) { |
255 | ✗ | throw_pretty( | |
256 | "Invalid argument: " << "xs has wrong dimension (it should be " + | ||
257 | std::to_string(T_) + ")"); | ||
258 | } | ||
259 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 198 times.
|
198 | if (us.size() != T_) { |
260 | ✗ | throw_pretty( | |
261 | "Invalid argument: " << "us has wrong dimension (it should be " + | ||
262 | std::to_string(T_) + ")"); | ||
263 | } | ||
264 | |||
265 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
266 | #pragma omp parallel for num_threads(nthreads_) | ||
267 | #endif | ||
268 |
2/2✓ Branch 0 taken 3960 times.
✓ Branch 1 taken 198 times.
|
4158 | for (std::size_t i = 0; i < T_; ++i) { |
269 |
2/4✓ Branch 6 taken 3960 times.
✗ Branch 7 not taken.
✓ Branch 10 taken 3960 times.
✗ Branch 11 not taken.
|
3960 | running_models_[i]->quasiStatic(running_datas_[i], us[i], xs[i]); |
270 | } | ||
271 | 198 | } | |
272 | |||
273 | template <typename Scalar> | ||
274 | std::vector<typename MathBaseTpl<Scalar>::VectorXs> | ||
275 | ✗ | ShootingProblemTpl<Scalar>::quasiStatic_xs(const std::vector<VectorXs>& xs) { | |
276 | ✗ | std::vector<VectorXs> us; | |
277 | ✗ | us.resize(T_); | |
278 | ✗ | for (std::size_t i = 0; i < T_; ++i) { | |
279 | ✗ | us[i] = VectorXs::Zero(running_models_[i]->get_nu()); | |
280 | } | ||
281 | ✗ | quasiStatic(us, xs); | |
282 | ✗ | return us; | |
283 | } | ||
284 | |||
285 | template <typename Scalar> | ||
286 | ✗ | void ShootingProblemTpl<Scalar>::circularAppend( | |
287 | std::shared_ptr<ActionModelAbstract> model, | ||
288 | std::shared_ptr<ActionDataAbstract> data) { | ||
289 | ✗ | if (!model->checkData(data)) { | |
290 | ✗ | throw_pretty("Invalid argument: " | |
291 | << "action data is not consistent with the action model") | ||
292 | } | ||
293 | ✗ | if (model->get_state()->get_nx() != nx_) { | |
294 | ✗ | throw_pretty( | |
295 | "Invalid argument: " << "nx is not consistent with the other nodes") | ||
296 | } | ||
297 | ✗ | if (model->get_state()->get_ndx() != ndx_) { | |
298 | ✗ | throw_pretty("Invalid argument: " | |
299 | << "ndx node is not consistent with the other nodes") | ||
300 | } | ||
301 | ✗ | is_updated_ = true; | |
302 | ✗ | for (std::size_t i = 0; i < T_ - 1; ++i) { | |
303 | ✗ | running_models_[i] = running_models_[i + 1]; | |
304 | ✗ | running_datas_[i] = running_datas_[i + 1]; | |
305 | } | ||
306 | ✗ | running_models_.back() = model; | |
307 | ✗ | running_datas_.back() = data; | |
308 | } | ||
309 | |||
310 | template <typename Scalar> | ||
311 | ✗ | void ShootingProblemTpl<Scalar>::circularAppend( | |
312 | std::shared_ptr<ActionModelAbstract> model) { | ||
313 | ✗ | if (model->get_state()->get_nx() != nx_) { | |
314 | ✗ | throw_pretty( | |
315 | "Invalid argument: " << "nx is not consistent with the other nodes") | ||
316 | } | ||
317 | ✗ | if (model->get_state()->get_ndx() != ndx_) { | |
318 | ✗ | throw_pretty("Invalid argument: " | |
319 | << "ndx node is not consistent with the other nodes") | ||
320 | } | ||
321 | ✗ | is_updated_ = true; | |
322 | ✗ | for (std::size_t i = 0; i < T_ - 1; ++i) { | |
323 | ✗ | running_models_[i] = running_models_[i + 1]; | |
324 | ✗ | running_datas_[i] = running_datas_[i + 1]; | |
325 | } | ||
326 | ✗ | running_models_.back() = model; | |
327 | ✗ | running_datas_.back() = model->createData(); | |
328 | } | ||
329 | |||
330 | template <typename Scalar> | ||
331 | ✗ | void ShootingProblemTpl<Scalar>::updateNode( | |
332 | const std::size_t i, std::shared_ptr<ActionModelAbstract> model, | ||
333 | std::shared_ptr<ActionDataAbstract> data) { | ||
334 | ✗ | if (i >= T_ + 1) { | |
335 | ✗ | throw_pretty("Invalid argument: " | |
336 | << "i is bigger than the allocated horizon (it should be less " | ||
337 | "than or equal to " + | ||
338 | std::to_string(T_ + 1) + ")"); | ||
339 | } | ||
340 | ✗ | if (!model->checkData(data)) { | |
341 | ✗ | throw_pretty("Invalid argument: " | |
342 | << "action data is not consistent with the action model") | ||
343 | } | ||
344 | ✗ | if (model->get_state()->get_nx() != nx_) { | |
345 | ✗ | throw_pretty( | |
346 | "Invalid argument: " << "nx is not consistent with the other nodes") | ||
347 | } | ||
348 | ✗ | if (model->get_state()->get_ndx() != ndx_) { | |
349 | ✗ | throw_pretty("Invalid argument: " | |
350 | << "ndx node is not consistent with the other nodes") | ||
351 | } | ||
352 | ✗ | is_updated_ = true; | |
353 | ✗ | if (i == T_) { | |
354 | ✗ | terminal_model_ = model; | |
355 | ✗ | terminal_data_ = data; | |
356 | } else { | ||
357 | ✗ | running_models_[i] = model; | |
358 | ✗ | running_datas_[i] = data; | |
359 | } | ||
360 | } | ||
361 | |||
362 | template <typename Scalar> | ||
363 | ✗ | void ShootingProblemTpl<Scalar>::updateModel( | |
364 | const std::size_t i, std::shared_ptr<ActionModelAbstract> model) { | ||
365 | ✗ | if (i >= T_ + 1) { | |
366 | ✗ | throw_pretty( | |
367 | "Invalid argument: " | ||
368 | << "i is bigger than the allocated horizon (it should be lower than " + | ||
369 | std::to_string(T_ + 1) + ")"); | ||
370 | } | ||
371 | ✗ | if (model->get_state()->get_nx() != nx_) { | |
372 | ✗ | throw_pretty( | |
373 | "Invalid argument: " << "nx is not consistent with the other nodes") | ||
374 | } | ||
375 | ✗ | if (model->get_state()->get_ndx() != ndx_) { | |
376 | ✗ | throw_pretty( | |
377 | "Invalid argument: " << "ndx is not consistent with the other nodes") | ||
378 | } | ||
379 | ✗ | is_updated_ = true; | |
380 | ✗ | if (i == T_) { | |
381 | ✗ | terminal_model_ = model; | |
382 | ✗ | terminal_data_ = terminal_model_->createData(); | |
383 | } else { | ||
384 | ✗ | running_models_[i] = model; | |
385 | ✗ | running_datas_[i] = model->createData(); | |
386 | } | ||
387 | } | ||
388 | |||
389 | template <typename Scalar> | ||
390 | template <typename NewScalar> | ||
391 | ✗ | ShootingProblemTpl<NewScalar> ShootingProblemTpl<Scalar>::cast() const { | |
392 | typedef ShootingProblemTpl<NewScalar> ReturnType; | ||
393 | ✗ | ReturnType ret(x0_.template cast<NewScalar>(), | |
394 | ✗ | vector_cast<NewScalar>(running_models_), | |
395 | ✗ | terminal_model_->template cast<NewScalar>()); | |
396 | ✗ | ret.set_nthreads((int)nthreads_); | |
397 | ✗ | return ret; | |
398 | } | ||
399 | |||
400 | template <typename Scalar> | ||
401 | 5652 | std::size_t ShootingProblemTpl<Scalar>::get_T() const { | |
402 | 5652 | return T_; | |
403 | } | ||
404 | |||
405 | template <typename Scalar> | ||
406 | const typename MathBaseTpl<Scalar>::VectorXs& | ||
407 | 545 | ShootingProblemTpl<Scalar>::get_x0() const { | |
408 | 545 | return x0_; | |
409 | } | ||
410 | |||
411 | template <typename Scalar> | ||
412 | 525 | void ShootingProblemTpl<Scalar>::allocateData() { | |
413 | 525 | running_datas_.resize(T_); | |
414 |
2/2✓ Branch 0 taken 9660 times.
✓ Branch 1 taken 525 times.
|
10185 | for (std::size_t i = 0; i < T_; ++i) { |
415 | 9660 | const std::shared_ptr<ActionModelAbstract>& model = running_models_[i]; | |
416 | 9660 | running_datas_[i] = model->createData(); | |
417 | } | ||
418 | 525 | terminal_data_ = terminal_model_->createData(); | |
419 | 525 | } | |
420 | |||
421 | template <typename Scalar> | ||
422 | const std::vector<std::shared_ptr<crocoddyl::ActionModelAbstractTpl<Scalar> > >& | ||
423 | 5079 | ShootingProblemTpl<Scalar>::get_runningModels() const { | |
424 | 5079 | return running_models_; | |
425 | } | ||
426 | |||
427 | template <typename Scalar> | ||
428 | const std::shared_ptr<crocoddyl::ActionModelAbstractTpl<Scalar> >& | ||
429 | 1552 | ShootingProblemTpl<Scalar>::get_terminalModel() const { | |
430 | 1552 | return terminal_model_; | |
431 | } | ||
432 | |||
433 | template <typename Scalar> | ||
434 | const std::vector<std::shared_ptr<crocoddyl::ActionDataAbstractTpl<Scalar> > >& | ||
435 | 37149 | ShootingProblemTpl<Scalar>::get_runningDatas() const { | |
436 | 37149 | return running_datas_; | |
437 | } | ||
438 | |||
439 | template <typename Scalar> | ||
440 | const std::shared_ptr<crocoddyl::ActionDataAbstractTpl<Scalar> >& | ||
441 | 1585 | ShootingProblemTpl<Scalar>::get_terminalData() const { | |
442 | 1585 | return terminal_data_; | |
443 | } | ||
444 | |||
445 | template <typename Scalar> | ||
446 | ✗ | void ShootingProblemTpl<Scalar>::set_x0(const VectorXs& x0_in) { | |
447 | ✗ | if (x0_in.size() != x0_.size()) { | |
448 | ✗ | throw_pretty("Invalid argument: " | |
449 | << "invalid size of x0 provided: Expected " << x0_.size() | ||
450 | << ", received " << x0_in.size()); | ||
451 | } | ||
452 | ✗ | x0_ = x0_in; | |
453 | } | ||
454 | |||
455 | template <typename Scalar> | ||
456 | ✗ | void ShootingProblemTpl<Scalar>::set_runningModels( | |
457 | const std::vector<std::shared_ptr<ActionModelAbstract> >& models) { | ||
458 | ✗ | for (std::size_t i = 0; i < T_; ++i) { | |
459 | ✗ | const std::shared_ptr<ActionModelAbstract>& model = running_models_[i]; | |
460 | ✗ | if (model->get_state()->get_nx() != nx_) { | |
461 | ✗ | throw_pretty("Invalid argument: " | |
462 | << "nx in " << i | ||
463 | << " node is not consistent with the other nodes") | ||
464 | } | ||
465 | ✗ | if (model->get_state()->get_ndx() != ndx_) { | |
466 | ✗ | throw_pretty("Invalid argument: " | |
467 | << "ndx in " << i | ||
468 | << " node is not consistent with the other nodes") | ||
469 | } | ||
470 | } | ||
471 | ✗ | is_updated_ = true; | |
472 | ✗ | T_ = models.size(); | |
473 | ✗ | running_models_.clear(); | |
474 | ✗ | running_datas_.clear(); | |
475 | ✗ | for (std::size_t i = 0; i < T_; ++i) { | |
476 | ✗ | const std::shared_ptr<ActionModelAbstract>& model = running_models_[i]; | |
477 | ✗ | running_datas_.push_back(model->createData()); | |
478 | } | ||
479 | } | ||
480 | |||
481 | template <typename Scalar> | ||
482 | ✗ | void ShootingProblemTpl<Scalar>::set_terminalModel( | |
483 | std::shared_ptr<ActionModelAbstract> model) { | ||
484 | ✗ | if (model->get_state()->get_nx() != nx_) { | |
485 | ✗ | throw_pretty( | |
486 | "Invalid argument: " << "nx is not consistent with the other nodes") | ||
487 | } | ||
488 | ✗ | if (model->get_state()->get_ndx() != ndx_) { | |
489 | ✗ | throw_pretty( | |
490 | "Invalid argument: " << "ndx is not consistent with the other nodes") | ||
491 | } | ||
492 | ✗ | is_updated_ = true; | |
493 | ✗ | terminal_model_ = model; | |
494 | ✗ | terminal_data_ = terminal_model_->createData(); | |
495 | } | ||
496 | |||
497 | template <typename Scalar> | ||
498 | ✗ | void ShootingProblemTpl<Scalar>::set_nthreads(const int nthreads) { | |
499 | #ifndef CROCODDYL_WITH_MULTITHREADING | ||
500 | (void)nthreads; | ||
501 | ✗ | std::cerr << "Warning: the number of threads won't affect the computational " | |
502 | "performance as multithreading " | ||
503 | "support is not enabled." | ||
504 | ✗ | << std::endl; | |
505 | #else | ||
506 | if (nthreads < 1) { | ||
507 | nthreads_ = CROCODDYL_WITH_NTHREADS; | ||
508 | } else { | ||
509 | nthreads_ = static_cast<std::size_t>(nthreads); | ||
510 | } | ||
511 | if (!enableMultithreading()) { | ||
512 | std::cerr << "Warning: the number of threads won't affect the " | ||
513 | "computational performance as multithreading " | ||
514 | "support is not enabled." | ||
515 | << std::endl; | ||
516 | nthreads_ = 1; | ||
517 | } | ||
518 | #endif | ||
519 | } | ||
520 | |||
521 | template <typename Scalar> | ||
522 | 49 | std::size_t ShootingProblemTpl<Scalar>::get_nx() const { | |
523 | 49 | return nx_; | |
524 | } | ||
525 | |||
526 | template <typename Scalar> | ||
527 | 230 | std::size_t ShootingProblemTpl<Scalar>::get_ndx() const { | |
528 | 230 | return ndx_; | |
529 | } | ||
530 | |||
531 | template <typename Scalar> | ||
532 | ✗ | std::size_t ShootingProblemTpl<Scalar>::get_nthreads() const { | |
533 | #ifndef CROCODDYL_WITH_MULTITHREADING | ||
534 | ✗ | std::cerr << "Warning: the number of threads won't affect the computational " | |
535 | "performance as multithreading " | ||
536 | "support is not enabled." | ||
537 | ✗ | << std::endl; | |
538 | #endif | ||
539 | ✗ | return nthreads_; | |
540 | } | ||
541 | |||
542 | template <typename Scalar> | ||
543 | 32 | bool ShootingProblemTpl<Scalar>::is_updated() { | |
544 | 32 | const bool status = is_updated_; | |
545 | 32 | is_updated_ = false; | |
546 | 32 | return status; | |
547 | } | ||
548 | |||
549 | template <typename Scalar> | ||
550 | 7 | std::ostream& operator<<(std::ostream& os, | |
551 | const ShootingProblemTpl<Scalar>& problem) { | ||
552 | 7 | os << "ShootingProblem (T=" << problem.get_T() << ", nx=" << problem.get_nx() | |
553 | 7 | << ", ndx=" << problem.get_ndx() << ") " << std::endl | |
554 | 7 | << " Models:" << std::endl; | |
555 | const std::vector< | ||
556 | std::shared_ptr<crocoddyl::ActionModelAbstractTpl<Scalar> > >& | ||
557 | 7 | runningModels = problem.get_runningModels(); | |
558 |
2/2✓ Branch 1 taken 140 times.
✓ Branch 2 taken 7 times.
|
147 | for (std::size_t t = 0; t < problem.get_T(); ++t) { |
559 | 140 | os << " " << t << ": " << *runningModels[t] << std::endl; | |
560 | } | ||
561 | 7 | os << " " << problem.get_T() << ": " << *problem.get_terminalModel(); | |
562 | 7 | return os; | |
563 | } | ||
564 | |||
565 | } // namespace crocoddyl | ||
566 |