Skip to article frontmatterSkip to article content

Word Embeddings

In our previous notebook, we built a solid sentiment classifier using logistic regression and simple count vectorizer. While that model performed well by focusing on word counts and phrases, it treated each word largely as an independent entity without understanding its underlying meaning or relationship to other words.

This time, we’ll delve into word embeddings – a technique that represents words as dense vectors, capturing their semantic relationships and context within a large body of text. Our today’s goal is to integrate them into our logistic regression pipeline and see how they will affect our sentiment analysis score.

Data Preparation

We are going to use the same dataset as before.

from datasets import load_dataset
ds = load_dataset('stanfordnlp/imdb', split='train+test')
train, test = ds.train_test_split(test_size=0.2, seed=0).values()
display(train.to_pandas())

x_train = train['text']
y_train = train['label']
x_test = test['text']
y_test = test['label']
Loading...

This time, we will do some data preprocessing. For each text sample, we are going to apply a technique called semantic vectorization. Its core idea is to vectorize separate words into a thing called word embedding instead of a simple index.

What is a word embedding? Essentially, it is a mathematical representation of a word (or phrase) as a vector (a numerical array) in a high-dimensional space. These vectors capture the semantic meaning of the word by representing its relationships to other words in a corpus of text. That means that semantically similar words would be close in this high-dimensional space.

This approach requires a separate vectorization model - we could start with a pre-trained model called GoogleNews300. It contains 300-dimensional vectors for 3 million words and phrases.

from os import path
from huggingface_hub import snapshot_download
model_path = path.join(snapshot_download('fse/word2vec-google-news-300'), 'word2vec-google-news-300.model')
Output
Loading...

Let’s see it in action by loading it and comparing the similarities of different words.

from gensim.models import KeyedVectors
wv = KeyedVectors.load(model_path)

pairs = [
    ('car', 'minivan'),   # a minivan is a kind of car
    ('car', 'bicycle'),   # still a wheeled vehicle
    ('car', 'airplane'),  # ok, no wheels, but still a vehicle
    ('car', 'cereal'),    # ... and so on
    ('car', 'communism'),
]

for w1, w2 in pairs:
    print('%r\t%r\t%.2f' % (w1, w2, wv.similarity(w1, w2)))
'car'	'minivan'	0.69
'car'	'bicycle'	0.54
'car'	'airplane'	0.42
'car'	'cereal'	0.14
'car'	'communism'	0.06

That makes some sense, right? Communism has poor relations with cars.
But what do those vectors look like?

