| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | /////////////////////////////////////////////////////////////////////////////// | ||
| 2 | // BSD 3-Clause License | ||
| 3 | // | ||
| 4 | // Copyright (C) 2019-2025, LAAS-CNRS, University of Edinburgh, | ||
| 5 | // University of Oxford, Heriot-Watt University | ||
| 6 | // Copyright note valid unless otherwise stated in individual files. | ||
| 7 | // All rights reserved. | ||
| 8 | /////////////////////////////////////////////////////////////////////////////// | ||
| 9 | |||
| 10 | #include "crocoddyl/core/utils/stop-watch.hpp" | ||
| 11 | |||
| 12 | namespace crocoddyl { | ||
| 13 | |||
| 14 | template <typename Scalar> | ||
| 15 | ✗ | ShootingProblemTpl<Scalar>::ShootingProblemTpl( | |
| 16 | const VectorXs& x0, | ||
| 17 | const std::vector<std::shared_ptr<ActionModelAbstract> >& running_models, | ||
| 18 | std::shared_ptr<ActionModelAbstract> terminal_model) | ||
| 19 | ✗ | : cost_(Scalar(0.)), | |
| 20 | ✗ | T_(running_models.size()), | |
| 21 | ✗ | x0_(x0), | |
| 22 | ✗ | terminal_model_(terminal_model), | |
| 23 | ✗ | running_models_(running_models), | |
| 24 | ✗ | nx_(running_models[0]->get_state()->get_nx()), | |
| 25 | ✗ | ndx_(running_models[0]->get_state()->get_ndx()), | |
| 26 | ✗ | nthreads_(1), | |
| 27 | ✗ | is_updated_(false) { | |
| 28 | ✗ | if (static_cast<std::size_t>(x0.size()) != nx_) { | |
| 29 | ✗ | throw_pretty( | |
| 30 | "Invalid argument: " << "x0 has wrong dimension (it should be " + | ||
| 31 | std::to_string(nx_) + ")"); | ||
| 32 | } | ||
| 33 | ✗ | for (std::size_t i = 1; i < T_; ++i) { | |
| 34 | ✗ | const std::shared_ptr<ActionModelAbstract>& model = running_models_[i]; | |
| 35 | ✗ | if (model->get_state()->get_nx() != nx_) { | |
| 36 | ✗ | throw_pretty("Invalid argument: " | |
| 37 | << "nx in " << i | ||
| 38 | << " node is not consistent with the other nodes") | ||
| 39 | } | ||
| 40 | ✗ | if (model->get_state()->get_ndx() != ndx_) { | |
| 41 | ✗ | throw_pretty("Invalid argument: " | |
| 42 | << "ndx in " << i | ||
| 43 | << " node is not consistent with the other nodes") | ||
| 44 | } | ||
| 45 | } | ||
| 46 | ✗ | if (terminal_model_->get_state()->get_nx() != nx_) { | |
| 47 | ✗ | throw_pretty( | |
| 48 | "Invalid argument: " | ||
| 49 | << "nx in terminal node is not consistent with the other nodes") | ||
| 50 | } | ||
| 51 | ✗ | if (terminal_model_->get_state()->get_ndx() != ndx_) { | |
| 52 | ✗ | throw_pretty( | |
| 53 | "Invalid argument: " | ||
| 54 | << "ndx in terminal node is not consistent with the other nodes") | ||
| 55 | } | ||
| 56 | ✗ | allocateData(); | |
| 57 | |||
| 58 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
| 59 | if (enableMultithreading()) { | ||
| 60 | nthreads_ = CROCODDYL_WITH_NTHREADS; | ||
| 61 | } | ||
| 62 | #endif | ||
| 63 | ✗ | } | |
| 64 | |||
| 65 | template <typename Scalar> | ||
| 66 | ✗ | ShootingProblemTpl<Scalar>::ShootingProblemTpl( | |
| 67 | const VectorXs& x0, | ||
| 68 | const std::vector<std::shared_ptr<ActionModelAbstract> >& running_models, | ||
| 69 | std::shared_ptr<ActionModelAbstract> terminal_model, | ||
| 70 | const std::vector<std::shared_ptr<ActionDataAbstract> >& running_datas, | ||
| 71 | std::shared_ptr<ActionDataAbstract> terminal_data) | ||
| 72 | ✗ | : cost_(Scalar(0.)), | |
| 73 | ✗ | T_(running_models.size()), | |
| 74 | ✗ | x0_(x0), | |
| 75 | ✗ | terminal_model_(terminal_model), | |
| 76 | ✗ | terminal_data_(terminal_data), | |
| 77 | ✗ | running_models_(running_models), | |
| 78 | ✗ | running_datas_(running_datas), | |
| 79 | ✗ | nx_(running_models[0]->get_state()->get_nx()), | |
| 80 | ✗ | ndx_(running_models[0]->get_state()->get_ndx()), | |
| 81 | ✗ | nthreads_(1) { | |
| 82 | ✗ | if (static_cast<std::size_t>(x0.size()) != nx_) { | |
| 83 | ✗ | throw_pretty( | |
| 84 | "Invalid argument: " << "x0 has wrong dimension (it should be " + | ||
| 85 | std::to_string(nx_) + ")"); | ||
| 86 | } | ||
| 87 | ✗ | const std::size_t Td = running_datas.size(); | |
| 88 | ✗ | if (Td != T_) { | |
| 89 | ✗ | throw_pretty( | |
| 90 | "Invalid argument: " | ||
| 91 | << "the number of running models and datas are not the same (" + | ||
| 92 | std::to_string(T_) + " != " + std::to_string(Td) + ")") | ||
| 93 | } | ||
| 94 | ✗ | for (std::size_t i = 0; i < T_; ++i) { | |
| 95 | ✗ | const std::shared_ptr<ActionModelAbstract>& model = running_models_[i]; | |
| 96 | ✗ | const std::shared_ptr<ActionDataAbstract>& data = running_datas_[i]; | |
| 97 | ✗ | if (model->get_state()->get_nx() != nx_) { | |
| 98 | ✗ | throw_pretty("Invalid argument: " | |
| 99 | << "nx in " << i | ||
| 100 | << " node is not consistent with the other nodes") | ||
| 101 | } | ||
| 102 | ✗ | if (model->get_state()->get_ndx() != ndx_) { | |
| 103 | ✗ | throw_pretty("Invalid argument: " | |
| 104 | << "ndx in " << i | ||
| 105 | << " node is not consistent with the other nodes") | ||
| 106 | } | ||
| 107 | ✗ | if (!model->checkData(data)) { | |
| 108 | ✗ | throw_pretty("Invalid argument: " | |
| 109 | << "action data in " << i | ||
| 110 | << " node is not consistent with the action model") | ||
| 111 | } | ||
| 112 | } | ||
| 113 | ✗ | if (!terminal_model->checkData(terminal_data)) { | |
| 114 | ✗ | throw_pretty("Invalid argument: " | |
| 115 | << "terminal action data is not consistent with the terminal " | ||
| 116 | "action model") | ||
| 117 | } | ||
| 118 | |||
| 119 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
| 120 | if (enableMultithreading()) { | ||
| 121 | nthreads_ = CROCODDYL_WITH_NTHREADS; | ||
| 122 | } | ||
| 123 | #endif | ||
| 124 | ✗ | } | |
| 125 | |||
| 126 | template <typename Scalar> | ||
| 127 | ✗ | ShootingProblemTpl<Scalar>::ShootingProblemTpl( | |
| 128 | const ShootingProblemTpl<Scalar>& problem) | ||
| 129 | ✗ | : cost_(Scalar(0.)), | |
| 130 | ✗ | T_(problem.get_T()), | |
| 131 | ✗ | x0_(problem.get_x0()), | |
| 132 | ✗ | terminal_model_(problem.get_terminalModel()), | |
| 133 | ✗ | terminal_data_(problem.get_terminalData()), | |
| 134 | ✗ | running_models_(problem.get_runningModels()), | |
| 135 | ✗ | running_datas_(problem.get_runningDatas()), | |
| 136 | ✗ | nx_(problem.get_nx()), | |
| 137 | ✗ | ndx_(problem.get_ndx()) {} | |
| 138 | |||
| 139 | template <typename Scalar> | ||
| 140 | ✗ | ShootingProblemTpl<Scalar>::~ShootingProblemTpl() {} | |
| 141 | |||
| 142 | template <typename Scalar> | ||
| 143 | ✗ | Scalar ShootingProblemTpl<Scalar>::calc(const std::vector<VectorXs>& xs, | |
| 144 | const std::vector<VectorXs>& us) { | ||
| 145 | ✗ | if (xs.size() != T_ + 1) { | |
| 146 | ✗ | throw_pretty( | |
| 147 | "Invalid argument: " << "xs has wrong dimension (it should be " + | ||
| 148 | std::to_string(T_ + 1) + ")"); | ||
| 149 | } | ||
| 150 | ✗ | if (us.size() != T_) { | |
| 151 | ✗ | throw_pretty( | |
| 152 | "Invalid argument: " << "us has wrong dimension (it should be " + | ||
| 153 | std::to_string(T_) + ")"); | ||
| 154 | } | ||
| 155 | ✗ | START_PROFILER("ShootingProblem::calc"); | |
| 156 | |||
| 157 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
| 158 | #pragma omp parallel for num_threads(nthreads_) | ||
| 159 | #endif | ||
| 160 | ✗ | for (std::size_t i = 0; i < T_; ++i) { | |
| 161 | ✗ | running_models_[i]->calc(running_datas_[i], xs[i], us[i]); | |
| 162 | } | ||
| 163 | ✗ | terminal_model_->calc(terminal_data_, xs.back()); | |
| 164 | |||
| 165 | ✗ | cost_ = Scalar(0.); | |
| 166 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
| 167 | #pragma omp simd reduction(+ : cost_) | ||
| 168 | #endif | ||
| 169 | ✗ | for (std::size_t i = 0; i < T_; ++i) { | |
| 170 | ✗ | cost_ += running_datas_[i]->cost; | |
| 171 | } | ||
| 172 | ✗ | cost_ += terminal_data_->cost; | |
| 173 | ✗ | STOP_PROFILER("ShootingProblem::calc"); | |
| 174 | ✗ | return cost_; | |
| 175 | } | ||
| 176 | |||
| 177 | template <typename Scalar> | ||
| 178 | ✗ | Scalar ShootingProblemTpl<Scalar>::calcDiff(const std::vector<VectorXs>& xs, | |
| 179 | const std::vector<VectorXs>& us) { | ||
| 180 | ✗ | if (xs.size() != T_ + 1) { | |
| 181 | ✗ | throw_pretty( | |
| 182 | "Invalid argument: " << "xs has wrong dimension (it should be " + | ||
| 183 | std::to_string(T_ + 1) + ")"); | ||
| 184 | } | ||
| 185 | ✗ | if (us.size() != T_) { | |
| 186 | ✗ | throw_pretty( | |
| 187 | "Invalid argument: " << "us has wrong dimension (it should be " + | ||
| 188 | std::to_string(T_) + ")"); | ||
| 189 | } | ||
| 190 | ✗ | START_PROFILER("ShootingProblem::calcDiff"); | |
| 191 | |||
| 192 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
| 193 | #pragma omp parallel for num_threads(nthreads_) | ||
| 194 | #endif | ||
| 195 | ✗ | for (std::size_t i = 0; i < T_; ++i) { | |
| 196 | ✗ | running_models_[i]->calcDiff(running_datas_[i], xs[i], us[i]); | |
| 197 | } | ||
| 198 | ✗ | terminal_model_->calcDiff(terminal_data_, xs.back()); | |
| 199 | |||
| 200 | ✗ | cost_ = Scalar(0.); | |
| 201 | // Apply SIMD only for floating-point types | ||
| 202 | if (std::is_floating_point<Scalar>::value) { | ||
| 203 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
| 204 | #pragma omp simd reduction(+ : cost_) | ||
| 205 | #endif | ||
| 206 | ✗ | for (std::size_t i = 0; i < T_; ++i) { | |
| 207 | ✗ | cost_ += running_datas_[i]->cost; | |
| 208 | } | ||
| 209 | ✗ | cost_ += terminal_data_->cost; | |
| 210 | } else { // For non-floating-point types (e.g., CppAD types), use the normal | ||
| 211 | // loop without SIMD | ||
| 212 | for (std::size_t i = 0; i < T_; ++i) { | ||
| 213 | cost_ += running_datas_[i]->cost; | ||
| 214 | } | ||
| 215 | cost_ += terminal_data_->cost; | ||
| 216 | } | ||
| 217 | ✗ | STOP_PROFILER("ShootingProblem::calcDiff"); | |
| 218 | ✗ | return cost_; | |
| 219 | } | ||
| 220 | |||
| 221 | template <typename Scalar> | ||
| 222 | ✗ | void ShootingProblemTpl<Scalar>::rollout(const std::vector<VectorXs>& us, | |
| 223 | std::vector<VectorXs>& xs) { | ||
| 224 | ✗ | if (xs.size() != T_ + 1) { | |
| 225 | ✗ | throw_pretty( | |
| 226 | "Invalid argument: " << "xs has wrong dimension (it should be " + | ||
| 227 | std::to_string(T_ + 1) + ")"); | ||
| 228 | } | ||
| 229 | ✗ | if (us.size() != T_) { | |
| 230 | ✗ | throw_pretty( | |
| 231 | "Invalid argument: " << "us has wrong dimension (it should be " + | ||
| 232 | std::to_string(T_) + ")"); | ||
| 233 | } | ||
| 234 | ✗ | START_PROFILER("ShootingProblem::rollout"); | |
| 235 | |||
| 236 | ✗ | xs[0] = x0_; | |
| 237 | ✗ | for (std::size_t i = 0; i < T_; ++i) { | |
| 238 | ✗ | const std::shared_ptr<ActionDataAbstract>& data = running_datas_[i]; | |
| 239 | ✗ | running_models_[i]->calc(data, xs[i], us[i]); | |
| 240 | ✗ | xs[i + 1] = data->xnext; | |
| 241 | } | ||
| 242 | ✗ | terminal_model_->calc(terminal_data_, xs.back()); | |
| 243 | ✗ | STOP_PROFILER("ShootingProblem::rollout"); | |
| 244 | ✗ | } | |
| 245 | |||
| 246 | template <typename Scalar> | ||
| 247 | std::vector<typename MathBaseTpl<Scalar>::VectorXs> | ||
| 248 | ✗ | ShootingProblemTpl<Scalar>::rollout_us(const std::vector<VectorXs>& us) { | |
| 249 | ✗ | std::vector<VectorXs> xs; | |
| 250 | ✗ | xs.resize(T_ + 1); | |
| 251 | ✗ | rollout(us, xs); | |
| 252 | ✗ | return xs; | |
| 253 | ✗ | } | |
| 254 | |||
| 255 | template <typename Scalar> | ||
| 256 | ✗ | void ShootingProblemTpl<Scalar>::quasiStatic(std::vector<VectorXs>& us, | |
| 257 | const std::vector<VectorXs>& xs) { | ||
| 258 | ✗ | if (xs.size() != T_) { | |
| 259 | ✗ | throw_pretty( | |
| 260 | "Invalid argument: " << "xs has wrong dimension (it should be " + | ||
| 261 | std::to_string(T_) + ")"); | ||
| 262 | } | ||
| 263 | ✗ | if (us.size() != T_) { | |
| 264 | ✗ | throw_pretty( | |
| 265 | "Invalid argument: " << "us has wrong dimension (it should be " + | ||
| 266 | std::to_string(T_) + ")"); | ||
| 267 | } | ||
| 268 | |||
| 269 | #ifdef CROCODDYL_WITH_MULTITHREADING | ||
| 270 | #pragma omp parallel for num_threads(nthreads_) | ||
| 271 | #endif | ||
| 272 | ✗ | for (std::size_t i = 0; i < T_; ++i) { | |
| 273 | ✗ | running_models_[i]->quasiStatic(running_datas_[i], us[i], xs[i]); | |
| 274 | } | ||
| 275 | ✗ | } | |
| 276 | |||
| 277 | template <typename Scalar> | ||
| 278 | std::vector<typename MathBaseTpl<Scalar>::VectorXs> | ||
| 279 | ✗ | ShootingProblemTpl<Scalar>::quasiStatic_xs(const std::vector<VectorXs>& xs) { | |
| 280 | ✗ | std::vector<VectorXs> us; | |
| 281 | ✗ | us.resize(T_); | |
| 282 | ✗ | for (std::size_t i = 0; i < T_; ++i) { | |
| 283 | ✗ | us[i] = VectorXs::Zero(running_models_[i]->get_nu()); | |
| 284 | } | ||
| 285 | ✗ | quasiStatic(us, xs); | |
| 286 | ✗ | return us; | |
| 287 | ✗ | } | |
| 288 | |||
| 289 | template <typename Scalar> | ||
| 290 | ✗ | void ShootingProblemTpl<Scalar>::circularAppend( | |
| 291 | std::shared_ptr<ActionModelAbstract> model, | ||
| 292 | std::shared_ptr<ActionDataAbstract> data) { | ||
| 293 | ✗ | if (!model->checkData(data)) { | |
| 294 | ✗ | throw_pretty("Invalid argument: " | |
| 295 | << "action data is not consistent with the action model") | ||
| 296 | } | ||
| 297 | ✗ | if (model->get_state()->get_nx() != nx_) { | |
| 298 | ✗ | throw_pretty( | |
| 299 | "Invalid argument: " << "nx is not consistent with the other nodes") | ||
| 300 | } | ||
| 301 | ✗ | if (model->get_state()->get_ndx() != ndx_) { | |
| 302 | ✗ | throw_pretty("Invalid argument: " | |
| 303 | << "ndx node is not consistent with the other nodes") | ||
| 304 | } | ||
| 305 | ✗ | is_updated_ = true; | |
| 306 | ✗ | for (std::size_t i = 0; i < T_ - 1; ++i) { | |
| 307 | ✗ | running_models_[i] = running_models_[i + 1]; | |
| 308 | ✗ | running_datas_[i] = running_datas_[i + 1]; | |
| 309 | } | ||
| 310 | ✗ | running_models_.back() = model; | |
| 311 | ✗ | running_datas_.back() = data; | |
| 312 | ✗ | } | |
| 313 | |||
| 314 | template <typename Scalar> | ||
| 315 | ✗ | void ShootingProblemTpl<Scalar>::circularAppend( | |
| 316 | std::shared_ptr<ActionModelAbstract> model) { | ||
| 317 | ✗ | if (model->get_state()->get_nx() != nx_) { | |
| 318 | ✗ | throw_pretty( | |
| 319 | "Invalid argument: " << "nx is not consistent with the other nodes") | ||
| 320 | } | ||
| 321 | ✗ | if (model->get_state()->get_ndx() != ndx_) { | |
| 322 | ✗ | throw_pretty("Invalid argument: " | |
| 323 | << "ndx node is not consistent with the other nodes") | ||
| 324 | } | ||
| 325 | ✗ | is_updated_ = true; | |
| 326 | ✗ | for (std::size_t i = 0; i < T_ - 1; ++i) { | |
| 327 | ✗ | running_models_[i] = running_models_[i + 1]; | |
| 328 | ✗ | running_datas_[i] = running_datas_[i + 1]; | |
| 329 | } | ||
| 330 | ✗ | running_models_.back() = model; | |
| 331 | ✗ | running_datas_.back() = model->createData(); | |
| 332 | ✗ | } | |
| 333 | |||
| 334 | template <typename Scalar> | ||
| 335 | ✗ | void ShootingProblemTpl<Scalar>::updateNode( | |
| 336 | const std::size_t i, std::shared_ptr<ActionModelAbstract> model, | ||
| 337 | std::shared_ptr<ActionDataAbstract> data) { | ||
| 338 | ✗ | if (i >= T_ + 1) { | |
| 339 | ✗ | throw_pretty("Invalid argument: " | |
| 340 | << "i is bigger than the allocated horizon (it should be less " | ||
| 341 | "than or equal to " + | ||
| 342 | std::to_string(T_ + 1) + ")"); | ||
| 343 | } | ||
| 344 | ✗ | if (!model->checkData(data)) { | |
| 345 | ✗ | throw_pretty("Invalid argument: " | |
| 346 | << "action data is not consistent with the action model") | ||
| 347 | } | ||
| 348 | ✗ | if (model->get_state()->get_nx() != nx_) { | |
| 349 | ✗ | throw_pretty( | |
| 350 | "Invalid argument: " << "nx is not consistent with the other nodes") | ||
| 351 | } | ||
| 352 | ✗ | if (model->get_state()->get_ndx() != ndx_) { | |
| 353 | ✗ | throw_pretty("Invalid argument: " | |
| 354 | << "ndx node is not consistent with the other nodes") | ||
| 355 | } | ||
| 356 | ✗ | is_updated_ = true; | |
| 357 | ✗ | if (i == T_) { | |
| 358 | ✗ | terminal_model_ = model; | |
| 359 | ✗ | terminal_data_ = data; | |
| 360 | } else { | ||
| 361 | ✗ | running_models_[i] = model; | |
| 362 | ✗ | running_datas_[i] = data; | |
| 363 | } | ||
| 364 | ✗ | } | |
| 365 | |||
| 366 | template <typename Scalar> | ||
| 367 | ✗ | void ShootingProblemTpl<Scalar>::updateModel( | |
| 368 | const std::size_t i, std::shared_ptr<ActionModelAbstract> model) { | ||
| 369 | ✗ | if (i >= T_ + 1) { | |
| 370 | ✗ | throw_pretty( | |
| 371 | "Invalid argument: " | ||
| 372 | << "i is bigger than the allocated horizon (it should be lower than " + | ||
| 373 | std::to_string(T_ + 1) + ")"); | ||
| 374 | } | ||
| 375 | ✗ | if (model->get_state()->get_nx() != nx_) { | |
| 376 | ✗ | throw_pretty( | |
| 377 | "Invalid argument: " << "nx is not consistent with the other nodes") | ||
| 378 | } | ||
| 379 | ✗ | if (model->get_state()->get_ndx() != ndx_) { | |
| 380 | ✗ | throw_pretty( | |
| 381 | "Invalid argument: " << "ndx is not consistent with the other nodes") | ||
| 382 | } | ||
| 383 | ✗ | is_updated_ = true; | |
| 384 | ✗ | if (i == T_) { | |
| 385 | ✗ | terminal_model_ = model; | |
| 386 | ✗ | terminal_data_ = terminal_model_->createData(); | |
| 387 | } else { | ||
| 388 | ✗ | running_models_[i] = model; | |
| 389 | ✗ | running_datas_[i] = model->createData(); | |
| 390 | } | ||
| 391 | ✗ | } | |
| 392 | |||
| 393 | template <typename Scalar> | ||
| 394 | template <typename NewScalar> | ||
| 395 | ✗ | ShootingProblemTpl<NewScalar> ShootingProblemTpl<Scalar>::cast() const { | |
| 396 | typedef ShootingProblemTpl<NewScalar> ReturnType; | ||
| 397 | ✗ | ReturnType ret(x0_.template cast<NewScalar>(), | |
| 398 | ✗ | vector_cast<NewScalar>(running_models_), | |
| 399 | ✗ | terminal_model_->template cast<NewScalar>()); | |
| 400 | ✗ | ret.set_nthreads((int)nthreads_); | |
| 401 | ✗ | return ret; | |
| 402 | ✗ | } | |
| 403 | |||
| 404 | template <typename Scalar> | ||
| 405 | ✗ | std::size_t ShootingProblemTpl<Scalar>::get_T() const { | |
| 406 | ✗ | return T_; | |
| 407 | } | ||
| 408 | |||
| 409 | template <typename Scalar> | ||
| 410 | const typename MathBaseTpl<Scalar>::VectorXs& | ||
| 411 | ✗ | ShootingProblemTpl<Scalar>::get_x0() const { | |
| 412 | ✗ | return x0_; | |
| 413 | } | ||
| 414 | |||
| 415 | template <typename Scalar> | ||
| 416 | ✗ | void ShootingProblemTpl<Scalar>::allocateData() { | |
| 417 | ✗ | running_datas_.resize(T_); | |
| 418 | ✗ | for (std::size_t i = 0; i < T_; ++i) { | |
| 419 | ✗ | const std::shared_ptr<ActionModelAbstract>& model = running_models_[i]; | |
| 420 | ✗ | running_datas_[i] = model->createData(); | |
| 421 | } | ||
| 422 | ✗ | terminal_data_ = terminal_model_->createData(); | |
| 423 | ✗ | } | |
| 424 | |||
| 425 | template <typename Scalar> | ||
| 426 | const std::vector<std::shared_ptr<crocoddyl::ActionModelAbstractTpl<Scalar> > >& | ||
| 427 | ✗ | ShootingProblemTpl<Scalar>::get_runningModels() const { | |
| 428 | ✗ | return running_models_; | |
| 429 | } | ||
| 430 | |||
| 431 | template <typename Scalar> | ||
| 432 | const std::shared_ptr<crocoddyl::ActionModelAbstractTpl<Scalar> >& | ||
| 433 | ✗ | ShootingProblemTpl<Scalar>::get_terminalModel() const { | |
| 434 | ✗ | return terminal_model_; | |
| 435 | } | ||
| 436 | |||
| 437 | template <typename Scalar> | ||
| 438 | const std::vector<std::shared_ptr<crocoddyl::ActionDataAbstractTpl<Scalar> > >& | ||
| 439 | ✗ | ShootingProblemTpl<Scalar>::get_runningDatas() const { | |
| 440 | ✗ | return running_datas_; | |
| 441 | } | ||
| 442 | |||
| 443 | template <typename Scalar> | ||
| 444 | const std::shared_ptr<crocoddyl::ActionDataAbstractTpl<Scalar> >& | ||
| 445 | ✗ | ShootingProblemTpl<Scalar>::get_terminalData() const { | |
| 446 | ✗ | return terminal_data_; | |
| 447 | } | ||
| 448 | |||
| 449 | template <typename Scalar> | ||
| 450 | ✗ | void ShootingProblemTpl<Scalar>::set_x0(const VectorXs& x0_in) { | |
| 451 | ✗ | if (x0_in.size() != x0_.size()) { | |
| 452 | ✗ | throw_pretty("Invalid argument: " | |
| 453 | << "invalid size of x0 provided: Expected " << x0_.size() | ||
| 454 | << ", received " << x0_in.size()); | ||
| 455 | } | ||
| 456 | ✗ | x0_ = x0_in; | |
| 457 | ✗ | } | |
| 458 | |||
| 459 | template <typename Scalar> | ||
| 460 | ✗ | void ShootingProblemTpl<Scalar>::set_runningModels( | |
| 461 | const std::vector<std::shared_ptr<ActionModelAbstract> >& models) { | ||
| 462 | ✗ | for (std::size_t i = 0; i < T_; ++i) { | |
| 463 | ✗ | const std::shared_ptr<ActionModelAbstract>& model = models[i]; | |
| 464 | ✗ | if (model->get_state()->get_nx() != nx_) { | |
| 465 | ✗ | throw_pretty("Invalid argument: " | |
| 466 | << "nx in " << i | ||
| 467 | << " node is not consistent with the other nodes") | ||
| 468 | } | ||
| 469 | ✗ | if (model->get_state()->get_ndx() != ndx_) { | |
| 470 | ✗ | throw_pretty("Invalid argument: " | |
| 471 | << "ndx in " << i | ||
| 472 | << " node is not consistent with the other nodes") | ||
| 473 | } | ||
| 474 | } | ||
| 475 | ✗ | is_updated_ = true; | |
| 476 | ✗ | T_ = models.size(); | |
| 477 | ✗ | running_models_.clear(); | |
| 478 | ✗ | running_datas_.clear(); | |
| 479 | ✗ | for (std::size_t i = 0; i < T_; ++i) { | |
| 480 | ✗ | const std::shared_ptr<ActionModelAbstract>& model = running_models_[i]; | |
| 481 | ✗ | running_models_.push_back(model); | |
| 482 | ✗ | running_datas_.push_back(model->createData()); | |
| 483 | } | ||
| 484 | ✗ | } | |
| 485 | |||
| 486 | template <typename Scalar> | ||
| 487 | ✗ | void ShootingProblemTpl<Scalar>::set_terminalModel( | |
| 488 | std::shared_ptr<ActionModelAbstract> model) { | ||
| 489 | ✗ | if (model->get_state()->get_nx() != nx_) { | |
| 490 | ✗ | throw_pretty( | |
| 491 | "Invalid argument: " << "nx is not consistent with the other nodes") | ||
| 492 | } | ||
| 493 | ✗ | if (model->get_state()->get_ndx() != ndx_) { | |
| 494 | ✗ | throw_pretty( | |
| 495 | "Invalid argument: " << "ndx is not consistent with the other nodes") | ||
| 496 | } | ||
| 497 | ✗ | is_updated_ = true; | |
| 498 | ✗ | terminal_model_ = model; | |
| 499 | ✗ | terminal_data_ = terminal_model_->createData(); | |
| 500 | ✗ | } | |
| 501 | |||
| 502 | template <typename Scalar> | ||
| 503 | ✗ | void ShootingProblemTpl<Scalar>::set_nthreads(const int nthreads) { | |
| 504 | #ifndef CROCODDYL_WITH_MULTITHREADING | ||
| 505 | (void)nthreads; | ||
| 506 | ✗ | std::cerr << "Warning: the number of threads won't affect the computational " | |
| 507 | "performance as multithreading " | ||
| 508 | "support is not enabled." | ||
| 509 | ✗ | << std::endl; | |
| 510 | #else | ||
| 511 | if (nthreads < 1) { | ||
| 512 | nthreads_ = CROCODDYL_WITH_NTHREADS; | ||
| 513 | } else { | ||
| 514 | nthreads_ = static_cast<std::size_t>(nthreads); | ||
| 515 | } | ||
| 516 | if (!enableMultithreading()) { | ||
| 517 | std::cerr << "Warning: the number of threads won't affect the " | ||
| 518 | "computational performance as multithreading " | ||
| 519 | "support is not enabled." | ||
| 520 | << std::endl; | ||
| 521 | nthreads_ = 1; | ||
| 522 | } | ||
| 523 | #endif | ||
| 524 | ✗ | } | |
| 525 | |||
| 526 | template <typename Scalar> | ||
| 527 | ✗ | std::size_t ShootingProblemTpl<Scalar>::get_nx() const { | |
| 528 | ✗ | return nx_; | |
| 529 | } | ||
| 530 | |||
| 531 | template <typename Scalar> | ||
| 532 | ✗ | std::size_t ShootingProblemTpl<Scalar>::get_ndx() const { | |
| 533 | ✗ | return ndx_; | |
| 534 | } | ||
| 535 | |||
| 536 | template <typename Scalar> | ||
| 537 | ✗ | std::size_t ShootingProblemTpl<Scalar>::get_nthreads() const { | |
| 538 | #ifndef CROCODDYL_WITH_MULTITHREADING | ||
| 539 | ✗ | std::cerr << "Warning: the number of threads won't affect the computational " | |
| 540 | "performance as multithreading " | ||
| 541 | "support is not enabled." | ||
| 542 | ✗ | << std::endl; | |
| 543 | #endif | ||
| 544 | ✗ | return nthreads_; | |
| 545 | } | ||
| 546 | |||
| 547 | template <typename Scalar> | ||
| 548 | ✗ | bool ShootingProblemTpl<Scalar>::is_updated() { | |
| 549 | ✗ | const bool status = is_updated_; | |
| 550 | ✗ | is_updated_ = false; | |
| 551 | ✗ | return status; | |
| 552 | } | ||
| 553 | |||
| 554 | template <typename Scalar> | ||
| 555 | ✗ | std::ostream& operator<<(std::ostream& os, | |
| 556 | const ShootingProblemTpl<Scalar>& problem) { | ||
| 557 | ✗ | os << "ShootingProblem (T=" << problem.get_T() << ", nx=" << problem.get_nx() | |
| 558 | ✗ | << ", ndx=" << problem.get_ndx() << ") " << std::endl | |
| 559 | ✗ | << " Models:" << std::endl; | |
| 560 | const std::vector< | ||
| 561 | std::shared_ptr<crocoddyl::ActionModelAbstractTpl<Scalar> > >& | ||
| 562 | ✗ | runningModels = problem.get_runningModels(); | |
| 563 | ✗ | for (std::size_t t = 0; t < problem.get_T(); ++t) { | |
| 564 | ✗ | os << " " << t << ": " << *runningModels[t] << std::endl; | |
| 565 | } | ||
| 566 | ✗ | os << " " << problem.get_T() << ": " << *problem.get_terminalModel(); | |
| 567 | ✗ | return os; | |
| 568 | } | ||
| 569 | |||
| 570 | } // namespace crocoddyl | ||
| 571 |