Preprocessing of Variables

Source Code: Preprocess

Preprocess.py
  1# Copyright (C) 2025 Sukanta Basu
  2#
  3# This program is free software: you can redistribute it and/or modify
  4# it under the terms of the GNU General Public License as published by
  5# the Free Software Foundation, either version 3 of the License, or
  6# (at your option) any later version.
  7#
  8# This program is distributed in the hope that it will be useful,
  9# but WITHOUT ANY WARRANTY; without even the implied warranty of
 10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11# GNU General Public License for more details.
 12#
 13# You should have received a copy of the GNU General Public License
 14# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15
 16"""
 17File: Preprocess.py
 18==========================
 19
 20:Author: Sukanta Basu
 21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
 22                code restructuring, and performance optimization
 23:Date: 2025-4-3
 24:Description: generates static variables which will be re-used
 25numerous times during a simulation
 26"""
 27
 28
 29# ============================================================
 30#  Imports
 31# ============================================================
 32
 33import jax.numpy as jnp
 34
 35# Import configuration from namelist
 36from ..config.ConfigLoader import *
 37
 38
 39# ============================================================
 40#  FFT-related constants
 41# ============================================================
 42
 43
 44def Constant():
 45    """ Note: I am using _loc in variable names to distinguish
 46    them from outer scope variable names"""
 47
 48    # Determine padded dimensions
 49    mx_loc = int(1.5 * nx)
 50    my_loc = int(1.5 * ny)
 51
 52    # Number of Fourier modes in x- and y-directions for input arrays
 53    nx_rfft_loc = nx // 2 + 1
 54    ny_rfft_loc = ny // 2 + 1
 55
 56    # Number of Fourier modes in x- and y-directions for padded arrays
 57    mx_rfft_loc = mx_loc // 2 + 1
 58    my_rfft_loc = my_loc // 2 + 1
 59
 60    return (mx_loc, my_loc,
 61            nx_rfft_loc, ny_rfft_loc,
 62            mx_rfft_loc, my_rfft_loc)
 63
 64
 65mx, my, nx_rfft, ny_rfft, mx_rfft, my_rfft = Constant()
 66
 67
 68# ============================================================
 69# Wavenumbers related to real FFT
 70# ============================================================
 71
 72
 73def Wavenumber():
 74    """
 75    Parameters:
 76    -----------
 77    None: Uses global parameters from Config
 78
 79    Returns:
 80    --------
 81    kx2 : ndarray, shape (nx, ny//2 + 1, nz)
 82        Wavenumber array for x-direction derivatives, broadcast to match FFT dimensions
 83    ky2 : ndarray, shape (nx, ny//2 + 1, nz)
 84        Wavenumber array for y-direction derivatives, broadcast to match FFT dimensions
 85
 86    Note: rfftfreq is used for y as we are using real FFT
 87    """
 88
 89    kx = jnp.fft.fftfreq(nx, 1 / nx)
 90    ky = jnp.fft.rfftfreq(ny, 1 / ny)
 91
 92    # Zeroing Nyquist frequencies to avoid instabilities
 93    kx = kx.at[nx // 2].set(0)
 94    ky = ky.at[ny // 2].set(0)
 95
 96    kx2 = kx[:, None, None]  # Reshape for broadcasting
 97    kx2 = jnp.broadcast_to(kx2, (nx, ny_rfft, nz))
 98
 99    ky2 = ky[None, :, None]  # Reshape for broadcasting
100    ky2 = jnp.broadcast_to(ky2, (nx, ny_rfft, nz))
101
102    return kx2, ky2
103
104
105# ============================================================
106# Array Initialization Functions
107# ============================================================
108
109
110def ZeRo3DIni():
111    """Create an array of zeros."""
112    if use_double_precision:
113        ZeRo = jnp.zeros((nx, ny, nz), dtype=jnp.float64)
114    else:
115        ZeRo = jnp.zeros((nx, ny, nz), dtype=jnp.float32)
116    return ZeRo
117
118
119def ZeRo2DIni():
120    """Create an array of zeros."""
121    if use_double_precision:
122        ZeRo2D = jnp.zeros((nx, ny), dtype=jnp.float64)
123    else:
124        ZeRo2D = jnp.zeros((nx, ny), dtype=jnp.float32)
125    return ZeRo2D
126
127
128def ZeRo1DIni():
129    """Create an array of zeros."""
130    if use_double_precision:
131        ZeRo1D = jnp.zeros((nz), dtype=jnp.float64)
132    else:
133        ZeRo1D = jnp.zeros((nz), dtype=jnp.float32)
134    return ZeRo1D
135
136
137def ZeRo3D_fftIni():
138    """Create an array of zeros for Fourier space operations with rfft2"""
139    if use_double_precision:
140        ZeRo3D_fft = jnp.zeros((nx, ny_rfft, nz), dtype=jnp.complex128)
141    else:
142        ZeRo3D_fft = jnp.zeros((nx, ny_rfft, nz), dtype=jnp.complex64)
143    return ZeRo3D_fft
144
145
146def ZeRo3D_padIni():
147    """Create a padded array of zeros"""
148    if use_double_precision:
149        ZeRo3D_pad_fft = jnp.zeros((mx, my, nz), dtype=jnp.float64)
150    else:
151        ZeRo3D_pad_fft = jnp.zeros((mx, my, nz), dtype=jnp.float32)
152    return ZeRo3D_pad_fft
153
154
155def ZeRo3D_pad_fftIni():
156    """Create a padded array of zeros for Fourier space operations with rfft2"""
157    if use_double_precision:
158        ZeRo3D_pad_fft = jnp.zeros((mx, my_rfft, nz), dtype=jnp.complex128)
159    else:
160        ZeRo3D_pad_fft = jnp.zeros((mx, my_rfft, nz), dtype=jnp.complex64)
161    return ZeRo3D_pad_fft