-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_evi.py
More file actions
63 lines (51 loc) · 1.92 KB
/
Copy pathrun_evi.py
File metadata and controls
63 lines (51 loc) · 1.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import sys
import json
from tqdm import tqdm
from tasks.evi_select.evi_runner import run_evi_task
# from models.openai import OpenAIModel
from models.llama import LlamaModel
from models.qwen import QwenModel
from utils import format_caption_and_table, extract_answer, format_cells
model_name = "qwen-2.5-7b"
prompt_type = "zeroshot"
table_format = "pipe_tagging"
# ["qwen-2.5-7b", "qwen-2.5-72b", "llama-3.1-8b", "llama-3.1-70b"]
model_registry = {
# "qwen-2.5-72b": lambda: QwenModel(model_name="Qwen2.5-72B"),
"qwen-2.5-7b": lambda: QwenModel(model_name="Qwen2.5-7B"),
# "llama-3.1-70b": lambda: LlamaModel(model_name="Llama-3.1-70B"),
# "llama-3.1-8b": lambda: LlamaModel(model_name="Llama-3.1-8B")
}
model = model_registry[model_name]()
OUTPUT_DIR = f"outputs/evi_task/{table_format}"
os.makedirs(OUTPUT_DIR, exist_ok=True)
kwargs = {
"temperature": 0,
"max_tokens": 1024
}
dataset_file = "data/data_100.json"
with open(dataset_file, "r", encoding="utf-8") as f:
samples = json.load(f)
# samples = samples[0:5]
results = []
for item in tqdm(samples):
# print(item["id"])
table = format_caption_and_table(item["table_column_names"], item["table_content_values"], item["table_caption"], type_=table_format)
prompt, response = run_evi_task(item["claim"], table, model, shots=prompt_type, **kwargs)
#
results.append({
"id": item["id"],
"claim": item["claim"],
"label_cells": str(format_cells(item["explanation_cells"])),
"pred_label": extract_answer(response),
"generated_response": response,
"paper_id": item["paper_id"],
"table_id": item["table_id"],
"user_prompt": prompt
})
output_path = os.path.join(
OUTPUT_DIR, f"{model_name.lower()}_{prompt_type}.json")
print(f"Saving output to: {output_path}")
with open(output_path, "w", encoding="utf-8") as out_f:
json.dump(results, out_f, indent=2, ensure_ascii=False)