1 #ifndef _AUTODIFFERENTIATION_HPP_INC_
2 #define _AUTODIFFERENTIATION_HPP_INC_
19 template<
typename RT,
typename LT,
typename OT>
21 std::is_same<RT, decltype(lhs *= rhs)>::value;
24 template<
typename RT,
typename LT,
typename OT>
26 std::is_same<RT, decltype(lhs += rhs)>::value;
29 template<
typename RT,
typename LT,
typename OT>
31 std::is_same<RT, decltype(lhs -= rhs)>::value;
34 template<
typename RT,
typename LT,
typename OT>
36 std::is_same<RT, decltype(lhs /= rhs)>::value;
40 template<
typename ValType,
size_t NumVars>
41 requires(std::is_arithmetic<ValType>::value && NumVars >= 0) struct DiffVar {
42 using ThisDiff = DiffVar<ValType, NumVars>;
45 std::array<ValType, NumVars> diffVars = {0};
47 template<
typename... VariadicType>
48 requires(... && std::is_convertible<VariadicType, ValType>::
49 value) constexpr
explicit DiffVar(ValType _var,
50 VariadicType... _diffVars)
51 :
var(_var), diffVars {{
static_cast<ValType
>(_diffVars)...}} {
54 constexpr DiffVar(ValType _var, std::array<ValType, NumVars> _diffVars)
55 :
var(_var), diffVars(_diffVars) {
58 template<
typename OtherValType>
59 requires(std::is_convertible<OtherValType, ValType>::
60 value) constexpr
explicit DiffVar(
const OtherValType & other) {
61 var =
static_cast<ValType
>(other);
64 constexpr
explicit operator ValType()
const {
68 constexpr ValType & operator[](
size_t index) {
72 return diffVars[index - 1];
76 constexpr
const ValType & operator[](
size_t index)
const {
80 return diffVars[index - 1];
84 template<
typename OtherValType>
85 requires(std::is_convertible<OtherValType, ValType>::value) constexpr
int
86 operator<=>(OtherValType & rhs) {
87 return var <=>
static_cast<ValType
>(rhs);
90 constexpr ThisDiff operator+=(
const ThisDiff & rhs) {
92 auto it1 = diffVars.begin();
93 auto it2 = rhs.diffVars.begin();
94 while (
it1 != diffVars.end() &&
it2 != rhs.diffVars.end()) {
102 template<
typename OtherValType>
103 requires(std::is_convertible<OtherValType, ValType>::value) constexpr ThisDiff
104 operator+=(
const OtherValType & rhs) {
105 var +=
static_cast<ValType
>(rhs);
109 constexpr ThisDiff operator-=(
const ThisDiff & rhs) {
111 auto it1 = diffVars.begin();
112 auto it2 = rhs.diffVars.begin();
113 while (
it1 != diffVars.end() &&
it2 != rhs.diffVars.end()) {
121 template<
typename OtherValType>
122 requires(std::is_convertible<OtherValType, ValType>::value) constexpr ThisDiff
123 operator-=(
const OtherValType & rhs) {
124 var -=
static_cast<ValType
>(rhs);
128 constexpr ThisDiff operator-()
const {
132 while (
it1 !=
toRet.diffVars.end()) {
140 constexpr ThisDiff operator*=(
const ThisDiff & rhs) {
141 auto it1 = diffVars.begin();
142 auto it2 = rhs.diffVars.begin();
143 while (
it1 != diffVars.end() &&
it2 != rhs.diffVars.end()) {
152 template<
typename OtherValType>
153 requires(std::is_convertible<OtherValType, ValType>::value) constexpr ThisDiff
154 operator*=(
const OtherValType & rhs) {
155 auto it1 = diffVars.begin();
156 while (
it1 != diffVars.end()) {
157 *
it1 =
static_cast<ValType
>(rhs) * *
it1;
160 var *=
static_cast<ValType
>(rhs);
164 constexpr ThisDiff operator/=(
const ThisDiff & rhs) {
165 auto it1 = diffVars.begin();
166 auto it2 = rhs.diffVars.begin();
167 while (
it1 != diffVars.end() &&
it2 != rhs.diffVars.end()) {
168 *
it1 = (rhs.var * *
it1 -
var * *
it2) / (rhs.var * rhs.var);
176 template<
typename OtherValType>
177 requires(std::is_convertible<OtherValType, ValType>::value) constexpr ThisDiff
178 operator/=(
const OtherValType & rhs) {
179 auto it1 = diffVars.begin();
180 while (
it1 != diffVars.end()) {
181 *
it1 = *
it1 /
static_cast<ValType
>(rhs);
184 var /=
static_cast<ValType
>(rhs);
191 template<
typename LT,
typename RT>
192 requires AddableResult<ThisDiff, LT, RT> constexpr
friend auto
193 operator+(LT lhs,
const RT & rhs) {
198 template<
typename LT,
typename RT>
199 requires AddableResult<ThisDiff, RT, LT> &&
200 (!AddableResult<ThisDiff, LT, RT>)constexpr
friend auto
201 operator+(
const LT & lhs, RT rhs) {
207 template<
typename LT,
typename RT>
208 requires SubtractableResult<ThisDiff, LT, RT> constexpr
friend auto
209 operator-(LT lhs,
const RT & rhs) {
214 template<
typename LT,
typename RT>
215 requires SubtractableResult<ThisDiff, RT, LT> &&
216 (!SubtractableResult<ThisDiff, LT, RT>)constexpr
friend auto
217 operator-(
const LT & lhs, RT rhs) {
222 template<
typename LT,
typename RT>
223 requires(!SubtractableResult<ThisDiff, LT, RT>) &&
224 AddableResult<ThisDiff, LT, RT> constexpr
friend auto
225 operator-(LT lhs,
const RT & rhs) {
230 template<
typename LT,
typename RT>
231 requires(!SubtractableResult<ThisDiff, RT, LT>) &&
232 AddableResult<ThisDiff, RT, LT> &&
233 (!AddableResult<ThisDiff, LT, RT>)constexpr
friend auto
234 operator-(
const LT & lhs, RT rhs) {
240 template<
typename LT,
typename RT>
241 requires MultipliableResult<ThisDiff, LT, RT> constexpr
friend auto
242 operator*(LT lhs,
const RT & rhs) {
247 template<
typename LT,
typename RT>
248 requires MultipliableResult<ThisDiff, RT, LT> &&
249 (!MultipliableResult<ThisDiff, LT, RT>)constexpr
friend auto
250 operator*(
const LT & lhs, RT rhs) {
256 template<
typename OT>
257 requires(!DivisableResult<ThisDiff, OT, ThisDiff>) &&
258 (std::is_convertible<OT, ValType>::value) constexpr
friend auto
259 operator/(
const OT & lhs,
const ThisDiff & rhs) {
260 auto lhsDiff =
static_cast<ThisDiff
>(lhs);
265 template<
typename OT>
266 requires(std::is_convertible<OT, ValType>::value) constexpr
friend auto
267 operator/(ThisDiff lhs,
const OT & rhs) {
272 constexpr
friend std::ostream &
273 operator<<(std::ostream & output,
const ThisDiff &
v) {
274 output <<
"Value: " <<
v.var;
276 for (
auto diffVar :
v.diffVars) {
277 output << std::endl <<
"Derivative " << n++ <<
": " << diffVar;
285 template<
typename ValType,
size_t NumVars,
typename F1,
typename F2>
286 requires(std::is_arithmetic<ValType>::value && NumVars >=
287 0) constexpr auto diffFunc(const DiffVar<ValType, NumVars> &
arg, F1
func,
302 template<
typename ValType,
size_t NumVars>
303 requires(std::is_arithmetic<ValType>::value && NumVars >=
304 0) auto
sin(const DiffVar<ValType, NumVars> &
arg) {
307 [](ValType
v) {
return std::cos(
v); });
310 template<
typename ValType,
size_t NumVars>
311 requires(std::is_arithmetic<ValType>::value && NumVars >=
312 0) auto cos(const DiffVar<ValType, NumVars> &
arg) {
314 arg, [](ValType
v) {
return std::cos(
v); },
318 template<
typename ValType,
size_t NumVars>
319 requires(std::is_arithmetic<ValType>::value && NumVars >=
320 0) auto tan(const DiffVar<ValType, NumVars> &
arg) {
322 arg, [](ValType
v) {
return std::tan(
v); },
323 [](ValType
v) {
return 1 / (std::cos(
v) * std::cos(
v)); });
326 template<
typename ValType,
size_t NumVars>
327 requires(std::is_arithmetic<ValType>::value && NumVars >=
328 0) auto sinh(const DiffVar<ValType, NumVars> &
arg) {
330 arg, [](ValType
v) {
return std::sinh(
v); },
331 [](ValType
v) {
return std::cosh(
v); });
334 template<
typename ValType,
size_t NumVars>
335 requires(std::is_arithmetic<ValType>::value && NumVars >=
336 0) auto cosh(const DiffVar<ValType, NumVars> &
arg) {
338 arg, [](ValType
v) {
return std::cosh(
v); },
339 [](ValType
v) {
return std::sinh(
v); });
342 template<
typename ValType,
size_t NumVars>
343 requires(std::is_arithmetic<ValType>::value && NumVars >=
344 0) auto tanh(const DiffVar<ValType, NumVars> &
arg) {
346 arg, [](ValType
v) {
return std::tanh(
v); },
347 [](ValType
v) {
return 1 / (std::cosh(
v) * std::cosh(
v)); });
350 template<
typename ValType,
size_t NumVars>
351 requires(std::is_arithmetic<ValType>::value && NumVars >=
352 0) auto exp(const DiffVar<ValType, NumVars> &
arg) {
357 auto it2 =
arg.diffVars.begin();
358 while (
it1 !=
toRet.diffVars.end() &&
it2 !=
arg.diffVars.end()) {
367 template<
typename ValType,
size_t NumVars,
typename OT>
368 requires(std::is_arithmetic<ValType>::value && NumVars >= 0) &&
369 (!std::is_same<OT, DiffVar<ValType, NumVars> >::value)
auto pow(
370 const DiffVar<ValType, NumVars> &
arg, OT
exponent) {
376 template<
typename ValType,
size_t NumVars>
377 requires(std::is_arithmetic<ValType>::value && NumVars >=
378 0) auto pow(const DiffVar<ValType, NumVars> &
arg,
382 auto deriv = [](ValType f, ValType fPrime, ValType g, ValType gPrime) {
383 return std::pow(f, g - 1) * (g * fPrime + f * std::log(f) * gPrime);
389 auto it2 =
arg.diffVars.begin();
391 while (
it1 !=
toRet.diffVars.end() &&
it2 !=
arg.diffVars.end()) {
400 template<
typename ValType,
size_t NumVars>
401 requires(std::is_arithmetic<ValType>::value && NumVars >=
402 0) auto sqrt(const DiffVar<ValType, NumVars> &
arg) {
403 auto funcResult = std::pow(
arg.var, 0.5);
404 DiffVar<ValType, NumVars>
toRet(funcResult);
407 auto it2 =
arg.diffVars.begin();
409 while (
it1 !=
toRet.diffVars.end() &&
it2 !=
arg.diffVars.end()) {
a namespace to hold the messiness of my auto-differentiator
concept SubtractableResult
requires(std::is_arithmetic< ValType >::value &&NumVars >=0) struct DiffVar
NumVars DiffVar< ValType, NumVars > exponent
concept MultipliableResult