Tuesday, 20 January 2026

runs both transformers

 # -*- coding: utf-8 -*-

"""

Created on Fri Jul 25 14:45:28 2025


@author: ali.saral

"""


# train_and_infer.py

import torch

import torch.nn as nn

import torch.optim as optim

import random

from my_transformer import Transformer  # import the corrected Transformer class


# --- Config ---

PAD_IDX = 0

SOS_IDX = 1

EOS_IDX = 2

VOCAB_SIZE = 30  # vocab tokens 0..29

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MAX_LEN = 6


# --- Generate toy data ---

def generate_pair():

    length = random.randint(2, MAX_LEN - 2)

    src_seq = [random.randint(3, 20) for _ in range(length)]

    tgt_seq = [x + 1 for x in src_seq]  # rule: increment each token by 1

    return src_seq, tgt_seq


def pad_sequence(seq, max_len):

    return seq + [PAD_IDX] * (max_len - len(seq))


def prepare_batch(batch_size=32):

    src_batch, tgt_batch = [], []

    for _ in range(batch_size):

        src, tgt = generate_pair()

        src = [SOS_IDX] + src + [EOS_IDX]

        tgt = [SOS_IDX] + tgt + [EOS_IDX]

        src = pad_sequence(src, MAX_LEN)

        tgt = pad_sequence(tgt, MAX_LEN)

        src_batch.append(src)

        tgt_batch.append(tgt)

    return torch.tensor(src_batch, dtype=torch.long, device=DEVICE), torch.tensor(tgt_batch, dtype=torch.long, device=DEVICE)


# --- Instantiate model, optimizer, criterion ---

model = Transformer(src_vocab=VOCAB_SIZE, tgt_vocab=VOCAB_SIZE, d_model=128, N=2, heads=4, d_ff=512).to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr=0.001)

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)


# --- Training loop ---

for epoch in range(1500):

    model.train()

    src, tgt = prepare_batch(batch_size=64)

    

    # Print input and target sequences of the current batch

    ###print(f"\nEpoch {epoch} Training Batch (showing first 3 samples):")

    for i in range(3):

        # Convert tensor row to list, remove padding index for clarity

        input_seq = [x.item() for x in src[i] if x.item() != PAD_IDX]

        target_seq = [x.item() for x in tgt[i] if x.item() != PAD_IDX]

        ###print(f"  Input: {input_seq}  -> Target: {target_seq}")

    

    optimizer.zero_grad()

    out = model(src, tgt[:, :-1])  # teacher forcing, predict next tokens

    loss = criterion(out.view(-1, VOCAB_SIZE), tgt[:, 1:].reshape(-1))

    loss.backward()

    optimizer.step()

    

    if epoch % 100 == 0:

        print(f"Epoch {epoch}: Loss = {loss.item():.4f}")

        

# --- Inference function ---

def translate(input_seq):

    model.eval()

    src_seq = [SOS_IDX] + input_seq + [EOS_IDX]

    src_seq = pad_sequence(src_seq, MAX_LEN)

    src = torch.tensor([src_seq], dtype=torch.long, device=DEVICE)


    tgt_seq = [SOS_IDX]

    for _ in range(MAX_LEN - 1):

        tgt_padded = pad_sequence(tgt_seq, MAX_LEN)

        tgt_tensor = torch.tensor([tgt_padded], dtype=torch.long, device=DEVICE)

        with torch.no_grad():

            out = model(src, tgt_tensor)

        next_token = out[0, len(tgt_seq) - 1].argmax().item()

        if next_token == EOS_IDX:

            break

        tgt_seq.append(next_token)

    return tgt_seq[1:]  # remove SOS


# --- Test the trained model ---

test_input = [3, 7, 2]  # expected output: [4, 8, 2]

output = translate(test_input)

print("\nTest Input:", test_input)

print("Predicted Output:", output)

test_input = [3, 7, 3]  # expected output: [4, 8, 4]

output = translate(test_input)

print("\nTest Input:", test_input)

print("Predicted Output:", output)

test_input = [3, 5, 10, 7]  # expected output: [4, 8, 2]

output = translate(test_input)

print("\nTest Input:", test_input)

print("Predicted Output:", output)


"""

Test Input: [3, 7, 2]

Predicted Output: [4, 8]


Test Input: [3, 7, 3]

Predicted Output: [4, 8, 4]


Test Input: [1, 5, 10, 7, 2, 0]

Predicted Output: [6, 6, 11, 8]

"""

"""

%runfile

Epoch 0: Loss = 3.5938

Epoch 100: Loss = 0.0167

Epoch 200: Loss = 0.0104

Epoch 300: Loss = 0.0023

Epoch 400: Loss = 0.0293

Epoch 500: Loss = 0.0125

Epoch 600: Loss = 0.0288

Epoch 700: Loss = 0.0111

Epoch 800: Loss = 0.0054

Epoch 900: Loss = 0.0009

Epoch 1000: Loss = 0.0003

Epoch 1100: Loss = 0.0002

Epoch 1200: Loss = 0.0002

Epoch 1300: Loss = 0.0001

Epoch 1400: Loss = 0.0001


Test Input: [3, 7, 2]

Predicted Output: [4, 8]


Test Input: [3, 7, 3]

Predicted Output: [4, 8, 4]


Test Input: [3, 5, 10, 7]

Predicted Output: [4, 6, 11, 8]

"""