Attention Visualization in Text Classification

2019-04-25

This post will give an brief introduction of attention mechanism in seq2seq model and then show you how does the attention work in text classification. Finally we visualize the attention weights in a RNN based text classification model. The model is implemented in Keras.

Ref:

  1. github DeepMoji
  2. Attention in Keras

Thanks the above 2 wonderful works.

1. About attention

Attention first comes from the way human beings observing things. Our brains use visual attention mechanism to process the visual information we received. When observing a picture, we first have an quick overview, and then our attention focus on several particular parts to gain the information we want further.
Except for the computer vision application, attention mechanism is useful in many NLP tasks too. Such as machine translation and text classification. You may be familiar with the model seq2seq + attention which is wellknown in machine translation, which is shown below. Among which the $a(s_{t-1}, hj)$ is a function for computing the correlation between decoder state $s{t-1}$ and each encoder state $h_j$.

The attention mechanism’s expression are as follows:

2. Attention in text classification

For the case of text classification, since only 1 label is asked to be output, we don’t have a decoder, neither do the decoder state. So the correlation function $a(s_{t-1}, h_j)$ is changed to $a(h_j)$ for computing the correlation between the label and each encoder state $h_j$. For the details of function $a(⋅)$ and the function $a_j = max(e_j)$ which only keeps the most significant correlation value, you can design them by your need. And the attention weights represents the relationship between the output label and the input tokens.
$$ e_j = a(h_j) $$

3. Visualize attention weights

This section will introduce you how to visualize the attention weights in a text classification task.
In this procedure, we need to get the words and their corresponding attention wgts and the final output label. The final effect will looks like this.

The main steps are as follows.

  1. Train a usual text classification model(mdl) and save its model wgts.(using checkpoints)
  2. The attention layer (better designed as an independent class) has a switch to control output probs or output attn wgts.
  3. Build 2 models and load the saved model wgts. One for predict the label(pred_mdl), the other for get the attention wgts(attn_mdl).
    First we provide you an implementation of Attention layer with such switch.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    # -*- coding: utf-8 -*-
    """
    File: attnlayer.py
    Date: 2019-03-30 15:46
    Author: amy
    """
    import logging
    from keras import initializers, layers
    from keras.engine import InputSpec, Layer
    from keras.utils import normalize
    from keras import backend as K
    import config as conf
    logger = logging.getLogger("console_file")
    class AttentionWeightedAverage(Layer):
    """
    Computes a weighted average of the different channels across timesteps.
    Uses 1 parameter pr. channel to compute the attention value for a single timestep.
    """
    def __init__(self, return_attention=False, **kwargs):
    self.init = initializers.get('glorot_uniform')
    # the swith to control return values
    self.return_attention = return_attention
    super(AttentionWeightedAverage, self).__init__(**kwargs)
    def get_config(self):
    config = {
    'return_attention': self.return_attention,
    }
    base_config = super(AttentionWeightedAverage, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))
    def build(self, input_shape):
    ''' define your attention layer params here: '''
    self.input_spec = [InputSpec(ndim=3)]
    assert len(input_shape) == 3
    self.W = self.add_weight(name="Wa",
    shape=(input_shape[-1], conf.n_classes),
    initializer=self.init,
    trainable=True)
    self.trainable_weights = [self.W]
    super(AttentionWeightedAverage, self).build(input_shape)
    def call(self, x):
    ''' define your attention function here: '''
    ej = a(hj) # implement it by yourself.
    # find the most important value
    # ai: (n_classes, 1)
    ai = K.max(ej, axis=-1)
    # softmax to get the attention weights
    att_weights = layers.Softmax(axis=-1)(ai)
    # weight-avg x
    weighted_inputs = x * K.expand_dims(att_weights)
    # output probs for each class
    result = K.sum(weighted_inputs, axis=1)
    if self.return_attention:
    # return both attn output and attn wgts (attn scores for each class)
    return [result, att_weights]
    # only return attn layer output: wgted avg hidden stateds
    return result
    def get_output_shape_for(self, input_shape):
    return self.compute_output_shape(input_shape)
    def compute_output_shape(self, input_shape):
    output_len = input_shape[2]
    if self.return_attention:
    return [(input_shape[0], output_len), (input_shape[0], input_shape[1])]
    return input_shape[0], output_len

