# -*- coding: utf-8 -*-
"""
Created on Fri Jul 25 14:12:32 2025
@author: ali.saral
"""
import torch
import torch.nn as nn
import math
# --- Positional Encoding ---
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.pe = pe.unsqueeze(0) # shape: [1, max_len, d_model]
def forward(self, x):
return x + self.pe[:, :x.size(1)].to(x.device)
# --- Scaled Dot-Product Attention ---
def scaled_dot_product_attention(q, k, v, mask=None):
print(f"q.shape = {q.shape}")
print(f"k.shape = {k.shape}")
print(f"v.shape = {v.shape}")
d_k = q.size(-1)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn = torch.softmax(scores, dim=-1)
return torch.matmul(attn, v), attn
# --- Multi-Head Attention ---
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model # ✅ Needed for final reshaping
self.num_heads = num_heads
self.d_k = d_model // num_heads # Dimension per head
self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
q_len = q.size(1)
# Linear projection and split into heads
def transform(x, linear):
x = linear(x) # [B, seq_len, d_model]
B, S, _ = x.shape # S = actual sequence length of that input (q or k or v)
x = x.view(B, S, self.num_heads, self.d_k)
return x.transpose(1, 2) # [B, heads, S, d_k]
q = transform(q, self.q_linear)
k = transform(k, self.k_linear)
v = transform(v, self.v_linear)
# Scaled Dot-Product Attention
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) # [B, heads, seq_len, seq_len]
if mask is not None:
# mask shape should be broadcastable to [B, heads, seq_len, seq_len]
scores = scores.masked_fill(mask == 0, -1e9)
attn = torch.softmax(scores, dim=-1) # [B, heads, seq_len, seq_len]
output = torch.matmul(attn, v) # [B, heads, seq_len, d_k]
# Combine heads back: [B, heads, seq_len, d_k] → [B, seq_len, d_model]
output = output.transpose(1, 2).contiguous().view(batch_size, q_len, self.d_model)
return self.out(output)
# --- Feed Forward ---
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
def forward(self, x):
return self.ff(x)
# --- Encoder Layer ---
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff):
super().__init__()
self.attn = MultiHeadAttention(d_model, num_heads)
self.ffn = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, mask):
x = self.norm1(x + self.attn(x, x, x, mask))
x = self.norm2(x + self.ffn(x))
return x
# --- Decoder Layer ---
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
self.ffn = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
def forward(self, x, enc_out, tgt_mask, memory_mask):
x = self.norm1(x + self.self_attn(x, x, x, tgt_mask))
x = self.norm2(x + self.cross_attn(x, enc_out, enc_out, memory_mask))
x = self.norm3(x + self.ffn(x))
return x
# --- Encoder ---
class Encoder(nn.Module):
def __init__(self, vocab_size, d_model, N, num_heads, d_ff, max_len=5000):
super().__init__()
self.embed = nn.Embedding(vocab_size, d_model)
self.pe = PositionalEncoding(d_model, max_len)
self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff) for _ in range(N)])
def forward(self, x, mask):
x = self.embed(x) # shape: [batch_size, seq_len, d_model]
###print("x.shape after embedding:", x.shape)
x = self.pe(x)
for layer in self.layers:
x = layer(x, mask)
return x
# --- Decoder ---
class Decoder(nn.Module):
def __init__(self, vocab_size, d_model, N, num_heads, d_ff, max_len=5000):
super().__init__()
self.embed = nn.Embedding(vocab_size, d_model)
self.pe = PositionalEncoding(d_model, max_len)
self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff) for _ in range(N)])
def forward(self, x, enc_out, tgt_mask, memory_mask):
x = self.pe(self.embed(x))
for layer in self.layers:
x = layer(x, enc_out, tgt_mask, memory_mask)
return x
# --- Transformer ---
class Transformer(nn.Module):
def __init__(self, src_vocab, tgt_vocab, d_model=512, N=2, heads=8, d_ff=1024):
super().__init__()
self.encoder = Encoder(src_vocab, d_model, N, heads, d_ff)
self.decoder = Decoder(tgt_vocab, d_model, N, heads, d_ff)
self.out = nn.Linear(d_model, tgt_vocab)
def make_pad_mask(self, seq, pad_token=0):
return (seq != pad_token).unsqueeze(1).unsqueeze(2)
def make_look_ahead_mask(self, size):
return torch.tril(torch.ones(size, size)).bool().unsqueeze(0).unsqueeze(1)
def forward(self, src, tgt):
src_mask = self.make_pad_mask(src)
tgt_mask = self.make_pad_mask(tgt) & self.make_look_ahead_mask(tgt.size(1)).to(tgt.device)
enc = self.encoder(src, src_mask)
dec = self.decoder(tgt, enc, tgt_mask, src_mask)
return self.out(dec)
# --- Main Program ---
if __name__ == "__main__":
src = torch.tensor([[1, 5, 6, 2, 0]]) # (batch, src_seq_len)
tgt = torch.tensor([[1, 7, 4, 3, 0]]) # (batch, tgt_seq_len)
model = Transformer(src_vocab=10000, tgt_vocab=10000)
out = model(src, tgt) # (1, 5, 10000)
print("Output shape:", out.shape)
# Convert logits to probabilities
probs = torch.softmax(out, dim=-1) # (1, 5, 10000)
# Predicted token IDs (argmax)
predicted_ids = torch.argmax(probs, dim=-1)
print("\nPredicted token IDs:", predicted_ids.tolist())
# Show top-5 token IDs per position
for pos in range(probs.size(1)):
top5 = torch.topk(probs[0, pos], 5)
print(f"\nPosition {pos} Top-5:")
for prob, idx in zip(top5.values, top5.indices):
print(f" Token ID: {idx.item()}, Probability: {prob.item():.4f}")