display(wv.get_vector('hello'))
array([-0.05419922, 0.01708984, -0.00527954, 0.33203125, -0.25 , -0.01397705, -0.15039062, -0.265625 , 0.01647949, 0.3828125 , -0.03295898, -0.09716797, -0.16308594, -0.04443359, 0.00946045, 0.18457031, 0.03637695, 0.16601562, 0.36328125, -0.25585938, 0.375 , 0.171875 , 0.21386719, -0.19921875, 0.13085938, -0.07275391, -0.02819824, 0.11621094, 0.15332031, 0.09082031, 0.06787109, -0.0300293 , -0.16894531, -0.20800781, -0.03710938, -0.22753906, 0.26367188, 0.012146 , 0.18359375, 0.31054688, -0.10791016, -0.19140625, 0.21582031, 0.13183594, -0.03515625, 0.18554688, -0.30859375, 0.04785156, -0.10986328, 0.14355469, -0.43554688, -0.0378418 , 0.10839844, 0.140625 , -0.10595703, 0.26171875, -0.17089844, 0.39453125, 0.12597656, -0.27734375, -0.28125 , 0.14746094, -0.20996094, 0.02355957, 0.18457031, 0.00445557, -0.27929688, -0.03637695, -0.29296875, 0.19628906, 0.20703125, 0.2890625 , -0.20507812, 0.06787109, -0.43164062, -0.10986328, -0.2578125 , -0.02331543, 0.11328125, 0.23144531, -0.04418945, 0.10839844, -0.2890625 , -0.09521484, -0.10351562, -0.0324707 , 0.07763672, -0.13378906, 0.22949219, 0.06298828, 0.08349609, 0.02929688, -0.11474609, 0.00534058, -0.12988281, 0.02514648, 0.08789062, 0.24511719, -0.11474609, -0.296875 , -0.59375 , -0.29492188, -0.13378906, 0.27734375, -0.04174805, 0.11621094, 0.28320312, 0.00241089, 0.13867188, -0.00683594, -0.30078125, 0.16210938, 0.01171875, -0.13867188, 0.48828125, 0.02880859, 0.02416992, 0.04736328, 0.05859375, -0.23828125, 0.02758789, 0.05981445, -0.03857422, 0.06933594, 0.14941406, -0.10888672, -0.07324219, 0.08789062, 0.27148438, 0.06591797, -0.37890625, -0.26171875, -0.13183594, 0.09570312, -0.3125 , 0.10205078, 0.03063965, 0.23632812, 0.00582886, 0.27734375, 0.20507812, -0.17871094, -0.31445312, -0.01586914, 0.13964844, 0.13574219, 0.0390625 , -0.29296875, 0.234375 , -0.33984375, -0.11816406, 0.10644531, -0.18457031, -0.02099609, 0.02563477, 0.25390625, 0.07275391, 0.13574219, -0.00138092, -0.2578125 , -0.2890625 , 0.10107422, 0.19238281, -0.04882812, 0.27929688, -0.3359375 , -0.07373047, 0.01879883, -0.10986328, -0.04614258, 0.15722656, 0.06689453, -0.03417969, 0.16308594, 0.08642578, 0.44726562, 0.02026367, -0.01977539, 0.07958984, 0.17773438, -0.04370117, -0.00952148, 0.16503906, 0.17285156, 0.23144531, -0.04272461, 0.02355957, 0.18359375, -0.41601562, -0.01745605, 0.16796875, 0.04736328, 0.14257812, 0.08496094, 0.33984375, 0.1484375 , -0.34375 , -0.14160156, -0.06835938, -0.14648438, -0.02844238, 0.07421875, -0.07666016, 0.12695312, 0.05859375, -0.07568359, -0.03344727, 0.23632812, -0.16308594, 0.16503906, 0.1484375 , -0.2421875 , -0.3515625 , -0.30664062, 0.00491333, 0.17675781, 0.46289062, 0.14257812, -0.25 , -0.25976562, 0.04370117, 0.34960938, 0.05957031, 0.07617188, -0.02868652, -0.09667969, -0.01281738, 0.05859375, -0.22949219, -0.1953125 , -0.12207031, 0.20117188, -0.42382812, 0.06005859, 0.50390625, 0.20898438, 0.11230469, -0.06054688, 0.33203125, 0.07421875, -0.05786133, 0.11083984, -0.06494141, 0.05639648, 0.01757812, 0.08398438, 0.13769531, 0.2578125 , 0.16796875, -0.16894531, 0.01794434, 0.16015625, 0.26171875, 0.31640625, -0.24804688, 0.05371094, -0.0859375 , 0.17089844, -0.39453125, -0.00156403, -0.07324219, -0.04614258, -0.16210938, -0.15722656, 0.21289062, -0.15820312, 0.04394531, 0.28515625, 0.01196289, -0.26953125, -0.04370117, 0.37109375, 0.04663086, -0.19726562, 0.3046875 , -0.36523438, -0.23632812, 0.08056641, -0.04248047, -0.14648438, -0.06225586, -0.0534668 , -0.05664062, 0.18945312, 0.37109375, -0.22070312, 0.04638672, 0.02612305, -0.11474609, 0.265625 , -0.02453613, 0.11083984, -0.02514648, -0.12060547, 0.05297852, 0.07128906, 0.00063705, -0.36523438, -0.13769531, -0.12890625], dtype=float32)

