From 0f0ae62cd53f2ce349caf4d9827f00127f57103f Mon Sep 17 00:00:00 2001 From: Vylion Date: Fri, 26 Apr 2019 22:30:09 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9E=20Fixed=20Out-of-bag=20error=20cal?= =?UTF-8?q?culation=20=F0=9F=93=9D=20Uncommented=20the=20single=20tree=20t?= =?UTF-8?q?raining=20in=20Forest=20testing=20=F0=9F=93=9D=20Silenced=20the?= =?UTF-8?q?=20single=20tree=20training=20output=20in=20Forest=20testing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- forest.py | 9 +++++---- forest_tester.py | 11 ++++++----- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/forest.py b/forest.py index 5cab1b0..f99b525 100644 --- a/forest.py +++ b/forest.py @@ -26,11 +26,11 @@ class Forest(object): oob = set(oob) - votes = {} successes = 0 for i in oob: entry = self.dataset[i] + votes = {} for tree in self.trees: if i not in tree.indices: @@ -40,9 +40,10 @@ class Forest(object): votes[key] = predict[key] else: votes[key] += predict[key] - majority = max(votes.items(), key=operator.itemgetter(1))[0] - if majority in entry.label: - successes += 1 + + majority = max(votes.items(), key=operator.itemgetter(1))[0] + if majority in entry.label: + successes += 1 return 1-(float(successes)/float(len(oob))) diff --git a/forest_tester.py b/forest_tester.py index df2e530..43a1caf 100644 --- a/forest_tester.py +++ b/forest_tester.py @@ -2,7 +2,7 @@ import os import random from timeit import default_timer as timer from star_reader import read_stars -# from tree_bootstrapped import Tree +from tree_bootstrapped import Tree from forest import Forest @@ -28,7 +28,7 @@ if __name__ == '__main__': random.shuffle(dataset) - cutoff = 0.4 + cutoff = 0.25 forest_size = 10 split = int(len(dataset) * cutoff) @@ -36,7 +36,6 @@ if __name__ == '__main__': log("\n----------\n", output) - """ log("\n-- TREE TRAINING --\n", output) log("Training Tree...", output) @@ -66,7 +65,6 @@ if __name__ == '__main__': log("\nTested {} entries.".format(tested), output) log("Accuracy: {:.2f}%\nError: {:.2f}%".format(s_rate, 100-s_rate), output) - """ log("\n-- FOREST TRAINING --\n", output) @@ -79,6 +77,9 @@ if __name__ == '__main__': forest = Forest(fields, training, forest_size) + t_end = timer() + log("Training complete.\nElapsed time: {:.3f}\n".format(t_end - t_start), output) + log("\n-- FOREST TEST --\n", output) total_success = 0 @@ -99,6 +100,6 @@ if __name__ == '__main__': error = forest.error_oob() - log("\nAverage error Out-of-Bag: {:.2f}%".format(error*100), output) + log("\nError Out-of-Bag: {:.2f}%".format(error*100), output) output.close()