MiniMind / training /dataset.py
fariasultana's picture
MiniMind Max2 - Efficient MoE Language Model
8b187bb verified
"""
MiniMind Dataset and DataLoader utilities
"""
import json
from typing import Optional, List, Dict, Any
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader
class TextDataset(Dataset):
"""Simple text dataset for language model training."""
def __init__(
self,
data_path: str,
tokenizer: Any,
max_length: int = 2048,
format_type: str = "jsonl", # jsonl, txt, parquet
):
self.tokenizer = tokenizer
self.max_length = max_length
self.data = self._load_data(data_path, format_type)
def _load_data(self, data_path: str, format_type: str) -> List[str]:
data = []
path = Path(data_path)
if format_type == "jsonl":
with open(path, "r", encoding="utf-8") as f:
for line in f:
item = json.loads(line.strip())
text = item.get("text", item.get("content", ""))
if text:
data.append(text)
elif format_type == "txt":
with open(path, "r", encoding="utf-8") as f:
data = [line.strip() for line in f if line.strip()]
else:
raise ValueError(f"Unsupported format: {format_type}")
return data
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
text = self.data[idx]
encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
)
return {
"input_ids": encoding["input_ids"].squeeze(0),
"attention_mask": encoding["attention_mask"].squeeze(0),
"labels": encoding["input_ids"].squeeze(0),
}
class DistillationDataset(Dataset):
"""Dataset for knowledge distillation with teacher logits."""
def __init__(
self,
data_path: str,
tokenizer: Any,
teacher_logits_path: Optional[str] = None,
max_length: int = 2048,
):
self.tokenizer = tokenizer
self.max_length = max_length
self.data = self._load_data(data_path)
self.teacher_logits = self._load_teacher_logits(teacher_logits_path) if teacher_logits_path else None
def _load_data(self, data_path: str) -> List[str]:
with open(data_path, "r", encoding="utf-8") as f:
return [json.loads(line.strip()).get("text", "") for line in f if line.strip()]
def _load_teacher_logits(self, path: str) -> Optional[torch.Tensor]:
if Path(path).exists():
return torch.load(path, map_location="cpu")
return None
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
text = self.data[idx]
encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
)
item = {
"input_ids": encoding["input_ids"].squeeze(0),
"attention_mask": encoding["attention_mask"].squeeze(0),
"labels": encoding["input_ids"].squeeze(0),
}
if self.teacher_logits is not None:
item["teacher_logits"] = self.teacher_logits[idx]
return item
def create_dataloader(
dataset: Dataset,
batch_size: int = 8,
shuffle: bool = True,
num_workers: int = 4,
pin_memory: bool = True,
) -> DataLoader:
"""Create a DataLoader with optimal settings."""
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=True,
)