Official implementation of FRODO (Framework for Faithful Reasoning Over Deliberate Output).
Paper: Making Reasoning Matter (EMNLP 2024)
Authors: Debjit Paul, Robert West, Antoine Bosselut, Boi Faltings
git clone https://github.com/Causal_CoT.git
cd Causal_CoT
pip install -r requirements.txtRequirements: Python ≥ 3.8, PyTorch ≥ 2.0.0
cd src
python test_frodo.pycd src
python train_frodo.pyfrom transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from frodo import FRODO, FRODOConfig
# Configure
config = FRODOConfig(
model_name="google/flan-t5-large",
beta=0.1, # DPO temperature
lambda_lm=1.0, # Language model loss weight
lambda_ie=1.0, # Indirect effect loss weight
lambda_margin=1.0, # Margin ranking loss weight
)
# Load models
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
inference_model = AutoModelForSeq2SeqLM.from_pretrained(config.model_name)
reasoning_model = AutoModelForSeq2SeqLM.from_pretrained(config.model_name)
# Create FRODO
frodo = FRODO(inference_model, reasoning_model, config)
# Train Phase 1: Inference Module
from src.frodo import DPODataset
from torch.utils.data import DataLoader
dpo_dataset = DPODataset(dpo_data, tokenizer)
dpo_dataloader = DataLoader(dpo_dataset, batch_size=8)
optimizer = torch.optim.AdamW(frodo.inference_module.parameters(), lr=5e-5)
frodo.train_inference_module(dpo_dataloader, num_epochs=3, optimizer=optimizer)
# Train Phase 2: Reasoning Module
from src.frodo import ReasoningDataset
reasoning_dataset = ReasoningDataset(reasoning_data, tokenizer)
reasoning_dataloader = DataLoader(reasoning_dataset, batch_size=8)
optimizer = torch.optim.AdamW(frodo.reasoning_module.parameters(), lr=5e-5)
frodo.train_reasoning_module(reasoning_dataloader, num_epochs=3, optimizer=optimizer)
# Inference
question = "Does a banana have more protein than an apple?"
reasoning, answer = frodo.generate_reasoning_and_answer(question, tokenizer)dpo_data = [
{
'question': 'Does a banana have more protein than an apple?',
'preferred_reasoning': 'Step 1: Banana has 1.3g protein per 100g. Step 2: Apple has 0.3g. Step 3: 1.3 > 0.3.',
'dispreferred_reasoning': 'Step 1: Banana has 0.4g protein per 100g. Step 2: Apple has 1.5g. Step 3: 1.5 > 0.4.',
},
]reasoning_data = [
{
'question': 'Does a banana have more protein than an apple?',
'reasoning': 'Step 1: Banana has 1.3g protein per 100g. Step 2: Apple has 0.3g. Step 3: 1.3 > 0.3.',
'answer': 'Yes',
'counterfactual_reasoning': 'Step 1: Banana has 0.4g protein per 100g. Step 2: Apple has 1.5g. Step 3: 1.5 > 0.4.',
},
]Causal_CoT/
├── src/ # FRODO implementation
│ ├── frodo.py # Main framework
│ ├── train_frodo.py # Training script
│ └── test_frodo.py # Tests
├── requirements.txt
└── README.md
| Dataset | Type | Links |
|---|---|---|
| StrategyQA | Multi-hop QA | Paper, Data |
| QuaRel | Qualitative Reasoning | Paper, Data |
| OpenBookQA | Science QA | Paper, Data |
| QASC | Multi-hop Science | Paper, Data |
| GSM8K | Math Problems | Paper, Data |
| Causal Understanding | Causal Reasoning | Paper, Data |
All datasets use silver rationales generated by GPT-3 (Text-Davinci-003).
| Parameter | Default | Range | Description |
|---|---|---|---|
beta |
0.1 | 0.01-0.5 | DPO temperature |
lambda_lm |
0.1 | 0.0-1.0 | Language model loss weight |
lambda_ie |
0.1 | 0.0-1.0 | Indirect effect loss weight |
lambda_margin |
0.1 | 0.0-1.0 | Margin ranking loss weight |
learning_rate |
5e-5 | 1e-5 to 1e-4 | Optimizer learning rate |
FRODO uses two modules:
Inference Module: DPO Loss
L_DPO = -log(σ(β * (log π_θ(r_w|x) - log π_ref(r_w|x) - log π_θ(r_l|x) + log π_ref(r_l|x))))
Reasoning Module: Combined Loss
L_PREF = λ_LM * L_LM + λ_IE * L_IE + λ_MR * L_MR
See src/frodo.py for implementation details.
@inproceedings{paul2024making,
title={Making Reasoning Matter: Measuring and Improving Faithfulness of Chain-of-Thought Reasoning},
author={Paul, Debjit and West, Robert and Bosselut, Antoine and Faltings, Boi},
booktitle={Findings of the Association for Computational Linguistics: EMNLP 2024},
year={2024}
}- Issues: GitHub Issues
- Email: debjit.paul@epfl.ch