diff --git a/tests/benchmark.py b/tests/benchmark.py
index 647afd5..804500f 100644
--- a/tests/benchmark.py
+++ b/tests/benchmark.py
@@ -117,7 +117,6 @@ def cp_vs_python_sparse(path: str = 'benchmark_results_001_sparse.json'):
results.append({'benchmark': 'Copapy', 'iter_size': iter_size, 'elapsed_time': elapsed_cp, 'sum_size': sum_size, 'v_size': v_size})
-
v1 = cp.vector(float(v) for v in range(v_size))
v2 = cp.vector(float(v) for v in [5]*v_size)
@@ -158,6 +157,7 @@ def plot_results(path: str):
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict
+ import matplotlib as mpl
# Load the benchmark results
with open(path, 'r') as f:
@@ -185,30 +185,97 @@ def plot_results(path: str):
v_sizes_set = sorted(set(v for benchmark_data in medians_by_benchmark.values() for v in benchmark_data.keys()))
# Create the plot
- plt.figure(figsize=(10, 6))
+ plt.figure(figsize=(6, 4))
for benchmark in benchmarks:
if benchmark != 'Python':
v_sizes = sorted(medians_by_benchmark[benchmark].keys())
elapsed_times = [medians_by_benchmark[benchmark][v] for v in v_sizes]
- plt.plot(v_sizes, elapsed_times, '.', label=benchmark)
+ plt.plot(v_sizes, elapsed_times, '.', label=benchmark, markersize=10)
plt.xlabel('Vector Size (v_size)')
plt.ylabel('Elapsed Time (seconds)')
#plt.title('Benchmark Results: Elapsed Time vs Vector Size')
- plt.legend()
+ plt.legend(frameon=False)
#plt.grid(True, alpha=0.3)
plt.ylim(bottom=0)
plt.tight_layout()
# Save to PNG
- plt.savefig(path.replace('.json', '') + '.png', dpi=300)
+ mpl.rcParams['svg.fonttype'] = 'none'
+ save_svg_with_theme_styles(plt, path.replace('.json', '') + '.svg')
print("Plot saved")
+def save_svg_with_theme_styles(pyplot_obj, path):
+ import io
+ import re
+ """
+ Takes a pyplot object (typically `plt`) or a figure, captures its SVG output,
+ injects theme-based CSS, and writes to disk.
+ """
+
+ # --- Step 1: Capture SVG to memory ---
+ buf = io.StringIO()
+
+ # pyplot_obj can be a module (plt) or a Figure instance
+ if hasattr(pyplot_obj, "gcf"):
+ fig = pyplot_obj.gcf()
+ else:
+ fig = pyplot_obj
+
+ fig.savefig(buf, format="svg", dpi=150, transparent=True)
+ svg_data = buf.getvalue()
+ buf.close()
+
+ # --- Step 2: Theme CSS to inject ---
+ theme_css = """
+
+ """
+
+ # --- Step 3: Inject CSS right after