Now we need to build a vectorization routine. For each sequence we will perform a simple tokenization, extract embeddings, and then squash them into a single averaged vector.

But why?

The reason is simple - most traditional classifiers (like logistic regression) are fundamentally designed to work with fixed-size, flat feature vectors. They don’t have an inherent mechanism to understand or process sequences of varying lengths or the temporal relationships within those sequences.

from gensim.utils import simple_preprocess
import numpy as np

def vectorize(text):
    tokens = simple_preprocess(text.lower(), deacc=True)
    token_vectors = [wv.get_vector(x) for x in tokens if x in wv]
    if token_vectors:
        return np.mean(token_vectors, axis=0)
    else:
        return np.zeros(wv.vector_size)

display(vectorize('Hello World!'))
Output
array([-5.90820312e-02, 4.27246094e-02, 1.09664917e-01, 2.31933594e-01, -1.54785156e-01, 1.24206543e-02, -3.73535156e-02, -2.03613281e-01, 4.36401367e-02, 2.67089844e-01, -6.06689453e-02, -8.30078125e-02, 2.09960938e-02, -1.28173828e-01, -5.04455566e-02, 1.40136719e-01, -2.25830078e-02, 1.89941406e-01, 2.62695312e-01, -1.08276367e-01, 1.48437500e-01, 1.24267578e-01, 1.52343750e-01, -1.44531250e-01, 1.11083984e-01, -2.56347656e-02, -1.44958496e-01, 8.83789062e-02, 4.87060547e-02, -2.24609375e-02, 2.72521973e-02, -4.12597656e-02, -1.77246094e-01, -6.59179688e-02, 3.29589844e-02, -1.29028320e-01, 5.17578125e-02, -3.44543457e-02, 1.60156250e-01, 2.27050781e-01, -6.68334961e-02, -1.36718750e-02, 8.53271484e-02, 2.10449219e-01, 8.93554688e-02, 2.57812500e-01, -1.57012939e-01, 2.33230591e-02, -9.74121094e-02, 1.65039062e-01, -3.05175781e-01, 1.96533203e-02, 5.90209961e-02, 3.19824219e-02, 3.24707031e-02, 2.53906250e-01, -5.61523438e-02, 7.32421875e-02, 4.04052734e-02, -1.12304688e-01, -8.22753906e-02, 9.57031250e-02, -4.71191406e-02, -1.59301758e-02, 4.32128906e-02, -1.97448730e-02, -1.67480469e-01, -5.18798828e-03, -1.45042419e-01, 7.05566406e-02, -1.02539062e-02, 1.98486328e-01, -2.39257812e-02, 1.11572266e-01, -1.80175781e-01, -9.91210938e-02, -1.17004395e-01, 3.52172852e-02, 6.72607422e-02, 7.32421875e-02, 8.30078125e-03, 3.17382812e-03, -1.28295898e-01, -1.10595703e-01, -3.41796875e-02, -6.35986328e-02, 1.33056641e-02, -1.57226562e-01, 1.14189148e-01, 1.51123047e-01, 1.39892578e-01, -1.34277344e-02, -1.20849609e-01, -7.00836182e-02, -1.95800781e-01, -7.18994141e-02, -6.34765625e-03, 6.64062500e-02, 1.12548828e-01, -2.15820312e-01, -3.43017578e-01, -4.88281250e-03, -1.03759766e-01, 1.63696289e-01, -1.07788086e-01, 8.75244141e-02, 1.62109375e-01, 3.31878662e-02, 4.82177734e-02, -3.38134766e-02, -1.20117188e-01, 9.24072266e-02, -2.44140625e-02, -1.14746094e-02, 3.06640625e-01, -3.00292969e-02, -9.77783203e-02, -1.56005859e-01, 7.37304688e-02, -1.15051270e-01, -1.74560547e-02, -4.08935547e-02, -5.81054688e-02, 9.39941406e-02, 8.69140625e-02, -1.48681641e-01, 3.41796875e-02, -4.88281250e-03, 1.95312500e-03, -5.85937500e-03, -1.93054199e-01, -7.91015625e-02, -1.08642578e-01, 1.04492188e-01, -1.24511719e-01, -4.61425781e-02, 1.21276855e-01, 1.46606445e-01, -3.76129150e-02, 1.06445312e-01, 2.43164062e-01, -1.54296875e-01, -1.29882812e-01, 2.45361328e-02, 1.64550781e-01, 1.15478516e-01, 6.00585938e-02, -1.12060547e-01, 8.74023438e-02, -2.19970703e-01, -1.95312500e-03, 7.12890625e-02, -7.79418945e-02, -6.07910156e-02, -1.91650391e-02, 8.56933594e-02, -2.34375000e-02, 7.37304688e-02, 6.22978210e-02, -2.02636719e-01, -2.17285156e-01, 8.83789062e-02, 1.32568359e-01, -1.63085938e-01, 1.70410156e-01, -2.20703125e-01, 2.44140625e-04, -9.06982422e-02, -1.97753906e-02, 1.42822266e-02, 2.46582031e-02, 1.06201172e-01, 6.15234375e-02, -1.22070312e-02, 9.64355469e-02, 1.06933594e-01, 6.43310547e-02, -6.92138672e-02, -4.41894531e-02, -3.90625000e-03, -6.04248047e-02, -1.42456055e-01, 1.19873047e-01, -3.17382812e-02, 1.66992188e-01, -7.43408203e-02, -1.15356445e-02, 6.03027344e-02, -1.52587891e-01, 1.22131348e-01, 1.86523438e-01, 2.29835510e-02, 7.76062012e-02, -5.51757812e-02, 1.24267578e-01, 1.01440430e-01, -2.29003906e-01, -1.04003906e-01, -1.11816406e-01, -3.38867188e-01, -1.94091797e-02, 4.80957031e-02, -3.40576172e-02, 2.05078125e-02, -2.36816406e-02, 7.81250000e-03, -1.15844727e-01, 9.93652344e-02, -2.17285156e-01, 1.08276367e-01, 2.51464844e-02, -8.83789062e-02, -1.65710449e-01, -9.69238281e-02, -3.29437256e-02, 1.10595703e-01, 1.30371094e-01, 1.27685547e-01, -1.16149902e-01, -1.01440430e-01, -2.35595703e-02, 1.32812500e-01, 3.12194824e-02, -1.48925781e-02, -7.29370117e-02, -2.74658203e-02, -2.26440430e-02, -8.83789062e-02, -8.63037109e-02, -1.81152344e-01, -2.83203125e-02, 6.49414062e-02, -1.32812500e-01, -4.76074219e-02, 2.96142578e-01, 1.68945312e-01, -3.41796875e-03, 7.08007812e-03, 1.59393311e-01, -1.46484375e-02, 5.55419922e-02, -1.70898438e-03, -5.24902344e-02, -3.30810547e-02, 7.05566406e-02, 6.15234375e-02, 6.89225197e-02, 2.92968750e-03, 2.34375000e-02, -1.65039062e-01, 9.05151367e-02, 1.12304688e-02, 1.82861328e-01, 1.69982910e-01, -1.36596680e-01, -8.69140625e-02, -2.72216797e-02, 9.24377441e-02, -1.98440552e-01, -7.79304504e-02, -2.85644531e-02, -1.05102539e-01, -8.14247131e-02, -3.54003906e-02, 1.58691406e-01, -5.13916016e-02, 8.03222656e-02, 1.70776367e-01, -5.35888672e-02, -8.37402344e-02, -1.28784180e-01, 2.54394531e-01, 6.09130859e-02, -1.72851562e-01, 1.42822266e-01, -1.80419922e-01, -2.84179688e-01, 3.82690430e-02, 8.66699219e-03, -9.24072266e-02, -4.00390625e-02, -1.54663086e-01, -2.57873535e-02, 6.32324219e-02, 2.27050781e-01, -5.83496094e-02, -3.78417969e-02, -2.11181641e-02, 1.24511719e-02, 2.14355469e-01, -1.83105469e-02, 4.78820801e-02, -1.53198242e-01, -4.26025391e-02, -1.89208984e-02, -3.51562500e-02, -3.55701447e-02, -1.53686523e-01, 1.07421875e-02, -9.54589844e-02], dtype=float32)

