diff --git a/.github/bump_version.py b/.github/bump_version.py
index bb0fd6dd3..779a82e38 100644
--- a/.github/bump_version.py
+++ b/.github/bump_version.py
@@ -19,9 +19,7 @@ def get_current_version(pyproject_path: Path) -> str:
def infer_bump(changelog_dir: Path) -> str:
fragments = [
- f
- for f in changelog_dir.iterdir()
- if f.is_file() and f.name != ".gitkeep"
+ f for f in changelog_dir.iterdir() if f.is_file() and f.name != ".gitkeep"
]
if not fragments:
print("No changelog fragments found", file=sys.stderr)
diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml
index 820eea12f..b8bcfafe6 100644
--- a/.github/workflows/pr.yaml
+++ b/.github/workflows/pr.yaml
@@ -7,10 +7,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
+ - name: Install uv
+ uses: astral-sh/setup-uv@v4
- name: Check formatting
- uses: "lgeiger/black-action@master"
- with:
- args: ". -l 79 --check"
+ run: uvx ruff format --check .
check-changelog:
name: Check changelog fragment
runs-on: ubuntu-latest
diff --git a/.github/workflows/push.yaml b/.github/workflows/push.yaml
index e02ad0317..b9422f88e 100644
--- a/.github/workflows/push.yaml
+++ b/.github/workflows/push.yaml
@@ -10,10 +10,10 @@ jobs:
&& (github.event.head_commit.message == 'Update PolicyEngine Core')
steps:
- uses: actions/checkout@v4
+ - name: Install uv
+ uses: astral-sh/setup-uv@v4
- name: Check formatting
- uses: "lgeiger/black-action@master"
- with:
- args: ". -l 79 --check"
+ run: uvx ruff format --check .
versioning:
name: Update versioning
if: |
diff --git a/Makefile b/Makefile
index a17cc83f4..fd99cbbbf 100644
--- a/Makefile
+++ b/Makefile
@@ -6,7 +6,7 @@ documentation:
python docs/add_plotly_to_book.py docs/_build
format:
- black . -l 79
+ ruff format .
install:
pip install -e ".[dev]" --config-settings editable_mode=compat
diff --git a/README.md b/README.md
index bd8fcafd0..c5b6af075 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,7 @@
[](https://codecov.io/gh/PolicyEngine/policyengine-core)
[](https://badge.fury.io/py/policyengine-core)
-[](https://github.com/psf/black)
+[](https://github.com/astral-sh/ruff)
This package, a fork of [OpenFisca-Core](https://github.com/OpenFisca/OpenFisca-Core), powers PolicyEngine country models and apps.
@@ -64,7 +64,7 @@ We also ask that you add tests for any new features or bug-fixes you add, so we
### Step 2: Formatting
-In addition to the tests, we use [Black](https://github.com/psf/black) to lint our codebase, so before opening a pull request, Step 2 is to lint the code by running
+In addition to the tests, we use [Ruff](https://github.com/astral-sh/ruff) to format our codebase, so before opening a pull request, Step 2 is to format the code by running
```
make format
diff --git a/changelog.d/switch-to-ruff.changed.md b/changelog.d/switch-to-ruff.changed.md
new file mode 100644
index 000000000..3e1764241
--- /dev/null
+++ b/changelog.d/switch-to-ruff.changed.md
@@ -0,0 +1 @@
+Switched code formatter from black to ruff format.
diff --git a/docs/usage/datasets.ipynb b/docs/usage/datasets.ipynb
index 0efe93fbc..a7cfbce22 100644
--- a/docs/usage/datasets.ipynb
+++ b/docs/usage/datasets.ipynb
@@ -39,9 +39,7 @@
" # Specify metadata used to describe and store the dataset.\n",
" name = \"country_template_dataset\"\n",
" label = \"Country template dataset\"\n",
- " file_path = (\n",
- " COUNTRY_DIR / \"data\" / \"storage\" / \"country_template_dataset.h5\"\n",
- " )\n",
+ " file_path = COUNTRY_DIR / \"data\" / \"storage\" / \"country_template_dataset.h5\"\n",
" data_format = Dataset.TIME_PERIOD_ARRAYS\n",
"\n",
" # The generation function is the most important part: it defines\n",
diff --git a/docs/usage/reforms.ipynb b/docs/usage/reforms.ipynb
index 49f228bdf..b9bd4c2ea 100644
--- a/docs/usage/reforms.ipynb
+++ b/docs/usage/reforms.ipynb
@@ -1,330 +1,336 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Reforms\n",
- "\n",
- "Reforms allow you to modify a tax-benefit system's parameters or variable logic to simulate policy changes. PolicyEngine Core provides two main approaches to creating reforms:\n",
- "\n",
- "1. **`Reform.from_dict()`** - Create reforms programmatically from a dictionary (recommended for most use cases)\n",
- "2. **Reform subclass** - Create a custom class for complex reforms requiring variable logic changes\n",
- "\n",
- "## Creating reforms with `Reform.from_dict()`\n",
- "\n",
- "The `Reform.from_dict()` method is the simplest way to create parameter-based reforms. It takes a dictionary mapping parameter paths to their new values."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Method signature\n",
- "\n",
- "```python\n",
- "Reform.from_dict(\n",
- " parameter_values: dict, # Parameter path -> period -> value mappings\n",
- " country_id: str = None, # Optional: country code for API integration\n",
- " name: str = None, # Optional: human-readable name for the reform\n",
- ") -> Reform\n",
- "```"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Dictionary format\n",
- "\n",
- "The `parameter_values` dictionary uses parameter paths as keys and period-value mappings as values:\n",
- "\n",
- "```python\n",
- "{\n",
- " 'path.to.parameter': {\n",
- " 'YYYY-MM-DD.YYYY-MM-DD': value, # Period format: start.end\n",
- " },\n",
- " 'path.to.another.parameter': {\n",
- " '2024-01-01.2100-12-31': 1000, # Example: effective 2024 onwards\n",
- " },\n",
- "}\n",
- "```\n",
- "\n",
- "**Period formats:**\n",
- "- `'YYYY-MM-DD.YYYY-MM-DD'` - Date range (start date, end date separated by `.`)\n",
- "- `'year:YYYY:N'` - N years starting from YYYY (e.g., `'year:2024:10'` for 2024-2033)\n",
- "- Simple value (no dict) - Applies for 100 years from 2000 (convenience shorthand)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Basic example\n",
- "\n",
- "Let's create a simple reform that increases the basic income amount in the country template."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from policyengine_core.country_template import Microsimulation\n",
- "from policyengine_core.reforms import Reform\n",
- "\n",
- "# Create a reform that increases basic income from 600 to 1000\n",
- "basic_income_reform = Reform.from_dict(\n",
- " {\n",
- " 'benefits.basic_income': {\n",
- " '2020-01-01.2100-12-31': 1000,\n",
- " }\n",
- " }\n",
- ")\n",
- "\n",
- "# Run simulations\n",
- "baseline = Microsimulation()\n",
- "reformed = Microsimulation(reform=basic_income_reform)\n",
- "\n",
- "# Compare results\n",
- "print(\"Baseline basic income:\")\n",
- "print(baseline.calculate(\"basic_income\", \"2022-01\"))\n",
- "print(\"\\nReformed basic income:\")\n",
- "print(reformed.calculate(\"basic_income\", \"2022-01\"))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Multiple parameters\n",
- "\n",
- "You can modify multiple parameters in a single reform."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Reform modifying multiple parameters\n",
- "multi_param_reform = Reform.from_dict(\n",
- " {\n",
- " 'benefits.basic_income': {\n",
- " '2020-01-01.2100-12-31': 800,\n",
- " },\n",
- " 'taxes.income_tax_rate': {\n",
- " '2020-01-01.2100-12-31': 0.20, # Increase from 0.15 to 0.20\n",
- " },\n",
- " }\n",
- ")\n",
- "\n",
- "reformed_multi = Microsimulation(reform=multi_param_reform)\n",
- "\n",
- "print(\"Reformed basic income:\", reformed_multi.calculate(\"basic_income\", \"2022-01\"))\n",
- "print(\"Reformed income tax:\", reformed_multi.calculate(\"income_tax\", \"2022-01\"))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Bracket parameters\n",
- "\n",
- "For parameters with brackets (like tax scales), use array index notation: `parameter.brackets[index].rate` or `parameter.brackets[index].threshold`."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Reform modifying a bracket parameter (social security contribution scale)\n",
- "bracket_reform = Reform.from_dict(\n",
- " {\n",
- " # Modify the first bracket's rate\n",
- " 'taxes.social_security_contribution.brackets[0].rate': {\n",
- " '2020-01-01.2100-12-31': 0.05,\n",
- " },\n",
- " # Modify the second bracket's threshold\n",
- " 'taxes.social_security_contribution.brackets[1].threshold': {\n",
- " '2020-01-01.2100-12-31': 15000,\n",
- " },\n",
- " }\n",
- ")\n",
- "\n",
- "reformed_bracket = Microsimulation(reform=bracket_reform)\n",
- "print(\"Reformed social security contribution:\", reformed_bracket.calculate(\"social_security_contribution\", \"2022-01\"))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Shorthand for permanent changes\n",
- "\n",
- "If you want a parameter change to apply indefinitely, you can omit the period dictionary and just provide the value directly."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Shorthand: value applies from year 2000 for 100 years\n",
- "simple_reform = Reform.from_dict(\n",
- " {\n",
- " 'benefits.basic_income': 1200, # No period dict needed\n",
- " 'taxes.income_tax_rate': 0.18,\n",
- " }\n",
- ")\n",
- "\n",
- "reformed_simple = Microsimulation(reform=simple_reform)\n",
- "print(\"Basic income:\", reformed_simple.calculate(\"basic_income\", \"2022-01\"))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Named reforms for API integration\n",
- "\n",
- "When working with the PolicyEngine API, you can provide a `country_id` and `name` for better integration."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Named reform with country ID (useful for API integration)\n",
- "named_reform = Reform.from_dict(\n",
- " {\n",
- " 'benefits.basic_income': {\n",
- " '2024-01-01.2100-12-31': 1500,\n",
- " },\n",
- " },\n",
- " country_id='country_template',\n",
- " name='Increased Basic Income Reform',\n",
- ")\n",
- "\n",
- "# The reform class now has these attributes set\n",
- "print(f\"Reform name: {named_reform.name}\")\n",
- "print(f\"Country ID: {named_reform.country_id}\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Creating reforms with a Reform subclass\n",
- "\n",
- "For more complex reforms that need to modify variable logic or simulation data, define a class inheriting from `Reform` with an `apply(self)` method. Inside it, `self` is the tax-benefit system attached to the simulation with loaded data via `self.simulation: Simulation`."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from policyengine_core.model_api import *\n",
- "\n",
- "baseline = Microsimulation()\n",
- "\n",
- "\n",
- "class custom_reform(Reform):\n",
- " def apply(self):\n",
- " simulation = self.simulation\n",
- "\n",
- " # Modify parameters\n",
- " simulation.tax_benefit_system.parameters.taxes.housing_tax.rate.update(\n",
- " 20\n",
- " )\n",
- "\n",
- " # Modify simulation data\n",
- " salary = simulation.calculate(\"salary\", \"2022-01\")\n",
- " new_salary = salary * 1.1\n",
- " simulation.set_input(\"salary\", \"2022-01\", new_salary)\n",
- "\n",
- "\n",
- "reformed = Microsimulation(reform=custom_reform)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "reformed.calculate(\"salary\", \"2022-01\"), baseline.calculate(\n",
- " \"salary\", \"2022-01\"\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "reformed.calculate(\"housing_tax\", 2022), baseline.calculate(\n",
- " \"housing_tax\", 2022\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Comparing baseline and reformed simulations\n",
- "\n",
- "A common workflow is to compare results between baseline and reformed scenarios."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Create reform\n",
- "reform = Reform.from_dict(\n",
- " {'benefits.basic_income': {'2020-01-01.2100-12-31': 1000}}\n",
- ")\n",
- "\n",
- "# Run both simulations\n",
- "baseline = Microsimulation()\n",
- "reformed = Microsimulation(reform=reform)\n",
- "\n",
- "# Calculate and compare\n",
- "baseline_income = baseline.calculate(\"basic_income\", \"2022-01\")\n",
- "reformed_income = reformed.calculate(\"basic_income\", \"2022-01\")\n",
- "\n",
- "print(\"Baseline basic income:\")\n",
- "print(baseline_income)\n",
- "print(\"\\nReformed basic income:\")\n",
- "print(reformed_income)\n",
- "print(f\"\\nDifference: {reformed_income['value'].sum() - baseline_income['value'].sum():.2f}\")"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "name": "python",
- "version": "3.10.0"
- }
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Reforms\n",
+ "\n",
+ "Reforms allow you to modify a tax-benefit system's parameters or variable logic to simulate policy changes. PolicyEngine Core provides two main approaches to creating reforms:\n",
+ "\n",
+ "1. **`Reform.from_dict()`** - Create reforms programmatically from a dictionary (recommended for most use cases)\n",
+ "2. **Reform subclass** - Create a custom class for complex reforms requiring variable logic changes\n",
+ "\n",
+ "## Creating reforms with `Reform.from_dict()`\n",
+ "\n",
+ "The `Reform.from_dict()` method is the simplest way to create parameter-based reforms. It takes a dictionary mapping parameter paths to their new values."
+ ]
},
- "nbformat": 4,
- "nbformat_minor": 4
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Method signature\n",
+ "\n",
+ "```python\n",
+ "Reform.from_dict(\n",
+ " parameter_values: dict, # Parameter path -> period -> value mappings\n",
+ " country_id: str = None, # Optional: country code for API integration\n",
+ " name: str = None, # Optional: human-readable name for the reform\n",
+ ") -> Reform\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Dictionary format\n",
+ "\n",
+ "The `parameter_values` dictionary uses parameter paths as keys and period-value mappings as values:\n",
+ "\n",
+ "```python\n",
+ "{\n",
+ " 'path.to.parameter': {\n",
+ " 'YYYY-MM-DD.YYYY-MM-DD': value, # Period format: start.end\n",
+ " },\n",
+ " 'path.to.another.parameter': {\n",
+ " '2024-01-01.2100-12-31': 1000, # Example: effective 2024 onwards\n",
+ " },\n",
+ "}\n",
+ "```\n",
+ "\n",
+ "**Period formats:**\n",
+ "- `'YYYY-MM-DD.YYYY-MM-DD'` - Date range (start date, end date separated by `.`)\n",
+ "- `'year:YYYY:N'` - N years starting from YYYY (e.g., `'year:2024:10'` for 2024-2033)\n",
+ "- Simple value (no dict) - Applies for 100 years from 2000 (convenience shorthand)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Basic example\n",
+ "\n",
+ "Let's create a simple reform that increases the basic income amount in the country template."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from policyengine_core.country_template import Microsimulation\n",
+ "from policyengine_core.reforms import Reform\n",
+ "\n",
+ "# Create a reform that increases basic income from 600 to 1000\n",
+ "basic_income_reform = Reform.from_dict(\n",
+ " {\n",
+ " \"benefits.basic_income\": {\n",
+ " \"2020-01-01.2100-12-31\": 1000,\n",
+ " }\n",
+ " }\n",
+ ")\n",
+ "\n",
+ "# Run simulations\n",
+ "baseline = Microsimulation()\n",
+ "reformed = Microsimulation(reform=basic_income_reform)\n",
+ "\n",
+ "# Compare results\n",
+ "print(\"Baseline basic income:\")\n",
+ "print(baseline.calculate(\"basic_income\", \"2022-01\"))\n",
+ "print(\"\\nReformed basic income:\")\n",
+ "print(reformed.calculate(\"basic_income\", \"2022-01\"))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Multiple parameters\n",
+ "\n",
+ "You can modify multiple parameters in a single reform."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Reform modifying multiple parameters\n",
+ "multi_param_reform = Reform.from_dict(\n",
+ " {\n",
+ " \"benefits.basic_income\": {\n",
+ " \"2020-01-01.2100-12-31\": 800,\n",
+ " },\n",
+ " \"taxes.income_tax_rate\": {\n",
+ " \"2020-01-01.2100-12-31\": 0.20, # Increase from 0.15 to 0.20\n",
+ " },\n",
+ " }\n",
+ ")\n",
+ "\n",
+ "reformed_multi = Microsimulation(reform=multi_param_reform)\n",
+ "\n",
+ "print(\n",
+ " \"Reformed basic income:\",\n",
+ " reformed_multi.calculate(\"basic_income\", \"2022-01\"),\n",
+ ")\n",
+ "print(\"Reformed income tax:\", reformed_multi.calculate(\"income_tax\", \"2022-01\"))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Bracket parameters\n",
+ "\n",
+ "For parameters with brackets (like tax scales), use array index notation: `parameter.brackets[index].rate` or `parameter.brackets[index].threshold`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Reform modifying a bracket parameter (social security contribution scale)\n",
+ "bracket_reform = Reform.from_dict(\n",
+ " {\n",
+ " # Modify the first bracket's rate\n",
+ " \"taxes.social_security_contribution.brackets[0].rate\": {\n",
+ " \"2020-01-01.2100-12-31\": 0.05,\n",
+ " },\n",
+ " # Modify the second bracket's threshold\n",
+ " \"taxes.social_security_contribution.brackets[1].threshold\": {\n",
+ " \"2020-01-01.2100-12-31\": 15000,\n",
+ " },\n",
+ " }\n",
+ ")\n",
+ "\n",
+ "reformed_bracket = Microsimulation(reform=bracket_reform)\n",
+ "print(\n",
+ " \"Reformed social security contribution:\",\n",
+ " reformed_bracket.calculate(\"social_security_contribution\", \"2022-01\"),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Shorthand for permanent changes\n",
+ "\n",
+ "If you want a parameter change to apply indefinitely, you can omit the period dictionary and just provide the value directly."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Shorthand: value applies from year 2000 for 100 years\n",
+ "simple_reform = Reform.from_dict(\n",
+ " {\n",
+ " \"benefits.basic_income\": 1200, # No period dict needed\n",
+ " \"taxes.income_tax_rate\": 0.18,\n",
+ " }\n",
+ ")\n",
+ "\n",
+ "reformed_simple = Microsimulation(reform=simple_reform)\n",
+ "print(\"Basic income:\", reformed_simple.calculate(\"basic_income\", \"2022-01\"))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Named reforms for API integration\n",
+ "\n",
+ "When working with the PolicyEngine API, you can provide a `country_id` and `name` for better integration."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Named reform with country ID (useful for API integration)\n",
+ "named_reform = Reform.from_dict(\n",
+ " {\n",
+ " \"benefits.basic_income\": {\n",
+ " \"2024-01-01.2100-12-31\": 1500,\n",
+ " },\n",
+ " },\n",
+ " country_id=\"country_template\",\n",
+ " name=\"Increased Basic Income Reform\",\n",
+ ")\n",
+ "\n",
+ "# The reform class now has these attributes set\n",
+ "print(f\"Reform name: {named_reform.name}\")\n",
+ "print(f\"Country ID: {named_reform.country_id}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Creating reforms with a Reform subclass\n",
+ "\n",
+ "For more complex reforms that need to modify variable logic or simulation data, define a class inheriting from `Reform` with an `apply(self)` method. Inside it, `self` is the tax-benefit system attached to the simulation with loaded data via `self.simulation: Simulation`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from policyengine_core.model_api import *\n",
+ "\n",
+ "baseline = Microsimulation()\n",
+ "\n",
+ "\n",
+ "class custom_reform(Reform):\n",
+ " def apply(self):\n",
+ " simulation = self.simulation\n",
+ "\n",
+ " # Modify parameters\n",
+ " simulation.tax_benefit_system.parameters.taxes.housing_tax.rate.update(20)\n",
+ "\n",
+ " # Modify simulation data\n",
+ " salary = simulation.calculate(\"salary\", \"2022-01\")\n",
+ " new_salary = salary * 1.1\n",
+ " simulation.set_input(\"salary\", \"2022-01\", new_salary)\n",
+ "\n",
+ "\n",
+ "reformed = Microsimulation(reform=custom_reform)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "(\n",
+ " reformed.calculate(\"salary\", \"2022-01\"),\n",
+ " baseline.calculate(\"salary\", \"2022-01\"),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "(\n",
+ " reformed.calculate(\"housing_tax\", 2022),\n",
+ " baseline.calculate(\"housing_tax\", 2022),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Comparing baseline and reformed simulations\n",
+ "\n",
+ "A common workflow is to compare results between baseline and reformed scenarios."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create reform\n",
+ "reform = Reform.from_dict({\"benefits.basic_income\": {\"2020-01-01.2100-12-31\": 1000}})\n",
+ "\n",
+ "# Run both simulations\n",
+ "baseline = Microsimulation()\n",
+ "reformed = Microsimulation(reform=reform)\n",
+ "\n",
+ "# Calculate and compare\n",
+ "baseline_income = baseline.calculate(\"basic_income\", \"2022-01\")\n",
+ "reformed_income = reformed.calculate(\"basic_income\", \"2022-01\")\n",
+ "\n",
+ "print(\"Baseline basic income:\")\n",
+ "print(baseline_income)\n",
+ "print(\"\\nReformed basic income:\")\n",
+ "print(reformed_income)\n",
+ "print(\n",
+ " f\"\\nDifference: {reformed_income['value'].sum() - baseline_income['value'].sum():.2f}\"\n",
+ ")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.10.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
}
diff --git a/policyengine_core/charts/api.py b/policyengine_core/charts/api.py
index 3b8fe039f..cddc2bf32 100644
--- a/policyengine_core/charts/api.py
+++ b/policyengine_core/charts/api.py
@@ -21,9 +21,7 @@ def get_api_chart_data(
version: str = None,
) -> dict:
if baseline_policy_id is None or version is None:
- response = requests.get(
- f"https://api.policyengine.org/{country_id}/metadata"
- )
+ response = requests.get(f"https://api.policyengine.org/{country_id}/metadata")
result = response.json().get("result", {})
baseline_policy_id = result.get("current_law_id")
version = result.get("version")
@@ -92,9 +90,7 @@ def intra_decile_chart(
"color": outcome_colour,
},
"orientation": "h",
- "text": [
- f"{impact['intra_decile']['all'][outcome_label] * 100:.0f}%"
- ],
+ "text": [f"{impact['intra_decile']['all'][outcome_label] * 100:.0f}%"],
"textposition": "inside",
"textangle": 0,
"xaxis": "x",
@@ -117,9 +113,7 @@ def intra_decile_chart(
"orientation": "h",
"text": [
f"{value * 100:.0f}%"
- for value in impact["intra_decile"]["deciles"][
- outcome_label
- ]
+ for value in impact["intra_decile"]["deciles"][outcome_label]
],
"textposition": "inside",
"textangle": 0,
@@ -201,8 +195,7 @@ def decile_chart(
"type": "bar",
"marker": {
"color": [
- DARK_GRAY if value < 0 else BLUE_PRIMARY
- for value in decile_values
+ DARK_GRAY if value < 0 else BLUE_PRIMARY for value in decile_values
],
},
"text": [f"{value:+.1%}" for value in decile_values],
diff --git a/policyengine_core/charts/bar.py b/policyengine_core/charts/bar.py
index 8f54ed169..03798cb0c 100644
--- a/policyengine_core/charts/bar.py
+++ b/policyengine_core/charts/bar.py
@@ -32,11 +32,7 @@ def bar_chart(
"""
hover_text_labels = [
- (
- hover_text_function(index, value)
- if hover_text_function is not None
- else None
- )
+ (hover_text_function(index, value) if hover_text_function is not None else None)
for index, value in data.items()
]
@@ -56,8 +52,7 @@ def bar_chart(
)
.update_traces(
marker_color=[
- positive_colour if v > 0 else negative_colour
- for v in data.values
+ positive_colour if v > 0 else negative_colour for v in data.values
],
hovertemplate=(
"%{customdata[0]}"
diff --git a/policyengine_core/charts/formatting.py b/policyengine_core/charts/formatting.py
index 182774c05..8be4d171b 100644
--- a/policyengine_core/charts/formatting.py
+++ b/policyengine_core/charts/formatting.py
@@ -73,9 +73,7 @@ def format_fig(fig: go.Figure) -> go.Figure:
def display_fig(fig: go.Figure) -> HTML:
- return HTML(
- format_fig(fig).to_html(full_html=False, include_plotlyjs="cdn")
- )
+ return HTML(format_fig(fig).to_html(full_html=False, include_plotlyjs="cdn"))
def cardinal(n: int) -> int:
diff --git a/policyengine_core/commons/formulas.py b/policyengine_core/commons/formulas.py
index 2bd1c494d..425628d55 100644
--- a/policyengine_core/commons/formulas.py
+++ b/policyengine_core/commons/formulas.py
@@ -98,14 +98,10 @@ def concat(this: ArrayLike[str], that: ArrayLike[str]) -> ArrayType[str]:
if isinstance(that, tuple):
raise TypeError("Second argument must not be a tuple.")
- if isinstance(this, numpy.ndarray) and not numpy.issubdtype(
- this.dtype, numpy.str_
- ):
+ if isinstance(this, numpy.ndarray) and not numpy.issubdtype(this.dtype, numpy.str_):
this = this.astype("str")
- if isinstance(that, numpy.ndarray) and not numpy.issubdtype(
- that.dtype, numpy.str_
- ):
+ if isinstance(that, numpy.ndarray) and not numpy.issubdtype(that.dtype, numpy.str_):
that = that.astype("str")
return numpy.char.add(this, that)
@@ -139,13 +135,11 @@ def switch(
"""
- assert (
- len(value_by_condition) > 0
- ), "'switch' must be called with at least one value."
+ assert len(value_by_condition) > 0, (
+ "'switch' must be called with at least one value."
+ )
- condlist = [
- conditions == condition for condition in value_by_condition.keys()
- ]
+ condlist = [conditions == condition for condition in value_by_condition.keys()]
return numpy.select(condlist, value_by_condition.values())
@@ -187,17 +181,13 @@ def for_each_variable(
if variable_entity.key == entity.entity.key:
values = entity(variable, period, options=options)
elif variable_entity.is_person:
- values = group_agg_func(
- entity.members(variable, period, options=options)
- )
+ values = group_agg_func(entity.members(variable, period, options=options))
elif entity.entity.is_person:
raise ValueError(
f"You requested to aggregate {variable} (defined for {variable_entity.plural}) to {entity.entity.plural}, but this is not yet implemented."
)
else: # Group-to-group aggregation
- variable_population = entity.simulation.populations[
- variable_entity.key
- ]
+ variable_population = entity.simulation.populations[variable_entity.key]
person_shares = variable_population.project(
variable_population(variable, period)
) / variable_population.project(variable_population.nb_persons())
@@ -229,9 +219,7 @@ def add(
Returns:
ArrayLike: The result of the operation.
"""
- return for_each_variable(
- entity, period, variables, agg_func="add", options=options
- )
+ return for_each_variable(entity, period, variables, agg_func="add", options=options)
def and_(
@@ -283,9 +271,7 @@ def amount_over(amount: ArrayLike, threshold: float) -> ArrayLike:
Returns:
ArrayLike: The amounts over the threshold.
"""
- logging.debug(
- "amount_over(x, y) is deprecated, use max_(x - y, 0) instead."
- )
+ logging.debug("amount_over(x, y) is deprecated, use max_(x - y, 0) instead.")
return max_(0, amount - threshold)
@@ -334,9 +320,9 @@ def random(population):
entity_ids = population(f"{population.entity.key}_id", period)
# Generate deterministic random values using vectorised hash
- seeds = np.abs(
- entity_ids * 100 + population.simulation.count_random_calls
- ).astype(np.uint64)
+ seeds = np.abs(entity_ids * 100 + population.simulation.count_random_calls).astype(
+ np.uint64
+ )
# PCG-style mixing function for high-quality pseudo-random generation
x = seeds * np.uint64(0x5851F42D4C957F2D)
diff --git a/policyengine_core/commons/rates.py b/policyengine_core/commons/rates.py
index 92ea774fd..06bef6ed0 100644
--- a/policyengine_core/commons/rates.py
+++ b/policyengine_core/commons/rates.py
@@ -98,9 +98,7 @@ def marginal_rate(
"""
marginal_rate: ArrayType[float]
- marginal_rate = +1 - (target[:-1] - target[1:]) / (
- varying[:-1] - varying[1:]
- )
+ marginal_rate = +1 - (target[:-1] - target[1:]) / (varying[:-1] - varying[1:])
if trim is not None:
marginal_rate = numpy.where(
diff --git a/policyengine_core/country_template/data/datasets/country_template_dataset.py b/policyengine_core/country_template/data/datasets/country_template_dataset.py
index cc5c3e974..d0ae21795 100644
--- a/policyengine_core/country_template/data/datasets/country_template_dataset.py
+++ b/policyengine_core/country_template/data/datasets/country_template_dataset.py
@@ -7,9 +7,7 @@ class CountryTemplateDataset(Dataset):
# Specify metadata used to describe and store the dataset.
name = "country_template_dataset"
label = "Country template dataset"
- file_path = (
- COUNTRY_DIR / "data" / "storage" / "country_template_dataset.h5"
- )
+ file_path = COUNTRY_DIR / "data" / "storage" / "country_template_dataset.h5"
data_format = Dataset.TIME_PERIOD_ARRAYS
# The generation function is the most important part: it defines
diff --git a/policyengine_core/country_template/variables/benefits.py b/policyengine_core/country_template/variables/benefits.py
index 640af7b9c..afb12a88f 100644
--- a/policyengine_core/country_template/variables/benefits.py
+++ b/policyengine_core/country_template/variables/benefits.py
@@ -19,7 +19,9 @@ class basic_income(Variable):
entity = Person
definition_period = MONTH
label = "Basic income provided to adults"
- reference = "https://law.gov.example/basic_income" # Always use the most official source
+ reference = (
+ "https://law.gov.example/basic_income" # Always use the most official source
+ )
def formula_2016_12(person, period, parameters):
"""
@@ -46,9 +48,7 @@ def formula_2015_12(person, period, parameters):
)
salary_condition = person("salary", period) == 0
return (
- age_condition
- * salary_condition
- * parameters(period).benefits.basic_income
+ age_condition * salary_condition * parameters(period).benefits.basic_income
) # The '*' is also used as a vectorial 'and'. See https://openfisca.org/doc/coding-the-legislation/25_vectorial_computing.html#boolean-operations
@@ -75,10 +75,7 @@ def formula_1980(household, period, parameters):
To compute this allowance, the 'rent' value must be provided for the same month,
but 'housing_occupancy_status' is not necessary.
"""
- return (
- household("rent", period)
- * parameters(period).benefits.housing_allowance
- )
+ return household("rent", period) * parameters(period).benefits.housing_allowance
# By default, you can use utf-8 characters in a variable. OpenFisca web API manages utf-8 encoding.
@@ -101,8 +98,7 @@ def formula(person, period, parameters):
In Arabic: تقاعد.
"""
age_condition = (
- person("age", period)
- >= parameters(period).general.age_of_retirement
+ person("age", period) >= parameters(period).general.age_of_retirement
)
return age_condition
@@ -135,9 +131,7 @@ def formula(household, period, parameters):
under_8 = household.any(ages < 8)
under_6 = household.any(ages < 6)
- allowance_condition = income_condition * (
- (is_single * under_8) + under_6
- )
+ allowance_condition = income_condition * ((is_single * under_8) + under_6)
allowance_amount = parenting_allowance.amount
return allowance_condition * allowance_amount
diff --git a/policyengine_core/country_template/variables/income.py b/policyengine_core/country_template/variables/income.py
index 2e2ebc6e5..18d4a9ac4 100644
--- a/policyengine_core/country_template/variables/income.py
+++ b/policyengine_core/country_template/variables/income.py
@@ -22,9 +22,7 @@ class salary(Variable):
definition_period = MONTH
set_input = set_input_divide_by_period # Optional attribute. Allows user to declare a salary for a year. OpenFisca will spread the yearly amount over the months contained in the year.
label = "Salary"
- reference = (
- "https://law.gov.example/salary" # Always use the most official source
- )
+ reference = "https://law.gov.example/salary" # Always use the most official source
class disposable_income(Variable):
diff --git a/policyengine_core/country_template/variables/taxes.py b/policyengine_core/country_template/variables/taxes.py
index 6fd74cd12..a7043bf7d 100644
--- a/policyengine_core/country_template/variables/taxes.py
+++ b/policyengine_core/country_template/variables/taxes.py
@@ -21,7 +21,9 @@ class income_tax(Variable):
entity = Person
definition_period = MONTH
label = "Income tax"
- reference = "https://law.gov.example/income_tax" # Always use the most official source
+ reference = (
+ "https://law.gov.example/income_tax" # Always use the most official source
+ )
def formula(person, period, parameters):
"""
@@ -29,18 +31,14 @@ def formula(person, period, parameters):
The formula to compute the income tax for a given person at a given period
"""
- return (
- person("salary", period) * parameters(period).taxes.income_tax_rate
- )
+ return person("salary", period) * parameters(period).taxes.income_tax_rate
class social_security_contribution(Variable):
value_type = float
entity = Person
definition_period = MONTH
- label = (
- "Progressive contribution paid on salaries to finance social security"
- )
+ label = "Progressive contribution paid on salaries to finance social security"
reference = "https://law.gov.example/social_security_contribution" # Always use the most official source
def formula(person, period, parameters):
@@ -60,7 +58,9 @@ class housing_tax(Variable):
entity = Household
definition_period = YEAR # This housing tax is defined for a year.
label = "Tax paid by each household proportionally to the size of its accommodation"
- reference = "https://law.gov.example/housing_tax" # Always use the most official source
+ reference = (
+ "https://law.gov.example/housing_tax" # Always use the most official source
+ )
def formula(household, period, parameters):
"""
diff --git a/policyengine_core/data/dataset.py b/policyengine_core/data/dataset.py
index 2cd715b6f..7caebb0c0 100644
--- a/policyengine_core/data/dataset.py
+++ b/policyengine_core/data/dataset.py
@@ -85,19 +85,21 @@ def __init__(self, require: bool = False):
self.file_path.parent.mkdir(parents=True, exist_ok=True)
- assert (
- self.name
- ), "You tried to instantiate a Dataset object, but no name has been provided."
- assert (
- self.label
- ), "You tried to instantiate a Dataset object, but no label has been provided."
+ assert self.name, (
+ "You tried to instantiate a Dataset object, but no name has been provided."
+ )
+ assert self.label, (
+ "You tried to instantiate a Dataset object, but no label has been provided."
+ )
assert self.data_format in [
Dataset.TABLES,
Dataset.ARRAYS,
Dataset.TIME_PERIOD_ARRAYS,
Dataset.FLAT_FILE,
- ], f"You tried to instantiate a Dataset object, but your data_format attribute is invalid ({self.data_format})."
+ ], (
+ f"You tried to instantiate a Dataset object, but your data_format attribute is invalid ({self.data_format})."
+ )
self._table_cache = {}
@@ -251,9 +253,7 @@ def load_dataset(
data[variable][time_period] = np.array(f[key])
elif self.data_format == Dataset.ARRAYS:
with h5py.File(file, "r") as f:
- data = {
- variable: np.array(f[variable]) for variable in f.keys()
- }
+ data = {variable: np.array(f[variable]) for variable in f.keys()}
return data
def generate(self):
@@ -342,7 +342,9 @@ def download(self, url: str = None, version: str = None) -> None:
if url.startswith("release://"):
org, repo, release_tag, file_path = url.split("/")[2:]
- url = f"https://api.github.com/repos/{org}/{repo}/releases/tags/{release_tag}"
+ url = (
+ f"https://api.github.com/repos/{org}/{repo}/releases/tags/{release_tag}"
+ )
response = requests.get(url, headers=auth_headers)
if response.status_code != 200:
raise ValueError(
@@ -486,9 +488,7 @@ def from_dataframe(dataframe: pd.DataFrame, time_period: str = None):
return dataset
- def upload_to_huggingface(
- self, owner_name: str, model_name: str, file_name: str
- ):
+ def upload_to_huggingface(self, owner_name: str, model_name: str, file_name: str):
"""Uploads the dataset to HuggingFace.
Args:
diff --git a/policyengine_core/data_storage/in_memory_storage.py b/policyengine_core/data_storage/in_memory_storage.py
index 969bae5a3..a5e510b42 100644
--- a/policyengine_core/data_storage/in_memory_storage.py
+++ b/policyengine_core/data_storage/in_memory_storage.py
@@ -21,9 +21,7 @@ def __init__(self, is_eternal: bool):
def clone(self) -> "InMemoryStorage":
clone = InMemoryStorage(self.is_eternal)
- clone._arrays = {
- period: array.copy() for period, array in self._arrays.items()
- }
+ clone._arrays = {period: array.copy() for period, array in self._arrays.items()}
return clone
def get(self, period: Period, branch_name: str = "default") -> ArrayLike:
@@ -44,9 +42,7 @@ def put(
self._arrays[f"{branch_name}:{period}"] = value
- def delete(
- self, period: Period = None, branch_name: str = "default"
- ) -> None:
+ def delete(self, period: Period = None, branch_name: str = "default") -> None:
if period is None:
self._arrays = {}
return
@@ -62,16 +58,12 @@ def delete(
}
def get_known_periods(self) -> list:
- return list(
- map(lambda x: periods.period(x.split(":")[1]), self._arrays.keys())
- )
+ return list(map(lambda x: periods.period(x.split(":")[1]), self._arrays.keys()))
def get_known_branch_periods(self) -> list:
return [
(branch_name, periods.period(period))
- for branch_name, period in map(
- lambda x: x.split(":"), self._arrays.keys()
- )
+ for branch_name, period in map(lambda x: x.split(":"), self._arrays.keys())
]
def get_memory_usage(self) -> dict:
diff --git a/policyengine_core/data_storage/on_disk_storage.py b/policyengine_core/data_storage/on_disk_storage.py
index 29954e2a3..3e92ae063 100644
--- a/policyengine_core/data_storage/on_disk_storage.py
+++ b/policyengine_core/data_storage/on_disk_storage.py
@@ -58,9 +58,7 @@ def put(
numpy.save(path, value)
self._files[filename] = path
- def delete(
- self, period: Period = None, branch_name: str = "default"
- ) -> None:
+ def delete(self, period: Period = None, branch_name: str = "default") -> None:
if period is None:
self._files = {}
return
@@ -77,16 +75,12 @@ def delete(
}
def get_known_periods(self) -> list:
- return list(
- [periods.period(x.split("_")[1]) for x in self._files.keys()]
- )
+ return list([periods.period(x.split("_")[1]) for x in self._files.keys()])
def get_known_branch_periods(self) -> list:
return [
(branch_name, periods.period(period))
- for branch_name, period in map(
- lambda x: x.split("_"), self._files.keys()
- )
+ for branch_name, period in map(lambda x: x.split("_"), self._files.keys())
]
def restore(self) -> None:
diff --git a/policyengine_core/entities/entity.py b/policyengine_core/entities/entity.py
index f97d68e73..46364600a 100644
--- a/policyengine_core/entities/entity.py
+++ b/policyengine_core/entities/entity.py
@@ -26,14 +26,10 @@ def check_role_validity(self, role: Any) -> None:
raise ValueError("{} is not a valid role".format(role))
def get_variable(self, variable_name: str, check_existence: bool = False):
- return self._tax_benefit_system.get_variable(
- variable_name, check_existence
- )
+ return self._tax_benefit_system.get_variable(variable_name, check_existence)
def check_variable_defined_for_entity(self, variable_name: str) -> None:
- variable_entity = self.get_variable(
- variable_name, check_existence=True
- ).entity
+ variable_entity = self.get_variable(variable_name, check_existence=True).entity
# Should be this:
# if variable_entity is not self:
if variable_entity.key != self.key:
diff --git a/policyengine_core/enums/enum.py b/policyengine_core/enums/enum.py
index c13a99b60..0522cdaf5 100644
--- a/policyengine_core/enums/enum.py
+++ b/policyengine_core/enums/enum.py
@@ -74,9 +74,7 @@ def encode(cls, array: Union[EnumArray, np.ndarray]) -> EnumArray:
else:
first_elem = None
if first_elem is not None and isinstance(first_elem, Enum):
- indices = np.array(
- [item.index for item in array], dtype=ENUM_ARRAY_DTYPE
- )
+ indices = np.array([item.index for item in array], dtype=ENUM_ARRAY_DTYPE)
return EnumArray(indices, cls)
# Convert byte-strings or object arrays to Unicode strings
diff --git a/policyengine_core/errors/variable_not_found_error.py b/policyengine_core/errors/variable_not_found_error.py
index a58d265a7..41152ee4d 100644
--- a/policyengine_core/errors/variable_not_found_error.py
+++ b/policyengine_core/errors/variable_not_found_error.py
@@ -25,9 +25,7 @@ def __init__(self, variable_name, tax_benefit_system):
"You tried to calculate or to set a value for variable '{0}', but it was not found in the loaded tax and benefit system ({1}).".format(
variable_name, country_package_id
),
- "Are you sure you spelled '{0}' correctly?".format(
- variable_name
- ),
+ "Are you sure you spelled '{0}' correctly?".format(variable_name),
"If this code used to work and suddenly does not, this is most probably linked to an update of the tax and benefit system.",
"Look at its changelog to learn about renames and removals and update your code. If it is an official package,",
"it is probably available on .".format(
diff --git a/policyengine_core/experimental/memory_config.py b/policyengine_core/experimental/memory_config.py
index 8d4c096c6..4876096c5 100644
--- a/policyengine_core/experimental/memory_config.py
+++ b/policyengine_core/experimental/memory_config.py
@@ -25,6 +25,4 @@ def __init__(
self.priority_variables = (
set(priority_variables) if priority_variables else set()
)
- self.variables_to_drop = (
- set(variables_to_drop) if variables_to_drop else set()
- )
+ self.variables_to_drop = set(variables_to_drop) if variables_to_drop else set()
diff --git a/policyengine_core/holders/helpers.py b/policyengine_core/holders/helpers.py
index b7ff4175b..21a9ebdb9 100644
--- a/policyengine_core/holders/helpers.py
+++ b/policyengine_core/holders/helpers.py
@@ -10,9 +10,7 @@
log = logging.getLogger(__name__)
-def set_input_dispatch_by_period(
- holder: Holder, period: Period, array: ArrayLike
-):
+def set_input_dispatch_by_period(holder: Holder, period: Period, array: ArrayLike):
"""
This function can be declared as a ``set_input`` attribute of a variable.
@@ -49,9 +47,7 @@ def set_input_dispatch_by_period(
sub_period = sub_period.offset(1)
-def set_input_divide_by_period(
- holder: Holder, period: Period, array: ArrayLike
-):
+def set_input_divide_by_period(holder: Holder, period: Period, array: ArrayLike):
"""
This function can be declared as a ``set_input`` attribute of a variable.
diff --git a/policyengine_core/holders/holder.py b/policyengine_core/holders/holder.py
index 972a1ff28..8bb5d9e4d 100644
--- a/policyengine_core/holders/holder.py
+++ b/policyengine_core/holders/holder.py
@@ -41,10 +41,7 @@ def __init__(self, variable: "Variable", population: "Population"):
):
self._disk_storage = self.create_disk_storage()
self._on_disk_storable = True
- if (
- self.variable.name
- in self.simulation.memory_config.variables_to_drop
- ):
+ if self.variable.name in self.simulation.memory_config.variables_to_drop:
self._do_not_store = True
def clone(self, population: "Population") -> "Holder":
@@ -97,9 +94,7 @@ def delete_arrays(
if self._disk_storage:
self._disk_storage.delete(period, branch_name)
- def get_array(
- self, period: Period, branch_name: str = "default"
- ) -> ArrayLike:
+ def get_array(self, period: Period, branch_name: str = "default") -> ArrayLike:
"""
Get the value of the variable for the given period.
@@ -151,9 +146,7 @@ def get_memory_usage(self) -> dict:
usage.update(self._memory_storage.get_memory_usage())
if self.simulation.trace:
- nb_requests = self.simulation.tracer.get_nb_requests(
- self.variable.name
- )
+ nb_requests = self.simulation.tracer.get_nb_requests(self.variable.name)
usage.update(
dict(
nb_requests=nb_requests,
@@ -173,11 +166,7 @@ def get_known_periods(self) -> List[Period]:
"""
return list(self._memory_storage.get_known_periods()) + list(
- (
- self._disk_storage.get_known_periods()
- if self._disk_storage
- else []
- )
+ (self._disk_storage.get_known_periods() if self._disk_storage else [])
)
def get_known_branch_periods(self) -> List[Tuple[str, Period]]:
@@ -236,10 +225,7 @@ def set_input(
return warnings.warn(warning_message, Warning)
if self.variable.value_type in (float, int) and isinstance(array, str):
array = tools.eval_expression(array)
- if (
- self.variable.set_input
- and period.unit != self.variable.definition_period
- ):
+ if self.variable.set_input and period.unit != self.variable.definition_period:
return self.variable.set_input(self, period, array)
return self._set(period, array, branch_name)
@@ -263,9 +249,7 @@ def _to_array(self, value: Any) -> ArrayLike:
original_value = value
value = self.variable.possible_values.encode(value)
if value.shape != original_value.shape:
- value = self.variable.possible_values.encode(
- original_value.astype("O")
- )
+ value = self.variable.possible_values.encode(original_value.astype("O"))
if value.dtype != self.variable.dtype:
try:
value = value.astype(self.variable.dtype)
@@ -311,8 +295,7 @@ def put_in_cache(
if (
self.simulation.opt_out_cache
and self.simulation.tax_benefit_system.cache_blacklist
- and self.variable.name
- in self.simulation.tax_benefit_system.cache_blacklist
+ and self.variable.name in self.simulation.tax_benefit_system.cache_blacklist
):
return
diff --git a/policyengine_core/parameters/config.py b/policyengine_core/parameters/config.py
index f1e86f0ec..73982d8a3 100644
--- a/policyengine_core/parameters/config.py
+++ b/policyengine_core/parameters/config.py
@@ -29,9 +29,7 @@ def date_constructor(_loader, node):
return node.value
-yaml.add_constructor(
- "tag:yaml.org,2002:timestamp", date_constructor, Loader=Loader
-)
+yaml.add_constructor("tag:yaml.org,2002:timestamp", date_constructor, Loader=Loader)
def dict_no_duplicate_constructor(loader, node, deep=False):
diff --git a/policyengine_core/parameters/helpers.py b/policyengine_core/parameters/helpers.py
index 242246805..b11a8aa87 100644
--- a/policyengine_core/parameters/helpers.py
+++ b/policyengine_core/parameters/helpers.py
@@ -73,9 +73,7 @@ def _parse_child(child_name, child, child_path):
):
return parameters.Parameter(child_name, child, child_path)
else:
- return parameters.ParameterNode(
- child_name, data=child, file_path=child_path
- )
+ return parameters.ParameterNode(child_name, data=child, file_path=child_path)
def _validate_parameter(parameter, data, data_type=None, allowed_keys=None):
@@ -86,8 +84,6 @@ def _validate_parameter(parameter, data, data_type=None, allowed_keys=None):
if data_type is not None and not isinstance(data, data_type):
raise ParameterParsingError(
- "'{}' must be of type {}.".format(
- parameter.name, type_map[data_type]
- ),
+ "'{}' must be of type {}.".format(parameter.name, type_map[data_type]),
parameter.file_path,
)
diff --git a/policyengine_core/parameters/operations/homogenize_parameters.py b/policyengine_core/parameters/operations/homogenize_parameters.py
index 5eb89f28a..6d2e830a7 100644
--- a/policyengine_core/parameters/operations/homogenize_parameters.py
+++ b/policyengine_core/parameters/operations/homogenize_parameters.py
@@ -24,9 +24,7 @@ def homogenize_parameter_structures(
for node in root.get_descendants():
if isinstance(node, ParameterNode):
breakdown = get_breakdown_variables(node)
- node = homogenize_parameter_node(
- node, breakdown, variables, default_value
- )
+ node = homogenize_parameter_node(node, breakdown, variables, default_value)
return root
diff --git a/policyengine_core/parameters/operations/interpolate_parameters.py b/policyengine_core/parameters/operations/interpolate_parameters.py
index ce51ee059..c91a8ce2f 100644
--- a/policyengine_core/parameters/operations/interpolate_parameters.py
+++ b/policyengine_core/parameters/operations/interpolate_parameters.py
@@ -17,24 +17,18 @@ def interpolate_parameters(root: ParameterNode) -> ParameterNode:
"""
for parameter in root.get_descendants():
if isinstance(parameter, Parameter):
- if (
- "interpolation" in parameter.metadata
- and not parameter.metadata["interpolation"].get(
- "completed", False
- )
- ):
+ if "interpolation" in parameter.metadata and not parameter.metadata[
+ "interpolation"
+ ].get("completed", False):
interpolated_entries = []
for i in range(len(parameter.values_list) - 1):
# For each gap in parameter values
start = instant(parameter.values_list[::-1][i].instant_str)
num_intervals = 1
# Find the number of intervals to fill
- interval_size = parameter.metadata["interpolation"][
- "interval"
- ]
+ interval_size = parameter.metadata["interpolation"]["interval"]
parameter_dates = [
- at_instant.instant_str
- for at_instant in parameter.values_list
+ at_instant.instant_str for at_instant in parameter.values_list
]
while (
str(start.offset(num_intervals, interval_size))
@@ -46,16 +40,13 @@ def interpolate_parameters(root: ParameterNode) -> ParameterNode:
start_str = str(
start.offset(
j,
- parameter.metadata["interpolation"][
- "interval"
- ],
+ parameter.metadata["interpolation"]["interval"],
)
)
start_value = parameter.values_list[::-1][i].value
end_value = parameter.values_list[::-1][i + 1].value
new_value = (
- start_value
- + (end_value - start_value) * j / num_intervals
+ start_value + (end_value - start_value) * j / num_intervals
)
interpolated_entries += [
ParameterAtInstant(
@@ -64,8 +55,6 @@ def interpolate_parameters(root: ParameterNode) -> ParameterNode:
]
for entry in interpolated_entries:
parameter.values_list.append(entry)
- parameter.values_list.sort(
- key=lambda x: x.instant_str, reverse=True
- )
+ parameter.values_list.sort(key=lambda x: x.instant_str, reverse=True)
parameter.metadata["interpolation"]["completed"] = True
return root
diff --git a/policyengine_core/parameters/operations/uprate_parameters.py b/policyengine_core/parameters/operations/uprate_parameters.py
index b09b2f5c7..415b9fbf9 100644
--- a/policyengine_core/parameters/operations/uprate_parameters.py
+++ b/policyengine_core/parameters/operations/uprate_parameters.py
@@ -77,9 +77,7 @@ def uprate_parameters(root: ParameterNode) -> ParameterNode:
)
# Construct cadence options object
- cadence_options = construct_cadence_options(
- cadence_meta, parameter
- )
+ cadence_options = construct_cadence_options(cadence_meta, parameter)
# Ensure that end comes after start and enactment comes after end
if cadence_options["end"] <= cadence_options["start"]:
@@ -120,9 +118,7 @@ def uprate_parameters(root: ParameterNode) -> ParameterNode:
if "start_instant" in meta:
last_instant = instant(meta["start_instant"])
else:
- last_instant = instant(
- parameter.values_list[0].instant_str
- )
+ last_instant = instant(parameter.values_list[0].instant_str)
# Pre-compute values that don't change in the loop
last_instant_str = str(last_instant)
@@ -143,17 +139,11 @@ def uprate_parameters(root: ParameterNode) -> ParameterNode:
# If the uprater instant is defined after the last parameter instant
if entry_instant > last_instant:
# Apply the uprater and add to the parameter
- uprater_at_entry = uprating_parameter(
- entry_instant
- )
- uprater_change = (
- uprater_at_entry / uprater_at_start
- )
+ uprater_at_entry = uprating_parameter(entry_instant)
+ uprater_change = uprater_at_entry / uprater_at_start
uprated_value = value_at_start * uprater_change
if has_rounding:
- uprated_value = round_uprated_value(
- meta, uprated_value
- )
+ uprated_value = round_uprated_value(meta, uprated_value)
parameter.values_list.append(
ParameterAtInstant(
parameter.name,
@@ -162,9 +152,7 @@ def uprate_parameters(root: ParameterNode) -> ParameterNode:
)
)
# Whether using cadence or not, sort the parameter values_list
- parameter.values_list.sort(
- key=lambda x: x.instant_str, reverse=True
- )
+ parameter.values_list.sort(key=lambda x: x.instant_str, reverse=True)
return root
@@ -184,9 +172,7 @@ def round_uprated_value(meta: dict, uprated_value: float) -> float:
return uprated_value
-def find_cadence_first(
- parameter: Parameter, cadence_options: dict
-) -> datetime:
+def find_cadence_first(parameter: Parameter, cadence_options: dict) -> datetime:
"""
Find first value to uprate. This should be the same (month, day) as
the uprating enactment date, but occurring after the last value within
@@ -236,13 +222,9 @@ def find_cadence_first(
# Conditionally determine settings to utilize
# Set the month only if the interval is "year"
- month_option = (
- cadence_options["enactment"].month if interval == "year" else None
- )
+ month_option = cadence_options["enactment"].month if interval == "year" else None
# Set the day if interval is "year" or "month", just not "day"
- day_option = (
- cadence_options["enactment"].day if interval != "day" else None
- )
+ day_option = cadence_options["enactment"].day if interval != "day" else None
rrule_obj = rrule.rrule(
freq=rrule.YEARLY,
@@ -294,9 +276,7 @@ def find_cadence_last(uprater: Parameter, cadence_options: dict) -> datetime:
# resulting in 25% of offsets being one day short if calculated
# using days only; further, this must be done with relativedelta and not
# datetime's timedelta, which doesn't handle year-based offsets
- start_date: datetime = (
- last_param - relativedelta(years=1) + relativedelta(days=1)
- )
+ start_date: datetime = last_param - relativedelta(years=1) + relativedelta(days=1)
# Conditionally determine settings to utilize
# Set the month only if the interval is "year"
@@ -357,9 +337,7 @@ def uprate_by_cadence(
rrule_interval = rrule.YEARLY
# Generate a list of iterations
- iterations = rrule.rrule(
- freq=rrule_interval, dtstart=first_date, until=last_date
- )
+ iterations = rrule.rrule(freq=rrule_interval, dtstart=first_date, until=last_date)
# Determine the offset between the first enactment
# date and the first start and end date
@@ -381,14 +359,10 @@ def uprate_by_cadence(
end_calc_date: datetime = enactment_date - enactment_end_offset
# Find uprater value at cadence start
- start_val = uprating_parameter.get_at_instant(
- instant(start_calc_date.date())
- )
+ start_val = uprating_parameter.get_at_instant(instant(start_calc_date.date()))
# Find uprater value at cadence end
- end_val = uprating_parameter.get_at_instant(
- instant(end_calc_date.date())
- )
+ end_val = uprating_parameter.get_at_instant(instant(end_calc_date.date()))
# Ensure that earliest date exists within uprater
if not start_val:
@@ -417,9 +391,7 @@ def uprate_by_cadence(
return uprated_data
-def construct_cadence_options(
- cadence_settings: dict, parameter: Parameter
-) -> dict:
+def construct_cadence_options(cadence_settings: dict, parameter: Parameter) -> dict:
# Define all settings with fixed input options
# so as to test that input is valid
FIXED_OPTIONS = {"interval": ["year", "month", "day"]}
diff --git a/policyengine_core/parameters/parameter.py b/policyengine_core/parameters/parameter.py
index 9717c6554..50c3205f7 100644
--- a/policyengine_core/parameters/parameter.py
+++ b/policyengine_core/parameters/parameter.py
@@ -45,9 +45,7 @@ class Parameter(AtInstantLike):
"""
- def __init__(
- self, name: str, data: dict, file_path: Optional[str] = None
- ) -> None:
+ def __init__(self, name: str, data: dict, file_path: Optional[str] = None) -> None:
self.name: str = name
self.file_path: Optional[str] = file_path
_validate_parameter(self, data, data_type=dict)
@@ -59,9 +57,7 @@ def __init__(
if data.get("values"):
# 'unit' and 'reference' are only listed here for backward compatibility
self.metadata.update(data.get("metadata", {}))
- _validate_parameter(
- self, data, allowed_keys=COMMON_KEYS.union({"values"})
- )
+ _validate_parameter(self, data, allowed_keys=COMMON_KEYS.union({"values"}))
self.description = data.get("description")
_validate_parameter(self, data["values"], data_type=dict)
@@ -122,9 +118,7 @@ def __repr__(self):
)
def __eq__(self, other):
- return (self.name == other.name) and (
- self.values_list == other.values_list
- )
+ return (self.name == other.name) and (self.values_list == other.values_list)
def clone(self):
clone = empty_clone(self)
@@ -132,8 +126,7 @@ def clone(self):
clone.metadata = copy.deepcopy(self.metadata)
clone.values_list = [
- parameter_at_instant.clone()
- for parameter_at_instant in self.values_list
+ parameter_at_instant.clone() for parameter_at_instant in self.values_list
]
return clone
@@ -200,9 +193,7 @@ def update(
# Insert new interval
value_name = _compose_name(self.name, item_name=start_str)
- new_interval = ParameterAtInstant(
- value_name, start_str, data={"value": value}
- )
+ new_interval = ParameterAtInstant(value_name, start_str, data={"value": value})
new_values.append(new_interval)
# Remove covered intervals
diff --git a/policyengine_core/parameters/parameter_at_instant.py b/policyengine_core/parameters/parameter_at_instant.py
index 23fda2159..b815f1c10 100644
--- a/policyengine_core/parameters/parameter_at_instant.py
+++ b/policyengine_core/parameters/parameter_at_instant.py
@@ -33,9 +33,7 @@ def __init__(
self.metadata: typing.Dict = {}
# Accept { 2015-01-01: 4000 }
- if not isinstance(data, dict) and isinstance(
- data, ALLOWED_PARAM_TYPES
- ):
+ if not isinstance(data, dict) and isinstance(data, ALLOWED_PARAM_TYPES):
self.value = data
return
@@ -49,9 +47,7 @@ def __init__(
self.value: float = data["value"]
def validate(self, data: dict) -> None:
- _validate_parameter(
- self, data, data_type=dict, allowed_keys=self._allowed_keys
- )
+ _validate_parameter(self, data, data_type=dict, allowed_keys=self._allowed_keys)
try:
value = data["value"]
except KeyError:
diff --git a/policyengine_core/parameters/parameter_node.py b/policyengine_core/parameters/parameter_node.py
index 12b58e38e..94dbd88ff 100644
--- a/policyengine_core/parameters/parameter_node.py
+++ b/policyengine_core/parameters/parameter_node.py
@@ -84,9 +84,7 @@ def __init__(
self.trace: bool = False
self.tracer = None
self.branch_name = None
- self._at_instant_cache: typing.Dict[
- Instant, ParameterNodeAtInstant
- ] = {}
+ self._at_instant_cache: typing.Dict[Instant, ParameterNodeAtInstant] = {}
self.parent = None
if directory_path:
@@ -104,12 +102,8 @@ def __init__(
with open(child_path, "r") as f:
# Get the header as the label (making sure to remove the leading hash), and the rest as the description
lines = f.readlines()
- metadata["label"] = (
- lines[0].replace("# ", "").strip()
- )
- metadata["description"] = "".join(
- lines[1:]
- ).strip()
+ metadata["label"] = lines[0].replace("# ", "").strip()
+ metadata["description"] = "".join(lines[1:]).strip()
self.metadata.update(metadata)
@@ -118,17 +112,13 @@ def __init__(
if child_name == "index":
data = _load_yaml_file(child_path) or {}
- _validate_parameter(
- self, data, allowed_keys=COMMON_KEYS
- )
+ _validate_parameter(self, data, allowed_keys=COMMON_KEYS)
self.description = data.get("description")
self.documentation = data.get("documentation")
self.metadata.update(data.get("metadata", {}))
elif child_name not in EXCLUDED_PARAMETER_CHILD_NAMES:
child_name_expanded = _compose_name(name, child_name)
- child = load_parameter_file(
- child_path, child_name_expanded
- )
+ child = load_parameter_file(child_path, child_name_expanded)
self.add_child(child_name, child)
elif os.path.isdir(child_path):
@@ -178,9 +168,7 @@ def add_child(self, name: str, child: Union["ParameterNode", Parameter]):
:param child: The new child, an instance of :class:`.ParameterScale` or :class:`.Parameter` or :class:`.ParameterNode`.
"""
if name in self.children:
- raise ValueError(
- "{} has already a child named {}".format(self.name, name)
- )
+ raise ValueError("{} has already a child named {}".format(self.name, name))
if not (
isinstance(child, ParameterNode)
or isinstance(child, Parameter)
@@ -198,9 +186,7 @@ def add_child(self, name: str, child: Union["ParameterNode", Parameter]):
def __repr__(self) -> str:
result = os.linesep.join(
[
- os.linesep.join(["{}:", "{}"]).format(
- name, tools.indent(repr(value))
- )
+ os.linesep.join(["{}:", "{}"]).format(name, tools.indent(repr(value)))
for name, value in sorted(self.children.items())
]
)
@@ -219,9 +205,7 @@ def clone(self) -> "ParameterNode":
clone.__dict__ = self.__dict__.copy()
clone.metadata = copy.deepcopy(self.metadata)
- clone.children = {
- key: child.clone() for key, child in self.children.items()
- }
+ clone.children = {key: child.clone() for key, child in self.children.items()}
for child_key, child in clone.children.items():
setattr(clone, child_key, child)
clone._at_instant_cache = {}
diff --git a/policyengine_core/parameters/parameter_node_at_instant.py b/policyengine_core/parameters/parameter_node_at_instant.py
index f3d8c6ef8..343e152dc 100644
--- a/policyengine_core/parameters/parameter_node_at_instant.py
+++ b/policyengine_core/parameters/parameter_node_at_instant.py
@@ -38,9 +38,7 @@ def __init__(self, name: str, node: "ParameterNode", instant_str: str):
if child_at_instant is not None:
self.add_child(child_name, child_at_instant)
- def add_child(
- self, child_name: str, child_at_instant: "ParameterNodeAtInstant"
- ):
+ def add_child(self, child_name: str, child_at_instant: "ParameterNodeAtInstant"):
self._children[child_name] = child_at_instant
setattr(self, child_name, child_at_instant)
@@ -59,9 +57,7 @@ def __getitem__(
if hasattr(key, "__array__") and not isinstance(key, numpy.ndarray):
key = numpy.asarray(key)
if isinstance(key, numpy.ndarray):
- return parameters.VectorialParameterNodeAtInstant.build_from_node(
- self
- )[key]
+ return parameters.VectorialParameterNodeAtInstant.build_from_node(self)[key]
return self._children[key]
def __iter__(self) -> Iterable:
@@ -70,9 +66,7 @@ def __iter__(self) -> Iterable:
def __repr__(self) -> str:
result = os.linesep.join(
[
- os.linesep.join(["{}:", "{}"]).format(
- name, tools.indent(repr(value))
- )
+ os.linesep.join(["{}:", "{}"]).format(name, tools.indent(repr(value)))
for name, value in self._children.items()
]
)
diff --git a/policyengine_core/parameters/parameter_scale.py b/policyengine_core/parameters/parameter_scale.py
index 39dcd2297..9a587ff52 100644
--- a/policyengine_core/parameters/parameter_scale.py
+++ b/policyengine_core/parameters/parameter_scale.py
@@ -85,9 +85,9 @@ def propagate_units(self) -> None:
child_key in bracket.children
and "unit" not in bracket.children[child_key].metadata
):
- bracket.children[child_key].metadata["unit"] = (
- self.metadata[unit_key]
- )
+ bracket.children[child_key].metadata["unit"] = self.metadata[
+ unit_key
+ ]
def propagate_uprating(self) -> None:
for bracket in self.brackets:
@@ -111,17 +111,12 @@ def clone(self) -> "ParameterScale":
return clone
def _get_at_instant(self, instant: Instant) -> TaxScaleLike:
- brackets = [
- bracket.get_at_instant(instant) for bracket in self.brackets
- ]
+ brackets = [bracket.get_at_instant(instant) for bracket in self.brackets]
if self.metadata.get("type") == "single_amount":
scale = SingleAmountTaxScale()
for bracket in brackets:
- if (
- "amount" in bracket._children
- and "threshold" in bracket._children
- ):
+ if "amount" in bracket._children and "threshold" in bracket._children:
amount = bracket.amount
threshold = bracket.threshold
scale.add_bracket(threshold, amount)
@@ -129,10 +124,7 @@ def _get_at_instant(self, instant: Instant) -> TaxScaleLike:
elif any("amount" in bracket._children for bracket in brackets):
scale = MarginalAmountTaxScale()
for bracket in brackets:
- if (
- "amount" in bracket._children
- and "threshold" in bracket._children
- ):
+ if "amount" in bracket._children and "threshold" in bracket._children:
amount = bracket.amount
threshold = bracket.threshold
scale.add_bracket(threshold, amount)
@@ -161,10 +153,7 @@ def _get_at_instant(self, instant: Instant) -> TaxScaleLike:
base = bracket.base
else:
base = 1.0
- if (
- "rate" in bracket._children
- and "threshold" in bracket._children
- ):
+ if "rate" in bracket._children and "threshold" in bracket._children:
rate = bracket.rate
threshold = bracket.threshold
scale.add_bracket(threshold, rate * base)
diff --git a/policyengine_core/parameters/parameter_scale_bracket.py b/policyengine_core/parameters/parameter_scale_bracket.py
index e8b200009..703b25a77 100644
--- a/policyengine_core/parameters/parameter_scale_bracket.py
+++ b/policyengine_core/parameters/parameter_scale_bracket.py
@@ -7,9 +7,7 @@ class ParameterScaleBracket(ParameterNode):
A parameter scale bracket.
"""
- _allowed_keys = set(
- ["amount", "threshold", "rate", "average_rate", "base"]
- )
+ _allowed_keys = set(["amount", "threshold", "rate", "average_rate", "base"])
@staticmethod
def allowed_unit_keys():
@@ -20,13 +18,11 @@ def get_descendants(self) -> Iterable[Parameter]:
if key in self.children:
yield self.children[key]
- def propagate_uprating(
- self, uprating: str, threshold: bool = False
- ) -> None:
+ def propagate_uprating(self, uprating: str, threshold: bool = False) -> None:
for key in self._allowed_keys:
if key in self.children:
if key == "threshold" and not threshold:
continue
- self.children[key].metadata["uprating"] = (
- uprating or self.children[key].metadata.get("uprating")
- )
+ self.children[key].metadata["uprating"] = uprating or self.children[
+ key
+ ].metadata.get("uprating")
diff --git a/policyengine_core/parameters/vectorial_parameter_node_at_instant.py b/policyengine_core/parameters/vectorial_parameter_node_at_instant.py
index c5ce1367c..514b17141 100644
--- a/policyengine_core/parameters/vectorial_parameter_node_at_instant.py
+++ b/policyengine_core/parameters/vectorial_parameter_node_at_instant.py
@@ -31,9 +31,7 @@ def build_from_node(
VectorialParameterNodeAtInstant.build_from_node(
node[subnode_name]
).vector
- if isinstance(
- node[subnode_name], parameters.ParameterNodeAtInstant
- )
+ if isinstance(node[subnode_name], parameters.ParameterNodeAtInstant)
else node[subnode_name]
)
for subnode_name in subnodes_name
@@ -46,15 +44,9 @@ def build_from_node(
dtype=[
(
subnode_name,
- (
- subnode.dtype
- if isinstance(subnode, numpy.recarray)
- else "float"
- ),
- )
- for (subnode_name, subnode) in zip(
- subnodes_name, vectorial_subnodes
+ (subnode.dtype if isinstance(subnode, numpy.recarray) else "float"),
)
+ for (subnode_name, subnode) in zip(subnodes_name, vectorial_subnodes)
],
)
@@ -68,12 +60,12 @@ def check_node_vectorisable(node: "ParameterNode") -> None:
Check that a node can be casted to a vectorial node, in order to be able to use fancy indexing.
"""
MESSAGE_PART_1 = "Cannot use fancy indexing on parameter node '{}', as"
- MESSAGE_PART_3 = "To use fancy indexing on parameter node, its children must be homogenous."
+ MESSAGE_PART_3 = (
+ "To use fancy indexing on parameter node, its children must be homogenous."
+ )
MESSAGE_PART_4 = "See more at ."
- def raise_key_inhomogeneity_error(
- node_with_key, node_without_key, missing_key
- ):
+ def raise_key_inhomogeneity_error(node_with_key, node_without_key, missing_key):
message = " ".join(
[
MESSAGE_PART_1,
@@ -146,24 +138,16 @@ def check_nodes_homogeneous(named_nodes):
first_node_keys = first_node._children.keys()
node_keys = node._children.keys()
if not first_node_keys == node_keys:
- missing_keys = set(first_node_keys).difference(
- node_keys
- )
- if (
- missing_keys
- ): # If the first_node has a key that node hasn't
+ missing_keys = set(first_node_keys).difference(node_keys)
+ if missing_keys: # If the first_node has a key that node hasn't
raise_key_inhomogeneity_error(
first_name, name, missing_keys.pop()
)
else: # If If the node has a key that first_node doesn't have
missing_key = (
- set(node_keys)
- .difference(first_node_keys)
- .pop()
- )
- raise_key_inhomogeneity_error(
- name, first_name, missing_key
+ set(node_keys).difference(first_node_keys).pop()
)
+ raise_key_inhomogeneity_error(name, first_name, missing_key)
children.update(extract_named_children(node))
check_nodes_homogeneous(children)
elif isinstance(first_node, float) or isinstance(first_node, int):
@@ -232,9 +216,7 @@ def __getitem__(self, key: str) -> Any:
and values[0].dtype.names
):
# Check if all values have the same dtype
- dtypes_match = all(
- val.dtype == values[0].dtype for val in values
- )
+ dtypes_match = all(val.dtype == values[0].dtype for val in values)
if not dtypes_match:
# Find the union of all field names across all values, preserving first seen order
@@ -247,9 +229,7 @@ def __getitem__(self, key: str) -> Any:
seen.add(field)
# Create unified dtype with all fields
- unified_dtype = numpy.dtype(
- [(f, " Any:
casted[field] = val[field]
values_cast.append(casted)
- default = numpy.zeros(
- len(values_cast[0]), dtype=unified_dtype
- )
+ default = numpy.zeros(len(values_cast[0]), dtype=unified_dtype)
# Fill with NaN
for field in unified_dtype.names:
default[field] = numpy.nan
@@ -289,9 +267,9 @@ def __getitem__(self, key: str) -> Any:
)
# If the result is not a leaf, wrap the result in a vectorial node.
- if numpy.issubdtype(
- result.dtype, numpy.record
- ) or numpy.issubdtype(result.dtype, numpy.void):
+ if numpy.issubdtype(result.dtype, numpy.record) or numpy.issubdtype(
+ result.dtype, numpy.void
+ ):
return VectorialParameterNodeAtInstant(
self._name, result.view(numpy.recarray), self._instant_str
)
diff --git a/policyengine_core/periods/helpers.py b/policyengine_core/periods/helpers.py
index f8ee95b8e..cfd011ac1 100644
--- a/policyengine_core/periods/helpers.py
+++ b/policyengine_core/periods/helpers.py
@@ -92,9 +92,7 @@ def instant_date(instant):
return None
instant_date = config.date_by_instant_cache.get(instant)
if instant_date is None:
- config.date_by_instant_cache[instant] = instant_date = datetime.date(
- *instant
- )
+ config.date_by_instant_cache[instant] = instant_date = datetime.date(*instant)
return instant_date
@@ -215,9 +213,7 @@ def period(value):
return periods.Period((config.DAY, value, 1))
if value == "ETERNITY" or value == config.ETERNITY:
- return periods.Period(
- ("eternity", instant(datetime.date.min), float("inf"))
- )
+ return periods.Period(("eternity", instant(datetime.date.min), float("inf")))
# check the type
if isinstance(value, int):
diff --git a/policyengine_core/periods/instant_.py b/policyengine_core/periods/instant_.py
index 4288e3a4a..ee9dbe92e 100644
--- a/policyengine_core/periods/instant_.py
+++ b/policyengine_core/periods/instant_.py
@@ -17,9 +17,7 @@ def __repr__(self) -> str:
>>> repr(instant('2014-2-3'))
'Instant((2014, 2, 3))'
"""
- return "{}({})".format(
- self.__class__.__name__, super(Instant, self).__repr__()
- )
+ return "{}({})".format(self.__class__.__name__, super(Instant, self).__repr__())
def __str__(self) -> str:
"""
@@ -35,9 +33,7 @@ def __str__(self) -> str:
"""
instant_str = config.str_by_instant_cache.get(self)
if instant_str is None:
- config.str_by_instant_cache[self] = instant_str = (
- self.date.isoformat()
- )
+ config.str_by_instant_cache[self] = instant_str = self.date.isoformat()
return instant_str
@property
@@ -54,9 +50,7 @@ def date(self) -> datetime.date:
"""
instant_date = config.date_by_instant_cache.get(self)
if instant_date is None:
- config.date_by_instant_cache[self] = instant_date = datetime.date(
- *self
- )
+ config.date_by_instant_cache[self] = instant_date = datetime.date(*self)
return instant_date
@property
@@ -103,9 +97,9 @@ def period(self, unit: str, size: int = 1):
config.MONTH,
config.YEAR,
), "Invalid unit: {} of type {}".format(unit, type(unit))
- assert (
- isinstance(size, int) and size >= 1
- ), "Invalid size: {} of type {}".format(size, type(size))
+ assert isinstance(size, int) and size >= 1, (
+ "Invalid size: {} of type {}".format(size, type(size))
+ )
return periods.Period((unit, self, size))
def offset(self, offset: int, unit: str) -> "Instant":
@@ -208,9 +202,9 @@ def offset(self, offset: int, unit: str) -> "Instant":
month = 12
day = 31
else:
- assert isinstance(
- offset, int
- ), "Invalid offset: {} of type {}".format(offset, type(offset))
+ assert isinstance(offset, int), "Invalid offset: {} of type {}".format(
+ offset, type(offset)
+ )
if unit == config.DAY:
day += offset
if offset < 0:
diff --git a/policyengine_core/periods/period_.py b/policyengine_core/periods/period_.py
index e7656b37a..fd58f69c0 100644
--- a/policyengine_core/periods/period_.py
+++ b/policyengine_core/periods/period_.py
@@ -30,9 +30,7 @@ def __repr__(self) -> str:
>>> repr(period('day', '2014-2-3'))
"Period(('day', Instant((2014, 2, 3)), 1))"
"""
- return "{}({})".format(
- self.__class__.__name__, super(Period, self).__repr__()
- )
+ return "{}({})".format(self.__class__.__name__, super(Period, self).__repr__())
def __str__(self) -> str:
"""
@@ -67,12 +65,7 @@ def __str__(self) -> str:
year, month, day = start_instant
# 1 year long period
- if (
- unit == config.MONTH
- and size == 12
- or unit == config.YEAR
- and size == 1
- ):
+ if unit == config.MONTH and size == 12 or unit == config.YEAR and size == 1:
if month == 1:
# civil year starting from january
return str(year)
@@ -90,18 +83,16 @@ def __str__(self) -> str:
if size == 1:
return "{}-{:02d}-{:02d}".format(year, month, day)
else:
- return "{}:{}-{:02d}-{:02d}:{}".format(
- unit, year, month, day, size
- )
+ return "{}:{}-{:02d}-{:02d}:{}".format(unit, year, month, day, size)
# complex period
return "{}:{}-{:02d}:{}".format(unit, year, month, size)
@property
def date(self) -> datetime.date:
- assert (
- self.size == 1
- ), '"date" is undefined for a period of size > 1: {}'.format(self)
+ assert self.size == 1, (
+ '"date" is undefined for a period of size > 1: {}'.format(self)
+ )
return self.start.date
@property
@@ -145,10 +136,7 @@ def intersection(self, start: Instant, stop: Instant):
return None
intersection_start = max(period_start, start)
intersection_stop = min(period_stop, stop)
- if (
- intersection_start == period_start
- and intersection_stop == period_stop
- ):
+ if intersection_start == period_start and intersection_stop == period_stop:
return self
if (
intersection_start.day == 1
@@ -166,9 +154,7 @@ def intersection(self, start: Instant, stop: Instant):
if (
intersection_start.day == 1
and intersection_stop.day
- == calendar.monthrange(
- intersection_stop.year, intersection_stop.month
- )[1]
+ == calendar.monthrange(intersection_stop.year, intersection_stop.month)[1]
):
return self.__class__(
(
@@ -203,14 +189,10 @@ def get_subperiods(self, unit: str) -> List["Period"]:
>>> [period('2014'), period('2015')]
"""
if helpers.unit_weight(self.unit) < helpers.unit_weight(unit):
- raise ValueError(
- "Cannot subdivide {0} into {1}".format(self.unit, unit)
- )
+ raise ValueError("Cannot subdivide {0} into {1}".format(self.unit, unit))
if unit == config.YEAR:
- return [
- self.this_year.offset(i, config.YEAR) for i in range(self.size)
- ]
+ return [self.this_year.offset(i, config.YEAR) for i in range(self.size)]
if unit == config.MONTH:
return [
@@ -220,8 +202,7 @@ def get_subperiods(self, unit: str) -> List["Period"]:
if unit == config.DAY:
return [
- self.first_day.offset(i, config.DAY)
- for i in range(self.size_in_days)
+ self.first_day.offset(i, config.DAY) for i in range(self.size_in_days)
]
def offset(self, offset: int, unit: str = None) -> "Period":
@@ -397,9 +378,7 @@ def size_in_months(self) -> int:
return self[2]
if self[0] == config.YEAR:
return self[2] * 12
- raise ValueError(
- "Cannot calculate number of months in {0}".format(self[0])
- )
+ raise ValueError("Cannot calculate number of months in {0}".format(self[0]))
@property
def size_in_days(self) -> int:
@@ -466,9 +445,7 @@ def stop(self) -> periods.Instant:
# Use datetime arithmetic for efficient day calculation
start_date = date(year, month, day)
end_date = start_date + timedelta(days=size - 1)
- return periods.Instant(
- (end_date.year, end_date.month, end_date.day)
- )
+ return periods.Instant((end_date.year, end_date.month, end_date.day))
else:
if unit == "month":
month += size
diff --git a/policyengine_core/populations/group_population.py b/policyengine_core/populations/group_population.py
index fd6b1bee2..c770ee751 100644
--- a/policyengine_core/populations/group_population.py
+++ b/policyengine_core/populations/group_population.py
@@ -29,22 +29,17 @@ def __call__(
period: Period = None,
options: Optional[Container[str]] = None,
):
- variable = self.simulation.tax_benefit_system.variables.get(
- variable_name
- )
+ variable = self.simulation.tax_benefit_system.variables.get(variable_name)
if variable.entity.is_person:
return self.sum(self.members(variable_name, period, options))
else:
return super().__call__(variable_name, period, options)
- def clone(
- self, simulation: "Simulation", members: Population
- ) -> "GroupPopulation":
+ def clone(self, simulation: "Simulation", members: Population) -> "GroupPopulation":
result = GroupPopulation(self.entity, members)
result.simulation = simulation
result._holders = {
- variable: holder.clone(self)
- for (variable, holder) in self._holders.items()
+ variable: holder.clone(self) for (variable, holder) in self._holders.items()
}
result.count = self.count
result.ids = self.ids
@@ -56,10 +51,7 @@ def clone(
@property
def members_position(self) -> ArrayLike:
- if (
- self._members_position is None
- and self.members_entity_id is not None
- ):
+ if self._members_position is None and self.members_entity_id is not None:
# We could use self.count and self.members.count , but with the current initilization, we are not sure count will be set before members_position is called
nb_entities = numpy.max(self.members_entity_id) + 1
nb_persons = len(self.members_entity_id)
@@ -88,9 +80,7 @@ def members_entity_id(self, members_entity_id: ArrayLike) -> None:
def members_role(self) -> ArrayLike:
if self._members_role is None:
default_role = self.entity.flattened_roles[0]
- self._members_role = numpy.repeat(
- default_role, len(self.members_entity_id)
- )
+ self._members_role = numpy.repeat(default_role, len(self.members_entity_id))
return self._members_role
@members_role.setter
@@ -110,11 +100,7 @@ def ordered_members_map(self) -> ArrayLike:
def get_role(self, role_name: str) -> Role:
return next(
- (
- role
- for role in self.entity.flattened_roles
- if role.key == role_name
- ),
+ (role for role in self.entity.flattened_roles if role.key == role_name),
None,
)
@@ -188,9 +174,7 @@ def reduce(
biggest_entity_size = numpy.max(position_in_entity) + 1
for p in range(biggest_entity_size):
- values = self.value_nth_person(
- p, filtered_array, default=neutral_element
- )
+ values = self.value_nth_person(p, filtered_array, default=neutral_element)
result = reducer(result, values)
return result
@@ -313,9 +297,7 @@ def value_from_person(
return result
@projectors.projectable
- def value_nth_person(
- self, n: int, array: ArrayLike, default: Any = 0
- ) -> ArrayLike:
+ def value_nth_person(self, n: int, array: ArrayLike, default: Any = 0) -> ArrayLike:
"""
Get the value of array for the person whose position in the entity is n.
@@ -354,6 +336,4 @@ def project(self, array: ArrayLike, role: Role = None) -> ArrayLike:
return array[self.members_entity_id]
else:
role_condition = self.members.has_role(role)
- return numpy.where(
- role_condition, array[self.members_entity_id], 0
- )
+ return numpy.where(role_condition, array[self.members_entity_id], 0)
diff --git a/policyengine_core/populations/population.py b/policyengine_core/populations/population.py
index 988485ec0..fbce9dcbf 100644
--- a/policyengine_core/populations/population.py
+++ b/policyengine_core/populations/population.py
@@ -72,19 +72,19 @@ def check_array_compatible_with_entity(self, array: numpy.ndarray) -> None:
)
)
- def check_period_validity(
- self, variable_name: str, period: Period
- ) -> None:
+ def check_period_validity(self, variable_name: str, period: Period) -> None:
if period is None:
stack = traceback.extract_stack()
filename, line_number, function_name, line_of_code = stack[-3]
- raise ValueError("""
+ raise ValueError(
+ """
You requested computation of variable "{}", but you did not specify on which period in "{}:{}":
{}
When you request the computation of a variable within a formula, you must always specify the period as the second parameter. The convention is to call this parameter "period". For example:
computed_salary = person('salary', period).
See more information at .
-""".format(variable_name, filename, line_number, line_of_code))
+""".format(variable_name, filename, line_number, line_of_code)
+ )
def __call__(
self,
@@ -112,9 +112,7 @@ def __call__(
raise ValueError(
"Options config.ADD and config.DIVIDE are incompatible (trying to compute variable {})".format(
variable_name
- ).encode(
- "utf-8"
- )
+ ).encode("utf-8")
)
from policyengine_core.simulations.microsimulation import (
@@ -139,9 +137,7 @@ def __call__(
variable_name, period, **calculate_kwargs
)
else:
- return self.simulation.calculate(
- variable_name, period, **calculate_kwargs
- )
+ return self.simulation.calculate(variable_name, period, **calculate_kwargs)
# Helpers
@@ -166,9 +162,7 @@ def get_memory_usage(self, variables: List[str] = None):
for holder_memory_usage in holders_memory_usage.values()
)
- return dict(
- total_nb_bytes=total_memory_usage, by_variable=holders_memory_usage
- )
+ return dict(total_nb_bytes=total_memory_usage, by_variable=holders_memory_usage)
@projectors.projectable
def has_role(self, role: Role) -> ArrayLike:
@@ -184,10 +178,7 @@ def has_role(self, role: Role) -> ArrayLike:
group_population = self.simulation.get_population(role.entity.plural)
if role.subroles:
return numpy.logical_or.reduce(
- [
- group_population.members_role == subrole
- for subrole in role.subroles
- ]
+ [group_population.members_role == subrole for subrole in role.subroles]
)
else:
return group_population.members_role == role
@@ -235,9 +226,7 @@ def get_rank(
# If entity is for instance 'person.household', we get the reference entity 'household' behind the projector
entity = (
- entity
- if not isinstance(entity, Projector)
- else entity.reference_entity
+ entity if not isinstance(entity, Projector) else entity.reference_entity
)
positions = entity.members_position
@@ -248,9 +237,7 @@ def get_rank(
# Matrix: the value in line i and column j is the value of criteria for the jth person of the ith entity
matrix = numpy.asarray(
[
- entity.value_nth_person(
- k, filtered_criteria, default=numpy.inf
- )
+ entity.value_nth_person(k, filtered_criteria, default=numpy.inf)
for k in range(biggest_entity_size)
]
).transpose()
diff --git a/policyengine_core/projectors/helpers.py b/policyengine_core/projectors/helpers.py
index 62fe5b840..bb7d50715 100644
--- a/policyengine_core/projectors/helpers.py
+++ b/policyengine_core/projectors/helpers.py
@@ -27,9 +27,7 @@ def get_projector_from_shortcut(population, shortcut, parent=None):
None,
)
if role:
- return projectors.UniqueRoleToEntityProjector(
- population, role, parent
- )
+ return projectors.UniqueRoleToEntityProjector(population, role, parent)
if shortcut in population.entity.containing_entities:
return getattr(
projectors.FirstPersonToEntityProjector(population, parent),
diff --git a/policyengine_core/reforms/reform.py b/policyengine_core/reforms/reform.py
index 340940fc2..068cac240 100644
--- a/policyengine_core/reforms/reform.py
+++ b/policyengine_core/reforms/reform.py
@@ -72,9 +72,7 @@ def __init__(self, baseline: TaxBenefitSystem):
super().__init__(baseline.entities)
self.baseline = baseline
self.parameters = baseline.parameters
- self._parameters_at_instant_cache = (
- baseline._parameters_at_instant_cache
- )
+ self._parameters_at_instant_cache = baseline._parameters_at_instant_cache
self.variables = baseline.variables.copy()
self.decomposition_file_path = baseline.decomposition_file_path
self.key = self.__class__.__name__
@@ -90,9 +88,9 @@ def __getattr__(self, attribute):
@property
def full_key(self) -> str:
key = self.key
- assert (
- key is not None
- ), "key was not set for reform {} (name: {!r})".format(self, self.name)
+ assert key is not None, "key was not set for reform {} (name: {!r})".format(
+ self, self.name
+ )
if self.baseline is not None and hasattr(self.baseline, "key"):
baseline_full_key = self.baseline.full_key
key = ".".join([baseline_full_key, key])
@@ -141,16 +139,12 @@ def apply(self):
for path, period_values in parameter_values.items():
parameter = self.parameters.get_child(path)
if not isinstance(period_values, dict):
- parameter.update(
- period="year:2000:100", value=period_values
- )
+ parameter.update(period="year:2000:100", value=period_values)
else:
for period, value in period_values.items():
try:
period = period_(period)
- parameter = parameter.update(
- period=period, value=value
- )
+ parameter = parameter.update(period=period, value=value)
except:
if "." in period:
start, stop = period.split(".")
@@ -229,9 +223,7 @@ def api_id(self):
sanitised_period_values = {}
for period, value in period_values.items():
period = period_(period)
- sanitised_period_values[f"{period.start}.{period.stop}"] = (
- value
- )
+ sanitised_period_values[f"{period.start}.{period.stop}"] = value
sanitised_parameter_values[path] = sanitised_period_values
response = requests.post(
diff --git a/policyengine_core/scripts/policyengine_command.py b/policyengine_core/scripts/policyengine_command.py
index 2319ce951..7aa85c888 100644
--- a/policyengine_core/scripts/policyengine_command.py
+++ b/policyengine_core/scripts/policyengine_command.py
@@ -12,9 +12,7 @@
def get_parser():
parser = argparse.ArgumentParser()
- subparsers = parser.add_subparsers(
- help="Available commands", dest="command"
- )
+ subparsers = parser.add_subparsers(help="Available commands", dest="command")
subparsers.required = (
True # Can be added as an argument of add_subparsers in Python 3
)
@@ -107,9 +105,7 @@ def build_data_parser(parser):
return parser
- parser_test = subparsers.add_parser(
- "test", help="Run OpenFisca YAML tests"
- )
+ parser_test = subparsers.add_parser("test", help="Run OpenFisca YAML tests")
parser_test = build_test_parser(parser_test)
parser_data = subparsers.add_parser("data", help="Manage OpenFisca data")
diff --git a/policyengine_core/scripts/run_data.py b/policyengine_core/scripts/run_data.py
index 53c465370..6cb914f03 100644
--- a/policyengine_core/scripts/run_data.py
+++ b/policyengine_core/scripts/run_data.py
@@ -53,8 +53,6 @@ def main(parser: ArgumentParser):
print("Saved datasets:")
for year in years:
filepath = dataset.file(year).absolute()
- print(
- " * " + filepath.name + " | " + str(filepath.absolute())
- )
+ print(" * " + filepath.name + " | " + str(filepath.absolute()))
else:
raise ValueError(f"Action {args.action} not recognised.")
diff --git a/policyengine_core/simulations/individual_sim.py b/policyengine_core/simulations/individual_sim.py
index 3a82a8109..c00933c87 100644
--- a/policyengine_core/simulations/individual_sim.py
+++ b/policyengine_core/simulations/individual_sim.py
@@ -41,38 +41,28 @@ def __init__(
self.sim_builder = SimulationBuilder()
self.parametric_vary = False
self.entities = {var.key: var for var in self.system.entities}
- self.situation_data = {
- entity.plural: {} for entity in self.system.entities
- }
+ self.situation_data = {entity.plural: {} for entity in self.system.entities}
self.varying = False
self.num_points = None
self.group_entity_names = [
- entity.key
- for entity in self.system.entities
- if not entity.is_person
+ entity.key for entity in self.system.entities if not entity.is_person
]
# Add add_entity functions
for entity in self.entities:
- setattr(
- self, f"add_{entity}", partial(self.add_data, entity=entity)
- )
+ setattr(self, f"add_{entity}", partial(self.add_data, entity=entity))
def build(self):
if self.required_entities is not None:
# Check for missing entities
entities = {entity.key: entity for entity in self.system.entities}
- person_entity = list(
- filter(lambda x: x.is_person, entities.values())
- )[0]
+ person_entity = list(filter(lambda x: x.is_person, entities.values()))[0]
for entity in self.required_entities:
entity_metadata = entities[entity]
roles = {role.key: role for role in entities[entity].roles}
default_role = roles[self.default_roles[entity]]
- no_entity_plural = (
- entity_metadata.plural not in self.situation_data
- )
+ no_entity_plural = entity_metadata.plural not in self.situation_data
if (
no_entity_plural
or len(self.situation_data[entity_metadata.plural]) == 0
@@ -113,25 +103,17 @@ def add_data(
input_period = input_period or self.year
entity_plural = self.entities[entity].plural
if name is None:
- name = (
- entity + "_" + str(len(self.situation_data[entity_plural]) + 1)
- )
+ name = entity + "_" + str(len(self.situation_data[entity_plural]) + 1)
if auto_period:
data = {}
for var, value in kwargs.items():
try:
- def_period = self.system.get_variable(
- var
- ).definition_period
+ def_period = self.system.get_variable(var).definition_period
if def_period in ["eternity", "year"]:
input_periods = [input_period]
else:
- input_periods = period(input_period).get_subperiods(
- def_period
- )
- data[var] = {
- str(subperiod): value for subperiod in input_periods
- }
+ input_periods = period(input_period).get_subperiods(def_period)
+ data[var] = {str(subperiod): value for subperiod in input_periods}
except:
data[var] = value
self.situation_data[entity_plural][name] = data
@@ -152,9 +134,7 @@ def get_entity(self, name: str) -> Entity:
][0]
return entity_type
- def map_to(
- self, arr: np.array, entity: str, target_entity: str, how: str = None
- ):
+ def map_to(self, arr: np.array, entity: str, target_entity: str, how: str = None):
"""Maps values from one entity to another.
Args:
@@ -265,10 +245,7 @@ def calc(
result = self.sim.calculate_add(var, period)
except:
result = self.simulation.calculate_divide(var, period)
- if (
- target is not None
- and target not in self.situation_data[entity.plural]
- ):
+ if target is not None and target not in self.situation_data[entity.plural]:
map_to = self.get_entity(target).key
if map_to is not None:
result = self.map_to(result, entity.key, map_to)
@@ -319,9 +296,9 @@ def deriv(
pass
x = x.astype(np.float32)
y = y.astype(np.float32)
- assert (
- len(y) > 1 and len(x) > 1
- ), "Simulation must vary on an axis to calculate derivatives."
+ assert len(y) > 1 and len(x) > 1, (
+ "Simulation must vary on an axis to calculate derivatives."
+ )
deriv = (y[1:] - y[:-1]) / (x[1:] - x[:-1])
deriv = np.append(deriv, deriv[-1])
return deriv
diff --git a/policyengine_core/simulations/microsimulation.py b/policyengine_core/simulations/microsimulation.py
index 2ef2334cf..28410d044 100644
--- a/policyengine_core/simulations/microsimulation.py
+++ b/policyengine_core/simulations/microsimulation.py
@@ -20,9 +20,7 @@ def get_weights(
variable = self.tax_benefit_system.get_variable(variable_name)
entity_key = map_to or variable.entity.key
weight_variable_name = f"{entity_key}_weight"
- weight_variable = self.tax_benefit_system.get_variable(
- weight_variable_name
- )
+ weight_variable = self.tax_benefit_system.get_variable(weight_variable_name)
weights = None
if time_period.unit == weight_variable.definition_period:
diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py
index fea8216ae..8f79c3b70 100644
--- a/policyengine_core/simulations/simulation.py
+++ b/policyengine_core/simulations/simulation.py
@@ -96,22 +96,15 @@ def __init__(
default_input_period: str = None,
default_calculation_period: str = None,
):
- self.default_input_period = (
- default_input_period or self.default_input_period
- )
+ self.default_input_period = default_input_period or self.default_input_period
self.default_calculation_period = (
default_calculation_period or self.default_calculation_period
)
if tax_benefit_system is None:
- if (
- self.default_tax_benefit_system_instance is not None
- and reform is None
- ):
+ if self.default_tax_benefit_system_instance is not None and reform is None:
tax_benefit_system = self.default_tax_benefit_system_instance
else:
- tax_benefit_system = self.default_tax_benefit_system(
- reform=reform
- )
+ tax_benefit_system = self.default_tax_benefit_system(reform=reform)
self.tax_benefit_system = tax_benefit_system
self.reform = reform
@@ -127,9 +120,7 @@ def __init__(
self.invalidated_caches = set()
self.debug: bool = False
self.trace: bool = trace
- self.tracer: SimpleTracer = (
- SimpleTracer() if not trace else FullTracer()
- )
+ self.tracer: SimpleTracer = SimpleTracer() if not trace else FullTracer()
self.opt_out_cache: bool = False
# controls the spirals detection; check for performance impact if > 1
self.max_spiral_loops: int = 10
@@ -146,9 +137,7 @@ def __init__(
raise ValueError(
"You provided both a situation and a dataset. Only one input method is allowed."
)
- self.build_from_populations(
- self.tax_benefit_system.instantiate_entities()
- )
+ self.build_from_populations(self.tax_benefit_system.instantiate_entities())
from policyengine_core.simulations.simulation_builder import (
SimulationBuilder,
) # Import here to avoid circular dependency
@@ -177,15 +166,11 @@ def __init__(
file_path=file_path,
version=version,
)
- datasets_by_name = {
- dataset.name: dataset for dataset in self.datasets
- }
+ datasets_by_name = {dataset.name: dataset for dataset in self.datasets}
if dataset in datasets_by_name:
dataset = datasets_by_name.get(dataset)
elif Path(dataset).exists():
- dataset = Dataset.from_file(
- dataset, self.default_input_period
- )
+ dataset = Dataset.from_file(dataset, self.default_input_period)
if isinstance(dataset, type):
self.dataset: Dataset = dataset(require=True)
elif isinstance(dataset, pd.DataFrame):
@@ -225,9 +210,7 @@ def __init__(
self.baseline = self.get_branch("baseline")
self.baseline.trace = self.trace
self.baseline.tracer = self.tracer
- self.baseline.tax_benefit_system = (
- self.default_tax_benefit_system_instance
- )
+ self.baseline.tax_benefit_system = self.default_tax_benefit_system_instance
else:
self.baseline = None
@@ -242,9 +225,7 @@ def apply_reform(self, reform: Union[tuple, Reform]):
reform = Reform.from_dict(reform)
reform.apply(self.tax_benefit_system)
- def build_from_populations(
- self, populations: Dict[str, Population]
- ) -> None:
+ def build_from_populations(self, populations: Dict[str, Population]) -> None:
"""This method of initialisation requires the populations to be pre-initialised.
Args:
@@ -263,9 +244,7 @@ def build_from_populations(
def build_from_dataset(self) -> None:
"""Build a simulation from a dataset."""
- self.build_from_populations(
- self.tax_benefit_system.instantiate_entities()
- )
+ self.build_from_populations(self.tax_benefit_system.instantiate_entities())
from policyengine_core.simulations.simulation_builder import (
SimulationBuilder,
) # Import here to avoid circular dependency
@@ -299,13 +278,11 @@ def get_eternity_array(name):
return data[name]
if self.dataset.data_format != Dataset.FLAT_FILE:
- assert (
- entity_id_field in data
- ), f"Missing {entity_id_field} column in the dataset. Each person entity must have an ID array defined for ETERNITY."
- elif entity_id_field not in data:
- data[entity_id_field] = np.arange(
- len(get_eternity_array("person_id"))
+ assert entity_id_field in data, (
+ f"Missing {entity_id_field} column in the dataset. Each person entity must have an ID array defined for ETERNITY."
)
+ elif entity_id_field not in data:
+ data[entity_id_field] = np.arange(len(get_eternity_array("person_id")))
entity_ids = get_eternity_array(entity_id_field)
builder.declare_person_entity(person_entity.key, entity_ids)
@@ -313,35 +290,29 @@ def get_eternity_array(name):
for group_entity in self.tax_benefit_system.group_entities:
entity_id_field = f"{group_entity.key}_id"
if self.dataset.data_format != Dataset.FLAT_FILE:
- assert (
- entity_id_field in data
- ), f"Missing {entity_id_field} column in the dataset. Each group entity must have an ID array defined for ETERNITY."
+ assert entity_id_field in data, (
+ f"Missing {entity_id_field} column in the dataset. Each group entity must have an ID array defined for ETERNITY."
+ )
entity_ids = get_eternity_array(entity_id_field)
elif entity_id_field not in data:
entity_id_field_values = get_eternity_array(
f"person_{group_entity.key}_id"
)
if entity_id_field_values is not None:
- entity_ids = np.arange(
- len(np.unique(entity_id_field_values))
- )
+ entity_ids = np.arange(len(np.unique(entity_id_field_values)))
else:
entity_ids = np.arange(len(data[list(data.keys())[0]]))
builder.declare_entity(group_entity.key, entity_ids)
- person_membership_id_field = (
- f"{person_entity.key}_{group_entity.key}_id"
- )
+ person_membership_id_field = f"{person_entity.key}_{group_entity.key}_id"
if self.dataset.data_format != Dataset.FLAT_FILE:
- assert (
- person_membership_id_field in data
- ), f"Missing {person_membership_id_field} column in the dataset. Each group entity must have a person membership array defined for ETERNITY."
+ assert person_membership_id_field in data, (
+ f"Missing {person_membership_id_field} column in the dataset. Each group entity must have a person membership array defined for ETERNITY."
+ )
elif person_membership_id_field not in data:
data[person_membership_id_field] = np.arange(len(data))
- person_membership_ids = get_eternity_array(
- person_membership_id_field
- )
+ person_membership_ids = get_eternity_array(person_membership_id_field)
person_role_field = f"{person_entity.key}_{group_entity.key}_role"
if person_role_field in data:
@@ -389,16 +360,12 @@ def get_eternity_array(name):
variable_name, time_period = variable.split("__")
else:
variable_name = variable
- time_period = (
- self.dataset.time_period or self.default_input_period
- )
+ time_period = self.dataset.time_period or self.default_input_period
if variable_name not in self.tax_benefit_system.variables:
continue
- variable_meta = self.tax_benefit_system.get_variable(
- variable_name
- )
+ variable_meta = self.tax_benefit_system.get_variable(variable_name)
entity = variable_meta.entity
population = self.get_population(entity.plural)
@@ -481,9 +448,7 @@ def calculate(
elif period is None and self.default_calculation_period is not None:
period = periods.period(self.default_calculation_period)
- self.tracer.record_calculation_start(
- variable_name, period, self.branch_name
- )
+ self.tracer.record_calculation_start(variable_name, period, self.branch_name)
np.random.seed(hash(variable_name + str(period)) % 1000000)
@@ -599,9 +564,7 @@ def calculate_dataframe(
df[variable_name] = self.calculate(variable_name, period, map_to)
return df
- def _calculate(
- self, variable_name: str, period: Period = None
- ) -> ArrayLike:
+ def _calculate(self, variable_name: str, period: Period = None) -> ArrayLike:
"""
Calculate the variable ``variable_name`` for the period ``period``, using the variable formula if it exists.
@@ -651,10 +614,7 @@ def _calculate(
)
cache_path = smc.get_cache_path()
if cache_path.exists():
- if (
- not self.macro_cache_read
- or self.tax_benefit_system.data_modified
- ):
+ if not self.macro_cache_read or self.tax_benefit_system.data_modified:
value = None
else:
value = smc.get_cache_value(cache_path)
@@ -663,18 +623,14 @@ def _calculate(
return value
if variable.requires_computation_after is not None:
- variables_in_stack = [
- node.get("name") for node in self.tracer.stack
- ]
+ variables_in_stack = [node.get("name") for node in self.tracer.stack]
variable_in_stack = (
variable.requires_computation_after in variables_in_stack
)
required_is_known_periods = self.get_holder(
variable.requires_computation_after
).get_known_periods()
- if (not variable_in_stack) and (
- not len(required_is_known_periods) > 0
- ):
+ if (not variable_in_stack) and (not len(required_is_known_periods) > 0):
raise ValueError(
f"Variable {variable_name} requires {variable.requires_computation_after} to be requested first. That variable is known in: {required_is_known_periods}. The full stack is: {variables_in_stack}. {variable_in_stack, len(required_is_known_periods) > 0}"
)
@@ -702,9 +658,7 @@ def _calculate(
if variable.defined_for is not None:
mask = (
- self.calculate(
- variable.defined_for, period, map_to=variable.entity.key
- )
+ self.calculate(variable.defined_for, period, map_to=variable.entity.key)
> 0
)
if np.all(~mask):
@@ -731,9 +685,7 @@ def _calculate(
and known_period.start < period.start
]
if variable.uprating is not None and len(start_instants) > 0:
- latest_known_period = known_periods[
- np.argmax(start_instants)
- ]
+ latest_known_period = known_periods[np.argmax(start_instants)]
try:
uprating_parameter = get_parameter(
self.tax_benefit_system.parameters,
@@ -743,16 +695,12 @@ def _calculate(
raise ValueError(
f"Could not find uprating parameter {variable.uprating} when trying to uprate {variable_name}."
)
- value_in_last_period = uprating_parameter(
- latest_known_period.start
- )
+ value_in_last_period = uprating_parameter(latest_known_period.start)
value_in_this_period = uprating_parameter(period.start)
if value_in_last_period == 0:
uprating_factor = 1
else:
- uprating_factor = (
- value_in_this_period / value_in_last_period
- )
+ uprating_factor = value_in_this_period / value_in_last_period
array = (
holder.get_array(latest_known_period, self.branch_name)
@@ -829,9 +777,9 @@ def calculate_add(
period = periods.period(period)
# Check that the requested period matches definition_period
- if periods.unit_weight(
- variable.definition_period
- ) > periods.unit_weight(period.unit):
+ if periods.unit_weight(variable.definition_period) > periods.unit_weight(
+ period.unit
+ ):
raise ValueError(
"Unable to compute variable '{0}' for period {1}: '{0}' can only be computed for {2}-long periods. You can use the DIVIDE option to get an estimate of {0} by dividing the yearly value by 12, or change the requested period to 'period.this_year'.".format(
variable.name, period, variable.definition_period
@@ -885,9 +833,7 @@ def calculate_divide(
if period.unit == periods.MONTH:
computation_period = period.this_year
- result = (
- self.calculate(variable_name, period=computation_period) / 12.0
- )
+ result = self.calculate(variable_name, period=computation_period) / 12.0
holder = self.get_holder(variable.name)
holder.put_in_cache(result, period, self.branch_name)
return result
@@ -900,9 +846,7 @@ def calculate_divide(
)
)
- def calculate_output(
- self, variable_name: str, period: Period = None
- ) -> ArrayLike:
+ def calculate_output(self, variable_name: str, period: Period = None) -> ArrayLike:
"""
Calculate the value of a variable using the ``calculate_output`` attribute of the variable.
"""
@@ -974,10 +918,7 @@ def _run_formula(
if values is None:
values = 0
for subtracted_variable in subtracts_list:
- if (
- subtracted_variable
- in self.tax_benefit_system.variables
- ):
+ if subtracted_variable in self.tax_benefit_system.variables:
values = values - self.calculate(
subtracted_variable,
period,
@@ -1012,29 +953,21 @@ def _run_formula(
return array
- def _check_period_consistency(
- self, period: Period, variable: Variable
- ) -> None:
+ def _check_period_consistency(self, period: Period, variable: Variable) -> None:
"""
Check that a period matches the variable definition_period
"""
if variable.definition_period == periods.ETERNITY:
return # For variables which values are constant in time, all periods are accepted
- if (
- variable.definition_period == periods.MONTH
- and period.unit != periods.MONTH
- ):
+ if variable.definition_period == periods.MONTH and period.unit != periods.MONTH:
raise ValueError(
"Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole month. You can use the ADD option to sum '{0}' over the requested period, or change the requested period to 'period.first_month'.".format(
variable.name, period
)
)
- if (
- variable.definition_period == periods.YEAR
- and period.unit != periods.YEAR
- ):
+ if variable.definition_period == periods.YEAR and period.unit != periods.YEAR:
raise ValueError(
"Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole year. You can use the DIVIDE option to get an estimate of {0} by dividing the yearly value by 12, or change the requested period to 'period.this_year'.".format(
variable.name, period
@@ -1080,8 +1013,7 @@ def _check_for_cycle(self, variable: str, period: Period) -> None:
previous_periods = [
frame["period"]
for frame in self.tracer.stack[:-1]
- if frame["name"] == variable
- and frame["branch_name"] == self.branch_name
+ if frame["name"] == variable and frame["branch_name"] == self.branch_name
]
if period in previous_periods:
found_last_frame = False
@@ -1140,17 +1072,13 @@ def get_array(self, variable_name: str, period: Period) -> ArrayLike:
"""
if period is not None and not isinstance(period, Period):
period = periods.period(period)
- return self.get_holder(variable_name).get_array(
- period, self.branch_name
- )
+ return self.get_holder(variable_name).get_array(period, self.branch_name)
def get_holder(self, variable_name: str) -> Holder:
"""
Get the :obj:`.Holder` associated with the variable ``variable_name`` for the simulation
"""
- return self.get_variable_population(variable_name).get_holder(
- variable_name
- )
+ return self.get_variable_population(variable_name).get_holder(variable_name)
def get_memory_usage(self, variables: List[str] = None) -> dict:
"""
@@ -1211,9 +1139,7 @@ def get_known_periods(self, variable: str) -> List[Period]:
"""
return self.get_holder(variable).get_known_periods()
- def set_input(
- self, variable_name: str, period: Period, value: ArrayLike
- ) -> None:
+ def set_input(self, variable_name: str, period: Period, value: ArrayLike) -> None:
"""
Set a variable's value for a given period
@@ -1238,9 +1164,7 @@ def set_input(
)
if (variable.end is not None) and (period.start.date > variable.end):
return
- self.get_holder(variable_name).set_input(
- period, value, self.branch_name
- )
+ self.get_holder(variable_name).set_input(period, value, self.branch_name)
def get_variable_population(self, variable_name: str) -> Population:
variable = self.tax_benefit_system.get_variable(
@@ -1391,10 +1315,7 @@ def extract_person(
for population in self.populations.values():
entity = population.entity
- if (
- not population.entity.is_person
- and entity.key not in exclude_entities
- ):
+ if not population.entity.is_person and entity.key not in exclude_entities:
situation[entity.plural] = {
entity.key: {
"members": [],
@@ -1412,19 +1333,15 @@ def extract_person(
people_indices_by_entity[entity.key] = other_people_indices
for variable in self.input_variables:
if (
- self.tax_benefit_system.get_variable(
- variable
- ).entity.key
+ self.tax_benefit_system.get_variable(variable).entity.key
== entity.key
):
- known_periods = self.get_holder(
- variable
- ).get_known_periods()
+ known_periods = self.get_holder(variable).get_known_periods()
if len(known_periods) > 0:
first_known_period = known_periods[0]
- value = self.calculate(
- variable, first_known_period
- )[group_index]
+ value = self.calculate(variable, first_known_period)[
+ group_index
+ ]
situation[entity.plural][entity.key][variable] = {
str(known_periods[0]): value
}
@@ -1436,18 +1353,14 @@ def extract_person(
for entity_key in people_indices_by_entity:
entity = self.populations[entity_key].entity
if person_index in people_indices_by_entity[entity.key]:
- situation[entity.plural][entity.key]["members"].append(
- person_name
- )
+ situation[entity.plural][entity.key]["members"].append(person_name)
situation[person.plural][person_name] = {}
for variable in self.input_variables:
if (
self.tax_benefit_system.get_variable(variable).entity.key
== person.key
):
- known_periods = self.get_holder(
- variable
- ).get_known_periods()
+ known_periods = self.get_holder(variable).get_known_periods()
if len(known_periods) > 0:
first_known_period = known_periods[0]
value = self.calculate(variable, first_known_period)[
@@ -1488,9 +1401,7 @@ def check_macro_cache(self, variable_name: str, period: str) -> bool:
return False
for parameter in parameter_deps:
- param = get_parameter(
- self.tax_benefit_system.parameters, parameter
- )
+ param = get_parameter(self.tax_benefit_system.parameters, parameter)
if param.modified:
return False
@@ -1624,9 +1535,7 @@ def subsample(
household_id_to_count[household_id] = 0
household_id_to_count[household_id] += 1
- subset_df = df[
- df[df_household_id_column].isin(chosen_household_ids)
- ].copy()
+ subset_df = df[df[df_household_id_column].isin(chosen_household_ids)].copy()
household_counts = subset_df[df_household_id_column].map(
lambda x: household_id_to_count.get(x, 0)
@@ -1641,9 +1550,7 @@ def subsample(
subset_df[col] *= household_counts.values
else:
subset_df[col] = household_counts.values
- subset_df[col] *= (
- target_total_weight / subset_df[col].values.sum()
- )
+ subset_df[col] *= target_total_weight / subset_df[col].values.sum()
df = subset_df
@@ -1653,9 +1560,7 @@ def subsample(
# Ensure the baseline branch has the new data.
if "baseline" in self.branches:
- baseline_tax_benefit_system = self.branches[
- "baseline"
- ].tax_benefit_system
+ baseline_tax_benefit_system = self.branches["baseline"].tax_benefit_system
self.branches["baseline"] = self.clone()
self.branches["tax_benefit_system"] = baseline_tax_benefit_system
diff --git a/policyengine_core/simulations/simulation_builder.py b/policyengine_core/simulations/simulation_builder.py
index f7a17e017..628a478c4 100644
--- a/policyengine_core/simulations/simulation_builder.py
+++ b/policyengine_core/simulations/simulation_builder.py
@@ -60,9 +60,7 @@ def __init__(self):
self.has_axes = False
self.axes_entity_counts: typing.Dict[Entity.plural, int] = {}
self.axes_entity_ids: typing.Dict[Entity.plural, typing.List[int]] = {}
- self.axes_memberships: typing.Dict[Entity.plural, typing.List[int]] = (
- {}
- )
+ self.axes_memberships: typing.Dict[Entity.plural, typing.List[int]] = {}
self.axes_roles: typing.Dict[Entity.plural, typing.List[int]] = {}
def build_from_dict(
@@ -77,12 +75,9 @@ def build_from_dict(
This method uses :any:`build_from_entities` if entities are fully specified, or :any:`build_from_variables` if not.
"""
- input_dict = self.explicit_singular_entities(
- tax_benefit_system, input_dict
- )
+ input_dict = self.explicit_singular_entities(tax_benefit_system, input_dict)
if any(
- key in tax_benefit_system.entities_plural()
- for key in input_dict.keys()
+ key in tax_benefit_system.entities_plural() for key in input_dict.keys()
):
simulation = self.build_from_entities(
tax_benefit_system, input_dict, simulation
@@ -95,8 +90,7 @@ def build_from_dict(
simulation.input_variables = [
variable.name
for variable in simulation.tax_benefit_system.variables.values()
- if len(simulation.get_holder(variable.name).get_known_periods())
- > 0
+ if len(simulation.get_holder(variable.name).get_known_periods()) > 0
]
return simulation
@@ -155,9 +149,7 @@ def build_from_entities(
", ".join(tax_benefit_system.entities_plural()),
),
)
- persons_json = input_dict.get(
- tax_benefit_system.person_entity.plural, None
- )
+ persons_json = input_dict.get(tax_benefit_system.person_entity.plural, None)
if not persons_json:
raise SituationParsingError(
@@ -167,9 +159,7 @@ def build_from_entities(
),
)
- persons_ids = self.add_person_entity(
- simulation.persons.entity, persons_json
- )
+ persons_ids = self.add_person_entity(simulation.persons.entity, persons_json)
for entity_class in tax_benefit_system.group_entities:
instances_json = input_dict.get(entity_class.plural)
@@ -190,18 +180,14 @@ def build_from_entities(
try:
self.finalize_variables_init(simulation.persons)
except PeriodMismatchError as e:
- self.raise_period_mismatch(
- simulation.persons.entity, persons_json, e
- )
+ self.raise_period_mismatch(simulation.persons.entity, persons_json, e)
for entity_class in tax_benefit_system.group_entities:
try:
population = simulation.populations[entity_class.key]
self.finalize_variables_init(population)
except PeriodMismatchError as e:
- self.raise_period_mismatch(
- population.entity, instances_json, e
- )
+ self.raise_period_mismatch(population.entity, instances_json, e)
return simulation
@@ -296,9 +282,9 @@ def join_with_persons(
roles: typing.Iterable[str],
) -> None:
# Maps group's identifiers to a 0-based integer range, for indexing into members_roles (see PR#876)
- group_sorted_indices = np.unique(
- persons_group_assignment, return_inverse=True
- )[1]
+ group_sorted_indices = np.unique(persons_group_assignment, return_inverse=True)[
+ 1
+ ]
group_population.members_entity_id = np.argsort(group_population.ids)[
group_sorted_indices
]
@@ -306,9 +292,7 @@ def join_with_persons(
flattened_roles = group_population.entity.flattened_roles
roles_array = np.array(roles)
if np.issubdtype(roles_array.dtype, np.integer):
- group_population.members_role = np.array(flattened_roles)[
- roles_array
- ]
+ group_population.members_role = np.array(flattened_roles)[roles_array]
else:
if len(flattened_roles) == 0:
group_population.members_role = np.int64(0)
@@ -360,9 +344,7 @@ def explicit_singular_entities(
return result
- def add_person_entity(
- self, entity: Entity, instances_json: dict
- ) -> List[int]:
+ def add_person_entity(self, entity: Entity, instances_json: dict) -> List[int]:
"""
Add the simulation's instances of the persons entity as described in ``instances_json``.
"""
@@ -374,24 +356,16 @@ def add_person_entity(
for instance_id, instance_object in instances_json.items():
check_type(instance_object, dict, [entity.plural, instance_id])
- self.init_variable_values(
- entity, instance_object, str(instance_id)
- )
+ self.init_variable_values(entity, instance_object, str(instance_id))
return self.get_ids(entity.plural)
- def add_default_group_entity(
- self, persons_ids: ArrayLike, entity: Entity
- ) -> None:
+ def add_default_group_entity(self, persons_ids: ArrayLike, entity: Entity) -> None:
persons_count = len(persons_ids)
self.entity_ids[entity.plural] = [entity.key]
self.entity_counts[entity.plural] = 1
- self.memberships[entity.plural] = np.zeros(
- persons_count, dtype=np.int32
- )
- self.roles[entity.plural] = np.repeat(
- entity.flattened_roles[0], persons_count
- )
+ self.memberships[entity.plural] = np.zeros(persons_count, dtype=np.int32)
+ self.roles[entity.plural] = np.repeat(entity.flattened_roles[0], persons_count)
def add_group_entity(
self,
@@ -411,9 +385,7 @@ def add_group_entity(
persons_count = len(persons_ids)
persons_to_allocate = set(persons_ids)
- self.memberships[entity.plural] = np.empty(
- persons_count, dtype=np.int32
- )
+ self.memberships[entity.plural] = np.empty(persons_count, dtype=np.int32)
self.roles[entity.plural] = np.empty(persons_count, dtype=object)
self.entity_ids[entity.plural] = entity_ids
@@ -422,13 +394,10 @@ def add_group_entity(
for instance_id, instance_object in instances_json.items():
check_type(instance_object, dict, [entity.plural, instance_id])
- variables_json = (
- instance_object.copy()
- ) # Don't mutate function input
+ variables_json = instance_object.copy() # Don't mutate function input
roles_json = {
- role.plural
- or role.key: transform_to_strict_syntax(
+ role.plural or role.key: transform_to_strict_syntax(
variables_json.pop(role.plural or role.key, [])
)
for role in entity.roles
@@ -456,9 +425,7 @@ def add_group_entity(
persons_to_allocate.discard(person_id)
entity_index = entity_ids.index(instance_id)
- role_by_plural = {
- role.plural or role.key: role for role in entity.roles
- }
+ role_by_plural = {role.plural or role.key: role for role in entity.roles}
for role_plural, persons_with_role in roles_json.items():
role = role_by_plural[role_plural]
@@ -469,17 +436,11 @@ def add_group_entity(
f"There can be at most {role.max} {role_plural} in a {entity.key}. {len(persons_with_role)} were declared in '{instance_id}'.",
)
- for index_within_role, person_id in enumerate(
- persons_with_role
- ):
+ for index_within_role, person_id in enumerate(persons_with_role):
person_index = persons_ids.index(person_id)
- self.memberships[entity.plural][
- person_index
- ] = entity_index
+ self.memberships[entity.plural][person_index] = entity_index
person_role = (
- role.subroles[index_within_role]
- if role.subroles
- else role
+ role.subroles[index_within_role] if role.subroles else role
)
self.roles[entity.plural][person_index] = person_role
@@ -489,21 +450,17 @@ def add_group_entity(
entity_ids = entity_ids + list(persons_to_allocate)
for person_id in persons_to_allocate:
person_index = persons_ids.index(person_id)
- self.memberships[entity.plural][person_index] = (
- entity_ids.index(person_id)
- )
- self.roles[entity.plural][person_index] = (
- entity.flattened_roles[0]
+ self.memberships[entity.plural][person_index] = entity_ids.index(
+ person_id
)
+ self.roles[entity.plural][person_index] = entity.flattened_roles[0]
# Adjust previously computed ids and counts
self.entity_ids[entity.plural] = entity_ids
self.entity_counts[entity.plural] = len(entity_ids)
# Convert back to Python array
self.roles[entity.plural] = self.roles[entity.plural].tolist()
- self.memberships[entity.plural] = self.memberships[
- entity.plural
- ].tolist()
+ self.memberships[entity.plural] = self.memberships[entity.plural].tolist()
def set_default_period(self, period_str: str) -> None:
if period_str:
@@ -529,9 +486,7 @@ def check_persons_to_allocate(
persons_to_allocate,
index,
):
- check_type(
- person_id, str, [entity_plural, entity_id, role_id, str(index)]
- )
+ check_type(person_id, str, [entity_plural, entity_id, role_id, str(index)])
if person_id not in persons_ids:
raise SituationParsingError(
[entity_plural, entity_id, role_id],
@@ -552,9 +507,7 @@ def init_variable_values(self, entity, instance_object, instance_id):
path_in_json = [entity.plural, instance_id, variable_name]
try:
entity.check_variable_defined_for_entity(variable_name)
- except (
- ValueError
- ) as e: # The variable is defined for another entity
+ except ValueError as e: # The variable is defined for another entity
raise SituationParsingError(path_in_json, e.args[0])
except VariableNotFoundError as e: # The variable doesn't exist
raise SituationParsingError(path_in_json, str(e), code=404)
@@ -605,9 +558,7 @@ def add_variable_value(
array[instance_index] = value
- self.input_buffer[variable.name][
- str(periods.period(period_str))
- ] = array
+ self.input_buffer[variable.name][str(periods.period(period_str))] = array
def finalize_variables_init(self, population):
# Due to set_input mechanism, we must bufferize all inputs, then actually set them,
@@ -617,9 +568,7 @@ def finalize_variables_init(self, population):
population.count = self.get_count(plural_key)
population.ids = self.get_ids(plural_key)
if plural_key in self.memberships:
- population.members_entity_id = np.array(
- self.get_memberships(plural_key)
- )
+ population.members_entity_id = np.array(self.get_memberships(plural_key))
population.members_role = np.array(self.get_roles(plural_key))
for variable_name in self.input_buffer.keys():
try:
@@ -632,9 +581,7 @@ def finalize_variables_init(self, population):
for period_str in self.input_buffer[variable_name].keys()
]
# We need to handle small periods first for set_input to work
- sorted_periods = sorted(
- unsorted_periods, key=periods.key_period_size
- )
+ sorted_periods = sorted(unsorted_periods, key=periods.key_period_size)
for period_value in sorted_periods:
values = buffer[str(period_value)]
# Hack to replicate the values in the persons entity
@@ -643,9 +590,7 @@ def finalize_variables_init(self, population):
variable = holder.variable
# TODO - this duplicates the check in Simulation.set_input, but
# fixing that requires improving Simulation's handling of entities
- if (variable.end is None) or (
- period_value.start.date <= variable.end
- ):
+ if (variable.end is None) or (period_value.start.date <= variable.end):
holder.set_input(period_value, array)
def raise_period_mismatch(self, entity, json, e):
@@ -671,15 +616,11 @@ def raise_period_mismatch(self, entity, json, e):
# Returns the total number of instances of this entity, including when there is replication along axes
def get_count(self, entity_name):
- return self.axes_entity_counts.get(
- entity_name, self.entity_counts[entity_name]
- )
+ return self.axes_entity_counts.get(entity_name, self.entity_counts[entity_name])
# Returns the ids of instances of this entity, including when there is replication along axes
def get_ids(self, entity_name):
- return self.axes_entity_ids.get(
- entity_name, self.entity_ids[entity_name]
- )
+ return self.axes_entity_ids.get(entity_name, self.entity_ids[entity_name])
# Returns the memberships of individuals in this entity, including when there is replication along axes
def get_memberships(self, entity_name):
@@ -691,9 +632,7 @@ def get_memberships(self, entity_name):
# Returns the roles of individuals in this entity, including when there is replication along axes
def get_roles(self, entity_name):
# Return empty array for the "persons" entity
- return self.axes_roles.get(
- entity_name, self.roles.get(entity_name, [])
- )
+ return self.axes_roles.get(entity_name, self.roles.get(entity_name, []))
def add_parallel_axis(self, axis):
# All parallel axes have the same count and entity.
@@ -726,9 +665,7 @@ def expand_axes(self):
indices = np.repeat(
np.arange(0, cell_count), self.entity_counts[entity_name]
)
- adjusted_ids = [
- id + str(ix) for id, ix in zip(original_ids, indices)
- ]
+ adjusted_ids = [id + str(ix) for id, ix in zip(original_ids, indices)]
self.axes_entity_ids[entity_name] = adjusted_ids
# Adjust roles
original_roles = self.get_roles(entity_name)
@@ -738,13 +675,9 @@ def expand_axes(self):
if entity_name != self.persons_plural:
original_memberships = self.get_memberships(entity_name)
# repeat membership, e.g. [1, 0] -> [1, 0, 1, 0, ...]
- repeated_memberships = np.tile(
- original_memberships, cell_count
- )
+ repeated_memberships = np.tile(original_memberships, cell_count)
indices = (
- np.repeat(
- np.arange(0, cell_count), len(original_memberships)
- )
+ np.repeat(np.arange(0, cell_count), len(original_memberships))
* self.entity_counts[entity_name]
)
adjusted_memberships = (
@@ -768,9 +701,7 @@ def expand_axes(self):
variable = axis_entity.get_variable(axis_name)
array = self.get_input(axis_name, str(axis_period))
if array is None:
- array = variable.default_array(
- axis_count * axis_entity_step_size
- )
+ array = variable.default_array(axis_count * axis_entity_step_size)
elif array.size == axis_entity_step_size:
array = np.tile(array, axis_count)
array[axis_index::axis_entity_step_size] = np.linspace(
@@ -809,9 +740,7 @@ def expand_axes(self):
array = np.tile(array, cell_count)
array[axis_index::axis_entity_step_size] = axis[
"min"
- ] + mesh.reshape(cell_count) * (
- axis["max"] - axis["min"]
- ) / (
+ ] + mesh.reshape(cell_count) * (axis["max"] - axis["min"]) / (
axis_count - 1
)
self.input_buffer[axis_name][str(axis_period)] = array
diff --git a/policyengine_core/simulations/simulation_macro_cache.py b/policyengine_core/simulations/simulation_macro_cache.py
index 9ecb702f3..80520c6b3 100644
--- a/policyengine_core/simulations/simulation_macro_cache.py
+++ b/policyengine_core/simulations/simulation_macro_cache.py
@@ -12,18 +12,14 @@ class Singleton(type):
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
- cls._instances[cls] = super(Singleton, cls).__call__(
- *args, **kwargs
- )
+ cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class SimulationMacroCache(metaclass=Singleton):
def __init__(self, tax_benefit_system: TaxBenefitSystem):
self.core_version = importlib.metadata.version("policyengine-core")
- self.country_package_metadata = (
- tax_benefit_system.get_package_metadata()
- )
+ self.country_package_metadata = tax_benefit_system.get_package_metadata()
self.country_version = self.country_package_metadata["version"]
self.cache_folder_path = None
self.cache_file_path = None
@@ -61,13 +57,9 @@ def get_cache_path(self):
def get_cache_value(self, cache_file_path: Path):
with h5py.File(cache_file_path, "r") as f:
# Validate both core version and country package metadata are up-to-date, otherwise flush the cache
- if (
- "metadata:core_version" in f
- and "metadata:country_version" in f
- ):
+ if "metadata:core_version" in f and "metadata:country_version" in f:
if (
- f["metadata:core_version"][()].decode("utf-8")
- != self.core_version
+ f["metadata:core_version"][()].decode("utf-8") != self.core_version
or f["metadata:country_version"][()].decode("utf-8")
!= self.country_version
):
diff --git a/policyengine_core/taxbenefitsystems/tax_benefit_system.py b/policyengine_core/taxbenefitsystems/tax_benefit_system.py
index 39aab4bf3..077f6f27e 100644
--- a/policyengine_core/taxbenefitsystems/tax_benefit_system.py
+++ b/policyengine_core/taxbenefitsystems/tax_benefit_system.py
@@ -73,9 +73,7 @@ class TaxBenefitSystem:
_parameters_at_instant_cache: Optional[Dict[Any, Any]] = None
person_key_plural: str = None
preprocess_parameters: str = None
- baseline: "TaxBenefitSystem" = (
- None # Baseline tax-benefit system. Used only by reforms. Note: Reforms can be chained.
- )
+ baseline: "TaxBenefitSystem" = None # Baseline tax-benefit system. Used only by reforms. Note: Reforms can be chained.
cache_blacklist = None
decomposition_file_path = None
variable_module_metadata: dict = None
@@ -108,13 +106,9 @@ def __init__(self, entities: Sequence[Entity] = None, reform=None) -> None:
self.variables: Dict[Any, Any] = {}
# Tax benefit systems are mutable, so entities (which need to know about our variables) can't be shared among them
if entities is None or len(entities) == 0:
- raise Exception(
- "A tax and benefit sytem must have at least an entity."
- )
+ raise Exception("A tax and benefit sytem must have at least an entity.")
self.entities = [copy.copy(entity) for entity in entities]
- self.person_entity = [
- entity for entity in self.entities if entity.is_person
- ][0]
+ self.person_entity = [entity for entity in self.entities if entity.is_person][0]
self.group_entities = [
entity for entity in self.entities if not entity.is_person
]
@@ -183,9 +177,7 @@ def add_abolition_parameters(self):
if not "abolitions" in self.parameters.gov.children:
self.parameters.gov.add_child(
"abolitions",
- ParameterNode(
- name="gov.abolitions", data=abolition_folder_data
- ),
+ ParameterNode(name="gov.abolitions", data=abolition_folder_data),
)
@property
@@ -293,14 +285,10 @@ def add_variables_from_file(self, file_path: str) -> None:
# As Python remembers loaded modules by name, in order to prevent collisions, we need to make sure that:
# - Files with the same name, but located in different directories, have a different module names. Hence the file path hash in the module name.
# - The same file, loaded by different tax and benefit systems, has distinct module names. Hence the `id(self)` in the module name.
- module_name = (
- f"{id(self)}_{hash(os.path.abspath(file_path))}_{file_name}"
- )
+ module_name = f"{id(self)}_{hash(os.path.abspath(file_path))}_{file_name}"
try:
- spec = importlib.util.spec_from_file_location(
- module_name, file_path
- )
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
@@ -317,9 +305,7 @@ def add_variables_from_file(self, file_path: str) -> None:
]
metadata = {}
- metadata["label"] = module.__dict__.get(
- "label", relative_file_path
- )
+ metadata["label"] = module.__dict__.get("label", relative_file_path)
metadata["description"] = module.__dict__.get("description", None)
metadata["index"] = module.__dict__.get("index", 0)
self.variable_module_metadata[relative_file_path] = metadata
@@ -338,9 +324,7 @@ def add_variables_from_file(self, file_path: str) -> None:
self.add_variable(pot_variable)
except Exception:
log.error(
- 'Unable to load OpenFisca variables from file "{}"'.format(
- file_path
- )
+ 'Unable to load OpenFisca variables from file "{}"'.format(file_path)
)
raise
@@ -370,9 +354,7 @@ def add_variable_metadata_from_folder(self, file_path: str) -> None:
self.variable_module_metadata[relative_file_path] = metadata
except Exception:
log.error(
- 'Unable to load OpenFisca metadata from file "{}"'.format(
- file_path
- )
+ 'Unable to load OpenFisca metadata from file "{}"'.format(file_path)
)
raise
@@ -388,9 +370,7 @@ def add_variables_from_directory(self, directory: str) -> None:
py_files.remove(init_module)
self.add_variable_metadata_from_folder(init_module)
if "README.md" in os.listdir(directory):
- self.add_variable_metadata_from_folder(
- os.path.join(directory, "README.md")
- )
+ self.add_variable_metadata_from_folder(os.path.join(directory, "README.md"))
for py_file in py_files:
self.add_variables_from_file(py_file)
subdirectories = glob.glob(os.path.join(directory, "*/"))
@@ -480,9 +460,7 @@ def apply_reform(self, reform_path: str) -> "TaxBenefitSystem":
)
if not issubclass(reform, Reform):
raise ValueError(
- "`{}` does not seem to be a valid Openfisca reform.".format(
- reform_path
- )
+ "`{}` does not seem to be a valid Openfisca reform.".format(reform_path)
)
return reform(self)
@@ -554,9 +532,7 @@ def _get_baseline_parameters_at_instant(
return self.get_parameters_at_instant(instant)
return baseline._get_baseline_parameters_at_instant(instant)
- def get_parameters_at_instant(
- self, instant: Instant
- ) -> ParameterNodeAtInstant:
+ def get_parameters_at_instant(self, instant: Instant) -> ParameterNodeAtInstant:
"""
Get the parameters of the legislation at a given instant
@@ -569,17 +545,15 @@ def get_parameters_at_instant(
elif isinstance(instant, (str, int)):
instant = periods.instant(instant)
else:
- assert isinstance(
- instant, Instant
- ), "Expected an Instant (e.g. Instant((2017, 1, 1)) ). Got: {}.".format(
- instant
+ assert isinstance(instant, Instant), (
+ "Expected an Instant (e.g. Instant((2017, 1, 1)) ). Got: {}.".format(
+ instant
+ )
)
parameters_at_instant = self._parameters_at_instant_cache.get(instant)
if parameters_at_instant is None and self.parameters is not None:
- parameters_at_instant = self.parameters.get_at_instant(
- str(instant)
- )
+ parameters_at_instant = self.parameters.get_at_instant(str(instant))
self._parameters_at_instant_cache[instant] = parameters_at_instant
return parameters_at_instant
@@ -621,15 +595,11 @@ def get_package_metadata(self) -> dict:
except importlib.metadata.PackageNotFoundError:
return fallback_metadata
- location = (
- inspect.getsourcefile(module).split(package_name)[0].rstrip("/")
- )
+ location = inspect.getsourcefile(module).split(package_name)[0].rstrip("/")
try:
home_page_metadatas = [
metadata.split(":", 1)[1].strip(" ")
- for metadata in distribution._get_metadata(
- distribution.PKG_INFO
- )
+ for metadata in distribution._get_metadata(distribution.PKG_INFO)
if "Home-page" in metadata
]
except:
diff --git a/policyengine_core/taxscales/abstract_rate_tax_scale.py b/policyengine_core/taxscales/abstract_rate_tax_scale.py
index 8a32be2df..4c2006b19 100644
--- a/policyengine_core/taxscales/abstract_rate_tax_scale.py
+++ b/policyengine_core/taxscales/abstract_rate_tax_scale.py
@@ -37,6 +37,5 @@ def calc(
right: bool,
) -> typing.NoReturn:
raise NotImplementedError(
- "Method 'calc' is not implemented for "
- f"{self.__class__.__name__}",
+ f"Method 'calc' is not implemented for {self.__class__.__name__}",
)
diff --git a/policyengine_core/taxscales/abstract_tax_scale.py b/policyengine_core/taxscales/abstract_tax_scale.py
index dae99df02..53a3427d2 100644
--- a/policyengine_core/taxscales/abstract_tax_scale.py
+++ b/policyengine_core/taxscales/abstract_tax_scale.py
@@ -33,8 +33,7 @@ def __init__(
def __repr__(self) -> typing.NoReturn:
raise NotImplementedError(
- "Method '__repr__' is not implemented for "
- f"{self.__class__.__name__}",
+ f"Method '__repr__' is not implemented for {self.__class__.__name__}",
)
def calc(
@@ -43,12 +42,10 @@ def calc(
right: bool,
) -> typing.NoReturn:
raise NotImplementedError(
- "Method 'calc' is not implemented for "
- f"{self.__class__.__name__}",
+ f"Method 'calc' is not implemented for {self.__class__.__name__}",
)
def to_dict(self) -> typing.NoReturn:
raise NotImplementedError(
- f"Method 'to_dict' is not implemented for "
- f"{self.__class__.__name__}",
+ f"Method 'to_dict' is not implemented for {self.__class__.__name__}",
)
diff --git a/policyengine_core/taxscales/amount_tax_scale_like.py b/policyengine_core/taxscales/amount_tax_scale_like.py
index 4619189e7..1a0cfebbf 100644
--- a/policyengine_core/taxscales/amount_tax_scale_like.py
+++ b/policyengine_core/taxscales/amount_tax_scale_like.py
@@ -29,9 +29,7 @@ def __repr__(self) -> str:
os.linesep.join(
[
f"- threshold: {threshold}{os.linesep} amount: {amount}"
- for (threshold, amount) in zip(
- self.thresholds, self.amounts
- )
+ for (threshold, amount) in zip(self.thresholds, self.amounts)
]
)
)
diff --git a/policyengine_core/taxscales/single_amount_tax_scale.py b/policyengine_core/taxscales/single_amount_tax_scale.py
index 934a8a2f6..52effc785 100644
--- a/policyengine_core/taxscales/single_amount_tax_scale.py
+++ b/policyengine_core/taxscales/single_amount_tax_scale.py
@@ -20,9 +20,7 @@ def calc(
Matches the input amount to a set of brackets and returns the single
cell value that fits within that bracket.
"""
- guarded_thresholds = numpy.array(
- [-numpy.inf] + self.thresholds + [numpy.inf]
- )
+ guarded_thresholds = numpy.array([-numpy.inf] + self.thresholds + [numpy.inf])
bracket_indices = numpy.digitize(
tax_base,
diff --git a/policyengine_core/taxscales/tax_scale_like.py b/policyengine_core/taxscales/tax_scale_like.py
index dbedea580..1a7254c7a 100644
--- a/policyengine_core/taxscales/tax_scale_like.py
+++ b/policyengine_core/taxscales/tax_scale_like.py
@@ -37,14 +37,12 @@ def __init__(
def __eq__(self, _other: object) -> typing.NoReturn:
raise NotImplementedError(
- "Method '__eq__' is not implemented for "
- f"{self.__class__.__name__}",
+ f"Method '__eq__' is not implemented for {self.__class__.__name__}",
)
def __ne__(self, _other: object) -> typing.NoReturn:
raise NotImplementedError(
- "Method '__ne__' is not implemented for "
- f"{self.__class__.__name__}",
+ f"Method '__ne__' is not implemented for {self.__class__.__name__}",
)
@abc.abstractmethod
diff --git a/policyengine_core/tools/__init__.py b/policyengine_core/tools/__init__.py
index b1d67e64d..5cc386fd8 100644
--- a/policyengine_core/tools/__init__.py
+++ b/policyengine_core/tools/__init__.py
@@ -48,20 +48,20 @@ def assert_near(
value = np.array(value).astype(np.float32)
diff = abs(target_value - value)
if absolute_error_margin is not None:
- assert (
- diff <= absolute_error_margin
- ).all(), "{}{} differs from {} with an absolute margin {} > {}".format(
- message, value, target_value, diff, absolute_error_margin
+ assert (diff <= absolute_error_margin).all(), (
+ "{}{} differs from {} with an absolute margin {} > {}".format(
+ message, value, target_value, diff, absolute_error_margin
+ )
)
if relative_error_margin is not None:
- assert (
- diff <= abs(relative_error_margin * target_value)
- ).all(), "{}{} differs from {} with a relative margin {} > {}".format(
- message,
- value,
- target_value,
- diff,
- abs(relative_error_margin * target_value),
+ assert (diff <= abs(relative_error_margin * target_value)).all(), (
+ "{}{} differs from {} with a relative margin {} > {}".format(
+ message,
+ value,
+ target_value,
+ diff,
+ abs(relative_error_margin * target_value),
+ )
)
diff --git a/policyengine_core/tools/google_cloud.py b/policyengine_core/tools/google_cloud.py
index 825063c61..791b1121b 100644
--- a/policyengine_core/tools/google_cloud.py
+++ b/policyengine_core/tools/google_cloud.py
@@ -70,9 +70,7 @@ def download_gcs_file(
credentials, project_id = _get_gcs_credentials()
- storage_client = storage.Client(
- credentials=credentials, project=project_id
- )
+ storage_client = storage.Client(credentials=credentials, project=project_id)
bucket_obj = storage_client.bucket(bucket)
blob = bucket_obj.blob(file_path)
@@ -115,9 +113,7 @@ def upload_gcs_file(
credentials, project_id = _get_gcs_credentials()
- storage_client = storage.Client(
- credentials=credentials, project=project_id
- )
+ storage_client = storage.Client(credentials=credentials, project=project_id)
bucket_obj = storage_client.bucket(bucket)
blob = bucket_obj.blob(file_path)
diff --git a/policyengine_core/tools/simulation_dumper.py b/policyengine_core/tools/simulation_dumper.py
index 4e2f8a183..c3db0c4fa 100644
--- a/policyengine_core/tools/simulation_dumper.py
+++ b/policyengine_core/tools/simulation_dumper.py
@@ -56,9 +56,7 @@ def restore_simulation(directory, tax_benefit_system, **kwargs):
population.count = person_count
variables_to_restore = (
- variable
- for variable in os.listdir(directory)
- if variable != "__entities__"
+ variable for variable in os.listdir(directory) if variable != "__entities__"
)
for variable in variables_to_restore:
_restore_holder(simulation, variable, directory)
@@ -81,9 +79,7 @@ def _dump_entity(population, directory):
if population.entity.is_person:
return
- np.save(
- os.path.join(path, "members_position.npy"), population.members_position
- )
+ np.save(os.path.join(path, "members_position.npy"), population.members_position)
np.save(
os.path.join(path, "members_entity_id.npy"),
population.members_entity_id,
@@ -109,12 +105,8 @@ def _restore_entity(population, directory):
if population.entity.is_person:
return
- population.members_position = np.load(
- os.path.join(path, "members_position.npy")
- )
- population.members_entity_id = np.load(
- os.path.join(path, "members_entity_id.npy")
- )
+ population.members_position = np.load(os.path.join(path, "members_position.npy"))
+ population.members_entity_id = np.load(os.path.join(path, "members_entity_id.npy"))
encoded_roles = np.load(os.path.join(path, "members_role.npy"))
flattened_roles = population.entity.flattened_roles
diff --git a/policyengine_core/tools/test_runner.py b/policyengine_core/tools/test_runner.py
index 4dbf75b58..1874ceed7 100644
--- a/policyengine_core/tools/test_runner.py
+++ b/policyengine_core/tools/test_runner.py
@@ -156,9 +156,7 @@ class YamlItem(pytest.Item):
Terminal nodes of the test collection tree.
"""
- def __init__(
- self, *, baseline_tax_benefit_system, test, options, **kwargs
- ):
+ def __init__(self, *, baseline_tax_benefit_system, test, options, **kwargs):
super(YamlItem, self).__init__(**kwargs)
self.baseline_tax_benefit_system = baseline_tax_benefit_system
self.options = options
@@ -238,9 +236,7 @@ def apply(self):
try:
builder.set_default_period(period)
- self.simulation = builder.build_from_dict(
- self.tax_benefit_system, input
- )
+ self.simulation = builder.build_from_dict(self.tax_benefit_system, input)
self.simulation.default_calculation_period = builder.default_period
except (VariableNotFoundError, SituationParsingError):
raise
@@ -283,9 +279,7 @@ def generate_performance_tables(self, tracer):
tracer.generate_performance_tables(".")
def generate_variable_graph(self, tracer):
- tracer.generate_variable_graph(
- self.test.get("name"), self._all_output_vars()
- )
+ tracer.generate_variable_graph(self.test.get("name"), self._all_output_vars())
def _all_output_vars(self):
return self._get_leaf_keys(self.test["output"])
@@ -306,19 +300,11 @@ def check_output(self):
if output is None:
return
for key, expected_value in output.items():
- if self.tax_benefit_system.get_variable(
- key
- ): # If key is a variable
- self.check_variable(
- key, expected_value, self.test.get("period")
- )
- elif self.simulation.populations.get(
- key
- ): # If key is an entity singular
+ if self.tax_benefit_system.get_variable(key): # If key is a variable
+ self.check_variable(key, expected_value, self.test.get("period"))
+ elif self.simulation.populations.get(key): # If key is an entity singular
for variable_name, value in expected_value.items():
- self.check_variable(
- variable_name, value, self.test.get("period")
- )
+ self.check_variable(variable_name, value, self.test.get("period"))
else:
population = self.simulation.get_population(plural=key)
if population is not None: # If key is an entity plural
@@ -334,9 +320,7 @@ def check_output(self):
else:
raise VariableNotFoundError(key, self.tax_benefit_system)
- def check_variable(
- self, variable_name, expected_value, period, entity_index=None
- ):
+ def check_variable(self, variable_name, expected_value, period, entity_index=None):
if self.should_ignore_variable(variable_name):
return
if isinstance(expected_value, dict):
@@ -430,12 +414,7 @@ def _get_tax_benefit_system(
key = hash(
(
id(baseline),
- ":".join(
- [
- reform if isinstance(reform, str) else ""
- for reform in reforms
- ]
- ),
+ ":".join([reform if isinstance(reform, str) else "" for reform in reforms]),
reform_key,
frozenset(extensions),
)
@@ -447,13 +426,11 @@ def _get_tax_benefit_system(
for reform_path in reforms:
if isinstance(reform_path, str):
- current_tax_benefit_system = (
- current_tax_benefit_system.apply_reform(reform_path)
+ current_tax_benefit_system = current_tax_benefit_system.apply_reform(
+ reform_path
)
else:
- current_tax_benefit_system = reform_path(
- current_tax_benefit_system
- )
+ current_tax_benefit_system = reform_path(current_tax_benefit_system)
current_tax_benefit_system._parameters_at_instant_cache = {}
for extension in extensions:
@@ -503,25 +480,25 @@ def assert_near(
value = np.array(value).astype(np.float32)
except ValueError:
# Data type not translatable to floating point, assert complete equality
- assert np.array(value) == np.array(
- target_value
- ), "{}{} differs from {}".format(message, value, target_value)
+ assert np.array(value) == np.array(target_value), "{}{} differs from {}".format(
+ message, value, target_value
+ )
return
diff = abs(target_value - value)
if absolute_error_margin is not None:
- assert (
- diff <= absolute_error_margin
- ).all(), "{}{} differs from {} with an absolute margin {} > {}".format(
- message, value, target_value, diff, absolute_error_margin
+ assert (diff <= absolute_error_margin).all(), (
+ "{}{} differs from {} with an absolute margin {} > {}".format(
+ message, value, target_value, diff, absolute_error_margin
+ )
)
if relative_error_margin is not None:
- assert (
- diff <= abs(relative_error_margin * target_value)
- ).all(), "{}{} differs from {} with a relative margin {} > {}".format(
- message,
- value,
- target_value,
- diff,
- abs(relative_error_margin * target_value),
+ assert (diff <= abs(relative_error_margin * target_value)).all(), (
+ "{}{} differs from {} with a relative margin {} > {}".format(
+ message,
+ value,
+ target_value,
+ diff,
+ abs(relative_error_margin * target_value),
+ )
)
diff --git a/policyengine_core/tracers/full_tracer.py b/policyengine_core/tracers/full_tracer.py
index 052d5c2b5..3cab802ec 100644
--- a/policyengine_core/tracers/full_tracer.py
+++ b/policyengine_core/tracers/full_tracer.py
@@ -30,9 +30,7 @@ def record_calculation_start(
period: str,
branch_name: str = "default",
) -> None:
- self._simple_tracer.record_calculation_start(
- variable, period, branch_name
- )
+ self._simple_tracer.record_calculation_start(variable, period, branch_name)
self._enter_calculation(variable, period, branch_name)
self._record_start_time()
@@ -143,9 +141,7 @@ def generate_performance_graph(self, dir_path: str) -> None:
def generate_performance_tables(self, dir_path: str) -> None:
self.performance_log.generate_performance_tables(dir_path)
- def generate_variable_graph(
- self, name: str, output_vars: list[str]
- ) -> None:
+ def generate_variable_graph(self, name: str, output_vars: list[str]) -> None:
self.variable_graph.visualize(
name, aggregate=False, max_depth=None, output_vars=output_vars
)
@@ -159,9 +155,7 @@ def _get_nb_requests(self, tree: tracers.TraceNode, variable: str) -> int:
return tree_call + children_calls
def get_nb_requests(self, variable: str) -> int:
- return sum(
- self._get_nb_requests(tree, variable) for tree in self.trees
- )
+ return sum(self._get_nb_requests(tree, variable) for tree in self.trees)
def get_flat_trace(self) -> dict:
return self.flat_trace.get_trace()
diff --git a/policyengine_core/tracers/performance_log.py b/policyengine_core/tracers/performance_log.py
index 7f1b89630..8f3c81d04 100644
--- a/policyengine_core/tracers/performance_log.py
+++ b/policyengine_core/tracers/performance_log.py
@@ -70,8 +70,7 @@ def _aggregate_calculations(calculations: list) -> dict:
calculation_count = len(calculations)
calculation_time = sum(
- calculation[1]["calculation_time"]
- for calculation in calculations
+ calculation[1]["calculation_time"] for calculation in calculations
)
formula_time = sum(
diff --git a/policyengine_core/tracers/simple_tracer.py b/policyengine_core/tracers/simple_tracer.py
index 0fb7843a9..d503ec5b7 100644
--- a/policyengine_core/tracers/simple_tracer.py
+++ b/policyengine_core/tracers/simple_tracer.py
@@ -27,9 +27,7 @@ def record_calculation_start(
def record_calculation_result(self, value: ArrayLike) -> None:
pass # ignore calculation result
- def record_parameter_access(
- self, parameter: str, period, branch_name: str, value
- ):
+ def record_parameter_access(self, parameter: str, period, branch_name: str, value):
pass
def record_calculation_end(self) -> None:
diff --git a/policyengine_core/tracers/trace_node.py b/policyengine_core/tracers/trace_node.py
index d051028e8..501648774 100644
--- a/policyengine_core/tracers/trace_node.py
+++ b/policyengine_core/tracers/trace_node.py
@@ -20,9 +20,7 @@ class TraceNode:
branch_name: str = "default"
parent: typing.Optional[TraceNode] = None
children: typing.List[TraceNode] = dataclasses.field(default_factory=list)
- parameters: typing.List[TraceNode] = dataclasses.field(
- default_factory=list
- )
+ parameters: typing.List[TraceNode] = dataclasses.field(default_factory=list)
value: typing.Optional[Array] = None
start: float = 0
end: float = 0
@@ -40,9 +38,7 @@ def formula_time(self) -> float:
child.calculation_time(round_=False) for child in self.children
)
- result = (
- +self.calculation_time(round_=False) - children_calculation_time
- )
+ result = +self.calculation_time(round_=False) - children_calculation_time
return self.round(result)
diff --git a/policyengine_core/tracers/tracing_parameter_node_at_instant.py b/policyengine_core/tracers/tracing_parameter_node_at_instant.py
index 032e37f0c..d4bf17466 100644
--- a/policyengine_core/tracers/tracing_parameter_node_at_instant.py
+++ b/policyengine_core/tracers/tracing_parameter_node_at_instant.py
@@ -16,9 +16,7 @@
VectorialParameterNodeAtInstant,
)
- ParameterNode = Union[
- ParameterNodeAtInstant, VectorialParameterNodeAtInstant
- ]
+ ParameterNode = Union[ParameterNodeAtInstant, VectorialParameterNodeAtInstant]
Child = Union[ParameterNode, ArrayLike]
@@ -62,9 +60,7 @@ def get_traced_child(
parameters.VectorialParameterNodeAtInstant,
),
):
- return TracingParameterNodeAtInstant(
- child, self.tracer, self.branch_name
- )
+ return TracingParameterNodeAtInstant(child, self.tracer, self.branch_name)
if not isinstance(key, str) or isinstance(
self.parameter_node_at_instant,
@@ -78,11 +74,7 @@ def get_traced_child(
else:
name = ".".join([self.parameter_node_at_instant._name, key])
- if isinstance(
- child, (numpy.ndarray,) + parameters.ALLOWED_PARAM_TYPES
- ):
- self.tracer.record_parameter_access(
- name, period, self.branch_name, child
- )
+ if isinstance(child, (numpy.ndarray,) + parameters.ALLOWED_PARAM_TYPES):
+ self.tracer.record_parameter_access(name, period, self.branch_name, child)
return child
diff --git a/policyengine_core/tracers/variable_graph.py b/policyengine_core/tracers/variable_graph.py
index 677c65f86..723e20e99 100644
--- a/policyengine_core/tracers/variable_graph.py
+++ b/policyengine_core/tracers/variable_graph.py
@@ -121,9 +121,7 @@ def _add_nodes_and_edges(self, net: Network, root_node: VisualizeNode):
continue
- net.add_node(
- id, color=node.color(), title=node.value, label=node.name
- )
+ net.add_node(id, color=node.color(), title=node.value, label=node.name)
for child in node.children:
edge = (id, child.name)
diff --git a/policyengine_core/variables/defined_for.py b/policyengine_core/variables/defined_for.py
index 29a2a2e1a..597bc874b 100644
--- a/policyengine_core/variables/defined_for.py
+++ b/policyengine_core/variables/defined_for.py
@@ -9,9 +9,7 @@
class CallableSubset:
- def __init__(
- self, population: Population, callable: Callable, mask: ArrayLike
- ):
+ def __init__(self, population: Population, callable: Callable, mask: ArrayLike):
self.population = population
self.callable = callable
self.mask = mask
diff --git a/policyengine_core/variables/helpers.py b/policyengine_core/variables/helpers.py
index 5eca45f5d..987dcfc7d 100644
--- a/policyengine_core/variables/helpers.py
+++ b/policyengine_core/variables/helpers.py
@@ -20,8 +20,7 @@ def get_annualized_variable(
def make_annual_formula(original_formula, annualization_period=None):
def annual_formula(population, period, parameters):
if period.start.month != 1 and (
- annualization_period is None
- or annualization_period.contains(period)
+ annualization_period is None or annualization_period.contains(period)
):
return population(variable.name, period.this_year.first_month)
if original_formula.__code__.co_argcount == 2:
diff --git a/policyengine_core/variables/variable.py b/policyengine_core/variables/variable.py
index 78ae0a835..dda76e0bb 100644
--- a/policyengine_core/variables/variable.py
+++ b/policyengine_core/variables/variable.py
@@ -155,9 +155,7 @@ def __init__(self, baseline_variable=None):
)
for property_name in INHERITED_ALLOWED_PROPERTIES:
- if not attr.get(property_name) and property_name in dir(
- self.__class__
- ):
+ if not attr.get(property_name) and property_name in dir(self.__class__):
attr[property_name] = getattr(self, property_name)
self.baseline_variable = baseline_variable
@@ -194,9 +192,7 @@ def __init__(self, baseline_variable=None):
allowed_type=self.value_type,
default=config.VALUE_TYPES[self.value_type].get("default"),
)
- self.entity = self.set(
- attr, "entity", required=True, setter=self.set_entity
- )
+ self.entity = self.set(attr, "entity", required=True, setter=self.set_entity)
self.definition_period = self.set(
attr,
"definition_period",
@@ -208,20 +204,14 @@ def __init__(self, baseline_variable=None):
periods.ETERNITY,
),
)
- self.label = self.set(
- attr, "label", allowed_type=str, setter=self.set_label
- )
+ self.label = self.set(attr, "label", allowed_type=str, setter=self.set_label)
if self.label is None:
- raise ValueError(
- 'Variable "{name}" has no label'.format(name=self.name)
- )
+ raise ValueError('Variable "{name}" has no label'.format(name=self.name))
self.end = self.set(attr, "end", allowed_type=str, setter=self.set_end)
self.reference = self.set(attr, "reference", setter=self.set_reference)
- self.cerfa_field = self.set(
- attr, "cerfa_field", allowed_type=(str, dict)
- )
+ self.cerfa_field = self.set(attr, "cerfa_field", allowed_type=(str, dict))
self.unit = self.set(attr, "unit", allowed_type=str)
self.quantity_type = self.set(
attr,
@@ -262,9 +252,7 @@ def __init__(self, baseline_variable=None):
attr,
"is_period_size_independent",
allowed_type=bool,
- default=config.VALUE_TYPES[self.value_type][
- "is_period_size_independent"
- ],
+ default=config.VALUE_TYPES[self.value_type]["is_period_size_independent"],
)
self.defined_for = self.set_defined_for(attr.pop("defined_for", None))
@@ -360,11 +348,7 @@ def set(
attribute_name, self.name
)
)
- if (
- required
- and allowed_values is not None
- and value not in allowed_values
- ):
+ if required and allowed_values is not None and value not in allowed_values:
raise ValueError(
"Invalid value '{}' for attribute '{}' in variable '{}'. Allowed values are '{}'.".format(
value, attribute_name, self.name, allowed_values
@@ -463,9 +447,7 @@ def set_formulas(self, formulas_attr):
# If the variable is reforming a baseline variable, keep the formulas from the latter when they are not overridden by new formulas.
if self.baseline_variable is not None:
- first_reform_formula_date = (
- formulas.peekitem(0)[0] if formulas else None
- )
+ first_reform_formula_date = formulas.peekitem(0)[0] if formulas else None
formulas.update(
{
baseline_start_date: baseline_formula
@@ -582,7 +564,9 @@ def get_formula(self, period=None):
return None
if period is None:
- return self.formulas.peekitem(index=0)[
+ return self.formulas.peekitem(
+ index=0
+ )[
1
] # peekitem gets the 1st key-value tuple (the oldest start_date and formula). Return the formula.
@@ -635,13 +619,17 @@ def check_set_value(self, value):
if self.value_type == datetime.date:
error_message = "Can't deal with date: '{}'.".format(value)
else:
- error_message = "Can't deal with value: expected type {}, received '{}'.".format(
- self.json_type, value
+ error_message = (
+ "Can't deal with value: expected type {}, received '{}'.".format(
+ self.json_type, value
+ )
)
raise ValueError(error_message)
except OverflowError:
- error_message = "Can't deal with value: '{}', it's too large for type '{}'.".format(
- value, self.json_type
+ error_message = (
+ "Can't deal with value: '{}', it's too large for type '{}'.".format(
+ value, self.json_type
+ )
)
raise ValueError(error_message)
diff --git a/pyproject.toml b/pyproject.toml
index 850ca667a..85f789b15 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -40,7 +40,7 @@ dependencies = [
[project.optional-dependencies]
dev = [
"build",
- "black",
+ "ruff>=0.9.0",
"coverage",
"furo",
"jupyter-book<1",
diff --git a/tests/core/commons/test_formulas.py b/tests/core/commons/test_formulas.py
index ecd3e1828..b0b885125 100644
--- a/tests/core/commons/test_formulas.py
+++ b/tests/core/commons/test_formulas.py
@@ -85,7 +85,5 @@ def test_concat_tuple_inputs():
with pytest.raises(TypeError, match="First argument must not be a tuple."):
commons.concat(("a", "b"), numpy.array(["c", "d"]))
- with pytest.raises(
- TypeError, match="Second argument must not be a tuple."
- ):
+ with pytest.raises(TypeError, match="Second argument must not be a tuple."):
commons.concat(numpy.array(["a", "b"]), ("c", "d"))
diff --git a/tests/core/data/test_dataset.py b/tests/core/data/test_dataset.py
index cd166a37d..d53703659 100644
--- a/tests/core/data/test_dataset.py
+++ b/tests/core/data/test_dataset.py
@@ -42,7 +42,6 @@ def test_atomic_write():
file.flush()
# Open the file before overwriting
with open(file.name, "r") as file_original:
-
atomic_write(Path(file.name), "NOPE\n".encode())
# Open file descriptor still points to the old node
@@ -55,9 +54,7 @@ def test_atomic_write():
def test_atomic_write_windows():
if sys.platform == "win32":
temp_dir = Path(tempfile.gettempdir())
- temp_files = [
- temp_dir / f"tempfile_{uuid.uuid4().hex}.tmp" for _ in range(5)
- ]
+ temp_files = [temp_dir / f"tempfile_{uuid.uuid4().hex}.tmp" for _ in range(5)]
managers = [WindowsAtomicFileManager(path) for path in temp_files]
@@ -81,9 +78,9 @@ def test_atomic_write_windows():
for i, results in enumerate(check_results):
for expected, actual in results:
- assert (
- expected == actual
- ), f"Mismatch in file {i}: expected {expected}, got {actual}"
+ assert expected == actual, (
+ f"Mismatch in file {i}: expected {expected}, got {actual}"
+ )
for temp_file in temp_files:
if temp_file.exists():
diff --git a/tests/core/enums/test_enum.py b/tests/core/enums/test_enum.py
index 0fe586f06..19983f835 100644
--- a/tests/core/enums/test_enum.py
+++ b/tests/core/enums/test_enum.py
@@ -18,9 +18,7 @@ class Sample(Enum):
DWORKIN = "dworkin"
sample_string_array = np.array(test_simple_array)
- sample_item_array = np.array(
- [Sample.MAXWELL, Sample.DWORKIN, Sample.MAXWELL]
- )
+ sample_item_array = np.array([Sample.MAXWELL, Sample.DWORKIN, Sample.MAXWELL])
explicit_s_array = np.array(test_simple_array, "S")
encoded_array = Sample.encode(sample_string_array)
@@ -73,11 +71,7 @@ class Sample(Enum):
error_message = str(exc_info.value)
# Should mention all unique invalid values
- assert (
- "FOO" in error_message
- or "BAR" in error_message
- or "BAZ" in error_message
- )
+ assert "FOO" in error_message or "BAR" in error_message or "BAZ" in error_message
def test_enum_encode_empty_string_raises_error():
diff --git a/tests/core/parameter_validation/test_parameter_clone.py b/tests/core/parameter_validation/test_parameter_clone.py
index 2069281dd..508d7f35d 100644
--- a/tests/core/parameter_validation/test_parameter_clone.py
+++ b/tests/core/parameter_validation/test_parameter_clone.py
@@ -36,10 +36,7 @@ def test_clone_parameter_node(tax_benefit_system):
assert clone is not node
assert clone.income_tax_rate is not node.income_tax_rate
- assert (
- clone.children["income_tax_rate"]
- is not node.children["income_tax_rate"]
- )
+ assert clone.children["income_tax_rate"] is not node.children["income_tax_rate"]
def test_clone_scale(tax_benefit_system):
diff --git a/tests/core/parameters/operations/test_nesting.py b/tests/core/parameters/operations/test_nesting.py
index a4a808798..8a1090ac1 100644
--- a/tests/core/parameters/operations/test_nesting.py
+++ b/tests/core/parameters/operations/test_nesting.py
@@ -85,8 +85,8 @@ class family_size(Variable):
family_sizes = np.array([1, 2, 3])
assert (
- system.parameters("2021-01-01").value_by_country_and_region[countries][
- regions
- ][family_sizes]
+ system.parameters("2021-01-01").value_by_country_and_region[countries][regions][
+ family_sizes
+ ]
== [1, 0, 0]
).all()
diff --git a/tests/core/parameters/operations/test_propagation.py b/tests/core/parameters/operations/test_propagation.py
index 5a52e13aa..c67829224 100644
--- a/tests/core/parameters/operations/test_propagation.py
+++ b/tests/core/parameters/operations/test_propagation.py
@@ -41,18 +41,18 @@ def test_parameter_interpolation():
propagated = propagate_parameter_metadata(root)
- assert (
- "example_field" in propagated.a.b.metadata
- ), "Metadata not passed down to direct child"
+ assert "example_field" in propagated.a.b.metadata, (
+ "Metadata not passed down to direct child"
+ )
- assert (
- "example_field" in propagated.a.c.d.e.metadata
- ), "Metadata not passed down to descendent"
+ assert "example_field" in propagated.a.c.d.e.metadata, (
+ "Metadata not passed down to descendent"
+ )
- assert (
- "some_existing_key" in propagated.a.c.d.metadata
- ), "Existing descendent metadata not preserved"
+ assert "some_existing_key" in propagated.a.c.d.metadata, (
+ "Existing descendent metadata not preserved"
+ )
- assert (
- propagated.a.c.d.metadata["example_field"] != "value_to_be_overwritten"
- ), "Existing descendent metadata field not overwritten"
+ assert propagated.a.c.d.metadata["example_field"] != "value_to_be_overwritten", (
+ "Existing descendent metadata field not overwritten"
+ )
diff --git a/tests/core/parameters/test_numpy2_structured_arrays.py b/tests/core/parameters/test_numpy2_structured_arrays.py
index 787df9039..b0a4be1be 100644
--- a/tests/core/parameters/test_numpy2_structured_arrays.py
+++ b/tests/core/parameters/test_numpy2_structured_arrays.py
@@ -93,9 +93,7 @@ def test_mismatched_structured_array_fields():
dtype_ny = np.dtype([("10", float), ("11", float)]) # Different fields
# Create data for 2 rows (age brackets)
- data_ca = np.array(
- [(100.0, 110.0, 120.0), (200.0, 210.0, 220.0)], dtype=dtype_ca
- )
+ data_ca = np.array([(100.0, 110.0, 120.0), (200.0, 210.0, 220.0)], dtype=dtype_ca)
data_ny = np.array([(300.0, 310.0), (400.0, 410.0)], dtype=dtype_ny)
# Create parent structure with both states
@@ -106,9 +104,7 @@ def test_mismatched_structured_array_fields():
)
parent_vector = parent_data.view(np.recarray)
- node = VectorialParameterNodeAtInstant(
- "states", parent_vector, "2024-01-01"
- )
+ node = VectorialParameterNodeAtInstant("states", parent_vector, "2024-01-01")
# Access both states - this triggers dtype mismatch handling
keys = np.array(["CA", "NY"])
diff --git a/tests/core/parameters_fancy_indexing/test_fancy_indexing.py b/tests/core/parameters_fancy_indexing/test_fancy_indexing.py
index 476a4f383..a3eaed648 100644
--- a/tests/core/parameters_fancy_indexing/test_fancy_indexing.py
+++ b/tests/core/parameters_fancy_indexing/test_fancy_indexing.py
@@ -31,9 +31,7 @@ def test_on_leaf():
def test_on_node():
- housing_occupancy_status = np.asarray(
- ["owner", "owner", "tenant", "tenant"]
- )
+ housing_occupancy_status = np.asarray(["owner", "owner", "tenant", "tenant"])
node = P.single[housing_occupancy_status]
assert_near(node.z1, [100, 100, 300, 300])
assert_near(node["z1"], [100, 100, 300, 300])
@@ -41,17 +39,13 @@ def test_on_node():
def test_double_fancy_indexing():
zone = np.asarray(["z1", "z2", "z2", "z1"])
- housing_occupancy_status = np.asarray(
- ["owner", "owner", "tenant", "tenant"]
- )
+ housing_occupancy_status = np.asarray(["owner", "owner", "tenant", "tenant"])
assert_near(P.single[housing_occupancy_status][zone], [100, 200, 400, 300])
def test_double_fancy_indexing_on_node():
family_status = np.asarray(["single", "couple", "single", "couple"])
- housing_occupancy_status = np.asarray(
- ["owner", "owner", "tenant", "tenant"]
- )
+ housing_occupancy_status = np.asarray(["owner", "owner", "tenant", "tenant"])
node = P[family_status][housing_occupancy_status]
assert_near(node.z1, [100, 500, 300, 700])
assert_near(node["z1"], [100, 500, 300, 700])
@@ -115,9 +109,6 @@ class TypesZone(Enum):
z2 = "Zone 2"
zone = np.asarray(
- [
- z.name
- for z in [TypesZone.z1, TypesZone.z2, TypesZone.z2, TypesZone.z1]
- ]
+ [z.name for z in [TypesZone.z1, TypesZone.z2, TypesZone.z2, TypesZone.z1]]
)
assert_near(P.single.owner[zone], [100, 200, 200, 100])
diff --git a/tests/core/tax_scales/test_marginal_rate_tax_scale.py b/tests/core/tax_scales/test_marginal_rate_tax_scale.py
index 106aa1728..cc73f399f 100644
--- a/tests/core/tax_scales/test_marginal_rate_tax_scale.py
+++ b/tests/core/tax_scales/test_marginal_rate_tax_scale.py
@@ -87,9 +87,7 @@ def test_calc():
def test_calc_without_round():
- tax_base = numpy.array(
- [200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005]
- )
+ tax_base = numpy.array([200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005])
tax_scale = taxscales.MarginalRateTaxScale()
tax_scale.add_bracket(0, 0)
tax_scale.add_bracket(100, 0.1)
@@ -104,9 +102,7 @@ def test_calc_without_round():
def test_calc_when_round_is_1():
- tax_base = numpy.array(
- [200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005]
- )
+ tax_base = numpy.array([200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005])
tax_scale = taxscales.MarginalRateTaxScale()
tax_scale.add_bracket(0, 0)
tax_scale.add_bracket(100, 0.1)
@@ -121,9 +117,7 @@ def test_calc_when_round_is_1():
def test_calc_when_round_is_2():
- tax_base = numpy.array(
- [200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005]
- )
+ tax_base = numpy.array([200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005])
tax_scale = taxscales.MarginalRateTaxScale()
tax_scale.add_bracket(0, 0)
tax_scale.add_bracket(100, 0.1)
@@ -138,9 +132,7 @@ def test_calc_when_round_is_2():
def test_calc_when_round_is_3():
- tax_base = numpy.array(
- [200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005]
- )
+ tax_base = numpy.array([200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005])
tax_scale = taxscales.MarginalRateTaxScale()
tax_scale.add_bracket(0, 0)
tax_scale.add_bracket(100, 0.1)
@@ -208,9 +200,7 @@ def test_inverse_scaled_marginal_tax_scales():
result = scaled_tax_scale.inverse()
- tools.assert_near(
- result.calc(scaled_net_tax_base), scaled_gross_tax_base, 1e-13
- )
+ tools.assert_near(result.calc(scaled_net_tax_base), scaled_gross_tax_base, 1e-13)
def test_to_average():
diff --git a/tests/core/test_axes.py b/tests/core/test_axes.py
index 35668b3fa..bf4eda1a4 100644
--- a/tests/core/test_axes.py
+++ b/tests/core/test_axes.py
@@ -343,6 +343,4 @@ def test_simulation_with_axes(tax_benefit_system):
assert simulation.get_array("salary", "2018-11") == pytest.approx(
[0, 0, 0, 0, 0, 0]
)
- assert simulation.get_array("rent", "2018-11") == pytest.approx(
- [0, 0, 3000, 0]
- )
+ assert simulation.get_array("rent", "2018-11") == pytest.approx([0, 0, 3000, 0])
diff --git a/tests/core/test_calculate_output.py b/tests/core/test_calculate_output.py
index 7211afb58..cc0b50638 100644
--- a/tests/core/test_calculate_output.py
+++ b/tests/core/test_calculate_output.py
@@ -50,20 +50,14 @@ def test_calculate_output_add(simulation):
simulation.set_input("variable_with_calculate_output_add", "2017-05", [20])
simulation.set_input("variable_with_calculate_output_add", "2017-12", [70])
tools.assert_near(
- simulation.calculate_output(
- "variable_with_calculate_output_add", 2017
- ),
+ simulation.calculate_output("variable_with_calculate_output_add", 2017),
100,
)
def test_calculate_output_divide(simulation):
- simulation.set_input(
- "variable_with_calculate_output_divide", 2017, [12000]
- )
+ simulation.set_input("variable_with_calculate_output_divide", 2017, [12000])
tools.assert_near(
- simulation.calculate_output(
- "variable_with_calculate_output_divide", "2017-06"
- ),
+ simulation.calculate_output("variable_with_calculate_output_divide", "2017-06"),
1000,
)
diff --git a/tests/core/test_countries.py b/tests/core/test_countries.py
index 83b82fa81..6fa70b457 100644
--- a/tests/core/test_countries.py
+++ b/tests/core/test_countries.py
@@ -11,25 +11,19 @@
PERIOD = periods.period("2016-01")
-@pytest.mark.parametrize(
- "simulation", [({"salary": 2000}, PERIOD)], indirect=True
-)
+@pytest.mark.parametrize("simulation", [({"salary": 2000}, PERIOD)], indirect=True)
def test_input_variable(simulation):
result = simulation.calculate("salary", PERIOD)
tools.assert_near(result, [2000], absolute_error_margin=0.01)
-@pytest.mark.parametrize(
- "simulation", [({"salary": 2000}, PERIOD)], indirect=True
-)
+@pytest.mark.parametrize("simulation", [({"salary": 2000}, PERIOD)], indirect=True)
def test_basic_calculation(simulation):
result = simulation.calculate("income_tax", PERIOD)
tools.assert_near(result, [300], absolute_error_margin=0.01)
-@pytest.mark.parametrize(
- "simulation", [({"salary": 24000}, PERIOD)], indirect=True
-)
+@pytest.mark.parametrize("simulation", [({"salary": 24000}, PERIOD)], indirect=True)
def test_calculate_add(simulation):
result = simulation.calculate_add("income_tax", PERIOD)
tools.assert_near(result, [3600], absolute_error_margin=0.01)
@@ -50,9 +44,7 @@ def test_calculate_divide(simulation):
tools.assert_near(result, [1000 / 12.0], absolute_error_margin=0.01)
-@pytest.mark.parametrize(
- "simulation", [({"salary": 20000}, PERIOD)], indirect=True
-)
+@pytest.mark.parametrize("simulation", [({"salary": 20000}, PERIOD)], indirect=True)
def test_bareme(simulation):
result = simulation.calculate("social_security_contribution", PERIOD)
expected = [0.02 * 6000 + 0.06 * 6400 + 0.12 * 7600]
@@ -68,9 +60,7 @@ def test_non_existing_variable(simulation):
@pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect=True)
def test_divide_option_on_month_defined_variable(simulation):
with pytest.raises(ValueError):
- simulation.person(
- "disposable_income", PERIOD, options=[populations.DIVIDE]
- )
+ simulation.person("disposable_income", PERIOD, options=[populations.DIVIDE])
@pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect=True)
@@ -78,24 +68,20 @@ def test_divide_option_with_complex_period(simulation):
quarter = PERIOD.last_3_months
with pytest.raises(ValueError) as error:
- simulation.household(
- "housing_tax", quarter, options=[populations.DIVIDE]
- )
+ simulation.household("housing_tax", quarter, options=[populations.DIVIDE])
error_message = str(error.value)
expected_words = ["DIVIDE", "one-year", "one-month", "period"]
for word in expected_words:
- assert (
- word in error_message
- ), f"Expected '{word}' in error message '{error_message}'"
+ assert word in error_message, (
+ f"Expected '{word}' in error message '{error_message}'"
+ )
def test_variable_with_reference(make_simulation, isolated_tax_benefit_system):
variables = {"salary": 4000}
- simulation = make_simulation(
- isolated_tax_benefit_system, variables, PERIOD
- )
+ simulation = make_simulation(isolated_tax_benefit_system, variables, PERIOD)
result = simulation.calculate("disposable_income", PERIOD)
@@ -108,9 +94,7 @@ def formula(household, period):
return household.empty_array()
isolated_tax_benefit_system.update_variable(disposable_income)
- simulation = make_simulation(
- isolated_tax_benefit_system, variables, PERIOD
- )
+ simulation = make_simulation(isolated_tax_benefit_system, variables, PERIOD)
result = simulation.calculate("disposable_income", PERIOD)
diff --git a/tests/core/test_cycles.py b/tests/core/test_cycles.py
index cee41bbac..a488765b1 100644
--- a/tests/core/test_cycles.py
+++ b/tests/core/test_cycles.py
@@ -146,9 +146,7 @@ def test_spiral_heuristic(simulation: Simulation, reference_period):
def test_spiral_cache(simulation, reference_period):
simulation.calculate("variable7", period=reference_period)
- cached_variable7 = simulation.get_holder("variable7").get_array(
- reference_period
- )
+ cached_variable7 = simulation.get_holder("variable7").get_array(reference_period)
assert cached_variable7 is not None
diff --git a/tests/core/test_dump_restore.py b/tests/core/test_dump_restore.py
index ed15994b1..5282cb72d 100644
--- a/tests/core/test_dump_restore.py
+++ b/tests/core/test_dump_restore.py
@@ -16,22 +16,14 @@ def test_dump(tax_benefit_system):
calculated_value = simulation.calculate("disposable_income", "2018-01")
simulation_dumper.dump_simulation(simulation, directory)
- simulation_2 = simulation_dumper.restore_simulation(
- directory, tax_benefit_system
- )
+ simulation_2 = simulation_dumper.restore_simulation(directory, tax_benefit_system)
# Check entities structure have been restored
testing.assert_array_equal(simulation.person.ids, simulation_2.person.ids)
- testing.assert_array_equal(
- simulation.person.count, simulation_2.person.count
- )
- testing.assert_array_equal(
- simulation.household.ids, simulation_2.household.ids
- )
- testing.assert_array_equal(
- simulation.household.count, simulation_2.household.count
- )
+ testing.assert_array_equal(simulation.person.count, simulation_2.person.count)
+ testing.assert_array_equal(simulation.household.ids, simulation_2.household.ids)
+ testing.assert_array_equal(simulation.household.count, simulation_2.household.count)
testing.assert_array_equal(
simulation.household.members_position,
simulation_2.household.members_position,
@@ -46,9 +38,7 @@ def test_dump(tax_benefit_system):
# Check calculated values are in cache
- disposable_income_holder = simulation_2.person.get_holder(
- "disposable_income"
- )
+ disposable_income_holder = simulation_2.person.get_holder("disposable_income")
cached_value = disposable_income_holder.get_array("2018-01")
assert cached_value is not None
testing.assert_array_equal(cached_value, calculated_value)
diff --git a/tests/core/test_entities.py b/tests/core/test_entities.py
index 5c2eb56b4..f9660e112 100644
--- a/tests/core/test_entities.py
+++ b/tests/core/test_entities.py
@@ -38,23 +38,17 @@
def new_simulation(tax_benefit_system, test_case, period=MONTH):
simulation_builder = SimulationBuilder()
simulation_builder.set_default_period(period)
- return simulation_builder.build_from_entities(
- tax_benefit_system, test_case
- )
+ return simulation_builder.build_from_entities(tax_benefit_system, test_case)
def test_role_index_and_positions(tax_benefit_system):
simulation = new_simulation(tax_benefit_system, TEST_CASE)
- tools.assert_near(
- simulation.household.members_entity_id, [0, 0, 0, 0, 1, 1]
- )
+ tools.assert_near(simulation.household.members_entity_id, [0, 0, 0, 0, 1, 1])
assert (
simulation.household.members_role
== [FIRST_PARENT, SECOND_PARENT, CHILD, CHILD, FIRST_PARENT, CHILD]
).all()
- tools.assert_near(
- simulation.household.members_position, [0, 1, 2, 3, 0, 1]
- )
+ tools.assert_near(simulation.household.members_position, [0, 1, 2, 3, 0, 1])
assert simulation.person.ids == [
"ind0",
"ind1",
@@ -213,9 +207,7 @@ def test_set_input_with_constructor(tax_benefit_system):
def test_has_role(tax_benefit_system):
simulation = new_simulation(tax_benefit_system, TEST_CASE)
individu = simulation.persons
- tools.assert_near(
- individu.has_role(CHILD), [False, False, True, True, False, True]
- )
+ tools.assert_near(individu.has_role(CHILD), [False, False, True, True, False, True])
def test_has_role_with_subrole(tax_benefit_system):
@@ -244,16 +236,10 @@ def test_project(tax_benefit_system):
housing_tax = household("housing_tax", YEAR)
projected_housing_tax = household.project(housing_tax)
- tools.assert_near(
- projected_housing_tax, [20000, 20000, 20000, 20000, 0, 0]
- )
+ tools.assert_near(projected_housing_tax, [20000, 20000, 20000, 20000, 0, 0])
- housing_tax_projected_on_parents = household.project(
- housing_tax, role=PARENT
- )
- tools.assert_near(
- housing_tax_projected_on_parents, [20000, 20000, 0, 0, 0, 0]
- )
+ housing_tax_projected_on_parents = household.project(housing_tax, role=PARENT)
+ tools.assert_near(housing_tax_projected_on_parents, [20000, 20000, 0, 0, 0, 0])
def test_implicit_projection(tax_benefit_system):
@@ -298,9 +284,7 @@ def test_any(tax_benefit_system):
tools.assert_near(has_household_member_with_age_inf_18, [True, False])
condition_age_2 = age > 18
- has_household_CHILD_with_age_sup_18 = household.any(
- condition_age_2, role=CHILD
- )
+ has_household_CHILD_with_age_sup_18 = household.any(condition_age_2, role=CHILD)
tools.assert_near(has_household_CHILD_with_age_sup_18, [False, True])
@@ -395,9 +379,7 @@ def test_partner(tax_benefit_system):
salary = persons("salary", period=MONTH)
- salary_second_parent = persons.value_from_partner(
- salary, persons.household, PARENT
- )
+ salary_second_parent = persons.value_from_partner(salary, persons.household, PARENT)
tools.assert_near(salary_second_parent, [1500, 1000, 0, 0, 0, 0])
@@ -425,27 +407,15 @@ def test_projectors_methods(tax_benefit_system):
household = simulation.household
person = simulation.person
- projected_vector = household.first_parent.has_role(
- entities.Household.FIRST_PARENT
- )
+ projected_vector = household.first_parent.has_role(entities.Household.FIRST_PARENT)
assert len(projected_vector) == 1 # Must be of a household dimension
salary_i = person.household.members("salary", "2017-01")
- assert (
- len(person.household.sum(salary_i)) == 2
- ) # Must be of a person dimension
- assert (
- len(person.household.max(salary_i)) == 2
- ) # Must be of a person dimension
- assert (
- len(person.household.min(salary_i)) == 2
- ) # Must be of a person dimension
- assert (
- len(person.household.all(salary_i)) == 2
- ) # Must be of a person dimension
- assert (
- len(person.household.any(salary_i)) == 2
- ) # Must be of a person dimension
+ assert len(person.household.sum(salary_i)) == 2 # Must be of a person dimension
+ assert len(person.household.max(salary_i)) == 2 # Must be of a person dimension
+ assert len(person.household.min(salary_i)) == 2 # Must be of a person dimension
+ assert len(person.household.all(salary_i)) == 2 # Must be of a person dimension
+ assert len(person.household.any(salary_i)) == 2 # Must be of a person dimension
assert (
len(household.first_parent.get_rank(household, salary_i)) == 1
) # Must be of a person dimension
@@ -498,9 +468,7 @@ def test_sum_following_bug_ipp_2(tax_benefit_system):
def test_get_memory_usage(tax_benefit_system):
test_case = deepcopy(situation_examples.single)
test_case["persons"]["Alicia"]["salary"] = {"2017-01": 0}
- simulation = SimulationBuilder().build_from_dict(
- tax_benefit_system, test_case
- )
+ simulation = SimulationBuilder().build_from_dict(tax_benefit_system, test_case)
simulation.calculate("disposable_income", "2017-01")
memory_usage = simulation.person.get_memory_usage(variables=["salary"])
assert memory_usage["total_nb_bytes"] > 0
@@ -539,12 +507,8 @@ def test_unordered_persons(tax_benefit_system):
household = simulation.household
person = simulation.person
- salary = household.members(
- "salary", "2016-01"
- ) # [ 3000, 0, 1500, 20, 500, 1000 ]
- accommodation_size = household(
- "accommodation_size", "2016-01"
- ) # [ 160, 60 ]
+ salary = household.members("salary", "2016-01") # [ 3000, 0, 1500, 20, 500, 1000 ]
+ accommodation_size = household("accommodation_size", "2016-01") # [ 160, 60 ]
# Aggregation/Projection persons -> entity
@@ -554,9 +518,7 @@ def test_unordered_persons(tax_benefit_system):
tools.assert_near(household.all(salary > 0), [False, True])
tools.assert_near(household.any(salary > 2000), [False, True])
tools.assert_near(household.first_person("salary", "2016-01"), [0, 3000])
- tools.assert_near(
- household.first_parent("salary", "2016-01"), [1000, 3000]
- )
+ tools.assert_near(household.first_parent("salary", "2016-01"), [1000, 3000])
tools.assert_near(household.second_parent("salary", "2016-01"), [1500, 0])
tools.assert_near(
person.value_from_partner(salary, person.household, PARENT),
diff --git a/tests/core/test_extensions.py b/tests/core/test_extensions.py
index b0a1f2855..8e647e8cb 100644
--- a/tests/core/test_extensions.py
+++ b/tests/core/test_extensions.py
@@ -8,9 +8,7 @@ def test_load_extension(tax_benefit_system):
tbs.load_extension("policyengine_core.extension_template")
assert tbs.get_variable("local_town_child_allowance") is not None
- assert (
- tax_benefit_system.get_variable("local_town_child_allowance") is None
- )
+ assert tax_benefit_system.get_variable("local_town_child_allowance") is None
def test_access_to_parameters(tax_benefit_system):
diff --git a/tests/core/test_formulas.py b/tests/core/test_formulas.py
index 3504283a3..57a1e56ec 100644
--- a/tests/core/test_formulas.py
+++ b/tests/core/test_formulas.py
@@ -86,9 +86,7 @@ def test_switch(simulation, month):
def test_multiplication(simulation, month):
- uses_multiplication = simulation.calculate(
- "uses_multiplication", period=month
- )
+ uses_multiplication = simulation.calculate("uses_multiplication", period=month)
assert isinstance(uses_multiplication, numpy.ndarray)
@@ -99,9 +97,7 @@ def test_broadcast_scalar(simulation, month):
def test_compare_multiplication_and_switch(simulation, month):
- uses_multiplication = simulation.calculate(
- "uses_multiplication", period=month
- )
+ uses_multiplication = simulation.calculate("uses_multiplication", period=month)
uses_switch = simulation.calculate("uses_switch", period=month)
assert numpy.all(uses_switch == uses_multiplication)
@@ -169,9 +165,7 @@ class projected_family_level_variable(Variable):
def formula(family, period):
return family.household("household_level_variable", period)
- system.add_variables(
- household_level_variable, projected_family_level_variable
- )
+ system.add_variables(household_level_variable, projected_family_level_variable)
simulation = SimulationBuilder().build_from_dict(
system,
@@ -191,6 +185,5 @@ def formula(family, period):
)
assert (
- simulation.calculate("projected_family_level_variable", "2021-01-01")
- == 5
+ simulation.calculate("projected_family_level_variable", "2021-01-01") == 5
).all()
diff --git a/tests/core/test_holders.py b/tests/core/test_holders.py
index 5d4e481ae..94aebe09d 100644
--- a/tests/core/test_holders.py
+++ b/tests/core/test_holders.py
@@ -49,9 +49,7 @@ def test_set_input_enum_int(couple):
def test_set_input_enum_item(couple):
simulation = couple
- status_occupancy = numpy.asarray(
- [housing.HousingOccupancyStatus.free_lodger]
- )
+ status_occupancy = numpy.asarray([housing.HousingOccupancyStatus.free_lodger])
simulation.household.get_holder("housing_occupancy_status").set_input(
period, status_occupancy
)
@@ -136,9 +134,7 @@ def test_get_memory_usage_with_trace(single):
def test_set_input_dispatch_by_period(single):
simulation = single
- variable = simulation.tax_benefit_system.get_variable(
- "housing_occupancy_status"
- )
+ variable = simulation.tax_benefit_system.get_variable("housing_occupancy_status")
entity = simulation.household
holder = Holder(variable, entity)
holders.set_input_dispatch_by_period(holder, periods.period(2019), "owner")
@@ -195,9 +191,7 @@ def test_cache_enum_on_disk(single):
simulation = single
simulation.memory_config = force_storage_on_disk
month = periods.period("2017-01")
- simulation.calculate(
- "housing_occupancy_status", month
- ) # First calculation
+ simulation.calculate("housing_occupancy_status", month) # First calculation
housing_occupancy_status = simulation.calculate(
"housing_occupancy_status", month
) # Read from cache
diff --git a/tests/core/test_microsimulation_person_accessor.py b/tests/core/test_microsimulation_person_accessor.py
index a5ef598a9..0f9cebd64 100644
--- a/tests/core/test_microsimulation_person_accessor.py
+++ b/tests/core/test_microsimulation_person_accessor.py
@@ -83,9 +83,7 @@ def test_person_accessor_kwargs_passed_correctly(self):
result_person = sim.person("salary", "2022-01")
# Call calculate() with use_weights=False directly
- result_calculate = sim.calculate(
- "salary", "2022-01", use_weights=False
- )
+ result_calculate = sim.calculate("salary", "2022-01", use_weights=False)
# Both should return numpy arrays with the same values
assert isinstance(result_person, np.ndarray)
diff --git a/tests/core/test_pandas3_compatibility.py b/tests/core/test_pandas3_compatibility.py
index 38a02f84c..8687bf225 100644
--- a/tests/core/test_pandas3_compatibility.py
+++ b/tests/core/test_pandas3_compatibility.py
@@ -67,9 +67,7 @@ def test_filled_array_with_pyarrow_string_dtype(self):
# PyArrow string dtype (proper way to create it)
arrow_string_dtype = pd.ArrowDtype(pa.string())
- result = population.filled_array(
- "test_value", dtype=arrow_string_dtype
- )
+ result = population.filled_array("test_value", dtype=arrow_string_dtype)
assert len(result) == 5
@@ -227,9 +225,5 @@ def test_is_pandas_extension_dtype(self):
assert isinstance(pd.StringDtype(), pd.api.extensions.ExtensionDtype)
# numpy dtypes are not
- assert not isinstance(
- np.dtype("float64"), pd.api.extensions.ExtensionDtype
- )
- assert not isinstance(
- np.dtype("object"), pd.api.extensions.ExtensionDtype
- )
+ assert not isinstance(np.dtype("float64"), pd.api.extensions.ExtensionDtype)
+ assert not isinstance(np.dtype("object"), pd.api.extensions.ExtensionDtype)
diff --git a/tests/core/test_parameters.py b/tests/core/test_parameters.py
index 2cb71e2d9..cbcdaf36f 100644
--- a/tests/core/test_parameters.py
+++ b/tests/core/test_parameters.py
@@ -14,9 +14,9 @@ def test_get_at_instant(tax_benefit_system):
parameters = tax_benefit_system.parameters
assert isinstance(parameters, ParameterNode), parameters
parameters_at_instant = parameters("2016-01-01")
- assert isinstance(
- parameters_at_instant, ParameterNodeAtInstant
- ), parameters_at_instant
+ assert isinstance(parameters_at_instant, ParameterNodeAtInstant), (
+ parameters_at_instant
+ )
assert parameters_at_instant.taxes.income_tax_rate == 0.15
assert parameters_at_instant.benefits.basic_income == 600
@@ -31,26 +31,20 @@ def test_param_values(tax_benefit_system):
for date, value in dated_values.items():
assert (
- tax_benefit_system.get_parameters_at_instant(
- date
- ).taxes.income_tax_rate
+ tax_benefit_system.get_parameters_at_instant(date).taxes.income_tax_rate
== value
)
def test_param_before_it_is_defined(tax_benefit_system):
with pytest.raises(ParameterNotFoundError):
- tax_benefit_system.get_parameters_at_instant(
- "1997-12-31"
- ).taxes.income_tax_rate
+ tax_benefit_system.get_parameters_at_instant("1997-12-31").taxes.income_tax_rate
# The placeholder should have no effect on the parameter computation
def test_param_with_placeholder(tax_benefit_system):
assert (
- tax_benefit_system.get_parameters_at_instant(
- "2018-01-01"
- ).taxes.income_tax_rate
+ tax_benefit_system.get_parameters_at_instant("2018-01-01").taxes.income_tax_rate
== 0.15
)
@@ -94,8 +88,7 @@ def test_parameter_repr(tax_benefit_system):
def test_parameters_metadata(tax_benefit_system):
parameter = tax_benefit_system.parameters.benefits.basic_income
assert (
- parameter.metadata["reference"]
- == "https://law.gov.example/basic-income/amount"
+ parameter.metadata["reference"] == "https://law.gov.example/basic-income/amount"
)
assert parameter.metadata["unit"] == "currency-EUR"
assert (
@@ -126,8 +119,7 @@ def test_parameter_documentation(tax_benefit_system):
def test_get_descendants(tax_benefit_system):
all_parameters = {
- parameter.name
- for parameter in tax_benefit_system.parameters.get_descendants()
+ parameter.name for parameter in tax_benefit_system.parameters.get_descendants()
}
assert all_parameters.issuperset(
{"taxes", "taxes.housing_tax", "taxes.housing_tax.minimal_amount"}
diff --git a/tests/core/test_periods.py b/tests/core/test_periods.py
index 9613c73c9..693d28933 100644
--- a/tests/core/test_periods.py
+++ b/tests/core/test_periods.py
@@ -133,17 +133,11 @@ def test_leap_month_size_in_days():
def test_3_month_size_in_days():
- assert (
- Period(("month", Instant((2013, 1, 3)), 3)).size_in_days
- == 31 + 28 + 31
- )
+ assert Period(("month", Instant((2013, 1, 3)), 3)).size_in_days == 31 + 28 + 31
def test_leap_3_month_size_in_days():
- assert (
- Period(("month", Instant((2012, 1, 3)), 3)).size_in_days
- == 31 + 29 + 31
- )
+ assert Period(("month", Instant((2012, 1, 3)), 3)).size_in_days == 31 + 29 + 31
def test_year_size_in_days():
diff --git a/tests/core/test_projectors.py b/tests/core/test_projectors.py
index 9b9c8adb5..b1795336e 100644
--- a/tests/core/test_projectors.py
+++ b/tests/core/test_projectors.py
@@ -176,9 +176,7 @@ def formula(person, period):
)
assert (
- simulation.calculate(
- "projected_enum_variable", "2021-01-01"
- ).decode_to_str()
+ simulation.calculate("projected_enum_variable", "2021-01-01").decode_to_str()
== np.array(["SECOND_OPTION"] * 3)
).all()
@@ -243,9 +241,7 @@ class person_enum_variable(Variable):
system,
{
"people": {
- "person1": {
- "person_enum_variable": {"ETERNITY": "SECOND_OPTION"}
- },
+ "person1": {"person_enum_variable": {"ETERNITY": "SECOND_OPTION"}},
"person2": {},
"person3": {},
},
@@ -337,9 +333,7 @@ class decoded_projected_family_level_variable(Variable):
label = "decoded projected family level variable"
def formula(family, period):
- return family.household(
- "household_level_variable", period
- ).decode_to_str()
+ return family.household("household_level_variable", period).decode_to_str()
system.add_variables(
household_level_variable,
@@ -371,8 +365,6 @@ def formula(family, period):
== np.array(["SECOND_OPTION"])
).all()
assert (
- simulation.calculate(
- "decoded_projected_family_level_variable", "2021-01-01"
- )
+ simulation.calculate("decoded_projected_family_level_variable", "2021-01-01")
== np.array(["SECOND_OPTION"])
).all()
diff --git a/tests/core/test_reforms.py b/tests/core/test_reforms.py
index 68eaf9bea..a6075ed9d 100644
--- a/tests/core/test_reforms.py
+++ b/tests/core/test_reforms.py
@@ -37,17 +37,13 @@ def test_formula_neutralization(make_simulation, tax_benefit_system):
basic_income = simulation.calculate("basic_income", period=period)
assert_near(basic_income, 600)
- disposable_income = simulation.calculate(
- "disposable_income", period=period
- )
+ disposable_income = simulation.calculate("disposable_income", period=period)
assert disposable_income > 0
reform_simulation = make_simulation(reform, {}, period)
reform_simulation.debug = True
- basic_income_reform = reform_simulation.calculate(
- "basic_income", period="2013-01"
- )
+ basic_income_reform = reform_simulation.calculate("basic_income", period="2013-01")
assert_near(basic_income_reform, 0, absolute_error_margin=0)
disposable_income_reform = reform_simulation.calculate(
"disposable_income", period=period
@@ -83,9 +79,7 @@ def apply(self):
reform = test_salary_neutralization(tax_benefit_system)
with warnings.catch_warnings(record=True) as raised_warnings:
- reform_simulation = make_simulation(
- reform, {"salary": [1200, 1000]}, period
- )
+ reform_simulation = make_simulation(reform, {"salary": [1200, 1000]}, period)
assert (
"You cannot set a value for the variable"
in raised_warnings[0].message.args[0]
@@ -101,9 +95,7 @@ def apply(self):
assert_near(disposable_income_reform, [600, 600])
-def test_permanent_variable_neutralization(
- make_simulation, tax_benefit_system
-):
+def test_permanent_variable_neutralization(make_simulation, tax_benefit_system):
class test_date_naissance_neutralization(Reform):
def apply(self):
self.neutralize_variable("birth")
@@ -115,9 +107,7 @@ def apply(self):
reform.base_tax_benefit_system, {"birth": "1980-01-01"}, period
)
with warnings.catch_warnings(record=True) as raised_warnings:
- reform_simulation = make_simulation(
- reform, {"birth": "1980-01-01"}, period
- )
+ reform_simulation = make_simulation(reform, {"birth": "1980-01-01"}, period)
assert (
"You cannot set a value for the variable"
in raised_warnings[0].message.args[0]
@@ -145,9 +135,7 @@ def apply(self):
assert tax_benefit_system.get_variable("new_variable") is None
reform_simulation = make_simulation(reform, {}, 2013)
reform_simulation.debug = True
- new_variable1 = reform_simulation.calculate(
- "new_variable", period="2013-01"
- )
+ new_variable1 = reform_simulation.calculate("new_variable", period="2013-01")
assert_near(new_variable1, 10, absolute_error_margin=0)
@@ -192,9 +180,7 @@ def apply(self):
reform = test_update_variable(tax_benefit_system)
disposable_income_reform = reform.get_variable("disposable_income")
- disposable_income_baseline = tax_benefit_system.get_variable(
- "disposable_income"
- )
+ disposable_income_baseline = tax_benefit_system.get_variable("disposable_income")
assert disposable_income_reform is not None
assert (
@@ -301,14 +287,9 @@ def apply(self):
assert reform_variable.value_type == baseline_variable.value_type
assert reform_variable.entity == baseline_variable.entity
assert reform_variable.label == baseline_variable.label
- assert (
- reform_variable.definition_period
- == baseline_variable.definition_period
- )
+ assert reform_variable.definition_period == baseline_variable.definition_period
assert reform_variable.set_input == baseline_variable.set_input
- assert (
- reform_variable.calculate_output == baseline_variable.calculate_output
- )
+ assert reform_variable.calculate_output == baseline_variable.calculate_output
def test_formulas_removal(tax_benefit_system):
diff --git a/tests/core/test_simulation_builder.py b/tests/core/test_simulation_builder.py
index d9c064be7..f838d6b3c 100644
--- a/tests/core/test_simulation_builder.py
+++ b/tests/core/test_simulation_builder.py
@@ -68,8 +68,7 @@ def test_build_default_simulation(tax_benefit_system):
assert one_person_simulation.household.count == 1
assert one_person_simulation.household.members_entity_id == [0]
assert (
- one_person_simulation.household.members_role
- == entities.Household.FIRST_PARENT
+ one_person_simulation.household.members_role == entities.Household.FIRST_PARENT
)
several_persons_simulation = SimulationBuilder().build_default_simulation(
@@ -116,9 +115,7 @@ def test_add_person_entity_with_values(persons):
persons_json = {"Alicia": {"salary": {"2018-11": 3000}}, "Javier": {}}
simulation_builder = SimulationBuilder()
simulation_builder.add_person_entity(persons, persons_json)
- tools.assert_near(
- simulation_builder.get_input("salary", "2018-11"), [3000, 0]
- )
+ tools.assert_near(simulation_builder.get_input("salary", "2018-11"), [3000, 0])
def test_add_person_values_with_default_period(persons):
@@ -126,9 +123,7 @@ def test_add_person_values_with_default_period(persons):
simulation_builder = SimulationBuilder()
simulation_builder.set_default_period("2018-11")
simulation_builder.add_person_entity(persons, persons_json)
- tools.assert_near(
- simulation_builder.get_input("salary", "2018-11"), [3000, 0]
- )
+ tools.assert_near(simulation_builder.get_input("salary", "2018-11"), [3000, 0])
def test_add_person_values_with_default_period_old_syntax(persons):
@@ -136,9 +131,7 @@ def test_add_person_values_with_default_period_old_syntax(persons):
simulation_builder = SimulationBuilder()
simulation_builder.set_default_period("month:2018-11")
simulation_builder.add_person_entity(persons, persons_json)
- tools.assert_near(
- simulation_builder.get_input("salary", "2018-11"), [3000, 0]
- )
+ tools.assert_near(simulation_builder.get_input("salary", "2018-11"), [3000, 0])
def test_add_group_entity(households):
@@ -158,9 +151,12 @@ def test_add_group_entity(households):
"Household_2",
]
assert simulation_builder.get_memberships("households") == [0, 0, 1, 1]
- assert [
- role.key for role in simulation_builder.get_roles("households")
- ] == ["parent", "parent", "child", "parent"]
+ assert [role.key for role in simulation_builder.get_roles("households")] == [
+ "parent",
+ "parent",
+ "child",
+ "parent",
+ ]
def test_add_group_entity_loose_syntax(households):
@@ -180,9 +176,12 @@ def test_add_group_entity_loose_syntax(households):
"Household_2",
]
assert simulation_builder.get_memberships("households") == [0, 0, 1, 1]
- assert [
- role.key for role in simulation_builder.get_roles("households")
- ] == ["parent", "parent", "child", "parent"]
+ assert [role.key for role in simulation_builder.get_roles("households")] == [
+ "parent",
+ "parent",
+ "child",
+ "parent",
+ ]
def test_add_variable_value(persons):
@@ -288,9 +287,7 @@ def test_fail_on_date_parsing(persons, date_variable):
)
assert excinfo.value.error == {
"persons": {
- "Alicia": {
- "datevar": {"2018-11": "Can't deal with date: '2019-02-30'."}
- }
+ "Alicia": {"datevar": {"2018-11": "Can't deal with date: '2019-02-30'."}}
}
}
@@ -311,9 +308,7 @@ def test_finalize_person_entity(persons):
simulation_builder.add_person_entity(persons, persons_json)
population = Population(persons)
simulation_builder.finalize_variables_init(population)
- tools.assert_near(
- population.get_holder("salary").get_array("2018-11"), [3000, 0]
- )
+ tools.assert_near(population.get_holder("salary").get_array("2018-11"), [3000, 0])
assert population.count == 2
assert population.ids == ["Alicia", "Javier"]
@@ -324,9 +319,7 @@ def test_canonicalize_period_keys(persons):
simulation_builder.add_person_entity(persons, persons_json)
population = Population(persons)
simulation_builder.finalize_variables_init(population)
- tools.assert_near(
- population.get_holder("salary").get_array("2018-12"), [100]
- )
+ tools.assert_near(population.get_holder("salary").get_array("2018-12"), [100])
def test_finalize_households(tax_benefit_system):
@@ -470,9 +463,7 @@ def test_nb_persons_in_households(tax_benefit_system):
simulation_builder = SimulationBuilder()
simulation_builder.create_entities(tax_benefit_system)
simulation_builder.declare_person_entity("person", persons_ids)
- household_instance = simulation_builder.declare_entity(
- "household", households_ids
- )
+ household_instance = simulation_builder.declare_entity("household", households_ids)
simulation_builder.join_with_persons(
household_instance, persons_households, ["first_parent"] * 5
)
@@ -490,9 +481,7 @@ def test_nb_persons_no_role(tax_benefit_system):
simulation_builder = SimulationBuilder()
simulation_builder.create_entities(tax_benefit_system)
simulation_builder.declare_person_entity("person", persons_ids)
- household_instance = simulation_builder.declare_entity(
- "household", households_ids
- )
+ household_instance = simulation_builder.declare_entity("household", households_ids)
simulation_builder.join_with_persons(
household_instance, persons_households, ["first_parent"] * 5
@@ -523,9 +512,7 @@ def test_nb_persons_by_role(tax_benefit_system):
simulation_builder = SimulationBuilder()
simulation_builder.create_entities(tax_benefit_system)
simulation_builder.declare_person_entity("person", persons_ids)
- household_instance = simulation_builder.declare_entity(
- "household", households_ids
- )
+ household_instance = simulation_builder.declare_entity("household", households_ids)
simulation_builder.join_with_persons(
household_instance, persons_households, persons_households_roles
@@ -547,9 +534,7 @@ def test_integral_roles(tax_benefit_system):
simulation_builder = SimulationBuilder()
simulation_builder.create_entities(tax_benefit_system)
simulation_builder.declare_person_entity("person", persons_ids)
- household_instance = simulation_builder.declare_entity(
- "household", households_ids
- )
+ household_instance = simulation_builder.declare_entity("household", households_ids)
simulation_builder.join_with_persons(
household_instance, persons_households, persons_households_roles
@@ -579,9 +564,7 @@ def test_from_person_variable_to_group(tax_benefit_system):
simulation_builder.create_entities(tax_benefit_system)
simulation_builder.declare_person_entity("person", persons_ids)
- household_instance = simulation_builder.declare_entity(
- "household", households_ids
- )
+ household_instance = simulation_builder.declare_entity("household", households_ids)
simulation_builder.join_with_persons(
household_instance, persons_households, ["first_parent"] * 5
)
@@ -592,9 +575,7 @@ def test_from_person_variable_to_group(tax_benefit_system):
total_taxes = simulation.calculate("total_taxes", period)
assert total_taxes == pytest.approx(households_rents)
- assert total_taxes / simulation.calculate("rent", period) == pytest.approx(
- 1
- )
+ assert total_taxes / simulation.calculate("rent", period) == pytest.approx(1)
def test_simulation(tax_benefit_system):
@@ -622,9 +603,7 @@ def test_vectorial_input(tax_benefit_system):
tax_benefit_system, test_runner.yaml.safe_load(input_yaml)
)
- tools.assert_near(
- simulation.get_array("salary", "2016-10"), [12000, 20000]
- )
+ tools.assert_near(simulation.get_array("salary", "2016-10"), [12000, 20000])
simulation.calculate("income_tax", "2016-10")
simulation.calculate("total_taxes", "2016-10")
diff --git a/tests/core/test_simulations.py b/tests/core/test_simulations.py
index 9a9c5f632..c6cf3140c 100644
--- a/tests/core/test_simulations.py
+++ b/tests/core/test_simulations.py
@@ -9,9 +9,7 @@
def test_calculate_full_tracer(tax_benefit_system):
- simulation = SimulationBuilder().build_default_simulation(
- tax_benefit_system
- )
+ simulation = SimulationBuilder().build_default_simulation(tax_benefit_system)
simulation.trace = True
simulation.calculate("income_tax", "2017-01")
@@ -26,9 +24,7 @@ def test_calculate_full_tracer(tax_benefit_system):
def test_get_entity_not_found(tax_benefit_system):
- simulation = SimulationBuilder().build_default_simulation(
- tax_benefit_system
- )
+ simulation = SimulationBuilder().build_default_simulation(tax_benefit_system)
assert simulation.get_entity(plural="no_such_entities") is None
@@ -60,9 +56,7 @@ def test_clone(tax_benefit_system):
def test_get_memory_usage(tax_benefit_system):
- simulation = SimulationBuilder().build_from_entities(
- tax_benefit_system, single
- )
+ simulation = SimulationBuilder().build_from_entities(tax_benefit_system, single)
simulation.calculate("disposable_income", "2017-01")
memory_usage = simulation.get_memory_usage(variables=["salary"])
assert memory_usage["total_nb_bytes"] > 0
@@ -70,14 +64,10 @@ def test_get_memory_usage(tax_benefit_system):
def test_macro_cache(tax_benefit_system):
- simulation = SimulationBuilder().build_from_entities(
- tax_benefit_system, single
- )
+ simulation = SimulationBuilder().build_from_entities(tax_benefit_system, single)
cache = SimulationMacroCache(tax_benefit_system)
- assert cache.core_version == importlib.metadata.version(
- "policyengine-core"
- )
+ assert cache.core_version == importlib.metadata.version("policyengine-core")
assert cache.country_version == "0.0.0"
cache.set_cache_path(
diff --git a/tests/core/test_tracers.py b/tests/core/test_tracers.py
index 66410692f..e653f39be 100644
--- a/tests/core/test_tracers.py
+++ b/tests/core/test_tracers.py
@@ -60,9 +60,7 @@ def tracer():
def test_stack_one_level(tracer):
tracer.record_calculation_start("a", 2017)
assert len(tracer.stack) == 1
- assert tracer.stack == [
- {"name": "a", "period": 2017, "branch_name": "default"}
- ]
+ assert tracer.stack == [{"name": "a", "period": 2017, "branch_name": "default"}]
tracer.record_calculation_end()
assert tracer.stack == []
@@ -80,9 +78,7 @@ def test_stack_two_levels(tracer):
tracer.record_calculation_end()
assert len(tracer.stack) == 1
- assert tracer.stack == [
- {"name": "a", "period": 2017, "branch_name": "default"}
- ]
+ assert tracer.stack == [{"name": "a", "period": 2017, "branch_name": "default"}]
@mark.parametrize("tracer", [SimpleTracer(), FullTracer()])
@@ -213,9 +209,7 @@ def test_flat_trace(tracer):
trace = tracer.get_flat_trace()
assert len(trace) == 2
- assert trace["a<2019, (default)>"]["dependencies"] == [
- "b<2019, (default)>"
- ]
+ assert trace["a<2019, (default)>"]["dependencies"] == ["b<2019, (default)>"]
assert trace["b<2019, (default)>"]["dependencies"] == []
@@ -260,9 +254,7 @@ def test_flat_trace_with_cache(tracer):
trace = tracer.get_flat_trace()
- assert trace["b<2019, (default)>"]["dependencies"] == [
- "c<2019, (default)>"
- ]
+ assert trace["b<2019, (default)>"]["dependencies"] == ["c<2019, (default)>"]
def test_calculation_time():
@@ -344,9 +336,7 @@ def test_rounding():
node_a.start = 1.23456789
node_a.end = node_a.start + 1.23456789e-03
- assert (
- node_a.calculation_time() == 1.235e-03
- ) # Keep only 3 significant figures
+ assert node_a.calculation_time() == 1.235e-03 # Keep only 3 significant figures
node_b = TraceNode("b", 2017)
node_b.start = node_a.start
@@ -428,10 +418,7 @@ def test_log_aggregate_with_strings(tracer):
tracer._exit_calculation()
lines = tracer.computation_log.lines(aggregate=True)
- assert (
- lines[0]
- == " A<2017, (default)> = {'avg': '?', 'max': '?', 'min': '?'}"
- )
+ assert lines[0] == " A<2017, (default)> = {'avg': '?', 'max': '?', 'min': '?'}"
def test_log_max_depth(tracer):
diff --git a/tests/core/test_yaml.py b/tests/core/test_yaml.py
index acb740d10..c6523cade 100644
--- a/tests/core/test_yaml.py
+++ b/tests/core/test_yaml.py
@@ -27,40 +27,31 @@ def test_success(tax_benefit_system):
def test_fail(tax_benefit_system):
- assert (
- run_yaml_test(tax_benefit_system, "test_failure.yaml")
- == EXIT_TESTSFAILED
- )
+ assert run_yaml_test(tax_benefit_system, "test_failure.yaml") == EXIT_TESTSFAILED
def test_relative_error_margin_success(tax_benefit_system):
assert (
- run_yaml_test(tax_benefit_system, "test_relative_error_margin.yaml")
- == EXIT_OK
+ run_yaml_test(tax_benefit_system, "test_relative_error_margin.yaml") == EXIT_OK
)
def test_relative_error_margin_fail(tax_benefit_system):
assert (
- run_yaml_test(
- tax_benefit_system, "failing_test_relative_error_margin.yaml"
- )
+ run_yaml_test(tax_benefit_system, "failing_test_relative_error_margin.yaml")
== EXIT_TESTSFAILED
)
def test_absolute_error_margin_success(tax_benefit_system):
assert (
- run_yaml_test(tax_benefit_system, "test_absolute_error_margin.yaml")
- == EXIT_OK
+ run_yaml_test(tax_benefit_system, "test_absolute_error_margin.yaml") == EXIT_OK
)
def test_absolute_error_margin_fail(tax_benefit_system):
assert (
- run_yaml_test(
- tax_benefit_system, "failing_test_absolute_error_margin.yaml"
- )
+ run_yaml_test(tax_benefit_system, "failing_test_absolute_error_margin.yaml")
== EXIT_TESTSFAILED
)
@@ -71,28 +62,19 @@ def test_run_tests_from_directory(tax_benefit_system):
def test_with_reform(tax_benefit_system):
- assert (
- run_yaml_test(tax_benefit_system, "test_with_reform.yaml") == EXIT_OK
- )
+ assert run_yaml_test(tax_benefit_system, "test_with_reform.yaml") == EXIT_OK
def test_with_extension(tax_benefit_system):
- assert (
- run_yaml_test(tax_benefit_system, "test_with_extension.yaml")
- == EXIT_OK
- )
+ assert run_yaml_test(tax_benefit_system, "test_with_extension.yaml") == EXIT_OK
def test_with_anchors(tax_benefit_system):
- assert (
- run_yaml_test(tax_benefit_system, "test_with_anchors.yaml") == EXIT_OK
- )
+ assert run_yaml_test(tax_benefit_system, "test_with_anchors.yaml") == EXIT_OK
def test_run_tests_from_directory_fail(tax_benefit_system):
- assert (
- run_yaml_test(tax_benefit_system, yaml_tests_dir) == EXIT_TESTSFAILED
- )
+ assert run_yaml_test(tax_benefit_system, yaml_tests_dir) == EXIT_TESTSFAILED
def test_name_filter(tax_benefit_system):
diff --git a/tests/core/tools/test_google_cloud.py b/tests/core/tools/test_google_cloud.py
index e9ba42ea1..ccb0403ba 100644
--- a/tests/core/tools/test_google_cloud.py
+++ b/tests/core/tools/test_google_cloud.py
@@ -8,17 +8,13 @@ def test_basic_url(self):
assert (bucket, file_path, version) == ("my-bucket", "file.h5", None)
def test_subdirectory_url(self):
- bucket, file_path, version = parse_gs_url(
- "gs://my-bucket/data/2024/file.h5"
- )
+ bucket, file_path, version = parse_gs_url("gs://my-bucket/data/2024/file.h5")
assert bucket == "my-bucket"
assert file_path == "data/2024/file.h5"
assert version is None
def test_url_with_version(self):
- bucket, file_path, version = parse_gs_url(
- "gs://my-bucket/file.h5@12345"
- )
+ bucket, file_path, version = parse_gs_url("gs://my-bucket/file.h5@12345")
assert (file_path, version) == ("file.h5", "12345")
def test_subdirectory_with_version(self):
@@ -29,9 +25,7 @@ def test_subdirectory_with_version(self):
assert (file_path, version) == ("path/to/file.h5", "67890")
def test_deep_subdirectory(self):
- bucket, file_path, version = parse_gs_url(
- "gs://my-bucket/a/b/c/d/e/file.h5"
- )
+ bucket, file_path, version = parse_gs_url("gs://my-bucket/a/b/c/d/e/file.h5")
assert file_path == "a/b/c/d/e/file.h5"
def test_invalid_url_no_gs_prefix(self):
diff --git a/tests/core/tools/test_hugging_face.py b/tests/core/tools/test_hugging_face.py
index d34175a95..2a68682e2 100644
--- a/tests/core/tools/test_hugging_face.py
+++ b/tests/core/tools/test_hugging_face.py
@@ -26,9 +26,7 @@ def test_download_public_repo(self):
) as mock_model_info:
# Create mock ModelInfo object emulating public repo
test_id = 0
- mock_model_info.return_value = ModelInfo(
- id=test_id, private=False
- )
+ mock_model_info.return_value = ModelInfo(id=test_id, private=False)
download_huggingface_dataset(
test_repo, test_filename, test_version, test_dir
@@ -114,9 +112,7 @@ class TestGetOrPromptHfToken:
def test_get_token_from_environment(self):
"""Test retrieving token when it exists in environment variables"""
test_token = "test_token_123"
- with patch.dict(
- os.environ, {"HUGGING_FACE_TOKEN": test_token}, clear=True
- ):
+ with patch.dict(os.environ, {"HUGGING_FACE_TOKEN": test_token}, clear=True):
result = get_or_prompt_hf_token()
assert result == test_token
@@ -193,9 +189,7 @@ def test_environment_variable_persistence(self):
class TestParseHfUrl:
def test_basic_url(self):
- owner, repo, file_path, version = parse_hf_url(
- "hf://owner/repo/file.h5"
- )
+ owner, repo, file_path, version = parse_hf_url("hf://owner/repo/file.h5")
assert (owner, repo, file_path, version) == (
"owner",
"repo",
@@ -213,9 +207,7 @@ def test_subdirectory_url(self):
assert version is None
def test_url_with_version(self):
- owner, repo, file_path, version = parse_hf_url(
- "hf://owner/repo/file.h5@v1.0"
- )
+ owner, repo, file_path, version = parse_hf_url("hf://owner/repo/file.h5@v1.0")
assert (file_path, version) == ("file.h5", "v1.0")
def test_subdirectory_with_version(self):
diff --git a/tests/core/tools/test_runner/test_yaml_runner.py b/tests/core/tools/test_runner/test_yaml_runner.py
index 48ddf65dd..72e3faadb 100644
--- a/tests/core/tools/test_runner/test_yaml_runner.py
+++ b/tests/core/tools/test_runner/test_yaml_runner.py
@@ -119,21 +119,15 @@ def test_reforms_formats():
baseline = TaxBenefitSystem()
lonely_reform_tbs = _get_tax_benefit_system(baseline, "lonely_reform", [])
- list_lonely_reform_tbs = _get_tax_benefit_system(
- baseline, ["lonely_reform"], []
- )
+ list_lonely_reform_tbs = _get_tax_benefit_system(baseline, ["lonely_reform"], [])
assert lonely_reform_tbs == list_lonely_reform_tbs
def test_reforms_order():
baseline = TaxBenefitSystem()
- abba_tax_benefit_system = _get_tax_benefit_system(
- baseline, ["ab", "ba"], []
- )
- baab_tax_benefit_system = _get_tax_benefit_system(
- baseline, ["ba", "ab"], []
- )
+ abba_tax_benefit_system = _get_tax_benefit_system(baseline, ["ab", "ba"], [])
+ baab_tax_benefit_system = _get_tax_benefit_system(baseline, ["ba", "ab"], [])
assert (
abba_tax_benefit_system != baab_tax_benefit_system
) # keep reforms order in cache
@@ -150,9 +144,7 @@ def test_tax_benefit_systems_with_extensions_cache():
def test_extensions_formats():
baseline = TaxBenefitSystem()
- lonely_extension_tbs = _get_tax_benefit_system(
- baseline, [], "lonely_extension"
- )
+ lonely_extension_tbs = _get_tax_benefit_system(baseline, [], "lonely_extension")
list_lonely_extension_tbs = _get_tax_benefit_system(
baseline, [], ["lonely_extension"]
)
diff --git a/tests/core/variables/test_annualize.py b/tests/core/variables/test_annualize.py
index 61f8168b8..eff500e4e 100644
--- a/tests/core/variables/test_annualize.py
+++ b/tests/core/variables/test_annualize.py
@@ -46,8 +46,7 @@ def test_without_annualize(monthly_variable):
person = PopulationMock(monthly_variable)
yearly_sum = sum(
- person("monthly_variable", month)
- for month in period.get_subperiods(MONTH)
+ person("monthly_variable", month) for month in period.get_subperiods(MONTH)
)
assert monthly_variable.calculation_count == 11
@@ -61,8 +60,7 @@ def test_with_annualize(monthly_variable):
person = PopulationMock(annualized_variable)
yearly_sum = sum(
- person("monthly_variable", month)
- for month in period.get_subperiods(MONTH)
+ person("monthly_variable", month) for month in period.get_subperiods(MONTH)
)
assert monthly_variable.calculation_count == 0
@@ -78,8 +76,7 @@ def test_with_partial_annualize(monthly_variable):
person = PopulationMock(annualized_variable)
yearly_sum = sum(
- person("monthly_variable", month)
- for month in period.get_subperiods(MONTH)
+ person("monthly_variable", month) for month in period.get_subperiods(MONTH)
)
assert monthly_variable.calculation_count == 11
diff --git a/tests/core/variables/test_variables.py b/tests/core/variables/test_variables.py
index 12e5e8acb..d98a5232b 100644
--- a/tests/core/variables/test_variables.py
+++ b/tests/core/variables/test_variables.py
@@ -42,9 +42,7 @@ def vectorize(individu, number):
return individu.filled_array(number)
-def check_error_at_add_variable(
- tax_benefit_system, variable, error_message_prefix
-):
+def check_error_at_add_variable(tax_benefit_system, variable, error_message_prefix):
try:
tax_benefit_system.add_variable(variable)
except ValueError as e:
@@ -107,9 +105,7 @@ def test_variable__strange_end_attribute():
)
# Check that Error at variable adding prevents it from registration in the taxbenefitsystem.
- assert not tax_benefit_system.variables.get(
- "variable__strange_end_attribute"
- )
+ assert not tax_benefit_system.variables.get("variable__strange_end_attribute")
# end, no formula
@@ -136,12 +132,8 @@ def test_variable__end_attribute_set_input(simulation):
month_after_end = "1990-01"
simulation.set_input("variable__end_attribute", month_before_end, 10)
simulation.set_input("variable__end_attribute", month_after_end, 10)
- assert (
- simulation.calculate("variable__end_attribute", month_before_end) == 10
- )
- assert (
- simulation.calculate("variable__end_attribute", month_after_end) == 0
- )
+ assert simulation.calculate("variable__end_attribute", month_before_end) == 10
+ assert simulation.calculate("variable__end_attribute", month_after_end) == 0
# end, one formula without date
@@ -170,25 +162,17 @@ def test_formulas_attributes_single_formula():
def test_call__end_attribute__one_simple_formula(simulation):
month = "1979-12"
- assert (
- simulation.calculate("end_attribute__one_simple_formula", month) == 100
- )
+ assert simulation.calculate("end_attribute__one_simple_formula", month) == 100
month = "1989-12"
- assert (
- simulation.calculate("end_attribute__one_simple_formula", month) == 100
- )
+ assert simulation.calculate("end_attribute__one_simple_formula", month) == 100
month = "1990-01"
- assert (
- simulation.calculate("end_attribute__one_simple_formula", month) == 0
- )
+ assert simulation.calculate("end_attribute__one_simple_formula", month) == 0
def test_dates__end_attribute__one_simple_formula():
- variable = tax_benefit_system.variables[
- "end_attribute__one_simple_formula"
- ]
+ variable = tax_benefit_system.variables["end_attribute__one_simple_formula"]
assert variable.end == datetime.date(1989, 12, 31)
assert len(variable.formulas) == 1
@@ -237,28 +221,17 @@ def formula_2000_01_01(individu, period):
def test_call__no_end_attribute__one_formula__start(simulation):
month = "1999-12"
- assert (
- simulation.calculate("no_end_attribute__one_formula__start", month)
- == 0
- )
+ assert simulation.calculate("no_end_attribute__one_formula__start", month) == 0
month = "2000-05"
- assert (
- simulation.calculate("no_end_attribute__one_formula__start", month)
- == 100
- )
+ assert simulation.calculate("no_end_attribute__one_formula__start", month) == 100
month = "2020-01"
- assert (
- simulation.calculate("no_end_attribute__one_formula__start", month)
- == 100
- )
+ assert simulation.calculate("no_end_attribute__one_formula__start", month) == 100
def test_dates__no_end_attribute__one_formula__start():
- variable = tax_benefit_system.variables[
- "no_end_attribute__one_formula__start"
- ]
+ variable = tax_benefit_system.variables["no_end_attribute__one_formula__start"]
assert variable.end is None
assert len(variable.formulas) == 1
@@ -268,7 +241,9 @@ def test_dates__no_end_attribute__one_formula__start():
class no_end_attribute__one_formula__eternity(Variable):
value_type = int
entity = Person
- definition_period = ETERNITY # For this entity, this variable shouldn't evolve through time
+ definition_period = (
+ ETERNITY # For this entity, this variable shouldn't evolve through time
+ )
label = "Variable without end attribute, one dated formula."
def formula_2000_01_01(individu, period):
@@ -281,33 +256,21 @@ def formula_2000_01_01(individu, period):
@mark.xfail()
def test_call__no_end_attribute__one_formula__eternity(simulation):
month = "1999-12"
- assert (
- simulation.calculate("no_end_attribute__one_formula__eternity", month)
- == 0
- )
+ assert simulation.calculate("no_end_attribute__one_formula__eternity", month) == 0
# This fails because a definition period of "ETERNITY" caches for all periods
month = "2000-01"
- assert (
- simulation.calculate("no_end_attribute__one_formula__eternity", month)
- == 100
- )
+ assert simulation.calculate("no_end_attribute__one_formula__eternity", month) == 100
def test_call__no_end_attribute__one_formula__eternity_before(simulation):
month = "1999-12"
- assert (
- simulation.calculate("no_end_attribute__one_formula__eternity", month)
- == 0
- )
+ assert simulation.calculate("no_end_attribute__one_formula__eternity", month) == 0
def test_call__no_end_attribute__one_formula__eternity_after(simulation):
month = "2000-01"
- assert (
- simulation.calculate("no_end_attribute__one_formula__eternity", month)
- == 100
- )
+ assert simulation.calculate("no_end_attribute__one_formula__eternity", month) == 100
# formula, different start formats
@@ -339,9 +302,7 @@ def test_formulas_attributes_dated_formulas():
def test_get_formulas():
- variable = tax_benefit_system.variables[
- "no_end_attribute__formulas__start_formats"
- ]
+ variable = tax_benefit_system.variables["no_end_attribute__formulas__start_formats"]
formula_2000 = variable.formulas["2000-01-01"]
formula_2010 = variable.formulas["2010-01-01"]
@@ -355,35 +316,21 @@ def test_get_formulas():
def test_call__no_end_attribute__formulas__start_formats(simulation):
month = "1999-12"
- assert (
- simulation.calculate(
- "no_end_attribute__formulas__start_formats", month
- )
- == 0
- )
+ assert simulation.calculate("no_end_attribute__formulas__start_formats", month) == 0
month = "2000-01"
assert (
- simulation.calculate(
- "no_end_attribute__formulas__start_formats", month
- )
- == 100
+ simulation.calculate("no_end_attribute__formulas__start_formats", month) == 100
)
month = "2009-12"
assert (
- simulation.calculate(
- "no_end_attribute__formulas__start_formats", month
- )
- == 100
+ simulation.calculate("no_end_attribute__formulas__start_formats", month) == 100
)
month = "2010-01"
assert (
- simulation.calculate(
- "no_end_attribute__formulas__start_formats", month
- )
- == 200
+ simulation.calculate("no_end_attribute__formulas__start_formats", month) == 200
)
@@ -428,9 +375,7 @@ def formula_2010_01_01(individu, period):
return vectorize(individu, 200)
-tax_benefit_system.add_variable(
- no_attribute__formulas__different_names__no_overlap
-)
+tax_benefit_system.add_variable(no_attribute__formulas__different_names__no_overlap)
def test_call__no_attribute__formulas__different_names__no_overlap(simulation):
@@ -473,19 +418,13 @@ def formula_2000_01_01(individu, period):
def test_call__end_attribute__one_formula__start(simulation):
month = "1980-01"
- assert (
- simulation.calculate("end_attribute__one_formula__start", month) == 0
- )
+ assert simulation.calculate("end_attribute__one_formula__start", month) == 0
month = "2000-01"
- assert (
- simulation.calculate("end_attribute__one_formula__start", month) == 100
- )
+ assert simulation.calculate("end_attribute__one_formula__start", month) == 100
month = "2002-01"
- assert (
- simulation.calculate("end_attribute__one_formula__start", month) == 0
- )
+ assert simulation.calculate("end_attribute__one_formula__start", month) == 0
# end < formula, start.
@@ -517,7 +456,9 @@ class end_attribute_restrictive__one_formula(Variable):
value_type = int
entity = Person
definition_period = MONTH
- label = "Variable with end attribute, one dated formula and dates intervals overlap."
+ label = (
+ "Variable with end attribute, one dated formula and dates intervals overlap."
+ )
end = "2001-01-01"
def formula_2001_01_01(individu, period):
@@ -529,22 +470,13 @@ def formula_2001_01_01(individu, period):
def test_call__end_attribute_restrictive__one_formula(simulation):
month = "2000-12"
- assert (
- simulation.calculate("end_attribute_restrictive__one_formula", month)
- == 0
- )
+ assert simulation.calculate("end_attribute_restrictive__one_formula", month) == 0
month = "2001-01"
- assert (
- simulation.calculate("end_attribute_restrictive__one_formula", month)
- == 100
- )
+ assert simulation.calculate("end_attribute_restrictive__one_formula", month) == 100
month = "2000-05"
- assert (
- simulation.calculate("end_attribute_restrictive__one_formula", month)
- == 0
- )
+ assert simulation.calculate("end_attribute_restrictive__one_formula", month) == 0
# formulas of different names (without dates overlap on formulas)
@@ -573,20 +505,17 @@ def formula_2010_01_01(individu, period):
def test_call__end_attribute__formulas__different_names(simulation):
month = "2000-01"
assert (
- simulation.calculate("end_attribute__formulas__different_names", month)
- == 100
+ simulation.calculate("end_attribute__formulas__different_names", month) == 100
)
month = "2005-01"
assert (
- simulation.calculate("end_attribute__formulas__different_names", month)
- == 200
+ simulation.calculate("end_attribute__formulas__different_names", month) == 200
)
month = "2010-12"
assert (
- simulation.calculate("end_attribute__formulas__different_names", month)
- == 300
+ simulation.calculate("end_attribute__formulas__different_names", month) == 300
)
diff --git a/tests/fixtures/simulations.py b/tests/fixtures/simulations.py
index f7894e261..7fa203543 100644
--- a/tests/fixtures/simulations.py
+++ b/tests/fixtures/simulations.py
@@ -24,8 +24,6 @@ def make_simulation():
def _simulation(simulation_builder, tax_benefit_system, variables, period):
simulation_builder.set_default_period(period)
- simulation = simulation_builder.build_from_variables(
- tax_benefit_system, variables
- )
+ simulation = simulation_builder.build_from_variables(tax_benefit_system, variables)
return simulation