nnlib
GPU-accelerated, C/C++ neural network library.
metric.h
Go to the documentation of this file.
1 
8 #ifndef NNLIB_METRIC_H
9 #define NNLIB_METRIC_H
10 
11 #include "tensor.h"
12 
19 class Metric {
20 
24 protected:
25  size_t numSamples;
26 
31 
37 public:
38  Metric();
39 
43  void reset();
44 
52  virtual float calculateMetric(const sTensor& targets, const sTensor& predictions) = 0;
53 
61  [[nodiscard]] virtual std::string getShortName() const = 0;
62 };
63 
71 class CategoricalAccuracy : public Metric {
72 
76 public:
78 
82  float calculateMetric(const sTensor& targets, const sTensor& predictions) override;
83 
87  [[nodiscard]] std::string getShortName() const override;
88 };
89 
97 class BinaryAccuracy : public Metric {
98 
102 public:
103  BinaryAccuracy();
104 
108  float calculateMetric(const sTensor& targets, const sTensor& predictions) override;
109 
113  [[nodiscard]] std::string getShortName() const override;
114 };
115 
116 #endif //NNLIB_METRIC_H
The implementation of binary accuracy.
Definition: metric.h:97
std::string getShortName() const override
Short string identifier of the metric.
Definition: binary_accuracy.cpp:47
BinaryAccuracy()
Constructor of BinaryAccuracy.
Definition: binary_accuracy.cpp:11
float calculateMetric(const sTensor &targets, const sTensor &predictions) override
Calcualate the current value of the metric given the new batches of targets and predictions.
Definition: binary_accuracy.cpp:27
The implementation of categorical accuracy.
Definition: metric.h:71
CategoricalAccuracy()
Constructor of CategoricalAccuracy.
Definition: categorical_accuracy.cpp:13
std::string getShortName() const override
Short string identifier of the metric.
Definition: categorical_accuracy.cpp:46
float calculateMetric(const sTensor &targets, const sTensor &predictions) override
Calcualate the current value of the metric given the new batches of targets and predictions.
Definition: categorical_accuracy.cpp:16
An abstract class to represent metrics.
Definition: metric.h:19
virtual std::string getShortName() const =0
Short string identifier of the metric.
Metric()
Constructor for the Metric class.
Definition: metric.cpp:10
size_t numSamples
The number of samples processed so far.
Definition: metric.h:25
float currentTotalMetric
The current total value of the metric.
Definition: metric.h:30
virtual float calculateMetric(const sTensor &targets, const sTensor &predictions)=0
Calcualate the current value of the metric given the new batches of targets and predictions.
void reset()
Reset the metric, i.e.: set numSamples and currentTotalMetric to 0.
Definition: metric.cpp:13
Header file declaring the Tensor class to represent multidimensional arrays.