| | """
|
| | Setup script for BitLinear PyTorch extension.
|
| |
|
| | This script builds the C++/CUDA extension using PyTorch's built-in
|
| | cpp_extension utilities. It handles:
|
| | - CPU-only builds (development)
|
| | - CUDA builds (production)
|
| | - Conditional compilation based on CUDA availability
|
| | """
|
| |
|
| | import os
|
| | import torch
|
| | from setuptools import setup, find_packages
|
| | from torch.utils.cpp_extension import (
|
| | BuildExtension,
|
| | CppExtension,
|
| | CUDAExtension,
|
| | CUDA_HOME,
|
| | )
|
| |
|
| |
|
| | VERSION = "0.1.0"
|
| | DESCRIPTION = "BitLinear: Ultra-Low-Precision Linear Layers for PyTorch"
|
| | LONG_DESCRIPTION = """
|
| | A research-grade PyTorch extension for ultra-low-precision (1.58-bit) ternary
|
| | linear layers inspired by BitNet and recent JMLR work on ternary representations
|
| | of neural networks.
|
| |
|
| | Features:
|
| | - Drop-in replacement for nn.Linear with ternary weights
|
| | - 20x memory compression
|
| | - Optimized CUDA kernels for GPU acceleration
|
| | - Greedy ternary decomposition for improved expressiveness
|
| | """
|
| |
|
| |
|
| | def cuda_is_available():
|
| | """Check if CUDA is available for compilation."""
|
| | return torch.cuda.is_available() and CUDA_HOME is not None
|
| |
|
| |
|
| | def get_extensions():
|
| | """
|
| | Build extension modules based on CUDA availability.
|
| |
|
| | Returns:
|
| | List of extension modules to compile
|
| | """
|
| |
|
| | source_dir = os.path.join("bitlinear", "cpp")
|
| | sources = [os.path.join(source_dir, "bitlinear.cpp")]
|
| |
|
| |
|
| | extra_compile_args = {
|
| | "cxx": ["-O3", "-std=c++17"],
|
| | }
|
| |
|
| |
|
| | define_macros = []
|
| |
|
| | if cuda_is_available():
|
| | print("CUDA detected, building with GPU support")
|
| |
|
| |
|
| | sources.append(os.path.join(source_dir, "bitlinear_kernel.cu"))
|
| |
|
| |
|
| | extra_compile_args["nvcc"] = [
|
| | "-O3",
|
| | "-std=c++17",
|
| | "--use_fast_math",
|
| | "-gencode=arch=compute_70,code=sm_70",
|
| | "-gencode=arch=compute_75,code=sm_75",
|
| | "-gencode=arch=compute_80,code=sm_80",
|
| | "-gencode=arch=compute_86,code=sm_86",
|
| | "-gencode=arch=compute_89,code=sm_89",
|
| | "-gencode=arch=compute_90,code=sm_90",
|
| | ]
|
| |
|
| |
|
| | define_macros.append(("WITH_CUDA", None))
|
| |
|
| |
|
| | extension = CUDAExtension(
|
| | name="bitlinear_cpp",
|
| | sources=sources,
|
| | extra_compile_args=extra_compile_args,
|
| | define_macros=define_macros,
|
| | )
|
| | else:
|
| | print("CUDA not detected, building CPU-only version")
|
| |
|
| |
|
| | extension = CppExtension(
|
| | name="bitlinear_cpp",
|
| | sources=sources,
|
| | extra_compile_args=extra_compile_args["cxx"],
|
| | define_macros=define_macros,
|
| | )
|
| |
|
| | return [extension]
|
| |
|
| |
|
| |
|
| | def read_requirements():
|
| | """Read requirements from requirements.txt if it exists."""
|
| | req_file = "requirements.txt"
|
| | if os.path.exists(req_file):
|
| | with open(req_file, "r") as f:
|
| | return [line.strip() for line in f if line.strip() and not line.startswith("#")]
|
| | return []
|
| |
|
| |
|
| |
|
| | setup(
|
| | name="bitlinear",
|
| | version=VERSION,
|
| | author="BitLinear Contributors",
|
| | description=DESCRIPTION,
|
| | long_description=LONG_DESCRIPTION,
|
| | long_description_content_type="text/markdown",
|
| | url="https://github.com/yourusername/bitlinear",
|
| | packages=find_packages(),
|
| | ext_modules=get_extensions(),
|
| | cmdclass={
|
| | "build_ext": BuildExtension.with_options(no_python_abi_suffix=True)
|
| | },
|
| | install_requires=[
|
| | "torch>=2.0.0",
|
| | "numpy>=1.20.0",
|
| | ],
|
| | extras_require={
|
| | "dev": [
|
| | "pytest>=7.0.0",
|
| | "pytest-cov>=4.0.0",
|
| | "black>=22.0.0",
|
| | "flake8>=5.0.0",
|
| | "mypy>=0.990",
|
| | ],
|
| | "test": [
|
| | "pytest>=7.0.0",
|
| | "pytest-cov>=4.0.0",
|
| | ],
|
| | },
|
| | python_requires=">=3.8",
|
| | classifiers=[
|
| | "Development Status :: 3 - Alpha",
|
| | "Intended Audience :: Science/Research",
|
| | "Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| | "License :: OSI Approved :: MIT License",
|
| | "Programming Language :: Python :: 3",
|
| | "Programming Language :: Python :: 3.8",
|
| | "Programming Language :: Python :: 3.9",
|
| | "Programming Language :: Python :: 3.10",
|
| | "Programming Language :: Python :: 3.11",
|
| | "Programming Language :: C++",
|
| | "Programming Language :: Python :: Implementation :: CPython",
|
| | ],
|
| | keywords="pytorch deep-learning quantization ternary bitnet transformer",
|
| | project_urls={
|
| | "Bug Reports": "https://github.com/yourusername/bitlinear/issues",
|
| | "Source": "https://github.com/yourusername/bitlinear",
|
| | "Documentation": "https://github.com/yourusername/bitlinear/blob/main/README.md",
|
| | },
|
| | )
|
| |
|