encoder_layer = nn.TransformerEncoderLayer(d_model=512,
nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer,
num_layers=6)
We need to store $n^2$ parameters
$(3*224*224)^2 = 22\: billion\: parameters$
x = torch.randn(1, 3, 224, 224)
# 2D conv
conv = nn.Conv2d(3, 768, 16, 16)
conv(x).reshape(-1, 196).transpose(0,1).shape
>> torch.Size([196, 768])
"We use standard learnable 1D position embeddings and the resulting sequence of embedding vectors serves as input to the encoder"
class ViT(pl.LightningModule):
def __init__(self, num_transformer_layers, num_classes=1000):
super().__init__()
self.criterion = nn.CrossEntropyLoss()
self.conv_embedding = nn.Conv2d(3, 768, 16, 16)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=8)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_transformer_layers)
self.mlp_head = nn.Linear(768, num_classes)
self.position_embedding_layer = nn.Embedding(197, 768)
def forward(self, x):
batch_size = x.shape[0]
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
#(batch_size, 196, 768)
patches_embedding = self.conv_embedding(x).reshape(-1, 196).transpose(0,1)
#(batch_size, 197, 768)
patches_embedding = torch.cat((cls_tokens, patches_embedding), dim=1)
#(batch_size, 197); 0, 0, ... 196, 196
positions = self._assign_positions_to_patches(
#(batch_size, 197, 768)
position_embedding = position_embedding_layer(positions)
#(batch_size, 197, 768)
final_embedding = patches_embedding + position_embedding
#(batch_size, 197, 768)
embedding_output = self.transformer_encoder(final_embedding)
#(batch_size, 768)
cls_vector = embedding_output[:, 0, :]
#(batch_size, num_classes)
return mlp_head(cls_vector)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = self.criterion(logits, y)
return loss
Transformers originally were seq2seq models, using ViT as a tool for captioning images is a "no-brainier"
BERT like embedding for images - prototyping CV models drastically accelerated
EVEN bigger models - even better image classification
More sophisticated (yet efficient) approach for patches
"Feel free to ask any question"
Piotr Mazurek
tugot17.github.io/Vision-Transformer-Presentation/