pablovela5620 commited on
Commit
6da47c0
·
1 Parent(s): cf43f05

init commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. .gitignore +234 -0
  3. LICENSE-APACHE +201 -0
  4. LICENSE-MIT +25 -0
  5. README.md +82 -12
  6. app.py +43 -0
  7. data/example-data/Amir-Khan-Lamont-Peterson_2689582.jpg +3 -0
  8. data/example-data/BNAAHPYGMYSE26U6C6T7VA6544.jpg +3 -0
  9. data/example-data/Canelo-Alvarez-b4d59f2080464e4d996177f5ce9792ee.jpg +3 -0
  10. data/example-data/Planche.jpg +3 -0
  11. data/example-data/yoga-example.jpg +3 -0
  12. pixi.lock +0 -0
  13. pyproject.toml +149 -0
  14. src/sam3d_body/__init__.py +12 -0
  15. src/sam3d_body/api/demo.py +241 -0
  16. src/sam3d_body/api/visualization.py +425 -0
  17. src/sam3d_body/build_models.py +56 -0
  18. src/sam3d_body/data/__init__.py +1 -0
  19. src/sam3d_body/data/transforms/__init__.py +21 -0
  20. src/sam3d_body/data/transforms/bbox_utils.py +380 -0
  21. src/sam3d_body/data/transforms/common.py +345 -0
  22. src/sam3d_body/data/utils/io.py +114 -0
  23. src/sam3d_body/data/utils/prepare_batch.py +99 -0
  24. src/sam3d_body/gradio_ui/sam3d_body_ui.py +164 -0
  25. src/sam3d_body/metadata/__init__.py +79 -0
  26. src/sam3d_body/metadata/mhr70.py +915 -0
  27. src/sam3d_body/models/__init__.py +1 -0
  28. src/sam3d_body/models/backbones/__init__.py +35 -0
  29. src/sam3d_body/models/backbones/dinov3.py +69 -0
  30. src/sam3d_body/models/backbones/vit.py +658 -0
  31. src/sam3d_body/models/decoders/__init__.py +32 -0
  32. src/sam3d_body/models/decoders/keypoint_prompt_sampler.py +183 -0
  33. src/sam3d_body/models/decoders/prompt_encoder.py +256 -0
  34. src/sam3d_body/models/decoders/promptable_decoder.py +194 -0
  35. src/sam3d_body/models/heads/__init__.py +28 -0
  36. src/sam3d_body/models/heads/camera_head.py +110 -0
  37. src/sam3d_body/models/heads/mhr_head.py +369 -0
  38. src/sam3d_body/models/meta_arch/__init__.py +3 -0
  39. src/sam3d_body/models/meta_arch/base_lightning_module.py +48 -0
  40. src/sam3d_body/models/meta_arch/base_model.py +162 -0
  41. src/sam3d_body/models/meta_arch/sam3d_body.py +1728 -0
  42. src/sam3d_body/models/modules/__init__.py +18 -0
  43. src/sam3d_body/models/modules/camera_embed.py +111 -0
  44. src/sam3d_body/models/modules/drop_path.py +42 -0
  45. src/sam3d_body/models/modules/geometry_utils.py +304 -0
  46. src/sam3d_body/models/modules/layer_scale.py +45 -0
  47. src/sam3d_body/models/modules/mhr_utils.py +392 -0
  48. src/sam3d_body/models/modules/misc.py +31 -0
  49. src/sam3d_body/models/modules/swiglu_ffn.py +96 -0
  50. src/sam3d_body/models/modules/transformer.py +651 -0
.gitattributes CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
1
+ # LFS/Xet-managed assets
2
+ *.jpg filter=lfs diff=lfs merge=lfs -text
3
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
4
+ *.png filter=lfs diff=lfs merge=lfs -text
5
+ *.gif filter=lfs diff=lfs merge=lfs -text
6
+
7
  *.7z filter=lfs diff=lfs merge=lfs -text
8
  *.arrow filter=lfs diff=lfs merge=lfs -text
9
  *.bin filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ # Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ # poetry.lock
109
+ # poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ # pdm.lock
116
+ # pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ # pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # Redis
135
+ *.rdb
136
+ *.aof
137
+ *.pid
138
+
139
+ # RabbitMQ
140
+ mnesia/
141
+ rabbitmq/
142
+ rabbitmq-data/
143
+
144
+ # ActiveMQ
145
+ activemq-data/
146
+
147
+ # SageMath parsed files
148
+ *.sage.py
149
+
150
+ # Environments
151
+ .env
152
+ .envrc
153
+ .venv
154
+ env/
155
+ venv/
156
+ ENV/
157
+ env.bak/
158
+ venv.bak/
159
+
160
+ # Spyder project settings
161
+ .spyderproject
162
+ .spyproject
163
+
164
+ # Rope project settings
165
+ .ropeproject
166
+
167
+ # mkdocs documentation
168
+ /site
169
+
170
+ # mypy
171
+ .mypy_cache/
172
+ .dmypy.json
173
+ dmypy.json
174
+
175
+ # Pyre type checker
176
+ .pyre/
177
+
178
+ # pytype static type analyzer
179
+ .pytype/
180
+
181
+ # Cython debug symbols
182
+ cython_debug/
183
+
184
+ # PyCharm
185
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
186
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
187
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
188
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
189
+ # .idea/
190
+
191
+ # Abstra
192
+ # Abstra is an AI-powered process automation framework.
193
+ # Ignore directories containing user credentials, local state, and settings.
194
+ # Learn more at https://abstra.io/docs
195
+ .abstra/
196
+
197
+ # Visual Studio Code
198
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
199
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
200
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
201
+ # you could uncomment the following to ignore the entire vscode folder
202
+ # .vscode/
203
+
204
+ # Ruff stuff:
205
+ .ruff_cache/
206
+
207
+ # PyPI configuration file
208
+ .pypirc
209
+
210
+ # Marimo
211
+ marimo/_static/
212
+ marimo/_lsp/
213
+ __marimo__/
214
+
215
+ # Streamlit
216
+ .streamlit/secrets.toml
217
+
218
+ # pixi environments
219
+ .pixi/*
220
+ !.pixi/config.toml
221
+
222
+ _checkpoints/*
223
+
224
+
225
+ # START Ruler Generated Files
226
+ /.codex/config.json
227
+ /.codex/config.json.bak
228
+ /.codex/config.toml
229
+ /.codex/config.toml.bak
230
+ /.vscode/mcp.json
231
+ /.vscode/mcp.json.bak
232
+ /AGENTS.md
233
+ /AGENTS.md.bak
234
+ # END Ruler Generated Files
LICENSE-APACHE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
LICENSE-MIT ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2022 Rerun Technologies AB <opensource@rerun.io>
2
+
3
+ Permission is hereby granted, free of charge, to any
4
+ person obtaining a copy of this software and associated
5
+ documentation files (the "Software"), to deal in the
6
+ Software without restriction, including without
7
+ limitation the rights to use, copy, modify, merge,
8
+ publish, distribute, sublicense, and/or sell copies of
9
+ the Software, and to permit persons to whom the Software
10
+ is furnished to do so, subject to the following
11
+ conditions:
12
+
13
+ The above copyright notice and this permission notice
14
+ shall be included in all copies or substantial portions
15
+ of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18
+ ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19
+ TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20
+ PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21
+ SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22
+ CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23
+ OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24
+ IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25
+ DEALINGS IN THE SOFTWARE.
README.md CHANGED
@@ -1,12 +1,82 @@
1
- ---
2
- title: Sam3d Body Rerun
3
- emoji: 🏆
4
- colorFrom: gray
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 6.0.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SAM3D Body with Rerun
2
+ An unofficial playground for Meta's SAM3D Body (DINOv3) with promptable SAM3 masks and live Rerun visualization. Uses **Rerun** for 3D inspection, **Gradio** for the UI, and **Pixi** for one-command setup.
3
+
4
+ <p align="center">
5
+ <a title="Rerun" href="https://rerun.io" target="_blank" rel="noopener noreferrer">
6
+ <img src="https://img.shields.io/badge/Rerun-0.27%2B-0b82f9" alt="Rerun badge">
7
+ </a>
8
+ <a title="Pixi" href="https://pixi.sh/latest/" target="_blank" rel="noopener noreferrer">
9
+ <img src="https://img.shields.io/badge/Install%20with-Pixi-16A34A" alt="Pixi badge">
10
+ </a>
11
+ <a title="CUDA" href="https://developer.nvidia.com/cuda-toolkit" target="_blank" rel="noopener noreferrer">
12
+ <img src="https://img.shields.io/badge/CUDA-12.9%2B-76b900" alt="CUDA badge">
13
+ </a>
14
+ <a title="GitHub" href="https://github.com/rerun-io/sam3d-body-rerun" target="_blank" rel="noopener noreferrer">
15
+ <img src="https://img.shields.io/github/stars/rerun-io/sam3d-body-rerun?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="GitHub stars">
16
+ </a>
17
+ </p>
18
+
19
+ <p align="center">
20
+ <!-- Drop your GIF/MP4 here once ready -->
21
+ <img src="media/sam3d-body-demo.gif" alt="example output" width="720" />
22
+ </p>
23
+
24
+ ## Installation
25
+ ### Using Pixi
26
+ Make sure you have the [Pixi](https://pixi.sh/latest/#installation) package manager installed.
27
+
28
+ TL;DR install Pixi:
29
+ ```bash
30
+ curl -fsSL https://pixi.sh/install.sh | sh
31
+ ```
32
+ Restart your shell so the new `pixi` binary is on `PATH`.
33
+
34
+ This is Linux only with an NVIDIA GPU.
35
+
36
+ The SAM3 and SAM3D Body checkpoints are gated on Hugging Face—request access for both [facebook/sam-3d-body-dinov3](https://huggingface.co/facebook/sam-3d-body-dinov3) and [facebook/sam3](https://huggingface.co/facebook/sam3), then authenticate either by setting `HF_TOKEN=<your token>` or running `huggingface-cli login` before the first download (see Meta's install notes).
37
+
38
+ First run will download HF checkpoints for SAM3, SAM3D Body, and the relative-depth model.
39
+ ```bash
40
+ git clone https://github.com/rerun-io/sam3d-body-rerun.git
41
+ cd sam3d-body-rerun
42
+ pixi run app
43
+ ```
44
+
45
+ All commands can be listed with `pixi task list`.
46
+
47
+ ## Usage
48
+ ### Gradio App
49
+ ```bash
50
+ pixi run app
51
+ ```
52
+ Opens the Gradio UI with an embedded streaming Rerun viewer. Try the bundled samples in `data/example-data` or upload your own RGB image; toggle “Log relative depth” to stream predicted depth.
53
+
54
+ ### CLI
55
+ From a dev shell (for tyro + dev deps):
56
+ ```
57
+ pixi run cli
58
+ ```
59
+
60
+ OR
61
+
62
+ ```bash
63
+ pixi shell -e dev
64
+ python tool/demo.py --help
65
+ ```
66
+ Run on a folder of images and configure Rerun output/recordings via the CLI flags.
67
+
68
+ ### Promptable SAM3 sandbox
69
+ If you just want SAM3 masks without 3D reconstruction:
70
+ ```bash
71
+ pixi run -e dev python tool/gradio_sam3.py
72
+ ```
73
+
74
+ ## Acknowledgements
75
+ Thanks to the original projects that make this demo possible:
76
+
77
+ - [facebook/sam-3d-body-dinov3](https://huggingface.co/facebook/sam-3d-body-dinov3) — SAM3D Body checkpoints and assets.
78
+ - [facebook/sam3](https://huggingface.co/facebook/sam3) — promptable concept segmentation.
79
+ - Relative depth/FOV from `MogeV1Predictor` in [monopriors](https://github.com/pablovela5620/monoprior).
80
+ - Built with [Rerun](https://rerun.io/), [Gradio](https://www.gradio.app/), and [Pixi](https://pixi.sh/latest/).
81
+
82
+ Dual licensed under Apache 2.0 and MIT for the code in this repository; upstream models/assets retain their original licenses (see `LICENSE-APACHE` and `LICENSE-MIT` for this repo).
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from pathlib import Path
4
+
5
+ PIXI_PATH = Path("/home/user/.pixi/bin/pixi")
6
+ PIXI_VERSION = "0.59.0"
7
+ MOCK_CUDA_VERSION = "12.9"
8
+
9
+ # Pretend CUDA 12.9 is available so pixi can solve environments on machines without GPUs.
10
+ os.environ.setdefault("CONDA_OVERRIDE_CUDA", MOCK_CUDA_VERSION)
11
+
12
+
13
+ def check_and_install_pixi() -> None:
14
+ try:
15
+ subprocess.check_call(f"{PIXI_PATH} --version", shell=True)
16
+ except subprocess.CalledProcessError:
17
+ print("pixi not found. Installing pixi...")
18
+ # Install pixi using the provided installation script
19
+ subprocess.check_call(
20
+ f"PIXI_VERSION=v{PIXI_VERSION} curl -fsSL https://pixi.sh/install.sh | bash",
21
+ shell=True,
22
+ )
23
+ subprocess.check_call(f"{PIXI_PATH} self-update --version {PIXI_VERSION}", shell=True)
24
+ subprocess.check_call(f"{PIXI_PATH} --version", shell=True)
25
+
26
+
27
+ def run_command(command: str) -> None:
28
+ try:
29
+ subprocess.check_call(command, shell=True)
30
+ except subprocess.CalledProcessError as e:
31
+ print(f"run command {command}. Error: {e}")
32
+
33
+
34
+ if __name__ == "__main__":
35
+ check_and_install_pixi()
36
+ # install lsof
37
+ # run_command(command=f"{PIXI_PATH} global install lsof")
38
+ # # kill anything running on port 7860
39
+ # run_command(command=f"{PIXI_PATH.parent}/lsof -t -i:7860 | xargs -r kill")
40
+ # clean current environment
41
+ run_command(command=f"{PIXI_PATH} clean")
42
+ # run spaces app
43
+ run_command(command=f"{PIXI_PATH} run app")
data/example-data/Amir-Khan-Lamont-Peterson_2689582.jpg ADDED

Git LFS Details

  • SHA256: 85013a25f46cad9ba86bc05786b48dfb6e5a2d5dfa9f19328997480ec23226e5
  • Pointer size: 131 Bytes
  • Size of remote file: 155 kB
data/example-data/BNAAHPYGMYSE26U6C6T7VA6544.jpg ADDED

Git LFS Details

  • SHA256: c5d64d944c10ffde20645075b9078b7359899dd019062ceaa6fd54b18be21042
  • Pointer size: 131 Bytes
  • Size of remote file: 719 kB
data/example-data/Canelo-Alvarez-b4d59f2080464e4d996177f5ce9792ee.jpg ADDED

Git LFS Details

  • SHA256: bc029593f9dae5bd0473148fe9b920d6e708220b126c1b0a09bb9b48bfa999be
  • Pointer size: 131 Bytes
  • Size of remote file: 134 kB
data/example-data/Planche.jpg ADDED

Git LFS Details

  • SHA256: 898a2376f2adac0676408cc5c563b8f50df9966caa3299d4013b0476dd5cdbbe
  • Pointer size: 131 Bytes
  • Size of remote file: 216 kB
data/example-data/yoga-example.jpg ADDED

Git LFS Details

  • SHA256: 260c554cb3e8cc582a37873951f05ee10e99631e2c858d7b77d246554212fdae
  • Pointer size: 130 Bytes
  • Size of remote file: 50.6 kB
pixi.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ authors = [{ name = "pablo vela", email = "pablovela5620@gmail.com" }]
3
+ dependencies = [
4
+ "jaxtyping<0.3.0",
5
+ "numpy>=2.0",
6
+ "einops>=0.8.0",
7
+ "icecream>=2.1.3",
8
+ "opencv-python>=4.10.0",
9
+ "pyserde>=0.20.0",
10
+ "rerun-sdk>=0.27.0",
11
+ "tyro>=0.9.1",
12
+ "tqdm",
13
+ "hf-transfer>=0.1.9",
14
+ "lovely-numpy>=0.2.13,<0.3",
15
+ "pandas>=2.3.3",
16
+ "braceexpand>=0.1.7,<0.2",
17
+ "roma>=1.5.4,<2",
18
+ "pytorch-lightning>=2.5.6,<3",
19
+ "yacs>=0.1.8,<0.2",
20
+ "omegaconf>=2.3.0,<3",
21
+ "termcolor>=3.2.0,<4",
22
+ "gradio-rerun>=0.27.0",
23
+ "spaces>=0.43.0",
24
+ ]
25
+ name = "sam3d_body"
26
+ requires-python = ">= 3.12"
27
+ version = "0.1.0"
28
+
29
+
30
+ [build-system]
31
+ build-backend = "hatchling.build"
32
+ requires = ["hatchling"]
33
+
34
+ [tool.hatch.metadata]
35
+ allow-direct-references = true
36
+
37
+ [tool.pixi.workspace]
38
+ channels = ["conda-forge"]
39
+ platforms = ["linux-64"]
40
+ preview = ["pixi-build"]
41
+
42
+ [tool.pixi.pypi-options]
43
+ no-build-isolation = ["detectron2", "moge"]
44
+ [tool.pixi.pypi-options.dependency-overrides]
45
+ # Allow iopath >=0.1.10 even though detectron2 pins <0.1.10, so it can satisfy sam-2.
46
+ iopath = ">=0.1.10"
47
+ gradio = ">=5.45.0,<6"
48
+ [tool.pixi.pypi-dependencies]
49
+ sam3d_body = { path = ".", editable = true }
50
+ moge = { git = "https://github.com/microsoft/MoGe.git" }
51
+ simplecv = { git = "https://github.com/pablovela5620/simplecv.git", branch = "main" }
52
+ timm = ">=0.9"
53
+ transformers = { git = "https://github.com/huggingface/transformers.git", rev = "d08b98b965176ea9cf8c8e8b24995c955b7e2ec9" }
54
+ monopriors = { git = "https://github.com/pablovela5620/monoprior.git" }
55
+
56
+ [tool.pixi.tasks]
57
+ app = "python tool/gradio_sam3d_body.py"
58
+ cli = "python tool/demo.py --image-folder data/example-data"
59
+
60
+ [tool.pixi.feature.cuda129.system-requirements]
61
+ cuda = "12.9"
62
+
63
+ [tool.pixi.feature.cuda129.dependencies]
64
+ # CUDA Build Tools
65
+ cuda-compiler = "*"
66
+ cuda-version = "12.9.*"
67
+ cuda-cudart-dev = "*"
68
+ cuda-crt = "*"
69
+ libcusparse-dev = "*"
70
+ cuda-driver-dev = "*"
71
+ cuda-nvcc = "*"
72
+ cuda-nvrtc-dev = "*"
73
+ cuda-nvtx = "*"
74
+ cuda-nvtx-dev = "*"
75
+ cuda-nvml-dev = "*"
76
+ cuda-profiler-api = "*"
77
+
78
+ # CUDA Libraries
79
+ cudnn = "*"
80
+ libcublas-dev = "*"
81
+ libcudss-dev = "*"
82
+ libcufile-dev = "*"
83
+ libcufft-dev = "*"
84
+ libcurand-dev = "*"
85
+ libcusolver-dev = "*"
86
+ cusparselt = "*"
87
+ libnvjitlink = "*"
88
+ # cuda129 end
89
+
90
+ [tool.pixi.feature.gpu.dependencies]
91
+ pytorch-gpu = ">=2.8.0"
92
+ torchvision = "*"
93
+
94
+
95
+ [tool.pixi.feature.dev.dependencies]
96
+ beartype = "*"
97
+ pyrefly = ">=0.42.2,<0.43"
98
+ ruff = ">=0.14.5,<0.15"
99
+
100
+ [tool.pixi.feature.dev.pypi-dependencies]
101
+ types-tqdm = "*"
102
+
103
+ [tool.pixi.environments]
104
+ cuda128 = { features = [
105
+ "cuda129",
106
+ ], solve-group = "cuda129", no-default-feature = true }
107
+ default = { features = ["gpu", "cuda129"], solve-group = "cuda129" }
108
+ dev = { features = ["dev", "gpu", "cuda129"], solve-group = "cuda129" }
109
+
110
+ [tool.pixi.dependencies]
111
+ av = ">=16.0.1,<17"
112
+ gradio = ">=5.45.0,<6"
113
+ huggingface_hub = ">=1.0,<2"
114
+ tomlkit = "==0.12.0"
115
+ audioop-lts = "*"
116
+ pydub = "*"
117
+ open3d = ">=0.19.0,<0.20"
118
+
119
+ [tool.ruff]
120
+ line-length = 150
121
+
122
+ [tool.ruff.lint]
123
+ select = [
124
+ # pycodestyle
125
+ "E",
126
+ # Pyflakes
127
+ "F",
128
+ # pyupgrade
129
+ "UP",
130
+ # flake8-bugbear
131
+ "B",
132
+ # flake8-simplify
133
+ "SIM",
134
+ # isort
135
+ "I",
136
+ ]
137
+
138
+ ignore = [
139
+ "E501", # Line too long.
140
+ "F722", # Forward annotation false positive from jaxtyping. Should be caught by pyright.
141
+ "F821", # Forward annotation false positive from jaxtyping. Should be caught by pyright.
142
+ "UP037", # Remove quotes from type, false positive when using jaxtyping
143
+ "UP040", # Beartype fails if not using this for typealias
144
+
145
+ ]
146
+
147
+ [tool.pyrefly]
148
+ project-includes = ["**/*"]
149
+ project-excludes = ["**/node_modules", "**/__pycache__", "**/*venv/**/*"]
src/sam3d_body/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Only enable beartype when running in the 'dev' environment
4
+ # Check the PIXI_ENVIRONMENT_NAME environment variable set by pixi
5
+ if os.environ.get("PIXI_ENVIRONMENT_NAME") == "dev":
6
+ try:
7
+ from beartype.claw import beartype_this_package
8
+
9
+ beartype_this_package()
10
+ except ImportError:
11
+ # beartype not available even in dev environment
12
+ pass
src/sam3d_body/api/demo.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal standalone demo wiring for SAM 3D Body with Rerun visualization."""
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+ from glob import glob
6
+ from pathlib import Path
7
+ from typing import Literal, TypedDict
8
+
9
+ import cv2
10
+ import numpy as np
11
+ import rerun as rr
12
+ import rerun.blueprint as rrb
13
+ import torch
14
+ from jaxtyping import Float32, UInt8
15
+ from monopriors.relative_depth_models import BaseRelativePredictor, RelativeDepthPrediction, get_relative_predictor
16
+ from numpy import ndarray
17
+ from serde import serde
18
+ from simplecv.rerun_log_utils import RerunTyroConfig
19
+ from torch import Tensor
20
+ from tqdm import tqdm
21
+ from transformers.models.sam3 import Sam3Model, Sam3Processor
22
+ from yacs.config import CfgNode
23
+
24
+ from sam3d_body.api.visualization import create_view, set_annotation_context, visualize_sample
25
+ from sam3d_body.build_models import load_sam_3d_body, load_sam_3d_body_hf
26
+ from sam3d_body.models.meta_arch import SAM3DBody
27
+ from sam3d_body.sam_3d_body_estimator import FinalPosePrediction, SAM3DBodyEstimator
28
+
29
+
30
+ class SAM3ResultsDict(TypedDict):
31
+ """Torch-format outputs returned directly by ``Sam3Processor`` post-processing."""
32
+
33
+ scores: Float32[Tensor, "n"]
34
+ boxes: Float32[Tensor, "n 4"]
35
+ masks: Float32[Tensor, "n h w"]
36
+
37
+
38
+ @serde()
39
+ class SAM3Results:
40
+ scores: Float32[ndarray, "n"]
41
+ """Per-instance confidence scores ``[N]``."""
42
+ boxes: Float32[ndarray, "n 4"]
43
+ """Bounding boxes in XYXY pixel coordinates ``[N, 4]``."""
44
+ masks: Float32[ndarray, "n h w"]
45
+ """Probability masks for each detection ``[N, H, W]`` (float32 in ``[0, 1]``)."""
46
+
47
+
48
+ @dataclass
49
+ class SAM3Config:
50
+ """Configuration for loading a SAM3 checkpoint and selecting device."""
51
+
52
+ device: Literal["cpu", "cuda"] = "cuda"
53
+ """Computation device passed to the Hugging Face SAM3 model."""
54
+ sam3_checkpoint: str = "facebook/sam3"
55
+ """Model identifier or path accepted by ``Sam3Model.from_pretrained``."""
56
+
57
+
58
+ class SAM3Predictor:
59
+ """Lightweight wrapper around the SAM3 model for single-image inference."""
60
+
61
+ def __init__(self, config: SAM3Config):
62
+ self.config = config
63
+ self.sam3_model = Sam3Model.from_pretrained(config.sam3_checkpoint).to(config.device)
64
+ self.sam3_processor = Sam3Processor.from_pretrained(config.sam3_checkpoint)
65
+
66
+ def predict_single_image(self, rgb_hw3: UInt8[ndarray, "h w 3"], text: str = "person") -> SAM3Results:
67
+ """Run SAM3 instance segmentation on one RGB image.
68
+
69
+ Args:
70
+ rgb_hw3: Input image in RGB order with dtype ``uint8`` and shape ``[H, W, 3]``.
71
+ text: Optional prompt used by SAM3's text-conditioned decoder (default: ``"person"``).
72
+
73
+ Returns:
74
+ ``SAM3Results`` with NumPy copies of scores, XYXY boxes, and binary masks.
75
+ """
76
+ inputs = self.sam3_processor(
77
+ images=rgb_hw3,
78
+ text=text,
79
+ return_tensors="pt",
80
+ ).to(self.config.device)
81
+
82
+ with torch.no_grad():
83
+ outputs = self.sam3_model(**inputs)
84
+
85
+ results: SAM3ResultsDict = self.sam3_processor.post_process_instance_segmentation(
86
+ outputs, threshold=0.5, mask_threshold=0.5, target_sizes=inputs.get("original_sizes").tolist()
87
+ )[0]
88
+
89
+ mask_probs: Float32[ndarray, "n h w"] = results["masks"].detach().cpu().numpy().astype(np.float32, copy=False)
90
+
91
+ return SAM3Results(
92
+ scores=results["scores"].detach().cpu().numpy().astype(np.float32, copy=False),
93
+ boxes=results["boxes"].detach().cpu().numpy().astype(np.float32, copy=False),
94
+ masks=mask_probs,
95
+ )
96
+
97
+
98
+ @dataclass
99
+ class SAM3DBodyE2EConfig:
100
+ """Bundle of sub-configurations required for the end-to-end demo."""
101
+
102
+ sam3_config: SAM3Config
103
+ """Settings for the underlying SAM3 detector."""
104
+ fov_estimator: Literal["MogeV1Predictor"] = "MogeV1Predictor"
105
+ """Identifier of the relative depth/FOV estimator to load."""
106
+ mhr_path: Path = Path("checkpoints/sam-3d-body-dinov3/assets/mhr_model.pt")
107
+ """Path to the MHR mesh/pose asset file required by the head network."""
108
+ checkpoint_path: Path = Path("checkpoints/sam-3d-body-dinov3/model.ckpt")
109
+ """Core SAM 3D Body model checkpoint (.ckpt)."""
110
+
111
+
112
+ class SAM3DBodyE2E:
113
+ """Convenience facade that chains detection, FOV estimation, and 3D reconstruction."""
114
+
115
+ def __init__(self, config: SAM3DBodyE2EConfig):
116
+ self.sam3_predictor = SAM3Predictor(config.sam3_config)
117
+ self.fov_predictor: BaseRelativePredictor = get_relative_predictor(config.fov_estimator)(device="cuda")
118
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
119
+ # load_output: tuple[SAM3DBody, CfgNode] = load_sam_3d_body(
120
+ # config.checkpoint_path,
121
+ # device=device,
122
+ # mhr_path=config.mhr_path,
123
+ # )
124
+ load_output: tuple[SAM3DBody, CfgNode] = load_sam_3d_body_hf(repo_id="facebook/sam-3d-body-dinov3")
125
+ model: SAM3DBody = load_output[0]
126
+ self.sam3d_body_estimator = SAM3DBodyEstimator(
127
+ sam_3d_body_model=model,
128
+ )
129
+
130
+ def predict_single_image(
131
+ self, rgb_hw3: UInt8[ndarray, "h w 3"]
132
+ ) -> tuple[list[FinalPosePrediction], RelativeDepthPrediction]:
133
+ """Estimate 3D poses for a single frame.
134
+
135
+ Pipeline:
136
+ 1. Use the configured relative-depth predictor to derive camera intrinsics ``K_33``.
137
+ 2. Run SAM3 to obtain person masks and boxes.
138
+ 3. Feed detections and intrinsics into ``SAM3DBodyEstimator`` for per-person 3D bodies.
139
+
140
+ Args:
141
+ rgb_hw3: RGB image with shape ``[H, W, 3]`` and dtype ``uint8``.
142
+
143
+ Returns:
144
+ A list of ``FinalPosePrediction`` entries—one per detected person.
145
+ """
146
+ # estimate the camera intrinsics
147
+ relative_pred: RelativeDepthPrediction = self.fov_predictor(rgb=rgb_hw3, K_33=None)
148
+ K_33: Float32[ndarray, "3 3"] = relative_pred.K_33
149
+
150
+ sam3_results: SAM3Results = self.sam3_predictor.predict_single_image(rgb_hw3)
151
+
152
+ outputs: list[FinalPosePrediction] = self.sam3d_body_estimator.process_one_image(
153
+ rgb_hw3,
154
+ xyxy=sam3_results.boxes,
155
+ masks=sam3_results.masks,
156
+ masks_score=sam3_results.scores,
157
+ K_33=K_33,
158
+ )
159
+ return outputs, relative_pred
160
+
161
+
162
+ @dataclass(slots=True)
163
+ class Sam3DBodyDemoConfig:
164
+ """Configuration for the standalone demo runner."""
165
+
166
+ rr_config: RerunTyroConfig
167
+ """Viewer/runtime options for Rerun (window layout, recording, etc.)."""
168
+
169
+ sam3_e2e_config: SAM3DBodyE2EConfig
170
+ """Configuration for the end-to-end SAM 3D Body model."""
171
+
172
+ image_folder: Path | None = None
173
+ """Directory containing input images to process."""
174
+
175
+ image_path: Path | None = None
176
+ """Path to a single input image to process."""
177
+
178
+ max_frames: int | None = None
179
+ """Optional limit on the number of images to process; ``None`` processes all images."""
180
+
181
+
182
+ def main(cfg: Sam3DBodyDemoConfig):
183
+ """Run the Rerun-enabled demo on a folder or single image.
184
+
185
+ Args:
186
+ cfg: Aggregated configuration containing Rerun settings, SAM3 model options,
187
+ and input image selection.
188
+ """
189
+ # Setup Rerun
190
+ parent_log_path = Path("/world")
191
+ set_annotation_context()
192
+ view: rrb.ContainerLike = create_view()
193
+ blueprint = rrb.Blueprint(view, collapse_panels=True)
194
+ rr.send_blueprint(blueprint)
195
+ rr.log("/", rr.ViewCoordinates.RDF, static=True)
196
+
197
+ if cfg.image_path is not None:
198
+ images_list = [str(cfg.image_path)]
199
+ elif cfg.image_folder is not None:
200
+ image_extensions: list[str] = [
201
+ "*.jpg",
202
+ "*.jpeg",
203
+ "*.png",
204
+ "*.gif",
205
+ "*.bmp",
206
+ "*.tiff",
207
+ "*.webp",
208
+ ]
209
+ images_list: list[str] = sorted(
210
+ [image for ext in image_extensions for image in glob(os.path.join(cfg.image_folder, ext))]
211
+ )
212
+ else:
213
+ raise ValueError("Either image_path or image_folder must be specified.")
214
+
215
+ # load end to end model
216
+ sam3D_body_e2e = SAM3DBodyE2E(cfg.sam3_e2e_config)
217
+
218
+ for idx, image_path in enumerate(tqdm(images_list)):
219
+ rr.set_time(timeline="image_sequence", sequence=idx)
220
+ # load image and convert to RGB
221
+ bgr_hw3: UInt8[ndarray, "h w 3"] = cv2.imread(image_path)
222
+ rgb_hw3: UInt8[ndarray, "h w 3"] = cv2.cvtColor(bgr_hw3, cv2.COLOR_BGR2RGB)
223
+
224
+ outputs: tuple[list[FinalPosePrediction], RelativeDepthPrediction] = sam3D_body_e2e.predict_single_image(
225
+ rgb_hw3
226
+ )
227
+ pred_list: list[FinalPosePrediction] = outputs[0]
228
+ relative_pred: RelativeDepthPrediction = outputs[1]
229
+
230
+ if len(pred_list) == 0:
231
+ # Detector/FOV failed on this frame; avoid crashing the visualization step.
232
+ print(f"[warn] No detections for {image_path}; skipping.")
233
+ continue
234
+
235
+ visualize_sample(
236
+ pred_list=pred_list,
237
+ rgb_hw3=rgb_hw3,
238
+ parent_log_path=parent_log_path,
239
+ faces=sam3D_body_e2e.sam3d_body_estimator.faces,
240
+ relative_depth_pred=relative_pred,
241
+ )
src/sam3d_body/api/visualization.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import open3d as o3d
6
+ import rerun as rr
7
+ import rerun.blueprint as rrb
8
+ from jaxtyping import Bool, Float32, Int, UInt8
9
+ from monopriors.depth_utils import depth_edges_mask
10
+ from monopriors.relative_depth_models import RelativeDepthPrediction
11
+ from numpy import ndarray
12
+ from simplecv.camera_parameters import Extrinsics, Intrinsics, PinholeParameters
13
+ from simplecv.ops.pc_utils import estimate_voxel_size
14
+ from simplecv.rerun_log_utils import log_pinhole
15
+
16
+ from sam3d_body.metadata.mhr70 import MHR70_ID2NAME, MHR70_IDS, MHR70_LINKS
17
+ from sam3d_body.sam_3d_body_estimator import FinalPosePrediction
18
+
19
+ BOX_PALETTE: UInt8[np.ndarray, "n_colors 4"] = np.array(
20
+ [
21
+ [255, 99, 71, 255], # tomato
22
+ [65, 105, 225, 255], # royal blue
23
+ [60, 179, 113, 255], # medium sea green
24
+ [255, 215, 0, 255], # gold
25
+ [138, 43, 226, 255], # blue violet
26
+ [255, 140, 0, 255], # dark orange
27
+ [220, 20, 60, 255], # crimson
28
+ [70, 130, 180, 255], # steel blue
29
+ ],
30
+ dtype=np.uint8,
31
+ )
32
+
33
+ # Use a separate id range for segmentation classes to avoid clobbering the person class (id=0).
34
+ SEG_CLASS_OFFSET = 1000 # background = 1000, persons start at 1001
35
+ MAX_POINT_CLOUD_POINTS = 50_000
36
+ MIN_DEPTH_CONFIDENCE = 0.5
37
+
38
+
39
+ def filter_out_of_bounds(
40
+ uv: Float32[ndarray, "n_points 2"],
41
+ h: int,
42
+ w: int,
43
+ xyz_cam: Float32[ndarray, "n_points 3"] | None = None,
44
+ ) -> Float32[ndarray, "n_points 2"]:
45
+ """Return a copy of ``uv`` with off-screen (and optional behind-camera) points masked.
46
+
47
+ Args:
48
+ uv: Pixel coordinates ``[N, 2]`` in (u, v) order.
49
+ h: Image height in pixels.
50
+ w: Image width in pixels.
51
+ xyz_cam: Optional camera-frame coordinates ``[N, 3]`` to mask points with negative ``z``.
52
+
53
+ Returns:
54
+ Copy of ``uv`` where out-of-bounds rows are set to ``NaN`` so Rerun hides them.
55
+ """
56
+
57
+ uv_filtered: Float32[ndarray, "n_points 2"] = np.asarray(uv, dtype=np.float32).copy()
58
+
59
+ out_of_bounds: Bool[ndarray, "n_points"] = np.logical_or(uv_filtered[:, 0] >= float(w), uv_filtered[:, 0] < 0.0)
60
+ out_of_bounds = np.logical_or(out_of_bounds, uv_filtered[:, 1] >= float(h))
61
+ out_of_bounds = np.logical_or(out_of_bounds, uv_filtered[:, 1] < 0.0)
62
+
63
+ if xyz_cam is not None:
64
+ out_of_bounds = np.logical_or(out_of_bounds, xyz_cam[:, 2] < 0.0)
65
+
66
+ uv_filtered[out_of_bounds, :] = np.nan
67
+ return uv_filtered
68
+
69
+
70
+ def compute_vertex_normals(
71
+ verts: Float32[ndarray, "n_verts 3"],
72
+ faces: Int[ndarray, "n_faces 3"],
73
+ eps: float = 1e-12,
74
+ ) -> Float32[ndarray, "n_verts 3"]:
75
+ """Compute per-vertex normals for a single mesh.
76
+
77
+ Args:
78
+ verts: Float32 array of vertex positions with shape ``(n_verts, 3)``.
79
+ faces: Int array of triangle indices with shape ``(n_faces, 3)``.
80
+ eps: Small epsilon to avoid division by zero when normalizing.
81
+
82
+ Returns:
83
+ Float32 array of unit vertex normals with shape ``(n_verts, 3)``; zeros for degenerate vertices.
84
+ """
85
+
86
+ # Expand faces to vertex triplets and fetch their positions.
87
+ faces_i: Int[ndarray, "n_faces 3"] = faces.astype(np.int64)
88
+ v0: Float32[ndarray, "n_faces 3"] = verts[faces_i[:, 0]]
89
+ v1: Float32[ndarray, "n_faces 3"] = verts[faces_i[:, 1]]
90
+ v2: Float32[ndarray, "n_faces 3"] = verts[faces_i[:, 2]]
91
+
92
+ # Face normal = cross(edge1, edge2).
93
+ e1: Float32[ndarray, "n_faces 3"] = v1 - v0
94
+ e2: Float32[ndarray, "n_faces 3"] = v2 - v0
95
+ face_normals: Float32[ndarray, "n_faces 3"] = np.cross(e1, e2)
96
+
97
+ # Accumulate each face normal into its three vertices with a vectorized scatter-add.
98
+ vertex_normals: Float32[ndarray, "n_verts 3"] = np.zeros_like(verts, dtype=np.float32)
99
+ flat_indices: Int[ndarray, "n_faces3"] = faces_i.reshape(-1)
100
+ face_normals_repeated: Float32[ndarray, "n_faces3 3"] = np.repeat(face_normals, 3, axis=0)
101
+ np.add.at(vertex_normals, flat_indices, face_normals_repeated)
102
+
103
+ norms: Float32[ndarray, "n_verts 1"] = np.linalg.norm(vertex_normals, axis=-1, keepdims=True)
104
+ denom: Float32[ndarray, "n_verts 1"] = np.maximum(norms, eps).astype(np.float32)
105
+ vn_unit: Float32[ndarray, "n_verts 3"] = (vertex_normals / denom).astype(np.float32)
106
+ mask: ndarray = norms > eps
107
+ vn_unit = np.where(mask, vn_unit, np.float32(0.0))
108
+ return vn_unit
109
+
110
+
111
+ def export_meshes_to_glb(
112
+ pred_list: list[FinalPosePrediction],
113
+ faces: Int[ndarray, "n_faces 3"],
114
+ output_dir: Path,
115
+ box_palette: UInt8[ndarray, "n_colors 4"] = BOX_PALETTE,
116
+ center_mesh: bool = True,
117
+ ) -> list[Path]:
118
+ """Write one GLB per predicted mesh and return the file paths."""
119
+
120
+ output_dir.mkdir(parents=True, exist_ok=True)
121
+ written_paths: list[Path] = []
122
+ faces_int: Int[ndarray, "n_faces 3"] = np.ascontiguousarray(faces, dtype=np.int32)
123
+
124
+ for idx, output in enumerate(pred_list):
125
+ verts_cam: Float32[ndarray, "n_verts 3"] = np.ascontiguousarray(output.pred_vertices, dtype=np.float32)
126
+ cam_t: Float32[ndarray, "3"] = np.ascontiguousarray(output.pred_cam_t, dtype=np.float32)
127
+ # Convert to world coordinates to mirror the viewer logging convention (cam → world via translation).
128
+ verts_world: Float32[ndarray, "n_verts 3"] = np.ascontiguousarray(verts_cam + cam_t, dtype=np.float32)
129
+ verts_export: Float32[ndarray, "n_verts 3"]
130
+ verts_export = verts_world - np.mean(verts_world, axis=0, keepdims=True) if center_mesh else verts_world
131
+
132
+ vertex_normals: Float32[ndarray, "n_verts 3"] = compute_vertex_normals(verts_export, faces_int)
133
+
134
+ mesh = o3d.geometry.TriangleMesh()
135
+ mesh.vertices = o3d.utility.Vector3dVector(verts_export.astype(np.float64))
136
+ mesh.triangles = o3d.utility.Vector3iVector(faces_int.astype(np.int32))
137
+ mesh.vertex_normals = o3d.utility.Vector3dVector(vertex_normals.astype(np.float64))
138
+
139
+ color: Float32[ndarray, "3"] = box_palette[idx % len(box_palette), :3].astype(np.float32) / 255.0
140
+ vertex_colors: Float32[ndarray, "n_verts 3"] = np.repeat(color[np.newaxis, :], verts_export.shape[0], axis=0)
141
+ mesh.vertex_colors = o3d.utility.Vector3dVector(vertex_colors.astype(np.float64))
142
+
143
+ glb_path: Path = output_dir / f"person_{idx:02d}.glb"
144
+ success: bool = bool(
145
+ o3d.io.write_triangle_mesh(
146
+ str(glb_path),
147
+ mesh,
148
+ write_ascii=False,
149
+ write_vertex_normals=True,
150
+ write_vertex_colors=True,
151
+ )
152
+ )
153
+ if not success:
154
+ fallback_path: Path = output_dir / f"person_{idx:02d}.ply"
155
+ success = bool(
156
+ o3d.io.write_triangle_mesh(
157
+ str(fallback_path),
158
+ mesh,
159
+ write_ascii=False,
160
+ write_vertex_normals=True,
161
+ write_vertex_colors=True,
162
+ )
163
+ )
164
+ if success:
165
+ glb_path = fallback_path
166
+
167
+ if success:
168
+ written_paths.append(glb_path)
169
+
170
+ return written_paths
171
+
172
+
173
+ def set_annotation_context() -> None:
174
+ """Register MHR-70 semantic metadata so subsequent logs show names/edges and mask colors."""
175
+ # Base person class (for keypoints / boxes) uses id=0 (original), segmentation uses 1000+ to avoid clashes.
176
+ person_class = rr.ClassDescription(
177
+ info=rr.AnnotationInfo(id=0, label="Person", color=(0, 0, 255)),
178
+ keypoint_annotations=[rr.AnnotationInfo(id=idx, label=name) for idx, name in MHR70_ID2NAME.items()],
179
+ keypoint_connections=MHR70_LINKS,
180
+ )
181
+
182
+ # Segmentation classes: id=SEG_CLASS_OFFSET background, ids SEG_CLASS_OFFSET+1..n for each instance color.
183
+ seg_classes: list[rr.ClassDescription] = [
184
+ rr.ClassDescription(info=rr.AnnotationInfo(id=SEG_CLASS_OFFSET, label="Background", color=(64, 64, 64))),
185
+ ]
186
+ for idx, color in enumerate(BOX_PALETTE[:, :3].tolist(), start=1):
187
+ seg_classes.append(
188
+ rr.ClassDescription(
189
+ info=rr.AnnotationInfo(
190
+ id=SEG_CLASS_OFFSET + idx, label=f"Person-{idx}", color=tuple(int(c) for c in color)
191
+ ),
192
+ )
193
+ )
194
+
195
+ rr.log(
196
+ "/",
197
+ rr.AnnotationContext([person_class, *seg_classes]),
198
+ static=True,
199
+ )
200
+
201
+
202
+ def visualize_sample(
203
+ pred_list: list[FinalPosePrediction],
204
+ rgb_hw3: UInt8[ndarray, "h w 3"],
205
+ parent_log_path: Path,
206
+ faces: Int[ndarray, "n_faces 3"],
207
+ relative_depth_pred: RelativeDepthPrediction | None = None,
208
+ ) -> None:
209
+ h: int = rgb_hw3.shape[0]
210
+ w: int = rgb_hw3.shape[1]
211
+ cam_log_path: Path = parent_log_path / "cam"
212
+ pinhole_log_path: Path = cam_log_path / "pinhole"
213
+ image_log_path: Path = pinhole_log_path / "image"
214
+ pred_log_path: Path = pinhole_log_path / "pred"
215
+ # log the pinhole camera parameters (assume fx=fy and center at image center)
216
+ focal_length: float = float(pred_list[0].focal_length)
217
+ intri: Intrinsics = Intrinsics(
218
+ camera_conventions="RDF",
219
+ fl_x=focal_length,
220
+ fl_y=focal_length,
221
+ cx=float(w) / 2.0,
222
+ cy=float(h) / 2.0,
223
+ height=h,
224
+ width=w,
225
+ )
226
+ world_T_cam: Float32[ndarray, "4 4"] = np.eye(4, dtype=np.float32)
227
+ extri: Extrinsics = Extrinsics(
228
+ world_R_cam=world_T_cam[:3, :3],
229
+ world_t_cam=world_T_cam[:3, 3],
230
+ )
231
+
232
+ pinhole_params: PinholeParameters = PinholeParameters(intrinsics=intri, extrinsics=extri, name="pinhole")
233
+ log_pinhole(camera=pinhole_params, cam_log_path=cam_log_path)
234
+ # clear the previous pred logs
235
+ rr.log(f"{pred_log_path}", rr.Clear(recursive=True))
236
+ rr.log(f"{image_log_path}", rr.Image(rgb_hw3, color_model=rr.ColorModel.RGB).compress(jpeg_quality=90))
237
+
238
+ # Build per-pixel maps (SEG_CLASS_OFFSET = background). Also build RGBA overlay with transparent background.
239
+ seg_map: Int[ndarray, "h w"] = np.full((h, w), SEG_CLASS_OFFSET, dtype=np.int32)
240
+ seg_overlay: UInt8[ndarray, "h w 4"] = np.zeros((h, w, 4), dtype=np.uint8)
241
+ human_mask: Bool[ndarray, "h w"] = np.zeros((h, w), dtype=bool)
242
+
243
+ mesh_root_path: Path = parent_log_path / "pred"
244
+ rr.log(str(mesh_root_path), rr.Clear(recursive=True))
245
+
246
+ for i, output in enumerate(pred_list):
247
+ box_color: UInt8[ndarray, "1 4"] = BOX_PALETTE[i % len(BOX_PALETTE)].reshape(1, 4)
248
+ rr.log(
249
+ f"{pred_log_path}/bbox_{i}",
250
+ rr.Boxes2D(
251
+ array=output.bbox,
252
+ array_format=rr.Box2DFormat.XYXY,
253
+ class_ids=0,
254
+ colors=box_color,
255
+ show_labels=True,
256
+ ),
257
+ )
258
+
259
+ kpts_cam: Float32[ndarray, "n_kpts 3"] = np.ascontiguousarray(output.pred_keypoints_3d, dtype=np.float32)
260
+ kpts_uv: Float32[ndarray, "n_kpts 2"] = np.ascontiguousarray(output.pred_keypoints_2d, dtype=np.float32)
261
+ kpts_uv_in_bounds: Float32[ndarray, "n_kpts 2"] = filter_out_of_bounds(
262
+ uv=kpts_uv,
263
+ h=h,
264
+ w=w,
265
+ xyz_cam=None, # Depth sign from the model can be negative; only cull by image bounds.
266
+ )
267
+ rr.log(
268
+ f"{pred_log_path}/uv_{i}",
269
+ rr.Points2D(
270
+ positions=kpts_uv_in_bounds,
271
+ keypoint_ids=MHR70_IDS,
272
+ class_ids=0,
273
+ colors=(0, 255, 0),
274
+ ),
275
+ )
276
+
277
+ # Accumulate segmentation masks (if present) into a single segmentation image.
278
+ mask = output.mask
279
+ if mask is not None:
280
+ mask_arr: ndarray = np.asarray(mask).squeeze()
281
+ if mask_arr.shape != seg_map.shape:
282
+ mask_arr = cv2.resize(
283
+ mask_arr.astype(np.uint8), (seg_map.shape[1], seg_map.shape[0]), interpolation=cv2.INTER_NEAREST
284
+ )
285
+ mask_bool = mask_arr.astype(bool)
286
+ human_mask = np.logical_or(human_mask, mask_bool)
287
+ seg_id = SEG_CLASS_OFFSET + i + 1 # keep person class (0) separate from seg classes
288
+ seg_map = np.where(mask_bool, np.uint16(seg_id), seg_map)
289
+
290
+ # Color overlay for this instance, background stays transparent.
291
+ color = BOX_PALETTE[i % len(BOX_PALETTE), :3]
292
+ seg_overlay[mask_bool] = np.array([color[0], color[1], color[2], 120], dtype=np.uint8)
293
+
294
+ # Log 3D keypoints in world coordinates
295
+ cam_t: Float32[ndarray, "3"] = np.ascontiguousarray(output.pred_cam_t, dtype=np.float32)
296
+ kpts_world: Float32[ndarray, "n_kpts 3"] = np.ascontiguousarray(kpts_cam + cam_t, dtype=np.float32)
297
+ rr.log(
298
+ f"{parent_log_path}/pred/kpts3d_{i}",
299
+ rr.Points3D(
300
+ positions=kpts_world,
301
+ keypoint_ids=MHR70_IDS,
302
+ class_ids=0,
303
+ colors=(0, 255, 0),
304
+ ),
305
+ )
306
+
307
+ # Log the full-body mesh in world coordinates so it shows in 3D
308
+ verts_cam: Float32[ndarray, "n_verts 3"] = np.ascontiguousarray(output.pred_vertices, dtype=np.float32)
309
+ verts_world: Float32[ndarray, "n_verts 3"] = np.ascontiguousarray(verts_cam + cam_t, dtype=np.float32)
310
+ faces_int: Int[ndarray, "n_faces 3"] = np.ascontiguousarray(faces, dtype=np.int32)
311
+ vertex_normals: Float32[ndarray, "n_verts 3"] = compute_vertex_normals(verts_world, faces_int)
312
+ rr.log(
313
+ f"{parent_log_path}/pred/mesh_{i}",
314
+ rr.Mesh3D(
315
+ vertex_positions=verts_world,
316
+ triangle_indices=faces_int,
317
+ vertex_normals=vertex_normals,
318
+ albedo_factor=(
319
+ float(box_color[0, 0]) / 255.0,
320
+ float(box_color[0, 1]) / 255.0,
321
+ float(box_color[0, 2]) / 255.0,
322
+ 0.35,
323
+ ),
324
+ ),
325
+ )
326
+
327
+ # Log segmentation ids (full map) and an RGBA overlay with transparent background.
328
+ if np.any(seg_map != SEG_CLASS_OFFSET):
329
+ rr.log(f"{pred_log_path}/segmentation_ids", rr.SegmentationImage(seg_map))
330
+ rr.log(f"{pred_log_path}/segmentation_overlay", rr.Image(seg_overlay, color_model=rr.ColorModel.RGBA))
331
+
332
+ # Optionally log depth and a background-only point cloud (for 3D view only).
333
+ if relative_depth_pred is not None:
334
+ depth_hw: Float32[ndarray, "h w"] = np.asarray(relative_depth_pred.depth, dtype=np.float32)
335
+ conf_hw: Float32[ndarray, "h w"] = np.asarray(relative_depth_pred.confidence, dtype=np.float32)
336
+ if depth_hw.shape != (h, w):
337
+ depth_hw = cv2.resize(depth_hw, (w, h), interpolation=cv2.INTER_NEAREST)
338
+ if conf_hw.shape != (h, w):
339
+ conf_hw = cv2.resize(conf_hw, (w, h), interpolation=cv2.INTER_NEAREST)
340
+ depth_hw = np.nan_to_num(depth_hw, nan=0.0, posinf=0.0, neginf=0.0)
341
+
342
+ # Remove flying pixels along depth discontinuities.
343
+ edges_mask: Bool[ndarray, "h w"] = depth_edges_mask(depth_hw, threshold=0.01)
344
+ depth_hw = depth_hw * np.logical_not(edges_mask)
345
+
346
+ # Remove low-confidence pixels.
347
+ conf_mask: Bool[ndarray, "h w"] = conf_hw >= MIN_DEPTH_CONFIDENCE
348
+ depth_hw = depth_hw * conf_mask
349
+
350
+ background_mask: Bool[ndarray, "h w"] = np.logical_not(human_mask)
351
+ depth_bg: Float32[ndarray, "h w"] = depth_hw * background_mask
352
+
353
+ # Log depth image (not referenced by the 2D blueprint).
354
+ # rr.log(f"{pinhole_log_path}/depth", rr.DepthImage(depth_bg, meter=1.0))
355
+
356
+ fx: float = float(relative_depth_pred.K_33[0, 0])
357
+ fy: float = float(relative_depth_pred.K_33[1, 1])
358
+ cx: float = float(relative_depth_pred.K_33[0, 2])
359
+ cy: float = float(relative_depth_pred.K_33[1, 2])
360
+
361
+ u: Float32[ndarray, "w"] = np.arange(w, dtype=np.float32)
362
+ v: Float32[ndarray, "h"] = np.arange(h, dtype=np.float32)
363
+ uu: Float32[ndarray, "h w"]
364
+ vv: Float32[ndarray, "h w"]
365
+ uu, vv = np.meshgrid(u, v)
366
+
367
+ z_cam: Float32[ndarray, "h w"] = depth_bg
368
+ valid: Bool[ndarray, "h w"] = np.logical_and(z_cam > 0.0, np.isfinite(z_cam))
369
+ if np.any(valid):
370
+ x_cam: Float32[ndarray, "h w"] = (uu - cx) * z_cam / fx
371
+ y_cam: Float32[ndarray, "h w"] = (vv - cy) * z_cam / fy
372
+ points_cam: Float32[ndarray, "h w 3"] = np.stack([x_cam, y_cam, z_cam], axis=-1)
373
+
374
+ points_flat: Float32[ndarray, "n_valid 3"] = points_cam[valid]
375
+ colors_flat: UInt8[ndarray, "n_valid 3"] = rgb_hw3[valid]
376
+
377
+ if points_flat.shape[0] > MAX_POINT_CLOUD_POINTS:
378
+ voxel_size: float = estimate_voxel_size(
379
+ points_flat, target_points=MAX_POINT_CLOUD_POINTS, tolerance=0.25
380
+ )
381
+ pcd: o3d.geometry.PointCloud = o3d.geometry.PointCloud()
382
+ pcd.points = o3d.utility.Vector3dVector(points_flat)
383
+ pcd.colors = o3d.utility.Vector3dVector(colors_flat.astype(np.float32) / 255.0)
384
+ pcd_ds: o3d.geometry.PointCloud = pcd.voxel_down_sample(voxel_size)
385
+ points_flat = np.asarray(pcd_ds.points, dtype=np.float32)
386
+ colors_flat = (np.asarray(pcd_ds.colors, dtype=np.float32) * 255.0).astype(np.uint8)
387
+
388
+ rr.log(
389
+ f"{parent_log_path}/depth_point_cloud",
390
+ rr.Points3D(
391
+ positions=points_flat,
392
+ colors=colors_flat,
393
+ ),
394
+ )
395
+
396
+
397
+ def create_view() -> rrb.ContainerLike:
398
+ view_2d = rrb.Vertical(
399
+ contents=[
400
+ # Top: people-only overlay on the RGB image.
401
+ rrb.Spatial2DView(
402
+ name="image",
403
+ origin="/world/cam/pinhole",
404
+ contents=[
405
+ "/world/cam/pinhole/image",
406
+ "/world/cam/pinhole/pred/segmentation_overlay",
407
+ ],
408
+ ),
409
+ # Bottom: 2D boxes + keypoints; segmentation hidden.
410
+ rrb.Spatial2DView(
411
+ name="mhr",
412
+ origin="/world/cam/pinhole",
413
+ contents=[
414
+ "/world/cam/pinhole/image",
415
+ "/world/cam/pinhole/pred/**",
416
+ "- /world/cam/pinhole/pred/segmentation_overlay/**",
417
+ "- /world/cam/pinhole/pred/segmentation_ids/**",
418
+ ],
419
+ ),
420
+ ],
421
+ )
422
+ view_3d = rrb.Spatial3DView(name="mhr_3d", line_grid=rrb.LineGrid3D(visible=False))
423
+ main_view = rrb.Horizontal(contents=[view_2d, view_3d], column_shares=[2, 3])
424
+ view = rrb.Tabs(contents=[main_view], name="sam-3d-body-demo")
425
+ return view
src/sam3d_body/build_models.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import os
3
+ from os import PathLike
4
+
5
+ import torch
6
+
7
+ from .models.meta_arch import SAM3DBody
8
+ from .utils.checkpoint import load_state_dict
9
+ from .utils.config import CN, get_config
10
+
11
+
12
+ def load_sam_3d_body(
13
+ checkpoint_path: str | PathLike[str] = "",
14
+ device: str | torch.device = "cuda",
15
+ mhr_path: str | PathLike[str] = "",
16
+ ) -> tuple[SAM3DBody, CN]:
17
+ print("Loading SAM 3D Body model...")
18
+
19
+ checkpoint_path = os.fspath(checkpoint_path)
20
+ mhr_path = os.fspath(mhr_path)
21
+
22
+ # Check the current directory, and if not present check the parent dir.
23
+ model_cfg = os.path.join(os.path.dirname(checkpoint_path), "model_config.yaml")
24
+ if not os.path.exists(model_cfg):
25
+ # Looks at parent dir
26
+ model_cfg = os.path.join(os.path.dirname(os.path.dirname(checkpoint_path)), "model_config.yaml")
27
+
28
+ model_cfg = get_config(model_cfg)
29
+
30
+ # Disable face for inference
31
+ model_cfg.defrost()
32
+ model_cfg.MODEL.MHR_HEAD.MHR_MODEL_PATH = mhr_path
33
+ model_cfg.freeze()
34
+
35
+ # Initialze the model
36
+ model = SAM3DBody(model_cfg)
37
+
38
+ checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
39
+ state_dict = checkpoint.get("state_dict", checkpoint)
40
+ load_state_dict(model, state_dict, strict=False)
41
+
42
+ model = model.to(device)
43
+ model.eval()
44
+ return model, model_cfg
45
+
46
+
47
+ def _hf_download(repo_id):
48
+ from huggingface_hub import snapshot_download
49
+
50
+ local_dir = snapshot_download(repo_id=repo_id)
51
+ return os.path.join(local_dir, "model.ckpt"), os.path.join(local_dir, "assets", "mhr_model.pt")
52
+
53
+
54
+ def load_sam_3d_body_hf(repo_id, **kwargs):
55
+ ckpt_path, mhr_path = _hf_download(repo_id)
56
+ return load_sam_3d_body(checkpoint_path=ckpt_path, mhr_path=mhr_path)
src/sam3d_body/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
src/sam3d_body/data/transforms/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from .bbox_utils import (
4
+ bbox_cs2xywh,
5
+ bbox_cs2xyxy,
6
+ bbox_xywh2cs,
7
+ bbox_xywh2xyxy,
8
+ bbox_xyxy2cs,
9
+ bbox_xyxy2xywh,
10
+ flip_bbox,
11
+ get_udp_warp_matrix,
12
+ get_warp_matrix,
13
+ )
14
+ from .common import (
15
+ Compose,
16
+ GetBBoxCenterScale,
17
+ NormalizeKeypoint,
18
+ SquarePad,
19
+ TopdownAffine,
20
+ VisionTransformWrapper,
21
+ )
src/sam3d_body/data/transforms/bbox_utils.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import math
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
+
9
+ def bbox_xyxy2xywh(bbox_xyxy: np.ndarray) -> np.ndarray:
10
+ """Transform the bbox format from x1y1x2y2 to xywh.
11
+
12
+ Args:
13
+ bbox_xyxy (np.ndarray): Bounding boxes (with scores), shaped (n, 4) or
14
+ (n, 5). (left, top, right, bottom, [score])
15
+
16
+ Returns:
17
+ np.ndarray: Bounding boxes (with scores),
18
+ shaped (n, 4) or (n, 5). (left, top, width, height, [score])
19
+ """
20
+ bbox_xywh = bbox_xyxy.copy()
21
+ bbox_xywh[:, 2] = bbox_xywh[:, 2] - bbox_xywh[:, 0]
22
+ bbox_xywh[:, 3] = bbox_xywh[:, 3] - bbox_xywh[:, 1]
23
+
24
+ return bbox_xywh
25
+
26
+
27
+ def bbox_xywh2xyxy(bbox_xywh: np.ndarray) -> np.ndarray:
28
+ """Transform the bbox format from xywh to x1y1x2y2.
29
+
30
+ Args:
31
+ bbox_xywh (ndarray): Bounding boxes (with scores),
32
+ shaped (n, 4) or (n, 5). (left, top, width, height, [score])
33
+ Returns:
34
+ np.ndarray: Bounding boxes (with scores), shaped (n, 4) or
35
+ (n, 5). (left, top, right, bottom, [score])
36
+ """
37
+ bbox_xyxy = bbox_xywh.copy()
38
+ bbox_xyxy[:, 2] = bbox_xyxy[:, 2] + bbox_xyxy[:, 0]
39
+ bbox_xyxy[:, 3] = bbox_xyxy[:, 3] + bbox_xyxy[:, 1]
40
+
41
+ return bbox_xyxy
42
+
43
+
44
+ def bbox_xyxy2cs(bbox: np.ndarray, padding: float = 1.0) -> tuple[np.ndarray, np.ndarray]:
45
+ """Transform the bbox format from (x,y,w,h) into (center, scale)
46
+
47
+ Args:
48
+ bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
49
+ as (left, top, right, bottom)
50
+ padding (float): BBox padding factor that will be multilied to scale.
51
+ Default: 1.0
52
+
53
+ Returns:
54
+ tuple: A tuple containing center and scale.
55
+ - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
56
+ (n, 2)
57
+ - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
58
+ (n, 2)
59
+ """
60
+ # convert single bbox from (4, ) to (1, 4)
61
+ dim = bbox.ndim
62
+ if dim == 1:
63
+ bbox = bbox[None, :]
64
+
65
+ x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
66
+ center = np.hstack([x1 + x2, y1 + y2]) * 0.5
67
+ scale = np.hstack([x2 - x1, y2 - y1]) * padding
68
+
69
+ if dim == 1:
70
+ center = center[0]
71
+ scale = scale[0]
72
+
73
+ return center, scale
74
+
75
+
76
+ def bbox_xywh2cs(bbox: np.ndarray, padding: float = 1.0) -> tuple[np.ndarray, np.ndarray]:
77
+ """Transform the bbox format from (x,y,w,h) into (center, scale)
78
+
79
+ Args:
80
+ bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
81
+ as (x, y, h, w)
82
+ padding (float): BBox padding factor that will be multilied to scale.
83
+ Default: 1.0
84
+
85
+ Returns:
86
+ tuple: A tuple containing center and scale.
87
+ - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
88
+ (n, 2)
89
+ - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
90
+ (n, 2)
91
+ """
92
+
93
+ # convert single bbox from (4, ) to (1, 4)
94
+ dim = bbox.ndim
95
+ if dim == 1:
96
+ bbox = bbox[None, :]
97
+
98
+ x, y, w, h = np.hsplit(bbox, [1, 2, 3])
99
+ center = np.hstack([x + w * 0.5, y + h * 0.5])
100
+ scale = np.hstack([w, h]) * padding
101
+
102
+ if dim == 1:
103
+ center = center[0]
104
+ scale = scale[0]
105
+
106
+ return center, scale
107
+
108
+
109
+ def bbox_cs2xyxy(center: np.ndarray, scale: np.ndarray, padding: float = 1.0) -> np.ndarray:
110
+ """Transform the bbox format from (center, scale) to (x1,y1,x2,y2).
111
+
112
+ Args:
113
+ center (ndarray): BBox center (x, y) in shape (2,) or (n, 2)
114
+ scale (ndarray): BBox scale (w, h) in shape (2,) or (n, 2)
115
+ padding (float): BBox padding factor that will be multilied to scale.
116
+ Default: 1.0
117
+
118
+ Returns:
119
+ ndarray[float32]: BBox (x1, y1, x2, y2) in shape (4, ) or (n, 4)
120
+ """
121
+
122
+ dim = center.ndim
123
+ assert scale.ndim == dim
124
+
125
+ if dim == 1:
126
+ center = center[None, :]
127
+ scale = scale[None, :]
128
+
129
+ wh = scale / padding
130
+ xy = center - 0.5 * wh
131
+ bbox = np.hstack((xy, xy + wh))
132
+
133
+ if dim == 1:
134
+ bbox = bbox[0]
135
+
136
+ return bbox
137
+
138
+
139
+ def bbox_cs2xywh(center: np.ndarray, scale: np.ndarray, padding: float = 1.0) -> np.ndarray:
140
+ """Transform the bbox format from (center, scale) to (x,y,w,h).
141
+
142
+ Args:
143
+ center (ndarray): BBox center (x, y) in shape (2,) or (n, 2)
144
+ scale (ndarray): BBox scale (w, h) in shape (2,) or (n, 2)
145
+ padding (float): BBox padding factor that will be multilied to scale.
146
+ Default: 1.0
147
+
148
+ Returns:
149
+ ndarray[float32]: BBox (x, y, w, h) in shape (4, ) or (n, 4)
150
+ """
151
+
152
+ dim = center.ndim
153
+ assert scale.ndim == dim
154
+
155
+ if dim == 1:
156
+ center = center[None, :]
157
+ scale = scale[None, :]
158
+
159
+ wh = scale / padding
160
+ xy = center - 0.5 * wh
161
+ bbox = np.hstack((xy, wh))
162
+
163
+ if dim == 1:
164
+ bbox = bbox[0]
165
+
166
+ return bbox
167
+
168
+
169
+ def flip_bbox(
170
+ bbox: np.ndarray,
171
+ image_size: tuple[int, int],
172
+ bbox_format: str = "xywh",
173
+ direction: str = "horizontal",
174
+ ) -> np.ndarray:
175
+ """Flip the bbox in the given direction.
176
+
177
+ Args:
178
+ bbox (np.ndarray): The bounding boxes. The shape should be (..., 4)
179
+ if ``bbox_format`` is ``'xyxy'`` or ``'xywh'``, and (..., 2) if
180
+ ``bbox_format`` is ``'center'``
181
+ image_size (tuple): The image shape in [w, h]
182
+ bbox_format (str): The bbox format. Options are ``'xywh'``, ``'xyxy'``
183
+ and ``'center'``.
184
+ direction (str): The flip direction. Options are ``'horizontal'``,
185
+ ``'vertical'`` and ``'diagonal'``. Defaults to ``'horizontal'``
186
+
187
+ Returns:
188
+ np.ndarray: The flipped bounding boxes.
189
+ """
190
+ direction_options = {"horizontal", "vertical", "diagonal"}
191
+ assert direction in direction_options, f'Invalid flipping direction "{direction}". Options are {direction_options}'
192
+
193
+ format_options = {"xywh", "xyxy", "center"}
194
+ assert bbox_format in format_options, f'Invalid bbox format "{bbox_format}". Options are {format_options}'
195
+
196
+ bbox_flipped = bbox.copy()
197
+ w, h = image_size
198
+
199
+ if direction == "horizontal":
200
+ if bbox_format == "xywh" or bbox_format == "center":
201
+ bbox_flipped[..., 0] = w - bbox[..., 0] - 1
202
+ elif bbox_format == "xyxy":
203
+ bbox_flipped[..., ::2] = w - bbox[..., ::2] - 1
204
+ elif direction == "vertical":
205
+ if bbox_format == "xywh" or bbox_format == "center":
206
+ bbox_flipped[..., 1] = h - bbox[..., 1] - 1
207
+ elif bbox_format == "xyxy":
208
+ bbox_flipped[..., 1::2] = h - bbox[..., 1::2] - 1
209
+ elif direction == "diagonal":
210
+ if bbox_format == "xywh" or bbox_format == "center":
211
+ bbox_flipped[..., :2] = [w, h] - bbox[..., :2] - 1
212
+ elif bbox_format == "xyxy":
213
+ bbox_flipped[...] = [w, h, w, h] - bbox - 1
214
+
215
+ return bbox_flipped
216
+
217
+
218
+ def fix_aspect_ratio(bbox_scale: np.ndarray, aspect_ratio: float):
219
+ """Reshape the bbox to a fixed aspect ratio.
220
+
221
+ Args:
222
+ bbox_scale (np.ndarray): The bbox scales (w, h) in shape (n, 2)
223
+ aspect_ratio (float): The ratio of ``w/h``
224
+
225
+ Returns:
226
+ np.darray: The reshaped bbox scales in (n, 2)
227
+ """
228
+ dim = bbox_scale.ndim
229
+ if dim == 1:
230
+ bbox_scale = bbox_scale[None, :]
231
+
232
+ w, h = np.hsplit(bbox_scale, [1])
233
+ bbox_scale = np.where(
234
+ w > h * aspect_ratio,
235
+ np.hstack([w, w / aspect_ratio]),
236
+ np.hstack([h * aspect_ratio, h]),
237
+ )
238
+ if dim == 1:
239
+ bbox_scale = bbox_scale[0]
240
+
241
+ return bbox_scale
242
+
243
+
244
+ def get_udp_warp_matrix(
245
+ center: np.ndarray,
246
+ scale: np.ndarray,
247
+ rot: float,
248
+ output_size: tuple[int, int],
249
+ ) -> np.ndarray:
250
+ """Calculate the affine transformation matrix under the unbiased
251
+ constraint. See `UDP (CVPR 2020)`_ for details.
252
+
253
+ Note:
254
+
255
+ - The bbox number: N
256
+
257
+ Args:
258
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
259
+ scale (np.ndarray[2, ]): Scale of the bounding box
260
+ wrt [width, height].
261
+ rot (float): Rotation angle (degree).
262
+ output_size (tuple): Size ([w, h]) of the output image
263
+
264
+ Returns:
265
+ np.ndarray: A 2x3 transformation matrix
266
+
267
+ .. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524
268
+ """
269
+ assert len(center) == 2
270
+ assert len(scale) == 2
271
+ assert len(output_size) == 2
272
+
273
+ input_size = center * 2
274
+ rot_rad = np.deg2rad(rot)
275
+ warp_mat = np.zeros((2, 3), dtype=np.float32)
276
+ scale_x = (output_size[0] - 1) / scale[0]
277
+ scale_y = (output_size[1] - 1) / scale[1]
278
+ warp_mat[0, 0] = math.cos(rot_rad) * scale_x
279
+ warp_mat[0, 1] = -math.sin(rot_rad) * scale_x
280
+ warp_mat[0, 2] = scale_x * (
281
+ -0.5 * input_size[0] * math.cos(rot_rad) + 0.5 * input_size[1] * math.sin(rot_rad) + 0.5 * scale[0]
282
+ )
283
+ warp_mat[1, 0] = math.sin(rot_rad) * scale_y
284
+ warp_mat[1, 1] = math.cos(rot_rad) * scale_y
285
+ warp_mat[1, 2] = scale_y * (
286
+ -0.5 * input_size[0] * math.sin(rot_rad) - 0.5 * input_size[1] * math.cos(rot_rad) + 0.5 * scale[1]
287
+ )
288
+ return warp_mat
289
+
290
+
291
+ def get_warp_matrix(
292
+ center: np.ndarray,
293
+ scale: np.ndarray,
294
+ rot: float,
295
+ output_size: tuple[int, int],
296
+ shift: tuple[float, float] = (0.0, 0.0),
297
+ inv: bool = False,
298
+ ) -> np.ndarray:
299
+ """Calculate the affine transformation matrix that can warp the bbox area
300
+ in the input image to the output size.
301
+
302
+ Args:
303
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
304
+ scale (np.ndarray[2, ]): Scale of the bounding box
305
+ wrt [width, height].
306
+ rot (float): Rotation angle (degree).
307
+ output_size (np.ndarray[2, ] | list(2,)): Size of the
308
+ destination heatmaps.
309
+ shift (0-100%): Shift translation ratio wrt the width/height.
310
+ Default (0., 0.).
311
+ inv (bool): Option to inverse the affine transform direction.
312
+ (inv=False: src->dst or inv=True: dst->src)
313
+
314
+ Returns:
315
+ np.ndarray: A 2x3 transformation matrix
316
+ """
317
+ assert len(center) == 2
318
+ assert len(scale) == 2
319
+ assert len(output_size) == 2
320
+ assert len(shift) == 2
321
+
322
+ shift = np.array(shift)
323
+ src_w = scale[0]
324
+ dst_w = output_size[0]
325
+ dst_h = output_size[1]
326
+
327
+ rot_rad = np.deg2rad(rot)
328
+ src_dir = _rotate_point(np.array([0.0, src_w * -0.5]), rot_rad)
329
+ dst_dir = np.array([0.0, dst_w * -0.5])
330
+
331
+ src = np.zeros((3, 2), dtype=np.float32)
332
+ src[0, :] = center + scale * shift
333
+ src[1, :] = center + src_dir + scale * shift
334
+ src[2, :] = _get_3rd_point(src[0, :], src[1, :])
335
+
336
+ dst = np.zeros((3, 2), dtype=np.float32)
337
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
338
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
339
+ dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
340
+
341
+ if inv:
342
+ warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
343
+ else:
344
+ warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
345
+ return warp_mat
346
+
347
+
348
+ def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
349
+ """Rotate a point by an angle.
350
+
351
+ Args:
352
+ pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
353
+ angle_rad (float): rotation angle in radian
354
+
355
+ Returns:
356
+ np.ndarray: Rotated point in shape (2, )
357
+ """
358
+
359
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
360
+ rot_mat = np.array([[cs, -sn], [sn, cs]])
361
+ return rot_mat @ pt
362
+
363
+
364
+ def _get_3rd_point(a: np.ndarray, b: np.ndarray):
365
+ """To calculate the affine matrix, three pairs of points are required. This
366
+ function is used to get the 3rd point, given 2D points a & b.
367
+
368
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
369
+ anticlockwise, using b as the rotation center.
370
+
371
+ Args:
372
+ a (np.ndarray): The 1st point (x,y) in shape (2, )
373
+ b (np.ndarray): The 2nd point (x,y) in shape (2, )
374
+
375
+ Returns:
376
+ np.ndarray: The 3rd point.
377
+ """
378
+ direction = a - b
379
+ c = b + np.r_[-direction[1], direction[0]]
380
+ return c
src/sam3d_body/data/transforms/common.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from collections.abc import Callable, Sequence
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch.nn as nn
8
+ import torchvision.transforms.functional as F
9
+ from PIL import Image
10
+
11
+ from sam3d_body.models.modules import to_2tuple
12
+
13
+ from .bbox_utils import (
14
+ bbox_xywh2cs,
15
+ bbox_xyxy2cs,
16
+ fix_aspect_ratio,
17
+ get_udp_warp_matrix,
18
+ get_warp_matrix,
19
+ )
20
+
21
+
22
+ class Compose:
23
+ """Compose multiple transforms sequentially.
24
+
25
+ Args:
26
+ transforms (Sequence[dict, callable], optional): Sequence of transform
27
+ object or config dict to be composed.
28
+ """
29
+
30
+ def __init__(self, transforms: list[Callable] | None = None):
31
+ if transforms is None:
32
+ transforms = []
33
+ else:
34
+ self.transforms = transforms
35
+
36
+ def __call__(self, data: dict) -> dict | None:
37
+ """Call function to apply transforms sequentially.
38
+
39
+ Args:
40
+ data (dict): A result dict contains the data to transform.
41
+
42
+ Returns:
43
+ dict: Transformed data.
44
+ """
45
+ for t in self.transforms:
46
+ data = t(data)
47
+ # The transform will return None when it failed to load images or
48
+ # cannot find suitable augmentation parameters to augment the data.
49
+ # Here we simply return None if the transform returns None and the
50
+ # dataset will handle it by randomly selecting another data sample.
51
+ if data is None:
52
+ return None
53
+ return data
54
+
55
+ def __repr__(self):
56
+ """Print ``self.transforms`` in sequence.
57
+
58
+ Returns:
59
+ str: Formatted string.
60
+ """
61
+ format_string = self.__class__.__name__ + "("
62
+ for t in self.transforms:
63
+ format_string += "\n"
64
+ format_string += f" {t}"
65
+ format_string += "\n)"
66
+ return format_string
67
+
68
+
69
+ class VisionTransformWrapper:
70
+ """A wrapper to use torchvision transform functions in this codebase."""
71
+
72
+ def __init__(self, transform: Callable):
73
+ self.transform = transform
74
+
75
+ def __call__(self, results: dict) -> dict | None:
76
+ results["img"] = self.transform(results["img"])
77
+ return results
78
+
79
+ def __repr__(self) -> str:
80
+ """print the basic information of the transform.
81
+
82
+ Returns:
83
+ str: Formatted string.
84
+ """
85
+ repr_str = self.transform.__class__.__name__
86
+ return repr_str
87
+
88
+
89
+ class GetBBoxCenterScale(nn.Module):
90
+ """Convert bboxes to center and scale.
91
+
92
+ The center is the coordinates of the bbox center, and the scale is the
93
+ bbox width and height normalized by a scale factor.
94
+
95
+ Required Keys:
96
+
97
+ - bbox
98
+ - bbox_format
99
+
100
+ Added Keys:
101
+
102
+ - bbox_center
103
+ - bbox_scale
104
+
105
+ Args:
106
+ padding (float): The bbox padding scale that will be multilied to
107
+ `bbox_scale`. Defaults to 1.25
108
+ """
109
+
110
+ def __init__(self, padding: float = 1.25) -> None:
111
+ super().__init__()
112
+
113
+ self.padding = padding
114
+
115
+ def forward(self, results: dict) -> dict | None:
116
+ """The transform function of :class:`GetBBoxCenterScale`.
117
+
118
+ Args:
119
+ results (dict): The result dict
120
+
121
+ Returns:
122
+ dict: The result dict.
123
+ """
124
+ if "bbox_center" in results and "bbox_scale" in results:
125
+ results["bbox_scale"] *= self.padding
126
+ else:
127
+ bbox = results["bbox"]
128
+ bbox_format = results.get("bbox_format", "none")
129
+ if bbox_format == "xywh":
130
+ center, scale = bbox_xywh2cs(bbox, padding=self.padding)
131
+ elif bbox_format == "xyxy":
132
+ center, scale = bbox_xyxy2cs(bbox, padding=self.padding)
133
+ else:
134
+ raise ValueError("Invalid bbox format: {}".format(results["bbox_format"]))
135
+
136
+ results["bbox_center"] = center
137
+ results["bbox_scale"] = scale
138
+ return results
139
+
140
+ def __repr__(self) -> str:
141
+ """print the basic information of the transform.
142
+
143
+ Returns:
144
+ str: Formatted string.
145
+ """
146
+ repr_str = self.__class__.__name__ + f"(padding={self.padding})"
147
+ return repr_str
148
+
149
+
150
+ class SquarePad:
151
+ def __call__(self, results: dict) -> dict | None:
152
+ assert isinstance(results["img"], Image.Image)
153
+ w, h = results["img"].size
154
+
155
+ max_wh = np.max([w, h])
156
+ hp = int((max_wh - w) / 2)
157
+ vp = int((max_wh - h) / 2)
158
+ padding = (hp, vp, max_wh - w - hp, max_wh - h - vp)
159
+
160
+ results["img"] = F.pad(results["img"], padding, 0, "constant")
161
+ return results
162
+
163
+ def __repr__(self) -> str:
164
+ """print the basic information of the transform.
165
+
166
+ Returns:
167
+ str: Formatted string.
168
+ """
169
+ repr_str = self.__class__.__name__
170
+ return repr_str
171
+
172
+
173
+ class ToPIL:
174
+ def __call__(self, results: dict) -> dict | None:
175
+ if isinstance(results["img"], list):
176
+ if isinstance(results["img"][0], np.ndarray):
177
+ results["img"] = [Image.fromarray(img) for img in results["img"]]
178
+ elif isinstance(results["img"], np.ndarray):
179
+ results["img"] = Image.fromarray(results["img"])
180
+
181
+
182
+ class ToCv2:
183
+ def __call__(self, results: dict) -> dict | None:
184
+ if isinstance(results["img"], list):
185
+ if isinstance(results["img"][0], Image.Image):
186
+ results["img"] = [np.array(img) for img in results["img"]]
187
+ elif isinstance(results["img"], Image.Image):
188
+ results["img"] = np.array(results["img"])
189
+
190
+
191
+ class TopdownAffine(nn.Module):
192
+ """Get the bbox image as the model input by affine transform.
193
+
194
+ Required Keys:
195
+ - img
196
+ - bbox_center
197
+ - bbox_scale
198
+ - bbox_rotation (optional)
199
+ - keypoints_2d (optional)
200
+ - mask (optional)
201
+
202
+ Modified Keys:
203
+ - img
204
+ - bbox_scale
205
+
206
+ Added Keys:
207
+ - input_size
208
+ - transformed_keypoints
209
+
210
+ Args:
211
+ input_size (Tuple[int, int]): The input image size of the model in
212
+ [w, h]. The bbox region will be cropped and resize to `input_size`
213
+ use_udp (bool): Whether use unbiased data processing. See
214
+ `UDP (CVPR 2020)`_ for details. Defaults to ``False``
215
+ aspect_ratio (float): both HMR2.0 and Sapiens will expand input bbox to
216
+ a fixed ratio (width/height = 192/256), then expand to the ratio of
217
+ the model input size. E.g., HMR2.0 will eventually expand to 1:1, while
218
+ Sapiens will be 768:1024.
219
+
220
+ .. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524
221
+ """
222
+
223
+ def __init__(
224
+ self,
225
+ input_size: int | tuple[int, int] | Sequence[int],
226
+ use_udp: bool = False,
227
+ aspect_ratio: float = 0.75,
228
+ fix_square: bool = False,
229
+ ) -> None:
230
+ super().__init__()
231
+
232
+ self.input_size = to_2tuple(input_size)
233
+ self.use_udp = use_udp
234
+ self.aspect_ratio = aspect_ratio
235
+ self.fix_square = fix_square
236
+
237
+ def forward(self, results: dict) -> dict | None:
238
+ """The transform function of :class:`TopdownAffine`.
239
+
240
+ See ``transform()`` method of :class:`BaseTransform` for details.
241
+
242
+ Args:
243
+ results (dict): The result dict
244
+
245
+ Returns:
246
+ dict: The result dict.
247
+ """
248
+ # # Debug only
249
+ # import copy
250
+ # results['ori_img'] = np.zeros((2000, 2000, 3), dtype=np.uint8)
251
+ # results['ori_img'][:results['img'].shape[0], :results['img'].shape[1]] = copy.deepcopy(results['img'])
252
+
253
+ w, h = self.input_size
254
+ warp_size = (int(w), int(h))
255
+
256
+ # expand bbox to fixed aspect ratio
257
+ results["orig_bbox_scale"] = results["bbox_scale"].copy()
258
+ if self.fix_square and results["bbox_scale"][0] == results["bbox_scale"][1]:
259
+ # In HMR2.0 etc, no fexpand_aspect_ratio for square bbox
260
+ bbox_scale = fix_aspect_ratio(results["bbox_scale"], aspect_ratio=w / h)
261
+ else:
262
+ # first to a prior aspect ratio, then reshape to model input size
263
+ bbox_scale = fix_aspect_ratio(results["bbox_scale"], aspect_ratio=self.aspect_ratio)
264
+ results["bbox_scale"] = fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
265
+ results["bbox_expand_factor"] = results["bbox_scale"].max() / results["orig_bbox_scale"].max()
266
+ rot = 0.0
267
+ if results["bbox_center"].ndim == 2:
268
+ assert results["bbox_center"].shape[0] == 1, (
269
+ "Only support cropping one instance at a time. Got invalid "
270
+ f"shape of bbox_center {results['bbox_center'].shape}."
271
+ )
272
+ center = results["bbox_center"][0]
273
+ scale = results["bbox_scale"][0]
274
+ if "bbox_rotation" in results:
275
+ rot = results["bbox_rotation"][0]
276
+ else:
277
+ center = results["bbox_center"]
278
+ scale = results["bbox_scale"]
279
+ if "bbox_rotation" in results:
280
+ rot = results["bbox_rotation"]
281
+
282
+ if self.use_udp:
283
+ warp_mat = get_udp_warp_matrix(center, scale, rot, output_size=(w, h))
284
+ else:
285
+ warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
286
+
287
+ if "img" not in results:
288
+ pass
289
+ elif isinstance(results["img"], list):
290
+ results["img"] = [
291
+ cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR) for img in results["img"]
292
+ ]
293
+ height, width = results["img"][0].shape[:2]
294
+ results["ori_img_size"] = np.array([width, height])
295
+ else:
296
+ height, width = results["img"].shape[:2]
297
+ results["ori_img_size"] = np.array([width, height])
298
+ results["img"] = cv2.warpAffine(results["img"], warp_mat, warp_size, flags=cv2.INTER_LINEAR)
299
+
300
+ if results.get("keypoints_2d") is not None:
301
+ results["orig_keypoints_2d"] = results["keypoints_2d"].copy()
302
+ transformed_keypoints = results["keypoints_2d"].copy()
303
+ # Only transform (x, y) coordinates
304
+ # cv2 expect the input to be [[[x1, y1], [x2, y2]]]
305
+ transformed_keypoints[:, :2] = cv2.transform(results["keypoints_2d"][None, :, :2], warp_mat)[0]
306
+ results["keypoints_2d"] = transformed_keypoints
307
+
308
+ if results.get("mask") is not None:
309
+ results["mask"] = cv2.warpAffine(results["mask"], warp_mat, warp_size, flags=cv2.INTER_LINEAR)
310
+
311
+ results["img_size"] = np.array([w, h])
312
+ results["input_size"] = np.array([w, h])
313
+ results["affine_trans"] = warp_mat
314
+ return results
315
+
316
+ def __repr__(self) -> str:
317
+ """print the basic information of the transform.
318
+
319
+ Returns:
320
+ str: Formatted string.
321
+ """
322
+ repr_str = self.__class__.__name__
323
+ repr_str += f"(input_size={self.input_size}, "
324
+ repr_str += f"use_udp={self.use_udp})"
325
+ return repr_str
326
+
327
+
328
+ class NormalizeKeypoint(nn.Module):
329
+ """
330
+ Normalize 2D keypoints to range [-0.5, 0.5].
331
+
332
+ Required Keys:
333
+ - keypoints_2d
334
+ - img_size
335
+
336
+ Modified Keys:
337
+ - keypoints_2d
338
+ """
339
+
340
+ def forward(self, results: dict) -> dict | None:
341
+ if "keypoints_2d" in results:
342
+ img_size = results.get("img_size", results["input_size"])
343
+
344
+ results["keypoints_2d"][:, :2] = results["keypoints_2d"][:, :2] / np.array(img_size).reshape(1, 2) - 0.5
345
+ return results
src/sam3d_body/data/utils/io.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import os
4
+ import time
5
+ from typing import Any, List
6
+
7
+ import braceexpand
8
+ import cv2
9
+ import numpy as np
10
+
11
+ from PIL import Image
12
+
13
+
14
+ def expand(s):
15
+ return os.path.expanduser(os.path.expandvars(s))
16
+
17
+
18
+ def expand_urls(urls: str | List[str]):
19
+ if isinstance(urls, str):
20
+ urls = [urls]
21
+ urls = [u for url in urls for u in braceexpand.braceexpand(expand(url))]
22
+ return urls
23
+
24
+
25
+ def load_image_from_file(
26
+ data_info: dict,
27
+ backend: str = "cv2",
28
+ image_format: str = "rgb",
29
+ retry: int = 10,
30
+ ) -> dict:
31
+ img = load_image(data_info["img_path"], backend, image_format, retry)
32
+ data_info["img"] = img
33
+ data_info["img_shape"] = img.shape[:2]
34
+ data_info["ori_shape"] = img.shape[:2]
35
+ return data_info
36
+
37
+
38
+ def _pil_load(path: str, image_format: str) -> Image.Image:
39
+ with Image.open(path) as img:
40
+ if img is not None and image_format.lower() == "rgb":
41
+ img = img.convert("RGB")
42
+ return img
43
+
44
+
45
+ def _cv2_load(path: str, image_format: str) -> np.ndarray:
46
+ img = cv2.imread(path)
47
+ if img is not None and image_format.lower() == "rgb":
48
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
49
+ return img
50
+
51
+
52
+ def load_image(
53
+ path: str,
54
+ backend: str = "pil",
55
+ image_format: str = "rgb",
56
+ retry: int = 10,
57
+ ) -> Any:
58
+ for i_try in range(retry):
59
+ if backend == "pil":
60
+ img = _pil_load(path, image_format)
61
+ elif backend == "cv2":
62
+ img = _cv2_load(path, image_format)
63
+ else:
64
+ raise ValueError("Invalid backend {} for loading image.".format(backend))
65
+
66
+ if img is not None:
67
+ return img
68
+ else:
69
+ print("Reading {} failed. Will retry.".format(path))
70
+ time.sleep(1.0)
71
+ if i_try == retry - 1:
72
+ raise Exception("Failed to load image {}".format(path))
73
+
74
+
75
+ def resize_image(img, target_size, center=None, scale=None):
76
+ height, width = img.shape[:2]
77
+ aspect_ratio = width / height
78
+
79
+ # Calculate the new size while maintaining the aspect ratio
80
+ if aspect_ratio > 1:
81
+ new_width = target_size
82
+ new_height = int(target_size / aspect_ratio)
83
+ else:
84
+ new_width = int(target_size * aspect_ratio)
85
+ new_height = target_size
86
+
87
+ # Resize the image using OpenCV
88
+ resized_img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_AREA)
89
+
90
+ # Create a new blank image with the target size
91
+ final_img = np.ones((target_size, target_size, 3), dtype=np.uint8) * 255
92
+
93
+ # Paste the resized image onto the blank image, centering it
94
+ start_x = (target_size - new_width) // 2
95
+ start_y = (target_size - new_height) // 2
96
+ final_img[start_y : start_y + new_height, start_x : start_x + new_width] = (
97
+ resized_img
98
+ )
99
+
100
+ if center is not None and scale is not None:
101
+ ratio_width = new_width / width
102
+ ratio_height = new_height / height
103
+
104
+ new_scale = np.stack(
105
+ [scale[:, 0] * ratio_width, scale[:, 1] * ratio_height], axis=1
106
+ )
107
+ new_center = np.stack(
108
+ [center[:, 0] * ratio_width, center[:, 1] * ratio_height], axis=1
109
+ )
110
+ new_center[:, 0] += start_x
111
+ new_center[:, 1] += start_y
112
+ else:
113
+ new_center, new_scale = None, None
114
+ return aspect_ratio, final_img, new_center, new_scale
src/sam3d_body/data/utils/prepare_batch.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from collections.abc import Callable
4
+ from typing import Any, TypedDict, cast
5
+
6
+ import numpy as np
7
+ import torch
8
+ from jaxtyping import Float, UInt8
9
+ from numpy import ndarray
10
+ from torch import Tensor
11
+ from torch.utils.data import default_collate
12
+
13
+
14
+ class PreparedBatchDict(TypedDict, total=False):
15
+ img: Float[Tensor, "B N 3 H W"]
16
+ img_size: Float[Tensor, "B N 2"]
17
+ ori_img_size: Float[Tensor, "B N 2"]
18
+ bbox_center: Float[Tensor, "B N 2"]
19
+ bbox_scale: Float[Tensor, "B N 2"]
20
+ bbox: Float[Tensor, "B N 4"]
21
+ affine_trans: Float[Tensor, "B N 2 3"]
22
+ mask: Float[Tensor, "B N 1 H W"]
23
+ mask_score: Float[Tensor, "B N"]
24
+ cam_int: Float[Tensor, "B 3 3"]
25
+ person_valid: Float[Tensor, "B N"]
26
+ img_ori: list["NoCollate"]
27
+
28
+
29
+ class NoCollate:
30
+ def __init__(self, data: Any) -> None:
31
+ self.data: Any = data
32
+
33
+
34
+ def prepare_batch(
35
+ img: UInt8[ndarray, "h w 3"],
36
+ transform: Callable[[dict[str, Any]], dict[str, Any]],
37
+ boxes: Float[ndarray, "n 4"],
38
+ masks: Float[ndarray, "n h w"] | None = None,
39
+ masks_score: Float[ndarray, "n"] | None = None,
40
+ cam_int: Float[Tensor, "B 3 3"] | None = None,
41
+ ) -> PreparedBatchDict:
42
+ """A helper function to prepare data batch for SAM 3D Body model inference."""
43
+ height, width = img.shape[:2]
44
+
45
+ # construct batch data samples
46
+ data_list: list[dict[str, Any]] = []
47
+ for idx in range(boxes.shape[0]):
48
+ data_info: dict[str, Any] = dict(img=img)
49
+ data_info["bbox"] = boxes[idx] # shape (4,)
50
+ data_info["bbox_format"] = "xyxy"
51
+
52
+ if masks is not None:
53
+ data_info["mask"] = masks[idx].astype(np.float32, copy=False)
54
+ if masks_score is not None:
55
+ data_info["mask_score"] = masks_score[idx]
56
+ else:
57
+ data_info["mask_score"] = np.array(1.0, dtype=np.float32)
58
+ else:
59
+ data_info["mask"] = np.zeros((height, width, 1), dtype=np.uint8)
60
+ data_info["mask_score"] = np.array(0.0, dtype=np.float32)
61
+
62
+ data_list.append(transform(data_info))
63
+
64
+ batch = default_collate(data_list)
65
+
66
+ max_num_person = batch["img"].shape[0]
67
+ for key in [
68
+ "img",
69
+ "img_size",
70
+ "ori_img_size",
71
+ "bbox_center",
72
+ "bbox_scale",
73
+ "bbox",
74
+ "affine_trans",
75
+ "mask",
76
+ "mask_score",
77
+ ]:
78
+ if key in batch:
79
+ batch[key] = batch[key].unsqueeze(0).float()
80
+ if "mask" in batch:
81
+ batch["mask"] = batch["mask"].unsqueeze(2)
82
+ batch["person_valid"] = torch.ones((1, max_num_person))
83
+
84
+ if cam_int is not None:
85
+ batch["cam_int"] = cam_int.to(batch["img"])
86
+ else:
87
+ # Default camera intrinsics according image size
88
+ batch["cam_int"] = torch.tensor(
89
+ [
90
+ [
91
+ [(height**2 + width**2) ** 0.5, 0, width / 2.0],
92
+ [0, (height**2 + width**2) ** 0.5, height / 2.0],
93
+ [0, 0, 1],
94
+ ]
95
+ ],
96
+ ).to(batch["img"])
97
+
98
+ batch["img_ori"] = [NoCollate(img)]
99
+ return cast(PreparedBatchDict, batch)
src/sam3d_body/gradio_ui/sam3d_body_ui.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Demonstrates integrating Rerun visualization with Gradio.
3
+
4
+ Provides example implementations of data streaming, keypoint annotation, and dynamic
5
+ visualization across multiple Gradio tabs using Rerun's recording and visualization capabilities.
6
+ """
7
+
8
+ import os
9
+ import shutil
10
+ import tempfile
11
+ from pathlib import Path
12
+ from typing import Final
13
+
14
+ import cv2
15
+ import gradio as gr
16
+ import rerun as rr
17
+ import rerun.blueprint as rrb
18
+ import spaces
19
+ from gradio_rerun import Rerun
20
+ from jaxtyping import Int, UInt8
21
+ from monopriors.relative_depth_models import RelativeDepthPrediction
22
+ from numpy import ndarray
23
+
24
+ from sam3d_body.api.demo import SAM3Config, SAM3DBodyE2E, SAM3DBodyE2EConfig, create_view, set_annotation_context
25
+ from sam3d_body.api.visualization import export_meshes_to_glb, visualize_sample
26
+ from sam3d_body.sam_3d_body_estimator import FinalPosePrediction
27
+
28
+ CFG: SAM3DBodyE2EConfig = SAM3DBodyE2EConfig(sam3_config=SAM3Config())
29
+ MODEL_E2E: SAM3DBodyE2E = SAM3DBodyE2E(config=CFG)
30
+ mesh_faces: Int[ndarray, "n_faces=36874 3"] = MODEL_E2E.sam3d_body_estimator.faces
31
+ STATE: Final[str] = "✅ Ready"
32
+ # Absolute path to bundled example data used by Gradio examples.
33
+ TEST_INPUT_DIR: Final[Path] = Path(__file__).resolve().parents[3] / "data" / "example-data"
34
+
35
+ # Allow Gradio to serve and cache files from the bundled test data directory.
36
+ gr.set_static_paths([str(TEST_INPUT_DIR)])
37
+
38
+
39
+ @spaces.GPU()
40
+ @rr.thread_local_stream("sam3d_body_gradio_ui")
41
+ def sam3d_prediction_fn(
42
+ rgb_hw3,
43
+ log_relative_depth,
44
+ export_glb,
45
+ center_glb,
46
+ pending_cleanup=None,
47
+ ) -> tuple[str, str, list[str]]:
48
+ # resize rgb so that its largest dimension is 1024
49
+ rgb_hw3: UInt8[ndarray, "h w 3"] = cv2.resize(
50
+ rgb_hw3, # type: ignore[arg-type]
51
+ dsize=(0, 0),
52
+ fx=1024 / max(rgb_hw3.shape[0], rgb_hw3.shape[1]),
53
+ fy=1024 / max(rgb_hw3.shape[0], rgb_hw3.shape[1]),
54
+ interpolation=cv2.INTER_AREA,
55
+ )
56
+ # We eventually want to clean up the RRD file after it's sent to the viewer, so tracking
57
+ # any pending files to be cleaned up when the state is deleted.
58
+ temp = tempfile.NamedTemporaryFile(prefix="cube_", suffix=".rrd", delete=False)
59
+
60
+ if pending_cleanup is not None:
61
+ pending_cleanup.append(temp.name)
62
+
63
+ view: rrb.ContainerLike = create_view()
64
+ blueprint = rrb.Blueprint(view, collapse_panels=True)
65
+ rr.save(path=temp.name, default_blueprint=blueprint)
66
+ set_annotation_context()
67
+ parent_log_path = Path("/world")
68
+ rr.log("/", rr.ViewCoordinates.RDF, static=True)
69
+
70
+ outputs: tuple[list[FinalPosePrediction], RelativeDepthPrediction] = MODEL_E2E.predict_single_image(rgb_hw3=rgb_hw3)
71
+ pred_list: list[FinalPosePrediction] = outputs[0]
72
+ relative_pred: RelativeDepthPrediction = outputs[1]
73
+ rr.set_time(timeline="image_sequence", sequence=0)
74
+ visualize_sample(
75
+ pred_list=pred_list,
76
+ rgb_hw3=rgb_hw3,
77
+ parent_log_path=parent_log_path,
78
+ faces=mesh_faces,
79
+ relative_depth_pred=relative_pred if log_relative_depth else None,
80
+ )
81
+
82
+ glb_files: list[str] = []
83
+ if export_glb and len(pred_list) > 0:
84
+ glb_dir: Path = Path(tempfile.mkdtemp(prefix="sam3d_glb_"))
85
+ glb_paths = export_meshes_to_glb(
86
+ pred_list=pred_list,
87
+ faces=mesh_faces,
88
+ output_dir=glb_dir,
89
+ center_mesh=center_glb,
90
+ )
91
+ glb_files = [str(p) for p in glb_paths]
92
+ if pending_cleanup is not None:
93
+ pending_cleanup.extend(glb_files)
94
+ pending_cleanup.append(str(glb_dir))
95
+
96
+ return temp.name, STATE, glb_files
97
+
98
+
99
+ def cleanup_rrds(pending_cleanup: list[str]) -> None:
100
+ for f in pending_cleanup:
101
+ if os.path.isdir(f):
102
+ shutil.rmtree(f, ignore_errors=True)
103
+ elif os.path.isfile(f):
104
+ os.unlink(f)
105
+
106
+
107
+ def _switch_to_outputs() -> gr.Tabs:
108
+ return gr.update(selected="outputs")
109
+
110
+
111
+ def main():
112
+ viewer = Rerun(
113
+ streaming=True,
114
+ panel_states={
115
+ "time": "collapsed",
116
+ "blueprint": "hidden",
117
+ "selection": "hidden",
118
+ },
119
+ height=800,
120
+ )
121
+
122
+ with gr.Blocks() as demo, gr.Tab("SAM3D Body Estimation"):
123
+ pending_cleanup = gr.State([], time_to_live=10, delete_callback=cleanup_rrds)
124
+ with gr.Row():
125
+ with gr.Column(scale=1):
126
+ tabs = gr.Tabs(selected="inputs")
127
+ with tabs:
128
+ with gr.TabItem("Inputs", id="inputs"):
129
+ img = gr.Image(interactive=True, label="Image", type="numpy", image_mode="RGB")
130
+ depth_checkbox = gr.Checkbox(label="Log relative depth", value=False)
131
+ with gr.Row():
132
+ export_checkbox = gr.Checkbox(label="Export GLB meshes", value=False)
133
+ center_checkbox = gr.Checkbox(label="Center GLB at origin", value=True)
134
+ create_rrd = gr.Button("Predict Pose")
135
+ with gr.TabItem("Outputs", id="outputs"):
136
+ status = gr.Text(STATE, label="Status")
137
+ mesh_files = gr.Files(label="GLB meshes", file_count="multiple")
138
+ gr.Examples(
139
+ examples=[
140
+ [str(TEST_INPUT_DIR / "Planche.jpg"), True, False, True],
141
+ [str(TEST_INPUT_DIR / "Amir-Khan-Lamont-Peterson_2689582.jpg"), False, False, True],
142
+ [str(TEST_INPUT_DIR / "BNAAHPYGMYSE26U6C6T7VA6544.jpg"), False, True, True],
143
+ [str(TEST_INPUT_DIR / "yoga-example.jpg"), True, True, False],
144
+ ],
145
+ inputs=[img, depth_checkbox, export_checkbox, center_checkbox],
146
+ outputs=[viewer, status, mesh_files],
147
+ fn=sam3d_prediction_fn,
148
+ run_on_click=True,
149
+ cache_examples=False,
150
+ examples_per_page=2,
151
+ )
152
+ with gr.Column(scale=5):
153
+ viewer.render()
154
+
155
+ create_rrd.click(
156
+ fn=_switch_to_outputs,
157
+ inputs=None,
158
+ outputs=[tabs],
159
+ ).then(
160
+ sam3d_prediction_fn,
161
+ inputs=[img, depth_checkbox, export_checkbox, center_checkbox, pending_cleanup],
162
+ outputs=[viewer, status, mesh_files],
163
+ )
164
+ return demo
src/sam3d_body/metadata/__init__.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ OPENPOSE_TO_COCO = [0, 16, 15, 18, 17, 5, 2, 6, 3, 7, 4, 12, 9, 13, 10, 14, 11]
4
+
5
+ # Mapping the J19 used in HMR2.0 to the 14 common points for evaluation
6
+ # J19 is defined as the first 19 keypoints in https://github.com/nkolot/SPIN/blob/master/constants.py#L42
7
+ # The first 14 keypoints in J19 are LSP keypoints
8
+ J19_TO_J14 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18]
9
+
10
+ # Mapping from 14 LSP keypoints to 17 COCO keypoints
11
+ # Key: coco_idx, value: lsp_idx
12
+ LSP_TO_COCO = {
13
+ 5: 9,
14
+ 6: 8,
15
+ 7: 10,
16
+ 8: 7,
17
+ 9: 11,
18
+ 10: 6,
19
+ 11: 3,
20
+ 12: 2,
21
+ 13: 4,
22
+ 14: 1,
23
+ 15: 5,
24
+ 16: 0,
25
+ }
26
+
27
+ # fmt: off
28
+ OPENPOSE_PERMUTATION = [0, 1, 5, 6, 7, 2, 3, 4, 8, 12, 13, 14, 9, 10, 11, 16, 15, 18, 17, 22, 23, 24, 19, 20, 21]
29
+ J19_PERMUTATION = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18]
30
+ COCO_PERMUTATION = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
31
+ # fmt: on
32
+
33
+ # Mapping the 70 MHR keypoints to OpenPose (COCO included)
34
+ # key: OpenPose, value: mhr_idx
35
+ MHR70_TO_OPENPOSE = {
36
+ 0: 0,
37
+ 1: 69,
38
+ 2: 6,
39
+ 3: 8,
40
+ 4: 41,
41
+ 5: 5,
42
+ 6: 7,
43
+ 7: 62,
44
+ 9: 10,
45
+ 10: 12,
46
+ 11: 14,
47
+ 12: 9,
48
+ 13: 11,
49
+ 14: 13,
50
+ 15: 2,
51
+ 16: 1,
52
+ 17: 4,
53
+ 18: 3,
54
+ 19: 15,
55
+ 20: 16,
56
+ 21: 17,
57
+ 22: 18,
58
+ 23: 19,
59
+ 24: 20,
60
+ }
61
+
62
+ # fmt: off
63
+ MHR70_PERMUTATION = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 18, 19, 20, 15, 16, 17, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 64, 63, 66, 65, 68, 67, 69]
64
+ # fmt: on
65
+ MHR70_TO_LSP = {
66
+ 0: 14,
67
+ 1: 12,
68
+ 2: 10,
69
+ 3: 9,
70
+ 4: 11,
71
+ 5: 13,
72
+ 6: 41,
73
+ 7: 8,
74
+ 8: 6,
75
+ 9: 5,
76
+ 10: 7,
77
+ 11: 62,
78
+ 12: 69,
79
+ }
src/sam3d_body/metadata/mhr70.py ADDED
@@ -0,0 +1,915 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ """The first 70 of 308 MHR keypoints, ignoring the rest for face keypoints"""
4
+
5
+ from typing import Final
6
+
7
+ mhr_names = [
8
+ "nose",
9
+ "left-eye",
10
+ "right-eye",
11
+ "left-ear",
12
+ "right-ear",
13
+ "left-shoulder",
14
+ "right-shoulder",
15
+ "left-elbow",
16
+ "right-elbow",
17
+ "left-hip",
18
+ "right-hip",
19
+ "left-knee",
20
+ "right-knee",
21
+ "left-ankle",
22
+ "right-ankle",
23
+ "left-big-toe-tip",
24
+ "left-small-toe-tip",
25
+ "left-heel",
26
+ "right-big-toe-tip",
27
+ "right-small-toe-tip",
28
+ "right-heel",
29
+ "right-thumb-tip",
30
+ "right-thumb-first-joint",
31
+ "right-thumb-second-joint",
32
+ "right-thumb-third-joint",
33
+ "right-index-tip",
34
+ "right-index-first-joint",
35
+ "right-index-second-joint",
36
+ "right-index-third-joint",
37
+ "right-middle-tip",
38
+ "right-middle-first-joint",
39
+ "right-middle-second-joint",
40
+ "right-middle-third-joint",
41
+ "right-ring-tip",
42
+ "right-ring-first-joint",
43
+ "right-ring-second-joint",
44
+ "right-ring-third-joint",
45
+ "right-pinky-tip",
46
+ "right-pinky-first-joint",
47
+ "right-pinky-second-joint",
48
+ "right-pinky-third-joint",
49
+ "right-wrist",
50
+ "left-thumb-tip",
51
+ "left-thumb-first-joint",
52
+ "left-thumb-second-joint",
53
+ "left-thumb-third-joint",
54
+ "left-index-tip",
55
+ "left-index-first-joint",
56
+ "left-index-second-joint",
57
+ "left-index-third-joint",
58
+ "left-middle-tip",
59
+ "left-middle-first-joint",
60
+ "left-middle-second-joint",
61
+ "left-middle-third-joint",
62
+ "left-ring-tip",
63
+ "left-ring-first-joint",
64
+ "left-ring-second-joint",
65
+ "left-ring-third-joint",
66
+ "left-pinky-tip",
67
+ "left-pinky-first-joint",
68
+ "left-pinky-second-joint",
69
+ "left-pinky-third-joint",
70
+ "left-wrist",
71
+ "left-olecranon",
72
+ "right-olecranon",
73
+ "left-cubital-fossa",
74
+ "right-cubital-fossa",
75
+ "left-acromion",
76
+ "right-acromion",
77
+ "neck",
78
+ ]
79
+
80
+ pose_info = dict(
81
+ pose_format="mhr70",
82
+ paper_info=dict(
83
+ author="",
84
+ year="",
85
+ homepage="",
86
+ ),
87
+ min_visible_keypoints=8,
88
+ image_height=4096,
89
+ image_width=2668,
90
+ original_keypoint_info={
91
+ 0: "nose",
92
+ 1: "left_eye",
93
+ 2: "right_eye",
94
+ 3: "left_ear",
95
+ 4: "right_ear",
96
+ 5: "left_shoulder",
97
+ 6: "right_shoulder",
98
+ 7: "left_elbow",
99
+ 8: "right_elbow",
100
+ 9: "left_hip",
101
+ 10: "right_hip",
102
+ 11: "left_knee",
103
+ 12: "right_knee",
104
+ 13: "left_ankle",
105
+ 14: "right_ankle",
106
+ 15: "left_big_toe_tip",
107
+ 16: "left_small_toe_tip",
108
+ 17: "left_heel",
109
+ 18: "right_big_toe_tip",
110
+ 19: "right_small_toe_tip",
111
+ 20: "right_heel",
112
+ 21: "right_thumb_tip",
113
+ 22: "right_thumb_first_joint",
114
+ 23: "right_thumb_second_joint",
115
+ 24: "right_thumb_third_joint",
116
+ 25: "right_index_tip",
117
+ 26: "right_index_first_joint",
118
+ 27: "right_index_second_joint",
119
+ 28: "right_index_third_joint",
120
+ 29: "right_middle_tip",
121
+ 30: "right_middle_first_joint",
122
+ 31: "right_middle_second_joint",
123
+ 32: "right_middle_third_joint",
124
+ 33: "right_ring_tip",
125
+ 34: "right_ring_first_joint",
126
+ 35: "right_ring_second_joint",
127
+ 36: "right_ring_third_joint",
128
+ 37: "right_pinky_tip",
129
+ 38: "right_pinky_first_joint",
130
+ 39: "right_pinky_second_joint",
131
+ 40: "right_pinky_third_joint",
132
+ 41: "right_wrist",
133
+ 42: "left_thumb_tip",
134
+ 43: "left_thumb_first_joint",
135
+ 44: "left_thumb_second_joint",
136
+ 45: "left_thumb_third_joint",
137
+ 46: "left_index_tip",
138
+ 47: "left_index_first_joint",
139
+ 48: "left_index_second_joint",
140
+ 49: "left_index_third_joint",
141
+ 50: "left_middle_tip",
142
+ 51: "left_middle_first_joint",
143
+ 52: "left_middle_second_joint",
144
+ 53: "left_middle_third_joint",
145
+ 54: "left_ring_tip",
146
+ 55: "left_ring_first_joint",
147
+ 56: "left_ring_second_joint",
148
+ 57: "left_ring_third_joint",
149
+ 58: "left_pinky_tip",
150
+ 59: "left_pinky_first_joint",
151
+ 60: "left_pinky_second_joint",
152
+ 61: "left_pinky_third_joint",
153
+ 62: "left_wrist",
154
+ 63: "left_olecranon",
155
+ 64: "right_olecranon",
156
+ 65: "left_cubital_fossa",
157
+ 66: "right_cubital_fossa",
158
+ 67: "left_acromion",
159
+ 68: "right_acromion",
160
+ 69: "neck",
161
+ },
162
+ keypoint_info={
163
+ 0: dict(name="nose", id=0, color=[51, 153, 255], type="upper", swap=""),
164
+ 1: dict(
165
+ name="left_eye", id=1, color=[51, 153, 255], type="upper", swap="right_eye"
166
+ ),
167
+ 2: dict(
168
+ name="right_eye", id=2, color=[51, 153, 255], type="upper", swap="left_eye"
169
+ ),
170
+ 3: dict(
171
+ name="left_ear", id=3, color=[51, 153, 255], type="upper", swap="right_ear"
172
+ ),
173
+ 4: dict(
174
+ name="right_ear", id=4, color=[51, 153, 255], type="upper", swap="left_ear"
175
+ ),
176
+ 5: dict(
177
+ name="left_shoulder",
178
+ id=5,
179
+ color=[51, 153, 255],
180
+ type="upper",
181
+ swap="right_shoulder",
182
+ ),
183
+ 6: dict(
184
+ name="right_shoulder",
185
+ id=6,
186
+ color=[51, 153, 255],
187
+ type="upper",
188
+ swap="left_shoulder",
189
+ ),
190
+ 7: dict(
191
+ name="left_elbow",
192
+ id=7,
193
+ color=[51, 153, 255],
194
+ type="upper",
195
+ swap="right_elbow",
196
+ ),
197
+ 8: dict(
198
+ name="right_elbow",
199
+ id=8,
200
+ color=[51, 153, 255],
201
+ type="upper",
202
+ swap="left_elbow",
203
+ ),
204
+ 9: dict(
205
+ name="left_hip", id=9, color=[51, 153, 255], type="lower", swap="right_hip"
206
+ ),
207
+ 10: dict(
208
+ name="right_hip", id=10, color=[51, 153, 255], type="lower", swap="left_hip"
209
+ ),
210
+ 11: dict(
211
+ name="left_knee",
212
+ id=11,
213
+ color=[51, 153, 255],
214
+ type="lower",
215
+ swap="right_knee",
216
+ ),
217
+ 12: dict(
218
+ name="right_knee",
219
+ id=12,
220
+ color=[51, 153, 255],
221
+ type="lower",
222
+ swap="left_knee",
223
+ ),
224
+ 13: dict(
225
+ name="left_ankle",
226
+ id=13,
227
+ color=[51, 153, 255],
228
+ type="lower",
229
+ swap="right_ankle",
230
+ ),
231
+ 14: dict(
232
+ name="right_ankle",
233
+ id=14,
234
+ color=[51, 153, 255],
235
+ type="lower",
236
+ swap="left_ankle",
237
+ ),
238
+ 15: dict(
239
+ name="left_big_toe",
240
+ id=15,
241
+ color=[51, 153, 255],
242
+ type="lower",
243
+ swap="right_big_toe",
244
+ ),
245
+ 16: dict(
246
+ name="left_small_toe",
247
+ id=16,
248
+ color=[51, 153, 255],
249
+ type="lower",
250
+ swap="right_small_toe",
251
+ ),
252
+ 17: dict(
253
+ name="left_heel",
254
+ id=17,
255
+ color=[51, 153, 255],
256
+ type="lower",
257
+ swap="right_heel",
258
+ ),
259
+ 18: dict(
260
+ name="right_big_toe",
261
+ id=18,
262
+ color=[51, 153, 255],
263
+ type="lower",
264
+ swap="left_big_toe",
265
+ ),
266
+ 19: dict(
267
+ name="right_small_toe",
268
+ id=19,
269
+ color=[51, 153, 255],
270
+ type="lower",
271
+ swap="left_small_toe",
272
+ ),
273
+ 20: dict(
274
+ name="right_heel",
275
+ id=20,
276
+ color=[51, 153, 255],
277
+ type="lower",
278
+ swap="left_heel",
279
+ ),
280
+ 21: dict(
281
+ name="right_thumb4",
282
+ id=21,
283
+ color=[51, 153, 255],
284
+ type="upper",
285
+ swap="left_thumb4",
286
+ ),
287
+ 22: dict(
288
+ name="right_thumb3",
289
+ id=22,
290
+ color=[51, 153, 255],
291
+ type="upper",
292
+ swap="left_thumb3",
293
+ ),
294
+ 23: dict(
295
+ name="right_thumb2",
296
+ id=23,
297
+ color=[51, 153, 255],
298
+ type="upper",
299
+ swap="left_thumb2",
300
+ ),
301
+ 24: dict(
302
+ name="right_thumb_third_joint",
303
+ id=24,
304
+ color=[51, 153, 255],
305
+ type="upper",
306
+ swap="left_thumb_third_joint",
307
+ ),
308
+ 25: dict(
309
+ name="right_forefinger4",
310
+ id=25,
311
+ color=[51, 153, 255],
312
+ type="upper",
313
+ swap="left_forefinger4",
314
+ ),
315
+ 26: dict(
316
+ name="right_forefinger3",
317
+ id=26,
318
+ color=[51, 153, 255],
319
+ type="upper",
320
+ swap="left_forefinger3",
321
+ ),
322
+ 27: dict(
323
+ name="right_forefinger2",
324
+ id=27,
325
+ color=[51, 153, 255],
326
+ type="upper",
327
+ swap="left_forefinger2",
328
+ ),
329
+ 28: dict(
330
+ name="right_forefinger_third_joint",
331
+ id=28,
332
+ color=[51, 153, 255],
333
+ type="upper",
334
+ swap="left_forefinger_third_joint",
335
+ ),
336
+ 29: dict(
337
+ name="right_middle_finger4",
338
+ id=29,
339
+ color=[51, 153, 255],
340
+ type="upper",
341
+ swap="left_middle_finger4",
342
+ ),
343
+ 30: dict(
344
+ name="right_middle_finger3",
345
+ id=30,
346
+ color=[51, 153, 255],
347
+ type="upper",
348
+ swap="left_middle_finger3",
349
+ ),
350
+ 31: dict(
351
+ name="right_middle_finger2",
352
+ id=31,
353
+ color=[51, 153, 255],
354
+ type="upper",
355
+ swap="left_middle_finger2",
356
+ ),
357
+ 32: dict(
358
+ name="right_middle_finger_third_joint",
359
+ id=32,
360
+ color=[51, 153, 255],
361
+ type="upper",
362
+ swap="left_middle_finger_third_joint",
363
+ ),
364
+ 33: dict(
365
+ name="right_ring_finger4",
366
+ id=33,
367
+ color=[51, 153, 255],
368
+ type="upper",
369
+ swap="left_ring_finger4",
370
+ ),
371
+ 34: dict(
372
+ name="right_ring_finger3",
373
+ id=34,
374
+ color=[51, 153, 255],
375
+ type="upper",
376
+ swap="left_ring_finger3",
377
+ ),
378
+ 35: dict(
379
+ name="right_ring_finger2",
380
+ id=35,
381
+ color=[51, 153, 255],
382
+ type="upper",
383
+ swap="left_ring_finger2",
384
+ ),
385
+ 36: dict(
386
+ name="right_ring_finger_third_joint",
387
+ id=36,
388
+ color=[51, 153, 255],
389
+ type="upper",
390
+ swap="left_ring_finger_third_joint",
391
+ ),
392
+ 37: dict(
393
+ name="right_pinky_finger4",
394
+ id=37,
395
+ color=[51, 153, 255],
396
+ type="upper",
397
+ swap="left_pinky_finger4",
398
+ ),
399
+ 38: dict(
400
+ name="right_pinky_finger3",
401
+ id=38,
402
+ color=[51, 153, 255],
403
+ type="upper",
404
+ swap="left_pinky_finger3",
405
+ ),
406
+ 39: dict(
407
+ name="right_pinky_finger2",
408
+ id=39,
409
+ color=[51, 153, 255],
410
+ type="upper",
411
+ swap="left_pinky_finger2",
412
+ ),
413
+ 40: dict(
414
+ name="right_pinky_finger_third_joint",
415
+ id=40,
416
+ color=[51, 153, 255],
417
+ type="upper",
418
+ swap="left_pinky_finger_third_joint",
419
+ ),
420
+ 41: dict(
421
+ name="right_wrist",
422
+ id=41,
423
+ color=[51, 153, 255],
424
+ type="upper",
425
+ swap="left_wrist",
426
+ ),
427
+ 42: dict(
428
+ name="left_thumb4",
429
+ id=42,
430
+ color=[51, 153, 255],
431
+ type="upper",
432
+ swap="right_thumb4",
433
+ ),
434
+ 43: dict(
435
+ name="left_thumb3",
436
+ id=43,
437
+ color=[51, 153, 255],
438
+ type="upper",
439
+ swap="right_thumb3",
440
+ ),
441
+ 44: dict(
442
+ name="left_thumb2",
443
+ id=44,
444
+ color=[51, 153, 255],
445
+ type="upper",
446
+ swap="right_thumb2",
447
+ ),
448
+ 45: dict(
449
+ name="left_thumb_third_joint",
450
+ id=45,
451
+ color=[51, 153, 255],
452
+ type="upper",
453
+ swap="right_thumb_third_joint",
454
+ ), ## doesnt match with wholebody
455
+ 46: dict(
456
+ name="left_forefinger4",
457
+ id=46,
458
+ color=[51, 153, 255],
459
+ type="upper",
460
+ swap="right_forefinger4",
461
+ ),
462
+ 47: dict(
463
+ name="left_forefinger3",
464
+ id=47,
465
+ color=[51, 153, 255],
466
+ type="upper",
467
+ swap="right_forefinger3",
468
+ ),
469
+ 48: dict(
470
+ name="left_forefinger2",
471
+ id=48,
472
+ color=[51, 153, 255],
473
+ type="upper",
474
+ swap="right_forefinger2",
475
+ ),
476
+ 49: dict(
477
+ name="left_forefinger_third_joint",
478
+ id=49,
479
+ color=[51, 153, 255],
480
+ type="upper",
481
+ swap="right_forefinger_third_joint",
482
+ ),
483
+ 50: dict(
484
+ name="left_middle_finger4",
485
+ id=50,
486
+ color=[51, 153, 255],
487
+ type="upper",
488
+ swap="right_middle_finger4",
489
+ ),
490
+ 51: dict(
491
+ name="left_middle_finger3",
492
+ id=51,
493
+ color=[51, 153, 255],
494
+ type="upper",
495
+ swap="right_middle_finger3",
496
+ ),
497
+ 52: dict(
498
+ name="left_middle_finger2",
499
+ id=52,
500
+ color=[51, 153, 255],
501
+ type="upper",
502
+ swap="right_middle_finger2",
503
+ ),
504
+ 53: dict(
505
+ name="left_middle_finger_third_joint",
506
+ id=53,
507
+ color=[51, 153, 255],
508
+ type="upper",
509
+ swap="right_middle_finger_third_joint",
510
+ ),
511
+ 54: dict(
512
+ name="left_ring_finger4",
513
+ id=54,
514
+ color=[51, 153, 255],
515
+ type="upper",
516
+ swap="right_ring_finger4",
517
+ ),
518
+ 55: dict(
519
+ name="left_ring_finger3",
520
+ id=55,
521
+ color=[51, 153, 255],
522
+ type="upper",
523
+ swap="right_ring_finger3",
524
+ ),
525
+ 56: dict(
526
+ name="left_ring_finger2",
527
+ id=56,
528
+ color=[51, 153, 255],
529
+ type="upper",
530
+ swap="right_ring_finger2",
531
+ ),
532
+ 57: dict(
533
+ name="left_ring_finger_third_joint",
534
+ id=57,
535
+ color=[51, 153, 255],
536
+ type="upper",
537
+ swap="right_ring_finger_third_joint",
538
+ ),
539
+ 58: dict(
540
+ name="left_pinky_finger4",
541
+ id=58,
542
+ color=[51, 153, 255],
543
+ type="upper",
544
+ swap="right_pinky_finger4",
545
+ ),
546
+ 59: dict(
547
+ name="left_pinky_finger3",
548
+ id=59,
549
+ color=[51, 153, 255],
550
+ type="upper",
551
+ swap="right_pinky_finger3",
552
+ ),
553
+ 60: dict(
554
+ name="left_pinky_finger2",
555
+ id=60,
556
+ color=[51, 153, 255],
557
+ type="upper",
558
+ swap="right_pinky_finger2",
559
+ ),
560
+ 61: dict(
561
+ name="left_pinky_finger_third_joint",
562
+ id=61,
563
+ color=[51, 153, 255],
564
+ type="upper",
565
+ swap="right_pinky_finger_third_joint",
566
+ ),
567
+ 62: dict(
568
+ name="left_wrist",
569
+ id=62,
570
+ color=[51, 153, 255],
571
+ type="upper",
572
+ swap="right_wrist",
573
+ ),
574
+ 63: dict(
575
+ name="left_olecranon",
576
+ id=63,
577
+ color=[51, 153, 255],
578
+ type="",
579
+ swap="right_olecranon",
580
+ ),
581
+ 64: dict(
582
+ name="right_olecranon",
583
+ id=64,
584
+ color=[51, 153, 255],
585
+ type="",
586
+ swap="left_olecranon",
587
+ ),
588
+ 65: dict(
589
+ name="left_cubital_fossa",
590
+ id=65,
591
+ color=[51, 153, 255],
592
+ type="",
593
+ swap="right_cubital_fossa",
594
+ ),
595
+ 66: dict(
596
+ name="right_cubital_fossa",
597
+ id=66,
598
+ color=[51, 153, 255],
599
+ type="",
600
+ swap="left_cubital_fossa",
601
+ ),
602
+ 67: dict(
603
+ name="left_acromion",
604
+ id=67,
605
+ color=[51, 153, 255],
606
+ type="",
607
+ swap="right_acromion",
608
+ ),
609
+ 68: dict(
610
+ name="right_acromion",
611
+ id=68,
612
+ color=[51, 153, 255],
613
+ type="",
614
+ swap="left_acromion",
615
+ ),
616
+ 69: dict(name="neck", id=69, color=[51, 153, 255], type="", swap=""),
617
+ },
618
+ skeleton_info={
619
+ 0: dict(link=("left_ankle", "left_knee"), id=0, color=[0, 255, 0]),
620
+ 1: dict(link=("left_knee", "left_hip"), id=1, color=[0, 255, 0]),
621
+ 2: dict(link=("right_ankle", "right_knee"), id=2, color=[255, 128, 0]),
622
+ 3: dict(link=("right_knee", "right_hip"), id=3, color=[255, 128, 0]),
623
+ 4: dict(link=("left_hip", "right_hip"), id=4, color=[51, 153, 255]),
624
+ 5: dict(link=("left_shoulder", "left_hip"), id=5, color=[51, 153, 255]),
625
+ 6: dict(link=("right_shoulder", "right_hip"), id=6, color=[51, 153, 255]),
626
+ 7: dict(link=("left_shoulder", "right_shoulder"), id=7, color=[51, 153, 255]),
627
+ 8: dict(link=("left_shoulder", "left_elbow"), id=8, color=[0, 255, 0]),
628
+ 9: dict(link=("right_shoulder", "right_elbow"), id=9, color=[255, 128, 0]),
629
+ 10: dict(link=("left_elbow", "left_wrist"), id=10, color=[0, 255, 0]),
630
+ 11: dict(link=("right_elbow", "right_wrist"), id=11, color=[255, 128, 0]),
631
+ 12: dict(link=("left_eye", "right_eye"), id=12, color=[51, 153, 255]),
632
+ 13: dict(link=("nose", "left_eye"), id=13, color=[51, 153, 255]),
633
+ 14: dict(link=("nose", "right_eye"), id=14, color=[51, 153, 255]),
634
+ 15: dict(link=("left_eye", "left_ear"), id=15, color=[51, 153, 255]),
635
+ 16: dict(link=("right_eye", "right_ear"), id=16, color=[51, 153, 255]),
636
+ 17: dict(link=("left_ear", "left_shoulder"), id=17, color=[51, 153, 255]),
637
+ 18: dict(link=("right_ear", "right_shoulder"), id=18, color=[51, 153, 255]),
638
+ 19: dict(link=("left_ankle", "left_big_toe"), id=19, color=[0, 255, 0]),
639
+ 20: dict(link=("left_ankle", "left_small_toe"), id=20, color=[0, 255, 0]),
640
+ 21: dict(link=("left_ankle", "left_heel"), id=21, color=[0, 255, 0]),
641
+ 22: dict(link=("right_ankle", "right_big_toe"), id=22, color=[255, 128, 0]),
642
+ 23: dict(link=("right_ankle", "right_small_toe"), id=23, color=[255, 128, 0]),
643
+ 24: dict(link=("right_ankle", "right_heel"), id=24, color=[255, 128, 0]),
644
+ 25: dict(
645
+ link=("left_wrist", "left_thumb_third_joint"), id=25, color=[255, 128, 0]
646
+ ),
647
+ 26: dict(
648
+ link=("left_thumb_third_joint", "left_thumb2"), id=26, color=[255, 128, 0]
649
+ ),
650
+ 27: dict(link=("left_thumb2", "left_thumb3"), id=27, color=[255, 128, 0]),
651
+ 28: dict(link=("left_thumb3", "left_thumb4"), id=28, color=[255, 128, 0]),
652
+ 29: dict(
653
+ link=("left_wrist", "left_forefinger_third_joint"),
654
+ id=29,
655
+ color=[255, 153, 255],
656
+ ),
657
+ 30: dict(
658
+ link=("left_forefinger_third_joint", "left_forefinger2"),
659
+ id=30,
660
+ color=[255, 153, 255],
661
+ ),
662
+ 31: dict(
663
+ link=("left_forefinger2", "left_forefinger3"), id=31, color=[255, 153, 255]
664
+ ),
665
+ 32: dict(
666
+ link=("left_forefinger3", "left_forefinger4"), id=32, color=[255, 153, 255]
667
+ ),
668
+ 33: dict(
669
+ link=("left_wrist", "left_middle_finger_third_joint"),
670
+ id=33,
671
+ color=[102, 178, 255],
672
+ ),
673
+ 34: dict(
674
+ link=("left_middle_finger_third_joint", "left_middle_finger2"),
675
+ id=34,
676
+ color=[102, 178, 255],
677
+ ),
678
+ 35: dict(
679
+ link=("left_middle_finger2", "left_middle_finger3"),
680
+ id=35,
681
+ color=[102, 178, 255],
682
+ ),
683
+ 36: dict(
684
+ link=("left_middle_finger3", "left_middle_finger4"),
685
+ id=36,
686
+ color=[102, 178, 255],
687
+ ),
688
+ 37: dict(
689
+ link=("left_wrist", "left_ring_finger_third_joint"),
690
+ id=37,
691
+ color=[255, 51, 51],
692
+ ),
693
+ 38: dict(
694
+ link=("left_ring_finger_third_joint", "left_ring_finger2"),
695
+ id=38,
696
+ color=[255, 51, 51],
697
+ ),
698
+ 39: dict(
699
+ link=("left_ring_finger2", "left_ring_finger3"), id=39, color=[255, 51, 51]
700
+ ),
701
+ 40: dict(
702
+ link=("left_ring_finger3", "left_ring_finger4"), id=40, color=[255, 51, 51]
703
+ ),
704
+ 41: dict(
705
+ link=("left_wrist", "left_pinky_finger_third_joint"),
706
+ id=41,
707
+ color=[0, 255, 0],
708
+ ),
709
+ 42: dict(
710
+ link=("left_pinky_finger_third_joint", "left_pinky_finger2"),
711
+ id=42,
712
+ color=[0, 255, 0],
713
+ ),
714
+ 43: dict(
715
+ link=("left_pinky_finger2", "left_pinky_finger3"), id=43, color=[0, 255, 0]
716
+ ),
717
+ 44: dict(
718
+ link=("left_pinky_finger3", "left_pinky_finger4"), id=44, color=[0, 255, 0]
719
+ ),
720
+ 45: dict(
721
+ link=("right_wrist", "right_thumb_third_joint"), id=45, color=[255, 128, 0]
722
+ ),
723
+ 46: dict(
724
+ link=("right_thumb_third_joint", "right_thumb2"), id=46, color=[255, 128, 0]
725
+ ),
726
+ 47: dict(link=("right_thumb2", "right_thumb3"), id=47, color=[255, 128, 0]),
727
+ 48: dict(link=("right_thumb3", "right_thumb4"), id=48, color=[255, 128, 0]),
728
+ 49: dict(
729
+ link=("right_wrist", "right_forefinger_third_joint"),
730
+ id=49,
731
+ color=[255, 153, 255],
732
+ ),
733
+ 50: dict(
734
+ link=("right_forefinger_third_joint", "right_forefinger2"),
735
+ id=50,
736
+ color=[255, 153, 255],
737
+ ),
738
+ 51: dict(
739
+ link=("right_forefinger2", "right_forefinger3"),
740
+ id=51,
741
+ color=[255, 153, 255],
742
+ ),
743
+ 52: dict(
744
+ link=("right_forefinger3", "right_forefinger4"),
745
+ id=52,
746
+ color=[255, 153, 255],
747
+ ),
748
+ 53: dict(
749
+ link=("right_wrist", "right_middle_finger_third_joint"),
750
+ id=53,
751
+ color=[102, 178, 255],
752
+ ),
753
+ 54: dict(
754
+ link=("right_middle_finger_third_joint", "right_middle_finger2"),
755
+ id=54,
756
+ color=[102, 178, 255],
757
+ ),
758
+ 55: dict(
759
+ link=("right_middle_finger2", "right_middle_finger3"),
760
+ id=55,
761
+ color=[102, 178, 255],
762
+ ),
763
+ 56: dict(
764
+ link=("right_middle_finger3", "right_middle_finger4"),
765
+ id=56,
766
+ color=[102, 178, 255],
767
+ ),
768
+ 57: dict(
769
+ link=("right_wrist", "right_ring_finger_third_joint"),
770
+ id=57,
771
+ color=[255, 51, 51],
772
+ ),
773
+ 58: dict(
774
+ link=("right_ring_finger_third_joint", "right_ring_finger2"),
775
+ id=58,
776
+ color=[255, 51, 51],
777
+ ),
778
+ 59: dict(
779
+ link=("right_ring_finger2", "right_ring_finger3"),
780
+ id=59,
781
+ color=[255, 51, 51],
782
+ ),
783
+ 60: dict(
784
+ link=("right_ring_finger3", "right_ring_finger4"),
785
+ id=60,
786
+ color=[255, 51, 51],
787
+ ),
788
+ 61: dict(
789
+ link=("right_wrist", "right_pinky_finger_third_joint"),
790
+ id=61,
791
+ color=[0, 255, 0],
792
+ ),
793
+ 62: dict(
794
+ link=("right_pinky_finger_third_joint", "right_pinky_finger2"),
795
+ id=62,
796
+ color=[0, 255, 0],
797
+ ),
798
+ 63: dict(
799
+ link=("right_pinky_finger2", "right_pinky_finger3"),
800
+ id=63,
801
+ color=[0, 255, 0],
802
+ ),
803
+ 64: dict(
804
+ link=("right_pinky_finger3", "right_pinky_finger4"),
805
+ id=64,
806
+ color=[0, 255, 0],
807
+ ),
808
+ },
809
+ joint_weights=[1.0] * 70,
810
+ body_keypoint_names=[
811
+ "nose",
812
+ "left_eye",
813
+ "right_eye",
814
+ "left_ear",
815
+ "right_ear",
816
+ "left_shoulder",
817
+ "right_shoulder",
818
+ "left_elbow",
819
+ "right_elbow",
820
+ "left_wrist",
821
+ "right_wrist",
822
+ "left_hip",
823
+ "right_hip",
824
+ "left_knee",
825
+ "right_knee",
826
+ "left_ankle",
827
+ "right_ankle",
828
+ ],
829
+ foot_keypoint_names=[
830
+ "left_big_toe",
831
+ "left_small_toe",
832
+ "left_heel",
833
+ "right_big_toe",
834
+ "right_small_toe",
835
+ "right_heel",
836
+ ],
837
+ left_hand_keypoint_names=[
838
+ "left_thumb4",
839
+ "left_thumb3",
840
+ "left_thumb2",
841
+ "left_thumb_third_joint",
842
+ "left_forefinger4",
843
+ "left_forefinger3",
844
+ "left_forefinger2",
845
+ "left_forefinger_third_joint",
846
+ "left_middle_finger4",
847
+ "left_middle_finger3",
848
+ "left_middle_finger2",
849
+ "left_middle_finger_third_joint",
850
+ "left_ring_finger4",
851
+ "left_ring_finger3",
852
+ "left_ring_finger2",
853
+ "left_ring_finger_third_joint",
854
+ "left_pinky_finger4",
855
+ "left_pinky_finger3",
856
+ "left_pinky_finger2",
857
+ "left_pinky_finger_third_joint",
858
+ ],
859
+ right_hand_keypoint_names=[
860
+ "right_thumb4",
861
+ "right_thumb3",
862
+ "right_thumb2",
863
+ "right_thumb_third_joint",
864
+ "right_forefinger4",
865
+ "right_forefinger3",
866
+ "right_forefinger2",
867
+ "right_forefinger_third_joint",
868
+ "right_middle_finger4",
869
+ "right_middle_finger3",
870
+ "right_middle_finger2",
871
+ "right_middle_finger_third_joint",
872
+ "right_ring_finger4",
873
+ "right_ring_finger3",
874
+ "right_ring_finger2",
875
+ "right_ring_finger_third_joint",
876
+ "right_pinky_finger4",
877
+ "right_pinky_finger3",
878
+ "right_pinky_finger2",
879
+ "right_pinky_finger_third_joint",
880
+ ],
881
+ ## 7 of them
882
+ extra_keypoint_names=[
883
+ "neck",
884
+ "left_olecranon",
885
+ "right_olecranon",
886
+ "left_cubital_fossa",
887
+ "right_cubital_fossa",
888
+ "left_acromion",
889
+ "right_acromion",
890
+ ],
891
+ sigmas=[],
892
+ )
893
+
894
+ # Rerun‑friendly helpers ----------------------------------------------------
895
+ # These mirror the COCO‑133 helpers exposed by ``simplecv.data.skeleton.coco_133``
896
+ # so downstream code can build annotation contexts without re‑deriving names/links.
897
+
898
+ MHR70_ID2NAME: Final[dict[int, str]] = {
899
+ idx: info["name"] for idx, info in pose_info["keypoint_info"].items()
900
+ }
901
+
902
+ MHR70_IDS: Final[list[int]] = sorted(MHR70_ID2NAME.keys())
903
+
904
+ _NAME_TO_ID = {name: idx for idx, name in MHR70_ID2NAME.items()}
905
+ MHR70_LINKS: Final[list[tuple[int, int]]] = [
906
+ (_NAME_TO_ID[link_info["link"][0]], _NAME_TO_ID[link_info["link"][1]])
907
+ for link_info in pose_info["skeleton_info"].values()
908
+ ]
909
+
910
+ __all__ = [
911
+ "pose_info",
912
+ "MHR70_ID2NAME",
913
+ "MHR70_IDS",
914
+ "MHR70_LINKS",
915
+ ]
src/sam3d_body/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
src/sam3d_body/models/backbones/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+
4
+ def create_backbone(name, cfg=None):
5
+ if name in ["vit_hmr"]:
6
+ from .vit import vit
7
+
8
+ backbone = vit(cfg)
9
+ elif name in ["vit_hmr_512_384"]:
10
+ from .vit import vit512_384
11
+
12
+ backbone = vit512_384(cfg)
13
+ elif name in ["vit_l"]:
14
+ from .vit import vit_l
15
+
16
+ backbone = vit_l(cfg)
17
+ elif name in ["vit_b"]:
18
+ from .vit import vit_b
19
+
20
+ backbone = vit_b(cfg)
21
+ elif name in [
22
+ "dinov3_vit7b",
23
+ "dinov3_vith16plus",
24
+ "dinov3_vits16",
25
+ "dinov3_vits16plus",
26
+ "dinov3_vitb16",
27
+ "dinov3_vitl16",
28
+ ]:
29
+ from .dinov3 import Dinov3Backbone
30
+
31
+ backbone = Dinov3Backbone(name, cfg=cfg)
32
+ else:
33
+ raise NotImplementedError("Backbone type is not implemented")
34
+
35
+ return backbone
src/sam3d_body/models/backbones/dinov3.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class Dinov3Backbone(nn.Module):
8
+ def __init__(
9
+ self, name="dinov2_vitb14", pretrained_weight=None, cfg=None, *args, **kwargs
10
+ ):
11
+ super().__init__()
12
+ self.name = name
13
+ self.cfg = cfg
14
+
15
+ self.encoder = torch.hub.load(
16
+ "facebookresearch/dinov3",
17
+ self.name,
18
+ source="github",
19
+ pretrained=False,
20
+ drop_path=self.cfg.MODEL.BACKBONE.DROP_PATH_RATE,
21
+ )
22
+ self.patch_size = self.encoder.patch_size
23
+ self.embed_dim = self.embed_dims = self.encoder.embed_dim
24
+
25
+ def forward(self, x, extra_embed=None):
26
+ """
27
+ Encode a RGB image using a ViT-backbone
28
+ Args:
29
+ - x: torch.Tensor of shape [bs,3,w,h]
30
+ Return:
31
+ - y: torch.Tensor of shape [bs,k,d] - image in patchified mode
32
+ """
33
+ assert extra_embed is None, "Not Implemented Yet"
34
+
35
+ y = self.encoder.get_intermediate_layers(x, n=1, reshape=True, norm=True)[-1]
36
+
37
+ return y
38
+
39
+ def get_layer_depth(self, param_name: str, prefix: str = "encoder."):
40
+ """Get the layer-wise depth of a parameter.
41
+ Args:
42
+ param_name (str): The name of the parameter.
43
+ prefix (str): The prefix for the parameter.
44
+ Defaults to an empty string.
45
+ Returns:
46
+ Tuple[int, int]: The layer-wise depth and the num of layers.
47
+ Note:
48
+ The first depth is the stem module (``layer_depth=0``), and the
49
+ last depth is the subsequent module (``layer_depth=num_layers-1``)
50
+ """
51
+ num_layers = self.encoder.n_blocks + 2
52
+
53
+ if not param_name.startswith(prefix):
54
+ # For subsequent module like head
55
+ return num_layers - 1, num_layers
56
+
57
+ param_name = param_name[len(prefix) :]
58
+
59
+ if param_name in ("cls_token", "pos_embed", "storage_tokens"):
60
+ layer_depth = 0
61
+ elif param_name.startswith("patch_embed"):
62
+ layer_depth = 0
63
+ elif param_name.startswith("blocks"):
64
+ layer_id = int(param_name.split(".")[1])
65
+ layer_depth = layer_id + 1
66
+ else:
67
+ layer_depth = num_layers - 1
68
+
69
+ return layer_depth, num_layers
src/sam3d_body/models/backbones/vit.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from functools import partial
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.utils.checkpoint as checkpoint
9
+
10
+ try:
11
+ from flash_attn.flash_attn_interface import flash_attn_func
12
+ except:
13
+ print("No Flash Attention!")
14
+
15
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
+
17
+ from ..modules.transformer import LayerNorm32
18
+
19
+
20
+ def vit(cfg):
21
+ return ViT(
22
+ img_size=(256, 192),
23
+ patch_size=16,
24
+ embed_dim=1280,
25
+ depth=32,
26
+ num_heads=16,
27
+ ratio=1,
28
+ norm_layer=LayerNorm32,
29
+ use_checkpoint=False,
30
+ mlp_ratio=4,
31
+ qkv_bias=True,
32
+ drop_path_rate=0.55,
33
+ frozen_stages=cfg.MODEL.BACKBONE.get("FROZEN_STAGES", -1),
34
+ flash_attn=cfg.MODEL.BACKBONE.get("FLASH_ATTN", False),
35
+ )
36
+
37
+
38
+ def vit_l(cfg):
39
+ return ViT(
40
+ img_size=(256, 192),
41
+ patch_size=16,
42
+ embed_dim=1024,
43
+ depth=24,
44
+ num_heads=16,
45
+ ratio=1,
46
+ norm_layer=LayerNorm32,
47
+ use_checkpoint=False,
48
+ mlp_ratio=4,
49
+ qkv_bias=True,
50
+ drop_path_rate=0.55,
51
+ frozen_stages=cfg.MODEL.BACKBONE.get("FROZEN_STAGES", -1),
52
+ flash_attn=cfg.MODEL.BACKBONE.get("FLASH_ATTN", False),
53
+ )
54
+
55
+
56
+ def vit_b(cfg):
57
+ return ViT(
58
+ img_size=(256, 192),
59
+ patch_size=16,
60
+ embed_dim=768,
61
+ depth=12,
62
+ num_heads=12,
63
+ ratio=1,
64
+ norm_layer=LayerNorm32,
65
+ use_checkpoint=False,
66
+ mlp_ratio=4,
67
+ qkv_bias=True,
68
+ drop_path_rate=0.3,
69
+ frozen_stages=cfg.MODEL.BACKBONE.get("FROZEN_STAGES", -1),
70
+ flash_attn=cfg.MODEL.BACKBONE.get("FLASH_ATTN", False),
71
+ )
72
+
73
+
74
+ def vit256(cfg):
75
+ return ViT(
76
+ img_size=(256, 256),
77
+ patch_size=16,
78
+ embed_dim=1280,
79
+ depth=32,
80
+ num_heads=16,
81
+ ratio=1,
82
+ norm_layer=LayerNorm32,
83
+ use_checkpoint=False,
84
+ mlp_ratio=4,
85
+ qkv_bias=True,
86
+ drop_path_rate=0.55,
87
+ frozen_stages=cfg.MODEL.BACKBONE.get("FROZEN_STAGES", -1),
88
+ flash_attn=cfg.MODEL.BACKBONE.get("FLASH_ATTN", False),
89
+ )
90
+
91
+
92
+ def vit512_384(cfg):
93
+ return ViT(
94
+ img_size=(512, 384),
95
+ patch_size=16,
96
+ embed_dim=1280,
97
+ depth=32,
98
+ num_heads=16,
99
+ ratio=1,
100
+ norm_layer=LayerNorm32,
101
+ use_checkpoint=False,
102
+ mlp_ratio=4,
103
+ qkv_bias=True,
104
+ drop_path_rate=0.55,
105
+ frozen_stages=cfg.MODEL.BACKBONE.get("FROZEN_STAGES", -1),
106
+ flash_attn=cfg.MODEL.BACKBONE.get("FLASH_ATTN", False),
107
+ )
108
+
109
+
110
+ def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True):
111
+ """
112
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
113
+ dimension for the original embeddings.
114
+ Args:
115
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
116
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
117
+ hw (Tuple): size of input image tokens.
118
+
119
+ Returns:
120
+ Absolute positional embeddings after processing with shape (1, H, W, C)
121
+ """
122
+ cls_token = None
123
+ B, L, C = abs_pos.shape
124
+ if has_cls_token:
125
+ cls_token = abs_pos[:, 0:1]
126
+ abs_pos = abs_pos[:, 1:]
127
+
128
+ if ori_h != h or ori_w != w:
129
+ new_abs_pos = (
130
+ F.interpolate(
131
+ abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
132
+ size=(h, w),
133
+ mode="bicubic",
134
+ align_corners=False,
135
+ )
136
+ .permute(0, 2, 3, 1)
137
+ .reshape(B, -1, C)
138
+ )
139
+
140
+ else:
141
+ new_abs_pos = abs_pos
142
+
143
+ if cls_token is not None:
144
+ new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1)
145
+ return new_abs_pos
146
+
147
+
148
+ class DropPath(nn.Module):
149
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
150
+
151
+ def __init__(self, drop_prob=None):
152
+ super(DropPath, self).__init__()
153
+ self.drop_prob = drop_prob
154
+
155
+ def forward(self, x):
156
+ return drop_path(x, self.drop_prob, self.training)
157
+
158
+ def extra_repr(self):
159
+ return "p={}".format(self.drop_prob)
160
+
161
+
162
+ class Mlp(nn.Module):
163
+ def __init__(
164
+ self,
165
+ in_features,
166
+ hidden_features=None,
167
+ out_features=None,
168
+ act_layer=nn.GELU,
169
+ drop=0.0,
170
+ ):
171
+ super().__init__()
172
+ out_features = out_features or in_features
173
+ hidden_features = hidden_features or in_features
174
+ self.fc1 = nn.Linear(in_features, hidden_features)
175
+ self.act = act_layer()
176
+ self.fc2 = nn.Linear(hidden_features, out_features)
177
+ self.drop = nn.Dropout(drop)
178
+
179
+ def forward(self, x):
180
+ x = self.fc1(x)
181
+ x = self.act(x)
182
+ x = self.fc2(x)
183
+ x = self.drop(x)
184
+ return x
185
+
186
+
187
+ class Attention(nn.Module):
188
+ def __init__(
189
+ self,
190
+ dim,
191
+ num_heads=8,
192
+ qkv_bias=False,
193
+ qk_scale=None,
194
+ attn_drop=0.0,
195
+ proj_drop=0.0,
196
+ attn_head_dim=None,
197
+ ):
198
+ super().__init__()
199
+ self.num_heads = num_heads
200
+ head_dim = dim // num_heads
201
+ self.dim = dim
202
+
203
+ if attn_head_dim is not None:
204
+ head_dim = attn_head_dim
205
+ all_head_dim = head_dim * self.num_heads
206
+
207
+ self.scale = qk_scale or head_dim**-0.5
208
+
209
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
210
+
211
+ self.attn_drop = nn.Dropout(attn_drop)
212
+ self.proj = nn.Linear(all_head_dim, dim)
213
+ self.proj_drop = nn.Dropout(proj_drop)
214
+
215
+ def forward(self, x):
216
+ B, N, C = x.shape
217
+ qkv = self.qkv(x)
218
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
219
+ q, k, v = (
220
+ qkv[0],
221
+ qkv[1],
222
+ qkv[2],
223
+ ) # make torchscript happy (cannot use tensor as tuple)
224
+
225
+ q = q * self.scale
226
+ attn = q @ k.transpose(-2, -1)
227
+
228
+ attn = attn.softmax(dim=-1)
229
+ attn = self.attn_drop(attn)
230
+
231
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
232
+ x = self.proj(x)
233
+ x = self.proj_drop(x)
234
+
235
+ return x
236
+
237
+
238
+ class FlashAttention(nn.Module):
239
+ def __init__(
240
+ self,
241
+ dim,
242
+ num_heads=8,
243
+ qkv_bias=False,
244
+ qk_scale=None,
245
+ attn_drop=0.0,
246
+ proj_drop=0.0,
247
+ attn_head_dim=None,
248
+ ):
249
+ super().__init__()
250
+ self.num_heads = num_heads
251
+ head_dim = attn_head_dim or (dim // num_heads)
252
+ self.head_dim = head_dim
253
+ self.dim = dim
254
+ self.qkv = nn.Linear(dim, head_dim * num_heads * 3, bias=qkv_bias)
255
+ self.proj = nn.Linear(head_dim * num_heads, dim)
256
+ self.proj_drop = nn.Dropout(proj_drop)
257
+ self.attn_drop = attn_drop
258
+
259
+ def forward(self, x):
260
+ B, N, C = x.shape # (batch, sequence_length, embedding_dim)
261
+
262
+ qkv = self.qkv(x) # (B, N, 3 * num_heads * head_dim)
263
+ qkv = qkv.view(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
264
+ q, k, v = qkv[0], qkv[1], qkv[2] # each: (B, num_heads, N, head_dim)
265
+
266
+ # FlashAttention expects (B, N, num_heads, head_dim)
267
+ q = q.transpose(1, 2).contiguous()
268
+ k = k.transpose(1, 2).contiguous()
269
+ v = v.transpose(1, 2).contiguous()
270
+
271
+ # Optional: FlashAttention requires fp16 or bf16
272
+ if q.dtype == torch.float32:
273
+ q = q.half()
274
+ k = k.half()
275
+ v = v.half()
276
+
277
+ out = flash_attn_func(
278
+ q, k, v, dropout_p=self.attn_drop, causal=False
279
+ ) # (B, N, num_heads * head_dim)
280
+
281
+ # If needed, cast back to float32
282
+ out = out.reshape(B, N, -1)
283
+ out = out.to(x.dtype)
284
+ # breakpoint()
285
+ out = self.proj(out)
286
+ out = self.proj_drop(out)
287
+ return out
288
+
289
+
290
+ class Block(nn.Module):
291
+
292
+ def __init__(
293
+ self,
294
+ dim,
295
+ num_heads,
296
+ mlp_ratio=4.0,
297
+ qkv_bias=False,
298
+ qk_scale=None,
299
+ drop=0.0,
300
+ attn_drop=0.0,
301
+ drop_path=0.0,
302
+ act_layer=nn.GELU,
303
+ norm_layer=nn.LayerNorm,
304
+ attn_head_dim=None,
305
+ flash_attn=False,
306
+ ):
307
+ super().__init__()
308
+
309
+ self.norm1 = norm_layer(dim)
310
+ if flash_attn:
311
+ self.attn = FlashAttention(
312
+ dim,
313
+ num_heads=num_heads,
314
+ qkv_bias=qkv_bias,
315
+ qk_scale=qk_scale,
316
+ attn_drop=attn_drop,
317
+ proj_drop=drop,
318
+ attn_head_dim=attn_head_dim,
319
+ )
320
+ else:
321
+ self.attn = Attention(
322
+ dim,
323
+ num_heads=num_heads,
324
+ qkv_bias=qkv_bias,
325
+ qk_scale=qk_scale,
326
+ attn_drop=attn_drop,
327
+ proj_drop=drop,
328
+ attn_head_dim=attn_head_dim,
329
+ )
330
+
331
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
332
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
333
+ self.norm2 = norm_layer(dim)
334
+ mlp_hidden_dim = int(dim * mlp_ratio)
335
+ self.mlp = Mlp(
336
+ in_features=dim,
337
+ hidden_features=mlp_hidden_dim,
338
+ act_layer=act_layer,
339
+ drop=drop,
340
+ )
341
+
342
+ def forward(self, x):
343
+ x = x + self.drop_path(self.attn(self.norm1(x)))
344
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
345
+ return x
346
+
347
+
348
+ class PatchEmbed(nn.Module):
349
+ """Image to Patch Embedding"""
350
+
351
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
352
+ super().__init__()
353
+ img_size = to_2tuple(img_size)
354
+ patch_size = to_2tuple(patch_size)
355
+ num_patches = (
356
+ (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio**2)
357
+ )
358
+ self.patch_shape = (
359
+ int(img_size[0] // patch_size[0] * ratio),
360
+ int(img_size[1] // patch_size[1] * ratio),
361
+ )
362
+ self.origin_patch_shape = (
363
+ int(img_size[0] // patch_size[0]),
364
+ int(img_size[1] // patch_size[1]),
365
+ )
366
+ self.img_size = img_size
367
+ self.patch_size = patch_size
368
+ self.num_patches = num_patches
369
+
370
+ self.proj = nn.Conv2d(
371
+ in_chans,
372
+ embed_dim,
373
+ kernel_size=patch_size,
374
+ stride=(patch_size[0] // ratio),
375
+ padding=4 + 2 * (ratio // 2 - 1),
376
+ )
377
+
378
+ def forward(self, x, **kwargs):
379
+ B, C, H, W = x.shape
380
+ x = self.proj(x)
381
+ Hp, Wp = x.shape[2], x.shape[3]
382
+
383
+ x = x.flatten(2).transpose(1, 2)
384
+ return x, (Hp, Wp)
385
+
386
+
387
+ class PatchEmbedNoPadding(nn.Module):
388
+ """Image to Patch Embedding"""
389
+
390
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
391
+ super().__init__()
392
+ img_size = to_2tuple(img_size)
393
+ patch_size = to_2tuple(patch_size)
394
+ num_patches = (
395
+ (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio**2)
396
+ )
397
+ self.patch_shape = (
398
+ int(img_size[0] // patch_size[0] * ratio),
399
+ int(img_size[1] // patch_size[1] * ratio),
400
+ )
401
+ self.origin_patch_shape = (
402
+ int(img_size[0] // patch_size[0]),
403
+ int(img_size[1] // patch_size[1]),
404
+ )
405
+ self.img_size = img_size
406
+ self.patch_size = patch_size
407
+ self.num_patches = num_patches
408
+
409
+ self.proj = nn.Conv2d(
410
+ in_chans,
411
+ embed_dim,
412
+ kernel_size=patch_size,
413
+ stride=(patch_size[0] // ratio),
414
+ padding=0,
415
+ )
416
+
417
+ def forward(self, x, **kwargs):
418
+ B, C, H, W = x.shape
419
+ x = self.proj(x)
420
+ Hp, Wp = x.shape[2], x.shape[3]
421
+
422
+ x = x.flatten(2).transpose(1, 2)
423
+ return x, (Hp, Wp)
424
+
425
+
426
+ class HybridEmbed(nn.Module):
427
+ """CNN Feature Map Embedding
428
+ Extract feature map from CNN, flatten, project to embedding dim.
429
+ """
430
+
431
+ def __init__(
432
+ self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768
433
+ ):
434
+ super().__init__()
435
+ assert isinstance(backbone, nn.Module)
436
+ img_size = to_2tuple(img_size)
437
+ self.img_size = img_size
438
+ self.backbone = backbone
439
+ if feature_size is None:
440
+ with torch.no_grad():
441
+ training = backbone.training
442
+ if training:
443
+ backbone.eval()
444
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[
445
+ -1
446
+ ]
447
+ feature_size = o.shape[-2:]
448
+ feature_dim = o.shape[1]
449
+ backbone.train(training)
450
+ else:
451
+ feature_size = to_2tuple(feature_size)
452
+ feature_dim = self.backbone.feature_info.channels()[-1]
453
+ self.num_patches = feature_size[0] * feature_size[1]
454
+ self.proj = nn.Linear(feature_dim, embed_dim)
455
+
456
+ def forward(self, x):
457
+ x = self.backbone(x)[-1]
458
+ x = x.flatten(2).transpose(1, 2)
459
+ x = self.proj(x)
460
+ return x
461
+
462
+
463
+ class ViT(nn.Module):
464
+
465
+ def __init__(
466
+ self,
467
+ img_size=224,
468
+ patch_size=16,
469
+ in_chans=3,
470
+ num_classes=80,
471
+ embed_dim=768,
472
+ depth=12,
473
+ num_heads=12,
474
+ mlp_ratio=4.0,
475
+ qkv_bias=False,
476
+ qk_scale=None,
477
+ drop_rate=0.0,
478
+ attn_drop_rate=0.0,
479
+ drop_path_rate=0.0,
480
+ hybrid_backbone=None,
481
+ norm_layer=None,
482
+ use_checkpoint=False,
483
+ frozen_stages=-1,
484
+ ratio=1,
485
+ last_norm=True,
486
+ patch_padding="pad",
487
+ freeze_attn=False,
488
+ freeze_ffn=False,
489
+ flash_attn=False,
490
+ no_patch_padding=False,
491
+ ):
492
+ # Protect mutable default arguments
493
+ super(ViT, self).__init__()
494
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
495
+ self.num_classes = num_classes
496
+ self.num_features = self.embed_dim = self.embed_dims = (
497
+ embed_dim # num_features for consistency with other models
498
+ )
499
+ self.frozen_stages = frozen_stages
500
+ self.use_checkpoint = use_checkpoint
501
+ self.patch_padding = patch_padding
502
+ self.freeze_attn = freeze_attn
503
+ self.freeze_ffn = freeze_ffn
504
+ self.depth = depth
505
+
506
+ if hybrid_backbone is not None:
507
+ self.patch_embed = HybridEmbed(
508
+ hybrid_backbone,
509
+ img_size=img_size,
510
+ in_chans=in_chans,
511
+ embed_dim=embed_dim,
512
+ )
513
+ else:
514
+ if no_patch_padding:
515
+ self.patch_embed = PatchEmbedNoPadding(
516
+ img_size=img_size,
517
+ patch_size=patch_size,
518
+ in_chans=in_chans,
519
+ embed_dim=embed_dim,
520
+ ratio=ratio,
521
+ )
522
+ else:
523
+ self.patch_embed = PatchEmbed(
524
+ img_size=img_size,
525
+ patch_size=patch_size,
526
+ in_chans=in_chans,
527
+ embed_dim=embed_dim,
528
+ ratio=ratio,
529
+ )
530
+ num_patches = self.patch_embed.num_patches
531
+ self.patch_size = patch_size
532
+
533
+ # since the pretraining model has class token
534
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
535
+
536
+ dpr = [
537
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
538
+ ] # stochastic depth decay rule
539
+
540
+ self.blocks = nn.ModuleList(
541
+ [
542
+ Block(
543
+ dim=embed_dim,
544
+ num_heads=num_heads,
545
+ mlp_ratio=mlp_ratio,
546
+ qkv_bias=qkv_bias,
547
+ qk_scale=qk_scale,
548
+ drop=drop_rate,
549
+ attn_drop=attn_drop_rate,
550
+ drop_path=dpr[i],
551
+ norm_layer=norm_layer,
552
+ flash_attn=flash_attn,
553
+ )
554
+ for i in range(depth)
555
+ ]
556
+ )
557
+
558
+ self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
559
+
560
+ if self.pos_embed is not None:
561
+ trunc_normal_(self.pos_embed, std=0.02)
562
+
563
+ self._freeze_stages()
564
+
565
+ def _freeze_stages(self):
566
+ """Freeze parameters."""
567
+ if self.frozen_stages >= 0:
568
+ self.patch_embed.eval()
569
+ for param in self.patch_embed.parameters():
570
+ param.requires_grad = False
571
+
572
+ for i in range(1, self.frozen_stages + 1):
573
+ m = self.blocks[i - 1]
574
+ m.eval()
575
+ for param in m.parameters():
576
+ param.requires_grad = False
577
+
578
+ if self.freeze_attn:
579
+ for i in range(0, self.depth):
580
+ m = self.blocks[i]
581
+ m.attn.eval()
582
+ m.norm1.eval()
583
+ for param in m.attn.parameters():
584
+ param.requires_grad = False
585
+ for param in m.norm1.parameters():
586
+ param.requires_grad = False
587
+
588
+ if self.freeze_ffn:
589
+ self.pos_embed.requires_grad = False
590
+ self.patch_embed.eval()
591
+ for param in self.patch_embed.parameters():
592
+ param.requires_grad = False
593
+ for i in range(0, self.depth):
594
+ m = self.blocks[i]
595
+ m.mlp.eval()
596
+ m.norm2.eval()
597
+ for param in m.mlp.parameters():
598
+ param.requires_grad = False
599
+ for param in m.norm2.parameters():
600
+ param.requires_grad = False
601
+
602
+ def init_weights(self):
603
+ """Initialize the weights in backbone.
604
+ Args:
605
+ pretrained (str, optional): Path to pre-trained weights.
606
+ Defaults to None.
607
+ """
608
+
609
+ def _init_weights(m):
610
+ if isinstance(m, nn.Linear):
611
+ trunc_normal_(m.weight, std=0.02)
612
+ if isinstance(m, nn.Linear) and m.bias is not None:
613
+ nn.init.constant_(m.bias, 0)
614
+ elif isinstance(m, nn.LayerNorm):
615
+ nn.init.constant_(m.bias, 0)
616
+ nn.init.constant_(m.weight, 1.0)
617
+
618
+ self.apply(_init_weights)
619
+
620
+ def get_num_layers(self):
621
+ return len(self.blocks)
622
+
623
+ @torch.jit.ignore
624
+ def no_weight_decay(self):
625
+ return {"pos_embed", "cls_token"}
626
+
627
+ def forward_features(self, x, extra_embed=None):
628
+ B, C, H, W = x.shape
629
+ x, (Hp, Wp) = self.patch_embed(x)
630
+
631
+ if self.pos_embed is not None:
632
+ # fit for multiple GPU training
633
+ # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
634
+ x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
635
+
636
+ if extra_embed is not None:
637
+ x = x + extra_embed.flatten(2).transpose(1, 2).to(x)
638
+
639
+ for blk in self.blocks:
640
+ if self.use_checkpoint:
641
+ x = checkpoint.checkpoint(blk, x)
642
+ else:
643
+ x = blk(x)
644
+
645
+ x = self.last_norm(x)
646
+
647
+ xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
648
+
649
+ return xp
650
+
651
+ def forward(self, x, *args, **kwargs):
652
+ x = self.forward_features(x, *args, **kwargs)
653
+ return x
654
+
655
+ def train(self, mode=True):
656
+ """Convert the model into training mode."""
657
+ super().train(mode)
658
+ self._freeze_stages()
src/sam3d_body/models/decoders/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from .keypoint_prompt_sampler import build_keypoint_sampler
4
+ from .prompt_encoder import PromptEncoder
5
+ from .promptable_decoder import PromptableDecoder
6
+
7
+
8
+ def build_decoder(cfg, context_dim=None):
9
+ from .promptable_decoder import PromptableDecoder
10
+
11
+ if cfg.TYPE == "sam":
12
+ return PromptableDecoder(
13
+ dims=cfg.DIM,
14
+ context_dims=context_dim,
15
+ depth=cfg.DEPTH,
16
+ num_heads=cfg.HEADS,
17
+ head_dims=cfg.DIM_HEAD,
18
+ mlp_dims=cfg.MLP_DIM,
19
+ layer_scale_init_value=cfg.LAYER_SCALE_INIT,
20
+ drop_rate=cfg.DROP_RATE,
21
+ attn_drop_rate=cfg.ATTN_DROP_RATE,
22
+ drop_path_rate=cfg.DROP_PATH_RATE,
23
+ ffn_type=cfg.FFN_TYPE,
24
+ enable_twoway=cfg.ENABLE_TWOWAY,
25
+ repeat_pe=cfg.REPEAT_PE,
26
+ frozen=cfg.get("FROZEN", False),
27
+ do_interm_preds=cfg.get("DO_INTERM_PREDS", False),
28
+ do_keypoint_tokens=cfg.get("DO_KEYPOINT_TOKENS", False),
29
+ keypoint_token_update=cfg.get("KEYPOINT_TOKEN_UPDATE", None),
30
+ )
31
+ else:
32
+ raise ValueError("Invalid decoder type: ", cfg.TYPE)
src/sam3d_body/models/decoders/keypoint_prompt_sampler.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import random
4
+ from abc import ABC, abstractmethod
5
+ from typing import Dict, List
6
+
7
+ import torch
8
+
9
+ from omegaconf import DictConfig
10
+ from yacs.config import CfgNode
11
+
12
+
13
+ def build_keypoint_sampler(sampler_cfg, prompt_keypoints, keybody_idx):
14
+ sampler_type = sampler_cfg.get("TYPE", "v1")
15
+ if sampler_type == "v1":
16
+ sampler_cls = KeypointSamplerV1
17
+ else:
18
+ raise ValueError("Invalid sampler type: ", sampler_type)
19
+
20
+ return sampler_cls(sampler_cfg, prompt_keypoints, keybody_idx)
21
+
22
+
23
+ class BaseKeypointSampler(ABC):
24
+ @abstractmethod
25
+ def sample(
26
+ self, gt_keypoints: torch.Tensor, pred_keypoints: torch.Tensor, is_train: bool
27
+ ) -> torch.Tensor:
28
+ pass
29
+
30
+ def _get_worst_keypoint(self, distances, keypoint_list):
31
+ # Set distance to -1 for non-promptable keypoints
32
+ cur_dist = torch.ones_like(distances) * -1
33
+ cur_dist[keypoint_list] = distances[keypoint_list]
34
+ keypoint_idx = int(cur_dist.argmax())
35
+ if cur_dist[keypoint_idx] > self.distance_thresh:
36
+ valid_keypoint = True
37
+ else:
38
+ valid_keypoint = False
39
+ return keypoint_idx, valid_keypoint
40
+
41
+ def _get_random_keypoint(self, distances, keypoint_list):
42
+ candidates = [idx for idx in keypoint_list if distances[idx] > 0]
43
+ if len(candidates):
44
+ keypoint_idx = random.choice(candidates)
45
+ valid_keypoint = True
46
+ else:
47
+ keypoint_idx = None
48
+ valid_keypoint = False
49
+ return keypoint_idx, valid_keypoint
50
+
51
+ def _masked_distance(self, x, y, mask=None):
52
+ """
53
+ Args:
54
+ x, y: [B, K, D]
55
+ mask: [B, K]
56
+ Return:
57
+ distances: [K, B]
58
+ """
59
+ distances = (x - y).pow(2).sum(dim=-1)
60
+ if mask is not None:
61
+ distances[mask] = -1
62
+ return distances.T
63
+
64
+
65
+ class KeypointSamplerV1(BaseKeypointSampler):
66
+ def __init__(
67
+ self,
68
+ sampler_cfg: DictConfig | CfgNode,
69
+ prompt_keypoints: Dict,
70
+ keybody_idx: List,
71
+ ):
72
+ self.prompt_keypoints = prompt_keypoints
73
+ self._keybody_idx = keybody_idx
74
+ self._non_keybody_idx = [
75
+ idx for idx in self.prompt_keypoints if idx not in self._keybody_idx
76
+ ]
77
+
78
+ self.keybody_ratio = sampler_cfg.get("KEYBODY_RATIO", 0.8)
79
+ self.worst_ratio = sampler_cfg.get("WORST_RATIO", 0.8)
80
+ self.negative_ratio = sampler_cfg.get("NEGATIVE_RATIO", 0.0)
81
+ self.dummy_ratio = sampler_cfg.get("DUMMY_RATIO", 0.1)
82
+ self.distance_thresh = sampler_cfg.get("DISTANCE_THRESH", 0.0)
83
+
84
+ def sample(
85
+ self,
86
+ gt_keypoints_2d: torch.Tensor,
87
+ pred_keypoints_2d: torch.Tensor,
88
+ is_train: bool = True,
89
+ force_dummy: bool = False,
90
+ ) -> torch.Tensor:
91
+ # Get the distance between each predicted and gt keypoint
92
+ # Elements will be ignored if (1) the gt has low confidence or
93
+ # (2) both the gt and pred are outside of the image
94
+ mask_1 = gt_keypoints_2d[:, :, -1] < 0.5
95
+ mask_2 = (
96
+ (gt_keypoints_2d[:, :, :2] > 0.5) | (gt_keypoints_2d[:, :, :2] < -0.5)
97
+ ).any(dim=-1)
98
+
99
+ # Elements to be ignored
100
+ if not is_train or torch.rand(1).item() > self.negative_ratio:
101
+ mask = mask_1 | mask_2
102
+ # print_base = "positive"
103
+ else:
104
+ mask_3 = (
105
+ (pred_keypoints_2d[:, :, :2] > 0.5)
106
+ | (pred_keypoints_2d[:, :, :2] < -0.5)
107
+ ).any(dim=-1)
108
+ # To include negative prompts
109
+ mask = mask_1 | (mask_2 & mask_3)
110
+ # print_base = "negative"
111
+
112
+ # Get pairwise distances with shape [K, B]
113
+ distances = self._masked_distance(
114
+ pred_keypoints_2d, gt_keypoints_2d[..., :2], mask
115
+ )
116
+
117
+ batch_size = distances.shape[1]
118
+ keypoints_prompt = []
119
+ for b in range(batch_size):
120
+ # print_str = print_base
121
+
122
+ # Decide to get the worst keypoint or a random keypoint
123
+ if not is_train or torch.rand(1).item() < self.worst_ratio:
124
+ sampler = self._get_worst_keypoint
125
+ # print_str += "_worst"
126
+ else:
127
+ sampler = self._get_random_keypoint
128
+ # print_str += "_random"
129
+
130
+ # Decide to prompt keybody kepoints or non-keybody ones
131
+ if not is_train or torch.rand(1).item() < self.keybody_ratio:
132
+ cur_idx = self._keybody_idx
133
+ alt_idx = self._non_keybody_idx
134
+ # print_str += "_keybody"
135
+ else:
136
+ cur_idx = self._non_keybody_idx
137
+ alt_idx = self._keybody_idx
138
+ # print_str += "_nonkey"
139
+
140
+ # Get a valid or dummy prompt
141
+ if not is_train or torch.rand(1).item() > self.dummy_ratio:
142
+ keypoint_idx, valid_keypoint = sampler(distances[:, b], cur_idx)
143
+
144
+ if not valid_keypoint:
145
+ # Try the alternative keypoints
146
+ keypoint_idx, valid_keypoint = self._get_worst_keypoint(
147
+ distances[:, b], alt_idx
148
+ )
149
+ else:
150
+ valid_keypoint = False
151
+
152
+ if valid_keypoint:
153
+ cur_point = gt_keypoints_2d[b, keypoint_idx].clone()
154
+ if torch.any(cur_point[:2] > 0.5) or torch.any(cur_point[:2] < -0.5):
155
+ # Negative prompt --> indicating the predicted keypoint is incorrect
156
+ cur_point[:2] = pred_keypoints_2d[b, keypoint_idx][:2]
157
+ cur_point = torch.clamp(
158
+ cur_point + 0.5, min=0.0, max=1.0
159
+ ) # shift from [-0.5, 0.5] to [0, 1]
160
+ cur_point[-1] = -1
161
+ # print_str += "_negative"
162
+ else:
163
+ cur_point = torch.clamp(
164
+ cur_point + 0.5, min=0.0, max=1.0
165
+ ) # shift from [-0.5, 0.5] to [0, 1]
166
+ cur_point[-1] = self.prompt_keypoints[
167
+ keypoint_idx
168
+ ] # map to prompt_idx
169
+ # print_str += "_positive"
170
+ else:
171
+ cur_point = torch.zeros(3).to(gt_keypoints_2d)
172
+ cur_point[-1] = -2
173
+ # print_str += "_dummy"
174
+
175
+ if force_dummy:
176
+ cur_point = torch.zeros(3).to(gt_keypoints_2d)
177
+ cur_point[-1] = -2
178
+
179
+ keypoints_prompt.append(cur_point)
180
+ # print(print_str)
181
+
182
+ keypoints_prompt = torch.stack(keypoints_prompt, dim=0).view(batch_size, 1, 3)
183
+ return keypoints_prompt
src/sam3d_body/models/decoders/prompt_encoder.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from typing import Any, Optional, Tuple
4
+
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from sam3d_body.models.modules.transformer import LayerNorm2d
11
+
12
+
13
+ class PromptEncoder(nn.Module):
14
+ def __init__(
15
+ self,
16
+ embed_dim: int,
17
+ num_body_joints: int,
18
+ # img_size: Tuple[int, int],
19
+ # patch_resolution: Tuple[int, int],
20
+ frozen: bool = False,
21
+ mask_embed_type: Optional[str] = None,
22
+ ) -> None:
23
+ """
24
+ Encodes prompts for input to SAM's mask decoder.
25
+
26
+ Arguments:
27
+ embed_dim (int): The prompts' embedding dimension
28
+ num_body_joints (int): The number of body joints
29
+ img_size (Tuple): The padded size of the image as input
30
+ to the image encoder, as (H, W).
31
+ patch_resolution (Tuple): image path size, as (H, W)
32
+ """
33
+ super().__init__()
34
+ self.embed_dim = embed_dim
35
+ self.num_body_joints = num_body_joints
36
+ # self.img_size = img_size
37
+ # self.patch_resolution = patch_resolution
38
+
39
+ # Keypoint prompts
40
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
41
+ self.point_embeddings = nn.ModuleList(
42
+ [nn.Embedding(1, embed_dim) for _ in range(self.num_body_joints)]
43
+ )
44
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
45
+ self.invalid_point_embed = nn.Embedding(1, embed_dim)
46
+
47
+ # Mask prompt
48
+ if mask_embed_type in ["v1"]:
49
+ mask_in_chans = 16 # SAM2
50
+ self.mask_downscaling = nn.Sequential(
51
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=4, stride=4),
52
+ LayerNorm2d(mask_in_chans // 4),
53
+ nn.GELU(),
54
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=4, stride=4),
55
+ LayerNorm2d(mask_in_chans),
56
+ nn.GELU(),
57
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
58
+ )
59
+ elif mask_embed_type in ["v2"]:
60
+ mask_in_chans = 256
61
+ self.mask_downscaling = nn.Sequential(
62
+ nn.Conv2d(1, mask_in_chans // 64, kernel_size=2, stride=2),
63
+ LayerNorm2d(mask_in_chans // 64),
64
+ nn.GELU(),
65
+ nn.Conv2d(
66
+ mask_in_chans // 64,
67
+ mask_in_chans // 16,
68
+ kernel_size=2,
69
+ stride=2,
70
+ ),
71
+ LayerNorm2d(mask_in_chans // 16),
72
+ nn.GELU(),
73
+ nn.Conv2d(
74
+ mask_in_chans // 16, mask_in_chans // 4, kernel_size=2, stride=2
75
+ ),
76
+ LayerNorm2d(mask_in_chans // 4),
77
+ nn.GELU(),
78
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
79
+ LayerNorm2d(mask_in_chans),
80
+ nn.GELU(),
81
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
82
+ )
83
+ else:
84
+ assert mask_embed_type is None
85
+
86
+ if mask_embed_type is not None:
87
+ # Zero-initialize the last conv layer as gating
88
+ nn.init.zeros_(self.mask_downscaling[-1].weight)
89
+ nn.init.zeros_(self.mask_downscaling[-1].bias)
90
+
91
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
92
+ nn.init.zeros_(self.no_mask_embed.weight)
93
+
94
+ self.frozen = frozen
95
+ self._freeze_stages()
96
+
97
+ def get_dense_pe(self, size: Tuple[int, int]) -> torch.Tensor:
98
+ """
99
+ Returns the positional encoding used to encode point prompts,
100
+ applied to a dense set of points the shape of the image encoding.
101
+
102
+ Returns:
103
+ torch.Tensor: Positional encoding with shape
104
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
105
+ """
106
+ return self.pe_layer(size).unsqueeze(0)
107
+
108
+ def _embed_keypoints(
109
+ self,
110
+ points: torch.Tensor,
111
+ labels: torch.Tensor,
112
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
113
+ """
114
+ Embeds point prompts.
115
+ Assuming points have been normalized to [0, 1].
116
+
117
+ Output shape [B, N, C], mask shape [B, N]
118
+ """
119
+ assert points.min() >= 0 and points.max() <= 1
120
+ point_embedding = self.pe_layer._pe_encoding(points.to(torch.float))
121
+ point_embedding[labels == -2] = 0.0 # invalid points
122
+ point_embedding[labels == -2] += self.invalid_point_embed.weight
123
+ point_embedding[labels == -1] = 0.0
124
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
125
+ for i in range(self.num_body_joints):
126
+ point_embedding[labels == i] += self.point_embeddings[i].weight
127
+
128
+ point_mask = labels > -2
129
+ return point_embedding, point_mask
130
+
131
+ def _get_batch_size(
132
+ self,
133
+ keypoints: Optional[torch.Tensor],
134
+ boxes: Optional[torch.Tensor],
135
+ masks: Optional[torch.Tensor],
136
+ ) -> int:
137
+ """
138
+ Gets the batch size of the output given the batch size of the input prompts.
139
+ """
140
+ if keypoints is not None:
141
+ return keypoints.shape[0]
142
+ elif boxes is not None:
143
+ return boxes.shape[0]
144
+ elif masks is not None:
145
+ return masks.shape[0]
146
+ else:
147
+ return 1
148
+
149
+ def _get_device(self) -> torch.device:
150
+ return self.point_embeddings[0].weight.device
151
+
152
+ def forward(
153
+ self,
154
+ keypoints: Optional[torch.Tensor],
155
+ boxes: Optional[torch.Tensor] = None,
156
+ masks: Optional[torch.Tensor] = None,
157
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
158
+ """
159
+ Embeds different types of prompts, returning both sparse and dense
160
+ embeddings.
161
+
162
+ Arguments:
163
+ keypoints (torchTensor or none): point coordinates and labels to embed.
164
+ boxes (torch.Tensor or none): boxes to embed
165
+ masks (torch.Tensor or none): masks to embed
166
+
167
+ Returns:
168
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
169
+ BxNx(embed_dim), where N is determined by the number of input points
170
+ and boxes.
171
+ torch.Tensor: dense embeddings for the masks, in the shape
172
+ Bx(embed_dim)x(embed_H)x(embed_W)
173
+ """
174
+ bs = self._get_batch_size(keypoints, boxes, masks)
175
+ sparse_embeddings = torch.empty(
176
+ (bs, 0, self.embed_dim), device=self._get_device()
177
+ )
178
+ sparse_masks = torch.empty((bs, 0), device=self._get_device())
179
+ if keypoints is not None:
180
+ coords = keypoints[:, :, :2]
181
+ labels = keypoints[:, :, -1]
182
+ point_embeddings, point_mask = self._embed_keypoints(
183
+ coords, labels
184
+ ) # pad=(boxes is None))
185
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
186
+ sparse_masks = torch.cat([sparse_masks, point_mask], dim=1)
187
+
188
+ return sparse_embeddings, sparse_masks
189
+
190
+ def get_mask_embeddings(
191
+ self,
192
+ masks: Optional[torch.Tensor] = None,
193
+ bs: int = 1,
194
+ size: Tuple[int, int] = (16, 16), # [H, W]
195
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
196
+ """Embeds mask inputs."""
197
+ no_mask_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
198
+ bs, -1, size[0], size[1]
199
+ )
200
+ if masks is not None:
201
+ mask_embeddings = self.mask_downscaling(masks)
202
+ else:
203
+ mask_embeddings = no_mask_embeddings
204
+ return mask_embeddings, no_mask_embeddings
205
+
206
+ def _freeze_stages(self):
207
+ """Freeze parameters."""
208
+ if self.frozen:
209
+ for param in self.parameters():
210
+ param.requires_grad = False
211
+
212
+
213
+ class PositionEmbeddingRandom(nn.Module):
214
+ """
215
+ Positional encoding using random spatial frequencies.
216
+ """
217
+
218
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
219
+ super().__init__()
220
+ if scale is None or scale <= 0.0:
221
+ scale = 1.0
222
+ self.register_buffer(
223
+ "positional_encoding_gaussian_matrix",
224
+ scale * torch.randn((2, num_pos_feats)),
225
+ )
226
+
227
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
228
+ """Positionally encode points that are normalized to [0,1]."""
229
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
230
+ coords = 2 * coords - 1
231
+ coords = coords @ self.positional_encoding_gaussian_matrix
232
+ coords = 2 * np.pi * coords
233
+ # outputs d_1 x ... x d_n x C shape
234
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
235
+
236
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
237
+ """Generate positional encoding for a grid of the specified size."""
238
+ h, w = size
239
+ device: Any = self.positional_encoding_gaussian_matrix.device
240
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
241
+ y_embed = grid.cumsum(dim=0) - 0.5
242
+ x_embed = grid.cumsum(dim=1) - 0.5
243
+ y_embed = y_embed / h
244
+ x_embed = x_embed / w
245
+
246
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
247
+ return pe.permute(2, 0, 1) # C x H x W
248
+
249
+ def forward_with_coords(
250
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
251
+ ) -> torch.Tensor:
252
+ """Positionally encode points that are not normalized to [0,1]."""
253
+ coords = coords_input.clone()
254
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
255
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
256
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
src/sam3d_body/models/decoders/promptable_decoder.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import pickle
4
+ from typing import Dict, Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from ..modules.transformer import build_norm_layer, TransformerDecoderLayer
10
+
11
+
12
+ class PromptableDecoder(nn.Module):
13
+ """Cross-attention based Transformer decoder with prompts input.
14
+
15
+ Args:
16
+ token_dims (int): The dimension of input pose tokens.
17
+ prompt_dims (int): The dimension of input prompt tokens.
18
+ context_dims (int): The dimension of image context features.
19
+ dims (int): The projected dimension of all tokens in the decoder.
20
+ depth (int): The number of layers for Transformer decoder.
21
+ num_heads (int): The number of heads for multi-head attention.
22
+ head_dims (int): The dimension of each head.
23
+ mlp_dims (int): The dimension of hidden layers in MLP.
24
+ layer_scale_init_value (float or torch.Tensor): Init value of layer
25
+ scale. Defaults to 0.
26
+ drop_rate (float): Probability of an element to be zeroed
27
+ after the feed forward layer. Defaults to 0.
28
+ attn_drop_rate (float): The drop out rate for attention output weights.
29
+ Defaults to 0.
30
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.
31
+ ffn_type (str): Select the type of ffn layers. Defaults to 'origin'.
32
+ act_layer (nn.Module, optional): The activation layer for FFNs.
33
+ Default: nn.GELU
34
+ norm_cfg (dict): Config dict for normalization layer.
35
+ Defaults to ``dict(type='LN')``.
36
+ enable_twoway (bool): Whether to enable two-way Transformer (used in SAM).
37
+ repeat_pe (bool): Whether to re-add PE at each layer (used in SAM)
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ dims: int,
43
+ context_dims: int,
44
+ depth: int,
45
+ num_heads: int = 8,
46
+ head_dims: int = 64,
47
+ mlp_dims: int = 1024,
48
+ layer_scale_init_value: float = 0.0,
49
+ drop_rate: float = 0.0,
50
+ attn_drop_rate: float = 0.0,
51
+ drop_path_rate: float = 0.0,
52
+ ffn_type: str = "origin",
53
+ act_layer: nn.Module = nn.GELU,
54
+ norm_cfg: Dict = dict(type="LN", eps=1e-6),
55
+ enable_twoway: bool = False,
56
+ repeat_pe: bool = False,
57
+ frozen: bool = False,
58
+ do_interm_preds: bool = False,
59
+ do_keypoint_tokens: bool = False,
60
+ keypoint_token_update: bool | str = False,
61
+ ):
62
+ super().__init__()
63
+
64
+ self.layers = nn.ModuleList()
65
+ for i in range(depth):
66
+ self.layers.append(
67
+ TransformerDecoderLayer(
68
+ token_dims=dims,
69
+ context_dims=context_dims,
70
+ num_heads=num_heads,
71
+ head_dims=head_dims,
72
+ mlp_dims=mlp_dims,
73
+ layer_scale_init_value=layer_scale_init_value,
74
+ drop_rate=drop_rate,
75
+ attn_drop_rate=attn_drop_rate,
76
+ drop_path_rate=drop_path_rate,
77
+ ffn_type=ffn_type,
78
+ act_layer=act_layer,
79
+ norm_cfg=norm_cfg,
80
+ enable_twoway=enable_twoway,
81
+ repeat_pe=repeat_pe,
82
+ skip_first_pe=(i == 0),
83
+ )
84
+ )
85
+
86
+ self.norm_final = build_norm_layer(norm_cfg, dims)
87
+ self.do_interm_preds = do_interm_preds
88
+ self.do_keypoint_tokens = do_keypoint_tokens
89
+ self.keypoint_token_update = keypoint_token_update
90
+
91
+ self.frozen = frozen
92
+ self._freeze_stages()
93
+
94
+ def forward(
95
+ self,
96
+ token_embedding: torch.Tensor,
97
+ image_embedding: torch.Tensor,
98
+ token_augment: Optional[torch.Tensor] = None,
99
+ image_augment: Optional[torch.Tensor] = None,
100
+ token_mask: Optional[torch.Tensor] = None,
101
+ channel_first: bool = True,
102
+ token_to_pose_output_fn=None,
103
+ keypoint_token_update_fn=None,
104
+ hand_embeddings=None,
105
+ hand_augment=None,
106
+ ):
107
+ """
108
+ Args:
109
+ token_embedding: [B, N, C]
110
+ image_embedding: [B, C, H, W]
111
+ """
112
+ if channel_first:
113
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
114
+ if image_augment is not None:
115
+ image_augment = image_augment.flatten(2).permute(0, 2, 1)
116
+ if hand_embeddings is not None:
117
+ hand_embeddings = hand_embeddings.flatten(2).permute(0, 2, 1)
118
+ hand_augment = hand_augment.flatten(2).permute(0, 2, 1)
119
+ if len(hand_augment) == 1:
120
+ # inflate batch dimension
121
+ assert len(hand_augment.shape) == 3
122
+ hand_augment = hand_augment.repeat(len(hand_embeddings), 1, 1)
123
+
124
+ if self.do_interm_preds:
125
+ assert token_to_pose_output_fn is not None
126
+ all_pose_outputs = []
127
+
128
+ for layer_idx, layer in enumerate(self.layers):
129
+ if hand_embeddings is None:
130
+ token_embedding, image_embedding = layer(
131
+ token_embedding,
132
+ image_embedding,
133
+ token_augment,
134
+ image_augment,
135
+ token_mask,
136
+ )
137
+ else:
138
+ token_embedding, image_embedding = layer(
139
+ token_embedding,
140
+ torch.cat([image_embedding, hand_embeddings], dim=1),
141
+ token_augment,
142
+ torch.cat([image_augment, hand_augment], dim=1),
143
+ token_mask,
144
+ )
145
+ image_embedding = image_embedding[:, : image_augment.shape[1]]
146
+
147
+ if self.do_interm_preds and layer_idx < len(self.layers) - 1:
148
+ curr_pose_output = token_to_pose_output_fn(
149
+ self.norm_final(token_embedding),
150
+ prev_pose_output=(
151
+ all_pose_outputs[-1] if len(all_pose_outputs) > 0 else None
152
+ ),
153
+ layer_idx=layer_idx,
154
+ )
155
+ all_pose_outputs.append(curr_pose_output)
156
+
157
+ if self.keypoint_token_update:
158
+ assert keypoint_token_update_fn is not None
159
+ token_embedding, token_augment, _, _ = keypoint_token_update_fn(
160
+ token_embedding, token_augment, curr_pose_output, layer_idx
161
+ )
162
+
163
+ out = self.norm_final(token_embedding)
164
+
165
+ if self.do_interm_preds:
166
+ curr_pose_output = token_to_pose_output_fn(
167
+ out,
168
+ prev_pose_output=(
169
+ all_pose_outputs[-1] if len(all_pose_outputs) > 0 else None
170
+ ),
171
+ layer_idx=layer_idx,
172
+ )
173
+ all_pose_outputs.append(curr_pose_output)
174
+
175
+ return out, all_pose_outputs
176
+ else:
177
+ return out
178
+
179
+ def _freeze_stages(self):
180
+ """Freeze parameters."""
181
+ if self.frozen:
182
+ for layer in self.layers:
183
+ layer.eval()
184
+ self.norm_final.eval()
185
+ for param in self.parameters():
186
+ param.requires_grad = False
187
+
188
+ def train(self, mode=True):
189
+ """
190
+ Convert the model into training mode.
191
+ (not called by lightning in trainer.fit() actually)
192
+ """
193
+ super().train(mode)
194
+ self._freeze_stages()
src/sam3d_body/models/heads/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from ..modules import to_2tuple
4
+ from .camera_head import PerspectiveHead
5
+ from .mhr_head import MHRHead
6
+
7
+
8
+ def build_head(cfg, head_type="mhr", enable_hand_model=False, default_scale_factor=1.0):
9
+ if head_type == "mhr":
10
+ return MHRHead(
11
+ input_dim=cfg.MODEL.DECODER.DIM,
12
+ mlp_depth=cfg.MODEL.MHR_HEAD.get("MLP_DEPTH", 1),
13
+ mhr_model_path=cfg.MODEL.MHR_HEAD.MHR_MODEL_PATH,
14
+ mlp_channel_div_factor=cfg.MODEL.MHR_HEAD.get("MLP_CHANNEL_DIV_FACTOR", 1),
15
+ enable_hand_model=enable_hand_model,
16
+ )
17
+ elif head_type == "perspective":
18
+ return PerspectiveHead(
19
+ input_dim=cfg.MODEL.DECODER.DIM,
20
+ img_size=to_2tuple(cfg.MODEL.IMAGE_SIZE),
21
+ mlp_depth=cfg.MODEL.get("CAMERA_HEAD", dict()).get("MLP_DEPTH", 1),
22
+ mlp_channel_div_factor=cfg.MODEL.get("CAMERA_HEAD", dict()).get(
23
+ "MLP_CHANNEL_DIV_FACTOR", 1
24
+ ),
25
+ default_scale_factor=default_scale_factor,
26
+ )
27
+ else:
28
+ raise ValueError("Invalid head type: ", head_type)
src/sam3d_body/models/heads/camera_head.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from typing import Optional, Sequence, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from sam3d_body.models.modules.geometry_utils import perspective_projection
9
+
10
+ from ..modules import get_intrinsic_matrix, to_2tuple
11
+ from ..modules.transformer import FFN
12
+
13
+
14
+ class PerspectiveHead(nn.Module):
15
+ """
16
+ Predict camera translation (s, tx, ty) and perform full-perspective
17
+ 2D reprojection (CLIFF/CameraHMR setup).
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ input_dim: int,
23
+ img_size: Tuple[int, int] | Sequence[int], # model input size (W, H)
24
+ mlp_depth: int = 1,
25
+ drop_ratio: float = 0.0,
26
+ mlp_channel_div_factor: int = 8,
27
+ default_scale_factor: float | int = 1,
28
+ ):
29
+ super().__init__()
30
+
31
+ # Metadata to compute 3D skeleton and 2D reprojection
32
+ self.img_size = to_2tuple(img_size)
33
+ self.ncam = 3 # (s, tx, ty)
34
+ self.default_scale_factor = default_scale_factor
35
+
36
+ self.proj = FFN(
37
+ embed_dims=input_dim,
38
+ feedforward_channels=input_dim // mlp_channel_div_factor,
39
+ output_dims=self.ncam,
40
+ num_fcs=mlp_depth,
41
+ ffn_drop=drop_ratio,
42
+ add_identity=False,
43
+ )
44
+
45
+ def forward(
46
+ self,
47
+ x: torch.Tensor,
48
+ init_estimate: Optional[torch.Tensor] = None,
49
+ ):
50
+ """
51
+ Args:
52
+ x: pose token with shape [B, C], usually C=DECODER.DIM
53
+ init_estimate: [B, self.ncam]
54
+ """
55
+ pred_cam = self.proj(x)
56
+ if init_estimate is not None:
57
+ pred_cam = pred_cam + init_estimate
58
+
59
+ return pred_cam
60
+
61
+ def perspective_projection(
62
+ self,
63
+ points_3d: torch.Tensor,
64
+ pred_cam: torch.Tensor,
65
+ bbox_center: torch.Tensor,
66
+ bbox_size: torch.Tensor,
67
+ img_size: torch.Tensor,
68
+ cam_int: torch.Tensor,
69
+ use_intrin_center: bool = False,
70
+ ):
71
+ """
72
+ Args:
73
+ bbox_center / img_size: shape [N, 2], in original image space (w, h)
74
+ bbox_size: shape [N,], in original image space
75
+ cam_int: shape [N, 3, 3]
76
+ """
77
+ batch_size = points_3d.shape[0]
78
+ pred_cam = pred_cam.clone()
79
+ pred_cam[..., [0, 2]] *= -1 # Camera system difference
80
+
81
+ # Compute camera translation: (scale, x, y) --> (x, y, depth)
82
+ # depth ~= f / s
83
+ # Note that f is in the NDC space (see Zolly section 3.1)
84
+ s, tx, ty = pred_cam[:, 0], pred_cam[:, 1], pred_cam[:, 2]
85
+ bs = bbox_size * s * self.default_scale_factor + 1e-8
86
+ focal_length = cam_int[:, 0, 0]
87
+ tz = 2 * focal_length / bs
88
+
89
+ if not use_intrin_center:
90
+ cx = 2 * (bbox_center[:, 0] - (img_size[:, 0] / 2)) / bs
91
+ cy = 2 * (bbox_center[:, 1] - (img_size[:, 1] / 2)) / bs
92
+ else:
93
+ cx = 2 * (bbox_center[:, 0] - (cam_int[:, 0, 2])) / bs
94
+ cy = 2 * (bbox_center[:, 1] - (cam_int[:, 1, 2])) / bs
95
+
96
+ pred_cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1)
97
+
98
+ # Compute camera translation
99
+ j3d_cam = points_3d + pred_cam_t.unsqueeze(1)
100
+
101
+ # Projection to the image plane.
102
+ # Note that the projection output is in *original* image space now.
103
+ j2d = perspective_projection(j3d_cam, cam_int)
104
+
105
+ return {
106
+ "pred_keypoints_2d": j2d.reshape(batch_size, -1, 2),
107
+ "pred_cam_t": pred_cam_t,
108
+ "focal_length": focal_length,
109
+ "pred_keypoints_2d_depth": j3d_cam.reshape(batch_size, -1, 3)[:, :, 2],
110
+ }
src/sam3d_body/models/heads/mhr_head.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import os
4
+ import warnings
5
+ from typing import Optional
6
+
7
+ import roma
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from ..modules import rot6d_to_rotmat
12
+ from ..modules.mhr_utils import (
13
+ compact_cont_to_model_params_body,
14
+ compact_cont_to_model_params_hand,
15
+ compact_model_params_to_cont_body,
16
+ mhr_param_hand_mask,
17
+ )
18
+
19
+ from ..modules.transformer import FFN
20
+
21
+ MOMENTUM_ENABLED = os.environ.get("MOMENTUM_ENABLED") is None
22
+ try:
23
+ if MOMENTUM_ENABLED:
24
+ from mhr.mhr import MHR
25
+
26
+ MOMENTUM_ENABLED = True
27
+ warnings.warn("Momentum is enabled")
28
+ else:
29
+ warnings.warn("Momentum is not enabled")
30
+ raise ImportError
31
+ except:
32
+ MOMENTUM_ENABLED = False
33
+ warnings.warn("Momentum is not enabled")
34
+
35
+
36
+ class MHRHead(nn.Module):
37
+
38
+ def __init__(
39
+ self,
40
+ input_dim: int,
41
+ mlp_depth: int = 1,
42
+ mhr_model_path: str = "",
43
+ extra_joint_regressor: str = "",
44
+ ffn_zero_bias: bool = True,
45
+ mlp_channel_div_factor: int = 8,
46
+ enable_hand_model=False,
47
+ ):
48
+ super().__init__()
49
+
50
+ self.num_shape_comps = 45
51
+ self.num_scale_comps = 28
52
+ self.num_hand_comps = 54
53
+ self.num_face_comps = 72
54
+ self.enable_hand_model = enable_hand_model
55
+
56
+ self.body_cont_dim = 260
57
+ self.npose = (
58
+ 6 # Global Rotation
59
+ + self.body_cont_dim # then body
60
+ + self.num_shape_comps
61
+ + self.num_scale_comps
62
+ + self.num_hand_comps * 2
63
+ + self.num_face_comps
64
+ )
65
+
66
+ self.proj = FFN(
67
+ embed_dims=input_dim,
68
+ feedforward_channels=input_dim // mlp_channel_div_factor,
69
+ output_dims=self.npose,
70
+ num_fcs=mlp_depth,
71
+ ffn_drop=0.0,
72
+ add_identity=False,
73
+ )
74
+
75
+ if ffn_zero_bias:
76
+ torch.nn.init.zeros_(self.proj.layers[-2].bias)
77
+
78
+ # MHR Parameters
79
+ self.model_data_dir = mhr_model_path
80
+ self.num_hand_scale_comps = self.num_scale_comps - 18
81
+ self.num_hand_pose_comps = self.num_hand_comps
82
+
83
+ # Buffers to be filled in by model state dict
84
+ self.joint_rotation = nn.Parameter(torch.zeros(127, 3, 3), requires_grad=False)
85
+ self.scale_mean = nn.Parameter(torch.zeros(68), requires_grad=False)
86
+ self.scale_comps = nn.Parameter(torch.zeros(28, 68), requires_grad=False)
87
+ self.faces = nn.Parameter(torch.zeros(36874, 3).long(), requires_grad=False)
88
+ self.hand_pose_mean = nn.Parameter(torch.zeros(54), requires_grad=False)
89
+ self.hand_pose_comps = nn.Parameter(torch.eye(54), requires_grad=False)
90
+ self.hand_joint_idxs_left = nn.Parameter(
91
+ torch.zeros(27).long(), requires_grad=False
92
+ )
93
+ self.hand_joint_idxs_right = nn.Parameter(
94
+ torch.zeros(27).long(), requires_grad=False
95
+ )
96
+ self.keypoint_mapping = nn.Parameter(
97
+ torch.zeros(308, 18439 + 127), requires_grad=False
98
+ )
99
+ # Some special buffers for the hand-version
100
+ self.right_wrist_coords = nn.Parameter(torch.zeros(3), requires_grad=False)
101
+ self.root_coords = nn.Parameter(torch.zeros(3), requires_grad=False)
102
+ self.local_to_world_wrist = nn.Parameter(torch.zeros(3, 3), requires_grad=False)
103
+ self.nonhand_param_idxs = nn.Parameter(
104
+ torch.zeros(145).long(), requires_grad=False
105
+ )
106
+
107
+ # Load MHR itself
108
+ if MOMENTUM_ENABLED:
109
+ self.mhr = MHR.from_files(
110
+ device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
111
+ lod=1,
112
+ )
113
+ else:
114
+ self.mhr = torch.jit.load(
115
+ mhr_model_path,
116
+ map_location=("cuda" if torch.cuda.is_available() else "cpu"),
117
+ )
118
+
119
+ for param in self.mhr.parameters():
120
+ param.requires_grad = False
121
+
122
+ def get_zero_pose_init(self, factor=1.0):
123
+ # Initialize pose token with zero-initialized learnable params
124
+ # Note: bias/initial value should be zero-pose in cont, not all-zeros
125
+ weights = torch.zeros(1, self.npose)
126
+ weights[:, : 6 + self.body_cont_dim] = torch.cat(
127
+ [
128
+ torch.FloatTensor([1, 0, 0, 0, 1, 0]),
129
+ compact_model_params_to_cont_body(torch.zeros(1, 133)).squeeze()
130
+ * factor,
131
+ ],
132
+ dim=0,
133
+ )
134
+ return weights
135
+
136
+ def replace_hands_in_pose(self, full_pose_params, hand_pose_params):
137
+ assert full_pose_params.shape[1] == 136
138
+
139
+ # This drops in the hand poses from hand_pose_params (PCA 6D) into full_pose_params.
140
+ # Split into left and right hands
141
+ left_hand_params, right_hand_params = torch.split(
142
+ hand_pose_params,
143
+ [self.num_hand_pose_comps, self.num_hand_pose_comps],
144
+ dim=1,
145
+ )
146
+
147
+ # Change from cont to model params
148
+ left_hand_params_model_params = compact_cont_to_model_params_hand(
149
+ self.hand_pose_mean
150
+ + torch.einsum("da,ab->db", left_hand_params, self.hand_pose_comps)
151
+ )
152
+ right_hand_params_model_params = compact_cont_to_model_params_hand(
153
+ self.hand_pose_mean
154
+ + torch.einsum("da,ab->db", right_hand_params, self.hand_pose_comps)
155
+ )
156
+
157
+ # Drop it in
158
+ full_pose_params[:, self.hand_joint_idxs_left] = left_hand_params_model_params
159
+ full_pose_params[:, self.hand_joint_idxs_right] = right_hand_params_model_params
160
+
161
+ return full_pose_params # B x 207
162
+
163
+ def mhr_forward(
164
+ self,
165
+ global_trans,
166
+ global_rot,
167
+ body_pose_params,
168
+ hand_pose_params,
169
+ scale_params,
170
+ shape_params,
171
+ expr_params=None,
172
+ return_keypoints=False,
173
+ do_pcblend=True,
174
+ return_joint_coords=False,
175
+ return_model_params=False,
176
+ return_joint_rotations=False,
177
+ scale_offsets=None,
178
+ vertex_offsets=None,
179
+ ):
180
+
181
+ if self.enable_hand_model:
182
+ # Transfer wrist-centric predictions to the body.
183
+ global_rot_ori = global_rot.clone()
184
+ global_trans_ori = global_trans.clone()
185
+ global_rot = roma.rotmat_to_euler(
186
+ "xyz",
187
+ roma.euler_to_rotmat("xyz", global_rot_ori) @ self.local_to_world_wrist,
188
+ )
189
+ global_trans = (
190
+ -(
191
+ roma.euler_to_rotmat("xyz", global_rot)
192
+ @ (self.right_wrist_coords - self.root_coords)
193
+ + self.root_coords
194
+ )
195
+ + global_trans_ori
196
+ )
197
+
198
+ body_pose_params = body_pose_params[..., :130]
199
+
200
+ # Convert from scale and shape params to actual scales and vertices
201
+ ## Add singleton batches in case...
202
+ if len(scale_params.shape) == 1:
203
+ scale_params = scale_params[None]
204
+ if len(shape_params.shape) == 1:
205
+ shape_params = shape_params[None]
206
+ ## Convert scale...
207
+ scales = self.scale_mean[None, :] + scale_params @ self.scale_comps
208
+ if scale_offsets is not None:
209
+ scales = scales + scale_offsets
210
+
211
+ # Now, figure out the pose.
212
+ ## 10 here is because it's more stable to optimize global translation in meters.
213
+ full_pose_params = torch.cat(
214
+ [global_trans * 10, global_rot, body_pose_params], dim=1
215
+ ) # B x 127
216
+ ## Put in hands
217
+ if hand_pose_params is not None:
218
+ full_pose_params = self.replace_hands_in_pose(
219
+ full_pose_params, hand_pose_params
220
+ )
221
+ model_params = torch.cat([full_pose_params, scales], dim=1)
222
+
223
+ if self.enable_hand_model:
224
+ # Zero out non-hand parameters
225
+ model_params[:, self.nonhand_param_idxs] = 0
226
+
227
+ curr_skinned_verts, curr_skel_state = self.mhr(
228
+ shape_params, model_params, expr_params
229
+ )
230
+ curr_joint_coords, curr_joint_quats, _ = torch.split(
231
+ curr_skel_state, [3, 4, 1], dim=2
232
+ )
233
+ curr_skinned_verts = curr_skinned_verts / 100
234
+ curr_joint_coords = curr_joint_coords / 100
235
+ curr_joint_rots = roma.unitquat_to_rotmat(curr_joint_quats)
236
+
237
+ # Prepare returns
238
+ to_return = [curr_skinned_verts]
239
+ if return_keypoints:
240
+ # Get sapiens 308 keypoints
241
+ model_vert_joints = torch.cat(
242
+ [curr_skinned_verts, curr_joint_coords], dim=1
243
+ ) # B x (num_verts + 127) x 3
244
+ model_keypoints_pred = (
245
+ (
246
+ self.keypoint_mapping
247
+ @ model_vert_joints.permute(1, 0, 2).flatten(1, 2)
248
+ )
249
+ .reshape(-1, model_vert_joints.shape[0], 3)
250
+ .permute(1, 0, 2)
251
+ )
252
+
253
+ if self.enable_hand_model:
254
+ # Zero out everything except for the right hand
255
+ model_keypoints_pred[:, :21] = 0
256
+ model_keypoints_pred[:, 42:] = 0
257
+
258
+ to_return = to_return + [model_keypoints_pred]
259
+ if return_joint_coords:
260
+ to_return = to_return + [curr_joint_coords]
261
+ if return_model_params:
262
+ to_return = to_return + [model_params]
263
+ if return_joint_rotations:
264
+ to_return = to_return + [curr_joint_rots]
265
+
266
+ if isinstance(to_return, list) and len(to_return) == 1:
267
+ return to_return[0]
268
+ else:
269
+ return tuple(to_return)
270
+
271
+ def forward(
272
+ self,
273
+ x: torch.Tensor,
274
+ init_estimate: Optional[torch.Tensor] = None,
275
+ do_pcblend=True,
276
+ slim_keypoints=False,
277
+ ):
278
+ """
279
+ Args:
280
+ x: pose token with shape [B, C], usually C=DECODER.DIM
281
+ init_estimate: [B, self.npose]
282
+ """
283
+ batch_size = x.shape[0]
284
+ pred = self.proj(x)
285
+ if init_estimate is not None:
286
+ pred = pred + init_estimate
287
+
288
+ # From pred, we want to pull out individual predictions.
289
+
290
+ ## First, get globals
291
+ ### Global rotation is first 6.
292
+ count = 6
293
+ global_rot_6d = pred[:, :count]
294
+ global_rot_rotmat = rot6d_to_rotmat(global_rot_6d) # B x 3 x 3
295
+ global_rot_euler = roma.rotmat_to_euler("ZYX", global_rot_rotmat) # B x 3
296
+ global_trans = torch.zeros_like(global_rot_euler)
297
+
298
+ ## Next, get body pose.
299
+ ### Hold onto raw, continuous version for iterative correction.
300
+ pred_pose_cont = pred[:, count : count + self.body_cont_dim]
301
+ count += self.body_cont_dim
302
+ ### Convert to eulers (and trans)
303
+ pred_pose_euler = compact_cont_to_model_params_body(pred_pose_cont)
304
+ ### Zero-out hands
305
+ pred_pose_euler[:, mhr_param_hand_mask] = 0
306
+ ### Zero-out jaw
307
+ pred_pose_euler[:, -3:] = 0
308
+
309
+ ## Get remaining parameters
310
+ pred_shape = pred[:, count : count + self.num_shape_comps]
311
+ count += self.num_shape_comps
312
+ pred_scale = pred[:, count : count + self.num_scale_comps]
313
+ count += self.num_scale_comps
314
+ pred_hand = pred[:, count : count + self.num_hand_comps * 2]
315
+ count += self.num_hand_comps * 2
316
+ pred_face = pred[:, count : count + self.num_face_comps] * 0
317
+ count += self.num_face_comps
318
+
319
+ # Run everything through mhr
320
+ output = self.mhr_forward(
321
+ global_trans=global_trans,
322
+ global_rot=global_rot_euler,
323
+ body_pose_params=pred_pose_euler,
324
+ hand_pose_params=pred_hand,
325
+ scale_params=pred_scale,
326
+ shape_params=pred_shape,
327
+ expr_params=pred_face,
328
+ do_pcblend=do_pcblend,
329
+ return_keypoints=True,
330
+ return_joint_coords=True,
331
+ return_model_params=True,
332
+ return_joint_rotations=True,
333
+ )
334
+
335
+ # Some existing code to get joints and fix camera system
336
+ verts, j3d, jcoords, mhr_model_params, joint_global_rots = output
337
+ j3d = j3d[:, :70] # 308 --> 70 keypoints
338
+
339
+ if verts is not None:
340
+ verts[..., [1, 2]] *= -1 # Camera system difference
341
+ j3d[..., [1, 2]] *= -1 # Camera system difference
342
+ if jcoords is not None:
343
+ jcoords[..., [1, 2]] *= -1
344
+
345
+ # Prep outputs
346
+ output = {
347
+ "pred_pose_raw": torch.cat(
348
+ [global_rot_6d, pred_pose_cont], dim=1
349
+ ), # Both global rot and continuous pose
350
+ "pred_pose_rotmat": None, # This normally used for mhr pose param rotmat supervision.
351
+ "global_rot": global_rot_euler,
352
+ "body_pose": pred_pose_euler, # Unused during training
353
+ "shape": pred_shape,
354
+ "scale": pred_scale,
355
+ "hand": pred_hand,
356
+ "face": pred_face,
357
+ "pred_keypoints_3d": j3d.reshape(batch_size, -1, 3),
358
+ "pred_vertices": (
359
+ verts.reshape(batch_size, -1, 3) if verts is not None else None
360
+ ),
361
+ "pred_joint_coords": (
362
+ jcoords.reshape(batch_size, -1, 3) if jcoords is not None else None
363
+ ),
364
+ "faces": self.faces.cpu().numpy(),
365
+ "joint_global_rots": joint_global_rots,
366
+ "mhr_model_params": mhr_model_params,
367
+ }
368
+
369
+ return output
src/sam3d_body/models/meta_arch/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from .sam3d_body import SAM3DBody
src/sam3d_body/models/meta_arch/base_lightning_module.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import numpy as np
4
+ import pytorch_lightning as pl
5
+ from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
6
+
7
+
8
+ class BaseLightningModule(pl.LightningModule):
9
+ def _log_metric(self, name, value, step=None):
10
+ for logger in self.trainer.loggers:
11
+ if isinstance(logger, WandbLogger):
12
+ if step is not None:
13
+ logger.experiment.log({name: value, "step": step})
14
+ else:
15
+ logger.experiment.log({name: value})
16
+ elif isinstance(logger, TensorBoardLogger):
17
+ logger.experiment.add_scalar(name, value, step)
18
+ else:
19
+ raise ValueError(f"Unsupported logger: {logger}")
20
+
21
+ def _log_image(self, name, img_tensor, dataformats="CHW", step_count=None):
22
+ """Log image tensor to both W&B and TensorBoard."""
23
+ step = step_count if step_count is not None else self.global_step
24
+ for logger in self.trainer.loggers:
25
+ if isinstance(logger, WandbLogger):
26
+ import wandb
27
+
28
+ img = img_tensor
29
+ if dataformats.upper() == "CHW":
30
+ # If in PyTorch format (C,H,W), convert to (H,W,C) for wandb
31
+ img = img_tensor.permute(1, 2, 0).cpu().numpy()
32
+ logger.experiment.log({name: wandb.Image(img), "step": step})
33
+ elif isinstance(logger, TensorBoardLogger):
34
+ logger.experiment.add_image(
35
+ name, img_tensor, step, dataformats=dataformats
36
+ )
37
+ else:
38
+ raise ValueError(f"Unsupported logger: {logger}")
39
+
40
+ def _log_hist(self, name, array, step_count=None):
41
+ for logger in self.trainer.loggers:
42
+ if isinstance(logger, WandbLogger):
43
+ import wandb
44
+
45
+ value = wandb.Histogram(
46
+ np_histogram=(array, np.arange(array.shape[0] + 1)),
47
+ )
48
+ logger.experiment.log({name: value, "step": step_count})
src/sam3d_body/models/meta_arch/base_model.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ """Define an abstract base model for consistent format input / processing / output."""
4
+
5
+ from abc import abstractmethod
6
+ from functools import partial
7
+
8
+ import torch
9
+ from yacs.config import CfgNode
10
+
11
+ from ..optim.fp16_utils import convert_module_to_f16, convert_to_fp16_safe
12
+ from .base_lightning_module import BaseLightningModule
13
+
14
+
15
+ class BaseModel(BaseLightningModule):
16
+ def __init__(self, cfg: CfgNode | None, **kwargs):
17
+ super().__init__()
18
+
19
+ # Save hyperparameters
20
+ self.save_hyperparameters(logger=False)
21
+ self.cfg = cfg
22
+
23
+ self._initialze_model(**kwargs)
24
+
25
+ # Initialize attributes for image-based batch format
26
+ self._max_num_person = None
27
+ self._person_valid = None
28
+
29
+ @abstractmethod
30
+ def _initialze_model(self, **kwargs) -> None:
31
+ pass
32
+
33
+ def data_preprocess(
34
+ self,
35
+ inputs: torch.Tensor,
36
+ crop_width: bool = False,
37
+ is_full: bool = False, # whether for full_branch
38
+ crop_hand: int = 0,
39
+ ) -> torch.Tensor:
40
+ image_mean = self.image_mean if not is_full else self.full_image_mean
41
+ image_std = self.image_std if not is_full else self.full_image_std
42
+
43
+ if inputs.max() > 1 and image_mean.max() <= 1.0:
44
+ inputs = inputs / 255.0
45
+ elif inputs.max() <= 1.0 and image_mean.max() > 1:
46
+ inputs = inputs * 255.0
47
+ batch_inputs = (inputs - image_mean) / image_std
48
+
49
+ if crop_width:
50
+ if crop_hand > 0:
51
+ batch_inputs = batch_inputs[:, :, :, crop_hand:-crop_hand]
52
+ elif self.cfg.MODEL.BACKBONE.TYPE in [
53
+ "vit_hmr",
54
+ "vit",
55
+ ]:
56
+ # ViT backbone assumes a different aspect ratio as input size
57
+ batch_inputs = batch_inputs[:, :, :, 32:-32]
58
+ elif self.cfg.MODEL.BACKBONE.TYPE in [
59
+ "vit_hmr_512_384",
60
+ ]:
61
+ batch_inputs = batch_inputs[:, :, :, 64:-64]
62
+ else:
63
+ raise Exception
64
+
65
+ return batch_inputs
66
+
67
+ def _initialize_batch(self, batch: dict) -> None:
68
+ # Check whether the input batch is with format
69
+ # [batch_size, num_person, ...]
70
+ if batch["img"].dim() == 5:
71
+ self._batch_size, self._max_num_person = batch["img"].shape[:2]
72
+ self._person_valid = self._flatten_person(batch["person_valid"]) > 0
73
+ else:
74
+ self._batch_size = batch["img"].shape[0]
75
+ self._max_num_person = 0
76
+ self._person_valid = None
77
+
78
+ def _flatten_person(self, x: torch.Tensor) -> torch.Tensor:
79
+ assert self._max_num_person is not None, "No max_num_person initialized"
80
+
81
+ if self._max_num_person:
82
+ # Merge person crops to batch dimension
83
+ shape = x.shape
84
+ x = x.view(self._batch_size * self._max_num_person, *shape[2:])
85
+ return x
86
+
87
+ def _unflatten_person(self, x: torch.Tensor) -> torch.Tensor:
88
+ shape = x.shape
89
+ if self._max_num_person:
90
+ x = x.view(self._batch_size, self._max_num_person, *shape[1:])
91
+ return x
92
+
93
+ def _get_valid(self, x: torch.Tensor) -> torch.Tensor:
94
+ assert self._max_num_person is not None, "No max_num_person initialized"
95
+
96
+ if self._person_valid is not None:
97
+ x = x[self._person_valid]
98
+ return x
99
+
100
+ def _full_to_crop(self, batch: dict, pred_keypoints_2d: torch.Tensor) -> torch.Tensor:
101
+ """Convert full-image keypoints coordinates to crop and normalize to [-0.5. 0.5]"""
102
+ pred_keypoints_2d_cropped = torch.cat(
103
+ [pred_keypoints_2d, torch.ones_like(pred_keypoints_2d[:, :, [-1]])], dim=-1
104
+ )
105
+ affine_trans = self._flatten_person(batch["affine_trans"]).to(pred_keypoints_2d_cropped)
106
+ img_size = self._flatten_person(batch["img_size"]).unsqueeze(1)
107
+ pred_keypoints_2d_cropped = pred_keypoints_2d_cropped @ affine_trans.mT
108
+ pred_keypoints_2d_cropped = pred_keypoints_2d_cropped[..., :2] / img_size - 0.5
109
+
110
+ return pred_keypoints_2d_cropped
111
+
112
+ def _cam_full_to_crop(
113
+ self, batch: dict, pred_cam_t: torch.Tensor, focal_length: torch.Tensor = None
114
+ ) -> torch.Tensor:
115
+ """Revert the camera translation from full to crop image space"""
116
+ num_person = batch["img"].shape[1]
117
+ cam_int = self._flatten_person(batch["cam_int"].unsqueeze(1).expand(-1, num_person, -1, -1).contiguous())
118
+ bbox_center = self._flatten_person(batch["bbox_center"])
119
+ bbox_size = self._flatten_person(batch["bbox_scale"])[:, 0]
120
+ input_size = self._flatten_person(batch["img_size"])[:, 0]
121
+
122
+ tx, ty, tz = pred_cam_t[:, 0], pred_cam_t[:, 1], pred_cam_t[:, 2]
123
+ if focal_length is None:
124
+ focal_length = cam_int[:, 0, 0]
125
+ bs = 2 * focal_length / (tz + 1e-8)
126
+
127
+ cx = 2 * (bbox_center[:, 0] - (cam_int[:, 0, 2])) / bs
128
+ cy = 2 * (bbox_center[:, 1] - (cam_int[:, 1, 2])) / bs
129
+
130
+ crop_cam_t = torch.stack([tx - cx, ty - cy, tz * bbox_size / input_size], dim=-1)
131
+ return crop_cam_t
132
+
133
+ def convert_to_fp16(self) -> torch.dtype:
134
+ """
135
+ Convert the torso of the model to float16.
136
+ """
137
+ fp16_type = torch.float16 if self.cfg.TRAIN.get("FP16_TYPE", "float16") == "float16" else torch.bfloat16
138
+
139
+ if hasattr(self, "backbone"):
140
+ self._set_fp16(self.backbone, fp16_type)
141
+ if hasattr(self, "full_encoder"):
142
+ self._set_fp16(self.full_encoder, fp16_type)
143
+
144
+ if hasattr(self.backbone, "lhand_pos_embed"):
145
+ self.backbone.lhand_pos_embed.data = self.backbone.lhand_pos_embed.data.to(fp16_type)
146
+
147
+ if hasattr(self.backbone, "rhand_pos_embed"):
148
+ self.backbone.rhand_pos_embed.data = self.backbone.rhand_pos_embed.data.to(fp16_type)
149
+
150
+ return fp16_type
151
+
152
+ def _set_fp16(self, module, fp16_type):
153
+ if hasattr(module, "pos_embed"):
154
+ module.apply(partial(convert_module_to_f16, dtype=fp16_type))
155
+ module.pos_embed.data = module.pos_embed.data.to(fp16_type)
156
+ elif hasattr(module.encoder, "rope_embed"):
157
+ # DINOv3
158
+ module.encoder.apply(partial(convert_to_fp16_safe, dtype=fp16_type))
159
+ module.encoder.rope_embed = module.encoder.rope_embed.to(fp16_type)
160
+ else:
161
+ # DINOv2
162
+ module.encoder.pos_embed.data = module.encoder.pos_embed.data.to(fp16_type)
src/sam3d_body/models/meta_arch/sam3d_body.py ADDED
@@ -0,0 +1,1728 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from collections.abc import Sequence
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ import roma
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from sam3d_body.data.utils.prepare_batch import prepare_batch
14
+ from sam3d_body.models.decoders.prompt_encoder import PositionEmbeddingRandom
15
+ from sam3d_body.models.modules.mhr_utils import (
16
+ fix_wrist_euler,
17
+ rotation_angle_difference,
18
+ )
19
+ from sam3d_body.utils import recursive_to
20
+ from sam3d_body.utils.logging import get_pylogger
21
+
22
+ from ..backbones import create_backbone
23
+ from ..decoders import PromptEncoder, build_decoder, build_keypoint_sampler
24
+ from ..heads import build_head
25
+ from ..modules.camera_embed import CameraEncoder
26
+ from ..modules.transformer import FFN, MLP
27
+ from .base_model import BaseModel
28
+
29
+ logger = get_pylogger(__name__)
30
+
31
+
32
+ # fmt: off
33
+ PROMPT_KEYPOINTS = { # keypoint_idx: prompt_idx
34
+ "mhr70": {
35
+ i: i for i in range(70)
36
+ }, # all 70 keypoints are supported for prompting
37
+ }
38
+ KEY_BODY = [5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 41, 62] # key body joints for prompting
39
+ KEY_RIGHT_HAND = list(range(21, 42))
40
+ # fmt: on
41
+
42
+
43
+ @dataclass
44
+ class BodyPredContainer:
45
+ """Structured container for main body + optional hand inference outputs."""
46
+
47
+ pose_output: dict[str, Any]
48
+ batch_lhand: dict[str, Any] | None = None
49
+ batch_rhand: dict[str, Any] | None = None
50
+ lhand_output: dict[str, Any] | None = None
51
+ rhand_output: dict[str, Any] | None = None
52
+
53
+
54
+ class SAM3DBody(BaseModel):
55
+ pelvis_idx = [9, 10] # left_hip, right_hip
56
+
57
+ def _initialze_model(self):
58
+ self.register_buffer("image_mean", torch.tensor(self.cfg.MODEL.IMAGE_MEAN).view(-1, 1, 1), False)
59
+ self.register_buffer("image_std", torch.tensor(self.cfg.MODEL.IMAGE_STD).view(-1, 1, 1), False)
60
+
61
+ # Create backbone feature extractor for human crops
62
+ self.backbone = create_backbone(self.cfg.MODEL.BACKBONE.TYPE, self.cfg)
63
+
64
+ # Create header for pose estimation output
65
+ self.head_pose = build_head(self.cfg, self.cfg.MODEL.PERSON_HEAD.POSE_TYPE)
66
+ self.head_pose.hand_pose_comps_ori = nn.Parameter(self.head_pose.hand_pose_comps.clone(), requires_grad=False)
67
+ self.head_pose.hand_pose_comps.data = torch.eye(54).to(self.head_pose.hand_pose_comps.data).float()
68
+
69
+ # Initialize pose token with learnable params
70
+ # Note: bias/initial value should be zero-pose in cont, not all-zeros
71
+ self.init_pose = nn.Embedding(1, self.head_pose.npose)
72
+
73
+ # Define header for hand pose estimation
74
+ self.head_pose_hand = build_head(self.cfg, self.cfg.MODEL.PERSON_HEAD.POSE_TYPE, enable_hand_model=True)
75
+ self.head_pose_hand.hand_pose_comps_ori = nn.Parameter(
76
+ self.head_pose_hand.hand_pose_comps.clone(), requires_grad=False
77
+ )
78
+ self.head_pose_hand.hand_pose_comps.data = torch.eye(54).to(self.head_pose_hand.hand_pose_comps.data).float()
79
+ self.init_pose_hand = nn.Embedding(1, self.head_pose_hand.npose)
80
+
81
+ self.head_camera = build_head(self.cfg, self.cfg.MODEL.PERSON_HEAD.CAMERA_TYPE)
82
+ self.init_camera = nn.Embedding(1, self.head_camera.ncam)
83
+ nn.init.zeros_(self.init_camera.weight)
84
+
85
+ self.head_camera_hand = build_head(
86
+ self.cfg,
87
+ self.cfg.MODEL.PERSON_HEAD.CAMERA_TYPE,
88
+ default_scale_factor=self.cfg.MODEL.CAMERA_HEAD.get("DEFAULT_SCALE_FACTOR_HAND", 1.0),
89
+ )
90
+ self.init_camera_hand = nn.Embedding(1, self.head_camera_hand.ncam)
91
+ nn.init.zeros_(self.init_camera_hand.weight)
92
+
93
+ self.camera_type = "perspective"
94
+
95
+ # Support conditioned information for decoder
96
+ cond_dim = 3
97
+ init_dim = self.head_pose.npose + self.head_camera.ncam + cond_dim
98
+ self.init_to_token_mhr = nn.Linear(init_dim, self.cfg.MODEL.DECODER.DIM)
99
+ self.prev_to_token_mhr = nn.Linear(init_dim - cond_dim, self.cfg.MODEL.DECODER.DIM)
100
+ self.init_to_token_mhr_hand = nn.Linear(init_dim, self.cfg.MODEL.DECODER.DIM)
101
+ self.prev_to_token_mhr_hand = nn.Linear(init_dim - cond_dim, self.cfg.MODEL.DECODER.DIM)
102
+
103
+ # Create prompt encoder
104
+ self.max_num_clicks = 0
105
+ if self.cfg.MODEL.PROMPT_ENCODER.ENABLE:
106
+ self.max_num_clicks = self.cfg.MODEL.PROMPT_ENCODER.MAX_NUM_CLICKS
107
+ self.prompt_keypoints = PROMPT_KEYPOINTS[self.cfg.MODEL.PROMPT_ENCODER.PROMPT_KEYPOINTS]
108
+
109
+ self.prompt_encoder = PromptEncoder(
110
+ embed_dim=self.backbone.embed_dims, # need to match backbone dims for PE
111
+ num_body_joints=len(set(self.prompt_keypoints.values())),
112
+ frozen=self.cfg.MODEL.PROMPT_ENCODER.get("frozen", False),
113
+ mask_embed_type=self.cfg.MODEL.PROMPT_ENCODER.get("MASK_EMBED_TYPE", None),
114
+ )
115
+ self.prompt_to_token = nn.Linear(self.backbone.embed_dims, self.cfg.MODEL.DECODER.DIM)
116
+
117
+ self.keypoint_prompt_sampler = build_keypoint_sampler(
118
+ self.cfg.MODEL.PROMPT_ENCODER.get("KEYPOINT_SAMPLER", {}),
119
+ prompt_keypoints=self.prompt_keypoints,
120
+ keybody_idx=(
121
+ KEY_BODY if not self.cfg.MODEL.PROMPT_ENCODER.get("SAMPLE_HAND", False) else KEY_RIGHT_HAND
122
+ ),
123
+ )
124
+ # To keep track of prompting history
125
+ self.prompt_hist = np.zeros(
126
+ (len(set(self.prompt_keypoints.values())) + 2, self.max_num_clicks),
127
+ dtype=np.float32,
128
+ )
129
+
130
+ if self.cfg.MODEL.DECODER.FROZEN:
131
+ for param in self.prompt_to_token.parameters():
132
+ param.requires_grad = False
133
+
134
+ # Create promptable decoder
135
+ self.decoder = build_decoder(self.cfg.MODEL.DECODER, context_dim=self.backbone.embed_dims)
136
+ # shared config for the two decoders
137
+ self.decoder_hand = build_decoder(self.cfg.MODEL.DECODER, context_dim=self.backbone.embed_dims)
138
+ self.hand_pe_layer = PositionEmbeddingRandom(self.backbone.embed_dims // 2)
139
+
140
+ # Manually convert the torso of the model to fp16.
141
+ if self.cfg.TRAIN.USE_FP16:
142
+ self.convert_to_fp16()
143
+ if self.cfg.TRAIN.get("FP16_TYPE", "float16") == "float16":
144
+ self.backbone_dtype = torch.float16
145
+ else:
146
+ self.backbone_dtype = torch.bfloat16
147
+ else:
148
+ self.backbone_dtype = torch.float32
149
+
150
+ self.ray_cond_emb = CameraEncoder(
151
+ self.backbone.embed_dim,
152
+ self.backbone.patch_size,
153
+ )
154
+ self.ray_cond_emb_hand = CameraEncoder(
155
+ self.backbone.embed_dim,
156
+ self.backbone.patch_size,
157
+ )
158
+
159
+ self.keypoint_embedding_idxs = list(range(70))
160
+ self.keypoint_embedding = nn.Embedding(len(self.keypoint_embedding_idxs), self.cfg.MODEL.DECODER.DIM)
161
+ self.keypoint_embedding_idxs_hand = list(range(70))
162
+ self.keypoint_embedding_hand = nn.Embedding(len(self.keypoint_embedding_idxs_hand), self.cfg.MODEL.DECODER.DIM)
163
+
164
+ if self.cfg.MODEL.DECODER.get("DO_HAND_DETECT_TOKENS", False):
165
+ self.hand_box_embedding = nn.Embedding(2, self.cfg.MODEL.DECODER.DIM) # for two hands
166
+ # decice if there is left or right hand inside the image
167
+ self.hand_cls_embed = nn.Linear(self.cfg.MODEL.DECODER.DIM, 2)
168
+ self.bbox_embed = MLP(self.cfg.MODEL.DECODER.DIM, self.cfg.MODEL.DECODER.DIM, 4, 3)
169
+
170
+ self.keypoint_posemb_linear = FFN(
171
+ embed_dims=2,
172
+ feedforward_channels=self.cfg.MODEL.DECODER.DIM,
173
+ output_dims=self.cfg.MODEL.DECODER.DIM,
174
+ num_fcs=2,
175
+ add_identity=False,
176
+ )
177
+ self.keypoint_posemb_linear_hand = FFN(
178
+ embed_dims=2,
179
+ feedforward_channels=self.cfg.MODEL.DECODER.DIM,
180
+ output_dims=self.cfg.MODEL.DECODER.DIM,
181
+ num_fcs=2,
182
+ add_identity=False,
183
+ )
184
+ self.keypoint_feat_linear = nn.Linear(self.backbone.embed_dims, self.cfg.MODEL.DECODER.DIM)
185
+ self.keypoint_feat_linear_hand = nn.Linear(self.backbone.embed_dims, self.cfg.MODEL.DECODER.DIM)
186
+
187
+ # Do all KPS
188
+ self.keypoint3d_embedding_idxs = list(range(70))
189
+ self.keypoint3d_embedding = nn.Embedding(len(self.keypoint3d_embedding_idxs), self.cfg.MODEL.DECODER.DIM)
190
+
191
+ # Assume always do full body for the hand decoder
192
+ self.keypoint3d_embedding_idxs_hand = list(range(70))
193
+ self.keypoint3d_embedding_hand = nn.Embedding(
194
+ len(self.keypoint3d_embedding_idxs_hand), self.cfg.MODEL.DECODER.DIM
195
+ )
196
+
197
+ self.keypoint3d_posemb_linear = FFN(
198
+ embed_dims=3,
199
+ feedforward_channels=self.cfg.MODEL.DECODER.DIM,
200
+ output_dims=self.cfg.MODEL.DECODER.DIM,
201
+ num_fcs=2,
202
+ add_identity=False,
203
+ )
204
+ self.keypoint3d_posemb_linear_hand = FFN(
205
+ embed_dims=3,
206
+ feedforward_channels=self.cfg.MODEL.DECODER.DIM,
207
+ output_dims=self.cfg.MODEL.DECODER.DIM,
208
+ num_fcs=2,
209
+ add_identity=False,
210
+ )
211
+
212
+ def _get_decoder_condition(self, batch: dict) -> torch.Tensor | None:
213
+ num_person = batch["img"].shape[1]
214
+
215
+ if self.cfg.MODEL.DECODER.CONDITION_TYPE == "cliff":
216
+ # CLIFF-style condition info (cx/f, cy/f, b/f)
217
+ cx, cy = torch.chunk(self._flatten_person(batch["bbox_center"]), chunks=2, dim=-1)
218
+ img_w, img_h = torch.chunk(self._flatten_person(batch["ori_img_size"]), chunks=2, dim=-1)
219
+ b = self._flatten_person(batch["bbox_scale"])[:, [0]]
220
+
221
+ focal_length = self._flatten_person(
222
+ batch["cam_int"].unsqueeze(1).expand(-1, num_person, -1, -1).contiguous()
223
+ )[:, 0, 0]
224
+ if not self.cfg.MODEL.DECODER.get("USE_INTRIN_CENTER", False):
225
+ condition_info = torch.cat([cx - img_w / 2.0, cy - img_h / 2.0, b], dim=-1)
226
+ else:
227
+ full_img_cxy = self._flatten_person(
228
+ batch["cam_int"].unsqueeze(1).expand(-1, num_person, -1, -1).contiguous()
229
+ )[:, [0, 1], [2, 2]]
230
+ condition_info = torch.cat([cx - full_img_cxy[:, [0]], cy - full_img_cxy[:, [1]], b], dim=-1)
231
+ condition_info[:, :2] = condition_info[:, :2] / focal_length.unsqueeze(-1) # [-1, 1]
232
+ condition_info[:, 2] = condition_info[:, 2] / focal_length # [-1, 1]
233
+ elif self.cfg.MODEL.DECODER.CONDITION_TYPE == "none":
234
+ return None
235
+ else:
236
+ raise NotImplementedError
237
+
238
+ return condition_info.type(batch["img"].dtype)
239
+
240
+ def forward_decoder(
241
+ self,
242
+ image_embeddings: torch.Tensor,
243
+ init_estimate: torch.Tensor | None = None,
244
+ keypoints: torch.Tensor | None = None,
245
+ prev_estimate: torch.Tensor | None = None,
246
+ condition_info: torch.Tensor | None = None,
247
+ batch=None,
248
+ ):
249
+ """
250
+ Args:
251
+ image_embeddings: image features from the backbone, shape (B, C, H, W)
252
+ init_estimate: initial estimate to be refined on, shape (B, 1, C)
253
+ keypoints: optional prompt input, shape (B, N, 3),
254
+ 3 for coordinates (x,y) + label.
255
+ (x, y) should be normalized to range [0, 1].
256
+ label==-1 indicates incorrect points,
257
+ label==-2 indicates invalid points
258
+ prev_estimate: optional prompt input, shape (B, 1, C),
259
+ previous estimate for pose refinement.
260
+ condition_info: optional condition information that is concatenated with
261
+ the input tokens, shape (B, c)
262
+ """
263
+ batch_size = image_embeddings.shape[0]
264
+
265
+ # Initial estimation for residual prediction.
266
+ if init_estimate is None:
267
+ init_pose = self.init_pose.weight.expand(batch_size, -1).unsqueeze(dim=1)
268
+ if hasattr(self, "init_camera"):
269
+ init_camera = self.init_camera.weight.expand(batch_size, -1).unsqueeze(dim=1)
270
+
271
+ init_estimate = (
272
+ init_pose if not hasattr(self, "init_camera") else torch.cat([init_pose, init_camera], dim=-1)
273
+ ) # This is basically pose & camera translation at the end. B x 1 x (404 + 3)
274
+
275
+ init_input = (
276
+ torch.cat([condition_info.view(batch_size, 1, -1), init_estimate], dim=-1)
277
+ if condition_info is not None
278
+ else init_estimate
279
+ ) # B x 1 x 410 (this is with the CLIFF condition)
280
+ token_embeddings = self.init_to_token_mhr(init_input).view(batch_size, 1, -1) # B x 1 x 1024 (linear layered)
281
+
282
+ num_pose_token = token_embeddings.shape[1]
283
+ assert num_pose_token == 1
284
+
285
+ image_augment, token_augment, token_mask = None, None, None
286
+ if hasattr(self, "prompt_encoder") and keypoints is not None:
287
+ if prev_estimate is None:
288
+ # Use initial embedding if no previous embedding
289
+ prev_estimate = init_estimate
290
+ # Previous estimate w/o the CLIFF condition.
291
+ prev_embeddings = self.prev_to_token_mhr(prev_estimate).view(
292
+ batch_size, 1, -1
293
+ ) # 407 -> B x 1 x 1024; linear layer-ed
294
+
295
+ if self.cfg.MODEL.BACKBONE.TYPE in [
296
+ "vit_hmr",
297
+ "vit",
298
+ "vit_b",
299
+ "vit_l",
300
+ ]:
301
+ # ViT backbone assumes a different aspect ratio as input size
302
+ image_augment = self.prompt_encoder.get_dense_pe((16, 16))[:, :, :, 2:-2]
303
+ elif self.cfg.MODEL.BACKBONE.TYPE in [
304
+ "vit_hmr_512_384",
305
+ ]:
306
+ # ViT backbone assumes a different aspect ratio as input size
307
+ image_augment = self.prompt_encoder.get_dense_pe((32, 32))[:, :, :, 4:-4]
308
+ else:
309
+ image_augment = self.prompt_encoder.get_dense_pe(image_embeddings.shape[-2:]) # (1, C, H, W)
310
+
311
+ image_embeddings = self.ray_cond_emb(image_embeddings, batch["ray_cond"])
312
+
313
+ # To start, keypoints is all [0, 0, -2]. The points get sent into self.pe_layer._pe_encoding,
314
+ # the labels determine the embedding weight (special one for -2, -1, then each of joint.)
315
+ prompt_embeddings, prompt_mask = self.prompt_encoder(keypoints=keypoints) # B x 1 x 1280
316
+ prompt_embeddings = self.prompt_to_token(prompt_embeddings) # Linear layered: B x 1 x 1024
317
+
318
+ # Concatenate pose tokens and prompt embeddings as decoder input
319
+ token_embeddings = torch.cat(
320
+ [
321
+ token_embeddings,
322
+ prev_embeddings,
323
+ prompt_embeddings,
324
+ ],
325
+ dim=1,
326
+ )
327
+
328
+ token_augment = torch.zeros_like(token_embeddings)
329
+ token_augment[:, [num_pose_token]] = prev_embeddings
330
+ token_augment[:, (num_pose_token + 1) :] = prompt_embeddings
331
+ token_mask = None
332
+
333
+ if self.cfg.MODEL.DECODER.get("DO_HAND_DETECT_TOKENS", False):
334
+ # Put in a token for each hand
335
+ hand_det_emb_start_idx = token_embeddings.shape[1]
336
+ token_embeddings = torch.cat(
337
+ [
338
+ token_embeddings,
339
+ self.hand_box_embedding.weight[None, :, :].repeat(batch_size, 1, 1),
340
+ ],
341
+ dim=1,
342
+ ) # B x 5 + 70 x 1024
343
+ # No positional embeddings
344
+ token_augment = torch.cat(
345
+ [
346
+ token_augment,
347
+ torch.zeros_like(token_embeddings[:, token_augment.shape[1] :, :]),
348
+ ],
349
+ dim=1,
350
+ ) # B x 5 + 70 x 1024
351
+
352
+ assert self.cfg.MODEL.DECODER.get("DO_KEYPOINT_TOKENS", False)
353
+ # Put in a token for each keypoint
354
+ kps_emb_start_idx = token_embeddings.shape[1]
355
+ token_embeddings = torch.cat(
356
+ [
357
+ token_embeddings,
358
+ self.keypoint_embedding.weight[None, :, :].repeat(batch_size, 1, 1),
359
+ ],
360
+ dim=1,
361
+ ) # B x 3 + 70 x 1024
362
+ # No positional embeddings
363
+ token_augment = torch.cat(
364
+ [
365
+ token_augment,
366
+ torch.zeros_like(token_embeddings[:, token_augment.shape[1] :, :]),
367
+ ],
368
+ dim=1,
369
+ ) # B x 3 + 70 x 1024
370
+ if self.cfg.MODEL.DECODER.get("DO_KEYPOINT3D_TOKENS", False):
371
+ # Put in a token for each keypoint
372
+ kps3d_emb_start_idx = token_embeddings.shape[1]
373
+ token_embeddings = torch.cat(
374
+ [
375
+ token_embeddings,
376
+ self.keypoint3d_embedding.weight[None, :, :].repeat(batch_size, 1, 1),
377
+ ],
378
+ dim=1,
379
+ ) # B x 3 + 70 + 70 x 1024
380
+ # No positional embeddings
381
+ token_augment = torch.cat(
382
+ [
383
+ token_augment,
384
+ torch.zeros_like(token_embeddings[:, token_augment.shape[1] :, :]),
385
+ ],
386
+ dim=1,
387
+ ) # B x 3 + 70 + 70 x 1024
388
+
389
+ # We're doing intermediate model predictions
390
+ def token_to_pose_output_fn(tokens, prev_pose_output, layer_idx):
391
+ # Get the pose token
392
+ pose_token = tokens[:, 0]
393
+
394
+ prev_pose = init_pose.view(batch_size, -1)
395
+ prev_camera = init_camera.view(batch_size, -1)
396
+
397
+ # Get pose outputs
398
+ pose_output = self.head_pose(pose_token, prev_pose)
399
+ # Get Camera Translation
400
+ if hasattr(self, "head_camera"):
401
+ pred_cam = self.head_camera(pose_token, prev_camera)
402
+ pose_output["pred_cam"] = pred_cam
403
+ # Run camera projection
404
+ pose_output = self.camera_project(pose_output, batch)
405
+
406
+ # Get 2D KPS in crop
407
+ pose_output["pred_keypoints_2d_cropped"] = self._full_to_crop(
408
+ batch, pose_output["pred_keypoints_2d"], self.body_batch_idx
409
+ )
410
+
411
+ return pose_output
412
+
413
+ kp_token_update_fn = self.keypoint_token_update_fn
414
+
415
+ # Now for 3D
416
+ kp3d_token_update_fn = self.keypoint3d_token_update_fn
417
+
418
+ # Combine the 2D and 3D functionse
419
+ def keypoint_token_update_fn_comb(*args):
420
+ if kp_token_update_fn is not None:
421
+ args = kp_token_update_fn(kps_emb_start_idx, image_embeddings, *args)
422
+ if kp3d_token_update_fn is not None:
423
+ args = kp3d_token_update_fn(kps3d_emb_start_idx, *args)
424
+ return args
425
+
426
+ pose_token, pose_output = self.decoder(
427
+ token_embeddings,
428
+ image_embeddings,
429
+ token_augment,
430
+ image_augment,
431
+ token_mask,
432
+ token_to_pose_output_fn=token_to_pose_output_fn,
433
+ keypoint_token_update_fn=keypoint_token_update_fn_comb,
434
+ )
435
+
436
+ if self.cfg.MODEL.DECODER.get("DO_HAND_DETECT_TOKENS", False):
437
+ return (
438
+ pose_token[:, hand_det_emb_start_idx : hand_det_emb_start_idx + 2],
439
+ pose_output,
440
+ )
441
+ else:
442
+ return pose_token, pose_output
443
+
444
+ def forward_decoder_hand(
445
+ self,
446
+ image_embeddings: torch.Tensor,
447
+ init_estimate: torch.Tensor | None = None,
448
+ keypoints: torch.Tensor | None = None,
449
+ prev_estimate: torch.Tensor | None = None,
450
+ condition_info: torch.Tensor | None = None,
451
+ batch=None,
452
+ ):
453
+ """
454
+ Args:
455
+ image_embeddings: image features from the backbone, shape (B, C, H, W)
456
+ init_estimate: initial estimate to be refined on, shape (B, 1, C)
457
+ keypoints: optional prompt input, shape (B, N, 3),
458
+ 3 for coordinates (x,y) + label.
459
+ (x, y) should be normalized to range [0, 1].
460
+ label==-1 indicates incorrect points,
461
+ label==-2 indicates invalid points
462
+ prev_estimate: optional prompt input, shape (B, 1, C),
463
+ previous estimate for pose refinement.
464
+ condition_info: optional condition information that is concatenated with
465
+ the input tokens, shape (B, c)
466
+ """
467
+ batch_size = image_embeddings.shape[0]
468
+
469
+ # Initial estimation for residual prediction.
470
+ if init_estimate is None:
471
+ init_pose = self.init_pose_hand.weight.expand(batch_size, -1).unsqueeze(dim=1)
472
+ if hasattr(self, "init_camera_hand"):
473
+ init_camera = self.init_camera_hand.weight.expand(batch_size, -1).unsqueeze(dim=1)
474
+
475
+ init_estimate = (
476
+ init_pose if not hasattr(self, "init_camera_hand") else torch.cat([init_pose, init_camera], dim=-1)
477
+ ) # This is basically pose & camera translation at the end. B x 1 x (404 + 3)
478
+
479
+ init_input = (
480
+ torch.cat([condition_info.view(batch_size, 1, -1), init_estimate], dim=-1)
481
+ if condition_info is not None
482
+ else init_estimate
483
+ ) # B x 1 x 410 (this is with the CLIFF condition)
484
+ token_embeddings = self.init_to_token_mhr_hand(init_input).view(
485
+ batch_size, 1, -1
486
+ ) # B x 1 x 1024 (linear layered)
487
+ num_pose_token = token_embeddings.shape[1]
488
+
489
+ image_augment, token_augment, token_mask = None, None, None
490
+ if hasattr(self, "prompt_encoder") and keypoints is not None:
491
+ if prev_estimate is None:
492
+ # Use initial embedding if no previous embedding
493
+ prev_estimate = init_estimate
494
+ # Previous estimate w/o the CLIFF condition.
495
+ prev_embeddings = self.prev_to_token_mhr_hand(prev_estimate).view(
496
+ batch_size, 1, -1
497
+ ) # 407 -> B x 1 x 1024; linear layer-ed
498
+
499
+ if self.cfg.MODEL.BACKBONE.TYPE in [
500
+ "vit_hmr",
501
+ "vit",
502
+ "vit_b",
503
+ "vit_l",
504
+ ]:
505
+ # ViT backbone assumes a different aspect ratio as input size
506
+ image_augment = self.hand_pe_layer((16, 16)).unsqueeze(0)[:, :, :, 2:-2]
507
+ elif self.cfg.MODEL.BACKBONE.TYPE in [
508
+ "vit_hmr_512_384",
509
+ ]:
510
+ # ViT backbone assumes a different aspect ratio as input size
511
+ image_augment = self.hand_pe_layer((32, 32)).unsqueeze(0)[:, :, :, 4:-4]
512
+ else:
513
+ image_augment = self.hand_pe_layer(image_embeddings.shape[-2:]).unsqueeze(0) # (1, C, H, W)
514
+
515
+ image_embeddings = self.ray_cond_emb_hand(image_embeddings, batch["ray_cond_hand"])
516
+
517
+ # To start, keypoints is all [0, 0, -2]. The points get sent into self.pe_layer._pe_encoding,
518
+ # the labels determine the embedding weight (special one for -2, -1, then each of joint.)
519
+ prompt_embeddings, prompt_mask = self.prompt_encoder(keypoints=keypoints) # B x 1 x 1280
520
+ prompt_embeddings = self.prompt_to_token(prompt_embeddings) # Linear layered: B x 1 x 1024
521
+
522
+ # Concatenate pose tokens and prompt embeddings as decoder input
523
+ token_embeddings = torch.cat(
524
+ [
525
+ token_embeddings,
526
+ prev_embeddings,
527
+ prompt_embeddings,
528
+ ],
529
+ dim=1,
530
+ )
531
+
532
+ token_augment = torch.zeros_like(token_embeddings)
533
+ token_augment[:, [num_pose_token]] = prev_embeddings
534
+ token_augment[:, (num_pose_token + 1) :] = prompt_embeddings
535
+ token_mask = None
536
+
537
+ if self.cfg.MODEL.DECODER.get("DO_HAND_DETECT_TOKENS", False):
538
+ # Put in a token for each hand
539
+ hand_det_emb_start_idx = token_embeddings.shape[1]
540
+ token_embeddings = torch.cat(
541
+ [
542
+ token_embeddings,
543
+ self.hand_box_embedding.weight[None, :, :].repeat(batch_size, 1, 1),
544
+ ],
545
+ dim=1,
546
+ ) # B x 5 + 70 x 1024
547
+ # No positional embeddings
548
+ token_augment = torch.cat(
549
+ [
550
+ token_augment,
551
+ torch.zeros_like(token_embeddings[:, token_augment.shape[1] :, :]),
552
+ ],
553
+ dim=1,
554
+ ) # B x 5 + 70 x 1024
555
+
556
+ assert self.cfg.MODEL.DECODER.get("DO_KEYPOINT_TOKENS", False)
557
+ # Put in a token for each keypoint
558
+ kps_emb_start_idx = token_embeddings.shape[1]
559
+ token_embeddings = torch.cat(
560
+ [
561
+ token_embeddings,
562
+ self.keypoint_embedding_hand.weight[None, :, :].repeat(batch_size, 1, 1),
563
+ ],
564
+ dim=1,
565
+ ) # B x 3 + 70 x 1024
566
+ # No positional embeddings
567
+ token_augment = torch.cat(
568
+ [
569
+ token_augment,
570
+ torch.zeros_like(token_embeddings[:, token_augment.shape[1] :, :]),
571
+ ],
572
+ dim=1,
573
+ ) # B x 3 + 70 x 1024
574
+
575
+ if self.cfg.MODEL.DECODER.get("DO_KEYPOINT3D_TOKENS", False):
576
+ # Put in a token for each keypoint
577
+ kps3d_emb_start_idx = token_embeddings.shape[1]
578
+ token_embeddings = torch.cat(
579
+ [
580
+ token_embeddings,
581
+ self.keypoint3d_embedding_hand.weight[None, :, :].repeat(batch_size, 1, 1),
582
+ ],
583
+ dim=1,
584
+ ) # B x 3 + 70 + 70 x 1024
585
+ # No positional embeddings
586
+ token_augment = torch.cat(
587
+ [
588
+ token_augment,
589
+ torch.zeros_like(token_embeddings[:, token_augment.shape[1] :, :]),
590
+ ],
591
+ dim=1,
592
+ ) # B x 3 + 70 + 70 x 1024
593
+
594
+ # We're doing intermediate model predictions
595
+ def token_to_pose_output_fn(tokens, prev_pose_output, layer_idx):
596
+ # Get the pose token
597
+ pose_token = tokens[:, 0]
598
+
599
+ prev_pose = init_pose.view(batch_size, -1)
600
+ prev_camera = init_camera.view(batch_size, -1)
601
+
602
+ # Get pose outputs
603
+ pose_output = self.head_pose_hand(pose_token, prev_pose)
604
+
605
+ # Get Camera Translation
606
+ if hasattr(self, "head_camera_hand"):
607
+ pred_cam = self.head_camera_hand(pose_token, prev_camera)
608
+ pose_output["pred_cam"] = pred_cam
609
+ # Run camera projection
610
+ pose_output = self.camera_project_hand(pose_output, batch)
611
+
612
+ # Get 2D KPS in crop
613
+ pose_output["pred_keypoints_2d_cropped"] = self._full_to_crop(
614
+ batch, pose_output["pred_keypoints_2d"], self.hand_batch_idx
615
+ )
616
+
617
+ return pose_output
618
+
619
+ kp_token_update_fn = self.keypoint_token_update_fn_hand
620
+
621
+ # Now for 3D
622
+ kp3d_token_update_fn = self.keypoint3d_token_update_fn_hand
623
+
624
+ # Combine the 2D and 3D functionse
625
+ def keypoint_token_update_fn_comb(*args):
626
+ if kp_token_update_fn is not None:
627
+ args = kp_token_update_fn(kps_emb_start_idx, image_embeddings, *args)
628
+ if kp3d_token_update_fn is not None:
629
+ args = kp3d_token_update_fn(kps3d_emb_start_idx, *args)
630
+ return args
631
+
632
+ pose_token, pose_output = self.decoder_hand(
633
+ token_embeddings,
634
+ image_embeddings,
635
+ token_augment,
636
+ image_augment,
637
+ token_mask,
638
+ token_to_pose_output_fn=token_to_pose_output_fn,
639
+ keypoint_token_update_fn=keypoint_token_update_fn_comb,
640
+ )
641
+
642
+ if self.cfg.MODEL.DECODER.get("DO_HAND_DETECT_TOKENS", False):
643
+ return (
644
+ pose_token[:, hand_det_emb_start_idx : hand_det_emb_start_idx + 2],
645
+ pose_output,
646
+ )
647
+ else:
648
+ return pose_token, pose_output
649
+
650
+ @torch.no_grad()
651
+ def _get_keypoint_prompt(self, batch, pred_keypoints_2d, force_dummy=False):
652
+ if self.camera_type == "perspective":
653
+ pred_keypoints_2d = self._full_to_crop(batch, pred_keypoints_2d)
654
+
655
+ gt_keypoints_2d = self._flatten_person(batch["keypoints_2d"]).clone()
656
+
657
+ keypoint_prompt = self.keypoint_prompt_sampler.sample(
658
+ gt_keypoints_2d,
659
+ pred_keypoints_2d,
660
+ is_train=self.training,
661
+ force_dummy=force_dummy,
662
+ )
663
+ return keypoint_prompt
664
+
665
+ def _get_mask_prompt(self, batch, image_embeddings):
666
+ x_mask = self._flatten_person(batch["mask"])
667
+ mask_embeddings, no_mask_embeddings = self.prompt_encoder.get_mask_embeddings(
668
+ x_mask, image_embeddings.shape[0], image_embeddings.shape[2:]
669
+ )
670
+ if self.cfg.MODEL.BACKBONE.TYPE in [
671
+ "vit_hmr",
672
+ "vit",
673
+ ]:
674
+ # ViT backbone assumes a different aspect ratio as input size
675
+ mask_embeddings = mask_embeddings[:, :, :, 2:-2]
676
+ elif self.cfg.MODEL.BACKBONE.TYPE in [
677
+ "vit_hmr_512_384",
678
+ ]:
679
+ # for x2 resolution
680
+ mask_embeddings = mask_embeddings[:, :, :, 4:-4]
681
+
682
+ mask_score = self._flatten_person(batch["mask_score"]).view(-1, 1, 1, 1)
683
+ mask_embeddings = torch.where(
684
+ mask_score > 0,
685
+ mask_score * mask_embeddings.to(image_embeddings),
686
+ no_mask_embeddings.to(image_embeddings),
687
+ )
688
+ return mask_embeddings
689
+
690
+ def _one_prompt_iter(self, batch, output, prev_prompt, full_output):
691
+ image_embeddings = output["image_embeddings"]
692
+ condition_info = output["condition_info"]
693
+
694
+ if "mhr" in output and output["mhr"] is not None:
695
+ pose_output = output["mhr"] # body-only output
696
+ # Use previous estimate as initialization
697
+ prev_estimate = torch.cat(
698
+ [
699
+ pose_output["pred_pose_raw"].detach(), # (B, 6)
700
+ pose_output["shape"].detach(),
701
+ pose_output["scale"].detach(),
702
+ pose_output["hand"].detach(),
703
+ pose_output["face"].detach(),
704
+ ],
705
+ dim=1,
706
+ ).unsqueeze(dim=1)
707
+ if hasattr(self, "init_camera"):
708
+ prev_estimate = torch.cat(
709
+ [prev_estimate, pose_output["pred_cam"].detach().unsqueeze(1)],
710
+ dim=-1,
711
+ )
712
+ prev_shape = prev_estimate.shape[1:]
713
+
714
+ pred_keypoints_2d = output["mhr"]["pred_keypoints_2d"].detach().clone()
715
+ kpt_shape = pred_keypoints_2d.shape[1:]
716
+
717
+ if "mhr_hand" in output and output["mhr_hand"] is not None:
718
+ pose_output_hand = output["mhr_hand"]
719
+ # Use previous estimate as initialization
720
+ prev_estimate_hand = torch.cat(
721
+ [
722
+ pose_output_hand["pred_pose_raw"].detach(), # (B, 6)
723
+ pose_output_hand["shape"].detach(),
724
+ pose_output_hand["scale"].detach(),
725
+ pose_output_hand["hand"].detach(),
726
+ pose_output_hand["face"].detach(),
727
+ ],
728
+ dim=1,
729
+ ).unsqueeze(dim=1)
730
+ if hasattr(self, "init_camera_hand"):
731
+ prev_estimate_hand = torch.cat(
732
+ [
733
+ prev_estimate_hand,
734
+ pose_output_hand["pred_cam"].detach().unsqueeze(1),
735
+ ],
736
+ dim=-1,
737
+ )
738
+ prev_shape = prev_estimate_hand.shape[1:]
739
+
740
+ pred_keypoints_2d_hand = output["mhr_hand"]["pred_keypoints_2d"].detach().clone()
741
+ kpt_shape = pred_keypoints_2d_hand.shape[1:]
742
+
743
+ all_prev_estimate = torch.zeros((image_embeddings.shape[0], *prev_shape), device=image_embeddings.device)
744
+ if "mhr" in output and output["mhr"] is not None:
745
+ all_prev_estimate[self.body_batch_idx] = prev_estimate
746
+ if "mhr_hand" in output and output["mhr_hand"] is not None:
747
+ all_prev_estimate[self.hand_batch_idx] = prev_estimate_hand
748
+
749
+ # Get keypoint prompts
750
+ all_pred_keypoints_2d = torch.zeros((image_embeddings.shape[0], *kpt_shape), device=image_embeddings.device)
751
+ if "mhr" in output and output["mhr"] is not None:
752
+ all_pred_keypoints_2d[self.body_batch_idx] = pred_keypoints_2d
753
+ if "mhr_hand" in output and output["mhr_hand"] is not None:
754
+ all_pred_keypoints_2d[self.hand_batch_idx] = pred_keypoints_2d_hand
755
+
756
+ keypoint_prompt = self._get_keypoint_prompt(batch, all_pred_keypoints_2d)
757
+ cur_keypoint_prompt = (
758
+ torch.cat(prev_prompt + [keypoint_prompt], dim=1) if len(prev_prompt) else keypoint_prompt
759
+ ) # [B, 1, 3]
760
+
761
+ pose_output, pose_output_hand = None, None
762
+ if len(self.body_batch_idx):
763
+ tokens_output, pose_output = self.forward_decoder(
764
+ image_embeddings[self.body_batch_idx],
765
+ init_estimate=None, # not recurring previous estimate
766
+ keypoints=cur_keypoint_prompt[self.body_batch_idx],
767
+ prev_estimate=all_prev_estimate[self.body_batch_idx],
768
+ condition_info=condition_info[self.body_batch_idx],
769
+ batch=batch,
770
+ full_output=None,
771
+ )
772
+ pose_output = pose_output[-1]
773
+
774
+ # Update prediction output
775
+ output.update(
776
+ {
777
+ "mhr": pose_output,
778
+ "mhr_hand": pose_output_hand,
779
+ }
780
+ )
781
+
782
+ return output, keypoint_prompt
783
+
784
+ def _full_to_crop(
785
+ self,
786
+ batch: dict,
787
+ pred_keypoints_2d: torch.Tensor,
788
+ batch_idx: torch.Tensor | Sequence[int] | None = None,
789
+ ) -> torch.Tensor:
790
+ """Convert full-image keypoints coordinates to crop and normalize to [-0.5. 0.5]"""
791
+ pred_keypoints_2d_cropped = torch.cat(
792
+ [pred_keypoints_2d, torch.ones_like(pred_keypoints_2d[:, :, [-1]])], dim=-1
793
+ )
794
+ if batch_idx is not None:
795
+ affine_trans = self._flatten_person(batch["affine_trans"])[batch_idx].to(pred_keypoints_2d_cropped)
796
+ img_size = self._flatten_person(batch["img_size"])[batch_idx].unsqueeze(1)
797
+ else:
798
+ affine_trans = self._flatten_person(batch["affine_trans"]).to(pred_keypoints_2d_cropped)
799
+ img_size = self._flatten_person(batch["img_size"]).unsqueeze(1)
800
+ pred_keypoints_2d_cropped = pred_keypoints_2d_cropped @ affine_trans.mT
801
+ pred_keypoints_2d_cropped = pred_keypoints_2d_cropped[..., :2] / img_size - 0.5
802
+
803
+ return pred_keypoints_2d_cropped
804
+
805
+ def camera_project(self, pose_output: dict, batch: dict) -> dict:
806
+ """
807
+ Project 3D keypoints to 2D using the camera parameters.
808
+ Args:
809
+ pose_output (Dict): Dictionary containing the pose output.
810
+ batch (Dict): Dictionary containing the batch data.
811
+ Returns:
812
+ Dict: Dictionary containing the projected 2D keypoints.
813
+ """
814
+ if hasattr(self, "head_camera"):
815
+ head_camera = self.head_camera
816
+ pred_cam = pose_output["pred_cam"]
817
+ else:
818
+ raise AssertionError("head_camera is not defined")
819
+
820
+ cam_out = head_camera.perspective_projection(
821
+ pose_output["pred_keypoints_3d"],
822
+ pred_cam,
823
+ self._flatten_person(batch["bbox_center"])[self.body_batch_idx],
824
+ self._flatten_person(batch["bbox_scale"])[self.body_batch_idx, 0],
825
+ self._flatten_person(batch["ori_img_size"])[self.body_batch_idx],
826
+ self._flatten_person(batch["cam_int"].unsqueeze(1).expand(-1, batch["img"].shape[1], -1, -1).contiguous())[
827
+ self.body_batch_idx
828
+ ],
829
+ use_intrin_center=self.cfg.MODEL.DECODER.get("USE_INTRIN_CENTER", False),
830
+ )
831
+
832
+ if pose_output.get("pred_vertices") is not None:
833
+ cam_out_vertices = head_camera.perspective_projection(
834
+ pose_output["pred_vertices"],
835
+ pred_cam,
836
+ self._flatten_person(batch["bbox_center"])[self.body_batch_idx],
837
+ self._flatten_person(batch["bbox_scale"])[self.body_batch_idx, 0],
838
+ self._flatten_person(batch["ori_img_size"])[self.body_batch_idx],
839
+ self._flatten_person(
840
+ batch["cam_int"].unsqueeze(1).expand(-1, batch["img"].shape[1], -1, -1).contiguous()
841
+ )[self.body_batch_idx],
842
+ use_intrin_center=self.cfg.MODEL.DECODER.get("USE_INTRIN_CENTER", False),
843
+ )
844
+ pose_output["pred_keypoints_2d_verts"] = cam_out_vertices["pred_keypoints_2d"]
845
+
846
+ pose_output.update(cam_out)
847
+
848
+ return pose_output
849
+
850
+ def camera_project_hand(self, pose_output: dict, batch: dict) -> dict:
851
+ """
852
+ Project 3D keypoints to 2D using the camera parameters.
853
+ Args:
854
+ pose_output (Dict): Dictionary containing the pose output.
855
+ batch (Dict): Dictionary containing the batch data.
856
+ Returns:
857
+ Dict: Dictionary containing the projected 2D keypoints.
858
+ """
859
+ if hasattr(self, "head_camera_hand"):
860
+ head_camera = self.head_camera_hand
861
+ pred_cam = pose_output["pred_cam"]
862
+ else:
863
+ raise AssertionError("head_camera_hand is not defined")
864
+
865
+ cam_out = head_camera.perspective_projection(
866
+ pose_output["pred_keypoints_3d"],
867
+ pred_cam,
868
+ self._flatten_person(batch["bbox_center"])[self.hand_batch_idx],
869
+ self._flatten_person(batch["bbox_scale"])[self.hand_batch_idx, 0],
870
+ self._flatten_person(batch["ori_img_size"])[self.hand_batch_idx],
871
+ self._flatten_person(batch["cam_int"].unsqueeze(1).expand(-1, batch["img"].shape[1], -1, -1).contiguous())[
872
+ self.hand_batch_idx
873
+ ],
874
+ use_intrin_center=self.cfg.MODEL.DECODER.get("USE_INTRIN_CENTER", False),
875
+ )
876
+
877
+ if pose_output.get("pred_vertices") is not None:
878
+ cam_out_vertices = head_camera.perspective_projection(
879
+ pose_output["pred_vertices"],
880
+ pred_cam,
881
+ self._flatten_person(batch["bbox_center"])[self.hand_batch_idx],
882
+ self._flatten_person(batch["bbox_scale"])[self.hand_batch_idx, 0],
883
+ self._flatten_person(batch["ori_img_size"])[self.hand_batch_idx],
884
+ self._flatten_person(
885
+ batch["cam_int"].unsqueeze(1).expand(-1, batch["img"].shape[1], -1, -1).contiguous()
886
+ )[self.hand_batch_idx],
887
+ use_intrin_center=self.cfg.MODEL.DECODER.get("USE_INTRIN_CENTER", False),
888
+ )
889
+ pose_output["pred_keypoints_2d_verts"] = cam_out_vertices["pred_keypoints_2d"]
890
+
891
+ pose_output.update(cam_out)
892
+
893
+ return pose_output
894
+
895
+ def get_ray_condition(self, batch):
896
+ B, N, _, H, W = batch["img"].shape
897
+ meshgrid_xy = (
898
+ torch.stack(torch.meshgrid(torch.arange(H), torch.arange(W), indexing="xy"), dim=2)[None, None, :, :, :]
899
+ .repeat(B, N, 1, 1, 1)
900
+ .cuda()
901
+ ) # B x N x H x W x 2
902
+ meshgrid_xy = meshgrid_xy / batch["affine_trans"][:, :, None, None, [0, 1], [0, 1]]
903
+ meshgrid_xy = (
904
+ meshgrid_xy
905
+ - batch["affine_trans"][:, :, None, None, [0, 1], [2, 2]]
906
+ / batch["affine_trans"][:, :, None, None, [0, 1], [0, 1]]
907
+ )
908
+
909
+ # Subtract out center & normalize to be rays
910
+ meshgrid_xy = meshgrid_xy - batch["cam_int"][:, None, None, None, [0, 1], [2, 2]]
911
+ meshgrid_xy = meshgrid_xy / batch["cam_int"][:, None, None, None, [0, 1], [0, 1]]
912
+
913
+ return meshgrid_xy.permute(0, 1, 4, 2, 3).to(batch["img"].dtype) # This is B x num_person x 2 x H x W
914
+
915
+ def forward_pose_branch(self, batch: dict) -> dict:
916
+ """Run a forward pass for the crop-image (pose) branch."""
917
+ batch_size, num_person = batch["img"].shape[:2]
918
+
919
+ # Forward backbone encoder
920
+ x = self.data_preprocess(
921
+ self._flatten_person(batch["img"]),
922
+ crop_width=(
923
+ self.cfg.MODEL.BACKBONE.TYPE
924
+ in [
925
+ "vit_hmr",
926
+ "vit",
927
+ "vit_b",
928
+ "vit_l",
929
+ "vit_hmr_512_384",
930
+ ]
931
+ ),
932
+ )
933
+
934
+ # Optionally get ray conditioining
935
+ ray_cond = self.get_ray_condition(batch) # This is B x num_person x 2 x H x W
936
+ ray_cond = self._flatten_person(ray_cond)
937
+ if self.cfg.MODEL.BACKBONE.TYPE in [
938
+ "vit_hmr",
939
+ "vit",
940
+ "vit_b",
941
+ "vit_l",
942
+ ]:
943
+ ray_cond = ray_cond[:, :, :, 32:-32]
944
+ elif self.cfg.MODEL.BACKBONE.TYPE in [
945
+ "vit_hmr_512_384",
946
+ ]:
947
+ ray_cond = ray_cond[:, :, :, 64:-64]
948
+
949
+ if len(self.body_batch_idx):
950
+ batch["ray_cond"] = ray_cond[self.body_batch_idx].clone()
951
+ if len(self.hand_batch_idx):
952
+ batch["ray_cond_hand"] = ray_cond[self.hand_batch_idx].clone()
953
+ ray_cond = None
954
+
955
+ image_embeddings = self.backbone(x.type(self.backbone_dtype), extra_embed=ray_cond) # (B, C, H, W)
956
+
957
+ if isinstance(image_embeddings, tuple):
958
+ image_embeddings = image_embeddings[-1]
959
+ image_embeddings = image_embeddings.type(x.dtype)
960
+
961
+ # Mask condition if available
962
+ if self.cfg.MODEL.PROMPT_ENCODER.get("MASK_EMBED_TYPE", None) is not None:
963
+ # v1: non-iterative mask conditioning
964
+ if self.cfg.MODEL.PROMPT_ENCODER.get("MASK_PROMPT", "v1") == "v1":
965
+ mask_embeddings = self._get_mask_prompt(batch, image_embeddings)
966
+ image_embeddings = image_embeddings + mask_embeddings
967
+ else:
968
+ raise NotImplementedError
969
+
970
+ # Prepare input for promptable decoder
971
+ condition_info = self._get_decoder_condition(batch)
972
+
973
+ # Initial estimate with a dummy prompt
974
+ keypoints_prompt = torch.zeros((batch_size * num_person, 1, 3)).to(batch["img"])
975
+ keypoints_prompt[:, :, -1] = -2
976
+
977
+ # Forward promptable decoder to get updated pose tokens and regression output
978
+ pose_output, pose_output_hand = None, None
979
+ if len(self.body_batch_idx):
980
+ tokens_output, pose_output = self.forward_decoder(
981
+ image_embeddings[self.body_batch_idx],
982
+ init_estimate=None,
983
+ keypoints=keypoints_prompt[self.body_batch_idx],
984
+ prev_estimate=None,
985
+ condition_info=condition_info[self.body_batch_idx],
986
+ batch=batch,
987
+ )
988
+ pose_output = pose_output[-1]
989
+ if len(self.hand_batch_idx):
990
+ tokens_output_hand, pose_output_hand = self.forward_decoder_hand(
991
+ image_embeddings[self.hand_batch_idx],
992
+ init_estimate=None,
993
+ keypoints=keypoints_prompt[self.hand_batch_idx],
994
+ prev_estimate=None,
995
+ condition_info=condition_info[self.hand_batch_idx],
996
+ batch=batch,
997
+ )
998
+ pose_output_hand = pose_output_hand[-1]
999
+
1000
+ output = {
1001
+ # "pose_token": pose_token,
1002
+ "mhr": pose_output, # mhr prediction output
1003
+ "mhr_hand": pose_output_hand, # mhr prediction output
1004
+ "condition_info": condition_info,
1005
+ "image_embeddings": image_embeddings,
1006
+ }
1007
+
1008
+ if self.cfg.MODEL.DECODER.get("DO_HAND_DETECT_TOKENS", False):
1009
+ if len(self.body_batch_idx):
1010
+ output_hand_box_tokens = tokens_output
1011
+ hand_coords = self.bbox_embed(output_hand_box_tokens).sigmoid() # x1, y1, w, h for body samples, 0 ~ 1
1012
+ hand_logits = self.hand_cls_embed(output_hand_box_tokens)
1013
+
1014
+ output["mhr"]["hand_box"] = hand_coords
1015
+ output["mhr"]["hand_logits"] = hand_logits
1016
+
1017
+ if len(self.hand_batch_idx):
1018
+ output_hand_box_tokens_hand_batch = tokens_output_hand
1019
+
1020
+ hand_coords_hand_batch = self.bbox_embed(
1021
+ output_hand_box_tokens_hand_batch
1022
+ ).sigmoid() # x1, y1, w, h for hand samples
1023
+ hand_logits_hand_batch = self.hand_cls_embed(output_hand_box_tokens_hand_batch)
1024
+
1025
+ output["mhr_hand"]["hand_box"] = hand_coords_hand_batch
1026
+ output["mhr_hand"]["hand_logits"] = hand_logits_hand_batch
1027
+
1028
+ return output
1029
+
1030
+ def forward_step(self, batch: dict, decoder_type: str = "body") -> dict:
1031
+ batch_size, num_person = batch["img"].shape[:2]
1032
+
1033
+ if decoder_type == "body":
1034
+ self.hand_batch_idx = []
1035
+ self.body_batch_idx = list(range(batch_size * num_person))
1036
+ elif decoder_type == "hand":
1037
+ self.hand_batch_idx = list(range(batch_size * num_person))
1038
+ self.body_batch_idx = []
1039
+ else:
1040
+ ValueError("Invalid decoder type: ", decoder_type)
1041
+
1042
+ # Crop-image (pose) branch
1043
+ pose_output = self.forward_pose_branch(batch)
1044
+
1045
+ return pose_output
1046
+
1047
+ def run_inference(
1048
+ self,
1049
+ img,
1050
+ batch: dict,
1051
+ inference_type: str = "full",
1052
+ transform_hand: Any = None,
1053
+ thresh_wrist_angle=1.4,
1054
+ ):
1055
+ """
1056
+ Run 3DB inference (optionally with hand detector).
1057
+
1058
+ inference_type:
1059
+ - full: full-body inference with both body and hand decoders
1060
+ - body: inference with body decoder only (still full-body output)
1061
+ - hand: inference with hand decoder only (only hand output)
1062
+ """
1063
+
1064
+ height, width = img.shape[:2]
1065
+ cam_int = batch["cam_int"].clone()
1066
+
1067
+ if inference_type == "body":
1068
+ pose_output = self.forward_step(batch, decoder_type="body")
1069
+ return BodyPredContainer(pose_output=pose_output)
1070
+ elif inference_type == "hand":
1071
+ pose_output = self.forward_step(batch, decoder_type="hand")
1072
+ return BodyPredContainer(pose_output=pose_output)
1073
+ elif inference_type != "full":
1074
+ raise ValueError("Invalid inference type: ", inference_type)
1075
+
1076
+ # Step 1. For full-body inference, we first inference with the body decoder.
1077
+ pose_output = self.forward_step(batch, decoder_type="body")
1078
+ left_xyxy, right_xyxy = self._get_hand_box(pose_output, batch)
1079
+ ori_local_wrist_rotmat = roma.euler_to_rotmat(
1080
+ "XZY",
1081
+ pose_output["mhr"]["body_pose"][:, [41, 43, 42, 31, 33, 32]].unflatten(1, (2, 3)),
1082
+ )
1083
+
1084
+ # Step 2. Re-run with each hand
1085
+ ## Left... Flip image & box
1086
+ flipped_img = img[:, ::-1]
1087
+ tmp = left_xyxy.copy()
1088
+ left_xyxy[:, 0] = width - tmp[:, 2] - 1
1089
+ left_xyxy[:, 2] = width - tmp[:, 0] - 1
1090
+
1091
+ batch_lhand = prepare_batch(flipped_img, transform_hand, left_xyxy, cam_int=cam_int.clone())
1092
+ batch_lhand = recursive_to(batch_lhand, "cuda")
1093
+ lhand_output = self.forward_step(batch_lhand, decoder_type="hand")
1094
+
1095
+ # Unflip output
1096
+ ## Flip scale
1097
+ ### Get MHR values
1098
+ scale_r_hands_mean = self.head_pose.scale_mean[8].item()
1099
+ scale_l_hands_mean = self.head_pose.scale_mean[9].item()
1100
+ scale_r_hands_std = self.head_pose.scale_comps[8, 8].item()
1101
+ scale_l_hands_std = self.head_pose.scale_comps[9, 9].item()
1102
+ ### Apply
1103
+ lhand_output["mhr_hand"]["scale"][:, 9] = (
1104
+ (scale_r_hands_mean + scale_r_hands_std * lhand_output["mhr_hand"]["scale"][:, 8]) - scale_l_hands_mean
1105
+ ) / scale_l_hands_std
1106
+ ## Get the right hand global rotation, flip it, put it in as left.
1107
+ lhand_output["mhr_hand"]["joint_global_rots"][:, 78] = lhand_output["mhr_hand"]["joint_global_rots"][
1108
+ :, 42
1109
+ ].clone()
1110
+ lhand_output["mhr_hand"]["joint_global_rots"][:, 78, [1, 2], :] *= -1
1111
+ ### Flip hand pose
1112
+ lhand_output["mhr_hand"]["hand"][:, :54] = lhand_output["mhr_hand"]["hand"][:, 54:]
1113
+ ### Unflip box
1114
+ batch_lhand["bbox_center"][:, :, 0] = width - batch_lhand["bbox_center"][:, :, 0] - 1
1115
+
1116
+ ## Right...
1117
+ batch_rhand = prepare_batch(img, transform_hand, right_xyxy, cam_int=cam_int.clone())
1118
+ batch_rhand = recursive_to(batch_rhand, "cuda")
1119
+ rhand_output = self.forward_step(batch_rhand, decoder_type="hand")
1120
+
1121
+ # Step 3. replace hand pose estimation from the body decoder.
1122
+ ## CRITERIA 1: LOCAL WRIST POSE DIFFERENCE
1123
+ joint_rotations = pose_output["mhr"]["joint_global_rots"]
1124
+ ### Get lowarm
1125
+ lowarm_joint_idxs = torch.LongTensor([76, 40]).cuda() # left, right
1126
+ lowarm_joint_rotations = joint_rotations[:, lowarm_joint_idxs] # B x 2 x 3 x 3
1127
+ ### Get zero-wrist pose
1128
+ wrist_twist_joint_idxs = torch.LongTensor([77, 41]).cuda() # left, right
1129
+ wrist_zero_rot_pose = lowarm_joint_rotations @ self.head_pose.joint_rotation[wrist_twist_joint_idxs]
1130
+ ### Get globals from left & right
1131
+ left_joint_global_rots = lhand_output["mhr_hand"]["joint_global_rots"]
1132
+ right_joint_global_rots = rhand_output["mhr_hand"]["joint_global_rots"]
1133
+ pred_global_wrist_rotmat = torch.stack(
1134
+ [
1135
+ left_joint_global_rots[:, 78],
1136
+ right_joint_global_rots[:, 42],
1137
+ ],
1138
+ dim=1,
1139
+ )
1140
+ ### Get the local poses that lead to the wrist being pred_global_wrist_rotmat
1141
+ fused_local_wrist_rotmat = torch.einsum("kabc,kabd->kadc", pred_global_wrist_rotmat, wrist_zero_rot_pose)
1142
+ angle_difference = rotation_angle_difference(ori_local_wrist_rotmat, fused_local_wrist_rotmat) # B x 2 x 3 x3
1143
+ angle_difference_valid_mask = angle_difference < thresh_wrist_angle
1144
+
1145
+ ## CRITERIA 2: hand box size
1146
+ hand_box_size_thresh = 64
1147
+ hand_box_size_valid_mask = torch.stack(
1148
+ [
1149
+ (batch_lhand["bbox_scale"].flatten(0, 1) > hand_box_size_thresh).all(dim=1),
1150
+ (batch_rhand["bbox_scale"].flatten(0, 1) > hand_box_size_thresh).all(dim=1),
1151
+ ],
1152
+ dim=1,
1153
+ )
1154
+
1155
+ ## CRITERIA 3: all hand 2D KPS (including wrist) inside of box.
1156
+ hand_kps2d_thresh = 0.5
1157
+ hand_kps2d_valid_mask = torch.stack(
1158
+ [
1159
+ lhand_output["mhr_hand"]["pred_keypoints_2d_cropped"].abs().amax(dim=(1, 2)) < hand_kps2d_thresh,
1160
+ rhand_output["mhr_hand"]["pred_keypoints_2d_cropped"].abs().amax(dim=(1, 2)) < hand_kps2d_thresh,
1161
+ ],
1162
+ dim=1,
1163
+ )
1164
+
1165
+ ## CRITERIA 4: 2D wrist distance.
1166
+ hand_wrist_kps2d_thresh = 0.25
1167
+ kps_right_wrist_idx = 41
1168
+ kps_left_wrist_idx = 62
1169
+ right_kps_full = rhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone()
1170
+ left_kps_full = lhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone()
1171
+ left_kps_full[:, :, 0] = width - left_kps_full[:, :, 0] - 1 # Flip left hand
1172
+ body_right_kps_full = pose_output["mhr"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone()
1173
+ body_left_kps_full = pose_output["mhr"]["pred_keypoints_2d"][:, [kps_left_wrist_idx]].clone()
1174
+ right_kps_dist = (right_kps_full - body_right_kps_full).flatten(0, 1).norm(dim=-1) / batch_lhand[
1175
+ "bbox_scale"
1176
+ ].flatten(0, 1)[:, 0]
1177
+ left_kps_dist = (left_kps_full - body_left_kps_full).flatten(0, 1).norm(dim=-1) / batch_rhand[
1178
+ "bbox_scale"
1179
+ ].flatten(0, 1)[:, 0]
1180
+ hand_wrist_kps2d_valid_mask = torch.stack(
1181
+ [
1182
+ left_kps_dist < hand_wrist_kps2d_thresh,
1183
+ right_kps_dist < hand_wrist_kps2d_thresh,
1184
+ ],
1185
+ dim=1,
1186
+ )
1187
+ ## Left-right
1188
+ hand_valid_mask = (
1189
+ angle_difference_valid_mask & hand_box_size_valid_mask & hand_kps2d_valid_mask & hand_wrist_kps2d_valid_mask
1190
+ )
1191
+
1192
+ # Keypoint prompting with the body decoder.
1193
+ # We use the wrist location from the hand decoder and the elbow location
1194
+ # from the body decoder as prompts to get an updated body pose estimation.
1195
+ batch_size, num_person = batch["img"].shape[:2]
1196
+ self.hand_batch_idx = []
1197
+ self.body_batch_idx = list(range(batch_size * num_person))
1198
+
1199
+ ## Get right & left wrist keypoints from crops; full image. Each are B x 1 x 2
1200
+ kps_right_wrist_idx = 41
1201
+ kps_left_wrist_idx = 62
1202
+ right_kps_full = rhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone()
1203
+ left_kps_full = lhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone()
1204
+ left_kps_full[:, :, 0] = width - left_kps_full[:, :, 0] - 1 # Flip left hand
1205
+
1206
+ # Next, get them to crop-normalized space.
1207
+ right_kps_crop = self._full_to_crop(batch, right_kps_full)
1208
+ left_kps_crop = self._full_to_crop(batch, left_kps_full)
1209
+
1210
+ # Get right & left elbow keypoints from crops; full image. Each are B x 1 x 2
1211
+ kps_right_elbow_idx = 8
1212
+ kps_left_elbow_idx = 7
1213
+ right_kps_elbow_full = pose_output["mhr"]["pred_keypoints_2d"][:, [kps_right_elbow_idx]].clone()
1214
+ left_kps_elbow_full = pose_output["mhr"]["pred_keypoints_2d"][:, [kps_left_elbow_idx]].clone()
1215
+
1216
+ # Next, get them to crop-normalized space.
1217
+ right_kps_elbow_crop = self._full_to_crop(batch, right_kps_elbow_full)
1218
+ left_kps_elbow_crop = self._full_to_crop(batch, left_kps_elbow_full)
1219
+
1220
+ # Assemble them into keypoint prompts
1221
+ keypoint_prompt = torch.cat(
1222
+ [right_kps_crop, left_kps_crop, right_kps_elbow_crop, left_kps_elbow_crop],
1223
+ dim=1,
1224
+ )
1225
+ keypoint_prompt = torch.cat([keypoint_prompt, keypoint_prompt[..., [-1]]], dim=-1)
1226
+ keypoint_prompt[:, 0, -1] = kps_right_wrist_idx
1227
+ keypoint_prompt[:, 1, -1] = kps_left_wrist_idx
1228
+ keypoint_prompt[:, 2, -1] = kps_right_elbow_idx
1229
+ keypoint_prompt[:, 3, -1] = kps_left_elbow_idx
1230
+
1231
+ if keypoint_prompt.shape[0] > 1:
1232
+ # Replace invalid keypoints to dummy prompts
1233
+ invalid_prompt = (
1234
+ (keypoint_prompt[..., 0] < -0.5)
1235
+ | (keypoint_prompt[..., 0] > 0.5)
1236
+ | (keypoint_prompt[..., 1] < -0.5)
1237
+ | (keypoint_prompt[..., 1] > 0.5)
1238
+ | (~hand_valid_mask[..., [1, 0, 1, 0]])
1239
+ ).unsqueeze(-1)
1240
+ dummy_prompt = torch.zeros((1, 1, 3)).to(keypoint_prompt)
1241
+ dummy_prompt[:, :, -1] = -2
1242
+ keypoint_prompt[:, :, :2] = torch.clamp(
1243
+ keypoint_prompt[:, :, :2] + 0.5, min=0.0, max=1.0
1244
+ ) # [-0.5, 0.5] --> [0, 1]
1245
+ keypoint_prompt = torch.where(invalid_prompt, dummy_prompt, keypoint_prompt)
1246
+ else:
1247
+ # Only keep valid keypoints
1248
+ valid_keypoint = (
1249
+ torch.all(
1250
+ (keypoint_prompt[:, :, :2] > -0.5) & (keypoint_prompt[:, :, :2] < 0.5),
1251
+ dim=2,
1252
+ )
1253
+ & hand_valid_mask[..., [1, 0, 1, 0]]
1254
+ ).squeeze()
1255
+ keypoint_prompt = keypoint_prompt[:, valid_keypoint]
1256
+ keypoint_prompt[:, :, :2] = torch.clamp(
1257
+ keypoint_prompt[:, :, :2] + 0.5, min=0.0, max=1.0
1258
+ ) # [-0.5, 0.5] --> [0, 1]
1259
+
1260
+ if keypoint_prompt.numel() != 0:
1261
+ pose_output, _ = self.run_keypoint_prompt(batch, pose_output, keypoint_prompt)
1262
+
1263
+ ##############################################################################
1264
+
1265
+ # Drop in hand pose
1266
+ left_hand_pose_params = lhand_output["mhr_hand"]["hand"][:, :54]
1267
+ right_hand_pose_params = rhand_output["mhr_hand"]["hand"][:, 54:]
1268
+ updated_hand_pose = torch.cat([left_hand_pose_params, right_hand_pose_params], dim=1)
1269
+
1270
+ # Drop in hand scales
1271
+ updated_scale = pose_output["mhr"]["scale"].clone()
1272
+ updated_scale[:, 9] = lhand_output["mhr_hand"]["scale"][:, 9]
1273
+ updated_scale[:, 8] = rhand_output["mhr_hand"]["scale"][:, 8]
1274
+ updated_scale[:, 18:] = (
1275
+ lhand_output["mhr_hand"]["scale"][:, 18:] + rhand_output["mhr_hand"]["scale"][:, 18:]
1276
+ ) / 2
1277
+
1278
+ # Update hand shape
1279
+ updated_shape = pose_output["mhr"]["shape"].clone()
1280
+ updated_shape[:, 40:] = (
1281
+ lhand_output["mhr_hand"]["shape"][:, 40:] + rhand_output["mhr_hand"]["shape"][:, 40:]
1282
+ ) / 2
1283
+
1284
+ ############################ Doing IK ############################
1285
+
1286
+ # First, forward just FK
1287
+ joint_rotations = self.head_pose.mhr_forward(
1288
+ global_trans=pose_output["mhr"]["global_rot"] * 0,
1289
+ global_rot=pose_output["mhr"]["global_rot"],
1290
+ body_pose_params=pose_output["mhr"]["body_pose"],
1291
+ hand_pose_params=updated_hand_pose,
1292
+ scale_params=updated_scale,
1293
+ shape_params=updated_shape,
1294
+ expr_params=pose_output["mhr"]["face"],
1295
+ return_joint_rotations=True,
1296
+ )[1]
1297
+
1298
+ # Get lowarm
1299
+ lowarm_joint_idxs = torch.LongTensor([76, 40]).cuda() # left, right
1300
+ lowarm_joint_rotations = joint_rotations[:, lowarm_joint_idxs] # B x 2 x 3 x 3
1301
+
1302
+ # Get zero-wrist pose
1303
+ wrist_twist_joint_idxs = torch.LongTensor([77, 41]).cuda() # left, right
1304
+ wrist_zero_rot_pose = lowarm_joint_rotations @ self.head_pose.joint_rotation[wrist_twist_joint_idxs]
1305
+
1306
+ # Get globals from left & right
1307
+ left_joint_global_rots = lhand_output["mhr_hand"]["joint_global_rots"]
1308
+ right_joint_global_rots = rhand_output["mhr_hand"]["joint_global_rots"]
1309
+ pred_global_wrist_rotmat = torch.stack(
1310
+ [
1311
+ left_joint_global_rots[:, 78],
1312
+ right_joint_global_rots[:, 42],
1313
+ ],
1314
+ dim=1,
1315
+ )
1316
+
1317
+ # Now we want to get the local poses that lead to the wrist being pred_global_wrist_rotmat
1318
+ fused_local_wrist_rotmat = torch.einsum("kabc,kabd->kadc", pred_global_wrist_rotmat, wrist_zero_rot_pose)
1319
+ wrist_xzy = fix_wrist_euler(roma.rotmat_to_euler("XZY", fused_local_wrist_rotmat))
1320
+
1321
+ # Put it in.
1322
+ angle_difference = rotation_angle_difference(ori_local_wrist_rotmat, fused_local_wrist_rotmat) # B x 2 x 3 x3
1323
+ valid_angle = angle_difference < thresh_wrist_angle
1324
+ valid_angle = valid_angle & hand_valid_mask
1325
+ valid_angle = valid_angle.unsqueeze(-1)
1326
+
1327
+ body_pose = pose_output["mhr"]["body_pose"][:, [41, 43, 42, 31, 33, 32]].unflatten(1, (2, 3))
1328
+ updated_body_pose = torch.where(valid_angle, wrist_xzy, body_pose)
1329
+ pose_output["mhr"]["body_pose"][:, [41, 43, 42, 31, 33, 32]] = updated_body_pose.flatten(1, 2)
1330
+
1331
+ hand_pose = pose_output["mhr"]["hand"].unflatten(1, (2, 54))
1332
+ pose_output["mhr"]["hand"] = torch.where(
1333
+ valid_angle, updated_hand_pose.unflatten(1, (2, 54)), hand_pose
1334
+ ).flatten(1, 2)
1335
+
1336
+ hand_scale = torch.stack(
1337
+ [pose_output["mhr"]["scale"][:, 9], pose_output["mhr"]["scale"][:, 8]],
1338
+ dim=1,
1339
+ )
1340
+ updated_hand_scale = torch.stack([updated_scale[:, 9], updated_scale[:, 8]], dim=1)
1341
+ masked_hand_scale = torch.where(valid_angle.squeeze(-1), updated_hand_scale, hand_scale)
1342
+ pose_output["mhr"]["scale"][:, 9] = masked_hand_scale[:, 0]
1343
+ pose_output["mhr"]["scale"][:, 8] = masked_hand_scale[:, 1]
1344
+
1345
+ # Replace shared shape and scale
1346
+ pose_output["mhr"]["scale"][:, 18:] = torch.where(
1347
+ valid_angle.squeeze(-1).sum(dim=1, keepdim=True) > 0,
1348
+ (
1349
+ lhand_output["mhr_hand"]["scale"][:, 18:] * valid_angle.squeeze(-1)[:, [0]]
1350
+ + rhand_output["mhr_hand"]["scale"][:, 18:] * valid_angle.squeeze(-1)[:, [1]]
1351
+ )
1352
+ / (valid_angle.squeeze(-1).sum(dim=1, keepdim=True) + 1e-8),
1353
+ pose_output["mhr"]["scale"][:, 18:],
1354
+ )
1355
+ pose_output["mhr"]["shape"][:, 40:] = torch.where(
1356
+ valid_angle.squeeze(-1).sum(dim=1, keepdim=True) > 0,
1357
+ (
1358
+ lhand_output["mhr_hand"]["shape"][:, 40:] * valid_angle.squeeze(-1)[:, [0]]
1359
+ + rhand_output["mhr_hand"]["shape"][:, 40:] * valid_angle.squeeze(-1)[:, [1]]
1360
+ )
1361
+ / (valid_angle.squeeze(-1).sum(dim=1, keepdim=True) + 1e-8),
1362
+ pose_output["mhr"]["shape"][:, 40:],
1363
+ )
1364
+
1365
+ ########################################################
1366
+
1367
+ # Re-run forward
1368
+ with torch.no_grad():
1369
+ verts, j3d, jcoords, mhr_model_params, joint_global_rots = self.head_pose.mhr_forward(
1370
+ global_trans=pose_output["mhr"]["global_rot"] * 0,
1371
+ global_rot=pose_output["mhr"]["global_rot"],
1372
+ body_pose_params=pose_output["mhr"]["body_pose"],
1373
+ hand_pose_params=pose_output["mhr"]["hand"],
1374
+ scale_params=pose_output["mhr"]["scale"],
1375
+ shape_params=pose_output["mhr"]["shape"],
1376
+ expr_params=pose_output["mhr"]["face"],
1377
+ return_keypoints=True,
1378
+ return_joint_coords=True,
1379
+ return_model_params=True,
1380
+ return_joint_rotations=True,
1381
+ )
1382
+ j3d = j3d[:, :70] # 308 --> 70 keypoints
1383
+ verts[..., [1, 2]] *= -1 # Camera system difference
1384
+ j3d[..., [1, 2]] *= -1 # Camera system difference
1385
+ jcoords[..., [1, 2]] *= -1
1386
+ pose_output["mhr"]["pred_keypoints_3d"] = j3d
1387
+ pose_output["mhr"]["pred_vertices"] = verts
1388
+ pose_output["mhr"]["pred_joint_coords"] = jcoords
1389
+ pose_output["mhr"]["pred_pose_raw"][...] = 0 # pred_pose_raw is not valid anymore
1390
+ pose_output["mhr"]["mhr_model_params"] = mhr_model_params
1391
+
1392
+ ########################################################
1393
+ # Project to 2D
1394
+ pred_keypoints_3d_proj = pose_output["mhr"]["pred_keypoints_3d"] + pose_output["mhr"]["pred_cam_t"][:, None, :]
1395
+ pred_keypoints_3d_proj[:, :, [0, 1]] *= pose_output["mhr"]["focal_length"][:, None, None]
1396
+ pred_keypoints_3d_proj[:, :, [0, 1]] = (
1397
+ pred_keypoints_3d_proj[:, :, [0, 1]]
1398
+ + torch.FloatTensor([width / 2, height / 2]).to(pred_keypoints_3d_proj)[None, None, :]
1399
+ * pred_keypoints_3d_proj[:, :, [2]]
1400
+ )
1401
+ pred_keypoints_3d_proj[:, :, :2] = pred_keypoints_3d_proj[:, :, :2] / pred_keypoints_3d_proj[:, :, [2]]
1402
+ pose_output["mhr"]["pred_keypoints_2d"] = pred_keypoints_3d_proj[:, :, :2]
1403
+
1404
+ return BodyPredContainer(
1405
+ pose_output=pose_output,
1406
+ batch_lhand=batch_lhand,
1407
+ batch_rhand=batch_rhand,
1408
+ lhand_output=lhand_output,
1409
+ rhand_output=rhand_output,
1410
+ )
1411
+
1412
+ def run_keypoint_prompt(self, batch, output, keypoint_prompt):
1413
+ image_embeddings = output["image_embeddings"]
1414
+ condition_info = output["condition_info"]
1415
+ pose_output = output["mhr"] # body-only output
1416
+ # Use previous estimate as initialization
1417
+ prev_estimate = torch.cat(
1418
+ [
1419
+ pose_output["pred_pose_raw"].detach(), # (B, 6)
1420
+ pose_output["shape"].detach(),
1421
+ pose_output["scale"].detach(),
1422
+ pose_output["hand"].detach(),
1423
+ pose_output["face"].detach(),
1424
+ ],
1425
+ dim=1,
1426
+ ).unsqueeze(dim=1)
1427
+ if hasattr(self, "init_camera"):
1428
+ prev_estimate = torch.cat(
1429
+ [prev_estimate, pose_output["pred_cam"].detach().unsqueeze(1)],
1430
+ dim=-1,
1431
+ )
1432
+
1433
+ tokens_output, pose_output = self.forward_decoder(
1434
+ image_embeddings,
1435
+ init_estimate=None, # not recurring previous estimate
1436
+ keypoints=keypoint_prompt,
1437
+ prev_estimate=prev_estimate,
1438
+ condition_info=condition_info,
1439
+ batch=batch,
1440
+ )
1441
+ pose_output = pose_output[-1]
1442
+
1443
+ output.update({"mhr": pose_output})
1444
+ return output, keypoint_prompt
1445
+
1446
+ def _get_hand_box(self, pose_output, batch):
1447
+ """Get hand bbox from the hand detector"""
1448
+ pred_left_hand_box = pose_output["mhr"]["hand_box"][:, 0].detach().cpu().numpy() * self.cfg.MODEL.IMAGE_SIZE[0]
1449
+ pred_right_hand_box = pose_output["mhr"]["hand_box"][:, 1].detach().cpu().numpy() * self.cfg.MODEL.IMAGE_SIZE[0]
1450
+
1451
+ # Change boxes into squares
1452
+ batch["left_center"] = pred_left_hand_box[:, :2]
1453
+ batch["left_scale"] = pred_left_hand_box[:, 2:].max(axis=1, keepdims=True).repeat(2, axis=1)
1454
+ batch["right_center"] = pred_right_hand_box[:, :2]
1455
+ batch["right_scale"] = pred_right_hand_box[:, 2:].max(axis=1, keepdims=True).repeat(2, axis=1)
1456
+
1457
+ # Crop to full. batch["affine_trans"] is full-to-crop, right application
1458
+ batch["left_scale"] = batch["left_scale"] / batch["affine_trans"][0, :, 0, 0].cpu().numpy()[:, None]
1459
+ batch["right_scale"] = batch["right_scale"] / batch["affine_trans"][0, :, 0, 0].cpu().numpy()[:, None]
1460
+ batch["left_center"] = (
1461
+ batch["left_center"] - batch["affine_trans"][0, :, [0, 1], [2, 2]].cpu().numpy()
1462
+ ) / batch["affine_trans"][0, :, 0, 0].cpu().numpy()[:, None]
1463
+ batch["right_center"] = (
1464
+ batch["right_center"] - batch["affine_trans"][0, :, [0, 1], [2, 2]].cpu().numpy()
1465
+ ) / batch["affine_trans"][0, :, 0, 0].cpu().numpy()[:, None]
1466
+
1467
+ left_xyxy = np.concatenate(
1468
+ [
1469
+ (batch["left_center"][:, 0] - batch["left_scale"][:, 0] * 1 / 2).reshape(-1, 1),
1470
+ (batch["left_center"][:, 1] - batch["left_scale"][:, 1] * 1 / 2).reshape(-1, 1),
1471
+ (batch["left_center"][:, 0] + batch["left_scale"][:, 0] * 1 / 2).reshape(-1, 1),
1472
+ (batch["left_center"][:, 1] + batch["left_scale"][:, 1] * 1 / 2).reshape(-1, 1),
1473
+ ],
1474
+ axis=1,
1475
+ )
1476
+ right_xyxy = np.concatenate(
1477
+ [
1478
+ (batch["right_center"][:, 0] - batch["right_scale"][:, 0] * 1 / 2).reshape(-1, 1),
1479
+ (batch["right_center"][:, 1] - batch["right_scale"][:, 1] * 1 / 2).reshape(-1, 1),
1480
+ (batch["right_center"][:, 0] + batch["right_scale"][:, 0] * 1 / 2).reshape(-1, 1),
1481
+ (batch["right_center"][:, 1] + batch["right_scale"][:, 1] * 1 / 2).reshape(-1, 1),
1482
+ ],
1483
+ axis=1,
1484
+ )
1485
+
1486
+ return left_xyxy, right_xyxy
1487
+
1488
+ def keypoint_token_update_fn(
1489
+ self,
1490
+ kps_emb_start_idx,
1491
+ image_embeddings,
1492
+ token_embeddings,
1493
+ token_augment,
1494
+ pose_output,
1495
+ layer_idx,
1496
+ ):
1497
+ # It's already after the last layer, we're done.
1498
+ if layer_idx == len(self.decoder.layers) - 1:
1499
+ return token_embeddings, token_augment, pose_output, layer_idx
1500
+
1501
+ # Clone
1502
+ token_embeddings = token_embeddings.clone()
1503
+ token_augment = token_augment.clone()
1504
+
1505
+ num_keypoints = self.keypoint_embedding.weight.shape[0]
1506
+
1507
+ # Get current 2D KPS predictions
1508
+ pred_keypoints_2d_cropped = pose_output["pred_keypoints_2d_cropped"].clone() # These are -0.5 ~ 0.5
1509
+ pred_keypoints_2d_depth = pose_output["pred_keypoints_2d_depth"].clone()
1510
+
1511
+ pred_keypoints_2d_cropped = pred_keypoints_2d_cropped[:, self.keypoint_embedding_idxs]
1512
+ pred_keypoints_2d_depth = pred_keypoints_2d_depth[:, self.keypoint_embedding_idxs]
1513
+
1514
+ # Get 2D KPS to be 0 ~ 1
1515
+ pred_keypoints_2d_cropped_01 = pred_keypoints_2d_cropped + 0.5
1516
+
1517
+ # Get a mask of those that are 1) beyond image boundaries or 2) behind the camera
1518
+ invalid_mask = (
1519
+ (pred_keypoints_2d_cropped_01[:, :, 0] < 0)
1520
+ | (pred_keypoints_2d_cropped_01[:, :, 0] > 1)
1521
+ | (pred_keypoints_2d_cropped_01[:, :, 1] < 0)
1522
+ | (pred_keypoints_2d_cropped_01[:, :, 1] > 1)
1523
+ | (pred_keypoints_2d_depth[:, :] < 1e-5)
1524
+ )
1525
+
1526
+ # Run them through the prompt encoder's pos emb function
1527
+ token_augment[:, kps_emb_start_idx : kps_emb_start_idx + num_keypoints, :] = self.keypoint_posemb_linear(
1528
+ pred_keypoints_2d_cropped
1529
+ ) * (~invalid_mask[:, :, None])
1530
+
1531
+ # Also maybe update token_embeddings with the grid sampled 2D feature.
1532
+ # Remember that pred_keypoints_2d_cropped are -0.5 ~ 0.5. We want -1 ~ 1
1533
+ # Sample points...
1534
+ ## Get sampling points
1535
+ pred_keypoints_2d_cropped_sample_points = pred_keypoints_2d_cropped * 2
1536
+ if self.cfg.MODEL.BACKBONE.TYPE in [
1537
+ "vit_hmr",
1538
+ "vit",
1539
+ "vit_b",
1540
+ "vit_l",
1541
+ "vit_hmr_512_384",
1542
+ ]:
1543
+ # Need to go from 256 x 256 coords to 256 x 192 (HW) because image_embeddings is 16x12
1544
+ # Aka, for x, what was normally -1 ~ 1 for 256 should be -16/12 ~ 16/12 (since to sample at original 256, need to overflow)
1545
+ pred_keypoints_2d_cropped_sample_points[:, :, 0] = (
1546
+ pred_keypoints_2d_cropped_sample_points[:, :, 0] / 12 * 16
1547
+ )
1548
+
1549
+ # Version 2 is projecting & bilinear sampling
1550
+ pred_keypoints_2d_cropped_feats = (
1551
+ F.grid_sample(
1552
+ image_embeddings,
1553
+ pred_keypoints_2d_cropped_sample_points[:, :, None, :], # -1 ~ 1, xy
1554
+ mode="bilinear",
1555
+ padding_mode="zeros",
1556
+ align_corners=False,
1557
+ )
1558
+ .squeeze(3)
1559
+ .permute(0, 2, 1)
1560
+ ) # B x kps x C
1561
+ # Zero out invalid locations...
1562
+ pred_keypoints_2d_cropped_feats = pred_keypoints_2d_cropped_feats * (~invalid_mask[:, :, None])
1563
+ # This is ADDING
1564
+ token_embeddings = token_embeddings.clone()
1565
+ token_embeddings[
1566
+ :,
1567
+ kps_emb_start_idx : kps_emb_start_idx + num_keypoints,
1568
+ :,
1569
+ ] += self.keypoint_feat_linear(pred_keypoints_2d_cropped_feats)
1570
+
1571
+ return token_embeddings, token_augment, pose_output, layer_idx
1572
+
1573
+ def keypoint3d_token_update_fn(
1574
+ self,
1575
+ kps3d_emb_start_idx,
1576
+ token_embeddings,
1577
+ token_augment,
1578
+ pose_output,
1579
+ layer_idx,
1580
+ ):
1581
+ # It's already after the last layer, we're done.
1582
+ if layer_idx == len(self.decoder.layers) - 1:
1583
+ return token_embeddings, token_augment, pose_output, layer_idx
1584
+
1585
+ num_keypoints3d = self.keypoint3d_embedding.weight.shape[0]
1586
+
1587
+ # Get current 3D kps predictions
1588
+ pred_keypoints_3d = pose_output["pred_keypoints_3d"].clone()
1589
+
1590
+ # Now, pelvis normalize
1591
+ pred_keypoints_3d = (
1592
+ pred_keypoints_3d
1593
+ - (pred_keypoints_3d[:, [self.pelvis_idx[0]], :] + pred_keypoints_3d[:, [self.pelvis_idx[1]], :]) / 2
1594
+ )
1595
+
1596
+ # Get the kps we care about, _after_ pelvis norm (just in case idxs shift)
1597
+ pred_keypoints_3d = pred_keypoints_3d[:, self.keypoint3d_embedding_idxs]
1598
+
1599
+ # Run through embedding MLP & put in
1600
+ token_augment = token_augment.clone()
1601
+ token_augment[
1602
+ :,
1603
+ kps3d_emb_start_idx : kps3d_emb_start_idx + num_keypoints3d,
1604
+ :,
1605
+ ] = self.keypoint3d_posemb_linear(pred_keypoints_3d)
1606
+
1607
+ return token_embeddings, token_augment, pose_output, layer_idx
1608
+
1609
+ def keypoint_token_update_fn_hand(
1610
+ self,
1611
+ kps_emb_start_idx,
1612
+ image_embeddings,
1613
+ token_embeddings,
1614
+ token_augment,
1615
+ pose_output,
1616
+ layer_idx,
1617
+ ):
1618
+ # It's already after the last layer, we're done.
1619
+ if layer_idx == len(self.decoder_hand.layers) - 1:
1620
+ return token_embeddings, token_augment, pose_output, layer_idx
1621
+
1622
+ # Clone
1623
+ token_embeddings = token_embeddings.clone()
1624
+ token_augment = token_augment.clone()
1625
+
1626
+ num_keypoints = self.keypoint_embedding_hand.weight.shape[0]
1627
+
1628
+ # Get current 2D KPS predictions
1629
+ pred_keypoints_2d_cropped = pose_output["pred_keypoints_2d_cropped"].clone() # These are -0.5 ~ 0.5
1630
+ pred_keypoints_2d_depth = pose_output["pred_keypoints_2d_depth"].clone()
1631
+
1632
+ pred_keypoints_2d_cropped = pred_keypoints_2d_cropped[:, self.keypoint_embedding_idxs_hand]
1633
+ pred_keypoints_2d_depth = pred_keypoints_2d_depth[:, self.keypoint_embedding_idxs_hand]
1634
+
1635
+ # Get 2D KPS to be 0 ~ 1
1636
+ pred_keypoints_2d_cropped_01 = pred_keypoints_2d_cropped + 0.5
1637
+
1638
+ # Get a mask of those that are 1) beyond image boundaries or 2) behind the camera
1639
+ invalid_mask = (
1640
+ (pred_keypoints_2d_cropped_01[:, :, 0] < 0)
1641
+ | (pred_keypoints_2d_cropped_01[:, :, 0] > 1)
1642
+ | (pred_keypoints_2d_cropped_01[:, :, 1] < 0)
1643
+ | (pred_keypoints_2d_cropped_01[:, :, 1] > 1)
1644
+ | (pred_keypoints_2d_depth[:, :] < 1e-5)
1645
+ )
1646
+
1647
+ # Run them through the prompt encoder's pos emb function
1648
+ token_augment[:, kps_emb_start_idx : kps_emb_start_idx + num_keypoints, :] = self.keypoint_posemb_linear_hand(
1649
+ pred_keypoints_2d_cropped
1650
+ ) * (~invalid_mask[:, :, None])
1651
+
1652
+ # Also maybe update token_embeddings with the grid sampled 2D feature.
1653
+ # Remember that pred_keypoints_2d_cropped are -0.5 ~ 0.5. We want -1 ~ 1
1654
+ # Sample points...
1655
+ ## Get sampling points
1656
+ pred_keypoints_2d_cropped_sample_points = pred_keypoints_2d_cropped * 2
1657
+ if self.cfg.MODEL.BACKBONE.TYPE in [
1658
+ "vit_hmr",
1659
+ "vit",
1660
+ "vit_b",
1661
+ "vit_l",
1662
+ "vit_hmr_512_384",
1663
+ ]:
1664
+ # Need to go from 256 x 256 coords to 256 x 192 (HW) because image_embeddings is 16x12
1665
+ # Aka, for x, what was normally -1 ~ 1 for 256 should be -16/12 ~ 16/12 (since to sample at original 256, need to overflow)
1666
+ pred_keypoints_2d_cropped_sample_points[:, :, 0] = (
1667
+ pred_keypoints_2d_cropped_sample_points[:, :, 0] / 12 * 16
1668
+ )
1669
+
1670
+ # Version 2 is projecting & bilinear sampling
1671
+ pred_keypoints_2d_cropped_feats = (
1672
+ F.grid_sample(
1673
+ image_embeddings,
1674
+ pred_keypoints_2d_cropped_sample_points[:, :, None, :], # -1 ~ 1, xy
1675
+ mode="bilinear",
1676
+ padding_mode="zeros",
1677
+ align_corners=False,
1678
+ )
1679
+ .squeeze(3)
1680
+ .permute(0, 2, 1)
1681
+ ) # B x kps x C
1682
+ # Zero out invalid locations...
1683
+ pred_keypoints_2d_cropped_feats = pred_keypoints_2d_cropped_feats * (~invalid_mask[:, :, None])
1684
+ # This is ADDING
1685
+ token_embeddings = token_embeddings.clone()
1686
+ token_embeddings[
1687
+ :,
1688
+ kps_emb_start_idx : kps_emb_start_idx + num_keypoints,
1689
+ :,
1690
+ ] += self.keypoint_feat_linear_hand(pred_keypoints_2d_cropped_feats)
1691
+
1692
+ return token_embeddings, token_augment, pose_output, layer_idx
1693
+
1694
+ def keypoint3d_token_update_fn_hand(
1695
+ self,
1696
+ kps3d_emb_start_idx,
1697
+ token_embeddings,
1698
+ token_augment,
1699
+ pose_output,
1700
+ layer_idx,
1701
+ ):
1702
+ # It's already after the last layer, we're done.
1703
+ if layer_idx == len(self.decoder_hand.layers) - 1:
1704
+ return token_embeddings, token_augment, pose_output, layer_idx
1705
+
1706
+ num_keypoints3d = self.keypoint3d_embedding_hand.weight.shape[0]
1707
+
1708
+ # Get current 3D kps predictions
1709
+ pred_keypoints_3d = pose_output["pred_keypoints_3d"].clone()
1710
+
1711
+ # Now, pelvis normalize
1712
+ pred_keypoints_3d = (
1713
+ pred_keypoints_3d
1714
+ - (pred_keypoints_3d[:, [self.pelvis_idx[0]], :] + pred_keypoints_3d[:, [self.pelvis_idx[1]], :]) / 2
1715
+ )
1716
+
1717
+ # Get the kps we care about, _after_ pelvis norm (just in case idxs shift)
1718
+ pred_keypoints_3d = pred_keypoints_3d[:, self.keypoint3d_embedding_idxs_hand]
1719
+
1720
+ # Run through embedding MLP & put in
1721
+ token_augment = token_augment.clone()
1722
+ token_augment[
1723
+ :,
1724
+ kps3d_emb_start_idx : kps3d_emb_start_idx + num_keypoints3d,
1725
+ :,
1726
+ ] = self.keypoint3d_posemb_linear_hand(pred_keypoints_3d)
1727
+
1728
+ return token_embeddings, token_augment, pose_output, layer_idx
src/sam3d_body/models/modules/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from .geometry_utils import (
4
+ aa_to_rotmat,
5
+ cam_crop_to_full,
6
+ focal_length_normalization,
7
+ get_focalLength_from_fieldOfView,
8
+ get_intrinsic_matrix,
9
+ inverse_perspective_projection,
10
+ log_depth,
11
+ perspective_projection,
12
+ rot6d_to_rotmat,
13
+ transform_points,
14
+ undo_focal_length_normalization,
15
+ undo_log_depth,
16
+ )
17
+
18
+ from .misc import to_2tuple, to_3tuple, to_4tuple, to_ntuple
src/sam3d_body/models/modules/camera_embed.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import einops
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from sam3d_body.models.modules.transformer import LayerNorm2d
9
+ from torch import nn
10
+
11
+
12
+ class CameraEncoder(nn.Module):
13
+ def __init__(self, embed_dim, patch_size=14):
14
+ super().__init__()
15
+ self.patch_size = patch_size
16
+ self.embed_dim = embed_dim
17
+ self.camera = FourierPositionEncoding(n=3, num_bands=16, max_resolution=64)
18
+
19
+ self.conv = nn.Conv2d(embed_dim + 99, embed_dim, kernel_size=1, bias=False)
20
+ self.norm = LayerNorm2d(embed_dim)
21
+
22
+ def forward(self, img_embeddings, rays):
23
+ B, D, _h, _w = img_embeddings.shape
24
+
25
+ with torch.no_grad():
26
+ scale = 1 / self.patch_size
27
+ rays = F.interpolate(
28
+ rays,
29
+ scale_factor=(scale, scale),
30
+ mode="bilinear",
31
+ align_corners=False,
32
+ antialias=True,
33
+ )
34
+ rays = rays.permute(0, 2, 3, 1).contiguous() # [b, h, w, 2]
35
+ rays = torch.cat([rays, torch.ones_like(rays[..., :1])], dim=-1)
36
+ rays_embeddings = self.camera(
37
+ pos=rays.reshape(B, -1, 3)
38
+ ) # (bs, N, 99): rays fourier embedding
39
+ rays_embeddings = einops.rearrange(
40
+ rays_embeddings, "b (h w) c -> b c h w", h=_h, w=_w
41
+ ).contiguous()
42
+
43
+ z = torch.concat([img_embeddings, rays_embeddings], dim=1)
44
+ z = self.norm(self.conv(z))
45
+
46
+ return z
47
+
48
+
49
+ class FourierPositionEncoding(nn.Module):
50
+ def __init__(self, n, num_bands, max_resolution):
51
+ """
52
+ Module that generate Fourier encoding - no learning involved
53
+ """
54
+ super().__init__()
55
+
56
+ self.num_bands = num_bands
57
+ self.max_resolution = [max_resolution] * n
58
+
59
+ @property
60
+ def channels(self):
61
+ """
62
+ Return the output dimension
63
+ """
64
+ num_dims = len(self.max_resolution)
65
+ encoding_size = self.num_bands * num_dims
66
+ encoding_size *= 2 # sin-cos
67
+ encoding_size += num_dims # concat
68
+
69
+ return encoding_size
70
+
71
+ def forward(self, pos):
72
+ """
73
+ Forward pass that take rays as input and generate Fourier positional encodings
74
+ """
75
+ fourier_pos_enc = _generate_fourier_features(
76
+ pos, num_bands=self.num_bands, max_resolution=self.max_resolution
77
+ )
78
+ return fourier_pos_enc
79
+
80
+
81
+ def _generate_fourier_features(pos, num_bands, max_resolution):
82
+ """Generate fourier features from a given set of positions and frequencies"""
83
+ b, n = pos.shape[:2]
84
+ device = pos.device
85
+
86
+ # Linear frequency sampling
87
+ min_freq = 1.0
88
+ freq_bands = torch.stack(
89
+ [
90
+ torch.linspace(start=min_freq, end=res / 2, steps=num_bands, device=device)
91
+ for res in max_resolution
92
+ ],
93
+ dim=0,
94
+ )
95
+
96
+ # Stacking
97
+ per_pos_features = torch.stack(
98
+ [pos[i, :, :][:, :, None] * freq_bands[None, :, :] for i in range(b)], 0
99
+ )
100
+ per_pos_features = per_pos_features.reshape(b, n, -1)
101
+
102
+ # Sin-Cos
103
+ per_pos_features = torch.cat(
104
+ [torch.sin(np.pi * per_pos_features), torch.cos(np.pi * per_pos_features)],
105
+ dim=-1,
106
+ )
107
+
108
+ # Concat with initial pos
109
+ per_pos_features = torch.cat([pos, per_pos_features], dim=-1)
110
+
111
+ return per_pos_features
src/sam3d_body/models/modules/drop_path.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ def drop_path(
8
+ x: torch.Tensor, drop_prob: float = 0.0, training: bool = False
9
+ ) -> torch.Tensor:
10
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
11
+ residual blocks).
12
+
13
+ We follow the implementation
14
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
15
+ """
16
+ if not training:
17
+ return x
18
+ keep_prob = 1 - drop_prob
19
+ # handle tensors with different dimensions, not just 4D tensors.
20
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
21
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
22
+ output = x.div(keep_prob) * random_tensor.floor()
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
28
+ residual blocks).
29
+
30
+ We follow the implementation
31
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
32
+
33
+ Args:
34
+ drop_prob (float): Probability of the path to be zeroed. Default: 0.1
35
+ """
36
+
37
+ def __init__(self, drop_prob: float = 0.1):
38
+ super().__init__()
39
+ self.drop_prob = drop_prob
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ return drop_path(x, self.drop_prob, self.training)
src/sam3d_body/models/modules/geometry_utils.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from typing import Optional
4
+
5
+ import cv2
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch.nn import functional as F
10
+ from jaxtyping import Float
11
+
12
+
13
+ def cam_crop_to_full(cam_bbox, box_center, box_size, img_size, focal_length=5000.0):
14
+ # Convert cam_bbox to full image
15
+ img_w, img_h = img_size[:, 0], img_size[:, 1]
16
+ cx, cy, b = box_center[:, 0], box_center[:, 1], box_size
17
+ w_2, h_2 = img_w / 2.0, img_h / 2.0
18
+ bs = b * cam_bbox[:, 0] + 1e-9
19
+ if type(focal_length) is float:
20
+ focal_length = torch.ones_like(cam_bbox[:, 0]) * focal_length
21
+ tz = 2 * focal_length / bs
22
+ tx = (2 * (cx - w_2) / bs) + cam_bbox[:, 1]
23
+ ty = (2 * (cy - h_2) / bs) + cam_bbox[:, 2]
24
+ full_cam = torch.stack([tx, ty, tz], dim=-1)
25
+ return full_cam
26
+
27
+
28
+ def aa_to_rotmat(theta: torch.Tensor):
29
+ """
30
+ Convert axis-angle representation to rotation matrix.
31
+ Works by first converting it to a quaternion.
32
+ Args:
33
+ theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations.
34
+ Returns:
35
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
36
+
37
+ Alternatives:
38
+ import roma
39
+ y = roma.rotvec_to_rotmat(x)
40
+ """
41
+ norm = torch.norm(theta + 1e-8, p=2, dim=1)
42
+ angle = torch.unsqueeze(norm, -1)
43
+ normalized = torch.div(theta, angle)
44
+ angle = angle * 0.5
45
+ v_cos = torch.cos(angle)
46
+ v_sin = torch.sin(angle)
47
+ quat = torch.cat([v_cos, v_sin * normalized], dim=1)
48
+ return _quat_to_rotmat(quat)
49
+
50
+
51
+ def _quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
52
+ """
53
+ Convert quaternion representation to rotation matrix.
54
+ Args:
55
+ quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z).
56
+ Returns:
57
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
58
+ """
59
+ norm_quat = quat
60
+ norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
61
+ w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
62
+
63
+ B = quat.size(0)
64
+
65
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
66
+ wx, wy, wz = w * x, w * y, w * z
67
+ xy, xz, yz = x * y, x * z, y * z
68
+
69
+ rotMat = torch.stack(
70
+ [
71
+ w2 + x2 - y2 - z2,
72
+ 2 * xy - 2 * wz,
73
+ 2 * wy + 2 * xz,
74
+ 2 * wz + 2 * xy,
75
+ w2 - x2 + y2 - z2,
76
+ 2 * yz - 2 * wx,
77
+ 2 * xz - 2 * wy,
78
+ 2 * wx + 2 * yz,
79
+ w2 - x2 - y2 + z2,
80
+ ],
81
+ dim=1,
82
+ ).view(B, 3, 3)
83
+ return rotMat
84
+
85
+
86
+ def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor:
87
+ """
88
+ Convert 6D rotation representation to 3x3 rotation matrix.
89
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
90
+ Args:
91
+ x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
92
+ Returns:
93
+ torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
94
+
95
+ Alternatives:
96
+ import roma
97
+ x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous()
98
+ y = roma.special_gramschmidt(x)
99
+ """
100
+ x = x.reshape(-1, 2, 3).permute(0, 2, 1).contiguous()
101
+ a1 = x[:, :, 0]
102
+ a2 = x[:, :, 1]
103
+ b1 = F.normalize(a1)
104
+ b2 = F.normalize(a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1)
105
+ b3 = torch.linalg.cross(b1, b2)
106
+ return torch.stack((b1, b2, b3), dim=-1)
107
+
108
+
109
+ def rotmat_to_rot6d(x: torch.Tensor) -> torch.Tensor:
110
+ """
111
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
112
+ by dropping the last row. Note that 6D representation is not unique.
113
+ Args:
114
+ x: batch of rotation matrices of size (B, 3, 3)
115
+
116
+ Returns:
117
+ 6D rotation representation, of size (B, 6)
118
+
119
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
120
+ On the Continuity of Rotation Representations in Neural Networks.
121
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
122
+ Retrieved from http://arxiv.org/abs/1812.07035
123
+ """
124
+ batch_dim = x.size()[:-2]
125
+ return x[..., :2, :].clone().reshape(batch_dim + (6,))
126
+
127
+
128
+ def rot_aa(aa: Float[np.ndarray, "3"], rot: float) -> Float[np.ndarray, "3"]:
129
+ """
130
+ Rotate axis angle parameters.
131
+ Args:
132
+ aa (np.array): Axis-angle vector of shape (3,).
133
+ rot (np.array): Rotation angle in degrees.
134
+ Returns:
135
+ np.array: Rotated axis-angle vector.
136
+ """
137
+ # pose parameters
138
+ R: Float[np.ndarray, "3 3"] = np.array(
139
+ [
140
+ [np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
141
+ [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
142
+ [0, 0, 1],
143
+ ],
144
+ dtype=np.float64,
145
+ )
146
+ # find the rotation of the body in camera frame
147
+ per_rdg: Float[np.ndarray, "3 3"]
148
+ per_rdg, _ = cv2.Rodrigues(aa)
149
+ # apply the global rotation to the global orientation
150
+ resrot: Float[np.ndarray, "3 3"]
151
+ resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
152
+ aa_vec: Float[np.ndarray, "3"] = (resrot.T)[0]
153
+ return aa_vec.astype(np.float32)
154
+
155
+
156
+ def transform_points(
157
+ points: torch.Tensor,
158
+ translation: Optional[torch.Tensor] = None,
159
+ rotation: Optional[torch.Tensor] = None,
160
+ ) -> torch.Tensor:
161
+ """
162
+ Transform a set of 3D points given translation and rotation.
163
+ Args:
164
+ points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points.
165
+ translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation.
166
+ rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation.
167
+ Returns:
168
+ torch.Tensor: Tensor of shape (B, N, 3) containing the transformed points.
169
+ """
170
+ if rotation is not None:
171
+ points = torch.einsum("bij,bkj->bki", rotation, points)
172
+
173
+ if translation is not None:
174
+ points = points + translation.unsqueeze(1)
175
+
176
+ return points
177
+
178
+
179
+ def get_intrinsic_matrix(
180
+ focal_length: torch.Tensor, principle: torch.Tensor
181
+ ) -> torch.Tensor:
182
+ """
183
+ Populate intrinsic camera matrix K given focal length and principle point.
184
+ Args:
185
+ focal_length: Tensor of shape (2,)
186
+ principle: Tensor of shape (2,)
187
+ Returns:
188
+ Tensor of shape (3, 3)
189
+ """
190
+ if isinstance(focal_length, float):
191
+ fl_x = fl_y = focal_length
192
+ elif len(focal_length) == 1:
193
+ fl_x = fl_y = focal_length[0]
194
+ else:
195
+ fl_x, fl_y = focal_length[0], focal_length[1]
196
+ K = torch.eye(3)
197
+ K[0, 0] = fl_x
198
+ K[1, 1] = fl_y
199
+ K[0, -1] = principle[0]
200
+ K[1, -1] = principle[1]
201
+
202
+ return K
203
+
204
+
205
+ def perspective_projection(x, K):
206
+ """
207
+ Computes the perspective projection of a set of points assuming the extrinsinc params have already been applied
208
+ Args:
209
+ - x [bs,N,3]: 3D points
210
+ - K [bs,3,3]: Camera instrincs params
211
+ """
212
+ # Apply perspective distortion
213
+ y = x / x[:, :, -1].unsqueeze(-1) # (bs, N, 3)
214
+
215
+ # Apply camera intrinsics
216
+ y = torch.einsum("bij,bkj->bki", K, y) # (bs, N, 3)
217
+
218
+ return y[:, :, :2]
219
+
220
+
221
+ def inverse_perspective_projection(points, K, distance):
222
+ """
223
+ Computes the inverse perspective projection of a set of points given an estimated distance.
224
+ Input:
225
+ points (bs, N, 2): 2D points
226
+ K (bs,3,3): camera intrinsics params
227
+ distance (bs, N, 1): distance in the 3D world
228
+ Similar to:
229
+ - pts_l_norm = cv2.undistortPoints(np.expand_dims(pts_l, axis=1), cameraMatrix=K_l, distCoeffs=None)
230
+ """
231
+ # Apply camera intrinsics
232
+ points = torch.cat([points, torch.ones_like(points[..., :1])], -1)
233
+ points = torch.einsum("bij,bkj->bki", torch.inverse(K), points)
234
+
235
+ # Apply perspective distortion
236
+ if distance == None:
237
+ return points
238
+ points = points * distance
239
+ return points
240
+
241
+
242
+ def get_cam_intrinsics(img_size, fov=55, p_x=None, p_y=None):
243
+ """Given image size, fov and principal point coordinates, return K the camera parameter matrix"""
244
+ K = np.eye(3)
245
+ # Get focal length.
246
+ focal = get_focalLength_from_fieldOfView(fov=fov, img_size=img_size)
247
+ K[0, 0], K[1, 1] = focal, focal
248
+
249
+ # Set principal point
250
+ if p_x is not None and p_y is not None:
251
+ K[0, -1], K[1, -1] = p_x * img_size, p_y * img_size
252
+ else:
253
+ K[0, -1], K[1, -1] = img_size // 2, img_size // 2
254
+
255
+ return K
256
+
257
+
258
+ def get_focalLength_from_fieldOfView(fov=60, img_size=512):
259
+ """
260
+ Compute the focal length of the camera lens by assuming a certain FOV for the entire image
261
+ Args:
262
+ - fov: float, expressed in degree
263
+ - img_size: int
264
+ Return:
265
+ focal: float
266
+ """
267
+ focal = img_size / (2 * np.tan(np.radians(fov) / 2))
268
+ return focal
269
+
270
+
271
+ def focal_length_normalization(x, f, fovn=60, img_size=448):
272
+ """
273
+ Section 3.1 of https://arxiv.org/pdf/1904.02028.pdf
274
+ E = (fn/f) * E' where E is 1/d
275
+ """
276
+ fn = get_focalLength_from_fieldOfView(fov=fovn, img_size=img_size)
277
+ y = x * (fn / f)
278
+ return y
279
+
280
+
281
+ def undo_focal_length_normalization(y, f, fovn=60, img_size=448):
282
+ """
283
+ Undo focal_length_normalization()
284
+ """
285
+ fn = get_focalLength_from_fieldOfView(fov=fovn, img_size=img_size)
286
+ x = y * (f / fn)
287
+ return x
288
+
289
+
290
+ EPS_LOG = 1e-10
291
+
292
+
293
+ def log_depth(x, eps=EPS_LOG):
294
+ """
295
+ Move depth to log space
296
+ """
297
+ return torch.log(x + eps)
298
+
299
+
300
+ def undo_log_depth(y, eps=EPS_LOG):
301
+ """
302
+ Undo log_depth()
303
+ """
304
+ return torch.exp(y) - eps
src/sam3d_body/models/modules/layer_scale.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from typing import Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class LayerScale(nn.Module):
10
+ """LayerScale layer.
11
+
12
+ Args:
13
+ dim (int): Dimension of input features.
14
+ layer_scale_init_value (float or torch.Tensor): Init value of layer
15
+ scale. Defaults to 1e-5.
16
+ inplace (bool): inplace: can optionally do the
17
+ operation in-place. Defaults to False.
18
+ data_format (str): The input data format, could be 'channels_last'
19
+ or 'channels_first', representing (B, C, H, W) and
20
+ (B, N, C) format data respectively. Defaults to 'channels_last'.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ dim: int,
26
+ layer_scale_init_value: Union[float, torch.Tensor] = 1e-5,
27
+ inplace: bool = False,
28
+ data_format: str = "channels_last",
29
+ ):
30
+ super().__init__()
31
+ assert data_format in (
32
+ "channels_last",
33
+ "channels_first",
34
+ ), "'data_format' could only be channels_last or channels_first."
35
+ self.inplace = inplace
36
+ self.data_format = data_format
37
+ self.weight = nn.Parameter(torch.ones(dim) * layer_scale_init_value)
38
+
39
+ def forward(self, x):
40
+ if self.data_format == "channels_first":
41
+ if self.inplace:
42
+ return x.mul_(self.weight.view(-1, 1, 1))
43
+ else:
44
+ return x * self.weight.view(-1, 1, 1)
45
+ return x.mul_(self.weight) if self.inplace else x * self.weight
src/sam3d_body/models/modules/mhr_utils.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import json
4
+ import math
5
+ import os.path as osp
6
+ import pickle
7
+
8
+ import cv2
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ def rotation_angle_difference(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
17
+ """
18
+ Compute the angle difference (magnitude) between two batches of SO(3) rotation matrices.
19
+ Args:
20
+ A: Tensor of shape (*, 3, 3), batch of rotation matrices.
21
+ B: Tensor of shape (*, 3, 3), batch of rotation matrices.
22
+ Returns:
23
+ Tensor of shape (*,), angle differences in radians.
24
+ """
25
+ # Compute relative rotation matrix
26
+ R_rel = torch.matmul(A, B.transpose(-2, -1)) # (B, 3, 3)
27
+ # Compute trace of relative rotation
28
+ trace = R_rel[..., 0, 0] + R_rel[..., 1, 1] + R_rel[..., 2, 2] # (B,)
29
+ # Compute angle using the trace formula
30
+ cos_theta = (trace - 1) / 2
31
+ # Clamp for numerical stability
32
+ cos_theta_clamped = torch.clamp(cos_theta, -1.0, 1.0)
33
+ # Compute angle difference
34
+ angle = torch.acos(cos_theta_clamped)
35
+ return angle
36
+
37
+
38
+ def fix_wrist_euler(
39
+ wrist_xzy, limits_x=(-2.2, 1.0), limits_z=(-2.2, 1.5), limits_y=(-1.2, 1.5)
40
+ ):
41
+ """
42
+ wrist_xzy: B x 2 x 3 (X, Z, Y angles)
43
+ Returns: Fixed angles within joint limits
44
+ """
45
+ x, z, y = wrist_xzy[..., 0], wrist_xzy[..., 1], wrist_xzy[..., 2]
46
+
47
+ x_alt = torch.atan2(torch.sin(x + torch.pi), torch.cos(x + torch.pi))
48
+ z_alt = torch.atan2(torch.sin(-(z + torch.pi)), torch.cos(-(z + torch.pi)))
49
+ y_alt = torch.atan2(torch.sin(y + torch.pi), torch.cos(y + torch.pi))
50
+
51
+ # Calculate L2 violation distance
52
+ def calc_violation(val, limits):
53
+ below = torch.clamp(limits[0] - val, min=0.0)
54
+ above = torch.clamp(val - limits[1], min=0.0)
55
+ return below**2 + above**2
56
+
57
+ violation_orig = (
58
+ calc_violation(x, limits_x)
59
+ + calc_violation(z, limits_z)
60
+ + calc_violation(y, limits_y)
61
+ )
62
+
63
+ violation_alt = (
64
+ calc_violation(x_alt, limits_x)
65
+ + calc_violation(z_alt, limits_z)
66
+ + calc_violation(y_alt, limits_y)
67
+ )
68
+
69
+ # Use alternative where it has lower L2 violation
70
+ use_alt = violation_alt < violation_orig
71
+
72
+ # Stack alternative and apply mask
73
+ wrist_xzy_alt = torch.stack([x_alt, z_alt, y_alt], dim=-1)
74
+ result = torch.where(use_alt.unsqueeze(-1), wrist_xzy_alt, wrist_xzy)
75
+
76
+ return result
77
+
78
+
79
+ def batch6DFromXYZ(r, return_9D=False):
80
+ """
81
+ Generate a matrix representing a rotation defined by a XYZ-Euler
82
+ rotation.
83
+
84
+ Args:
85
+ r: ... x 3 rotation vectors
86
+
87
+ Returns:
88
+ ... x 6
89
+ """
90
+ rc = torch.cos(r)
91
+ rs = torch.sin(r)
92
+ cx = rc[..., 0]
93
+ cy = rc[..., 1]
94
+ cz = rc[..., 2]
95
+ sx = rs[..., 0]
96
+ sy = rs[..., 1]
97
+ sz = rs[..., 2]
98
+
99
+ result = torch.empty(list(r.shape[:-1]) + [3, 3], dtype=r.dtype).to(r.device)
100
+
101
+ result[..., 0, 0] = cy * cz
102
+ result[..., 0, 1] = -cx * sz + sx * sy * cz
103
+ result[..., 0, 2] = sx * sz + cx * sy * cz
104
+ result[..., 1, 0] = cy * sz
105
+ result[..., 1, 1] = cx * cz + sx * sy * sz
106
+ result[..., 1, 2] = -sx * cz + cx * sy * sz
107
+ result[..., 2, 0] = -sy
108
+ result[..., 2, 1] = sx * cy
109
+ result[..., 2, 2] = cx * cy
110
+
111
+ if not return_9D:
112
+ return torch.cat([result[..., :, 0], result[..., :, 1]], dim=-1)
113
+ else:
114
+ return result
115
+
116
+
117
+ # https://github.com/papagina/RotationContinuity/blob/758b0ce551c06372cab7022d4c0bdf331c89c696/shapenet/code/tools.py#L82
118
+ def batchXYZfrom6D(poses):
119
+ # Args: poses: ... x 6, where "6" is the combined first and second columns
120
+ # First, get the rotaiton matrix
121
+ x_raw = poses[..., :3]
122
+ y_raw = poses[..., 3:]
123
+
124
+ x = F.normalize(x_raw, dim=-1)
125
+ z = torch.cross(x, y_raw, dim=-1)
126
+ z = F.normalize(z, dim=-1)
127
+ y = torch.cross(z, x, dim=-1)
128
+
129
+ matrix = torch.stack([x, y, z], dim=-1) # ... x 3 x 3
130
+
131
+ # Now get it into euler
132
+ # https://github.com/papagina/RotationContinuity/blob/758b0ce551c06372cab7022d4c0bdf331c89c696/shapenet/code/tools.py#L412
133
+ sy = torch.sqrt(
134
+ matrix[..., 0, 0] * matrix[..., 0, 0] + matrix[..., 1, 0] * matrix[..., 1, 0]
135
+ )
136
+ singular = sy < 1e-6
137
+ singular = singular.float()
138
+
139
+ x = torch.atan2(matrix[..., 2, 1], matrix[..., 2, 2])
140
+ y = torch.atan2(-matrix[..., 2, 0], sy)
141
+ z = torch.atan2(matrix[..., 1, 0], matrix[..., 0, 0])
142
+
143
+ xs = torch.atan2(-matrix[..., 1, 2], matrix[..., 1, 1])
144
+ ys = torch.atan2(-matrix[..., 2, 0], sy)
145
+ zs = matrix[..., 1, 0] * 0
146
+
147
+ out_euler = torch.zeros_like(matrix[..., 0])
148
+ out_euler[..., 0] = x * (1 - singular) + xs * singular
149
+ out_euler[..., 1] = y * (1 - singular) + ys * singular
150
+ out_euler[..., 2] = z * (1 - singular) + zs * singular
151
+
152
+ return out_euler
153
+
154
+
155
+ def resize_image(image_array, scale_factor, interpolation=cv2.INTER_LINEAR):
156
+ new_height = int(image_array.shape[0] // scale_factor)
157
+ new_width = int(image_array.shape[1] // scale_factor)
158
+ resized_image = cv2.resize(
159
+ image_array, (new_width, new_height), interpolation=interpolation
160
+ )
161
+
162
+ return resized_image
163
+
164
+
165
+ def compact_cont_to_model_params_hand(hand_cont):
166
+ # These are ordered by joint, not model params ^^
167
+ assert hand_cont.shape[-1] == 54
168
+ hand_dofs_in_order = torch.tensor([3, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 2, 3, 1, 1])
169
+ assert sum(hand_dofs_in_order) == 27
170
+ # Mask of 3DoFs into hand_cont
171
+ mask_cont_threedofs = torch.cat(
172
+ [torch.ones(2 * k).bool() * (k in [3]) for k in hand_dofs_in_order]
173
+ )
174
+ # Mask of 1DoFs (including 2DoF) into hand_cont
175
+ mask_cont_onedofs = torch.cat(
176
+ [torch.ones(2 * k).bool() * (k in [1, 2]) for k in hand_dofs_in_order]
177
+ )
178
+ # Mask of 3DoFs into hand_model_params
179
+ mask_model_params_threedofs = torch.cat(
180
+ [torch.ones(k).bool() * (k in [3]) for k in hand_dofs_in_order]
181
+ )
182
+ # Mask of 1DoFs (including 2DoF) into hand_model_params
183
+ mask_model_params_onedofs = torch.cat(
184
+ [torch.ones(k).bool() * (k in [1, 2]) for k in hand_dofs_in_order]
185
+ )
186
+
187
+ # Convert hand_cont to eulers
188
+ ## First for 3DoFs
189
+ hand_cont_threedofs = hand_cont[..., mask_cont_threedofs].unflatten(-1, (-1, 6))
190
+ hand_model_params_threedofs = batchXYZfrom6D(hand_cont_threedofs).flatten(-2, -1)
191
+ ## Next for 1DoFs
192
+ hand_cont_onedofs = hand_cont[..., mask_cont_onedofs].unflatten(
193
+ -1, (-1, 2)
194
+ ) # (sincos)
195
+ hand_model_params_onedofs = torch.atan2(
196
+ hand_cont_onedofs[..., -2], hand_cont_onedofs[..., -1]
197
+ )
198
+
199
+ # Finally, assemble into a 27-dim vector, ordered by joint, then XYZ.
200
+ hand_model_params = torch.zeros(*hand_cont.shape[:-1], 27).to(hand_cont)
201
+ hand_model_params[..., mask_model_params_threedofs] = hand_model_params_threedofs
202
+ hand_model_params[..., mask_model_params_onedofs] = hand_model_params_onedofs
203
+
204
+ return hand_model_params
205
+
206
+
207
+ def compact_model_params_to_cont_hand(hand_model_params):
208
+ # These are ordered by joint, not model params ^^
209
+ assert hand_model_params.shape[-1] == 27
210
+ hand_dofs_in_order = torch.tensor([3, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 2, 3, 1, 1])
211
+ assert sum(hand_dofs_in_order) == 27
212
+ # Mask of 3DoFs into hand_cont
213
+ mask_cont_threedofs = torch.cat(
214
+ [torch.ones(2 * k).bool() * (k in [3]) for k in hand_dofs_in_order]
215
+ )
216
+ # Mask of 1DoFs (including 2DoF) into hand_cont
217
+ mask_cont_onedofs = torch.cat(
218
+ [torch.ones(2 * k).bool() * (k in [1, 2]) for k in hand_dofs_in_order]
219
+ )
220
+ # Mask of 3DoFs into hand_model_params
221
+ mask_model_params_threedofs = torch.cat(
222
+ [torch.ones(k).bool() * (k in [3]) for k in hand_dofs_in_order]
223
+ )
224
+ # Mask of 1DoFs (including 2DoF) into hand_model_params
225
+ mask_model_params_onedofs = torch.cat(
226
+ [torch.ones(k).bool() * (k in [1, 2]) for k in hand_dofs_in_order]
227
+ )
228
+
229
+ # Convert eulers to hand_cont hand_cont
230
+ ## First for 3DoFs
231
+ hand_model_params_threedofs = hand_model_params[
232
+ ..., mask_model_params_threedofs
233
+ ].unflatten(-1, (-1, 3))
234
+ hand_cont_threedofs = batch6DFromXYZ(hand_model_params_threedofs).flatten(-2, -1)
235
+ ## Next for 1DoFs
236
+ hand_model_params_onedofs = hand_model_params[..., mask_model_params_onedofs]
237
+ hand_cont_onedofs = torch.stack(
238
+ [hand_model_params_onedofs.sin(), hand_model_params_onedofs.cos()], dim=-1
239
+ ).flatten(-2, -1)
240
+
241
+ # Finally, assemble into a 27-dim vector, ordered by joint, then XYZ.
242
+ hand_cont = torch.zeros(*hand_model_params.shape[:-1], 54).to(hand_model_params)
243
+ hand_cont[..., mask_cont_threedofs] = hand_cont_threedofs
244
+ hand_cont[..., mask_cont_onedofs] = hand_cont_onedofs
245
+
246
+ return hand_cont
247
+
248
+
249
+ def batch9Dfrom6D(poses):
250
+ # Args: poses: ... x 6, where "6" is the combined first and second columns
251
+ # First, get the rotaiton matrix
252
+ x_raw = poses[..., :3]
253
+ y_raw = poses[..., 3:]
254
+
255
+ x = F.normalize(x_raw, dim=-1)
256
+ z = torch.cross(x, y_raw, dim=-1)
257
+ z = F.normalize(z, dim=-1)
258
+ y = torch.cross(z, x, dim=-1)
259
+
260
+ matrix = torch.stack([x, y, z], dim=-1).flatten(-2, -1) # ... x 3 x 3 -> x9
261
+
262
+ return matrix
263
+
264
+
265
+ def batch4Dfrom2D(poses):
266
+ # Args: poses: ... x 2, where "2" is sincos
267
+ poses_norm = F.normalize(poses, dim=-1)
268
+
269
+ poses_4d = torch.stack(
270
+ [
271
+ poses_norm[..., 1],
272
+ poses_norm[..., 0],
273
+ -poses_norm[..., 0],
274
+ poses_norm[..., 1],
275
+ ],
276
+ dim=-1,
277
+ ) # Flattened SO2.
278
+
279
+ return poses_4d # .... x 4
280
+
281
+
282
+ def compact_cont_to_rotmat_body(body_pose_cont, inflate_trans=False):
283
+ # fmt: off
284
+ all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)])
285
+ all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123])
286
+ all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129])
287
+ # fmt: on
288
+ num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
289
+ num_1dof_angles = len(all_param_1dof_rot_idxs)
290
+ num_1dof_trans = len(all_param_1dof_trans_idxs)
291
+ assert body_pose_cont.shape[-1] == (
292
+ 2 * num_3dof_angles + 2 * num_1dof_angles + num_1dof_trans
293
+ )
294
+ # Get subsets
295
+ body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles]
296
+ body_cont_1dofs = body_pose_cont[
297
+ ..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles
298
+ ]
299
+ body_cont_trans = body_pose_cont[..., 2 * num_3dof_angles + 2 * num_1dof_angles :]
300
+ # Convert conts to model params
301
+ ## First for 3dofs
302
+ body_cont_3dofs = body_cont_3dofs.unflatten(-1, (-1, 6))
303
+ body_rotmat_3dofs = batch9Dfrom6D(body_cont_3dofs).flatten(-2, -1)
304
+ ## Next for 1dofs
305
+ body_cont_1dofs = body_cont_1dofs.unflatten(-1, (-1, 2)) # (sincos)
306
+ body_rotmat_1dofs = batch4Dfrom2D(body_cont_1dofs).flatten(-2, -1)
307
+ if inflate_trans:
308
+ assert (
309
+ False
310
+ ), "This is left as a possibility to increase the space/contribution/supervision trans params gets compared to rots"
311
+ else:
312
+ ## Nothing to do for trans
313
+ body_rotmat_trans = body_cont_trans
314
+ # Put them together
315
+ body_rotmat_params = torch.cat(
316
+ [body_rotmat_3dofs, body_rotmat_1dofs, body_rotmat_trans], dim=-1
317
+ )
318
+ return body_rotmat_params
319
+
320
+
321
+ def compact_cont_to_model_params_body(body_pose_cont):
322
+ # fmt: off
323
+ all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)])
324
+ all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123])
325
+ all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129])
326
+ # fmt: on
327
+ num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
328
+ num_1dof_angles = len(all_param_1dof_rot_idxs)
329
+ num_1dof_trans = len(all_param_1dof_trans_idxs)
330
+ assert body_pose_cont.shape[-1] == (
331
+ 2 * num_3dof_angles + 2 * num_1dof_angles + num_1dof_trans
332
+ )
333
+ # Get subsets
334
+ body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles]
335
+ body_cont_1dofs = body_pose_cont[
336
+ ..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles
337
+ ]
338
+ body_cont_trans = body_pose_cont[..., 2 * num_3dof_angles + 2 * num_1dof_angles :]
339
+ # Convert conts to model params
340
+ ## First for 3dofs
341
+ body_cont_3dofs = body_cont_3dofs.unflatten(-1, (-1, 6))
342
+ body_params_3dofs = batchXYZfrom6D(body_cont_3dofs).flatten(-2, -1)
343
+ ## Next for 1dofs
344
+ body_cont_1dofs = body_cont_1dofs.unflatten(-1, (-1, 2)) # (sincos)
345
+ body_params_1dofs = torch.atan2(body_cont_1dofs[..., -2], body_cont_1dofs[..., -1])
346
+ ## Nothing to do for trans
347
+ body_params_trans = body_cont_trans
348
+ # Put them together
349
+ body_pose_params = torch.zeros(*body_pose_cont.shape[:-1], 133).to(body_pose_cont)
350
+ body_pose_params[..., all_param_3dof_rot_idxs.flatten()] = body_params_3dofs
351
+ body_pose_params[..., all_param_1dof_rot_idxs] = body_params_1dofs
352
+ body_pose_params[..., all_param_1dof_trans_idxs] = body_params_trans
353
+ return body_pose_params
354
+
355
+
356
+ def compact_model_params_to_cont_body(body_pose_params):
357
+ # fmt: off
358
+ all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)])
359
+ all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123])
360
+ all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129])
361
+ # fmt: on
362
+ num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
363
+ num_1dof_angles = len(all_param_1dof_rot_idxs)
364
+ num_1dof_trans = len(all_param_1dof_trans_idxs)
365
+ assert body_pose_params.shape[-1] == (
366
+ num_3dof_angles + num_1dof_angles + num_1dof_trans
367
+ )
368
+ # Take out params
369
+ body_params_3dofs = body_pose_params[..., all_param_3dof_rot_idxs.flatten()]
370
+ body_params_1dofs = body_pose_params[..., all_param_1dof_rot_idxs]
371
+ body_params_trans = body_pose_params[..., all_param_1dof_trans_idxs]
372
+ # params to cont
373
+ body_cont_3dofs = batch6DFromXYZ(body_params_3dofs.unflatten(-1, (-1, 3))).flatten(
374
+ -2, -1
375
+ )
376
+ body_cont_1dofs = torch.stack(
377
+ [body_params_1dofs.sin(), body_params_1dofs.cos()], dim=-1
378
+ ).flatten(-2, -1)
379
+ body_cont_trans = body_params_trans
380
+ # Put them together
381
+ body_pose_cont = torch.cat(
382
+ [body_cont_3dofs, body_cont_1dofs, body_cont_trans], dim=-1
383
+ )
384
+ return body_pose_cont
385
+
386
+
387
+ # fmt: off
388
+ mhr_param_hand_idxs = [62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115]
389
+ mhr_cont_hand_idxs = [72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237]
390
+ mhr_param_hand_mask = torch.zeros(133).bool(); mhr_param_hand_mask[mhr_param_hand_idxs] = True
391
+ mhr_cont_hand_mask = torch.zeros(260).bool(); mhr_cont_hand_mask[mhr_cont_hand_idxs] = True
392
+ # fmt: on
src/sam3d_body/models/modules/misc.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import collections.abc
4
+ from itertools import repeat
5
+
6
+
7
+ # From PyTorch internals
8
+ def _ntuple(n):
9
+ """A `to_tuple` function generator.
10
+
11
+ It returns a function, this function will repeat the input to a tuple of
12
+ length ``n`` if the input is not an Iterable object, otherwise, return the
13
+ input directly.
14
+
15
+ Args:
16
+ n (int): The number of the target length.
17
+ """
18
+
19
+ def parse(x):
20
+ if isinstance(x, collections.abc.Iterable):
21
+ return x
22
+ return tuple(repeat(x, n))
23
+
24
+ return parse
25
+
26
+
27
+ to_1tuple = _ntuple(1)
28
+ to_2tuple = _ntuple(2)
29
+ to_3tuple = _ntuple(3)
30
+ to_4tuple = _ntuple(4)
31
+ to_ntuple = _ntuple
src/sam3d_body/models/modules/swiglu_ffn.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from .drop_path import DropPath
10
+
11
+ from .layer_scale import LayerScale
12
+
13
+
14
+ class SwiGLUFFN(nn.Module):
15
+ """SwiGLU FFN layer.
16
+
17
+ Modified from https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py
18
+ """ # noqa
19
+
20
+ def __init__(
21
+ self,
22
+ embed_dims: int,
23
+ feedforward_channels: Optional[int] = None,
24
+ out_dims: Optional[int] = None,
25
+ layer_scale_init_value: float = 0.0,
26
+ bias: bool = True,
27
+ drop_path_rate: float = 0.0,
28
+ norm_layer: nn.Module = nn.LayerNorm,
29
+ add_identity: bool = True,
30
+ ) -> None:
31
+ super().__init__()
32
+ self.embed_dims = embed_dims
33
+ self.out_dims = out_dims or embed_dims
34
+ hidden_dims = feedforward_channels or embed_dims
35
+
36
+ self.w12 = nn.Linear(self.embed_dims, 2 * hidden_dims, bias=bias)
37
+
38
+ self.norm = norm_layer
39
+
40
+ self.w3 = nn.Linear(hidden_dims, self.out_dims, bias=bias)
41
+
42
+ if layer_scale_init_value > 0:
43
+ self.gamma2 = LayerScale(
44
+ dim=embed_dims, layer_scale_init_value=layer_scale_init_value
45
+ )
46
+ else:
47
+ self.gamma2 = nn.Identity()
48
+
49
+ self.dropout_layer = DropPath(drop_path_rate)
50
+ self.add_identity = add_identity
51
+
52
+ def forward(
53
+ self, x: torch.Tensor, identity: Optional[torch.Tensor] = None
54
+ ) -> torch.Tensor:
55
+ x12 = self.w12(x)
56
+ x1, x2 = x12.chunk(2, dim=-1)
57
+ hidden = F.silu(x1) * x2
58
+ hidden = self.norm(hidden)
59
+ out = self.w3(hidden)
60
+ out = self.gamma2(out)
61
+ out = self.dropout_layer(out)
62
+
63
+ if self.out_dims != self.embed_dims or not self.add_identity:
64
+ # due to the dimension inconsistence or user setting
65
+ # not to apply residual operation
66
+ return out
67
+
68
+ if identity is None:
69
+ identity = x
70
+ return identity + out
71
+
72
+
73
+ class SwiGLUFFNFused(SwiGLUFFN):
74
+ """SwiGLU FFN layer with fusing.
75
+
76
+ Modified from https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py
77
+ """ # noqa
78
+
79
+ def __init__(
80
+ self,
81
+ embed_dims: int,
82
+ feedforward_channels: Optional[int] = None,
83
+ out_dims: Optional[int] = None,
84
+ layer_scale_init_value: float = 0.0,
85
+ bias: bool = True,
86
+ ) -> None:
87
+ out_dims = out_dims or embed_dims
88
+ feedforward_channels = feedforward_channels or embed_dims
89
+ feedforward_channels = (int(feedforward_channels * 2 / 3) + 7) // 8 * 8
90
+ super().__init__(
91
+ embed_dims=embed_dims,
92
+ feedforward_channels=feedforward_channels,
93
+ out_dims=out_dims,
94
+ layer_scale_init_value=layer_scale_init_value,
95
+ bias=bias,
96
+ )
src/sam3d_body/models/modules/transformer.py ADDED
@@ -0,0 +1,651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from typing import Dict, Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from .drop_path import DropPath
10
+
11
+ from .layer_scale import LayerScale
12
+ from .swiglu_ffn import SwiGLUFFNFused
13
+
14
+
15
+ class MLP(nn.Module):
16
+ # borrowed from DET R
17
+ """Very simple multi-layer perceptron (also called FFN)"""
18
+
19
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
20
+ super().__init__()
21
+ self.num_layers = num_layers
22
+ h = [hidden_dim] * (num_layers - 1)
23
+ self.layers = nn.ModuleList(
24
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
25
+ )
26
+
27
+ def forward(self, x):
28
+ for i, layer in enumerate(self.layers):
29
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
30
+ return x
31
+
32
+
33
+ class LayerNorm32(nn.LayerNorm):
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ return super().forward(x.float()).type(x.dtype)
36
+
37
+
38
+ def build_norm_layer(cfg: Dict, num_features: int):
39
+ """Build normalization layer.
40
+
41
+ Args:
42
+ cfg (dict): The norm layer config, which should contain:
43
+
44
+ - type (str): Layer type.
45
+ - layer args: Args needed to instantiate a norm layer.
46
+ - requires_grad (bool, optional): Whether stop gradient updates.
47
+ num_features (int): Number of input channels.
48
+ postfix (int | str): The postfix to be appended into norm abbreviation
49
+ to create named layer.
50
+
51
+ Returns:
52
+ tuple[str, nn.Module]: The first element is the layer name consisting
53
+ of abbreviation and postfix, e.g., bn1, gn. The second element is the
54
+ created norm layer.
55
+ """
56
+ if not isinstance(cfg, dict):
57
+ raise TypeError("cfg must be a dict")
58
+ if "type" not in cfg:
59
+ raise KeyError('the cfg dict must contain the key "type"')
60
+ cfg_ = cfg.copy()
61
+
62
+ layer_type = cfg_.pop("type")
63
+ if layer_type == "LN":
64
+ norm_layer = LayerNorm32
65
+ else:
66
+ raise ValueError("Unsupported norm layer: ", layer_type)
67
+
68
+ requires_grad = cfg_.pop("requires_grad", True)
69
+ cfg_.setdefault("eps", 1e-5)
70
+ if norm_layer is not nn.GroupNorm:
71
+ layer = norm_layer(num_features, **cfg_)
72
+ if layer_type == "SyncBN" and hasattr(layer, "_specify_ddp_gpu_num"):
73
+ layer._specify_ddp_gpu_num(1)
74
+ else:
75
+ assert "num_groups" in cfg_
76
+ layer = norm_layer(num_channels=num_features, **cfg_)
77
+
78
+ for param in layer.parameters():
79
+ param.requires_grad = requires_grad
80
+
81
+ return layer
82
+
83
+
84
+ class LayerNorm2d(nn.Module):
85
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
86
+ super().__init__()
87
+ self.weight = nn.Parameter(torch.ones(num_channels))
88
+ self.bias = nn.Parameter(torch.zeros(num_channels))
89
+ self.eps = eps
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ u = x.mean(1, keepdim=True)
93
+ s = (x - u).pow(2).mean(1, keepdim=True)
94
+ x = (x - u) / torch.sqrt(s + self.eps)
95
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
96
+ return x
97
+
98
+
99
+ class FFN(nn.Module):
100
+ """Implements feed-forward networks (FFNs) with identity connection.
101
+
102
+ Args:
103
+ embed_dims (int): The feature dimension. Same as
104
+ `MultiheadAttention`. Defaults: 256.
105
+ feedforward_channels (int): The hidden dimension of FFNs.
106
+ Defaults: 1024.
107
+ num_fcs (int, optional): The number of fully-connected layers in
108
+ FFNs. Default: 2.
109
+ act_layer (nn.Module, optional): The activation layer for FFNs.
110
+ Default: nn.ReLU
111
+ ffn_drop (float, optional): Probability of an element to be
112
+ zeroed in FFN. Default 0.0.
113
+ add_identity (bool, optional): Whether to add the
114
+ identity connection. Default: `True`.
115
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.
116
+ layer_scale_init_value (float): Initial value of scale factor in
117
+ LayerScale. Default: 1.0
118
+ """
119
+
120
+ # @deprecated_api_warning(
121
+ # {
122
+ # 'dropout': 'ffn_drop',
123
+ # 'add_residual': 'add_identity'
124
+ # },
125
+ # cls_name='FFN')
126
+ def __init__(
127
+ self,
128
+ embed_dims=256,
129
+ feedforward_channels=1024,
130
+ output_dims=None,
131
+ num_fcs=2,
132
+ act_layer=nn.ReLU,
133
+ ffn_drop=0.0,
134
+ drop_path_rate=0.0,
135
+ add_identity=True,
136
+ layer_scale_init_value=0.0,
137
+ ):
138
+ super().__init__()
139
+ self.embed_dims = embed_dims
140
+ self.feedforward_channels = feedforward_channels
141
+ self.output_dims = output_dims or embed_dims
142
+ self.num_fcs = num_fcs
143
+
144
+ layers = []
145
+ in_channels = embed_dims
146
+ for _ in range(num_fcs - 1):
147
+ layers.append(
148
+ nn.Sequential(
149
+ nn.Linear(in_channels, feedforward_channels),
150
+ act_layer(),
151
+ nn.Dropout(ffn_drop),
152
+ )
153
+ )
154
+ in_channels = feedforward_channels
155
+ layers.append(nn.Linear(in_channels, self.output_dims))
156
+ layers.append(nn.Dropout(ffn_drop))
157
+ self.layers = nn.Sequential(*layers)
158
+ self.dropout_layer = (
159
+ DropPath(drop_path_rate) if drop_path_rate > 0.0 else torch.nn.Identity()
160
+ )
161
+ self.add_identity = add_identity
162
+
163
+ if layer_scale_init_value > 0:
164
+ self.gamma2 = LayerScale(embed_dims, scale=layer_scale_init_value)
165
+ else:
166
+ self.gamma2 = nn.Identity()
167
+
168
+ # @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
169
+ def forward(self, x, identity=None):
170
+ """Forward function for `FFN`.
171
+
172
+ The function would add x to the output tensor if residue is None.
173
+ """
174
+ out = self.layers(x)
175
+ out = self.gamma2(out)
176
+ if not self.add_identity:
177
+ return self.dropout_layer(out)
178
+ if identity is None:
179
+ identity = x
180
+ return identity + self.dropout_layer(out)
181
+
182
+
183
+ class MultiheadAttention(nn.Module):
184
+ """Multi-head Attention Module.
185
+
186
+ This module implements multi-head attention that supports different input
187
+ dims and embed dims. And it also supports a shortcut from ``value``, which
188
+ is useful if input dims is not the same with embed dims.
189
+
190
+ Args:
191
+ embed_dims (int): The embedding dimension.
192
+ num_heads (int): Parallel attention heads.
193
+ input_dims (int, optional): The input dimension, and if None,
194
+ use ``embed_dims``. Defaults to None.
195
+ attn_drop (float): Dropout rate of the dropout layer after the
196
+ attention calculation of query and key. Defaults to 0.
197
+ proj_drop (float): Dropout rate of the dropout layer after the
198
+ output projection. Defaults to 0.
199
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.
200
+ qkv_bias (bool): If True, add a learnable bias to q, k, v.
201
+ Defaults to True.
202
+ qk_scale (float, optional): Override default qk scale of
203
+ ``head_dim ** -0.5`` if set. Defaults to None.
204
+ proj_bias (bool) If True, add a learnable bias to output projection.
205
+ Defaults to True.
206
+ v_shortcut (bool): Add a shortcut from value to output. It's usually
207
+ used if ``input_dims`` is different from ``embed_dims``.
208
+ Defaults to False.
209
+ use_layer_scale (bool): Whether to use layer scale. Defaults to False.
210
+ layer_scale_init_value (float or torch.Tensor): Init value of layer
211
+ scale. Defaults to 0.
212
+ """
213
+
214
+ def __init__(
215
+ self,
216
+ embed_dims,
217
+ num_heads,
218
+ input_dims=None,
219
+ attn_drop=0.0,
220
+ proj_drop=0.0,
221
+ drop_path_rate=0.0,
222
+ qkv_bias=True,
223
+ proj_bias=True,
224
+ v_shortcut=False,
225
+ layer_scale_init_value=0.0,
226
+ ):
227
+ super().__init__()
228
+
229
+ self.input_dims = input_dims or embed_dims
230
+ self.embed_dims = embed_dims
231
+ self.num_heads = num_heads
232
+ self.v_shortcut = v_shortcut
233
+
234
+ self.head_dims = embed_dims // num_heads
235
+
236
+ self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias)
237
+ self.attn_drop = attn_drop
238
+ self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
239
+ self.proj_drop = nn.Dropout(proj_drop)
240
+
241
+ self.out_drop = DropPath(drop_path_rate)
242
+
243
+ if layer_scale_init_value > 0:
244
+ layer_scale_init_value = layer_scale_init_value or 1e-5
245
+ self.gamma1 = LayerScale(
246
+ embed_dims, layer_scale_init_value=layer_scale_init_value
247
+ )
248
+ else:
249
+ self.gamma1 = nn.Identity()
250
+
251
+ def forward(self, x):
252
+ B, N, _ = x.shape
253
+ qkv = (
254
+ self.qkv(x)
255
+ .reshape(B, N, 3, self.num_heads, self.head_dims)
256
+ .permute(2, 0, 3, 1, 4)
257
+ )
258
+ q, k, v = qkv[0], qkv[1], qkv[2]
259
+
260
+ attn_drop = self.attn_drop if self.training else 0.0
261
+ x = F.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
262
+ x = x.transpose(1, 2).reshape(B, N, self.embed_dims)
263
+
264
+ x = self.proj(x)
265
+ x = self.out_drop(self.gamma1(self.proj_drop(x)))
266
+
267
+ if self.v_shortcut:
268
+ x = v.squeeze(1) + x
269
+ return x
270
+
271
+
272
+ class Attention(nn.Module):
273
+ """Multi-head Attention Module for both self and cross attention.
274
+
275
+ Support masking invalid elements for attention.
276
+
277
+ Args:
278
+ embed_dims (int): The embedding dimension.
279
+ num_heads (int): Parallel attention heads.
280
+ input_dims (int, optional): The input dimension, and if None,
281
+ use ``embed_dims``. Defaults to None.
282
+ attn_drop (float): Dropout rate of the dropout layer after the
283
+ attention calculation of query and key. Defaults to 0.
284
+ proj_drop (float): Dropout rate of the dropout layer after the
285
+ output projection. Defaults to 0.
286
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.
287
+ qkv_bias (bool): If True, add a learnable bias to q, k, v.
288
+ Defaults to True.
289
+ qk_scale (float, optional): Override default qk scale of
290
+ ``head_dim ** -0.5`` if set. Defaults to None.
291
+ proj_bias (bool) If True, add a learnable bias to output projection.
292
+ Defaults to True.
293
+ v_shortcut (bool): Add a shortcut from value to output. It's usually
294
+ used if ``input_dims`` is different from ``embed_dims``.
295
+ Defaults to False.
296
+ use_layer_scale (bool): Whether to use layer scale. Defaults to False.
297
+ layer_scale_init_value (float or torch.Tensor): Init value of layer
298
+ scale. Defaults to 0.
299
+ """
300
+
301
+ def __init__(
302
+ self,
303
+ embed_dims,
304
+ num_heads,
305
+ query_dims=None,
306
+ key_dims=None,
307
+ value_dims=None,
308
+ attn_drop=0.0,
309
+ proj_drop=0.0,
310
+ drop_path_rate=0.0,
311
+ qkv_bias=True,
312
+ proj_bias=True,
313
+ v_shortcut=False,
314
+ layer_scale_init_value=0.0,
315
+ ):
316
+ super().__init__()
317
+
318
+ self.query_dims = query_dims or embed_dims
319
+ self.key_dims = key_dims or embed_dims
320
+ self.value_dims = value_dims or embed_dims
321
+ self.embed_dims = embed_dims
322
+ self.num_heads = num_heads
323
+ self.v_shortcut = v_shortcut
324
+
325
+ self.head_dims = embed_dims // num_heads
326
+
327
+ self.q_proj = nn.Linear(self.query_dims, embed_dims, bias=qkv_bias)
328
+ self.k_proj = nn.Linear(self.key_dims, embed_dims, bias=qkv_bias)
329
+ self.v_proj = nn.Linear(self.value_dims, embed_dims, bias=qkv_bias)
330
+ self.attn_drop = attn_drop
331
+ self.proj = nn.Linear(embed_dims, self.query_dims, bias=proj_bias)
332
+ self.proj_drop = nn.Dropout(proj_drop)
333
+
334
+ self.out_drop = DropPath(drop_path_rate)
335
+
336
+ if layer_scale_init_value > 0:
337
+ layer_scale_init_value = layer_scale_init_value or 1e-5
338
+ self.gamma1 = LayerScale(
339
+ embed_dims, layer_scale_init_value=layer_scale_init_value
340
+ )
341
+ else:
342
+ self.gamma1 = nn.Identity()
343
+
344
+ def _separate_heads(self, x: torch.Tensor) -> torch.Tensor:
345
+ b, n, _ = x.shape
346
+ x = x.reshape(b, n, self.num_heads, self.head_dims)
347
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
348
+
349
+ def forward(
350
+ self,
351
+ q: torch.Tensor,
352
+ k: torch.Tensor,
353
+ v: torch.Tensor,
354
+ attn_mask: Optional[torch.Tensor] = None,
355
+ ):
356
+ B, N, _ = q.shape
357
+ q = self._separate_heads(self.q_proj(q))
358
+ k = self._separate_heads(self.k_proj(k))
359
+ v = self._separate_heads(self.v_proj(v))
360
+
361
+ attn_drop = self.attn_drop if self.training else 0.0
362
+ if attn_mask is not None:
363
+ attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
364
+
365
+ x = F.scaled_dot_product_attention(
366
+ q, k, v, attn_mask=attn_mask, dropout_p=attn_drop
367
+ )
368
+ x = x.transpose(1, 2).reshape(B, N, self.embed_dims)
369
+
370
+ x = self.proj(x)
371
+ x = self.out_drop(self.gamma1(self.proj_drop(x)))
372
+
373
+ if self.v_shortcut:
374
+ x = v.squeeze(1) + x
375
+ return x
376
+
377
+
378
+ class TransformerEncoderLayer(nn.Module):
379
+ """Implements one encoder layer in Vision Transformer.
380
+
381
+ Args:
382
+ embed_dims (int): The feature dimension
383
+ num_heads (int): Parallel attention heads
384
+ feedforward_channels (int): The hidden dimension for FFNs
385
+ layer_scale_init_value (float or torch.Tensor): Init value of layer
386
+ scale. Defaults to 0.
387
+ drop_rate (float): Probability of an element to be zeroed
388
+ after the feed forward layer. Defaults to 0.
389
+ attn_drop_rate (float): The drop out rate for attention output weights.
390
+ Defaults to 0.
391
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.
392
+ num_fcs (int): The number of fully-connected layers for FFNs.
393
+ Defaults to 2.
394
+ qkv_bias (bool): enable bias for qkv if True. Defaults to True.
395
+ ffn_type (str): Select the type of ffn layers. Defaults to 'origin'.
396
+ act_layer (nn.Module, optional): The activation layer for FFNs.
397
+ Default: nn.GELU
398
+ norm_cfg (dict): Config dict for normalization layer.
399
+ Defaults to ``dict(type='LN')``.
400
+ """
401
+
402
+ def __init__(
403
+ self,
404
+ embed_dims,
405
+ num_heads,
406
+ feedforward_channels,
407
+ layer_scale_init_value=0.0,
408
+ drop_rate=0.0,
409
+ attn_drop_rate=0.0,
410
+ drop_path_rate=0.0,
411
+ num_fcs=2,
412
+ qkv_bias=True,
413
+ ffn_type="origin",
414
+ act_layer=nn.GELU,
415
+ norm_cfg=dict(type="LN", eps=1e-6),
416
+ ):
417
+ super().__init__()
418
+
419
+ self.embed_dims = embed_dims
420
+
421
+ self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
422
+
423
+ self.attn = MultiheadAttention(
424
+ embed_dims=embed_dims,
425
+ num_heads=num_heads,
426
+ attn_drop=attn_drop_rate,
427
+ proj_drop=drop_rate,
428
+ drop_path_rate=drop_path_rate,
429
+ qkv_bias=qkv_bias,
430
+ layer_scale_init_value=layer_scale_init_value,
431
+ )
432
+
433
+ self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
434
+
435
+ if ffn_type == "origin":
436
+ self.ffn = FFN(
437
+ embed_dims=embed_dims,
438
+ feedforward_channels=feedforward_channels,
439
+ num_fcs=num_fcs,
440
+ ffn_drop=drop_rate,
441
+ drop_path_rate=drop_path_rate,
442
+ act_layer=act_layer,
443
+ layer_scale_init_value=layer_scale_init_value,
444
+ )
445
+ elif ffn_type == "swiglu_fused":
446
+ self.ffn = SwiGLUFFNFused(
447
+ embed_dims=embed_dims,
448
+ feedforward_channels=feedforward_channels,
449
+ layer_scale_init_value=layer_scale_init_value,
450
+ )
451
+ else:
452
+ raise NotImplementedError
453
+
454
+ @property
455
+ def norm1(self):
456
+ return self.ln1
457
+
458
+ @property
459
+ def norm2(self):
460
+ return self.ln2
461
+
462
+ def forward(self, x):
463
+ x = x + self.attn(self.ln1(x))
464
+ x = self.ffn(self.ln2(x), identity=x)
465
+ return x
466
+
467
+
468
+ class TransformerDecoderLayer(nn.Module):
469
+ """Implements one decoder layer in cross-attention Transformer.
470
+
471
+ Adapted from Segment Anything Model (SAM) implementation.
472
+
473
+ Args:
474
+ embed_dims (int): The feature dimension
475
+ num_heads (int): Parallel attention heads
476
+ feedforward_channels (int): The hidden dimension for FFNs
477
+ layer_scale_init_value (float or torch.Tensor): Init value of layer
478
+ scale. Defaults to 0.
479
+ drop_rate (float): Probability of an element to be zeroed
480
+ after the feed forward layer. Defaults to 0.
481
+ attn_drop_rate (float): The drop out rate for attention output weights.
482
+ Defaults to 0.
483
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.
484
+ num_fcs (int): The number of fully-connected layers for FFNs.
485
+ Defaults to 2.
486
+ qkv_bias (bool): enable bias for qkv if True. Defaults to True.
487
+ ffn_type (str): Select the type of ffn layers. Defaults to 'origin'.
488
+ act_layer (nn.Module, optional): The activation layer for FFNs.
489
+ Default: nn.GELU
490
+ norm_cfg (dict): Config dict for normalization layer.
491
+ Defaults to ``dict(type='LN')``.
492
+ enable_twoway (bool): Whether to enable two-way Transformer (used in SAM).
493
+ repeat_pe (bool): Whether to re-add PE at each layer (used in SAM)
494
+ skip_first_pe (bool)
495
+ """
496
+
497
+ def __init__(
498
+ self,
499
+ token_dims: int,
500
+ context_dims: int,
501
+ num_heads: int = 8,
502
+ head_dims: int = 64,
503
+ mlp_dims: int = 1024,
504
+ layer_scale_init_value: float = 0.0,
505
+ drop_rate: float = 0.0,
506
+ attn_drop_rate: float = 0.0,
507
+ drop_path_rate: float = 0.0,
508
+ ffn_type: str = "origin",
509
+ act_layer: type[nn.Module] | nn.Module = nn.GELU,
510
+ norm_cfg: Dict = dict(type="LN", eps=1e-6),
511
+ enable_twoway: bool = False,
512
+ repeat_pe: bool = False,
513
+ skip_first_pe: bool = False,
514
+ ):
515
+ super().__init__()
516
+ self.repeat_pe = repeat_pe
517
+ self.skip_first_pe = skip_first_pe
518
+ if self.repeat_pe:
519
+ self.ln_pe_1 = build_norm_layer(norm_cfg, token_dims)
520
+ self.ln_pe_2 = build_norm_layer(norm_cfg, context_dims)
521
+
522
+ self.ln1 = build_norm_layer(norm_cfg, token_dims)
523
+
524
+ self.self_attn = Attention(
525
+ embed_dims=num_heads * head_dims,
526
+ num_heads=num_heads,
527
+ query_dims=token_dims,
528
+ key_dims=token_dims,
529
+ value_dims=token_dims,
530
+ attn_drop=attn_drop_rate,
531
+ proj_drop=drop_rate,
532
+ drop_path_rate=drop_path_rate,
533
+ layer_scale_init_value=layer_scale_init_value,
534
+ )
535
+
536
+ self.ln2_1 = build_norm_layer(norm_cfg, token_dims)
537
+ self.ln2_2 = build_norm_layer(norm_cfg, context_dims)
538
+
539
+ self.cross_attn = Attention(
540
+ embed_dims=num_heads * head_dims,
541
+ num_heads=num_heads,
542
+ query_dims=token_dims,
543
+ key_dims=context_dims,
544
+ value_dims=context_dims,
545
+ attn_drop=attn_drop_rate,
546
+ proj_drop=drop_rate,
547
+ drop_path_rate=drop_path_rate,
548
+ layer_scale_init_value=layer_scale_init_value,
549
+ )
550
+
551
+ self.ln3 = build_norm_layer(norm_cfg, token_dims)
552
+
553
+ if ffn_type == "origin":
554
+ self.ffn = FFN(
555
+ embed_dims=token_dims,
556
+ feedforward_channels=mlp_dims,
557
+ ffn_drop=drop_rate,
558
+ drop_path_rate=drop_path_rate,
559
+ act_layer=act_layer,
560
+ layer_scale_init_value=layer_scale_init_value,
561
+ )
562
+ elif ffn_type == "swiglu_fused":
563
+ self.ffn = SwiGLUFFNFused(
564
+ embed_dims=token_dims,
565
+ feedforward_channels=mlp_dims,
566
+ layer_scale_init_value=layer_scale_init_value,
567
+ )
568
+ else:
569
+ raise NotImplementedError
570
+
571
+ self.enable_twoway = enable_twoway
572
+ if self.enable_twoway:
573
+ self.ln4_1 = build_norm_layer(norm_cfg, context_dims)
574
+ self.ln4_2 = build_norm_layer(norm_cfg, token_dims)
575
+
576
+ self.cross_attn_2 = Attention(
577
+ embed_dims=num_heads * head_dims,
578
+ num_heads=num_heads,
579
+ query_dims=context_dims,
580
+ key_dims=token_dims,
581
+ value_dims=token_dims,
582
+ attn_drop=attn_drop_rate,
583
+ proj_drop=drop_rate,
584
+ drop_path_rate=drop_path_rate,
585
+ layer_scale_init_value=layer_scale_init_value,
586
+ )
587
+
588
+ def forward(
589
+ self,
590
+ x: torch.Tensor,
591
+ context: torch.Tensor,
592
+ x_pe: Optional[torch.Tensor] = None,
593
+ context_pe: Optional[torch.Tensor] = None,
594
+ x_mask: Optional[torch.Tensor] = None,
595
+ ):
596
+ """
597
+ Args:
598
+ x: shape [B, N, C]
599
+ context: shape [B, N, C]
600
+ x_mask: shape [B, N]
601
+ """
602
+ if self.repeat_pe and context_pe is not None:
603
+ # LaPE: https://openaccess.thecvf.com/content/ICCV2023/papers/Yu_LaPE_Layer-adaptive_Position_Embedding_for_Vision_Transformers_with_Independent_Layer_ICCV_2023_paper.pdf
604
+ x_pe = self.ln_pe_1(x_pe)
605
+ context_pe = self.ln_pe_2(context_pe)
606
+
607
+ # Self attention block for tokens
608
+ if self.repeat_pe and not self.skip_first_pe and x_pe is not None:
609
+ q = k = self.ln1(x) + x_pe
610
+ v = self.ln1(x)
611
+ else:
612
+ q = k = v = self.ln1(x)
613
+
614
+ attn_mask = None
615
+ if x_mask is not None:
616
+ attn_mask = x_mask[:, :, None] @ x_mask[:, None, :]
617
+ # Set diagonal to 1 to prevent nan output
618
+ attn_mask.diagonal(dim1=1, dim2=2).fill_(1)
619
+ attn_mask = attn_mask > 0
620
+ x = x + self.self_attn(q=q, k=k, v=v, attn_mask=attn_mask)
621
+
622
+ # Cross attention block, tokens attending to image embedding
623
+ if self.repeat_pe and context_pe is not None:
624
+ q = self.ln2_1(x) + x_pe
625
+ k = self.ln2_2(context) + context_pe
626
+ v = self.ln2_2(context)
627
+ else:
628
+ q = self.ln2_1(x)
629
+ k = v = self.ln2_2(context)
630
+ x = x + self.cross_attn(q=q, k=k, v=v)
631
+
632
+ # MLP block
633
+ x = self.ffn(self.ln3(x), identity=x)
634
+
635
+ # (Optional) Cross attention block, image embeddings attending to tokens
636
+ if self.enable_twoway:
637
+ if self.repeat_pe and context_pe is not None:
638
+ q = self.ln4_1(context) + context_pe
639
+ k = self.ln4_2(x) + x_pe
640
+ v = self.ln4_2(x)
641
+ else:
642
+ q = self.ln4_1(context)
643
+ k = v = self.ln4_2(x)
644
+ attn_mask = (
645
+ (x_mask[:, None, :].repeat(1, context.shape[1], 1)) > 0
646
+ if x_mask is not None
647
+ else None
648
+ )
649
+ context = context + self.cross_attn_2(q=q, k=k, v=v, attn_mask=attn_mask)
650
+
651
+ return x, context