📻 Initial commit
This commit is contained in:
commit
3b6e5f642e
7 changed files with 119926 additions and 0 deletions
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
__pycache__/*
|
||||||
|
output/*
|
||||||
|
.vscode/*
|
119615
hygdata_v3.csv
Normal file
119615
hygdata_v3.csv
Normal file
File diff suppressed because it is too large
Load diff
26
question.py
Normal file
26
question.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
|
||||||
|
def is_numeric(value):
|
||||||
|
# Test if a value is numeric
|
||||||
|
return isinstance(value, int) or isinstance(value, float)
|
||||||
|
|
||||||
|
|
||||||
|
class Question(object):
|
||||||
|
def __init__(self, fields, pos, value):
|
||||||
|
self.fields = fields
|
||||||
|
self.pos = pos
|
||||||
|
self.value = value
|
||||||
|
self.numeric = is_numeric(value)
|
||||||
|
|
||||||
|
def match(self, entry):
|
||||||
|
val = entry.data[self.pos]
|
||||||
|
|
||||||
|
if self.numeric:
|
||||||
|
return val and val > self.value
|
||||||
|
else:
|
||||||
|
return val == self.value
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
condition = self.numeric and ">" or "="
|
||||||
|
field = self.fields[self.pos]
|
||||||
|
|
||||||
|
return "Is {f} {cond} {val}?".format(f=field, cond=condition, val=self.value) # noqa
|
15
star.py
Normal file
15
star.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
class Star(object):
|
||||||
|
def __init__(self, label, display_name, data, fields):
|
||||||
|
self.label = label
|
||||||
|
self.name = display_name
|
||||||
|
data_list = []
|
||||||
|
for field in fields:
|
||||||
|
data_list.append(data[field])
|
||||||
|
self.data = data_list
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
if len(self.label) > 1:
|
||||||
|
classification = ' or '.join(self.label)
|
||||||
|
else:
|
||||||
|
classification = self.label
|
||||||
|
return 'Star {} {} of spectral type {}'.format(self.name, self.data, classification)
|
73
star_reader.py
Normal file
73
star_reader.py
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
import csv
|
||||||
|
from timeit import default_timer as timer
|
||||||
|
from star import Star
|
||||||
|
|
||||||
|
STAR_CLASSES = 'OBAFGKMC'
|
||||||
|
KEPT_DATA = ['rv', 'absmag', 'ci', 'lum']
|
||||||
|
|
||||||
|
|
||||||
|
def make_star(header, row, fields=None):
|
||||||
|
data = {}
|
||||||
|
types = []
|
||||||
|
|
||||||
|
for field, value in zip(header, row):
|
||||||
|
try:
|
||||||
|
num = float(value)
|
||||||
|
if num == int(num):
|
||||||
|
num = int(num)
|
||||||
|
value = num
|
||||||
|
except ValueError:
|
||||||
|
if value == '':
|
||||||
|
value = None
|
||||||
|
|
||||||
|
if field == 'dist' and value >= 100000:
|
||||||
|
# Discarding star with dubious value
|
||||||
|
return None
|
||||||
|
|
||||||
|
if field == 'spect':
|
||||||
|
if value is None:
|
||||||
|
# Discarding unclassified star
|
||||||
|
return None
|
||||||
|
|
||||||
|
type_list = value.split('/')
|
||||||
|
types = []
|
||||||
|
for sp_type in type_list:
|
||||||
|
if sp_type and sp_type[0] in STAR_CLASSES:
|
||||||
|
types.append(sp_type[0])
|
||||||
|
value = ''.join(set(types))
|
||||||
|
|
||||||
|
data[field] = value
|
||||||
|
|
||||||
|
display_name = data['proper'] or data['bf'] or ('ID ' + str(data['id']))
|
||||||
|
fields = fields or header
|
||||||
|
|
||||||
|
return Star(data['spect'], display_name, data, fields)
|
||||||
|
|
||||||
|
|
||||||
|
def read_stars(fields=KEPT_DATA):
|
||||||
|
print("Parsing stars...")
|
||||||
|
star_list = []
|
||||||
|
header = None
|
||||||
|
|
||||||
|
t_start = timer()
|
||||||
|
|
||||||
|
with open('hygdata_v3.csv', 'r') as csv_file:
|
||||||
|
reader = csv.reader(csv_file)
|
||||||
|
header = next(reader)
|
||||||
|
|
||||||
|
for row in reader:
|
||||||
|
star = make_star(header, row, fields)
|
||||||
|
if star is not None:
|
||||||
|
star_list.append(star)
|
||||||
|
|
||||||
|
csv_file.close()
|
||||||
|
|
||||||
|
t_end = timer()
|
||||||
|
|
||||||
|
print("Parsed {} stars.\nElapsed time: {:.3f}\n".format(len(star_list), t_end-t_start)) # noqa
|
||||||
|
|
||||||
|
return star_list, fields or header
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
read_stars()
|
145
tree.py
Normal file
145
tree.py
Normal file
|
@ -0,0 +1,145 @@
|
||||||
|
|
||||||
|
from question import Question
|
||||||
|
|
||||||
|
|
||||||
|
def unique_vals(dataset, column):
|
||||||
|
return set([entry.data[column] for entry in dataset])
|
||||||
|
|
||||||
|
|
||||||
|
def count_labels(dataset):
|
||||||
|
counts = {}
|
||||||
|
for entry in dataset:
|
||||||
|
for label in entry.label:
|
||||||
|
if label not in counts:
|
||||||
|
counts[label] = 1
|
||||||
|
else:
|
||||||
|
counts[label] += 1
|
||||||
|
return counts
|
||||||
|
|
||||||
|
|
||||||
|
def partition(dataset, question):
|
||||||
|
matching, non_matching = [], []
|
||||||
|
|
||||||
|
for entry in dataset:
|
||||||
|
if question.match(entry):
|
||||||
|
matching.append(entry)
|
||||||
|
else:
|
||||||
|
non_matching.append(entry)
|
||||||
|
|
||||||
|
return matching, non_matching
|
||||||
|
|
||||||
|
|
||||||
|
def gini(dataset):
|
||||||
|
counts = count_labels(dataset)
|
||||||
|
impurity = 1
|
||||||
|
|
||||||
|
for label in counts:
|
||||||
|
prob = counts[label] / float(len(dataset))
|
||||||
|
impurity -= prob**2
|
||||||
|
|
||||||
|
return impurity
|
||||||
|
|
||||||
|
|
||||||
|
def info_gain(left_set, right_set, uncertainty):
|
||||||
|
p = float(len(left_set)) / float(len(left_set) + len(right_set))
|
||||||
|
|
||||||
|
return uncertainty - p * gini(left_set) - (1-p) * gini(right_set)
|
||||||
|
|
||||||
|
|
||||||
|
def find_best_split(fields, dataset, uncertainty=None):
|
||||||
|
best_gain, best_question, best_split = 0, None, None
|
||||||
|
|
||||||
|
uncertainty = uncertainty or gini(dataset)
|
||||||
|
|
||||||
|
columns = len(dataset[0].data)
|
||||||
|
|
||||||
|
for i in range(columns):
|
||||||
|
values = unique_vals(dataset, i)
|
||||||
|
for value in values:
|
||||||
|
question = Question(fields, i, value)
|
||||||
|
|
||||||
|
matching, non_matching = partition(dataset, question)
|
||||||
|
|
||||||
|
if not matching or not non_matching:
|
||||||
|
continue
|
||||||
|
|
||||||
|
gain = info_gain(matching, non_matching, uncertainty)
|
||||||
|
|
||||||
|
if gain > best_gain:
|
||||||
|
best_gain, best_question = gain, question
|
||||||
|
best_split = (matching, non_matching)
|
||||||
|
|
||||||
|
return best_gain, best_question, best_split
|
||||||
|
|
||||||
|
|
||||||
|
class Node(object):
|
||||||
|
def __init__(self, fields, dataset, level=0):
|
||||||
|
self.fields = fields
|
||||||
|
self.gini = gini(dataset)
|
||||||
|
self.build(dataset, level)
|
||||||
|
|
||||||
|
def build(self, dataset, level):
|
||||||
|
best_split = find_best_split(self.fields, dataset, self.gini)
|
||||||
|
gain, question, branches = best_split
|
||||||
|
|
||||||
|
if not branches:
|
||||||
|
# Means we got 0 gain
|
||||||
|
print("Found a leaf at level {}".format(level))
|
||||||
|
self.predictions = count_labels(dataset)
|
||||||
|
self.is_leaf = True
|
||||||
|
return
|
||||||
|
|
||||||
|
left, right = branches
|
||||||
|
|
||||||
|
print("Found a level {} split:".format(level))
|
||||||
|
print(question)
|
||||||
|
print("Matching: {} entries\tNon-matching: {} entries".format(len(left), len(right))) # noqa
|
||||||
|
|
||||||
|
self.left_branch = Node(self.fields, left, level + 1)
|
||||||
|
self.right_branch = Node(self.fields, right, level + 1)
|
||||||
|
self.question = question
|
||||||
|
self.is_leaf = False
|
||||||
|
return
|
||||||
|
|
||||||
|
def classify(self, entry):
|
||||||
|
if self.is_leaf:
|
||||||
|
return self
|
||||||
|
|
||||||
|
if self.question.match(entry):
|
||||||
|
return self.left_branch.classify(entry)
|
||||||
|
else:
|
||||||
|
return self.right_branch.classify(entry)
|
||||||
|
|
||||||
|
def print(self, spacing=''):
|
||||||
|
if self.is_leaf:
|
||||||
|
s = spacing + "Predict: "
|
||||||
|
total = float(sum(self.predictions.values()))
|
||||||
|
probs = {}
|
||||||
|
for label in self.predictions:
|
||||||
|
prob = self.predictions[label] * 100 / total
|
||||||
|
probs[label] = "{:.2f}%".format(prob)
|
||||||
|
return s + str(probs)
|
||||||
|
|
||||||
|
s = spacing + str(self.question) + '\n'
|
||||||
|
s += spacing + "-> True:\n"
|
||||||
|
s += self.left_branch.print(spacing + " ") + '\n'
|
||||||
|
s += spacing + "-> False:\n"
|
||||||
|
s += self.right_branch.print(spacing + " ")
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.print()
|
||||||
|
|
||||||
|
|
||||||
|
class Tree(object):
|
||||||
|
def __init__(self, fields, dataset):
|
||||||
|
self.fields = fields
|
||||||
|
self.dataset = dataset
|
||||||
|
self.root = Node(self.fields, self.dataset)
|
||||||
|
|
||||||
|
def classify(self, entry):
|
||||||
|
return self.root.classify(entry)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return str(self.root)
|
49
tree_tester.py
Normal file
49
tree_tester.py
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
import os
|
||||||
|
from timeit import default_timer as timer
|
||||||
|
from star_reader import read_stars
|
||||||
|
from tree import Tree
|
||||||
|
|
||||||
|
|
||||||
|
def log(s, open_file):
|
||||||
|
print(s)
|
||||||
|
open_file.write(str(s) + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
if not os.path.exists("output"):
|
||||||
|
os.mkdir("output")
|
||||||
|
|
||||||
|
if not os.path.exists("output/tree_testing.txt"):
|
||||||
|
output = open("output/tree_testing.txt", 'w')
|
||||||
|
else:
|
||||||
|
output = open("output/tree_testing.txt", 'a')
|
||||||
|
|
||||||
|
dataset, fields = read_stars()
|
||||||
|
|
||||||
|
log("\n----------\n", output)
|
||||||
|
|
||||||
|
log("Training Tree...", output)
|
||||||
|
t_start = timer()
|
||||||
|
|
||||||
|
split = int(len(dataset) * 0.65)
|
||||||
|
split = 500
|
||||||
|
training, testing = dataset[:split], dataset[split + 1:]
|
||||||
|
log("Training set: {} entries.".format(len(training)), output)
|
||||||
|
log("Testing set: {} entries.".format(len(testing)), output)
|
||||||
|
|
||||||
|
tree = Tree(fields, training)
|
||||||
|
|
||||||
|
t_end = timer()
|
||||||
|
timestamp = "Training complete.\nElapsed time: {:.3f}\n"
|
||||||
|
log(timestamp.format(t_end - t_start), output)
|
||||||
|
|
||||||
|
log(tree, output)
|
||||||
|
|
||||||
|
log("\n-- TEST --\n", output)
|
||||||
|
|
||||||
|
for entry in testing:
|
||||||
|
label = entry.label
|
||||||
|
predict = tree.classify(entry)
|
||||||
|
log("Actual: {}\tPredicted: {}".format(label, predict), output)
|
||||||
|
|
||||||
|
output.close()
|
Loading…
Reference in a new issue