Installation
Requirements
Python ≥ 3.9
pip ≥ 20.0
From PyPI
The package is published on PyPI:
pip install gradhpo
This installs gradhpo and pulls in JAX, optax and the rest of the runtime
dependencies.
From Source
Clone and install from source:
git clone https://github.com/intsystems/gradhpo.git
pip install ./gradhpo/src
For development (editable install + test/lint extras):
git clone https://github.com/intsystems/gradhpo.git
cd gradhpo
pip install -e ./src
pip install pytest pytest-cov flake8
Dependencies
Pinned ranges from src/requirements.txt:
Library |
Version |
|---|---|
JAX |
|
jaxlib |
|
optax |
|
chex |
|
numpy |
|
scipy |
|
scikit-learn |
|
typing-extensions |
|
To build the documentation, additionally install:
pip install -r doc/requirements.txt
Verifying the Installation
import gradhpo
print(gradhpo.__version__)
from gradhpo import (
BilevelState,
BilevelOptimizer,
OnlineHypergradientOptimizer,
T1T2Optimizer,
GreedyOptimizer,
FOOptimizer,
OneStepOptimizer,
)
Running the Test Suite
pip install pytest pytest-cov
pytest tests/ --cov=gradhpo --cov-report=term-missing
The full suite contains 76 tests and currently reaches 100 % statement coverage.
Building the Documentation
cd doc
sphinx-build -W --keep-going -b html source build/html
The output is written to doc/build/html/index.html.
Troubleshooting
- ModuleNotFoundError: No module named ‘gradhpo’
The package is not in the active Python environment. Install it with
pip install gradhpo(PyPI) orpip install -e ./srcfrom a source checkout.- No GPU acceleration
JAX uses CPU by default. To enable GPU, install
jaxlibwith CUDA support following the JAX installation guide.- Documentation build errors
Reinstall the documentation dependencies:
pip install --upgrade -r doc/requirements.txt.