Tuesday, 20 January 2026

Custom Pytorch using transformer

 # -*- coding: utf-8 -*-

"""

Created on Fri Jul 25 14:12:32 2025


@author: ali.saral

"""


import torch

import torch.nn as nn

import math


# --- Positional Encoding ---

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=5000):

        super().__init__()

        pe = torch.zeros(max_len, d_model)

        position = torch.arange(0, max_len).unsqueeze(1)

        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)

        pe[:, 1::2] = torch.cos(position * div_term)

        self.pe = pe.unsqueeze(0)  # shape: [1, max_len, d_model]


    def forward(self, x):

        return x + self.pe[:, :x.size(1)].to(x.device)


# --- Scaled Dot-Product Attention ---

def scaled_dot_product_attention(q, k, v, mask=None):

    print(f"q.shape = {q.shape}")

    print(f"k.shape = {k.shape}")

    print(f"v.shape = {v.shape}")

    d_k = q.size(-1)

    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)

    if mask is not None:

        scores = scores.masked_fill(mask == 0, -1e9)

    attn = torch.softmax(scores, dim=-1)

    return torch.matmul(attn, v), attn


# --- Multi-Head Attention ---

class MultiHeadAttention(nn.Module):

    def __init__(self, d_model, num_heads):

        super().__init__()

        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"


        self.d_model = d_model                   # ✅ Needed for final reshaping

        self.num_heads = num_heads

        self.d_k = d_model // num_heads          # Dimension per head


        self.q_linear = nn.Linear(d_model, d_model)

        self.k_linear = nn.Linear(d_model, d_model)

        self.v_linear = nn.Linear(d_model, d_model)

        self.out = nn.Linear(d_model, d_model)


    def forward(self, q, k, v, mask=None):

        batch_size = q.size(0)

        q_len = q.size(1)


        # Linear projection and split into heads

        def transform(x, linear):

            x = linear(x)  # [B, seq_len, d_model]

            B, S, _ = x.shape  # S = actual sequence length of that input (q or k or v)

            x = x.view(B, S, self.num_heads, self.d_k)

            return x.transpose(1, 2)  # [B, heads, S, d_k]


        q = transform(q, self.q_linear)

        k = transform(k, self.k_linear)

        v = transform(v, self.v_linear)


        # Scaled Dot-Product Attention

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)  # [B, heads, seq_len, seq_len]


        if mask is not None:

            # mask shape should be broadcastable to [B, heads, seq_len, seq_len]

            scores = scores.masked_fill(mask == 0, -1e9)


        attn = torch.softmax(scores, dim=-1)     # [B, heads, seq_len, seq_len]

        output = torch.matmul(attn, v)           # [B, heads, seq_len, d_k]


        # Combine heads back: [B, heads, seq_len, d_k] → [B, seq_len, d_model]

        output = output.transpose(1, 2).contiguous().view(batch_size, q_len, self.d_model)


        return self.out(output)

# --- Feed Forward ---

class FeedForward(nn.Module):

    def __init__(self, d_model, d_ff):

        super().__init__()

        self.ff = nn.Sequential(

            nn.Linear(d_model, d_ff),

            nn.ReLU(),

            nn.Linear(d_ff, d_model)

        )


    def forward(self, x):

        return self.ff(x)


# --- Encoder Layer ---

class EncoderLayer(nn.Module):

    def __init__(self, d_model, num_heads, d_ff):

        super().__init__()

        self.attn = MultiHeadAttention(d_model, num_heads)

        self.ffn = FeedForward(d_model, d_ff)

        self.norm1 = nn.LayerNorm(d_model)

        self.norm2 = nn.LayerNorm(d_model)


    def forward(self, x, mask):

        x = self.norm1(x + self.attn(x, x, x, mask))

        x = self.norm2(x + self.ffn(x))

        return x


# --- Decoder Layer ---

