-
Notifications
You must be signed in to change notification settings - Fork 72
Expand file tree
/
Copy pathconvert_weights.py
More file actions
36 lines (28 loc) · 1.33 KB
/
convert_weights.py
File metadata and controls
36 lines (28 loc) · 1.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import torch
def main(args):
sd = torch.load(args.src, map_location="cpu")["model"]
sd = {k: v for k, v in sd.items() if "teacher" not in k}
sd = {
k.replace("backbone.vision_backbone", "image_encoder"): v for k, v in sd.items()
}
sd = {k.replace("mlp.fc1", "mlp.layers.0"): v for k, v in sd.items()}
sd = {k.replace("mlp.fc2", "mlp.layers.1"): v for k, v in sd.items()}
sd = {k.replace("convs", "neck.convs"): v for k, v in sd.items()}
sd = {
k.replace("transformer.encoder", "memory_attention"): v for k, v in sd.items()
}
sd = {k.replace("maskmem_backbone", "memory_encoder"): v for k, v in sd.items()}
sd = {k.replace("maskmem_backbone", "memory_encoder"): v for k, v in sd.items()}
sd = {k.replace("mlp.lin1", "mlp.layers.0"): v for k, v in sd.items()}
sd = {k.replace("mlp.lin2", "mlp.layers.1"): v for k, v in sd.items()}
torch.save({"model": sd}, args.src.replace(".pt", "_converted.pt"))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--src", type=str, required=True)
args = parser.parse_args()
main(args)