Skip to main content
  1. Blogs/

Implementing Seemore: A VLM from Scratch in PyTorch

·2912 words·14 mins
VLM Implementation
This blog is a implmentation of Seemore from huggingface blog.

I am doing my project on “Vision-Language Model (VLM) Unlearning” and first I need to get more understanding of VLM. I just came across a article from huggingface blog and I found it interesting.

Here I will record my implementation details with the seemore tutorial and note what is important.

Let’s get started.

Preliminary
#

First we need to do some chores before we really implementing the core components.

 1
 2import base64
 3import io
 4import pandas as pd
 5
 6from PIL import Image
 7import torchvision.transforms as transforms
 8
 9import torch
10import torch.nn as nn
11from torch.nn import functional as F
12from torch.nn import init
13
14device = 'cuda' if torch.cuda.is_available() else 'cpu'

We will use the tinyshakespear dataset to build the encoding and decoding functions. For the sake of brevity we do not pretrain the decoder model on it.

Here we use a very simple conversion of each character to a corresponding integer. We also construct two mappings stoi and itos.

15text_path = './input.txt'
16with open(text_path, 'r', encoding='utf-8') as f:
17    text = f.read()
18
19# get all unique characters in the text
20chars = sorted(list(set(text)))
21# create a mapping from unique characters to indices
22stoi = {ch: i for i, ch in enumerate(chars)}
23# stoi contains 65 characters and index mapping, add an extra ''
24stoi['']=65
25# create a mapping from indices to unique characters
26itos = {i: ch for ch, i in stoi.items()}

And we construct two very simple encode and decode lambda function to get a list of integers from a string, and vice versa.

27# encode a string: take a string, output a list of integers
28encode = lambda s: torch.tensor([stoi[ch] for ch in s], dtype=torch.long)
29# decode a list of integers: take a list of integers, output a string
30decode = lambda l: ''.join([itos[i] for i in l])
31vocab_size = len(stoi)

VLM components
#

There are 3 main components for a VLM:

  • Image Encoder.
    • It is used to extract visual features from images.
    • It used a from stratch implementation of the original vision transformer in CLIP.
  • Vision-Language Projector.
    • Image embeddings are not of the same shape as text embeddings used by the decoder. So we need to change dimensionality of image features extracted by the image encoder to match what’s observed in the text embedding space.
    • So image features become ‘visual tokens’ for the decoder.
    • This could be a single layer of an MLP.
  • A decoder only language model.
    • This is the component that ultimately generates text.
    • In this implementation the projection module is incorporated to the decoder.
    • Typically this is not observed.

The VLM architecture from seemore.
The VLM architecture from seemore.

In summary, an image encoder extracts feature from a given image, passes these image embeddings to a vision-language projector which projects these image embeddings to the text embedding space, that is then concatenated with the text embeddings from the text inputs and used to autoregressively generated text by a decoder only language model.

However, in real practice, we usually used a pretrained vision transformer from CLIP or SigLIP is used. Source: OpenAI

Image Encoder
#

To implement vision transformer from scratch we will create a PatchEmbeddings class that can take an image and create a sequence of patches.

This process is crucial for enabling the transformer architecture to process visual data effectively, specifically using the attention blocks in the subsequent steps of the architecture.

Vision Transformer. Source: Arxiv/2010.11929

Here we create a convolutional layer to extract patch embeddings. We assume the input image has 3 color channels, hidden_dim sets the number of output channels to match the hidden dimension, kernel_size=patch_size and stride=patch_size ensure each patch is separately embedded.

