🐞 Fixed Out-of-bag error calculation
📝 Uncommented the single tree training in Forest testing 📝 Silenced the single tree training output in Forest testing
This commit is contained in:
parent
6f4987e179
commit
0f0ae62cd5
2 changed files with 11 additions and 9 deletions
|
@ -26,11 +26,11 @@ class Forest(object):
|
||||||
|
|
||||||
oob = set(oob)
|
oob = set(oob)
|
||||||
|
|
||||||
votes = {}
|
|
||||||
successes = 0
|
successes = 0
|
||||||
|
|
||||||
for i in oob:
|
for i in oob:
|
||||||
entry = self.dataset[i]
|
entry = self.dataset[i]
|
||||||
|
votes = {}
|
||||||
|
|
||||||
for tree in self.trees:
|
for tree in self.trees:
|
||||||
if i not in tree.indices:
|
if i not in tree.indices:
|
||||||
|
@ -40,9 +40,10 @@ class Forest(object):
|
||||||
votes[key] = predict[key]
|
votes[key] = predict[key]
|
||||||
else:
|
else:
|
||||||
votes[key] += predict[key]
|
votes[key] += predict[key]
|
||||||
majority = max(votes.items(), key=operator.itemgetter(1))[0]
|
|
||||||
if majority in entry.label:
|
majority = max(votes.items(), key=operator.itemgetter(1))[0]
|
||||||
successes += 1
|
if majority in entry.label:
|
||||||
|
successes += 1
|
||||||
|
|
||||||
return 1-(float(successes)/float(len(oob)))
|
return 1-(float(successes)/float(len(oob)))
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import os
|
||||||
import random
|
import random
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
from star_reader import read_stars
|
from star_reader import read_stars
|
||||||
# from tree_bootstrapped import Tree
|
from tree_bootstrapped import Tree
|
||||||
from forest import Forest
|
from forest import Forest
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
random.shuffle(dataset)
|
random.shuffle(dataset)
|
||||||
|
|
||||||
cutoff = 0.4
|
cutoff = 0.25
|
||||||
forest_size = 10
|
forest_size = 10
|
||||||
|
|
||||||
split = int(len(dataset) * cutoff)
|
split = int(len(dataset) * cutoff)
|
||||||
|
@ -36,7 +36,6 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
log("\n----------\n", output)
|
log("\n----------\n", output)
|
||||||
|
|
||||||
"""
|
|
||||||
log("\n-- TREE TRAINING --\n", output)
|
log("\n-- TREE TRAINING --\n", output)
|
||||||
|
|
||||||
log("Training Tree...", output)
|
log("Training Tree...", output)
|
||||||
|
@ -66,7 +65,6 @@ if __name__ == '__main__':
|
||||||
log("\nTested {} entries.".format(tested), output)
|
log("\nTested {} entries.".format(tested), output)
|
||||||
|
|
||||||
log("Accuracy: {:.2f}%\nError: {:.2f}%".format(s_rate, 100-s_rate), output)
|
log("Accuracy: {:.2f}%\nError: {:.2f}%".format(s_rate, 100-s_rate), output)
|
||||||
"""
|
|
||||||
|
|
||||||
log("\n-- FOREST TRAINING --\n", output)
|
log("\n-- FOREST TRAINING --\n", output)
|
||||||
|
|
||||||
|
@ -79,6 +77,9 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
forest = Forest(fields, training, forest_size)
|
forest = Forest(fields, training, forest_size)
|
||||||
|
|
||||||
|
t_end = timer()
|
||||||
|
log("Training complete.\nElapsed time: {:.3f}\n".format(t_end - t_start), output)
|
||||||
|
|
||||||
log("\n-- FOREST TEST --\n", output)
|
log("\n-- FOREST TEST --\n", output)
|
||||||
|
|
||||||
total_success = 0
|
total_success = 0
|
||||||
|
@ -99,6 +100,6 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
error = forest.error_oob()
|
error = forest.error_oob()
|
||||||
|
|
||||||
log("\nAverage error Out-of-Bag: {:.2f}%".format(error*100), output)
|
log("\nError Out-of-Bag: {:.2f}%".format(error*100), output)
|
||||||
|
|
||||||
output.close()
|
output.close()
|
||||||
|
|
Loading…
Reference in a new issue