JAX-ALFA Performance Benchmarks: NVIDIA A100 (80 GB)

Last updated: May 2026

This notebook reports wall-clock performance of JAX-ALFA across a range of cases, grid resolutions, SGS models, and floating-point precisions. All timings are expressed as time per iteration (ms/iter), which normalizes by the number of time steps in each completed run. This makes matched configurations directly comparable, while cross-case comparisons should still be interpreted in light of the active physics and configuration choices.

The notebook is structured around three questions:

  1. Resolution scaling — how does cost grow with grid size?

  2. Precision impact — how much does single precision (SP) accelerate runs vs. double (DP)?

  3. Model overhead — what is the extra cost of LASDD vs. LAD, and WL vs. SM?

The data-loading cells scan run.log files automatically; adding a new platform directory (e.g. examples_A6000ada/) and registering it in PLATFORM_DIRS below is all that is needed to extend the plots to a new GPU.

Hardware

Platform label

System

GPU

VRAM

A100 (80 GB)

NVIDIA DGX Cloud

NVIDIA A100 Tensor Core GPU

80 GB HBM2e

Case Summary

Case

Type

Geostrophic wind

Surface BC

Moisture

Coriolis

CBL_N91

Convective BL

\(U_g = V_g = 0\)

Constant heat flux

No

Yes

SBL_GABLS1

Stable BL

\(U_g = 8\) m/s

Time-varying \(T_s\)

No

Yes

SGS model naming convention

Label

Dynamic procedure

Base model

Scale assumption

LAD-SM

Locally Averaged Dynamic

Smagorinsky (SM)

Scale invariant

LAD-WL

Locally Averaged Dynamic

Wong-Lilly (WL)

Scale invariant

LASDD-SM

Locally Averaged Scale-Dependent Dynamic

Smagorinsky (SM)

Scale dependent

LASDD-WL

Locally Averaged Scale-Dependent Dynamic

Wong-Lilly (WL)

Scale dependent

Run matrix for each case: resolutions \(\times\) SGS models (LAD-SM, LAD-WL, LASDD-SM, LASDD-WL) \(\times\) precisions (SP, DP). Entries without a run.log were either not yet submitted or are still running.

Setup

[1]:
import re
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from pathlib import Path

plt.rcParams.update({
    'text.usetex': True,
    'font.size': 14,
    'axes.labelsize': 16,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
})

Repository root and platform directories

[2]:
def find_repo_root(start=None):
    path = Path(start or ('__file__' in globals() and __file__) or Path.cwd()).resolve()
    for candidate in (path, *path.parents):
        if (candidate / 'examples').is_dir() and (candidate / 'docs').is_dir():
            return candidate
    raise FileNotFoundError('Could not locate JAXALFA0.1 repository root')

BaseDir = find_repo_root()
print(f'Repository root: {BaseDir}')

# Register platform directories here.
# Key = human-readable label shown in plots and tables.
# Value = path to the examples directory for that platform (relative to BaseDir).
PLATFORM_DIRS = {
    'A100 (80 GB)': BaseDir / 'examples',
    # 'RTX A6000 Ada': BaseDir / 'examples_A6000ada',   # uncomment when available
}
Repository root: /Users/sukantabasu/Dropbox/Codes/LES/JAX-ALFA/JAXALFA0.1

Data loading

[3]:
_DIR_RE = re.compile(
    r'(?P<nx>\d+)x(?P<ny>\d+)x(?P<nz>\d+)'
    r'_(?P<dyn>LAD|LASDD)'
    r'_(?P<base>SM|WL)'
    r'_(?P<prec>SP|DP)$'
)
_TIME_RE = re.compile(r'Total Elapsed Time:\s+([\d.]+)\s+seconds')
_ITER_RE = re.compile(r'Finished Iteration\s+(\d+)\s*/\s*(\d+)')


def parse_run_log(log_path):
    """Return (total_seconds, total_iterations) from a run.log, or None.

    Some logs contain more than one completed timing block from repeated runs.
    Use the last Total Elapsed Time and last Finished Iteration entry, which
    correspond to the most recent completed block in the file.
    """
    text = log_path.read_text(errors='replace')
    all_times = _TIME_RE.findall(text)
    if not all_times:
        return None
    total_sec = float(all_times[-1])
    iters = _ITER_RE.findall(text)
    if not iters:
        return None
    last_done, last_total = int(iters[-1][0]), int(iters[-1][1])
    if last_done < last_total:
        return None  # incomplete run
    return total_sec, last_total


