tiny-math-llm / src /tokenizer.py
anujbhatt4ai's picture
Initial upload of TinyLLM
13c35e3 verified
import random
def generate_v1_data():
"""Generates all exhaustive single-digit math problems."""
data = []
# Operators and their functions
ops = {'+': lambda a, b: a + b,
'-': lambda a, b: a - b,
'*': lambda a, b: a * b,
'/': lambda a, b: a / b}
# Iterate through all single-digit pairs (0-9)
for a in range(10):
for b in range(10):
for op_char, op_func in ops.items():
# Check for constraints: Single-Digit Answer (0-9) & Validity
if op_char == '+':
result = op_func(a, b)
# Constraint: Sum must be a single digit (<= 9)
if result <= 9:
problem = f"{a} + {b} = {result}"
data.append(problem)
elif op_char == '-':
result = op_func(a, b)
# Constraint: Result must be non-negative (>= 0) and <= 9
if 0 <= result <= 9:
problem = f"{a} - {b} = {result}"
data.append(problem)
elif op_char == '*':
result = op_func(a, b)
# Constraint: Product must be a single digit (<= 9)
if result <= 9:
problem = f"{a} * {b} = {result}"
data.append(problem)
elif op_char == '/':
# Cannot divide by zero
if b == 0:
continue
result = op_func(a, b)
# Constraint: Result must be a whole number (no remainder) AND a single digit (<= 9)
if a % b == 0 and result <= 9:
# Use int() to remove potential float from division result
problem = f"{a} / {b} = {int(result)}"
data.append(problem)
# IMPORTANT: Shuffle and add <EOS> marker
random.shuffle(data)
final_data = [d + "<EOS>" for d in data]
return final_data
class CharacterTokenizer:
"""A simple character-level tokenizer for the math problems."""
def __init__(self, data):
# 1. Build the unique vocabulary from the entire dataset
# We need to make sure the data is generated first!
chars = sorted(list(set("".join(data))))
# Add a Padding token for PyTorch batching
if '<PAD>' not in chars:
chars.append('<PAD>')
self.stoi = {ch: i for i, ch in enumerate(chars)}
self.itos = {i: ch for i, ch in enumerate(chars)}
self.vocab_size = len(chars)
self.pad_token_id = self.stoi['<PAD>']
def encode(self, s):
"""Encodes a string into a list of integers."""
return [self.stoi[c] for c in s]
def decode(self, l):
"""Decodes a list of integers back into a string."""
return "".join([self.itos[i] for i in l])