Preprocessing of Variables
Source Code: Preprocess
Preprocess.py
1# Copyright (C) 2025 Sukanta Basu
2#
3# This program is free software: you can redistribute it and/or modify
4# it under the terms of the GNU General Public License as published by
5# the Free Software Foundation, either version 3 of the License, or
6# (at your option) any later version.
7#
8# This program is distributed in the hope that it will be useful,
9# but WITHOUT ANY WARRANTY; without even the implied warranty of
10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11# GNU General Public License for more details.
12#
13# You should have received a copy of the GNU General Public License
14# along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16"""
17File: Preprocess.py
18==========================
19
20:Author: Sukanta Basu
21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
22 code restructuring, and performance optimization
23:Date: 2025-4-3
24:Description: generates static variables which will be re-used
25numerous times during a simulation
26"""
27
28
29# ============================================================
30# Imports
31# ============================================================
32
33import jax.numpy as jnp
34
35# Import configuration from namelist
36from ..config.ConfigLoader import *
37
38
39# ============================================================
40# FFT-related constants
41# ============================================================
42
43
44def Constant():
45 """ Note: I am using _loc in variable names to distinguish
46 them from outer scope variable names"""
47
48 # Determine padded dimensions
49 mx_loc = int(1.5 * nx)
50 my_loc = int(1.5 * ny)
51
52 # Number of Fourier modes in x- and y-directions for input arrays
53 nx_rfft_loc = nx // 2 + 1
54 ny_rfft_loc = ny // 2 + 1
55
56 # Number of Fourier modes in x- and y-directions for padded arrays
57 mx_rfft_loc = mx_loc // 2 + 1
58 my_rfft_loc = my_loc // 2 + 1
59
60 return (mx_loc, my_loc,
61 nx_rfft_loc, ny_rfft_loc,
62 mx_rfft_loc, my_rfft_loc)
63
64
65mx, my, nx_rfft, ny_rfft, mx_rfft, my_rfft = Constant()
66
67
68# ============================================================
69# Wavenumbers related to real FFT
70# ============================================================
71
72
73def Wavenumber():
74 """
75 Parameters:
76 -----------
77 None: Uses global parameters from Config
78
79 Returns:
80 --------
81 kx2 : ndarray, shape (nx, ny//2 + 1, nz)
82 Wavenumber array for x-direction derivatives, broadcast to match FFT dimensions
83 ky2 : ndarray, shape (nx, ny//2 + 1, nz)
84 Wavenumber array for y-direction derivatives, broadcast to match FFT dimensions
85
86 Note: rfftfreq is used for y as we are using real FFT
87 """
88
89 kx = jnp.fft.fftfreq(nx, 1 / nx)
90 ky = jnp.fft.rfftfreq(ny, 1 / ny)
91
92 # Zeroing Nyquist frequencies to avoid instabilities
93 kx = kx.at[nx // 2].set(0)
94 ky = ky.at[ny // 2].set(0)
95
96 kx2 = kx[:, None, None] # Reshape for broadcasting
97 kx2 = jnp.broadcast_to(kx2, (nx, ny_rfft, nz))
98
99 ky2 = ky[None, :, None] # Reshape for broadcasting
100 ky2 = jnp.broadcast_to(ky2, (nx, ny_rfft, nz))
101
102 return kx2, ky2
103
104
105# ============================================================
106# Array Initialization Functions
107# ============================================================
108
109
110def ZeRo3DIni():
111 """Create an array of zeros."""
112 if use_double_precision:
113 ZeRo = jnp.zeros((nx, ny, nz), dtype=jnp.float64)
114 else:
115 ZeRo = jnp.zeros((nx, ny, nz), dtype=jnp.float32)
116 return ZeRo
117
118
119def ZeRo2DIni():
120 """Create an array of zeros."""
121 if use_double_precision:
122 ZeRo2D = jnp.zeros((nx, ny), dtype=jnp.float64)
123 else:
124 ZeRo2D = jnp.zeros((nx, ny), dtype=jnp.float32)
125 return ZeRo2D
126
127
128def ZeRo1DIni():
129 """Create an array of zeros."""
130 if use_double_precision:
131 ZeRo1D = jnp.zeros((nz), dtype=jnp.float64)
132 else:
133 ZeRo1D = jnp.zeros((nz), dtype=jnp.float32)
134 return ZeRo1D
135
136
137def ZeRo3D_fftIni():
138 """Create an array of zeros for Fourier space operations with rfft2"""
139 if use_double_precision:
140 ZeRo3D_fft = jnp.zeros((nx, ny_rfft, nz), dtype=jnp.complex128)
141 else:
142 ZeRo3D_fft = jnp.zeros((nx, ny_rfft, nz), dtype=jnp.complex64)
143 return ZeRo3D_fft
144
145
146def ZeRo3D_padIni():
147 """Create a padded array of zeros"""
148 if use_double_precision:
149 ZeRo3D_pad_fft = jnp.zeros((mx, my, nz), dtype=jnp.float64)
150 else:
151 ZeRo3D_pad_fft = jnp.zeros((mx, my, nz), dtype=jnp.float32)
152 return ZeRo3D_pad_fft
153
154
155def ZeRo3D_pad_fftIni():
156 """Create a padded array of zeros for Fourier space operations with rfft2"""
157 if use_double_precision:
158 ZeRo3D_pad_fft = jnp.zeros((mx, my_rfft, nz), dtype=jnp.complex128)
159 else:
160 ZeRo3D_pad_fft = jnp.zeros((mx, my_rfft, nz), dtype=jnp.complex64)
161 return ZeRo3D_pad_fft