Miscellaneous Terms
File: Utilities.py
- Author:
Sukanta Basu
- AI Assistance:
Claude Code (Anthropic) and Codex (OpenAI) are used for documentation, code restructuring, and performance optimization
- Date:
2025-10-20
- Description:
miscellaneous functions
- Imfilter(x)[source]
Applies a 3x3 box filter to a 2D field with periodic boundaries.
Parameters:
- xjax.numpy.ndarray
2D array representing a horizontal slice
Returns:
- jax.numpy.ndarray
Filtered 2D array with the same shape as input
- LogMemory()[source]
Prints memory usage statistics for all available JAX devices. Converts memory values to MB for readability.
- PlanarMean(F)[source]
Computes horizontal average of a 3D field at each vertical level.
Parameters:
- Fjax.numpy.ndarray
3D array with shape (nx, ny, nz)
Returns:
- jax.numpy.ndarray
1D array of length nz containing planar-averaged values
- Roots(coeffs, init_guess=1 + 0j, tol=1e-06, max_iter=20)[source]
Find a root of a polynomial using Laguerre’s method.
Laguerre’s method is a root-finding algorithm that works well for polynomials and converges cubically for simple roots. It works in the complex plane and can find complex roots.
Parameters:
- coeffsndarray
Polynomial coefficients in descending order of degree. For a 5th degree polynomial: [a5, a4, a3, a2, a1, a0] represents: a5*x^5 + a4*x^4 + a3*x^3 + a2*x^2 + a1*x + a0
- init_guesscomplex, optional
Initial guess for the root. Default: 1.0+0j For real polynomials, can use real guess, but complex arithmetic is used internally to handle complex roots.
- tolfloat, optional
Convergence tolerance. Iteration stops when: |x_new - x_old| < tol * (1 + |x_old|) or |f(x)| < tol Default: 1e-6
- max_iterint, optional
Maximum number of iterations. Default: 20
Returns:
- complex
Root of the polynomial if converged, or nan+0j if not converged within max_iter iterations.
Notes:
Uses complex arithmetic internally to handle complex roots
Includes numerical safeguards against division by zero
Uses relative tolerance for better handling of roots at different scales
The algorithm is JIT-compiled for performance
- StagGridAvg(F)[source]
Computes averaging of a 3D array along the z-direction.
Parameters:
- Fjax.numpy.ndarray
3D array with shape (nx, ny, nz)
Returns:
- jax.numpy.ndarray
Averaged field with shape (nx, ny, nz-1)
Source Code: Utilities
1# Copyright (C) 2025 Sukanta Basu
2#
3# This program is free software: you can redistribute it and/or modify
4# it under the terms of the GNU General Public License as published by
5# the Free Software Foundation, either version 3 of the License, or
6# (at your option) any later version.
7#
8# This program is distributed in the hope that it will be useful,
9# but WITHOUT ANY WARRANTY; without even the implied warranty of
10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11# GNU General Public License for more details.
12#
13# You should have received a copy of the GNU General Public License
14# along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16"""
17File: Utilities.py
18===========================
19
20:Author: Sukanta Basu
21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
22 code restructuring, and performance optimization
23:Date: 2025-10-20
24:Description: miscellaneous functions
25"""
26
27
28# ============================================================
29# Imports
30# ============================================================
31
32import jax
33import jax.numpy as jnp
34
35
36# ============================================================
37# Compute planar averaged values of a 3D field
38# ============================================================
39
40@jax.jit
41def PlanarMean(F):
42 """
43 Computes horizontal average of a 3D field at each vertical level.
44
45 Parameters:
46 -----------
47 F : jax.numpy.ndarray
48 3D array with shape (nx, ny, nz)
49
50 Returns:
51 --------
52 jax.numpy.ndarray
53 1D array of length nz containing planar-averaged values
54 """
55
56 return jnp.mean(F, axis=(0, 1))
57
58
59# ============================================================
60# Compute averaging of a 3D array along the z-direction
61# ============================================================
62
63@jax.jit
64def StagGridAvg(F):
65 """
66 Computes averaging of a 3D array along the z-direction.
67
68 Parameters:
69 -----------
70 F : jax.numpy.ndarray
71 3D array with shape (nx, ny, nz)
72
73 Returns:
74 --------
75 jax.numpy.ndarray
76 Averaged field with shape (nx, ny, nz-1)
77 """
78
79 return 0.5 * (F[:, :, :-1] + F[:, :, 1:])
80
81
82# ============================================================
83# Compute moving average filtering of 2D fields
84# ============================================================
85
86@jax.jit
87def Imfilter(x):
88 """
89 Applies a 3x3 box filter to a 2D field with periodic boundaries.
90
91 Parameters:
92 -----------
93 x : jax.numpy.ndarray
94 2D array representing a horizontal slice
95
96 Returns:
97 --------
98 jax.numpy.ndarray
99 Filtered 2D array with the same shape as input
100 """
101
102 kernel = jnp.ones((3, 3)) / 9.0 # 3x3 mean filter (box filter)
103
104 # Apply periodic padding (wrap around)
105 x_padded = jnp.pad(x, ((1, 1), (1, 1)), mode='wrap')
106
107 # Perform 2D convolution
108 x_filtered = jax.lax.conv_general_dilated(
109 x_padded[None, None, :, :], # Add batch & channel dims
110 kernel[None, None, :, :], # Kernel shape: [Out_C, In_C, H, W]
111 (1, 1), # Stride (1,1) ensures all pixels are processed
112 'VALID' # No extra padding applied beyond what we added
113 )
114
115 return x_filtered[0, 0, :, :] # Remove batch & channel dims
116
117
118# ============================================================
119# Compute roots of a polynomial using Laguerre's method
120# ============================================================
121
122@jax.jit
123def Roots(coeffs, init_guess=1.0 + 0j, tol=1e-6, max_iter=20):
124 """
125 Find a root of a polynomial using Laguerre's method.
126
127 Laguerre's method is a root-finding algorithm that works well for
128 polynomials and converges cubically for simple roots. It works in
129 the complex plane and can find complex roots.
130
131 Parameters:
132 -----------
133 coeffs : ndarray
134 Polynomial coefficients in descending order of degree.
135 For a 5th degree polynomial: [a5, a4, a3, a2, a1, a0]
136 represents: a5*x^5 + a4*x^4 + a3*x^3 + a2*x^2 + a1*x + a0
137 init_guess : complex, optional
138 Initial guess for the root. Default: 1.0+0j
139 For real polynomials, can use real guess, but complex arithmetic
140 is used internally to handle complex roots.
141 tol : float, optional
142 Convergence tolerance. Iteration stops when:
143 |x_new - x_old| < tol * (1 + |x_old|) or |f(x)| < tol
144 Default: 1e-6
145 max_iter : int, optional
146 Maximum number of iterations. Default: 20
147
148 Returns:
149 --------
150 complex
151 Root of the polynomial if converged, or nan+0j if not converged
152 within max_iter iterations.
153
154 Notes:
155 ------
156 - Uses complex arithmetic internally to handle complex roots
157 - Includes numerical safeguards against division by zero
158 - Uses relative tolerance for better handling of roots at different scales
159 - The algorithm is JIT-compiled for performance
160
161 """
162
163 # Convert to complex arithmetic to handle complex roots
164 cdtype = jnp.complex128 if jax.config.jax_enable_x64 else jnp.complex64
165 coeffs = coeffs.astype(cdtype)
166 x = jnp.asarray(init_guess, dtype=cdtype)
167
168 # Polynomial degree
169 n = coeffs.shape[0] - 1
170
171 # Helper functions for polynomial evaluation
172 def polynomial(p, x):
173 """Evaluate polynomial at x using Horner's method."""
174 return jnp.polyval(p, x)
175
176 def derivative(p, x):
177 """Evaluate first derivative at x."""
178 dp = jnp.polyder(p)
179 return jnp.polyval(dp, x)
180
181 def second_derivative(p, x):
182 """Evaluate second derivative at x."""
183 d2p = jnp.polyder(jnp.polyder(p))
184 return jnp.polyval(d2p, x)
185
186 # Loop continuation condition
187 def cond_fn(state):
188 x, iteration, converged = state
189 return (iteration < max_iter) & (~converged)
190
191 # Laguerre iteration step
192 def body_fn(state):
193 x, iteration, _ = state
194
195 # Evaluate polynomial and derivatives
196 f_x = polynomial(coeffs, x)
197 df_x = derivative(coeffs, x)
198 d2f_x = second_derivative(coeffs, x)
199
200 # Numerical safeguard: small epsilon for complex arithmetic
201 eps = 1e-16 + 0j
202
203 # Protect against division by zero
204 f_safe = jnp.where(jnp.abs(f_x) < eps, eps, f_x)
205
206 # Laguerre's method formulas
207 G = df_x / f_safe
208 H = G ** 2 - d2f_x / f_safe
209
210 # Compute discriminant: (n-1) * (n*H - G²)
211 # This is the CORRECTED formula
212 discriminant = (n - 1) * (n * H - G ** 2)
213 sqrt_disc = jnp.sqrt(discriminant)
214
215 # Choose denominator with larger absolute value for stability
216 denom1 = G + sqrt_disc
217 denom2 = G - sqrt_disc
218 denom = jnp.where(jnp.abs(denom1) > jnp.abs(denom2), denom1, denom2)
219
220 # Protect against division by zero
221 denom = jnp.where(jnp.abs(denom) < eps, eps, denom)
222
223 # Laguerre update step
224 x_new = x - n / denom
225
226 # Check convergence using both relative step size and function value
227 # This handles both large and small roots effectively
228 converged_step = jnp.abs(x_new - x) < tol * (1 + jnp.abs(x))
229 converged_value = jnp.abs(f_x) < tol
230 new_converged = converged_step | converged_value
231
232 return x_new, iteration + 1, new_converged
233
234 # Initialize state and run iteration loop
235 init_state = (x, jnp.array(0), jnp.array(False))
236 final_x, final_iter, converged = jax.lax.while_loop(
237 cond_fn, body_fn, init_state
238 )
239
240 # Return root if converged, otherwise return nan
241 return jnp.where(converged, final_x, jnp.nan + 0j)
242
243
244# ============================================================
245# Measure memory usage
246# ============================================================
247
248@jax.jit
249def LogMemory():
250 """
251 Prints memory usage statistics for all available JAX devices.
252 Converts memory values to MB for readability.
253 """
254
255 devices = jax.devices()
256 for device in devices:
257 print(f"\nDevice: {device}")
258 stats = device.memory_stats()
259
260 if stats is None:
261 print("Memory stats not available for this device")
262 continue
263
264 # Convert to MB for readability
265 bytes_in_use = stats.get('bytes_in_use', 0) / (1024 * 1024)
266 peak_bytes = stats.get('peak_bytes_in_use', 0) / (1024 * 1024)
267 allocated_bytes = stats.get('bytes_allocated', 0) / (1024 * 1024)
268
269 print(f"Current memory usage: {bytes_in_use:.2f} MB")
270 print(f"Peak memory usage: {peak_bytes:.2f} MB")
271 print(f"Allocated memory: {allocated_bytes:.2f} MB")
272
273 # Print all available stats for debugging
274 print("\nAll available memory stats:")
275 for key, value in stats.items():
276 print(f"{key}: {value / (1024 * 1024):.2f} MB")