def scan_platform(examples_dir, platform_label):
    """Walk examples_dir and return a list of record dicts."""
    records = []
    for log_path in sorted(examples_dir.glob('*/runs/*/run.log')):
        run_dir  = log_path.parent
        run_name = run_dir.name
        case     = run_dir.parent.parent.name
        m = _DIR_RE.match(run_name)
        if not m:
            continue
        result = parse_run_log(log_path)
        if result is None:
            continue
        total_sec, total_iters = result
        nx        = int(m.group('nx'))
        dyn_key   = m.group('dyn')    # LAD | LASDD  (dynamic procedure)
        base_key  = m.group('base')   # SM  | WL     (base SGS model)
        prec_key  = m.group('prec')   # SP  | DP
        sgs_label = f"{dyn_key}-{base_key}"   # e.g. LAD-SM
        records.append({
            'Platform':         platform_label,
            'Case':             case,
            'Resolution':       f'{nx}³',
            'N':                nx,
            'Grid points':      nx**3,
            'Dynamic proc.':    dyn_key,
            'Base model':       base_key,
            'SGS model':        sgs_label,
            'Precision':        prec_key,
            'Total iterations': total_iters,
            'Total time (s)':   round(total_sec, 3),
            'Time/iter (ms)':   round(total_sec * 1000 / total_iters, 4),
        })
    return records


all_records = []
for label, path in PLATFORM_DIRS.items():
    if path.is_dir():
        recs = scan_platform(path, label)
        all_records.extend(recs)
        print(f'{label}: {len(recs)} completed runs found in {path}')
    else:
        print(f'{label}: directory not found — {path}')

df = pd.DataFrame(all_records)

sort_keys = ['Platform', 'Case', 'N', 'Dynamic proc.', 'Base model', 'Precision']
df_sorted = df.sort_values(sort_keys).reset_index(drop=True)
A100 (80 GB): 57 completed runs found in /Users/sukantabasu/Dropbox/Codes/LES/JAX-ALFA/JAXALFA0.1/examples

Headline Metrics

A compact snapshot of the completed benchmark set before the detailed plots.

[4]:
completed_runs = len(df_sorted)
completed_cases = df_sorted['Case'].nunique()
completed_resolutions = df_sorted[['Case', 'N']].drop_duplicates().shape[0]

headline = pd.DataFrame([
    {'Metric': 'Completed runs', 'Value': completed_runs},
    {'Metric': 'Cases with completed runs', 'Value': completed_cases},
    {'Metric': 'Case-resolution pairs', 'Value': completed_resolutions},
    {'Metric': 'Fastest run (ms/iter)', 'Value': df_sorted['Time/iter (ms)'].min()},
    {'Metric': 'Slowest run (ms/iter)', 'Value': df_sorted['Time/iter (ms)'].max()},
    {'Metric': 'Median run (ms/iter)', 'Value': df_sorted['Time/iter (ms)'].median()},
])

case_summary = (
    df_sorted
    .groupby(['Case', 'N', 'Resolution'], as_index=False)
    .agg(
        **{
            'Completed runs': ('Time/iter (ms)', 'size'),
            'Min ms/iter': ('Time/iter (ms)', 'min'),
            'Median ms/iter': ('Time/iter (ms)', 'median'),
            'Max ms/iter': ('Time/iter (ms)', 'max'),
        }
    )
    .sort_values(['Case', 'N'])
    .reset_index(drop=True)
)

display(headline)
display(case_summary)
Metric Value
0 Completed runs 57.0000
1 Cases with completed runs 4.0000
2 Case-resolution pairs 10.0000
3 Fastest run (ms/iter) 8.3434
4 Slowest run (ms/iter) 2580.1850
5 Median run (ms/iter) 45.0507
Case N Resolution Completed runs Min ms/iter Median ms/iter Max ms/iter
0 CBL_N91 64 64³ 8 11.2704 15.01100 18.7421
1 CBL_N91 128 128³ 8 42.0543 70.10770 101.5644
2 CBL_N91 256 256³ 8 481.1481 743.39035 1027.9178
3 CBL_N91 384 384³ 1 2580.1850 2580.18500 2580.1850
4 NBL_A94 64 64³ 1 12.6894 12.68940 12.6894
5 NBL_A94 128 128³ 4 41.8231 43.18860 45.0180
6 SBL_GABLS1 64 64³ 8 8.3434 12.25935 16.6933
7 SBL_GABLS1 128 128³ 8 32.9590 56.12020 81.9103
8 SBL_GABLS1 256 256³ 7 399.5710 425.48760 839.0768
9 SBL_GABLS3 256 256³ 4 411.8750 418.97910 438.0652

