Utility functions
I’ve spent the past week writing some desperately neeeded utility functions to instrument the code that i’ve built out as part of the scholar’s program. Largely, i’ve been trying to gain a better understanding of exactly what’s going on inside these models as they operate over data continously. One lesson learned is on the importance of writing utility functions early, not only because it saves you valuable time when approaching deadlines, but also because being really explicit up front about the kind of data you would like to collect also informs what kind of experiments your’e likely to run.
Feedback Transformer
Last blog post, I mentioned I’ve been playing around with implementing the feedback transformer. The principle idea in the feedback transformer is to allow low level representation in transformers to attend to previous higher level representations. This modifies the computational path of the the traditional transformer architecture and transforms it something functionally resembelling an autoregressive RNN. I wrote out a quick and dirty implementation for this (below), which I’ll clean up and post on github at some point, along with the million other things I have to catch up on.
class MultiHeadedAttn(nn.Module):
def __init__(self, Config):
super(MultiHeadedAttn, self).__init__()
self.c = c = Config
# Generates queries, keys, values
self.fc1 = nn.Linear(c.embdSize, (c.qkvSize)*c.numAttnHeads)
self.fc2 = nn.Linear(c.embdSize, c.embdSize)
# Create Mask
self.register_buffer(
"causalMask",
torch.tril(torch.ones((c.blockSize, c.blockSize))))
self.register_buffer(
"padMask",
torch.ones(c.blockSize, c.blockSize))
def forward(self, x, k, v):
B, T, embdSize = x.shape # B =Batch size, T = numTokens
h = self.c.numAttnHeads
q = self.fc1(x)
q = q.reshape(B,h,T,-1)
k = k.reshape(B,h,T,-1)
v = k.reshape(B,h,T,-1)
# God bless einsum
attn = torch.einsum('bhij,bhkj->bhik', q, k)
mask = torch.unsqueeze(self.padMasks, 1).repeat(
1, h, 1, 1)[..., :T,:T] *self.causalMask[:T, :T]
attn = attn.masked_fill(mask == 0., float('-inf'))
scores = F.softmax(attn/np.sqrt(self.c.qkvSize), -1)
outpt = torch.einsum('bhij,bhjk->bhik', scores, v)
outpt = outpt.view(B, T, embdSize)
return self.fc2(outpt)
class DecoderBlock(nn.Module):
def __init__(self, config):
super(DecoderBlock, self).__init__()
c = config
self.ln1 = nn.LayerNorm(c.embdSize)
self.attn = MultiHeadedAttn(config)
self.mlp = nn.Sequential(
nn.Linear(c.embdSize, c.embdSize*4),
nn.GELU(),
nn.Linear(c.embdSize * 4, c.embdSize)
)
self.ln2 = nn.LayerNorm(c.embdSize)
def _setPadMasks(self, padMasks):
self.attn.padMasks = padMasks
def forward(self, x, k, v):
x = self.ln1(x+self.attn(x,k,v))
x = self.ln2(x + self.mlp(x))
return x
class TinyFeedbackTransformer(nn.Module):
def __init__(self, Config):
super(TinyFeedbackTransformer, self).__init__()
# Size assertions
assert Config.embdSize >= Config.numAttnHeads
# Configuration stuff
c = Config
c.qkvSize = c.embdSize // c.numAttnHeads
self.config = c
self.wordEmbedding = nn.Embedding(
c.paddingIndx+1, c.embdSize, padding_idx=c.paddingIndx)
self.posEmbedding = nn.Parameter(
torch.zeros(1, c.blockSize, c.embdSize))
#New learnable parameters for feedback transformer
self.memoryCoeff = nn.Parameter(torch.ones(1, c.numLayers))
self.ffkv = nn.Linear(c.embdSize, (c.qkvSize*2)*c.numAttnHeads)
self.blocks = nn.ModuleList(
[DecoderBlock(c) for _ in range(c.numLayers)]
)
self.ln1 = nn.LayerNorm(c.embdSize)
self.head = nn.Linear(c.embdSize, c.paddingIndx, bias=False)
self.apply(self._init_weights)
def forward(self, indxs, padMasks):
for mod in self.blocks:
mod._setPadMasks(padMasks)
numTokens = indxs.shape[1]
# Combine word and position embeddings
x = self.wordEmbedding(indxs)
pos = self.posEmbedding[:, :numTokens, :]
x = x+pos
#Initalize the memory tensor
memory = torch.tensor([]).to(device)
batchSize = x.shape[0]
blockSize = self.config.blockSize
maxMemSize = self.config.memorySize
finalOutputs = torch.tensor([]).to(device)
#Pass through the transformer
for indx in range(x.shape[1]):
currSlice = x[:, indx, ...].view(batchSize,1,-1)
inpt = torch.cat((memory,currSlice), dim=1)
inpt = inpt[:, -maxMemSize:, ...]
# Grab the W_k and the W_v parameters
wk, wv = self.ffkv(inpt).chunk(2, 2)
outputs = torch.tensor([]).to(device)
for decoderBlock in self.blocks:
inpt = decoderBlock(inpt, wk, wv)
outputs = torch.cat((outputs, torch.unsqueeze(inpt[:,-1,:],0)))
currmemory = torch.einsum('il,lbd->bd',
torch.softmax(self.memoryCoeff, -1), outputs)
memory = torch.cat((memory, torch.unsqueeze(currmemory,1)), dim=1)
memory = memory[:, -maxMemSize:, ...]
finalOutputs = torch.cat((finalOutputs,torch.unsqueeze(inpt[:,-1,:],1)), dim =1)
finalOutputs = self.head(self.ln1(finalOutputs))
return finalOutputs
def _init_weights(self, module):
if isinstance(module, nn.Embedding):
d = (module.embedding_dim)**(1/2)
module.weight.data.normal_(mean=0.0, std=0.125/d)
if isinstance(module, nn.Linear):
d = (module.in_features)**(1/2)
module.weight.data.normal_(mean=0.0, std=0.125/d)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
Anywho, that’s it for now. More to come in the coming weeks if I ever find time to organize my thoughts and my work. TTYL