#### Custom loss function does not converge – Tensorflow

I’m training a model to predict an angle in the [0, 2*pi) interval and I want to use a loss function that understands that we’re working on a circumference.
I defined the loss function `cos_loss`, the idea of the penalty is to encourage the model to keep predictions in the desired interval.

``````def cos_loss(y_true, y_pred):
loss = 2 * (1 - tf.math.cos(y_true-y_pred))
penalty = tf.math.maximum(0., y_pred - 2* np.pi)
return tf.reduce_mean(loss + penalty, axis=-1)

metrics = ["mae","mse", root_mean_squared_error]
self.model.compile(
loss=cos_loss,
optimizer=optimizer,
metrics=metrics
)
``````

But the training is failing as both training and validation loss remain the same (see log below) However if I use a standard loss function, like MSE, the model actually trains. I also tried other different “circular” function and none of the was capable of converging.

``````Epoch 1/100
2024-08-02 14:12:12,975 1814 [callbacks.py:61] :    INFO: Epoch: 0 - loss: 1.99 val_loss: 2.39
17/17 [==============================] - 12s 398ms/step - loss: 1.9908 - mae: 2.9680 - mse: 12.1402 - root_mean_squared_error: 2.9680 - val_loss: 2.3948 - val_mae: 3.3679 - val_mse: 13.8058 - val_root_mean_squared_error: 3.3679
Epoch 2/100
2024-08-02 14:12:18,107 1814 [callbacks.py:61] :    INFO: Epoch: 1 - loss: 1.98 val_loss: 2.39
17/17 [==============================] - 5s 300ms/step - loss: 1.9792 - mae: 2.9828 - mse: 12.2488 - root_mean_squared_error: 2.9828 - val_loss: 2.3948 - val_mae: 3.3679 - val_mse: 13.8058 - val_root_mean_squared_error: 3.3679
Epoch 3/100
2024-08-02 14:12:23,498 1814 [callbacks.py:61] :    INFO: Epoch: 2 - loss: 1.98 val_loss: 2.39
17/17 [==============================] - 5s 325ms/step - loss: 1.9792 - mae: 2.9828 - mse: 12.2488 - root_mean_squared_error: 2.9828 - val_loss: 2.3948 - val_mae: 3.3679 - val_mse: 13.8058 - val_root_mean_squared_error: 3.3679
Epoch 4/100
2024-08-02 14:12:28,983 1814 [callbacks.py:61] :    INFO: Epoch: 3 - loss: 1.98 val_loss: 2.39
17/17 [==============================] - 5s 329ms/step - loss: 1.9792 - mae: 2.9828 - mse: 12.2488 - root_mean_squared_error: 2.9828 - val_loss: 2.3948 - val_mae: 3.3679 - val_mse: 13.8058 - val_root_mean_squared_error: 3.3679
Epoch 5/100
2024-08-02 14:12:34,369 1814 [callbacks.py:61] :    INFO: Epoch: 4 - loss: 1.98 val_loss: 2.39
17/17 [==============================] - 10s 585ms/step - loss: 1.9792 - mae: 2.9828 - mse: 12.2488 - root_mean_squared_error: 2.9828 - val_loss: 2.3948 - val_mae: 3.3679 - val_mse: 13.8058 - val_root_mean_squared_error: 3.3679
Epoch 6/100
2024-08-02 14:12:44,881 1814 [callbacks.py:61] :    INFO: Epoch: 5 - loss: 1.98 val_loss: 2.39
17/17 [==============================] - 6s 372ms/step - loss: 1.9792 - mae: 2.9828 - mse: 12.2488 - root_mean_squared_error: 2.9828 - val_loss: 2.3948 - val_mae: 3.3679 - val_mse: 13.8058 - val_root_mean_squared_error: 3.3679
Epoch 7/100
2024-08-02 14:12:50,794 1814 [callbacks.py:61] :    INFO: Epoch: 6 - loss: 1.98 val_loss: 2.39
17/17 [==============================] - 6s 357ms/step - loss: 1.9792 - mae: 2.9828 - mse: 12.2488 - root_mean_squared_error: 2.9828 - val_loss: 2.3948 - val_mae: 3.3679 - val_mse: 13.8058 - val_root_mean_squared_error: 3.3679
Epoch 8/100
2024-08-02 14:12:56,460 1814 [callbacks.py:61] :    INFO: Epoch: 7 - loss: 1.98 val_loss: 2.39
17/17 [==============================] - 6s 337ms/step - loss: 1.9792 - mae: 2.9828 - mse: 12.2488 - root_mean_squared_error: 2.9828 - val_loss: 2.3948 - val_mae: 3.3679 - val_mse: 13.8058 - val_root_mean_squared_error: 3.3679
Epoch 9/100
2024-08-02 14:13:01,771 1814 [callbacks.py:61] :    INFO: Epoch: 8 - loss: 1.98 val_loss: 2.39
17/17 [==============================] - 10s 577ms/step - loss: 1.9792 - mae: 2.9828 - mse: 12.2488 - root_mean_squared_error: 2.9828 - val_loss: 2.3948 - val_mae: 3.3679 - val_mse: 13.8058 - val_root_mean_squared_error: 3.3679
Epoch 10/100
2024-08-02 14:13:11,118 1814 [callbacks.py:61] :    INFO: Epoch: 9 - loss: 1.98 val_loss: 2.39
17/17 [==============================] - 5s 301ms/step - loss: 1.9792 - mae: 2.9828 - mse: 12.2488 - root_mean_squared_error: 2.9828 - val_loss: 2.3948 - val_mae: 3.3679 - val_mse: 13.8058 - val_root_mean_squared_error: 3.3679
Epoch 11/100
2024-08-02 14:13:16,100 1814 [callbacks.py:61] :    INFO: Epoch: 10 - loss: 1.98 val_loss: 2.39
17/17 [==============================] - 5s 293ms/step - loss: 1.9792 - mae: 2.9828 - mse: 12.2488 - root_mean_squared_error: 2.9828 - val_loss: 2.3948 - val_mae: 3.3679 - val_mse: 13.8058 - val_root_mean_squared_error: 3.3679
Epoch 12/100
17/17 [==============================] - ETA: 0s - loss: 1.9792 - mae: 2.9828 - mse: 12.2488 - root_mean_squared_error: 2.9828
``````

