# -*- coding: utf-8 -*-
"""
Created on Fri Jul 25 18:04:19 2025
@author: ali.saral
"""
# train_seq2seq_transformer.py
import torch
import torch.nn as nn
import torch.optim as optim
import random
# --- Configuration ---
PAD_IDX = 0
SOS_IDX = 1
EOS_IDX = 2
VOCAB_SIZE = 30
MAX_LEN = 6
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Data Utilities ---
def generate_pair():
length = random.randint(2, MAX_LEN - 2)
src_seq = [random.randint(3, 20) for _ in range(length)]
tgt_seq = [x + 1 for x in src_seq]
return src_seq, tgt_seq
def pad_sequence(seq, max_len):
return seq + [PAD_IDX] * (max_len - len(seq))
def prepare_batch(batch_size=32):
src_batch, tgt_batch = [], []
for _ in range(batch_size):
src, tgt = generate_pair()
src = [SOS_IDX] + src + [EOS_IDX]
tgt = [SOS_IDX] + tgt + [EOS_IDX]
src = pad_sequence(src, MAX_LEN)
tgt = pad_sequence(tgt, MAX_LEN)
src_batch.append(src)
tgt_batch.append(tgt)
return torch.tensor(src_batch, dtype=torch.long, device=DEVICE), torch.tensor(tgt_batch, dtype=torch.long, device=DEVICE)
# --- Positional Encoding ---
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.pe = pe.unsqueeze(0) # (1, max_len, d_model)
def forward(self, x):
x = x + self.pe[:, :x.size(1)].to(x.device)
return x
# --- Transformer Model ---
class Seq2SeqTransformer(nn.Module):
def __init__(self, vocab_size, d_model=128, nhead=4, num_layers=2, dim_feedforward=512):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_enc = PositionalEncoding(d_model)
self.transformer = nn.Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_layers,
num_decoder_layers=num_layers,
dim_feedforward=dim_feedforward,
dropout=0.1,
batch_first=True
)
self.fc_out = nn.Linear(d_model, vocab_size)
def forward(self, src, tgt):
src_mask = None
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
src_emb = self.pos_enc(self.embedding(src) * (self.d_model ** 0.5))
tgt_emb = self.pos_enc(self.embedding(tgt) * (self.d_model ** 0.5))
out = self.transformer(src_emb, tgt_emb, src_mask=src_mask, tgt_mask=tgt_mask)
return self.fc_out(out)
# --- Instantiate ---
model = Seq2SeqTransformer(vocab_size=VOCAB_SIZE).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
# --- Training Loop ---
for epoch in range(2000):
model.train()
src, tgt = prepare_batch(64)
optimizer.zero_grad()
output = model(src, tgt[:, :-1]) # predict next tokens
loss = criterion(output.reshape(-1, VOCAB_SIZE), tgt[:, 1:].reshape(-1))
loss.backward()
optimizer.step()
if epoch % 100 == 0:
print(f"Epoch {epoch}: Loss = {loss.item():.4f}")
# --- Inference ---
def translate(input_seq):
model.eval()
src_seq = [SOS_IDX] + input_seq + [EOS_IDX]
src_seq = pad_sequence(src_seq, MAX_LEN)
src = torch.tensor([src_seq], dtype=torch.long, device=DEVICE)
tgt_seq = [SOS_IDX]
for _ in range(MAX_LEN - 1):
tgt_padded = pad_sequence(tgt_seq, MAX_LEN)
tgt_tensor = torch.tensor([tgt_padded], dtype=torch.long, device=DEVICE)
with torch.no_grad():
output = model(src, tgt_tensor)
next_token = output[0, len(tgt_seq)-1].argmax().item()
if next_token == EOS_IDX:
break
tgt_seq.append(next_token)
return tgt_seq[1:]
# --- Test ---
for test_input in [[3, 7, 2], [3, 7, 3], [4, 5, 10, 7]]:
output = translate(test_input)
print(f"\nInput: {test_input}")
print(f"Predicted Output: {output}")
""""
%runfile
Reloaded modules: my_transformer
Epoch 0: Loss = 3.4394
Epoch 5: Loss = 2.6161
Epoch 10: Loss = 2.2482
Epoch 15: Loss = 1.9300
Epoch 20: Loss = 1.6606
Epoch 25: Loss = 1.5804
Input: [3, 7, 2]
Predicted Output: [4, 8]
Input: [3, 7, 3]
Predicted Output: [4, 4]
Input: [1, 5, 10, 7, 2, 0]
Predicted Output: [6, 8]
"""