Completion Matrix

Completed run.log coverage for the case-resolution combinations represented in the benchmark data. Blank cells indicate configurations with no completed timing in the scanned logs.

[5]:
config_order = [
    f'{sgs}-{prec}'
    for sgs in ['LAD-SM', 'LAD-WL', 'LASDD-SM', 'LASDD-WL']
    for prec in ['DP', 'SP']
]

coverage = (
    df_sorted
    .assign(Config=df_sorted['SGS model'] + '-' + df_sorted['Precision'])
    .pivot_table(index=['Case', 'N'], columns='Config', values='Time/iter (ms)', aggfunc='size')
    .reindex(columns=config_order)
    .fillna(0)
    .astype(int)
    .sort_index()
)

fig, ax = plt.subplots(figsize=(10, 0.6 * len(coverage) + 2), constrained_layout=True)
complete = coverage.values > 0
escape_case = lambda s: s.replace('_', r'\_')
im = ax.imshow(complete, aspect='auto', cmap='Greens', vmin=0, vmax=1)

ax.set_xticks(np.arange(len(coverage.columns)))
ax.set_xticklabels(coverage.columns, rotation=45, ha='right')
ax.set_yticks(np.arange(len(coverage.index)))
ax.set_yticklabels([f'{escape_case(case)}  {n}$^3$' for case, n in coverage.index])

for i in range(coverage.shape[0]):
    for j in range(coverage.shape[1]):
        ax.text(j, i, 'x' if complete[i, j] else '', ha='center', va='center', color='black')

ax.set_title('Completed Benchmark Coverage')
ax.set_xlabel('Configuration')
ax.set_ylabel('Case and resolution')
ax.grid(False)
plt.show()
../_images/benchmark_Benchmark_A100_14_0.png

Compact Results Table

A summarized view of completed timings. The full per-run table remains available as df_sorted for audit or export.

[6]:
compact_cols = ['Case', 'Resolution', 'Completed runs', 'Min ms/iter', 'Median ms/iter', 'Max ms/iter']

compact_results = case_summary[compact_cols].copy()
display(compact_results)
Case Resolution Completed runs Min ms/iter Median ms/iter Max ms/iter
0 CBL_N91 64³ 8 11.2704 15.01100 18.7421
1 CBL_N91 128³ 8 42.0543 70.10770 101.5644
2 CBL_N91 256³ 8 481.1481 743.39035 1027.9178
3 CBL_N91 384³ 1 2580.1850 2580.18500 2580.1850
4 NBL_A94 64³ 1 12.6894 12.68940 12.6894
5 NBL_A94 128³ 4 41.8231 43.18860 45.0180
6 SBL_GABLS1 64³ 8 8.3434 12.25935 16.6933
7 SBL_GABLS1 128³ 8 32.9590 56.12020 81.9103
8 SBL_GABLS1 256³ 7 399.5710 425.48760 839.0768
9 SBL_GABLS3 256³ 4 411.8750 418.97910 438.0652

Resolution Scaling

Time per iteration as a function of grid size \(N\) for all available cases on A100 (80 GB). Each row is one case; columns show DP (left) and SP (right). Lines are coloured by SGS model. The dashed reference line shows ideal \(N^3\) scaling anchored to the 64\(^3\) DP average of the first case.

[7]:
sgs_styles = {
    'LAD-SM':   {'color': '#1f77b4', 'marker': 'o'},
    'LAD-WL':   {'color': '#ff7f0e', 'marker': 's'},
    'LASDD-SM': {'color': '#2ca02c', 'marker': '^'},
    'LASDD-WL': {'color': '#d62728', 'marker': 'D'},
}

cases_in_data = sorted(df_sorted[df_sorted['Platform'] == 'A100 (80 GB)']['Case'].unique())
N_vals_all    = sorted(df_sorted[df_sorted['Platform'] == 'A100 (80 GB)']['N'].unique())
n_cases       = len(cases_in_data)

# Reference line anchored to 64³ DP of the first available case
ref_case   = cases_in_data[0]
sub64_dp   = df_sorted[
    (df_sorted['Case'] == ref_case) &
    (df_sorted['Precision'] == 'DP') &
    (df_sorted['N'] == 64) &
    (df_sorted['Platform'] == 'A100 (80 GB)')
]['Time/iter (ms)'].mean()

