| __author__ = 'Taneem Jan, taneemishere.github.io' | |
| import sys | |
| import numpy as np | |
| START_TOKEN = "<START>" | |
| END_TOKEN = "<END>" | |
| PLACEHOLDER = " " | |
| SEPARATOR = '->' | |
| class Vocabulary: | |
| def __init__(self): | |
| self.binary_vocabulary = {} | |
| self.vocabulary = {} | |
| self.token_lookup = {} | |
| self.size = 0 | |
| self.append(START_TOKEN) | |
| self.append(END_TOKEN) | |
| self.append(PLACEHOLDER) | |
| def append(self, token): | |
| if token not in self.vocabulary: | |
| self.vocabulary[token] = self.size | |
| self.token_lookup[self.size] = token | |
| self.size += 1 | |
| def create_binary_representation(self): | |
| if sys.version_info >= (3,): | |
| items = self.vocabulary.items() | |
| else: | |
| items = self.vocabulary.iteritems() | |
| for key, value in items: | |
| binary = np.zeros(self.size) | |
| binary[value] = 1 | |
| self.binary_vocabulary[key] = binary | |
| def get_serialized_binary_representation(self): | |
| if len(self.binary_vocabulary) == 0: | |
| self.create_binary_representation() | |
| string = "" | |
| if sys.version_info >= (3,): | |
| items = self.binary_vocabulary.items() | |
| else: | |
| items = self.binary_vocabulary.iteritems() | |
| for key, value in items: | |
| array_as_string = np.array2string(value, separator=',', max_line_width=self.size * self.size) | |
| string += "{}{}{}\n".format(key, SEPARATOR, array_as_string[1:len(array_as_string) - 1]) | |
| return string | |
| def save(self, path): | |
| output_file_name = "{}/words.vocab".format(path) | |
| output_file = open(output_file_name, 'w') | |
| output_file.write(self.get_serialized_binary_representation()) | |
| output_file.close() | |
| def retrieve(self, path): | |
| input_file = open("{}/words.vocab".format(path), 'r') | |
| buffer = "" | |
| for line in input_file: | |
| try: | |
| separator_position = len(buffer) + line.index(SEPARATOR) | |
| buffer += line | |
| key = buffer[:separator_position] | |
| value = buffer[separator_position + len(SEPARATOR):] | |
| value = np.fromstring(value, sep=',') | |
| self.binary_vocabulary[key] = value | |
| self.vocabulary[key] = np.where(value == 1)[0][0] | |
| self.token_lookup[np.where(value == 1)[0][0]] = key | |
| buffer = "" | |
| except ValueError: | |
| buffer += line | |
| input_file.close() | |
| self.size = len(self.vocabulary) | |