class DecoderLayer(nn.Module):

    def __init__(self, d_model, num_heads, d_ff):

        super().__init__()

        self.self_attn = MultiHeadAttention(d_model, num_heads)

        self.cross_attn = MultiHeadAttention(d_model, num_heads)

        self.ffn = FeedForward(d_model, d_ff)

        self.norm1 = nn.LayerNorm(d_model)

        self.norm2 = nn.LayerNorm(d_model)

        self.norm3 = nn.LayerNorm(d_model)


    def forward(self, x, enc_out, tgt_mask, memory_mask):

        x = self.norm1(x + self.self_attn(x, x, x, tgt_mask))

        x = self.norm2(x + self.cross_attn(x, enc_out, enc_out, memory_mask))

        x = self.norm3(x + self.ffn(x))

        return x


# --- Encoder ---

class Encoder(nn.Module):

    def __init__(self, vocab_size, d_model, N, num_heads, d_ff, max_len=5000):

        super().__init__()

        self.embed = nn.Embedding(vocab_size, d_model)

        self.pe = PositionalEncoding(d_model, max_len)

        self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff) for _ in range(N)])


    def forward(self, x, mask):

        x = self.embed(x)     # shape: [batch_size, seq_len, d_model]

        ###print("x.shape after embedding:", x.shape)

        x = self.pe(x)


        for layer in self.layers:

            x = layer(x, mask)

        return x


# --- Decoder ---

class Decoder(nn.Module):

    def __init__(self, vocab_size, d_model, N, num_heads, d_ff, max_len=5000):

        super().__init__()

        self.embed = nn.Embedding(vocab_size, d_model)

        self.pe = PositionalEncoding(d_model, max_len)

        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff) for _ in range(N)])


    def forward(self, x, enc_out, tgt_mask, memory_mask):

        x = self.pe(self.embed(x))

        for layer in self.layers:

            x = layer(x, enc_out, tgt_mask, memory_mask)

        return x


# --- Transformer ---

class Transformer(nn.Module):

    def __init__(self, src_vocab, tgt_vocab, d_model=512, N=2, heads=8, d_ff=1024):

        super().__init__()

        self.encoder = Encoder(src_vocab, d_model, N, heads, d_ff)

        self.decoder = Decoder(tgt_vocab, d_model, N, heads, d_ff)

        self.out = nn.Linear(d_model, tgt_vocab)


    def make_pad_mask(self, seq, pad_token=0):

        return (seq != pad_token).unsqueeze(1).unsqueeze(2)


    def make_look_ahead_mask(self, size):

        return torch.tril(torch.ones(size, size)).bool().unsqueeze(0).unsqueeze(1)


    def forward(self, src, tgt):

        src_mask = self.make_pad_mask(src)

        tgt_mask = self.make_pad_mask(tgt) & self.make_look_ahead_mask(tgt.size(1)).to(tgt.device)

        enc = self.encoder(src, src_mask)

        dec = self.decoder(tgt, enc, tgt_mask, src_mask)

        return self.out(dec)


# --- Main Program ---

if __name__ == "__main__":

    src = torch.tensor([[1, 5, 6, 2, 0]])  # (batch, src_seq_len)

    tgt = torch.tensor([[1, 7, 4, 3, 0]])  # (batch, tgt_seq_len)


    model = Transformer(src_vocab=10000, tgt_vocab=10000)

    out = model(src, tgt)  # (1, 5, 10000)


    print("Output shape:", out.shape)


    # Convert logits to probabilities

    probs = torch.softmax(out, dim=-1)  # (1, 5, 10000)


    # Predicted token IDs (argmax)

    predicted_ids = torch.argmax(probs, dim=-1)

    print("\nPredicted token IDs:", predicted_ids.tolist())


    # Show top-5 token IDs per position

    for pos in range(probs.size(1)):

        top5 = torch.topk(probs[0, pos], 5)

        print(f"\nPosition {pos} Top-5:")

        for prob, idx in zip(top5.values, top5.indices):

            print(f"  Token ID: {idx.item()}, Probability: {prob.item():.4f}")