Pressure Terms

Source Code: NSE_PressureTerms

NSE_PressureTerms.py
  1# Copyright (C) 2025 Sukanta Basu
  2#
  3# This program is free software: you can redistribute it and/or modify
  4# it under the terms of the GNU General Public License as published by
  5# the Free Software Foundation, either version 3 of the License, or
  6# (at your option) any later version.
  7#
  8# This program is distributed in the hope that it will be useful,
  9# but WITHOUT ANY WARRANTY; without even the implied warranty of
 10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11# GNU General Public License for more details.
 12#
 13# You should have received a copy of the GNU General Public License
 14# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15
 16"""
 17File: NSE_PressureTerms.py
 18=============================
 19
 20:Author: Sukanta Basu
 21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
 22                code restructuring, and performance optimization
 23:Date: 2025-4-3
 24:Description: computes pressure terms by solving the Poisson equation using rfft2
 25"""
 26
 27# ============================================================
 28#  Imports
 29# ============================================================
 30
 31import jax
 32import jax.numpy as jnp
 33
 34# Import configuration from namelist
 35from ..config.ConfigLoader import *
 36
 37# Import derived variables
 38from ..config.DerivedVars import *
 39
 40# Import FFT modules
 41from ..operations.FFT import FFT
 42
 43# Import derivative functions
 44from ..operations.Derivatives import Derivxy, Derivz_Generic_uvp
 45
 46from ..initialization.Preprocess import Constant, Wavenumber, ZeRo3DIni
 47mx, my, nx_rfft, ny_rfft, mx_rfft, my_rfft = Constant()
 48
 49
 50# ============================================================
 51#  Initialize static variables for pressure solver
 52# ============================================================
 53
 54@jax.jit
 55def PressureInit():
 56    """
 57    Returns:
 58    --------
 59    kr2_pressure : ndarray, shape (nx, ny_rfft)
 60        Wavenumber array for x-direction derivatives, for rfft2
 61    kc2_pressure : ndarray, shape (nx, ny_rfft)
 62        Wavenumber array for y-direction derivatives, for rfft2
 63    a_pressure : ndarray, shape (nz+1)
 64        Coefficient array for tridiagonal matrix (sub-diagonal)
 65    b_pressure : ndarray, shape (nx, ny_rfft, nz+1)
 66        Main diagonal coefficients for each grid point
 67    c_pressure : ndarray, shape (nz+1)
 68        Coefficient array for tridiagonal matrix (super-diagonal)
 69    """
 70    # Create 2D wavenumber arrays
 71    # For x-direction, use the full range
 72    kr2_pressure = jnp.zeros((nx, ny_rfft))
 73    for i in range(nx):
 74        if i < nx // 2:
 75            kr2_pressure = kr2_pressure.at[i, :].set(i)
 76        else:
 77            kr2_pressure = kr2_pressure.at[i, :].set(i - nx)
 78
 79    # For y-direction, use only positive frequencies
 80    kc2_pressure = jnp.zeros((nx, ny_rfft))
 81    for j in range(ny_rfft):
 82        kc2_pressure = kc2_pressure.at[:, j].set(j)
 83
 84    # Create tridiagonal matrix coefficients
 85    a_pressure = jnp.ones(nz + 1) / (dz ** 2)  # sub-diagonal
 86    c_pressure = jnp.ones(nz + 1) / (dz ** 2)  # super-diagonal
 87
 88    # Set boundary conditions for tridiagonal matrix
 89    a_pressure = a_pressure.at[0].set(0)  # bottom boundary
 90    a_pressure = a_pressure.at[nz].set(-1)  # top boundary
 91    c_pressure = c_pressure.at[0].set(1)  # bottom boundary
 92    c_pressure = c_pressure.at[nz].set(0)  # top boundary
 93
 94    # Calculate the main diagonal values for each grid point
 95    bb = -(kr2_pressure ** 2 + kc2_pressure ** 2 + 2 / (dz ** 2))
 96    b_pressure = jnp.repeat(bb[:, :, jnp.newaxis], nz + 1, axis=2)
 97    b_pressure = b_pressure.at[:, :, 0].set(-1)
 98    b_pressure = b_pressure.at[:, :, nz].set(1)
 99
