I am doing my project on “Vision-Language Model (VLM) Unlearning” and first I need to get more understanding of VLM. I just came across a article from huggingface blog and I found it interesting.
Here I will record my implementation details with the seemore tutorial and note what is important.
Let’s get started.
Preliminary#
First we need to do some chores before we really implementing the core components.
1
2import base64
3import io
4import pandas as pd
5
6from PIL import Image
7import torchvision.transforms as transforms
8
9import torch
10import torch.nn as nn
11from torch.nn import functional as F
12from torch.nn import init
13
14device = 'cuda' if torch.cuda.is_available() else 'cpu'
We will use the tinyshakespear
dataset to build the encoding and decoding functions. For the sake of brevity we do not pretrain the decoder model on it.
Here we use a very simple conversion of each character to a corresponding integer. We also construct two mappings stoi
and itos
.
15text_path = './input.txt'
16with open(text_path, 'r', encoding='utf-8') as f:
17 text = f.read()
18
19# get all unique characters in the text
20chars = sorted(list(set(text)))
21# create a mapping from unique characters to indices
22stoi = {ch: i for i, ch in enumerate(chars)}
23# stoi contains 65 characters and index mapping, add an extra ''
24stoi['']=65
25# create a mapping from indices to unique characters
26itos = {i: ch for ch, i in stoi.items()}
And we construct two very simple encode
and decode
lambda function to get a list of integers from a string, and vice versa.
27# encode a string: take a string, output a list of integers
28encode = lambda s: torch.tensor([stoi[ch] for ch in s], dtype=torch.long)
29# decode a list of integers: take a list of integers, output a string
30decode = lambda l: ''.join([itos[i] for i in l])
31vocab_size = len(stoi)
VLM components#
There are 3 main components for a VLM:
- Image Encoder.
- It is used to extract visual features from images.
- It used a from stratch implementation of the original vision transformer in CLIP.
- Vision-Language Projector.
- Image embeddings are not of the same shape as text embeddings used by the decoder. So we need to change dimensionality of image features extracted by the image encoder to match what’s observed in the text embedding space.
- So image features become ‘visual tokens’ for the decoder.
- This could be a single layer of an MLP.
- A decoder only language model.
- This is the component that ultimately generates text.
- In this implementation the projection module is incorporated to the decoder.
- Typically this is not observed.
In summary, an image encoder extracts feature from a given image, passes these image embeddings to a vision-language projector which projects these image embeddings to the text embedding space, that is then concatenated with the text embeddings from the text inputs and used to autoregressively generated text by a decoder only language model.
Image Encoder#
To implement vision transformer from scratch we will create a PatchEmbeddings
class that can take an image and create a sequence of patches.
This process is crucial for enabling the transformer architecture to process visual data effectively, specifically using the attention blocks in the subsequent steps of the architecture.
Here we create a convolutional layer to extract patch embeddings. We assume the input image has 3 color channels, hidden_dim
sets the number of output channels to match the hidden dimension, kernel_size=patch_size
and stride=patch_size
ensure each patch is separately embedded.
27class PatchEmbeddings(nn.Module):
28 def __init__(self, img_size=96, patch_size=16, hidden_dim=512):
29 super().__init__()
30 self.img_size = img_size
31 self.patch_size = patch_size
32 self.num_patches = (img_size // patch_size) ** 2
33 # Ensure the convolution outputs a feature map with hidden_dim channels
34 self.conv = nn.Conv2d(in_channels=3, out_channels=hidden_dim,
35 kernel_size=patch_size, stride=patch_size)
36
37 def forward(self, X):
38 X = self.conv(X)
39 X = X.flatten(2) # Flatten the patch dimensions
40 X = X.transpose(1, 2) # [B, num_patches, hidden_dim]
41 return X
The input image is broken down to \( \mathrm{imgsize} / \mathrm{patchsize} \) patches using the convolution layer and projected into vectors with a channel dimension of 512.
The most components in the trasnformer blocks, such as the attention head, multi-head attention, and the MLP in each attention head are mostly identical across the ViT we are implementing ofr the visual token
generation and the decoder language model for the actual text output generation.
GELU used quite often in ViT and ReLU used in text transformers. However, it seems that GELU is being used for both due to its resultant model performance.
42class MLP(nn.Module):
43 def __init__(self, n_embd, dropout=0.1, is_decoder=True):
44 super().__init__()
45 layers = [
46 nn.Linear(n_embd, 4 * n_embd),
47 nn.ReLU() if is_decoder else nn.GELU(),
48 nn.Linear(4 * n_embd, n_embd),
49 nn.Dropout(dropout)
50 ]
51 self.net = nn.Sequential(*layers)
52
53 def forward(self, x):
54 return self.net(x)
The only key difference is the masking applied in each attention head in the decoder language model. This is done to ensure the integrity of the autoregressive language generation process, particularly in a decoder-only model. This masking technique is crucial as it obscures any information following the current token’s position, thereby directing the model’s attention to only the preceding parts of the sequence. Such an attention mechanism is known as causal self-attention.
55class Head(nn.Module):
56 def __init__(self, n_embd, head_size, dropout=0.1, is_decoder=False):
57 super().__init__()
58 self.key = nn.Linear(n_embd, head_size, bias=False)
59 self.query = nn.Linear(n_embd, head_size, bias=False)
60 self.value = nn.Linear(n_embd, head_size, bias=False)
61 self.dropout = nn.Dropout(dropout)
62 self.is_decoder = is_decoder
63
64 def forward(self, x):
65 B, T, C = x.shape
66 k = self.key(x)
67 q = self.query(x)
68 v = self.value(x)
69
70 # Compute attention scores
71 wei = q @ k.transpose(-2, -1) * (C**-0.5)
72 if self.is_decoder:
73 # Ensure the mask is the correct size for the current sequence length
74 tril = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device))
75 wei = wei.masked_fill(tril == 0, float('-inf'))
76
77 # Apply softmax to get probabilities
78 wei = F.softmax(wei, dim=-1)
79 wei = self.dropout(wei)
80
81 # Perform weighted aggregation of values
82 out = wei @ v
83 return out
The implementation of multihead attention is as follows:
84class MultiHeadAttention(nn.Module):
85 def __init__(self, n_embd, num_heads, dropout=0.1, is_decoder=False):
86 super().__init__()
87 #Using assert statements for this type of checks is a good idea in general in your code
88 assert n_embd % num_heads == 0, "n_embd must be divisible by num_heads"
89 self.heads = nn.ModuleList([
90 Head(n_embd, n_embd // num_heads, dropout, is_decoder)
91 for _ in range(num_heads)
92 ])
93 self.proj = nn.Linear(n_embd, n_embd)
94 self.dropout = nn.Dropout(dropout)
95
96 def forward(self, x):
97 out = torch.cat([h(x) for h in self.heads], dim=-1)
98 out = self.dropout(self.proj(out))
99 return out
Each encoder transformer blocks looks as follows:
100class Block(nn.Module):
101 def __init__(self, n_embd, num_heads, dropout=0.1, is_decoder=False):
102 super().__init__()
103 self.ln1 = nn.LayerNorm(n_embd)
104 self.attn = MultiHeadAttention(n_embd, num_heads, dropout, is_decoder)
105 self.ln2 = nn.LayerNorm(n_embd)
106 self.ffn = nn.Sequential(
107 nn.Linear(n_embd, 4 * n_embd),
108 nn.GELU(),
109 nn.Linear(4 * n_embd, n_embd),
110 )
111
112 def forward(self, x):
113 original_x = x # Save for residual connection
114 x = self.ln1(x)
115 attn_output = self.attn(x)
116 x = original_x + attn_output
117 x = self.ln2(x)
118 ffn_output = self.ffn(x)
119 x = x + ffn_output
120 return x
Now all this can be be put together to implement a Vision Transformer:
121class ViT(nn.Module):
122 def __init__(self, img_size, patch_size, num_hiddens, num_heads, num_blks, emb_dropout, blk_dropout):
123 super().__init__()
124 self.patch_embedding = PatchEmbeddings(img_size, patch_size, num_hiddens)
125 self.cls_token = nn.Parameter(torch.zeros(1, 1, num_hiddens))
126 num_patches = (img_size // patch_size) ** 2
127 self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, num_hiddens))
128 self.dropout = nn.Dropout(emb_dropout)
129 self.blocks = nn.ModuleList([Block(num_hiddens, num_heads, blk_dropout, is_decoder=False) for _ in range(num_blks)])
130 self.layer_norm = nn.LayerNorm(num_hiddens)
131
132 def forward(self, X):
133 x = self.patch_embedding(X)
134 cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
135 x = torch.cat((cls_tokens, x), dim=1)
136 x += self.pos_embedding
137 x = self.dropout(x)
138 for block in self.blocks:
139 x = block(x)
140 x = self.layer_norm(x[:, 0])
141 return x
Vision-Language Projector#
We cannot directly concatenate this to the text embeddings. We need to project this from the dimensionality of image embeddings from the vision transformer to the dimensionality of text embeddings.
This can be a single learnable layer followed by a nonlienarity or an MLP. The reason that an interesting current trend of keeping both the pretrained vision encoder and language decoder frozen during the VLM training phase. So giving more parameters to learn via this connection module could improve the ability of the overall VLM to generalize and help in the downstream instruction tuning process.
Here’s the implementation of this projection module. It’s not too different from the MLP used in the transformer blocks.
142class MultiModalProjector(nn.Module):
143 def __init__(self, n_embd, image_embed_dim, dropout=0.1):
144 super().__init__()
145 self.net = nn.Sequential(
146 nn.Linear(image_embed_dim, 4 * image_embed_dim),
147 nn.GELU(),
148 nn.Linear(4 * image_embed_dim, n_embd),
149 nn.Dropout(dropout)
150 )
151
152 def forward(self, x):
153 x = self.net(x)
154 return x
Language Decoder#
The final component we need to look at is the decoder language model. I have integrated the projection module into the decoder model class implementation.
There’s no easy way to directly feed in reshaped embeddings in this implementation, so I’ve had to improvise a little.
However in using pretrained models with the Hugging Face API or any other modern library that allows you to use pretrained large language models, you can directly feed embeddings as input to the model.
Some interesting exercises:
- How the image embeddings are reshaped using the vision language projector to match that of text embeddings.
- Then concatenated with token embedding.
- Subsequently combined with position embeddings and used to calculate a loss function (and eventually generate text).
The crucial parts of this decoder implementation is given below. Note how the is_decoder
flag is passed as True
to use the masked version of the self attention blocks, resulting in causal scaled dot product self attention in the language decoder.
155class DecoderLanguageModel(nn.Module):
156 def __init__(self, n_embd, image_embed_dim, vocab_size, num_heads, n_layer, use_images=False):
157 super().__init__()
158 self.use_images = use_images
159 self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
160 self.position_embedding_table = nn.Embedding(1000, n_embd)
161 if use_images:
162 self.image_projection = MultiModalProjector(n_embd, image_embed_dim)
163 self.blocks = nn.Sequential(*[Block(n_embd, num_heads, is_decoder=True) for _ in range(n_layer)])
164 self.ln_f = nn.LayerNorm(n_embd)
165 self.lm_head = nn.Linear(n_embd, vocab_size)
166
167 def forward(self, idx, image_embeds=None, targets=None):
168 tok_emb = self.token_embedding_table(idx)
169 if self.use_images and image_embeds is not None:
170 img_emb = self.image_projection(image_embeds).unsqueeze(1)
171 tok_emb = torch.cat([img_emb, tok_emb], dim=1)
172 pos_emb = self.position_embedding_table(torch.arange(tok_emb.size(1), device=device)).unsqueeze(0)
173 x = tok_emb + pos_emb
174 x = self.blocks(x)
175 x = self.ln_f(x)
176 logits = self.lm_head(x)
177 if targets is not None:
178 if self.use_images and image_embeds is not None:
179 batch_size = idx.size(0)
180 targets = torch.cat([torch.full((batch_size, 1), -100, dtype=torch.long, device=device), targets], dim=1)
181 loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)
182 return logits, loss
183 return logits
184
185 def generate(self, idx, image_embeds, max_new_tokens):
186 B, T = idx.shape
187 generated = idx
188
189 if self.use_images and image_embeds is not None:
190 img_emb = self.image_projection(image_embeds).unsqueeze(1)
191 current_output = torch.cat([img_emb, self.token_embedding_table(idx)], dim=1)
192 else:
193 current_output = self.token_embedding_table(idx)
194
195 for i in range(max_new_tokens):
196 T_current = current_output.size(1)
197 current_pos_emb = self.position_embedding_table(torch.arange(T_current, device=device)).unsqueeze(0)
198 current_output += current_pos_emb
199
200 for block in self.blocks:
201 current_output = block(current_output)
202
203 logits = self.lm_head(current_output[:, -1, :])
204 probs = F.softmax(logits, dim=-1)
205 idx_next = torch.multinomial(probs, num_samples=1)
206 generated = torch.cat((generated, idx_next), dim=1)
207 idx_next_emb = self.token_embedding_table(idx_next)
208 current_output = torch.cat((current_output, idx_next_emb), dim=1)
209
210 return generated
Summary#
Now that we have our three key components, we can put it all together into a Vision Language Model. The full implementation is given below. All that’s happening here is:
- Get image features from the vision encoder (Here it’s a vision transformer, but it could be any model that could generate features from an image input such as a
ResNet
or a traditional convolutional neural network (needless to say performance may suffer)) - A projection module for projecting image tokens to the same embedding space as text embeddings for the decoder (this projector is integrated with the decoder in this implementation)
- A decoder language model for generating text conditioned on a preceding image.
211class VisionLanguageModel(nn.Module):
212 def __init__(self, n_embd, image_embed_dim, vocab_size, n_layer, img_size, patch_size, num_heads, num_blks, emb_dropout, blk_dropout):
213 super().__init__()
214 num_hiddens = image_embed_dim # Set num_hiddens equal to image_embed_dim
215 assert num_hiddens % num_heads == 0, "num_hiddens must be divisible by num_heads"
216 self.vision_encoder = ViT(img_size, patch_size, num_hiddens, num_heads, num_blks, emb_dropout, blk_dropout)
217 self.decoder = DecoderLanguageModel(n_embd, image_embed_dim, vocab_size, num_heads, n_layer, use_images=True)
218
219 def forward(self, img_array, idx, targets=None):
220 image_embeds = self.vision_encoder(img_array)
221
222 if image_embeds.nelement() == 0 or image_embeds.shape[1] == 0:
223 raise ValueError("somethign is messed up with the ViT model. It's returning an empty tensor or the embedding dimension is empty")
224
225 if targets is not None:
226 logits, loss = self.decoder(idx, image_embeds, targets)
227 return logits, loss
228 else:
229 logits = self.decoder(idx, image_embeds)
230 return logits
231
232 def generate(self, img_array, idx, max_new_tokens):
233 image_embeds = self.vision_encoder(img_array)
234
235 if image_embeds.nelement() == 0 or image_embeds.shape[1] ==0:
236 raise ValueError("somethign is messed up with the ViT model. It's returning an empty tensor or the embedding dimension is empty")
237
238 generated_tokens = self.decoder.generate(idx, image_embeds, max_new_tokens)
239 return generated_tokens
Extra code for training#
240def base64_to_tensor(base64_str, img_size=96):
241 image = Image.open(io.BytesIO(base64.b64decode(base64_str)))
242 if image.mode != 'RGB':
243 image = image.convert('RGB')
244 transform = transforms.Compose([
245 transforms.Resize((img_size, img_size)),
246 transforms.ToTensor(),
247 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
248 ])
249 return transform(image).unsqueeze(0) # Add batch dimension
250
251#Adjusting the data loader from makemore for multimodal data
252def get_batch(df, batch_size, split='train', img_size=96, val_batch_size=8):
253 # Split data into training and validation sets
254 n = int(0.9 * len(df)) # first 90% will be train, rest val
255 df_train = df.iloc[:n]
256 df_val = df.iloc[n:]
257 data = df_train if split == 'train' else df_val
258 batch_size = batch_size if split == 'train' else val_batch_size
259 replace = False if split == 'train' else True
260 batch = data.sample(n=batch_size, replace=replace)
261
262 images = torch.cat([base64_to_tensor(img, img_size) for img in batch['b64string_images']], dim=0).to(device)
263 text_indices = [torch.tensor(encode(desc), dtype=torch.long) for desc in batch['caption']]
264 max_length = max(len(t) for t in text_indices)
265
266 padded_text = torch.full((batch_size, max_length), fill_value=stoi[''], dtype=torch.long).to(device)
267 for i, text in enumerate(text_indices):
268 padded_text[i, :len(text)] = text
269
270 targets = torch.cat([padded_text[:, 1:], torch.full((batch_size, 1), fill_value=stoi[''], dtype=torch.long, device=device)], dim=1)
271
272 # Truncate or pad targets to match the length of padded_text
273 if targets.size(1) > padded_text.size(1):
274 targets = targets[:, :padded_text.size(1)]
275 elif targets.size(1) < padded_text.size(1):
276 targets = torch.cat([targets, torch.full((batch_size, padded_text.size(1) - targets.size(1)), fill_value=stoi[''], dtype=torch.long, device=device)], dim=1)
277
278 return images, padded_text, targets
279
280#Adjusting the training loop from makemore for multimodal data
281def train_model(model, df, epochs, vocab_size, img_size=96):
282 optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
283 model.to(device)
284 for epoch in range(epochs):
285 model.train()
286 for _ in range(max_iters):
287 images, idx, targets = get_batch(df, batch_size, 'train', img_size)
288 optimizer.zero_grad()
289 logits, loss = model(images, idx, targets)
290 loss.backward()
291 optimizer.step()
292 if _ % eval_interval == 0:
293 print(f"Loss at iteration {_}: {loss.item()}")
294 val_loss = estimate_loss(model, df, 'val', img_size, val_batch_size=8)
295 print(f"Validation Loss after epoch {epoch}: {val_loss}")
296
297def estimate_loss(model, df, split, img_size=96, val_batch_size=8):
298 losses = []
299 model.eval()
300 for _ in range(eval_iters):
301 images, idx, targets = get_batch(df, batch_size, split, img_size, val_batch_size=val_batch_size)
302 _, loss = model(images, idx, targets)
303 losses.append(loss.item())
304 return sum(losses) / len(losses)
305
306df = pd.read_csv("./inputs.csv")
307#Expanding dataframe so that there's enough data to test. This is just duplicating data. A real dataset would have more rows
308df = pd.concat([df] * 30)[['b64string_images', 'caption']]
309df.shape
310
311batch_size = 16 # how many independent sequences will we process in parallel?
312block_size = 32 # what is the maximum context length for predictions?
313max_iters = 100
314eval_interval = 10
315learning_rate = 1e-3
316epochs=1
317device = 'cuda' if torch.cuda.is_available() else 'cpu'
318eval_iters = 40
319num_blks= 3
320head_size = 16
321n_embd = 128
322n_head = 8
323n_layer = 8
324dropout = 0.1
325img_size=96
326patch_size =16
327image_embed_dim = 512
328emb_dropout = blk_dropout =0.1
329
330# Initialize the model
331model = VisionLanguageModel(n_embd, image_embed_dim, vocab_size, n_layer, img_size, patch_size, n_head, num_blks, emb_dropout, blk_dropout)
332model.to(device)
333
334# Dummy data to initialize lazy modules
335dummy_img = torch.randn(1, 3, img_size, img_size).to(device)
336dummy_idx = torch.randint(0, vocab_size, (1, block_size)).to(device)
337model(dummy_img, dummy_idx) # Forward pass to initialize all parameters
338
339# Train the model
340train_model(model, df, epochs, vocab_size, img_size)
In practice, the commonly observed sequence is:
- Get pretrained vision encoder from SigLIP or CLIP (both come in difference sizes). Freeze weights (i.e. don’t update during backward pass in training).
- Get pretrained decoder only language model e.g. all the way from TinyLLaMA, Phi-2 etc. to Llama 3 (or even much bigger in the case of GPT-4 and Grok 1.5 etc.). Freeze weights.
- Implement a projection module and train a VLM module much like what we have here, but only updating the weights of this projection module. This would effectively be the pretraining phase.
- Then during the instruction finetuning keep both the projection module and the decoder language model unfrozen and update weights of both in the backward pass.