Import All Modules
Source Code: Imports
Imports.py
1# Copyright (C) 2025 Sukanta Basu
2#
3# This program is free software: you can redistribute it and/or modify
4# it under the terms of the GNU General Public License as published by
5# the Free Software Foundation, either version 3 of the License, or
6# (at your option) any later version.
7#
8# This program is distributed in the hope that it will be useful,
9# but WITHOUT ANY WARRANTY; without even the implied warranty of
10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11# GNU General Public License for more details.
12#
13# You should have received a copy of the GNU General Public License
14# along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16"""
17File: Imports.py
18=======================
19
20:Author: Sukanta Basu
21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
22 code restructuring, and performance optimization
23:Date: 2025-4-3
24:Description: imports all the modules for JAX-ALFA
25"""
26
27import os
28
29# Import configurations (namelists)
30from . import ConfigLoader as Config
31
32if Config.optGPU == 0:
33 os.environ["CUDA_VISIBLE_DEVICES"] = ""
34 os.environ["JAX_PLATFORMS"] = "cpu"
35 os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=true"
36 print("JAX environment configured for CPU")
37else:
38 if "CUDA_VISIBLE_DEVICES" not in os.environ:
39 # Local run: honour GPU_ID from Config
40 os.environ["CUDA_VISIBLE_DEVICES"] = str(Config.GPU_ID)
41 print(f"JAX environment configured for GPU {Config.GPU_ID}")
42 else:
43 # Cluster run: SLURM already set CUDA_VISIBLE_DEVICES; don't override
44 print(f"JAX environment configured for GPU "
45 f"(CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']})")
46
47import jax
48
49if Config.use_double_precision:
50 jax.config.update("jax_enable_x64", True)
51
52# Uncomment this line if the dynamic RAM allocation is needed
53# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
54
55
56def ImportLES():
57 # Get the caller's global namespace
58 import inspect
59 caller_globals = inspect.currentframe().f_back.f_globals
60
61 # Basic libraries
62 import time
63 import numpy as np
64 import jax
65 import jax.numpy as jnp
66
67 # Import derived variables module
68 from . import DerivedVars
69
70 # Imports from initialization
71 from ..initialization.Initialization import Initialize_uvw, Initialize_TH
72 from ..initialization.Initialization import Initialize_Q
73 from ..initialization.Initialization import Initialize_MoistureSurfaceBC
74 from ..initialization.Initialization import Initialize_GeoWind
75 from ..initialization.Initialization import Initialize_GeoWind_Varying
76 from ..initialization.Initialization import Initialize_RayleighDampingLayer
77 from ..initialization.Initialization import Initialize_SurfaceBC
78 from ..initialization.Initialization import Initialize_AdvForcing
79 from ..initialization.Preprocess import Wavenumber, Constant
80 from ..initialization.Preprocess import ZeRo3DIni, ZeRo2DIni, ZeRo1DIni
81 from ..initialization.Preprocess import ZeRo3D_padIni, ZeRo3D_fftIni, ZeRo3D_pad_fftIni
82
83 # Imports from operations
84 from ..operations.Derivatives import Derivxy, Derivz_M, velocityGradients
85 from ..operations.Derivatives import Derivz_TH, potentialTemperatureGradients
86 from ..operations.Derivatives import moistureGradients
87 from ..operations.Derivatives import Derivz_Generic_uvp, Derivz_Generic_w
88 from ..operations.FFT import FFT, FFT_pad
89 from ..operations.Filtering import Filtering_Explicit, Filtering_Level1, Filtering_Level2
90 from ..operations.Dealiasing import Dealias1, Dealias2
91
92 # Imports from utilities
93 from ..utilities.Utilities import StagGridAvg, LogMemory
94 from ..utilities.Statistics import ComputeStats, InitializeStats
95
96 # Imports from subgridscale
97 from ..subgridscale.StrainRates import StrainsUVPnodes_Dealias, StrainsWnodes_Dealias
98 from ..subgridscale.StrainRates import StrainsUVPnodes_NoDealias, StrainsWnodes_NoDealias
99 from ..subgridscale.SGSStresses_SM import StressesUVPnodes_Dealias, StressesWnodes_Dealias
100 from ..subgridscale.SGSStresses_SM import StressesUVPnodes_NoDealias, StressesWnodes_NoDealias
101 from ..subgridscale.SGSStresses_SM import Wall
102 from ..subgridscale.DynamicSGS_Main import DynamicSGS, DynamicSGSscalar
103 from ..subgridscale.DynamicSGS_LASDD_SM import LASDD as LASDD_SM
104 from ..subgridscale.DynamicSGS_LASDD_WL import LASDD as LASDD_WL
105 from ..subgridscale.DynamicSGS_ScalarLASDD_SM import ScalarLASDD as ScalarLASDD_SM
106 from ..subgridscale.DynamicSGS_ScalarLASDD_WL import ScalarLASDD as ScalarLASDD_WL
107
108 # Get constants from Preprocess
109 mx, my, nx_rfft, ny_rfft, mx_rfft, my_rfft = Constant()
110
111 # Select versions based on configuration
112 if DerivedVars.optDealias == 1:
113 StrainsUVPnodes = StrainsUVPnodes_Dealias
114 StrainsWnodes = StrainsWnodes_Dealias
115 else:
116 StrainsUVPnodes = StrainsUVPnodes_NoDealias
117 StrainsWnodes = StrainsWnodes_NoDealias
118
119 # Imports from surface
120 from ..surface.SurfaceFlux import MOSTstable, MOSTunstable
121 from ..surface.SurfaceFlux import SurfaceFlux_HomogeneousConstantFlux
122 from ..surface.SurfaceFlux import SurfaceFlux_HeterogeneousConstantFlux
123 from ..surface.SurfaceFlux import SurfaceFlux_HomogeneousVaryingFlux
124 from ..surface.SurfaceFlux import SurfaceFlux_HeterogeneousVaryingFlux
125 from ..surface.SurfaceFlux import SurfaceFlux_HomogeneousPrescribedTemperature
126 from ..surface.SurfaceFlux import SurfaceFlux_HeterogeneousPrescribedTemperature
127 from ..surface.SurfaceFlux import SurfaceMoistureFlux_HomogeneousPrescribedQ
128 from ..surface.SurfaceFlux import SurfaceMoistureFlux_HeterogeneousPrescribedQ
129
130 # Imports from pde
131 from ..pde.NSE_AdvectionTerms import Advection
132 from ..pde.SCL_AdvectionTerms import ScalarAdvection
133 from ..pde.NSE_BuoyancyTerms import BuoyancyOpt1, BuoyancyOpt2
134 from ..pde.NSE_PressureTerms import PressureInit, PressureRC
135 from ..pde.NSE_PressureTerms import PressureMatrix, PressureSolve
136 if Config.optPressureSolver == 1:
137 from ..pde.NSE_PressureTerms_Thomas import ThomasPressureInit, ThomasPressureSolve
138 from ..pde.NSE_SGSTerms import DivStressStaticSGS, DivStressDynamicSGS
139 from ..pde.SCL_SGSTerms import DivFluxStaticSGS, DivFluxDynamicSGS
140 from ..pde.NSE_SGSTerms_STABSM import DivStressStaticSGS_STABSM
141 from ..pde.SCL_SGSTerms_STABSM import DivFluxStaticSGS_STABSM
142 from ..pde.NSE_AllTerms import RHS_Momentum
143 from ..pde.SCL_AllTerms import RHS_Scalar, RHS_Moisture
144 from ..pde.NSE_TimeAdvancement import AB2_uvw
145 from ..pde.SCL_TimeAdvancement import AB2_TH, AB2_Q
146
147 # Add all imports to namespace
148 for name, value in locals().items():
149 if not name.startswith('_') and name != 'caller_globals' and name != 'inspect':
150 caller_globals[name] = value
151
152 # Add all constants from Config to namespace
153 for name in dir(Config):
154 if not name.startswith('_') and not name.startswith('np'):
155 # Add the variable to caller's namespace
156 caller_globals[name] = getattr(Config, name)