🐞 Fixed Out-of-bag error calculation

📝 Uncommented the single tree training in Forest testing
📝 Silenced the single tree training output in Forest testing
This commit is contained in:
Vylion 2019-04-26 22:30:09 +02:00
parent 6f4987e179
commit 0f0ae62cd5
2 changed files with 11 additions and 9 deletions

View file

@ -26,11 +26,11 @@ class Forest(object):
oob = set(oob) oob = set(oob)
votes = {}
successes = 0 successes = 0
for i in oob: for i in oob:
entry = self.dataset[i] entry = self.dataset[i]
votes = {}
for tree in self.trees: for tree in self.trees:
if i not in tree.indices: if i not in tree.indices:
@ -40,9 +40,10 @@ class Forest(object):
votes[key] = predict[key] votes[key] = predict[key]
else: else:
votes[key] += predict[key] votes[key] += predict[key]
majority = max(votes.items(), key=operator.itemgetter(1))[0]
if majority in entry.label: majority = max(votes.items(), key=operator.itemgetter(1))[0]
successes += 1 if majority in entry.label:
successes += 1
return 1-(float(successes)/float(len(oob))) return 1-(float(successes)/float(len(oob)))

View file

@ -2,7 +2,7 @@ import os
import random import random
from timeit import default_timer as timer from timeit import default_timer as timer
from star_reader import read_stars from star_reader import read_stars
# from tree_bootstrapped import Tree from tree_bootstrapped import Tree
from forest import Forest from forest import Forest
@ -28,7 +28,7 @@ if __name__ == '__main__':
random.shuffle(dataset) random.shuffle(dataset)
cutoff = 0.4 cutoff = 0.25
forest_size = 10 forest_size = 10
split = int(len(dataset) * cutoff) split = int(len(dataset) * cutoff)
@ -36,7 +36,6 @@ if __name__ == '__main__':
log("\n----------\n", output) log("\n----------\n", output)
"""
log("\n-- TREE TRAINING --\n", output) log("\n-- TREE TRAINING --\n", output)
log("Training Tree...", output) log("Training Tree...", output)
@ -66,7 +65,6 @@ if __name__ == '__main__':
log("\nTested {} entries.".format(tested), output) log("\nTested {} entries.".format(tested), output)
log("Accuracy: {:.2f}%\nError: {:.2f}%".format(s_rate, 100-s_rate), output) log("Accuracy: {:.2f}%\nError: {:.2f}%".format(s_rate, 100-s_rate), output)
"""
log("\n-- FOREST TRAINING --\n", output) log("\n-- FOREST TRAINING --\n", output)
@ -79,6 +77,9 @@ if __name__ == '__main__':
forest = Forest(fields, training, forest_size) forest = Forest(fields, training, forest_size)
t_end = timer()
log("Training complete.\nElapsed time: {:.3f}\n".format(t_end - t_start), output)
log("\n-- FOREST TEST --\n", output) log("\n-- FOREST TEST --\n", output)
total_success = 0 total_success = 0
@ -99,6 +100,6 @@ if __name__ == '__main__':
error = forest.error_oob() error = forest.error_oob()
log("\nAverage error Out-of-Bag: {:.2f}%".format(error*100), output) log("\nError Out-of-Bag: {:.2f}%".format(error*100), output)
output.close() output.close()