📻 Added parallelization to best question search
🚧 Created new tree to enable bootstrapping with indices (to avoid making a whole new bootstrapped database per tree)
This commit is contained in:
parent
3b6e5f642e
commit
9946ca10f9
6 changed files with 267 additions and 23 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,3 +1,4 @@
|
||||||
__pycache__/*
|
__pycache__/*
|
||||||
output/*
|
output/*
|
||||||
.vscode/*
|
.vscode/*
|
||||||
|
hygdata_v3.csv
|
3
star.py
3
star.py
|
@ -12,4 +12,5 @@ class Star(object):
|
||||||
classification = ' or '.join(self.label)
|
classification = ' or '.join(self.label)
|
||||||
else:
|
else:
|
||||||
classification = self.label
|
classification = self.label
|
||||||
return 'Star {} {} of spectral type {}'.format(self.name, self.data, classification)
|
return ("Star {} {} of spectral type {}"
|
||||||
|
.format(self.name, self.data, classification))
|
||||||
|
|
|
@ -15,6 +15,8 @@ def make_star(header, row, fields=None):
|
||||||
num = float(value)
|
num = float(value)
|
||||||
if num == int(num):
|
if num == int(num):
|
||||||
num = int(num)
|
num = int(num)
|
||||||
|
else:
|
||||||
|
num = round(num, 2)
|
||||||
value = num
|
value = num
|
||||||
except ValueError:
|
except ValueError:
|
||||||
if value == '':
|
if value == '':
|
||||||
|
@ -31,10 +33,13 @@ def make_star(header, row, fields=None):
|
||||||
|
|
||||||
type_list = value.split('/')
|
type_list = value.split('/')
|
||||||
types = []
|
types = []
|
||||||
for sp_type in type_list:
|
for star_type in type_list:
|
||||||
if sp_type and sp_type[0] in STAR_CLASSES:
|
for sp_type in STAR_CLASSES:
|
||||||
types.append(sp_type[0])
|
if star_type and sp_type in star_type.upper():
|
||||||
|
types.append(sp_type)
|
||||||
value = ''.join(set(types))
|
value = ''.join(set(types))
|
||||||
|
if value == '':
|
||||||
|
return None
|
||||||
|
|
||||||
data[field] = value
|
data[field] = value
|
||||||
|
|
||||||
|
|
51
tree.py
51
tree.py
|
@ -1,4 +1,5 @@
|
||||||
|
import random
|
||||||
|
import multiprocessing as mp
|
||||||
from question import Question
|
from question import Question
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,7 +47,17 @@ def info_gain(left_set, right_set, uncertainty):
|
||||||
return uncertainty - p * gini(left_set) - (1-p) * gini(right_set)
|
return uncertainty - p * gini(left_set) - (1-p) * gini(right_set)
|
||||||
|
|
||||||
|
|
||||||
|
def splitter(info):
|
||||||
|
question, dataset, uncertainty = info
|
||||||
|
matching, non_matching = partition(dataset, question)
|
||||||
|
if not matching or not non_matching:
|
||||||
|
return None
|
||||||
|
gain = info_gain(matching, non_matching, uncertainty)
|
||||||
|
return (gain, question, (matching, non_matching))
|
||||||
|
|
||||||
|
|
||||||
def find_best_split(fields, dataset, uncertainty=None):
|
def find_best_split(fields, dataset, uncertainty=None):
|
||||||
|
print("Splitting {} entries.".format(len(dataset)))
|
||||||
best_gain, best_question, best_split = 0, None, None
|
best_gain, best_question, best_split = 0, None, None
|
||||||
|
|
||||||
uncertainty = uncertainty or gini(dataset)
|
uncertainty = uncertainty or gini(dataset)
|
||||||
|
@ -55,6 +66,28 @@ def find_best_split(fields, dataset, uncertainty=None):
|
||||||
|
|
||||||
for i in range(columns):
|
for i in range(columns):
|
||||||
values = unique_vals(dataset, i)
|
values = unique_vals(dataset, i)
|
||||||
|
|
||||||
|
if len(dataset) > 400:
|
||||||
|
# Parallelize best split search
|
||||||
|
cpus = mp.cpu_count()
|
||||||
|
if i == 0:
|
||||||
|
print("-- Using {} CPUs to parallelize the split search."
|
||||||
|
.format(cpus))
|
||||||
|
splits = []
|
||||||
|
for value in values:
|
||||||
|
question = Question(fields, i, value)
|
||||||
|
splits.append((question, dataset, uncertainty))
|
||||||
|
|
||||||
|
chunk = max(int(len(splits)/(cpus*4)), 1)
|
||||||
|
with mp.Pool(cpus) as p:
|
||||||
|
for split in p.imap_unordered(splitter, splits,
|
||||||
|
chunksize=chunk):
|
||||||
|
if split is not None:
|
||||||
|
gain, question, branches = split
|
||||||
|
if gain > best_gain:
|
||||||
|
best_gain, best_question, best_split = \
|
||||||
|
gain, question, branches
|
||||||
|
else:
|
||||||
for value in values:
|
for value in values:
|
||||||
question = Question(fields, i, value)
|
question = Question(fields, i, value)
|
||||||
|
|
||||||
|
@ -110,6 +143,13 @@ class Node(object):
|
||||||
else:
|
else:
|
||||||
return self.right_branch.classify(entry)
|
return self.right_branch.classify(entry)
|
||||||
|
|
||||||
|
def predict(self, entry):
|
||||||
|
predict = self.classify(entry).predictions
|
||||||
|
choices = []
|
||||||
|
for label, count in predict.items():
|
||||||
|
choices.extend([label]*count)
|
||||||
|
return random.choice(choices)
|
||||||
|
|
||||||
def print(self, spacing=''):
|
def print(self, spacing=''):
|
||||||
if self.is_leaf:
|
if self.is_leaf:
|
||||||
s = spacing + "Predict: "
|
s = spacing + "Predict: "
|
||||||
|
@ -121,9 +161,9 @@ class Node(object):
|
||||||
return s + str(probs)
|
return s + str(probs)
|
||||||
|
|
||||||
s = spacing + str(self.question) + '\n'
|
s = spacing + str(self.question) + '\n'
|
||||||
s += spacing + "-> True:\n"
|
s += spacing + "├─ True:\n"
|
||||||
s += self.left_branch.print(spacing + " ") + '\n'
|
s += self.left_branch.print(spacing + "│ ") + '\n'
|
||||||
s += spacing + "-> False:\n"
|
s += spacing + "└─ False:\n"
|
||||||
s += self.right_branch.print(spacing + " ")
|
s += self.right_branch.print(spacing + " ")
|
||||||
|
|
||||||
return s
|
return s
|
||||||
|
@ -141,5 +181,8 @@ class Tree(object):
|
||||||
def classify(self, entry):
|
def classify(self, entry):
|
||||||
return self.root.classify(entry)
|
return self.root.classify(entry)
|
||||||
|
|
||||||
|
def predict(self, entry):
|
||||||
|
return self.root.predict(entry)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return str(self.root)
|
return str(self.root)
|
||||||
|
|
182
tree_bootstrapped.py
Normal file
182
tree_bootstrapped.py
Normal file
|
@ -0,0 +1,182 @@
|
||||||
|
import multiprocessing as mp
|
||||||
|
from question import Question
|
||||||
|
|
||||||
|
|
||||||
|
def unique_vals(dataset, indices, column):
|
||||||
|
return set([dataset[i].data[column] for i in indices])
|
||||||
|
|
||||||
|
|
||||||
|
def count_labels(dataset, indices):
|
||||||
|
counts = {}
|
||||||
|
for i in indices:
|
||||||
|
for label in dataset[i].label:
|
||||||
|
if label not in counts:
|
||||||
|
counts[label] = 1
|
||||||
|
else:
|
||||||
|
counts[label] += 1
|
||||||
|
return counts
|
||||||
|
|
||||||
|
|
||||||
|
def partition(dataset, indices, question):
|
||||||
|
matching, non_matching = [], []
|
||||||
|
|
||||||
|
for i in indices:
|
||||||
|
if question.match(dataset[i]):
|
||||||
|
matching.append(i)
|
||||||
|
else:
|
||||||
|
non_matching.append(i)
|
||||||
|
|
||||||
|
return matching, non_matching
|
||||||
|
|
||||||
|
|
||||||
|
def gini(dataset, indices):
|
||||||
|
counts = count_labels(dataset, indices)
|
||||||
|
impurity = 1
|
||||||
|
|
||||||
|
for label in counts:
|
||||||
|
prob = counts[label] / float(len(dataset))
|
||||||
|
impurity -= prob**2
|
||||||
|
|
||||||
|
return impurity
|
||||||
|
|
||||||
|
|
||||||
|
def info_gain(dataset, lid, rid, uncertainty):
|
||||||
|
p = float(len(lid)) / float(len(lid) + len(rid))
|
||||||
|
|
||||||
|
return uncertainty - p * gini(dataset, lid) - (1-p) * gini(dataset, rid)
|
||||||
|
|
||||||
|
|
||||||
|
def splitter(info):
|
||||||
|
question, dataset, indices, uncertainty = info
|
||||||
|
matching, non_matching = partition(dataset, indices, question)
|
||||||
|
if not matching or not non_matching:
|
||||||
|
return None
|
||||||
|
gain = info_gain(dataset, matching, non_matching, uncertainty)
|
||||||
|
return (gain, question, (matching, non_matching))
|
||||||
|
|
||||||
|
|
||||||
|
def find_best_split(fields, dataset, indices, uncertainty=None):
|
||||||
|
print("Splitting {} entries.".format(len(dataset)))
|
||||||
|
best_gain, best_question, best_split = 0, None, None
|
||||||
|
|
||||||
|
uncertainty = uncertainty or gini(dataset)
|
||||||
|
|
||||||
|
columns = len(fields)
|
||||||
|
|
||||||
|
for i in range(columns):
|
||||||
|
values = unique_vals(dataset, indices, i)
|
||||||
|
|
||||||
|
if len(indices) > 400:
|
||||||
|
# Parallelize best split search
|
||||||
|
cpus = mp.cpu_count()
|
||||||
|
if i == 0:
|
||||||
|
print("-- Using {} CPUs to parallelize the split search."
|
||||||
|
.format(cpus))
|
||||||
|
splits = []
|
||||||
|
for value in values:
|
||||||
|
question = Question(fields, i, value)
|
||||||
|
splits.append((question, dataset, indices, uncertainty))
|
||||||
|
|
||||||
|
chunk = max(int(len(splits)/(cpus*4)), 1)
|
||||||
|
with mp.Pool(cpus) as p:
|
||||||
|
for split in p.imap_unordered(splitter, splits,
|
||||||
|
chunksize=chunk):
|
||||||
|
if split is not None:
|
||||||
|
gain, question, branches = split
|
||||||
|
if gain > best_gain:
|
||||||
|
best_gain, best_question, best_split = \
|
||||||
|
gain, question, branches
|
||||||
|
else:
|
||||||
|
for value in values:
|
||||||
|
question = Question(fields, i, value)
|
||||||
|
|
||||||
|
matching, non_matching = partition(dataset, indices, question)
|
||||||
|
|
||||||
|
if not matching or not non_matching:
|
||||||
|
continue
|
||||||
|
|
||||||
|
gain = info_gain(dataset, matching, non_matching, uncertainty)
|
||||||
|
|
||||||
|
if gain > best_gain:
|
||||||
|
best_gain, best_question = gain, question
|
||||||
|
best_split = (matching, non_matching)
|
||||||
|
|
||||||
|
return best_gain, best_question, best_split
|
||||||
|
|
||||||
|
|
||||||
|
class Node(object):
|
||||||
|
def __init__(self, fields, dataset, bootstrap, level=0):
|
||||||
|
self.fields = fields
|
||||||
|
self.dataset = dataset
|
||||||
|
self.bootstrap = bootstrap
|
||||||
|
self.gini = gini(dataset, self.bootstrap)
|
||||||
|
self.build(level)
|
||||||
|
|
||||||
|
def build(self, level):
|
||||||
|
best_split = find_best_split(self.fields, self.dataset,
|
||||||
|
self.bootstrap, self.gini)
|
||||||
|
gain, question, branches = best_split
|
||||||
|
|
||||||
|
if not branches:
|
||||||
|
# Means we got 0 gain
|
||||||
|
print("Found a leaf at level {}".format(level))
|
||||||
|
self.predictions = count_labels(self.dataset, self.bootstrap)
|
||||||
|
self.is_leaf = True
|
||||||
|
return
|
||||||
|
|
||||||
|
left, right = branches
|
||||||
|
|
||||||
|
print("Found a level {} split:".format(level))
|
||||||
|
print(question)
|
||||||
|
print("Matching: {} entries\tNon-matching: {} entries".format(len(left), len(right))) # noqa
|
||||||
|
|
||||||
|
self.left_branch = Node(self.fields, self.dataset, left, level + 1)
|
||||||
|
self.right_branch = Node(self.fields, self.dataset, right, level + 1)
|
||||||
|
self.question = question
|
||||||
|
self.is_leaf = False
|
||||||
|
return
|
||||||
|
|
||||||
|
def classify(self, entry):
|
||||||
|
if self.is_leaf:
|
||||||
|
return self
|
||||||
|
|
||||||
|
if self.question.match(entry):
|
||||||
|
return self.left_branch.classify(entry)
|
||||||
|
else:
|
||||||
|
return self.right_branch.classify(entry)
|
||||||
|
|
||||||
|
def print(self, spacing=''):
|
||||||
|
if self.is_leaf:
|
||||||
|
s = spacing + "Predict: "
|
||||||
|
total = float(sum(self.predictions.values()))
|
||||||
|
probs = {}
|
||||||
|
for label in self.predictions:
|
||||||
|
prob = self.predictions[label] * 100 / total
|
||||||
|
probs[label] = "{:.2f}%".format(prob)
|
||||||
|
return s + str(probs)
|
||||||
|
|
||||||
|
s = spacing + ("(Gini: {:.2f}) {}\n"
|
||||||
|
.format(self.gini, str(self.question)))
|
||||||
|
s += spacing + "├─ True:\n"
|
||||||
|
s += self.left_branch.print(spacing + "│ ") + '\n'
|
||||||
|
s += spacing + "└─ False:\n"
|
||||||
|
s += self.right_branch.print(spacing + "│ ")
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.print()
|
||||||
|
|
||||||
|
|
||||||
|
class Tree(object):
|
||||||
|
def __init__(self, fields, dataset, bootstrap):
|
||||||
|
self.fields = fields
|
||||||
|
self.dataset = dataset
|
||||||
|
self.bootstrap = bootstrap
|
||||||
|
self.root = Node(self.fields, self.dataset, self.bootstrap)
|
||||||
|
|
||||||
|
def classify(self, entry):
|
||||||
|
return self.root.classify(entry)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return str(self.root)
|
|
@ -14,9 +14,9 @@ if __name__ == '__main__':
|
||||||
os.mkdir("output")
|
os.mkdir("output")
|
||||||
|
|
||||||
if not os.path.exists("output/tree_testing.txt"):
|
if not os.path.exists("output/tree_testing.txt"):
|
||||||
output = open("output/tree_testing.txt", 'w')
|
output = open("output/tree_testing.txt", 'w', encoding="utf-8")
|
||||||
else:
|
else:
|
||||||
output = open("output/tree_testing.txt", 'a')
|
output = open("output/tree_testing.txt", 'a', encoding="utf-8")
|
||||||
|
|
||||||
dataset, fields = read_stars()
|
dataset, fields = read_stars()
|
||||||
|
|
||||||
|
@ -26,7 +26,6 @@ if __name__ == '__main__':
|
||||||
t_start = timer()
|
t_start = timer()
|
||||||
|
|
||||||
split = int(len(dataset) * 0.65)
|
split = int(len(dataset) * 0.65)
|
||||||
split = 500
|
|
||||||
training, testing = dataset[:split], dataset[split + 1:]
|
training, testing = dataset[:split], dataset[split + 1:]
|
||||||
log("Training set: {} entries.".format(len(training)), output)
|
log("Training set: {} entries.".format(len(training)), output)
|
||||||
log("Testing set: {} entries.".format(len(testing)), output)
|
log("Testing set: {} entries.".format(len(testing)), output)
|
||||||
|
@ -41,9 +40,22 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
log("\n-- TEST --\n", output)
|
log("\n-- TEST --\n", output)
|
||||||
|
|
||||||
|
failures = 0
|
||||||
|
|
||||||
for entry in testing:
|
for entry in testing:
|
||||||
label = entry.label
|
label = entry.label
|
||||||
predict = tree.classify(entry)
|
predict = tree.predict(entry)
|
||||||
log("Actual: {}\tPredicted: {}".format(label, predict), output)
|
if predict not in label:
|
||||||
|
print("Actual: {}\tPredicted: {}".format(label, predict))
|
||||||
|
failures += 1
|
||||||
|
|
||||||
|
tested = len(testing)
|
||||||
|
success = tested - failures
|
||||||
|
s_rate = float(success)*100/float(tested)
|
||||||
|
|
||||||
|
log("\nSuccessfully predicted {} out of {} entries."
|
||||||
|
.format(success, tested), output)
|
||||||
|
|
||||||
|
log("Accuracy: {:.2f}%\nError: {:.2f}%".format(s_rate, 100-s_rate), output)
|
||||||
|
|
||||||
output.close()
|
output.close()
|
||||||
|
|
Loading…
Reference in a new issue