204 lines
6.3 KiB
Python
204 lines
6.3 KiB
Python
import multiprocessing as mp
|
|
from question import Question
|
|
|
|
|
|
def unique_vals(dataset, indices, column):
|
|
return set([dataset[i].data[column] for i in indices])
|
|
|
|
|
|
def count_labels(dataset, indices):
|
|
counts = {}
|
|
for i in indices:
|
|
for label in dataset[i].label:
|
|
if label not in counts:
|
|
counts[label] = 1
|
|
else:
|
|
counts[label] += 1
|
|
return counts
|
|
|
|
|
|
def partition(dataset, indices, question):
|
|
matching, non_matching = [], []
|
|
|
|
for i in indices:
|
|
if question.match(dataset[i]):
|
|
matching.append(i)
|
|
else:
|
|
non_matching.append(i)
|
|
|
|
return matching, non_matching
|
|
|
|
|
|
def gini(dataset, indices):
|
|
counts = count_labels(dataset, indices)
|
|
impurity = 1
|
|
|
|
for label in counts:
|
|
prob = counts[label] / float(len(indices))
|
|
impurity -= prob**2
|
|
|
|
return impurity
|
|
|
|
|
|
def info_gain(dataset, lid, rid, uncertainty):
|
|
p = float(len(lid)) / float(len(lid) + len(rid))
|
|
|
|
return uncertainty - p * gini(dataset, lid) - (1-p) * gini(dataset, rid)
|
|
|
|
|
|
def splitter(info):
|
|
question, dataset, indices, uncertainty = info
|
|
matching, non_matching = partition(dataset, indices, question)
|
|
if not matching or not non_matching:
|
|
return None
|
|
gain = info_gain(dataset, matching, non_matching, uncertainty)
|
|
return (gain, question, (matching, non_matching))
|
|
|
|
|
|
class Node(object):
|
|
def __init__(self, fields, dataset, bootstrap, level=0, out=True):
|
|
self.fields = fields
|
|
self.dataset = dataset
|
|
self.indices = bootstrap
|
|
self.out = out
|
|
self.gini = gini(dataset, self.indices)
|
|
self.build(level, out)
|
|
|
|
def build(self, level, out=True):
|
|
best_split = self.split(out)
|
|
gain, question, branches = best_split
|
|
|
|
if not branches:
|
|
# Means we got 0 gain
|
|
if out:
|
|
print("Found a leaf at level {}".format(level))
|
|
self.predictions = count_labels(self.dataset, self.indices)
|
|
self.is_leaf = True
|
|
return
|
|
|
|
left, right = branches
|
|
|
|
if out:
|
|
print("Found a level {} split:".format(level))
|
|
print(question)
|
|
print("Matching: {} entries\tNon-matching: {} entries".format(len(left), len(right)))
|
|
|
|
self.left_branch = Node(self.fields, self.dataset, left, level + 1, out)
|
|
self.right_branch = Node(self.fields, self.dataset, right, level + 1, out)
|
|
self.question = question
|
|
self.is_leaf = False
|
|
return
|
|
|
|
def split(self, out=True):
|
|
if out:
|
|
print("Splitting {} entries.".format(len(self.indices)))
|
|
best_gain, best_question, best_split = 0, None, None
|
|
|
|
uncertainty = self.gini or gini(self.dataset, self.indices)
|
|
|
|
cpus = mp.cpu_count()
|
|
columns = len(self.fields)
|
|
|
|
parallelize = len(self.indices) > 1000
|
|
|
|
if parallelize and out:
|
|
print("\n-- Using {} CPUs to parallelize the split search\n".format(cpus))
|
|
|
|
for i in range(columns):
|
|
values = unique_vals(self.dataset, self.indices, i)
|
|
|
|
if parallelize:
|
|
# Parallelize best split search
|
|
splits = []
|
|
for value in values:
|
|
question = Question(self.fields, i, value)
|
|
splits.append((question, self.dataset, self.indices, uncertainty))
|
|
|
|
chunk = max(int(len(splits)/(cpus*4)), 1)
|
|
with mp.Pool(cpus) as p:
|
|
for split in p.imap_unordered(splitter, splits, chunksize=chunk):
|
|
if split is not None:
|
|
gain, question, branches = split
|
|
if gain > best_gain:
|
|
best_gain, best_question, best_split = gain, question, branches
|
|
else:
|
|
for value in values:
|
|
question = Question(self.fields, i, value)
|
|
|
|
matching, non_matching = partition(self.dataset, self.indices, question)
|
|
|
|
if not matching or not non_matching:
|
|
continue
|
|
|
|
gain = info_gain(self.dataset, matching, non_matching, uncertainty)
|
|
|
|
if gain > best_gain:
|
|
best_gain, best_question = gain, question
|
|
best_split = (matching, non_matching)
|
|
|
|
return best_gain, best_question, best_split
|
|
|
|
def classify(self, entry):
|
|
if self.is_leaf:
|
|
return self
|
|
|
|
if self.question.match(entry):
|
|
return self.left_branch.classify(entry)
|
|
else:
|
|
return self.right_branch.classify(entry)
|
|
|
|
def predict(self, entry):
|
|
successes = []
|
|
predict = self.classify(entry).predictions.copy()
|
|
total = float(sum(predict.values()))
|
|
for key, value in predict.items():
|
|
predict[key] = float(predict[key]) / total
|
|
|
|
for label in entry.label:
|
|
if label in predict:
|
|
success = predict[label]
|
|
successes.append(success)
|
|
|
|
return sum(successes), predict
|
|
|
|
def print(self, spacing=''):
|
|
if self.is_leaf:
|
|
s = spacing + "Predict: "
|
|
total = float(sum(self.predictions.values()))
|
|
probs = {}
|
|
for label in self.predictions:
|
|
prob = self.predictions[label] * 100 / total
|
|
probs[label] = "{:.2f}%".format(prob)
|
|
return s + str(probs)
|
|
|
|
s = spacing + ("(Gini: {:.2f}) {}\n"
|
|
.format(self.gini, str(self.question)))
|
|
s += spacing + "├─ True:\n"
|
|
s += self.left_branch.print(spacing + "│ ") + '\n'
|
|
s += spacing + "└─ False:\n"
|
|
s += self.right_branch.print(spacing + " ")
|
|
|
|
return s
|
|
|
|
def __str__(self):
|
|
return self.print()
|
|
|
|
|
|
class Tree(object):
|
|
def __init__(self, fields, dataset, bootstrap, out=True):
|
|
self.fields = fields
|
|
self.dataset = dataset
|
|
self.indices = bootstrap
|
|
# Out of bag
|
|
self.oob = [i for i in range(len(dataset)) if i not in bootstrap]
|
|
|
|
self.root = Node(self.fields, self.dataset, self.indices, out=out)
|
|
|
|
def classify(self, entry):
|
|
return self.root.classify(entry)
|
|
|
|
def predict(self, entry):
|
|
return self.root.predict(entry)
|
|
|
|
def __str__(self):
|
|
return str(self.root)
|