100    return (kr2_pressure, kc2_pressure,
101            a_pressure, b_pressure, c_pressure)
102
103
104# ============================================================
105#  Compute right-hand side for pressure equation
106# ============================================================
107
108@jax.jit
109def PressureRC(
110        u, v, w,
111        RHS_u, RHS_v, RHS_w,
112        RHS_u_previous, RHS_v_previous, RHS_w_previous,
113        divtz, kr2_pressure, kc2_pressure):
114    """
115    Calculate the right-hand side of the Poisson equation for pressure using rfft2.
116
117    Parameters:
118    -----------
119    u, v, w : ndarray of shape (nx, ny, nz)
120        Velocity components at current time step
121    RHS_u, RHS_v, RHS_w : ndarray of shape (nx, ny, nz)
122        Right-hand side terms for momentum equations at current time step
123    RHS_u_previous, RHS_v_previous, RHS_w_previous : ndarray of shape (nx, ny, nz)
124        Right-hand side terms for momentum equations at previous time step
125    divtz : ndarray of shape (nx, ny, nz)
126        Divergence of the SGS stress tensor in z-direction
127    kr2_pressure, kc2_pressure : ndarray
128        Wavenumber arrays for spectral derivatives with rfft2
129
130    Returns:
131    --------
132    RC_real : ndarray of shape (nx, ny_rfft, nz+1)
133        Real part of the right-hand side for pressure Poisson equation
134    RC_imag : ndarray of shape (nx, ny_rfft, nz+1)
135        Imaginary part of the right-hand side for pressure Poisson equation
136    fRz_real : ndarray
137        Real part of fRz for zero mode processing
138    """
139    # Compute intermediate terms
140    Rx = RHS_u - (1 / 3) * RHS_u_previous + (2 / (3 * dt_nondim)) * u
141    Ry = RHS_v - (1 / 3) * RHS_v_previous + (2 / (3 * dt_nondim)) * v
142    Rz = RHS_w - (1 / 3) * RHS_w_previous + (2 / (3 * dt_nondim)) * w
143
144    # Calculate rfft2 transforms
145    fRx = jnp.fft.rfft2(Rx, axes=(0, 1))
146    fRy = jnp.fft.rfft2(Ry, axes=(0, 1))
147    fRz = jnp.fft.rfft2(Rz, axes=(0, 1))
148
149    # Initialize arrays for RC_real and RC_imag with zeros
150    RC_real = jnp.zeros((nx, ny_rfft, nz + 1))
151    RC_imag = jnp.zeros((nx, ny_rfft, nz + 1))
152
153    # Get boundary terms
154    fbot = jnp.fft.rfft2(divtz[:, :, 0], axes=(0, 1))
155    ftop = jnp.fft.rfft2(divtz[:, :, -1], axes=(0, 1))
156
157    # Set boundary conditions
158    RC_real = RC_real.at[:, :, 0].set(-dz * jnp.real(fbot))
159    RC_imag = RC_imag.at[:, :, 0].set(-dz * jnp.imag(fbot))
160    RC_real = RC_real.at[:, :, nz].set(-dz * jnp.real(ftop))
161    RC_imag = RC_imag.at[:, :, nz].set(-dz * jnp.imag(ftop))
162
163    # Prepare wavenumbers for broadcasting
164    kr2_3d = kr2_pressure[:, :, jnp.newaxis]
165    kc2_3d = kc2_pressure[:, :, jnp.newaxis]
166
167    # Compute horizontal derivatives for interior points
168    horiz_deriv_real = -kr2_3d * jnp.imag(fRx[:, :, :-1])
169    horiz_deriv_real -= kc2_3d * jnp.imag(fRy[:, :, :-1])
170
171    horiz_deriv_imag = kr2_3d * jnp.real(fRx[:, :, :-1])
172    horiz_deriv_imag += kc2_3d * jnp.real(fRy[:, :, :-1])
173
174    # Compute vertical derivatives
175    vert_diff_real = jnp.diff(jnp.real(fRz), axis=2) / dz
176    vert_diff_imag = jnp.diff(jnp.imag(fRz), axis=2) / dz
177
178    # Update interior points
179    RC_real = RC_real.at[:, :, 1:nz].set(horiz_deriv_real + vert_diff_real)
180    RC_imag = RC_imag.at[:, :, 1:nz].set(horiz_deriv_imag + vert_diff_imag)
181
182    return RC_real, RC_imag, jnp.real(fRz)
183
184
185@jax.jit
186def PressureMatrix(a_pressure, b_pressure_ij, c_pressure):
187    """
188    Create a tridiagonal matrix for pressure solver for a specific (i,j) wavenumber.
189
190    Parameters:
191    -----------
192    a_pressure : ndarray, shape (nz+1)
193        Lower diagonal coefficients of the tridiagonal matrix
194    b_pressure_ij : ndarray, shape (nz+1)
195        Main diagonal coefficients for a specific wavenumber pair (i,j)
196    c_pressure : ndarray, shape (nz+1)
197        Upper diagonal coefficients of the tridiagonal matrix
198
199    Returns:
200    --------
201    L_matrix : ndarray, shape (nz+1, nz+1)
202        Tridiagonal matrix for the specific wavenumber pair
203    """
204    # Initialize the tridiagonal matrix
205    L_matrix = jnp.zeros((nz + 1, nz + 1))
206
207    # Set main diagonal (b values)
208    L_matrix = L_matrix.at[jnp.arange(nz + 1), jnp.arange(nz + 1)].set(b_pressure_ij)
209
210    # Set lower diagonal (a values) - skip first row
211    L_matrix = L_matrix.at[jnp.arange(1, nz + 1), jnp.arange(nz)].set(a_pressure[1:])
212
213    # Set upper diagonal (c values) - skip last row
214    L_matrix = L_matrix.at[jnp.arange(nz), jnp.arange(1, nz + 1)].set(c_pressure[:nz])
215
216    return L_matrix
217
218
219@jax.jit
220def PressureSolve(RC_real, RC_imag, fRz_real, a_pressure, b_pressure, c_pressure):
221    """
222    Solve the pressure Poisson equation using rfft2 with batched processing.
223
224    Parameters:
225    -----------
226    RC_real : ndarray of shape (nx, ny_rfft, nz+1)
227        Real part of the right-hand side
228    RC_imag : ndarray of shape (nx, ny_rfft, nz+1)
229        Imaginary part of the right-hand side
230    fRz_real : ndarray
231        Real part of fRz for zero mode processing
232    a_pressure, c_pressure : ndarray of shape (nz+1)
233        Coefficients for tridiagonal matrix
234    b_pressure : ndarray of shape (nx, ny_rfft, nz+1)
235        Main diagonal coefficients for all wavenumbers
236
237    Returns:
238    --------
239    p : ndarray of shape (nx, ny, nz)
240        Pressure field
241    dpdx, dpdy, dpdz : ndarray of shape (nx, ny, nz)
242        Pressure gradient components
243    """
244    # Function to solve for a single (i, j) pair
245    def solve_system(i, j, b_ij, rc_real, rc_imag):
246        # In rfft2, special cases are: zero mode (0,0) and Nyquist (nx//2, 0)
247        is_zero_mode = (i == 0) & (j == 0)
248        is_nyquist_x = (i == nx_rfft - 1)  # This is nx//2 for even nx
249        is_nyquist_y = (j == ny_rfft - 1)  # This is ny//2 for even ny
250        is_special = is_zero_mode | is_nyquist_x | is_nyquist_y
251
252        # Function to solve the tridiagonal system
253        def solve_case(_):
254            L_matrix = PressureMatrix(a_pressure, b_ij, c_pressure)
255            return jnp.linalg.solve(L_matrix, rc_real)[1:], jnp.linalg.solve(L_matrix, rc_imag)[1:]
256
257        # Function to return zeros for special cases
258        def skip_case(_):
259            return jnp.zeros(nz), jnp.zeros(nz)
260
261        return jax.lax.cond(is_special, skip_case, solve_case, None)
262
263    # Define batch size - number of rows to process at a time
264    batch_size = 8
265
266    # Function to process a batch of rows at once using vmap
267    def process_batch(carry, i_batch):
268        # Apply vmap to solve each (i,j) pair in the batch
269        fp_real_batch, fp_imag_batch = jax.vmap(
270            jax.vmap(solve_system, in_axes=(None, 0, 0, 0, 0)),
271            in_axes=(0, None, 0, 0, 0)
272        )(i_batch, jnp.arange(ny_rfft), b_pressure[i_batch], RC_real[i_batch], RC_imag[i_batch])
273
274        return carry, (fp_real_batch, fp_imag_batch)
275
276    # Create batches of row indices
277    num_full_batches = nx // batch_size
278    last_batch_size = nx % batch_size
279
280    # Process full batches first
281    i_batches = jnp.arange(num_full_batches * batch_size).reshape(-1, batch_size)
282    _, (fp_real_batches, fp_imag_batches) = jax.lax.scan(process_batch, None, i_batches)
283
284    # Reshape the batches to correct dimensions
285    fp_real = fp_real_batches.reshape(num_full_batches * batch_size, ny_rfft, nz)
286    fp_imag = fp_imag_batches.reshape(num_full_batches * batch_size, ny_rfft, nz)
287
288    # Handle last partial batch if needed
289    if last_batch_size > 0:
290        last_batch = jnp.arange(num_full_batches * batch_size, nx)
291
292        # Pad the last batch to match batch_size (ensuring consistent processing)
293        padded_last_batch = jnp.pad(last_batch, (0, batch_size - last_batch_size),
294                                    mode='constant', constant_values=last_batch[-1])
295
296        # Process the padded last batch
297        _, (fp_real_last_padded, fp_imag_last_padded) = process_batch(None, padded_last_batch)
298
299        # Extract only the valid results (discard padding results)
300        fp_real_last = fp_real_last_padded[:last_batch_size]
301        fp_imag_last = fp_imag_last_padded[:last_batch_size]
302
303        # Combine results
304        fp_real = jnp.concatenate([fp_real, fp_real_last], axis=0)
305        fp_imag = jnp.concatenate([fp_imag, fp_imag_last], axis=0)
306
307    # Handle the zero mode (i=0, j=0) special case
308    zero_mode_first = RC_real[0, 0, 0]
309    zero_mode_rest = zero_mode_first + jnp.cumsum(fRz_real[0, 0, 1:nz] * dz)
310    zero_mode = jnp.concatenate([jnp.array([zero_mode_first]), zero_mode_rest])
311
312    # Set the zero mode values
313    fp_real = fp_real.at[0, 0].set(zero_mode)
314    fp_imag = fp_imag.at[0, 0].set(jnp.zeros(nz))
315
316    # Combine real and imaginary parts into complex values
317    fp = fp_real + 1j * fp_imag
318
319    # Compute inverse FFT using irfft2 to get back to physical space
320    p = jnp.fft.irfft2(fp, axes=(0, 1), s=(nx, ny))
321
322    # Compute pressure derivatives
323    p_fft = FFT(p)  # Fourier transform of pressure
324
325    # Get wavenumbers for spectral derivatives
326    kx2, ky2 = Wavenumber()
327
328    dpdx = Derivxy(p_fft, kx2)  # x-derivative
329    dpdy = Derivxy(p_fft, ky2)  # y-derivative
330
331    # z-derivative using finite differences
332    dum = ZeRo3DIni()
333    dpdz = Derivz_Generic_uvp(p, dum)
334
335    return p, dpdx, dpdy, dpdz