nnlib
GPU-accelerated, C/C++ neural network library.
functions.h
1 
10 #ifndef NNLIB_FUNCTIONS_H
11 #define NNLIB_FUNCTIONS_H
12 
13 #include <typeinfo>
14 #include <utility>
15 
16 #include "tensor.h"
17 #include "tuple_utils.h"
18 
19 
21 public:
22  std::vector<sTensor> parents;
23 
24  virtual std::vector<sTensor> backward(sTensor grad) = 0;
25 
26  virtual ~BackwardFunction() = default;
27 };
28 
29 template<typename... Types>
30 class Function : public BackwardFunction {
31 public:
32  Function() = default;
33 
34  sTensor forward(const Types&... args) {
35  auto tup = getType<sTensor>(std::make_tuple(args...));
36  parents = toVector(tup);
37 
38  sTensor result = forwardFn(args...);
39  for (auto& parent : parents) {
40  if (parent->requiresGrad) {
41  result->requiresGrad = true;
42  break;
43  }
44  }
45  return result;
46  }
47 
48  std::vector<sTensor> backward(sTensor grad) override {
49  return backwardFn(std::move(grad));
50  }
51 
52  virtual sTensor forwardFn(const Types&... args) = 0;
53 
54  virtual std::vector<sTensor> backwardFn(sTensor grad) = 0;
55 
56  ~Function() override = default;
57 };
58 
59 class SumReduce : public Function<sTensor> {
60  std::vector<size_t> shapeCache;
61 
62 public:
63  sTensor forwardFn(const sTensor& a) override;
64 
65  std::vector<sTensor> backwardFn(sTensor grad) override;
66 };
67 
68 class Add : public Function<sTensor, sTensor> {
69 public:
70  sTensor forwardFn(const sTensor& a, const sTensor& b) override;
71 
72  std::vector<sTensor> backwardFn(sTensor grad) override;
73 };
74 
75 class AddBroadcast : public Function<sTensor, sTensor> {
76 public:
77  sTensor forwardFn(const sTensor& a, const sTensor& b) override;
78 
79  std::vector<sTensor> backwardFn(sTensor grad) override;
80 };
81 
82 class Subtract : public Function<sTensor, sTensor> {
83 public:
84  sTensor forwardFn(const sTensor& a, const sTensor& b) override;
85 
86  std::vector<sTensor> backwardFn(sTensor grad) override;
87 };
88 
89 class Hadamard : public Function<sTensor, sTensor> {
90  sTensor cacheA;
91  sTensor cacheB;
92 
93 public:
94  sTensor forwardFn(const sTensor& a, const sTensor& b) override;
95 
96  std::vector<sTensor> backwardFn(sTensor grad) override;
97 };
98 
99 class Divide : public Function<sTensor, sTensor> {
100  sTensor cacheA;
101  sTensor cacheB;
102 
103 public:
104  sTensor forwardFn(const sTensor& a, const sTensor& b) override;
105 
106  std::vector<sTensor> backwardFn(sTensor grad) override;
107 };
108 
109 class Log : public Function<sTensor> {
110  sTensor cacheA;
111 
112 public:
113  sTensor forwardFn(const sTensor& args) override;
114 
115  std::vector<sTensor> backwardFn(sTensor grad) override;
116 };
117 
118 class MulConstant : public Function<sTensor, float> {
119  float constantCache;
120 
121 public:
122  sTensor forwardFn(const sTensor& a, const float& b) override;
123 
124  std::vector<sTensor> backwardFn(sTensor grad) override;
125 };
126 
127 class MatVecMul : public Function<sTensor, sTensor> {
128  sTensor cacheA;
129  sTensor cacheB;
130 
131 public:
132  sTensor forwardFn(const sTensor& a, const sTensor& b) override;
133 
134  std::vector<sTensor> backwardFn(sTensor grad) override;
135 };
136 
137 class Matmul : public Function<sTensor, sTensor> {
138  sTensor cacheA;
139  sTensor cacheB;
140 
141 public:
142  sTensor forwardFn(const sTensor& a, const sTensor& b) override;
143 
144  std::vector<sTensor> backwardFn(sTensor grad) override;
145 
146  ~Matmul() override = default;
147 };
148 
149 class Transpose : public Function<sTensor> {
150  sTensor cacheA;
151 
152 public:
153  sTensor forwardFn(const sTensor& a) override;
154 
155  std::vector<sTensor> backwardFn(sTensor grad) override;
156 };
157 
158 class ReLU : public Function<sTensor> {
159  sTensor cacheA;
160 
161 public:
162  sTensor forwardFn(const sTensor& a) override;
163 
164  std::vector<sTensor> backwardFn(sTensor grad) override;
165 };
166 
167 class Sigmoid : public Function<sTensor> {
168  sTensor cacheA;
169 
170 public:
171  sTensor forwardFn(const sTensor& a) override;
172 
173  std::vector<sTensor> backwardFn(sTensor grad) override;
174 };
175 
176 #endif //NNLIB_FUNCTIONS_H
Definition: functions.h:75
Definition: functions.h:68
Definition: functions.h:20
Definition: functions.h:99
Definition: functions.h:30
Definition: functions.h:89
Definition: functions.h:109
Definition: functions.h:127
Definition: functions.h:137
Definition: functions.h:118
Definition: functions.h:158
Definition: functions.h:167
Definition: functions.h:82
Definition: functions.h:59
Definition: functions.h:149
Header file declaring the Tensor class to represent multidimensional arrays.