Directory: | ./ |
---|---|
File: | include/crocoddyl/core/numdiff/state.hxx |
Date: | 2025-03-26 19:23:43 |
Exec | Total | Coverage | |
---|---|---|---|
Lines: | 82 | 111 | 73.9% |
Branches: | 125 | 504 | 24.8% |
Line | Branch | Exec | Source |
---|---|---|---|
1 | /////////////////////////////////////////////////////////////////////////////// | ||
2 | // BSD 3-Clause License | ||
3 | // | ||
4 | // Copyright (C) 2019-2025, LAAS-CNRS, University of Edinburgh, | ||
5 | // New York University, Max Planck Gesellschaft, | ||
6 | // Heriot-Watt University | ||
7 | // Copyright note valid unless otherwise stated in individual files. | ||
8 | // All rights reserved. | ||
9 | /////////////////////////////////////////////////////////////////////////////// | ||
10 | |||
11 | namespace crocoddyl { | ||
12 | |||
13 | template <typename Scalar> | ||
14 | 29 | StateNumDiffTpl<Scalar>::StateNumDiffTpl(std::shared_ptr<Base> state) | |
15 | : Base(state->get_nx(), state->get_ndx()), | ||
16 | 29 | state_(state), | |
17 | 29 | e_jac_(sqrt(Scalar(2.0) * std::numeric_limits<Scalar>::epsilon())) {} | |
18 | |||
19 | template <typename Scalar> | ||
20 | 62 | StateNumDiffTpl<Scalar>::~StateNumDiffTpl() {} | |
21 | |||
22 | template <typename Scalar> | ||
23 | ✗ | typename MathBaseTpl<Scalar>::VectorXs StateNumDiffTpl<Scalar>::zero() const { | |
24 | ✗ | return state_->zero(); | |
25 | } | ||
26 | |||
27 | template <typename Scalar> | ||
28 | ✗ | typename MathBaseTpl<Scalar>::VectorXs StateNumDiffTpl<Scalar>::rand() const { | |
29 | ✗ | return state_->rand(); | |
30 | } | ||
31 | |||
32 | template <typename Scalar> | ||
33 | 3580 | void StateNumDiffTpl<Scalar>::diff(const Eigen::Ref<const VectorXs>& x0, | |
34 | const Eigen::Ref<const VectorXs>& x1, | ||
35 | Eigen::Ref<VectorXs> dxout) const { | ||
36 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 3580 times.
|
3580 | if (static_cast<std::size_t>(x0.size()) != nx_) { |
37 | ✗ | throw_pretty( | |
38 | "Invalid argument: " << "x0 has wrong dimension (it should be " + | ||
39 | std::to_string(nx_) + ")"); | ||
40 | } | ||
41 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 3580 times.
|
3580 | if (static_cast<std::size_t>(x1.size()) != nx_) { |
42 | ✗ | throw_pretty( | |
43 | "Invalid argument: " << "x1 has wrong dimension (it should be " + | ||
44 | std::to_string(nx_) + ")"); | ||
45 | } | ||
46 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 3580 times.
|
3580 | if (static_cast<std::size_t>(dxout.size()) != ndx_) { |
47 | ✗ | throw_pretty( | |
48 | "Invalid argument: " << "dxout has wrong dimension (it should be " + | ||
49 | std::to_string(ndx_) + ")"); | ||
50 | } | ||
51 |
1/2✓ Branch 3 taken 3580 times.
✗ Branch 4 not taken.
|
3580 | state_->diff(x0, x1, dxout); |
52 | 3580 | } | |
53 | |||
54 | template <typename Scalar> | ||
55 | 4468 | void StateNumDiffTpl<Scalar>::integrate(const Eigen::Ref<const VectorXs>& x, | |
56 | const Eigen::Ref<const VectorXs>& dx, | ||
57 | Eigen::Ref<VectorXs> xout) const { | ||
58 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 4468 times.
|
4468 | if (static_cast<std::size_t>(x.size()) != nx_) { |
59 | ✗ | throw_pretty( | |
60 | "Invalid argument: " << "x has wrong dimension (it should be " + | ||
61 | std::to_string(nx_) + ")"); | ||
62 | } | ||
63 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 4468 times.
|
4468 | if (static_cast<std::size_t>(dx.size()) != ndx_) { |
64 | ✗ | throw_pretty( | |
65 | "Invalid argument: " << "dx has wrong dimension (it should be " + | ||
66 | std::to_string(ndx_) + ")"); | ||
67 | } | ||
68 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 4468 times.
|
4468 | if (static_cast<std::size_t>(xout.size()) != nx_) { |
69 | ✗ | throw_pretty( | |
70 | "Invalid argument: " << "xout has wrong dimension (it should be " + | ||
71 | std::to_string(nx_) + ")"); | ||
72 | } | ||
73 |
1/2✓ Branch 3 taken 4468 times.
✗ Branch 4 not taken.
|
4468 | state_->integrate(x, dx, xout); |
74 | 4468 | } | |
75 | |||
76 | template <typename Scalar> | ||
77 | 28 | void StateNumDiffTpl<Scalar>::Jdiff(const Eigen::Ref<const VectorXs>& x0, | |
78 | const Eigen::Ref<const VectorXs>& x1, | ||
79 | Eigen::Ref<MatrixXs> Jfirst, | ||
80 | Eigen::Ref<MatrixXs> Jsecond, | ||
81 | Jcomponent firstsecond) const { | ||
82 |
1/10✗ Branch 1 not taken.
✓ Branch 2 taken 28 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
|
28 | assert_pretty( |
83 | is_a_Jcomponent(firstsecond), | ||
84 | ("firstsecond must be one of the Jcomponent {both, first, second}")); | ||
85 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 28 times.
|
28 | if (static_cast<std::size_t>(x0.size()) != nx_) { |
86 | ✗ | throw_pretty( | |
87 | "Invalid argument: " << "x0 has wrong dimension (it should be " + | ||
88 | std::to_string(nx_) + ")"); | ||
89 | } | ||
90 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 28 times.
|
28 | if (static_cast<std::size_t>(x1.size()) != nx_) { |
91 | ✗ | throw_pretty( | |
92 | "Invalid argument: " << "x1 has wrong dimension (it should be " + | ||
93 | std::to_string(nx_) + ")"); | ||
94 | } | ||
95 |
2/4✓ Branch 1 taken 28 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 28 times.
✗ Branch 5 not taken.
|
28 | VectorXs tmp_x_ = VectorXs::Zero(nx_); |
96 |
2/4✓ Branch 1 taken 28 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 28 times.
✗ Branch 5 not taken.
|
28 | VectorXs dx_ = VectorXs::Zero(ndx_); |
97 |
2/4✓ Branch 1 taken 28 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 28 times.
✗ Branch 5 not taken.
|
28 | VectorXs dx0_ = VectorXs::Zero(ndx_); |
98 | |||
99 |
1/2✓ Branch 1 taken 28 times.
✗ Branch 2 not taken.
|
28 | dx_.setZero(); |
100 |
2/4✓ Branch 1 taken 28 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 28 times.
✗ Branch 5 not taken.
|
28 | diff(x0, x1, dx0_); |
101 |
4/4✓ Branch 0 taken 21 times.
✓ Branch 1 taken 7 times.
✓ Branch 2 taken 14 times.
✓ Branch 3 taken 7 times.
|
28 | if (firstsecond == first || firstsecond == both) { |
102 |
1/2✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
|
21 | const Scalar x0h_jac = e_jac_ * std::max(Scalar(1.), x0.norm()); |
103 |
2/4✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 21 times.
|
42 | if (static_cast<std::size_t>(Jfirst.rows()) != ndx_ || |
104 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 21 times.
|
21 | static_cast<std::size_t>(Jfirst.cols()) != ndx_) { |
105 | ✗ | throw_pretty( | |
106 | "Invalid argument: " << "Jfirst has wrong dimension (it should be " + | ||
107 | std::to_string(ndx_) + "," + | ||
108 | std::to_string(ndx_) + ")"); | ||
109 | } | ||
110 |
1/2✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
|
21 | Jfirst.setZero(); |
111 |
2/2✓ Branch 0 taken 888 times.
✓ Branch 1 taken 21 times.
|
909 | for (std::size_t i = 0; i < ndx_; ++i) { |
112 |
1/2✓ Branch 1 taken 888 times.
✗ Branch 2 not taken.
|
888 | dx_(i) = x0h_jac; |
113 |
3/6✓ Branch 1 taken 888 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 888 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 888 times.
✗ Branch 8 not taken.
|
888 | integrate(x0, dx_, tmp_x_); |
114 |
4/8✓ Branch 1 taken 888 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 888 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 888 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 888 times.
✗ Branch 11 not taken.
|
888 | diff(tmp_x_, x1, Jfirst.col(i)); |
115 |
2/4✓ Branch 1 taken 888 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 888 times.
✗ Branch 5 not taken.
|
888 | Jfirst.col(i) -= dx0_; |
116 |
1/2✓ Branch 1 taken 888 times.
✗ Branch 2 not taken.
|
888 | dx_(i) = Scalar(0.); |
117 | } | ||
118 |
1/2✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
|
21 | Jfirst /= x0h_jac; |
119 | } | ||
120 |
4/4✓ Branch 0 taken 21 times.
✓ Branch 1 taken 7 times.
✓ Branch 2 taken 14 times.
✓ Branch 3 taken 7 times.
|
28 | if (firstsecond == second || firstsecond == both) { |
121 |
1/2✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
|
21 | const Scalar x1h_jac = e_jac_ * std::max(Scalar(1.), x1.norm()); |
122 |
2/4✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 21 times.
|
42 | if (static_cast<std::size_t>(Jsecond.rows()) != ndx_ || |
123 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 21 times.
|
21 | static_cast<std::size_t>(Jsecond.cols()) != ndx_) { |
124 | ✗ | throw_pretty( | |
125 | "Invalid argument: " << "Jsecond has wrong dimension (it should be " + | ||
126 | std::to_string(ndx_) + "," + | ||
127 | std::to_string(ndx_) + ")"); | ||
128 | } | ||
129 | |||
130 |
1/2✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
|
21 | Jsecond.setZero(); |
131 |
2/2✓ Branch 0 taken 888 times.
✓ Branch 1 taken 21 times.
|
909 | for (std::size_t i = 0; i < ndx_; ++i) { |
132 |
1/2✓ Branch 1 taken 888 times.
✗ Branch 2 not taken.
|
888 | dx_(i) = x1h_jac; |
133 |
3/6✓ Branch 1 taken 888 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 888 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 888 times.
✗ Branch 8 not taken.
|
888 | integrate(x1, dx_, tmp_x_); |
134 |
4/8✓ Branch 1 taken 888 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 888 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 888 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 888 times.
✗ Branch 11 not taken.
|
888 | diff(x0, tmp_x_, Jsecond.col(i)); |
135 |
2/4✓ Branch 1 taken 888 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 888 times.
✗ Branch 5 not taken.
|
888 | Jsecond.col(i) -= dx0_; |
136 |
1/2✓ Branch 1 taken 888 times.
✗ Branch 2 not taken.
|
888 | dx_(i) = Scalar(0.); |
137 | } | ||
138 |
1/2✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
|
21 | Jsecond /= x1h_jac; |
139 | } | ||
140 | 28 | } | |
141 | |||
142 | template <typename Scalar> | ||
143 | 28 | void StateNumDiffTpl<Scalar>::Jintegrate(const Eigen::Ref<const VectorXs>& x, | |
144 | const Eigen::Ref<const VectorXs>& dx, | ||
145 | Eigen::Ref<MatrixXs> Jfirst, | ||
146 | Eigen::Ref<MatrixXs> Jsecond, | ||
147 | const Jcomponent firstsecond, | ||
148 | const AssignmentOp) const { | ||
149 |
1/10✗ Branch 1 not taken.
✓ Branch 2 taken 28 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
|
28 | assert_pretty( |
150 | is_a_Jcomponent(firstsecond), | ||
151 | ("firstsecond must be one of the Jcomponent {both, first, second}")); | ||
152 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 28 times.
|
28 | if (static_cast<std::size_t>(x.size()) != nx_) { |
153 | ✗ | throw_pretty( | |
154 | "Invalid argument: " << "x has wrong dimension (it should be " + | ||
155 | std::to_string(nx_) + ")"); | ||
156 | } | ||
157 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 28 times.
|
28 | if (static_cast<std::size_t>(dx.size()) != ndx_) { |
158 | ✗ | throw_pretty( | |
159 | "Invalid argument: " << "dx has wrong dimension (it should be " + | ||
160 | std::to_string(ndx_) + ")"); | ||
161 | } | ||
162 |
2/4✓ Branch 1 taken 28 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 28 times.
✗ Branch 5 not taken.
|
28 | VectorXs tmp_x_ = VectorXs::Zero(nx_); |
163 |
2/4✓ Branch 1 taken 28 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 28 times.
✗ Branch 5 not taken.
|
28 | VectorXs dx_ = VectorXs::Zero(ndx_); |
164 |
2/4✓ Branch 1 taken 28 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 28 times.
✗ Branch 5 not taken.
|
28 | VectorXs x0_ = VectorXs::Zero(nx_); |
165 | |||
166 | // x0_ = integrate(x, dx) | ||
167 |
2/4✓ Branch 1 taken 28 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 28 times.
✗ Branch 5 not taken.
|
28 | integrate(x, dx, x0_); |
168 | |||
169 |
4/4✓ Branch 0 taken 21 times.
✓ Branch 1 taken 7 times.
✓ Branch 2 taken 14 times.
✓ Branch 3 taken 7 times.
|
28 | if (firstsecond == first || firstsecond == both) { |
170 |
1/2✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
|
21 | const Scalar xh_jac = e_jac_ * std::max(Scalar(1.), x.norm()); |
171 |
2/4✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 21 times.
|
42 | if (static_cast<std::size_t>(Jfirst.rows()) != ndx_ || |
172 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 21 times.
|
21 | static_cast<std::size_t>(Jfirst.cols()) != ndx_) { |
173 | ✗ | throw_pretty( | |
174 | "Invalid argument: " << "Jfirst has wrong dimension (it should be " + | ||
175 | std::to_string(ndx_) + "," + | ||
176 | std::to_string(ndx_) + ")"); | ||
177 | } | ||
178 |
1/2✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
|
21 | Jfirst.setZero(); |
179 |
2/2✓ Branch 0 taken 888 times.
✓ Branch 1 taken 21 times.
|
909 | for (std::size_t i = 0; i < ndx_; ++i) { |
180 |
1/2✓ Branch 1 taken 888 times.
✗ Branch 2 not taken.
|
888 | dx_(i) = xh_jac; |
181 |
3/6✓ Branch 1 taken 888 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 888 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 888 times.
✗ Branch 8 not taken.
|
888 | integrate(x, dx_, tmp_x_); |
182 |
3/6✓ Branch 1 taken 888 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 888 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 888 times.
✗ Branch 8 not taken.
|
888 | integrate(tmp_x_, dx, tmp_x_); |
183 |
5/10✓ Branch 1 taken 888 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 888 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 888 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 888 times.
✗ Branch 11 not taken.
✓ Branch 13 taken 888 times.
✗ Branch 14 not taken.
|
888 | diff(x0_, tmp_x_, Jfirst.col(i)); |
184 |
1/2✓ Branch 1 taken 888 times.
✗ Branch 2 not taken.
|
888 | dx_(i) = Scalar(0.); |
185 | } | ||
186 |
1/2✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
|
21 | Jfirst /= xh_jac; |
187 | } | ||
188 |
4/4✓ Branch 0 taken 21 times.
✓ Branch 1 taken 7 times.
✓ Branch 2 taken 14 times.
✓ Branch 3 taken 7 times.
|
28 | if (firstsecond == second || firstsecond == both) { |
189 |
1/2✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
|
21 | const Scalar dxh_jac = e_jac_ * std::max(Scalar(1.), dx.norm()); |
190 |
2/4✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 21 times.
|
42 | if (static_cast<std::size_t>(Jsecond.rows()) != ndx_ || |
191 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 21 times.
|
21 | static_cast<std::size_t>(Jsecond.cols()) != ndx_) { |
192 | ✗ | throw_pretty( | |
193 | "Invalid argument: " << "Jsecond has wrong dimension (it should be " + | ||
194 | std::to_string(ndx_) + "," + | ||
195 | std::to_string(ndx_) + ")"); | ||
196 | } | ||
197 |
1/2✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
|
21 | Jsecond.setZero(); |
198 |
2/2✓ Branch 0 taken 888 times.
✓ Branch 1 taken 21 times.
|
909 | for (std::size_t i = 0; i < ndx_; ++i) { |
199 |
1/2✓ Branch 1 taken 888 times.
✗ Branch 2 not taken.
|
888 | dx_(i) = dxh_jac; |
200 |
4/8✓ Branch 1 taken 888 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 888 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 888 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 888 times.
✗ Branch 11 not taken.
|
888 | integrate(x, dx + dx_, tmp_x_); |
201 |
5/10✓ Branch 1 taken 888 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 888 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 888 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 888 times.
✗ Branch 11 not taken.
✓ Branch 13 taken 888 times.
✗ Branch 14 not taken.
|
888 | diff(x0_, tmp_x_, Jsecond.col(i)); |
202 |
1/2✓ Branch 1 taken 888 times.
✗ Branch 2 not taken.
|
888 | dx_(i) = Scalar(0.); |
203 | } | ||
204 |
1/2✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
|
21 | Jsecond /= dxh_jac; |
205 | } | ||
206 | 28 | } | |
207 | |||
208 | template <typename Scalar> | ||
209 | template <typename NewScalar> | ||
210 | ✗ | StateNumDiffTpl<NewScalar> StateNumDiffTpl<Scalar>::cast() const { | |
211 | typedef StateNumDiffTpl<NewScalar> ReturnType; | ||
212 | ✗ | ReturnType res(state_->template cast<NewScalar>()); | |
213 | ✗ | return res; | |
214 | } | ||
215 | |||
216 | template <typename Scalar> | ||
217 | ✗ | void StateNumDiffTpl<Scalar>::JintegrateTransport( | |
218 | const Eigen::Ref<const VectorXs>&, const Eigen::Ref<const VectorXs>&, | ||
219 | ✗ | Eigen::Ref<MatrixXs>, const Jcomponent) const {} | |
220 | |||
221 | template <typename Scalar> | ||
222 | ✗ | const Scalar StateNumDiffTpl<Scalar>::get_disturbance() const { | |
223 | ✗ | return e_jac_; | |
224 | } | ||
225 | |||
226 | template <typename Scalar> | ||
227 | ✗ | void StateNumDiffTpl<Scalar>::set_disturbance(Scalar disturbance) { | |
228 | ✗ | if (disturbance < Scalar(0.)) { | |
229 | ✗ | throw_pretty("Invalid argument: " << "Disturbance constant is positive"); | |
230 | } | ||
231 | ✗ | e_jac_ = disturbance; | |
232 | } | ||
233 | |||
234 | } // namespace crocoddyl | ||
235 |