Spaces:
Running
on
T4
Running
on
T4
Commit
·
6da47c0
1
Parent(s):
cf43f05
init commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- .gitignore +234 -0
- LICENSE-APACHE +201 -0
- LICENSE-MIT +25 -0
- README.md +82 -12
- app.py +43 -0
- data/example-data/Amir-Khan-Lamont-Peterson_2689582.jpg +3 -0
- data/example-data/BNAAHPYGMYSE26U6C6T7VA6544.jpg +3 -0
- data/example-data/Canelo-Alvarez-b4d59f2080464e4d996177f5ce9792ee.jpg +3 -0
- data/example-data/Planche.jpg +3 -0
- data/example-data/yoga-example.jpg +3 -0
- pixi.lock +0 -0
- pyproject.toml +149 -0
- src/sam3d_body/__init__.py +12 -0
- src/sam3d_body/api/demo.py +241 -0
- src/sam3d_body/api/visualization.py +425 -0
- src/sam3d_body/build_models.py +56 -0
- src/sam3d_body/data/__init__.py +1 -0
- src/sam3d_body/data/transforms/__init__.py +21 -0
- src/sam3d_body/data/transforms/bbox_utils.py +380 -0
- src/sam3d_body/data/transforms/common.py +345 -0
- src/sam3d_body/data/utils/io.py +114 -0
- src/sam3d_body/data/utils/prepare_batch.py +99 -0
- src/sam3d_body/gradio_ui/sam3d_body_ui.py +164 -0
- src/sam3d_body/metadata/__init__.py +79 -0
- src/sam3d_body/metadata/mhr70.py +915 -0
- src/sam3d_body/models/__init__.py +1 -0
- src/sam3d_body/models/backbones/__init__.py +35 -0
- src/sam3d_body/models/backbones/dinov3.py +69 -0
- src/sam3d_body/models/backbones/vit.py +658 -0
- src/sam3d_body/models/decoders/__init__.py +32 -0
- src/sam3d_body/models/decoders/keypoint_prompt_sampler.py +183 -0
- src/sam3d_body/models/decoders/prompt_encoder.py +256 -0
- src/sam3d_body/models/decoders/promptable_decoder.py +194 -0
- src/sam3d_body/models/heads/__init__.py +28 -0
- src/sam3d_body/models/heads/camera_head.py +110 -0
- src/sam3d_body/models/heads/mhr_head.py +369 -0
- src/sam3d_body/models/meta_arch/__init__.py +3 -0
- src/sam3d_body/models/meta_arch/base_lightning_module.py +48 -0
- src/sam3d_body/models/meta_arch/base_model.py +162 -0
- src/sam3d_body/models/meta_arch/sam3d_body.py +1728 -0
- src/sam3d_body/models/modules/__init__.py +18 -0
- src/sam3d_body/models/modules/camera_embed.py +111 -0
- src/sam3d_body/models/modules/drop_path.py +42 -0
- src/sam3d_body/models/modules/geometry_utils.py +304 -0
- src/sam3d_body/models/modules/layer_scale.py +45 -0
- src/sam3d_body/models/modules/mhr_utils.py +392 -0
- src/sam3d_body/models/modules/misc.py +31 -0
- src/sam3d_body/models/modules/swiglu_ffn.py +96 -0
- src/sam3d_body/models/modules/transformer.py +651 -0
.gitattributes
CHANGED
|
@@ -1,3 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
# LFS/Xet-managed assets
|
| 2 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
|
| 7 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 8 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 9 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[codz]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py.cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
# Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
# poetry.lock
|
| 109 |
+
# poetry.toml
|
| 110 |
+
|
| 111 |
+
# pdm
|
| 112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 113 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
| 114 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
| 115 |
+
# pdm.lock
|
| 116 |
+
# pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# pixi
|
| 121 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
| 122 |
+
# pixi.lock
|
| 123 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
| 124 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
| 125 |
+
.pixi
|
| 126 |
+
|
| 127 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 128 |
+
__pypackages__/
|
| 129 |
+
|
| 130 |
+
# Celery stuff
|
| 131 |
+
celerybeat-schedule
|
| 132 |
+
celerybeat.pid
|
| 133 |
+
|
| 134 |
+
# Redis
|
| 135 |
+
*.rdb
|
| 136 |
+
*.aof
|
| 137 |
+
*.pid
|
| 138 |
+
|
| 139 |
+
# RabbitMQ
|
| 140 |
+
mnesia/
|
| 141 |
+
rabbitmq/
|
| 142 |
+
rabbitmq-data/
|
| 143 |
+
|
| 144 |
+
# ActiveMQ
|
| 145 |
+
activemq-data/
|
| 146 |
+
|
| 147 |
+
# SageMath parsed files
|
| 148 |
+
*.sage.py
|
| 149 |
+
|
| 150 |
+
# Environments
|
| 151 |
+
.env
|
| 152 |
+
.envrc
|
| 153 |
+
.venv
|
| 154 |
+
env/
|
| 155 |
+
venv/
|
| 156 |
+
ENV/
|
| 157 |
+
env.bak/
|
| 158 |
+
venv.bak/
|
| 159 |
+
|
| 160 |
+
# Spyder project settings
|
| 161 |
+
.spyderproject
|
| 162 |
+
.spyproject
|
| 163 |
+
|
| 164 |
+
# Rope project settings
|
| 165 |
+
.ropeproject
|
| 166 |
+
|
| 167 |
+
# mkdocs documentation
|
| 168 |
+
/site
|
| 169 |
+
|
| 170 |
+
# mypy
|
| 171 |
+
.mypy_cache/
|
| 172 |
+
.dmypy.json
|
| 173 |
+
dmypy.json
|
| 174 |
+
|
| 175 |
+
# Pyre type checker
|
| 176 |
+
.pyre/
|
| 177 |
+
|
| 178 |
+
# pytype static type analyzer
|
| 179 |
+
.pytype/
|
| 180 |
+
|
| 181 |
+
# Cython debug symbols
|
| 182 |
+
cython_debug/
|
| 183 |
+
|
| 184 |
+
# PyCharm
|
| 185 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 186 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 187 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 188 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 189 |
+
# .idea/
|
| 190 |
+
|
| 191 |
+
# Abstra
|
| 192 |
+
# Abstra is an AI-powered process automation framework.
|
| 193 |
+
# Ignore directories containing user credentials, local state, and settings.
|
| 194 |
+
# Learn more at https://abstra.io/docs
|
| 195 |
+
.abstra/
|
| 196 |
+
|
| 197 |
+
# Visual Studio Code
|
| 198 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
| 199 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
| 200 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
| 201 |
+
# you could uncomment the following to ignore the entire vscode folder
|
| 202 |
+
# .vscode/
|
| 203 |
+
|
| 204 |
+
# Ruff stuff:
|
| 205 |
+
.ruff_cache/
|
| 206 |
+
|
| 207 |
+
# PyPI configuration file
|
| 208 |
+
.pypirc
|
| 209 |
+
|
| 210 |
+
# Marimo
|
| 211 |
+
marimo/_static/
|
| 212 |
+
marimo/_lsp/
|
| 213 |
+
__marimo__/
|
| 214 |
+
|
| 215 |
+
# Streamlit
|
| 216 |
+
.streamlit/secrets.toml
|
| 217 |
+
|
| 218 |
+
# pixi environments
|
| 219 |
+
.pixi/*
|
| 220 |
+
!.pixi/config.toml
|
| 221 |
+
|
| 222 |
+
_checkpoints/*
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
# START Ruler Generated Files
|
| 226 |
+
/.codex/config.json
|
| 227 |
+
/.codex/config.json.bak
|
| 228 |
+
/.codex/config.toml
|
| 229 |
+
/.codex/config.toml.bak
|
| 230 |
+
/.vscode/mcp.json
|
| 231 |
+
/.vscode/mcp.json.bak
|
| 232 |
+
/AGENTS.md
|
| 233 |
+
/AGENTS.md.bak
|
| 234 |
+
# END Ruler Generated Files
|
LICENSE-APACHE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
LICENSE-MIT
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Copyright (c) 2022 Rerun Technologies AB <opensource@rerun.io>
|
| 2 |
+
|
| 3 |
+
Permission is hereby granted, free of charge, to any
|
| 4 |
+
person obtaining a copy of this software and associated
|
| 5 |
+
documentation files (the "Software"), to deal in the
|
| 6 |
+
Software without restriction, including without
|
| 7 |
+
limitation the rights to use, copy, modify, merge,
|
| 8 |
+
publish, distribute, sublicense, and/or sell copies of
|
| 9 |
+
the Software, and to permit persons to whom the Software
|
| 10 |
+
is furnished to do so, subject to the following
|
| 11 |
+
conditions:
|
| 12 |
+
|
| 13 |
+
The above copyright notice and this permission notice
|
| 14 |
+
shall be included in all copies or substantial portions
|
| 15 |
+
of the Software.
|
| 16 |
+
|
| 17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
|
| 18 |
+
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
|
| 19 |
+
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
|
| 20 |
+
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
|
| 21 |
+
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
| 22 |
+
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
| 23 |
+
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
|
| 24 |
+
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
| 25 |
+
DEALINGS IN THE SOFTWARE.
|
README.md
CHANGED
|
@@ -1,12 +1,82 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SAM3D Body with Rerun
|
| 2 |
+
An unofficial playground for Meta's SAM3D Body (DINOv3) with promptable SAM3 masks and live Rerun visualization. Uses **Rerun** for 3D inspection, **Gradio** for the UI, and **Pixi** for one-command setup.
|
| 3 |
+
|
| 4 |
+
<p align="center">
|
| 5 |
+
<a title="Rerun" href="https://rerun.io" target="_blank" rel="noopener noreferrer">
|
| 6 |
+
<img src="https://img.shields.io/badge/Rerun-0.27%2B-0b82f9" alt="Rerun badge">
|
| 7 |
+
</a>
|
| 8 |
+
<a title="Pixi" href="https://pixi.sh/latest/" target="_blank" rel="noopener noreferrer">
|
| 9 |
+
<img src="https://img.shields.io/badge/Install%20with-Pixi-16A34A" alt="Pixi badge">
|
| 10 |
+
</a>
|
| 11 |
+
<a title="CUDA" href="https://developer.nvidia.com/cuda-toolkit" target="_blank" rel="noopener noreferrer">
|
| 12 |
+
<img src="https://img.shields.io/badge/CUDA-12.9%2B-76b900" alt="CUDA badge">
|
| 13 |
+
</a>
|
| 14 |
+
<a title="GitHub" href="https://github.com/rerun-io/sam3d-body-rerun" target="_blank" rel="noopener noreferrer">
|
| 15 |
+
<img src="https://img.shields.io/github/stars/rerun-io/sam3d-body-rerun?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="GitHub stars">
|
| 16 |
+
</a>
|
| 17 |
+
</p>
|
| 18 |
+
|
| 19 |
+
<p align="center">
|
| 20 |
+
<!-- Drop your GIF/MP4 here once ready -->
|
| 21 |
+
<img src="media/sam3d-body-demo.gif" alt="example output" width="720" />
|
| 22 |
+
</p>
|
| 23 |
+
|
| 24 |
+
## Installation
|
| 25 |
+
### Using Pixi
|
| 26 |
+
Make sure you have the [Pixi](https://pixi.sh/latest/#installation) package manager installed.
|
| 27 |
+
|
| 28 |
+
TL;DR install Pixi:
|
| 29 |
+
```bash
|
| 30 |
+
curl -fsSL https://pixi.sh/install.sh | sh
|
| 31 |
+
```
|
| 32 |
+
Restart your shell so the new `pixi` binary is on `PATH`.
|
| 33 |
+
|
| 34 |
+
This is Linux only with an NVIDIA GPU.
|
| 35 |
+
|
| 36 |
+
The SAM3 and SAM3D Body checkpoints are gated on Hugging Face—request access for both [facebook/sam-3d-body-dinov3](https://huggingface.co/facebook/sam-3d-body-dinov3) and [facebook/sam3](https://huggingface.co/facebook/sam3), then authenticate either by setting `HF_TOKEN=<your token>` or running `huggingface-cli login` before the first download (see Meta's install notes).
|
| 37 |
+
|
| 38 |
+
First run will download HF checkpoints for SAM3, SAM3D Body, and the relative-depth model.
|
| 39 |
+
```bash
|
| 40 |
+
git clone https://github.com/rerun-io/sam3d-body-rerun.git
|
| 41 |
+
cd sam3d-body-rerun
|
| 42 |
+
pixi run app
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
All commands can be listed with `pixi task list`.
|
| 46 |
+
|
| 47 |
+
## Usage
|
| 48 |
+
### Gradio App
|
| 49 |
+
```bash
|
| 50 |
+
pixi run app
|
| 51 |
+
```
|
| 52 |
+
Opens the Gradio UI with an embedded streaming Rerun viewer. Try the bundled samples in `data/example-data` or upload your own RGB image; toggle “Log relative depth” to stream predicted depth.
|
| 53 |
+
|
| 54 |
+
### CLI
|
| 55 |
+
From a dev shell (for tyro + dev deps):
|
| 56 |
+
```
|
| 57 |
+
pixi run cli
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
OR
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
pixi shell -e dev
|
| 64 |
+
python tool/demo.py --help
|
| 65 |
+
```
|
| 66 |
+
Run on a folder of images and configure Rerun output/recordings via the CLI flags.
|
| 67 |
+
|
| 68 |
+
### Promptable SAM3 sandbox
|
| 69 |
+
If you just want SAM3 masks without 3D reconstruction:
|
| 70 |
+
```bash
|
| 71 |
+
pixi run -e dev python tool/gradio_sam3.py
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
## Acknowledgements
|
| 75 |
+
Thanks to the original projects that make this demo possible:
|
| 76 |
+
|
| 77 |
+
- [facebook/sam-3d-body-dinov3](https://huggingface.co/facebook/sam-3d-body-dinov3) — SAM3D Body checkpoints and assets.
|
| 78 |
+
- [facebook/sam3](https://huggingface.co/facebook/sam3) — promptable concept segmentation.
|
| 79 |
+
- Relative depth/FOV from `MogeV1Predictor` in [monopriors](https://github.com/pablovela5620/monoprior).
|
| 80 |
+
- Built with [Rerun](https://rerun.io/), [Gradio](https://www.gradio.app/), and [Pixi](https://pixi.sh/latest/).
|
| 81 |
+
|
| 82 |
+
Dual licensed under Apache 2.0 and MIT for the code in this repository; upstream models/assets retain their original licenses (see `LICENSE-APACHE` and `LICENSE-MIT` for this repo).
|
app.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
PIXI_PATH = Path("/home/user/.pixi/bin/pixi")
|
| 6 |
+
PIXI_VERSION = "0.59.0"
|
| 7 |
+
MOCK_CUDA_VERSION = "12.9"
|
| 8 |
+
|
| 9 |
+
# Pretend CUDA 12.9 is available so pixi can solve environments on machines without GPUs.
|
| 10 |
+
os.environ.setdefault("CONDA_OVERRIDE_CUDA", MOCK_CUDA_VERSION)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def check_and_install_pixi() -> None:
|
| 14 |
+
try:
|
| 15 |
+
subprocess.check_call(f"{PIXI_PATH} --version", shell=True)
|
| 16 |
+
except subprocess.CalledProcessError:
|
| 17 |
+
print("pixi not found. Installing pixi...")
|
| 18 |
+
# Install pixi using the provided installation script
|
| 19 |
+
subprocess.check_call(
|
| 20 |
+
f"PIXI_VERSION=v{PIXI_VERSION} curl -fsSL https://pixi.sh/install.sh | bash",
|
| 21 |
+
shell=True,
|
| 22 |
+
)
|
| 23 |
+
subprocess.check_call(f"{PIXI_PATH} self-update --version {PIXI_VERSION}", shell=True)
|
| 24 |
+
subprocess.check_call(f"{PIXI_PATH} --version", shell=True)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def run_command(command: str) -> None:
|
| 28 |
+
try:
|
| 29 |
+
subprocess.check_call(command, shell=True)
|
| 30 |
+
except subprocess.CalledProcessError as e:
|
| 31 |
+
print(f"run command {command}. Error: {e}")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if __name__ == "__main__":
|
| 35 |
+
check_and_install_pixi()
|
| 36 |
+
# install lsof
|
| 37 |
+
# run_command(command=f"{PIXI_PATH} global install lsof")
|
| 38 |
+
# # kill anything running on port 7860
|
| 39 |
+
# run_command(command=f"{PIXI_PATH.parent}/lsof -t -i:7860 | xargs -r kill")
|
| 40 |
+
# clean current environment
|
| 41 |
+
run_command(command=f"{PIXI_PATH} clean")
|
| 42 |
+
# run spaces app
|
| 43 |
+
run_command(command=f"{PIXI_PATH} run app")
|
data/example-data/Amir-Khan-Lamont-Peterson_2689582.jpg
ADDED
|
Git LFS Details
|
data/example-data/BNAAHPYGMYSE26U6C6T7VA6544.jpg
ADDED
|
Git LFS Details
|
data/example-data/Canelo-Alvarez-b4d59f2080464e4d996177f5ce9792ee.jpg
ADDED
|
Git LFS Details
|
data/example-data/Planche.jpg
ADDED
|
Git LFS Details
|
data/example-data/yoga-example.jpg
ADDED
|
Git LFS Details
|
pixi.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pyproject.toml
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
authors = [{ name = "pablo vela", email = "pablovela5620@gmail.com" }]
|
| 3 |
+
dependencies = [
|
| 4 |
+
"jaxtyping<0.3.0",
|
| 5 |
+
"numpy>=2.0",
|
| 6 |
+
"einops>=0.8.0",
|
| 7 |
+
"icecream>=2.1.3",
|
| 8 |
+
"opencv-python>=4.10.0",
|
| 9 |
+
"pyserde>=0.20.0",
|
| 10 |
+
"rerun-sdk>=0.27.0",
|
| 11 |
+
"tyro>=0.9.1",
|
| 12 |
+
"tqdm",
|
| 13 |
+
"hf-transfer>=0.1.9",
|
| 14 |
+
"lovely-numpy>=0.2.13,<0.3",
|
| 15 |
+
"pandas>=2.3.3",
|
| 16 |
+
"braceexpand>=0.1.7,<0.2",
|
| 17 |
+
"roma>=1.5.4,<2",
|
| 18 |
+
"pytorch-lightning>=2.5.6,<3",
|
| 19 |
+
"yacs>=0.1.8,<0.2",
|
| 20 |
+
"omegaconf>=2.3.0,<3",
|
| 21 |
+
"termcolor>=3.2.0,<4",
|
| 22 |
+
"gradio-rerun>=0.27.0",
|
| 23 |
+
"spaces>=0.43.0",
|
| 24 |
+
]
|
| 25 |
+
name = "sam3d_body"
|
| 26 |
+
requires-python = ">= 3.12"
|
| 27 |
+
version = "0.1.0"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
[build-system]
|
| 31 |
+
build-backend = "hatchling.build"
|
| 32 |
+
requires = ["hatchling"]
|
| 33 |
+
|
| 34 |
+
[tool.hatch.metadata]
|
| 35 |
+
allow-direct-references = true
|
| 36 |
+
|
| 37 |
+
[tool.pixi.workspace]
|
| 38 |
+
channels = ["conda-forge"]
|
| 39 |
+
platforms = ["linux-64"]
|
| 40 |
+
preview = ["pixi-build"]
|
| 41 |
+
|
| 42 |
+
[tool.pixi.pypi-options]
|
| 43 |
+
no-build-isolation = ["detectron2", "moge"]
|
| 44 |
+
[tool.pixi.pypi-options.dependency-overrides]
|
| 45 |
+
# Allow iopath >=0.1.10 even though detectron2 pins <0.1.10, so it can satisfy sam-2.
|
| 46 |
+
iopath = ">=0.1.10"
|
| 47 |
+
gradio = ">=5.45.0,<6"
|
| 48 |
+
[tool.pixi.pypi-dependencies]
|
| 49 |
+
sam3d_body = { path = ".", editable = true }
|
| 50 |
+
moge = { git = "https://github.com/microsoft/MoGe.git" }
|
| 51 |
+
simplecv = { git = "https://github.com/pablovela5620/simplecv.git", branch = "main" }
|
| 52 |
+
timm = ">=0.9"
|
| 53 |
+
transformers = { git = "https://github.com/huggingface/transformers.git", rev = "d08b98b965176ea9cf8c8e8b24995c955b7e2ec9" }
|
| 54 |
+
monopriors = { git = "https://github.com/pablovela5620/monoprior.git" }
|
| 55 |
+
|
| 56 |
+
[tool.pixi.tasks]
|
| 57 |
+
app = "python tool/gradio_sam3d_body.py"
|
| 58 |
+
cli = "python tool/demo.py --image-folder data/example-data"
|
| 59 |
+
|
| 60 |
+
[tool.pixi.feature.cuda129.system-requirements]
|
| 61 |
+
cuda = "12.9"
|
| 62 |
+
|
| 63 |
+
[tool.pixi.feature.cuda129.dependencies]
|
| 64 |
+
# CUDA Build Tools
|
| 65 |
+
cuda-compiler = "*"
|
| 66 |
+
cuda-version = "12.9.*"
|
| 67 |
+
cuda-cudart-dev = "*"
|
| 68 |
+
cuda-crt = "*"
|
| 69 |
+
libcusparse-dev = "*"
|
| 70 |
+
cuda-driver-dev = "*"
|
| 71 |
+
cuda-nvcc = "*"
|
| 72 |
+
cuda-nvrtc-dev = "*"
|
| 73 |
+
cuda-nvtx = "*"
|
| 74 |
+
cuda-nvtx-dev = "*"
|
| 75 |
+
cuda-nvml-dev = "*"
|
| 76 |
+
cuda-profiler-api = "*"
|
| 77 |
+
|
| 78 |
+
# CUDA Libraries
|
| 79 |
+
cudnn = "*"
|
| 80 |
+
libcublas-dev = "*"
|
| 81 |
+
libcudss-dev = "*"
|
| 82 |
+
libcufile-dev = "*"
|
| 83 |
+
libcufft-dev = "*"
|
| 84 |
+
libcurand-dev = "*"
|
| 85 |
+
libcusolver-dev = "*"
|
| 86 |
+
cusparselt = "*"
|
| 87 |
+
libnvjitlink = "*"
|
| 88 |
+
# cuda129 end
|
| 89 |
+
|
| 90 |
+
[tool.pixi.feature.gpu.dependencies]
|
| 91 |
+
pytorch-gpu = ">=2.8.0"
|
| 92 |
+
torchvision = "*"
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
[tool.pixi.feature.dev.dependencies]
|
| 96 |
+
beartype = "*"
|
| 97 |
+
pyrefly = ">=0.42.2,<0.43"
|
| 98 |
+
ruff = ">=0.14.5,<0.15"
|
| 99 |
+
|
| 100 |
+
[tool.pixi.feature.dev.pypi-dependencies]
|
| 101 |
+
types-tqdm = "*"
|
| 102 |
+
|
| 103 |
+
[tool.pixi.environments]
|
| 104 |
+
cuda128 = { features = [
|
| 105 |
+
"cuda129",
|
| 106 |
+
], solve-group = "cuda129", no-default-feature = true }
|
| 107 |
+
default = { features = ["gpu", "cuda129"], solve-group = "cuda129" }
|
| 108 |
+
dev = { features = ["dev", "gpu", "cuda129"], solve-group = "cuda129" }
|
| 109 |
+
|
| 110 |
+
[tool.pixi.dependencies]
|
| 111 |
+
av = ">=16.0.1,<17"
|
| 112 |
+
gradio = ">=5.45.0,<6"
|
| 113 |
+
huggingface_hub = ">=1.0,<2"
|
| 114 |
+
tomlkit = "==0.12.0"
|
| 115 |
+
audioop-lts = "*"
|
| 116 |
+
pydub = "*"
|
| 117 |
+
open3d = ">=0.19.0,<0.20"
|
| 118 |
+
|
| 119 |
+
[tool.ruff]
|
| 120 |
+
line-length = 150
|
| 121 |
+
|
| 122 |
+
[tool.ruff.lint]
|
| 123 |
+
select = [
|
| 124 |
+
# pycodestyle
|
| 125 |
+
"E",
|
| 126 |
+
# Pyflakes
|
| 127 |
+
"F",
|
| 128 |
+
# pyupgrade
|
| 129 |
+
"UP",
|
| 130 |
+
# flake8-bugbear
|
| 131 |
+
"B",
|
| 132 |
+
# flake8-simplify
|
| 133 |
+
"SIM",
|
| 134 |
+
# isort
|
| 135 |
+
"I",
|
| 136 |
+
]
|
| 137 |
+
|
| 138 |
+
ignore = [
|
| 139 |
+
"E501", # Line too long.
|
| 140 |
+
"F722", # Forward annotation false positive from jaxtyping. Should be caught by pyright.
|
| 141 |
+
"F821", # Forward annotation false positive from jaxtyping. Should be caught by pyright.
|
| 142 |
+
"UP037", # Remove quotes from type, false positive when using jaxtyping
|
| 143 |
+
"UP040", # Beartype fails if not using this for typealias
|
| 144 |
+
|
| 145 |
+
]
|
| 146 |
+
|
| 147 |
+
[tool.pyrefly]
|
| 148 |
+
project-includes = ["**/*"]
|
| 149 |
+
project-excludes = ["**/node_modules", "**/__pycache__", "**/*venv/**/*"]
|
src/sam3d_body/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# Only enable beartype when running in the 'dev' environment
|
| 4 |
+
# Check the PIXI_ENVIRONMENT_NAME environment variable set by pixi
|
| 5 |
+
if os.environ.get("PIXI_ENVIRONMENT_NAME") == "dev":
|
| 6 |
+
try:
|
| 7 |
+
from beartype.claw import beartype_this_package
|
| 8 |
+
|
| 9 |
+
beartype_this_package()
|
| 10 |
+
except ImportError:
|
| 11 |
+
# beartype not available even in dev environment
|
| 12 |
+
pass
|
src/sam3d_body/api/demo.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Minimal standalone demo wiring for SAM 3D Body with Rerun visualization."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from glob import glob
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Literal, TypedDict
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
import numpy as np
|
| 11 |
+
import rerun as rr
|
| 12 |
+
import rerun.blueprint as rrb
|
| 13 |
+
import torch
|
| 14 |
+
from jaxtyping import Float32, UInt8
|
| 15 |
+
from monopriors.relative_depth_models import BaseRelativePredictor, RelativeDepthPrediction, get_relative_predictor
|
| 16 |
+
from numpy import ndarray
|
| 17 |
+
from serde import serde
|
| 18 |
+
from simplecv.rerun_log_utils import RerunTyroConfig
|
| 19 |
+
from torch import Tensor
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
from transformers.models.sam3 import Sam3Model, Sam3Processor
|
| 22 |
+
from yacs.config import CfgNode
|
| 23 |
+
|
| 24 |
+
from sam3d_body.api.visualization import create_view, set_annotation_context, visualize_sample
|
| 25 |
+
from sam3d_body.build_models import load_sam_3d_body, load_sam_3d_body_hf
|
| 26 |
+
from sam3d_body.models.meta_arch import SAM3DBody
|
| 27 |
+
from sam3d_body.sam_3d_body_estimator import FinalPosePrediction, SAM3DBodyEstimator
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class SAM3ResultsDict(TypedDict):
|
| 31 |
+
"""Torch-format outputs returned directly by ``Sam3Processor`` post-processing."""
|
| 32 |
+
|
| 33 |
+
scores: Float32[Tensor, "n"]
|
| 34 |
+
boxes: Float32[Tensor, "n 4"]
|
| 35 |
+
masks: Float32[Tensor, "n h w"]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@serde()
|
| 39 |
+
class SAM3Results:
|
| 40 |
+
scores: Float32[ndarray, "n"]
|
| 41 |
+
"""Per-instance confidence scores ``[N]``."""
|
| 42 |
+
boxes: Float32[ndarray, "n 4"]
|
| 43 |
+
"""Bounding boxes in XYXY pixel coordinates ``[N, 4]``."""
|
| 44 |
+
masks: Float32[ndarray, "n h w"]
|
| 45 |
+
"""Probability masks for each detection ``[N, H, W]`` (float32 in ``[0, 1]``)."""
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class SAM3Config:
|
| 50 |
+
"""Configuration for loading a SAM3 checkpoint and selecting device."""
|
| 51 |
+
|
| 52 |
+
device: Literal["cpu", "cuda"] = "cuda"
|
| 53 |
+
"""Computation device passed to the Hugging Face SAM3 model."""
|
| 54 |
+
sam3_checkpoint: str = "facebook/sam3"
|
| 55 |
+
"""Model identifier or path accepted by ``Sam3Model.from_pretrained``."""
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class SAM3Predictor:
|
| 59 |
+
"""Lightweight wrapper around the SAM3 model for single-image inference."""
|
| 60 |
+
|
| 61 |
+
def __init__(self, config: SAM3Config):
|
| 62 |
+
self.config = config
|
| 63 |
+
self.sam3_model = Sam3Model.from_pretrained(config.sam3_checkpoint).to(config.device)
|
| 64 |
+
self.sam3_processor = Sam3Processor.from_pretrained(config.sam3_checkpoint)
|
| 65 |
+
|
| 66 |
+
def predict_single_image(self, rgb_hw3: UInt8[ndarray, "h w 3"], text: str = "person") -> SAM3Results:
|
| 67 |
+
"""Run SAM3 instance segmentation on one RGB image.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
rgb_hw3: Input image in RGB order with dtype ``uint8`` and shape ``[H, W, 3]``.
|
| 71 |
+
text: Optional prompt used by SAM3's text-conditioned decoder (default: ``"person"``).
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
``SAM3Results`` with NumPy copies of scores, XYXY boxes, and binary masks.
|
| 75 |
+
"""
|
| 76 |
+
inputs = self.sam3_processor(
|
| 77 |
+
images=rgb_hw3,
|
| 78 |
+
text=text,
|
| 79 |
+
return_tensors="pt",
|
| 80 |
+
).to(self.config.device)
|
| 81 |
+
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
outputs = self.sam3_model(**inputs)
|
| 84 |
+
|
| 85 |
+
results: SAM3ResultsDict = self.sam3_processor.post_process_instance_segmentation(
|
| 86 |
+
outputs, threshold=0.5, mask_threshold=0.5, target_sizes=inputs.get("original_sizes").tolist()
|
| 87 |
+
)[0]
|
| 88 |
+
|
| 89 |
+
mask_probs: Float32[ndarray, "n h w"] = results["masks"].detach().cpu().numpy().astype(np.float32, copy=False)
|
| 90 |
+
|
| 91 |
+
return SAM3Results(
|
| 92 |
+
scores=results["scores"].detach().cpu().numpy().astype(np.float32, copy=False),
|
| 93 |
+
boxes=results["boxes"].detach().cpu().numpy().astype(np.float32, copy=False),
|
| 94 |
+
masks=mask_probs,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@dataclass
|
| 99 |
+
class SAM3DBodyE2EConfig:
|
| 100 |
+
"""Bundle of sub-configurations required for the end-to-end demo."""
|
| 101 |
+
|
| 102 |
+
sam3_config: SAM3Config
|
| 103 |
+
"""Settings for the underlying SAM3 detector."""
|
| 104 |
+
fov_estimator: Literal["MogeV1Predictor"] = "MogeV1Predictor"
|
| 105 |
+
"""Identifier of the relative depth/FOV estimator to load."""
|
| 106 |
+
mhr_path: Path = Path("checkpoints/sam-3d-body-dinov3/assets/mhr_model.pt")
|
| 107 |
+
"""Path to the MHR mesh/pose asset file required by the head network."""
|
| 108 |
+
checkpoint_path: Path = Path("checkpoints/sam-3d-body-dinov3/model.ckpt")
|
| 109 |
+
"""Core SAM 3D Body model checkpoint (.ckpt)."""
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class SAM3DBodyE2E:
|
| 113 |
+
"""Convenience facade that chains detection, FOV estimation, and 3D reconstruction."""
|
| 114 |
+
|
| 115 |
+
def __init__(self, config: SAM3DBodyE2EConfig):
|
| 116 |
+
self.sam3_predictor = SAM3Predictor(config.sam3_config)
|
| 117 |
+
self.fov_predictor: BaseRelativePredictor = get_relative_predictor(config.fov_estimator)(device="cuda")
|
| 118 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 119 |
+
# load_output: tuple[SAM3DBody, CfgNode] = load_sam_3d_body(
|
| 120 |
+
# config.checkpoint_path,
|
| 121 |
+
# device=device,
|
| 122 |
+
# mhr_path=config.mhr_path,
|
| 123 |
+
# )
|
| 124 |
+
load_output: tuple[SAM3DBody, CfgNode] = load_sam_3d_body_hf(repo_id="facebook/sam-3d-body-dinov3")
|
| 125 |
+
model: SAM3DBody = load_output[0]
|
| 126 |
+
self.sam3d_body_estimator = SAM3DBodyEstimator(
|
| 127 |
+
sam_3d_body_model=model,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def predict_single_image(
|
| 131 |
+
self, rgb_hw3: UInt8[ndarray, "h w 3"]
|
| 132 |
+
) -> tuple[list[FinalPosePrediction], RelativeDepthPrediction]:
|
| 133 |
+
"""Estimate 3D poses for a single frame.
|
| 134 |
+
|
| 135 |
+
Pipeline:
|
| 136 |
+
1. Use the configured relative-depth predictor to derive camera intrinsics ``K_33``.
|
| 137 |
+
2. Run SAM3 to obtain person masks and boxes.
|
| 138 |
+
3. Feed detections and intrinsics into ``SAM3DBodyEstimator`` for per-person 3D bodies.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
rgb_hw3: RGB image with shape ``[H, W, 3]`` and dtype ``uint8``.
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
A list of ``FinalPosePrediction`` entries—one per detected person.
|
| 145 |
+
"""
|
| 146 |
+
# estimate the camera intrinsics
|
| 147 |
+
relative_pred: RelativeDepthPrediction = self.fov_predictor(rgb=rgb_hw3, K_33=None)
|
| 148 |
+
K_33: Float32[ndarray, "3 3"] = relative_pred.K_33
|
| 149 |
+
|
| 150 |
+
sam3_results: SAM3Results = self.sam3_predictor.predict_single_image(rgb_hw3)
|
| 151 |
+
|
| 152 |
+
outputs: list[FinalPosePrediction] = self.sam3d_body_estimator.process_one_image(
|
| 153 |
+
rgb_hw3,
|
| 154 |
+
xyxy=sam3_results.boxes,
|
| 155 |
+
masks=sam3_results.masks,
|
| 156 |
+
masks_score=sam3_results.scores,
|
| 157 |
+
K_33=K_33,
|
| 158 |
+
)
|
| 159 |
+
return outputs, relative_pred
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
@dataclass(slots=True)
|
| 163 |
+
class Sam3DBodyDemoConfig:
|
| 164 |
+
"""Configuration for the standalone demo runner."""
|
| 165 |
+
|
| 166 |
+
rr_config: RerunTyroConfig
|
| 167 |
+
"""Viewer/runtime options for Rerun (window layout, recording, etc.)."""
|
| 168 |
+
|
| 169 |
+
sam3_e2e_config: SAM3DBodyE2EConfig
|
| 170 |
+
"""Configuration for the end-to-end SAM 3D Body model."""
|
| 171 |
+
|
| 172 |
+
image_folder: Path | None = None
|
| 173 |
+
"""Directory containing input images to process."""
|
| 174 |
+
|
| 175 |
+
image_path: Path | None = None
|
| 176 |
+
"""Path to a single input image to process."""
|
| 177 |
+
|
| 178 |
+
max_frames: int | None = None
|
| 179 |
+
"""Optional limit on the number of images to process; ``None`` processes all images."""
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def main(cfg: Sam3DBodyDemoConfig):
|
| 183 |
+
"""Run the Rerun-enabled demo on a folder or single image.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
cfg: Aggregated configuration containing Rerun settings, SAM3 model options,
|
| 187 |
+
and input image selection.
|
| 188 |
+
"""
|
| 189 |
+
# Setup Rerun
|
| 190 |
+
parent_log_path = Path("/world")
|
| 191 |
+
set_annotation_context()
|
| 192 |
+
view: rrb.ContainerLike = create_view()
|
| 193 |
+
blueprint = rrb.Blueprint(view, collapse_panels=True)
|
| 194 |
+
rr.send_blueprint(blueprint)
|
| 195 |
+
rr.log("/", rr.ViewCoordinates.RDF, static=True)
|
| 196 |
+
|
| 197 |
+
if cfg.image_path is not None:
|
| 198 |
+
images_list = [str(cfg.image_path)]
|
| 199 |
+
elif cfg.image_folder is not None:
|
| 200 |
+
image_extensions: list[str] = [
|
| 201 |
+
"*.jpg",
|
| 202 |
+
"*.jpeg",
|
| 203 |
+
"*.png",
|
| 204 |
+
"*.gif",
|
| 205 |
+
"*.bmp",
|
| 206 |
+
"*.tiff",
|
| 207 |
+
"*.webp",
|
| 208 |
+
]
|
| 209 |
+
images_list: list[str] = sorted(
|
| 210 |
+
[image for ext in image_extensions for image in glob(os.path.join(cfg.image_folder, ext))]
|
| 211 |
+
)
|
| 212 |
+
else:
|
| 213 |
+
raise ValueError("Either image_path or image_folder must be specified.")
|
| 214 |
+
|
| 215 |
+
# load end to end model
|
| 216 |
+
sam3D_body_e2e = SAM3DBodyE2E(cfg.sam3_e2e_config)
|
| 217 |
+
|
| 218 |
+
for idx, image_path in enumerate(tqdm(images_list)):
|
| 219 |
+
rr.set_time(timeline="image_sequence", sequence=idx)
|
| 220 |
+
# load image and convert to RGB
|
| 221 |
+
bgr_hw3: UInt8[ndarray, "h w 3"] = cv2.imread(image_path)
|
| 222 |
+
rgb_hw3: UInt8[ndarray, "h w 3"] = cv2.cvtColor(bgr_hw3, cv2.COLOR_BGR2RGB)
|
| 223 |
+
|
| 224 |
+
outputs: tuple[list[FinalPosePrediction], RelativeDepthPrediction] = sam3D_body_e2e.predict_single_image(
|
| 225 |
+
rgb_hw3
|
| 226 |
+
)
|
| 227 |
+
pred_list: list[FinalPosePrediction] = outputs[0]
|
| 228 |
+
relative_pred: RelativeDepthPrediction = outputs[1]
|
| 229 |
+
|
| 230 |
+
if len(pred_list) == 0:
|
| 231 |
+
# Detector/FOV failed on this frame; avoid crashing the visualization step.
|
| 232 |
+
print(f"[warn] No detections for {image_path}; skipping.")
|
| 233 |
+
continue
|
| 234 |
+
|
| 235 |
+
visualize_sample(
|
| 236 |
+
pred_list=pred_list,
|
| 237 |
+
rgb_hw3=rgb_hw3,
|
| 238 |
+
parent_log_path=parent_log_path,
|
| 239 |
+
faces=sam3D_body_e2e.sam3d_body_estimator.faces,
|
| 240 |
+
relative_depth_pred=relative_pred,
|
| 241 |
+
)
|
src/sam3d_body/api/visualization.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import open3d as o3d
|
| 6 |
+
import rerun as rr
|
| 7 |
+
import rerun.blueprint as rrb
|
| 8 |
+
from jaxtyping import Bool, Float32, Int, UInt8
|
| 9 |
+
from monopriors.depth_utils import depth_edges_mask
|
| 10 |
+
from monopriors.relative_depth_models import RelativeDepthPrediction
|
| 11 |
+
from numpy import ndarray
|
| 12 |
+
from simplecv.camera_parameters import Extrinsics, Intrinsics, PinholeParameters
|
| 13 |
+
from simplecv.ops.pc_utils import estimate_voxel_size
|
| 14 |
+
from simplecv.rerun_log_utils import log_pinhole
|
| 15 |
+
|
| 16 |
+
from sam3d_body.metadata.mhr70 import MHR70_ID2NAME, MHR70_IDS, MHR70_LINKS
|
| 17 |
+
from sam3d_body.sam_3d_body_estimator import FinalPosePrediction
|
| 18 |
+
|
| 19 |
+
BOX_PALETTE: UInt8[np.ndarray, "n_colors 4"] = np.array(
|
| 20 |
+
[
|
| 21 |
+
[255, 99, 71, 255], # tomato
|
| 22 |
+
[65, 105, 225, 255], # royal blue
|
| 23 |
+
[60, 179, 113, 255], # medium sea green
|
| 24 |
+
[255, 215, 0, 255], # gold
|
| 25 |
+
[138, 43, 226, 255], # blue violet
|
| 26 |
+
[255, 140, 0, 255], # dark orange
|
| 27 |
+
[220, 20, 60, 255], # crimson
|
| 28 |
+
[70, 130, 180, 255], # steel blue
|
| 29 |
+
],
|
| 30 |
+
dtype=np.uint8,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# Use a separate id range for segmentation classes to avoid clobbering the person class (id=0).
|
| 34 |
+
SEG_CLASS_OFFSET = 1000 # background = 1000, persons start at 1001
|
| 35 |
+
MAX_POINT_CLOUD_POINTS = 50_000
|
| 36 |
+
MIN_DEPTH_CONFIDENCE = 0.5
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def filter_out_of_bounds(
|
| 40 |
+
uv: Float32[ndarray, "n_points 2"],
|
| 41 |
+
h: int,
|
| 42 |
+
w: int,
|
| 43 |
+
xyz_cam: Float32[ndarray, "n_points 3"] | None = None,
|
| 44 |
+
) -> Float32[ndarray, "n_points 2"]:
|
| 45 |
+
"""Return a copy of ``uv`` with off-screen (and optional behind-camera) points masked.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
uv: Pixel coordinates ``[N, 2]`` in (u, v) order.
|
| 49 |
+
h: Image height in pixels.
|
| 50 |
+
w: Image width in pixels.
|
| 51 |
+
xyz_cam: Optional camera-frame coordinates ``[N, 3]`` to mask points with negative ``z``.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
Copy of ``uv`` where out-of-bounds rows are set to ``NaN`` so Rerun hides them.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
uv_filtered: Float32[ndarray, "n_points 2"] = np.asarray(uv, dtype=np.float32).copy()
|
| 58 |
+
|
| 59 |
+
out_of_bounds: Bool[ndarray, "n_points"] = np.logical_or(uv_filtered[:, 0] >= float(w), uv_filtered[:, 0] < 0.0)
|
| 60 |
+
out_of_bounds = np.logical_or(out_of_bounds, uv_filtered[:, 1] >= float(h))
|
| 61 |
+
out_of_bounds = np.logical_or(out_of_bounds, uv_filtered[:, 1] < 0.0)
|
| 62 |
+
|
| 63 |
+
if xyz_cam is not None:
|
| 64 |
+
out_of_bounds = np.logical_or(out_of_bounds, xyz_cam[:, 2] < 0.0)
|
| 65 |
+
|
| 66 |
+
uv_filtered[out_of_bounds, :] = np.nan
|
| 67 |
+
return uv_filtered
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def compute_vertex_normals(
|
| 71 |
+
verts: Float32[ndarray, "n_verts 3"],
|
| 72 |
+
faces: Int[ndarray, "n_faces 3"],
|
| 73 |
+
eps: float = 1e-12,
|
| 74 |
+
) -> Float32[ndarray, "n_verts 3"]:
|
| 75 |
+
"""Compute per-vertex normals for a single mesh.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
verts: Float32 array of vertex positions with shape ``(n_verts, 3)``.
|
| 79 |
+
faces: Int array of triangle indices with shape ``(n_faces, 3)``.
|
| 80 |
+
eps: Small epsilon to avoid division by zero when normalizing.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Float32 array of unit vertex normals with shape ``(n_verts, 3)``; zeros for degenerate vertices.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
# Expand faces to vertex triplets and fetch their positions.
|
| 87 |
+
faces_i: Int[ndarray, "n_faces 3"] = faces.astype(np.int64)
|
| 88 |
+
v0: Float32[ndarray, "n_faces 3"] = verts[faces_i[:, 0]]
|
| 89 |
+
v1: Float32[ndarray, "n_faces 3"] = verts[faces_i[:, 1]]
|
| 90 |
+
v2: Float32[ndarray, "n_faces 3"] = verts[faces_i[:, 2]]
|
| 91 |
+
|
| 92 |
+
# Face normal = cross(edge1, edge2).
|
| 93 |
+
e1: Float32[ndarray, "n_faces 3"] = v1 - v0
|
| 94 |
+
e2: Float32[ndarray, "n_faces 3"] = v2 - v0
|
| 95 |
+
face_normals: Float32[ndarray, "n_faces 3"] = np.cross(e1, e2)
|
| 96 |
+
|
| 97 |
+
# Accumulate each face normal into its three vertices with a vectorized scatter-add.
|
| 98 |
+
vertex_normals: Float32[ndarray, "n_verts 3"] = np.zeros_like(verts, dtype=np.float32)
|
| 99 |
+
flat_indices: Int[ndarray, "n_faces3"] = faces_i.reshape(-1)
|
| 100 |
+
face_normals_repeated: Float32[ndarray, "n_faces3 3"] = np.repeat(face_normals, 3, axis=0)
|
| 101 |
+
np.add.at(vertex_normals, flat_indices, face_normals_repeated)
|
| 102 |
+
|
| 103 |
+
norms: Float32[ndarray, "n_verts 1"] = np.linalg.norm(vertex_normals, axis=-1, keepdims=True)
|
| 104 |
+
denom: Float32[ndarray, "n_verts 1"] = np.maximum(norms, eps).astype(np.float32)
|
| 105 |
+
vn_unit: Float32[ndarray, "n_verts 3"] = (vertex_normals / denom).astype(np.float32)
|
| 106 |
+
mask: ndarray = norms > eps
|
| 107 |
+
vn_unit = np.where(mask, vn_unit, np.float32(0.0))
|
| 108 |
+
return vn_unit
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def export_meshes_to_glb(
|
| 112 |
+
pred_list: list[FinalPosePrediction],
|
| 113 |
+
faces: Int[ndarray, "n_faces 3"],
|
| 114 |
+
output_dir: Path,
|
| 115 |
+
box_palette: UInt8[ndarray, "n_colors 4"] = BOX_PALETTE,
|
| 116 |
+
center_mesh: bool = True,
|
| 117 |
+
) -> list[Path]:
|
| 118 |
+
"""Write one GLB per predicted mesh and return the file paths."""
|
| 119 |
+
|
| 120 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 121 |
+
written_paths: list[Path] = []
|
| 122 |
+
faces_int: Int[ndarray, "n_faces 3"] = np.ascontiguousarray(faces, dtype=np.int32)
|
| 123 |
+
|
| 124 |
+
for idx, output in enumerate(pred_list):
|
| 125 |
+
verts_cam: Float32[ndarray, "n_verts 3"] = np.ascontiguousarray(output.pred_vertices, dtype=np.float32)
|
| 126 |
+
cam_t: Float32[ndarray, "3"] = np.ascontiguousarray(output.pred_cam_t, dtype=np.float32)
|
| 127 |
+
# Convert to world coordinates to mirror the viewer logging convention (cam → world via translation).
|
| 128 |
+
verts_world: Float32[ndarray, "n_verts 3"] = np.ascontiguousarray(verts_cam + cam_t, dtype=np.float32)
|
| 129 |
+
verts_export: Float32[ndarray, "n_verts 3"]
|
| 130 |
+
verts_export = verts_world - np.mean(verts_world, axis=0, keepdims=True) if center_mesh else verts_world
|
| 131 |
+
|
| 132 |
+
vertex_normals: Float32[ndarray, "n_verts 3"] = compute_vertex_normals(verts_export, faces_int)
|
| 133 |
+
|
| 134 |
+
mesh = o3d.geometry.TriangleMesh()
|
| 135 |
+
mesh.vertices = o3d.utility.Vector3dVector(verts_export.astype(np.float64))
|
| 136 |
+
mesh.triangles = o3d.utility.Vector3iVector(faces_int.astype(np.int32))
|
| 137 |
+
mesh.vertex_normals = o3d.utility.Vector3dVector(vertex_normals.astype(np.float64))
|
| 138 |
+
|
| 139 |
+
color: Float32[ndarray, "3"] = box_palette[idx % len(box_palette), :3].astype(np.float32) / 255.0
|
| 140 |
+
vertex_colors: Float32[ndarray, "n_verts 3"] = np.repeat(color[np.newaxis, :], verts_export.shape[0], axis=0)
|
| 141 |
+
mesh.vertex_colors = o3d.utility.Vector3dVector(vertex_colors.astype(np.float64))
|
| 142 |
+
|
| 143 |
+
glb_path: Path = output_dir / f"person_{idx:02d}.glb"
|
| 144 |
+
success: bool = bool(
|
| 145 |
+
o3d.io.write_triangle_mesh(
|
| 146 |
+
str(glb_path),
|
| 147 |
+
mesh,
|
| 148 |
+
write_ascii=False,
|
| 149 |
+
write_vertex_normals=True,
|
| 150 |
+
write_vertex_colors=True,
|
| 151 |
+
)
|
| 152 |
+
)
|
| 153 |
+
if not success:
|
| 154 |
+
fallback_path: Path = output_dir / f"person_{idx:02d}.ply"
|
| 155 |
+
success = bool(
|
| 156 |
+
o3d.io.write_triangle_mesh(
|
| 157 |
+
str(fallback_path),
|
| 158 |
+
mesh,
|
| 159 |
+
write_ascii=False,
|
| 160 |
+
write_vertex_normals=True,
|
| 161 |
+
write_vertex_colors=True,
|
| 162 |
+
)
|
| 163 |
+
)
|
| 164 |
+
if success:
|
| 165 |
+
glb_path = fallback_path
|
| 166 |
+
|
| 167 |
+
if success:
|
| 168 |
+
written_paths.append(glb_path)
|
| 169 |
+
|
| 170 |
+
return written_paths
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def set_annotation_context() -> None:
|
| 174 |
+
"""Register MHR-70 semantic metadata so subsequent logs show names/edges and mask colors."""
|
| 175 |
+
# Base person class (for keypoints / boxes) uses id=0 (original), segmentation uses 1000+ to avoid clashes.
|
| 176 |
+
person_class = rr.ClassDescription(
|
| 177 |
+
info=rr.AnnotationInfo(id=0, label="Person", color=(0, 0, 255)),
|
| 178 |
+
keypoint_annotations=[rr.AnnotationInfo(id=idx, label=name) for idx, name in MHR70_ID2NAME.items()],
|
| 179 |
+
keypoint_connections=MHR70_LINKS,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Segmentation classes: id=SEG_CLASS_OFFSET background, ids SEG_CLASS_OFFSET+1..n for each instance color.
|
| 183 |
+
seg_classes: list[rr.ClassDescription] = [
|
| 184 |
+
rr.ClassDescription(info=rr.AnnotationInfo(id=SEG_CLASS_OFFSET, label="Background", color=(64, 64, 64))),
|
| 185 |
+
]
|
| 186 |
+
for idx, color in enumerate(BOX_PALETTE[:, :3].tolist(), start=1):
|
| 187 |
+
seg_classes.append(
|
| 188 |
+
rr.ClassDescription(
|
| 189 |
+
info=rr.AnnotationInfo(
|
| 190 |
+
id=SEG_CLASS_OFFSET + idx, label=f"Person-{idx}", color=tuple(int(c) for c in color)
|
| 191 |
+
),
|
| 192 |
+
)
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
rr.log(
|
| 196 |
+
"/",
|
| 197 |
+
rr.AnnotationContext([person_class, *seg_classes]),
|
| 198 |
+
static=True,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def visualize_sample(
|
| 203 |
+
pred_list: list[FinalPosePrediction],
|
| 204 |
+
rgb_hw3: UInt8[ndarray, "h w 3"],
|
| 205 |
+
parent_log_path: Path,
|
| 206 |
+
faces: Int[ndarray, "n_faces 3"],
|
| 207 |
+
relative_depth_pred: RelativeDepthPrediction | None = None,
|
| 208 |
+
) -> None:
|
| 209 |
+
h: int = rgb_hw3.shape[0]
|
| 210 |
+
w: int = rgb_hw3.shape[1]
|
| 211 |
+
cam_log_path: Path = parent_log_path / "cam"
|
| 212 |
+
pinhole_log_path: Path = cam_log_path / "pinhole"
|
| 213 |
+
image_log_path: Path = pinhole_log_path / "image"
|
| 214 |
+
pred_log_path: Path = pinhole_log_path / "pred"
|
| 215 |
+
# log the pinhole camera parameters (assume fx=fy and center at image center)
|
| 216 |
+
focal_length: float = float(pred_list[0].focal_length)
|
| 217 |
+
intri: Intrinsics = Intrinsics(
|
| 218 |
+
camera_conventions="RDF",
|
| 219 |
+
fl_x=focal_length,
|
| 220 |
+
fl_y=focal_length,
|
| 221 |
+
cx=float(w) / 2.0,
|
| 222 |
+
cy=float(h) / 2.0,
|
| 223 |
+
height=h,
|
| 224 |
+
width=w,
|
| 225 |
+
)
|
| 226 |
+
world_T_cam: Float32[ndarray, "4 4"] = np.eye(4, dtype=np.float32)
|
| 227 |
+
extri: Extrinsics = Extrinsics(
|
| 228 |
+
world_R_cam=world_T_cam[:3, :3],
|
| 229 |
+
world_t_cam=world_T_cam[:3, 3],
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
pinhole_params: PinholeParameters = PinholeParameters(intrinsics=intri, extrinsics=extri, name="pinhole")
|
| 233 |
+
log_pinhole(camera=pinhole_params, cam_log_path=cam_log_path)
|
| 234 |
+
# clear the previous pred logs
|
| 235 |
+
rr.log(f"{pred_log_path}", rr.Clear(recursive=True))
|
| 236 |
+
rr.log(f"{image_log_path}", rr.Image(rgb_hw3, color_model=rr.ColorModel.RGB).compress(jpeg_quality=90))
|
| 237 |
+
|
| 238 |
+
# Build per-pixel maps (SEG_CLASS_OFFSET = background). Also build RGBA overlay with transparent background.
|
| 239 |
+
seg_map: Int[ndarray, "h w"] = np.full((h, w), SEG_CLASS_OFFSET, dtype=np.int32)
|
| 240 |
+
seg_overlay: UInt8[ndarray, "h w 4"] = np.zeros((h, w, 4), dtype=np.uint8)
|
| 241 |
+
human_mask: Bool[ndarray, "h w"] = np.zeros((h, w), dtype=bool)
|
| 242 |
+
|
| 243 |
+
mesh_root_path: Path = parent_log_path / "pred"
|
| 244 |
+
rr.log(str(mesh_root_path), rr.Clear(recursive=True))
|
| 245 |
+
|
| 246 |
+
for i, output in enumerate(pred_list):
|
| 247 |
+
box_color: UInt8[ndarray, "1 4"] = BOX_PALETTE[i % len(BOX_PALETTE)].reshape(1, 4)
|
| 248 |
+
rr.log(
|
| 249 |
+
f"{pred_log_path}/bbox_{i}",
|
| 250 |
+
rr.Boxes2D(
|
| 251 |
+
array=output.bbox,
|
| 252 |
+
array_format=rr.Box2DFormat.XYXY,
|
| 253 |
+
class_ids=0,
|
| 254 |
+
colors=box_color,
|
| 255 |
+
show_labels=True,
|
| 256 |
+
),
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
kpts_cam: Float32[ndarray, "n_kpts 3"] = np.ascontiguousarray(output.pred_keypoints_3d, dtype=np.float32)
|
| 260 |
+
kpts_uv: Float32[ndarray, "n_kpts 2"] = np.ascontiguousarray(output.pred_keypoints_2d, dtype=np.float32)
|
| 261 |
+
kpts_uv_in_bounds: Float32[ndarray, "n_kpts 2"] = filter_out_of_bounds(
|
| 262 |
+
uv=kpts_uv,
|
| 263 |
+
h=h,
|
| 264 |
+
w=w,
|
| 265 |
+
xyz_cam=None, # Depth sign from the model can be negative; only cull by image bounds.
|
| 266 |
+
)
|
| 267 |
+
rr.log(
|
| 268 |
+
f"{pred_log_path}/uv_{i}",
|
| 269 |
+
rr.Points2D(
|
| 270 |
+
positions=kpts_uv_in_bounds,
|
| 271 |
+
keypoint_ids=MHR70_IDS,
|
| 272 |
+
class_ids=0,
|
| 273 |
+
colors=(0, 255, 0),
|
| 274 |
+
),
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# Accumulate segmentation masks (if present) into a single segmentation image.
|
| 278 |
+
mask = output.mask
|
| 279 |
+
if mask is not None:
|
| 280 |
+
mask_arr: ndarray = np.asarray(mask).squeeze()
|
| 281 |
+
if mask_arr.shape != seg_map.shape:
|
| 282 |
+
mask_arr = cv2.resize(
|
| 283 |
+
mask_arr.astype(np.uint8), (seg_map.shape[1], seg_map.shape[0]), interpolation=cv2.INTER_NEAREST
|
| 284 |
+
)
|
| 285 |
+
mask_bool = mask_arr.astype(bool)
|
| 286 |
+
human_mask = np.logical_or(human_mask, mask_bool)
|
| 287 |
+
seg_id = SEG_CLASS_OFFSET + i + 1 # keep person class (0) separate from seg classes
|
| 288 |
+
seg_map = np.where(mask_bool, np.uint16(seg_id), seg_map)
|
| 289 |
+
|
| 290 |
+
# Color overlay for this instance, background stays transparent.
|
| 291 |
+
color = BOX_PALETTE[i % len(BOX_PALETTE), :3]
|
| 292 |
+
seg_overlay[mask_bool] = np.array([color[0], color[1], color[2], 120], dtype=np.uint8)
|
| 293 |
+
|
| 294 |
+
# Log 3D keypoints in world coordinates
|
| 295 |
+
cam_t: Float32[ndarray, "3"] = np.ascontiguousarray(output.pred_cam_t, dtype=np.float32)
|
| 296 |
+
kpts_world: Float32[ndarray, "n_kpts 3"] = np.ascontiguousarray(kpts_cam + cam_t, dtype=np.float32)
|
| 297 |
+
rr.log(
|
| 298 |
+
f"{parent_log_path}/pred/kpts3d_{i}",
|
| 299 |
+
rr.Points3D(
|
| 300 |
+
positions=kpts_world,
|
| 301 |
+
keypoint_ids=MHR70_IDS,
|
| 302 |
+
class_ids=0,
|
| 303 |
+
colors=(0, 255, 0),
|
| 304 |
+
),
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# Log the full-body mesh in world coordinates so it shows in 3D
|
| 308 |
+
verts_cam: Float32[ndarray, "n_verts 3"] = np.ascontiguousarray(output.pred_vertices, dtype=np.float32)
|
| 309 |
+
verts_world: Float32[ndarray, "n_verts 3"] = np.ascontiguousarray(verts_cam + cam_t, dtype=np.float32)
|
| 310 |
+
faces_int: Int[ndarray, "n_faces 3"] = np.ascontiguousarray(faces, dtype=np.int32)
|
| 311 |
+
vertex_normals: Float32[ndarray, "n_verts 3"] = compute_vertex_normals(verts_world, faces_int)
|
| 312 |
+
rr.log(
|
| 313 |
+
f"{parent_log_path}/pred/mesh_{i}",
|
| 314 |
+
rr.Mesh3D(
|
| 315 |
+
vertex_positions=verts_world,
|
| 316 |
+
triangle_indices=faces_int,
|
| 317 |
+
vertex_normals=vertex_normals,
|
| 318 |
+
albedo_factor=(
|
| 319 |
+
float(box_color[0, 0]) / 255.0,
|
| 320 |
+
float(box_color[0, 1]) / 255.0,
|
| 321 |
+
float(box_color[0, 2]) / 255.0,
|
| 322 |
+
0.35,
|
| 323 |
+
),
|
| 324 |
+
),
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# Log segmentation ids (full map) and an RGBA overlay with transparent background.
|
| 328 |
+
if np.any(seg_map != SEG_CLASS_OFFSET):
|
| 329 |
+
rr.log(f"{pred_log_path}/segmentation_ids", rr.SegmentationImage(seg_map))
|
| 330 |
+
rr.log(f"{pred_log_path}/segmentation_overlay", rr.Image(seg_overlay, color_model=rr.ColorModel.RGBA))
|
| 331 |
+
|
| 332 |
+
# Optionally log depth and a background-only point cloud (for 3D view only).
|
| 333 |
+
if relative_depth_pred is not None:
|
| 334 |
+
depth_hw: Float32[ndarray, "h w"] = np.asarray(relative_depth_pred.depth, dtype=np.float32)
|
| 335 |
+
conf_hw: Float32[ndarray, "h w"] = np.asarray(relative_depth_pred.confidence, dtype=np.float32)
|
| 336 |
+
if depth_hw.shape != (h, w):
|
| 337 |
+
depth_hw = cv2.resize(depth_hw, (w, h), interpolation=cv2.INTER_NEAREST)
|
| 338 |
+
if conf_hw.shape != (h, w):
|
| 339 |
+
conf_hw = cv2.resize(conf_hw, (w, h), interpolation=cv2.INTER_NEAREST)
|
| 340 |
+
depth_hw = np.nan_to_num(depth_hw, nan=0.0, posinf=0.0, neginf=0.0)
|
| 341 |
+
|
| 342 |
+
# Remove flying pixels along depth discontinuities.
|
| 343 |
+
edges_mask: Bool[ndarray, "h w"] = depth_edges_mask(depth_hw, threshold=0.01)
|
| 344 |
+
depth_hw = depth_hw * np.logical_not(edges_mask)
|
| 345 |
+
|
| 346 |
+
# Remove low-confidence pixels.
|
| 347 |
+
conf_mask: Bool[ndarray, "h w"] = conf_hw >= MIN_DEPTH_CONFIDENCE
|
| 348 |
+
depth_hw = depth_hw * conf_mask
|
| 349 |
+
|
| 350 |
+
background_mask: Bool[ndarray, "h w"] = np.logical_not(human_mask)
|
| 351 |
+
depth_bg: Float32[ndarray, "h w"] = depth_hw * background_mask
|
| 352 |
+
|
| 353 |
+
# Log depth image (not referenced by the 2D blueprint).
|
| 354 |
+
# rr.log(f"{pinhole_log_path}/depth", rr.DepthImage(depth_bg, meter=1.0))
|
| 355 |
+
|
| 356 |
+
fx: float = float(relative_depth_pred.K_33[0, 0])
|
| 357 |
+
fy: float = float(relative_depth_pred.K_33[1, 1])
|
| 358 |
+
cx: float = float(relative_depth_pred.K_33[0, 2])
|
| 359 |
+
cy: float = float(relative_depth_pred.K_33[1, 2])
|
| 360 |
+
|
| 361 |
+
u: Float32[ndarray, "w"] = np.arange(w, dtype=np.float32)
|
| 362 |
+
v: Float32[ndarray, "h"] = np.arange(h, dtype=np.float32)
|
| 363 |
+
uu: Float32[ndarray, "h w"]
|
| 364 |
+
vv: Float32[ndarray, "h w"]
|
| 365 |
+
uu, vv = np.meshgrid(u, v)
|
| 366 |
+
|
| 367 |
+
z_cam: Float32[ndarray, "h w"] = depth_bg
|
| 368 |
+
valid: Bool[ndarray, "h w"] = np.logical_and(z_cam > 0.0, np.isfinite(z_cam))
|
| 369 |
+
if np.any(valid):
|
| 370 |
+
x_cam: Float32[ndarray, "h w"] = (uu - cx) * z_cam / fx
|
| 371 |
+
y_cam: Float32[ndarray, "h w"] = (vv - cy) * z_cam / fy
|
| 372 |
+
points_cam: Float32[ndarray, "h w 3"] = np.stack([x_cam, y_cam, z_cam], axis=-1)
|
| 373 |
+
|
| 374 |
+
points_flat: Float32[ndarray, "n_valid 3"] = points_cam[valid]
|
| 375 |
+
colors_flat: UInt8[ndarray, "n_valid 3"] = rgb_hw3[valid]
|
| 376 |
+
|
| 377 |
+
if points_flat.shape[0] > MAX_POINT_CLOUD_POINTS:
|
| 378 |
+
voxel_size: float = estimate_voxel_size(
|
| 379 |
+
points_flat, target_points=MAX_POINT_CLOUD_POINTS, tolerance=0.25
|
| 380 |
+
)
|
| 381 |
+
pcd: o3d.geometry.PointCloud = o3d.geometry.PointCloud()
|
| 382 |
+
pcd.points = o3d.utility.Vector3dVector(points_flat)
|
| 383 |
+
pcd.colors = o3d.utility.Vector3dVector(colors_flat.astype(np.float32) / 255.0)
|
| 384 |
+
pcd_ds: o3d.geometry.PointCloud = pcd.voxel_down_sample(voxel_size)
|
| 385 |
+
points_flat = np.asarray(pcd_ds.points, dtype=np.float32)
|
| 386 |
+
colors_flat = (np.asarray(pcd_ds.colors, dtype=np.float32) * 255.0).astype(np.uint8)
|
| 387 |
+
|
| 388 |
+
rr.log(
|
| 389 |
+
f"{parent_log_path}/depth_point_cloud",
|
| 390 |
+
rr.Points3D(
|
| 391 |
+
positions=points_flat,
|
| 392 |
+
colors=colors_flat,
|
| 393 |
+
),
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def create_view() -> rrb.ContainerLike:
|
| 398 |
+
view_2d = rrb.Vertical(
|
| 399 |
+
contents=[
|
| 400 |
+
# Top: people-only overlay on the RGB image.
|
| 401 |
+
rrb.Spatial2DView(
|
| 402 |
+
name="image",
|
| 403 |
+
origin="/world/cam/pinhole",
|
| 404 |
+
contents=[
|
| 405 |
+
"/world/cam/pinhole/image",
|
| 406 |
+
"/world/cam/pinhole/pred/segmentation_overlay",
|
| 407 |
+
],
|
| 408 |
+
),
|
| 409 |
+
# Bottom: 2D boxes + keypoints; segmentation hidden.
|
| 410 |
+
rrb.Spatial2DView(
|
| 411 |
+
name="mhr",
|
| 412 |
+
origin="/world/cam/pinhole",
|
| 413 |
+
contents=[
|
| 414 |
+
"/world/cam/pinhole/image",
|
| 415 |
+
"/world/cam/pinhole/pred/**",
|
| 416 |
+
"- /world/cam/pinhole/pred/segmentation_overlay/**",
|
| 417 |
+
"- /world/cam/pinhole/pred/segmentation_ids/**",
|
| 418 |
+
],
|
| 419 |
+
),
|
| 420 |
+
],
|
| 421 |
+
)
|
| 422 |
+
view_3d = rrb.Spatial3DView(name="mhr_3d", line_grid=rrb.LineGrid3D(visible=False))
|
| 423 |
+
main_view = rrb.Horizontal(contents=[view_2d, view_3d], column_shares=[2, 3])
|
| 424 |
+
view = rrb.Tabs(contents=[main_view], name="sam-3d-body-demo")
|
| 425 |
+
return view
|
src/sam3d_body/build_models.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
import os
|
| 3 |
+
from os import PathLike
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from .models.meta_arch import SAM3DBody
|
| 8 |
+
from .utils.checkpoint import load_state_dict
|
| 9 |
+
from .utils.config import CN, get_config
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_sam_3d_body(
|
| 13 |
+
checkpoint_path: str | PathLike[str] = "",
|
| 14 |
+
device: str | torch.device = "cuda",
|
| 15 |
+
mhr_path: str | PathLike[str] = "",
|
| 16 |
+
) -> tuple[SAM3DBody, CN]:
|
| 17 |
+
print("Loading SAM 3D Body model...")
|
| 18 |
+
|
| 19 |
+
checkpoint_path = os.fspath(checkpoint_path)
|
| 20 |
+
mhr_path = os.fspath(mhr_path)
|
| 21 |
+
|
| 22 |
+
# Check the current directory, and if not present check the parent dir.
|
| 23 |
+
model_cfg = os.path.join(os.path.dirname(checkpoint_path), "model_config.yaml")
|
| 24 |
+
if not os.path.exists(model_cfg):
|
| 25 |
+
# Looks at parent dir
|
| 26 |
+
model_cfg = os.path.join(os.path.dirname(os.path.dirname(checkpoint_path)), "model_config.yaml")
|
| 27 |
+
|
| 28 |
+
model_cfg = get_config(model_cfg)
|
| 29 |
+
|
| 30 |
+
# Disable face for inference
|
| 31 |
+
model_cfg.defrost()
|
| 32 |
+
model_cfg.MODEL.MHR_HEAD.MHR_MODEL_PATH = mhr_path
|
| 33 |
+
model_cfg.freeze()
|
| 34 |
+
|
| 35 |
+
# Initialze the model
|
| 36 |
+
model = SAM3DBody(model_cfg)
|
| 37 |
+
|
| 38 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| 39 |
+
state_dict = checkpoint.get("state_dict", checkpoint)
|
| 40 |
+
load_state_dict(model, state_dict, strict=False)
|
| 41 |
+
|
| 42 |
+
model = model.to(device)
|
| 43 |
+
model.eval()
|
| 44 |
+
return model, model_cfg
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _hf_download(repo_id):
|
| 48 |
+
from huggingface_hub import snapshot_download
|
| 49 |
+
|
| 50 |
+
local_dir = snapshot_download(repo_id=repo_id)
|
| 51 |
+
return os.path.join(local_dir, "model.ckpt"), os.path.join(local_dir, "assets", "mhr_model.pt")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_sam_3d_body_hf(repo_id, **kwargs):
|
| 55 |
+
ckpt_path, mhr_path = _hf_download(repo_id)
|
| 56 |
+
return load_sam_3d_body(checkpoint_path=ckpt_path, mhr_path=mhr_path)
|
src/sam3d_body/data/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
src/sam3d_body/data/transforms/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
from .bbox_utils import (
|
| 4 |
+
bbox_cs2xywh,
|
| 5 |
+
bbox_cs2xyxy,
|
| 6 |
+
bbox_xywh2cs,
|
| 7 |
+
bbox_xywh2xyxy,
|
| 8 |
+
bbox_xyxy2cs,
|
| 9 |
+
bbox_xyxy2xywh,
|
| 10 |
+
flip_bbox,
|
| 11 |
+
get_udp_warp_matrix,
|
| 12 |
+
get_warp_matrix,
|
| 13 |
+
)
|
| 14 |
+
from .common import (
|
| 15 |
+
Compose,
|
| 16 |
+
GetBBoxCenterScale,
|
| 17 |
+
NormalizeKeypoint,
|
| 18 |
+
SquarePad,
|
| 19 |
+
TopdownAffine,
|
| 20 |
+
VisionTransformWrapper,
|
| 21 |
+
)
|
src/sam3d_body/data/transforms/bbox_utils.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def bbox_xyxy2xywh(bbox_xyxy: np.ndarray) -> np.ndarray:
|
| 10 |
+
"""Transform the bbox format from x1y1x2y2 to xywh.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
bbox_xyxy (np.ndarray): Bounding boxes (with scores), shaped (n, 4) or
|
| 14 |
+
(n, 5). (left, top, right, bottom, [score])
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
np.ndarray: Bounding boxes (with scores),
|
| 18 |
+
shaped (n, 4) or (n, 5). (left, top, width, height, [score])
|
| 19 |
+
"""
|
| 20 |
+
bbox_xywh = bbox_xyxy.copy()
|
| 21 |
+
bbox_xywh[:, 2] = bbox_xywh[:, 2] - bbox_xywh[:, 0]
|
| 22 |
+
bbox_xywh[:, 3] = bbox_xywh[:, 3] - bbox_xywh[:, 1]
|
| 23 |
+
|
| 24 |
+
return bbox_xywh
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def bbox_xywh2xyxy(bbox_xywh: np.ndarray) -> np.ndarray:
|
| 28 |
+
"""Transform the bbox format from xywh to x1y1x2y2.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
bbox_xywh (ndarray): Bounding boxes (with scores),
|
| 32 |
+
shaped (n, 4) or (n, 5). (left, top, width, height, [score])
|
| 33 |
+
Returns:
|
| 34 |
+
np.ndarray: Bounding boxes (with scores), shaped (n, 4) or
|
| 35 |
+
(n, 5). (left, top, right, bottom, [score])
|
| 36 |
+
"""
|
| 37 |
+
bbox_xyxy = bbox_xywh.copy()
|
| 38 |
+
bbox_xyxy[:, 2] = bbox_xyxy[:, 2] + bbox_xyxy[:, 0]
|
| 39 |
+
bbox_xyxy[:, 3] = bbox_xyxy[:, 3] + bbox_xyxy[:, 1]
|
| 40 |
+
|
| 41 |
+
return bbox_xyxy
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def bbox_xyxy2cs(bbox: np.ndarray, padding: float = 1.0) -> tuple[np.ndarray, np.ndarray]:
|
| 45 |
+
"""Transform the bbox format from (x,y,w,h) into (center, scale)
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
|
| 49 |
+
as (left, top, right, bottom)
|
| 50 |
+
padding (float): BBox padding factor that will be multilied to scale.
|
| 51 |
+
Default: 1.0
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
tuple: A tuple containing center and scale.
|
| 55 |
+
- np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
|
| 56 |
+
(n, 2)
|
| 57 |
+
- np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
|
| 58 |
+
(n, 2)
|
| 59 |
+
"""
|
| 60 |
+
# convert single bbox from (4, ) to (1, 4)
|
| 61 |
+
dim = bbox.ndim
|
| 62 |
+
if dim == 1:
|
| 63 |
+
bbox = bbox[None, :]
|
| 64 |
+
|
| 65 |
+
x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
|
| 66 |
+
center = np.hstack([x1 + x2, y1 + y2]) * 0.5
|
| 67 |
+
scale = np.hstack([x2 - x1, y2 - y1]) * padding
|
| 68 |
+
|
| 69 |
+
if dim == 1:
|
| 70 |
+
center = center[0]
|
| 71 |
+
scale = scale[0]
|
| 72 |
+
|
| 73 |
+
return center, scale
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def bbox_xywh2cs(bbox: np.ndarray, padding: float = 1.0) -> tuple[np.ndarray, np.ndarray]:
|
| 77 |
+
"""Transform the bbox format from (x,y,w,h) into (center, scale)
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
|
| 81 |
+
as (x, y, h, w)
|
| 82 |
+
padding (float): BBox padding factor that will be multilied to scale.
|
| 83 |
+
Default: 1.0
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
tuple: A tuple containing center and scale.
|
| 87 |
+
- np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
|
| 88 |
+
(n, 2)
|
| 89 |
+
- np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
|
| 90 |
+
(n, 2)
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
# convert single bbox from (4, ) to (1, 4)
|
| 94 |
+
dim = bbox.ndim
|
| 95 |
+
if dim == 1:
|
| 96 |
+
bbox = bbox[None, :]
|
| 97 |
+
|
| 98 |
+
x, y, w, h = np.hsplit(bbox, [1, 2, 3])
|
| 99 |
+
center = np.hstack([x + w * 0.5, y + h * 0.5])
|
| 100 |
+
scale = np.hstack([w, h]) * padding
|
| 101 |
+
|
| 102 |
+
if dim == 1:
|
| 103 |
+
center = center[0]
|
| 104 |
+
scale = scale[0]
|
| 105 |
+
|
| 106 |
+
return center, scale
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def bbox_cs2xyxy(center: np.ndarray, scale: np.ndarray, padding: float = 1.0) -> np.ndarray:
|
| 110 |
+
"""Transform the bbox format from (center, scale) to (x1,y1,x2,y2).
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
center (ndarray): BBox center (x, y) in shape (2,) or (n, 2)
|
| 114 |
+
scale (ndarray): BBox scale (w, h) in shape (2,) or (n, 2)
|
| 115 |
+
padding (float): BBox padding factor that will be multilied to scale.
|
| 116 |
+
Default: 1.0
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
ndarray[float32]: BBox (x1, y1, x2, y2) in shape (4, ) or (n, 4)
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
dim = center.ndim
|
| 123 |
+
assert scale.ndim == dim
|
| 124 |
+
|
| 125 |
+
if dim == 1:
|
| 126 |
+
center = center[None, :]
|
| 127 |
+
scale = scale[None, :]
|
| 128 |
+
|
| 129 |
+
wh = scale / padding
|
| 130 |
+
xy = center - 0.5 * wh
|
| 131 |
+
bbox = np.hstack((xy, xy + wh))
|
| 132 |
+
|
| 133 |
+
if dim == 1:
|
| 134 |
+
bbox = bbox[0]
|
| 135 |
+
|
| 136 |
+
return bbox
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def bbox_cs2xywh(center: np.ndarray, scale: np.ndarray, padding: float = 1.0) -> np.ndarray:
|
| 140 |
+
"""Transform the bbox format from (center, scale) to (x,y,w,h).
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
center (ndarray): BBox center (x, y) in shape (2,) or (n, 2)
|
| 144 |
+
scale (ndarray): BBox scale (w, h) in shape (2,) or (n, 2)
|
| 145 |
+
padding (float): BBox padding factor that will be multilied to scale.
|
| 146 |
+
Default: 1.0
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
ndarray[float32]: BBox (x, y, w, h) in shape (4, ) or (n, 4)
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
dim = center.ndim
|
| 153 |
+
assert scale.ndim == dim
|
| 154 |
+
|
| 155 |
+
if dim == 1:
|
| 156 |
+
center = center[None, :]
|
| 157 |
+
scale = scale[None, :]
|
| 158 |
+
|
| 159 |
+
wh = scale / padding
|
| 160 |
+
xy = center - 0.5 * wh
|
| 161 |
+
bbox = np.hstack((xy, wh))
|
| 162 |
+
|
| 163 |
+
if dim == 1:
|
| 164 |
+
bbox = bbox[0]
|
| 165 |
+
|
| 166 |
+
return bbox
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def flip_bbox(
|
| 170 |
+
bbox: np.ndarray,
|
| 171 |
+
image_size: tuple[int, int],
|
| 172 |
+
bbox_format: str = "xywh",
|
| 173 |
+
direction: str = "horizontal",
|
| 174 |
+
) -> np.ndarray:
|
| 175 |
+
"""Flip the bbox in the given direction.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
bbox (np.ndarray): The bounding boxes. The shape should be (..., 4)
|
| 179 |
+
if ``bbox_format`` is ``'xyxy'`` or ``'xywh'``, and (..., 2) if
|
| 180 |
+
``bbox_format`` is ``'center'``
|
| 181 |
+
image_size (tuple): The image shape in [w, h]
|
| 182 |
+
bbox_format (str): The bbox format. Options are ``'xywh'``, ``'xyxy'``
|
| 183 |
+
and ``'center'``.
|
| 184 |
+
direction (str): The flip direction. Options are ``'horizontal'``,
|
| 185 |
+
``'vertical'`` and ``'diagonal'``. Defaults to ``'horizontal'``
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
np.ndarray: The flipped bounding boxes.
|
| 189 |
+
"""
|
| 190 |
+
direction_options = {"horizontal", "vertical", "diagonal"}
|
| 191 |
+
assert direction in direction_options, f'Invalid flipping direction "{direction}". Options are {direction_options}'
|
| 192 |
+
|
| 193 |
+
format_options = {"xywh", "xyxy", "center"}
|
| 194 |
+
assert bbox_format in format_options, f'Invalid bbox format "{bbox_format}". Options are {format_options}'
|
| 195 |
+
|
| 196 |
+
bbox_flipped = bbox.copy()
|
| 197 |
+
w, h = image_size
|
| 198 |
+
|
| 199 |
+
if direction == "horizontal":
|
| 200 |
+
if bbox_format == "xywh" or bbox_format == "center":
|
| 201 |
+
bbox_flipped[..., 0] = w - bbox[..., 0] - 1
|
| 202 |
+
elif bbox_format == "xyxy":
|
| 203 |
+
bbox_flipped[..., ::2] = w - bbox[..., ::2] - 1
|
| 204 |
+
elif direction == "vertical":
|
| 205 |
+
if bbox_format == "xywh" or bbox_format == "center":
|
| 206 |
+
bbox_flipped[..., 1] = h - bbox[..., 1] - 1
|
| 207 |
+
elif bbox_format == "xyxy":
|
| 208 |
+
bbox_flipped[..., 1::2] = h - bbox[..., 1::2] - 1
|
| 209 |
+
elif direction == "diagonal":
|
| 210 |
+
if bbox_format == "xywh" or bbox_format == "center":
|
| 211 |
+
bbox_flipped[..., :2] = [w, h] - bbox[..., :2] - 1
|
| 212 |
+
elif bbox_format == "xyxy":
|
| 213 |
+
bbox_flipped[...] = [w, h, w, h] - bbox - 1
|
| 214 |
+
|
| 215 |
+
return bbox_flipped
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def fix_aspect_ratio(bbox_scale: np.ndarray, aspect_ratio: float):
|
| 219 |
+
"""Reshape the bbox to a fixed aspect ratio.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
bbox_scale (np.ndarray): The bbox scales (w, h) in shape (n, 2)
|
| 223 |
+
aspect_ratio (float): The ratio of ``w/h``
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
np.darray: The reshaped bbox scales in (n, 2)
|
| 227 |
+
"""
|
| 228 |
+
dim = bbox_scale.ndim
|
| 229 |
+
if dim == 1:
|
| 230 |
+
bbox_scale = bbox_scale[None, :]
|
| 231 |
+
|
| 232 |
+
w, h = np.hsplit(bbox_scale, [1])
|
| 233 |
+
bbox_scale = np.where(
|
| 234 |
+
w > h * aspect_ratio,
|
| 235 |
+
np.hstack([w, w / aspect_ratio]),
|
| 236 |
+
np.hstack([h * aspect_ratio, h]),
|
| 237 |
+
)
|
| 238 |
+
if dim == 1:
|
| 239 |
+
bbox_scale = bbox_scale[0]
|
| 240 |
+
|
| 241 |
+
return bbox_scale
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def get_udp_warp_matrix(
|
| 245 |
+
center: np.ndarray,
|
| 246 |
+
scale: np.ndarray,
|
| 247 |
+
rot: float,
|
| 248 |
+
output_size: tuple[int, int],
|
| 249 |
+
) -> np.ndarray:
|
| 250 |
+
"""Calculate the affine transformation matrix under the unbiased
|
| 251 |
+
constraint. See `UDP (CVPR 2020)`_ for details.
|
| 252 |
+
|
| 253 |
+
Note:
|
| 254 |
+
|
| 255 |
+
- The bbox number: N
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
center (np.ndarray[2, ]): Center of the bounding box (x, y).
|
| 259 |
+
scale (np.ndarray[2, ]): Scale of the bounding box
|
| 260 |
+
wrt [width, height].
|
| 261 |
+
rot (float): Rotation angle (degree).
|
| 262 |
+
output_size (tuple): Size ([w, h]) of the output image
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
np.ndarray: A 2x3 transformation matrix
|
| 266 |
+
|
| 267 |
+
.. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524
|
| 268 |
+
"""
|
| 269 |
+
assert len(center) == 2
|
| 270 |
+
assert len(scale) == 2
|
| 271 |
+
assert len(output_size) == 2
|
| 272 |
+
|
| 273 |
+
input_size = center * 2
|
| 274 |
+
rot_rad = np.deg2rad(rot)
|
| 275 |
+
warp_mat = np.zeros((2, 3), dtype=np.float32)
|
| 276 |
+
scale_x = (output_size[0] - 1) / scale[0]
|
| 277 |
+
scale_y = (output_size[1] - 1) / scale[1]
|
| 278 |
+
warp_mat[0, 0] = math.cos(rot_rad) * scale_x
|
| 279 |
+
warp_mat[0, 1] = -math.sin(rot_rad) * scale_x
|
| 280 |
+
warp_mat[0, 2] = scale_x * (
|
| 281 |
+
-0.5 * input_size[0] * math.cos(rot_rad) + 0.5 * input_size[1] * math.sin(rot_rad) + 0.5 * scale[0]
|
| 282 |
+
)
|
| 283 |
+
warp_mat[1, 0] = math.sin(rot_rad) * scale_y
|
| 284 |
+
warp_mat[1, 1] = math.cos(rot_rad) * scale_y
|
| 285 |
+
warp_mat[1, 2] = scale_y * (
|
| 286 |
+
-0.5 * input_size[0] * math.sin(rot_rad) - 0.5 * input_size[1] * math.cos(rot_rad) + 0.5 * scale[1]
|
| 287 |
+
)
|
| 288 |
+
return warp_mat
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def get_warp_matrix(
|
| 292 |
+
center: np.ndarray,
|
| 293 |
+
scale: np.ndarray,
|
| 294 |
+
rot: float,
|
| 295 |
+
output_size: tuple[int, int],
|
| 296 |
+
shift: tuple[float, float] = (0.0, 0.0),
|
| 297 |
+
inv: bool = False,
|
| 298 |
+
) -> np.ndarray:
|
| 299 |
+
"""Calculate the affine transformation matrix that can warp the bbox area
|
| 300 |
+
in the input image to the output size.
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
center (np.ndarray[2, ]): Center of the bounding box (x, y).
|
| 304 |
+
scale (np.ndarray[2, ]): Scale of the bounding box
|
| 305 |
+
wrt [width, height].
|
| 306 |
+
rot (float): Rotation angle (degree).
|
| 307 |
+
output_size (np.ndarray[2, ] | list(2,)): Size of the
|
| 308 |
+
destination heatmaps.
|
| 309 |
+
shift (0-100%): Shift translation ratio wrt the width/height.
|
| 310 |
+
Default (0., 0.).
|
| 311 |
+
inv (bool): Option to inverse the affine transform direction.
|
| 312 |
+
(inv=False: src->dst or inv=True: dst->src)
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
np.ndarray: A 2x3 transformation matrix
|
| 316 |
+
"""
|
| 317 |
+
assert len(center) == 2
|
| 318 |
+
assert len(scale) == 2
|
| 319 |
+
assert len(output_size) == 2
|
| 320 |
+
assert len(shift) == 2
|
| 321 |
+
|
| 322 |
+
shift = np.array(shift)
|
| 323 |
+
src_w = scale[0]
|
| 324 |
+
dst_w = output_size[0]
|
| 325 |
+
dst_h = output_size[1]
|
| 326 |
+
|
| 327 |
+
rot_rad = np.deg2rad(rot)
|
| 328 |
+
src_dir = _rotate_point(np.array([0.0, src_w * -0.5]), rot_rad)
|
| 329 |
+
dst_dir = np.array([0.0, dst_w * -0.5])
|
| 330 |
+
|
| 331 |
+
src = np.zeros((3, 2), dtype=np.float32)
|
| 332 |
+
src[0, :] = center + scale * shift
|
| 333 |
+
src[1, :] = center + src_dir + scale * shift
|
| 334 |
+
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
|
| 335 |
+
|
| 336 |
+
dst = np.zeros((3, 2), dtype=np.float32)
|
| 337 |
+
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
| 338 |
+
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
| 339 |
+
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
|
| 340 |
+
|
| 341 |
+
if inv:
|
| 342 |
+
warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
| 343 |
+
else:
|
| 344 |
+
warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
| 345 |
+
return warp_mat
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
|
| 349 |
+
"""Rotate a point by an angle.
|
| 350 |
+
|
| 351 |
+
Args:
|
| 352 |
+
pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
|
| 353 |
+
angle_rad (float): rotation angle in radian
|
| 354 |
+
|
| 355 |
+
Returns:
|
| 356 |
+
np.ndarray: Rotated point in shape (2, )
|
| 357 |
+
"""
|
| 358 |
+
|
| 359 |
+
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
|
| 360 |
+
rot_mat = np.array([[cs, -sn], [sn, cs]])
|
| 361 |
+
return rot_mat @ pt
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def _get_3rd_point(a: np.ndarray, b: np.ndarray):
|
| 365 |
+
"""To calculate the affine matrix, three pairs of points are required. This
|
| 366 |
+
function is used to get the 3rd point, given 2D points a & b.
|
| 367 |
+
|
| 368 |
+
The 3rd point is defined by rotating vector `a - b` by 90 degrees
|
| 369 |
+
anticlockwise, using b as the rotation center.
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
a (np.ndarray): The 1st point (x,y) in shape (2, )
|
| 373 |
+
b (np.ndarray): The 2nd point (x,y) in shape (2, )
|
| 374 |
+
|
| 375 |
+
Returns:
|
| 376 |
+
np.ndarray: The 3rd point.
|
| 377 |
+
"""
|
| 378 |
+
direction = a - b
|
| 379 |
+
c = b + np.r_[-direction[1], direction[0]]
|
| 380 |
+
return c
|
src/sam3d_body/data/transforms/common.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
from collections.abc import Callable, Sequence
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torchvision.transforms.functional as F
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
from sam3d_body.models.modules import to_2tuple
|
| 12 |
+
|
| 13 |
+
from .bbox_utils import (
|
| 14 |
+
bbox_xywh2cs,
|
| 15 |
+
bbox_xyxy2cs,
|
| 16 |
+
fix_aspect_ratio,
|
| 17 |
+
get_udp_warp_matrix,
|
| 18 |
+
get_warp_matrix,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Compose:
|
| 23 |
+
"""Compose multiple transforms sequentially.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
transforms (Sequence[dict, callable], optional): Sequence of transform
|
| 27 |
+
object or config dict to be composed.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, transforms: list[Callable] | None = None):
|
| 31 |
+
if transforms is None:
|
| 32 |
+
transforms = []
|
| 33 |
+
else:
|
| 34 |
+
self.transforms = transforms
|
| 35 |
+
|
| 36 |
+
def __call__(self, data: dict) -> dict | None:
|
| 37 |
+
"""Call function to apply transforms sequentially.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
data (dict): A result dict contains the data to transform.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
dict: Transformed data.
|
| 44 |
+
"""
|
| 45 |
+
for t in self.transforms:
|
| 46 |
+
data = t(data)
|
| 47 |
+
# The transform will return None when it failed to load images or
|
| 48 |
+
# cannot find suitable augmentation parameters to augment the data.
|
| 49 |
+
# Here we simply return None if the transform returns None and the
|
| 50 |
+
# dataset will handle it by randomly selecting another data sample.
|
| 51 |
+
if data is None:
|
| 52 |
+
return None
|
| 53 |
+
return data
|
| 54 |
+
|
| 55 |
+
def __repr__(self):
|
| 56 |
+
"""Print ``self.transforms`` in sequence.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
str: Formatted string.
|
| 60 |
+
"""
|
| 61 |
+
format_string = self.__class__.__name__ + "("
|
| 62 |
+
for t in self.transforms:
|
| 63 |
+
format_string += "\n"
|
| 64 |
+
format_string += f" {t}"
|
| 65 |
+
format_string += "\n)"
|
| 66 |
+
return format_string
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class VisionTransformWrapper:
|
| 70 |
+
"""A wrapper to use torchvision transform functions in this codebase."""
|
| 71 |
+
|
| 72 |
+
def __init__(self, transform: Callable):
|
| 73 |
+
self.transform = transform
|
| 74 |
+
|
| 75 |
+
def __call__(self, results: dict) -> dict | None:
|
| 76 |
+
results["img"] = self.transform(results["img"])
|
| 77 |
+
return results
|
| 78 |
+
|
| 79 |
+
def __repr__(self) -> str:
|
| 80 |
+
"""print the basic information of the transform.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
str: Formatted string.
|
| 84 |
+
"""
|
| 85 |
+
repr_str = self.transform.__class__.__name__
|
| 86 |
+
return repr_str
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class GetBBoxCenterScale(nn.Module):
|
| 90 |
+
"""Convert bboxes to center and scale.
|
| 91 |
+
|
| 92 |
+
The center is the coordinates of the bbox center, and the scale is the
|
| 93 |
+
bbox width and height normalized by a scale factor.
|
| 94 |
+
|
| 95 |
+
Required Keys:
|
| 96 |
+
|
| 97 |
+
- bbox
|
| 98 |
+
- bbox_format
|
| 99 |
+
|
| 100 |
+
Added Keys:
|
| 101 |
+
|
| 102 |
+
- bbox_center
|
| 103 |
+
- bbox_scale
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
padding (float): The bbox padding scale that will be multilied to
|
| 107 |
+
`bbox_scale`. Defaults to 1.25
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
def __init__(self, padding: float = 1.25) -> None:
|
| 111 |
+
super().__init__()
|
| 112 |
+
|
| 113 |
+
self.padding = padding
|
| 114 |
+
|
| 115 |
+
def forward(self, results: dict) -> dict | None:
|
| 116 |
+
"""The transform function of :class:`GetBBoxCenterScale`.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
results (dict): The result dict
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
dict: The result dict.
|
| 123 |
+
"""
|
| 124 |
+
if "bbox_center" in results and "bbox_scale" in results:
|
| 125 |
+
results["bbox_scale"] *= self.padding
|
| 126 |
+
else:
|
| 127 |
+
bbox = results["bbox"]
|
| 128 |
+
bbox_format = results.get("bbox_format", "none")
|
| 129 |
+
if bbox_format == "xywh":
|
| 130 |
+
center, scale = bbox_xywh2cs(bbox, padding=self.padding)
|
| 131 |
+
elif bbox_format == "xyxy":
|
| 132 |
+
center, scale = bbox_xyxy2cs(bbox, padding=self.padding)
|
| 133 |
+
else:
|
| 134 |
+
raise ValueError("Invalid bbox format: {}".format(results["bbox_format"]))
|
| 135 |
+
|
| 136 |
+
results["bbox_center"] = center
|
| 137 |
+
results["bbox_scale"] = scale
|
| 138 |
+
return results
|
| 139 |
+
|
| 140 |
+
def __repr__(self) -> str:
|
| 141 |
+
"""print the basic information of the transform.
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
str: Formatted string.
|
| 145 |
+
"""
|
| 146 |
+
repr_str = self.__class__.__name__ + f"(padding={self.padding})"
|
| 147 |
+
return repr_str
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class SquarePad:
|
| 151 |
+
def __call__(self, results: dict) -> dict | None:
|
| 152 |
+
assert isinstance(results["img"], Image.Image)
|
| 153 |
+
w, h = results["img"].size
|
| 154 |
+
|
| 155 |
+
max_wh = np.max([w, h])
|
| 156 |
+
hp = int((max_wh - w) / 2)
|
| 157 |
+
vp = int((max_wh - h) / 2)
|
| 158 |
+
padding = (hp, vp, max_wh - w - hp, max_wh - h - vp)
|
| 159 |
+
|
| 160 |
+
results["img"] = F.pad(results["img"], padding, 0, "constant")
|
| 161 |
+
return results
|
| 162 |
+
|
| 163 |
+
def __repr__(self) -> str:
|
| 164 |
+
"""print the basic information of the transform.
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
str: Formatted string.
|
| 168 |
+
"""
|
| 169 |
+
repr_str = self.__class__.__name__
|
| 170 |
+
return repr_str
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class ToPIL:
|
| 174 |
+
def __call__(self, results: dict) -> dict | None:
|
| 175 |
+
if isinstance(results["img"], list):
|
| 176 |
+
if isinstance(results["img"][0], np.ndarray):
|
| 177 |
+
results["img"] = [Image.fromarray(img) for img in results["img"]]
|
| 178 |
+
elif isinstance(results["img"], np.ndarray):
|
| 179 |
+
results["img"] = Image.fromarray(results["img"])
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class ToCv2:
|
| 183 |
+
def __call__(self, results: dict) -> dict | None:
|
| 184 |
+
if isinstance(results["img"], list):
|
| 185 |
+
if isinstance(results["img"][0], Image.Image):
|
| 186 |
+
results["img"] = [np.array(img) for img in results["img"]]
|
| 187 |
+
elif isinstance(results["img"], Image.Image):
|
| 188 |
+
results["img"] = np.array(results["img"])
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class TopdownAffine(nn.Module):
|
| 192 |
+
"""Get the bbox image as the model input by affine transform.
|
| 193 |
+
|
| 194 |
+
Required Keys:
|
| 195 |
+
- img
|
| 196 |
+
- bbox_center
|
| 197 |
+
- bbox_scale
|
| 198 |
+
- bbox_rotation (optional)
|
| 199 |
+
- keypoints_2d (optional)
|
| 200 |
+
- mask (optional)
|
| 201 |
+
|
| 202 |
+
Modified Keys:
|
| 203 |
+
- img
|
| 204 |
+
- bbox_scale
|
| 205 |
+
|
| 206 |
+
Added Keys:
|
| 207 |
+
- input_size
|
| 208 |
+
- transformed_keypoints
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
input_size (Tuple[int, int]): The input image size of the model in
|
| 212 |
+
[w, h]. The bbox region will be cropped and resize to `input_size`
|
| 213 |
+
use_udp (bool): Whether use unbiased data processing. See
|
| 214 |
+
`UDP (CVPR 2020)`_ for details. Defaults to ``False``
|
| 215 |
+
aspect_ratio (float): both HMR2.0 and Sapiens will expand input bbox to
|
| 216 |
+
a fixed ratio (width/height = 192/256), then expand to the ratio of
|
| 217 |
+
the model input size. E.g., HMR2.0 will eventually expand to 1:1, while
|
| 218 |
+
Sapiens will be 768:1024.
|
| 219 |
+
|
| 220 |
+
.. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
def __init__(
|
| 224 |
+
self,
|
| 225 |
+
input_size: int | tuple[int, int] | Sequence[int],
|
| 226 |
+
use_udp: bool = False,
|
| 227 |
+
aspect_ratio: float = 0.75,
|
| 228 |
+
fix_square: bool = False,
|
| 229 |
+
) -> None:
|
| 230 |
+
super().__init__()
|
| 231 |
+
|
| 232 |
+
self.input_size = to_2tuple(input_size)
|
| 233 |
+
self.use_udp = use_udp
|
| 234 |
+
self.aspect_ratio = aspect_ratio
|
| 235 |
+
self.fix_square = fix_square
|
| 236 |
+
|
| 237 |
+
def forward(self, results: dict) -> dict | None:
|
| 238 |
+
"""The transform function of :class:`TopdownAffine`.
|
| 239 |
+
|
| 240 |
+
See ``transform()`` method of :class:`BaseTransform` for details.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
results (dict): The result dict
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
dict: The result dict.
|
| 247 |
+
"""
|
| 248 |
+
# # Debug only
|
| 249 |
+
# import copy
|
| 250 |
+
# results['ori_img'] = np.zeros((2000, 2000, 3), dtype=np.uint8)
|
| 251 |
+
# results['ori_img'][:results['img'].shape[0], :results['img'].shape[1]] = copy.deepcopy(results['img'])
|
| 252 |
+
|
| 253 |
+
w, h = self.input_size
|
| 254 |
+
warp_size = (int(w), int(h))
|
| 255 |
+
|
| 256 |
+
# expand bbox to fixed aspect ratio
|
| 257 |
+
results["orig_bbox_scale"] = results["bbox_scale"].copy()
|
| 258 |
+
if self.fix_square and results["bbox_scale"][0] == results["bbox_scale"][1]:
|
| 259 |
+
# In HMR2.0 etc, no fexpand_aspect_ratio for square bbox
|
| 260 |
+
bbox_scale = fix_aspect_ratio(results["bbox_scale"], aspect_ratio=w / h)
|
| 261 |
+
else:
|
| 262 |
+
# first to a prior aspect ratio, then reshape to model input size
|
| 263 |
+
bbox_scale = fix_aspect_ratio(results["bbox_scale"], aspect_ratio=self.aspect_ratio)
|
| 264 |
+
results["bbox_scale"] = fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
|
| 265 |
+
results["bbox_expand_factor"] = results["bbox_scale"].max() / results["orig_bbox_scale"].max()
|
| 266 |
+
rot = 0.0
|
| 267 |
+
if results["bbox_center"].ndim == 2:
|
| 268 |
+
assert results["bbox_center"].shape[0] == 1, (
|
| 269 |
+
"Only support cropping one instance at a time. Got invalid "
|
| 270 |
+
f"shape of bbox_center {results['bbox_center'].shape}."
|
| 271 |
+
)
|
| 272 |
+
center = results["bbox_center"][0]
|
| 273 |
+
scale = results["bbox_scale"][0]
|
| 274 |
+
if "bbox_rotation" in results:
|
| 275 |
+
rot = results["bbox_rotation"][0]
|
| 276 |
+
else:
|
| 277 |
+
center = results["bbox_center"]
|
| 278 |
+
scale = results["bbox_scale"]
|
| 279 |
+
if "bbox_rotation" in results:
|
| 280 |
+
rot = results["bbox_rotation"]
|
| 281 |
+
|
| 282 |
+
if self.use_udp:
|
| 283 |
+
warp_mat = get_udp_warp_matrix(center, scale, rot, output_size=(w, h))
|
| 284 |
+
else:
|
| 285 |
+
warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
|
| 286 |
+
|
| 287 |
+
if "img" not in results:
|
| 288 |
+
pass
|
| 289 |
+
elif isinstance(results["img"], list):
|
| 290 |
+
results["img"] = [
|
| 291 |
+
cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR) for img in results["img"]
|
| 292 |
+
]
|
| 293 |
+
height, width = results["img"][0].shape[:2]
|
| 294 |
+
results["ori_img_size"] = np.array([width, height])
|
| 295 |
+
else:
|
| 296 |
+
height, width = results["img"].shape[:2]
|
| 297 |
+
results["ori_img_size"] = np.array([width, height])
|
| 298 |
+
results["img"] = cv2.warpAffine(results["img"], warp_mat, warp_size, flags=cv2.INTER_LINEAR)
|
| 299 |
+
|
| 300 |
+
if results.get("keypoints_2d") is not None:
|
| 301 |
+
results["orig_keypoints_2d"] = results["keypoints_2d"].copy()
|
| 302 |
+
transformed_keypoints = results["keypoints_2d"].copy()
|
| 303 |
+
# Only transform (x, y) coordinates
|
| 304 |
+
# cv2 expect the input to be [[[x1, y1], [x2, y2]]]
|
| 305 |
+
transformed_keypoints[:, :2] = cv2.transform(results["keypoints_2d"][None, :, :2], warp_mat)[0]
|
| 306 |
+
results["keypoints_2d"] = transformed_keypoints
|
| 307 |
+
|
| 308 |
+
if results.get("mask") is not None:
|
| 309 |
+
results["mask"] = cv2.warpAffine(results["mask"], warp_mat, warp_size, flags=cv2.INTER_LINEAR)
|
| 310 |
+
|
| 311 |
+
results["img_size"] = np.array([w, h])
|
| 312 |
+
results["input_size"] = np.array([w, h])
|
| 313 |
+
results["affine_trans"] = warp_mat
|
| 314 |
+
return results
|
| 315 |
+
|
| 316 |
+
def __repr__(self) -> str:
|
| 317 |
+
"""print the basic information of the transform.
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
str: Formatted string.
|
| 321 |
+
"""
|
| 322 |
+
repr_str = self.__class__.__name__
|
| 323 |
+
repr_str += f"(input_size={self.input_size}, "
|
| 324 |
+
repr_str += f"use_udp={self.use_udp})"
|
| 325 |
+
return repr_str
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class NormalizeKeypoint(nn.Module):
|
| 329 |
+
"""
|
| 330 |
+
Normalize 2D keypoints to range [-0.5, 0.5].
|
| 331 |
+
|
| 332 |
+
Required Keys:
|
| 333 |
+
- keypoints_2d
|
| 334 |
+
- img_size
|
| 335 |
+
|
| 336 |
+
Modified Keys:
|
| 337 |
+
- keypoints_2d
|
| 338 |
+
"""
|
| 339 |
+
|
| 340 |
+
def forward(self, results: dict) -> dict | None:
|
| 341 |
+
if "keypoints_2d" in results:
|
| 342 |
+
img_size = results.get("img_size", results["input_size"])
|
| 343 |
+
|
| 344 |
+
results["keypoints_2d"][:, :2] = results["keypoints_2d"][:, :2] / np.array(img_size).reshape(1, 2) - 0.5
|
| 345 |
+
return results
|
src/sam3d_body/data/utils/io.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
from typing import Any, List
|
| 6 |
+
|
| 7 |
+
import braceexpand
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def expand(s):
|
| 15 |
+
return os.path.expanduser(os.path.expandvars(s))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def expand_urls(urls: str | List[str]):
|
| 19 |
+
if isinstance(urls, str):
|
| 20 |
+
urls = [urls]
|
| 21 |
+
urls = [u for url in urls for u in braceexpand.braceexpand(expand(url))]
|
| 22 |
+
return urls
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def load_image_from_file(
|
| 26 |
+
data_info: dict,
|
| 27 |
+
backend: str = "cv2",
|
| 28 |
+
image_format: str = "rgb",
|
| 29 |
+
retry: int = 10,
|
| 30 |
+
) -> dict:
|
| 31 |
+
img = load_image(data_info["img_path"], backend, image_format, retry)
|
| 32 |
+
data_info["img"] = img
|
| 33 |
+
data_info["img_shape"] = img.shape[:2]
|
| 34 |
+
data_info["ori_shape"] = img.shape[:2]
|
| 35 |
+
return data_info
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _pil_load(path: str, image_format: str) -> Image.Image:
|
| 39 |
+
with Image.open(path) as img:
|
| 40 |
+
if img is not None and image_format.lower() == "rgb":
|
| 41 |
+
img = img.convert("RGB")
|
| 42 |
+
return img
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _cv2_load(path: str, image_format: str) -> np.ndarray:
|
| 46 |
+
img = cv2.imread(path)
|
| 47 |
+
if img is not None and image_format.lower() == "rgb":
|
| 48 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 49 |
+
return img
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def load_image(
|
| 53 |
+
path: str,
|
| 54 |
+
backend: str = "pil",
|
| 55 |
+
image_format: str = "rgb",
|
| 56 |
+
retry: int = 10,
|
| 57 |
+
) -> Any:
|
| 58 |
+
for i_try in range(retry):
|
| 59 |
+
if backend == "pil":
|
| 60 |
+
img = _pil_load(path, image_format)
|
| 61 |
+
elif backend == "cv2":
|
| 62 |
+
img = _cv2_load(path, image_format)
|
| 63 |
+
else:
|
| 64 |
+
raise ValueError("Invalid backend {} for loading image.".format(backend))
|
| 65 |
+
|
| 66 |
+
if img is not None:
|
| 67 |
+
return img
|
| 68 |
+
else:
|
| 69 |
+
print("Reading {} failed. Will retry.".format(path))
|
| 70 |
+
time.sleep(1.0)
|
| 71 |
+
if i_try == retry - 1:
|
| 72 |
+
raise Exception("Failed to load image {}".format(path))
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def resize_image(img, target_size, center=None, scale=None):
|
| 76 |
+
height, width = img.shape[:2]
|
| 77 |
+
aspect_ratio = width / height
|
| 78 |
+
|
| 79 |
+
# Calculate the new size while maintaining the aspect ratio
|
| 80 |
+
if aspect_ratio > 1:
|
| 81 |
+
new_width = target_size
|
| 82 |
+
new_height = int(target_size / aspect_ratio)
|
| 83 |
+
else:
|
| 84 |
+
new_width = int(target_size * aspect_ratio)
|
| 85 |
+
new_height = target_size
|
| 86 |
+
|
| 87 |
+
# Resize the image using OpenCV
|
| 88 |
+
resized_img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
| 89 |
+
|
| 90 |
+
# Create a new blank image with the target size
|
| 91 |
+
final_img = np.ones((target_size, target_size, 3), dtype=np.uint8) * 255
|
| 92 |
+
|
| 93 |
+
# Paste the resized image onto the blank image, centering it
|
| 94 |
+
start_x = (target_size - new_width) // 2
|
| 95 |
+
start_y = (target_size - new_height) // 2
|
| 96 |
+
final_img[start_y : start_y + new_height, start_x : start_x + new_width] = (
|
| 97 |
+
resized_img
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
if center is not None and scale is not None:
|
| 101 |
+
ratio_width = new_width / width
|
| 102 |
+
ratio_height = new_height / height
|
| 103 |
+
|
| 104 |
+
new_scale = np.stack(
|
| 105 |
+
[scale[:, 0] * ratio_width, scale[:, 1] * ratio_height], axis=1
|
| 106 |
+
)
|
| 107 |
+
new_center = np.stack(
|
| 108 |
+
[center[:, 0] * ratio_width, center[:, 1] * ratio_height], axis=1
|
| 109 |
+
)
|
| 110 |
+
new_center[:, 0] += start_x
|
| 111 |
+
new_center[:, 1] += start_y
|
| 112 |
+
else:
|
| 113 |
+
new_center, new_scale = None, None
|
| 114 |
+
return aspect_ratio, final_img, new_center, new_scale
|
src/sam3d_body/data/utils/prepare_batch.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
from collections.abc import Callable
|
| 4 |
+
from typing import Any, TypedDict, cast
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from jaxtyping import Float, UInt8
|
| 9 |
+
from numpy import ndarray
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
from torch.utils.data import default_collate
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class PreparedBatchDict(TypedDict, total=False):
|
| 15 |
+
img: Float[Tensor, "B N 3 H W"]
|
| 16 |
+
img_size: Float[Tensor, "B N 2"]
|
| 17 |
+
ori_img_size: Float[Tensor, "B N 2"]
|
| 18 |
+
bbox_center: Float[Tensor, "B N 2"]
|
| 19 |
+
bbox_scale: Float[Tensor, "B N 2"]
|
| 20 |
+
bbox: Float[Tensor, "B N 4"]
|
| 21 |
+
affine_trans: Float[Tensor, "B N 2 3"]
|
| 22 |
+
mask: Float[Tensor, "B N 1 H W"]
|
| 23 |
+
mask_score: Float[Tensor, "B N"]
|
| 24 |
+
cam_int: Float[Tensor, "B 3 3"]
|
| 25 |
+
person_valid: Float[Tensor, "B N"]
|
| 26 |
+
img_ori: list["NoCollate"]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class NoCollate:
|
| 30 |
+
def __init__(self, data: Any) -> None:
|
| 31 |
+
self.data: Any = data
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def prepare_batch(
|
| 35 |
+
img: UInt8[ndarray, "h w 3"],
|
| 36 |
+
transform: Callable[[dict[str, Any]], dict[str, Any]],
|
| 37 |
+
boxes: Float[ndarray, "n 4"],
|
| 38 |
+
masks: Float[ndarray, "n h w"] | None = None,
|
| 39 |
+
masks_score: Float[ndarray, "n"] | None = None,
|
| 40 |
+
cam_int: Float[Tensor, "B 3 3"] | None = None,
|
| 41 |
+
) -> PreparedBatchDict:
|
| 42 |
+
"""A helper function to prepare data batch for SAM 3D Body model inference."""
|
| 43 |
+
height, width = img.shape[:2]
|
| 44 |
+
|
| 45 |
+
# construct batch data samples
|
| 46 |
+
data_list: list[dict[str, Any]] = []
|
| 47 |
+
for idx in range(boxes.shape[0]):
|
| 48 |
+
data_info: dict[str, Any] = dict(img=img)
|
| 49 |
+
data_info["bbox"] = boxes[idx] # shape (4,)
|
| 50 |
+
data_info["bbox_format"] = "xyxy"
|
| 51 |
+
|
| 52 |
+
if masks is not None:
|
| 53 |
+
data_info["mask"] = masks[idx].astype(np.float32, copy=False)
|
| 54 |
+
if masks_score is not None:
|
| 55 |
+
data_info["mask_score"] = masks_score[idx]
|
| 56 |
+
else:
|
| 57 |
+
data_info["mask_score"] = np.array(1.0, dtype=np.float32)
|
| 58 |
+
else:
|
| 59 |
+
data_info["mask"] = np.zeros((height, width, 1), dtype=np.uint8)
|
| 60 |
+
data_info["mask_score"] = np.array(0.0, dtype=np.float32)
|
| 61 |
+
|
| 62 |
+
data_list.append(transform(data_info))
|
| 63 |
+
|
| 64 |
+
batch = default_collate(data_list)
|
| 65 |
+
|
| 66 |
+
max_num_person = batch["img"].shape[0]
|
| 67 |
+
for key in [
|
| 68 |
+
"img",
|
| 69 |
+
"img_size",
|
| 70 |
+
"ori_img_size",
|
| 71 |
+
"bbox_center",
|
| 72 |
+
"bbox_scale",
|
| 73 |
+
"bbox",
|
| 74 |
+
"affine_trans",
|
| 75 |
+
"mask",
|
| 76 |
+
"mask_score",
|
| 77 |
+
]:
|
| 78 |
+
if key in batch:
|
| 79 |
+
batch[key] = batch[key].unsqueeze(0).float()
|
| 80 |
+
if "mask" in batch:
|
| 81 |
+
batch["mask"] = batch["mask"].unsqueeze(2)
|
| 82 |
+
batch["person_valid"] = torch.ones((1, max_num_person))
|
| 83 |
+
|
| 84 |
+
if cam_int is not None:
|
| 85 |
+
batch["cam_int"] = cam_int.to(batch["img"])
|
| 86 |
+
else:
|
| 87 |
+
# Default camera intrinsics according image size
|
| 88 |
+
batch["cam_int"] = torch.tensor(
|
| 89 |
+
[
|
| 90 |
+
[
|
| 91 |
+
[(height**2 + width**2) ** 0.5, 0, width / 2.0],
|
| 92 |
+
[0, (height**2 + width**2) ** 0.5, height / 2.0],
|
| 93 |
+
[0, 0, 1],
|
| 94 |
+
]
|
| 95 |
+
],
|
| 96 |
+
).to(batch["img"])
|
| 97 |
+
|
| 98 |
+
batch["img_ori"] = [NoCollate(img)]
|
| 99 |
+
return cast(PreparedBatchDict, batch)
|
src/sam3d_body/gradio_ui/sam3d_body_ui.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Demonstrates integrating Rerun visualization with Gradio.
|
| 3 |
+
|
| 4 |
+
Provides example implementations of data streaming, keypoint annotation, and dynamic
|
| 5 |
+
visualization across multiple Gradio tabs using Rerun's recording and visualization capabilities.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import shutil
|
| 10 |
+
import tempfile
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Final
|
| 13 |
+
|
| 14 |
+
import cv2
|
| 15 |
+
import gradio as gr
|
| 16 |
+
import rerun as rr
|
| 17 |
+
import rerun.blueprint as rrb
|
| 18 |
+
import spaces
|
| 19 |
+
from gradio_rerun import Rerun
|
| 20 |
+
from jaxtyping import Int, UInt8
|
| 21 |
+
from monopriors.relative_depth_models import RelativeDepthPrediction
|
| 22 |
+
from numpy import ndarray
|
| 23 |
+
|
| 24 |
+
from sam3d_body.api.demo import SAM3Config, SAM3DBodyE2E, SAM3DBodyE2EConfig, create_view, set_annotation_context
|
| 25 |
+
from sam3d_body.api.visualization import export_meshes_to_glb, visualize_sample
|
| 26 |
+
from sam3d_body.sam_3d_body_estimator import FinalPosePrediction
|
| 27 |
+
|
| 28 |
+
CFG: SAM3DBodyE2EConfig = SAM3DBodyE2EConfig(sam3_config=SAM3Config())
|
| 29 |
+
MODEL_E2E: SAM3DBodyE2E = SAM3DBodyE2E(config=CFG)
|
| 30 |
+
mesh_faces: Int[ndarray, "n_faces=36874 3"] = MODEL_E2E.sam3d_body_estimator.faces
|
| 31 |
+
STATE: Final[str] = "✅ Ready"
|
| 32 |
+
# Absolute path to bundled example data used by Gradio examples.
|
| 33 |
+
TEST_INPUT_DIR: Final[Path] = Path(__file__).resolve().parents[3] / "data" / "example-data"
|
| 34 |
+
|
| 35 |
+
# Allow Gradio to serve and cache files from the bundled test data directory.
|
| 36 |
+
gr.set_static_paths([str(TEST_INPUT_DIR)])
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@spaces.GPU()
|
| 40 |
+
@rr.thread_local_stream("sam3d_body_gradio_ui")
|
| 41 |
+
def sam3d_prediction_fn(
|
| 42 |
+
rgb_hw3,
|
| 43 |
+
log_relative_depth,
|
| 44 |
+
export_glb,
|
| 45 |
+
center_glb,
|
| 46 |
+
pending_cleanup=None,
|
| 47 |
+
) -> tuple[str, str, list[str]]:
|
| 48 |
+
# resize rgb so that its largest dimension is 1024
|
| 49 |
+
rgb_hw3: UInt8[ndarray, "h w 3"] = cv2.resize(
|
| 50 |
+
rgb_hw3, # type: ignore[arg-type]
|
| 51 |
+
dsize=(0, 0),
|
| 52 |
+
fx=1024 / max(rgb_hw3.shape[0], rgb_hw3.shape[1]),
|
| 53 |
+
fy=1024 / max(rgb_hw3.shape[0], rgb_hw3.shape[1]),
|
| 54 |
+
interpolation=cv2.INTER_AREA,
|
| 55 |
+
)
|
| 56 |
+
# We eventually want to clean up the RRD file after it's sent to the viewer, so tracking
|
| 57 |
+
# any pending files to be cleaned up when the state is deleted.
|
| 58 |
+
temp = tempfile.NamedTemporaryFile(prefix="cube_", suffix=".rrd", delete=False)
|
| 59 |
+
|
| 60 |
+
if pending_cleanup is not None:
|
| 61 |
+
pending_cleanup.append(temp.name)
|
| 62 |
+
|
| 63 |
+
view: rrb.ContainerLike = create_view()
|
| 64 |
+
blueprint = rrb.Blueprint(view, collapse_panels=True)
|
| 65 |
+
rr.save(path=temp.name, default_blueprint=blueprint)
|
| 66 |
+
set_annotation_context()
|
| 67 |
+
parent_log_path = Path("/world")
|
| 68 |
+
rr.log("/", rr.ViewCoordinates.RDF, static=True)
|
| 69 |
+
|
| 70 |
+
outputs: tuple[list[FinalPosePrediction], RelativeDepthPrediction] = MODEL_E2E.predict_single_image(rgb_hw3=rgb_hw3)
|
| 71 |
+
pred_list: list[FinalPosePrediction] = outputs[0]
|
| 72 |
+
relative_pred: RelativeDepthPrediction = outputs[1]
|
| 73 |
+
rr.set_time(timeline="image_sequence", sequence=0)
|
| 74 |
+
visualize_sample(
|
| 75 |
+
pred_list=pred_list,
|
| 76 |
+
rgb_hw3=rgb_hw3,
|
| 77 |
+
parent_log_path=parent_log_path,
|
| 78 |
+
faces=mesh_faces,
|
| 79 |
+
relative_depth_pred=relative_pred if log_relative_depth else None,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
glb_files: list[str] = []
|
| 83 |
+
if export_glb and len(pred_list) > 0:
|
| 84 |
+
glb_dir: Path = Path(tempfile.mkdtemp(prefix="sam3d_glb_"))
|
| 85 |
+
glb_paths = export_meshes_to_glb(
|
| 86 |
+
pred_list=pred_list,
|
| 87 |
+
faces=mesh_faces,
|
| 88 |
+
output_dir=glb_dir,
|
| 89 |
+
center_mesh=center_glb,
|
| 90 |
+
)
|
| 91 |
+
glb_files = [str(p) for p in glb_paths]
|
| 92 |
+
if pending_cleanup is not None:
|
| 93 |
+
pending_cleanup.extend(glb_files)
|
| 94 |
+
pending_cleanup.append(str(glb_dir))
|
| 95 |
+
|
| 96 |
+
return temp.name, STATE, glb_files
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def cleanup_rrds(pending_cleanup: list[str]) -> None:
|
| 100 |
+
for f in pending_cleanup:
|
| 101 |
+
if os.path.isdir(f):
|
| 102 |
+
shutil.rmtree(f, ignore_errors=True)
|
| 103 |
+
elif os.path.isfile(f):
|
| 104 |
+
os.unlink(f)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _switch_to_outputs() -> gr.Tabs:
|
| 108 |
+
return gr.update(selected="outputs")
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def main():
|
| 112 |
+
viewer = Rerun(
|
| 113 |
+
streaming=True,
|
| 114 |
+
panel_states={
|
| 115 |
+
"time": "collapsed",
|
| 116 |
+
"blueprint": "hidden",
|
| 117 |
+
"selection": "hidden",
|
| 118 |
+
},
|
| 119 |
+
height=800,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
with gr.Blocks() as demo, gr.Tab("SAM3D Body Estimation"):
|
| 123 |
+
pending_cleanup = gr.State([], time_to_live=10, delete_callback=cleanup_rrds)
|
| 124 |
+
with gr.Row():
|
| 125 |
+
with gr.Column(scale=1):
|
| 126 |
+
tabs = gr.Tabs(selected="inputs")
|
| 127 |
+
with tabs:
|
| 128 |
+
with gr.TabItem("Inputs", id="inputs"):
|
| 129 |
+
img = gr.Image(interactive=True, label="Image", type="numpy", image_mode="RGB")
|
| 130 |
+
depth_checkbox = gr.Checkbox(label="Log relative depth", value=False)
|
| 131 |
+
with gr.Row():
|
| 132 |
+
export_checkbox = gr.Checkbox(label="Export GLB meshes", value=False)
|
| 133 |
+
center_checkbox = gr.Checkbox(label="Center GLB at origin", value=True)
|
| 134 |
+
create_rrd = gr.Button("Predict Pose")
|
| 135 |
+
with gr.TabItem("Outputs", id="outputs"):
|
| 136 |
+
status = gr.Text(STATE, label="Status")
|
| 137 |
+
mesh_files = gr.Files(label="GLB meshes", file_count="multiple")
|
| 138 |
+
gr.Examples(
|
| 139 |
+
examples=[
|
| 140 |
+
[str(TEST_INPUT_DIR / "Planche.jpg"), True, False, True],
|
| 141 |
+
[str(TEST_INPUT_DIR / "Amir-Khan-Lamont-Peterson_2689582.jpg"), False, False, True],
|
| 142 |
+
[str(TEST_INPUT_DIR / "BNAAHPYGMYSE26U6C6T7VA6544.jpg"), False, True, True],
|
| 143 |
+
[str(TEST_INPUT_DIR / "yoga-example.jpg"), True, True, False],
|
| 144 |
+
],
|
| 145 |
+
inputs=[img, depth_checkbox, export_checkbox, center_checkbox],
|
| 146 |
+
outputs=[viewer, status, mesh_files],
|
| 147 |
+
fn=sam3d_prediction_fn,
|
| 148 |
+
run_on_click=True,
|
| 149 |
+
cache_examples=False,
|
| 150 |
+
examples_per_page=2,
|
| 151 |
+
)
|
| 152 |
+
with gr.Column(scale=5):
|
| 153 |
+
viewer.render()
|
| 154 |
+
|
| 155 |
+
create_rrd.click(
|
| 156 |
+
fn=_switch_to_outputs,
|
| 157 |
+
inputs=None,
|
| 158 |
+
outputs=[tabs],
|
| 159 |
+
).then(
|
| 160 |
+
sam3d_prediction_fn,
|
| 161 |
+
inputs=[img, depth_checkbox, export_checkbox, center_checkbox, pending_cleanup],
|
| 162 |
+
outputs=[viewer, status, mesh_files],
|
| 163 |
+
)
|
| 164 |
+
return demo
|
src/sam3d_body/metadata/__init__.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
OPENPOSE_TO_COCO = [0, 16, 15, 18, 17, 5, 2, 6, 3, 7, 4, 12, 9, 13, 10, 14, 11]
|
| 4 |
+
|
| 5 |
+
# Mapping the J19 used in HMR2.0 to the 14 common points for evaluation
|
| 6 |
+
# J19 is defined as the first 19 keypoints in https://github.com/nkolot/SPIN/blob/master/constants.py#L42
|
| 7 |
+
# The first 14 keypoints in J19 are LSP keypoints
|
| 8 |
+
J19_TO_J14 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18]
|
| 9 |
+
|
| 10 |
+
# Mapping from 14 LSP keypoints to 17 COCO keypoints
|
| 11 |
+
# Key: coco_idx, value: lsp_idx
|
| 12 |
+
LSP_TO_COCO = {
|
| 13 |
+
5: 9,
|
| 14 |
+
6: 8,
|
| 15 |
+
7: 10,
|
| 16 |
+
8: 7,
|
| 17 |
+
9: 11,
|
| 18 |
+
10: 6,
|
| 19 |
+
11: 3,
|
| 20 |
+
12: 2,
|
| 21 |
+
13: 4,
|
| 22 |
+
14: 1,
|
| 23 |
+
15: 5,
|
| 24 |
+
16: 0,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
# fmt: off
|
| 28 |
+
OPENPOSE_PERMUTATION = [0, 1, 5, 6, 7, 2, 3, 4, 8, 12, 13, 14, 9, 10, 11, 16, 15, 18, 17, 22, 23, 24, 19, 20, 21]
|
| 29 |
+
J19_PERMUTATION = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18]
|
| 30 |
+
COCO_PERMUTATION = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
|
| 31 |
+
# fmt: on
|
| 32 |
+
|
| 33 |
+
# Mapping the 70 MHR keypoints to OpenPose (COCO included)
|
| 34 |
+
# key: OpenPose, value: mhr_idx
|
| 35 |
+
MHR70_TO_OPENPOSE = {
|
| 36 |
+
0: 0,
|
| 37 |
+
1: 69,
|
| 38 |
+
2: 6,
|
| 39 |
+
3: 8,
|
| 40 |
+
4: 41,
|
| 41 |
+
5: 5,
|
| 42 |
+
6: 7,
|
| 43 |
+
7: 62,
|
| 44 |
+
9: 10,
|
| 45 |
+
10: 12,
|
| 46 |
+
11: 14,
|
| 47 |
+
12: 9,
|
| 48 |
+
13: 11,
|
| 49 |
+
14: 13,
|
| 50 |
+
15: 2,
|
| 51 |
+
16: 1,
|
| 52 |
+
17: 4,
|
| 53 |
+
18: 3,
|
| 54 |
+
19: 15,
|
| 55 |
+
20: 16,
|
| 56 |
+
21: 17,
|
| 57 |
+
22: 18,
|
| 58 |
+
23: 19,
|
| 59 |
+
24: 20,
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
# fmt: off
|
| 63 |
+
MHR70_PERMUTATION = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 18, 19, 20, 15, 16, 17, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 64, 63, 66, 65, 68, 67, 69]
|
| 64 |
+
# fmt: on
|
| 65 |
+
MHR70_TO_LSP = {
|
| 66 |
+
0: 14,
|
| 67 |
+
1: 12,
|
| 68 |
+
2: 10,
|
| 69 |
+
3: 9,
|
| 70 |
+
4: 11,
|
| 71 |
+
5: 13,
|
| 72 |
+
6: 41,
|
| 73 |
+
7: 8,
|
| 74 |
+
8: 6,
|
| 75 |
+
9: 5,
|
| 76 |
+
10: 7,
|
| 77 |
+
11: 62,
|
| 78 |
+
12: 69,
|
| 79 |
+
}
|
src/sam3d_body/metadata/mhr70.py
ADDED
|
@@ -0,0 +1,915 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
"""The first 70 of 308 MHR keypoints, ignoring the rest for face keypoints"""
|
| 4 |
+
|
| 5 |
+
from typing import Final
|
| 6 |
+
|
| 7 |
+
mhr_names = [
|
| 8 |
+
"nose",
|
| 9 |
+
"left-eye",
|
| 10 |
+
"right-eye",
|
| 11 |
+
"left-ear",
|
| 12 |
+
"right-ear",
|
| 13 |
+
"left-shoulder",
|
| 14 |
+
"right-shoulder",
|
| 15 |
+
"left-elbow",
|
| 16 |
+
"right-elbow",
|
| 17 |
+
"left-hip",
|
| 18 |
+
"right-hip",
|
| 19 |
+
"left-knee",
|
| 20 |
+
"right-knee",
|
| 21 |
+
"left-ankle",
|
| 22 |
+
"right-ankle",
|
| 23 |
+
"left-big-toe-tip",
|
| 24 |
+
"left-small-toe-tip",
|
| 25 |
+
"left-heel",
|
| 26 |
+
"right-big-toe-tip",
|
| 27 |
+
"right-small-toe-tip",
|
| 28 |
+
"right-heel",
|
| 29 |
+
"right-thumb-tip",
|
| 30 |
+
"right-thumb-first-joint",
|
| 31 |
+
"right-thumb-second-joint",
|
| 32 |
+
"right-thumb-third-joint",
|
| 33 |
+
"right-index-tip",
|
| 34 |
+
"right-index-first-joint",
|
| 35 |
+
"right-index-second-joint",
|
| 36 |
+
"right-index-third-joint",
|
| 37 |
+
"right-middle-tip",
|
| 38 |
+
"right-middle-first-joint",
|
| 39 |
+
"right-middle-second-joint",
|
| 40 |
+
"right-middle-third-joint",
|
| 41 |
+
"right-ring-tip",
|
| 42 |
+
"right-ring-first-joint",
|
| 43 |
+
"right-ring-second-joint",
|
| 44 |
+
"right-ring-third-joint",
|
| 45 |
+
"right-pinky-tip",
|
| 46 |
+
"right-pinky-first-joint",
|
| 47 |
+
"right-pinky-second-joint",
|
| 48 |
+
"right-pinky-third-joint",
|
| 49 |
+
"right-wrist",
|
| 50 |
+
"left-thumb-tip",
|
| 51 |
+
"left-thumb-first-joint",
|
| 52 |
+
"left-thumb-second-joint",
|
| 53 |
+
"left-thumb-third-joint",
|
| 54 |
+
"left-index-tip",
|
| 55 |
+
"left-index-first-joint",
|
| 56 |
+
"left-index-second-joint",
|
| 57 |
+
"left-index-third-joint",
|
| 58 |
+
"left-middle-tip",
|
| 59 |
+
"left-middle-first-joint",
|
| 60 |
+
"left-middle-second-joint",
|
| 61 |
+
"left-middle-third-joint",
|
| 62 |
+
"left-ring-tip",
|
| 63 |
+
"left-ring-first-joint",
|
| 64 |
+
"left-ring-second-joint",
|
| 65 |
+
"left-ring-third-joint",
|
| 66 |
+
"left-pinky-tip",
|
| 67 |
+
"left-pinky-first-joint",
|
| 68 |
+
"left-pinky-second-joint",
|
| 69 |
+
"left-pinky-third-joint",
|
| 70 |
+
"left-wrist",
|
| 71 |
+
"left-olecranon",
|
| 72 |
+
"right-olecranon",
|
| 73 |
+
"left-cubital-fossa",
|
| 74 |
+
"right-cubital-fossa",
|
| 75 |
+
"left-acromion",
|
| 76 |
+
"right-acromion",
|
| 77 |
+
"neck",
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
pose_info = dict(
|
| 81 |
+
pose_format="mhr70",
|
| 82 |
+
paper_info=dict(
|
| 83 |
+
author="",
|
| 84 |
+
year="",
|
| 85 |
+
homepage="",
|
| 86 |
+
),
|
| 87 |
+
min_visible_keypoints=8,
|
| 88 |
+
image_height=4096,
|
| 89 |
+
image_width=2668,
|
| 90 |
+
original_keypoint_info={
|
| 91 |
+
0: "nose",
|
| 92 |
+
1: "left_eye",
|
| 93 |
+
2: "right_eye",
|
| 94 |
+
3: "left_ear",
|
| 95 |
+
4: "right_ear",
|
| 96 |
+
5: "left_shoulder",
|
| 97 |
+
6: "right_shoulder",
|
| 98 |
+
7: "left_elbow",
|
| 99 |
+
8: "right_elbow",
|
| 100 |
+
9: "left_hip",
|
| 101 |
+
10: "right_hip",
|
| 102 |
+
11: "left_knee",
|
| 103 |
+
12: "right_knee",
|
| 104 |
+
13: "left_ankle",
|
| 105 |
+
14: "right_ankle",
|
| 106 |
+
15: "left_big_toe_tip",
|
| 107 |
+
16: "left_small_toe_tip",
|
| 108 |
+
17: "left_heel",
|
| 109 |
+
18: "right_big_toe_tip",
|
| 110 |
+
19: "right_small_toe_tip",
|
| 111 |
+
20: "right_heel",
|
| 112 |
+
21: "right_thumb_tip",
|
| 113 |
+
22: "right_thumb_first_joint",
|
| 114 |
+
23: "right_thumb_second_joint",
|
| 115 |
+
24: "right_thumb_third_joint",
|
| 116 |
+
25: "right_index_tip",
|
| 117 |
+
26: "right_index_first_joint",
|
| 118 |
+
27: "right_index_second_joint",
|
| 119 |
+
28: "right_index_third_joint",
|
| 120 |
+
29: "right_middle_tip",
|
| 121 |
+
30: "right_middle_first_joint",
|
| 122 |
+
31: "right_middle_second_joint",
|
| 123 |
+
32: "right_middle_third_joint",
|
| 124 |
+
33: "right_ring_tip",
|
| 125 |
+
34: "right_ring_first_joint",
|
| 126 |
+
35: "right_ring_second_joint",
|
| 127 |
+
36: "right_ring_third_joint",
|
| 128 |
+
37: "right_pinky_tip",
|
| 129 |
+
38: "right_pinky_first_joint",
|
| 130 |
+
39: "right_pinky_second_joint",
|
| 131 |
+
40: "right_pinky_third_joint",
|
| 132 |
+
41: "right_wrist",
|
| 133 |
+
42: "left_thumb_tip",
|
| 134 |
+
43: "left_thumb_first_joint",
|
| 135 |
+
44: "left_thumb_second_joint",
|
| 136 |
+
45: "left_thumb_third_joint",
|
| 137 |
+
46: "left_index_tip",
|
| 138 |
+
47: "left_index_first_joint",
|
| 139 |
+
48: "left_index_second_joint",
|
| 140 |
+
49: "left_index_third_joint",
|
| 141 |
+
50: "left_middle_tip",
|
| 142 |
+
51: "left_middle_first_joint",
|
| 143 |
+
52: "left_middle_second_joint",
|
| 144 |
+
53: "left_middle_third_joint",
|
| 145 |
+
54: "left_ring_tip",
|
| 146 |
+
55: "left_ring_first_joint",
|
| 147 |
+
56: "left_ring_second_joint",
|
| 148 |
+
57: "left_ring_third_joint",
|
| 149 |
+
58: "left_pinky_tip",
|
| 150 |
+
59: "left_pinky_first_joint",
|
| 151 |
+
60: "left_pinky_second_joint",
|
| 152 |
+
61: "left_pinky_third_joint",
|
| 153 |
+
62: "left_wrist",
|
| 154 |
+
63: "left_olecranon",
|
| 155 |
+
64: "right_olecranon",
|
| 156 |
+
65: "left_cubital_fossa",
|
| 157 |
+
66: "right_cubital_fossa",
|
| 158 |
+
67: "left_acromion",
|
| 159 |
+
68: "right_acromion",
|
| 160 |
+
69: "neck",
|
| 161 |
+
},
|
| 162 |
+
keypoint_info={
|
| 163 |
+
0: dict(name="nose", id=0, color=[51, 153, 255], type="upper", swap=""),
|
| 164 |
+
1: dict(
|
| 165 |
+
name="left_eye", id=1, color=[51, 153, 255], type="upper", swap="right_eye"
|
| 166 |
+
),
|
| 167 |
+
2: dict(
|
| 168 |
+
name="right_eye", id=2, color=[51, 153, 255], type="upper", swap="left_eye"
|
| 169 |
+
),
|
| 170 |
+
3: dict(
|
| 171 |
+
name="left_ear", id=3, color=[51, 153, 255], type="upper", swap="right_ear"
|
| 172 |
+
),
|
| 173 |
+
4: dict(
|
| 174 |
+
name="right_ear", id=4, color=[51, 153, 255], type="upper", swap="left_ear"
|
| 175 |
+
),
|
| 176 |
+
5: dict(
|
| 177 |
+
name="left_shoulder",
|
| 178 |
+
id=5,
|
| 179 |
+
color=[51, 153, 255],
|
| 180 |
+
type="upper",
|
| 181 |
+
swap="right_shoulder",
|
| 182 |
+
),
|
| 183 |
+
6: dict(
|
| 184 |
+
name="right_shoulder",
|
| 185 |
+
id=6,
|
| 186 |
+
color=[51, 153, 255],
|
| 187 |
+
type="upper",
|
| 188 |
+
swap="left_shoulder",
|
| 189 |
+
),
|
| 190 |
+
7: dict(
|
| 191 |
+
name="left_elbow",
|
| 192 |
+
id=7,
|
| 193 |
+
color=[51, 153, 255],
|
| 194 |
+
type="upper",
|
| 195 |
+
swap="right_elbow",
|
| 196 |
+
),
|
| 197 |
+
8: dict(
|
| 198 |
+
name="right_elbow",
|
| 199 |
+
id=8,
|
| 200 |
+
color=[51, 153, 255],
|
| 201 |
+
type="upper",
|
| 202 |
+
swap="left_elbow",
|
| 203 |
+
),
|
| 204 |
+
9: dict(
|
| 205 |
+
name="left_hip", id=9, color=[51, 153, 255], type="lower", swap="right_hip"
|
| 206 |
+
),
|
| 207 |
+
10: dict(
|
| 208 |
+
name="right_hip", id=10, color=[51, 153, 255], type="lower", swap="left_hip"
|
| 209 |
+
),
|
| 210 |
+
11: dict(
|
| 211 |
+
name="left_knee",
|
| 212 |
+
id=11,
|
| 213 |
+
color=[51, 153, 255],
|
| 214 |
+
type="lower",
|
| 215 |
+
swap="right_knee",
|
| 216 |
+
),
|
| 217 |
+
12: dict(
|
| 218 |
+
name="right_knee",
|
| 219 |
+
id=12,
|
| 220 |
+
color=[51, 153, 255],
|
| 221 |
+
type="lower",
|
| 222 |
+
swap="left_knee",
|
| 223 |
+
),
|
| 224 |
+
13: dict(
|
| 225 |
+
name="left_ankle",
|
| 226 |
+
id=13,
|
| 227 |
+
color=[51, 153, 255],
|
| 228 |
+
type="lower",
|
| 229 |
+
swap="right_ankle",
|
| 230 |
+
),
|
| 231 |
+
14: dict(
|
| 232 |
+
name="right_ankle",
|
| 233 |
+
id=14,
|
| 234 |
+
color=[51, 153, 255],
|
| 235 |
+
type="lower",
|
| 236 |
+
swap="left_ankle",
|
| 237 |
+
),
|
| 238 |
+
15: dict(
|
| 239 |
+
name="left_big_toe",
|
| 240 |
+
id=15,
|
| 241 |
+
color=[51, 153, 255],
|
| 242 |
+
type="lower",
|
| 243 |
+
swap="right_big_toe",
|
| 244 |
+
),
|
| 245 |
+
16: dict(
|
| 246 |
+
name="left_small_toe",
|
| 247 |
+
id=16,
|
| 248 |
+
color=[51, 153, 255],
|
| 249 |
+
type="lower",
|
| 250 |
+
swap="right_small_toe",
|
| 251 |
+
),
|
| 252 |
+
17: dict(
|
| 253 |
+
name="left_heel",
|
| 254 |
+
id=17,
|
| 255 |
+
color=[51, 153, 255],
|
| 256 |
+
type="lower",
|
| 257 |
+
swap="right_heel",
|
| 258 |
+
),
|
| 259 |
+
18: dict(
|
| 260 |
+
name="right_big_toe",
|
| 261 |
+
id=18,
|
| 262 |
+
color=[51, 153, 255],
|
| 263 |
+
type="lower",
|
| 264 |
+
swap="left_big_toe",
|
| 265 |
+
),
|
| 266 |
+
19: dict(
|
| 267 |
+
name="right_small_toe",
|
| 268 |
+
id=19,
|
| 269 |
+
color=[51, 153, 255],
|
| 270 |
+
type="lower",
|
| 271 |
+
swap="left_small_toe",
|
| 272 |
+
),
|
| 273 |
+
20: dict(
|
| 274 |
+
name="right_heel",
|
| 275 |
+
id=20,
|
| 276 |
+
color=[51, 153, 255],
|
| 277 |
+
type="lower",
|
| 278 |
+
swap="left_heel",
|
| 279 |
+
),
|
| 280 |
+
21: dict(
|
| 281 |
+
name="right_thumb4",
|
| 282 |
+
id=21,
|
| 283 |
+
color=[51, 153, 255],
|
| 284 |
+
type="upper",
|
| 285 |
+
swap="left_thumb4",
|
| 286 |
+
),
|
| 287 |
+
22: dict(
|
| 288 |
+
name="right_thumb3",
|
| 289 |
+
id=22,
|
| 290 |
+
color=[51, 153, 255],
|
| 291 |
+
type="upper",
|
| 292 |
+
swap="left_thumb3",
|
| 293 |
+
),
|
| 294 |
+
23: dict(
|
| 295 |
+
name="right_thumb2",
|
| 296 |
+
id=23,
|
| 297 |
+
color=[51, 153, 255],
|
| 298 |
+
type="upper",
|
| 299 |
+
swap="left_thumb2",
|
| 300 |
+
),
|
| 301 |
+
24: dict(
|
| 302 |
+
name="right_thumb_third_joint",
|
| 303 |
+
id=24,
|
| 304 |
+
color=[51, 153, 255],
|
| 305 |
+
type="upper",
|
| 306 |
+
swap="left_thumb_third_joint",
|
| 307 |
+
),
|
| 308 |
+
25: dict(
|
| 309 |
+
name="right_forefinger4",
|
| 310 |
+
id=25,
|
| 311 |
+
color=[51, 153, 255],
|
| 312 |
+
type="upper",
|
| 313 |
+
swap="left_forefinger4",
|
| 314 |
+
),
|
| 315 |
+
26: dict(
|
| 316 |
+
name="right_forefinger3",
|
| 317 |
+
id=26,
|
| 318 |
+
color=[51, 153, 255],
|
| 319 |
+
type="upper",
|
| 320 |
+
swap="left_forefinger3",
|
| 321 |
+
),
|
| 322 |
+
27: dict(
|
| 323 |
+
name="right_forefinger2",
|
| 324 |
+
id=27,
|
| 325 |
+
color=[51, 153, 255],
|
| 326 |
+
type="upper",
|
| 327 |
+
swap="left_forefinger2",
|
| 328 |
+
),
|
| 329 |
+
28: dict(
|
| 330 |
+
name="right_forefinger_third_joint",
|
| 331 |
+
id=28,
|
| 332 |
+
color=[51, 153, 255],
|
| 333 |
+
type="upper",
|
| 334 |
+
swap="left_forefinger_third_joint",
|
| 335 |
+
),
|
| 336 |
+
29: dict(
|
| 337 |
+
name="right_middle_finger4",
|
| 338 |
+
id=29,
|
| 339 |
+
color=[51, 153, 255],
|
| 340 |
+
type="upper",
|
| 341 |
+
swap="left_middle_finger4",
|
| 342 |
+
),
|
| 343 |
+
30: dict(
|
| 344 |
+
name="right_middle_finger3",
|
| 345 |
+
id=30,
|
| 346 |
+
color=[51, 153, 255],
|
| 347 |
+
type="upper",
|
| 348 |
+
swap="left_middle_finger3",
|
| 349 |
+
),
|
| 350 |
+
31: dict(
|
| 351 |
+
name="right_middle_finger2",
|
| 352 |
+
id=31,
|
| 353 |
+
color=[51, 153, 255],
|
| 354 |
+
type="upper",
|
| 355 |
+
swap="left_middle_finger2",
|
| 356 |
+
),
|
| 357 |
+
32: dict(
|
| 358 |
+
name="right_middle_finger_third_joint",
|
| 359 |
+
id=32,
|
| 360 |
+
color=[51, 153, 255],
|
| 361 |
+
type="upper",
|
| 362 |
+
swap="left_middle_finger_third_joint",
|
| 363 |
+
),
|
| 364 |
+
33: dict(
|
| 365 |
+
name="right_ring_finger4",
|
| 366 |
+
id=33,
|
| 367 |
+
color=[51, 153, 255],
|
| 368 |
+
type="upper",
|
| 369 |
+
swap="left_ring_finger4",
|
| 370 |
+
),
|
| 371 |
+
34: dict(
|
| 372 |
+
name="right_ring_finger3",
|
| 373 |
+
id=34,
|
| 374 |
+
color=[51, 153, 255],
|
| 375 |
+
type="upper",
|
| 376 |
+
swap="left_ring_finger3",
|
| 377 |
+
),
|
| 378 |
+
35: dict(
|
| 379 |
+
name="right_ring_finger2",
|
| 380 |
+
id=35,
|
| 381 |
+
color=[51, 153, 255],
|
| 382 |
+
type="upper",
|
| 383 |
+
swap="left_ring_finger2",
|
| 384 |
+
),
|
| 385 |
+
36: dict(
|
| 386 |
+
name="right_ring_finger_third_joint",
|
| 387 |
+
id=36,
|
| 388 |
+
color=[51, 153, 255],
|
| 389 |
+
type="upper",
|
| 390 |
+
swap="left_ring_finger_third_joint",
|
| 391 |
+
),
|
| 392 |
+
37: dict(
|
| 393 |
+
name="right_pinky_finger4",
|
| 394 |
+
id=37,
|
| 395 |
+
color=[51, 153, 255],
|
| 396 |
+
type="upper",
|
| 397 |
+
swap="left_pinky_finger4",
|
| 398 |
+
),
|
| 399 |
+
38: dict(
|
| 400 |
+
name="right_pinky_finger3",
|
| 401 |
+
id=38,
|
| 402 |
+
color=[51, 153, 255],
|
| 403 |
+
type="upper",
|
| 404 |
+
swap="left_pinky_finger3",
|
| 405 |
+
),
|
| 406 |
+
39: dict(
|
| 407 |
+
name="right_pinky_finger2",
|
| 408 |
+
id=39,
|
| 409 |
+
color=[51, 153, 255],
|
| 410 |
+
type="upper",
|
| 411 |
+
swap="left_pinky_finger2",
|
| 412 |
+
),
|
| 413 |
+
40: dict(
|
| 414 |
+
name="right_pinky_finger_third_joint",
|
| 415 |
+
id=40,
|
| 416 |
+
color=[51, 153, 255],
|
| 417 |
+
type="upper",
|
| 418 |
+
swap="left_pinky_finger_third_joint",
|
| 419 |
+
),
|
| 420 |
+
41: dict(
|
| 421 |
+
name="right_wrist",
|
| 422 |
+
id=41,
|
| 423 |
+
color=[51, 153, 255],
|
| 424 |
+
type="upper",
|
| 425 |
+
swap="left_wrist",
|
| 426 |
+
),
|
| 427 |
+
42: dict(
|
| 428 |
+
name="left_thumb4",
|
| 429 |
+
id=42,
|
| 430 |
+
color=[51, 153, 255],
|
| 431 |
+
type="upper",
|
| 432 |
+
swap="right_thumb4",
|
| 433 |
+
),
|
| 434 |
+
43: dict(
|
| 435 |
+
name="left_thumb3",
|
| 436 |
+
id=43,
|
| 437 |
+
color=[51, 153, 255],
|
| 438 |
+
type="upper",
|
| 439 |
+
swap="right_thumb3",
|
| 440 |
+
),
|
| 441 |
+
44: dict(
|
| 442 |
+
name="left_thumb2",
|
| 443 |
+
id=44,
|
| 444 |
+
color=[51, 153, 255],
|
| 445 |
+
type="upper",
|
| 446 |
+
swap="right_thumb2",
|
| 447 |
+
),
|
| 448 |
+
45: dict(
|
| 449 |
+
name="left_thumb_third_joint",
|
| 450 |
+
id=45,
|
| 451 |
+
color=[51, 153, 255],
|
| 452 |
+
type="upper",
|
| 453 |
+
swap="right_thumb_third_joint",
|
| 454 |
+
), ## doesnt match with wholebody
|
| 455 |
+
46: dict(
|
| 456 |
+
name="left_forefinger4",
|
| 457 |
+
id=46,
|
| 458 |
+
color=[51, 153, 255],
|
| 459 |
+
type="upper",
|
| 460 |
+
swap="right_forefinger4",
|
| 461 |
+
),
|
| 462 |
+
47: dict(
|
| 463 |
+
name="left_forefinger3",
|
| 464 |
+
id=47,
|
| 465 |
+
color=[51, 153, 255],
|
| 466 |
+
type="upper",
|
| 467 |
+
swap="right_forefinger3",
|
| 468 |
+
),
|
| 469 |
+
48: dict(
|
| 470 |
+
name="left_forefinger2",
|
| 471 |
+
id=48,
|
| 472 |
+
color=[51, 153, 255],
|
| 473 |
+
type="upper",
|
| 474 |
+
swap="right_forefinger2",
|
| 475 |
+
),
|
| 476 |
+
49: dict(
|
| 477 |
+
name="left_forefinger_third_joint",
|
| 478 |
+
id=49,
|
| 479 |
+
color=[51, 153, 255],
|
| 480 |
+
type="upper",
|
| 481 |
+
swap="right_forefinger_third_joint",
|
| 482 |
+
),
|
| 483 |
+
50: dict(
|
| 484 |
+
name="left_middle_finger4",
|
| 485 |
+
id=50,
|
| 486 |
+
color=[51, 153, 255],
|
| 487 |
+
type="upper",
|
| 488 |
+
swap="right_middle_finger4",
|
| 489 |
+
),
|
| 490 |
+
51: dict(
|
| 491 |
+
name="left_middle_finger3",
|
| 492 |
+
id=51,
|
| 493 |
+
color=[51, 153, 255],
|
| 494 |
+
type="upper",
|
| 495 |
+
swap="right_middle_finger3",
|
| 496 |
+
),
|
| 497 |
+
52: dict(
|
| 498 |
+
name="left_middle_finger2",
|
| 499 |
+
id=52,
|
| 500 |
+
color=[51, 153, 255],
|
| 501 |
+
type="upper",
|
| 502 |
+
swap="right_middle_finger2",
|
| 503 |
+
),
|
| 504 |
+
53: dict(
|
| 505 |
+
name="left_middle_finger_third_joint",
|
| 506 |
+
id=53,
|
| 507 |
+
color=[51, 153, 255],
|
| 508 |
+
type="upper",
|
| 509 |
+
swap="right_middle_finger_third_joint",
|
| 510 |
+
),
|
| 511 |
+
54: dict(
|
| 512 |
+
name="left_ring_finger4",
|
| 513 |
+
id=54,
|
| 514 |
+
color=[51, 153, 255],
|
| 515 |
+
type="upper",
|
| 516 |
+
swap="right_ring_finger4",
|
| 517 |
+
),
|
| 518 |
+
55: dict(
|
| 519 |
+
name="left_ring_finger3",
|
| 520 |
+
id=55,
|
| 521 |
+
color=[51, 153, 255],
|
| 522 |
+
type="upper",
|
| 523 |
+
swap="right_ring_finger3",
|
| 524 |
+
),
|
| 525 |
+
56: dict(
|
| 526 |
+
name="left_ring_finger2",
|
| 527 |
+
id=56,
|
| 528 |
+
color=[51, 153, 255],
|
| 529 |
+
type="upper",
|
| 530 |
+
swap="right_ring_finger2",
|
| 531 |
+
),
|
| 532 |
+
57: dict(
|
| 533 |
+
name="left_ring_finger_third_joint",
|
| 534 |
+
id=57,
|
| 535 |
+
color=[51, 153, 255],
|
| 536 |
+
type="upper",
|
| 537 |
+
swap="right_ring_finger_third_joint",
|
| 538 |
+
),
|
| 539 |
+
58: dict(
|
| 540 |
+
name="left_pinky_finger4",
|
| 541 |
+
id=58,
|
| 542 |
+
color=[51, 153, 255],
|
| 543 |
+
type="upper",
|
| 544 |
+
swap="right_pinky_finger4",
|
| 545 |
+
),
|
| 546 |
+
59: dict(
|
| 547 |
+
name="left_pinky_finger3",
|
| 548 |
+
id=59,
|
| 549 |
+
color=[51, 153, 255],
|
| 550 |
+
type="upper",
|
| 551 |
+
swap="right_pinky_finger3",
|
| 552 |
+
),
|
| 553 |
+
60: dict(
|
| 554 |
+
name="left_pinky_finger2",
|
| 555 |
+
id=60,
|
| 556 |
+
color=[51, 153, 255],
|
| 557 |
+
type="upper",
|
| 558 |
+
swap="right_pinky_finger2",
|
| 559 |
+
),
|
| 560 |
+
61: dict(
|
| 561 |
+
name="left_pinky_finger_third_joint",
|
| 562 |
+
id=61,
|
| 563 |
+
color=[51, 153, 255],
|
| 564 |
+
type="upper",
|
| 565 |
+
swap="right_pinky_finger_third_joint",
|
| 566 |
+
),
|
| 567 |
+
62: dict(
|
| 568 |
+
name="left_wrist",
|
| 569 |
+
id=62,
|
| 570 |
+
color=[51, 153, 255],
|
| 571 |
+
type="upper",
|
| 572 |
+
swap="right_wrist",
|
| 573 |
+
),
|
| 574 |
+
63: dict(
|
| 575 |
+
name="left_olecranon",
|
| 576 |
+
id=63,
|
| 577 |
+
color=[51, 153, 255],
|
| 578 |
+
type="",
|
| 579 |
+
swap="right_olecranon",
|
| 580 |
+
),
|
| 581 |
+
64: dict(
|
| 582 |
+
name="right_olecranon",
|
| 583 |
+
id=64,
|
| 584 |
+
color=[51, 153, 255],
|
| 585 |
+
type="",
|
| 586 |
+
swap="left_olecranon",
|
| 587 |
+
),
|
| 588 |
+
65: dict(
|
| 589 |
+
name="left_cubital_fossa",
|
| 590 |
+
id=65,
|
| 591 |
+
color=[51, 153, 255],
|
| 592 |
+
type="",
|
| 593 |
+
swap="right_cubital_fossa",
|
| 594 |
+
),
|
| 595 |
+
66: dict(
|
| 596 |
+
name="right_cubital_fossa",
|
| 597 |
+
id=66,
|
| 598 |
+
color=[51, 153, 255],
|
| 599 |
+
type="",
|
| 600 |
+
swap="left_cubital_fossa",
|
| 601 |
+
),
|
| 602 |
+
67: dict(
|
| 603 |
+
name="left_acromion",
|
| 604 |
+
id=67,
|
| 605 |
+
color=[51, 153, 255],
|
| 606 |
+
type="",
|
| 607 |
+
swap="right_acromion",
|
| 608 |
+
),
|
| 609 |
+
68: dict(
|
| 610 |
+
name="right_acromion",
|
| 611 |
+
id=68,
|
| 612 |
+
color=[51, 153, 255],
|
| 613 |
+
type="",
|
| 614 |
+
swap="left_acromion",
|
| 615 |
+
),
|
| 616 |
+
69: dict(name="neck", id=69, color=[51, 153, 255], type="", swap=""),
|
| 617 |
+
},
|
| 618 |
+
skeleton_info={
|
| 619 |
+
0: dict(link=("left_ankle", "left_knee"), id=0, color=[0, 255, 0]),
|
| 620 |
+
1: dict(link=("left_knee", "left_hip"), id=1, color=[0, 255, 0]),
|
| 621 |
+
2: dict(link=("right_ankle", "right_knee"), id=2, color=[255, 128, 0]),
|
| 622 |
+
3: dict(link=("right_knee", "right_hip"), id=3, color=[255, 128, 0]),
|
| 623 |
+
4: dict(link=("left_hip", "right_hip"), id=4, color=[51, 153, 255]),
|
| 624 |
+
5: dict(link=("left_shoulder", "left_hip"), id=5, color=[51, 153, 255]),
|
| 625 |
+
6: dict(link=("right_shoulder", "right_hip"), id=6, color=[51, 153, 255]),
|
| 626 |
+
7: dict(link=("left_shoulder", "right_shoulder"), id=7, color=[51, 153, 255]),
|
| 627 |
+
8: dict(link=("left_shoulder", "left_elbow"), id=8, color=[0, 255, 0]),
|
| 628 |
+
9: dict(link=("right_shoulder", "right_elbow"), id=9, color=[255, 128, 0]),
|
| 629 |
+
10: dict(link=("left_elbow", "left_wrist"), id=10, color=[0, 255, 0]),
|
| 630 |
+
11: dict(link=("right_elbow", "right_wrist"), id=11, color=[255, 128, 0]),
|
| 631 |
+
12: dict(link=("left_eye", "right_eye"), id=12, color=[51, 153, 255]),
|
| 632 |
+
13: dict(link=("nose", "left_eye"), id=13, color=[51, 153, 255]),
|
| 633 |
+
14: dict(link=("nose", "right_eye"), id=14, color=[51, 153, 255]),
|
| 634 |
+
15: dict(link=("left_eye", "left_ear"), id=15, color=[51, 153, 255]),
|
| 635 |
+
16: dict(link=("right_eye", "right_ear"), id=16, color=[51, 153, 255]),
|
| 636 |
+
17: dict(link=("left_ear", "left_shoulder"), id=17, color=[51, 153, 255]),
|
| 637 |
+
18: dict(link=("right_ear", "right_shoulder"), id=18, color=[51, 153, 255]),
|
| 638 |
+
19: dict(link=("left_ankle", "left_big_toe"), id=19, color=[0, 255, 0]),
|
| 639 |
+
20: dict(link=("left_ankle", "left_small_toe"), id=20, color=[0, 255, 0]),
|
| 640 |
+
21: dict(link=("left_ankle", "left_heel"), id=21, color=[0, 255, 0]),
|
| 641 |
+
22: dict(link=("right_ankle", "right_big_toe"), id=22, color=[255, 128, 0]),
|
| 642 |
+
23: dict(link=("right_ankle", "right_small_toe"), id=23, color=[255, 128, 0]),
|
| 643 |
+
24: dict(link=("right_ankle", "right_heel"), id=24, color=[255, 128, 0]),
|
| 644 |
+
25: dict(
|
| 645 |
+
link=("left_wrist", "left_thumb_third_joint"), id=25, color=[255, 128, 0]
|
| 646 |
+
),
|
| 647 |
+
26: dict(
|
| 648 |
+
link=("left_thumb_third_joint", "left_thumb2"), id=26, color=[255, 128, 0]
|
| 649 |
+
),
|
| 650 |
+
27: dict(link=("left_thumb2", "left_thumb3"), id=27, color=[255, 128, 0]),
|
| 651 |
+
28: dict(link=("left_thumb3", "left_thumb4"), id=28, color=[255, 128, 0]),
|
| 652 |
+
29: dict(
|
| 653 |
+
link=("left_wrist", "left_forefinger_third_joint"),
|
| 654 |
+
id=29,
|
| 655 |
+
color=[255, 153, 255],
|
| 656 |
+
),
|
| 657 |
+
30: dict(
|
| 658 |
+
link=("left_forefinger_third_joint", "left_forefinger2"),
|
| 659 |
+
id=30,
|
| 660 |
+
color=[255, 153, 255],
|
| 661 |
+
),
|
| 662 |
+
31: dict(
|
| 663 |
+
link=("left_forefinger2", "left_forefinger3"), id=31, color=[255, 153, 255]
|
| 664 |
+
),
|
| 665 |
+
32: dict(
|
| 666 |
+
link=("left_forefinger3", "left_forefinger4"), id=32, color=[255, 153, 255]
|
| 667 |
+
),
|
| 668 |
+
33: dict(
|
| 669 |
+
link=("left_wrist", "left_middle_finger_third_joint"),
|
| 670 |
+
id=33,
|
| 671 |
+
color=[102, 178, 255],
|
| 672 |
+
),
|
| 673 |
+
34: dict(
|
| 674 |
+
link=("left_middle_finger_third_joint", "left_middle_finger2"),
|
| 675 |
+
id=34,
|
| 676 |
+
color=[102, 178, 255],
|
| 677 |
+
),
|
| 678 |
+
35: dict(
|
| 679 |
+
link=("left_middle_finger2", "left_middle_finger3"),
|
| 680 |
+
id=35,
|
| 681 |
+
color=[102, 178, 255],
|
| 682 |
+
),
|
| 683 |
+
36: dict(
|
| 684 |
+
link=("left_middle_finger3", "left_middle_finger4"),
|
| 685 |
+
id=36,
|
| 686 |
+
color=[102, 178, 255],
|
| 687 |
+
),
|
| 688 |
+
37: dict(
|
| 689 |
+
link=("left_wrist", "left_ring_finger_third_joint"),
|
| 690 |
+
id=37,
|
| 691 |
+
color=[255, 51, 51],
|
| 692 |
+
),
|
| 693 |
+
38: dict(
|
| 694 |
+
link=("left_ring_finger_third_joint", "left_ring_finger2"),
|
| 695 |
+
id=38,
|
| 696 |
+
color=[255, 51, 51],
|
| 697 |
+
),
|
| 698 |
+
39: dict(
|
| 699 |
+
link=("left_ring_finger2", "left_ring_finger3"), id=39, color=[255, 51, 51]
|
| 700 |
+
),
|
| 701 |
+
40: dict(
|
| 702 |
+
link=("left_ring_finger3", "left_ring_finger4"), id=40, color=[255, 51, 51]
|
| 703 |
+
),
|
| 704 |
+
41: dict(
|
| 705 |
+
link=("left_wrist", "left_pinky_finger_third_joint"),
|
| 706 |
+
id=41,
|
| 707 |
+
color=[0, 255, 0],
|
| 708 |
+
),
|
| 709 |
+
42: dict(
|
| 710 |
+
link=("left_pinky_finger_third_joint", "left_pinky_finger2"),
|
| 711 |
+
id=42,
|
| 712 |
+
color=[0, 255, 0],
|
| 713 |
+
),
|
| 714 |
+
43: dict(
|
| 715 |
+
link=("left_pinky_finger2", "left_pinky_finger3"), id=43, color=[0, 255, 0]
|
| 716 |
+
),
|
| 717 |
+
44: dict(
|
| 718 |
+
link=("left_pinky_finger3", "left_pinky_finger4"), id=44, color=[0, 255, 0]
|
| 719 |
+
),
|
| 720 |
+
45: dict(
|
| 721 |
+
link=("right_wrist", "right_thumb_third_joint"), id=45, color=[255, 128, 0]
|
| 722 |
+
),
|
| 723 |
+
46: dict(
|
| 724 |
+
link=("right_thumb_third_joint", "right_thumb2"), id=46, color=[255, 128, 0]
|
| 725 |
+
),
|
| 726 |
+
47: dict(link=("right_thumb2", "right_thumb3"), id=47, color=[255, 128, 0]),
|
| 727 |
+
48: dict(link=("right_thumb3", "right_thumb4"), id=48, color=[255, 128, 0]),
|
| 728 |
+
49: dict(
|
| 729 |
+
link=("right_wrist", "right_forefinger_third_joint"),
|
| 730 |
+
id=49,
|
| 731 |
+
color=[255, 153, 255],
|
| 732 |
+
),
|
| 733 |
+
50: dict(
|
| 734 |
+
link=("right_forefinger_third_joint", "right_forefinger2"),
|
| 735 |
+
id=50,
|
| 736 |
+
color=[255, 153, 255],
|
| 737 |
+
),
|
| 738 |
+
51: dict(
|
| 739 |
+
link=("right_forefinger2", "right_forefinger3"),
|
| 740 |
+
id=51,
|
| 741 |
+
color=[255, 153, 255],
|
| 742 |
+
),
|
| 743 |
+
52: dict(
|
| 744 |
+
link=("right_forefinger3", "right_forefinger4"),
|
| 745 |
+
id=52,
|
| 746 |
+
color=[255, 153, 255],
|
| 747 |
+
),
|
| 748 |
+
53: dict(
|
| 749 |
+
link=("right_wrist", "right_middle_finger_third_joint"),
|
| 750 |
+
id=53,
|
| 751 |
+
color=[102, 178, 255],
|
| 752 |
+
),
|
| 753 |
+
54: dict(
|
| 754 |
+
link=("right_middle_finger_third_joint", "right_middle_finger2"),
|
| 755 |
+
id=54,
|
| 756 |
+
color=[102, 178, 255],
|
| 757 |
+
),
|
| 758 |
+
55: dict(
|
| 759 |
+
link=("right_middle_finger2", "right_middle_finger3"),
|
| 760 |
+
id=55,
|
| 761 |
+
color=[102, 178, 255],
|
| 762 |
+
),
|
| 763 |
+
56: dict(
|
| 764 |
+
link=("right_middle_finger3", "right_middle_finger4"),
|
| 765 |
+
id=56,
|
| 766 |
+
color=[102, 178, 255],
|
| 767 |
+
),
|
| 768 |
+
57: dict(
|
| 769 |
+
link=("right_wrist", "right_ring_finger_third_joint"),
|
| 770 |
+
id=57,
|
| 771 |
+
color=[255, 51, 51],
|
| 772 |
+
),
|
| 773 |
+
58: dict(
|
| 774 |
+
link=("right_ring_finger_third_joint", "right_ring_finger2"),
|
| 775 |
+
id=58,
|
| 776 |
+
color=[255, 51, 51],
|
| 777 |
+
),
|
| 778 |
+
59: dict(
|
| 779 |
+
link=("right_ring_finger2", "right_ring_finger3"),
|
| 780 |
+
id=59,
|
| 781 |
+
color=[255, 51, 51],
|
| 782 |
+
),
|
| 783 |
+
60: dict(
|
| 784 |
+
link=("right_ring_finger3", "right_ring_finger4"),
|
| 785 |
+
id=60,
|
| 786 |
+
color=[255, 51, 51],
|
| 787 |
+
),
|
| 788 |
+
61: dict(
|
| 789 |
+
link=("right_wrist", "right_pinky_finger_third_joint"),
|
| 790 |
+
id=61,
|
| 791 |
+
color=[0, 255, 0],
|
| 792 |
+
),
|
| 793 |
+
62: dict(
|
| 794 |
+
link=("right_pinky_finger_third_joint", "right_pinky_finger2"),
|
| 795 |
+
id=62,
|
| 796 |
+
color=[0, 255, 0],
|
| 797 |
+
),
|
| 798 |
+
63: dict(
|
| 799 |
+
link=("right_pinky_finger2", "right_pinky_finger3"),
|
| 800 |
+
id=63,
|
| 801 |
+
color=[0, 255, 0],
|
| 802 |
+
),
|
| 803 |
+
64: dict(
|
| 804 |
+
link=("right_pinky_finger3", "right_pinky_finger4"),
|
| 805 |
+
id=64,
|
| 806 |
+
color=[0, 255, 0],
|
| 807 |
+
),
|
| 808 |
+
},
|
| 809 |
+
joint_weights=[1.0] * 70,
|
| 810 |
+
body_keypoint_names=[
|
| 811 |
+
"nose",
|
| 812 |
+
"left_eye",
|
| 813 |
+
"right_eye",
|
| 814 |
+
"left_ear",
|
| 815 |
+
"right_ear",
|
| 816 |
+
"left_shoulder",
|
| 817 |
+
"right_shoulder",
|
| 818 |
+
"left_elbow",
|
| 819 |
+
"right_elbow",
|
| 820 |
+
"left_wrist",
|
| 821 |
+
"right_wrist",
|
| 822 |
+
"left_hip",
|
| 823 |
+
"right_hip",
|
| 824 |
+
"left_knee",
|
| 825 |
+
"right_knee",
|
| 826 |
+
"left_ankle",
|
| 827 |
+
"right_ankle",
|
| 828 |
+
],
|
| 829 |
+
foot_keypoint_names=[
|
| 830 |
+
"left_big_toe",
|
| 831 |
+
"left_small_toe",
|
| 832 |
+
"left_heel",
|
| 833 |
+
"right_big_toe",
|
| 834 |
+
"right_small_toe",
|
| 835 |
+
"right_heel",
|
| 836 |
+
],
|
| 837 |
+
left_hand_keypoint_names=[
|
| 838 |
+
"left_thumb4",
|
| 839 |
+
"left_thumb3",
|
| 840 |
+
"left_thumb2",
|
| 841 |
+
"left_thumb_third_joint",
|
| 842 |
+
"left_forefinger4",
|
| 843 |
+
"left_forefinger3",
|
| 844 |
+
"left_forefinger2",
|
| 845 |
+
"left_forefinger_third_joint",
|
| 846 |
+
"left_middle_finger4",
|
| 847 |
+
"left_middle_finger3",
|
| 848 |
+
"left_middle_finger2",
|
| 849 |
+
"left_middle_finger_third_joint",
|
| 850 |
+
"left_ring_finger4",
|
| 851 |
+
"left_ring_finger3",
|
| 852 |
+
"left_ring_finger2",
|
| 853 |
+
"left_ring_finger_third_joint",
|
| 854 |
+
"left_pinky_finger4",
|
| 855 |
+
"left_pinky_finger3",
|
| 856 |
+
"left_pinky_finger2",
|
| 857 |
+
"left_pinky_finger_third_joint",
|
| 858 |
+
],
|
| 859 |
+
right_hand_keypoint_names=[
|
| 860 |
+
"right_thumb4",
|
| 861 |
+
"right_thumb3",
|
| 862 |
+
"right_thumb2",
|
| 863 |
+
"right_thumb_third_joint",
|
| 864 |
+
"right_forefinger4",
|
| 865 |
+
"right_forefinger3",
|
| 866 |
+
"right_forefinger2",
|
| 867 |
+
"right_forefinger_third_joint",
|
| 868 |
+
"right_middle_finger4",
|
| 869 |
+
"right_middle_finger3",
|
| 870 |
+
"right_middle_finger2",
|
| 871 |
+
"right_middle_finger_third_joint",
|
| 872 |
+
"right_ring_finger4",
|
| 873 |
+
"right_ring_finger3",
|
| 874 |
+
"right_ring_finger2",
|
| 875 |
+
"right_ring_finger_third_joint",
|
| 876 |
+
"right_pinky_finger4",
|
| 877 |
+
"right_pinky_finger3",
|
| 878 |
+
"right_pinky_finger2",
|
| 879 |
+
"right_pinky_finger_third_joint",
|
| 880 |
+
],
|
| 881 |
+
## 7 of them
|
| 882 |
+
extra_keypoint_names=[
|
| 883 |
+
"neck",
|
| 884 |
+
"left_olecranon",
|
| 885 |
+
"right_olecranon",
|
| 886 |
+
"left_cubital_fossa",
|
| 887 |
+
"right_cubital_fossa",
|
| 888 |
+
"left_acromion",
|
| 889 |
+
"right_acromion",
|
| 890 |
+
],
|
| 891 |
+
sigmas=[],
|
| 892 |
+
)
|
| 893 |
+
|
| 894 |
+
# Rerun‑friendly helpers ----------------------------------------------------
|
| 895 |
+
# These mirror the COCO‑133 helpers exposed by ``simplecv.data.skeleton.coco_133``
|
| 896 |
+
# so downstream code can build annotation contexts without re‑deriving names/links.
|
| 897 |
+
|
| 898 |
+
MHR70_ID2NAME: Final[dict[int, str]] = {
|
| 899 |
+
idx: info["name"] for idx, info in pose_info["keypoint_info"].items()
|
| 900 |
+
}
|
| 901 |
+
|
| 902 |
+
MHR70_IDS: Final[list[int]] = sorted(MHR70_ID2NAME.keys())
|
| 903 |
+
|
| 904 |
+
_NAME_TO_ID = {name: idx for idx, name in MHR70_ID2NAME.items()}
|
| 905 |
+
MHR70_LINKS: Final[list[tuple[int, int]]] = [
|
| 906 |
+
(_NAME_TO_ID[link_info["link"][0]], _NAME_TO_ID[link_info["link"][1]])
|
| 907 |
+
for link_info in pose_info["skeleton_info"].values()
|
| 908 |
+
]
|
| 909 |
+
|
| 910 |
+
__all__ = [
|
| 911 |
+
"pose_info",
|
| 912 |
+
"MHR70_ID2NAME",
|
| 913 |
+
"MHR70_IDS",
|
| 914 |
+
"MHR70_LINKS",
|
| 915 |
+
]
|
src/sam3d_body/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
src/sam3d_body/models/backbones/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def create_backbone(name, cfg=None):
|
| 5 |
+
if name in ["vit_hmr"]:
|
| 6 |
+
from .vit import vit
|
| 7 |
+
|
| 8 |
+
backbone = vit(cfg)
|
| 9 |
+
elif name in ["vit_hmr_512_384"]:
|
| 10 |
+
from .vit import vit512_384
|
| 11 |
+
|
| 12 |
+
backbone = vit512_384(cfg)
|
| 13 |
+
elif name in ["vit_l"]:
|
| 14 |
+
from .vit import vit_l
|
| 15 |
+
|
| 16 |
+
backbone = vit_l(cfg)
|
| 17 |
+
elif name in ["vit_b"]:
|
| 18 |
+
from .vit import vit_b
|
| 19 |
+
|
| 20 |
+
backbone = vit_b(cfg)
|
| 21 |
+
elif name in [
|
| 22 |
+
"dinov3_vit7b",
|
| 23 |
+
"dinov3_vith16plus",
|
| 24 |
+
"dinov3_vits16",
|
| 25 |
+
"dinov3_vits16plus",
|
| 26 |
+
"dinov3_vitb16",
|
| 27 |
+
"dinov3_vitl16",
|
| 28 |
+
]:
|
| 29 |
+
from .dinov3 import Dinov3Backbone
|
| 30 |
+
|
| 31 |
+
backbone = Dinov3Backbone(name, cfg=cfg)
|
| 32 |
+
else:
|
| 33 |
+
raise NotImplementedError("Backbone type is not implemented")
|
| 34 |
+
|
| 35 |
+
return backbone
|
src/sam3d_body/models/backbones/dinov3.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Dinov3Backbone(nn.Module):
|
| 8 |
+
def __init__(
|
| 9 |
+
self, name="dinov2_vitb14", pretrained_weight=None, cfg=None, *args, **kwargs
|
| 10 |
+
):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.name = name
|
| 13 |
+
self.cfg = cfg
|
| 14 |
+
|
| 15 |
+
self.encoder = torch.hub.load(
|
| 16 |
+
"facebookresearch/dinov3",
|
| 17 |
+
self.name,
|
| 18 |
+
source="github",
|
| 19 |
+
pretrained=False,
|
| 20 |
+
drop_path=self.cfg.MODEL.BACKBONE.DROP_PATH_RATE,
|
| 21 |
+
)
|
| 22 |
+
self.patch_size = self.encoder.patch_size
|
| 23 |
+
self.embed_dim = self.embed_dims = self.encoder.embed_dim
|
| 24 |
+
|
| 25 |
+
def forward(self, x, extra_embed=None):
|
| 26 |
+
"""
|
| 27 |
+
Encode a RGB image using a ViT-backbone
|
| 28 |
+
Args:
|
| 29 |
+
- x: torch.Tensor of shape [bs,3,w,h]
|
| 30 |
+
Return:
|
| 31 |
+
- y: torch.Tensor of shape [bs,k,d] - image in patchified mode
|
| 32 |
+
"""
|
| 33 |
+
assert extra_embed is None, "Not Implemented Yet"
|
| 34 |
+
|
| 35 |
+
y = self.encoder.get_intermediate_layers(x, n=1, reshape=True, norm=True)[-1]
|
| 36 |
+
|
| 37 |
+
return y
|
| 38 |
+
|
| 39 |
+
def get_layer_depth(self, param_name: str, prefix: str = "encoder."):
|
| 40 |
+
"""Get the layer-wise depth of a parameter.
|
| 41 |
+
Args:
|
| 42 |
+
param_name (str): The name of the parameter.
|
| 43 |
+
prefix (str): The prefix for the parameter.
|
| 44 |
+
Defaults to an empty string.
|
| 45 |
+
Returns:
|
| 46 |
+
Tuple[int, int]: The layer-wise depth and the num of layers.
|
| 47 |
+
Note:
|
| 48 |
+
The first depth is the stem module (``layer_depth=0``), and the
|
| 49 |
+
last depth is the subsequent module (``layer_depth=num_layers-1``)
|
| 50 |
+
"""
|
| 51 |
+
num_layers = self.encoder.n_blocks + 2
|
| 52 |
+
|
| 53 |
+
if not param_name.startswith(prefix):
|
| 54 |
+
# For subsequent module like head
|
| 55 |
+
return num_layers - 1, num_layers
|
| 56 |
+
|
| 57 |
+
param_name = param_name[len(prefix) :]
|
| 58 |
+
|
| 59 |
+
if param_name in ("cls_token", "pos_embed", "storage_tokens"):
|
| 60 |
+
layer_depth = 0
|
| 61 |
+
elif param_name.startswith("patch_embed"):
|
| 62 |
+
layer_depth = 0
|
| 63 |
+
elif param_name.startswith("blocks"):
|
| 64 |
+
layer_id = int(param_name.split(".")[1])
|
| 65 |
+
layer_depth = layer_id + 1
|
| 66 |
+
else:
|
| 67 |
+
layer_depth = num_layers - 1
|
| 68 |
+
|
| 69 |
+
return layer_depth, num_layers
|
src/sam3d_body/models/backbones/vit.py
ADDED
|
@@ -0,0 +1,658 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
from functools import partial
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torch.utils.checkpoint as checkpoint
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from flash_attn.flash_attn_interface import flash_attn_func
|
| 12 |
+
except:
|
| 13 |
+
print("No Flash Attention!")
|
| 14 |
+
|
| 15 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
| 16 |
+
|
| 17 |
+
from ..modules.transformer import LayerNorm32
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def vit(cfg):
|
| 21 |
+
return ViT(
|
| 22 |
+
img_size=(256, 192),
|
| 23 |
+
patch_size=16,
|
| 24 |
+
embed_dim=1280,
|
| 25 |
+
depth=32,
|
| 26 |
+
num_heads=16,
|
| 27 |
+
ratio=1,
|
| 28 |
+
norm_layer=LayerNorm32,
|
| 29 |
+
use_checkpoint=False,
|
| 30 |
+
mlp_ratio=4,
|
| 31 |
+
qkv_bias=True,
|
| 32 |
+
drop_path_rate=0.55,
|
| 33 |
+
frozen_stages=cfg.MODEL.BACKBONE.get("FROZEN_STAGES", -1),
|
| 34 |
+
flash_attn=cfg.MODEL.BACKBONE.get("FLASH_ATTN", False),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def vit_l(cfg):
|
| 39 |
+
return ViT(
|
| 40 |
+
img_size=(256, 192),
|
| 41 |
+
patch_size=16,
|
| 42 |
+
embed_dim=1024,
|
| 43 |
+
depth=24,
|
| 44 |
+
num_heads=16,
|
| 45 |
+
ratio=1,
|
| 46 |
+
norm_layer=LayerNorm32,
|
| 47 |
+
use_checkpoint=False,
|
| 48 |
+
mlp_ratio=4,
|
| 49 |
+
qkv_bias=True,
|
| 50 |
+
drop_path_rate=0.55,
|
| 51 |
+
frozen_stages=cfg.MODEL.BACKBONE.get("FROZEN_STAGES", -1),
|
| 52 |
+
flash_attn=cfg.MODEL.BACKBONE.get("FLASH_ATTN", False),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def vit_b(cfg):
|
| 57 |
+
return ViT(
|
| 58 |
+
img_size=(256, 192),
|
| 59 |
+
patch_size=16,
|
| 60 |
+
embed_dim=768,
|
| 61 |
+
depth=12,
|
| 62 |
+
num_heads=12,
|
| 63 |
+
ratio=1,
|
| 64 |
+
norm_layer=LayerNorm32,
|
| 65 |
+
use_checkpoint=False,
|
| 66 |
+
mlp_ratio=4,
|
| 67 |
+
qkv_bias=True,
|
| 68 |
+
drop_path_rate=0.3,
|
| 69 |
+
frozen_stages=cfg.MODEL.BACKBONE.get("FROZEN_STAGES", -1),
|
| 70 |
+
flash_attn=cfg.MODEL.BACKBONE.get("FLASH_ATTN", False),
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def vit256(cfg):
|
| 75 |
+
return ViT(
|
| 76 |
+
img_size=(256, 256),
|
| 77 |
+
patch_size=16,
|
| 78 |
+
embed_dim=1280,
|
| 79 |
+
depth=32,
|
| 80 |
+
num_heads=16,
|
| 81 |
+
ratio=1,
|
| 82 |
+
norm_layer=LayerNorm32,
|
| 83 |
+
use_checkpoint=False,
|
| 84 |
+
mlp_ratio=4,
|
| 85 |
+
qkv_bias=True,
|
| 86 |
+
drop_path_rate=0.55,
|
| 87 |
+
frozen_stages=cfg.MODEL.BACKBONE.get("FROZEN_STAGES", -1),
|
| 88 |
+
flash_attn=cfg.MODEL.BACKBONE.get("FLASH_ATTN", False),
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def vit512_384(cfg):
|
| 93 |
+
return ViT(
|
| 94 |
+
img_size=(512, 384),
|
| 95 |
+
patch_size=16,
|
| 96 |
+
embed_dim=1280,
|
| 97 |
+
depth=32,
|
| 98 |
+
num_heads=16,
|
| 99 |
+
ratio=1,
|
| 100 |
+
norm_layer=LayerNorm32,
|
| 101 |
+
use_checkpoint=False,
|
| 102 |
+
mlp_ratio=4,
|
| 103 |
+
qkv_bias=True,
|
| 104 |
+
drop_path_rate=0.55,
|
| 105 |
+
frozen_stages=cfg.MODEL.BACKBONE.get("FROZEN_STAGES", -1),
|
| 106 |
+
flash_attn=cfg.MODEL.BACKBONE.get("FLASH_ATTN", False),
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True):
|
| 111 |
+
"""
|
| 112 |
+
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
|
| 113 |
+
dimension for the original embeddings.
|
| 114 |
+
Args:
|
| 115 |
+
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
|
| 116 |
+
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
|
| 117 |
+
hw (Tuple): size of input image tokens.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
Absolute positional embeddings after processing with shape (1, H, W, C)
|
| 121 |
+
"""
|
| 122 |
+
cls_token = None
|
| 123 |
+
B, L, C = abs_pos.shape
|
| 124 |
+
if has_cls_token:
|
| 125 |
+
cls_token = abs_pos[:, 0:1]
|
| 126 |
+
abs_pos = abs_pos[:, 1:]
|
| 127 |
+
|
| 128 |
+
if ori_h != h or ori_w != w:
|
| 129 |
+
new_abs_pos = (
|
| 130 |
+
F.interpolate(
|
| 131 |
+
abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
|
| 132 |
+
size=(h, w),
|
| 133 |
+
mode="bicubic",
|
| 134 |
+
align_corners=False,
|
| 135 |
+
)
|
| 136 |
+
.permute(0, 2, 3, 1)
|
| 137 |
+
.reshape(B, -1, C)
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
else:
|
| 141 |
+
new_abs_pos = abs_pos
|
| 142 |
+
|
| 143 |
+
if cls_token is not None:
|
| 144 |
+
new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1)
|
| 145 |
+
return new_abs_pos
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class DropPath(nn.Module):
|
| 149 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 150 |
+
|
| 151 |
+
def __init__(self, drop_prob=None):
|
| 152 |
+
super(DropPath, self).__init__()
|
| 153 |
+
self.drop_prob = drop_prob
|
| 154 |
+
|
| 155 |
+
def forward(self, x):
|
| 156 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 157 |
+
|
| 158 |
+
def extra_repr(self):
|
| 159 |
+
return "p={}".format(self.drop_prob)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class Mlp(nn.Module):
|
| 163 |
+
def __init__(
|
| 164 |
+
self,
|
| 165 |
+
in_features,
|
| 166 |
+
hidden_features=None,
|
| 167 |
+
out_features=None,
|
| 168 |
+
act_layer=nn.GELU,
|
| 169 |
+
drop=0.0,
|
| 170 |
+
):
|
| 171 |
+
super().__init__()
|
| 172 |
+
out_features = out_features or in_features
|
| 173 |
+
hidden_features = hidden_features or in_features
|
| 174 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 175 |
+
self.act = act_layer()
|
| 176 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 177 |
+
self.drop = nn.Dropout(drop)
|
| 178 |
+
|
| 179 |
+
def forward(self, x):
|
| 180 |
+
x = self.fc1(x)
|
| 181 |
+
x = self.act(x)
|
| 182 |
+
x = self.fc2(x)
|
| 183 |
+
x = self.drop(x)
|
| 184 |
+
return x
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class Attention(nn.Module):
|
| 188 |
+
def __init__(
|
| 189 |
+
self,
|
| 190 |
+
dim,
|
| 191 |
+
num_heads=8,
|
| 192 |
+
qkv_bias=False,
|
| 193 |
+
qk_scale=None,
|
| 194 |
+
attn_drop=0.0,
|
| 195 |
+
proj_drop=0.0,
|
| 196 |
+
attn_head_dim=None,
|
| 197 |
+
):
|
| 198 |
+
super().__init__()
|
| 199 |
+
self.num_heads = num_heads
|
| 200 |
+
head_dim = dim // num_heads
|
| 201 |
+
self.dim = dim
|
| 202 |
+
|
| 203 |
+
if attn_head_dim is not None:
|
| 204 |
+
head_dim = attn_head_dim
|
| 205 |
+
all_head_dim = head_dim * self.num_heads
|
| 206 |
+
|
| 207 |
+
self.scale = qk_scale or head_dim**-0.5
|
| 208 |
+
|
| 209 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
|
| 210 |
+
|
| 211 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 212 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
| 213 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 214 |
+
|
| 215 |
+
def forward(self, x):
|
| 216 |
+
B, N, C = x.shape
|
| 217 |
+
qkv = self.qkv(x)
|
| 218 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
| 219 |
+
q, k, v = (
|
| 220 |
+
qkv[0],
|
| 221 |
+
qkv[1],
|
| 222 |
+
qkv[2],
|
| 223 |
+
) # make torchscript happy (cannot use tensor as tuple)
|
| 224 |
+
|
| 225 |
+
q = q * self.scale
|
| 226 |
+
attn = q @ k.transpose(-2, -1)
|
| 227 |
+
|
| 228 |
+
attn = attn.softmax(dim=-1)
|
| 229 |
+
attn = self.attn_drop(attn)
|
| 230 |
+
|
| 231 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
| 232 |
+
x = self.proj(x)
|
| 233 |
+
x = self.proj_drop(x)
|
| 234 |
+
|
| 235 |
+
return x
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class FlashAttention(nn.Module):
|
| 239 |
+
def __init__(
|
| 240 |
+
self,
|
| 241 |
+
dim,
|
| 242 |
+
num_heads=8,
|
| 243 |
+
qkv_bias=False,
|
| 244 |
+
qk_scale=None,
|
| 245 |
+
attn_drop=0.0,
|
| 246 |
+
proj_drop=0.0,
|
| 247 |
+
attn_head_dim=None,
|
| 248 |
+
):
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.num_heads = num_heads
|
| 251 |
+
head_dim = attn_head_dim or (dim // num_heads)
|
| 252 |
+
self.head_dim = head_dim
|
| 253 |
+
self.dim = dim
|
| 254 |
+
self.qkv = nn.Linear(dim, head_dim * num_heads * 3, bias=qkv_bias)
|
| 255 |
+
self.proj = nn.Linear(head_dim * num_heads, dim)
|
| 256 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 257 |
+
self.attn_drop = attn_drop
|
| 258 |
+
|
| 259 |
+
def forward(self, x):
|
| 260 |
+
B, N, C = x.shape # (batch, sequence_length, embedding_dim)
|
| 261 |
+
|
| 262 |
+
qkv = self.qkv(x) # (B, N, 3 * num_heads * head_dim)
|
| 263 |
+
qkv = qkv.view(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 264 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # each: (B, num_heads, N, head_dim)
|
| 265 |
+
|
| 266 |
+
# FlashAttention expects (B, N, num_heads, head_dim)
|
| 267 |
+
q = q.transpose(1, 2).contiguous()
|
| 268 |
+
k = k.transpose(1, 2).contiguous()
|
| 269 |
+
v = v.transpose(1, 2).contiguous()
|
| 270 |
+
|
| 271 |
+
# Optional: FlashAttention requires fp16 or bf16
|
| 272 |
+
if q.dtype == torch.float32:
|
| 273 |
+
q = q.half()
|
| 274 |
+
k = k.half()
|
| 275 |
+
v = v.half()
|
| 276 |
+
|
| 277 |
+
out = flash_attn_func(
|
| 278 |
+
q, k, v, dropout_p=self.attn_drop, causal=False
|
| 279 |
+
) # (B, N, num_heads * head_dim)
|
| 280 |
+
|
| 281 |
+
# If needed, cast back to float32
|
| 282 |
+
out = out.reshape(B, N, -1)
|
| 283 |
+
out = out.to(x.dtype)
|
| 284 |
+
# breakpoint()
|
| 285 |
+
out = self.proj(out)
|
| 286 |
+
out = self.proj_drop(out)
|
| 287 |
+
return out
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class Block(nn.Module):
|
| 291 |
+
|
| 292 |
+
def __init__(
|
| 293 |
+
self,
|
| 294 |
+
dim,
|
| 295 |
+
num_heads,
|
| 296 |
+
mlp_ratio=4.0,
|
| 297 |
+
qkv_bias=False,
|
| 298 |
+
qk_scale=None,
|
| 299 |
+
drop=0.0,
|
| 300 |
+
attn_drop=0.0,
|
| 301 |
+
drop_path=0.0,
|
| 302 |
+
act_layer=nn.GELU,
|
| 303 |
+
norm_layer=nn.LayerNorm,
|
| 304 |
+
attn_head_dim=None,
|
| 305 |
+
flash_attn=False,
|
| 306 |
+
):
|
| 307 |
+
super().__init__()
|
| 308 |
+
|
| 309 |
+
self.norm1 = norm_layer(dim)
|
| 310 |
+
if flash_attn:
|
| 311 |
+
self.attn = FlashAttention(
|
| 312 |
+
dim,
|
| 313 |
+
num_heads=num_heads,
|
| 314 |
+
qkv_bias=qkv_bias,
|
| 315 |
+
qk_scale=qk_scale,
|
| 316 |
+
attn_drop=attn_drop,
|
| 317 |
+
proj_drop=drop,
|
| 318 |
+
attn_head_dim=attn_head_dim,
|
| 319 |
+
)
|
| 320 |
+
else:
|
| 321 |
+
self.attn = Attention(
|
| 322 |
+
dim,
|
| 323 |
+
num_heads=num_heads,
|
| 324 |
+
qkv_bias=qkv_bias,
|
| 325 |
+
qk_scale=qk_scale,
|
| 326 |
+
attn_drop=attn_drop,
|
| 327 |
+
proj_drop=drop,
|
| 328 |
+
attn_head_dim=attn_head_dim,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 332 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 333 |
+
self.norm2 = norm_layer(dim)
|
| 334 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 335 |
+
self.mlp = Mlp(
|
| 336 |
+
in_features=dim,
|
| 337 |
+
hidden_features=mlp_hidden_dim,
|
| 338 |
+
act_layer=act_layer,
|
| 339 |
+
drop=drop,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
def forward(self, x):
|
| 343 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
| 344 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 345 |
+
return x
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
class PatchEmbed(nn.Module):
|
| 349 |
+
"""Image to Patch Embedding"""
|
| 350 |
+
|
| 351 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
|
| 352 |
+
super().__init__()
|
| 353 |
+
img_size = to_2tuple(img_size)
|
| 354 |
+
patch_size = to_2tuple(patch_size)
|
| 355 |
+
num_patches = (
|
| 356 |
+
(img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio**2)
|
| 357 |
+
)
|
| 358 |
+
self.patch_shape = (
|
| 359 |
+
int(img_size[0] // patch_size[0] * ratio),
|
| 360 |
+
int(img_size[1] // patch_size[1] * ratio),
|
| 361 |
+
)
|
| 362 |
+
self.origin_patch_shape = (
|
| 363 |
+
int(img_size[0] // patch_size[0]),
|
| 364 |
+
int(img_size[1] // patch_size[1]),
|
| 365 |
+
)
|
| 366 |
+
self.img_size = img_size
|
| 367 |
+
self.patch_size = patch_size
|
| 368 |
+
self.num_patches = num_patches
|
| 369 |
+
|
| 370 |
+
self.proj = nn.Conv2d(
|
| 371 |
+
in_chans,
|
| 372 |
+
embed_dim,
|
| 373 |
+
kernel_size=patch_size,
|
| 374 |
+
stride=(patch_size[0] // ratio),
|
| 375 |
+
padding=4 + 2 * (ratio // 2 - 1),
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
def forward(self, x, **kwargs):
|
| 379 |
+
B, C, H, W = x.shape
|
| 380 |
+
x = self.proj(x)
|
| 381 |
+
Hp, Wp = x.shape[2], x.shape[3]
|
| 382 |
+
|
| 383 |
+
x = x.flatten(2).transpose(1, 2)
|
| 384 |
+
return x, (Hp, Wp)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class PatchEmbedNoPadding(nn.Module):
|
| 388 |
+
"""Image to Patch Embedding"""
|
| 389 |
+
|
| 390 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
|
| 391 |
+
super().__init__()
|
| 392 |
+
img_size = to_2tuple(img_size)
|
| 393 |
+
patch_size = to_2tuple(patch_size)
|
| 394 |
+
num_patches = (
|
| 395 |
+
(img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio**2)
|
| 396 |
+
)
|
| 397 |
+
self.patch_shape = (
|
| 398 |
+
int(img_size[0] // patch_size[0] * ratio),
|
| 399 |
+
int(img_size[1] // patch_size[1] * ratio),
|
| 400 |
+
)
|
| 401 |
+
self.origin_patch_shape = (
|
| 402 |
+
int(img_size[0] // patch_size[0]),
|
| 403 |
+
int(img_size[1] // patch_size[1]),
|
| 404 |
+
)
|
| 405 |
+
self.img_size = img_size
|
| 406 |
+
self.patch_size = patch_size
|
| 407 |
+
self.num_patches = num_patches
|
| 408 |
+
|
| 409 |
+
self.proj = nn.Conv2d(
|
| 410 |
+
in_chans,
|
| 411 |
+
embed_dim,
|
| 412 |
+
kernel_size=patch_size,
|
| 413 |
+
stride=(patch_size[0] // ratio),
|
| 414 |
+
padding=0,
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
def forward(self, x, **kwargs):
|
| 418 |
+
B, C, H, W = x.shape
|
| 419 |
+
x = self.proj(x)
|
| 420 |
+
Hp, Wp = x.shape[2], x.shape[3]
|
| 421 |
+
|
| 422 |
+
x = x.flatten(2).transpose(1, 2)
|
| 423 |
+
return x, (Hp, Wp)
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
class HybridEmbed(nn.Module):
|
| 427 |
+
"""CNN Feature Map Embedding
|
| 428 |
+
Extract feature map from CNN, flatten, project to embedding dim.
|
| 429 |
+
"""
|
| 430 |
+
|
| 431 |
+
def __init__(
|
| 432 |
+
self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768
|
| 433 |
+
):
|
| 434 |
+
super().__init__()
|
| 435 |
+
assert isinstance(backbone, nn.Module)
|
| 436 |
+
img_size = to_2tuple(img_size)
|
| 437 |
+
self.img_size = img_size
|
| 438 |
+
self.backbone = backbone
|
| 439 |
+
if feature_size is None:
|
| 440 |
+
with torch.no_grad():
|
| 441 |
+
training = backbone.training
|
| 442 |
+
if training:
|
| 443 |
+
backbone.eval()
|
| 444 |
+
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[
|
| 445 |
+
-1
|
| 446 |
+
]
|
| 447 |
+
feature_size = o.shape[-2:]
|
| 448 |
+
feature_dim = o.shape[1]
|
| 449 |
+
backbone.train(training)
|
| 450 |
+
else:
|
| 451 |
+
feature_size = to_2tuple(feature_size)
|
| 452 |
+
feature_dim = self.backbone.feature_info.channels()[-1]
|
| 453 |
+
self.num_patches = feature_size[0] * feature_size[1]
|
| 454 |
+
self.proj = nn.Linear(feature_dim, embed_dim)
|
| 455 |
+
|
| 456 |
+
def forward(self, x):
|
| 457 |
+
x = self.backbone(x)[-1]
|
| 458 |
+
x = x.flatten(2).transpose(1, 2)
|
| 459 |
+
x = self.proj(x)
|
| 460 |
+
return x
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
class ViT(nn.Module):
|
| 464 |
+
|
| 465 |
+
def __init__(
|
| 466 |
+
self,
|
| 467 |
+
img_size=224,
|
| 468 |
+
patch_size=16,
|
| 469 |
+
in_chans=3,
|
| 470 |
+
num_classes=80,
|
| 471 |
+
embed_dim=768,
|
| 472 |
+
depth=12,
|
| 473 |
+
num_heads=12,
|
| 474 |
+
mlp_ratio=4.0,
|
| 475 |
+
qkv_bias=False,
|
| 476 |
+
qk_scale=None,
|
| 477 |
+
drop_rate=0.0,
|
| 478 |
+
attn_drop_rate=0.0,
|
| 479 |
+
drop_path_rate=0.0,
|
| 480 |
+
hybrid_backbone=None,
|
| 481 |
+
norm_layer=None,
|
| 482 |
+
use_checkpoint=False,
|
| 483 |
+
frozen_stages=-1,
|
| 484 |
+
ratio=1,
|
| 485 |
+
last_norm=True,
|
| 486 |
+
patch_padding="pad",
|
| 487 |
+
freeze_attn=False,
|
| 488 |
+
freeze_ffn=False,
|
| 489 |
+
flash_attn=False,
|
| 490 |
+
no_patch_padding=False,
|
| 491 |
+
):
|
| 492 |
+
# Protect mutable default arguments
|
| 493 |
+
super(ViT, self).__init__()
|
| 494 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
| 495 |
+
self.num_classes = num_classes
|
| 496 |
+
self.num_features = self.embed_dim = self.embed_dims = (
|
| 497 |
+
embed_dim # num_features for consistency with other models
|
| 498 |
+
)
|
| 499 |
+
self.frozen_stages = frozen_stages
|
| 500 |
+
self.use_checkpoint = use_checkpoint
|
| 501 |
+
self.patch_padding = patch_padding
|
| 502 |
+
self.freeze_attn = freeze_attn
|
| 503 |
+
self.freeze_ffn = freeze_ffn
|
| 504 |
+
self.depth = depth
|
| 505 |
+
|
| 506 |
+
if hybrid_backbone is not None:
|
| 507 |
+
self.patch_embed = HybridEmbed(
|
| 508 |
+
hybrid_backbone,
|
| 509 |
+
img_size=img_size,
|
| 510 |
+
in_chans=in_chans,
|
| 511 |
+
embed_dim=embed_dim,
|
| 512 |
+
)
|
| 513 |
+
else:
|
| 514 |
+
if no_patch_padding:
|
| 515 |
+
self.patch_embed = PatchEmbedNoPadding(
|
| 516 |
+
img_size=img_size,
|
| 517 |
+
patch_size=patch_size,
|
| 518 |
+
in_chans=in_chans,
|
| 519 |
+
embed_dim=embed_dim,
|
| 520 |
+
ratio=ratio,
|
| 521 |
+
)
|
| 522 |
+
else:
|
| 523 |
+
self.patch_embed = PatchEmbed(
|
| 524 |
+
img_size=img_size,
|
| 525 |
+
patch_size=patch_size,
|
| 526 |
+
in_chans=in_chans,
|
| 527 |
+
embed_dim=embed_dim,
|
| 528 |
+
ratio=ratio,
|
| 529 |
+
)
|
| 530 |
+
num_patches = self.patch_embed.num_patches
|
| 531 |
+
self.patch_size = patch_size
|
| 532 |
+
|
| 533 |
+
# since the pretraining model has class token
|
| 534 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
| 535 |
+
|
| 536 |
+
dpr = [
|
| 537 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
| 538 |
+
] # stochastic depth decay rule
|
| 539 |
+
|
| 540 |
+
self.blocks = nn.ModuleList(
|
| 541 |
+
[
|
| 542 |
+
Block(
|
| 543 |
+
dim=embed_dim,
|
| 544 |
+
num_heads=num_heads,
|
| 545 |
+
mlp_ratio=mlp_ratio,
|
| 546 |
+
qkv_bias=qkv_bias,
|
| 547 |
+
qk_scale=qk_scale,
|
| 548 |
+
drop=drop_rate,
|
| 549 |
+
attn_drop=attn_drop_rate,
|
| 550 |
+
drop_path=dpr[i],
|
| 551 |
+
norm_layer=norm_layer,
|
| 552 |
+
flash_attn=flash_attn,
|
| 553 |
+
)
|
| 554 |
+
for i in range(depth)
|
| 555 |
+
]
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
|
| 559 |
+
|
| 560 |
+
if self.pos_embed is not None:
|
| 561 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
| 562 |
+
|
| 563 |
+
self._freeze_stages()
|
| 564 |
+
|
| 565 |
+
def _freeze_stages(self):
|
| 566 |
+
"""Freeze parameters."""
|
| 567 |
+
if self.frozen_stages >= 0:
|
| 568 |
+
self.patch_embed.eval()
|
| 569 |
+
for param in self.patch_embed.parameters():
|
| 570 |
+
param.requires_grad = False
|
| 571 |
+
|
| 572 |
+
for i in range(1, self.frozen_stages + 1):
|
| 573 |
+
m = self.blocks[i - 1]
|
| 574 |
+
m.eval()
|
| 575 |
+
for param in m.parameters():
|
| 576 |
+
param.requires_grad = False
|
| 577 |
+
|
| 578 |
+
if self.freeze_attn:
|
| 579 |
+
for i in range(0, self.depth):
|
| 580 |
+
m = self.blocks[i]
|
| 581 |
+
m.attn.eval()
|
| 582 |
+
m.norm1.eval()
|
| 583 |
+
for param in m.attn.parameters():
|
| 584 |
+
param.requires_grad = False
|
| 585 |
+
for param in m.norm1.parameters():
|
| 586 |
+
param.requires_grad = False
|
| 587 |
+
|
| 588 |
+
if self.freeze_ffn:
|
| 589 |
+
self.pos_embed.requires_grad = False
|
| 590 |
+
self.patch_embed.eval()
|
| 591 |
+
for param in self.patch_embed.parameters():
|
| 592 |
+
param.requires_grad = False
|
| 593 |
+
for i in range(0, self.depth):
|
| 594 |
+
m = self.blocks[i]
|
| 595 |
+
m.mlp.eval()
|
| 596 |
+
m.norm2.eval()
|
| 597 |
+
for param in m.mlp.parameters():
|
| 598 |
+
param.requires_grad = False
|
| 599 |
+
for param in m.norm2.parameters():
|
| 600 |
+
param.requires_grad = False
|
| 601 |
+
|
| 602 |
+
def init_weights(self):
|
| 603 |
+
"""Initialize the weights in backbone.
|
| 604 |
+
Args:
|
| 605 |
+
pretrained (str, optional): Path to pre-trained weights.
|
| 606 |
+
Defaults to None.
|
| 607 |
+
"""
|
| 608 |
+
|
| 609 |
+
def _init_weights(m):
|
| 610 |
+
if isinstance(m, nn.Linear):
|
| 611 |
+
trunc_normal_(m.weight, std=0.02)
|
| 612 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 613 |
+
nn.init.constant_(m.bias, 0)
|
| 614 |
+
elif isinstance(m, nn.LayerNorm):
|
| 615 |
+
nn.init.constant_(m.bias, 0)
|
| 616 |
+
nn.init.constant_(m.weight, 1.0)
|
| 617 |
+
|
| 618 |
+
self.apply(_init_weights)
|
| 619 |
+
|
| 620 |
+
def get_num_layers(self):
|
| 621 |
+
return len(self.blocks)
|
| 622 |
+
|
| 623 |
+
@torch.jit.ignore
|
| 624 |
+
def no_weight_decay(self):
|
| 625 |
+
return {"pos_embed", "cls_token"}
|
| 626 |
+
|
| 627 |
+
def forward_features(self, x, extra_embed=None):
|
| 628 |
+
B, C, H, W = x.shape
|
| 629 |
+
x, (Hp, Wp) = self.patch_embed(x)
|
| 630 |
+
|
| 631 |
+
if self.pos_embed is not None:
|
| 632 |
+
# fit for multiple GPU training
|
| 633 |
+
# since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
|
| 634 |
+
x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
|
| 635 |
+
|
| 636 |
+
if extra_embed is not None:
|
| 637 |
+
x = x + extra_embed.flatten(2).transpose(1, 2).to(x)
|
| 638 |
+
|
| 639 |
+
for blk in self.blocks:
|
| 640 |
+
if self.use_checkpoint:
|
| 641 |
+
x = checkpoint.checkpoint(blk, x)
|
| 642 |
+
else:
|
| 643 |
+
x = blk(x)
|
| 644 |
+
|
| 645 |
+
x = self.last_norm(x)
|
| 646 |
+
|
| 647 |
+
xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
|
| 648 |
+
|
| 649 |
+
return xp
|
| 650 |
+
|
| 651 |
+
def forward(self, x, *args, **kwargs):
|
| 652 |
+
x = self.forward_features(x, *args, **kwargs)
|
| 653 |
+
return x
|
| 654 |
+
|
| 655 |
+
def train(self, mode=True):
|
| 656 |
+
"""Convert the model into training mode."""
|
| 657 |
+
super().train(mode)
|
| 658 |
+
self._freeze_stages()
|
src/sam3d_body/models/decoders/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
from .keypoint_prompt_sampler import build_keypoint_sampler
|
| 4 |
+
from .prompt_encoder import PromptEncoder
|
| 5 |
+
from .promptable_decoder import PromptableDecoder
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def build_decoder(cfg, context_dim=None):
|
| 9 |
+
from .promptable_decoder import PromptableDecoder
|
| 10 |
+
|
| 11 |
+
if cfg.TYPE == "sam":
|
| 12 |
+
return PromptableDecoder(
|
| 13 |
+
dims=cfg.DIM,
|
| 14 |
+
context_dims=context_dim,
|
| 15 |
+
depth=cfg.DEPTH,
|
| 16 |
+
num_heads=cfg.HEADS,
|
| 17 |
+
head_dims=cfg.DIM_HEAD,
|
| 18 |
+
mlp_dims=cfg.MLP_DIM,
|
| 19 |
+
layer_scale_init_value=cfg.LAYER_SCALE_INIT,
|
| 20 |
+
drop_rate=cfg.DROP_RATE,
|
| 21 |
+
attn_drop_rate=cfg.ATTN_DROP_RATE,
|
| 22 |
+
drop_path_rate=cfg.DROP_PATH_RATE,
|
| 23 |
+
ffn_type=cfg.FFN_TYPE,
|
| 24 |
+
enable_twoway=cfg.ENABLE_TWOWAY,
|
| 25 |
+
repeat_pe=cfg.REPEAT_PE,
|
| 26 |
+
frozen=cfg.get("FROZEN", False),
|
| 27 |
+
do_interm_preds=cfg.get("DO_INTERM_PREDS", False),
|
| 28 |
+
do_keypoint_tokens=cfg.get("DO_KEYPOINT_TOKENS", False),
|
| 29 |
+
keypoint_token_update=cfg.get("KEYPOINT_TOKEN_UPDATE", None),
|
| 30 |
+
)
|
| 31 |
+
else:
|
| 32 |
+
raise ValueError("Invalid decoder type: ", cfg.TYPE)
|
src/sam3d_body/models/decoders/keypoint_prompt_sampler.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
import random
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from typing import Dict, List
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from omegaconf import DictConfig
|
| 10 |
+
from yacs.config import CfgNode
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def build_keypoint_sampler(sampler_cfg, prompt_keypoints, keybody_idx):
|
| 14 |
+
sampler_type = sampler_cfg.get("TYPE", "v1")
|
| 15 |
+
if sampler_type == "v1":
|
| 16 |
+
sampler_cls = KeypointSamplerV1
|
| 17 |
+
else:
|
| 18 |
+
raise ValueError("Invalid sampler type: ", sampler_type)
|
| 19 |
+
|
| 20 |
+
return sampler_cls(sampler_cfg, prompt_keypoints, keybody_idx)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class BaseKeypointSampler(ABC):
|
| 24 |
+
@abstractmethod
|
| 25 |
+
def sample(
|
| 26 |
+
self, gt_keypoints: torch.Tensor, pred_keypoints: torch.Tensor, is_train: bool
|
| 27 |
+
) -> torch.Tensor:
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
def _get_worst_keypoint(self, distances, keypoint_list):
|
| 31 |
+
# Set distance to -1 for non-promptable keypoints
|
| 32 |
+
cur_dist = torch.ones_like(distances) * -1
|
| 33 |
+
cur_dist[keypoint_list] = distances[keypoint_list]
|
| 34 |
+
keypoint_idx = int(cur_dist.argmax())
|
| 35 |
+
if cur_dist[keypoint_idx] > self.distance_thresh:
|
| 36 |
+
valid_keypoint = True
|
| 37 |
+
else:
|
| 38 |
+
valid_keypoint = False
|
| 39 |
+
return keypoint_idx, valid_keypoint
|
| 40 |
+
|
| 41 |
+
def _get_random_keypoint(self, distances, keypoint_list):
|
| 42 |
+
candidates = [idx for idx in keypoint_list if distances[idx] > 0]
|
| 43 |
+
if len(candidates):
|
| 44 |
+
keypoint_idx = random.choice(candidates)
|
| 45 |
+
valid_keypoint = True
|
| 46 |
+
else:
|
| 47 |
+
keypoint_idx = None
|
| 48 |
+
valid_keypoint = False
|
| 49 |
+
return keypoint_idx, valid_keypoint
|
| 50 |
+
|
| 51 |
+
def _masked_distance(self, x, y, mask=None):
|
| 52 |
+
"""
|
| 53 |
+
Args:
|
| 54 |
+
x, y: [B, K, D]
|
| 55 |
+
mask: [B, K]
|
| 56 |
+
Return:
|
| 57 |
+
distances: [K, B]
|
| 58 |
+
"""
|
| 59 |
+
distances = (x - y).pow(2).sum(dim=-1)
|
| 60 |
+
if mask is not None:
|
| 61 |
+
distances[mask] = -1
|
| 62 |
+
return distances.T
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class KeypointSamplerV1(BaseKeypointSampler):
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
sampler_cfg: DictConfig | CfgNode,
|
| 69 |
+
prompt_keypoints: Dict,
|
| 70 |
+
keybody_idx: List,
|
| 71 |
+
):
|
| 72 |
+
self.prompt_keypoints = prompt_keypoints
|
| 73 |
+
self._keybody_idx = keybody_idx
|
| 74 |
+
self._non_keybody_idx = [
|
| 75 |
+
idx for idx in self.prompt_keypoints if idx not in self._keybody_idx
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
self.keybody_ratio = sampler_cfg.get("KEYBODY_RATIO", 0.8)
|
| 79 |
+
self.worst_ratio = sampler_cfg.get("WORST_RATIO", 0.8)
|
| 80 |
+
self.negative_ratio = sampler_cfg.get("NEGATIVE_RATIO", 0.0)
|
| 81 |
+
self.dummy_ratio = sampler_cfg.get("DUMMY_RATIO", 0.1)
|
| 82 |
+
self.distance_thresh = sampler_cfg.get("DISTANCE_THRESH", 0.0)
|
| 83 |
+
|
| 84 |
+
def sample(
|
| 85 |
+
self,
|
| 86 |
+
gt_keypoints_2d: torch.Tensor,
|
| 87 |
+
pred_keypoints_2d: torch.Tensor,
|
| 88 |
+
is_train: bool = True,
|
| 89 |
+
force_dummy: bool = False,
|
| 90 |
+
) -> torch.Tensor:
|
| 91 |
+
# Get the distance between each predicted and gt keypoint
|
| 92 |
+
# Elements will be ignored if (1) the gt has low confidence or
|
| 93 |
+
# (2) both the gt and pred are outside of the image
|
| 94 |
+
mask_1 = gt_keypoints_2d[:, :, -1] < 0.5
|
| 95 |
+
mask_2 = (
|
| 96 |
+
(gt_keypoints_2d[:, :, :2] > 0.5) | (gt_keypoints_2d[:, :, :2] < -0.5)
|
| 97 |
+
).any(dim=-1)
|
| 98 |
+
|
| 99 |
+
# Elements to be ignored
|
| 100 |
+
if not is_train or torch.rand(1).item() > self.negative_ratio:
|
| 101 |
+
mask = mask_1 | mask_2
|
| 102 |
+
# print_base = "positive"
|
| 103 |
+
else:
|
| 104 |
+
mask_3 = (
|
| 105 |
+
(pred_keypoints_2d[:, :, :2] > 0.5)
|
| 106 |
+
| (pred_keypoints_2d[:, :, :2] < -0.5)
|
| 107 |
+
).any(dim=-1)
|
| 108 |
+
# To include negative prompts
|
| 109 |
+
mask = mask_1 | (mask_2 & mask_3)
|
| 110 |
+
# print_base = "negative"
|
| 111 |
+
|
| 112 |
+
# Get pairwise distances with shape [K, B]
|
| 113 |
+
distances = self._masked_distance(
|
| 114 |
+
pred_keypoints_2d, gt_keypoints_2d[..., :2], mask
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
batch_size = distances.shape[1]
|
| 118 |
+
keypoints_prompt = []
|
| 119 |
+
for b in range(batch_size):
|
| 120 |
+
# print_str = print_base
|
| 121 |
+
|
| 122 |
+
# Decide to get the worst keypoint or a random keypoint
|
| 123 |
+
if not is_train or torch.rand(1).item() < self.worst_ratio:
|
| 124 |
+
sampler = self._get_worst_keypoint
|
| 125 |
+
# print_str += "_worst"
|
| 126 |
+
else:
|
| 127 |
+
sampler = self._get_random_keypoint
|
| 128 |
+
# print_str += "_random"
|
| 129 |
+
|
| 130 |
+
# Decide to prompt keybody kepoints or non-keybody ones
|
| 131 |
+
if not is_train or torch.rand(1).item() < self.keybody_ratio:
|
| 132 |
+
cur_idx = self._keybody_idx
|
| 133 |
+
alt_idx = self._non_keybody_idx
|
| 134 |
+
# print_str += "_keybody"
|
| 135 |
+
else:
|
| 136 |
+
cur_idx = self._non_keybody_idx
|
| 137 |
+
alt_idx = self._keybody_idx
|
| 138 |
+
# print_str += "_nonkey"
|
| 139 |
+
|
| 140 |
+
# Get a valid or dummy prompt
|
| 141 |
+
if not is_train or torch.rand(1).item() > self.dummy_ratio:
|
| 142 |
+
keypoint_idx, valid_keypoint = sampler(distances[:, b], cur_idx)
|
| 143 |
+
|
| 144 |
+
if not valid_keypoint:
|
| 145 |
+
# Try the alternative keypoints
|
| 146 |
+
keypoint_idx, valid_keypoint = self._get_worst_keypoint(
|
| 147 |
+
distances[:, b], alt_idx
|
| 148 |
+
)
|
| 149 |
+
else:
|
| 150 |
+
valid_keypoint = False
|
| 151 |
+
|
| 152 |
+
if valid_keypoint:
|
| 153 |
+
cur_point = gt_keypoints_2d[b, keypoint_idx].clone()
|
| 154 |
+
if torch.any(cur_point[:2] > 0.5) or torch.any(cur_point[:2] < -0.5):
|
| 155 |
+
# Negative prompt --> indicating the predicted keypoint is incorrect
|
| 156 |
+
cur_point[:2] = pred_keypoints_2d[b, keypoint_idx][:2]
|
| 157 |
+
cur_point = torch.clamp(
|
| 158 |
+
cur_point + 0.5, min=0.0, max=1.0
|
| 159 |
+
) # shift from [-0.5, 0.5] to [0, 1]
|
| 160 |
+
cur_point[-1] = -1
|
| 161 |
+
# print_str += "_negative"
|
| 162 |
+
else:
|
| 163 |
+
cur_point = torch.clamp(
|
| 164 |
+
cur_point + 0.5, min=0.0, max=1.0
|
| 165 |
+
) # shift from [-0.5, 0.5] to [0, 1]
|
| 166 |
+
cur_point[-1] = self.prompt_keypoints[
|
| 167 |
+
keypoint_idx
|
| 168 |
+
] # map to prompt_idx
|
| 169 |
+
# print_str += "_positive"
|
| 170 |
+
else:
|
| 171 |
+
cur_point = torch.zeros(3).to(gt_keypoints_2d)
|
| 172 |
+
cur_point[-1] = -2
|
| 173 |
+
# print_str += "_dummy"
|
| 174 |
+
|
| 175 |
+
if force_dummy:
|
| 176 |
+
cur_point = torch.zeros(3).to(gt_keypoints_2d)
|
| 177 |
+
cur_point[-1] = -2
|
| 178 |
+
|
| 179 |
+
keypoints_prompt.append(cur_point)
|
| 180 |
+
# print(print_str)
|
| 181 |
+
|
| 182 |
+
keypoints_prompt = torch.stack(keypoints_prompt, dim=0).view(batch_size, 1, 3)
|
| 183 |
+
return keypoints_prompt
|
src/sam3d_body/models/decoders/prompt_encoder.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
from typing import Any, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
from sam3d_body.models.modules.transformer import LayerNorm2d
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class PromptEncoder(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
embed_dim: int,
|
| 17 |
+
num_body_joints: int,
|
| 18 |
+
# img_size: Tuple[int, int],
|
| 19 |
+
# patch_resolution: Tuple[int, int],
|
| 20 |
+
frozen: bool = False,
|
| 21 |
+
mask_embed_type: Optional[str] = None,
|
| 22 |
+
) -> None:
|
| 23 |
+
"""
|
| 24 |
+
Encodes prompts for input to SAM's mask decoder.
|
| 25 |
+
|
| 26 |
+
Arguments:
|
| 27 |
+
embed_dim (int): The prompts' embedding dimension
|
| 28 |
+
num_body_joints (int): The number of body joints
|
| 29 |
+
img_size (Tuple): The padded size of the image as input
|
| 30 |
+
to the image encoder, as (H, W).
|
| 31 |
+
patch_resolution (Tuple): image path size, as (H, W)
|
| 32 |
+
"""
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.embed_dim = embed_dim
|
| 35 |
+
self.num_body_joints = num_body_joints
|
| 36 |
+
# self.img_size = img_size
|
| 37 |
+
# self.patch_resolution = patch_resolution
|
| 38 |
+
|
| 39 |
+
# Keypoint prompts
|
| 40 |
+
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
| 41 |
+
self.point_embeddings = nn.ModuleList(
|
| 42 |
+
[nn.Embedding(1, embed_dim) for _ in range(self.num_body_joints)]
|
| 43 |
+
)
|
| 44 |
+
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
| 45 |
+
self.invalid_point_embed = nn.Embedding(1, embed_dim)
|
| 46 |
+
|
| 47 |
+
# Mask prompt
|
| 48 |
+
if mask_embed_type in ["v1"]:
|
| 49 |
+
mask_in_chans = 16 # SAM2
|
| 50 |
+
self.mask_downscaling = nn.Sequential(
|
| 51 |
+
nn.Conv2d(1, mask_in_chans // 4, kernel_size=4, stride=4),
|
| 52 |
+
LayerNorm2d(mask_in_chans // 4),
|
| 53 |
+
nn.GELU(),
|
| 54 |
+
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=4, stride=4),
|
| 55 |
+
LayerNorm2d(mask_in_chans),
|
| 56 |
+
nn.GELU(),
|
| 57 |
+
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
|
| 58 |
+
)
|
| 59 |
+
elif mask_embed_type in ["v2"]:
|
| 60 |
+
mask_in_chans = 256
|
| 61 |
+
self.mask_downscaling = nn.Sequential(
|
| 62 |
+
nn.Conv2d(1, mask_in_chans // 64, kernel_size=2, stride=2),
|
| 63 |
+
LayerNorm2d(mask_in_chans // 64),
|
| 64 |
+
nn.GELU(),
|
| 65 |
+
nn.Conv2d(
|
| 66 |
+
mask_in_chans // 64,
|
| 67 |
+
mask_in_chans // 16,
|
| 68 |
+
kernel_size=2,
|
| 69 |
+
stride=2,
|
| 70 |
+
),
|
| 71 |
+
LayerNorm2d(mask_in_chans // 16),
|
| 72 |
+
nn.GELU(),
|
| 73 |
+
nn.Conv2d(
|
| 74 |
+
mask_in_chans // 16, mask_in_chans // 4, kernel_size=2, stride=2
|
| 75 |
+
),
|
| 76 |
+
LayerNorm2d(mask_in_chans // 4),
|
| 77 |
+
nn.GELU(),
|
| 78 |
+
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
|
| 79 |
+
LayerNorm2d(mask_in_chans),
|
| 80 |
+
nn.GELU(),
|
| 81 |
+
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
|
| 82 |
+
)
|
| 83 |
+
else:
|
| 84 |
+
assert mask_embed_type is None
|
| 85 |
+
|
| 86 |
+
if mask_embed_type is not None:
|
| 87 |
+
# Zero-initialize the last conv layer as gating
|
| 88 |
+
nn.init.zeros_(self.mask_downscaling[-1].weight)
|
| 89 |
+
nn.init.zeros_(self.mask_downscaling[-1].bias)
|
| 90 |
+
|
| 91 |
+
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
| 92 |
+
nn.init.zeros_(self.no_mask_embed.weight)
|
| 93 |
+
|
| 94 |
+
self.frozen = frozen
|
| 95 |
+
self._freeze_stages()
|
| 96 |
+
|
| 97 |
+
def get_dense_pe(self, size: Tuple[int, int]) -> torch.Tensor:
|
| 98 |
+
"""
|
| 99 |
+
Returns the positional encoding used to encode point prompts,
|
| 100 |
+
applied to a dense set of points the shape of the image encoding.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
torch.Tensor: Positional encoding with shape
|
| 104 |
+
1x(embed_dim)x(embedding_h)x(embedding_w)
|
| 105 |
+
"""
|
| 106 |
+
return self.pe_layer(size).unsqueeze(0)
|
| 107 |
+
|
| 108 |
+
def _embed_keypoints(
|
| 109 |
+
self,
|
| 110 |
+
points: torch.Tensor,
|
| 111 |
+
labels: torch.Tensor,
|
| 112 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 113 |
+
"""
|
| 114 |
+
Embeds point prompts.
|
| 115 |
+
Assuming points have been normalized to [0, 1].
|
| 116 |
+
|
| 117 |
+
Output shape [B, N, C], mask shape [B, N]
|
| 118 |
+
"""
|
| 119 |
+
assert points.min() >= 0 and points.max() <= 1
|
| 120 |
+
point_embedding = self.pe_layer._pe_encoding(points.to(torch.float))
|
| 121 |
+
point_embedding[labels == -2] = 0.0 # invalid points
|
| 122 |
+
point_embedding[labels == -2] += self.invalid_point_embed.weight
|
| 123 |
+
point_embedding[labels == -1] = 0.0
|
| 124 |
+
point_embedding[labels == -1] += self.not_a_point_embed.weight
|
| 125 |
+
for i in range(self.num_body_joints):
|
| 126 |
+
point_embedding[labels == i] += self.point_embeddings[i].weight
|
| 127 |
+
|
| 128 |
+
point_mask = labels > -2
|
| 129 |
+
return point_embedding, point_mask
|
| 130 |
+
|
| 131 |
+
def _get_batch_size(
|
| 132 |
+
self,
|
| 133 |
+
keypoints: Optional[torch.Tensor],
|
| 134 |
+
boxes: Optional[torch.Tensor],
|
| 135 |
+
masks: Optional[torch.Tensor],
|
| 136 |
+
) -> int:
|
| 137 |
+
"""
|
| 138 |
+
Gets the batch size of the output given the batch size of the input prompts.
|
| 139 |
+
"""
|
| 140 |
+
if keypoints is not None:
|
| 141 |
+
return keypoints.shape[0]
|
| 142 |
+
elif boxes is not None:
|
| 143 |
+
return boxes.shape[0]
|
| 144 |
+
elif masks is not None:
|
| 145 |
+
return masks.shape[0]
|
| 146 |
+
else:
|
| 147 |
+
return 1
|
| 148 |
+
|
| 149 |
+
def _get_device(self) -> torch.device:
|
| 150 |
+
return self.point_embeddings[0].weight.device
|
| 151 |
+
|
| 152 |
+
def forward(
|
| 153 |
+
self,
|
| 154 |
+
keypoints: Optional[torch.Tensor],
|
| 155 |
+
boxes: Optional[torch.Tensor] = None,
|
| 156 |
+
masks: Optional[torch.Tensor] = None,
|
| 157 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 158 |
+
"""
|
| 159 |
+
Embeds different types of prompts, returning both sparse and dense
|
| 160 |
+
embeddings.
|
| 161 |
+
|
| 162 |
+
Arguments:
|
| 163 |
+
keypoints (torchTensor or none): point coordinates and labels to embed.
|
| 164 |
+
boxes (torch.Tensor or none): boxes to embed
|
| 165 |
+
masks (torch.Tensor or none): masks to embed
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
torch.Tensor: sparse embeddings for the points and boxes, with shape
|
| 169 |
+
BxNx(embed_dim), where N is determined by the number of input points
|
| 170 |
+
and boxes.
|
| 171 |
+
torch.Tensor: dense embeddings for the masks, in the shape
|
| 172 |
+
Bx(embed_dim)x(embed_H)x(embed_W)
|
| 173 |
+
"""
|
| 174 |
+
bs = self._get_batch_size(keypoints, boxes, masks)
|
| 175 |
+
sparse_embeddings = torch.empty(
|
| 176 |
+
(bs, 0, self.embed_dim), device=self._get_device()
|
| 177 |
+
)
|
| 178 |
+
sparse_masks = torch.empty((bs, 0), device=self._get_device())
|
| 179 |
+
if keypoints is not None:
|
| 180 |
+
coords = keypoints[:, :, :2]
|
| 181 |
+
labels = keypoints[:, :, -1]
|
| 182 |
+
point_embeddings, point_mask = self._embed_keypoints(
|
| 183 |
+
coords, labels
|
| 184 |
+
) # pad=(boxes is None))
|
| 185 |
+
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
|
| 186 |
+
sparse_masks = torch.cat([sparse_masks, point_mask], dim=1)
|
| 187 |
+
|
| 188 |
+
return sparse_embeddings, sparse_masks
|
| 189 |
+
|
| 190 |
+
def get_mask_embeddings(
|
| 191 |
+
self,
|
| 192 |
+
masks: Optional[torch.Tensor] = None,
|
| 193 |
+
bs: int = 1,
|
| 194 |
+
size: Tuple[int, int] = (16, 16), # [H, W]
|
| 195 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 196 |
+
"""Embeds mask inputs."""
|
| 197 |
+
no_mask_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
|
| 198 |
+
bs, -1, size[0], size[1]
|
| 199 |
+
)
|
| 200 |
+
if masks is not None:
|
| 201 |
+
mask_embeddings = self.mask_downscaling(masks)
|
| 202 |
+
else:
|
| 203 |
+
mask_embeddings = no_mask_embeddings
|
| 204 |
+
return mask_embeddings, no_mask_embeddings
|
| 205 |
+
|
| 206 |
+
def _freeze_stages(self):
|
| 207 |
+
"""Freeze parameters."""
|
| 208 |
+
if self.frozen:
|
| 209 |
+
for param in self.parameters():
|
| 210 |
+
param.requires_grad = False
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class PositionEmbeddingRandom(nn.Module):
|
| 214 |
+
"""
|
| 215 |
+
Positional encoding using random spatial frequencies.
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
| 219 |
+
super().__init__()
|
| 220 |
+
if scale is None or scale <= 0.0:
|
| 221 |
+
scale = 1.0
|
| 222 |
+
self.register_buffer(
|
| 223 |
+
"positional_encoding_gaussian_matrix",
|
| 224 |
+
scale * torch.randn((2, num_pos_feats)),
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
| 228 |
+
"""Positionally encode points that are normalized to [0,1]."""
|
| 229 |
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
| 230 |
+
coords = 2 * coords - 1
|
| 231 |
+
coords = coords @ self.positional_encoding_gaussian_matrix
|
| 232 |
+
coords = 2 * np.pi * coords
|
| 233 |
+
# outputs d_1 x ... x d_n x C shape
|
| 234 |
+
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
| 235 |
+
|
| 236 |
+
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
|
| 237 |
+
"""Generate positional encoding for a grid of the specified size."""
|
| 238 |
+
h, w = size
|
| 239 |
+
device: Any = self.positional_encoding_gaussian_matrix.device
|
| 240 |
+
grid = torch.ones((h, w), device=device, dtype=torch.float32)
|
| 241 |
+
y_embed = grid.cumsum(dim=0) - 0.5
|
| 242 |
+
x_embed = grid.cumsum(dim=1) - 0.5
|
| 243 |
+
y_embed = y_embed / h
|
| 244 |
+
x_embed = x_embed / w
|
| 245 |
+
|
| 246 |
+
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
|
| 247 |
+
return pe.permute(2, 0, 1) # C x H x W
|
| 248 |
+
|
| 249 |
+
def forward_with_coords(
|
| 250 |
+
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
|
| 251 |
+
) -> torch.Tensor:
|
| 252 |
+
"""Positionally encode points that are not normalized to [0,1]."""
|
| 253 |
+
coords = coords_input.clone()
|
| 254 |
+
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
| 255 |
+
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
| 256 |
+
return self._pe_encoding(coords.to(torch.float)) # B x N x C
|
src/sam3d_body/models/decoders/promptable_decoder.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
import pickle
|
| 4 |
+
from typing import Dict, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
from ..modules.transformer import build_norm_layer, TransformerDecoderLayer
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PromptableDecoder(nn.Module):
|
| 13 |
+
"""Cross-attention based Transformer decoder with prompts input.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
token_dims (int): The dimension of input pose tokens.
|
| 17 |
+
prompt_dims (int): The dimension of input prompt tokens.
|
| 18 |
+
context_dims (int): The dimension of image context features.
|
| 19 |
+
dims (int): The projected dimension of all tokens in the decoder.
|
| 20 |
+
depth (int): The number of layers for Transformer decoder.
|
| 21 |
+
num_heads (int): The number of heads for multi-head attention.
|
| 22 |
+
head_dims (int): The dimension of each head.
|
| 23 |
+
mlp_dims (int): The dimension of hidden layers in MLP.
|
| 24 |
+
layer_scale_init_value (float or torch.Tensor): Init value of layer
|
| 25 |
+
scale. Defaults to 0.
|
| 26 |
+
drop_rate (float): Probability of an element to be zeroed
|
| 27 |
+
after the feed forward layer. Defaults to 0.
|
| 28 |
+
attn_drop_rate (float): The drop out rate for attention output weights.
|
| 29 |
+
Defaults to 0.
|
| 30 |
+
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
| 31 |
+
ffn_type (str): Select the type of ffn layers. Defaults to 'origin'.
|
| 32 |
+
act_layer (nn.Module, optional): The activation layer for FFNs.
|
| 33 |
+
Default: nn.GELU
|
| 34 |
+
norm_cfg (dict): Config dict for normalization layer.
|
| 35 |
+
Defaults to ``dict(type='LN')``.
|
| 36 |
+
enable_twoway (bool): Whether to enable two-way Transformer (used in SAM).
|
| 37 |
+
repeat_pe (bool): Whether to re-add PE at each layer (used in SAM)
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
dims: int,
|
| 43 |
+
context_dims: int,
|
| 44 |
+
depth: int,
|
| 45 |
+
num_heads: int = 8,
|
| 46 |
+
head_dims: int = 64,
|
| 47 |
+
mlp_dims: int = 1024,
|
| 48 |
+
layer_scale_init_value: float = 0.0,
|
| 49 |
+
drop_rate: float = 0.0,
|
| 50 |
+
attn_drop_rate: float = 0.0,
|
| 51 |
+
drop_path_rate: float = 0.0,
|
| 52 |
+
ffn_type: str = "origin",
|
| 53 |
+
act_layer: nn.Module = nn.GELU,
|
| 54 |
+
norm_cfg: Dict = dict(type="LN", eps=1e-6),
|
| 55 |
+
enable_twoway: bool = False,
|
| 56 |
+
repeat_pe: bool = False,
|
| 57 |
+
frozen: bool = False,
|
| 58 |
+
do_interm_preds: bool = False,
|
| 59 |
+
do_keypoint_tokens: bool = False,
|
| 60 |
+
keypoint_token_update: bool | str = False,
|
| 61 |
+
):
|
| 62 |
+
super().__init__()
|
| 63 |
+
|
| 64 |
+
self.layers = nn.ModuleList()
|
| 65 |
+
for i in range(depth):
|
| 66 |
+
self.layers.append(
|
| 67 |
+
TransformerDecoderLayer(
|
| 68 |
+
token_dims=dims,
|
| 69 |
+
context_dims=context_dims,
|
| 70 |
+
num_heads=num_heads,
|
| 71 |
+
head_dims=head_dims,
|
| 72 |
+
mlp_dims=mlp_dims,
|
| 73 |
+
layer_scale_init_value=layer_scale_init_value,
|
| 74 |
+
drop_rate=drop_rate,
|
| 75 |
+
attn_drop_rate=attn_drop_rate,
|
| 76 |
+
drop_path_rate=drop_path_rate,
|
| 77 |
+
ffn_type=ffn_type,
|
| 78 |
+
act_layer=act_layer,
|
| 79 |
+
norm_cfg=norm_cfg,
|
| 80 |
+
enable_twoway=enable_twoway,
|
| 81 |
+
repeat_pe=repeat_pe,
|
| 82 |
+
skip_first_pe=(i == 0),
|
| 83 |
+
)
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
self.norm_final = build_norm_layer(norm_cfg, dims)
|
| 87 |
+
self.do_interm_preds = do_interm_preds
|
| 88 |
+
self.do_keypoint_tokens = do_keypoint_tokens
|
| 89 |
+
self.keypoint_token_update = keypoint_token_update
|
| 90 |
+
|
| 91 |
+
self.frozen = frozen
|
| 92 |
+
self._freeze_stages()
|
| 93 |
+
|
| 94 |
+
def forward(
|
| 95 |
+
self,
|
| 96 |
+
token_embedding: torch.Tensor,
|
| 97 |
+
image_embedding: torch.Tensor,
|
| 98 |
+
token_augment: Optional[torch.Tensor] = None,
|
| 99 |
+
image_augment: Optional[torch.Tensor] = None,
|
| 100 |
+
token_mask: Optional[torch.Tensor] = None,
|
| 101 |
+
channel_first: bool = True,
|
| 102 |
+
token_to_pose_output_fn=None,
|
| 103 |
+
keypoint_token_update_fn=None,
|
| 104 |
+
hand_embeddings=None,
|
| 105 |
+
hand_augment=None,
|
| 106 |
+
):
|
| 107 |
+
"""
|
| 108 |
+
Args:
|
| 109 |
+
token_embedding: [B, N, C]
|
| 110 |
+
image_embedding: [B, C, H, W]
|
| 111 |
+
"""
|
| 112 |
+
if channel_first:
|
| 113 |
+
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
| 114 |
+
if image_augment is not None:
|
| 115 |
+
image_augment = image_augment.flatten(2).permute(0, 2, 1)
|
| 116 |
+
if hand_embeddings is not None:
|
| 117 |
+
hand_embeddings = hand_embeddings.flatten(2).permute(0, 2, 1)
|
| 118 |
+
hand_augment = hand_augment.flatten(2).permute(0, 2, 1)
|
| 119 |
+
if len(hand_augment) == 1:
|
| 120 |
+
# inflate batch dimension
|
| 121 |
+
assert len(hand_augment.shape) == 3
|
| 122 |
+
hand_augment = hand_augment.repeat(len(hand_embeddings), 1, 1)
|
| 123 |
+
|
| 124 |
+
if self.do_interm_preds:
|
| 125 |
+
assert token_to_pose_output_fn is not None
|
| 126 |
+
all_pose_outputs = []
|
| 127 |
+
|
| 128 |
+
for layer_idx, layer in enumerate(self.layers):
|
| 129 |
+
if hand_embeddings is None:
|
| 130 |
+
token_embedding, image_embedding = layer(
|
| 131 |
+
token_embedding,
|
| 132 |
+
image_embedding,
|
| 133 |
+
token_augment,
|
| 134 |
+
image_augment,
|
| 135 |
+
token_mask,
|
| 136 |
+
)
|
| 137 |
+
else:
|
| 138 |
+
token_embedding, image_embedding = layer(
|
| 139 |
+
token_embedding,
|
| 140 |
+
torch.cat([image_embedding, hand_embeddings], dim=1),
|
| 141 |
+
token_augment,
|
| 142 |
+
torch.cat([image_augment, hand_augment], dim=1),
|
| 143 |
+
token_mask,
|
| 144 |
+
)
|
| 145 |
+
image_embedding = image_embedding[:, : image_augment.shape[1]]
|
| 146 |
+
|
| 147 |
+
if self.do_interm_preds and layer_idx < len(self.layers) - 1:
|
| 148 |
+
curr_pose_output = token_to_pose_output_fn(
|
| 149 |
+
self.norm_final(token_embedding),
|
| 150 |
+
prev_pose_output=(
|
| 151 |
+
all_pose_outputs[-1] if len(all_pose_outputs) > 0 else None
|
| 152 |
+
),
|
| 153 |
+
layer_idx=layer_idx,
|
| 154 |
+
)
|
| 155 |
+
all_pose_outputs.append(curr_pose_output)
|
| 156 |
+
|
| 157 |
+
if self.keypoint_token_update:
|
| 158 |
+
assert keypoint_token_update_fn is not None
|
| 159 |
+
token_embedding, token_augment, _, _ = keypoint_token_update_fn(
|
| 160 |
+
token_embedding, token_augment, curr_pose_output, layer_idx
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
out = self.norm_final(token_embedding)
|
| 164 |
+
|
| 165 |
+
if self.do_interm_preds:
|
| 166 |
+
curr_pose_output = token_to_pose_output_fn(
|
| 167 |
+
out,
|
| 168 |
+
prev_pose_output=(
|
| 169 |
+
all_pose_outputs[-1] if len(all_pose_outputs) > 0 else None
|
| 170 |
+
),
|
| 171 |
+
layer_idx=layer_idx,
|
| 172 |
+
)
|
| 173 |
+
all_pose_outputs.append(curr_pose_output)
|
| 174 |
+
|
| 175 |
+
return out, all_pose_outputs
|
| 176 |
+
else:
|
| 177 |
+
return out
|
| 178 |
+
|
| 179 |
+
def _freeze_stages(self):
|
| 180 |
+
"""Freeze parameters."""
|
| 181 |
+
if self.frozen:
|
| 182 |
+
for layer in self.layers:
|
| 183 |
+
layer.eval()
|
| 184 |
+
self.norm_final.eval()
|
| 185 |
+
for param in self.parameters():
|
| 186 |
+
param.requires_grad = False
|
| 187 |
+
|
| 188 |
+
def train(self, mode=True):
|
| 189 |
+
"""
|
| 190 |
+
Convert the model into training mode.
|
| 191 |
+
(not called by lightning in trainer.fit() actually)
|
| 192 |
+
"""
|
| 193 |
+
super().train(mode)
|
| 194 |
+
self._freeze_stages()
|
src/sam3d_body/models/heads/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
from ..modules import to_2tuple
|
| 4 |
+
from .camera_head import PerspectiveHead
|
| 5 |
+
from .mhr_head import MHRHead
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def build_head(cfg, head_type="mhr", enable_hand_model=False, default_scale_factor=1.0):
|
| 9 |
+
if head_type == "mhr":
|
| 10 |
+
return MHRHead(
|
| 11 |
+
input_dim=cfg.MODEL.DECODER.DIM,
|
| 12 |
+
mlp_depth=cfg.MODEL.MHR_HEAD.get("MLP_DEPTH", 1),
|
| 13 |
+
mhr_model_path=cfg.MODEL.MHR_HEAD.MHR_MODEL_PATH,
|
| 14 |
+
mlp_channel_div_factor=cfg.MODEL.MHR_HEAD.get("MLP_CHANNEL_DIV_FACTOR", 1),
|
| 15 |
+
enable_hand_model=enable_hand_model,
|
| 16 |
+
)
|
| 17 |
+
elif head_type == "perspective":
|
| 18 |
+
return PerspectiveHead(
|
| 19 |
+
input_dim=cfg.MODEL.DECODER.DIM,
|
| 20 |
+
img_size=to_2tuple(cfg.MODEL.IMAGE_SIZE),
|
| 21 |
+
mlp_depth=cfg.MODEL.get("CAMERA_HEAD", dict()).get("MLP_DEPTH", 1),
|
| 22 |
+
mlp_channel_div_factor=cfg.MODEL.get("CAMERA_HEAD", dict()).get(
|
| 23 |
+
"MLP_CHANNEL_DIV_FACTOR", 1
|
| 24 |
+
),
|
| 25 |
+
default_scale_factor=default_scale_factor,
|
| 26 |
+
)
|
| 27 |
+
else:
|
| 28 |
+
raise ValueError("Invalid head type: ", head_type)
|
src/sam3d_body/models/heads/camera_head.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Sequence, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from sam3d_body.models.modules.geometry_utils import perspective_projection
|
| 9 |
+
|
| 10 |
+
from ..modules import get_intrinsic_matrix, to_2tuple
|
| 11 |
+
from ..modules.transformer import FFN
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class PerspectiveHead(nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
Predict camera translation (s, tx, ty) and perform full-perspective
|
| 17 |
+
2D reprojection (CLIFF/CameraHMR setup).
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
input_dim: int,
|
| 23 |
+
img_size: Tuple[int, int] | Sequence[int], # model input size (W, H)
|
| 24 |
+
mlp_depth: int = 1,
|
| 25 |
+
drop_ratio: float = 0.0,
|
| 26 |
+
mlp_channel_div_factor: int = 8,
|
| 27 |
+
default_scale_factor: float | int = 1,
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
|
| 31 |
+
# Metadata to compute 3D skeleton and 2D reprojection
|
| 32 |
+
self.img_size = to_2tuple(img_size)
|
| 33 |
+
self.ncam = 3 # (s, tx, ty)
|
| 34 |
+
self.default_scale_factor = default_scale_factor
|
| 35 |
+
|
| 36 |
+
self.proj = FFN(
|
| 37 |
+
embed_dims=input_dim,
|
| 38 |
+
feedforward_channels=input_dim // mlp_channel_div_factor,
|
| 39 |
+
output_dims=self.ncam,
|
| 40 |
+
num_fcs=mlp_depth,
|
| 41 |
+
ffn_drop=drop_ratio,
|
| 42 |
+
add_identity=False,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def forward(
|
| 46 |
+
self,
|
| 47 |
+
x: torch.Tensor,
|
| 48 |
+
init_estimate: Optional[torch.Tensor] = None,
|
| 49 |
+
):
|
| 50 |
+
"""
|
| 51 |
+
Args:
|
| 52 |
+
x: pose token with shape [B, C], usually C=DECODER.DIM
|
| 53 |
+
init_estimate: [B, self.ncam]
|
| 54 |
+
"""
|
| 55 |
+
pred_cam = self.proj(x)
|
| 56 |
+
if init_estimate is not None:
|
| 57 |
+
pred_cam = pred_cam + init_estimate
|
| 58 |
+
|
| 59 |
+
return pred_cam
|
| 60 |
+
|
| 61 |
+
def perspective_projection(
|
| 62 |
+
self,
|
| 63 |
+
points_3d: torch.Tensor,
|
| 64 |
+
pred_cam: torch.Tensor,
|
| 65 |
+
bbox_center: torch.Tensor,
|
| 66 |
+
bbox_size: torch.Tensor,
|
| 67 |
+
img_size: torch.Tensor,
|
| 68 |
+
cam_int: torch.Tensor,
|
| 69 |
+
use_intrin_center: bool = False,
|
| 70 |
+
):
|
| 71 |
+
"""
|
| 72 |
+
Args:
|
| 73 |
+
bbox_center / img_size: shape [N, 2], in original image space (w, h)
|
| 74 |
+
bbox_size: shape [N,], in original image space
|
| 75 |
+
cam_int: shape [N, 3, 3]
|
| 76 |
+
"""
|
| 77 |
+
batch_size = points_3d.shape[0]
|
| 78 |
+
pred_cam = pred_cam.clone()
|
| 79 |
+
pred_cam[..., [0, 2]] *= -1 # Camera system difference
|
| 80 |
+
|
| 81 |
+
# Compute camera translation: (scale, x, y) --> (x, y, depth)
|
| 82 |
+
# depth ~= f / s
|
| 83 |
+
# Note that f is in the NDC space (see Zolly section 3.1)
|
| 84 |
+
s, tx, ty = pred_cam[:, 0], pred_cam[:, 1], pred_cam[:, 2]
|
| 85 |
+
bs = bbox_size * s * self.default_scale_factor + 1e-8
|
| 86 |
+
focal_length = cam_int[:, 0, 0]
|
| 87 |
+
tz = 2 * focal_length / bs
|
| 88 |
+
|
| 89 |
+
if not use_intrin_center:
|
| 90 |
+
cx = 2 * (bbox_center[:, 0] - (img_size[:, 0] / 2)) / bs
|
| 91 |
+
cy = 2 * (bbox_center[:, 1] - (img_size[:, 1] / 2)) / bs
|
| 92 |
+
else:
|
| 93 |
+
cx = 2 * (bbox_center[:, 0] - (cam_int[:, 0, 2])) / bs
|
| 94 |
+
cy = 2 * (bbox_center[:, 1] - (cam_int[:, 1, 2])) / bs
|
| 95 |
+
|
| 96 |
+
pred_cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1)
|
| 97 |
+
|
| 98 |
+
# Compute camera translation
|
| 99 |
+
j3d_cam = points_3d + pred_cam_t.unsqueeze(1)
|
| 100 |
+
|
| 101 |
+
# Projection to the image plane.
|
| 102 |
+
# Note that the projection output is in *original* image space now.
|
| 103 |
+
j2d = perspective_projection(j3d_cam, cam_int)
|
| 104 |
+
|
| 105 |
+
return {
|
| 106 |
+
"pred_keypoints_2d": j2d.reshape(batch_size, -1, 2),
|
| 107 |
+
"pred_cam_t": pred_cam_t,
|
| 108 |
+
"focal_length": focal_length,
|
| 109 |
+
"pred_keypoints_2d_depth": j3d_cam.reshape(batch_size, -1, 3)[:, :, 2],
|
| 110 |
+
}
|
src/sam3d_body/models/heads/mhr_head.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import warnings
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import roma
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
from ..modules import rot6d_to_rotmat
|
| 12 |
+
from ..modules.mhr_utils import (
|
| 13 |
+
compact_cont_to_model_params_body,
|
| 14 |
+
compact_cont_to_model_params_hand,
|
| 15 |
+
compact_model_params_to_cont_body,
|
| 16 |
+
mhr_param_hand_mask,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
from ..modules.transformer import FFN
|
| 20 |
+
|
| 21 |
+
MOMENTUM_ENABLED = os.environ.get("MOMENTUM_ENABLED") is None
|
| 22 |
+
try:
|
| 23 |
+
if MOMENTUM_ENABLED:
|
| 24 |
+
from mhr.mhr import MHR
|
| 25 |
+
|
| 26 |
+
MOMENTUM_ENABLED = True
|
| 27 |
+
warnings.warn("Momentum is enabled")
|
| 28 |
+
else:
|
| 29 |
+
warnings.warn("Momentum is not enabled")
|
| 30 |
+
raise ImportError
|
| 31 |
+
except:
|
| 32 |
+
MOMENTUM_ENABLED = False
|
| 33 |
+
warnings.warn("Momentum is not enabled")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class MHRHead(nn.Module):
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
input_dim: int,
|
| 41 |
+
mlp_depth: int = 1,
|
| 42 |
+
mhr_model_path: str = "",
|
| 43 |
+
extra_joint_regressor: str = "",
|
| 44 |
+
ffn_zero_bias: bool = True,
|
| 45 |
+
mlp_channel_div_factor: int = 8,
|
| 46 |
+
enable_hand_model=False,
|
| 47 |
+
):
|
| 48 |
+
super().__init__()
|
| 49 |
+
|
| 50 |
+
self.num_shape_comps = 45
|
| 51 |
+
self.num_scale_comps = 28
|
| 52 |
+
self.num_hand_comps = 54
|
| 53 |
+
self.num_face_comps = 72
|
| 54 |
+
self.enable_hand_model = enable_hand_model
|
| 55 |
+
|
| 56 |
+
self.body_cont_dim = 260
|
| 57 |
+
self.npose = (
|
| 58 |
+
6 # Global Rotation
|
| 59 |
+
+ self.body_cont_dim # then body
|
| 60 |
+
+ self.num_shape_comps
|
| 61 |
+
+ self.num_scale_comps
|
| 62 |
+
+ self.num_hand_comps * 2
|
| 63 |
+
+ self.num_face_comps
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
self.proj = FFN(
|
| 67 |
+
embed_dims=input_dim,
|
| 68 |
+
feedforward_channels=input_dim // mlp_channel_div_factor,
|
| 69 |
+
output_dims=self.npose,
|
| 70 |
+
num_fcs=mlp_depth,
|
| 71 |
+
ffn_drop=0.0,
|
| 72 |
+
add_identity=False,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
if ffn_zero_bias:
|
| 76 |
+
torch.nn.init.zeros_(self.proj.layers[-2].bias)
|
| 77 |
+
|
| 78 |
+
# MHR Parameters
|
| 79 |
+
self.model_data_dir = mhr_model_path
|
| 80 |
+
self.num_hand_scale_comps = self.num_scale_comps - 18
|
| 81 |
+
self.num_hand_pose_comps = self.num_hand_comps
|
| 82 |
+
|
| 83 |
+
# Buffers to be filled in by model state dict
|
| 84 |
+
self.joint_rotation = nn.Parameter(torch.zeros(127, 3, 3), requires_grad=False)
|
| 85 |
+
self.scale_mean = nn.Parameter(torch.zeros(68), requires_grad=False)
|
| 86 |
+
self.scale_comps = nn.Parameter(torch.zeros(28, 68), requires_grad=False)
|
| 87 |
+
self.faces = nn.Parameter(torch.zeros(36874, 3).long(), requires_grad=False)
|
| 88 |
+
self.hand_pose_mean = nn.Parameter(torch.zeros(54), requires_grad=False)
|
| 89 |
+
self.hand_pose_comps = nn.Parameter(torch.eye(54), requires_grad=False)
|
| 90 |
+
self.hand_joint_idxs_left = nn.Parameter(
|
| 91 |
+
torch.zeros(27).long(), requires_grad=False
|
| 92 |
+
)
|
| 93 |
+
self.hand_joint_idxs_right = nn.Parameter(
|
| 94 |
+
torch.zeros(27).long(), requires_grad=False
|
| 95 |
+
)
|
| 96 |
+
self.keypoint_mapping = nn.Parameter(
|
| 97 |
+
torch.zeros(308, 18439 + 127), requires_grad=False
|
| 98 |
+
)
|
| 99 |
+
# Some special buffers for the hand-version
|
| 100 |
+
self.right_wrist_coords = nn.Parameter(torch.zeros(3), requires_grad=False)
|
| 101 |
+
self.root_coords = nn.Parameter(torch.zeros(3), requires_grad=False)
|
| 102 |
+
self.local_to_world_wrist = nn.Parameter(torch.zeros(3, 3), requires_grad=False)
|
| 103 |
+
self.nonhand_param_idxs = nn.Parameter(
|
| 104 |
+
torch.zeros(145).long(), requires_grad=False
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Load MHR itself
|
| 108 |
+
if MOMENTUM_ENABLED:
|
| 109 |
+
self.mhr = MHR.from_files(
|
| 110 |
+
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
| 111 |
+
lod=1,
|
| 112 |
+
)
|
| 113 |
+
else:
|
| 114 |
+
self.mhr = torch.jit.load(
|
| 115 |
+
mhr_model_path,
|
| 116 |
+
map_location=("cuda" if torch.cuda.is_available() else "cpu"),
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
for param in self.mhr.parameters():
|
| 120 |
+
param.requires_grad = False
|
| 121 |
+
|
| 122 |
+
def get_zero_pose_init(self, factor=1.0):
|
| 123 |
+
# Initialize pose token with zero-initialized learnable params
|
| 124 |
+
# Note: bias/initial value should be zero-pose in cont, not all-zeros
|
| 125 |
+
weights = torch.zeros(1, self.npose)
|
| 126 |
+
weights[:, : 6 + self.body_cont_dim] = torch.cat(
|
| 127 |
+
[
|
| 128 |
+
torch.FloatTensor([1, 0, 0, 0, 1, 0]),
|
| 129 |
+
compact_model_params_to_cont_body(torch.zeros(1, 133)).squeeze()
|
| 130 |
+
* factor,
|
| 131 |
+
],
|
| 132 |
+
dim=0,
|
| 133 |
+
)
|
| 134 |
+
return weights
|
| 135 |
+
|
| 136 |
+
def replace_hands_in_pose(self, full_pose_params, hand_pose_params):
|
| 137 |
+
assert full_pose_params.shape[1] == 136
|
| 138 |
+
|
| 139 |
+
# This drops in the hand poses from hand_pose_params (PCA 6D) into full_pose_params.
|
| 140 |
+
# Split into left and right hands
|
| 141 |
+
left_hand_params, right_hand_params = torch.split(
|
| 142 |
+
hand_pose_params,
|
| 143 |
+
[self.num_hand_pose_comps, self.num_hand_pose_comps],
|
| 144 |
+
dim=1,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Change from cont to model params
|
| 148 |
+
left_hand_params_model_params = compact_cont_to_model_params_hand(
|
| 149 |
+
self.hand_pose_mean
|
| 150 |
+
+ torch.einsum("da,ab->db", left_hand_params, self.hand_pose_comps)
|
| 151 |
+
)
|
| 152 |
+
right_hand_params_model_params = compact_cont_to_model_params_hand(
|
| 153 |
+
self.hand_pose_mean
|
| 154 |
+
+ torch.einsum("da,ab->db", right_hand_params, self.hand_pose_comps)
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Drop it in
|
| 158 |
+
full_pose_params[:, self.hand_joint_idxs_left] = left_hand_params_model_params
|
| 159 |
+
full_pose_params[:, self.hand_joint_idxs_right] = right_hand_params_model_params
|
| 160 |
+
|
| 161 |
+
return full_pose_params # B x 207
|
| 162 |
+
|
| 163 |
+
def mhr_forward(
|
| 164 |
+
self,
|
| 165 |
+
global_trans,
|
| 166 |
+
global_rot,
|
| 167 |
+
body_pose_params,
|
| 168 |
+
hand_pose_params,
|
| 169 |
+
scale_params,
|
| 170 |
+
shape_params,
|
| 171 |
+
expr_params=None,
|
| 172 |
+
return_keypoints=False,
|
| 173 |
+
do_pcblend=True,
|
| 174 |
+
return_joint_coords=False,
|
| 175 |
+
return_model_params=False,
|
| 176 |
+
return_joint_rotations=False,
|
| 177 |
+
scale_offsets=None,
|
| 178 |
+
vertex_offsets=None,
|
| 179 |
+
):
|
| 180 |
+
|
| 181 |
+
if self.enable_hand_model:
|
| 182 |
+
# Transfer wrist-centric predictions to the body.
|
| 183 |
+
global_rot_ori = global_rot.clone()
|
| 184 |
+
global_trans_ori = global_trans.clone()
|
| 185 |
+
global_rot = roma.rotmat_to_euler(
|
| 186 |
+
"xyz",
|
| 187 |
+
roma.euler_to_rotmat("xyz", global_rot_ori) @ self.local_to_world_wrist,
|
| 188 |
+
)
|
| 189 |
+
global_trans = (
|
| 190 |
+
-(
|
| 191 |
+
roma.euler_to_rotmat("xyz", global_rot)
|
| 192 |
+
@ (self.right_wrist_coords - self.root_coords)
|
| 193 |
+
+ self.root_coords
|
| 194 |
+
)
|
| 195 |
+
+ global_trans_ori
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
body_pose_params = body_pose_params[..., :130]
|
| 199 |
+
|
| 200 |
+
# Convert from scale and shape params to actual scales and vertices
|
| 201 |
+
## Add singleton batches in case...
|
| 202 |
+
if len(scale_params.shape) == 1:
|
| 203 |
+
scale_params = scale_params[None]
|
| 204 |
+
if len(shape_params.shape) == 1:
|
| 205 |
+
shape_params = shape_params[None]
|
| 206 |
+
## Convert scale...
|
| 207 |
+
scales = self.scale_mean[None, :] + scale_params @ self.scale_comps
|
| 208 |
+
if scale_offsets is not None:
|
| 209 |
+
scales = scales + scale_offsets
|
| 210 |
+
|
| 211 |
+
# Now, figure out the pose.
|
| 212 |
+
## 10 here is because it's more stable to optimize global translation in meters.
|
| 213 |
+
full_pose_params = torch.cat(
|
| 214 |
+
[global_trans * 10, global_rot, body_pose_params], dim=1
|
| 215 |
+
) # B x 127
|
| 216 |
+
## Put in hands
|
| 217 |
+
if hand_pose_params is not None:
|
| 218 |
+
full_pose_params = self.replace_hands_in_pose(
|
| 219 |
+
full_pose_params, hand_pose_params
|
| 220 |
+
)
|
| 221 |
+
model_params = torch.cat([full_pose_params, scales], dim=1)
|
| 222 |
+
|
| 223 |
+
if self.enable_hand_model:
|
| 224 |
+
# Zero out non-hand parameters
|
| 225 |
+
model_params[:, self.nonhand_param_idxs] = 0
|
| 226 |
+
|
| 227 |
+
curr_skinned_verts, curr_skel_state = self.mhr(
|
| 228 |
+
shape_params, model_params, expr_params
|
| 229 |
+
)
|
| 230 |
+
curr_joint_coords, curr_joint_quats, _ = torch.split(
|
| 231 |
+
curr_skel_state, [3, 4, 1], dim=2
|
| 232 |
+
)
|
| 233 |
+
curr_skinned_verts = curr_skinned_verts / 100
|
| 234 |
+
curr_joint_coords = curr_joint_coords / 100
|
| 235 |
+
curr_joint_rots = roma.unitquat_to_rotmat(curr_joint_quats)
|
| 236 |
+
|
| 237 |
+
# Prepare returns
|
| 238 |
+
to_return = [curr_skinned_verts]
|
| 239 |
+
if return_keypoints:
|
| 240 |
+
# Get sapiens 308 keypoints
|
| 241 |
+
model_vert_joints = torch.cat(
|
| 242 |
+
[curr_skinned_verts, curr_joint_coords], dim=1
|
| 243 |
+
) # B x (num_verts + 127) x 3
|
| 244 |
+
model_keypoints_pred = (
|
| 245 |
+
(
|
| 246 |
+
self.keypoint_mapping
|
| 247 |
+
@ model_vert_joints.permute(1, 0, 2).flatten(1, 2)
|
| 248 |
+
)
|
| 249 |
+
.reshape(-1, model_vert_joints.shape[0], 3)
|
| 250 |
+
.permute(1, 0, 2)
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
if self.enable_hand_model:
|
| 254 |
+
# Zero out everything except for the right hand
|
| 255 |
+
model_keypoints_pred[:, :21] = 0
|
| 256 |
+
model_keypoints_pred[:, 42:] = 0
|
| 257 |
+
|
| 258 |
+
to_return = to_return + [model_keypoints_pred]
|
| 259 |
+
if return_joint_coords:
|
| 260 |
+
to_return = to_return + [curr_joint_coords]
|
| 261 |
+
if return_model_params:
|
| 262 |
+
to_return = to_return + [model_params]
|
| 263 |
+
if return_joint_rotations:
|
| 264 |
+
to_return = to_return + [curr_joint_rots]
|
| 265 |
+
|
| 266 |
+
if isinstance(to_return, list) and len(to_return) == 1:
|
| 267 |
+
return to_return[0]
|
| 268 |
+
else:
|
| 269 |
+
return tuple(to_return)
|
| 270 |
+
|
| 271 |
+
def forward(
|
| 272 |
+
self,
|
| 273 |
+
x: torch.Tensor,
|
| 274 |
+
init_estimate: Optional[torch.Tensor] = None,
|
| 275 |
+
do_pcblend=True,
|
| 276 |
+
slim_keypoints=False,
|
| 277 |
+
):
|
| 278 |
+
"""
|
| 279 |
+
Args:
|
| 280 |
+
x: pose token with shape [B, C], usually C=DECODER.DIM
|
| 281 |
+
init_estimate: [B, self.npose]
|
| 282 |
+
"""
|
| 283 |
+
batch_size = x.shape[0]
|
| 284 |
+
pred = self.proj(x)
|
| 285 |
+
if init_estimate is not None:
|
| 286 |
+
pred = pred + init_estimate
|
| 287 |
+
|
| 288 |
+
# From pred, we want to pull out individual predictions.
|
| 289 |
+
|
| 290 |
+
## First, get globals
|
| 291 |
+
### Global rotation is first 6.
|
| 292 |
+
count = 6
|
| 293 |
+
global_rot_6d = pred[:, :count]
|
| 294 |
+
global_rot_rotmat = rot6d_to_rotmat(global_rot_6d) # B x 3 x 3
|
| 295 |
+
global_rot_euler = roma.rotmat_to_euler("ZYX", global_rot_rotmat) # B x 3
|
| 296 |
+
global_trans = torch.zeros_like(global_rot_euler)
|
| 297 |
+
|
| 298 |
+
## Next, get body pose.
|
| 299 |
+
### Hold onto raw, continuous version for iterative correction.
|
| 300 |
+
pred_pose_cont = pred[:, count : count + self.body_cont_dim]
|
| 301 |
+
count += self.body_cont_dim
|
| 302 |
+
### Convert to eulers (and trans)
|
| 303 |
+
pred_pose_euler = compact_cont_to_model_params_body(pred_pose_cont)
|
| 304 |
+
### Zero-out hands
|
| 305 |
+
pred_pose_euler[:, mhr_param_hand_mask] = 0
|
| 306 |
+
### Zero-out jaw
|
| 307 |
+
pred_pose_euler[:, -3:] = 0
|
| 308 |
+
|
| 309 |
+
## Get remaining parameters
|
| 310 |
+
pred_shape = pred[:, count : count + self.num_shape_comps]
|
| 311 |
+
count += self.num_shape_comps
|
| 312 |
+
pred_scale = pred[:, count : count + self.num_scale_comps]
|
| 313 |
+
count += self.num_scale_comps
|
| 314 |
+
pred_hand = pred[:, count : count + self.num_hand_comps * 2]
|
| 315 |
+
count += self.num_hand_comps * 2
|
| 316 |
+
pred_face = pred[:, count : count + self.num_face_comps] * 0
|
| 317 |
+
count += self.num_face_comps
|
| 318 |
+
|
| 319 |
+
# Run everything through mhr
|
| 320 |
+
output = self.mhr_forward(
|
| 321 |
+
global_trans=global_trans,
|
| 322 |
+
global_rot=global_rot_euler,
|
| 323 |
+
body_pose_params=pred_pose_euler,
|
| 324 |
+
hand_pose_params=pred_hand,
|
| 325 |
+
scale_params=pred_scale,
|
| 326 |
+
shape_params=pred_shape,
|
| 327 |
+
expr_params=pred_face,
|
| 328 |
+
do_pcblend=do_pcblend,
|
| 329 |
+
return_keypoints=True,
|
| 330 |
+
return_joint_coords=True,
|
| 331 |
+
return_model_params=True,
|
| 332 |
+
return_joint_rotations=True,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
# Some existing code to get joints and fix camera system
|
| 336 |
+
verts, j3d, jcoords, mhr_model_params, joint_global_rots = output
|
| 337 |
+
j3d = j3d[:, :70] # 308 --> 70 keypoints
|
| 338 |
+
|
| 339 |
+
if verts is not None:
|
| 340 |
+
verts[..., [1, 2]] *= -1 # Camera system difference
|
| 341 |
+
j3d[..., [1, 2]] *= -1 # Camera system difference
|
| 342 |
+
if jcoords is not None:
|
| 343 |
+
jcoords[..., [1, 2]] *= -1
|
| 344 |
+
|
| 345 |
+
# Prep outputs
|
| 346 |
+
output = {
|
| 347 |
+
"pred_pose_raw": torch.cat(
|
| 348 |
+
[global_rot_6d, pred_pose_cont], dim=1
|
| 349 |
+
), # Both global rot and continuous pose
|
| 350 |
+
"pred_pose_rotmat": None, # This normally used for mhr pose param rotmat supervision.
|
| 351 |
+
"global_rot": global_rot_euler,
|
| 352 |
+
"body_pose": pred_pose_euler, # Unused during training
|
| 353 |
+
"shape": pred_shape,
|
| 354 |
+
"scale": pred_scale,
|
| 355 |
+
"hand": pred_hand,
|
| 356 |
+
"face": pred_face,
|
| 357 |
+
"pred_keypoints_3d": j3d.reshape(batch_size, -1, 3),
|
| 358 |
+
"pred_vertices": (
|
| 359 |
+
verts.reshape(batch_size, -1, 3) if verts is not None else None
|
| 360 |
+
),
|
| 361 |
+
"pred_joint_coords": (
|
| 362 |
+
jcoords.reshape(batch_size, -1, 3) if jcoords is not None else None
|
| 363 |
+
),
|
| 364 |
+
"faces": self.faces.cpu().numpy(),
|
| 365 |
+
"joint_global_rots": joint_global_rots,
|
| 366 |
+
"mhr_model_params": mhr_model_params,
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
return output
|
src/sam3d_body/models/meta_arch/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
from .sam3d_body import SAM3DBody
|
src/sam3d_body/models/meta_arch/base_lightning_module.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BaseLightningModule(pl.LightningModule):
|
| 9 |
+
def _log_metric(self, name, value, step=None):
|
| 10 |
+
for logger in self.trainer.loggers:
|
| 11 |
+
if isinstance(logger, WandbLogger):
|
| 12 |
+
if step is not None:
|
| 13 |
+
logger.experiment.log({name: value, "step": step})
|
| 14 |
+
else:
|
| 15 |
+
logger.experiment.log({name: value})
|
| 16 |
+
elif isinstance(logger, TensorBoardLogger):
|
| 17 |
+
logger.experiment.add_scalar(name, value, step)
|
| 18 |
+
else:
|
| 19 |
+
raise ValueError(f"Unsupported logger: {logger}")
|
| 20 |
+
|
| 21 |
+
def _log_image(self, name, img_tensor, dataformats="CHW", step_count=None):
|
| 22 |
+
"""Log image tensor to both W&B and TensorBoard."""
|
| 23 |
+
step = step_count if step_count is not None else self.global_step
|
| 24 |
+
for logger in self.trainer.loggers:
|
| 25 |
+
if isinstance(logger, WandbLogger):
|
| 26 |
+
import wandb
|
| 27 |
+
|
| 28 |
+
img = img_tensor
|
| 29 |
+
if dataformats.upper() == "CHW":
|
| 30 |
+
# If in PyTorch format (C,H,W), convert to (H,W,C) for wandb
|
| 31 |
+
img = img_tensor.permute(1, 2, 0).cpu().numpy()
|
| 32 |
+
logger.experiment.log({name: wandb.Image(img), "step": step})
|
| 33 |
+
elif isinstance(logger, TensorBoardLogger):
|
| 34 |
+
logger.experiment.add_image(
|
| 35 |
+
name, img_tensor, step, dataformats=dataformats
|
| 36 |
+
)
|
| 37 |
+
else:
|
| 38 |
+
raise ValueError(f"Unsupported logger: {logger}")
|
| 39 |
+
|
| 40 |
+
def _log_hist(self, name, array, step_count=None):
|
| 41 |
+
for logger in self.trainer.loggers:
|
| 42 |
+
if isinstance(logger, WandbLogger):
|
| 43 |
+
import wandb
|
| 44 |
+
|
| 45 |
+
value = wandb.Histogram(
|
| 46 |
+
np_histogram=(array, np.arange(array.shape[0] + 1)),
|
| 47 |
+
)
|
| 48 |
+
logger.experiment.log({name: value, "step": step_count})
|
src/sam3d_body/models/meta_arch/base_model.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
"""Define an abstract base model for consistent format input / processing / output."""
|
| 4 |
+
|
| 5 |
+
from abc import abstractmethod
|
| 6 |
+
from functools import partial
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from yacs.config import CfgNode
|
| 10 |
+
|
| 11 |
+
from ..optim.fp16_utils import convert_module_to_f16, convert_to_fp16_safe
|
| 12 |
+
from .base_lightning_module import BaseLightningModule
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BaseModel(BaseLightningModule):
|
| 16 |
+
def __init__(self, cfg: CfgNode | None, **kwargs):
|
| 17 |
+
super().__init__()
|
| 18 |
+
|
| 19 |
+
# Save hyperparameters
|
| 20 |
+
self.save_hyperparameters(logger=False)
|
| 21 |
+
self.cfg = cfg
|
| 22 |
+
|
| 23 |
+
self._initialze_model(**kwargs)
|
| 24 |
+
|
| 25 |
+
# Initialize attributes for image-based batch format
|
| 26 |
+
self._max_num_person = None
|
| 27 |
+
self._person_valid = None
|
| 28 |
+
|
| 29 |
+
@abstractmethod
|
| 30 |
+
def _initialze_model(self, **kwargs) -> None:
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
def data_preprocess(
|
| 34 |
+
self,
|
| 35 |
+
inputs: torch.Tensor,
|
| 36 |
+
crop_width: bool = False,
|
| 37 |
+
is_full: bool = False, # whether for full_branch
|
| 38 |
+
crop_hand: int = 0,
|
| 39 |
+
) -> torch.Tensor:
|
| 40 |
+
image_mean = self.image_mean if not is_full else self.full_image_mean
|
| 41 |
+
image_std = self.image_std if not is_full else self.full_image_std
|
| 42 |
+
|
| 43 |
+
if inputs.max() > 1 and image_mean.max() <= 1.0:
|
| 44 |
+
inputs = inputs / 255.0
|
| 45 |
+
elif inputs.max() <= 1.0 and image_mean.max() > 1:
|
| 46 |
+
inputs = inputs * 255.0
|
| 47 |
+
batch_inputs = (inputs - image_mean) / image_std
|
| 48 |
+
|
| 49 |
+
if crop_width:
|
| 50 |
+
if crop_hand > 0:
|
| 51 |
+
batch_inputs = batch_inputs[:, :, :, crop_hand:-crop_hand]
|
| 52 |
+
elif self.cfg.MODEL.BACKBONE.TYPE in [
|
| 53 |
+
"vit_hmr",
|
| 54 |
+
"vit",
|
| 55 |
+
]:
|
| 56 |
+
# ViT backbone assumes a different aspect ratio as input size
|
| 57 |
+
batch_inputs = batch_inputs[:, :, :, 32:-32]
|
| 58 |
+
elif self.cfg.MODEL.BACKBONE.TYPE in [
|
| 59 |
+
"vit_hmr_512_384",
|
| 60 |
+
]:
|
| 61 |
+
batch_inputs = batch_inputs[:, :, :, 64:-64]
|
| 62 |
+
else:
|
| 63 |
+
raise Exception
|
| 64 |
+
|
| 65 |
+
return batch_inputs
|
| 66 |
+
|
| 67 |
+
def _initialize_batch(self, batch: dict) -> None:
|
| 68 |
+
# Check whether the input batch is with format
|
| 69 |
+
# [batch_size, num_person, ...]
|
| 70 |
+
if batch["img"].dim() == 5:
|
| 71 |
+
self._batch_size, self._max_num_person = batch["img"].shape[:2]
|
| 72 |
+
self._person_valid = self._flatten_person(batch["person_valid"]) > 0
|
| 73 |
+
else:
|
| 74 |
+
self._batch_size = batch["img"].shape[0]
|
| 75 |
+
self._max_num_person = 0
|
| 76 |
+
self._person_valid = None
|
| 77 |
+
|
| 78 |
+
def _flatten_person(self, x: torch.Tensor) -> torch.Tensor:
|
| 79 |
+
assert self._max_num_person is not None, "No max_num_person initialized"
|
| 80 |
+
|
| 81 |
+
if self._max_num_person:
|
| 82 |
+
# Merge person crops to batch dimension
|
| 83 |
+
shape = x.shape
|
| 84 |
+
x = x.view(self._batch_size * self._max_num_person, *shape[2:])
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
def _unflatten_person(self, x: torch.Tensor) -> torch.Tensor:
|
| 88 |
+
shape = x.shape
|
| 89 |
+
if self._max_num_person:
|
| 90 |
+
x = x.view(self._batch_size, self._max_num_person, *shape[1:])
|
| 91 |
+
return x
|
| 92 |
+
|
| 93 |
+
def _get_valid(self, x: torch.Tensor) -> torch.Tensor:
|
| 94 |
+
assert self._max_num_person is not None, "No max_num_person initialized"
|
| 95 |
+
|
| 96 |
+
if self._person_valid is not None:
|
| 97 |
+
x = x[self._person_valid]
|
| 98 |
+
return x
|
| 99 |
+
|
| 100 |
+
def _full_to_crop(self, batch: dict, pred_keypoints_2d: torch.Tensor) -> torch.Tensor:
|
| 101 |
+
"""Convert full-image keypoints coordinates to crop and normalize to [-0.5. 0.5]"""
|
| 102 |
+
pred_keypoints_2d_cropped = torch.cat(
|
| 103 |
+
[pred_keypoints_2d, torch.ones_like(pred_keypoints_2d[:, :, [-1]])], dim=-1
|
| 104 |
+
)
|
| 105 |
+
affine_trans = self._flatten_person(batch["affine_trans"]).to(pred_keypoints_2d_cropped)
|
| 106 |
+
img_size = self._flatten_person(batch["img_size"]).unsqueeze(1)
|
| 107 |
+
pred_keypoints_2d_cropped = pred_keypoints_2d_cropped @ affine_trans.mT
|
| 108 |
+
pred_keypoints_2d_cropped = pred_keypoints_2d_cropped[..., :2] / img_size - 0.5
|
| 109 |
+
|
| 110 |
+
return pred_keypoints_2d_cropped
|
| 111 |
+
|
| 112 |
+
def _cam_full_to_crop(
|
| 113 |
+
self, batch: dict, pred_cam_t: torch.Tensor, focal_length: torch.Tensor = None
|
| 114 |
+
) -> torch.Tensor:
|
| 115 |
+
"""Revert the camera translation from full to crop image space"""
|
| 116 |
+
num_person = batch["img"].shape[1]
|
| 117 |
+
cam_int = self._flatten_person(batch["cam_int"].unsqueeze(1).expand(-1, num_person, -1, -1).contiguous())
|
| 118 |
+
bbox_center = self._flatten_person(batch["bbox_center"])
|
| 119 |
+
bbox_size = self._flatten_person(batch["bbox_scale"])[:, 0]
|
| 120 |
+
input_size = self._flatten_person(batch["img_size"])[:, 0]
|
| 121 |
+
|
| 122 |
+
tx, ty, tz = pred_cam_t[:, 0], pred_cam_t[:, 1], pred_cam_t[:, 2]
|
| 123 |
+
if focal_length is None:
|
| 124 |
+
focal_length = cam_int[:, 0, 0]
|
| 125 |
+
bs = 2 * focal_length / (tz + 1e-8)
|
| 126 |
+
|
| 127 |
+
cx = 2 * (bbox_center[:, 0] - (cam_int[:, 0, 2])) / bs
|
| 128 |
+
cy = 2 * (bbox_center[:, 1] - (cam_int[:, 1, 2])) / bs
|
| 129 |
+
|
| 130 |
+
crop_cam_t = torch.stack([tx - cx, ty - cy, tz * bbox_size / input_size], dim=-1)
|
| 131 |
+
return crop_cam_t
|
| 132 |
+
|
| 133 |
+
def convert_to_fp16(self) -> torch.dtype:
|
| 134 |
+
"""
|
| 135 |
+
Convert the torso of the model to float16.
|
| 136 |
+
"""
|
| 137 |
+
fp16_type = torch.float16 if self.cfg.TRAIN.get("FP16_TYPE", "float16") == "float16" else torch.bfloat16
|
| 138 |
+
|
| 139 |
+
if hasattr(self, "backbone"):
|
| 140 |
+
self._set_fp16(self.backbone, fp16_type)
|
| 141 |
+
if hasattr(self, "full_encoder"):
|
| 142 |
+
self._set_fp16(self.full_encoder, fp16_type)
|
| 143 |
+
|
| 144 |
+
if hasattr(self.backbone, "lhand_pos_embed"):
|
| 145 |
+
self.backbone.lhand_pos_embed.data = self.backbone.lhand_pos_embed.data.to(fp16_type)
|
| 146 |
+
|
| 147 |
+
if hasattr(self.backbone, "rhand_pos_embed"):
|
| 148 |
+
self.backbone.rhand_pos_embed.data = self.backbone.rhand_pos_embed.data.to(fp16_type)
|
| 149 |
+
|
| 150 |
+
return fp16_type
|
| 151 |
+
|
| 152 |
+
def _set_fp16(self, module, fp16_type):
|
| 153 |
+
if hasattr(module, "pos_embed"):
|
| 154 |
+
module.apply(partial(convert_module_to_f16, dtype=fp16_type))
|
| 155 |
+
module.pos_embed.data = module.pos_embed.data.to(fp16_type)
|
| 156 |
+
elif hasattr(module.encoder, "rope_embed"):
|
| 157 |
+
# DINOv3
|
| 158 |
+
module.encoder.apply(partial(convert_to_fp16_safe, dtype=fp16_type))
|
| 159 |
+
module.encoder.rope_embed = module.encoder.rope_embed.to(fp16_type)
|
| 160 |
+
else:
|
| 161 |
+
# DINOv2
|
| 162 |
+
module.encoder.pos_embed.data = module.encoder.pos_embed.data.to(fp16_type)
|
src/sam3d_body/models/meta_arch/sam3d_body.py
ADDED
|
@@ -0,0 +1,1728 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
from collections.abc import Sequence
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import roma
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
from sam3d_body.data.utils.prepare_batch import prepare_batch
|
| 14 |
+
from sam3d_body.models.decoders.prompt_encoder import PositionEmbeddingRandom
|
| 15 |
+
from sam3d_body.models.modules.mhr_utils import (
|
| 16 |
+
fix_wrist_euler,
|
| 17 |
+
rotation_angle_difference,
|
| 18 |
+
)
|
| 19 |
+
from sam3d_body.utils import recursive_to
|
| 20 |
+
from sam3d_body.utils.logging import get_pylogger
|
| 21 |
+
|
| 22 |
+
from ..backbones import create_backbone
|
| 23 |
+
from ..decoders import PromptEncoder, build_decoder, build_keypoint_sampler
|
| 24 |
+
from ..heads import build_head
|
| 25 |
+
from ..modules.camera_embed import CameraEncoder
|
| 26 |
+
from ..modules.transformer import FFN, MLP
|
| 27 |
+
from .base_model import BaseModel
|
| 28 |
+
|
| 29 |
+
logger = get_pylogger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# fmt: off
|
| 33 |
+
PROMPT_KEYPOINTS = { # keypoint_idx: prompt_idx
|
| 34 |
+
"mhr70": {
|
| 35 |
+
i: i for i in range(70)
|
| 36 |
+
}, # all 70 keypoints are supported for prompting
|
| 37 |
+
}
|
| 38 |
+
KEY_BODY = [5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 41, 62] # key body joints for prompting
|
| 39 |
+
KEY_RIGHT_HAND = list(range(21, 42))
|
| 40 |
+
# fmt: on
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class BodyPredContainer:
|
| 45 |
+
"""Structured container for main body + optional hand inference outputs."""
|
| 46 |
+
|
| 47 |
+
pose_output: dict[str, Any]
|
| 48 |
+
batch_lhand: dict[str, Any] | None = None
|
| 49 |
+
batch_rhand: dict[str, Any] | None = None
|
| 50 |
+
lhand_output: dict[str, Any] | None = None
|
| 51 |
+
rhand_output: dict[str, Any] | None = None
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class SAM3DBody(BaseModel):
|
| 55 |
+
pelvis_idx = [9, 10] # left_hip, right_hip
|
| 56 |
+
|
| 57 |
+
def _initialze_model(self):
|
| 58 |
+
self.register_buffer("image_mean", torch.tensor(self.cfg.MODEL.IMAGE_MEAN).view(-1, 1, 1), False)
|
| 59 |
+
self.register_buffer("image_std", torch.tensor(self.cfg.MODEL.IMAGE_STD).view(-1, 1, 1), False)
|
| 60 |
+
|
| 61 |
+
# Create backbone feature extractor for human crops
|
| 62 |
+
self.backbone = create_backbone(self.cfg.MODEL.BACKBONE.TYPE, self.cfg)
|
| 63 |
+
|
| 64 |
+
# Create header for pose estimation output
|
| 65 |
+
self.head_pose = build_head(self.cfg, self.cfg.MODEL.PERSON_HEAD.POSE_TYPE)
|
| 66 |
+
self.head_pose.hand_pose_comps_ori = nn.Parameter(self.head_pose.hand_pose_comps.clone(), requires_grad=False)
|
| 67 |
+
self.head_pose.hand_pose_comps.data = torch.eye(54).to(self.head_pose.hand_pose_comps.data).float()
|
| 68 |
+
|
| 69 |
+
# Initialize pose token with learnable params
|
| 70 |
+
# Note: bias/initial value should be zero-pose in cont, not all-zeros
|
| 71 |
+
self.init_pose = nn.Embedding(1, self.head_pose.npose)
|
| 72 |
+
|
| 73 |
+
# Define header for hand pose estimation
|
| 74 |
+
self.head_pose_hand = build_head(self.cfg, self.cfg.MODEL.PERSON_HEAD.POSE_TYPE, enable_hand_model=True)
|
| 75 |
+
self.head_pose_hand.hand_pose_comps_ori = nn.Parameter(
|
| 76 |
+
self.head_pose_hand.hand_pose_comps.clone(), requires_grad=False
|
| 77 |
+
)
|
| 78 |
+
self.head_pose_hand.hand_pose_comps.data = torch.eye(54).to(self.head_pose_hand.hand_pose_comps.data).float()
|
| 79 |
+
self.init_pose_hand = nn.Embedding(1, self.head_pose_hand.npose)
|
| 80 |
+
|
| 81 |
+
self.head_camera = build_head(self.cfg, self.cfg.MODEL.PERSON_HEAD.CAMERA_TYPE)
|
| 82 |
+
self.init_camera = nn.Embedding(1, self.head_camera.ncam)
|
| 83 |
+
nn.init.zeros_(self.init_camera.weight)
|
| 84 |
+
|
| 85 |
+
self.head_camera_hand = build_head(
|
| 86 |
+
self.cfg,
|
| 87 |
+
self.cfg.MODEL.PERSON_HEAD.CAMERA_TYPE,
|
| 88 |
+
default_scale_factor=self.cfg.MODEL.CAMERA_HEAD.get("DEFAULT_SCALE_FACTOR_HAND", 1.0),
|
| 89 |
+
)
|
| 90 |
+
self.init_camera_hand = nn.Embedding(1, self.head_camera_hand.ncam)
|
| 91 |
+
nn.init.zeros_(self.init_camera_hand.weight)
|
| 92 |
+
|
| 93 |
+
self.camera_type = "perspective"
|
| 94 |
+
|
| 95 |
+
# Support conditioned information for decoder
|
| 96 |
+
cond_dim = 3
|
| 97 |
+
init_dim = self.head_pose.npose + self.head_camera.ncam + cond_dim
|
| 98 |
+
self.init_to_token_mhr = nn.Linear(init_dim, self.cfg.MODEL.DECODER.DIM)
|
| 99 |
+
self.prev_to_token_mhr = nn.Linear(init_dim - cond_dim, self.cfg.MODEL.DECODER.DIM)
|
| 100 |
+
self.init_to_token_mhr_hand = nn.Linear(init_dim, self.cfg.MODEL.DECODER.DIM)
|
| 101 |
+
self.prev_to_token_mhr_hand = nn.Linear(init_dim - cond_dim, self.cfg.MODEL.DECODER.DIM)
|
| 102 |
+
|
| 103 |
+
# Create prompt encoder
|
| 104 |
+
self.max_num_clicks = 0
|
| 105 |
+
if self.cfg.MODEL.PROMPT_ENCODER.ENABLE:
|
| 106 |
+
self.max_num_clicks = self.cfg.MODEL.PROMPT_ENCODER.MAX_NUM_CLICKS
|
| 107 |
+
self.prompt_keypoints = PROMPT_KEYPOINTS[self.cfg.MODEL.PROMPT_ENCODER.PROMPT_KEYPOINTS]
|
| 108 |
+
|
| 109 |
+
self.prompt_encoder = PromptEncoder(
|
| 110 |
+
embed_dim=self.backbone.embed_dims, # need to match backbone dims for PE
|
| 111 |
+
num_body_joints=len(set(self.prompt_keypoints.values())),
|
| 112 |
+
frozen=self.cfg.MODEL.PROMPT_ENCODER.get("frozen", False),
|
| 113 |
+
mask_embed_type=self.cfg.MODEL.PROMPT_ENCODER.get("MASK_EMBED_TYPE", None),
|
| 114 |
+
)
|
| 115 |
+
self.prompt_to_token = nn.Linear(self.backbone.embed_dims, self.cfg.MODEL.DECODER.DIM)
|
| 116 |
+
|
| 117 |
+
self.keypoint_prompt_sampler = build_keypoint_sampler(
|
| 118 |
+
self.cfg.MODEL.PROMPT_ENCODER.get("KEYPOINT_SAMPLER", {}),
|
| 119 |
+
prompt_keypoints=self.prompt_keypoints,
|
| 120 |
+
keybody_idx=(
|
| 121 |
+
KEY_BODY if not self.cfg.MODEL.PROMPT_ENCODER.get("SAMPLE_HAND", False) else KEY_RIGHT_HAND
|
| 122 |
+
),
|
| 123 |
+
)
|
| 124 |
+
# To keep track of prompting history
|
| 125 |
+
self.prompt_hist = np.zeros(
|
| 126 |
+
(len(set(self.prompt_keypoints.values())) + 2, self.max_num_clicks),
|
| 127 |
+
dtype=np.float32,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
if self.cfg.MODEL.DECODER.FROZEN:
|
| 131 |
+
for param in self.prompt_to_token.parameters():
|
| 132 |
+
param.requires_grad = False
|
| 133 |
+
|
| 134 |
+
# Create promptable decoder
|
| 135 |
+
self.decoder = build_decoder(self.cfg.MODEL.DECODER, context_dim=self.backbone.embed_dims)
|
| 136 |
+
# shared config for the two decoders
|
| 137 |
+
self.decoder_hand = build_decoder(self.cfg.MODEL.DECODER, context_dim=self.backbone.embed_dims)
|
| 138 |
+
self.hand_pe_layer = PositionEmbeddingRandom(self.backbone.embed_dims // 2)
|
| 139 |
+
|
| 140 |
+
# Manually convert the torso of the model to fp16.
|
| 141 |
+
if self.cfg.TRAIN.USE_FP16:
|
| 142 |
+
self.convert_to_fp16()
|
| 143 |
+
if self.cfg.TRAIN.get("FP16_TYPE", "float16") == "float16":
|
| 144 |
+
self.backbone_dtype = torch.float16
|
| 145 |
+
else:
|
| 146 |
+
self.backbone_dtype = torch.bfloat16
|
| 147 |
+
else:
|
| 148 |
+
self.backbone_dtype = torch.float32
|
| 149 |
+
|
| 150 |
+
self.ray_cond_emb = CameraEncoder(
|
| 151 |
+
self.backbone.embed_dim,
|
| 152 |
+
self.backbone.patch_size,
|
| 153 |
+
)
|
| 154 |
+
self.ray_cond_emb_hand = CameraEncoder(
|
| 155 |
+
self.backbone.embed_dim,
|
| 156 |
+
self.backbone.patch_size,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
self.keypoint_embedding_idxs = list(range(70))
|
| 160 |
+
self.keypoint_embedding = nn.Embedding(len(self.keypoint_embedding_idxs), self.cfg.MODEL.DECODER.DIM)
|
| 161 |
+
self.keypoint_embedding_idxs_hand = list(range(70))
|
| 162 |
+
self.keypoint_embedding_hand = nn.Embedding(len(self.keypoint_embedding_idxs_hand), self.cfg.MODEL.DECODER.DIM)
|
| 163 |
+
|
| 164 |
+
if self.cfg.MODEL.DECODER.get("DO_HAND_DETECT_TOKENS", False):
|
| 165 |
+
self.hand_box_embedding = nn.Embedding(2, self.cfg.MODEL.DECODER.DIM) # for two hands
|
| 166 |
+
# decice if there is left or right hand inside the image
|
| 167 |
+
self.hand_cls_embed = nn.Linear(self.cfg.MODEL.DECODER.DIM, 2)
|
| 168 |
+
self.bbox_embed = MLP(self.cfg.MODEL.DECODER.DIM, self.cfg.MODEL.DECODER.DIM, 4, 3)
|
| 169 |
+
|
| 170 |
+
self.keypoint_posemb_linear = FFN(
|
| 171 |
+
embed_dims=2,
|
| 172 |
+
feedforward_channels=self.cfg.MODEL.DECODER.DIM,
|
| 173 |
+
output_dims=self.cfg.MODEL.DECODER.DIM,
|
| 174 |
+
num_fcs=2,
|
| 175 |
+
add_identity=False,
|
| 176 |
+
)
|
| 177 |
+
self.keypoint_posemb_linear_hand = FFN(
|
| 178 |
+
embed_dims=2,
|
| 179 |
+
feedforward_channels=self.cfg.MODEL.DECODER.DIM,
|
| 180 |
+
output_dims=self.cfg.MODEL.DECODER.DIM,
|
| 181 |
+
num_fcs=2,
|
| 182 |
+
add_identity=False,
|
| 183 |
+
)
|
| 184 |
+
self.keypoint_feat_linear = nn.Linear(self.backbone.embed_dims, self.cfg.MODEL.DECODER.DIM)
|
| 185 |
+
self.keypoint_feat_linear_hand = nn.Linear(self.backbone.embed_dims, self.cfg.MODEL.DECODER.DIM)
|
| 186 |
+
|
| 187 |
+
# Do all KPS
|
| 188 |
+
self.keypoint3d_embedding_idxs = list(range(70))
|
| 189 |
+
self.keypoint3d_embedding = nn.Embedding(len(self.keypoint3d_embedding_idxs), self.cfg.MODEL.DECODER.DIM)
|
| 190 |
+
|
| 191 |
+
# Assume always do full body for the hand decoder
|
| 192 |
+
self.keypoint3d_embedding_idxs_hand = list(range(70))
|
| 193 |
+
self.keypoint3d_embedding_hand = nn.Embedding(
|
| 194 |
+
len(self.keypoint3d_embedding_idxs_hand), self.cfg.MODEL.DECODER.DIM
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
self.keypoint3d_posemb_linear = FFN(
|
| 198 |
+
embed_dims=3,
|
| 199 |
+
feedforward_channels=self.cfg.MODEL.DECODER.DIM,
|
| 200 |
+
output_dims=self.cfg.MODEL.DECODER.DIM,
|
| 201 |
+
num_fcs=2,
|
| 202 |
+
add_identity=False,
|
| 203 |
+
)
|
| 204 |
+
self.keypoint3d_posemb_linear_hand = FFN(
|
| 205 |
+
embed_dims=3,
|
| 206 |
+
feedforward_channels=self.cfg.MODEL.DECODER.DIM,
|
| 207 |
+
output_dims=self.cfg.MODEL.DECODER.DIM,
|
| 208 |
+
num_fcs=2,
|
| 209 |
+
add_identity=False,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
def _get_decoder_condition(self, batch: dict) -> torch.Tensor | None:
|
| 213 |
+
num_person = batch["img"].shape[1]
|
| 214 |
+
|
| 215 |
+
if self.cfg.MODEL.DECODER.CONDITION_TYPE == "cliff":
|
| 216 |
+
# CLIFF-style condition info (cx/f, cy/f, b/f)
|
| 217 |
+
cx, cy = torch.chunk(self._flatten_person(batch["bbox_center"]), chunks=2, dim=-1)
|
| 218 |
+
img_w, img_h = torch.chunk(self._flatten_person(batch["ori_img_size"]), chunks=2, dim=-1)
|
| 219 |
+
b = self._flatten_person(batch["bbox_scale"])[:, [0]]
|
| 220 |
+
|
| 221 |
+
focal_length = self._flatten_person(
|
| 222 |
+
batch["cam_int"].unsqueeze(1).expand(-1, num_person, -1, -1).contiguous()
|
| 223 |
+
)[:, 0, 0]
|
| 224 |
+
if not self.cfg.MODEL.DECODER.get("USE_INTRIN_CENTER", False):
|
| 225 |
+
condition_info = torch.cat([cx - img_w / 2.0, cy - img_h / 2.0, b], dim=-1)
|
| 226 |
+
else:
|
| 227 |
+
full_img_cxy = self._flatten_person(
|
| 228 |
+
batch["cam_int"].unsqueeze(1).expand(-1, num_person, -1, -1).contiguous()
|
| 229 |
+
)[:, [0, 1], [2, 2]]
|
| 230 |
+
condition_info = torch.cat([cx - full_img_cxy[:, [0]], cy - full_img_cxy[:, [1]], b], dim=-1)
|
| 231 |
+
condition_info[:, :2] = condition_info[:, :2] / focal_length.unsqueeze(-1) # [-1, 1]
|
| 232 |
+
condition_info[:, 2] = condition_info[:, 2] / focal_length # [-1, 1]
|
| 233 |
+
elif self.cfg.MODEL.DECODER.CONDITION_TYPE == "none":
|
| 234 |
+
return None
|
| 235 |
+
else:
|
| 236 |
+
raise NotImplementedError
|
| 237 |
+
|
| 238 |
+
return condition_info.type(batch["img"].dtype)
|
| 239 |
+
|
| 240 |
+
def forward_decoder(
|
| 241 |
+
self,
|
| 242 |
+
image_embeddings: torch.Tensor,
|
| 243 |
+
init_estimate: torch.Tensor | None = None,
|
| 244 |
+
keypoints: torch.Tensor | None = None,
|
| 245 |
+
prev_estimate: torch.Tensor | None = None,
|
| 246 |
+
condition_info: torch.Tensor | None = None,
|
| 247 |
+
batch=None,
|
| 248 |
+
):
|
| 249 |
+
"""
|
| 250 |
+
Args:
|
| 251 |
+
image_embeddings: image features from the backbone, shape (B, C, H, W)
|
| 252 |
+
init_estimate: initial estimate to be refined on, shape (B, 1, C)
|
| 253 |
+
keypoints: optional prompt input, shape (B, N, 3),
|
| 254 |
+
3 for coordinates (x,y) + label.
|
| 255 |
+
(x, y) should be normalized to range [0, 1].
|
| 256 |
+
label==-1 indicates incorrect points,
|
| 257 |
+
label==-2 indicates invalid points
|
| 258 |
+
prev_estimate: optional prompt input, shape (B, 1, C),
|
| 259 |
+
previous estimate for pose refinement.
|
| 260 |
+
condition_info: optional condition information that is concatenated with
|
| 261 |
+
the input tokens, shape (B, c)
|
| 262 |
+
"""
|
| 263 |
+
batch_size = image_embeddings.shape[0]
|
| 264 |
+
|
| 265 |
+
# Initial estimation for residual prediction.
|
| 266 |
+
if init_estimate is None:
|
| 267 |
+
init_pose = self.init_pose.weight.expand(batch_size, -1).unsqueeze(dim=1)
|
| 268 |
+
if hasattr(self, "init_camera"):
|
| 269 |
+
init_camera = self.init_camera.weight.expand(batch_size, -1).unsqueeze(dim=1)
|
| 270 |
+
|
| 271 |
+
init_estimate = (
|
| 272 |
+
init_pose if not hasattr(self, "init_camera") else torch.cat([init_pose, init_camera], dim=-1)
|
| 273 |
+
) # This is basically pose & camera translation at the end. B x 1 x (404 + 3)
|
| 274 |
+
|
| 275 |
+
init_input = (
|
| 276 |
+
torch.cat([condition_info.view(batch_size, 1, -1), init_estimate], dim=-1)
|
| 277 |
+
if condition_info is not None
|
| 278 |
+
else init_estimate
|
| 279 |
+
) # B x 1 x 410 (this is with the CLIFF condition)
|
| 280 |
+
token_embeddings = self.init_to_token_mhr(init_input).view(batch_size, 1, -1) # B x 1 x 1024 (linear layered)
|
| 281 |
+
|
| 282 |
+
num_pose_token = token_embeddings.shape[1]
|
| 283 |
+
assert num_pose_token == 1
|
| 284 |
+
|
| 285 |
+
image_augment, token_augment, token_mask = None, None, None
|
| 286 |
+
if hasattr(self, "prompt_encoder") and keypoints is not None:
|
| 287 |
+
if prev_estimate is None:
|
| 288 |
+
# Use initial embedding if no previous embedding
|
| 289 |
+
prev_estimate = init_estimate
|
| 290 |
+
# Previous estimate w/o the CLIFF condition.
|
| 291 |
+
prev_embeddings = self.prev_to_token_mhr(prev_estimate).view(
|
| 292 |
+
batch_size, 1, -1
|
| 293 |
+
) # 407 -> B x 1 x 1024; linear layer-ed
|
| 294 |
+
|
| 295 |
+
if self.cfg.MODEL.BACKBONE.TYPE in [
|
| 296 |
+
"vit_hmr",
|
| 297 |
+
"vit",
|
| 298 |
+
"vit_b",
|
| 299 |
+
"vit_l",
|
| 300 |
+
]:
|
| 301 |
+
# ViT backbone assumes a different aspect ratio as input size
|
| 302 |
+
image_augment = self.prompt_encoder.get_dense_pe((16, 16))[:, :, :, 2:-2]
|
| 303 |
+
elif self.cfg.MODEL.BACKBONE.TYPE in [
|
| 304 |
+
"vit_hmr_512_384",
|
| 305 |
+
]:
|
| 306 |
+
# ViT backbone assumes a different aspect ratio as input size
|
| 307 |
+
image_augment = self.prompt_encoder.get_dense_pe((32, 32))[:, :, :, 4:-4]
|
| 308 |
+
else:
|
| 309 |
+
image_augment = self.prompt_encoder.get_dense_pe(image_embeddings.shape[-2:]) # (1, C, H, W)
|
| 310 |
+
|
| 311 |
+
image_embeddings = self.ray_cond_emb(image_embeddings, batch["ray_cond"])
|
| 312 |
+
|
| 313 |
+
# To start, keypoints is all [0, 0, -2]. The points get sent into self.pe_layer._pe_encoding,
|
| 314 |
+
# the labels determine the embedding weight (special one for -2, -1, then each of joint.)
|
| 315 |
+
prompt_embeddings, prompt_mask = self.prompt_encoder(keypoints=keypoints) # B x 1 x 1280
|
| 316 |
+
prompt_embeddings = self.prompt_to_token(prompt_embeddings) # Linear layered: B x 1 x 1024
|
| 317 |
+
|
| 318 |
+
# Concatenate pose tokens and prompt embeddings as decoder input
|
| 319 |
+
token_embeddings = torch.cat(
|
| 320 |
+
[
|
| 321 |
+
token_embeddings,
|
| 322 |
+
prev_embeddings,
|
| 323 |
+
prompt_embeddings,
|
| 324 |
+
],
|
| 325 |
+
dim=1,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
token_augment = torch.zeros_like(token_embeddings)
|
| 329 |
+
token_augment[:, [num_pose_token]] = prev_embeddings
|
| 330 |
+
token_augment[:, (num_pose_token + 1) :] = prompt_embeddings
|
| 331 |
+
token_mask = None
|
| 332 |
+
|
| 333 |
+
if self.cfg.MODEL.DECODER.get("DO_HAND_DETECT_TOKENS", False):
|
| 334 |
+
# Put in a token for each hand
|
| 335 |
+
hand_det_emb_start_idx = token_embeddings.shape[1]
|
| 336 |
+
token_embeddings = torch.cat(
|
| 337 |
+
[
|
| 338 |
+
token_embeddings,
|
| 339 |
+
self.hand_box_embedding.weight[None, :, :].repeat(batch_size, 1, 1),
|
| 340 |
+
],
|
| 341 |
+
dim=1,
|
| 342 |
+
) # B x 5 + 70 x 1024
|
| 343 |
+
# No positional embeddings
|
| 344 |
+
token_augment = torch.cat(
|
| 345 |
+
[
|
| 346 |
+
token_augment,
|
| 347 |
+
torch.zeros_like(token_embeddings[:, token_augment.shape[1] :, :]),
|
| 348 |
+
],
|
| 349 |
+
dim=1,
|
| 350 |
+
) # B x 5 + 70 x 1024
|
| 351 |
+
|
| 352 |
+
assert self.cfg.MODEL.DECODER.get("DO_KEYPOINT_TOKENS", False)
|
| 353 |
+
# Put in a token for each keypoint
|
| 354 |
+
kps_emb_start_idx = token_embeddings.shape[1]
|
| 355 |
+
token_embeddings = torch.cat(
|
| 356 |
+
[
|
| 357 |
+
token_embeddings,
|
| 358 |
+
self.keypoint_embedding.weight[None, :, :].repeat(batch_size, 1, 1),
|
| 359 |
+
],
|
| 360 |
+
dim=1,
|
| 361 |
+
) # B x 3 + 70 x 1024
|
| 362 |
+
# No positional embeddings
|
| 363 |
+
token_augment = torch.cat(
|
| 364 |
+
[
|
| 365 |
+
token_augment,
|
| 366 |
+
torch.zeros_like(token_embeddings[:, token_augment.shape[1] :, :]),
|
| 367 |
+
],
|
| 368 |
+
dim=1,
|
| 369 |
+
) # B x 3 + 70 x 1024
|
| 370 |
+
if self.cfg.MODEL.DECODER.get("DO_KEYPOINT3D_TOKENS", False):
|
| 371 |
+
# Put in a token for each keypoint
|
| 372 |
+
kps3d_emb_start_idx = token_embeddings.shape[1]
|
| 373 |
+
token_embeddings = torch.cat(
|
| 374 |
+
[
|
| 375 |
+
token_embeddings,
|
| 376 |
+
self.keypoint3d_embedding.weight[None, :, :].repeat(batch_size, 1, 1),
|
| 377 |
+
],
|
| 378 |
+
dim=1,
|
| 379 |
+
) # B x 3 + 70 + 70 x 1024
|
| 380 |
+
# No positional embeddings
|
| 381 |
+
token_augment = torch.cat(
|
| 382 |
+
[
|
| 383 |
+
token_augment,
|
| 384 |
+
torch.zeros_like(token_embeddings[:, token_augment.shape[1] :, :]),
|
| 385 |
+
],
|
| 386 |
+
dim=1,
|
| 387 |
+
) # B x 3 + 70 + 70 x 1024
|
| 388 |
+
|
| 389 |
+
# We're doing intermediate model predictions
|
| 390 |
+
def token_to_pose_output_fn(tokens, prev_pose_output, layer_idx):
|
| 391 |
+
# Get the pose token
|
| 392 |
+
pose_token = tokens[:, 0]
|
| 393 |
+
|
| 394 |
+
prev_pose = init_pose.view(batch_size, -1)
|
| 395 |
+
prev_camera = init_camera.view(batch_size, -1)
|
| 396 |
+
|
| 397 |
+
# Get pose outputs
|
| 398 |
+
pose_output = self.head_pose(pose_token, prev_pose)
|
| 399 |
+
# Get Camera Translation
|
| 400 |
+
if hasattr(self, "head_camera"):
|
| 401 |
+
pred_cam = self.head_camera(pose_token, prev_camera)
|
| 402 |
+
pose_output["pred_cam"] = pred_cam
|
| 403 |
+
# Run camera projection
|
| 404 |
+
pose_output = self.camera_project(pose_output, batch)
|
| 405 |
+
|
| 406 |
+
# Get 2D KPS in crop
|
| 407 |
+
pose_output["pred_keypoints_2d_cropped"] = self._full_to_crop(
|
| 408 |
+
batch, pose_output["pred_keypoints_2d"], self.body_batch_idx
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
return pose_output
|
| 412 |
+
|
| 413 |
+
kp_token_update_fn = self.keypoint_token_update_fn
|
| 414 |
+
|
| 415 |
+
# Now for 3D
|
| 416 |
+
kp3d_token_update_fn = self.keypoint3d_token_update_fn
|
| 417 |
+
|
| 418 |
+
# Combine the 2D and 3D functionse
|
| 419 |
+
def keypoint_token_update_fn_comb(*args):
|
| 420 |
+
if kp_token_update_fn is not None:
|
| 421 |
+
args = kp_token_update_fn(kps_emb_start_idx, image_embeddings, *args)
|
| 422 |
+
if kp3d_token_update_fn is not None:
|
| 423 |
+
args = kp3d_token_update_fn(kps3d_emb_start_idx, *args)
|
| 424 |
+
return args
|
| 425 |
+
|
| 426 |
+
pose_token, pose_output = self.decoder(
|
| 427 |
+
token_embeddings,
|
| 428 |
+
image_embeddings,
|
| 429 |
+
token_augment,
|
| 430 |
+
image_augment,
|
| 431 |
+
token_mask,
|
| 432 |
+
token_to_pose_output_fn=token_to_pose_output_fn,
|
| 433 |
+
keypoint_token_update_fn=keypoint_token_update_fn_comb,
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
if self.cfg.MODEL.DECODER.get("DO_HAND_DETECT_TOKENS", False):
|
| 437 |
+
return (
|
| 438 |
+
pose_token[:, hand_det_emb_start_idx : hand_det_emb_start_idx + 2],
|
| 439 |
+
pose_output,
|
| 440 |
+
)
|
| 441 |
+
else:
|
| 442 |
+
return pose_token, pose_output
|
| 443 |
+
|
| 444 |
+
def forward_decoder_hand(
|
| 445 |
+
self,
|
| 446 |
+
image_embeddings: torch.Tensor,
|
| 447 |
+
init_estimate: torch.Tensor | None = None,
|
| 448 |
+
keypoints: torch.Tensor | None = None,
|
| 449 |
+
prev_estimate: torch.Tensor | None = None,
|
| 450 |
+
condition_info: torch.Tensor | None = None,
|
| 451 |
+
batch=None,
|
| 452 |
+
):
|
| 453 |
+
"""
|
| 454 |
+
Args:
|
| 455 |
+
image_embeddings: image features from the backbone, shape (B, C, H, W)
|
| 456 |
+
init_estimate: initial estimate to be refined on, shape (B, 1, C)
|
| 457 |
+
keypoints: optional prompt input, shape (B, N, 3),
|
| 458 |
+
3 for coordinates (x,y) + label.
|
| 459 |
+
(x, y) should be normalized to range [0, 1].
|
| 460 |
+
label==-1 indicates incorrect points,
|
| 461 |
+
label==-2 indicates invalid points
|
| 462 |
+
prev_estimate: optional prompt input, shape (B, 1, C),
|
| 463 |
+
previous estimate for pose refinement.
|
| 464 |
+
condition_info: optional condition information that is concatenated with
|
| 465 |
+
the input tokens, shape (B, c)
|
| 466 |
+
"""
|
| 467 |
+
batch_size = image_embeddings.shape[0]
|
| 468 |
+
|
| 469 |
+
# Initial estimation for residual prediction.
|
| 470 |
+
if init_estimate is None:
|
| 471 |
+
init_pose = self.init_pose_hand.weight.expand(batch_size, -1).unsqueeze(dim=1)
|
| 472 |
+
if hasattr(self, "init_camera_hand"):
|
| 473 |
+
init_camera = self.init_camera_hand.weight.expand(batch_size, -1).unsqueeze(dim=1)
|
| 474 |
+
|
| 475 |
+
init_estimate = (
|
| 476 |
+
init_pose if not hasattr(self, "init_camera_hand") else torch.cat([init_pose, init_camera], dim=-1)
|
| 477 |
+
) # This is basically pose & camera translation at the end. B x 1 x (404 + 3)
|
| 478 |
+
|
| 479 |
+
init_input = (
|
| 480 |
+
torch.cat([condition_info.view(batch_size, 1, -1), init_estimate], dim=-1)
|
| 481 |
+
if condition_info is not None
|
| 482 |
+
else init_estimate
|
| 483 |
+
) # B x 1 x 410 (this is with the CLIFF condition)
|
| 484 |
+
token_embeddings = self.init_to_token_mhr_hand(init_input).view(
|
| 485 |
+
batch_size, 1, -1
|
| 486 |
+
) # B x 1 x 1024 (linear layered)
|
| 487 |
+
num_pose_token = token_embeddings.shape[1]
|
| 488 |
+
|
| 489 |
+
image_augment, token_augment, token_mask = None, None, None
|
| 490 |
+
if hasattr(self, "prompt_encoder") and keypoints is not None:
|
| 491 |
+
if prev_estimate is None:
|
| 492 |
+
# Use initial embedding if no previous embedding
|
| 493 |
+
prev_estimate = init_estimate
|
| 494 |
+
# Previous estimate w/o the CLIFF condition.
|
| 495 |
+
prev_embeddings = self.prev_to_token_mhr_hand(prev_estimate).view(
|
| 496 |
+
batch_size, 1, -1
|
| 497 |
+
) # 407 -> B x 1 x 1024; linear layer-ed
|
| 498 |
+
|
| 499 |
+
if self.cfg.MODEL.BACKBONE.TYPE in [
|
| 500 |
+
"vit_hmr",
|
| 501 |
+
"vit",
|
| 502 |
+
"vit_b",
|
| 503 |
+
"vit_l",
|
| 504 |
+
]:
|
| 505 |
+
# ViT backbone assumes a different aspect ratio as input size
|
| 506 |
+
image_augment = self.hand_pe_layer((16, 16)).unsqueeze(0)[:, :, :, 2:-2]
|
| 507 |
+
elif self.cfg.MODEL.BACKBONE.TYPE in [
|
| 508 |
+
"vit_hmr_512_384",
|
| 509 |
+
]:
|
| 510 |
+
# ViT backbone assumes a different aspect ratio as input size
|
| 511 |
+
image_augment = self.hand_pe_layer((32, 32)).unsqueeze(0)[:, :, :, 4:-4]
|
| 512 |
+
else:
|
| 513 |
+
image_augment = self.hand_pe_layer(image_embeddings.shape[-2:]).unsqueeze(0) # (1, C, H, W)
|
| 514 |
+
|
| 515 |
+
image_embeddings = self.ray_cond_emb_hand(image_embeddings, batch["ray_cond_hand"])
|
| 516 |
+
|
| 517 |
+
# To start, keypoints is all [0, 0, -2]. The points get sent into self.pe_layer._pe_encoding,
|
| 518 |
+
# the labels determine the embedding weight (special one for -2, -1, then each of joint.)
|
| 519 |
+
prompt_embeddings, prompt_mask = self.prompt_encoder(keypoints=keypoints) # B x 1 x 1280
|
| 520 |
+
prompt_embeddings = self.prompt_to_token(prompt_embeddings) # Linear layered: B x 1 x 1024
|
| 521 |
+
|
| 522 |
+
# Concatenate pose tokens and prompt embeddings as decoder input
|
| 523 |
+
token_embeddings = torch.cat(
|
| 524 |
+
[
|
| 525 |
+
token_embeddings,
|
| 526 |
+
prev_embeddings,
|
| 527 |
+
prompt_embeddings,
|
| 528 |
+
],
|
| 529 |
+
dim=1,
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
token_augment = torch.zeros_like(token_embeddings)
|
| 533 |
+
token_augment[:, [num_pose_token]] = prev_embeddings
|
| 534 |
+
token_augment[:, (num_pose_token + 1) :] = prompt_embeddings
|
| 535 |
+
token_mask = None
|
| 536 |
+
|
| 537 |
+
if self.cfg.MODEL.DECODER.get("DO_HAND_DETECT_TOKENS", False):
|
| 538 |
+
# Put in a token for each hand
|
| 539 |
+
hand_det_emb_start_idx = token_embeddings.shape[1]
|
| 540 |
+
token_embeddings = torch.cat(
|
| 541 |
+
[
|
| 542 |
+
token_embeddings,
|
| 543 |
+
self.hand_box_embedding.weight[None, :, :].repeat(batch_size, 1, 1),
|
| 544 |
+
],
|
| 545 |
+
dim=1,
|
| 546 |
+
) # B x 5 + 70 x 1024
|
| 547 |
+
# No positional embeddings
|
| 548 |
+
token_augment = torch.cat(
|
| 549 |
+
[
|
| 550 |
+
token_augment,
|
| 551 |
+
torch.zeros_like(token_embeddings[:, token_augment.shape[1] :, :]),
|
| 552 |
+
],
|
| 553 |
+
dim=1,
|
| 554 |
+
) # B x 5 + 70 x 1024
|
| 555 |
+
|
| 556 |
+
assert self.cfg.MODEL.DECODER.get("DO_KEYPOINT_TOKENS", False)
|
| 557 |
+
# Put in a token for each keypoint
|
| 558 |
+
kps_emb_start_idx = token_embeddings.shape[1]
|
| 559 |
+
token_embeddings = torch.cat(
|
| 560 |
+
[
|
| 561 |
+
token_embeddings,
|
| 562 |
+
self.keypoint_embedding_hand.weight[None, :, :].repeat(batch_size, 1, 1),
|
| 563 |
+
],
|
| 564 |
+
dim=1,
|
| 565 |
+
) # B x 3 + 70 x 1024
|
| 566 |
+
# No positional embeddings
|
| 567 |
+
token_augment = torch.cat(
|
| 568 |
+
[
|
| 569 |
+
token_augment,
|
| 570 |
+
torch.zeros_like(token_embeddings[:, token_augment.shape[1] :, :]),
|
| 571 |
+
],
|
| 572 |
+
dim=1,
|
| 573 |
+
) # B x 3 + 70 x 1024
|
| 574 |
+
|
| 575 |
+
if self.cfg.MODEL.DECODER.get("DO_KEYPOINT3D_TOKENS", False):
|
| 576 |
+
# Put in a token for each keypoint
|
| 577 |
+
kps3d_emb_start_idx = token_embeddings.shape[1]
|
| 578 |
+
token_embeddings = torch.cat(
|
| 579 |
+
[
|
| 580 |
+
token_embeddings,
|
| 581 |
+
self.keypoint3d_embedding_hand.weight[None, :, :].repeat(batch_size, 1, 1),
|
| 582 |
+
],
|
| 583 |
+
dim=1,
|
| 584 |
+
) # B x 3 + 70 + 70 x 1024
|
| 585 |
+
# No positional embeddings
|
| 586 |
+
token_augment = torch.cat(
|
| 587 |
+
[
|
| 588 |
+
token_augment,
|
| 589 |
+
torch.zeros_like(token_embeddings[:, token_augment.shape[1] :, :]),
|
| 590 |
+
],
|
| 591 |
+
dim=1,
|
| 592 |
+
) # B x 3 + 70 + 70 x 1024
|
| 593 |
+
|
| 594 |
+
# We're doing intermediate model predictions
|
| 595 |
+
def token_to_pose_output_fn(tokens, prev_pose_output, layer_idx):
|
| 596 |
+
# Get the pose token
|
| 597 |
+
pose_token = tokens[:, 0]
|
| 598 |
+
|
| 599 |
+
prev_pose = init_pose.view(batch_size, -1)
|
| 600 |
+
prev_camera = init_camera.view(batch_size, -1)
|
| 601 |
+
|
| 602 |
+
# Get pose outputs
|
| 603 |
+
pose_output = self.head_pose_hand(pose_token, prev_pose)
|
| 604 |
+
|
| 605 |
+
# Get Camera Translation
|
| 606 |
+
if hasattr(self, "head_camera_hand"):
|
| 607 |
+
pred_cam = self.head_camera_hand(pose_token, prev_camera)
|
| 608 |
+
pose_output["pred_cam"] = pred_cam
|
| 609 |
+
# Run camera projection
|
| 610 |
+
pose_output = self.camera_project_hand(pose_output, batch)
|
| 611 |
+
|
| 612 |
+
# Get 2D KPS in crop
|
| 613 |
+
pose_output["pred_keypoints_2d_cropped"] = self._full_to_crop(
|
| 614 |
+
batch, pose_output["pred_keypoints_2d"], self.hand_batch_idx
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
return pose_output
|
| 618 |
+
|
| 619 |
+
kp_token_update_fn = self.keypoint_token_update_fn_hand
|
| 620 |
+
|
| 621 |
+
# Now for 3D
|
| 622 |
+
kp3d_token_update_fn = self.keypoint3d_token_update_fn_hand
|
| 623 |
+
|
| 624 |
+
# Combine the 2D and 3D functionse
|
| 625 |
+
def keypoint_token_update_fn_comb(*args):
|
| 626 |
+
if kp_token_update_fn is not None:
|
| 627 |
+
args = kp_token_update_fn(kps_emb_start_idx, image_embeddings, *args)
|
| 628 |
+
if kp3d_token_update_fn is not None:
|
| 629 |
+
args = kp3d_token_update_fn(kps3d_emb_start_idx, *args)
|
| 630 |
+
return args
|
| 631 |
+
|
| 632 |
+
pose_token, pose_output = self.decoder_hand(
|
| 633 |
+
token_embeddings,
|
| 634 |
+
image_embeddings,
|
| 635 |
+
token_augment,
|
| 636 |
+
image_augment,
|
| 637 |
+
token_mask,
|
| 638 |
+
token_to_pose_output_fn=token_to_pose_output_fn,
|
| 639 |
+
keypoint_token_update_fn=keypoint_token_update_fn_comb,
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
if self.cfg.MODEL.DECODER.get("DO_HAND_DETECT_TOKENS", False):
|
| 643 |
+
return (
|
| 644 |
+
pose_token[:, hand_det_emb_start_idx : hand_det_emb_start_idx + 2],
|
| 645 |
+
pose_output,
|
| 646 |
+
)
|
| 647 |
+
else:
|
| 648 |
+
return pose_token, pose_output
|
| 649 |
+
|
| 650 |
+
@torch.no_grad()
|
| 651 |
+
def _get_keypoint_prompt(self, batch, pred_keypoints_2d, force_dummy=False):
|
| 652 |
+
if self.camera_type == "perspective":
|
| 653 |
+
pred_keypoints_2d = self._full_to_crop(batch, pred_keypoints_2d)
|
| 654 |
+
|
| 655 |
+
gt_keypoints_2d = self._flatten_person(batch["keypoints_2d"]).clone()
|
| 656 |
+
|
| 657 |
+
keypoint_prompt = self.keypoint_prompt_sampler.sample(
|
| 658 |
+
gt_keypoints_2d,
|
| 659 |
+
pred_keypoints_2d,
|
| 660 |
+
is_train=self.training,
|
| 661 |
+
force_dummy=force_dummy,
|
| 662 |
+
)
|
| 663 |
+
return keypoint_prompt
|
| 664 |
+
|
| 665 |
+
def _get_mask_prompt(self, batch, image_embeddings):
|
| 666 |
+
x_mask = self._flatten_person(batch["mask"])
|
| 667 |
+
mask_embeddings, no_mask_embeddings = self.prompt_encoder.get_mask_embeddings(
|
| 668 |
+
x_mask, image_embeddings.shape[0], image_embeddings.shape[2:]
|
| 669 |
+
)
|
| 670 |
+
if self.cfg.MODEL.BACKBONE.TYPE in [
|
| 671 |
+
"vit_hmr",
|
| 672 |
+
"vit",
|
| 673 |
+
]:
|
| 674 |
+
# ViT backbone assumes a different aspect ratio as input size
|
| 675 |
+
mask_embeddings = mask_embeddings[:, :, :, 2:-2]
|
| 676 |
+
elif self.cfg.MODEL.BACKBONE.TYPE in [
|
| 677 |
+
"vit_hmr_512_384",
|
| 678 |
+
]:
|
| 679 |
+
# for x2 resolution
|
| 680 |
+
mask_embeddings = mask_embeddings[:, :, :, 4:-4]
|
| 681 |
+
|
| 682 |
+
mask_score = self._flatten_person(batch["mask_score"]).view(-1, 1, 1, 1)
|
| 683 |
+
mask_embeddings = torch.where(
|
| 684 |
+
mask_score > 0,
|
| 685 |
+
mask_score * mask_embeddings.to(image_embeddings),
|
| 686 |
+
no_mask_embeddings.to(image_embeddings),
|
| 687 |
+
)
|
| 688 |
+
return mask_embeddings
|
| 689 |
+
|
| 690 |
+
def _one_prompt_iter(self, batch, output, prev_prompt, full_output):
|
| 691 |
+
image_embeddings = output["image_embeddings"]
|
| 692 |
+
condition_info = output["condition_info"]
|
| 693 |
+
|
| 694 |
+
if "mhr" in output and output["mhr"] is not None:
|
| 695 |
+
pose_output = output["mhr"] # body-only output
|
| 696 |
+
# Use previous estimate as initialization
|
| 697 |
+
prev_estimate = torch.cat(
|
| 698 |
+
[
|
| 699 |
+
pose_output["pred_pose_raw"].detach(), # (B, 6)
|
| 700 |
+
pose_output["shape"].detach(),
|
| 701 |
+
pose_output["scale"].detach(),
|
| 702 |
+
pose_output["hand"].detach(),
|
| 703 |
+
pose_output["face"].detach(),
|
| 704 |
+
],
|
| 705 |
+
dim=1,
|
| 706 |
+
).unsqueeze(dim=1)
|
| 707 |
+
if hasattr(self, "init_camera"):
|
| 708 |
+
prev_estimate = torch.cat(
|
| 709 |
+
[prev_estimate, pose_output["pred_cam"].detach().unsqueeze(1)],
|
| 710 |
+
dim=-1,
|
| 711 |
+
)
|
| 712 |
+
prev_shape = prev_estimate.shape[1:]
|
| 713 |
+
|
| 714 |
+
pred_keypoints_2d = output["mhr"]["pred_keypoints_2d"].detach().clone()
|
| 715 |
+
kpt_shape = pred_keypoints_2d.shape[1:]
|
| 716 |
+
|
| 717 |
+
if "mhr_hand" in output and output["mhr_hand"] is not None:
|
| 718 |
+
pose_output_hand = output["mhr_hand"]
|
| 719 |
+
# Use previous estimate as initialization
|
| 720 |
+
prev_estimate_hand = torch.cat(
|
| 721 |
+
[
|
| 722 |
+
pose_output_hand["pred_pose_raw"].detach(), # (B, 6)
|
| 723 |
+
pose_output_hand["shape"].detach(),
|
| 724 |
+
pose_output_hand["scale"].detach(),
|
| 725 |
+
pose_output_hand["hand"].detach(),
|
| 726 |
+
pose_output_hand["face"].detach(),
|
| 727 |
+
],
|
| 728 |
+
dim=1,
|
| 729 |
+
).unsqueeze(dim=1)
|
| 730 |
+
if hasattr(self, "init_camera_hand"):
|
| 731 |
+
prev_estimate_hand = torch.cat(
|
| 732 |
+
[
|
| 733 |
+
prev_estimate_hand,
|
| 734 |
+
pose_output_hand["pred_cam"].detach().unsqueeze(1),
|
| 735 |
+
],
|
| 736 |
+
dim=-1,
|
| 737 |
+
)
|
| 738 |
+
prev_shape = prev_estimate_hand.shape[1:]
|
| 739 |
+
|
| 740 |
+
pred_keypoints_2d_hand = output["mhr_hand"]["pred_keypoints_2d"].detach().clone()
|
| 741 |
+
kpt_shape = pred_keypoints_2d_hand.shape[1:]
|
| 742 |
+
|
| 743 |
+
all_prev_estimate = torch.zeros((image_embeddings.shape[0], *prev_shape), device=image_embeddings.device)
|
| 744 |
+
if "mhr" in output and output["mhr"] is not None:
|
| 745 |
+
all_prev_estimate[self.body_batch_idx] = prev_estimate
|
| 746 |
+
if "mhr_hand" in output and output["mhr_hand"] is not None:
|
| 747 |
+
all_prev_estimate[self.hand_batch_idx] = prev_estimate_hand
|
| 748 |
+
|
| 749 |
+
# Get keypoint prompts
|
| 750 |
+
all_pred_keypoints_2d = torch.zeros((image_embeddings.shape[0], *kpt_shape), device=image_embeddings.device)
|
| 751 |
+
if "mhr" in output and output["mhr"] is not None:
|
| 752 |
+
all_pred_keypoints_2d[self.body_batch_idx] = pred_keypoints_2d
|
| 753 |
+
if "mhr_hand" in output and output["mhr_hand"] is not None:
|
| 754 |
+
all_pred_keypoints_2d[self.hand_batch_idx] = pred_keypoints_2d_hand
|
| 755 |
+
|
| 756 |
+
keypoint_prompt = self._get_keypoint_prompt(batch, all_pred_keypoints_2d)
|
| 757 |
+
cur_keypoint_prompt = (
|
| 758 |
+
torch.cat(prev_prompt + [keypoint_prompt], dim=1) if len(prev_prompt) else keypoint_prompt
|
| 759 |
+
) # [B, 1, 3]
|
| 760 |
+
|
| 761 |
+
pose_output, pose_output_hand = None, None
|
| 762 |
+
if len(self.body_batch_idx):
|
| 763 |
+
tokens_output, pose_output = self.forward_decoder(
|
| 764 |
+
image_embeddings[self.body_batch_idx],
|
| 765 |
+
init_estimate=None, # not recurring previous estimate
|
| 766 |
+
keypoints=cur_keypoint_prompt[self.body_batch_idx],
|
| 767 |
+
prev_estimate=all_prev_estimate[self.body_batch_idx],
|
| 768 |
+
condition_info=condition_info[self.body_batch_idx],
|
| 769 |
+
batch=batch,
|
| 770 |
+
full_output=None,
|
| 771 |
+
)
|
| 772 |
+
pose_output = pose_output[-1]
|
| 773 |
+
|
| 774 |
+
# Update prediction output
|
| 775 |
+
output.update(
|
| 776 |
+
{
|
| 777 |
+
"mhr": pose_output,
|
| 778 |
+
"mhr_hand": pose_output_hand,
|
| 779 |
+
}
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
return output, keypoint_prompt
|
| 783 |
+
|
| 784 |
+
def _full_to_crop(
|
| 785 |
+
self,
|
| 786 |
+
batch: dict,
|
| 787 |
+
pred_keypoints_2d: torch.Tensor,
|
| 788 |
+
batch_idx: torch.Tensor | Sequence[int] | None = None,
|
| 789 |
+
) -> torch.Tensor:
|
| 790 |
+
"""Convert full-image keypoints coordinates to crop and normalize to [-0.5. 0.5]"""
|
| 791 |
+
pred_keypoints_2d_cropped = torch.cat(
|
| 792 |
+
[pred_keypoints_2d, torch.ones_like(pred_keypoints_2d[:, :, [-1]])], dim=-1
|
| 793 |
+
)
|
| 794 |
+
if batch_idx is not None:
|
| 795 |
+
affine_trans = self._flatten_person(batch["affine_trans"])[batch_idx].to(pred_keypoints_2d_cropped)
|
| 796 |
+
img_size = self._flatten_person(batch["img_size"])[batch_idx].unsqueeze(1)
|
| 797 |
+
else:
|
| 798 |
+
affine_trans = self._flatten_person(batch["affine_trans"]).to(pred_keypoints_2d_cropped)
|
| 799 |
+
img_size = self._flatten_person(batch["img_size"]).unsqueeze(1)
|
| 800 |
+
pred_keypoints_2d_cropped = pred_keypoints_2d_cropped @ affine_trans.mT
|
| 801 |
+
pred_keypoints_2d_cropped = pred_keypoints_2d_cropped[..., :2] / img_size - 0.5
|
| 802 |
+
|
| 803 |
+
return pred_keypoints_2d_cropped
|
| 804 |
+
|
| 805 |
+
def camera_project(self, pose_output: dict, batch: dict) -> dict:
|
| 806 |
+
"""
|
| 807 |
+
Project 3D keypoints to 2D using the camera parameters.
|
| 808 |
+
Args:
|
| 809 |
+
pose_output (Dict): Dictionary containing the pose output.
|
| 810 |
+
batch (Dict): Dictionary containing the batch data.
|
| 811 |
+
Returns:
|
| 812 |
+
Dict: Dictionary containing the projected 2D keypoints.
|
| 813 |
+
"""
|
| 814 |
+
if hasattr(self, "head_camera"):
|
| 815 |
+
head_camera = self.head_camera
|
| 816 |
+
pred_cam = pose_output["pred_cam"]
|
| 817 |
+
else:
|
| 818 |
+
raise AssertionError("head_camera is not defined")
|
| 819 |
+
|
| 820 |
+
cam_out = head_camera.perspective_projection(
|
| 821 |
+
pose_output["pred_keypoints_3d"],
|
| 822 |
+
pred_cam,
|
| 823 |
+
self._flatten_person(batch["bbox_center"])[self.body_batch_idx],
|
| 824 |
+
self._flatten_person(batch["bbox_scale"])[self.body_batch_idx, 0],
|
| 825 |
+
self._flatten_person(batch["ori_img_size"])[self.body_batch_idx],
|
| 826 |
+
self._flatten_person(batch["cam_int"].unsqueeze(1).expand(-1, batch["img"].shape[1], -1, -1).contiguous())[
|
| 827 |
+
self.body_batch_idx
|
| 828 |
+
],
|
| 829 |
+
use_intrin_center=self.cfg.MODEL.DECODER.get("USE_INTRIN_CENTER", False),
|
| 830 |
+
)
|
| 831 |
+
|
| 832 |
+
if pose_output.get("pred_vertices") is not None:
|
| 833 |
+
cam_out_vertices = head_camera.perspective_projection(
|
| 834 |
+
pose_output["pred_vertices"],
|
| 835 |
+
pred_cam,
|
| 836 |
+
self._flatten_person(batch["bbox_center"])[self.body_batch_idx],
|
| 837 |
+
self._flatten_person(batch["bbox_scale"])[self.body_batch_idx, 0],
|
| 838 |
+
self._flatten_person(batch["ori_img_size"])[self.body_batch_idx],
|
| 839 |
+
self._flatten_person(
|
| 840 |
+
batch["cam_int"].unsqueeze(1).expand(-1, batch["img"].shape[1], -1, -1).contiguous()
|
| 841 |
+
)[self.body_batch_idx],
|
| 842 |
+
use_intrin_center=self.cfg.MODEL.DECODER.get("USE_INTRIN_CENTER", False),
|
| 843 |
+
)
|
| 844 |
+
pose_output["pred_keypoints_2d_verts"] = cam_out_vertices["pred_keypoints_2d"]
|
| 845 |
+
|
| 846 |
+
pose_output.update(cam_out)
|
| 847 |
+
|
| 848 |
+
return pose_output
|
| 849 |
+
|
| 850 |
+
def camera_project_hand(self, pose_output: dict, batch: dict) -> dict:
|
| 851 |
+
"""
|
| 852 |
+
Project 3D keypoints to 2D using the camera parameters.
|
| 853 |
+
Args:
|
| 854 |
+
pose_output (Dict): Dictionary containing the pose output.
|
| 855 |
+
batch (Dict): Dictionary containing the batch data.
|
| 856 |
+
Returns:
|
| 857 |
+
Dict: Dictionary containing the projected 2D keypoints.
|
| 858 |
+
"""
|
| 859 |
+
if hasattr(self, "head_camera_hand"):
|
| 860 |
+
head_camera = self.head_camera_hand
|
| 861 |
+
pred_cam = pose_output["pred_cam"]
|
| 862 |
+
else:
|
| 863 |
+
raise AssertionError("head_camera_hand is not defined")
|
| 864 |
+
|
| 865 |
+
cam_out = head_camera.perspective_projection(
|
| 866 |
+
pose_output["pred_keypoints_3d"],
|
| 867 |
+
pred_cam,
|
| 868 |
+
self._flatten_person(batch["bbox_center"])[self.hand_batch_idx],
|
| 869 |
+
self._flatten_person(batch["bbox_scale"])[self.hand_batch_idx, 0],
|
| 870 |
+
self._flatten_person(batch["ori_img_size"])[self.hand_batch_idx],
|
| 871 |
+
self._flatten_person(batch["cam_int"].unsqueeze(1).expand(-1, batch["img"].shape[1], -1, -1).contiguous())[
|
| 872 |
+
self.hand_batch_idx
|
| 873 |
+
],
|
| 874 |
+
use_intrin_center=self.cfg.MODEL.DECODER.get("USE_INTRIN_CENTER", False),
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
if pose_output.get("pred_vertices") is not None:
|
| 878 |
+
cam_out_vertices = head_camera.perspective_projection(
|
| 879 |
+
pose_output["pred_vertices"],
|
| 880 |
+
pred_cam,
|
| 881 |
+
self._flatten_person(batch["bbox_center"])[self.hand_batch_idx],
|
| 882 |
+
self._flatten_person(batch["bbox_scale"])[self.hand_batch_idx, 0],
|
| 883 |
+
self._flatten_person(batch["ori_img_size"])[self.hand_batch_idx],
|
| 884 |
+
self._flatten_person(
|
| 885 |
+
batch["cam_int"].unsqueeze(1).expand(-1, batch["img"].shape[1], -1, -1).contiguous()
|
| 886 |
+
)[self.hand_batch_idx],
|
| 887 |
+
use_intrin_center=self.cfg.MODEL.DECODER.get("USE_INTRIN_CENTER", False),
|
| 888 |
+
)
|
| 889 |
+
pose_output["pred_keypoints_2d_verts"] = cam_out_vertices["pred_keypoints_2d"]
|
| 890 |
+
|
| 891 |
+
pose_output.update(cam_out)
|
| 892 |
+
|
| 893 |
+
return pose_output
|
| 894 |
+
|
| 895 |
+
def get_ray_condition(self, batch):
|
| 896 |
+
B, N, _, H, W = batch["img"].shape
|
| 897 |
+
meshgrid_xy = (
|
| 898 |
+
torch.stack(torch.meshgrid(torch.arange(H), torch.arange(W), indexing="xy"), dim=2)[None, None, :, :, :]
|
| 899 |
+
.repeat(B, N, 1, 1, 1)
|
| 900 |
+
.cuda()
|
| 901 |
+
) # B x N x H x W x 2
|
| 902 |
+
meshgrid_xy = meshgrid_xy / batch["affine_trans"][:, :, None, None, [0, 1], [0, 1]]
|
| 903 |
+
meshgrid_xy = (
|
| 904 |
+
meshgrid_xy
|
| 905 |
+
- batch["affine_trans"][:, :, None, None, [0, 1], [2, 2]]
|
| 906 |
+
/ batch["affine_trans"][:, :, None, None, [0, 1], [0, 1]]
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
# Subtract out center & normalize to be rays
|
| 910 |
+
meshgrid_xy = meshgrid_xy - batch["cam_int"][:, None, None, None, [0, 1], [2, 2]]
|
| 911 |
+
meshgrid_xy = meshgrid_xy / batch["cam_int"][:, None, None, None, [0, 1], [0, 1]]
|
| 912 |
+
|
| 913 |
+
return meshgrid_xy.permute(0, 1, 4, 2, 3).to(batch["img"].dtype) # This is B x num_person x 2 x H x W
|
| 914 |
+
|
| 915 |
+
def forward_pose_branch(self, batch: dict) -> dict:
|
| 916 |
+
"""Run a forward pass for the crop-image (pose) branch."""
|
| 917 |
+
batch_size, num_person = batch["img"].shape[:2]
|
| 918 |
+
|
| 919 |
+
# Forward backbone encoder
|
| 920 |
+
x = self.data_preprocess(
|
| 921 |
+
self._flatten_person(batch["img"]),
|
| 922 |
+
crop_width=(
|
| 923 |
+
self.cfg.MODEL.BACKBONE.TYPE
|
| 924 |
+
in [
|
| 925 |
+
"vit_hmr",
|
| 926 |
+
"vit",
|
| 927 |
+
"vit_b",
|
| 928 |
+
"vit_l",
|
| 929 |
+
"vit_hmr_512_384",
|
| 930 |
+
]
|
| 931 |
+
),
|
| 932 |
+
)
|
| 933 |
+
|
| 934 |
+
# Optionally get ray conditioining
|
| 935 |
+
ray_cond = self.get_ray_condition(batch) # This is B x num_person x 2 x H x W
|
| 936 |
+
ray_cond = self._flatten_person(ray_cond)
|
| 937 |
+
if self.cfg.MODEL.BACKBONE.TYPE in [
|
| 938 |
+
"vit_hmr",
|
| 939 |
+
"vit",
|
| 940 |
+
"vit_b",
|
| 941 |
+
"vit_l",
|
| 942 |
+
]:
|
| 943 |
+
ray_cond = ray_cond[:, :, :, 32:-32]
|
| 944 |
+
elif self.cfg.MODEL.BACKBONE.TYPE in [
|
| 945 |
+
"vit_hmr_512_384",
|
| 946 |
+
]:
|
| 947 |
+
ray_cond = ray_cond[:, :, :, 64:-64]
|
| 948 |
+
|
| 949 |
+
if len(self.body_batch_idx):
|
| 950 |
+
batch["ray_cond"] = ray_cond[self.body_batch_idx].clone()
|
| 951 |
+
if len(self.hand_batch_idx):
|
| 952 |
+
batch["ray_cond_hand"] = ray_cond[self.hand_batch_idx].clone()
|
| 953 |
+
ray_cond = None
|
| 954 |
+
|
| 955 |
+
image_embeddings = self.backbone(x.type(self.backbone_dtype), extra_embed=ray_cond) # (B, C, H, W)
|
| 956 |
+
|
| 957 |
+
if isinstance(image_embeddings, tuple):
|
| 958 |
+
image_embeddings = image_embeddings[-1]
|
| 959 |
+
image_embeddings = image_embeddings.type(x.dtype)
|
| 960 |
+
|
| 961 |
+
# Mask condition if available
|
| 962 |
+
if self.cfg.MODEL.PROMPT_ENCODER.get("MASK_EMBED_TYPE", None) is not None:
|
| 963 |
+
# v1: non-iterative mask conditioning
|
| 964 |
+
if self.cfg.MODEL.PROMPT_ENCODER.get("MASK_PROMPT", "v1") == "v1":
|
| 965 |
+
mask_embeddings = self._get_mask_prompt(batch, image_embeddings)
|
| 966 |
+
image_embeddings = image_embeddings + mask_embeddings
|
| 967 |
+
else:
|
| 968 |
+
raise NotImplementedError
|
| 969 |
+
|
| 970 |
+
# Prepare input for promptable decoder
|
| 971 |
+
condition_info = self._get_decoder_condition(batch)
|
| 972 |
+
|
| 973 |
+
# Initial estimate with a dummy prompt
|
| 974 |
+
keypoints_prompt = torch.zeros((batch_size * num_person, 1, 3)).to(batch["img"])
|
| 975 |
+
keypoints_prompt[:, :, -1] = -2
|
| 976 |
+
|
| 977 |
+
# Forward promptable decoder to get updated pose tokens and regression output
|
| 978 |
+
pose_output, pose_output_hand = None, None
|
| 979 |
+
if len(self.body_batch_idx):
|
| 980 |
+
tokens_output, pose_output = self.forward_decoder(
|
| 981 |
+
image_embeddings[self.body_batch_idx],
|
| 982 |
+
init_estimate=None,
|
| 983 |
+
keypoints=keypoints_prompt[self.body_batch_idx],
|
| 984 |
+
prev_estimate=None,
|
| 985 |
+
condition_info=condition_info[self.body_batch_idx],
|
| 986 |
+
batch=batch,
|
| 987 |
+
)
|
| 988 |
+
pose_output = pose_output[-1]
|
| 989 |
+
if len(self.hand_batch_idx):
|
| 990 |
+
tokens_output_hand, pose_output_hand = self.forward_decoder_hand(
|
| 991 |
+
image_embeddings[self.hand_batch_idx],
|
| 992 |
+
init_estimate=None,
|
| 993 |
+
keypoints=keypoints_prompt[self.hand_batch_idx],
|
| 994 |
+
prev_estimate=None,
|
| 995 |
+
condition_info=condition_info[self.hand_batch_idx],
|
| 996 |
+
batch=batch,
|
| 997 |
+
)
|
| 998 |
+
pose_output_hand = pose_output_hand[-1]
|
| 999 |
+
|
| 1000 |
+
output = {
|
| 1001 |
+
# "pose_token": pose_token,
|
| 1002 |
+
"mhr": pose_output, # mhr prediction output
|
| 1003 |
+
"mhr_hand": pose_output_hand, # mhr prediction output
|
| 1004 |
+
"condition_info": condition_info,
|
| 1005 |
+
"image_embeddings": image_embeddings,
|
| 1006 |
+
}
|
| 1007 |
+
|
| 1008 |
+
if self.cfg.MODEL.DECODER.get("DO_HAND_DETECT_TOKENS", False):
|
| 1009 |
+
if len(self.body_batch_idx):
|
| 1010 |
+
output_hand_box_tokens = tokens_output
|
| 1011 |
+
hand_coords = self.bbox_embed(output_hand_box_tokens).sigmoid() # x1, y1, w, h for body samples, 0 ~ 1
|
| 1012 |
+
hand_logits = self.hand_cls_embed(output_hand_box_tokens)
|
| 1013 |
+
|
| 1014 |
+
output["mhr"]["hand_box"] = hand_coords
|
| 1015 |
+
output["mhr"]["hand_logits"] = hand_logits
|
| 1016 |
+
|
| 1017 |
+
if len(self.hand_batch_idx):
|
| 1018 |
+
output_hand_box_tokens_hand_batch = tokens_output_hand
|
| 1019 |
+
|
| 1020 |
+
hand_coords_hand_batch = self.bbox_embed(
|
| 1021 |
+
output_hand_box_tokens_hand_batch
|
| 1022 |
+
).sigmoid() # x1, y1, w, h for hand samples
|
| 1023 |
+
hand_logits_hand_batch = self.hand_cls_embed(output_hand_box_tokens_hand_batch)
|
| 1024 |
+
|
| 1025 |
+
output["mhr_hand"]["hand_box"] = hand_coords_hand_batch
|
| 1026 |
+
output["mhr_hand"]["hand_logits"] = hand_logits_hand_batch
|
| 1027 |
+
|
| 1028 |
+
return output
|
| 1029 |
+
|
| 1030 |
+
def forward_step(self, batch: dict, decoder_type: str = "body") -> dict:
|
| 1031 |
+
batch_size, num_person = batch["img"].shape[:2]
|
| 1032 |
+
|
| 1033 |
+
if decoder_type == "body":
|
| 1034 |
+
self.hand_batch_idx = []
|
| 1035 |
+
self.body_batch_idx = list(range(batch_size * num_person))
|
| 1036 |
+
elif decoder_type == "hand":
|
| 1037 |
+
self.hand_batch_idx = list(range(batch_size * num_person))
|
| 1038 |
+
self.body_batch_idx = []
|
| 1039 |
+
else:
|
| 1040 |
+
ValueError("Invalid decoder type: ", decoder_type)
|
| 1041 |
+
|
| 1042 |
+
# Crop-image (pose) branch
|
| 1043 |
+
pose_output = self.forward_pose_branch(batch)
|
| 1044 |
+
|
| 1045 |
+
return pose_output
|
| 1046 |
+
|
| 1047 |
+
def run_inference(
|
| 1048 |
+
self,
|
| 1049 |
+
img,
|
| 1050 |
+
batch: dict,
|
| 1051 |
+
inference_type: str = "full",
|
| 1052 |
+
transform_hand: Any = None,
|
| 1053 |
+
thresh_wrist_angle=1.4,
|
| 1054 |
+
):
|
| 1055 |
+
"""
|
| 1056 |
+
Run 3DB inference (optionally with hand detector).
|
| 1057 |
+
|
| 1058 |
+
inference_type:
|
| 1059 |
+
- full: full-body inference with both body and hand decoders
|
| 1060 |
+
- body: inference with body decoder only (still full-body output)
|
| 1061 |
+
- hand: inference with hand decoder only (only hand output)
|
| 1062 |
+
"""
|
| 1063 |
+
|
| 1064 |
+
height, width = img.shape[:2]
|
| 1065 |
+
cam_int = batch["cam_int"].clone()
|
| 1066 |
+
|
| 1067 |
+
if inference_type == "body":
|
| 1068 |
+
pose_output = self.forward_step(batch, decoder_type="body")
|
| 1069 |
+
return BodyPredContainer(pose_output=pose_output)
|
| 1070 |
+
elif inference_type == "hand":
|
| 1071 |
+
pose_output = self.forward_step(batch, decoder_type="hand")
|
| 1072 |
+
return BodyPredContainer(pose_output=pose_output)
|
| 1073 |
+
elif inference_type != "full":
|
| 1074 |
+
raise ValueError("Invalid inference type: ", inference_type)
|
| 1075 |
+
|
| 1076 |
+
# Step 1. For full-body inference, we first inference with the body decoder.
|
| 1077 |
+
pose_output = self.forward_step(batch, decoder_type="body")
|
| 1078 |
+
left_xyxy, right_xyxy = self._get_hand_box(pose_output, batch)
|
| 1079 |
+
ori_local_wrist_rotmat = roma.euler_to_rotmat(
|
| 1080 |
+
"XZY",
|
| 1081 |
+
pose_output["mhr"]["body_pose"][:, [41, 43, 42, 31, 33, 32]].unflatten(1, (2, 3)),
|
| 1082 |
+
)
|
| 1083 |
+
|
| 1084 |
+
# Step 2. Re-run with each hand
|
| 1085 |
+
## Left... Flip image & box
|
| 1086 |
+
flipped_img = img[:, ::-1]
|
| 1087 |
+
tmp = left_xyxy.copy()
|
| 1088 |
+
left_xyxy[:, 0] = width - tmp[:, 2] - 1
|
| 1089 |
+
left_xyxy[:, 2] = width - tmp[:, 0] - 1
|
| 1090 |
+
|
| 1091 |
+
batch_lhand = prepare_batch(flipped_img, transform_hand, left_xyxy, cam_int=cam_int.clone())
|
| 1092 |
+
batch_lhand = recursive_to(batch_lhand, "cuda")
|
| 1093 |
+
lhand_output = self.forward_step(batch_lhand, decoder_type="hand")
|
| 1094 |
+
|
| 1095 |
+
# Unflip output
|
| 1096 |
+
## Flip scale
|
| 1097 |
+
### Get MHR values
|
| 1098 |
+
scale_r_hands_mean = self.head_pose.scale_mean[8].item()
|
| 1099 |
+
scale_l_hands_mean = self.head_pose.scale_mean[9].item()
|
| 1100 |
+
scale_r_hands_std = self.head_pose.scale_comps[8, 8].item()
|
| 1101 |
+
scale_l_hands_std = self.head_pose.scale_comps[9, 9].item()
|
| 1102 |
+
### Apply
|
| 1103 |
+
lhand_output["mhr_hand"]["scale"][:, 9] = (
|
| 1104 |
+
(scale_r_hands_mean + scale_r_hands_std * lhand_output["mhr_hand"]["scale"][:, 8]) - scale_l_hands_mean
|
| 1105 |
+
) / scale_l_hands_std
|
| 1106 |
+
## Get the right hand global rotation, flip it, put it in as left.
|
| 1107 |
+
lhand_output["mhr_hand"]["joint_global_rots"][:, 78] = lhand_output["mhr_hand"]["joint_global_rots"][
|
| 1108 |
+
:, 42
|
| 1109 |
+
].clone()
|
| 1110 |
+
lhand_output["mhr_hand"]["joint_global_rots"][:, 78, [1, 2], :] *= -1
|
| 1111 |
+
### Flip hand pose
|
| 1112 |
+
lhand_output["mhr_hand"]["hand"][:, :54] = lhand_output["mhr_hand"]["hand"][:, 54:]
|
| 1113 |
+
### Unflip box
|
| 1114 |
+
batch_lhand["bbox_center"][:, :, 0] = width - batch_lhand["bbox_center"][:, :, 0] - 1
|
| 1115 |
+
|
| 1116 |
+
## Right...
|
| 1117 |
+
batch_rhand = prepare_batch(img, transform_hand, right_xyxy, cam_int=cam_int.clone())
|
| 1118 |
+
batch_rhand = recursive_to(batch_rhand, "cuda")
|
| 1119 |
+
rhand_output = self.forward_step(batch_rhand, decoder_type="hand")
|
| 1120 |
+
|
| 1121 |
+
# Step 3. replace hand pose estimation from the body decoder.
|
| 1122 |
+
## CRITERIA 1: LOCAL WRIST POSE DIFFERENCE
|
| 1123 |
+
joint_rotations = pose_output["mhr"]["joint_global_rots"]
|
| 1124 |
+
### Get lowarm
|
| 1125 |
+
lowarm_joint_idxs = torch.LongTensor([76, 40]).cuda() # left, right
|
| 1126 |
+
lowarm_joint_rotations = joint_rotations[:, lowarm_joint_idxs] # B x 2 x 3 x 3
|
| 1127 |
+
### Get zero-wrist pose
|
| 1128 |
+
wrist_twist_joint_idxs = torch.LongTensor([77, 41]).cuda() # left, right
|
| 1129 |
+
wrist_zero_rot_pose = lowarm_joint_rotations @ self.head_pose.joint_rotation[wrist_twist_joint_idxs]
|
| 1130 |
+
### Get globals from left & right
|
| 1131 |
+
left_joint_global_rots = lhand_output["mhr_hand"]["joint_global_rots"]
|
| 1132 |
+
right_joint_global_rots = rhand_output["mhr_hand"]["joint_global_rots"]
|
| 1133 |
+
pred_global_wrist_rotmat = torch.stack(
|
| 1134 |
+
[
|
| 1135 |
+
left_joint_global_rots[:, 78],
|
| 1136 |
+
right_joint_global_rots[:, 42],
|
| 1137 |
+
],
|
| 1138 |
+
dim=1,
|
| 1139 |
+
)
|
| 1140 |
+
### Get the local poses that lead to the wrist being pred_global_wrist_rotmat
|
| 1141 |
+
fused_local_wrist_rotmat = torch.einsum("kabc,kabd->kadc", pred_global_wrist_rotmat, wrist_zero_rot_pose)
|
| 1142 |
+
angle_difference = rotation_angle_difference(ori_local_wrist_rotmat, fused_local_wrist_rotmat) # B x 2 x 3 x3
|
| 1143 |
+
angle_difference_valid_mask = angle_difference < thresh_wrist_angle
|
| 1144 |
+
|
| 1145 |
+
## CRITERIA 2: hand box size
|
| 1146 |
+
hand_box_size_thresh = 64
|
| 1147 |
+
hand_box_size_valid_mask = torch.stack(
|
| 1148 |
+
[
|
| 1149 |
+
(batch_lhand["bbox_scale"].flatten(0, 1) > hand_box_size_thresh).all(dim=1),
|
| 1150 |
+
(batch_rhand["bbox_scale"].flatten(0, 1) > hand_box_size_thresh).all(dim=1),
|
| 1151 |
+
],
|
| 1152 |
+
dim=1,
|
| 1153 |
+
)
|
| 1154 |
+
|
| 1155 |
+
## CRITERIA 3: all hand 2D KPS (including wrist) inside of box.
|
| 1156 |
+
hand_kps2d_thresh = 0.5
|
| 1157 |
+
hand_kps2d_valid_mask = torch.stack(
|
| 1158 |
+
[
|
| 1159 |
+
lhand_output["mhr_hand"]["pred_keypoints_2d_cropped"].abs().amax(dim=(1, 2)) < hand_kps2d_thresh,
|
| 1160 |
+
rhand_output["mhr_hand"]["pred_keypoints_2d_cropped"].abs().amax(dim=(1, 2)) < hand_kps2d_thresh,
|
| 1161 |
+
],
|
| 1162 |
+
dim=1,
|
| 1163 |
+
)
|
| 1164 |
+
|
| 1165 |
+
## CRITERIA 4: 2D wrist distance.
|
| 1166 |
+
hand_wrist_kps2d_thresh = 0.25
|
| 1167 |
+
kps_right_wrist_idx = 41
|
| 1168 |
+
kps_left_wrist_idx = 62
|
| 1169 |
+
right_kps_full = rhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone()
|
| 1170 |
+
left_kps_full = lhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone()
|
| 1171 |
+
left_kps_full[:, :, 0] = width - left_kps_full[:, :, 0] - 1 # Flip left hand
|
| 1172 |
+
body_right_kps_full = pose_output["mhr"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone()
|
| 1173 |
+
body_left_kps_full = pose_output["mhr"]["pred_keypoints_2d"][:, [kps_left_wrist_idx]].clone()
|
| 1174 |
+
right_kps_dist = (right_kps_full - body_right_kps_full).flatten(0, 1).norm(dim=-1) / batch_lhand[
|
| 1175 |
+
"bbox_scale"
|
| 1176 |
+
].flatten(0, 1)[:, 0]
|
| 1177 |
+
left_kps_dist = (left_kps_full - body_left_kps_full).flatten(0, 1).norm(dim=-1) / batch_rhand[
|
| 1178 |
+
"bbox_scale"
|
| 1179 |
+
].flatten(0, 1)[:, 0]
|
| 1180 |
+
hand_wrist_kps2d_valid_mask = torch.stack(
|
| 1181 |
+
[
|
| 1182 |
+
left_kps_dist < hand_wrist_kps2d_thresh,
|
| 1183 |
+
right_kps_dist < hand_wrist_kps2d_thresh,
|
| 1184 |
+
],
|
| 1185 |
+
dim=1,
|
| 1186 |
+
)
|
| 1187 |
+
## Left-right
|
| 1188 |
+
hand_valid_mask = (
|
| 1189 |
+
angle_difference_valid_mask & hand_box_size_valid_mask & hand_kps2d_valid_mask & hand_wrist_kps2d_valid_mask
|
| 1190 |
+
)
|
| 1191 |
+
|
| 1192 |
+
# Keypoint prompting with the body decoder.
|
| 1193 |
+
# We use the wrist location from the hand decoder and the elbow location
|
| 1194 |
+
# from the body decoder as prompts to get an updated body pose estimation.
|
| 1195 |
+
batch_size, num_person = batch["img"].shape[:2]
|
| 1196 |
+
self.hand_batch_idx = []
|
| 1197 |
+
self.body_batch_idx = list(range(batch_size * num_person))
|
| 1198 |
+
|
| 1199 |
+
## Get right & left wrist keypoints from crops; full image. Each are B x 1 x 2
|
| 1200 |
+
kps_right_wrist_idx = 41
|
| 1201 |
+
kps_left_wrist_idx = 62
|
| 1202 |
+
right_kps_full = rhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone()
|
| 1203 |
+
left_kps_full = lhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone()
|
| 1204 |
+
left_kps_full[:, :, 0] = width - left_kps_full[:, :, 0] - 1 # Flip left hand
|
| 1205 |
+
|
| 1206 |
+
# Next, get them to crop-normalized space.
|
| 1207 |
+
right_kps_crop = self._full_to_crop(batch, right_kps_full)
|
| 1208 |
+
left_kps_crop = self._full_to_crop(batch, left_kps_full)
|
| 1209 |
+
|
| 1210 |
+
# Get right & left elbow keypoints from crops; full image. Each are B x 1 x 2
|
| 1211 |
+
kps_right_elbow_idx = 8
|
| 1212 |
+
kps_left_elbow_idx = 7
|
| 1213 |
+
right_kps_elbow_full = pose_output["mhr"]["pred_keypoints_2d"][:, [kps_right_elbow_idx]].clone()
|
| 1214 |
+
left_kps_elbow_full = pose_output["mhr"]["pred_keypoints_2d"][:, [kps_left_elbow_idx]].clone()
|
| 1215 |
+
|
| 1216 |
+
# Next, get them to crop-normalized space.
|
| 1217 |
+
right_kps_elbow_crop = self._full_to_crop(batch, right_kps_elbow_full)
|
| 1218 |
+
left_kps_elbow_crop = self._full_to_crop(batch, left_kps_elbow_full)
|
| 1219 |
+
|
| 1220 |
+
# Assemble them into keypoint prompts
|
| 1221 |
+
keypoint_prompt = torch.cat(
|
| 1222 |
+
[right_kps_crop, left_kps_crop, right_kps_elbow_crop, left_kps_elbow_crop],
|
| 1223 |
+
dim=1,
|
| 1224 |
+
)
|
| 1225 |
+
keypoint_prompt = torch.cat([keypoint_prompt, keypoint_prompt[..., [-1]]], dim=-1)
|
| 1226 |
+
keypoint_prompt[:, 0, -1] = kps_right_wrist_idx
|
| 1227 |
+
keypoint_prompt[:, 1, -1] = kps_left_wrist_idx
|
| 1228 |
+
keypoint_prompt[:, 2, -1] = kps_right_elbow_idx
|
| 1229 |
+
keypoint_prompt[:, 3, -1] = kps_left_elbow_idx
|
| 1230 |
+
|
| 1231 |
+
if keypoint_prompt.shape[0] > 1:
|
| 1232 |
+
# Replace invalid keypoints to dummy prompts
|
| 1233 |
+
invalid_prompt = (
|
| 1234 |
+
(keypoint_prompt[..., 0] < -0.5)
|
| 1235 |
+
| (keypoint_prompt[..., 0] > 0.5)
|
| 1236 |
+
| (keypoint_prompt[..., 1] < -0.5)
|
| 1237 |
+
| (keypoint_prompt[..., 1] > 0.5)
|
| 1238 |
+
| (~hand_valid_mask[..., [1, 0, 1, 0]])
|
| 1239 |
+
).unsqueeze(-1)
|
| 1240 |
+
dummy_prompt = torch.zeros((1, 1, 3)).to(keypoint_prompt)
|
| 1241 |
+
dummy_prompt[:, :, -1] = -2
|
| 1242 |
+
keypoint_prompt[:, :, :2] = torch.clamp(
|
| 1243 |
+
keypoint_prompt[:, :, :2] + 0.5, min=0.0, max=1.0
|
| 1244 |
+
) # [-0.5, 0.5] --> [0, 1]
|
| 1245 |
+
keypoint_prompt = torch.where(invalid_prompt, dummy_prompt, keypoint_prompt)
|
| 1246 |
+
else:
|
| 1247 |
+
# Only keep valid keypoints
|
| 1248 |
+
valid_keypoint = (
|
| 1249 |
+
torch.all(
|
| 1250 |
+
(keypoint_prompt[:, :, :2] > -0.5) & (keypoint_prompt[:, :, :2] < 0.5),
|
| 1251 |
+
dim=2,
|
| 1252 |
+
)
|
| 1253 |
+
& hand_valid_mask[..., [1, 0, 1, 0]]
|
| 1254 |
+
).squeeze()
|
| 1255 |
+
keypoint_prompt = keypoint_prompt[:, valid_keypoint]
|
| 1256 |
+
keypoint_prompt[:, :, :2] = torch.clamp(
|
| 1257 |
+
keypoint_prompt[:, :, :2] + 0.5, min=0.0, max=1.0
|
| 1258 |
+
) # [-0.5, 0.5] --> [0, 1]
|
| 1259 |
+
|
| 1260 |
+
if keypoint_prompt.numel() != 0:
|
| 1261 |
+
pose_output, _ = self.run_keypoint_prompt(batch, pose_output, keypoint_prompt)
|
| 1262 |
+
|
| 1263 |
+
##############################################################################
|
| 1264 |
+
|
| 1265 |
+
# Drop in hand pose
|
| 1266 |
+
left_hand_pose_params = lhand_output["mhr_hand"]["hand"][:, :54]
|
| 1267 |
+
right_hand_pose_params = rhand_output["mhr_hand"]["hand"][:, 54:]
|
| 1268 |
+
updated_hand_pose = torch.cat([left_hand_pose_params, right_hand_pose_params], dim=1)
|
| 1269 |
+
|
| 1270 |
+
# Drop in hand scales
|
| 1271 |
+
updated_scale = pose_output["mhr"]["scale"].clone()
|
| 1272 |
+
updated_scale[:, 9] = lhand_output["mhr_hand"]["scale"][:, 9]
|
| 1273 |
+
updated_scale[:, 8] = rhand_output["mhr_hand"]["scale"][:, 8]
|
| 1274 |
+
updated_scale[:, 18:] = (
|
| 1275 |
+
lhand_output["mhr_hand"]["scale"][:, 18:] + rhand_output["mhr_hand"]["scale"][:, 18:]
|
| 1276 |
+
) / 2
|
| 1277 |
+
|
| 1278 |
+
# Update hand shape
|
| 1279 |
+
updated_shape = pose_output["mhr"]["shape"].clone()
|
| 1280 |
+
updated_shape[:, 40:] = (
|
| 1281 |
+
lhand_output["mhr_hand"]["shape"][:, 40:] + rhand_output["mhr_hand"]["shape"][:, 40:]
|
| 1282 |
+
) / 2
|
| 1283 |
+
|
| 1284 |
+
############################ Doing IK ############################
|
| 1285 |
+
|
| 1286 |
+
# First, forward just FK
|
| 1287 |
+
joint_rotations = self.head_pose.mhr_forward(
|
| 1288 |
+
global_trans=pose_output["mhr"]["global_rot"] * 0,
|
| 1289 |
+
global_rot=pose_output["mhr"]["global_rot"],
|
| 1290 |
+
body_pose_params=pose_output["mhr"]["body_pose"],
|
| 1291 |
+
hand_pose_params=updated_hand_pose,
|
| 1292 |
+
scale_params=updated_scale,
|
| 1293 |
+
shape_params=updated_shape,
|
| 1294 |
+
expr_params=pose_output["mhr"]["face"],
|
| 1295 |
+
return_joint_rotations=True,
|
| 1296 |
+
)[1]
|
| 1297 |
+
|
| 1298 |
+
# Get lowarm
|
| 1299 |
+
lowarm_joint_idxs = torch.LongTensor([76, 40]).cuda() # left, right
|
| 1300 |
+
lowarm_joint_rotations = joint_rotations[:, lowarm_joint_idxs] # B x 2 x 3 x 3
|
| 1301 |
+
|
| 1302 |
+
# Get zero-wrist pose
|
| 1303 |
+
wrist_twist_joint_idxs = torch.LongTensor([77, 41]).cuda() # left, right
|
| 1304 |
+
wrist_zero_rot_pose = lowarm_joint_rotations @ self.head_pose.joint_rotation[wrist_twist_joint_idxs]
|
| 1305 |
+
|
| 1306 |
+
# Get globals from left & right
|
| 1307 |
+
left_joint_global_rots = lhand_output["mhr_hand"]["joint_global_rots"]
|
| 1308 |
+
right_joint_global_rots = rhand_output["mhr_hand"]["joint_global_rots"]
|
| 1309 |
+
pred_global_wrist_rotmat = torch.stack(
|
| 1310 |
+
[
|
| 1311 |
+
left_joint_global_rots[:, 78],
|
| 1312 |
+
right_joint_global_rots[:, 42],
|
| 1313 |
+
],
|
| 1314 |
+
dim=1,
|
| 1315 |
+
)
|
| 1316 |
+
|
| 1317 |
+
# Now we want to get the local poses that lead to the wrist being pred_global_wrist_rotmat
|
| 1318 |
+
fused_local_wrist_rotmat = torch.einsum("kabc,kabd->kadc", pred_global_wrist_rotmat, wrist_zero_rot_pose)
|
| 1319 |
+
wrist_xzy = fix_wrist_euler(roma.rotmat_to_euler("XZY", fused_local_wrist_rotmat))
|
| 1320 |
+
|
| 1321 |
+
# Put it in.
|
| 1322 |
+
angle_difference = rotation_angle_difference(ori_local_wrist_rotmat, fused_local_wrist_rotmat) # B x 2 x 3 x3
|
| 1323 |
+
valid_angle = angle_difference < thresh_wrist_angle
|
| 1324 |
+
valid_angle = valid_angle & hand_valid_mask
|
| 1325 |
+
valid_angle = valid_angle.unsqueeze(-1)
|
| 1326 |
+
|
| 1327 |
+
body_pose = pose_output["mhr"]["body_pose"][:, [41, 43, 42, 31, 33, 32]].unflatten(1, (2, 3))
|
| 1328 |
+
updated_body_pose = torch.where(valid_angle, wrist_xzy, body_pose)
|
| 1329 |
+
pose_output["mhr"]["body_pose"][:, [41, 43, 42, 31, 33, 32]] = updated_body_pose.flatten(1, 2)
|
| 1330 |
+
|
| 1331 |
+
hand_pose = pose_output["mhr"]["hand"].unflatten(1, (2, 54))
|
| 1332 |
+
pose_output["mhr"]["hand"] = torch.where(
|
| 1333 |
+
valid_angle, updated_hand_pose.unflatten(1, (2, 54)), hand_pose
|
| 1334 |
+
).flatten(1, 2)
|
| 1335 |
+
|
| 1336 |
+
hand_scale = torch.stack(
|
| 1337 |
+
[pose_output["mhr"]["scale"][:, 9], pose_output["mhr"]["scale"][:, 8]],
|
| 1338 |
+
dim=1,
|
| 1339 |
+
)
|
| 1340 |
+
updated_hand_scale = torch.stack([updated_scale[:, 9], updated_scale[:, 8]], dim=1)
|
| 1341 |
+
masked_hand_scale = torch.where(valid_angle.squeeze(-1), updated_hand_scale, hand_scale)
|
| 1342 |
+
pose_output["mhr"]["scale"][:, 9] = masked_hand_scale[:, 0]
|
| 1343 |
+
pose_output["mhr"]["scale"][:, 8] = masked_hand_scale[:, 1]
|
| 1344 |
+
|
| 1345 |
+
# Replace shared shape and scale
|
| 1346 |
+
pose_output["mhr"]["scale"][:, 18:] = torch.where(
|
| 1347 |
+
valid_angle.squeeze(-1).sum(dim=1, keepdim=True) > 0,
|
| 1348 |
+
(
|
| 1349 |
+
lhand_output["mhr_hand"]["scale"][:, 18:] * valid_angle.squeeze(-1)[:, [0]]
|
| 1350 |
+
+ rhand_output["mhr_hand"]["scale"][:, 18:] * valid_angle.squeeze(-1)[:, [1]]
|
| 1351 |
+
)
|
| 1352 |
+
/ (valid_angle.squeeze(-1).sum(dim=1, keepdim=True) + 1e-8),
|
| 1353 |
+
pose_output["mhr"]["scale"][:, 18:],
|
| 1354 |
+
)
|
| 1355 |
+
pose_output["mhr"]["shape"][:, 40:] = torch.where(
|
| 1356 |
+
valid_angle.squeeze(-1).sum(dim=1, keepdim=True) > 0,
|
| 1357 |
+
(
|
| 1358 |
+
lhand_output["mhr_hand"]["shape"][:, 40:] * valid_angle.squeeze(-1)[:, [0]]
|
| 1359 |
+
+ rhand_output["mhr_hand"]["shape"][:, 40:] * valid_angle.squeeze(-1)[:, [1]]
|
| 1360 |
+
)
|
| 1361 |
+
/ (valid_angle.squeeze(-1).sum(dim=1, keepdim=True) + 1e-8),
|
| 1362 |
+
pose_output["mhr"]["shape"][:, 40:],
|
| 1363 |
+
)
|
| 1364 |
+
|
| 1365 |
+
########################################################
|
| 1366 |
+
|
| 1367 |
+
# Re-run forward
|
| 1368 |
+
with torch.no_grad():
|
| 1369 |
+
verts, j3d, jcoords, mhr_model_params, joint_global_rots = self.head_pose.mhr_forward(
|
| 1370 |
+
global_trans=pose_output["mhr"]["global_rot"] * 0,
|
| 1371 |
+
global_rot=pose_output["mhr"]["global_rot"],
|
| 1372 |
+
body_pose_params=pose_output["mhr"]["body_pose"],
|
| 1373 |
+
hand_pose_params=pose_output["mhr"]["hand"],
|
| 1374 |
+
scale_params=pose_output["mhr"]["scale"],
|
| 1375 |
+
shape_params=pose_output["mhr"]["shape"],
|
| 1376 |
+
expr_params=pose_output["mhr"]["face"],
|
| 1377 |
+
return_keypoints=True,
|
| 1378 |
+
return_joint_coords=True,
|
| 1379 |
+
return_model_params=True,
|
| 1380 |
+
return_joint_rotations=True,
|
| 1381 |
+
)
|
| 1382 |
+
j3d = j3d[:, :70] # 308 --> 70 keypoints
|
| 1383 |
+
verts[..., [1, 2]] *= -1 # Camera system difference
|
| 1384 |
+
j3d[..., [1, 2]] *= -1 # Camera system difference
|
| 1385 |
+
jcoords[..., [1, 2]] *= -1
|
| 1386 |
+
pose_output["mhr"]["pred_keypoints_3d"] = j3d
|
| 1387 |
+
pose_output["mhr"]["pred_vertices"] = verts
|
| 1388 |
+
pose_output["mhr"]["pred_joint_coords"] = jcoords
|
| 1389 |
+
pose_output["mhr"]["pred_pose_raw"][...] = 0 # pred_pose_raw is not valid anymore
|
| 1390 |
+
pose_output["mhr"]["mhr_model_params"] = mhr_model_params
|
| 1391 |
+
|
| 1392 |
+
########################################################
|
| 1393 |
+
# Project to 2D
|
| 1394 |
+
pred_keypoints_3d_proj = pose_output["mhr"]["pred_keypoints_3d"] + pose_output["mhr"]["pred_cam_t"][:, None, :]
|
| 1395 |
+
pred_keypoints_3d_proj[:, :, [0, 1]] *= pose_output["mhr"]["focal_length"][:, None, None]
|
| 1396 |
+
pred_keypoints_3d_proj[:, :, [0, 1]] = (
|
| 1397 |
+
pred_keypoints_3d_proj[:, :, [0, 1]]
|
| 1398 |
+
+ torch.FloatTensor([width / 2, height / 2]).to(pred_keypoints_3d_proj)[None, None, :]
|
| 1399 |
+
* pred_keypoints_3d_proj[:, :, [2]]
|
| 1400 |
+
)
|
| 1401 |
+
pred_keypoints_3d_proj[:, :, :2] = pred_keypoints_3d_proj[:, :, :2] / pred_keypoints_3d_proj[:, :, [2]]
|
| 1402 |
+
pose_output["mhr"]["pred_keypoints_2d"] = pred_keypoints_3d_proj[:, :, :2]
|
| 1403 |
+
|
| 1404 |
+
return BodyPredContainer(
|
| 1405 |
+
pose_output=pose_output,
|
| 1406 |
+
batch_lhand=batch_lhand,
|
| 1407 |
+
batch_rhand=batch_rhand,
|
| 1408 |
+
lhand_output=lhand_output,
|
| 1409 |
+
rhand_output=rhand_output,
|
| 1410 |
+
)
|
| 1411 |
+
|
| 1412 |
+
def run_keypoint_prompt(self, batch, output, keypoint_prompt):
|
| 1413 |
+
image_embeddings = output["image_embeddings"]
|
| 1414 |
+
condition_info = output["condition_info"]
|
| 1415 |
+
pose_output = output["mhr"] # body-only output
|
| 1416 |
+
# Use previous estimate as initialization
|
| 1417 |
+
prev_estimate = torch.cat(
|
| 1418 |
+
[
|
| 1419 |
+
pose_output["pred_pose_raw"].detach(), # (B, 6)
|
| 1420 |
+
pose_output["shape"].detach(),
|
| 1421 |
+
pose_output["scale"].detach(),
|
| 1422 |
+
pose_output["hand"].detach(),
|
| 1423 |
+
pose_output["face"].detach(),
|
| 1424 |
+
],
|
| 1425 |
+
dim=1,
|
| 1426 |
+
).unsqueeze(dim=1)
|
| 1427 |
+
if hasattr(self, "init_camera"):
|
| 1428 |
+
prev_estimate = torch.cat(
|
| 1429 |
+
[prev_estimate, pose_output["pred_cam"].detach().unsqueeze(1)],
|
| 1430 |
+
dim=-1,
|
| 1431 |
+
)
|
| 1432 |
+
|
| 1433 |
+
tokens_output, pose_output = self.forward_decoder(
|
| 1434 |
+
image_embeddings,
|
| 1435 |
+
init_estimate=None, # not recurring previous estimate
|
| 1436 |
+
keypoints=keypoint_prompt,
|
| 1437 |
+
prev_estimate=prev_estimate,
|
| 1438 |
+
condition_info=condition_info,
|
| 1439 |
+
batch=batch,
|
| 1440 |
+
)
|
| 1441 |
+
pose_output = pose_output[-1]
|
| 1442 |
+
|
| 1443 |
+
output.update({"mhr": pose_output})
|
| 1444 |
+
return output, keypoint_prompt
|
| 1445 |
+
|
| 1446 |
+
def _get_hand_box(self, pose_output, batch):
|
| 1447 |
+
"""Get hand bbox from the hand detector"""
|
| 1448 |
+
pred_left_hand_box = pose_output["mhr"]["hand_box"][:, 0].detach().cpu().numpy() * self.cfg.MODEL.IMAGE_SIZE[0]
|
| 1449 |
+
pred_right_hand_box = pose_output["mhr"]["hand_box"][:, 1].detach().cpu().numpy() * self.cfg.MODEL.IMAGE_SIZE[0]
|
| 1450 |
+
|
| 1451 |
+
# Change boxes into squares
|
| 1452 |
+
batch["left_center"] = pred_left_hand_box[:, :2]
|
| 1453 |
+
batch["left_scale"] = pred_left_hand_box[:, 2:].max(axis=1, keepdims=True).repeat(2, axis=1)
|
| 1454 |
+
batch["right_center"] = pred_right_hand_box[:, :2]
|
| 1455 |
+
batch["right_scale"] = pred_right_hand_box[:, 2:].max(axis=1, keepdims=True).repeat(2, axis=1)
|
| 1456 |
+
|
| 1457 |
+
# Crop to full. batch["affine_trans"] is full-to-crop, right application
|
| 1458 |
+
batch["left_scale"] = batch["left_scale"] / batch["affine_trans"][0, :, 0, 0].cpu().numpy()[:, None]
|
| 1459 |
+
batch["right_scale"] = batch["right_scale"] / batch["affine_trans"][0, :, 0, 0].cpu().numpy()[:, None]
|
| 1460 |
+
batch["left_center"] = (
|
| 1461 |
+
batch["left_center"] - batch["affine_trans"][0, :, [0, 1], [2, 2]].cpu().numpy()
|
| 1462 |
+
) / batch["affine_trans"][0, :, 0, 0].cpu().numpy()[:, None]
|
| 1463 |
+
batch["right_center"] = (
|
| 1464 |
+
batch["right_center"] - batch["affine_trans"][0, :, [0, 1], [2, 2]].cpu().numpy()
|
| 1465 |
+
) / batch["affine_trans"][0, :, 0, 0].cpu().numpy()[:, None]
|
| 1466 |
+
|
| 1467 |
+
left_xyxy = np.concatenate(
|
| 1468 |
+
[
|
| 1469 |
+
(batch["left_center"][:, 0] - batch["left_scale"][:, 0] * 1 / 2).reshape(-1, 1),
|
| 1470 |
+
(batch["left_center"][:, 1] - batch["left_scale"][:, 1] * 1 / 2).reshape(-1, 1),
|
| 1471 |
+
(batch["left_center"][:, 0] + batch["left_scale"][:, 0] * 1 / 2).reshape(-1, 1),
|
| 1472 |
+
(batch["left_center"][:, 1] + batch["left_scale"][:, 1] * 1 / 2).reshape(-1, 1),
|
| 1473 |
+
],
|
| 1474 |
+
axis=1,
|
| 1475 |
+
)
|
| 1476 |
+
right_xyxy = np.concatenate(
|
| 1477 |
+
[
|
| 1478 |
+
(batch["right_center"][:, 0] - batch["right_scale"][:, 0] * 1 / 2).reshape(-1, 1),
|
| 1479 |
+
(batch["right_center"][:, 1] - batch["right_scale"][:, 1] * 1 / 2).reshape(-1, 1),
|
| 1480 |
+
(batch["right_center"][:, 0] + batch["right_scale"][:, 0] * 1 / 2).reshape(-1, 1),
|
| 1481 |
+
(batch["right_center"][:, 1] + batch["right_scale"][:, 1] * 1 / 2).reshape(-1, 1),
|
| 1482 |
+
],
|
| 1483 |
+
axis=1,
|
| 1484 |
+
)
|
| 1485 |
+
|
| 1486 |
+
return left_xyxy, right_xyxy
|
| 1487 |
+
|
| 1488 |
+
def keypoint_token_update_fn(
|
| 1489 |
+
self,
|
| 1490 |
+
kps_emb_start_idx,
|
| 1491 |
+
image_embeddings,
|
| 1492 |
+
token_embeddings,
|
| 1493 |
+
token_augment,
|
| 1494 |
+
pose_output,
|
| 1495 |
+
layer_idx,
|
| 1496 |
+
):
|
| 1497 |
+
# It's already after the last layer, we're done.
|
| 1498 |
+
if layer_idx == len(self.decoder.layers) - 1:
|
| 1499 |
+
return token_embeddings, token_augment, pose_output, layer_idx
|
| 1500 |
+
|
| 1501 |
+
# Clone
|
| 1502 |
+
token_embeddings = token_embeddings.clone()
|
| 1503 |
+
token_augment = token_augment.clone()
|
| 1504 |
+
|
| 1505 |
+
num_keypoints = self.keypoint_embedding.weight.shape[0]
|
| 1506 |
+
|
| 1507 |
+
# Get current 2D KPS predictions
|
| 1508 |
+
pred_keypoints_2d_cropped = pose_output["pred_keypoints_2d_cropped"].clone() # These are -0.5 ~ 0.5
|
| 1509 |
+
pred_keypoints_2d_depth = pose_output["pred_keypoints_2d_depth"].clone()
|
| 1510 |
+
|
| 1511 |
+
pred_keypoints_2d_cropped = pred_keypoints_2d_cropped[:, self.keypoint_embedding_idxs]
|
| 1512 |
+
pred_keypoints_2d_depth = pred_keypoints_2d_depth[:, self.keypoint_embedding_idxs]
|
| 1513 |
+
|
| 1514 |
+
# Get 2D KPS to be 0 ~ 1
|
| 1515 |
+
pred_keypoints_2d_cropped_01 = pred_keypoints_2d_cropped + 0.5
|
| 1516 |
+
|
| 1517 |
+
# Get a mask of those that are 1) beyond image boundaries or 2) behind the camera
|
| 1518 |
+
invalid_mask = (
|
| 1519 |
+
(pred_keypoints_2d_cropped_01[:, :, 0] < 0)
|
| 1520 |
+
| (pred_keypoints_2d_cropped_01[:, :, 0] > 1)
|
| 1521 |
+
| (pred_keypoints_2d_cropped_01[:, :, 1] < 0)
|
| 1522 |
+
| (pred_keypoints_2d_cropped_01[:, :, 1] > 1)
|
| 1523 |
+
| (pred_keypoints_2d_depth[:, :] < 1e-5)
|
| 1524 |
+
)
|
| 1525 |
+
|
| 1526 |
+
# Run them through the prompt encoder's pos emb function
|
| 1527 |
+
token_augment[:, kps_emb_start_idx : kps_emb_start_idx + num_keypoints, :] = self.keypoint_posemb_linear(
|
| 1528 |
+
pred_keypoints_2d_cropped
|
| 1529 |
+
) * (~invalid_mask[:, :, None])
|
| 1530 |
+
|
| 1531 |
+
# Also maybe update token_embeddings with the grid sampled 2D feature.
|
| 1532 |
+
# Remember that pred_keypoints_2d_cropped are -0.5 ~ 0.5. We want -1 ~ 1
|
| 1533 |
+
# Sample points...
|
| 1534 |
+
## Get sampling points
|
| 1535 |
+
pred_keypoints_2d_cropped_sample_points = pred_keypoints_2d_cropped * 2
|
| 1536 |
+
if self.cfg.MODEL.BACKBONE.TYPE in [
|
| 1537 |
+
"vit_hmr",
|
| 1538 |
+
"vit",
|
| 1539 |
+
"vit_b",
|
| 1540 |
+
"vit_l",
|
| 1541 |
+
"vit_hmr_512_384",
|
| 1542 |
+
]:
|
| 1543 |
+
# Need to go from 256 x 256 coords to 256 x 192 (HW) because image_embeddings is 16x12
|
| 1544 |
+
# Aka, for x, what was normally -1 ~ 1 for 256 should be -16/12 ~ 16/12 (since to sample at original 256, need to overflow)
|
| 1545 |
+
pred_keypoints_2d_cropped_sample_points[:, :, 0] = (
|
| 1546 |
+
pred_keypoints_2d_cropped_sample_points[:, :, 0] / 12 * 16
|
| 1547 |
+
)
|
| 1548 |
+
|
| 1549 |
+
# Version 2 is projecting & bilinear sampling
|
| 1550 |
+
pred_keypoints_2d_cropped_feats = (
|
| 1551 |
+
F.grid_sample(
|
| 1552 |
+
image_embeddings,
|
| 1553 |
+
pred_keypoints_2d_cropped_sample_points[:, :, None, :], # -1 ~ 1, xy
|
| 1554 |
+
mode="bilinear",
|
| 1555 |
+
padding_mode="zeros",
|
| 1556 |
+
align_corners=False,
|
| 1557 |
+
)
|
| 1558 |
+
.squeeze(3)
|
| 1559 |
+
.permute(0, 2, 1)
|
| 1560 |
+
) # B x kps x C
|
| 1561 |
+
# Zero out invalid locations...
|
| 1562 |
+
pred_keypoints_2d_cropped_feats = pred_keypoints_2d_cropped_feats * (~invalid_mask[:, :, None])
|
| 1563 |
+
# This is ADDING
|
| 1564 |
+
token_embeddings = token_embeddings.clone()
|
| 1565 |
+
token_embeddings[
|
| 1566 |
+
:,
|
| 1567 |
+
kps_emb_start_idx : kps_emb_start_idx + num_keypoints,
|
| 1568 |
+
:,
|
| 1569 |
+
] += self.keypoint_feat_linear(pred_keypoints_2d_cropped_feats)
|
| 1570 |
+
|
| 1571 |
+
return token_embeddings, token_augment, pose_output, layer_idx
|
| 1572 |
+
|
| 1573 |
+
def keypoint3d_token_update_fn(
|
| 1574 |
+
self,
|
| 1575 |
+
kps3d_emb_start_idx,
|
| 1576 |
+
token_embeddings,
|
| 1577 |
+
token_augment,
|
| 1578 |
+
pose_output,
|
| 1579 |
+
layer_idx,
|
| 1580 |
+
):
|
| 1581 |
+
# It's already after the last layer, we're done.
|
| 1582 |
+
if layer_idx == len(self.decoder.layers) - 1:
|
| 1583 |
+
return token_embeddings, token_augment, pose_output, layer_idx
|
| 1584 |
+
|
| 1585 |
+
num_keypoints3d = self.keypoint3d_embedding.weight.shape[0]
|
| 1586 |
+
|
| 1587 |
+
# Get current 3D kps predictions
|
| 1588 |
+
pred_keypoints_3d = pose_output["pred_keypoints_3d"].clone()
|
| 1589 |
+
|
| 1590 |
+
# Now, pelvis normalize
|
| 1591 |
+
pred_keypoints_3d = (
|
| 1592 |
+
pred_keypoints_3d
|
| 1593 |
+
- (pred_keypoints_3d[:, [self.pelvis_idx[0]], :] + pred_keypoints_3d[:, [self.pelvis_idx[1]], :]) / 2
|
| 1594 |
+
)
|
| 1595 |
+
|
| 1596 |
+
# Get the kps we care about, _after_ pelvis norm (just in case idxs shift)
|
| 1597 |
+
pred_keypoints_3d = pred_keypoints_3d[:, self.keypoint3d_embedding_idxs]
|
| 1598 |
+
|
| 1599 |
+
# Run through embedding MLP & put in
|
| 1600 |
+
token_augment = token_augment.clone()
|
| 1601 |
+
token_augment[
|
| 1602 |
+
:,
|
| 1603 |
+
kps3d_emb_start_idx : kps3d_emb_start_idx + num_keypoints3d,
|
| 1604 |
+
:,
|
| 1605 |
+
] = self.keypoint3d_posemb_linear(pred_keypoints_3d)
|
| 1606 |
+
|
| 1607 |
+
return token_embeddings, token_augment, pose_output, layer_idx
|
| 1608 |
+
|
| 1609 |
+
def keypoint_token_update_fn_hand(
|
| 1610 |
+
self,
|
| 1611 |
+
kps_emb_start_idx,
|
| 1612 |
+
image_embeddings,
|
| 1613 |
+
token_embeddings,
|
| 1614 |
+
token_augment,
|
| 1615 |
+
pose_output,
|
| 1616 |
+
layer_idx,
|
| 1617 |
+
):
|
| 1618 |
+
# It's already after the last layer, we're done.
|
| 1619 |
+
if layer_idx == len(self.decoder_hand.layers) - 1:
|
| 1620 |
+
return token_embeddings, token_augment, pose_output, layer_idx
|
| 1621 |
+
|
| 1622 |
+
# Clone
|
| 1623 |
+
token_embeddings = token_embeddings.clone()
|
| 1624 |
+
token_augment = token_augment.clone()
|
| 1625 |
+
|
| 1626 |
+
num_keypoints = self.keypoint_embedding_hand.weight.shape[0]
|
| 1627 |
+
|
| 1628 |
+
# Get current 2D KPS predictions
|
| 1629 |
+
pred_keypoints_2d_cropped = pose_output["pred_keypoints_2d_cropped"].clone() # These are -0.5 ~ 0.5
|
| 1630 |
+
pred_keypoints_2d_depth = pose_output["pred_keypoints_2d_depth"].clone()
|
| 1631 |
+
|
| 1632 |
+
pred_keypoints_2d_cropped = pred_keypoints_2d_cropped[:, self.keypoint_embedding_idxs_hand]
|
| 1633 |
+
pred_keypoints_2d_depth = pred_keypoints_2d_depth[:, self.keypoint_embedding_idxs_hand]
|
| 1634 |
+
|
| 1635 |
+
# Get 2D KPS to be 0 ~ 1
|
| 1636 |
+
pred_keypoints_2d_cropped_01 = pred_keypoints_2d_cropped + 0.5
|
| 1637 |
+
|
| 1638 |
+
# Get a mask of those that are 1) beyond image boundaries or 2) behind the camera
|
| 1639 |
+
invalid_mask = (
|
| 1640 |
+
(pred_keypoints_2d_cropped_01[:, :, 0] < 0)
|
| 1641 |
+
| (pred_keypoints_2d_cropped_01[:, :, 0] > 1)
|
| 1642 |
+
| (pred_keypoints_2d_cropped_01[:, :, 1] < 0)
|
| 1643 |
+
| (pred_keypoints_2d_cropped_01[:, :, 1] > 1)
|
| 1644 |
+
| (pred_keypoints_2d_depth[:, :] < 1e-5)
|
| 1645 |
+
)
|
| 1646 |
+
|
| 1647 |
+
# Run them through the prompt encoder's pos emb function
|
| 1648 |
+
token_augment[:, kps_emb_start_idx : kps_emb_start_idx + num_keypoints, :] = self.keypoint_posemb_linear_hand(
|
| 1649 |
+
pred_keypoints_2d_cropped
|
| 1650 |
+
) * (~invalid_mask[:, :, None])
|
| 1651 |
+
|
| 1652 |
+
# Also maybe update token_embeddings with the grid sampled 2D feature.
|
| 1653 |
+
# Remember that pred_keypoints_2d_cropped are -0.5 ~ 0.5. We want -1 ~ 1
|
| 1654 |
+
# Sample points...
|
| 1655 |
+
## Get sampling points
|
| 1656 |
+
pred_keypoints_2d_cropped_sample_points = pred_keypoints_2d_cropped * 2
|
| 1657 |
+
if self.cfg.MODEL.BACKBONE.TYPE in [
|
| 1658 |
+
"vit_hmr",
|
| 1659 |
+
"vit",
|
| 1660 |
+
"vit_b",
|
| 1661 |
+
"vit_l",
|
| 1662 |
+
"vit_hmr_512_384",
|
| 1663 |
+
]:
|
| 1664 |
+
# Need to go from 256 x 256 coords to 256 x 192 (HW) because image_embeddings is 16x12
|
| 1665 |
+
# Aka, for x, what was normally -1 ~ 1 for 256 should be -16/12 ~ 16/12 (since to sample at original 256, need to overflow)
|
| 1666 |
+
pred_keypoints_2d_cropped_sample_points[:, :, 0] = (
|
| 1667 |
+
pred_keypoints_2d_cropped_sample_points[:, :, 0] / 12 * 16
|
| 1668 |
+
)
|
| 1669 |
+
|
| 1670 |
+
# Version 2 is projecting & bilinear sampling
|
| 1671 |
+
pred_keypoints_2d_cropped_feats = (
|
| 1672 |
+
F.grid_sample(
|
| 1673 |
+
image_embeddings,
|
| 1674 |
+
pred_keypoints_2d_cropped_sample_points[:, :, None, :], # -1 ~ 1, xy
|
| 1675 |
+
mode="bilinear",
|
| 1676 |
+
padding_mode="zeros",
|
| 1677 |
+
align_corners=False,
|
| 1678 |
+
)
|
| 1679 |
+
.squeeze(3)
|
| 1680 |
+
.permute(0, 2, 1)
|
| 1681 |
+
) # B x kps x C
|
| 1682 |
+
# Zero out invalid locations...
|
| 1683 |
+
pred_keypoints_2d_cropped_feats = pred_keypoints_2d_cropped_feats * (~invalid_mask[:, :, None])
|
| 1684 |
+
# This is ADDING
|
| 1685 |
+
token_embeddings = token_embeddings.clone()
|
| 1686 |
+
token_embeddings[
|
| 1687 |
+
:,
|
| 1688 |
+
kps_emb_start_idx : kps_emb_start_idx + num_keypoints,
|
| 1689 |
+
:,
|
| 1690 |
+
] += self.keypoint_feat_linear_hand(pred_keypoints_2d_cropped_feats)
|
| 1691 |
+
|
| 1692 |
+
return token_embeddings, token_augment, pose_output, layer_idx
|
| 1693 |
+
|
| 1694 |
+
def keypoint3d_token_update_fn_hand(
|
| 1695 |
+
self,
|
| 1696 |
+
kps3d_emb_start_idx,
|
| 1697 |
+
token_embeddings,
|
| 1698 |
+
token_augment,
|
| 1699 |
+
pose_output,
|
| 1700 |
+
layer_idx,
|
| 1701 |
+
):
|
| 1702 |
+
# It's already after the last layer, we're done.
|
| 1703 |
+
if layer_idx == len(self.decoder_hand.layers) - 1:
|
| 1704 |
+
return token_embeddings, token_augment, pose_output, layer_idx
|
| 1705 |
+
|
| 1706 |
+
num_keypoints3d = self.keypoint3d_embedding_hand.weight.shape[0]
|
| 1707 |
+
|
| 1708 |
+
# Get current 3D kps predictions
|
| 1709 |
+
pred_keypoints_3d = pose_output["pred_keypoints_3d"].clone()
|
| 1710 |
+
|
| 1711 |
+
# Now, pelvis normalize
|
| 1712 |
+
pred_keypoints_3d = (
|
| 1713 |
+
pred_keypoints_3d
|
| 1714 |
+
- (pred_keypoints_3d[:, [self.pelvis_idx[0]], :] + pred_keypoints_3d[:, [self.pelvis_idx[1]], :]) / 2
|
| 1715 |
+
)
|
| 1716 |
+
|
| 1717 |
+
# Get the kps we care about, _after_ pelvis norm (just in case idxs shift)
|
| 1718 |
+
pred_keypoints_3d = pred_keypoints_3d[:, self.keypoint3d_embedding_idxs_hand]
|
| 1719 |
+
|
| 1720 |
+
# Run through embedding MLP & put in
|
| 1721 |
+
token_augment = token_augment.clone()
|
| 1722 |
+
token_augment[
|
| 1723 |
+
:,
|
| 1724 |
+
kps3d_emb_start_idx : kps3d_emb_start_idx + num_keypoints3d,
|
| 1725 |
+
:,
|
| 1726 |
+
] = self.keypoint3d_posemb_linear_hand(pred_keypoints_3d)
|
| 1727 |
+
|
| 1728 |
+
return token_embeddings, token_augment, pose_output, layer_idx
|
src/sam3d_body/models/modules/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
from .geometry_utils import (
|
| 4 |
+
aa_to_rotmat,
|
| 5 |
+
cam_crop_to_full,
|
| 6 |
+
focal_length_normalization,
|
| 7 |
+
get_focalLength_from_fieldOfView,
|
| 8 |
+
get_intrinsic_matrix,
|
| 9 |
+
inverse_perspective_projection,
|
| 10 |
+
log_depth,
|
| 11 |
+
perspective_projection,
|
| 12 |
+
rot6d_to_rotmat,
|
| 13 |
+
transform_points,
|
| 14 |
+
undo_focal_length_normalization,
|
| 15 |
+
undo_log_depth,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
from .misc import to_2tuple, to_3tuple, to_4tuple, to_ntuple
|
src/sam3d_body/models/modules/camera_embed.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
import einops
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from sam3d_body.models.modules.transformer import LayerNorm2d
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CameraEncoder(nn.Module):
|
| 13 |
+
def __init__(self, embed_dim, patch_size=14):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.patch_size = patch_size
|
| 16 |
+
self.embed_dim = embed_dim
|
| 17 |
+
self.camera = FourierPositionEncoding(n=3, num_bands=16, max_resolution=64)
|
| 18 |
+
|
| 19 |
+
self.conv = nn.Conv2d(embed_dim + 99, embed_dim, kernel_size=1, bias=False)
|
| 20 |
+
self.norm = LayerNorm2d(embed_dim)
|
| 21 |
+
|
| 22 |
+
def forward(self, img_embeddings, rays):
|
| 23 |
+
B, D, _h, _w = img_embeddings.shape
|
| 24 |
+
|
| 25 |
+
with torch.no_grad():
|
| 26 |
+
scale = 1 / self.patch_size
|
| 27 |
+
rays = F.interpolate(
|
| 28 |
+
rays,
|
| 29 |
+
scale_factor=(scale, scale),
|
| 30 |
+
mode="bilinear",
|
| 31 |
+
align_corners=False,
|
| 32 |
+
antialias=True,
|
| 33 |
+
)
|
| 34 |
+
rays = rays.permute(0, 2, 3, 1).contiguous() # [b, h, w, 2]
|
| 35 |
+
rays = torch.cat([rays, torch.ones_like(rays[..., :1])], dim=-1)
|
| 36 |
+
rays_embeddings = self.camera(
|
| 37 |
+
pos=rays.reshape(B, -1, 3)
|
| 38 |
+
) # (bs, N, 99): rays fourier embedding
|
| 39 |
+
rays_embeddings = einops.rearrange(
|
| 40 |
+
rays_embeddings, "b (h w) c -> b c h w", h=_h, w=_w
|
| 41 |
+
).contiguous()
|
| 42 |
+
|
| 43 |
+
z = torch.concat([img_embeddings, rays_embeddings], dim=1)
|
| 44 |
+
z = self.norm(self.conv(z))
|
| 45 |
+
|
| 46 |
+
return z
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class FourierPositionEncoding(nn.Module):
|
| 50 |
+
def __init__(self, n, num_bands, max_resolution):
|
| 51 |
+
"""
|
| 52 |
+
Module that generate Fourier encoding - no learning involved
|
| 53 |
+
"""
|
| 54 |
+
super().__init__()
|
| 55 |
+
|
| 56 |
+
self.num_bands = num_bands
|
| 57 |
+
self.max_resolution = [max_resolution] * n
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def channels(self):
|
| 61 |
+
"""
|
| 62 |
+
Return the output dimension
|
| 63 |
+
"""
|
| 64 |
+
num_dims = len(self.max_resolution)
|
| 65 |
+
encoding_size = self.num_bands * num_dims
|
| 66 |
+
encoding_size *= 2 # sin-cos
|
| 67 |
+
encoding_size += num_dims # concat
|
| 68 |
+
|
| 69 |
+
return encoding_size
|
| 70 |
+
|
| 71 |
+
def forward(self, pos):
|
| 72 |
+
"""
|
| 73 |
+
Forward pass that take rays as input and generate Fourier positional encodings
|
| 74 |
+
"""
|
| 75 |
+
fourier_pos_enc = _generate_fourier_features(
|
| 76 |
+
pos, num_bands=self.num_bands, max_resolution=self.max_resolution
|
| 77 |
+
)
|
| 78 |
+
return fourier_pos_enc
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _generate_fourier_features(pos, num_bands, max_resolution):
|
| 82 |
+
"""Generate fourier features from a given set of positions and frequencies"""
|
| 83 |
+
b, n = pos.shape[:2]
|
| 84 |
+
device = pos.device
|
| 85 |
+
|
| 86 |
+
# Linear frequency sampling
|
| 87 |
+
min_freq = 1.0
|
| 88 |
+
freq_bands = torch.stack(
|
| 89 |
+
[
|
| 90 |
+
torch.linspace(start=min_freq, end=res / 2, steps=num_bands, device=device)
|
| 91 |
+
for res in max_resolution
|
| 92 |
+
],
|
| 93 |
+
dim=0,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Stacking
|
| 97 |
+
per_pos_features = torch.stack(
|
| 98 |
+
[pos[i, :, :][:, :, None] * freq_bands[None, :, :] for i in range(b)], 0
|
| 99 |
+
)
|
| 100 |
+
per_pos_features = per_pos_features.reshape(b, n, -1)
|
| 101 |
+
|
| 102 |
+
# Sin-Cos
|
| 103 |
+
per_pos_features = torch.cat(
|
| 104 |
+
[torch.sin(np.pi * per_pos_features), torch.cos(np.pi * per_pos_features)],
|
| 105 |
+
dim=-1,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Concat with initial pos
|
| 109 |
+
per_pos_features = torch.cat([pos, per_pos_features], dim=-1)
|
| 110 |
+
|
| 111 |
+
return per_pos_features
|
src/sam3d_body/models/modules/drop_path.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def drop_path(
|
| 8 |
+
x: torch.Tensor, drop_prob: float = 0.0, training: bool = False
|
| 9 |
+
) -> torch.Tensor:
|
| 10 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
|
| 11 |
+
residual blocks).
|
| 12 |
+
|
| 13 |
+
We follow the implementation
|
| 14 |
+
https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
|
| 15 |
+
"""
|
| 16 |
+
if not training:
|
| 17 |
+
return x
|
| 18 |
+
keep_prob = 1 - drop_prob
|
| 19 |
+
# handle tensors with different dimensions, not just 4D tensors.
|
| 20 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
| 21 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
| 22 |
+
output = x.div(keep_prob) * random_tensor.floor()
|
| 23 |
+
return output
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DropPath(nn.Module):
|
| 27 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
|
| 28 |
+
residual blocks).
|
| 29 |
+
|
| 30 |
+
We follow the implementation
|
| 31 |
+
https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
drop_prob (float): Probability of the path to be zeroed. Default: 0.1
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, drop_prob: float = 0.1):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.drop_prob = drop_prob
|
| 40 |
+
|
| 41 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
return drop_path(x, self.drop_prob, self.training)
|
src/sam3d_body/models/modules/geometry_utils.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
from jaxtyping import Float
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def cam_crop_to_full(cam_bbox, box_center, box_size, img_size, focal_length=5000.0):
|
| 14 |
+
# Convert cam_bbox to full image
|
| 15 |
+
img_w, img_h = img_size[:, 0], img_size[:, 1]
|
| 16 |
+
cx, cy, b = box_center[:, 0], box_center[:, 1], box_size
|
| 17 |
+
w_2, h_2 = img_w / 2.0, img_h / 2.0
|
| 18 |
+
bs = b * cam_bbox[:, 0] + 1e-9
|
| 19 |
+
if type(focal_length) is float:
|
| 20 |
+
focal_length = torch.ones_like(cam_bbox[:, 0]) * focal_length
|
| 21 |
+
tz = 2 * focal_length / bs
|
| 22 |
+
tx = (2 * (cx - w_2) / bs) + cam_bbox[:, 1]
|
| 23 |
+
ty = (2 * (cy - h_2) / bs) + cam_bbox[:, 2]
|
| 24 |
+
full_cam = torch.stack([tx, ty, tz], dim=-1)
|
| 25 |
+
return full_cam
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def aa_to_rotmat(theta: torch.Tensor):
|
| 29 |
+
"""
|
| 30 |
+
Convert axis-angle representation to rotation matrix.
|
| 31 |
+
Works by first converting it to a quaternion.
|
| 32 |
+
Args:
|
| 33 |
+
theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations.
|
| 34 |
+
Returns:
|
| 35 |
+
torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
|
| 36 |
+
|
| 37 |
+
Alternatives:
|
| 38 |
+
import roma
|
| 39 |
+
y = roma.rotvec_to_rotmat(x)
|
| 40 |
+
"""
|
| 41 |
+
norm = torch.norm(theta + 1e-8, p=2, dim=1)
|
| 42 |
+
angle = torch.unsqueeze(norm, -1)
|
| 43 |
+
normalized = torch.div(theta, angle)
|
| 44 |
+
angle = angle * 0.5
|
| 45 |
+
v_cos = torch.cos(angle)
|
| 46 |
+
v_sin = torch.sin(angle)
|
| 47 |
+
quat = torch.cat([v_cos, v_sin * normalized], dim=1)
|
| 48 |
+
return _quat_to_rotmat(quat)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
|
| 52 |
+
"""
|
| 53 |
+
Convert quaternion representation to rotation matrix.
|
| 54 |
+
Args:
|
| 55 |
+
quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z).
|
| 56 |
+
Returns:
|
| 57 |
+
torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
|
| 58 |
+
"""
|
| 59 |
+
norm_quat = quat
|
| 60 |
+
norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
|
| 61 |
+
w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
|
| 62 |
+
|
| 63 |
+
B = quat.size(0)
|
| 64 |
+
|
| 65 |
+
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
|
| 66 |
+
wx, wy, wz = w * x, w * y, w * z
|
| 67 |
+
xy, xz, yz = x * y, x * z, y * z
|
| 68 |
+
|
| 69 |
+
rotMat = torch.stack(
|
| 70 |
+
[
|
| 71 |
+
w2 + x2 - y2 - z2,
|
| 72 |
+
2 * xy - 2 * wz,
|
| 73 |
+
2 * wy + 2 * xz,
|
| 74 |
+
2 * wz + 2 * xy,
|
| 75 |
+
w2 - x2 + y2 - z2,
|
| 76 |
+
2 * yz - 2 * wx,
|
| 77 |
+
2 * xz - 2 * wy,
|
| 78 |
+
2 * wx + 2 * yz,
|
| 79 |
+
w2 - x2 - y2 + z2,
|
| 80 |
+
],
|
| 81 |
+
dim=1,
|
| 82 |
+
).view(B, 3, 3)
|
| 83 |
+
return rotMat
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor:
|
| 87 |
+
"""
|
| 88 |
+
Convert 6D rotation representation to 3x3 rotation matrix.
|
| 89 |
+
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
|
| 90 |
+
Args:
|
| 91 |
+
x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
|
| 92 |
+
Returns:
|
| 93 |
+
torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
|
| 94 |
+
|
| 95 |
+
Alternatives:
|
| 96 |
+
import roma
|
| 97 |
+
x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous()
|
| 98 |
+
y = roma.special_gramschmidt(x)
|
| 99 |
+
"""
|
| 100 |
+
x = x.reshape(-1, 2, 3).permute(0, 2, 1).contiguous()
|
| 101 |
+
a1 = x[:, :, 0]
|
| 102 |
+
a2 = x[:, :, 1]
|
| 103 |
+
b1 = F.normalize(a1)
|
| 104 |
+
b2 = F.normalize(a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1)
|
| 105 |
+
b3 = torch.linalg.cross(b1, b2)
|
| 106 |
+
return torch.stack((b1, b2, b3), dim=-1)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def rotmat_to_rot6d(x: torch.Tensor) -> torch.Tensor:
|
| 110 |
+
"""
|
| 111 |
+
Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
|
| 112 |
+
by dropping the last row. Note that 6D representation is not unique.
|
| 113 |
+
Args:
|
| 114 |
+
x: batch of rotation matrices of size (B, 3, 3)
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
6D rotation representation, of size (B, 6)
|
| 118 |
+
|
| 119 |
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
| 120 |
+
On the Continuity of Rotation Representations in Neural Networks.
|
| 121 |
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
| 122 |
+
Retrieved from http://arxiv.org/abs/1812.07035
|
| 123 |
+
"""
|
| 124 |
+
batch_dim = x.size()[:-2]
|
| 125 |
+
return x[..., :2, :].clone().reshape(batch_dim + (6,))
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def rot_aa(aa: Float[np.ndarray, "3"], rot: float) -> Float[np.ndarray, "3"]:
|
| 129 |
+
"""
|
| 130 |
+
Rotate axis angle parameters.
|
| 131 |
+
Args:
|
| 132 |
+
aa (np.array): Axis-angle vector of shape (3,).
|
| 133 |
+
rot (np.array): Rotation angle in degrees.
|
| 134 |
+
Returns:
|
| 135 |
+
np.array: Rotated axis-angle vector.
|
| 136 |
+
"""
|
| 137 |
+
# pose parameters
|
| 138 |
+
R: Float[np.ndarray, "3 3"] = np.array(
|
| 139 |
+
[
|
| 140 |
+
[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
|
| 141 |
+
[np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
|
| 142 |
+
[0, 0, 1],
|
| 143 |
+
],
|
| 144 |
+
dtype=np.float64,
|
| 145 |
+
)
|
| 146 |
+
# find the rotation of the body in camera frame
|
| 147 |
+
per_rdg: Float[np.ndarray, "3 3"]
|
| 148 |
+
per_rdg, _ = cv2.Rodrigues(aa)
|
| 149 |
+
# apply the global rotation to the global orientation
|
| 150 |
+
resrot: Float[np.ndarray, "3 3"]
|
| 151 |
+
resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
|
| 152 |
+
aa_vec: Float[np.ndarray, "3"] = (resrot.T)[0]
|
| 153 |
+
return aa_vec.astype(np.float32)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def transform_points(
|
| 157 |
+
points: torch.Tensor,
|
| 158 |
+
translation: Optional[torch.Tensor] = None,
|
| 159 |
+
rotation: Optional[torch.Tensor] = None,
|
| 160 |
+
) -> torch.Tensor:
|
| 161 |
+
"""
|
| 162 |
+
Transform a set of 3D points given translation and rotation.
|
| 163 |
+
Args:
|
| 164 |
+
points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points.
|
| 165 |
+
translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation.
|
| 166 |
+
rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation.
|
| 167 |
+
Returns:
|
| 168 |
+
torch.Tensor: Tensor of shape (B, N, 3) containing the transformed points.
|
| 169 |
+
"""
|
| 170 |
+
if rotation is not None:
|
| 171 |
+
points = torch.einsum("bij,bkj->bki", rotation, points)
|
| 172 |
+
|
| 173 |
+
if translation is not None:
|
| 174 |
+
points = points + translation.unsqueeze(1)
|
| 175 |
+
|
| 176 |
+
return points
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def get_intrinsic_matrix(
|
| 180 |
+
focal_length: torch.Tensor, principle: torch.Tensor
|
| 181 |
+
) -> torch.Tensor:
|
| 182 |
+
"""
|
| 183 |
+
Populate intrinsic camera matrix K given focal length and principle point.
|
| 184 |
+
Args:
|
| 185 |
+
focal_length: Tensor of shape (2,)
|
| 186 |
+
principle: Tensor of shape (2,)
|
| 187 |
+
Returns:
|
| 188 |
+
Tensor of shape (3, 3)
|
| 189 |
+
"""
|
| 190 |
+
if isinstance(focal_length, float):
|
| 191 |
+
fl_x = fl_y = focal_length
|
| 192 |
+
elif len(focal_length) == 1:
|
| 193 |
+
fl_x = fl_y = focal_length[0]
|
| 194 |
+
else:
|
| 195 |
+
fl_x, fl_y = focal_length[0], focal_length[1]
|
| 196 |
+
K = torch.eye(3)
|
| 197 |
+
K[0, 0] = fl_x
|
| 198 |
+
K[1, 1] = fl_y
|
| 199 |
+
K[0, -1] = principle[0]
|
| 200 |
+
K[1, -1] = principle[1]
|
| 201 |
+
|
| 202 |
+
return K
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def perspective_projection(x, K):
|
| 206 |
+
"""
|
| 207 |
+
Computes the perspective projection of a set of points assuming the extrinsinc params have already been applied
|
| 208 |
+
Args:
|
| 209 |
+
- x [bs,N,3]: 3D points
|
| 210 |
+
- K [bs,3,3]: Camera instrincs params
|
| 211 |
+
"""
|
| 212 |
+
# Apply perspective distortion
|
| 213 |
+
y = x / x[:, :, -1].unsqueeze(-1) # (bs, N, 3)
|
| 214 |
+
|
| 215 |
+
# Apply camera intrinsics
|
| 216 |
+
y = torch.einsum("bij,bkj->bki", K, y) # (bs, N, 3)
|
| 217 |
+
|
| 218 |
+
return y[:, :, :2]
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def inverse_perspective_projection(points, K, distance):
|
| 222 |
+
"""
|
| 223 |
+
Computes the inverse perspective projection of a set of points given an estimated distance.
|
| 224 |
+
Input:
|
| 225 |
+
points (bs, N, 2): 2D points
|
| 226 |
+
K (bs,3,3): camera intrinsics params
|
| 227 |
+
distance (bs, N, 1): distance in the 3D world
|
| 228 |
+
Similar to:
|
| 229 |
+
- pts_l_norm = cv2.undistortPoints(np.expand_dims(pts_l, axis=1), cameraMatrix=K_l, distCoeffs=None)
|
| 230 |
+
"""
|
| 231 |
+
# Apply camera intrinsics
|
| 232 |
+
points = torch.cat([points, torch.ones_like(points[..., :1])], -1)
|
| 233 |
+
points = torch.einsum("bij,bkj->bki", torch.inverse(K), points)
|
| 234 |
+
|
| 235 |
+
# Apply perspective distortion
|
| 236 |
+
if distance == None:
|
| 237 |
+
return points
|
| 238 |
+
points = points * distance
|
| 239 |
+
return points
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def get_cam_intrinsics(img_size, fov=55, p_x=None, p_y=None):
|
| 243 |
+
"""Given image size, fov and principal point coordinates, return K the camera parameter matrix"""
|
| 244 |
+
K = np.eye(3)
|
| 245 |
+
# Get focal length.
|
| 246 |
+
focal = get_focalLength_from_fieldOfView(fov=fov, img_size=img_size)
|
| 247 |
+
K[0, 0], K[1, 1] = focal, focal
|
| 248 |
+
|
| 249 |
+
# Set principal point
|
| 250 |
+
if p_x is not None and p_y is not None:
|
| 251 |
+
K[0, -1], K[1, -1] = p_x * img_size, p_y * img_size
|
| 252 |
+
else:
|
| 253 |
+
K[0, -1], K[1, -1] = img_size // 2, img_size // 2
|
| 254 |
+
|
| 255 |
+
return K
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def get_focalLength_from_fieldOfView(fov=60, img_size=512):
|
| 259 |
+
"""
|
| 260 |
+
Compute the focal length of the camera lens by assuming a certain FOV for the entire image
|
| 261 |
+
Args:
|
| 262 |
+
- fov: float, expressed in degree
|
| 263 |
+
- img_size: int
|
| 264 |
+
Return:
|
| 265 |
+
focal: float
|
| 266 |
+
"""
|
| 267 |
+
focal = img_size / (2 * np.tan(np.radians(fov) / 2))
|
| 268 |
+
return focal
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def focal_length_normalization(x, f, fovn=60, img_size=448):
|
| 272 |
+
"""
|
| 273 |
+
Section 3.1 of https://arxiv.org/pdf/1904.02028.pdf
|
| 274 |
+
E = (fn/f) * E' where E is 1/d
|
| 275 |
+
"""
|
| 276 |
+
fn = get_focalLength_from_fieldOfView(fov=fovn, img_size=img_size)
|
| 277 |
+
y = x * (fn / f)
|
| 278 |
+
return y
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def undo_focal_length_normalization(y, f, fovn=60, img_size=448):
|
| 282 |
+
"""
|
| 283 |
+
Undo focal_length_normalization()
|
| 284 |
+
"""
|
| 285 |
+
fn = get_focalLength_from_fieldOfView(fov=fovn, img_size=img_size)
|
| 286 |
+
x = y * (f / fn)
|
| 287 |
+
return x
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
EPS_LOG = 1e-10
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def log_depth(x, eps=EPS_LOG):
|
| 294 |
+
"""
|
| 295 |
+
Move depth to log space
|
| 296 |
+
"""
|
| 297 |
+
return torch.log(x + eps)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def undo_log_depth(y, eps=EPS_LOG):
|
| 301 |
+
"""
|
| 302 |
+
Undo log_depth()
|
| 303 |
+
"""
|
| 304 |
+
return torch.exp(y) - eps
|
src/sam3d_body/models/modules/layer_scale.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LayerScale(nn.Module):
|
| 10 |
+
"""LayerScale layer.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
dim (int): Dimension of input features.
|
| 14 |
+
layer_scale_init_value (float or torch.Tensor): Init value of layer
|
| 15 |
+
scale. Defaults to 1e-5.
|
| 16 |
+
inplace (bool): inplace: can optionally do the
|
| 17 |
+
operation in-place. Defaults to False.
|
| 18 |
+
data_format (str): The input data format, could be 'channels_last'
|
| 19 |
+
or 'channels_first', representing (B, C, H, W) and
|
| 20 |
+
(B, N, C) format data respectively. Defaults to 'channels_last'.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
dim: int,
|
| 26 |
+
layer_scale_init_value: Union[float, torch.Tensor] = 1e-5,
|
| 27 |
+
inplace: bool = False,
|
| 28 |
+
data_format: str = "channels_last",
|
| 29 |
+
):
|
| 30 |
+
super().__init__()
|
| 31 |
+
assert data_format in (
|
| 32 |
+
"channels_last",
|
| 33 |
+
"channels_first",
|
| 34 |
+
), "'data_format' could only be channels_last or channels_first."
|
| 35 |
+
self.inplace = inplace
|
| 36 |
+
self.data_format = data_format
|
| 37 |
+
self.weight = nn.Parameter(torch.ones(dim) * layer_scale_init_value)
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
if self.data_format == "channels_first":
|
| 41 |
+
if self.inplace:
|
| 42 |
+
return x.mul_(self.weight.view(-1, 1, 1))
|
| 43 |
+
else:
|
| 44 |
+
return x * self.weight.view(-1, 1, 1)
|
| 45 |
+
return x.mul_(self.weight) if self.inplace else x * self.weight
|
src/sam3d_body/models/modules/mhr_utils.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import os.path as osp
|
| 6 |
+
import pickle
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def rotation_angle_difference(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
|
| 17 |
+
"""
|
| 18 |
+
Compute the angle difference (magnitude) between two batches of SO(3) rotation matrices.
|
| 19 |
+
Args:
|
| 20 |
+
A: Tensor of shape (*, 3, 3), batch of rotation matrices.
|
| 21 |
+
B: Tensor of shape (*, 3, 3), batch of rotation matrices.
|
| 22 |
+
Returns:
|
| 23 |
+
Tensor of shape (*,), angle differences in radians.
|
| 24 |
+
"""
|
| 25 |
+
# Compute relative rotation matrix
|
| 26 |
+
R_rel = torch.matmul(A, B.transpose(-2, -1)) # (B, 3, 3)
|
| 27 |
+
# Compute trace of relative rotation
|
| 28 |
+
trace = R_rel[..., 0, 0] + R_rel[..., 1, 1] + R_rel[..., 2, 2] # (B,)
|
| 29 |
+
# Compute angle using the trace formula
|
| 30 |
+
cos_theta = (trace - 1) / 2
|
| 31 |
+
# Clamp for numerical stability
|
| 32 |
+
cos_theta_clamped = torch.clamp(cos_theta, -1.0, 1.0)
|
| 33 |
+
# Compute angle difference
|
| 34 |
+
angle = torch.acos(cos_theta_clamped)
|
| 35 |
+
return angle
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def fix_wrist_euler(
|
| 39 |
+
wrist_xzy, limits_x=(-2.2, 1.0), limits_z=(-2.2, 1.5), limits_y=(-1.2, 1.5)
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
wrist_xzy: B x 2 x 3 (X, Z, Y angles)
|
| 43 |
+
Returns: Fixed angles within joint limits
|
| 44 |
+
"""
|
| 45 |
+
x, z, y = wrist_xzy[..., 0], wrist_xzy[..., 1], wrist_xzy[..., 2]
|
| 46 |
+
|
| 47 |
+
x_alt = torch.atan2(torch.sin(x + torch.pi), torch.cos(x + torch.pi))
|
| 48 |
+
z_alt = torch.atan2(torch.sin(-(z + torch.pi)), torch.cos(-(z + torch.pi)))
|
| 49 |
+
y_alt = torch.atan2(torch.sin(y + torch.pi), torch.cos(y + torch.pi))
|
| 50 |
+
|
| 51 |
+
# Calculate L2 violation distance
|
| 52 |
+
def calc_violation(val, limits):
|
| 53 |
+
below = torch.clamp(limits[0] - val, min=0.0)
|
| 54 |
+
above = torch.clamp(val - limits[1], min=0.0)
|
| 55 |
+
return below**2 + above**2
|
| 56 |
+
|
| 57 |
+
violation_orig = (
|
| 58 |
+
calc_violation(x, limits_x)
|
| 59 |
+
+ calc_violation(z, limits_z)
|
| 60 |
+
+ calc_violation(y, limits_y)
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
violation_alt = (
|
| 64 |
+
calc_violation(x_alt, limits_x)
|
| 65 |
+
+ calc_violation(z_alt, limits_z)
|
| 66 |
+
+ calc_violation(y_alt, limits_y)
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Use alternative where it has lower L2 violation
|
| 70 |
+
use_alt = violation_alt < violation_orig
|
| 71 |
+
|
| 72 |
+
# Stack alternative and apply mask
|
| 73 |
+
wrist_xzy_alt = torch.stack([x_alt, z_alt, y_alt], dim=-1)
|
| 74 |
+
result = torch.where(use_alt.unsqueeze(-1), wrist_xzy_alt, wrist_xzy)
|
| 75 |
+
|
| 76 |
+
return result
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def batch6DFromXYZ(r, return_9D=False):
|
| 80 |
+
"""
|
| 81 |
+
Generate a matrix representing a rotation defined by a XYZ-Euler
|
| 82 |
+
rotation.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
r: ... x 3 rotation vectors
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
... x 6
|
| 89 |
+
"""
|
| 90 |
+
rc = torch.cos(r)
|
| 91 |
+
rs = torch.sin(r)
|
| 92 |
+
cx = rc[..., 0]
|
| 93 |
+
cy = rc[..., 1]
|
| 94 |
+
cz = rc[..., 2]
|
| 95 |
+
sx = rs[..., 0]
|
| 96 |
+
sy = rs[..., 1]
|
| 97 |
+
sz = rs[..., 2]
|
| 98 |
+
|
| 99 |
+
result = torch.empty(list(r.shape[:-1]) + [3, 3], dtype=r.dtype).to(r.device)
|
| 100 |
+
|
| 101 |
+
result[..., 0, 0] = cy * cz
|
| 102 |
+
result[..., 0, 1] = -cx * sz + sx * sy * cz
|
| 103 |
+
result[..., 0, 2] = sx * sz + cx * sy * cz
|
| 104 |
+
result[..., 1, 0] = cy * sz
|
| 105 |
+
result[..., 1, 1] = cx * cz + sx * sy * sz
|
| 106 |
+
result[..., 1, 2] = -sx * cz + cx * sy * sz
|
| 107 |
+
result[..., 2, 0] = -sy
|
| 108 |
+
result[..., 2, 1] = sx * cy
|
| 109 |
+
result[..., 2, 2] = cx * cy
|
| 110 |
+
|
| 111 |
+
if not return_9D:
|
| 112 |
+
return torch.cat([result[..., :, 0], result[..., :, 1]], dim=-1)
|
| 113 |
+
else:
|
| 114 |
+
return result
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# https://github.com/papagina/RotationContinuity/blob/758b0ce551c06372cab7022d4c0bdf331c89c696/shapenet/code/tools.py#L82
|
| 118 |
+
def batchXYZfrom6D(poses):
|
| 119 |
+
# Args: poses: ... x 6, where "6" is the combined first and second columns
|
| 120 |
+
# First, get the rotaiton matrix
|
| 121 |
+
x_raw = poses[..., :3]
|
| 122 |
+
y_raw = poses[..., 3:]
|
| 123 |
+
|
| 124 |
+
x = F.normalize(x_raw, dim=-1)
|
| 125 |
+
z = torch.cross(x, y_raw, dim=-1)
|
| 126 |
+
z = F.normalize(z, dim=-1)
|
| 127 |
+
y = torch.cross(z, x, dim=-1)
|
| 128 |
+
|
| 129 |
+
matrix = torch.stack([x, y, z], dim=-1) # ... x 3 x 3
|
| 130 |
+
|
| 131 |
+
# Now get it into euler
|
| 132 |
+
# https://github.com/papagina/RotationContinuity/blob/758b0ce551c06372cab7022d4c0bdf331c89c696/shapenet/code/tools.py#L412
|
| 133 |
+
sy = torch.sqrt(
|
| 134 |
+
matrix[..., 0, 0] * matrix[..., 0, 0] + matrix[..., 1, 0] * matrix[..., 1, 0]
|
| 135 |
+
)
|
| 136 |
+
singular = sy < 1e-6
|
| 137 |
+
singular = singular.float()
|
| 138 |
+
|
| 139 |
+
x = torch.atan2(matrix[..., 2, 1], matrix[..., 2, 2])
|
| 140 |
+
y = torch.atan2(-matrix[..., 2, 0], sy)
|
| 141 |
+
z = torch.atan2(matrix[..., 1, 0], matrix[..., 0, 0])
|
| 142 |
+
|
| 143 |
+
xs = torch.atan2(-matrix[..., 1, 2], matrix[..., 1, 1])
|
| 144 |
+
ys = torch.atan2(-matrix[..., 2, 0], sy)
|
| 145 |
+
zs = matrix[..., 1, 0] * 0
|
| 146 |
+
|
| 147 |
+
out_euler = torch.zeros_like(matrix[..., 0])
|
| 148 |
+
out_euler[..., 0] = x * (1 - singular) + xs * singular
|
| 149 |
+
out_euler[..., 1] = y * (1 - singular) + ys * singular
|
| 150 |
+
out_euler[..., 2] = z * (1 - singular) + zs * singular
|
| 151 |
+
|
| 152 |
+
return out_euler
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def resize_image(image_array, scale_factor, interpolation=cv2.INTER_LINEAR):
|
| 156 |
+
new_height = int(image_array.shape[0] // scale_factor)
|
| 157 |
+
new_width = int(image_array.shape[1] // scale_factor)
|
| 158 |
+
resized_image = cv2.resize(
|
| 159 |
+
image_array, (new_width, new_height), interpolation=interpolation
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
return resized_image
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def compact_cont_to_model_params_hand(hand_cont):
|
| 166 |
+
# These are ordered by joint, not model params ^^
|
| 167 |
+
assert hand_cont.shape[-1] == 54
|
| 168 |
+
hand_dofs_in_order = torch.tensor([3, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 2, 3, 1, 1])
|
| 169 |
+
assert sum(hand_dofs_in_order) == 27
|
| 170 |
+
# Mask of 3DoFs into hand_cont
|
| 171 |
+
mask_cont_threedofs = torch.cat(
|
| 172 |
+
[torch.ones(2 * k).bool() * (k in [3]) for k in hand_dofs_in_order]
|
| 173 |
+
)
|
| 174 |
+
# Mask of 1DoFs (including 2DoF) into hand_cont
|
| 175 |
+
mask_cont_onedofs = torch.cat(
|
| 176 |
+
[torch.ones(2 * k).bool() * (k in [1, 2]) for k in hand_dofs_in_order]
|
| 177 |
+
)
|
| 178 |
+
# Mask of 3DoFs into hand_model_params
|
| 179 |
+
mask_model_params_threedofs = torch.cat(
|
| 180 |
+
[torch.ones(k).bool() * (k in [3]) for k in hand_dofs_in_order]
|
| 181 |
+
)
|
| 182 |
+
# Mask of 1DoFs (including 2DoF) into hand_model_params
|
| 183 |
+
mask_model_params_onedofs = torch.cat(
|
| 184 |
+
[torch.ones(k).bool() * (k in [1, 2]) for k in hand_dofs_in_order]
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Convert hand_cont to eulers
|
| 188 |
+
## First for 3DoFs
|
| 189 |
+
hand_cont_threedofs = hand_cont[..., mask_cont_threedofs].unflatten(-1, (-1, 6))
|
| 190 |
+
hand_model_params_threedofs = batchXYZfrom6D(hand_cont_threedofs).flatten(-2, -1)
|
| 191 |
+
## Next for 1DoFs
|
| 192 |
+
hand_cont_onedofs = hand_cont[..., mask_cont_onedofs].unflatten(
|
| 193 |
+
-1, (-1, 2)
|
| 194 |
+
) # (sincos)
|
| 195 |
+
hand_model_params_onedofs = torch.atan2(
|
| 196 |
+
hand_cont_onedofs[..., -2], hand_cont_onedofs[..., -1]
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# Finally, assemble into a 27-dim vector, ordered by joint, then XYZ.
|
| 200 |
+
hand_model_params = torch.zeros(*hand_cont.shape[:-1], 27).to(hand_cont)
|
| 201 |
+
hand_model_params[..., mask_model_params_threedofs] = hand_model_params_threedofs
|
| 202 |
+
hand_model_params[..., mask_model_params_onedofs] = hand_model_params_onedofs
|
| 203 |
+
|
| 204 |
+
return hand_model_params
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def compact_model_params_to_cont_hand(hand_model_params):
|
| 208 |
+
# These are ordered by joint, not model params ^^
|
| 209 |
+
assert hand_model_params.shape[-1] == 27
|
| 210 |
+
hand_dofs_in_order = torch.tensor([3, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 2, 3, 1, 1])
|
| 211 |
+
assert sum(hand_dofs_in_order) == 27
|
| 212 |
+
# Mask of 3DoFs into hand_cont
|
| 213 |
+
mask_cont_threedofs = torch.cat(
|
| 214 |
+
[torch.ones(2 * k).bool() * (k in [3]) for k in hand_dofs_in_order]
|
| 215 |
+
)
|
| 216 |
+
# Mask of 1DoFs (including 2DoF) into hand_cont
|
| 217 |
+
mask_cont_onedofs = torch.cat(
|
| 218 |
+
[torch.ones(2 * k).bool() * (k in [1, 2]) for k in hand_dofs_in_order]
|
| 219 |
+
)
|
| 220 |
+
# Mask of 3DoFs into hand_model_params
|
| 221 |
+
mask_model_params_threedofs = torch.cat(
|
| 222 |
+
[torch.ones(k).bool() * (k in [3]) for k in hand_dofs_in_order]
|
| 223 |
+
)
|
| 224 |
+
# Mask of 1DoFs (including 2DoF) into hand_model_params
|
| 225 |
+
mask_model_params_onedofs = torch.cat(
|
| 226 |
+
[torch.ones(k).bool() * (k in [1, 2]) for k in hand_dofs_in_order]
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Convert eulers to hand_cont hand_cont
|
| 230 |
+
## First for 3DoFs
|
| 231 |
+
hand_model_params_threedofs = hand_model_params[
|
| 232 |
+
..., mask_model_params_threedofs
|
| 233 |
+
].unflatten(-1, (-1, 3))
|
| 234 |
+
hand_cont_threedofs = batch6DFromXYZ(hand_model_params_threedofs).flatten(-2, -1)
|
| 235 |
+
## Next for 1DoFs
|
| 236 |
+
hand_model_params_onedofs = hand_model_params[..., mask_model_params_onedofs]
|
| 237 |
+
hand_cont_onedofs = torch.stack(
|
| 238 |
+
[hand_model_params_onedofs.sin(), hand_model_params_onedofs.cos()], dim=-1
|
| 239 |
+
).flatten(-2, -1)
|
| 240 |
+
|
| 241 |
+
# Finally, assemble into a 27-dim vector, ordered by joint, then XYZ.
|
| 242 |
+
hand_cont = torch.zeros(*hand_model_params.shape[:-1], 54).to(hand_model_params)
|
| 243 |
+
hand_cont[..., mask_cont_threedofs] = hand_cont_threedofs
|
| 244 |
+
hand_cont[..., mask_cont_onedofs] = hand_cont_onedofs
|
| 245 |
+
|
| 246 |
+
return hand_cont
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def batch9Dfrom6D(poses):
|
| 250 |
+
# Args: poses: ... x 6, where "6" is the combined first and second columns
|
| 251 |
+
# First, get the rotaiton matrix
|
| 252 |
+
x_raw = poses[..., :3]
|
| 253 |
+
y_raw = poses[..., 3:]
|
| 254 |
+
|
| 255 |
+
x = F.normalize(x_raw, dim=-1)
|
| 256 |
+
z = torch.cross(x, y_raw, dim=-1)
|
| 257 |
+
z = F.normalize(z, dim=-1)
|
| 258 |
+
y = torch.cross(z, x, dim=-1)
|
| 259 |
+
|
| 260 |
+
matrix = torch.stack([x, y, z], dim=-1).flatten(-2, -1) # ... x 3 x 3 -> x9
|
| 261 |
+
|
| 262 |
+
return matrix
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def batch4Dfrom2D(poses):
|
| 266 |
+
# Args: poses: ... x 2, where "2" is sincos
|
| 267 |
+
poses_norm = F.normalize(poses, dim=-1)
|
| 268 |
+
|
| 269 |
+
poses_4d = torch.stack(
|
| 270 |
+
[
|
| 271 |
+
poses_norm[..., 1],
|
| 272 |
+
poses_norm[..., 0],
|
| 273 |
+
-poses_norm[..., 0],
|
| 274 |
+
poses_norm[..., 1],
|
| 275 |
+
],
|
| 276 |
+
dim=-1,
|
| 277 |
+
) # Flattened SO2.
|
| 278 |
+
|
| 279 |
+
return poses_4d # .... x 4
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def compact_cont_to_rotmat_body(body_pose_cont, inflate_trans=False):
|
| 283 |
+
# fmt: off
|
| 284 |
+
all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)])
|
| 285 |
+
all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123])
|
| 286 |
+
all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129])
|
| 287 |
+
# fmt: on
|
| 288 |
+
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
|
| 289 |
+
num_1dof_angles = len(all_param_1dof_rot_idxs)
|
| 290 |
+
num_1dof_trans = len(all_param_1dof_trans_idxs)
|
| 291 |
+
assert body_pose_cont.shape[-1] == (
|
| 292 |
+
2 * num_3dof_angles + 2 * num_1dof_angles + num_1dof_trans
|
| 293 |
+
)
|
| 294 |
+
# Get subsets
|
| 295 |
+
body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles]
|
| 296 |
+
body_cont_1dofs = body_pose_cont[
|
| 297 |
+
..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles
|
| 298 |
+
]
|
| 299 |
+
body_cont_trans = body_pose_cont[..., 2 * num_3dof_angles + 2 * num_1dof_angles :]
|
| 300 |
+
# Convert conts to model params
|
| 301 |
+
## First for 3dofs
|
| 302 |
+
body_cont_3dofs = body_cont_3dofs.unflatten(-1, (-1, 6))
|
| 303 |
+
body_rotmat_3dofs = batch9Dfrom6D(body_cont_3dofs).flatten(-2, -1)
|
| 304 |
+
## Next for 1dofs
|
| 305 |
+
body_cont_1dofs = body_cont_1dofs.unflatten(-1, (-1, 2)) # (sincos)
|
| 306 |
+
body_rotmat_1dofs = batch4Dfrom2D(body_cont_1dofs).flatten(-2, -1)
|
| 307 |
+
if inflate_trans:
|
| 308 |
+
assert (
|
| 309 |
+
False
|
| 310 |
+
), "This is left as a possibility to increase the space/contribution/supervision trans params gets compared to rots"
|
| 311 |
+
else:
|
| 312 |
+
## Nothing to do for trans
|
| 313 |
+
body_rotmat_trans = body_cont_trans
|
| 314 |
+
# Put them together
|
| 315 |
+
body_rotmat_params = torch.cat(
|
| 316 |
+
[body_rotmat_3dofs, body_rotmat_1dofs, body_rotmat_trans], dim=-1
|
| 317 |
+
)
|
| 318 |
+
return body_rotmat_params
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def compact_cont_to_model_params_body(body_pose_cont):
|
| 322 |
+
# fmt: off
|
| 323 |
+
all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)])
|
| 324 |
+
all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123])
|
| 325 |
+
all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129])
|
| 326 |
+
# fmt: on
|
| 327 |
+
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
|
| 328 |
+
num_1dof_angles = len(all_param_1dof_rot_idxs)
|
| 329 |
+
num_1dof_trans = len(all_param_1dof_trans_idxs)
|
| 330 |
+
assert body_pose_cont.shape[-1] == (
|
| 331 |
+
2 * num_3dof_angles + 2 * num_1dof_angles + num_1dof_trans
|
| 332 |
+
)
|
| 333 |
+
# Get subsets
|
| 334 |
+
body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles]
|
| 335 |
+
body_cont_1dofs = body_pose_cont[
|
| 336 |
+
..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles
|
| 337 |
+
]
|
| 338 |
+
body_cont_trans = body_pose_cont[..., 2 * num_3dof_angles + 2 * num_1dof_angles :]
|
| 339 |
+
# Convert conts to model params
|
| 340 |
+
## First for 3dofs
|
| 341 |
+
body_cont_3dofs = body_cont_3dofs.unflatten(-1, (-1, 6))
|
| 342 |
+
body_params_3dofs = batchXYZfrom6D(body_cont_3dofs).flatten(-2, -1)
|
| 343 |
+
## Next for 1dofs
|
| 344 |
+
body_cont_1dofs = body_cont_1dofs.unflatten(-1, (-1, 2)) # (sincos)
|
| 345 |
+
body_params_1dofs = torch.atan2(body_cont_1dofs[..., -2], body_cont_1dofs[..., -1])
|
| 346 |
+
## Nothing to do for trans
|
| 347 |
+
body_params_trans = body_cont_trans
|
| 348 |
+
# Put them together
|
| 349 |
+
body_pose_params = torch.zeros(*body_pose_cont.shape[:-1], 133).to(body_pose_cont)
|
| 350 |
+
body_pose_params[..., all_param_3dof_rot_idxs.flatten()] = body_params_3dofs
|
| 351 |
+
body_pose_params[..., all_param_1dof_rot_idxs] = body_params_1dofs
|
| 352 |
+
body_pose_params[..., all_param_1dof_trans_idxs] = body_params_trans
|
| 353 |
+
return body_pose_params
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def compact_model_params_to_cont_body(body_pose_params):
|
| 357 |
+
# fmt: off
|
| 358 |
+
all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)])
|
| 359 |
+
all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123])
|
| 360 |
+
all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129])
|
| 361 |
+
# fmt: on
|
| 362 |
+
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
|
| 363 |
+
num_1dof_angles = len(all_param_1dof_rot_idxs)
|
| 364 |
+
num_1dof_trans = len(all_param_1dof_trans_idxs)
|
| 365 |
+
assert body_pose_params.shape[-1] == (
|
| 366 |
+
num_3dof_angles + num_1dof_angles + num_1dof_trans
|
| 367 |
+
)
|
| 368 |
+
# Take out params
|
| 369 |
+
body_params_3dofs = body_pose_params[..., all_param_3dof_rot_idxs.flatten()]
|
| 370 |
+
body_params_1dofs = body_pose_params[..., all_param_1dof_rot_idxs]
|
| 371 |
+
body_params_trans = body_pose_params[..., all_param_1dof_trans_idxs]
|
| 372 |
+
# params to cont
|
| 373 |
+
body_cont_3dofs = batch6DFromXYZ(body_params_3dofs.unflatten(-1, (-1, 3))).flatten(
|
| 374 |
+
-2, -1
|
| 375 |
+
)
|
| 376 |
+
body_cont_1dofs = torch.stack(
|
| 377 |
+
[body_params_1dofs.sin(), body_params_1dofs.cos()], dim=-1
|
| 378 |
+
).flatten(-2, -1)
|
| 379 |
+
body_cont_trans = body_params_trans
|
| 380 |
+
# Put them together
|
| 381 |
+
body_pose_cont = torch.cat(
|
| 382 |
+
[body_cont_3dofs, body_cont_1dofs, body_cont_trans], dim=-1
|
| 383 |
+
)
|
| 384 |
+
return body_pose_cont
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
# fmt: off
|
| 388 |
+
mhr_param_hand_idxs = [62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115]
|
| 389 |
+
mhr_cont_hand_idxs = [72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237]
|
| 390 |
+
mhr_param_hand_mask = torch.zeros(133).bool(); mhr_param_hand_mask[mhr_param_hand_idxs] = True
|
| 391 |
+
mhr_cont_hand_mask = torch.zeros(260).bool(); mhr_cont_hand_mask[mhr_cont_hand_idxs] = True
|
| 392 |
+
# fmt: on
|
src/sam3d_body/models/modules/misc.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
import collections.abc
|
| 4 |
+
from itertools import repeat
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# From PyTorch internals
|
| 8 |
+
def _ntuple(n):
|
| 9 |
+
"""A `to_tuple` function generator.
|
| 10 |
+
|
| 11 |
+
It returns a function, this function will repeat the input to a tuple of
|
| 12 |
+
length ``n`` if the input is not an Iterable object, otherwise, return the
|
| 13 |
+
input directly.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
n (int): The number of the target length.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def parse(x):
|
| 20 |
+
if isinstance(x, collections.abc.Iterable):
|
| 21 |
+
return x
|
| 22 |
+
return tuple(repeat(x, n))
|
| 23 |
+
|
| 24 |
+
return parse
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
to_1tuple = _ntuple(1)
|
| 28 |
+
to_2tuple = _ntuple(2)
|
| 29 |
+
to_3tuple = _ntuple(3)
|
| 30 |
+
to_4tuple = _ntuple(4)
|
| 31 |
+
to_ntuple = _ntuple
|
src/sam3d_body/models/modules/swiglu_ffn.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from .drop_path import DropPath
|
| 10 |
+
|
| 11 |
+
from .layer_scale import LayerScale
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SwiGLUFFN(nn.Module):
|
| 15 |
+
"""SwiGLU FFN layer.
|
| 16 |
+
|
| 17 |
+
Modified from https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py
|
| 18 |
+
""" # noqa
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
embed_dims: int,
|
| 23 |
+
feedforward_channels: Optional[int] = None,
|
| 24 |
+
out_dims: Optional[int] = None,
|
| 25 |
+
layer_scale_init_value: float = 0.0,
|
| 26 |
+
bias: bool = True,
|
| 27 |
+
drop_path_rate: float = 0.0,
|
| 28 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 29 |
+
add_identity: bool = True,
|
| 30 |
+
) -> None:
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.embed_dims = embed_dims
|
| 33 |
+
self.out_dims = out_dims or embed_dims
|
| 34 |
+
hidden_dims = feedforward_channels or embed_dims
|
| 35 |
+
|
| 36 |
+
self.w12 = nn.Linear(self.embed_dims, 2 * hidden_dims, bias=bias)
|
| 37 |
+
|
| 38 |
+
self.norm = norm_layer
|
| 39 |
+
|
| 40 |
+
self.w3 = nn.Linear(hidden_dims, self.out_dims, bias=bias)
|
| 41 |
+
|
| 42 |
+
if layer_scale_init_value > 0:
|
| 43 |
+
self.gamma2 = LayerScale(
|
| 44 |
+
dim=embed_dims, layer_scale_init_value=layer_scale_init_value
|
| 45 |
+
)
|
| 46 |
+
else:
|
| 47 |
+
self.gamma2 = nn.Identity()
|
| 48 |
+
|
| 49 |
+
self.dropout_layer = DropPath(drop_path_rate)
|
| 50 |
+
self.add_identity = add_identity
|
| 51 |
+
|
| 52 |
+
def forward(
|
| 53 |
+
self, x: torch.Tensor, identity: Optional[torch.Tensor] = None
|
| 54 |
+
) -> torch.Tensor:
|
| 55 |
+
x12 = self.w12(x)
|
| 56 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 57 |
+
hidden = F.silu(x1) * x2
|
| 58 |
+
hidden = self.norm(hidden)
|
| 59 |
+
out = self.w3(hidden)
|
| 60 |
+
out = self.gamma2(out)
|
| 61 |
+
out = self.dropout_layer(out)
|
| 62 |
+
|
| 63 |
+
if self.out_dims != self.embed_dims or not self.add_identity:
|
| 64 |
+
# due to the dimension inconsistence or user setting
|
| 65 |
+
# not to apply residual operation
|
| 66 |
+
return out
|
| 67 |
+
|
| 68 |
+
if identity is None:
|
| 69 |
+
identity = x
|
| 70 |
+
return identity + out
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class SwiGLUFFNFused(SwiGLUFFN):
|
| 74 |
+
"""SwiGLU FFN layer with fusing.
|
| 75 |
+
|
| 76 |
+
Modified from https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py
|
| 77 |
+
""" # noqa
|
| 78 |
+
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
embed_dims: int,
|
| 82 |
+
feedforward_channels: Optional[int] = None,
|
| 83 |
+
out_dims: Optional[int] = None,
|
| 84 |
+
layer_scale_init_value: float = 0.0,
|
| 85 |
+
bias: bool = True,
|
| 86 |
+
) -> None:
|
| 87 |
+
out_dims = out_dims or embed_dims
|
| 88 |
+
feedforward_channels = feedforward_channels or embed_dims
|
| 89 |
+
feedforward_channels = (int(feedforward_channels * 2 / 3) + 7) // 8 * 8
|
| 90 |
+
super().__init__(
|
| 91 |
+
embed_dims=embed_dims,
|
| 92 |
+
feedforward_channels=feedforward_channels,
|
| 93 |
+
out_dims=out_dims,
|
| 94 |
+
layer_scale_init_value=layer_scale_init_value,
|
| 95 |
+
bias=bias,
|
| 96 |
+
)
|
src/sam3d_body/models/modules/transformer.py
ADDED
|
@@ -0,0 +1,651 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
from typing import Dict, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from .drop_path import DropPath
|
| 10 |
+
|
| 11 |
+
from .layer_scale import LayerScale
|
| 12 |
+
from .swiglu_ffn import SwiGLUFFNFused
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class MLP(nn.Module):
|
| 16 |
+
# borrowed from DET R
|
| 17 |
+
"""Very simple multi-layer perceptron (also called FFN)"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.num_layers = num_layers
|
| 22 |
+
h = [hidden_dim] * (num_layers - 1)
|
| 23 |
+
self.layers = nn.ModuleList(
|
| 24 |
+
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
for i, layer in enumerate(self.layers):
|
| 29 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
| 30 |
+
return x
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class LayerNorm32(nn.LayerNorm):
|
| 34 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 35 |
+
return super().forward(x.float()).type(x.dtype)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def build_norm_layer(cfg: Dict, num_features: int):
|
| 39 |
+
"""Build normalization layer.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
cfg (dict): The norm layer config, which should contain:
|
| 43 |
+
|
| 44 |
+
- type (str): Layer type.
|
| 45 |
+
- layer args: Args needed to instantiate a norm layer.
|
| 46 |
+
- requires_grad (bool, optional): Whether stop gradient updates.
|
| 47 |
+
num_features (int): Number of input channels.
|
| 48 |
+
postfix (int | str): The postfix to be appended into norm abbreviation
|
| 49 |
+
to create named layer.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
tuple[str, nn.Module]: The first element is the layer name consisting
|
| 53 |
+
of abbreviation and postfix, e.g., bn1, gn. The second element is the
|
| 54 |
+
created norm layer.
|
| 55 |
+
"""
|
| 56 |
+
if not isinstance(cfg, dict):
|
| 57 |
+
raise TypeError("cfg must be a dict")
|
| 58 |
+
if "type" not in cfg:
|
| 59 |
+
raise KeyError('the cfg dict must contain the key "type"')
|
| 60 |
+
cfg_ = cfg.copy()
|
| 61 |
+
|
| 62 |
+
layer_type = cfg_.pop("type")
|
| 63 |
+
if layer_type == "LN":
|
| 64 |
+
norm_layer = LayerNorm32
|
| 65 |
+
else:
|
| 66 |
+
raise ValueError("Unsupported norm layer: ", layer_type)
|
| 67 |
+
|
| 68 |
+
requires_grad = cfg_.pop("requires_grad", True)
|
| 69 |
+
cfg_.setdefault("eps", 1e-5)
|
| 70 |
+
if norm_layer is not nn.GroupNorm:
|
| 71 |
+
layer = norm_layer(num_features, **cfg_)
|
| 72 |
+
if layer_type == "SyncBN" and hasattr(layer, "_specify_ddp_gpu_num"):
|
| 73 |
+
layer._specify_ddp_gpu_num(1)
|
| 74 |
+
else:
|
| 75 |
+
assert "num_groups" in cfg_
|
| 76 |
+
layer = norm_layer(num_channels=num_features, **cfg_)
|
| 77 |
+
|
| 78 |
+
for param in layer.parameters():
|
| 79 |
+
param.requires_grad = requires_grad
|
| 80 |
+
|
| 81 |
+
return layer
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class LayerNorm2d(nn.Module):
|
| 85 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
| 88 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
| 89 |
+
self.eps = eps
|
| 90 |
+
|
| 91 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 92 |
+
u = x.mean(1, keepdim=True)
|
| 93 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 94 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 95 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 96 |
+
return x
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class FFN(nn.Module):
|
| 100 |
+
"""Implements feed-forward networks (FFNs) with identity connection.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
embed_dims (int): The feature dimension. Same as
|
| 104 |
+
`MultiheadAttention`. Defaults: 256.
|
| 105 |
+
feedforward_channels (int): The hidden dimension of FFNs.
|
| 106 |
+
Defaults: 1024.
|
| 107 |
+
num_fcs (int, optional): The number of fully-connected layers in
|
| 108 |
+
FFNs. Default: 2.
|
| 109 |
+
act_layer (nn.Module, optional): The activation layer for FFNs.
|
| 110 |
+
Default: nn.ReLU
|
| 111 |
+
ffn_drop (float, optional): Probability of an element to be
|
| 112 |
+
zeroed in FFN. Default 0.0.
|
| 113 |
+
add_identity (bool, optional): Whether to add the
|
| 114 |
+
identity connection. Default: `True`.
|
| 115 |
+
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
| 116 |
+
layer_scale_init_value (float): Initial value of scale factor in
|
| 117 |
+
LayerScale. Default: 1.0
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
# @deprecated_api_warning(
|
| 121 |
+
# {
|
| 122 |
+
# 'dropout': 'ffn_drop',
|
| 123 |
+
# 'add_residual': 'add_identity'
|
| 124 |
+
# },
|
| 125 |
+
# cls_name='FFN')
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
embed_dims=256,
|
| 129 |
+
feedforward_channels=1024,
|
| 130 |
+
output_dims=None,
|
| 131 |
+
num_fcs=2,
|
| 132 |
+
act_layer=nn.ReLU,
|
| 133 |
+
ffn_drop=0.0,
|
| 134 |
+
drop_path_rate=0.0,
|
| 135 |
+
add_identity=True,
|
| 136 |
+
layer_scale_init_value=0.0,
|
| 137 |
+
):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.embed_dims = embed_dims
|
| 140 |
+
self.feedforward_channels = feedforward_channels
|
| 141 |
+
self.output_dims = output_dims or embed_dims
|
| 142 |
+
self.num_fcs = num_fcs
|
| 143 |
+
|
| 144 |
+
layers = []
|
| 145 |
+
in_channels = embed_dims
|
| 146 |
+
for _ in range(num_fcs - 1):
|
| 147 |
+
layers.append(
|
| 148 |
+
nn.Sequential(
|
| 149 |
+
nn.Linear(in_channels, feedforward_channels),
|
| 150 |
+
act_layer(),
|
| 151 |
+
nn.Dropout(ffn_drop),
|
| 152 |
+
)
|
| 153 |
+
)
|
| 154 |
+
in_channels = feedforward_channels
|
| 155 |
+
layers.append(nn.Linear(in_channels, self.output_dims))
|
| 156 |
+
layers.append(nn.Dropout(ffn_drop))
|
| 157 |
+
self.layers = nn.Sequential(*layers)
|
| 158 |
+
self.dropout_layer = (
|
| 159 |
+
DropPath(drop_path_rate) if drop_path_rate > 0.0 else torch.nn.Identity()
|
| 160 |
+
)
|
| 161 |
+
self.add_identity = add_identity
|
| 162 |
+
|
| 163 |
+
if layer_scale_init_value > 0:
|
| 164 |
+
self.gamma2 = LayerScale(embed_dims, scale=layer_scale_init_value)
|
| 165 |
+
else:
|
| 166 |
+
self.gamma2 = nn.Identity()
|
| 167 |
+
|
| 168 |
+
# @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
|
| 169 |
+
def forward(self, x, identity=None):
|
| 170 |
+
"""Forward function for `FFN`.
|
| 171 |
+
|
| 172 |
+
The function would add x to the output tensor if residue is None.
|
| 173 |
+
"""
|
| 174 |
+
out = self.layers(x)
|
| 175 |
+
out = self.gamma2(out)
|
| 176 |
+
if not self.add_identity:
|
| 177 |
+
return self.dropout_layer(out)
|
| 178 |
+
if identity is None:
|
| 179 |
+
identity = x
|
| 180 |
+
return identity + self.dropout_layer(out)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class MultiheadAttention(nn.Module):
|
| 184 |
+
"""Multi-head Attention Module.
|
| 185 |
+
|
| 186 |
+
This module implements multi-head attention that supports different input
|
| 187 |
+
dims and embed dims. And it also supports a shortcut from ``value``, which
|
| 188 |
+
is useful if input dims is not the same with embed dims.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
embed_dims (int): The embedding dimension.
|
| 192 |
+
num_heads (int): Parallel attention heads.
|
| 193 |
+
input_dims (int, optional): The input dimension, and if None,
|
| 194 |
+
use ``embed_dims``. Defaults to None.
|
| 195 |
+
attn_drop (float): Dropout rate of the dropout layer after the
|
| 196 |
+
attention calculation of query and key. Defaults to 0.
|
| 197 |
+
proj_drop (float): Dropout rate of the dropout layer after the
|
| 198 |
+
output projection. Defaults to 0.
|
| 199 |
+
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
| 200 |
+
qkv_bias (bool): If True, add a learnable bias to q, k, v.
|
| 201 |
+
Defaults to True.
|
| 202 |
+
qk_scale (float, optional): Override default qk scale of
|
| 203 |
+
``head_dim ** -0.5`` if set. Defaults to None.
|
| 204 |
+
proj_bias (bool) If True, add a learnable bias to output projection.
|
| 205 |
+
Defaults to True.
|
| 206 |
+
v_shortcut (bool): Add a shortcut from value to output. It's usually
|
| 207 |
+
used if ``input_dims`` is different from ``embed_dims``.
|
| 208 |
+
Defaults to False.
|
| 209 |
+
use_layer_scale (bool): Whether to use layer scale. Defaults to False.
|
| 210 |
+
layer_scale_init_value (float or torch.Tensor): Init value of layer
|
| 211 |
+
scale. Defaults to 0.
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
def __init__(
|
| 215 |
+
self,
|
| 216 |
+
embed_dims,
|
| 217 |
+
num_heads,
|
| 218 |
+
input_dims=None,
|
| 219 |
+
attn_drop=0.0,
|
| 220 |
+
proj_drop=0.0,
|
| 221 |
+
drop_path_rate=0.0,
|
| 222 |
+
qkv_bias=True,
|
| 223 |
+
proj_bias=True,
|
| 224 |
+
v_shortcut=False,
|
| 225 |
+
layer_scale_init_value=0.0,
|
| 226 |
+
):
|
| 227 |
+
super().__init__()
|
| 228 |
+
|
| 229 |
+
self.input_dims = input_dims or embed_dims
|
| 230 |
+
self.embed_dims = embed_dims
|
| 231 |
+
self.num_heads = num_heads
|
| 232 |
+
self.v_shortcut = v_shortcut
|
| 233 |
+
|
| 234 |
+
self.head_dims = embed_dims // num_heads
|
| 235 |
+
|
| 236 |
+
self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias)
|
| 237 |
+
self.attn_drop = attn_drop
|
| 238 |
+
self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
|
| 239 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 240 |
+
|
| 241 |
+
self.out_drop = DropPath(drop_path_rate)
|
| 242 |
+
|
| 243 |
+
if layer_scale_init_value > 0:
|
| 244 |
+
layer_scale_init_value = layer_scale_init_value or 1e-5
|
| 245 |
+
self.gamma1 = LayerScale(
|
| 246 |
+
embed_dims, layer_scale_init_value=layer_scale_init_value
|
| 247 |
+
)
|
| 248 |
+
else:
|
| 249 |
+
self.gamma1 = nn.Identity()
|
| 250 |
+
|
| 251 |
+
def forward(self, x):
|
| 252 |
+
B, N, _ = x.shape
|
| 253 |
+
qkv = (
|
| 254 |
+
self.qkv(x)
|
| 255 |
+
.reshape(B, N, 3, self.num_heads, self.head_dims)
|
| 256 |
+
.permute(2, 0, 3, 1, 4)
|
| 257 |
+
)
|
| 258 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 259 |
+
|
| 260 |
+
attn_drop = self.attn_drop if self.training else 0.0
|
| 261 |
+
x = F.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
|
| 262 |
+
x = x.transpose(1, 2).reshape(B, N, self.embed_dims)
|
| 263 |
+
|
| 264 |
+
x = self.proj(x)
|
| 265 |
+
x = self.out_drop(self.gamma1(self.proj_drop(x)))
|
| 266 |
+
|
| 267 |
+
if self.v_shortcut:
|
| 268 |
+
x = v.squeeze(1) + x
|
| 269 |
+
return x
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class Attention(nn.Module):
|
| 273 |
+
"""Multi-head Attention Module for both self and cross attention.
|
| 274 |
+
|
| 275 |
+
Support masking invalid elements for attention.
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
embed_dims (int): The embedding dimension.
|
| 279 |
+
num_heads (int): Parallel attention heads.
|
| 280 |
+
input_dims (int, optional): The input dimension, and if None,
|
| 281 |
+
use ``embed_dims``. Defaults to None.
|
| 282 |
+
attn_drop (float): Dropout rate of the dropout layer after the
|
| 283 |
+
attention calculation of query and key. Defaults to 0.
|
| 284 |
+
proj_drop (float): Dropout rate of the dropout layer after the
|
| 285 |
+
output projection. Defaults to 0.
|
| 286 |
+
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
| 287 |
+
qkv_bias (bool): If True, add a learnable bias to q, k, v.
|
| 288 |
+
Defaults to True.
|
| 289 |
+
qk_scale (float, optional): Override default qk scale of
|
| 290 |
+
``head_dim ** -0.5`` if set. Defaults to None.
|
| 291 |
+
proj_bias (bool) If True, add a learnable bias to output projection.
|
| 292 |
+
Defaults to True.
|
| 293 |
+
v_shortcut (bool): Add a shortcut from value to output. It's usually
|
| 294 |
+
used if ``input_dims`` is different from ``embed_dims``.
|
| 295 |
+
Defaults to False.
|
| 296 |
+
use_layer_scale (bool): Whether to use layer scale. Defaults to False.
|
| 297 |
+
layer_scale_init_value (float or torch.Tensor): Init value of layer
|
| 298 |
+
scale. Defaults to 0.
|
| 299 |
+
"""
|
| 300 |
+
|
| 301 |
+
def __init__(
|
| 302 |
+
self,
|
| 303 |
+
embed_dims,
|
| 304 |
+
num_heads,
|
| 305 |
+
query_dims=None,
|
| 306 |
+
key_dims=None,
|
| 307 |
+
value_dims=None,
|
| 308 |
+
attn_drop=0.0,
|
| 309 |
+
proj_drop=0.0,
|
| 310 |
+
drop_path_rate=0.0,
|
| 311 |
+
qkv_bias=True,
|
| 312 |
+
proj_bias=True,
|
| 313 |
+
v_shortcut=False,
|
| 314 |
+
layer_scale_init_value=0.0,
|
| 315 |
+
):
|
| 316 |
+
super().__init__()
|
| 317 |
+
|
| 318 |
+
self.query_dims = query_dims or embed_dims
|
| 319 |
+
self.key_dims = key_dims or embed_dims
|
| 320 |
+
self.value_dims = value_dims or embed_dims
|
| 321 |
+
self.embed_dims = embed_dims
|
| 322 |
+
self.num_heads = num_heads
|
| 323 |
+
self.v_shortcut = v_shortcut
|
| 324 |
+
|
| 325 |
+
self.head_dims = embed_dims // num_heads
|
| 326 |
+
|
| 327 |
+
self.q_proj = nn.Linear(self.query_dims, embed_dims, bias=qkv_bias)
|
| 328 |
+
self.k_proj = nn.Linear(self.key_dims, embed_dims, bias=qkv_bias)
|
| 329 |
+
self.v_proj = nn.Linear(self.value_dims, embed_dims, bias=qkv_bias)
|
| 330 |
+
self.attn_drop = attn_drop
|
| 331 |
+
self.proj = nn.Linear(embed_dims, self.query_dims, bias=proj_bias)
|
| 332 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 333 |
+
|
| 334 |
+
self.out_drop = DropPath(drop_path_rate)
|
| 335 |
+
|
| 336 |
+
if layer_scale_init_value > 0:
|
| 337 |
+
layer_scale_init_value = layer_scale_init_value or 1e-5
|
| 338 |
+
self.gamma1 = LayerScale(
|
| 339 |
+
embed_dims, layer_scale_init_value=layer_scale_init_value
|
| 340 |
+
)
|
| 341 |
+
else:
|
| 342 |
+
self.gamma1 = nn.Identity()
|
| 343 |
+
|
| 344 |
+
def _separate_heads(self, x: torch.Tensor) -> torch.Tensor:
|
| 345 |
+
b, n, _ = x.shape
|
| 346 |
+
x = x.reshape(b, n, self.num_heads, self.head_dims)
|
| 347 |
+
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
| 348 |
+
|
| 349 |
+
def forward(
|
| 350 |
+
self,
|
| 351 |
+
q: torch.Tensor,
|
| 352 |
+
k: torch.Tensor,
|
| 353 |
+
v: torch.Tensor,
|
| 354 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 355 |
+
):
|
| 356 |
+
B, N, _ = q.shape
|
| 357 |
+
q = self._separate_heads(self.q_proj(q))
|
| 358 |
+
k = self._separate_heads(self.k_proj(k))
|
| 359 |
+
v = self._separate_heads(self.v_proj(v))
|
| 360 |
+
|
| 361 |
+
attn_drop = self.attn_drop if self.training else 0.0
|
| 362 |
+
if attn_mask is not None:
|
| 363 |
+
attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
| 364 |
+
|
| 365 |
+
x = F.scaled_dot_product_attention(
|
| 366 |
+
q, k, v, attn_mask=attn_mask, dropout_p=attn_drop
|
| 367 |
+
)
|
| 368 |
+
x = x.transpose(1, 2).reshape(B, N, self.embed_dims)
|
| 369 |
+
|
| 370 |
+
x = self.proj(x)
|
| 371 |
+
x = self.out_drop(self.gamma1(self.proj_drop(x)))
|
| 372 |
+
|
| 373 |
+
if self.v_shortcut:
|
| 374 |
+
x = v.squeeze(1) + x
|
| 375 |
+
return x
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class TransformerEncoderLayer(nn.Module):
|
| 379 |
+
"""Implements one encoder layer in Vision Transformer.
|
| 380 |
+
|
| 381 |
+
Args:
|
| 382 |
+
embed_dims (int): The feature dimension
|
| 383 |
+
num_heads (int): Parallel attention heads
|
| 384 |
+
feedforward_channels (int): The hidden dimension for FFNs
|
| 385 |
+
layer_scale_init_value (float or torch.Tensor): Init value of layer
|
| 386 |
+
scale. Defaults to 0.
|
| 387 |
+
drop_rate (float): Probability of an element to be zeroed
|
| 388 |
+
after the feed forward layer. Defaults to 0.
|
| 389 |
+
attn_drop_rate (float): The drop out rate for attention output weights.
|
| 390 |
+
Defaults to 0.
|
| 391 |
+
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
| 392 |
+
num_fcs (int): The number of fully-connected layers for FFNs.
|
| 393 |
+
Defaults to 2.
|
| 394 |
+
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
|
| 395 |
+
ffn_type (str): Select the type of ffn layers. Defaults to 'origin'.
|
| 396 |
+
act_layer (nn.Module, optional): The activation layer for FFNs.
|
| 397 |
+
Default: nn.GELU
|
| 398 |
+
norm_cfg (dict): Config dict for normalization layer.
|
| 399 |
+
Defaults to ``dict(type='LN')``.
|
| 400 |
+
"""
|
| 401 |
+
|
| 402 |
+
def __init__(
|
| 403 |
+
self,
|
| 404 |
+
embed_dims,
|
| 405 |
+
num_heads,
|
| 406 |
+
feedforward_channels,
|
| 407 |
+
layer_scale_init_value=0.0,
|
| 408 |
+
drop_rate=0.0,
|
| 409 |
+
attn_drop_rate=0.0,
|
| 410 |
+
drop_path_rate=0.0,
|
| 411 |
+
num_fcs=2,
|
| 412 |
+
qkv_bias=True,
|
| 413 |
+
ffn_type="origin",
|
| 414 |
+
act_layer=nn.GELU,
|
| 415 |
+
norm_cfg=dict(type="LN", eps=1e-6),
|
| 416 |
+
):
|
| 417 |
+
super().__init__()
|
| 418 |
+
|
| 419 |
+
self.embed_dims = embed_dims
|
| 420 |
+
|
| 421 |
+
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
|
| 422 |
+
|
| 423 |
+
self.attn = MultiheadAttention(
|
| 424 |
+
embed_dims=embed_dims,
|
| 425 |
+
num_heads=num_heads,
|
| 426 |
+
attn_drop=attn_drop_rate,
|
| 427 |
+
proj_drop=drop_rate,
|
| 428 |
+
drop_path_rate=drop_path_rate,
|
| 429 |
+
qkv_bias=qkv_bias,
|
| 430 |
+
layer_scale_init_value=layer_scale_init_value,
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
|
| 434 |
+
|
| 435 |
+
if ffn_type == "origin":
|
| 436 |
+
self.ffn = FFN(
|
| 437 |
+
embed_dims=embed_dims,
|
| 438 |
+
feedforward_channels=feedforward_channels,
|
| 439 |
+
num_fcs=num_fcs,
|
| 440 |
+
ffn_drop=drop_rate,
|
| 441 |
+
drop_path_rate=drop_path_rate,
|
| 442 |
+
act_layer=act_layer,
|
| 443 |
+
layer_scale_init_value=layer_scale_init_value,
|
| 444 |
+
)
|
| 445 |
+
elif ffn_type == "swiglu_fused":
|
| 446 |
+
self.ffn = SwiGLUFFNFused(
|
| 447 |
+
embed_dims=embed_dims,
|
| 448 |
+
feedforward_channels=feedforward_channels,
|
| 449 |
+
layer_scale_init_value=layer_scale_init_value,
|
| 450 |
+
)
|
| 451 |
+
else:
|
| 452 |
+
raise NotImplementedError
|
| 453 |
+
|
| 454 |
+
@property
|
| 455 |
+
def norm1(self):
|
| 456 |
+
return self.ln1
|
| 457 |
+
|
| 458 |
+
@property
|
| 459 |
+
def norm2(self):
|
| 460 |
+
return self.ln2
|
| 461 |
+
|
| 462 |
+
def forward(self, x):
|
| 463 |
+
x = x + self.attn(self.ln1(x))
|
| 464 |
+
x = self.ffn(self.ln2(x), identity=x)
|
| 465 |
+
return x
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
class TransformerDecoderLayer(nn.Module):
|
| 469 |
+
"""Implements one decoder layer in cross-attention Transformer.
|
| 470 |
+
|
| 471 |
+
Adapted from Segment Anything Model (SAM) implementation.
|
| 472 |
+
|
| 473 |
+
Args:
|
| 474 |
+
embed_dims (int): The feature dimension
|
| 475 |
+
num_heads (int): Parallel attention heads
|
| 476 |
+
feedforward_channels (int): The hidden dimension for FFNs
|
| 477 |
+
layer_scale_init_value (float or torch.Tensor): Init value of layer
|
| 478 |
+
scale. Defaults to 0.
|
| 479 |
+
drop_rate (float): Probability of an element to be zeroed
|
| 480 |
+
after the feed forward layer. Defaults to 0.
|
| 481 |
+
attn_drop_rate (float): The drop out rate for attention output weights.
|
| 482 |
+
Defaults to 0.
|
| 483 |
+
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
| 484 |
+
num_fcs (int): The number of fully-connected layers for FFNs.
|
| 485 |
+
Defaults to 2.
|
| 486 |
+
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
|
| 487 |
+
ffn_type (str): Select the type of ffn layers. Defaults to 'origin'.
|
| 488 |
+
act_layer (nn.Module, optional): The activation layer for FFNs.
|
| 489 |
+
Default: nn.GELU
|
| 490 |
+
norm_cfg (dict): Config dict for normalization layer.
|
| 491 |
+
Defaults to ``dict(type='LN')``.
|
| 492 |
+
enable_twoway (bool): Whether to enable two-way Transformer (used in SAM).
|
| 493 |
+
repeat_pe (bool): Whether to re-add PE at each layer (used in SAM)
|
| 494 |
+
skip_first_pe (bool)
|
| 495 |
+
"""
|
| 496 |
+
|
| 497 |
+
def __init__(
|
| 498 |
+
self,
|
| 499 |
+
token_dims: int,
|
| 500 |
+
context_dims: int,
|
| 501 |
+
num_heads: int = 8,
|
| 502 |
+
head_dims: int = 64,
|
| 503 |
+
mlp_dims: int = 1024,
|
| 504 |
+
layer_scale_init_value: float = 0.0,
|
| 505 |
+
drop_rate: float = 0.0,
|
| 506 |
+
attn_drop_rate: float = 0.0,
|
| 507 |
+
drop_path_rate: float = 0.0,
|
| 508 |
+
ffn_type: str = "origin",
|
| 509 |
+
act_layer: type[nn.Module] | nn.Module = nn.GELU,
|
| 510 |
+
norm_cfg: Dict = dict(type="LN", eps=1e-6),
|
| 511 |
+
enable_twoway: bool = False,
|
| 512 |
+
repeat_pe: bool = False,
|
| 513 |
+
skip_first_pe: bool = False,
|
| 514 |
+
):
|
| 515 |
+
super().__init__()
|
| 516 |
+
self.repeat_pe = repeat_pe
|
| 517 |
+
self.skip_first_pe = skip_first_pe
|
| 518 |
+
if self.repeat_pe:
|
| 519 |
+
self.ln_pe_1 = build_norm_layer(norm_cfg, token_dims)
|
| 520 |
+
self.ln_pe_2 = build_norm_layer(norm_cfg, context_dims)
|
| 521 |
+
|
| 522 |
+
self.ln1 = build_norm_layer(norm_cfg, token_dims)
|
| 523 |
+
|
| 524 |
+
self.self_attn = Attention(
|
| 525 |
+
embed_dims=num_heads * head_dims,
|
| 526 |
+
num_heads=num_heads,
|
| 527 |
+
query_dims=token_dims,
|
| 528 |
+
key_dims=token_dims,
|
| 529 |
+
value_dims=token_dims,
|
| 530 |
+
attn_drop=attn_drop_rate,
|
| 531 |
+
proj_drop=drop_rate,
|
| 532 |
+
drop_path_rate=drop_path_rate,
|
| 533 |
+
layer_scale_init_value=layer_scale_init_value,
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
self.ln2_1 = build_norm_layer(norm_cfg, token_dims)
|
| 537 |
+
self.ln2_2 = build_norm_layer(norm_cfg, context_dims)
|
| 538 |
+
|
| 539 |
+
self.cross_attn = Attention(
|
| 540 |
+
embed_dims=num_heads * head_dims,
|
| 541 |
+
num_heads=num_heads,
|
| 542 |
+
query_dims=token_dims,
|
| 543 |
+
key_dims=context_dims,
|
| 544 |
+
value_dims=context_dims,
|
| 545 |
+
attn_drop=attn_drop_rate,
|
| 546 |
+
proj_drop=drop_rate,
|
| 547 |
+
drop_path_rate=drop_path_rate,
|
| 548 |
+
layer_scale_init_value=layer_scale_init_value,
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
self.ln3 = build_norm_layer(norm_cfg, token_dims)
|
| 552 |
+
|
| 553 |
+
if ffn_type == "origin":
|
| 554 |
+
self.ffn = FFN(
|
| 555 |
+
embed_dims=token_dims,
|
| 556 |
+
feedforward_channels=mlp_dims,
|
| 557 |
+
ffn_drop=drop_rate,
|
| 558 |
+
drop_path_rate=drop_path_rate,
|
| 559 |
+
act_layer=act_layer,
|
| 560 |
+
layer_scale_init_value=layer_scale_init_value,
|
| 561 |
+
)
|
| 562 |
+
elif ffn_type == "swiglu_fused":
|
| 563 |
+
self.ffn = SwiGLUFFNFused(
|
| 564 |
+
embed_dims=token_dims,
|
| 565 |
+
feedforward_channels=mlp_dims,
|
| 566 |
+
layer_scale_init_value=layer_scale_init_value,
|
| 567 |
+
)
|
| 568 |
+
else:
|
| 569 |
+
raise NotImplementedError
|
| 570 |
+
|
| 571 |
+
self.enable_twoway = enable_twoway
|
| 572 |
+
if self.enable_twoway:
|
| 573 |
+
self.ln4_1 = build_norm_layer(norm_cfg, context_dims)
|
| 574 |
+
self.ln4_2 = build_norm_layer(norm_cfg, token_dims)
|
| 575 |
+
|
| 576 |
+
self.cross_attn_2 = Attention(
|
| 577 |
+
embed_dims=num_heads * head_dims,
|
| 578 |
+
num_heads=num_heads,
|
| 579 |
+
query_dims=context_dims,
|
| 580 |
+
key_dims=token_dims,
|
| 581 |
+
value_dims=token_dims,
|
| 582 |
+
attn_drop=attn_drop_rate,
|
| 583 |
+
proj_drop=drop_rate,
|
| 584 |
+
drop_path_rate=drop_path_rate,
|
| 585 |
+
layer_scale_init_value=layer_scale_init_value,
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
def forward(
|
| 589 |
+
self,
|
| 590 |
+
x: torch.Tensor,
|
| 591 |
+
context: torch.Tensor,
|
| 592 |
+
x_pe: Optional[torch.Tensor] = None,
|
| 593 |
+
context_pe: Optional[torch.Tensor] = None,
|
| 594 |
+
x_mask: Optional[torch.Tensor] = None,
|
| 595 |
+
):
|
| 596 |
+
"""
|
| 597 |
+
Args:
|
| 598 |
+
x: shape [B, N, C]
|
| 599 |
+
context: shape [B, N, C]
|
| 600 |
+
x_mask: shape [B, N]
|
| 601 |
+
"""
|
| 602 |
+
if self.repeat_pe and context_pe is not None:
|
| 603 |
+
# LaPE: https://openaccess.thecvf.com/content/ICCV2023/papers/Yu_LaPE_Layer-adaptive_Position_Embedding_for_Vision_Transformers_with_Independent_Layer_ICCV_2023_paper.pdf
|
| 604 |
+
x_pe = self.ln_pe_1(x_pe)
|
| 605 |
+
context_pe = self.ln_pe_2(context_pe)
|
| 606 |
+
|
| 607 |
+
# Self attention block for tokens
|
| 608 |
+
if self.repeat_pe and not self.skip_first_pe and x_pe is not None:
|
| 609 |
+
q = k = self.ln1(x) + x_pe
|
| 610 |
+
v = self.ln1(x)
|
| 611 |
+
else:
|
| 612 |
+
q = k = v = self.ln1(x)
|
| 613 |
+
|
| 614 |
+
attn_mask = None
|
| 615 |
+
if x_mask is not None:
|
| 616 |
+
attn_mask = x_mask[:, :, None] @ x_mask[:, None, :]
|
| 617 |
+
# Set diagonal to 1 to prevent nan output
|
| 618 |
+
attn_mask.diagonal(dim1=1, dim2=2).fill_(1)
|
| 619 |
+
attn_mask = attn_mask > 0
|
| 620 |
+
x = x + self.self_attn(q=q, k=k, v=v, attn_mask=attn_mask)
|
| 621 |
+
|
| 622 |
+
# Cross attention block, tokens attending to image embedding
|
| 623 |
+
if self.repeat_pe and context_pe is not None:
|
| 624 |
+
q = self.ln2_1(x) + x_pe
|
| 625 |
+
k = self.ln2_2(context) + context_pe
|
| 626 |
+
v = self.ln2_2(context)
|
| 627 |
+
else:
|
| 628 |
+
q = self.ln2_1(x)
|
| 629 |
+
k = v = self.ln2_2(context)
|
| 630 |
+
x = x + self.cross_attn(q=q, k=k, v=v)
|
| 631 |
+
|
| 632 |
+
# MLP block
|
| 633 |
+
x = self.ffn(self.ln3(x), identity=x)
|
| 634 |
+
|
| 635 |
+
# (Optional) Cross attention block, image embeddings attending to tokens
|
| 636 |
+
if self.enable_twoway:
|
| 637 |
+
if self.repeat_pe and context_pe is not None:
|
| 638 |
+
q = self.ln4_1(context) + context_pe
|
| 639 |
+
k = self.ln4_2(x) + x_pe
|
| 640 |
+
v = self.ln4_2(x)
|
| 641 |
+
else:
|
| 642 |
+
q = self.ln4_1(context)
|
| 643 |
+
k = v = self.ln4_2(x)
|
| 644 |
+
attn_mask = (
|
| 645 |
+
(x_mask[:, None, :].repeat(1, context.shape[1], 1)) > 0
|
| 646 |
+
if x_mask is not None
|
| 647 |
+
else None
|
| 648 |
+
)
|
| 649 |
+
context = context + self.cross_attn_2(q=q, k=k, v=v, attn_mask=attn_mask)
|
| 650 |
+
|
| 651 |
+
return x, context
|