nnlib
GPU-accelerated, C/C++ neural network library.
loss.h
Go to the documentation of this file.
1 
8 #ifndef NNLIB_LOSS_H
9 #define NNLIB_LOSS_H
10 
11 #include "metric.h"
12 #include "tensor.h"
13 
21 class Loss : public Metric {
22 
28 public:
29  Loss();
30 
40  float calculateMetric(const sTensor& targets, const sTensor& predictions) override;
41 
42  virtual sTensor calculateLoss(const sTensor& targets, const sTensor& predictions) = 0;
43 };
44 
48 class MeanSquaredError : public Loss {
49 public:
50  sTensor calculateLoss(const sTensor& targets, const sTensor& predictions) override;
51 
52  [[nodiscard]] std::string getShortName() const override;
53 };
54 
60 class BinaryCrossEntropy : public Loss {
61 public:
62  sTensor calculateLoss(const sTensor& targets, const sTensor& predictions) override;
63 
64  [[nodiscard]] std::string getShortName() const override;
65 };
66 
76 class CategoricalCrossEntropy : public Loss {
77 public:
78  sTensor calculateLoss(const sTensor& targets, const sTensor& predictions) override;
79 
80  [[nodiscard]] std::string getShortName() const override;
81 };
82 
83 
84 #endif //NNLIB_LOSS_H
Class representing the Binary Cross Entropy.
Definition: loss.h:60
std::string getShortName() const override
Short string identifier of the metric.
Definition: binary_cross_entropy.cpp:28
Class representing the Categorical Cross Entropy.
Definition: loss.h:76
std::string getShortName() const override
Short string identifier of the metric.
Definition: categorical_cross_entropy.cpp:10
Abstract class representing a loss function.
Definition: loss.h:21
float calculateMetric(const sTensor &targets, const sTensor &predictions) override
Defines the method inherited from abstract Metric parent.
Definition: loss.cpp:13
Loss()
Constructor for the Loss class.
Definition: loss.cpp:10
Class representing the Mean Squared Error.
Definition: loss.h:48
std::string getShortName() const override
Short string identifier of the metric.
Definition: mean_squared_error.cpp:10
An abstract class to represent metrics.
Definition: metric.h:19
Header file declaring different metrics.
Header file declaring the Tensor class to represent multidimensional arrays.