Skip to content

Commit 3846ddb

Browse files
committed
fix: robust configuration handling and tokenizer support
- Updated training scripts (stage 1-4) to safely handle missing 'system' configuration section by initializing defaults. - Added support for 'tokenizer' configuration in stage 1, allowing custom tokenizers to be passed. - Included user-provided updates to configuration YAML files (adding 'system' section).
1 parent 43556fb commit 3846ddb

9 files changed

Lines changed: 34 additions & 13 deletions

configs/baseline_170m.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,6 @@ data:
3131
data_dir: "data/text_corpus"
3232
traces_path: "data/train_traces.jsonl"
3333

34-
tokenizer: "gpt2"
35-
seed: 42
34+
system:
35+
seed: 42
36+
device: "cuda"

configs/baseline_27m.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,6 @@ data:
3131
data_dir: "data/text_corpus"
3232
traces_path: "data/train_traces.jsonl"
3333

34-
tokenizer: "gpt2"
35-
seed: 42
34+
system:
35+
seed: 42
36+
device: "cuda"

configs/baseline_350m.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,6 @@ data:
3131
data_dir: "data/text_corpus"
3232
traces_path: "data/train_traces.jsonl"
3333

34-
tokenizer: "gpt2"
35-
seed: 42
34+
system:
35+
seed: 42
36+
device: "cuda"

configs/small.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ data:
3636
traces_path: "data/train_traces.jsonl"
3737

3838
tokenizer: null
39-
seed: 42
4039

4140
system:
42-
device: "cpu"
4341
seed: 42
42+
device: "cuda"

configs/test_config.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,7 @@ data:
3232
traces_path: "data/train_traces.jsonl"
3333

3434
tokenizer: null
35-
seed: 42
35+
36+
system:
37+
seed: 42
38+
device: "cuda"

scripts/training/stage_1_backbone.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,17 @@ def main():
9292
if config['training'].get('use_amp', False):
9393
cmd.append('--amp')
9494

95+
system_config = config.get('system', {})
96+
9597
if args.device:
9698
cmd.extend(['--device', args.device])
97-
elif 'device' in config['system']:
98-
cmd.extend(['--device', config['system']['device']])
99+
elif system_config.get('device'):
100+
cmd.extend(['--device', system_config['device']])
99101

100102
if args.seed is not None:
101103
cmd.extend(['--seed', str(args.seed)])
102-
elif 'seed' in config['system']:
103-
cmd.extend(['--seed', str(config['system']['seed'])])
104+
elif system_config.get('seed'):
105+
cmd.extend(['--seed', str(system_config['seed'])])
104106

105107
if args.no_wandb:
106108
cmd.append('--no-wandb')
@@ -109,6 +111,8 @@ def main():
109111

110112
if args.tokenizer:
111113
cmd.extend(['--tokenizer', args.tokenizer])
114+
elif config.get('tokenizer'):
115+
cmd.extend(['--tokenizer', config['tokenizer']])
112116

113117
if args.resume:
114118
cmd.extend(['--resume', args.resume])

scripts/training/stage_2_dynamics.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def main():
3636

3737
config = load_config(args.config)
3838

39+
# Ensure system config exists
40+
if 'system' not in config:
41+
config['system'] = {'seed': 42, 'device': 'cuda' if torch.cuda.is_available() else 'cpu'}
42+
3943
# Apply Overrides
4044
if args.epochs: config['dynamics']['dynamics_epochs'] = args.epochs
4145
if args.batch_size: config['training']['batch_size'] = args.batch_size

scripts/training/stage_3_value_head.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@ async def main_async():
134134

135135
config = load_config(args.config)
136136

137+
# Ensure system config exists
138+
if 'system' not in config:
139+
config['system'] = {'seed': 42, 'device': 'cuda' if torch.cuda.is_available() else 'cpu'}
140+
137141
# Apply Overrides
138142
if args.epochs: config['training']['value_training']['epochs'] = args.epochs
139143
if args.lr: config['training']['value_training']['lr'] = args.lr

scripts/training/stage_4_assembly.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ def main():
4040

4141
config_dict = load_config(args.config)
4242

43+
# Ensure system config exists
44+
if 'system' not in config_dict:
45+
config_dict['system'] = {'seed': 42, 'device': 'cuda' if torch.cuda.is_available() else 'cpu'}
46+
4347
# Apply Overrides
4448
if args.seed: config_dict['system']['seed'] = args.seed
4549
if args.device: config_dict['system']['device'] = args.device

0 commit comments

Comments
 (0)