|
|
import torch
|
|
|
from torch.utils.data import Dataset
|
|
|
from typing import List, Tuple
|
|
|
|
|
|
class MathDataset(Dataset):
|
|
|
"""
|
|
|
A custom PyTorch Dataset to handle the encoded math problem sequences.
|
|
|
It performs the crucial language model shift (X is the input, Y is X shifted by one)
|
|
|
and handles padding.
|
|
|
"""
|
|
|
def __init__(self, data: List[str], tokenizer, max_len: int):
|
|
|
self.data = data
|
|
|
self.tokenizer = tokenizer
|
|
|
self.max_len = max_len
|
|
|
self.pad_token_id = tokenizer.pad_token_id
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
|
return len(self.data)
|
|
|
|
|
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
raw_text = self.data[idx]
|
|
|
sequence_ids = self.tokenizer.encode(raw_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = sequence_ids[:-1]
|
|
|
|
|
|
|
|
|
|
|
|
y = sequence_ids[1:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
padding_length = self.max_len - len(x)
|
|
|
|
|
|
|
|
|
x_padded = x + [self.pad_token_id] * padding_length
|
|
|
y_padded = y + [self.pad_token_id] * padding_length
|
|
|
|
|
|
|
|
|
return torch.tensor(x_padded, dtype=torch.long), torch.tensor(y_padded, dtype=torch.long) |