Import All Modules

Source Code: Imports

Imports.py
  1# Copyright (C) 2025 Sukanta Basu
  2#
  3# This program is free software: you can redistribute it and/or modify
  4# it under the terms of the GNU General Public License as published by
  5# the Free Software Foundation, either version 3 of the License, or
  6# (at your option) any later version.
  7#
  8# This program is distributed in the hope that it will be useful,
  9# but WITHOUT ANY WARRANTY; without even the implied warranty of
 10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11# GNU General Public License for more details.
 12#
 13# You should have received a copy of the GNU General Public License
 14# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15
 16"""
 17File: Imports.py
 18=======================
 19
 20:Author: Sukanta Basu
 21:AI Assistance: Claude Code (Anthropic) and Codex (OpenAI) are used for documentation,
 22                code restructuring, and performance optimization
 23:Date: 2025-4-3
 24:Description: imports all the modules for JAX-ALFA
 25"""
 26
 27import os
 28
 29# Import configurations (namelists)
 30from . import ConfigLoader as Config
 31
 32if Config.optGPU == 0:
 33    os.environ["CUDA_VISIBLE_DEVICES"] = ""
 34    os.environ["JAX_PLATFORMS"] = "cpu"
 35    os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=true"
 36    print("JAX environment configured for CPU")
 37else:
 38    if "CUDA_VISIBLE_DEVICES" not in os.environ:
 39        # Local run: honour GPU_ID from Config
 40        os.environ["CUDA_VISIBLE_DEVICES"] = str(Config.GPU_ID)
 41        print(f"JAX environment configured for GPU {Config.GPU_ID}")
 42    else:
 43        # Cluster run: SLURM already set CUDA_VISIBLE_DEVICES; don't override
 44        print(f"JAX environment configured for GPU "
 45              f"(CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']})")
 46
 47import jax
 48
 49if Config.use_double_precision:
 50    jax.config.update("jax_enable_x64", True)
 51
 52# Uncomment this line if the dynamic RAM allocation is needed
 53# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
 54
 55
 56def ImportLES():
 57    # Get the caller's global namespace
 58    import inspect
 59    caller_globals = inspect.currentframe().f_back.f_globals
 60
 61    # Basic libraries
 62    import time
 63    import numpy as np
 64    import jax
 65    import jax.numpy as jnp
 66
 67    # Import derived variables module
 68    from . import DerivedVars
 69
 70    # Imports from initialization
 71    from ..initialization.Initialization import Initialize_uvw, Initialize_TH
 72    from ..initialization.Initialization import Initialize_Q
 73    from ..initialization.Initialization import Initialize_MoistureSurfaceBC
 74    from ..initialization.Initialization import Initialize_GeoWind
 75    from ..initialization.Initialization import Initialize_GeoWind_Varying
 76    from ..initialization.Initialization import Initialize_RayleighDampingLayer
 77    from ..initialization.Initialization import Initialize_SurfaceBC
 78    from ..initialization.Initialization import Initialize_AdvForcing
 79    from ..initialization.Preprocess import Wavenumber, Constant
 80    from ..initialization.Preprocess import ZeRo3DIni, ZeRo2DIni, ZeRo1DIni
 81    from ..initialization.Preprocess import ZeRo3D_padIni, ZeRo3D_fftIni, ZeRo3D_pad_fftIni
 82
 83    # Imports from operations
 84    from ..operations.Derivatives import Derivxy, Derivz_M, velocityGradients
 85    from ..operations.Derivatives import Derivz_TH, potentialTemperatureGradients
 86    from ..operations.Derivatives import moistureGradients
 87    from ..operations.Derivatives import Derivz_Generic_uvp, Derivz_Generic_w
 88    from ..operations.FFT import FFT, FFT_pad
 89    from ..operations.Filtering import Filtering_Explicit, Filtering_Level1, Filtering_Level2
 90    from ..operations.Dealiasing import Dealias1, Dealias2
 91
 92    # Imports from utilities
 93    from ..utilities.Utilities import StagGridAvg, LogMemory
 94    from ..utilities.Statistics import ComputeStats, InitializeStats
 95
 96    # Imports from subgridscale
 97    from ..subgridscale.StrainRates import StrainsUVPnodes_Dealias, StrainsWnodes_Dealias
 98    from ..subgridscale.StrainRates import StrainsUVPnodes_NoDealias, StrainsWnodes_NoDealias
 99    from ..subgridscale.SGSStresses_SM import StressesUVPnodes_Dealias, StressesWnodes_Dealias
