🐞 Fixed Out-of-bag error calculation
📝 Uncommented the single tree training in Forest testing 📝 Silenced the single tree training output in Forest testing
This commit is contained in:
parent
6f4987e179
commit
0f0ae62cd5
2 changed files with 11 additions and 9 deletions
|
@ -26,11 +26,11 @@ class Forest(object):
|
|||
|
||||
oob = set(oob)
|
||||
|
||||
votes = {}
|
||||
successes = 0
|
||||
|
||||
for i in oob:
|
||||
entry = self.dataset[i]
|
||||
votes = {}
|
||||
|
||||
for tree in self.trees:
|
||||
if i not in tree.indices:
|
||||
|
@ -40,6 +40,7 @@ class Forest(object):
|
|||
votes[key] = predict[key]
|
||||
else:
|
||||
votes[key] += predict[key]
|
||||
|
||||
majority = max(votes.items(), key=operator.itemgetter(1))[0]
|
||||
if majority in entry.label:
|
||||
successes += 1
|
||||
|
|
|
@ -2,7 +2,7 @@ import os
|
|||
import random
|
||||
from timeit import default_timer as timer
|
||||
from star_reader import read_stars
|
||||
# from tree_bootstrapped import Tree
|
||||
from tree_bootstrapped import Tree
|
||||
from forest import Forest
|
||||
|
||||
|
||||
|
@ -28,7 +28,7 @@ if __name__ == '__main__':
|
|||
|
||||
random.shuffle(dataset)
|
||||
|
||||
cutoff = 0.4
|
||||
cutoff = 0.25
|
||||
forest_size = 10
|
||||
|
||||
split = int(len(dataset) * cutoff)
|
||||
|
@ -36,7 +36,6 @@ if __name__ == '__main__':
|
|||
|
||||
log("\n----------\n", output)
|
||||
|
||||
"""
|
||||
log("\n-- TREE TRAINING --\n", output)
|
||||
|
||||
log("Training Tree...", output)
|
||||
|
@ -66,7 +65,6 @@ if __name__ == '__main__':
|
|||
log("\nTested {} entries.".format(tested), output)
|
||||
|
||||
log("Accuracy: {:.2f}%\nError: {:.2f}%".format(s_rate, 100-s_rate), output)
|
||||
"""
|
||||
|
||||
log("\n-- FOREST TRAINING --\n", output)
|
||||
|
||||
|
@ -79,6 +77,9 @@ if __name__ == '__main__':
|
|||
|
||||
forest = Forest(fields, training, forest_size)
|
||||
|
||||
t_end = timer()
|
||||
log("Training complete.\nElapsed time: {:.3f}\n".format(t_end - t_start), output)
|
||||
|
||||
log("\n-- FOREST TEST --\n", output)
|
||||
|
||||
total_success = 0
|
||||
|
@ -99,6 +100,6 @@ if __name__ == '__main__':
|
|||
|
||||
error = forest.error_oob()
|
||||
|
||||
log("\nAverage error Out-of-Bag: {:.2f}%".format(error*100), output)
|
||||
log("\nError Out-of-Bag: {:.2f}%".format(error*100), output)
|
||||
|
||||
output.close()
|
||||
|
|
Loading…
Reference in a new issue