fig, axs = plt.subplots(n_cases, 2,
                         figsize=(12, 4 * n_cases + 1),
                         constrained_layout=True,
                         sharey=True)
if n_cases == 1:
    axs = axs[np.newaxis, :]

for row, case in enumerate(cases_in_data):
    escape_label = case.replace('_', r'\_')
    for col, (prec, prec_title) in enumerate(zip(
            ['DP', 'SP'], ['Double Precision (DP)', 'Single Precision (SP)'])):
        ax = axs[row, col]
        sub = df_sorted[
            (df_sorted['Case'] == case) &
            (df_sorted['Precision'] == prec) &
            (df_sorted['Platform'] == 'A100 (80 GB)')
        ].copy()

        for sgs_label, style in sgs_styles.items():
            grp = sub[sub['SGS model'] == sgs_label].sort_values('N')
            if grp.empty:
                continue
            ax.plot(grp['N'], grp['Time/iter (ms)'],
                    color=style['color'], marker=style['marker'],
                    linewidth=2, markersize=7, label=sgs_label)

        N_ref = np.array(N_vals_all)
        t_ref = sub64_dp * (N_ref / 64) ** 3
        ax.plot(N_ref, t_ref, 'k--', linewidth=1.2, label=r'$\propto N^3$')

        ax.set_xscale('log', base=2)
        ax.set_yscale('log')
        ax.set_xticks(N_vals_all)
        ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
        ax.set_title(fr'{escape_label} --- {prec_title}')
        ax.legend(frameon=False, fontsize=10)
        ax.grid(True, which='both', alpha=0.3)

        if row == n_cases - 1:
            ax.set_xlabel(r'$N$ (grid points per dimension)')
        if col == 0:
            ax.set_ylabel(r'Time per iteration (ms)')

fig.suptitle(r'Resolution Scaling --- A100 (80 GB)', fontsize=16)
plt.show()
../_images/benchmark_Benchmark_A100_18_0.png

Performance Ratio Summary

Three matched-pair diagnostics in one figure: SP speedup, LASDD dynamic-procedure overhead, and WL base-model overhead. Only completed matched pairs are shown.

[8]:
escape_case = lambda s: s.replace('_', r'\_')

merge_keys = ['Platform', 'Case', 'N', 'Dynamic proc.', 'Base model']

df_dp = df_sorted[df_sorted['Precision'] == 'DP'].copy()
df_sp = df_sorted[df_sorted['Precision'] == 'SP'].copy()

df_speedup = df_dp[merge_keys + ['SGS model', 'Time/iter (ms)', 'Resolution']].merge(
    df_sp[merge_keys + ['Time/iter (ms)']],
    on=merge_keys, suffixes=('_DP', '_SP')
)
df_speedup['Ratio'] = df_speedup['Time/iter (ms)_DP'] / df_speedup['Time/iter (ms)_SP']
df_speedup['Label'] = df_speedup.apply(
    lambda r: f"{escape_case(r['Case'])} {r['Resolution']} {r['SGS model']}", axis=1
)
df_speedup['Metric'] = 'SP speedup'

dyn_keys = ['Platform', 'Case', 'N', 'Base model', 'Precision', 'Resolution']
df_lad = df_sorted[df_sorted['Dynamic proc.'] == 'LAD'].copy()
df_lasdd = df_sorted[df_sorted['Dynamic proc.'] == 'LASDD'].copy()

df_dyn_ratio = df_lad[dyn_keys + ['Time/iter (ms)']].merge(
    df_lasdd[dyn_keys + ['Time/iter (ms)']],
    on=dyn_keys, suffixes=('_LAD', '_LASDD')
)
df_dyn_ratio['Ratio'] = df_dyn_ratio['Time/iter (ms)_LASDD'] / df_dyn_ratio['Time/iter (ms)_LAD']
df_dyn_ratio['Label'] = df_dyn_ratio.apply(
    lambda r: f"{escape_case(r['Case'])} {r['Resolution']} {r['Base model']}-{r['Precision']}", axis=1
)
df_dyn_ratio['Metric'] = 'LASDD/LAD overhead'

base_keys = ['Platform', 'Case', 'N', 'Dynamic proc.', 'Precision', 'Resolution']
df_sm = df_sorted[df_sorted['Base model'] == 'SM'].copy()
df_wl = df_sorted[df_sorted['Base model'] == 'WL'].copy()

