JUK1
AutoDifferentiation.hpp
Go to the documentation of this file.
1 #ifndef _AUTODIFFERENTIATION_HPP_INC_
2 #define _AUTODIFFERENTIATION_HPP_INC_
3 #include <array>
4 #include <concepts>
5 #include <tuple>
6 #include <cmath>
7 #include <functional>
8 #include <iostream>
9 
18 
19 template<typename RT, typename LT, typename OT>
20 concept MultipliableResult = requires(LT lhs, OT rhs) {
21  std::is_same<RT, decltype(lhs *= rhs)>::value;
22 };
23 
24 template<typename RT, typename LT, typename OT>
25 concept AddableResult = requires(LT lhs, OT rhs) {
26  std::is_same<RT, decltype(lhs += rhs)>::value;
27 };
28 
29 template<typename RT, typename LT, typename OT>
30 concept SubtractableResult = requires(LT lhs, OT rhs) {
31  std::is_same<RT, decltype(lhs -= rhs)>::value;
32 };
33 
34 template<typename RT, typename LT, typename OT>
35 concept DivisableResult = requires(LT lhs, OT rhs) {
36  std::is_same<RT, decltype(lhs /= rhs)>::value;
37 };
38 
39 
40 template<typename ValType, size_t NumVars>
41 requires(std::is_arithmetic<ValType>::value && NumVars >= 0) struct DiffVar {
42  using ThisDiff = DiffVar<ValType, NumVars>;
43 
44  ValType var = 0;
45  std::array<ValType, NumVars> diffVars = {0};
46 
47  template<typename... VariadicType>
48  requires(... && std::is_convertible<VariadicType, ValType>::
49  value) constexpr explicit DiffVar(ValType _var,
50  VariadicType... _diffVars)
51  : var(_var), diffVars {{static_cast<ValType>(_diffVars)...}} {
52  }
53 
54  constexpr DiffVar(ValType _var, std::array<ValType, NumVars> _diffVars)
55  : var(_var), diffVars(_diffVars) {
56  }
57 
58  template<typename OtherValType>
59  requires(std::is_convertible<OtherValType, ValType>::
60  value) constexpr explicit DiffVar(const OtherValType & other) {
61  var = static_cast<ValType>(other);
62  }
63 
64  constexpr explicit operator ValType() const {
65  return var;
66  }
67 
68  constexpr ValType & operator[](size_t index) {
69  if (index == 0) {
70  return var;
71  } else {
72  return diffVars[index - 1];
73  }
74  }
75 
76  constexpr const ValType & operator[](size_t index) const {
77  if (index == 0) {
78  return var;
79  } else {
80  return diffVars[index - 1];
81  }
82  }
83 
84  template<typename OtherValType>
85  requires(std::is_convertible<OtherValType, ValType>::value) constexpr int
86  operator<=>(OtherValType & rhs) {
87  return var <=> static_cast<ValType>(rhs);
88  }
89 
90  constexpr ThisDiff operator+=(const ThisDiff & rhs) {
91  var += rhs.var;
92  auto it1 = diffVars.begin();
93  auto it2 = rhs.diffVars.begin();
94  while (it1 != diffVars.end() && it2 != rhs.diffVars.end()) {
95  *it1 += *it2;
96  ++it1;
97  ++it2;
98  }
99  return *this;
100  }
101 
102  template<typename OtherValType>
103  requires(std::is_convertible<OtherValType, ValType>::value) constexpr ThisDiff
104  operator+=(const OtherValType & rhs) {
105  var += static_cast<ValType>(rhs);
106  return *this;
107  }
108 
109  constexpr ThisDiff operator-=(const ThisDiff & rhs) {
110  var -= rhs.var;
111  auto it1 = diffVars.begin();
112  auto it2 = rhs.diffVars.begin();
113  while (it1 != diffVars.end() && it2 != rhs.diffVars.end()) {
114  *it1 -= *it2;
115  ++it1;
116  ++it2;
117  }
118  return *this;
119  }
120 
121  template<typename OtherValType>
122  requires(std::is_convertible<OtherValType, ValType>::value) constexpr ThisDiff
123  operator-=(const OtherValType & rhs) {
124  var -= static_cast<ValType>(rhs);
125  return *this;
126  }
127 
128  constexpr ThisDiff operator-() const {
129  auto toRet = *this;
130  toRet.var = -var;
131  auto it1 = toRet.diffVars.begin();
132  while (it1 != toRet.diffVars.end()) {
133  *it1 = -*it1;
134  ++it1;
135  }
136  return toRet;
137  }
138 
139 
140  constexpr ThisDiff operator*=(const ThisDiff & rhs) {
141  auto it1 = diffVars.begin();
142  auto it2 = rhs.diffVars.begin();
143  while (it1 != diffVars.end() && it2 != rhs.diffVars.end()) {
144  *it1 = (rhs.var * *it1 + var * *it2);
145  ++it1;
146  ++it2;
147  }
148  var *= rhs.var;
149  return *this;
150  }
151 
152  template<typename OtherValType>
153  requires(std::is_convertible<OtherValType, ValType>::value) constexpr ThisDiff
154  operator*=(const OtherValType & rhs) {
155  auto it1 = diffVars.begin();
156  while (it1 != diffVars.end()) {
157  *it1 = static_cast<ValType>(rhs) * *it1;
158  ++it1;
159  }
160  var *= static_cast<ValType>(rhs);
161  return *this;
162  }
163 
164  constexpr ThisDiff operator/=(const ThisDiff & rhs) {
165  auto it1 = diffVars.begin();
166  auto it2 = rhs.diffVars.begin();
167  while (it1 != diffVars.end() && it2 != rhs.diffVars.end()) {
168  *it1 = (rhs.var * *it1 - var * *it2) / (rhs.var * rhs.var);
169  ++it1;
170  ++it2;
171  }
172  var /= rhs.var;
173  return *this;
174  }
175 
176  template<typename OtherValType>
177  requires(std::is_convertible<OtherValType, ValType>::value) constexpr ThisDiff
178  operator/=(const OtherValType & rhs) {
179  auto it1 = diffVars.begin();
180  while (it1 != diffVars.end()) {
181  *it1 = *it1 / static_cast<ValType>(rhs);
182  ++it1;
183  }
184  var /= static_cast<ValType>(rhs);
185  return *this;
186  }
187 
188  // Friends
189  // Binary Operators
190  // Addition
191  template<typename LT, typename RT>
192  requires AddableResult<ThisDiff, LT, RT> constexpr friend auto
193  operator+(LT lhs, const RT & rhs) {
194  lhs += rhs;
195  return lhs;
196  }
197 
198  template<typename LT, typename RT>
199  requires AddableResult<ThisDiff, RT, LT> &&
200  (!AddableResult<ThisDiff, LT, RT>)constexpr friend auto
201  operator+(const LT & lhs, RT rhs) {
202  rhs += lhs;
203  return rhs;
204  }
205 
206  // Subtraction
207  template<typename LT, typename RT>
208  requires SubtractableResult<ThisDiff, LT, RT> constexpr friend auto
209  operator-(LT lhs, const RT & rhs) {
210  lhs -= rhs;
211  return lhs;
212  }
213 
214  template<typename LT, typename RT>
215  requires SubtractableResult<ThisDiff, RT, LT> &&
216  (!SubtractableResult<ThisDiff, LT, RT>)constexpr friend auto
217  operator-(const LT & lhs, RT rhs) {
218  rhs -= lhs;
219  return -rhs;
220  }
221 
222  template<typename LT, typename RT>
223  requires(!SubtractableResult<ThisDiff, LT, RT>) &&
224  AddableResult<ThisDiff, LT, RT> constexpr friend auto
225  operator-(LT lhs, const RT & rhs) {
226  lhs += -rhs;
227  return lhs;
228  }
229 
230  template<typename LT, typename RT>
231  requires(!SubtractableResult<ThisDiff, RT, LT>) &&
232  AddableResult<ThisDiff, RT, LT> &&
233  (!AddableResult<ThisDiff, LT, RT>)constexpr friend auto
234  operator-(const LT & lhs, RT rhs) {
235  rhs += -lhs;
236  return -rhs;
237  }
238 
239  // Multiplication
240  template<typename LT, typename RT>
241  requires MultipliableResult<ThisDiff, LT, RT> constexpr friend auto
242  operator*(LT lhs, const RT & rhs) {
243  lhs *= rhs;
244  return lhs;
245  }
246 
247  template<typename LT, typename RT>
248  requires MultipliableResult<ThisDiff, RT, LT> &&
249  (!MultipliableResult<ThisDiff, LT, RT>)constexpr friend auto
250  operator*(const LT & lhs, RT rhs) {
251  rhs *= lhs;
252  return rhs;
253  }
254 
255  // Division
256  template<typename OT>
257  requires(!DivisableResult<ThisDiff, OT, ThisDiff>) &&
258  (std::is_convertible<OT, ValType>::value) constexpr friend auto
259  operator/(const OT & lhs, const ThisDiff & rhs) {
260  auto lhsDiff = static_cast<ThisDiff>(lhs);
261  lhsDiff /= rhs;
262  return lhsDiff;
263  }
264 
265  template<typename OT>
266  requires(std::is_convertible<OT, ValType>::value) constexpr friend auto
267  operator/(ThisDiff lhs, const OT & rhs) {
268  lhs /= rhs;
269  return lhs;
270  }
271 
272  constexpr friend std::ostream &
273  operator<<(std::ostream & output, const ThisDiff & v) {
274  output << "Value: " << v.var;
275  size_t n = 0;
276  for (auto diffVar : v.diffVars) {
277  output << std::endl << "Derivative " << n++ << ": " << diffVar;
278  }
279 
280  return output;
281  }
282 };
283 
284 
285 template<typename ValType, size_t NumVars, typename F1, typename F2>
286 requires(std::is_arithmetic<ValType>::value && NumVars >=
287  0) constexpr auto diffFunc(const DiffVar<ValType, NumVars> & arg, F1 func,
288  F2 deriv) {
289  DiffVar<ValType, NumVars> toRet(func(arg.var));
290 
291  auto it1 = toRet.diffVars.begin();
292  auto it2 = arg.diffVars.begin();
293  auto derivEval = deriv(arg.var);
294  while (it1 != toRet.diffVars.end() && it2 != arg.diffVars.end()) {
295  *it1 = *it2 * derivEval;
296  ++it1;
297  ++it2;
298  }
299  return toRet;
300 }
301 
302 template<typename ValType, size_t NumVars>
303 requires(std::is_arithmetic<ValType>::value && NumVars >=
304  0) auto sin(const DiffVar<ValType, NumVars> & arg) {
305  return diffFunc(
306  arg, [](ValType v) { return std::sin(v); },
307  [](ValType v) { return std::cos(v); });
308 }
309 
310 template<typename ValType, size_t NumVars>
311 requires(std::is_arithmetic<ValType>::value && NumVars >=
312  0) auto cos(const DiffVar<ValType, NumVars> & arg) {
313  return diffFunc(
314  arg, [](ValType v) { return std::cos(v); },
315  [](ValType v) { return -std::sin(v); });
316 }
317 
318 template<typename ValType, size_t NumVars>
319 requires(std::is_arithmetic<ValType>::value && NumVars >=
320  0) auto tan(const DiffVar<ValType, NumVars> & arg) {
321  return diffFunc(
322  arg, [](ValType v) { return std::tan(v); },
323  [](ValType v) { return 1 / (std::cos(v) * std::cos(v)); });
324 }
325 
326 template<typename ValType, size_t NumVars>
327 requires(std::is_arithmetic<ValType>::value && NumVars >=
328  0) auto sinh(const DiffVar<ValType, NumVars> & arg) {
329  return diffFunc(
330  arg, [](ValType v) { return std::sinh(v); },
331  [](ValType v) { return std::cosh(v); });
332 }
333 
334 template<typename ValType, size_t NumVars>
335 requires(std::is_arithmetic<ValType>::value && NumVars >=
336  0) auto cosh(const DiffVar<ValType, NumVars> & arg) {
337  return diffFunc(
338  arg, [](ValType v) { return std::cosh(v); },
339  [](ValType v) { return std::sinh(v); });
340 }
341 
342 template<typename ValType, size_t NumVars>
343 requires(std::is_arithmetic<ValType>::value && NumVars >=
344  0) auto tanh(const DiffVar<ValType, NumVars> & arg) {
345  return diffFunc(
346  arg, [](ValType v) { return std::tanh(v); },
347  [](ValType v) { return 1 / (std::cosh(v) * std::cosh(v)); });
348 }
349 
350 template<typename ValType, size_t NumVars>
351 requires(std::is_arithmetic<ValType>::value && NumVars >=
352  0) auto exp(const DiffVar<ValType, NumVars> & arg) {
353  auto derivEval = std::exp(arg.var);
354  DiffVar<ValType, NumVars> toRet(derivEval);
355 
356  auto it1 = toRet.diffVars.begin();
357  auto it2 = arg.diffVars.begin();
358  while (it1 != toRet.diffVars.end() && it2 != arg.diffVars.end()) {
359  *it1 = *it2 * derivEval;
360  ++it1;
361  ++it2;
362  }
364  return toRet;
365 }
366 
367 template<typename ValType, size_t NumVars, typename OT>
368  requires(std::is_arithmetic<ValType>::value && NumVars >= 0) &&
369  (!std::is_same<OT, DiffVar<ValType, NumVars> >::value) auto pow(
370  const DiffVar<ValType, NumVars> & arg, OT exponent) {
371  return diffFunc(
372  arg, [&exponent](ValType v) { return std::pow(v, exponent); },
373  [&exponent](ValType v) { return exponent * std::pow(v, exponent - 1); });
374 }
375 
376 template<typename ValType, size_t NumVars>
377 requires(std::is_arithmetic<ValType>::value && NumVars >=
378  0) auto pow(const DiffVar<ValType, NumVars> & arg,
379  DiffVar<ValType, NumVars> exponent) {
380  auto func = [](ValType v, ValType exponent) { return std::pow(v, exponent); };
381 
382  auto deriv = [](ValType f, ValType fPrime, ValType g, ValType gPrime) {
383  return std::pow(f, g - 1) * (g * fPrime + f * std::log(f) * gPrime);
384  };
385 
386  DiffVar<ValType, NumVars> toRet(func(arg.var, exponent.var));
387 
388  auto it1 = toRet.diffVars.begin();
389  auto it2 = arg.diffVars.begin();
390  auto it3 = exponent.diffVars.begin();
391  while (it1 != toRet.diffVars.end() && it2 != arg.diffVars.end()) {
392  *it1 = deriv(arg.var, *it2, exponent.var, *it3);
393  ++it1;
394  ++it2;
395  ++it3;
396  }
397  return toRet;
398 }
399 
400 template<typename ValType, size_t NumVars>
401 requires(std::is_arithmetic<ValType>::value && NumVars >=
402  0) auto sqrt(const DiffVar<ValType, NumVars> & arg) {
403  auto funcResult = std::pow(arg.var, 0.5);
404  DiffVar<ValType, NumVars> toRet(funcResult);
405 
406  auto it1 = toRet.diffVars.begin();
407  auto it2 = arg.diffVars.begin();
408  auto derivEval = 0.5 / funcResult;
409  while (it1 != toRet.diffVars.end() && it2 != arg.diffVars.end()) {
410  *it1 = *it2 * derivEval;
411  ++it1;
412  ++it2;
413  }
414  return toRet;
415 }
416 
417 
418 } // namespace AutoDifferentiation
419 #endif
v
Definition: comp.m:4
a namespace to hold the messiness of my auto-differentiator
requires(std::is_arithmetic< ValType >::value &&NumVars >=0) struct DiffVar
NumVars DiffVar< ValType, NumVars > exponent
ang(col sin()