27class PatchEmbeddings(nn.Module):
28    def __init__(self, img_size=96, patch_size=16, hidden_dim=512):
29        super().__init__()
30        self.img_size = img_size
31        self.patch_size = patch_size
32        self.num_patches = (img_size // patch_size) ** 2
33        # Ensure the convolution outputs a feature map with hidden_dim channels
34        self.conv = nn.Conv2d(in_channels=3, out_channels=hidden_dim,
35                              kernel_size=patch_size, stride=patch_size)
36
37    def forward(self, X):
38        X = self.conv(X)
39        X = X.flatten(2)  # Flatten the patch dimensions
40        X = X.transpose(1, 2)  # [B, num_patches, hidden_dim]
41        return X

The input image is broken down to \( \mathrm{imgsize} / \mathrm{patchsize} \) patches using the convolution layer and projected into vectors with a channel dimension of 512.

The most components in the trasnformer blocks, such as the attention head, multi-head attention, and the MLP in each attention head are mostly identical across the ViT we are implementing ofr the visual token generation and the decoder language model for the actual text output generation.

GELU used quite often in ViT and ReLU used in text transformers. However, it seems that GELU is being used for both due to its resultant model performance.

42class MLP(nn.Module):
43    def __init__(self, n_embd, dropout=0.1, is_decoder=True):
44        super().__init__()
45        layers = [
46            nn.Linear(n_embd, 4 * n_embd),
47            nn.ReLU() if is_decoder else nn.GELU(),
48            nn.Linear(4 * n_embd, n_embd),
49            nn.Dropout(dropout)
50        ]
51        self.net = nn.Sequential(*layers)
52
53    def forward(self, x):
54        return self.net(x)

The only key difference is the masking applied in each attention head in the decoder language model. This is done to ensure the integrity of the autoregressive language generation process, particularly in a decoder-only model. This masking technique is crucial as it obscures any information following the current token’s position, thereby directing the model’s attention to only the preceding parts of the sequence. Such an attention mechanism is known as causal self-attention.

The lower triangular mask is only applied in the case of a decoder model. Consider the bright blue triangle in matrix W absent in the case of visualising the process in each attention head in the vision encoder. Source: AviSoori1x

55class Head(nn.Module):
56    def __init__(self, n_embd, head_size, dropout=0.1, is_decoder=False):
57        super().__init__()
58        self.key = nn.Linear(n_embd, head_size, bias=False)
59        self.query = nn.Linear(n_embd, head_size, bias=False)
60        self.value = nn.Linear(n_embd, head_size, bias=False)
61        self.dropout = nn.Dropout(dropout)
62        self.is_decoder = is_decoder
63
64    def forward(self, x):
65        B, T, C = x.shape
66        k = self.key(x)
67        q = self.query(x)
68        v = self.value(x)
69
70        # Compute attention scores
71        wei = q @ k.transpose(-2, -1) * (C**-0.5)
72        if self.is_decoder:
73            # Ensure the mask is the correct size for the current sequence length
74            tril = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device))
75            wei = wei.masked_fill(tril == 0, float('-inf'))
76
77        # Apply softmax to get probabilities
78        wei = F.softmax(wei, dim=-1)
79        wei = self.dropout(wei)
80
81        # Perform weighted aggregation of values
82        out = wei @ v
83        return out

The implementation of multihead attention is as follows:

84class MultiHeadAttention(nn.Module):
85    def __init__(self, n_embd, num_heads, dropout=0.1, is_decoder=False):
86        super().__init__()
87        #Using assert statements for this type of checks is a good idea in general in your code
88        assert n_embd % num_heads == 0, "n_embd must be divisible by num_heads"
89        self.heads = nn.ModuleList([
90            Head(n_embd, n_embd // num_heads, dropout, is_decoder)
91            for _ in range(num_heads)
92        ])
93        self.proj = nn.Linear(n_embd, n_embd)
94        self.dropout = nn.Dropout(dropout)
95
96    def forward(self, x):
97        out = torch.cat([h(x) for h in self.heads], dim=-1)
98        out = self.dropout(self.proj(out))
99        return out

Each encoder transformer blocks looks as follows:

100class Block(nn.Module):
101    def __init__(self, n_embd, num_heads, dropout=0.1, is_decoder=False):
102        super().__init__()
103        self.ln1 = nn.LayerNorm(n_embd)
104        self.attn = MultiHeadAttention(n_embd, num_heads, dropout, is_decoder)
105        self.ln2 = nn.LayerNorm(n_embd)
106        self.ffn = nn.Sequential(
107            nn.Linear(n_embd, 4 * n_embd),
108            nn.GELU(),
109            nn.Linear(4 * n_embd, n_embd),
110        )
111
112    def forward(self, x):
113        original_x = x  # Save for residual connection
114        x = self.ln1(x)
115        attn_output = self.attn(x)
116        x = original_x + attn_output
117        x = self.ln2(x)
118        ffn_output = self.ffn(x)
119        x = x + ffn_output
120        return x

Now all this can be be put together to implement a Vision Transformer:

121class ViT(nn.Module):
122    def __init__(self, img_size, patch_size, num_hiddens, num_heads, num_blks, emb_dropout, blk_dropout):
123        super().__init__()
124        self.patch_embedding = PatchEmbeddings(img_size, patch_size, num_hiddens)
125        self.cls_token = nn.Parameter(torch.zeros(1, 1, num_hiddens))
126        num_patches = (img_size // patch_size) ** 2
127        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, num_hiddens))
128        self.dropout = nn.Dropout(emb_dropout)
129        self.blocks = nn.ModuleList([Block(num_hiddens, num_heads, blk_dropout, is_decoder=False) for _ in range(num_blks)])
130        self.layer_norm = nn.LayerNorm(num_hiddens)
131
132    def forward(self, X):
133        x = self.patch_embedding(X)
134        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
135        x = torch.cat((cls_tokens, x), dim=1)
136        x += self.pos_embedding
137        x = self.dropout(x)
138        for block in self.blocks:
139            x = block(x)
140        x = self.layer_norm(x[:, 0])
141        return x

