Miscellaneous Terms

File: Utilities.py

Author:

Sukanta Basu

AI Assistance:

Claude Code (Anthropic) and Codex (OpenAI) are used for documentation, code restructuring, and performance optimization

Date:

2025-10-20

Description:

miscellaneous functions

Imfilter(x)[source]

Applies a 3x3 box filter to a 2D field with periodic boundaries.

Parameters:

xjax.numpy.ndarray

2D array representing a horizontal slice

Returns:

jax.numpy.ndarray

Filtered 2D array with the same shape as input

LogMemory()[source]

Prints memory usage statistics for all available JAX devices. Converts memory values to MB for readability.

PlanarMean(F)[source]

Computes horizontal average of a 3D field at each vertical level.

Parameters:

Fjax.numpy.ndarray

3D array with shape (nx, ny, nz)

Returns:

jax.numpy.ndarray

1D array of length nz containing planar-averaged values

Roots(coeffs, init_guess=1 + 0j, tol=1e-06, max_iter=20)[source]

Find a root of a polynomial using Laguerre’s method.

Laguerre’s method is a root-finding algorithm that works well for polynomials and converges cubically for simple roots. It works in the complex plane and can find complex roots.

Parameters:

coeffsndarray

Polynomial coefficients in descending order of degree. For a 5th degree polynomial: [a5, a4, a3, a2, a1, a0] represents: a5*x^5 + a4*x^4 + a3*x^3 + a2*x^2 + a1*x + a0

init_guesscomplex, optional

Initial guess for the root. Default: 1.0+0j For real polynomials, can use real guess, but complex arithmetic is used internally to handle complex roots.

tolfloat, optional

Convergence tolerance. Iteration stops when: |x_new - x_old| < tol * (1 + |x_old|) or |f(x)| < tol Default: 1e-6

max_iterint, optional

Maximum number of iterations. Default: 20

Returns:

complex

Root of the polynomial if converged, or nan+0j if not converged within max_iter iterations.

Notes:

  • Uses complex arithmetic internally to handle complex roots

  • Includes numerical safeguards against division by zero

  • Uses relative tolerance for better handling of roots at different scales

  • The algorithm is JIT-compiled for performance

StagGridAvg(F)[source]

Computes averaging of a 3D array along the z-direction.

Parameters:

Fjax.numpy.ndarray

3D array with shape (nx, ny, nz)

Returns:

jax.numpy.ndarray

Averaged field with shape (nx, ny, nz-1)

Source Code: Utilities

