21 EIGEN_MAKE_ALIGNED_OPERATOR_NEW
24 typedef _Scalar Scalar;
29 typedef typename MathBase::VectorXs VectorXs;
30 typedef typename MathBase::MatrixXs MatrixXs;
33 :
Base(weights.size()), weights_(weights), new_weights_(
false) {};
36 virtual void calc(
const std::shared_ptr<ActivationDataAbstract>& data,
37 const Eigen::Ref<const VectorXs>& r)
override {
38 if (
static_cast<std::size_t
>(r.size()) != nr_) {
40 "Invalid argument: " <<
"r has wrong dimension (it should be " +
41 std::to_string(nr_) +
")");
43 std::shared_ptr<Data> d = std::static_pointer_cast<Data>(data);
45 d->Wr = weights_.cwiseProduct(r);
46 data->a_value = Scalar(0.5) * r.dot(d->Wr);
49 virtual void calcDiff(
const std::shared_ptr<ActivationDataAbstract>& data,
50 const Eigen::Ref<const VectorXs>& r)
override {
51 if (
static_cast<std::size_t
>(r.size()) != nr_) {
53 "Invalid argument: " <<
"r has wrong dimension (it should be " +
54 std::to_string(nr_) +
")");
57 std::shared_ptr<Data> d = std::static_pointer_cast<Data>(data);
60 data->Arr.diagonal() = weights_;
65 assert_pretty(MatrixXs(data->Arr).isApprox(Arr_),
"Arr has wrong value");
69 virtual std::shared_ptr<ActivationDataAbstract> createData()
override {
70 std::shared_ptr<Data> data =
71 std::allocate_shared<Data>(Eigen::aligned_allocator<Data>(),
this);
72 data->Arr.diagonal() = weights_;
81 template <
typename NewScalar>
84 ReturnType res(weights_.template cast<NewScalar>());
88 const VectorXs& get_weights()
const {
return weights_; };
89 void set_weights(
const VectorXs& weights) {
90 if (weights.size() != weights_.size()) {
91 throw_pretty(
"Invalid argument: "
92 <<
"weight vector has wrong dimension (it should be " +
93 std::to_string(weights_.size()) +
")");
105 virtual void print(std::ostream& os)
const override {
106 os <<
"ActivationModelQuad {nr=" << nr_ <<
"}";