Crocoddyl
conversions.hpp
1 // BSD 3-Clause License
3 //
4 // Copyright (C) 2024-2025, Heriot-Watt University
5 // Copyright note valid unless otherwise stated in individual files.
6 // All rights reserved.
8 
9 #ifndef CROCODDYL_UTILS_CONVERSIONS_HPP_
10 #define CROCODDYL_UTILS_CONVERSIONS_HPP_
11 
12 #include <vector>
13 
14 #ifdef CROCODDYL_WITH_CODEGEN
15 #include <cppad/cg/support/cppadcg_eigen.hpp>
16 #include <cppad/cppad.hpp>
17 #endif
18 
19 #include "crocoddyl/core/mathbase.hpp"
20 
21 namespace crocoddyl {
22 
23 template <typename Scalar>
25  typedef typename std::conditional<std::is_floating_point<Scalar>::value,
26  Scalar, double>::type type;
27 };
28 
29 // Casting between floating-point types
30 template <typename NewScalar, typename Scalar>
31 static typename std::enable_if<std::is_floating_point<NewScalar>::value &&
32  std::is_floating_point<Scalar>::value,
33  NewScalar>::type
34 scalar_cast(const Scalar& x) {
35  return static_cast<NewScalar>(x);
36 }
37 
38 template <typename NewScalar, typename Scalar,
39  template <typename> class ItemTpl>
40 std::vector<ItemTpl<NewScalar>> vector_cast(
41  const std::vector<ItemTpl<Scalar>>& in) {
42  std::vector<ItemTpl<NewScalar>> out;
43  out.reserve(in.size()); // Optimize allocation
44  for (const auto& obj : in) {
45  out.push_back(obj.template cast<NewScalar>());
46  }
47  return out;
48 }
49 
50 template <typename NewScalar, typename Scalar,
51  template <typename> class ItemTpl>
52 std::vector<std::shared_ptr<ItemTpl<NewScalar>>> vector_cast(
53  const std::vector<std::shared_ptr<ItemTpl<Scalar>>>& in) {
54  std::vector<std::shared_ptr<ItemTpl<NewScalar>>> out;
55  out.reserve(in.size()); // Optimize allocation
56  for (const auto& obj : in) {
57  out.push_back(std::static_pointer_cast<ItemTpl<NewScalar>>(
58  obj->template cast<NewScalar>()));
59  }
60  return out;
61 }
62 
63 } // namespace crocoddyl
64 
65 #ifdef CROCODDYL_WITH_CODEGEN
66 
67 // Specialize Eigen's internal cast_impl for your specific types
68 namespace Eigen {
69 namespace internal {
70 
71 template <>
72 struct cast_impl<CppAD::AD<CppAD::cg::CG<double>>, float> {
73  EIGEN_DEVICE_FUNC static inline float run(
74  const CppAD::AD<CppAD::cg::CG<double>>& x) {
75  // Perform the conversion. This example extracts the value from the AD type.
76  // You might need to adjust this depending on the specific implementation of
77  // CppAD::cg::CG<double>.
78  return static_cast<float>(CppAD::Value(x).getValue());
79  }
80 };
81 
82 template <>
83 struct cast_impl<CppAD::AD<CppAD::cg::CG<double>>, double> {
84  EIGEN_DEVICE_FUNC static inline double run(
85  const CppAD::AD<CppAD::cg::CG<double>>& x) {
86  return CppAD::Value(x).getValue();
87  }
88 };
89 
90 template <>
91 struct cast_impl<CppAD::AD<CppAD::cg::CG<float>>, float> {
92  EIGEN_DEVICE_FUNC static inline float run(
93  const CppAD::AD<CppAD::cg::CG<float>>& x) {
94  return CppAD::Value(x).getValue();
95  }
96 };
97 
98 template <>
99 struct cast_impl<CppAD::AD<CppAD::cg::CG<float>>, double> {
100  EIGEN_DEVICE_FUNC static inline double run(
101  const CppAD::AD<CppAD::cg::CG<float>>& x) {
102  // Perform the conversion. This example extracts the value from the AD type.
103  // You might need to adjust this depending on the specific implementation of
104  // CppAD::cg::CG<float>.
105  return static_cast<float>(CppAD::Value(x).getValue());
106  }
107 };
108 
109 // Convert from CppAD::AD<CppAD::cg::CG<float>> to
110 // CppAD::AD<CppAD::cg::CG<double>>
111 template <>
112 struct cast_impl<CppAD::AD<CppAD::cg::CG<float>>,
113  CppAD::AD<CppAD::cg::CG<double>>> {
114  EIGEN_DEVICE_FUNC static inline CppAD::AD<CppAD::cg::CG<double>> run(
115  const CppAD::AD<CppAD::cg::CG<float>>& x) {
116  return CppAD::AD<CppAD::cg::CG<double>>(
117  CppAD::cg::CG<double>(CppAD::Value(x).getValue()));
118  }
119 };
120 
121 // Convert from CppAD::AD<CppAD::cg::CG<double>> to
122 // CppAD::AD<CppAD::cg::CG<float>>
123 template <>
124 struct cast_impl<CppAD::AD<CppAD::cg::CG<double>>,
125  CppAD::AD<CppAD::cg::CG<float>>> {
126  EIGEN_DEVICE_FUNC static inline CppAD::AD<CppAD::cg::CG<float>> run(
127  const CppAD::AD<CppAD::cg::CG<double>>& x) {
128  return CppAD::AD<CppAD::cg::CG<float>>(
129  CppAD::cg::CG<float>(static_cast<float>(CppAD::Value(x).getValue())));
130  }
131 };
132 
133 } // namespace internal
134 } // namespace Eigen
135 
136 namespace crocoddyl {
137 
138 // Casting to CppAD types from floating-point types
139 template <typename NewScalar, typename Scalar>
140 static typename std::enable_if<
141  std::is_floating_point<Scalar>::value &&
142  (std::is_same<NewScalar, CppAD::AD<CppAD::cg::CG<double>>>::value ||
143  std::is_same<NewScalar, CppAD::AD<CppAD::cg::CG<float>>>::value),
144  NewScalar>::type
145 scalar_cast(const Scalar& x) {
146  return static_cast<NewScalar>(x);
147 }
148 
149 // Casting to floating-point types from CppAD types
150 template <typename NewScalar, typename Scalar>
151 static inline typename std::enable_if<std::is_floating_point<Scalar>::value,
152  NewScalar>::type
153 scalar_cast(const CppAD::AD<CppAD::cg::CG<Scalar>>& x) {
154  return static_cast<NewScalar>(CppAD::Value(x).getValue());
155 }
156 
157 } // namespace crocoddyl
158 
159 #endif
160 
161 #endif // CROCODDYL_UTILS_CONVERSIONS_HPP_