Spaces:
Sleeping
Sleeping
| import torch | |
| from inference.model import DiffTransformerLLM | |
| from inference.inference import load_model | |
| import argparse | |
| import os | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Export DiffTransformerLLM to ONNX") | |
| parser.add_argument( | |
| "--checkpoint", type=str, required=True, help="Path to model checkpoint (.pt)" | |
| ) | |
| parser.add_argument( | |
| "--onnx_path", type=str, default="model.onnx", help="Output ONNX file path" | |
| ) | |
| parser.add_argument( | |
| "--seq_len", type=int, default=32, help="Dummy input sequence length" | |
| ) | |
| args = parser.parse_args() | |
| device = torch.device("cpu") | |
| print(f"Loading model from {args.checkpoint}") | |
| model = load_model(args.checkpoint, device=device, fp16=False, quantize=False) | |
| model.eval() | |
| # Prepare dummy input | |
| batch_size = 1 | |
| seq_len = args.seq_len | |
| input_ids = torch.randint(0, 259, (batch_size, seq_len), dtype=torch.long) | |
| # Create a dummy causal mask. This will be a dynamic input to the ONNX model. | |
| causal_mask = torch.triu( | |
| torch.ones(1, seq_len, seq_len, dtype=torch.bool), diagonal=1 | |
| ) | |
| attn_mask = torch.zeros(1, seq_len, seq_len, dtype=torch.float32) | |
| attn_mask.masked_fill_(causal_mask, float("-inf")) | |
| # Export to ONNX | |
| print(f"Exporting to ONNX: {args.onnx_path}") | |
| torch.onnx.export( | |
| model, | |
| (input_ids, attn_mask), | |
| args.onnx_path, | |
| input_names=["input_ids", "attn_mask"], | |
| output_names=["logits"], | |
| dynamic_axes={ | |
| "input_ids": {0: "batch_size", 1: "seq_len"}, | |
| "attn_mask": {0: "batch_size", 1: "seq_len", 2: "seq_len"}, | |
| "logits": {0: "batch_size", 1: "seq_len"}, | |
| }, | |
| opset_version=17, | |
| do_constant_folding=True, | |
| ) | |
| print(f"ONNX export complete: {args.onnx_path}") | |
| if __name__ == "__main__": | |
| main() | |