Tuesday, 20 January 2026

pytorch based transformer

 # -*- coding: utf-8 -*-

"""

Created on Fri Jul 25 18:04:19 2025


@author: ali.saral

"""


# train_seq2seq_transformer.py

import torch

import torch.nn as nn

import torch.optim as optim

import random


# --- Configuration ---

PAD_IDX = 0

SOS_IDX = 1

EOS_IDX = 2

VOCAB_SIZE = 30

MAX_LEN = 6

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# --- Data Utilities ---

def generate_pair():

    length = random.randint(2, MAX_LEN - 2)

    src_seq = [random.randint(3, 20) for _ in range(length)]

    tgt_seq = [x + 1 for x in src_seq]

    return src_seq, tgt_seq


def pad_sequence(seq, max_len):

    return seq + [PAD_IDX] * (max_len - len(seq))


def prepare_batch(batch_size=32):

    src_batch, tgt_batch = [], []

    for _ in range(batch_size):

        src, tgt = generate_pair()

        src = [SOS_IDX] + src + [EOS_IDX]

        tgt = [SOS_IDX] + tgt + [EOS_IDX]

        src = pad_sequence(src, MAX_LEN)

        tgt = pad_sequence(tgt, MAX_LEN)

        src_batch.append(src)

        tgt_batch.append(tgt)

    return torch.tensor(src_batch, dtype=torch.long, device=DEVICE), torch.tensor(tgt_batch, dtype=torch.long, device=DEVICE)


# --- Positional Encoding ---

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=5000):

        super().__init__()

        pe = torch.zeros(max_len, d_model)

        position = torch.arange(0, max_len).unsqueeze(1).float()

        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)

        pe[:, 1::2] = torch.cos(position * div_term)

        self.pe = pe.unsqueeze(0)  # (1, max_len, d_model)


    def forward(self, x):

        x = x + self.pe[:, :x.size(1)].to(x.device)

        return x


# --- Transformer Model ---

class Seq2SeqTransformer(nn.Module):

    def __init__(self, vocab_size, d_model=128, nhead=4, num_layers=2, dim_feedforward=512):

        super().__init__()

        self.d_model = d_model

        self.embedding = nn.Embedding(vocab_size, d_model)

        self.pos_enc = PositionalEncoding(d_model)

        self.transformer = nn.Transformer(

            d_model=d_model,

            nhead=nhead,

            num_encoder_layers=num_layers,

            num_decoder_layers=num_layers,

            dim_feedforward=dim_feedforward,

            dropout=0.1,

            batch_first=True

        )

        self.fc_out = nn.Linear(d_model, vocab_size)


    def forward(self, src, tgt):

        src_mask = None

        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)


        src_emb = self.pos_enc(self.embedding(src) * (self.d_model ** 0.5))

        tgt_emb = self.pos_enc(self.embedding(tgt) * (self.d_model ** 0.5))


        out = self.transformer(src_emb, tgt_emb, src_mask=src_mask, tgt_mask=tgt_mask)

        return self.fc_out(out)


# --- Instantiate ---

model = Seq2SeqTransformer(vocab_size=VOCAB_SIZE).to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr=0.001)

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)


# --- Training Loop ---

for epoch in range(2000):

    model.train()

    src, tgt = prepare_batch(64)

    optimizer.zero_grad()

    output = model(src, tgt[:, :-1])  # predict next tokens

    loss = criterion(output.reshape(-1, VOCAB_SIZE), tgt[:, 1:].reshape(-1))

    loss.backward()

    optimizer.step()

    if epoch % 100 == 0:

        print(f"Epoch {epoch}: Loss = {loss.item():.4f}")


# --- Inference ---

def translate(input_seq):

    model.eval()

    src_seq = [SOS_IDX] + input_seq + [EOS_IDX]

    src_seq = pad_sequence(src_seq, MAX_LEN)

    src = torch.tensor([src_seq], dtype=torch.long, device=DEVICE)


    tgt_seq = [SOS_IDX]

    for _ in range(MAX_LEN - 1):

        tgt_padded = pad_sequence(tgt_seq, MAX_LEN)

        tgt_tensor = torch.tensor([tgt_padded], dtype=torch.long, device=DEVICE)

        with torch.no_grad():

            output = model(src, tgt_tensor)

        next_token = output[0, len(tgt_seq)-1].argmax().item()

        if next_token == EOS_IDX:

            break

        tgt_seq.append(next_token)

    return tgt_seq[1:]


# --- Test ---

for test_input in [[3, 7, 2], [3, 7, 3], [4, 5, 10, 7]]:

    output = translate(test_input)

    print(f"\nInput: {test_input}")

    print(f"Predicted Output: {output}")

    

""""

%runfile 

Reloaded modules: my_transformer

Epoch 0: Loss = 3.4394

Epoch 5: Loss = 2.6161

Epoch 10: Loss = 2.2482

Epoch 15: Loss = 1.9300

Epoch 20: Loss = 1.6606

Epoch 25: Loss = 1.5804


Input: [3, 7, 2]

Predicted Output: [4, 8]


Input: [3, 7, 3]

Predicted Output: [4, 4]


Input: [1, 5, 10, 7, 2, 0]

Predicted Output: [6, 8]


"""