Second, we provide you a RNN model using the predefined attention layer.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def birnn_model(return_attention=False):
logger.info("using [birnn] model...")
# word embedding layer
w_inp = Input(shape=(p['w_maxlen'],), name='word_input')
w_emb = Embedding(w_vocab_size, p['w_embdim'],
embeddings_initializer=random_uniform(minval=-1., maxval=1.),
name='word_embedding')(w_inp)
w_emb = SpatialDropout1D(p['w_embdrop'])(w_emb)
# rnn cell type
rnn = LSTM
# birnn
w_fw = rnn(p['w_featdim'], name='word_fw_rnn', return_sequences=True)(w_emb)
w_bw = rnn(p['w_featdim'], go_backwards=True, return_sequences=True, name='word_bw_rnn')(w_emb)
w_feat = concatenate([w_fw, w_bw])
# attention
ha = AttentionWeightedAverage(name="attn", return_attention=return_attention)(w_feat, pre_embs=emo_embs)
if return_attention is True:
m = Model(inputs=[w_inp], outputs=ha)
else:
h = concatenate([c_feat, ha])
# output
emo = Dense(conf.n_classes, activation='softmax', name='emoji')(h)
m = Model(inputs=[w_inp], outputs=[emo])
return m

Finally, we will show you how to gain and visualize attn wgts.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class Visualizer(object):
def __init__(self,
word_info_file="data/word_info.txt",
sample_file="data/test.txt"):
"""
Visualizes attention weights.
Args:
word_vocab_file: idx2word dict file.
sample_file: input data file to visualize.
"""
self.w_inps, self.labels = None, None
self.word2idx, self.w_maxlen = None, None
self.idx2word = None
self.w_vocab_size = None
self.word_lists = []
self.texts = None
self.pred_model = None
self.attn_model = None
self.weights_file = "checkpoints/attn_vis.epo_10.hdf5"
self.n_sample = None
self.correct_file = "./attention_maps/correct.txt"
self.wrong_file = "./attention_maps/wrong.txt"
def set_models(self, pred_model, attn_model):
"""
Sets the models to use.
Args:
pred_model: the prediction model
attn_model: the model that outputs the activation maps
Returns:
"""
self.pred_model = pred_model
self.attn_model = attn_model
def attention_map(self):
"""
Text to visualze attention map for.
"""
# predict labels
logger.info("start attention mapping...")
y_preds_onehot = self.pred_model.predict(x={'word_input': self.w_inps}, batch_size=128)
y_preds = np.argmax(y_preds_onehot, axis=-1)
y_trues = self.labels
# get the activation map
ret = self.attn_model.predict(x={ 'word_input': self.w_inps}, batch_size=128)
attn_wgts = ret[1]
# save to correct_file and wrong_file
logger.info("saving to files ...")
with open(self.correct_file, "w", encoding=conf.enc) as c_fw, \
open(self.wrong_file, "w", encoding=conf.enc) as w_fw:
for i in range(self.n_sample):
wl = self.word_lists[i]
text = self.texts[i]
y_pred = y_preds[i]
y_true = y_trues[i]
pred = y_pred
label = y_true
wgts = attn_wgts[i]
assert len(wl) == len(wgts)
ss = ""
for k in range(len(wl)):
w, wgt = wl[k], wgts[k]
if w not in ["PAD", ] # ignore special tokens:
ss += "({}, {:.4f}) ".format(w, wgt)
if y_pred == y_true:
c_fw.write("[{}]: {}\n".format(i, text))
c_fw.write("true: {}, pred: {} || {}\n".format(label, pred, ss))
else:
w_fw.write("[{}]: {}\n".format(i, text))
w_fw.write("[]true: {}, pred: {} || {}\n".format(label, pred, ss))
logger.info("finish attention mapping...")
def do_visualize(self):
# preprocess data
w_tst, w_vocab, tst_labels, tst_docs, word2idx = load_data(test_file)
self.w_inps = w_tst
self.labels = np.argmax(tst_labels, axis=-1)
self.word2idx = word2idx
self.idx2word = {v: k for k, v in self.word2idx.items()}
self.w_vocab_size = len(self.word2idx)
# map idxes to word lists
for seq in self.w_inps:
wl = []
for idx in seq:
if idx in self.idx2word:
wl.append(self.idx2word[idx])
else:
wl.append("OOV")
self.word_lists.append(wl)
self.texts = tst_docs
self.n_sample = len(self.texts)
# build 2 models and load weights
pred_model = birnn_model(return_attention=False)
build_model(pred_model)
pred_model.load_weights('checkpoints/attn_vis.epo_10.hdf5', by_name=True)
pred_model.summary()
attn_model = birnn_model(return_attention=True)
build_model(attn_model)
attn_model.load_weights('checkpoints/attn_vis.epo_10.hdf5', by_name=True)
attn_model.summary()
self.set_models(pred_model, attn_model)
# visualize
self.attention_map()
if __name__ == '__main__':
vis = Visualizer()
vis.do_visualize()

By the way, after getting the words, prediction label and the attention weights, you can draw the above picture in differenent ways. The larger the attention weight is, the darker the text background color should be.