Vision-Language Projector
#

We cannot directly concatenate this to the text embeddings. We need to project this from the dimensionality of image embeddings from the vision transformer to the dimensionality of text embeddings.

This can be a single learnable layer followed by a nonlienarity or an MLP. The reason that an interesting current trend of keeping both the pretrained vision encoder and language decoder frozen during the VLM training phase. So giving more parameters to learn via this connection module could improve the ability of the overall VLM to generalize and help in the downstream instruction tuning process.

Here’s the implementation of this projection module. It’s not too different from the MLP used in the transformer blocks.

142class MultiModalProjector(nn.Module):
143    def __init__(self, n_embd, image_embed_dim, dropout=0.1):
144        super().__init__()
145        self.net = nn.Sequential(
146            nn.Linear(image_embed_dim, 4 * image_embed_dim),
147            nn.GELU(),
148            nn.Linear(4 * image_embed_dim, n_embd),
149            nn.Dropout(dropout)
150        )
151
152    def forward(self, x):
153        x = self.net(x)
154        return x

Language Decoder
#

The final component we need to look at is the decoder language model. I have integrated the projection module into the decoder model class implementation.

There’s no easy way to directly feed in reshaped embeddings in this implementation, so I’ve had to improvise a little.

However in using pretrained models with the Hugging Face API or any other modern library that allows you to use pretrained large language models, you can directly feed embeddings as input to the model.

Some interesting exercises:

  • How the image embeddings are reshaped using the vision language projector to match that of text embeddings.
  • Then concatenated with token embedding.
  • Subsequently combined with position embeddings and used to calculate a loss function (and eventually generate text).

The crucial parts of this decoder implementation is given below. Note how the is_decoder flag is passed as True to use the masked version of the self attention blocks, resulting in causal scaled dot product self attention in the language decoder.

155class DecoderLanguageModel(nn.Module):
156    def __init__(self, n_embd, image_embed_dim, vocab_size, num_heads, n_layer, use_images=False):
157        super().__init__()
158        self.use_images = use_images
159        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
160        self.position_embedding_table = nn.Embedding(1000, n_embd)
161        if use_images:
162            self.image_projection = MultiModalProjector(n_embd, image_embed_dim)
163        self.blocks = nn.Sequential(*[Block(n_embd, num_heads, is_decoder=True) for _ in range(n_layer)])
164        self.ln_f = nn.LayerNorm(n_embd)
165        self.lm_head = nn.Linear(n_embd, vocab_size)
166
167    def forward(self, idx, image_embeds=None, targets=None):
168        tok_emb = self.token_embedding_table(idx)
169        if self.use_images and image_embeds is not None:
170            img_emb = self.image_projection(image_embeds).unsqueeze(1)
171            tok_emb = torch.cat([img_emb, tok_emb], dim=1)
172        pos_emb = self.position_embedding_table(torch.arange(tok_emb.size(1), device=device)).unsqueeze(0)
173        x = tok_emb + pos_emb
174        x = self.blocks(x)
175        x = self.ln_f(x)
176        logits = self.lm_head(x)
177        if targets is not None:
178            if self.use_images and image_embeds is not None:
179                batch_size = idx.size(0)
180                targets = torch.cat([torch.full((batch_size, 1), -100, dtype=torch.long, device=device), targets], dim=1)
181            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)
182            return logits, loss
183        return logits
184
185    def generate(self, idx, image_embeds, max_new_tokens):
186        B, T = idx.shape
187        generated = idx
188
189        if self.use_images and image_embeds is not None:
190            img_emb = self.image_projection(image_embeds).unsqueeze(1)
191            current_output = torch.cat([img_emb, self.token_embedding_table(idx)], dim=1)
192        else:
193            current_output = self.token_embedding_table(idx)
194
195        for i in range(max_new_tokens):
196            T_current = current_output.size(1)
197            current_pos_emb = self.position_embedding_table(torch.arange(T_current, device=device)).unsqueeze(0)
198            current_output += current_pos_emb
199
200            for block in self.blocks:
201                current_output = block(current_output)
202
203            logits = self.lm_head(current_output[:, -1, :])
204            probs = F.softmax(logits, dim=-1)
205            idx_next = torch.multinomial(probs, num_samples=1)
206            generated = torch.cat((generated, idx_next), dim=1)
207            idx_next_emb = self.token_embedding_table(idx_next)
208            current_output = torch.cat((current_output, idx_next_emb), dim=1)
209
210        return generated