The model is a convolutional network + fully connected layer.

``````>>> model.summary()
_________________________________________________________________
Layer (type)                Output Shape              Param #
=================================================================
layer_0 (Reshape)           (None, 64, 100, 4)        0

layer_1 (Conv2D)            (None, 32, 34, 2)         51200

layer_2 (BatchNormalizatio  (None, 32, 34, 2)         8
n)

layer_3 (Activation)        (None, 32, 34, 2)         0

layer_4 (Conv2D)            (None, 64, 17, 1)         18432

layer_5 (BatchNormalizatio  (None, 64, 17, 1)         4
n)

layer_6 (Activation)        (None, 64, 17, 1)         0

layer_7 (Conv2D)            (None, 96, 9, 1)          55296

layer_8 (BatchNormalizatio  (None, 96, 9, 1)          4
n)

layer_9 (Activation)        (None, 96, 9, 1)          0

layer_10 (Conv2D)           (None, 128, 5, 1)         110592

layer_11 (BatchNormalizati  (None, 128, 5, 1)         4
on)

layer_12 (Activation)       (None, 128, 5, 1)         0

layer_13 (Flatten)          (None, 640)               0

layer_14 (Dense)            (None, 1024)              655360

layer_15 (Dense)            (None, 512)               524288

layer_16 (Dense)            (None, 256)               131072

layer_17 (Dense)            (None, 128)               32768

layer_18 (Dense)            (None, 1)                 128
``````

Theme wordpress giá rẻ Theme wordpress giá rẻ Thiết kế website