Custom loss function in Tensorflow 2.0

Source: Deep Learning on Medium

Custom loss function in Tensorflow 2.0

High and low-level implementation of Loss.

No escaping version 2.0 in 2020

The first hitch I ran into when I was learning to write my own layers in Tensorflow (TF) was how to write a loss function. TF contains almost all the loss functions you’ll ever need but sometimes that is not enough. When implementing deep reinforcement learning or constructing your own models you might need to write your loss functions. This is exactly what this blog post intends to elucidate. We will write a loss function in two different ways:

  1. For tf.keras model (High Level)
  2. For custom TF models (Low Level)

For both cases, we will construct a simple neural network to learn squares of numbers. The network will take in one input and will have one output. The network is by no means successful or complete. It is highly rudimentary and is meant to only demonstrate the different loss function implementations.

1. tf.keras custom loss (High level)

Let’s look at a high-level loss function. We assume that we have already constructed a model using tf.keras. A custom loss function for the model can be implemented in the following way:

High level loss implementation in tf.keras

First things first, a custom loss function ALWAYS requires two arguments. The first one is the actual value (y_actual) and the second one is the predicted value via the model (y_model). It is important to note that both these are TF Tensors and not Numpy arrays. Inside the function you free to calculate loss however you want to calculate it, with the exception that the return value needs to be a TF scalar.

In the above example, I have calculated squared error loss using the tensorflow.keras.backend.square(). However, it is not necessary to use Keras backend, and any valid tensor operation will do just fine.

Once the loss function is computed we need to pass it in the model compile call.

model.compile(loss=custom_loss,optimizer=optimizer)

The complete code can be found here: link

2. Custom TF loss (Low level)

In the previous part, we looked at a tf.keras model. What if we wanted to write a network from scratch in TF, how would we implement the loss function in this case? This will be a low level implementation of the model. Let us again write the same model. The complete implementation is given below,

Low level implementation of model in TF 2.0

Ufff! that’s a lot of code. Let’s unpack the information.

__init__(): The constructor constructs the layers of the model (without returning a tf.keras.model.

run(): Runs the model for a given input by passing the input manually through layers and returns the output of the final layer.

get_loss(): computes the loss and returns it as a TF Tensor value

Till now it seemed pretty easy to implement loss since we were dealing directly with models, but now we need to perform learning by auto-differentiation. This is implemented in TF 2.0 using tf.GradientTape(). The function get_grad() computes the gradient wrt to the variables of the layers. It is important to note that all the variables and arguments are TF Tensors.

network_learn(): We apply the gradient descent step in this function using the gradients obtained from the get_grad() function.

With everything written out, we can implement the training loop as following,

x=[1,2,3,4,5,6,7,8,9,10]
x=np.asarray(x,dtype=np.float32).reshape((10,1))
y=[1,4,9,16,25,36,49,64,81,100]
y=np.asarray(y,dtype=np.float32).reshape((10,1))
model=model()
for i in range(100):
model.network_learn(x,y)

Full code can be found here: link

Conclusion

In this post, we have seen both the high-level and the low-level implantation of a custom loss function in TensorFlow 2.0. Knowing how to implement a custom loss function is indispensable in Reinforcement Learning or advanced Deep Learning and I hope that this small post has made it easier for you to implement your own loss function. For more details on custom loss functions, we refer to the reader to the TensorFlow documentation.