100    from ..subgridscale.SGSStresses_SM import StressesUVPnodes_NoDealias, StressesWnodes_NoDealias
101    from ..subgridscale.SGSStresses_SM import Wall
102    from ..subgridscale.DynamicSGS_Main import DynamicSGS, DynamicSGSscalar
103    from ..subgridscale.DynamicSGS_LASDD_SM import LASDD as LASDD_SM
104    from ..subgridscale.DynamicSGS_LASDD_WL import LASDD as LASDD_WL
105    from ..subgridscale.DynamicSGS_ScalarLASDD_SM import ScalarLASDD as ScalarLASDD_SM
106    from ..subgridscale.DynamicSGS_ScalarLASDD_WL import ScalarLASDD as ScalarLASDD_WL
107
108    # Get constants from Preprocess
109    mx, my, nx_rfft, ny_rfft, mx_rfft, my_rfft = Constant()
110
111    # Select versions based on configuration
112    if DerivedVars.optDealias == 1:
113        StrainsUVPnodes = StrainsUVPnodes_Dealias
114        StrainsWnodes = StrainsWnodes_Dealias
115    else:
116        StrainsUVPnodes = StrainsUVPnodes_NoDealias
117        StrainsWnodes = StrainsWnodes_NoDealias
118
119    # Imports from surface
120    from ..surface.SurfaceFlux import MOSTstable, MOSTunstable
121    from ..surface.SurfaceFlux import SurfaceFlux_HomogeneousConstantFlux
122    from ..surface.SurfaceFlux import SurfaceFlux_HeterogeneousConstantFlux
123    from ..surface.SurfaceFlux import SurfaceFlux_HomogeneousVaryingFlux
124    from ..surface.SurfaceFlux import SurfaceFlux_HeterogeneousVaryingFlux
125    from ..surface.SurfaceFlux import SurfaceFlux_HomogeneousPrescribedTemperature
126    from ..surface.SurfaceFlux import SurfaceFlux_HeterogeneousPrescribedTemperature
127    from ..surface.SurfaceFlux import SurfaceMoistureFlux_HomogeneousPrescribedQ
128    from ..surface.SurfaceFlux import SurfaceMoistureFlux_HeterogeneousPrescribedQ
129
130    # Imports from pde
131    from ..pde.NSE_AdvectionTerms import Advection
132    from ..pde.SCL_AdvectionTerms import ScalarAdvection
133    from ..pde.NSE_BuoyancyTerms import BuoyancyOpt1, BuoyancyOpt2
134    from ..pde.NSE_PressureTerms import PressureInit, PressureRC
135    from ..pde.NSE_PressureTerms import PressureMatrix, PressureSolve
136    if Config.optPressureSolver == 1:
137        from ..pde.NSE_PressureTerms_Thomas import ThomasPressureInit, ThomasPressureSolve
138    from ..pde.NSE_SGSTerms import DivStressStaticSGS, DivStressDynamicSGS
139    from ..pde.SCL_SGSTerms import DivFluxStaticSGS, DivFluxDynamicSGS
140    from ..pde.NSE_SGSTerms_STABSM import DivStressStaticSGS_STABSM
141    from ..pde.SCL_SGSTerms_STABSM import DivFluxStaticSGS_STABSM
142    from ..pde.NSE_AllTerms import RHS_Momentum
143    from ..pde.SCL_AllTerms import RHS_Scalar, RHS_Moisture
144    from ..pde.NSE_TimeAdvancement import AB2_uvw
145    from ..pde.SCL_TimeAdvancement import AB2_TH, AB2_Q
146
147    # Add all imports to namespace
148    for name, value in locals().items():
149        if not name.startswith('_') and name != 'caller_globals' and name != 'inspect':
150            caller_globals[name] = value
151
152    # Add all constants from Config to namespace
153    for name in dir(Config):
154        if not name.startswith('_') and not name.startswith('np'):
155            # Add the variable to caller's namespace
156            caller_globals[name] = getattr(Config, name)