10 #ifndef NNLIB_FUNCTIONS_H
11 #define NNLIB_FUNCTIONS_H
22 std::vector<sTensor> parents;
24 virtual std::vector<sTensor> backward(sTensor grad) = 0;
29 template<
typename... Types>
34 sTensor forward(
const Types&... args) {
35 auto tup = getType<sTensor>(std::make_tuple(args...));
36 parents = toVector(tup);
38 sTensor result = forwardFn(args...);
39 for (
auto& parent : parents) {
40 if (parent->requiresGrad) {
41 result->requiresGrad =
true;
48 std::vector<sTensor> backward(sTensor grad)
override {
49 return backwardFn(std::move(grad));
52 virtual sTensor forwardFn(
const Types&... args) = 0;
54 virtual std::vector<sTensor> backwardFn(sTensor grad) = 0;
60 std::vector<size_t> shapeCache;
63 sTensor forwardFn(
const sTensor& a)
override;
65 std::vector<sTensor> backwardFn(sTensor grad)
override;
70 sTensor forwardFn(
const sTensor& a,
const sTensor& b)
override;
72 std::vector<sTensor> backwardFn(sTensor grad)
override;
77 sTensor forwardFn(
const sTensor& a,
const sTensor& b)
override;
79 std::vector<sTensor> backwardFn(sTensor grad)
override;
84 sTensor forwardFn(
const sTensor& a,
const sTensor& b)
override;
86 std::vector<sTensor> backwardFn(sTensor grad)
override;
94 sTensor forwardFn(
const sTensor& a,
const sTensor& b)
override;
96 std::vector<sTensor> backwardFn(sTensor grad)
override;
104 sTensor forwardFn(
const sTensor& a,
const sTensor& b)
override;
106 std::vector<sTensor> backwardFn(sTensor grad)
override;
113 sTensor forwardFn(
const sTensor& args)
override;
115 std::vector<sTensor> backwardFn(sTensor grad)
override;
122 sTensor forwardFn(
const sTensor& a,
const float& b)
override;
124 std::vector<sTensor> backwardFn(sTensor grad)
override;
132 sTensor forwardFn(
const sTensor& a,
const sTensor& b)
override;
134 std::vector<sTensor> backwardFn(sTensor grad)
override;
142 sTensor forwardFn(
const sTensor& a,
const sTensor& b)
override;
144 std::vector<sTensor> backwardFn(sTensor grad)
override;
146 ~
Matmul()
override =
default;
153 sTensor forwardFn(
const sTensor& a)
override;
155 std::vector<sTensor> backwardFn(sTensor grad)
override;
162 sTensor forwardFn(
const sTensor& a)
override;
164 std::vector<sTensor> backwardFn(sTensor grad)
override;
171 sTensor forwardFn(
const sTensor& a)
override;
173 std::vector<sTensor> backwardFn(sTensor grad)
override;
Definition: functions.h:75
Definition: functions.h:68
Definition: functions.h:20
Definition: functions.h:99
Definition: functions.h:30
Definition: functions.h:89
Definition: functions.h:109
Definition: functions.h:127
Definition: functions.h:137
Definition: functions.h:118
Definition: functions.h:158
Definition: functions.h:167
Definition: functions.h:82
Definition: functions.h:59
Definition: functions.h:149
Header file declaring the Tensor class to represent multidimensional arrays.