Summary
#

Now that we have our three key components, we can put it all together into a Vision Language Model. The full implementation is given below. All that’s happening here is:

  • Get image features from the vision encoder (Here it’s a vision transformer, but it could be any model that could generate features from an image input such as a ResNet or a traditional convolutional neural network (needless to say performance may suffer))
  • A projection module for projecting image tokens to the same embedding space as text embeddings for the decoder (this projector is integrated with the decoder in this implementation)
  • A decoder language model for generating text conditioned on a preceding image.
211class VisionLanguageModel(nn.Module):
212    def __init__(self, n_embd, image_embed_dim, vocab_size, n_layer, img_size, patch_size, num_heads, num_blks, emb_dropout, blk_dropout):
213        super().__init__()
214        num_hiddens = image_embed_dim  # Set num_hiddens equal to image_embed_dim
215        assert num_hiddens % num_heads == 0, "num_hiddens must be divisible by num_heads"
216        self.vision_encoder = ViT(img_size, patch_size, num_hiddens, num_heads, num_blks, emb_dropout, blk_dropout)
217        self.decoder = DecoderLanguageModel(n_embd, image_embed_dim, vocab_size, num_heads, n_layer, use_images=True)
218
219    def forward(self, img_array, idx, targets=None):
220        image_embeds = self.vision_encoder(img_array)
221
222        if image_embeds.nelement() == 0 or image_embeds.shape[1] == 0:
223            raise ValueError("somethign is messed up with the ViT model. It's returning an empty tensor or the embedding dimension is empty")
224
225        if targets is not None:
226            logits, loss = self.decoder(idx, image_embeds, targets)
227            return logits, loss
228        else:
229            logits = self.decoder(idx, image_embeds)
230            return logits
231
232    def generate(self, img_array, idx, max_new_tokens):
233      image_embeds = self.vision_encoder(img_array)
234
235      if image_embeds.nelement() == 0 or image_embeds.shape[1] ==0:
236        raise ValueError("somethign is messed up with the ViT model. It's returning an empty tensor or the embedding dimension is empty")
237
238      generated_tokens = self.decoder.generate(idx, image_embeds, max_new_tokens)
239      return generated_tokens

Extra code for training
#

