# -*- coding: utf-8 -*-
"""
Created on Fri Jul 25 14:45:28 2025
@author: ali.saral
"""
# train_and_infer.py
import torch
import torch.nn as nn
import torch.optim as optim
import random
from my_transformer import Transformer # import the corrected Transformer class
# --- Config ---
PAD_IDX = 0
SOS_IDX = 1
EOS_IDX = 2
VOCAB_SIZE = 30 # vocab tokens 0..29
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_LEN = 6
# --- Generate toy data ---
def generate_pair():
length = random.randint(2, MAX_LEN - 2)
src_seq = [random.randint(3, 20) for _ in range(length)]
tgt_seq = [x + 1 for x in src_seq] # rule: increment each token by 1
return src_seq, tgt_seq
def pad_sequence(seq, max_len):
return seq + [PAD_IDX] * (max_len - len(seq))
def prepare_batch(batch_size=32):
src_batch, tgt_batch = [], []
for _ in range(batch_size):
src, tgt = generate_pair()
src = [SOS_IDX] + src + [EOS_IDX]
tgt = [SOS_IDX] + tgt + [EOS_IDX]
src = pad_sequence(src, MAX_LEN)
tgt = pad_sequence(tgt, MAX_LEN)
src_batch.append(src)
tgt_batch.append(tgt)
return torch.tensor(src_batch, dtype=torch.long, device=DEVICE), torch.tensor(tgt_batch, dtype=torch.long, device=DEVICE)
# --- Instantiate model, optimizer, criterion ---
model = Transformer(src_vocab=VOCAB_SIZE, tgt_vocab=VOCAB_SIZE, d_model=128, N=2, heads=4, d_ff=512).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
# --- Training loop ---
for epoch in range(1500):
model.train()
src, tgt = prepare_batch(batch_size=64)
# Print input and target sequences of the current batch
###print(f"\nEpoch {epoch} Training Batch (showing first 3 samples):")
for i in range(3):
# Convert tensor row to list, remove padding index for clarity
input_seq = [x.item() for x in src[i] if x.item() != PAD_IDX]
target_seq = [x.item() for x in tgt[i] if x.item() != PAD_IDX]
###print(f" Input: {input_seq} -> Target: {target_seq}")
optimizer.zero_grad()
out = model(src, tgt[:, :-1]) # teacher forcing, predict next tokens
loss = criterion(out.view(-1, VOCAB_SIZE), tgt[:, 1:].reshape(-1))
loss.backward()
optimizer.step()
if epoch % 100 == 0:
print(f"Epoch {epoch}: Loss = {loss.item():.4f}")
# --- Inference function ---
def translate(input_seq):
model.eval()
src_seq = [SOS_IDX] + input_seq + [EOS_IDX]
src_seq = pad_sequence(src_seq, MAX_LEN)
src = torch.tensor([src_seq], dtype=torch.long, device=DEVICE)
tgt_seq = [SOS_IDX]
for _ in range(MAX_LEN - 1):
tgt_padded = pad_sequence(tgt_seq, MAX_LEN)
tgt_tensor = torch.tensor([tgt_padded], dtype=torch.long, device=DEVICE)
with torch.no_grad():
out = model(src, tgt_tensor)
next_token = out[0, len(tgt_seq) - 1].argmax().item()
if next_token == EOS_IDX:
break
tgt_seq.append(next_token)
return tgt_seq[1:] # remove SOS
# --- Test the trained model ---
test_input = [3, 7, 2] # expected output: [4, 8, 2]
output = translate(test_input)
print("\nTest Input:", test_input)
print("Predicted Output:", output)
test_input = [3, 7, 3] # expected output: [4, 8, 4]
output = translate(test_input)
print("\nTest Input:", test_input)
print("Predicted Output:", output)
test_input = [3, 5, 10, 7] # expected output: [4, 8, 2]
output = translate(test_input)
print("\nTest Input:", test_input)
print("Predicted Output:", output)
"""
Test Input: [3, 7, 2]
Predicted Output: [4, 8]
Test Input: [3, 7, 3]
Predicted Output: [4, 8, 4]
Test Input: [1, 5, 10, 7, 2, 0]
Predicted Output: [6, 6, 11, 8]
"""
"""
%runfile
Epoch 0: Loss = 3.5938
Epoch 100: Loss = 0.0167
Epoch 200: Loss = 0.0104
Epoch 300: Loss = 0.0023
Epoch 400: Loss = 0.0293
Epoch 500: Loss = 0.0125
Epoch 600: Loss = 0.0288
Epoch 700: Loss = 0.0111
Epoch 800: Loss = 0.0054
Epoch 900: Loss = 0.0009
Epoch 1000: Loss = 0.0003
Epoch 1100: Loss = 0.0002
Epoch 1200: Loss = 0.0002
Epoch 1300: Loss = 0.0001
Epoch 1400: Loss = 0.0001
Test Input: [3, 7, 2]
Predicted Output: [4, 8]
Test Input: [3, 7, 3]
Predicted Output: [4, 8, 4]
Test Input: [3, 5, 10, 7]
Predicted Output: [4, 6, 11, 8]
"""
