Line |
Branch |
Exec |
Source |
1 |
|
|
/////////////////////////////////////////////////////////////////////////////// |
2 |
|
|
// BSD 3-Clause License |
3 |
|
|
// |
4 |
|
|
// Copyright (C) 2019-2025, LAAS-CNRS, University of Edinburgh, |
5 |
|
|
// New York University, Max Planck Gesellschaft, |
6 |
|
|
// Heriot-Watt University |
7 |
|
|
// Copyright note valid unless otherwise stated in individual files. |
8 |
|
|
// All rights reserved. |
9 |
|
|
/////////////////////////////////////////////////////////////////////////////// |
10 |
|
|
|
11 |
|
|
namespace crocoddyl { |
12 |
|
|
|
13 |
|
|
template <typename Scalar> |
14 |
|
✗ |
StateNumDiffTpl<Scalar>::StateNumDiffTpl(std::shared_ptr<Base> state) |
15 |
|
|
: Base(state->get_nx(), state->get_ndx()), |
16 |
|
✗ |
state_(state), |
17 |
|
✗ |
e_jac_(sqrt(Scalar(2.0) * std::numeric_limits<Scalar>::epsilon())) {} |
18 |
|
|
|
19 |
|
|
template <typename Scalar> |
20 |
|
✗ |
StateNumDiffTpl<Scalar>::~StateNumDiffTpl() {} |
21 |
|
|
|
22 |
|
|
template <typename Scalar> |
23 |
|
✗ |
typename MathBaseTpl<Scalar>::VectorXs StateNumDiffTpl<Scalar>::zero() const { |
24 |
|
✗ |
return state_->zero(); |
25 |
|
|
} |
26 |
|
|
|
27 |
|
|
template <typename Scalar> |
28 |
|
✗ |
typename MathBaseTpl<Scalar>::VectorXs StateNumDiffTpl<Scalar>::rand() const { |
29 |
|
✗ |
return state_->rand(); |
30 |
|
|
} |
31 |
|
|
|
32 |
|
|
template <typename Scalar> |
33 |
|
✗ |
void StateNumDiffTpl<Scalar>::diff(const Eigen::Ref<const VectorXs>& x0, |
34 |
|
|
const Eigen::Ref<const VectorXs>& x1, |
35 |
|
|
Eigen::Ref<VectorXs> dxout) const { |
36 |
|
✗ |
if (static_cast<std::size_t>(x0.size()) != nx_) { |
37 |
|
✗ |
throw_pretty( |
38 |
|
|
"Invalid argument: " << "x0 has wrong dimension (it should be " + |
39 |
|
|
std::to_string(nx_) + ")"); |
40 |
|
|
} |
41 |
|
✗ |
if (static_cast<std::size_t>(x1.size()) != nx_) { |
42 |
|
✗ |
throw_pretty( |
43 |
|
|
"Invalid argument: " << "x1 has wrong dimension (it should be " + |
44 |
|
|
std::to_string(nx_) + ")"); |
45 |
|
|
} |
46 |
|
✗ |
if (static_cast<std::size_t>(dxout.size()) != ndx_) { |
47 |
|
✗ |
throw_pretty( |
48 |
|
|
"Invalid argument: " << "dxout has wrong dimension (it should be " + |
49 |
|
|
std::to_string(ndx_) + ")"); |
50 |
|
|
} |
51 |
|
✗ |
state_->diff(x0, x1, dxout); |
52 |
|
|
} |
53 |
|
|
|
54 |
|
|
template <typename Scalar> |
55 |
|
✗ |
void StateNumDiffTpl<Scalar>::integrate(const Eigen::Ref<const VectorXs>& x, |
56 |
|
|
const Eigen::Ref<const VectorXs>& dx, |
57 |
|
|
Eigen::Ref<VectorXs> xout) const { |
58 |
|
✗ |
if (static_cast<std::size_t>(x.size()) != nx_) { |
59 |
|
✗ |
throw_pretty( |
60 |
|
|
"Invalid argument: " << "x has wrong dimension (it should be " + |
61 |
|
|
std::to_string(nx_) + ")"); |
62 |
|
|
} |
63 |
|
✗ |
if (static_cast<std::size_t>(dx.size()) != ndx_) { |
64 |
|
✗ |
throw_pretty( |
65 |
|
|
"Invalid argument: " << "dx has wrong dimension (it should be " + |
66 |
|
|
std::to_string(ndx_) + ")"); |
67 |
|
|
} |
68 |
|
✗ |
if (static_cast<std::size_t>(xout.size()) != nx_) { |
69 |
|
✗ |
throw_pretty( |
70 |
|
|
"Invalid argument: " << "xout has wrong dimension (it should be " + |
71 |
|
|
std::to_string(nx_) + ")"); |
72 |
|
|
} |
73 |
|
✗ |
state_->integrate(x, dx, xout); |
74 |
|
|
} |
75 |
|
|
|
76 |
|
|
template <typename Scalar> |
77 |
|
✗ |
void StateNumDiffTpl<Scalar>::Jdiff(const Eigen::Ref<const VectorXs>& x0, |
78 |
|
|
const Eigen::Ref<const VectorXs>& x1, |
79 |
|
|
Eigen::Ref<MatrixXs> Jfirst, |
80 |
|
|
Eigen::Ref<MatrixXs> Jsecond, |
81 |
|
|
Jcomponent firstsecond) const { |
82 |
|
✗ |
assert_pretty( |
83 |
|
|
is_a_Jcomponent(firstsecond), |
84 |
|
|
("firstsecond must be one of the Jcomponent {both, first, second}")); |
85 |
|
✗ |
if (static_cast<std::size_t>(x0.size()) != nx_) { |
86 |
|
✗ |
throw_pretty( |
87 |
|
|
"Invalid argument: " << "x0 has wrong dimension (it should be " + |
88 |
|
|
std::to_string(nx_) + ")"); |
89 |
|
|
} |
90 |
|
✗ |
if (static_cast<std::size_t>(x1.size()) != nx_) { |
91 |
|
✗ |
throw_pretty( |
92 |
|
|
"Invalid argument: " << "x1 has wrong dimension (it should be " + |
93 |
|
|
std::to_string(nx_) + ")"); |
94 |
|
|
} |
95 |
|
✗ |
VectorXs tmp_x_ = VectorXs::Zero(nx_); |
96 |
|
✗ |
VectorXs dx_ = VectorXs::Zero(ndx_); |
97 |
|
✗ |
VectorXs dx0_ = VectorXs::Zero(ndx_); |
98 |
|
|
|
99 |
|
✗ |
dx_.setZero(); |
100 |
|
✗ |
diff(x0, x1, dx0_); |
101 |
|
✗ |
if (firstsecond == first || firstsecond == both) { |
102 |
|
✗ |
const Scalar x0h_jac = e_jac_ * std::max(Scalar(1.), x0.norm()); |
103 |
|
✗ |
if (static_cast<std::size_t>(Jfirst.rows()) != ndx_ || |
104 |
|
✗ |
static_cast<std::size_t>(Jfirst.cols()) != ndx_) { |
105 |
|
✗ |
throw_pretty( |
106 |
|
|
"Invalid argument: " << "Jfirst has wrong dimension (it should be " + |
107 |
|
|
std::to_string(ndx_) + "," + |
108 |
|
|
std::to_string(ndx_) + ")"); |
109 |
|
|
} |
110 |
|
✗ |
Jfirst.setZero(); |
111 |
|
✗ |
for (std::size_t i = 0; i < ndx_; ++i) { |
112 |
|
✗ |
dx_(i) = x0h_jac; |
113 |
|
✗ |
integrate(x0, dx_, tmp_x_); |
114 |
|
✗ |
diff(tmp_x_, x1, Jfirst.col(i)); |
115 |
|
✗ |
Jfirst.col(i) -= dx0_; |
116 |
|
✗ |
dx_(i) = Scalar(0.); |
117 |
|
|
} |
118 |
|
✗ |
Jfirst /= x0h_jac; |
119 |
|
|
} |
120 |
|
✗ |
if (firstsecond == second || firstsecond == both) { |
121 |
|
✗ |
const Scalar x1h_jac = e_jac_ * std::max(Scalar(1.), x1.norm()); |
122 |
|
✗ |
if (static_cast<std::size_t>(Jsecond.rows()) != ndx_ || |
123 |
|
✗ |
static_cast<std::size_t>(Jsecond.cols()) != ndx_) { |
124 |
|
✗ |
throw_pretty( |
125 |
|
|
"Invalid argument: " << "Jsecond has wrong dimension (it should be " + |
126 |
|
|
std::to_string(ndx_) + "," + |
127 |
|
|
std::to_string(ndx_) + ")"); |
128 |
|
|
} |
129 |
|
|
|
130 |
|
✗ |
Jsecond.setZero(); |
131 |
|
✗ |
for (std::size_t i = 0; i < ndx_; ++i) { |
132 |
|
✗ |
dx_(i) = x1h_jac; |
133 |
|
✗ |
integrate(x1, dx_, tmp_x_); |
134 |
|
✗ |
diff(x0, tmp_x_, Jsecond.col(i)); |
135 |
|
✗ |
Jsecond.col(i) -= dx0_; |
136 |
|
✗ |
dx_(i) = Scalar(0.); |
137 |
|
|
} |
138 |
|
✗ |
Jsecond /= x1h_jac; |
139 |
|
|
} |
140 |
|
|
} |
141 |
|
|
|
142 |
|
|
template <typename Scalar> |
143 |
|
✗ |
void StateNumDiffTpl<Scalar>::Jintegrate(const Eigen::Ref<const VectorXs>& x, |
144 |
|
|
const Eigen::Ref<const VectorXs>& dx, |
145 |
|
|
Eigen::Ref<MatrixXs> Jfirst, |
146 |
|
|
Eigen::Ref<MatrixXs> Jsecond, |
147 |
|
|
const Jcomponent firstsecond, |
148 |
|
|
const AssignmentOp) const { |
149 |
|
✗ |
assert_pretty( |
150 |
|
|
is_a_Jcomponent(firstsecond), |
151 |
|
|
("firstsecond must be one of the Jcomponent {both, first, second}")); |
152 |
|
✗ |
if (static_cast<std::size_t>(x.size()) != nx_) { |
153 |
|
✗ |
throw_pretty( |
154 |
|
|
"Invalid argument: " << "x has wrong dimension (it should be " + |
155 |
|
|
std::to_string(nx_) + ")"); |
156 |
|
|
} |
157 |
|
✗ |
if (static_cast<std::size_t>(dx.size()) != ndx_) { |
158 |
|
✗ |
throw_pretty( |
159 |
|
|
"Invalid argument: " << "dx has wrong dimension (it should be " + |
160 |
|
|
std::to_string(ndx_) + ")"); |
161 |
|
|
} |
162 |
|
✗ |
VectorXs tmp_x_ = VectorXs::Zero(nx_); |
163 |
|
✗ |
VectorXs dx_ = VectorXs::Zero(ndx_); |
164 |
|
✗ |
VectorXs x0_ = VectorXs::Zero(nx_); |
165 |
|
|
|
166 |
|
|
// x0_ = integrate(x, dx) |
167 |
|
✗ |
integrate(x, dx, x0_); |
168 |
|
|
|
169 |
|
✗ |
if (firstsecond == first || firstsecond == both) { |
170 |
|
✗ |
const Scalar xh_jac = e_jac_ * std::max(Scalar(1.), x.norm()); |
171 |
|
✗ |
if (static_cast<std::size_t>(Jfirst.rows()) != ndx_ || |
172 |
|
✗ |
static_cast<std::size_t>(Jfirst.cols()) != ndx_) { |
173 |
|
✗ |
throw_pretty( |
174 |
|
|
"Invalid argument: " << "Jfirst has wrong dimension (it should be " + |
175 |
|
|
std::to_string(ndx_) + "," + |
176 |
|
|
std::to_string(ndx_) + ")"); |
177 |
|
|
} |
178 |
|
✗ |
Jfirst.setZero(); |
179 |
|
✗ |
for (std::size_t i = 0; i < ndx_; ++i) { |
180 |
|
✗ |
dx_(i) = xh_jac; |
181 |
|
✗ |
integrate(x, dx_, tmp_x_); |
182 |
|
✗ |
integrate(tmp_x_, dx, tmp_x_); |
183 |
|
✗ |
diff(x0_, tmp_x_, Jfirst.col(i)); |
184 |
|
✗ |
dx_(i) = Scalar(0.); |
185 |
|
|
} |
186 |
|
✗ |
Jfirst /= xh_jac; |
187 |
|
|
} |
188 |
|
✗ |
if (firstsecond == second || firstsecond == both) { |
189 |
|
✗ |
const Scalar dxh_jac = e_jac_ * std::max(Scalar(1.), dx.norm()); |
190 |
|
✗ |
if (static_cast<std::size_t>(Jsecond.rows()) != ndx_ || |
191 |
|
✗ |
static_cast<std::size_t>(Jsecond.cols()) != ndx_) { |
192 |
|
✗ |
throw_pretty( |
193 |
|
|
"Invalid argument: " << "Jsecond has wrong dimension (it should be " + |
194 |
|
|
std::to_string(ndx_) + "," + |
195 |
|
|
std::to_string(ndx_) + ")"); |
196 |
|
|
} |
197 |
|
✗ |
Jsecond.setZero(); |
198 |
|
✗ |
for (std::size_t i = 0; i < ndx_; ++i) { |
199 |
|
✗ |
dx_(i) = dxh_jac; |
200 |
|
✗ |
integrate(x, dx + dx_, tmp_x_); |
201 |
|
✗ |
diff(x0_, tmp_x_, Jsecond.col(i)); |
202 |
|
✗ |
dx_(i) = Scalar(0.); |
203 |
|
|
} |
204 |
|
✗ |
Jsecond /= dxh_jac; |
205 |
|
|
} |
206 |
|
|
} |
207 |
|
|
|
208 |
|
|
template <typename Scalar> |
209 |
|
|
template <typename NewScalar> |
210 |
|
✗ |
StateNumDiffTpl<NewScalar> StateNumDiffTpl<Scalar>::cast() const { |
211 |
|
|
typedef StateNumDiffTpl<NewScalar> ReturnType; |
212 |
|
✗ |
ReturnType res(state_->template cast<NewScalar>()); |
213 |
|
✗ |
return res; |
214 |
|
|
} |
215 |
|
|
|
216 |
|
|
template <typename Scalar> |
217 |
|
✗ |
void StateNumDiffTpl<Scalar>::JintegrateTransport( |
218 |
|
|
const Eigen::Ref<const VectorXs>&, const Eigen::Ref<const VectorXs>&, |
219 |
|
✗ |
Eigen::Ref<MatrixXs>, const Jcomponent) const {} |
220 |
|
|
|
221 |
|
|
template <typename Scalar> |
222 |
|
✗ |
const Scalar StateNumDiffTpl<Scalar>::get_disturbance() const { |
223 |
|
✗ |
return e_jac_; |
224 |
|
|
} |
225 |
|
|
|
226 |
|
|
template <typename Scalar> |
227 |
|
✗ |
void StateNumDiffTpl<Scalar>::set_disturbance(Scalar disturbance) { |
228 |
|
✗ |
if (disturbance < Scalar(0.)) { |
229 |
|
✗ |
throw_pretty("Invalid argument: " << "Disturbance constant is positive"); |
230 |
|
|
} |
231 |
|
✗ |
e_jac_ = disturbance; |
232 |
|
|
} |
233 |
|
|
|
234 |
|
|
template <typename Scalar> |
235 |
|
✗ |
void StateNumDiffTpl<Scalar>::print(std::ostream& os) const { |
236 |
|
✗ |
os << "StateNumDiff {state=" << *state_ << "}"; |
237 |
|
|
} |
238 |
|
|
|
239 |
|
|
} // namespace crocoddyl |
240 |
|
|
|