From 42490f80b816491799b241b8f3921693a91615f6 Mon Sep 17 00:00:00 2001 From: Nicolas Kruse Date: Sat, 13 Dec 2025 20:50:33 +0100 Subject: [PATCH] benchmark script updated to generate dark/bright mode svgs --- tests/benchmark.py | 81 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 74 insertions(+), 7 deletions(-) diff --git a/tests/benchmark.py b/tests/benchmark.py index 647afd5..804500f 100644 --- a/tests/benchmark.py +++ b/tests/benchmark.py @@ -117,7 +117,6 @@ def cp_vs_python_sparse(path: str = 'benchmark_results_001_sparse.json'): results.append({'benchmark': 'Copapy', 'iter_size': iter_size, 'elapsed_time': elapsed_cp, 'sum_size': sum_size, 'v_size': v_size}) - v1 = cp.vector(float(v) for v in range(v_size)) v2 = cp.vector(float(v) for v in [5]*v_size) @@ -158,6 +157,7 @@ def plot_results(path: str): import matplotlib.pyplot as plt import numpy as np from collections import defaultdict + import matplotlib as mpl # Load the benchmark results with open(path, 'r') as f: @@ -185,30 +185,97 @@ def plot_results(path: str): v_sizes_set = sorted(set(v for benchmark_data in medians_by_benchmark.values() for v in benchmark_data.keys())) # Create the plot - plt.figure(figsize=(10, 6)) + plt.figure(figsize=(6, 4)) for benchmark in benchmarks: if benchmark != 'Python': v_sizes = sorted(medians_by_benchmark[benchmark].keys()) elapsed_times = [medians_by_benchmark[benchmark][v] for v in v_sizes] - plt.plot(v_sizes, elapsed_times, '.', label=benchmark) + plt.plot(v_sizes, elapsed_times, '.', label=benchmark, markersize=10) plt.xlabel('Vector Size (v_size)') plt.ylabel('Elapsed Time (seconds)') #plt.title('Benchmark Results: Elapsed Time vs Vector Size') - plt.legend() + plt.legend(frameon=False) #plt.grid(True, alpha=0.3) plt.ylim(bottom=0) plt.tight_layout() # Save to PNG - plt.savefig(path.replace('.json', '') + '.png', dpi=300) + mpl.rcParams['svg.fonttype'] = 'none' + save_svg_with_theme_styles(plt, path.replace('.json', '') + '.svg') print("Plot saved") +def save_svg_with_theme_styles(pyplot_obj, path): + import io + import re + """ + Takes a pyplot object (typically `plt`) or a figure, captures its SVG output, + injects theme-based CSS, and writes to disk. + """ + + # --- Step 1: Capture SVG to memory --- + buf = io.StringIO() + + # pyplot_obj can be a module (plt) or a Figure instance + if hasattr(pyplot_obj, "gcf"): + fig = pyplot_obj.gcf() + else: + fig = pyplot_obj + + fig.savefig(buf, format="svg", dpi=150, transparent=True) + svg_data = buf.getvalue() + buf.close() + + # --- Step 2: Theme CSS to inject --- + theme_css = """ + + """ + + # --- Step 3: Inject CSS right after tag --- + # Find the first > after the opening tag + modified_svg = re.sub( + r"(]*>)", + r"\1\n" + theme_css, + svg_data, + count=1 + ) + + # --- Step 4: Write final output to disk --- + with open(path, "w", encoding="utf-8") as f: + f.write(modified_svg) + + if __name__ == "__main__": - path1 = 'benchmark_results_001.json' - path2 = 'benchmark_results_001_sparse.json' + path1 = 'docs/source/media/benchmark_results_001.json' + path2 = 'docs/source/media/benchmark_results_001_sparse.json' if 'no_simd' in sys.argv[1:]: os.environ["NPY_DISABLE_CPU_FEATURES"] = CPU_SIMD_FEATURES