Skip to content

debjitpaul/Causal_CoT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

36 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Making Reasoning Matter: Measuring and Improving Faithfulness of Chain-of-Thought Reasoning

arXiv License: MIT

Official implementation of FRODO (Framework for Faithful Reasoning Over Deliberate Output).

Paper: Making Reasoning Matter (EMNLP 2024)
Authors: Debjit Paul, Robert West, Antoine Bosselut, Boi Faltings

Installation

git clone https://github.com/Causal_CoT.git
cd Causal_CoT
pip install -r requirements.txt

Requirements: Python ≥ 3.8, PyTorch ≥ 2.0.0

Quick Start

Run Tests

cd src
python test_frodo.py

Train FRODO

cd src
python train_frodo.py

Use in Your Code

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from frodo import FRODO, FRODOConfig

# Configure
config = FRODOConfig(
    model_name="google/flan-t5-large",
    beta=0.1,              # DPO temperature
    lambda_lm=1.0,         # Language model loss weight
    lambda_ie=1.0,         # Indirect effect loss weight
    lambda_margin=1.0,     # Margin ranking loss weight
)

# Load models
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
inference_model = AutoModelForSeq2SeqLM.from_pretrained(config.model_name)
reasoning_model = AutoModelForSeq2SeqLM.from_pretrained(config.model_name)

# Create FRODO
frodo = FRODO(inference_model, reasoning_model, config)

# Train Phase 1: Inference Module
from src.frodo import DPODataset
from torch.utils.data import DataLoader

dpo_dataset = DPODataset(dpo_data, tokenizer)
dpo_dataloader = DataLoader(dpo_dataset, batch_size=8)
optimizer = torch.optim.AdamW(frodo.inference_module.parameters(), lr=5e-5)
frodo.train_inference_module(dpo_dataloader, num_epochs=3, optimizer=optimizer)

# Train Phase 2: Reasoning Module
from src.frodo import ReasoningDataset

reasoning_dataset = ReasoningDataset(reasoning_data, tokenizer)
reasoning_dataloader = DataLoader(reasoning_dataset, batch_size=8)
optimizer = torch.optim.AdamW(frodo.reasoning_module.parameters(), lr=5e-5)
frodo.train_reasoning_module(reasoning_dataloader, num_epochs=3, optimizer=optimizer)

# Inference
question = "Does a banana have more protein than an apple?"
reasoning, answer = frodo.generate_reasoning_and_answer(question, tokenizer)

Data Format

For Inference Module (DPO)

dpo_data = [
    {
        'question': 'Does a banana have more protein than an apple?',
        'preferred_reasoning': 'Step 1: Banana has 1.3g protein per 100g. Step 2: Apple has 0.3g. Step 3: 1.3 > 0.3.',
        'dispreferred_reasoning': 'Step 1: Banana has 0.4g protein per 100g. Step 2: Apple has 1.5g. Step 3: 1.5 > 0.4.',
    },
]

For Reasoning Module

reasoning_data = [
    {
        'question': 'Does a banana have more protein than an apple?',
        'reasoning': 'Step 1: Banana has 1.3g protein per 100g. Step 2: Apple has 0.3g. Step 3: 1.3 > 0.3.',
        'answer': 'Yes',
        'counterfactual_reasoning': 'Step 1: Banana has 0.4g protein per 100g. Step 2: Apple has 1.5g. Step 3: 1.5 > 0.4.',
    },
]

Repository Structure

Causal_CoT/
├── src/                     # FRODO implementation
│   ├── frodo.py            # Main framework
│   ├── train_frodo.py      # Training script
│   └── test_frodo.py       # Tests
├── requirements.txt
└── README.md

Datasets

Dataset Type Links
StrategyQA Multi-hop QA Paper, Data
QuaRel Qualitative Reasoning Paper, Data
OpenBookQA Science QA Paper, Data
QASC Multi-hop Science Paper, Data
GSM8K Math Problems Paper, Data
Causal Understanding Causal Reasoning Paper, Data

All datasets use silver rationales generated by GPT-3 (Text-Davinci-003).

Hyperparameters

Parameter Default Range Description
beta 0.1 0.01-0.5 DPO temperature
lambda_lm 0.1 0.0-1.0 Language model loss weight
lambda_ie 0.1 0.0-1.0 Indirect effect loss weight
lambda_margin 0.1 0.0-1.0 Margin ranking loss weight
learning_rate 5e-5 1e-5 to 1e-4 Optimizer learning rate

Loss Functions

FRODO uses two modules:

Inference Module: DPO Loss

L_DPO = -log(σ(β * (log π_θ(r_w|x) - log π_ref(r_w|x) - log π_θ(r_l|x) + log π_ref(r_l|x))))

Reasoning Module: Combined Loss

L_PREF = λ_LM * L_LM + λ_IE * L_IE + λ_MR * L_MR

See src/frodo.py for implementation details.

Citation

@inproceedings{paul2024making,
  title={Making Reasoning Matter: Measuring and Improving Faithfulness of Chain-of-Thought Reasoning},
  author={Paul, Debjit and West, Robert and Bosselut, Antoine and Faltings, Boi},
  booktitle={Findings of the Association for Computational Linguistics: EMNLP 2024},
  year={2024}
}

Contact

About

About The corresponding code from our paper " Making Reasoning Matter: Measuring and Improving Faithfulness of Chain-of-Thought Reasoning" . Do not hesitate to open an issue if you run into any trouble!

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages