Segmentation through Q-Learning

Source: Deep Learning on Medium


Go to the profile of Shahbaz Khan

So we have many a great approaches on text segmentation already and they perform really well. However I couldn’t help but think about trying to perform semantic text segmentation through reinforcement learning. So I will try to explain my approach and my reasoning behind why it could (if not already) be great at it.

So we have function approximation and we already know that a neural network can learn a non-linear function to a great scale, if you are not already familiar with Deep Q-Learning, I would suggest you look into that first.

So a quick introduction to Deep Q-Learning, is that we are basically trying to approximate a function using a neural network that can map states to action probabilities.

Now we know something about reinforcement learning agents that work upon the idea of function approximation, is that they try to generalize to the state space, meaning they tend to perform in the same way when they see similar states.

Now that is a key concept in semantic segmentation, we want to find segments in a change of context, and as we know, “change is constant”.

So though we can not recognize or differentiate one particular change from another, we however can learn to identify a change, whatever be its type.

And thus, we have a working idea, and more importantly, a motivation to experiment with the thought that a reinforcement learning agent can generalize well enough to learn to segment text in a specific domain, if that.

So now diving into my approach and the implementation:

We have something I like to call a “window”, which is our context of sentences. A window size of 12 results in the agent being able to observe a state of 12 sentences and then choosing where the new segment starts.

Thus when we feed in a window of sentences, we have formalized a state framework, and this does not defy the Markov property, thus we can apply our fancy reinforcement learning algorithms to find out the optimal policy.

So basically what is happening, is that we feed in a context window of n sentences and the agent decides where in the index of (0 to n-1) does the segment lie, or is there no segment at all.

We reward it for correctly identifying segments and we punish it for inserting new segments or missing out actual segments.

If you are interested in knowing how it was implemented, check out the code on my repository:

https://github.com/thehawkgriffith/dqn-segmentation

I do not know how well the model scales, but it requires a lot of optimization. I will keep updating the repository as I find something better.

Thank you for the read!