Embeddings with Word2Vec in non-NLP Contexts — Details

Source: Deep Learning on Medium

Dual Embedding Space

SGNS learns two different embedding vectors for each object; the In and Out vectors (also referred to as Target and Context vectors). Why would you need two different vectors for each object? Many people use a single vector in downstream tasks[26] by averaging the In and Out vectors , or throw away the Out vectors and use only the In vectors. You can also learn a single embedding vector by using a shared embedding parameter layer in your model while training(Siamese network with shared parameters[25]).

So why create two separate vectors for each object? Let’s inspect technical and logical reasoning.

Technical: Let’s shift our mindset from NLP to Instacart dataset; “words” become “products” and “sentences” become “order basket”. Context for a product is the other products within the current order basket. For a product “Bag of bananas”, consider the case where we use the same vector for the Target(In) and Context(Out). “Bag of bananas” do not semantically occur within the context of itself (context being the order basket). By using the same vector v for the “Bag of Bananas”; assigning a low probability of p(“Bag of bananas”|“Bag of bananas”) would be impossible because assigning a low value to v · v is impossible.

Logical: Using the dual In-Out vectors enable us to evaluate the probability of product being in a target or context environment. Thus we can calculate product similarity (Cosine similarity in In Vectors) or product complementarity (Cosine similarity between In and Out vectors). Eventually, this “dual embedding space” architecture forms the foundation of more advanced models for similarity/complementary prediction that are used in production[16,17,18,19,20].

Word2Vec Model in Tensorflow(Also refered as Dual Encoder Model, Siamese Networks or Dual Tower Networks)

Model Parameters

Let’s evaluate SGNS parameters;

Window size: Setting a window size is task dependent. In the Airbnb case [11], where listing embeddings are generated with users’ listing click session sequences, the listings that are clicked consecutively in a session might be more semantically related than the listings clicked first and last in a session. So setting a small window size (3–6) might be appropriate to narrow the relatedness window in a sequence. The more data you have, the lower window size you can use.

However, in the Instacart dataset case, a product in an order basket is related with all the other products in the basket because our objective function is similarity/complementarity prediction within “basket level”. So our window size is the basket size count for each order. As an additional theoretical note, if your dataset is large enough and if you shuffle the order of products in an order basket for each epoch, you can use a smaller window size; and may achieve the same results as in using a larger window size.

Dataset generation: Target-Context (In-Out) data pairs are constructed from your dataset using the window size parameter. For each target, you can add additional data pairs to your dataset for the following objectives:

  • Adding target metadata for better generalization (Meta-Prod2Vec)[8]. E.g. Target-Product Category
  • Embedding other objects into the same embedding space such as Brands[8] E.g. Target-Brand
  • Adding additional target-context pairs to influence or add additional associations to embedding vectors[11]

Epoch: Number of epochs do not have a marginal effect on outcome, you can easily decide with offline convergence evaluation. However, be aware that the original Word2Vec code[36] and libraries like Gensim do not use mini-batching (with no mini-batch, model parameters are updated with each data in dataset), thus increasing the number of epochs will not have the same effect as compared to a model that uses mini-batching.

Candidate Sampling: Candidate sampling algorithms enable efficient task learning architectures without calculating the full softmax over the entire label classes[28,29]. As SGNS uses negative sampling method[2], sampling distribution generation and associated sampling parameters play a crucial task in setting up successful SGNS models. So, how do you set up your negative sampling architecture?

  • Generic sampling — Your negative samples are drawn from the same input dataset using a sampling distribution parameter (more on that below) .
  • Context specific sampling — You select your negative samples using your target context. In Instacart case, for a particular product, you can select negative samples from the same product category/aisle. This ‘hard negatives’ technique enables the model to converge faster and better. However, you need to invest resources into this, since you need to be able to select negatives for each Target. The negatives can be retrieved during mini-batch training, or alternatively you can generate a static negative distribution dataset beforehand. This choice depends on your training hardware, distributed training architecture and costs.