df_base_ratio = df_sm[base_keys + ['Time/iter (ms)']].merge(
    df_wl[base_keys + ['Time/iter (ms)']],
    on=base_keys, suffixes=('_SM', '_WL')
)
df_base_ratio['Ratio'] = df_base_ratio['Time/iter (ms)_WL'] / df_base_ratio['Time/iter (ms)_SM']
df_base_ratio['Label'] = df_base_ratio.apply(
    lambda r: f"{escape_case(r['Case'])} {r['Resolution']} {r['Dynamic proc.']}-{r['Precision']}", axis=1
)
df_base_ratio['Metric'] = 'WL/SM overhead'

ratio_panels = [
    (df_speedup.sort_values(['Case', 'N', 'SGS model']), 'SP speedup: $t_{DP}/t_{SP}$', 2.0, '#4e79a7'),
    (df_dyn_ratio.sort_values(['Case', 'N', 'Base model', 'Precision']), 'LASDD overhead: $t_{LASDD}/t_{LAD}$', 1.0, '#f28e2b'),
    (df_base_ratio.sort_values(['Case', 'N', 'Dynamic proc.', 'Precision']), 'WL overhead: $t_{WL}/t_{SM}$', 1.0, '#59a14f'),
]

fig, axs = plt.subplots(3, 1, figsize=(12, 12), constrained_layout=True)

for ax, (sub, title, reference, color) in zip(axs, ratio_panels):
    y = np.arange(len(sub))
    ax.barh(y, sub['Ratio'], color=color, edgecolor='black', linewidth=0.5)
    ax.axvline(reference, color='k', linestyle='--', linewidth=1.1)
    ax.set_yticks(y)
    ax.set_yticklabels(sub['Label'], fontsize=9)
    ax.invert_yaxis()
    ax.set_xlabel('Ratio')
    ax.set_title(title)
    ax.grid(axis='x', alpha=0.25)
    xmax = max(sub['Ratio'].max() * 1.15, reference * 1.2)
    ax.set_xlim(0, xmax)
    for yi, value in zip(y, sub['Ratio']):
        ax.text(value + 0.02 * xmax, yi, f'{value:.2f}', va='center', fontsize=8)

fig.suptitle(r'Matched-Pair Performance Ratios --- A100 (80 GB)', fontsize=16)
plt.show()
../_images/benchmark_Benchmark_A100_20_0.png

Cross-Case Diagnostic

Time per iteration at \(64^3\) for the available completed runs. This is a diagnostic comparison rather than a normalized physics benchmark; differences can reflect active physics, configuration choices, and algorithmic work per solver iteration.

[9]:
df_64 = df_sorted[
    (df_sorted['N'] == 64) &
    (df_sorted['Platform'] == 'A100 (80 GB)')
].copy()

sgs_models = sorted(df_64['SGS model'].unique())
cases = sorted(df_64['Case'].unique())

case_colors = {
    'CBL_N91':    '#4e79a7',
    'SBL_GABLS1': '#f28e2b',
    'NBL_A94':    '#59a14f',
    'DC_Wangara': '#e15759',
}

fig, axs = plt.subplots(1, 2, figsize=(13, 5), constrained_layout=True, sharey=True)

for ax, prec in zip(axs, ['DP', 'SP']):
    sub = df_64[df_64['Precision'] == prec].sort_values(['Case', 'SGS model'])
    sgs_labels_avail = sorted(sub['SGS model'].unique())
    n_sgs = len(sgs_labels_avail)
    n_cases_avail = len(cases)
    group_width = 0.8
    bar_width = group_width / max(n_cases_avail, 1)

    x = np.arange(n_sgs)
    for i, case in enumerate(cases):
        vals = []
        for sgs in sgs_labels_avail:
            row = sub[(sub['Case'] == case) & (sub['SGS model'] == sgs)]
            vals.append(row['Time/iter (ms)'].values[0] if not row.empty else np.nan)
        offset = (i - n_cases_avail / 2 + 0.5) * bar_width
        ax.bar(x + offset, vals, width=bar_width * 0.9,
               color=case_colors.get(case, '#999'),
               edgecolor='black', linewidth=0.6, label=case.replace('_', r'\_'))

    ax.set_xticks(x)
    ax.set_xticklabels(sgs_labels_avail, fontsize=11)
    ax.set_xlabel('SGS model')
    ax.set_ylabel(r'Time per iteration (ms)')
    ax.set_title(f'{prec}')
    ax.legend(frameon=False, fontsize=10)
    ax.grid(axis='y', alpha=0.3)

fig.suptitle(r'Cross-case comparison at $64^3$ --- A100 (80 GB)', fontsize=16)
plt.show()
../_images/benchmark_Benchmark_A100_22_0.png