Pressure Terms
Source Code: NSE_PressureTerms
NSE_PressureTerms.py
1# Copyright (C) 2025 Sukanta Basu
2#
3# This program is free software: you can redistribute it and/or modify
4# it under the terms of the GNU General Public License as published by
5# the Free Software Foundation, either version 3 of the License, or
6# (at your option) any later version.
7#
8# This program is distributed in the hope that it will be useful,
9# but WITHOUT ANY WARRANTY; without even the implied warranty of
10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11# GNU General Public License for more details.
12#
13# You should have received a copy of the GNU General Public License
14# along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16"""
17File: NSE_PressureTerms.py
18=============================
19
20:Author: Sukanta Basu
21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
22 code restructuring, and performance optimization
23:Date: 2025-4-3
24:Description: computes pressure terms by solving the Poisson equation using rfft2
25"""
26
27# ============================================================
28# Imports
29# ============================================================
30
31import jax
32import jax.numpy as jnp
33
34# Import configuration from namelist
35from ..config.ConfigLoader import *
36
37# Import derived variables
38from ..config.DerivedVars import *
39
40# Import FFT modules
41from ..operations.FFT import FFT
42
43# Import derivative functions
44from ..operations.Derivatives import Derivxy, Derivz_Generic_uvp
45
46from ..initialization.Preprocess import Constant, Wavenumber, ZeRo3DIni
47mx, my, nx_rfft, ny_rfft, mx_rfft, my_rfft = Constant()
48
49
50# ============================================================
51# Initialize static variables for pressure solver
52# ============================================================
53
54@jax.jit
55def PressureInit():
56 """
57 Returns:
58 --------
59 kr2_pressure : ndarray, shape (nx, ny_rfft)
60 Wavenumber array for x-direction derivatives, for rfft2
61 kc2_pressure : ndarray, shape (nx, ny_rfft)
62 Wavenumber array for y-direction derivatives, for rfft2
63 a_pressure : ndarray, shape (nz+1)
64 Coefficient array for tridiagonal matrix (sub-diagonal)
65 b_pressure : ndarray, shape (nx, ny_rfft, nz+1)
66 Main diagonal coefficients for each grid point
67 c_pressure : ndarray, shape (nz+1)
68 Coefficient array for tridiagonal matrix (super-diagonal)
69 """
70 # Create 2D wavenumber arrays
71 # For x-direction, use the full range
72 kr2_pressure = jnp.zeros((nx, ny_rfft))
73 for i in range(nx):
74 if i < nx // 2:
75 kr2_pressure = kr2_pressure.at[i, :].set(i)
76 else:
77 kr2_pressure = kr2_pressure.at[i, :].set(i - nx)
78
79 # For y-direction, use only positive frequencies
80 kc2_pressure = jnp.zeros((nx, ny_rfft))
81 for j in range(ny_rfft):
82 kc2_pressure = kc2_pressure.at[:, j].set(j)
83
84 # Create tridiagonal matrix coefficients
85 a_pressure = jnp.ones(nz + 1) / (dz ** 2) # sub-diagonal
86 c_pressure = jnp.ones(nz + 1) / (dz ** 2) # super-diagonal
87
88 # Set boundary conditions for tridiagonal matrix
89 a_pressure = a_pressure.at[0].set(0) # bottom boundary
90 a_pressure = a_pressure.at[nz].set(-1) # top boundary
91 c_pressure = c_pressure.at[0].set(1) # bottom boundary
92 c_pressure = c_pressure.at[nz].set(0) # top boundary
93
94 # Calculate the main diagonal values for each grid point
95 bb = -(kr2_pressure ** 2 + kc2_pressure ** 2 + 2 / (dz ** 2))
96 b_pressure = jnp.repeat(bb[:, :, jnp.newaxis], nz + 1, axis=2)
97 b_pressure = b_pressure.at[:, :, 0].set(-1)
98 b_pressure = b_pressure.at[:, :, nz].set(1)
99
100 return (kr2_pressure, kc2_pressure,
101 a_pressure, b_pressure, c_pressure)
102
103
104# ============================================================
105# Compute right-hand side for pressure equation
106# ============================================================
107
108@jax.jit
109def PressureRC(
110 u, v, w,
111 RHS_u, RHS_v, RHS_w,
112 RHS_u_previous, RHS_v_previous, RHS_w_previous,
113 divtz, kr2_pressure, kc2_pressure):
114 """
115 Calculate the right-hand side of the Poisson equation for pressure using rfft2.
116
117 Parameters:
118 -----------
119 u, v, w : ndarray of shape (nx, ny, nz)
120 Velocity components at current time step
121 RHS_u, RHS_v, RHS_w : ndarray of shape (nx, ny, nz)
122 Right-hand side terms for momentum equations at current time step
123 RHS_u_previous, RHS_v_previous, RHS_w_previous : ndarray of shape (nx, ny, nz)
124 Right-hand side terms for momentum equations at previous time step
125 divtz : ndarray of shape (nx, ny, nz)
126 Divergence of the SGS stress tensor in z-direction
127 kr2_pressure, kc2_pressure : ndarray
128 Wavenumber arrays for spectral derivatives with rfft2
129
130 Returns:
131 --------
132 RC_real : ndarray of shape (nx, ny_rfft, nz+1)
133 Real part of the right-hand side for pressure Poisson equation
134 RC_imag : ndarray of shape (nx, ny_rfft, nz+1)
135 Imaginary part of the right-hand side for pressure Poisson equation
136 fRz_real : ndarray
137 Real part of fRz for zero mode processing
138 """
139 # Compute intermediate terms
140 Rx = RHS_u - (1 / 3) * RHS_u_previous + (2 / (3 * dt_nondim)) * u
141 Ry = RHS_v - (1 / 3) * RHS_v_previous + (2 / (3 * dt_nondim)) * v
142 Rz = RHS_w - (1 / 3) * RHS_w_previous + (2 / (3 * dt_nondim)) * w
143
144 # Calculate rfft2 transforms
145 fRx = jnp.fft.rfft2(Rx, axes=(0, 1))
146 fRy = jnp.fft.rfft2(Ry, axes=(0, 1))
147 fRz = jnp.fft.rfft2(Rz, axes=(0, 1))
148
149 # Initialize arrays for RC_real and RC_imag with zeros
150 RC_real = jnp.zeros((nx, ny_rfft, nz + 1))
151 RC_imag = jnp.zeros((nx, ny_rfft, nz + 1))
152
153 # Get boundary terms
154 fbot = jnp.fft.rfft2(divtz[:, :, 0], axes=(0, 1))
155 ftop = jnp.fft.rfft2(divtz[:, :, -1], axes=(0, 1))
156
157 # Set boundary conditions
158 RC_real = RC_real.at[:, :, 0].set(-dz * jnp.real(fbot))
159 RC_imag = RC_imag.at[:, :, 0].set(-dz * jnp.imag(fbot))
160 RC_real = RC_real.at[:, :, nz].set(-dz * jnp.real(ftop))
161 RC_imag = RC_imag.at[:, :, nz].set(-dz * jnp.imag(ftop))
162
163 # Prepare wavenumbers for broadcasting
164 kr2_3d = kr2_pressure[:, :, jnp.newaxis]
165 kc2_3d = kc2_pressure[:, :, jnp.newaxis]
166
167 # Compute horizontal derivatives for interior points
168 horiz_deriv_real = -kr2_3d * jnp.imag(fRx[:, :, :-1])
169 horiz_deriv_real -= kc2_3d * jnp.imag(fRy[:, :, :-1])
170
171 horiz_deriv_imag = kr2_3d * jnp.real(fRx[:, :, :-1])
172 horiz_deriv_imag += kc2_3d * jnp.real(fRy[:, :, :-1])
173
174 # Compute vertical derivatives
175 vert_diff_real = jnp.diff(jnp.real(fRz), axis=2) / dz
176 vert_diff_imag = jnp.diff(jnp.imag(fRz), axis=2) / dz
177
178 # Update interior points
179 RC_real = RC_real.at[:, :, 1:nz].set(horiz_deriv_real + vert_diff_real)
180 RC_imag = RC_imag.at[:, :, 1:nz].set(horiz_deriv_imag + vert_diff_imag)
181
182 return RC_real, RC_imag, jnp.real(fRz)
183
184
185@jax.jit
186def PressureMatrix(a_pressure, b_pressure_ij, c_pressure):
187 """
188 Create a tridiagonal matrix for pressure solver for a specific (i,j) wavenumber.
189
190 Parameters:
191 -----------
192 a_pressure : ndarray, shape (nz+1)
193 Lower diagonal coefficients of the tridiagonal matrix
194 b_pressure_ij : ndarray, shape (nz+1)
195 Main diagonal coefficients for a specific wavenumber pair (i,j)
196 c_pressure : ndarray, shape (nz+1)
197 Upper diagonal coefficients of the tridiagonal matrix
198
199 Returns:
200 --------
201 L_matrix : ndarray, shape (nz+1, nz+1)
202 Tridiagonal matrix for the specific wavenumber pair
203 """
204 # Initialize the tridiagonal matrix
205 L_matrix = jnp.zeros((nz + 1, nz + 1))
206
207 # Set main diagonal (b values)
208 L_matrix = L_matrix.at[jnp.arange(nz + 1), jnp.arange(nz + 1)].set(b_pressure_ij)
209
210 # Set lower diagonal (a values) - skip first row
211 L_matrix = L_matrix.at[jnp.arange(1, nz + 1), jnp.arange(nz)].set(a_pressure[1:])
212
213 # Set upper diagonal (c values) - skip last row
214 L_matrix = L_matrix.at[jnp.arange(nz), jnp.arange(1, nz + 1)].set(c_pressure[:nz])
215
216 return L_matrix
217
218
219@jax.jit
220def PressureSolve(RC_real, RC_imag, fRz_real, a_pressure, b_pressure, c_pressure):
221 """
222 Solve the pressure Poisson equation using rfft2 with batched processing.
223
224 Parameters:
225 -----------
226 RC_real : ndarray of shape (nx, ny_rfft, nz+1)
227 Real part of the right-hand side
228 RC_imag : ndarray of shape (nx, ny_rfft, nz+1)
229 Imaginary part of the right-hand side
230 fRz_real : ndarray
231 Real part of fRz for zero mode processing
232 a_pressure, c_pressure : ndarray of shape (nz+1)
233 Coefficients for tridiagonal matrix
234 b_pressure : ndarray of shape (nx, ny_rfft, nz+1)
235 Main diagonal coefficients for all wavenumbers
236
237 Returns:
238 --------
239 p : ndarray of shape (nx, ny, nz)
240 Pressure field
241 dpdx, dpdy, dpdz : ndarray of shape (nx, ny, nz)
242 Pressure gradient components
243 """
244 # Function to solve for a single (i, j) pair
245 def solve_system(i, j, b_ij, rc_real, rc_imag):
246 # In rfft2, special cases are: zero mode (0,0) and Nyquist (nx//2, 0)
247 is_zero_mode = (i == 0) & (j == 0)
248 is_nyquist_x = (i == nx_rfft - 1) # This is nx//2 for even nx
249 is_nyquist_y = (j == ny_rfft - 1) # This is ny//2 for even ny
250 is_special = is_zero_mode | is_nyquist_x | is_nyquist_y
251
252 # Function to solve the tridiagonal system
253 def solve_case(_):
254 L_matrix = PressureMatrix(a_pressure, b_ij, c_pressure)
255 return jnp.linalg.solve(L_matrix, rc_real)[1:], jnp.linalg.solve(L_matrix, rc_imag)[1:]
256
257 # Function to return zeros for special cases
258 def skip_case(_):
259 return jnp.zeros(nz), jnp.zeros(nz)
260
261 return jax.lax.cond(is_special, skip_case, solve_case, None)
262
263 # Define batch size - number of rows to process at a time
264 batch_size = 8
265
266 # Function to process a batch of rows at once using vmap
267 def process_batch(carry, i_batch):
268 # Apply vmap to solve each (i,j) pair in the batch
269 fp_real_batch, fp_imag_batch = jax.vmap(
270 jax.vmap(solve_system, in_axes=(None, 0, 0, 0, 0)),
271 in_axes=(0, None, 0, 0, 0)
272 )(i_batch, jnp.arange(ny_rfft), b_pressure[i_batch], RC_real[i_batch], RC_imag[i_batch])
273
274 return carry, (fp_real_batch, fp_imag_batch)
275
276 # Create batches of row indices
277 num_full_batches = nx // batch_size
278 last_batch_size = nx % batch_size
279
280 # Process full batches first
281 i_batches = jnp.arange(num_full_batches * batch_size).reshape(-1, batch_size)
282 _, (fp_real_batches, fp_imag_batches) = jax.lax.scan(process_batch, None, i_batches)
283
284 # Reshape the batches to correct dimensions
285 fp_real = fp_real_batches.reshape(num_full_batches * batch_size, ny_rfft, nz)
286 fp_imag = fp_imag_batches.reshape(num_full_batches * batch_size, ny_rfft, nz)
287
288 # Handle last partial batch if needed
289 if last_batch_size > 0:
290 last_batch = jnp.arange(num_full_batches * batch_size, nx)
291
292 # Pad the last batch to match batch_size (ensuring consistent processing)
293 padded_last_batch = jnp.pad(last_batch, (0, batch_size - last_batch_size),
294 mode='constant', constant_values=last_batch[-1])
295
296 # Process the padded last batch
297 _, (fp_real_last_padded, fp_imag_last_padded) = process_batch(None, padded_last_batch)
298
299 # Extract only the valid results (discard padding results)
300 fp_real_last = fp_real_last_padded[:last_batch_size]
301 fp_imag_last = fp_imag_last_padded[:last_batch_size]
302
303 # Combine results
304 fp_real = jnp.concatenate([fp_real, fp_real_last], axis=0)
305 fp_imag = jnp.concatenate([fp_imag, fp_imag_last], axis=0)
306
307 # Handle the zero mode (i=0, j=0) special case
308 zero_mode_first = RC_real[0, 0, 0]
309 zero_mode_rest = zero_mode_first + jnp.cumsum(fRz_real[0, 0, 1:nz] * dz)
310 zero_mode = jnp.concatenate([jnp.array([zero_mode_first]), zero_mode_rest])
311
312 # Set the zero mode values
313 fp_real = fp_real.at[0, 0].set(zero_mode)
314 fp_imag = fp_imag.at[0, 0].set(jnp.zeros(nz))
315
316 # Combine real and imaginary parts into complex values
317 fp = fp_real + 1j * fp_imag
318
319 # Compute inverse FFT using irfft2 to get back to physical space
320 p = jnp.fft.irfft2(fp, axes=(0, 1), s=(nx, ny))
321
322 # Compute pressure derivatives
323 p_fft = FFT(p) # Fourier transform of pressure
324
325 # Get wavenumbers for spectral derivatives
326 kx2, ky2 = Wavenumber()
327
328 dpdx = Derivxy(p_fft, kx2) # x-derivative
329 dpdy = Derivxy(p_fft, ky2) # y-derivative
330
331 # z-derivative using finite differences
332 dum = ZeRo3DIni()
333 dpdz = Derivz_Generic_uvp(p, dum)
334
335 return p, dpdx, dpdy, dpdz