Negative sampling noise distribution parameter(α): Negative samples are sampled from your distribution using a frequency smoothing parameter(α) where the frequency of items are raised to the power of α. With α, you can adjust the probability of selecting popular or rare items as negatives.

  • α=1 is uniform distribution — original item frequencies in the dataset are used.
  • 0<α<1 — high frequency items are smoothed down.
  • α=0 is unigram distribution — item frequency is 1 in dataset.
  • α<0 — low frequency items are weighted up.

Number of negative Samples (k): For each target in our sampling distribution, k number of negatives are selected. In the next section, we will evaluate the correlation between k and α.

Evaluation- Next Product Prediction in Order Basket

We will evaluate model parameters (k,α) using Instacart dataset by predicting the next item in a current order basket.

Sample code [31].(For clarity, Gensim is used.)

After our model is trained on training dataset, using the test set, we will hide a random product in customers’ order basket and predict the hidden item by using other products in the basket. We will calculate a “basket vector” by averaging the Target (In) embeddings of the products in the basket. Then, with the calculated “basket vector”, we will search the nearest items in the Context(Out) vector space and present nearest items as recommendations. This recommendation is basically “Here are the products recommended for you, calculated by what you have already put in your basket”. Below is the Hitrate@10 analysis with varying k and α.

Hitrate@10

We see that the accuracy is low with high α (α=0.75, α=1). The intuition being; with high α, popular high frequency items/users dominate the distribution and decreases the model’s generalization capability.

With decreasing α, we are predicting more ‘less frequent’ products, and this results in better model score, maxed at α=0. So which α would you choose for this model? I would chose α=-0.5, because although it has a lower score than α=0, I would argue that it will score better in online evaluation, assuming that it will offer customers more diverse recommendations (Serendipity, Novelty).

Correlation between α and k:

With high α (α>0); increasing k decreases accuracy. The intuition is; with high α, you are recommending popular/high frequency items which does not reflect the real distribution. If you increase k in this situation, you are selecting even more popular/high frequency items as negatives which results further overfitting to a incorrect distribution. The end point is; if you want to recommend popular items with high α, no need to increase k further.

With low α (α<0); increasing k increases accuracy. The intuition is; with low α, you recommending more diverse items from the tail (infrequent items). In this realm, if you increase the k, you are selecting more negative samples from the tail, enabling the model to see more distinct items, thus fit nicely to a diverse distribution. The end point is; if you want to recommend diverse items (Novelty, Serendipity) with low α, you can increase k further up to a point where it starts to overfits.

Note: Don’t forget to use cosine distance for your evaluation (dot product with normalized vectors). If you use euclidean distance, you can over emphasize frequent items in your dataset; and also may incorrectly represent infrequent items due to your embedding initialization values.

Further analysis

So far, we have analyzed Word2Vec. It’s strength is its simplicity and scalability, but that is also its downfall. With Word2Vec, it is not easy to incorporate item properties or item-item relationship properties(graph structure). These additional item/interaction features are required for better model generalization/accuracy and as well as alleviating the cold start problem.

Yes, you can add item features with Meta-Prod2Vec [8], and with Graph Random Walk models([9,10]) you can add node features as well. But problem is that, all these added ‘features’ are IDs, not vectors; namely, they are the IDs of your features stored in your backend. For the state-of-the-art models, you may want to use vectors as item features(E.g. encoded item image vector). Also, the added item features are not weighted, hence, the effect of each item property has on the item and model output is not learned as parameters.

One other drawback of Word2Vec is that is is Transductive; meaning that when a new item is added to the available item list, we have to re-train the whole model (or continue training the model).

With the relatively new Inductive Graph Neural models[34,35], we can add item/relationship features as learnable vectors, and also get embeddings of new items unseen to the trained model.

Graph Neural Network research is an exciting area for the coming years.

Conclusion

The main takeaway of this document;

  • The analysis of Word2Vec based models in Non-NLP Context.
  • Model parameter analysis, mainly the correlation between α and k for the business context of similarity/complementary predictions.

The future documents will continue to explore more advanced embedding models and downstream deep models.