-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathprofile_lightning_ir.py
More file actions
66 lines (61 loc) · 1.71 KB
/
profile_lightning_ir.py
File metadata and controls
66 lines (61 loc) · 1.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from lightning.pytorch.profilers import PyTorchProfiler
from lightning_ir import (
BiEncoderConfig,
BiEncoderModule,
FLOPSRegularization,
KLDivergence,
LightningIRDataModule,
LightningIRTrainer,
RunDataset,
ScoreBasedInBatchCrossEntropy,
SpladeConfig,
SupervisedMarginMSE,
)
from torch.optim import AdamW
from tite.utils.callbacks import DummyImportCallback
profiler = PyTorchProfiler(
sort_by_key="cpu_time",
export_to_chrome=False,
record_module_names=False,
# group_by_input_shapes=True,
# record_shapes=True,
)
trainer = LightningIRTrainer(
precision="bf16-mixed",
logger=False,
profiler=profiler,
max_steps=100,
num_sanity_val_steps=0,
enable_checkpointing=False,
)
data_module = LightningIRDataModule(
num_workers=4,
train_batch_size=8,
train_dataset=RunDataset(
"/mnt/ceph/storage/data-tmp/current/fschlatt/lightning-ir-experiments/runs-archive/"
"__10000__msmarco-passage-train-judged.run",
depth=100,
sample_size=8,
sampling_strategy="log_random",
targets="score",
),
)
model = BiEncoderModule(
model_name_or_path="./models/pre-trained/tite-2-late-flash-new",
config=SpladeConfig(
similarity_function="dot",
doc_length=256,
tie_projection=False,
query_pooling_strategy="max",
doc_pooling_strategy="max",
),
# config=BiEncoderConfig(doc_length=256),
loss_functions=[
(SupervisedMarginMSE(), 0.05),
KLDivergence(),
ScoreBasedInBatchCrossEntropy(min_target_diff=3),
FLOPSRegularization(query_weight=0.01, doc_weight=0.02),
],
)
model.set_optimizer(AdamW, lr=1e-5)
trainer.fit(model, data_module)