random-forest/forest.py
Vylion 0f0ae62cd5 🐞 Fixed Out-of-bag error calculation
📝 Uncommented the single tree training in Forest testing
📝 Silenced the single tree training output in Forest testing
2019-04-26 22:30:09 +02:00

60 lines
1.8 KiB
Python

import random
import operator
from tree_bootstrapped import Tree
class Forest(object):
def __init__(self, fields, dataset, size, tree_out=False, out=True):
self.fields = fields
self.dataset = dataset
self.size = size
self.trees = []
for i in range(size):
n = len(dataset)
bootstrap = [random.randrange(n) for j in range(n)]
tree = Tree(self.fields, self.dataset, bootstrap, (tree_out and out))
self.trees.append(tree)
if out:
print("\nPlanted tree {}".format(i))
def error_oob(self):
oob = []
for tree in self.trees:
oob.extend(tree.oob)
oob = set(oob)
successes = 0
for i in oob:
entry = self.dataset[i]
votes = {}
for tree in self.trees:
if i not in tree.indices:
predict = tree.classify(entry).predictions
for key, value in predict.items():
if key not in votes:
votes[key] = predict[key]
else:
votes[key] += predict[key]
majority = max(votes.items(), key=operator.itemgetter(1))[0]
if majority in entry.label:
successes += 1
return 1-(float(successes)/float(len(oob)))
def predict(self, entry):
votes = {}
for tree in self.trees:
predict = tree.classify(entry).predictions
for key, value in predict.items():
if key not in votes:
votes[key] = predict[key]
else:
votes[key] += predict[key]
majority = max(votes.items(), key=operator.itemgetter(1))[0]
return majority