diff --git a/bashplotlib/scatterplot.py b/bashplotlib/scatterplot.py index ddaf3cc..b9a1948 100644 --- a/bashplotlib/scatterplot.py +++ b/bashplotlib/scatterplot.py @@ -17,7 +17,8 @@ def get_scale(series, is_y=False, steps=20): min_val = min(series) max_val = max(series) scaled_series = [] - for x in drange(min_val, max_val, (max_val - min_val) / steps): + for x in drange(min_val, max_val, (max_val - min_val) / steps, + include_stop=True): if x > 0 and scaled_series and max(scaled_series) < 0: scaled_series.append(0.0) scaled_series.append(x) @@ -45,9 +46,13 @@ def plot_scatter(f, xs, ys, size, pch, colour, title): if isinstance(f, str): f = open(f) - data = [tuple(map(float, line.strip().split(','))) for line in f] - xs = [i[0] for i in data] - ys = [i[1] for i in data] + data = [tuple(line.strip().split(',')) for line in f] + xs = [float(i[0]) for i in data] + ys = [float(i[1]) for i in data] + if len(data[0]) > 2: + cs = [i[2].strip() for i in data] + else: + cs = None else: xs = [float(str(row).strip()) for row in open(xs)] ys = [float(str(row).strip()) for row in open(ys)] @@ -65,16 +70,11 @@ def plot_scatter(f, xs, ys, size, pch, colour, title): for (i, (xp, yp)) in enumerate(zip(xs, ys)): if xp <= x and yp >= y and (xp, yp) not in plotted: point = pch - #point = str(i) plotted.add((xp, yp)) - if x == 0 and y == 0: - point = "o" - elif x == 0: - point = "|" - elif y == 0: - point = "-" + if cs: + colour = cs[i] printcolour(point, True, colour) - print("|") + print(" |") print("-" * (2 * len(get_scale(xs, False, size)) + 2))