Crocoddyl
 
Loading...
Searching...
No Matches
conversions.hpp
1
2// BSD 3-Clause License
3//
4// Copyright (C) 2024-2025, Heriot-Watt University
5// Copyright note valid unless otherwise stated in individual files.
6// All rights reserved.
8
9#ifndef CROCODDYL_UTILS_CONVERSIONS_HPP_
10#define CROCODDYL_UTILS_CONVERSIONS_HPP_
11
12#include <vector>
13
14#ifdef CROCODDYL_WITH_CODEGEN
15#include <cppad/cg/support/cppadcg_eigen.hpp>
16#include <cppad/cppad.hpp>
17#endif
18
19#include "crocoddyl/core/mathbase.hpp"
20
21namespace crocoddyl {
22
23template <typename Scalar>
25 typedef typename std::conditional<std::is_floating_point<Scalar>::value,
26 Scalar, double>::type type;
27};
28
29// Casting between floating-point types
30template <typename NewScalar, typename Scalar>
31static typename std::enable_if<std::is_floating_point<NewScalar>::value &&
32 std::is_floating_point<Scalar>::value,
33 NewScalar>::type
34scalar_cast(const Scalar& x) {
35 return static_cast<NewScalar>(x);
36}
37
38template <typename NewScalar, typename Scalar,
39 template <typename> class ItemTpl>
40std::vector<ItemTpl<NewScalar>> vector_cast(
41 const std::vector<ItemTpl<Scalar>>& in) {
42 std::vector<ItemTpl<NewScalar>> out;
43 out.reserve(in.size()); // Optimize allocation
44 for (const auto& obj : in) {
45 out.push_back(obj.template cast<NewScalar>());
46 }
47 return out;
48}
49
50template <typename NewScalar, typename Scalar,
51 template <typename> class ItemTpl>
52std::vector<std::shared_ptr<ItemTpl<NewScalar>>> vector_cast(
53 const std::vector<std::shared_ptr<ItemTpl<Scalar>>>& in) {
54 std::vector<std::shared_ptr<ItemTpl<NewScalar>>> out;
55 out.reserve(in.size()); // Optimize allocation
56 for (const auto& obj : in) {
57 out.push_back(std::static_pointer_cast<ItemTpl<NewScalar>>(
58 obj->template cast<NewScalar>()));
59 }
60 return out;
61}
62
63} // namespace crocoddyl
64
65#ifdef CROCODDYL_WITH_CODEGEN
66
67// Specialize Eigen's internal cast_impl for your specific types
68namespace Eigen {
69namespace internal {
70
71template <>
72struct cast_impl<CppAD::AD<CppAD::cg::CG<double>>, float> {
73 EIGEN_DEVICE_FUNC static inline float run(
74 const CppAD::AD<CppAD::cg::CG<double>>& x) {
75 // Perform the conversion. This example extracts the value from the AD type.
76 // You might need to adjust this depending on the specific implementation of
77 // CppAD::cg::CG<double>.
78 return static_cast<float>(CppAD::Value(x).getValue());
79 }
80};
81
82template <>
83struct cast_impl<CppAD::AD<CppAD::cg::CG<double>>, double> {
84 EIGEN_DEVICE_FUNC static inline double run(
85 const CppAD::AD<CppAD::cg::CG<double>>& x) {
86 return CppAD::Value(x).getValue();
87 }
88};
89
90template <>
91struct cast_impl<CppAD::AD<CppAD::cg::CG<float>>, float> {
92 EIGEN_DEVICE_FUNC static inline float run(
93 const CppAD::AD<CppAD::cg::CG<float>>& x) {
94 return CppAD::Value(x).getValue();
95 }
96};
97
98template <>
99struct cast_impl<CppAD::AD<CppAD::cg::CG<float>>, double> {
100 EIGEN_DEVICE_FUNC static inline double run(
101 const CppAD::AD<CppAD::cg::CG<float>>& x) {
102 // Perform the conversion. This example extracts the value from the AD type.
103 // You might need to adjust this depending on the specific implementation of
104 // CppAD::cg::CG<float>.
105 return static_cast<float>(CppAD::Value(x).getValue());
106 }
107};
108
109// Convert from CppAD::AD<CppAD::cg::CG<float>> to
110// CppAD::AD<CppAD::cg::CG<double>>
111template <>
112struct cast_impl<CppAD::AD<CppAD::cg::CG<float>>,
113 CppAD::AD<CppAD::cg::CG<double>>> {
114 EIGEN_DEVICE_FUNC static inline CppAD::AD<CppAD::cg::CG<double>> run(
115 const CppAD::AD<CppAD::cg::CG<float>>& x) {
116 return CppAD::AD<CppAD::cg::CG<double>>(
117 CppAD::cg::CG<double>(CppAD::Value(x).getValue()));
118 }
119};
120
121// Convert from CppAD::AD<CppAD::cg::CG<double>> to
122// CppAD::AD<CppAD::cg::CG<float>>
123template <>
124struct cast_impl<CppAD::AD<CppAD::cg::CG<double>>,
125 CppAD::AD<CppAD::cg::CG<float>>> {
126 EIGEN_DEVICE_FUNC static inline CppAD::AD<CppAD::cg::CG<float>> run(
127 const CppAD::AD<CppAD::cg::CG<double>>& x) {
128 return CppAD::AD<CppAD::cg::CG<float>>(
129 CppAD::cg::CG<float>(static_cast<float>(CppAD::Value(x).getValue())));
130 }
131};
132
133} // namespace internal
134} // namespace Eigen
135
136namespace crocoddyl {
137
138// Casting to CppAD types from floating-point types
139template <typename NewScalar, typename Scalar>
140static typename std::enable_if<
141 std::is_floating_point<Scalar>::value &&
142 (std::is_same<NewScalar, CppAD::AD<CppAD::cg::CG<double>>>::value ||
143 std::is_same<NewScalar, CppAD::AD<CppAD::cg::CG<float>>>::value),
144 NewScalar>::type
145scalar_cast(const Scalar& x) {
146 return static_cast<NewScalar>(x);
147}
148
149// Casting to floating-point types from CppAD types
150template <typename NewScalar, typename Scalar>
151static inline typename std::enable_if<std::is_floating_point<Scalar>::value,
152 NewScalar>::type
153scalar_cast(const CppAD::AD<CppAD::cg::CG<Scalar>>& x) {
154 return static_cast<NewScalar>(CppAD::Value(x).getValue());
155}
156
157} // namespace crocoddyl
158
159#endif
160
161#endif // CROCODDYL_UTILS_CONVERSIONS_HPP_