|
|
""" |
|
|
MiniMind Dataset and DataLoader utilities |
|
|
""" |
|
|
|
|
|
import json |
|
|
from typing import Optional, List, Dict, Any |
|
|
from pathlib import Path |
|
|
import torch |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
|
|
|
class TextDataset(Dataset): |
|
|
"""Simple text dataset for language model training.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
data_path: str, |
|
|
tokenizer: Any, |
|
|
max_length: int = 2048, |
|
|
format_type: str = "jsonl", |
|
|
): |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
self.data = self._load_data(data_path, format_type) |
|
|
|
|
|
def _load_data(self, data_path: str, format_type: str) -> List[str]: |
|
|
data = [] |
|
|
path = Path(data_path) |
|
|
|
|
|
if format_type == "jsonl": |
|
|
with open(path, "r", encoding="utf-8") as f: |
|
|
for line in f: |
|
|
item = json.loads(line.strip()) |
|
|
text = item.get("text", item.get("content", "")) |
|
|
if text: |
|
|
data.append(text) |
|
|
elif format_type == "txt": |
|
|
with open(path, "r", encoding="utf-8") as f: |
|
|
data = [line.strip() for line in f if line.strip()] |
|
|
else: |
|
|
raise ValueError(f"Unsupported format: {format_type}") |
|
|
|
|
|
return data |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
|
|
text = self.data[idx] |
|
|
encoding = self.tokenizer( |
|
|
text, |
|
|
truncation=True, |
|
|
max_length=self.max_length, |
|
|
padding="max_length", |
|
|
return_tensors="pt", |
|
|
) |
|
|
return { |
|
|
"input_ids": encoding["input_ids"].squeeze(0), |
|
|
"attention_mask": encoding["attention_mask"].squeeze(0), |
|
|
"labels": encoding["input_ids"].squeeze(0), |
|
|
} |
|
|
|
|
|
|
|
|
class DistillationDataset(Dataset): |
|
|
"""Dataset for knowledge distillation with teacher logits.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
data_path: str, |
|
|
tokenizer: Any, |
|
|
teacher_logits_path: Optional[str] = None, |
|
|
max_length: int = 2048, |
|
|
): |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
self.data = self._load_data(data_path) |
|
|
self.teacher_logits = self._load_teacher_logits(teacher_logits_path) if teacher_logits_path else None |
|
|
|
|
|
def _load_data(self, data_path: str) -> List[str]: |
|
|
with open(data_path, "r", encoding="utf-8") as f: |
|
|
return [json.loads(line.strip()).get("text", "") for line in f if line.strip()] |
|
|
|
|
|
def _load_teacher_logits(self, path: str) -> Optional[torch.Tensor]: |
|
|
if Path(path).exists(): |
|
|
return torch.load(path, map_location="cpu") |
|
|
return None |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
|
|
text = self.data[idx] |
|
|
encoding = self.tokenizer( |
|
|
text, |
|
|
truncation=True, |
|
|
max_length=self.max_length, |
|
|
padding="max_length", |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
item = { |
|
|
"input_ids": encoding["input_ids"].squeeze(0), |
|
|
"attention_mask": encoding["attention_mask"].squeeze(0), |
|
|
"labels": encoding["input_ids"].squeeze(0), |
|
|
} |
|
|
|
|
|
if self.teacher_logits is not None: |
|
|
item["teacher_logits"] = self.teacher_logits[idx] |
|
|
|
|
|
return item |
|
|
|
|
|
|
|
|
def create_dataloader( |
|
|
dataset: Dataset, |
|
|
batch_size: int = 8, |
|
|
shuffle: bool = True, |
|
|
num_workers: int = 4, |
|
|
pin_memory: bool = True, |
|
|
) -> DataLoader: |
|
|
"""Create a DataLoader with optimal settings.""" |
|
|
return DataLoader( |
|
|
dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=shuffle, |
|
|
num_workers=num_workers, |
|
|
pin_memory=pin_memory, |
|
|
drop_last=True, |
|
|
) |
|
|
|