Building and Training the Model

Our final pipeline will remain almost the same - vectorizer, followed by a logistic regression classifier. Defining a simple cross-validation grid is a nice idea as well - but it will contain only one parameter now (regularisation strength).

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer
from sklearn.linear_model import LogisticRegression

vectorizer = FunctionTransformer(lambda x: np.vstack([vectorize(seq) for seq in x]))
classifier = LogisticRegression()

pipeline = Pipeline([
    ('vectorizer', vectorizer),
    ('classifier', classifier),
])

param_grid = {
    'classifier__C': [0.1, 1, 10],
}

The model is ready to be trained. This time, we may use grid search - our parameters matrix is so tiny that we could afford a full hyperplanar parameter search instead of a randomized one.

%%capture --no-stdout
from joblib import parallel_backend
from sklearn.model_selection import GridSearchCV
cv = GridSearchCV(pipeline, param_grid, n_jobs=-1, verbose=3)
cv.fit(x_train, y_train)
Output
Fitting 5 folds for each of 3 candidates, totalling 15 fits
[CV 1/5] END .................classifier__C=0.1;, score=0.831 total time= 1.2min
[CV 2/5] END .................classifier__C=0.1;, score=0.831 total time= 1.2min
[CV 3/5] END .................classifier__C=0.1;, score=0.824 total time= 1.2min
[CV 4/5] END .................classifier__C=0.1;, score=0.820 total time= 1.1min
[CV 5/5] END .................classifier__C=0.1;, score=0.821 total time= 1.1min
[CV 1/5] END ...................classifier__C=1;, score=0.854 total time= 1.1min
[CV 2/5] END ...................classifier__C=1;, score=0.855 total time= 1.1min
[CV 3/5] END ...................classifier__C=1;, score=0.852 total time= 1.1min
[CV 4/5] END ...................classifier__C=1;, score=0.845 total time=  59.1s
[CV 5/5] END ...................classifier__C=1;, score=0.848 total time=  58.5s
[CV 1/5] END ..................classifier__C=10;, score=0.864 total time=  57.7s
[CV 2/5] END ..................classifier__C=10;, score=0.862 total time=  54.9s
[CV 3/5] END ..................classifier__C=10;, score=0.861 total time=  54.7s
[CV 4/5] END ..................classifier__C=10;, score=0.855 total time=  52.0s
[CV 5/5] END ..................classifier__C=10;, score=0.861 total time=  50.1s

Result

from sklearn.metrics import classification_report
prediction = cv.best_estimator_.predict(x_test)
print(classification_report(y_test, prediction, target_names=ds.features['label'].names))
              precision    recall  f1-score   support

         neg       0.86      0.85      0.86      5025
         pos       0.86      0.87      0.86      4975

    accuracy                           0.86     10000
   macro avg       0.86      0.86      0.86     10000
weighted avg       0.86      0.86      0.86     10000

Conclusion

The model achieved a final accuracy of 86%. While this demonstrates the basic application of word embeddings, it significantly underperforms compared to the previous approach using a simple count vectorizer with n-grams (which reached 90%).

This suggests that for this specific dataset and task, the simple averaging of word embeddings, which loses word order and contextual nuances, is less effective than a feature representation that explicitly captures local phrases (like n-grams).

This highlights a limitation of simple averaging - it discards crucial word order and local contextual information, which n-grams successfully captured. To overcome this and fully leverage the semantic power of word embeddings without losing sequential context, we need models that can learn to understand the relationships between words in a sentence.