240def base64_to_tensor(base64_str, img_size=96):
241    image = Image.open(io.BytesIO(base64.b64decode(base64_str)))
242    if image.mode != 'RGB':
243        image = image.convert('RGB')
244    transform = transforms.Compose([
245        transforms.Resize((img_size, img_size)),
246        transforms.ToTensor(),
247        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
248    ])
249    return transform(image).unsqueeze(0)  # Add batch dimension
250
251#Adjusting the data loader from makemore for multimodal data
252def get_batch(df, batch_size, split='train', img_size=96, val_batch_size=8):
253    # Split data into training and validation sets
254    n = int(0.9 * len(df))  # first 90% will be train, rest val
255    df_train = df.iloc[:n]
256    df_val = df.iloc[n:]
257    data = df_train if split == 'train' else df_val
258    batch_size = batch_size if split == 'train' else val_batch_size
259    replace = False if split == 'train' else True
260    batch = data.sample(n=batch_size, replace=replace)
261
262    images = torch.cat([base64_to_tensor(img, img_size) for img in batch['b64string_images']], dim=0).to(device)
263    text_indices = [torch.tensor(encode(desc), dtype=torch.long) for desc in batch['caption']]
264    max_length = max(len(t) for t in text_indices)
265
266    padded_text = torch.full((batch_size, max_length), fill_value=stoi[''], dtype=torch.long).to(device)
267    for i, text in enumerate(text_indices):
268        padded_text[i, :len(text)] = text
269
270    targets = torch.cat([padded_text[:, 1:], torch.full((batch_size, 1), fill_value=stoi[''], dtype=torch.long, device=device)], dim=1)
271
272    # Truncate or pad targets to match the length of padded_text
273    if targets.size(1) > padded_text.size(1):
274        targets = targets[:, :padded_text.size(1)]
275    elif targets.size(1) < padded_text.size(1):
276        targets = torch.cat([targets, torch.full((batch_size, padded_text.size(1) - targets.size(1)), fill_value=stoi[''], dtype=torch.long, device=device)], dim=1)
277
278    return images, padded_text, targets
279
280#Adjusting the training loop from makemore for multimodal data
281def train_model(model, df, epochs, vocab_size, img_size=96):
282    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
283    model.to(device)
284    for epoch in range(epochs):
285        model.train()
286        for _ in range(max_iters):
287            images, idx, targets = get_batch(df, batch_size, 'train', img_size)
288            optimizer.zero_grad()
289            logits, loss = model(images, idx, targets)
290            loss.backward()
291            optimizer.step()
292            if _ % eval_interval == 0:
293                print(f"Loss at iteration {_}: {loss.item()}")
294        val_loss = estimate_loss(model, df, 'val', img_size, val_batch_size=8)
295        print(f"Validation Loss after epoch {epoch}: {val_loss}")
296
297def estimate_loss(model, df, split, img_size=96, val_batch_size=8):
298    losses = []
299    model.eval()
300    for _ in range(eval_iters):
301        images, idx, targets = get_batch(df, batch_size, split, img_size, val_batch_size=val_batch_size)
302        _, loss = model(images, idx, targets)
303        losses.append(loss.item())
304    return sum(losses) / len(losses)
305
306df = pd.read_csv("./inputs.csv")
307#Expanding dataframe so that there's enough data to test. This is just duplicating data. A real dataset would have more rows
308df = pd.concat([df] * 30)[['b64string_images', 'caption']]
309df.shape
310
311batch_size = 16 # how many independent sequences will we process in parallel?
312block_size = 32 # what is the maximum context length for predictions?
313max_iters = 100
314eval_interval = 10
315learning_rate = 1e-3
316epochs=1
317device = 'cuda' if torch.cuda.is_available() else 'cpu'
318eval_iters = 40
319num_blks= 3
320head_size = 16
321n_embd = 128
322n_head = 8
323n_layer = 8
324dropout = 0.1
325img_size=96
326patch_size =16
327image_embed_dim = 512
328emb_dropout = blk_dropout =0.1
329
330# Initialize the model
331model = VisionLanguageModel(n_embd, image_embed_dim, vocab_size, n_layer, img_size, patch_size, n_head, num_blks, emb_dropout, blk_dropout)
332model.to(device)
333
334# Dummy data to initialize lazy modules
335dummy_img = torch.randn(1, 3, img_size, img_size).to(device)
336dummy_idx = torch.randint(0, vocab_size, (1, block_size)).to(device)
337model(dummy_img, dummy_idx)  # Forward pass to initialize all parameters
338
339# Train the model
340train_model(model, df, epochs, vocab_size, img_size)

In practice, the commonly observed sequence is:

  1. Get pretrained vision encoder from SigLIP or CLIP (both come in difference sizes). Freeze weights (i.e. don’t update during backward pass in training).
  2. Get pretrained decoder only language model e.g. all the way from TinyLLaMA, Phi-2 etc. to Llama 3 (or even much bigger in the case of GPT-4 and Grok 1.5 etc.). Freeze weights.
  3. Implement a projection module and train a VLM module much like what we have here, but only updating the weights of this projection module. This would effectively be the pretraining phase.
  4. Then during the instruction finetuning keep both the projection module and the decoder language model unfrozen and update weights of both in the backward pass.

comments powered by Disqus