Skip to content

Commit b9036f0

Browse files
author
jngaravitoc
committed
minor updated in basis functionality
1 parent 90ec1f9 commit b9036f0

2 files changed

Lines changed: 52 additions & 35 deletions

File tree

EXPtools/basis/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
check_basis_params,
44
write_config,
55
make_basis,
6+
config_dict_to_yaml,
67
)
78
from .makemodel import (
89
write_table,
@@ -16,4 +17,5 @@
1617
"make_basis",
1718
"write_table",
1819
"make_model",
20+
"config_dict_to_yaml",
1921
]

EXPtools/basis/basis_utils.py

Lines changed: 50 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
import pyEXP
66
from EXPtools.basis.makemodel import make_model
77

8-
def load_basis(config_name, cache_dir=None):
8+
def load_basis(config_file, cache_dir=None):
99
"""
1010
Load a basis configuration from a YAML file and initialize a Basis object.
1111
1212
Parameters
1313
----------
14-
config_name : str
14+
config_file : str
1515
Path to the YAML configuration file. If the provided filename does not
1616
end with `.yaml`, the extension is automatically appended.
1717
cache_dir : str (optional)
@@ -30,18 +30,28 @@ def load_basis(config_name, cache_dir=None):
3030
"""
3131

3232
# Check file existence
33-
if not os.path.exists(config_name):
34-
raise FileNotFoundError(f"Configuration file not found: {config_name}")
33+
if not os.path.exists(config_file):
34+
raise FileNotFoundError(f"Configuration file not found: {config_file}")
3535

3636
# Load YAML safely
37-
with open(config_name, "r") as f:
38-
config = yaml.load(f, Loader=yaml.FullLoader)
37+
#with open(config_file) as f:
38+
# config_yaml = f.read()
39+
with open(config_file, "r") as f:
40+
config_yaml = yaml.safe_load(f)
41+
42+
43+
modelfile = config_yaml["parameters"]["modelname"]
44+
if not os.path.exists(modelfile):
45+
raise FileNotFoundError(f"Modelname file not found: {modelfile}")
3946

40-
if cache_dir:
41-
config = re.sub(r"(modelname:\s*)(\S+)", rf"\1{cache_dir}\2", config)
42-
config = re.sub(r"(cachename:\s*)(\S+)", rf"\1{cache_dir}\2", config)
47+
#if cache_dir:
48+
# config_yaml = re.sub(r"(modelname:\s*)(\S+)", rf"\1{cache_dir}\2", config_yaml)
49+
# config_yaml = re.sub(r"(cachename:\s*)(\S+)", rf"\1{cache_dir}\2", config_yaml)
4350
# Build basis from configuration
44-
basis = pyEXP.basis.Basis.factory(config)
51+
config_str = yaml.safe_dump(config_yaml)
52+
53+
54+
basis = pyEXP.basis.Basis.factory(config_str)
4555
return basis
4656

4757
def check_basis_params(basis_params):
@@ -112,11 +122,12 @@ def check_basis_params(basis_params):
112122
raise AttributeError(f"basis id {basis_params['basis_id']} not found. Please chose between sphereSL or cylinder")
113123

114124

115-
def write_config(
116-
basis_params,
117-
write_yaml=False,
118-
filename="basis_config.yaml",
119-
):
125+
def write_config(basis, basis_filename):
126+
with open(basis_filename, 'w') as file:
127+
yaml.safe_dump(basis, file, default_flow_style=False)
128+
129+
130+
def config_dict_to_yaml(basis_params):
120131

121132
"""
122133
Create a YAML configuration file string for building a basis model.
@@ -156,7 +167,9 @@ def write_config(
156167
If ``modelname`` is required but cannot be opened.
157168
ValueError
158169
If the model file does not contain valid radius data.
170+
159171
"""
172+
basis_params = basis_params.copy()
160173
check_basis_params(basis_params)
161174

162175
if basis_params['basis_id'] == "sphereSL":
@@ -173,31 +186,25 @@ def write_config(
173186
basis_params["rmax"] = float("{:.3f}".format(rmax))
174187
basis_params["numr"] = int(numr)
175188

176-
basis_id = basis_params['basis_id']
177-
basis_params.pop('basis_id')
189+
#remove id
190+
basis_id = basis_params.pop("basis_id")
178191
config_dict = {
179192
"id": basis_id,
180193
"parameters": basis_params
181194
}
182-
print('OK')
183-
yaml_str = yaml.dump(config_dict, sort_keys=False)
184-
print('here')
185-
if write_yaml:
186-
with open(filename, "w") as f:
187-
f.write(yaml_str)
188-
print('----')
189-
return yaml_str
190-
195+
config_yaml = yaml.dump(config_dict, sort_keys=False)
196+
return config_yaml
197+
191198

192-
def make_basis(R, D, Mtotal, basis_params, physical_units=True, write_yaml=False):
199+
def make_basis(radii, density, Mtotal, basis_params, physical_units=True, write_basis=False, basis_filename='test_config.yaml'):
193200
"""
194-
Construct a basis from a given radial density profile.
201+
Construct a basis from a density profile.
195202
196203
Parameters
197204
----------
198-
R : array_like
205+
radii : array_like
199206
Radial grid points (e.g., radii at which density `D` is defined).
200-
D : array_like
207+
density : array_like
201208
Density values corresponding to each radius in `R`.
202209
Mtotal : float, optional
203210
Total mass normalization (default is 1.0).
@@ -214,7 +221,7 @@ def make_basis(R, D, Mtotal, basis_params, physical_units=True, write_yaml=False
214221
A basis object initialized with the given density model.
215222
216223
Notes
217-
-----
224+
-----
218225
- This function wraps `makemodel.makemodel` to generate a model from
219226
the supplied density profile and total mass.
220227
- It then builds a basis either spherical (`sphereSL`) or cylindrical using `EXPtools.make_config`
@@ -229,12 +236,20 @@ def make_basis(R, D, Mtotal, basis_params, physical_units=True, write_yaml=False
229236
basis_params['cachename']="test_cache.txt"
230237

231238
_ = make_model(
232-
R, D, Mtotal=Mtotal,
239+
radii, density, Mtotal=Mtotal,
233240
output_filename=basis_params['modelname'],
234241
physical_units=physical_units
235242
)
236243
print('Done making model')
237-
config = write_config(basis_params, write_yaml)
244+
config = config_dict_to_yaml(basis_params)
245+
#if write_basis == True:
246+
# yaml_config = yaml.safe_load(config)
247+
# write_config(yaml_config, basis_filename=basis_filename)
248+
#basis = load_basis(basis_filename)
249+
250+
if write_basis:
251+
yaml_config = yaml.safe_load(config)
252+
write_config(yaml_config, basis_filename)
253+
return load_basis(basis_filename)
238254

239-
basis = pyEXP.basis.Basis.factory(config)
240-
return basis
255+
return pyEXP.basis.Basis.factory(config)

0 commit comments

Comments
 (0)