Utilities.py
  1# Copyright (C) 2025 Sukanta Basu
  2#
  3# This program is free software: you can redistribute it and/or modify
  4# it under the terms of the GNU General Public License as published by
  5# the Free Software Foundation, either version 3 of the License, or
  6# (at your option) any later version.
  7#
  8# This program is distributed in the hope that it will be useful,
  9# but WITHOUT ANY WARRANTY; without even the implied warranty of
 10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11# GNU General Public License for more details.
 12#
 13# You should have received a copy of the GNU General Public License
 14# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15
 16"""
 17File: Utilities.py
 18===========================
 19
 20:Author: Sukanta Basu
 21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
 22                code restructuring, and performance optimization
 23:Date: 2025-10-20
 24:Description: miscellaneous functions
 25"""
 26
 27
 28# ============================================================
 29# Imports
 30# ============================================================
 31
 32import jax
 33import jax.numpy as jnp
 34
 35
 36# ============================================================
 37#  Compute planar averaged values of a 3D field
 38# ============================================================
 39
 40@jax.jit
 41def PlanarMean(F):
 42    """
 43    Computes horizontal average of a 3D field at each vertical level.
 44
 45    Parameters:
 46    -----------
 47    F : jax.numpy.ndarray
 48        3D array with shape (nx, ny, nz)
 49
 50    Returns:
 51    --------
 52    jax.numpy.ndarray
 53        1D array of length nz containing planar-averaged values
 54    """
 55
 56    return jnp.mean(F, axis=(0, 1))
 57
 58
 59# ============================================================
 60#  Compute averaging of a 3D array along the z-direction
 61# ============================================================
 62
 63@jax.jit
 64def StagGridAvg(F):
 65    """
 66    Computes averaging of a 3D array along the z-direction.
 67
 68    Parameters:
 69    -----------
 70    F : jax.numpy.ndarray
 71        3D array with shape (nx, ny, nz)
 72
 73    Returns:
 74    --------
 75    jax.numpy.ndarray
 76        Averaged field with shape (nx, ny, nz-1)
 77    """
 78
 79    return 0.5 * (F[:, :, :-1] + F[:, :, 1:])
 80
 81
 82# ============================================================
 83# Compute moving average filtering of 2D fields
 84# ============================================================
 85
 86@jax.jit
 87def Imfilter(x):
 88    """
 89    Applies a 3x3 box filter to a 2D field with periodic boundaries.
 90
 91    Parameters:
 92    -----------
 93    x : jax.numpy.ndarray
 94        2D array representing a horizontal slice
 95
 96    Returns:
 97    --------
 98    jax.numpy.ndarray
 99        Filtered 2D array with the same shape as input
100    """
101
102    kernel = jnp.ones((3, 3)) / 9.0  # 3x3 mean filter (box filter)
103
104    # Apply periodic padding (wrap around)
105    x_padded = jnp.pad(x, ((1, 1), (1, 1)), mode='wrap')
106
107    # Perform 2D convolution
108    x_filtered = jax.lax.conv_general_dilated(
109        x_padded[None, None, :, :],  # Add batch & channel dims
110        kernel[None, None, :, :],    # Kernel shape: [Out_C, In_C, H, W]
111        (1, 1),  # Stride (1,1) ensures all pixels are processed
112        'VALID'  # No extra padding applied beyond what we added
113    )
114
115    return x_filtered[0, 0, :, :]  # Remove batch & channel dims
116
117
118# ============================================================
119# Compute roots of a polynomial using Laguerre's method
120# ============================================================
121
122@jax.jit
123def Roots(coeffs, init_guess=1.0 + 0j, tol=1e-6, max_iter=20):
124    """
125    Find a root of a polynomial using Laguerre's method.
126
127    Laguerre's method is a root-finding algorithm that works well for
128    polynomials and converges cubically for simple roots. It works in
129    the complex plane and can find complex roots.
130
131    Parameters:
132    -----------
133    coeffs : ndarray
134        Polynomial coefficients in descending order of degree.
135        For a 5th degree polynomial: [a5, a4, a3, a2, a1, a0]
136        represents: a5*x^5 + a4*x^4 + a3*x^3 + a2*x^2 + a1*x + a0
137    init_guess : complex, optional
138        Initial guess for the root. Default: 1.0+0j
139        For real polynomials, can use real guess, but complex arithmetic
140        is used internally to handle complex roots.
141    tol : float, optional
142        Convergence tolerance. Iteration stops when:
143        |x_new - x_old| < tol * (1 + |x_old|) or |f(x)| < tol
144        Default: 1e-6
145    max_iter : int, optional
146        Maximum number of iterations. Default: 20
147
148    Returns:
149    --------
150    complex
151        Root of the polynomial if converged, or nan+0j if not converged
152        within max_iter iterations.
153
154    Notes:
155    ------
156    - Uses complex arithmetic internally to handle complex roots
157    - Includes numerical safeguards against division by zero
158    - Uses relative tolerance for better handling of roots at different scales
159    - The algorithm is JIT-compiled for performance
160
161    """
162
163    # Convert to complex arithmetic to handle complex roots
164    cdtype = jnp.complex128 if jax.config.jax_enable_x64 else jnp.complex64
165    coeffs = coeffs.astype(cdtype)
166    x = jnp.asarray(init_guess, dtype=cdtype)
167
168    # Polynomial degree
169    n = coeffs.shape[0] - 1
170
171    # Helper functions for polynomial evaluation
172    def polynomial(p, x):
173        """Evaluate polynomial at x using Horner's method."""
174        return jnp.polyval(p, x)
175
176    def derivative(p, x):
177        """Evaluate first derivative at x."""
178        dp = jnp.polyder(p)
179        return jnp.polyval(dp, x)
180
181    def second_derivative(p, x):
182        """Evaluate second derivative at x."""
183        d2p = jnp.polyder(jnp.polyder(p))
184        return jnp.polyval(d2p, x)
185
186    # Loop continuation condition
187    def cond_fn(state):
188        x, iteration, converged = state
189        return (iteration < max_iter) & (~converged)
190
191    # Laguerre iteration step
192    def body_fn(state):
193        x, iteration, _ = state
194
195        # Evaluate polynomial and derivatives
196        f_x = polynomial(coeffs, x)
197        df_x = derivative(coeffs, x)
198        d2f_x = second_derivative(coeffs, x)
199
200        # Numerical safeguard: small epsilon for complex arithmetic
201        eps = 1e-16 + 0j
202
203        # Protect against division by zero
204        f_safe = jnp.where(jnp.abs(f_x) < eps, eps, f_x)
205
206        # Laguerre's method formulas
207        G = df_x / f_safe
208        H = G ** 2 - d2f_x / f_safe
209
210        # Compute discriminant: (n-1) * (n*H - G²)
211        # This is the CORRECTED formula
212        discriminant = (n - 1) * (n * H - G ** 2)
213        sqrt_disc = jnp.sqrt(discriminant)
214
215        # Choose denominator with larger absolute value for stability
216        denom1 = G + sqrt_disc
217        denom2 = G - sqrt_disc
218        denom = jnp.where(jnp.abs(denom1) > jnp.abs(denom2), denom1, denom2)
219
220        # Protect against division by zero
221        denom = jnp.where(jnp.abs(denom) < eps, eps, denom)
222
223        # Laguerre update step
224        x_new = x - n / denom
225
226        # Check convergence using both relative step size and function value
227        # This handles both large and small roots effectively
228        converged_step = jnp.abs(x_new - x) < tol * (1 + jnp.abs(x))
229        converged_value = jnp.abs(f_x) < tol
230        new_converged = converged_step | converged_value
231
232        return x_new, iteration + 1, new_converged
233
234    # Initialize state and run iteration loop
235    init_state = (x, jnp.array(0), jnp.array(False))
236    final_x, final_iter, converged = jax.lax.while_loop(
237        cond_fn, body_fn, init_state
238    )
239
240    # Return root if converged, otherwise return nan
241    return jnp.where(converged, final_x, jnp.nan + 0j)
242
243
244# ============================================================
245#  Measure memory usage
246# ============================================================
247
248@jax.jit
249def LogMemory():
250    """
251    Prints memory usage statistics for all available JAX devices.
252    Converts memory values to MB for readability.
253    """
254
255    devices = jax.devices()
256    for device in devices:
257        print(f"\nDevice: {device}")
258        stats = device.memory_stats()
259
260        if stats is None:
261            print("Memory stats not available for this device")
262            continue
263
264        # Convert to MB for readability
265        bytes_in_use = stats.get('bytes_in_use', 0) / (1024 * 1024)
266        peak_bytes = stats.get('peak_bytes_in_use', 0) / (1024 * 1024)
267        allocated_bytes = stats.get('bytes_allocated', 0) / (1024 * 1024)
268
269        print(f"Current memory usage: {bytes_in_use:.2f} MB")
270        print(f"Peak memory usage: {peak_bytes:.2f} MB")
271        print(f"Allocated memory: {allocated_bytes:.2f} MB")
272
273        # Print all available stats for debugging
274        print("\nAll available memory stats:")
275        for key, value in stats.items():
276            print(f"{key}: {value / (1024 * 1024):.2f} MB")