Parmaklarınızın ucu ile sequence to sequence Pytorch
transformer yapabilirsiniz
Pytorch’un hazır modüllerini kullanarak 120 satır Python kodu ile bir sequence to
sequence
Transformer yapmak mümkün (Bkz. https://github.com/arsaral/test-project_transformer).
https://github.com/arsaral/test-project_transformer/blob/main/my_transformer_torch_version.py
içindeki basit örnek verilen dizi içindeki sayıları 1 arttırıp çıktı dizisi üreten çok basit bir sequence to sequence
transformer. Bu transformer bir tercüme
motoru ile aynı yapıya sahiptir.
Input: [3, 7, 2]
Predicted Output: [4, 8]
2 sayısı özel bir sayı olduğu için
2 eklenmemiş.
PAD_IDX = 0
SOS_IDX = 1
EOS_IDX = 2
Öğrenme süreci çok kısa tutulduğu
için (25 epoch) ve eğitim verileri çok sınırlı (32) olduğu için
çıktılar mükemmel değil:
%runfile
Reloaded modules: my_transformer
Epoch 0: Loss = 3.4394
Epoch 5: Loss = 2.6161
Epoch 10: Loss = 2.2482
Epoch 15: Loss = 1.9300
Epoch 20: Loss = 1.6606
Epoch 25: Loss = 1.5804
Çıktı:
Input: [3, 7, 2]
Predicted Output: [4, 8]
Input: [3, 7, 3]
Predicted Output: [4, 4]
Input: [1, 5, 10, 7, 2, 0]
Predicted Output: [6, 8]
Bu program tamamen hazır
kullanılabilir pytorch kütüphaneleri ile yapılmıştır ve
Eğitim veri ve çevrim sayısı çok sınırlı olduğu için
kolaylıkla çalıştırabilirsiniz.
import torch
import torch.nn as nn
import torch.optim as optim
import random
prepare_batch fonksiyonu generate_pair()
fonksiyonunu kullanarak
src, tgt şeklinde eğitim
verilerini oluşturur.
def prepare_batch(batch_size=32):
src_batch, tgt_batch = [], []
for _ in range(batch_size):
src, tgt = generate_pair()
src = [SOS_IDX] + src + [EOS_IDX]
tgt = [SOS_IDX] + tgt + [EOS_IDX]
src = pad_sequence(src, MAX_LEN)
tgt = pad_sequence(tgt, MAX_LEN)
src_batch.append(src)
tgt_batch.append(tgt)
return torch.tensor(src_batch, dtype=torch.long, device=DEVICE),
torch.tensor(tgt_batch, dtype=torch.long, device=DEVICE)
transformerı aşağıdaki kısım ile
çalıştırabilirsiniz:
# --- Test ---
for test_input in [[3, 7, 2], [3,
7, 3], [4, 5, 10, 7]]:
output = translate(test_input)
print(f"\nInput: {test_input}")
print(f"Predicted Output: {output}")
