-
Notifications
You must be signed in to change notification settings - Fork 2
Updated Addition of FTOL (q_solver tolerance) Hyperparameter to Config #88 #89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughThe changes update the optimization workflow by introducing a new parameter, Changes
Sequence Diagram(s)sequenceDiagram
participant S as Stac Class
participant CS as Compute STAC Module
participant SC as StacCore Instance
S->>S: Instantiate stac_core_obj
S->>CS: Call optimization functions (root, offset, pose) with stac_core_obj
CS->>SC: Invoke q_opt / m_opt methods on stac_core_obj
SC-->>CS: Return optimization results
CS-->>S: Return final outputs
Possibly related PRs
Suggested reviewers
Poem
Tip ⚡🧪 Multi-step agentic review comment chat (experimental)
✨ Finishing Touches
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (8)
stac_mjx/stac_core.py (6)
147-151
: Add parameter and return details to the docstring.
While the docstring briefly describes the squared-error operation, consider including param/return entries to maintain consistency with other functions:
• Clarify the expected shape/type of x.
• Clarify the scalar or array returned.
153-187
: Use a structured logging approach instead of raw prints.
The try-except block currently prints errors to stdout, making them harder to filter or search. Leveraging loggers (e.g., Python’slogging
library) would provide more control and align better with production-level best practices.try: return mjx_data, q_solver.run( ... ) except ValueError as ex: - print("Warning: optimization failed.", flush=True) - print(ex, flush=True) + import logging + logger = logging.getLogger(__name__) + logger.warning("Optimization failed: %s", ex)
234-239
: Address docstring style issues (line 234).
The docstring triggers several style warnings (D205, D212, D415). Make sure the first line is a one-sentence summary ending with punctuation, with a blank line before detail sentences.🧰 Tools
🪛 GitHub Actions: CI
[warning] 234-234: D205: 1 blank line required between summary line and description (found 0)
[warning] 234-234: D212: Multi-line docstring summary should start at the first line
[warning] 234-234: D415: First line should end with a period, question mark, or exclamation point (not 'o')
243-249
: Document the init method.
A direct docstring on the__init__
method will address the D107 warning and help clarify the meaning of thetol
parameter.🧰 Tools
🪛 GitHub Actions: CI
[warning] 243-243: D107: Missing docstring in init
263-279
: Fix docstring format forq_opt
.
The docstring summary line should end with a period, and a blank line should separate the summary from the detailed description.🧰 Tools
🪛 GitHub Actions: CI
[warning] 263-263: D205: 1 blank line required between summary line and description (found 0)
[warning] 263-263: D212: Multi-line docstring summary should start at the first line
[warning] 263-263: D415: First line should end with a period, question mark, or exclamation point (not ')')
[warning] 263-263: D402: First line should not be the function's 'signature'
293-310
: Refine docstring format form_opt
.
Similar style warnings (D205, D212, D415, D402) apply. Ensure the summary line is on the first line and ends with punctuation, and avoid putting the function signature in the summary line.🧰 Tools
🪛 GitHub Actions: CI
[warning] 293-293: D205: 1 blank line required between summary line and description (found 0)
[warning] 293-293: D212: Multi-line docstring summary should start at the first line
[warning] 293-293: D415: First line should end with a period, question mark, or exclamation point (not ')')
[warning] 293-293: D402: First line should not be the function's 'signature'
tests/test_stac_core.py (1)
26-26
: Avoid referencing internal/private fields.
optax._src.base.GradientTransformationExtraArgs
is a private API that may change without notice. Consider asserting a public type or verifying behavior in a more stable way.stac_mjx/compute_stac.py (1)
13-280
: Consider updating function docstrings to reflect the new StacCore parameterWhile the code changes are correct, the docstrings for all three functions (root_optimization, offset_optimization, and pose_optimization) have not been updated to include documentation for the new
stac_core_obj
parameter. This would improve code maintainability and help future developers understand the purpose of this parameter.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
stac_mjx/compute_stac.py
(8 hunks)stac_mjx/stac.py
(11 hunks)stac_mjx/stac_core.py
(3 hunks)tests/test_stac_core.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
tests/test_stac_core.py
1-1: pytest
imported but unused
Remove unused import: pytest
(F401)
3-3: dm_control.mjcf
imported but unused
Remove unused import: dm_control.mjcf
(F401)
4-4: dm_control
imported but unused
Remove unused import: dm_control
(F401)
6-6: mujoco
imported but unused
Remove unused import: mujoco
(F401)
7-7: stac_mjx.utils.mjx_load
imported but unused
Remove unused import: stac_mjx.utils.mjx_load
(F401)
8-8: Redefinition of unused stac_mjx
from line 2
Remove definition: stac_mjx
(F811)
🪛 GitHub Actions: CI
tests/test_stac_core.py
[error] 43-43: FileNotFoundError: [Errno 2] Unable to synchronously open file (unable to open file: name = '/home/runner/work/stac-mjx/tests/data/test_rodent_mocap_1000_frames.nwb', errno = 2, error message = 'No such file or directory')
[warning] 16-16: DeprecationWarning: jax.lib.xla_bridge.get_backend is deprecated; use jax.extend.backend.get_backend.
stac_mjx/stac_core.py
[warning] 234-234: D205: 1 blank line required between summary line and description (found 0)
[warning] 234-234: D212: Multi-line docstring summary should start at the first line
[warning] 234-234: D415: First line should end with a period, question mark, or exclamation point (not 'o')
[warning] 243-243: D107: Missing docstring in init
[warning] 263-263: D205: 1 blank line required between summary line and description (found 0)
[warning] 263-263: D212: Multi-line docstring summary should start at the first line
[warning] 263-263: D415: First line should end with a period, question mark, or exclamation point (not ')')
[warning] 263-263: D402: First line should not be the function's 'signature'
[warning] 293-293: D205: 1 blank line required between summary line and description (found 0)
[warning] 293-293: D212: Multi-line docstring summary should start at the first line
[warning] 293-293: D415: First line should end with a period, question mark, or exclamation point (not ')')
[warning] 293-293: D402: First line should not be the function's 'signature'
⏰ Context from checks skipped due to timeout of 90000ms (2)
- GitHub Check: Tests (ubuntu-latest, Python 3.11)
- GitHub Check: Lint
🔇 Additional comments (17)
stac_mjx/stac_core.py (2)
7-7
: No concerns for the new import.
This import is standard practice for enabling partial application.
189-201
: Consider symmetry in error handling.
Unlike_q_opt
,_m_opt
does not handle solver errors. If failures are possible here, adding a similar try-except block or a graceful fallback might keep the code consistent and robust.stac_mjx/stac.py (5)
15-15
: Imports look good.
No concerns about the new stac_core import. This improves modularity.
72-72
: Initialize attribute consistently.
Definingself.stac_core_obj = None
clarifies the attribute’s presence. Good approach for lazy initialization later on.
234-235
: Great use of tolerances from config.
Passingself.cfg.model.FTOL
ensures the solver respects user-configured tolerance. Good design for future maintainability.
244-296
: Coordination with compute_stac calls is cohesive.
Your approach to pass thestac_core_obj
into each optimization function ensures consistent usage of the new solver-based architecture.
344-346
: Lazy instantiation logic.
Creating theStacCore
object if it does not exist ensures the class's usage remains flexible and memory-safe.tests/test_stac_core.py (2)
16-37
: Redundant config usage.
The function signature includesconfig, mouse_config
, yet you overwriteconfig
from disk within the function. This might be confusing if the fixture config is needed. Validate whether you intend to reuse the fixture config or the reloaded config.🧰 Tools
🪛 GitHub Actions: CI
[warning] 16-16: DeprecationWarning: jax.lib.xla_bridge.get_backend is deprecated; use jax.extend.backend.get_backend.
43-43
: Check file paths or fixtures for the missing NWB file.
The test fails because the NWB file is not found. Ensure correct fixture usage or local test data.Would you like to verify the path with a shell script to list available test files in the repository?
🧰 Tools
🪛 GitHub Actions: CI
[error] 43-43: FileNotFoundError: [Errno 2] Unable to synchronously open file (unable to open file: name = '/home/runner/work/stac-mjx/tests/data/test_rodent_mocap_1000_frames.nwb', errno = 2, error message = 'No such file or directory')
stac_mjx/compute_stac.py (8)
13-14
: Function signature updated with StacCore parameterThis change aligns with the PR objective of transitioning from global variables to a StacCore object that encapsulates the optimization methods. The new parameter allows for customization of the optimization process through configurable hyperparameters like FTOL.
55-56
: Method call updated to use StacCore objectThe call has been correctly updated from a global method to the method on the StacCore object, allowing access to the encapsulated optimization logic and hyperparameters.
82-83
: Method call updated to use StacCore objectConsistent with the first change, this call has been properly updated to use the StacCore object's method.
107-108
: Function signature updated with StacCore parameterSimilar to the root_optimization function, the offset_optimization signature has been updated to include the StacCore object parameter, maintaining consistency across the API.
151-152
: Method call updated to use StacCore object's m_optThe m_opt method call has been correctly updated to use the StacCore object, aligning with the PR objective of encapsulating previously global optimization methods.
179-180
: Function signature updated with StacCore parameterThe pose_optimization function signature has been consistently updated to include the StacCore object parameter, completing the refactoring across all optimization functions.
220-221
: Method call updated to use StacCore objectThis call within the pose_optimization function has been correctly updated to use the StacCore object's q_opt method.
237-238
: Method call updated to use StacCore objectThe second q_opt call within the pose_optimization function has been properly updated to use the StacCore object's method.
import pytest | ||
import stac_mjx | ||
from dm_control import mjcf | ||
import dm_control | ||
from pathlib import Path | ||
import mujoco | ||
from stac_mjx.utils import mjx_load | ||
import stac_mjx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
Th 8000 e reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Remove unused imports.
Static analysis flags unused imports (pytest
, dm_control
, mujoco
, etc.) and a redefinition of stac_mjx
. Consider removing extraneous imports to improve clarity and abide by lint rules.
- import pytest
- import dm_control
- import mujoco
- from stac_mjx.utils import mjx_load
- import stac_mjx
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
import pytest | |
import stac_mjx | |
from dm_control import mjcf | |
import dm_control | |
from pathlib import Path | |
import mujoco | |
from stac_mjx.utils import mjx_load | |
import stac_mjx | |
import stac_mjx | |
from dm_control import mjcf | |
from pathlib import Path |
🧰 Tools
🪛 Ruff (0.8.2)
1-1: pytest
imported but unused
Remove unused import: pytest
(F401)
3-3: dm_control.mjcf
imported but unused
Remove unused import: dm_control.mjcf
(F401)
4-4: dm_control
imported but unused
Remove unused import: dm_control
(F401)
6-6: mujoco
imported but unused
Remove unused import: mujoco
(F401)
7-7: stac_mjx.utils.mjx_load
imported but unused
Remove unused import: stac_mjx.utils.mjx_load
(F401)
8-8: Redefinition of unused stac_mjx
from line 2
Remove definition: stac_mjx
(F811)
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #89 +/- ##
==========================================
+ Coverage 44.83% 45.13% +0.30%
==========================================
Files 9 9
Lines 678 689 +11
==========================================
+ Hits 304 311 +7
- Misses 374 378 +4 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🔭 Outside diff range comments (1)
tests/test_stac_core.py (1)
39-57
: 💡 Verification agent🧩 Analysis chain
Test function has commented out assertions and lacks validation.
The test loads configurations and data but doesn't actually execute the core functionality it's meant to test as the call to
run_stac
is commented out. This makes the test ineffective for validating the compilation behavior.Either uncomment the critical lines or add explicit comments explaining why the test is designed this way (perhaps it's intentionally a partial test that only checks initial conditions).
🌐 Web query:
What are best practices for handling long-running tests in Python?
💡 Result:
To handle long-running tests in Python effectively, developers should focus on optimization strategies, efficient resource management, and smart test execution. Here are key best practices:
Identify Performance Bottlenecks
- Profile test durations using pytest's
--durations
flag to pinpoint slow tests[1][13].- Analyze test collection time by restricting test discovery to relevant directories or using
-k
for keyword filtering[1][6].Optimize Test Execution
- Parallelize tests with
pytest-xdist
to leverage multi-core systems[1][6][7].- Run subsets selectively using file paths or test markers to avoid unnecessary executions[1][6].
- Implement incremental testing to run only tests affected by code changes (requires custom tooling or plugins like
pytest-testmon
)[7].Improve Test Efficiency
- Mock external dependencies (APIs, databases) to eliminate I/O delays[5][6][17].
- Replace
time.sleep()
with conditional polling or asynchronous waiting[1][6].- Use session/module-scoped fixtures to minimize repeated setup/teardown[1][5].
Manage Long-Running Tests
- Set timeouts with
pytest-timeout
to prevent hanging tests (via decorators or CLI)[2].- Separate slow tests into dedicated suites for scheduled/on-demand execution[5][17].
- Leverage background tasks in frameworks like FastAPI for non-blocking workflows[9].
Maintain Test Health
- Avoid redundant assertions by testing one concern per method[5].
- Profile regularly to catch new performance issues early[1][6].
- Optimize database interactions with transaction rollbacks or in-memory databases[1][6].
For critical systems, consider asynchronous execution patterns and state management for reliable long-running processes[8][11]. Always balance test coverage with execution speed—fast feedback loops are crucial for developer productivity[3][5][17].
Citations:
- 1: https://pytest-with-eric.com/pytest-advanced/pytest-improve-runtime/
- 2: https://www.lambdatest.com/blog/pytest-timeouts/
- 3: https://docs.python-guide.org/writing/tests/
- 4: https://eli.thegreenplace.net/2017/interacting-with-a-long-running-child-process-in-python/
- 5: https://pytest-with-eric.com/introduction/python-unit-testing-best-practices/
- 6: https://www.devopsschool.com/blog/checklist-to-reduce-unit-test-time-in-python/
- 7: https://engineering.ramp.com/faster-python-tests
- 8: https://autokitteh.com/technical-blog/durable-python-reliable-long-running-workflows-with-just-a-few-lines-of-code/
- 9: https://blog.stackademic.com/managing-long-running-processes-with-fastapi-in-python-a5de07eaf76a
- 10: https://discuss.python.org/t/unexpected-large-execution-time/26677
- 11: https://xsoar.pan.dev/docs/integrations/long-running
- 12: https://docs.pytest.org/en/stable/how-to/output.html
- 13: https://docs.pytest.org/en/6.2.x/usage.html
- 14: https://docs.djangoproject.com/en/5.1/topics/testing/overview/
- 15: https://realpython.com/python-testing/
- 16: https://panel.holoviz.org/how_to/callbacks/load.html
- 17: https://muuktest.com/blog/python-unit-test-guide
- 18: https://community.atlassian.com/t5/Bitbucket-questions/problem-with-a-long-running-python-unit-test/qaq-p/995653
- 19: pallets/click#2171
- 20: https://docs.python.org/3/library/unittest.html
Action Required: Update the Test to Clearly Address Long-Running Behavior
- The test currently does not execute the core functionality because the call to
run_stac
and its associated assertions are commented out.- Given that the test is intentionally disabled due to long runtime without a GPU, consider using a pytest marker (for example,
@pytest.mark.slow
or@pytest.mark.skipif
) to clearly indicate that this is a long-running test rather than simply commenting it out.- If the decision to omit these assertions is intentional (e.g., to reduce execution time during regular testing), please add an explicit comment documenting the rationale and any conditions under which the full test should be run (such as in a dedicated performance testing suite).
🧹 Nitpick comments (1)
tests/test_stac_core.py (1)
16-37
: Unused function parameter and missing documentation.The
mouse_config
parameter isn't used in the function and there's no docstring explaining the purpose of the test.-def test_stac_core_obj(config, mouse_config): +def test_stac_core_obj(config): + """Test that StacCore objects are created with correct types and tolerance values."""Also, consider adding assertions with messages to make test failures more informative.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
stac_mjx/stac_core.py
(3 hunks)tests/test_stac_core.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
tests/test_stac_core.py
1-1: pytest
imported but unused
Remove unused import: pytest
(F401)
3-3: dm_control.mjcf
imported but unused
Remove unused import: dm_control.mjcf
(F401)
4-4: dm_control
imported but unused
Remove unused import: dm_control
(F401)
6-6: mujoco
imported but unused
Remove unused import: mujoco
(F401)
7-7: stac_mjx.utils.mjx_load
imported but unused
Remove unused import: stac_mjx.utils.mjx_load
(F401)
8-8: Redefinition of unused stac_mjx
from line 2
Remove definition: stac_mjx
(F811)
⏰ Context from checks skipped due to timeout of 90000ms (2)
- GitHub Check: Tests (ubuntu-latest, Python 3.11)
- GitHub Check: Lint
🔇 Additional comments (7)
tests/test_stac_core.py (1)
1-11
: Remove unused imports to improve code clarity.Several imports in this file are unused and should be removed to comply with linting rules:
-import pytest import stac_mjx -from dm_control import mjcf -import dm_control from pathlib import Path -import mujoco -from stac_mjx.utils import mjx_load -import stac_mjx import optax import jaxoptThe
stac_mjx
module is imported twice (lines 2 and 8).🧰 Tools
🪛 Ruff (0.8.2)
1-1:
pytest
imported but unusedRemove unused import:
pytest
(F401)
3-3:
dm_control.mjcf
imported but unusedRemove unused import:
dm_control.mjcf
(F401)
4-4:
dm_control
imported but unusedRemove unused import:
dm_control
(F401)
6-6:
mujoco
imported but unusedRemove unused import:
mujoco
(F401)
7-7:
stac_mjx.utils.mjx_load
imported but unusedRemove unused import:
stac_mjx.utils.mjx_load
(F401)
8-8: Redefinition of unused
stac_mjx
from line 2Remove definition:
stac_mjx
(F811)
stac_mjx/stac_core.py (6)
148-151: Implementation of squared_error function looks good. The squared_error function is a simple utility function that computes the sum of squared elements in a vector, which is appropriate for its use in the optimization.
153-187
: Well-structured optimization function with proper error handling.The
_q_opt
function correctly applies JIT compilation with static arguments for the solver, which allows for better performance by reusing compilations when only the data changes but not the solver. The error handling is also properly implemented to catch optimization failures.
189-230
: Good use of partial JIT for optimization function.The
_m_opt
function correctly applies partial JIT compilation with static_argnames for the solver parameter, which improves performance when reusing the function with the same solver but different data.
233-256
: Well-documented class initialization with configurable tolerance.The
StacCore
class provides a clear encapsulation of the optimization solvers with proper documentation. The tolerance parameter for the q_solver is correctly implemented, allowing for configuration of the FTOL hyperparameter as required.
257-286
: Well-designed wrapper method for optimization.The
q_opt
method provides a clean interface to the underlying_q_opt
function, passing the class's solver instance along with the other parameters. This design effectively encapsulates the optimization logic.
287-329
: Comprehensive documentation for m_opt method.The
m_opt
method is well-documented with clear parameter descriptions and properly delegates to the_m_opt
function with the class's solver instance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/test_stac_core.py (1)
39-57
: Consider enabling the commented test assertionsThe test is designed to verify JIT compilation cache sizes, but the assertions after running
run_stac
are commented out, making this test incomplete.If the test is too slow without a GPU, consider:
- Adding a proper skip condition using pytest markers
- Or including a clear comment explaining why assertions are disabled
def test_stac_core_compilations(): # tests cache sizes of jit compiled functions - # run time is long without gpu so commented out - # tests after run_stac + # NOTE: Full test requires GPU and is slow. + # Uncomment the assertions below when running with GPU support. config = stac_mjx.load_configs(TEST_DIR / "configs") stac_mjx.enable_xla_flags() assert stac_mjx.stac_core.m_loss._cache_size() == 0 assert stac_mjx.stac_core._m_opt._cache_size() == 0 assert stac_mjx.stac_core._q_opt._cache_size() == 0 kp_data, sorted_kp_names = stac_mjx.load_mocap(config, base_path=PROJECT_DIR) + + # Skip full test if running without GPU + import os + if os.environ.get("TEST_WITH_GPU") != "1": + return + # _, _ = stac_mjx.run_stac(config, kp_data, sorted_kp_names, base_path=PROJECT_DIR) # assert stac_mjx.stac_core.m_loss._cache_size() == 0 # assert stac_mjx.stac_core._q_opt._cache_size() == 2 # assert stac_mjx.stac_core._m_opt._cache_size() == 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
stac_mjx/stac_core.py
(3 hunks)tests/test_stac_core.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
tests/test_stac_core.py
1-1: pytest
imported but unused
Remove unused import: pytest
(F401)
3-3: dm_control.mjcf
imported but unused
Remove unused import: dm_control.mjcf
(F401)
4-4: dm_control
imported but unused
Remove unused import: dm_control
(F401)
6-6: mujoco
imported but unused
Remove unused import: mujoco
(F401)
7-7: stac_mjx.utils.mjx_load
imported but unused
Remove unused import: stac_mjx.utils.mjx_load
(F401)
8-8: Redefinition of unused stac_mjx
from line 2
Remove definition: stac_mjx
(F811)
⏰ Context from checks skipped due to timeout of 90000ms (2)
- GitHub Check: Tests (ubuntu-latest, Python 3.11)
- GitHub Check: Lint
🔇 Additional comments (9)
tests/test_stac_core.py (2)
1-8
: Remove unused imports.Several imports are flagged as unused by the static analyzer. Consider removing these to improve code cleanliness.
- import pytest - import stac_mjx - from dm_control import mjcf - import dm_control from pathlib import Path - import mujoco - from stac_mjx.utils import mjx_load - import stac_mjx + import stac_mjx import optax import jaxopt🧰 Tools
🪛 Ruff (0.8.2)
1-1:
pytest
imported but unusedRemove unused import:
pytest
(F401)
3-3:
dm_control.mjcf
imported but unusedRemove unused import:
dm_control.mjcf
(F401)
4-4:
dm_control
imported but unusedRemove unused import:
dm_control
(F401)
6-6:
mujoco
imported but unusedRemove unused import:
mujoco
(F401)
7-7:
stac_mjx.utils.mjx_load
imported but unusedRemove unused import:
stac_mjx.utils.mjx_load
(F401)
8-8: Redefinition of unused
stac_mjx
from line 2Remove definition:
stac_mjx
(F811)
16-37
: Unit testing of StacCore object initialization - LGTM!The test function properly verifies:
- The instance type is correct
- Instance variables have the expected types
- The tolerance parameter is correctly applied to the q_solver
This effectively validates the new StacCore functionality and the FTOL parameter implementation.
stac_mjx/stac_core.py (7)
7-7
: Good addition of partial import for JIT optimization.The
partial
decorator allows for better optimization of JIT-compiled functions with static arguments. This is a good practice for performance improvement when using JAX.
148-151
: Function reintroduction looks good.The
squared_error
function has been reintroduced as a separate utility function, which makes the code more modular and easier to maintain.
153-187
: Good refactoring of q_opt into _q_opt with partial JIT.The refactoring of
q_opt
into a private function_q_opt
with@partial(jit, static_argnames=["q_solver"])
is a good improvement. It allows the solver to be passed as an argument while maintaining JIT optimization by marking it as a static argument.The error handling is also well-maintained in this refactoring.
189-231
: Good refactoring of m_opt into _m_opt with partial JIT.Similar to the q_opt refactoring, moving the logic to a private function with partial JIT optimization is a good approach. The function maintains the same core logic while allowing the solver to be passed as an argument.
233-255
: Well-designed StacCore class with good documentation.The new
StacCore
class effectively encapsulates the optimization components that were previously global variables. The class documentation and constructor documentation are clear and informative.The implementation of the FTOL parameter as a tolerance for the q_solver addresses the requirements from issue #27 mentioned in the PR objectives.
256-285
: Good wrapper implementation for q_opt.This wrapper method properly delegates to the private
_q_opt
function while passing the instance'sq_solver
. The documentation is clear and explains the function's purpose well.
286-329
: Good wrapper implementation for m_opt with thorough documentation.The
m_opt
wrapper method is well-implemented and includes comprehensive documentation that explains parameters and return values. This maintains the same interface while encapsulating the implementation details.
tests/test_stac_core.py
Outdated
TEST_DIR = Path(__file__).parent | ||
PROJECT_DIR = TEST_DIR.parent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bad pattern here -- use fixtures instead of hardcoding even relative paths
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🔭 Outside diff range comments (1)
tests/test_stac_core.py (1)
46-64
: 🛠️ Refactor suggestionImprove test completeness and robustness.
The test has commented-out assertions that don't run, making it incomplete. Additionally, it accesses internal functions with leading underscores (
_cache_size()
,_m_opt
, etc.) which could change.
- Consider moving the long-running tests to a separate file or using pytest markers to conditionally run them.
- Avoid accessing internal implementation details (functions with leading underscores).
-def test_stac_core_compilations(TEST_DIR, PROJECT_DIR): +@pytest.mark.slow +def test_stac_core_compilations(test_dir, project_dir): # tests cache sizes of jit compiled functions - # run time is long without gpu so commented out - # tests after run_stac - config = stac_mjx.load_configs(TEST_DIR / "configs") + config = stac_mjx.load_configs(test_dir / "configs") stac_mjx.enable_xla_flags() - assert stac_mjx.stac_core.m_loss._cache_size() == 0 - assert stac_mjx.stac_core._m_opt._cache_size() == 0 - assert stac_mjx.stac_core._q_opt._cache_size() == 0 + # Create a minimal test case that doesn't require complete run_stac + stac_core_obj = stac_mjx.stac_core.StacCore(tol=config.model.FTOL) + + # Test some basic functionality that can validate the JIT compilation + # For example, test if functions are callable or if minimal inputs produce expected outputs + # Without accessing internal implementation details - kp_data, sorted_kp_names = stac_mjx.load_mocap(config, base_path=PROJECT_DIR) - # _, _ = stac_mjx.run_stac(config, kp_data, sorted_kp_names, base_path=PROJECT_DIR) - - # assert stac_mjx.stac_core.m_loss._cache_size() == 0 - # assert stac_mjx.stac_core._q_opt._cache_size() == 2 - # assert stac_mjx.stac_core._m_opt._cache_size() == 2 + # If we need to test cache functionality, we should create a public API + # or use proper test doubles/mocksAlternatively, if the runtime is an issue, create a separate test file with the
@pytest.mark.slow
marker:# tests/test_stac_core_slow.py import pytest @pytest.mark.slow def test_full_stac_compilation(test_dir, project_dir): # Full test with run_stac that can be skipped in regular test runs config = stac_mjx.load_configs(test_dir / "configs") stac_mjx.enable_xla_flags() kp_data, sorted_kp_names = stac_mjx.load_mocap(config, base_path=project_dir) _, _ = stac_mjx.run_stac(config, kp_data, sorted_kp_names, base_path=project_dir) # Add assertions here to validate the run
♻️ Duplicate comments (1)
tests/test_stac_core.py (1)
1-11
: 🛠️ Refactor suggestionClean up unused and duplicate imports.
Several imports are identified by static analysis as unused or duplicate:
- Duplicate import of
stac_mjx
(lines 2 and 8)- Unused imports of
dm_control.mjcf
,dm_control
,mujoco
, andstac_mjx.utils.mjx_load
import pytest import stac_mjx -from dm_control import mjcf -import dm_control from pathlib import Path -import mujoco -from stac_mjx.utils import mjx_load -import stac_mjx import optax import jaxopt🧰 Tools
🪛 Ruff (0.8.2)
3-3:
dm_control.mjcf
imported but unusedRemove unused import:
dm_control.mjcf
(F401)
4-4:
dm_control
imported but unusedRemove unused import:
dm_control
(F401)
6-6:
mujoco
imported but unusedRemove unused import:
mujoco
(F401)
7-7:
stac_mjx.utils.mjx_load
imported but unusedRemove unused import:
stac_mjx.utils.mjx_load
(F401)
8-8: Redefinition of unused
stac_mjx
from line 2Remove definition:
stac_mjx
(F811)
🧹 Nitpick comments (2)
tests/test_stac_core.py (2)
13-20
: Follow pytest naming conventions for fixtures.Fixture names should be lowercase and descriptive. Also, consider moving common test fixtures to a
conftest.py
file for reuse across multiple test files.@pytest.fixture -def TEST_DIR(): +def test_dir(): return Path(__file__).parent @pytest.fixture -def PROJECT_DIR(): +def project_dir(): return Path(__file__).parent.parent
32-38
: Avoid directly accessing internal implementation details.The test is checking against specific internal implementation types (
._src
) which makes the tests brittle to internal changes in the libraries. Consider checking against more stable interfaces.- assert isinstance( - stac_core_obj.opt, optax._src.base.GradientTransformationExtraArgs - ) - assert isinstance( - stac_core_obj.q_solver, jaxopt._src.projected_gradient.ProjectedGradient - ) - assert isinstance(stac_core_obj.m_solver, jaxopt._src.optax_wrapper.OptaxSolver) + # Check that the attributes exist and are of the expected high-level types + assert hasattr(stac_core_obj, 'opt') + assert hasattr(stac_core_obj, 'q_solver') + assert hasattr(stac_core_obj, 'm_solver') + + # Optionally check that they implement expected interfaces + assert hasattr(stac_core_obj.q_solver, 'tol') + assert hasattr(stac_core_obj.m_solver, 'update')
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/test_stac_core.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
tests/test_stac_core.py
3-3: dm_control.mjcf
imported but unused
Remove unused import: dm_control.mjcf
(F401)
4-4: dm_control
imported but unused
Remove unused import: dm_control
(F401)
6-6: mujoco
imported but unused
Remove unused import: mujoco
(F401)
7-7: stac_mjx.utils.mjx_load
imported but unused
Remove unused import: stac_mjx.utils.mjx_load
(F401)
8-8: Redefinition of unused stac_mjx
from line 2
Remove definition: stac_mjx
(F811)
⏰ Context from checks skipped due to timeout of 90000ms (2)
- GitHub Check: Lint
- GitHub Check: Tests (ubuntu-latest, Python 3.11)
🔇 Additional comments (1)
tests/test_stac_core.py (1)
25-26
:❓ Verification inconclusive
Use hardcoded paths with caution.
Loading configurations from a relative path can be brittle and cause failures when tests are run from different directories. Consider using a more robust method to locate test configurations.
Are there other tests that use a similar pattern? You might want to create a shared fixture for configuration files.
🏁 Script executed:
#! 9E88 span>/bin/bash # Find other tests that use TEST_DIR / "configs" rg -l "TEST_DIR.*configs" tests/
Length of output: 58
Ensure Reliable Configuration Path Usage
- The pattern
TEST_DIR / "configs"
is currently only present intests/test_stac_core.py
(lines 25-26).- Using a relative path to load test configurations can be brittle if tests are executed from a directory other than the repository root.
- Consider refactoring this approach by centralizing configuration path resolution—perhaps via a shared fixture or explicit absolute path determination—to guard against potential issues in different execution contexts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tested on my end--looks great! just the suggested variable name change and feel free to merge 👍
q_solver = ProjectedGradient(fun=q_loss, projection=projection_box, maxiter=250) | ||
m_solver = OptaxSolver(opt=opt, fun=m_loss, maxiter=2000) | ||
self.q_solver = ProjectedGradient( | ||
fun=q_loss, projection=projection_box, maxiter=250, tol=tol |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change variable name to q_tol
throughout so we can have an m_tol
later
Second pr, first one closed because of failed linting checks and faulty files paths used in tester file. Linting and pytest updated.
Addresses issue #27 by implementing FTOL (tolerance) hyperparameter to pass into q_solver attribute in new StacCore object in stac_core.py. FTOL hyperparameter found in model config file.
Newly defined StacCore object has three attributes: opt, q_solver, m_solver. These were previously global variables defined during the python file import but now are added to the StacCore to pass in hyperparameter arguments (namely FTOL). The StacCore object has two functions, q_opt and m_opt, which are wrappers for the previously defined q_opt and m_opt functions. These functions use the q_solver and m_solver instance attributes, so wrapper functions were made to use the partial decorator for jit compilations to bypass 'dynamic' input types.
Stac.py was updated to add the StacCore object as a Stac instance attribute, and ComputeStac was updated to pass in the StacCore object as an argument into optimization functions to access the q_opt and m_opt functions.
Test file ran to ensure right object types of StacCore instance attributes, correct tolerance hyperparameters were being passed in, and JIT was compiling correctly with the partial decorator
Summary by CodeRabbit
StacCore
class for improved organization and modularity in optimization processes.StacCore
instance, enhancing the optimization workflow.stac_core_obj
for optimization processes.StacCore
class and the compilation process to ensure functionality and correctness.