zhiyucheng commited on
Commit
f6c9092
·
1 Parent(s): 574fdb8
Files changed (1) hide show
  1. generate_metadata.py +0 -36
generate_metadata.py DELETED
@@ -1,36 +0,0 @@
1
- #!/usr/bin/env python3
2
-
3
- import glob
4
- import json
5
- import os
6
-
7
- from safetensors import safe_open # pip install safetensors
8
-
9
-
10
- def main():
11
- # Collect all shard files matching "model-*-of-*.safetensors"
12
- shard_files = sorted(glob.glob("model-*-of-*.safetensors"))
13
-
14
- # Calculate total size of all shards (in bytes)
15
- total_size = sum(os.path.getsize(sf) for sf in shard_files)
16
-
17
- metadata = {"total_size": total_size}
18
- weight_map = {}
19
-
20
- # Iterate over each shard and map its tensor names to the shard filename
21
- for shard_file in shard_files:
22
- with safe_open(shard_file, framework="np") as f:
23
- for tensor_name in f.keys():
24
- weight_map[tensor_name] = os.path.basename(shard_file)
25
-
26
- output_dict = {"metadata": metadata, "weight_map": weight_map}
27
-
28
- # Write JSON structure to "model.safetensors.index.json"
29
- with open("model.safetensors.index.json", "w", encoding="utf-8") as out_file:
30
- json.dump(output_dict, out_file, indent=2)
31
-
32
- print("Created model.safetensors.index.json with total size =", total_size, "bytes.")
33
-
34
-
35
- if __name__ == "__main__":
36
- main()