Source code for src.utilities.Utilities

# Copyright (C) 2025 Sukanta Basu
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

"""
File: Utilities.py
===========================

:Author: Sukanta Basu
:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
                code restructuring, and performance optimization
:Date: 2025-10-20
:Description: miscellaneous functions
"""


# ============================================================
# Imports
# ============================================================

import jax
import jax.numpy as jnp


# ============================================================
#  Compute planar averaged values of a 3D field
# ============================================================

[docs] @jax.jit def PlanarMean(F): """ Computes horizontal average of a 3D field at each vertical level. Parameters: ----------- F : jax.numpy.ndarray 3D array with shape (nx, ny, nz) Returns: -------- jax.numpy.ndarray 1D array of length nz containing planar-averaged values """ return jnp.mean(F, axis=(0, 1))
# ============================================================ # Compute averaging of a 3D array along the z-direction # ============================================================
[docs] @jax.jit def StagGridAvg(F): """ Computes averaging of a 3D array along the z-direction. Parameters: ----------- F : jax.numpy.ndarray 3D array with shape (nx, ny, nz) Returns: -------- jax.numpy.ndarray Averaged field with shape (nx, ny, nz-1) """ return 0.5 * (F[:, :, :-1] + F[:, :, 1:])
# ============================================================ # Compute moving average filtering of 2D fields # ============================================================
[docs] @jax.jit def Imfilter(x): """ Applies a 3x3 box filter to a 2D field with periodic boundaries. Parameters: ----------- x : jax.numpy.ndarray 2D array representing a horizontal slice Returns: -------- jax.numpy.ndarray Filtered 2D array with the same shape as input """ kernel = jnp.ones((3, 3)) / 9.0 # 3x3 mean filter (box filter) # Apply periodic padding (wrap around) x_padded = jnp.pad(x, ((1, 1), (1, 1)), mode='wrap') # Perform 2D convolution x_filtered = jax.lax.conv_general_dilated( x_padded[None, None, :, :], # Add batch & channel dims kernel[None, None, :, :], # Kernel shape: [Out_C, In_C, H, W] (1, 1), # Stride (1,1) ensures all pixels are processed 'VALID' # No extra padding applied beyond what we added ) return x_filtered[0, 0, :, :] # Remove batch & channel dims
# ============================================================ # Compute roots of a polynomial using Laguerre's method # ============================================================
[docs] @jax.jit def Roots(coeffs, init_guess=1.0 + 0j, tol=1e-6, max_iter=20): """ Find a root of a polynomial using Laguerre's method. Laguerre's method is a root-finding algorithm that works well for polynomials and converges cubically for simple roots. It works in the complex plane and can find complex roots. Parameters: ----------- coeffs : ndarray Polynomial coefficients in descending order of degree. For a 5th degree polynomial: [a5, a4, a3, a2, a1, a0] represents: a5*x^5 + a4*x^4 + a3*x^3 + a2*x^2 + a1*x + a0 init_guess : complex, optional Initial guess for the root. Default: 1.0+0j For real polynomials, can use real guess, but complex arithmetic is used internally to handle complex roots. tol : float, optional Convergence tolerance. Iteration stops when: |x_new - x_old| < tol * (1 + |x_old|) or |f(x)| < tol Default: 1e-6 max_iter : int, optional Maximum number of iterations. Default: 20 Returns: -------- complex Root of the polynomial if converged, or nan+0j if not converged within max_iter iterations. Notes: ------ - Uses complex arithmetic internally to handle complex roots - Includes numerical safeguards against division by zero - Uses relative tolerance for better handling of roots at different scales - The algorithm is JIT-compiled for performance """ # Convert to complex arithmetic to handle complex roots cdtype = jnp.complex128 if jax.config.jax_enable_x64 else jnp.complex64 coeffs = coeffs.astype(cdtype) x = jnp.asarray(init_guess, dtype=cdtype) # Polynomial degree n = coeffs.shape[0] - 1 # Helper functions for polynomial evaluation def polynomial(p, x): """Evaluate polynomial at x using Horner's method.""" return jnp.polyval(p, x) def derivative(p, x): """Evaluate first derivative at x.""" dp = jnp.polyder(p) return jnp.polyval(dp, x) def second_derivative(p, x): """Evaluate second derivative at x.""" d2p = jnp.polyder(jnp.polyder(p)) return jnp.polyval(d2p, x) # Loop continuation condition def cond_fn(state): x, iteration, converged = state return (iteration < max_iter) & (~converged) # Laguerre iteration step def body_fn(state): x, iteration, _ = state # Evaluate polynomial and derivatives f_x = polynomial(coeffs, x) df_x = derivative(coeffs, x) d2f_x = second_derivative(coeffs, x) # Numerical safeguard: small epsilon for complex arithmetic eps = 1e-16 + 0j # Protect against division by zero f_safe = jnp.where(jnp.abs(f_x) < eps, eps, f_x) # Laguerre's method formulas G = df_x / f_safe H = G ** 2 - d2f_x / f_safe # Compute discriminant: (n-1) * (n*H - G²) # This is the CORRECTED formula discriminant = (n - 1) * (n * H - G ** 2) sqrt_disc = jnp.sqrt(discriminant) # Choose denominator with larger absolute value for stability denom1 = G + sqrt_disc denom2 = G - sqrt_disc denom = jnp.where(jnp.abs(denom1) > jnp.abs(denom2), denom1, denom2) # Protect against division by zero denom = jnp.where(jnp.abs(denom) < eps, eps, denom) # Laguerre update step x_new = x - n / denom # Check convergence using both relative step size and function value # This handles both large and small roots effectively converged_step = jnp.abs(x_new - x) < tol * (1 + jnp.abs(x)) converged_value = jnp.abs(f_x) < tol new_converged = converged_step | converged_value return x_new, iteration + 1, new_converged # Initialize state and run iteration loop init_state = (x, jnp.array(0), jnp.array(False)) final_x, final_iter, converged = jax.lax.while_loop( cond_fn, body_fn, init_state ) # Return root if converged, otherwise return nan return jnp.where(converged, final_x, jnp.nan + 0j)
# ============================================================ # Measure memory usage # ============================================================
[docs] @jax.jit def LogMemory(): """ Prints memory usage statistics for all available JAX devices. Converts memory values to MB for readability. """ devices = jax.devices() for device in devices: print(f"\nDevice: {device}") stats = device.memory_stats() if stats is None: print("Memory stats not available for this device") continue # Convert to MB for readability bytes_in_use = stats.get('bytes_in_use', 0) / (1024 * 1024) peak_bytes = stats.get('peak_bytes_in_use', 0) / (1024 * 1024) allocated_bytes = stats.get('bytes_allocated', 0) / (1024 * 1024) print(f"Current memory usage: {bytes_in_use:.2f} MB") print(f"Peak memory usage: {peak_bytes:.2f} MB") print(f"Allocated memory: {allocated_bytes:.2f} MB") # Print all available stats for debugging print("\nAll available memory stats:") for key, value in stats.items(): print(f"{key}: {value / (1024 * 1024):.2f} MB")