Skip to content

Commit

Permalink
Incorporate changes from #110
Browse files Browse the repository at this point in the history
* Remove Never type hints from core; mypy complained in test file
  • Loading branch information
glatterf42 committed Aug 22, 2024
1 parent 94a46ec commit dd143c3
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 121 deletions.
10 changes: 2 additions & 8 deletions ixmp4/core/optimization/variable.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import sys
from datetime import datetime
from typing import Any, ClassVar, Iterable

Expand All @@ -10,11 +9,6 @@
from ixmp4.data.abstract import Run
from ixmp4.data.abstract.optimization import Column

if sys.version_info >= (3, 11):
from typing import Never
else:
from typing import NoReturn as Never


class Variable(BaseModelFacade):
_model: VariableModel
Expand Down Expand Up @@ -55,11 +49,11 @@ def marginals(self) -> list:
return self._model.data.get("marginals", [])

@property
def constrained_to_indexsets(self) -> list[str | Never]:
def constrained_to_indexsets(self) -> list[str]:
return [column.indexset.name for column in self._model.columns]

@property
def columns(self) -> list[Column | Never]:
def columns(self) -> list[Column]:
return self._model.columns

@property
Expand Down
65 changes: 30 additions & 35 deletions tests/core/test_optimization_variable.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pandas as pd
import pytest

from ixmp4 import Platform
import ixmp4
from ixmp4.core import IndexSet, OptimizationVariable

from ..utils import all_platforms, create_indexsets_for_run
from ..utils import create_indexsets_for_run


def df_from_list(variables: list):
Expand All @@ -31,11 +31,9 @@ def df_from_list(variables: list):
)


@all_platforms
class TestCoreVariable:
def test_create_variable(self, test_mp, request):
test_mp: Platform = request.getfixturevalue(test_mp) # type: ignore
run = test_mp.runs.create("Model", "Scenario")
def test_create_variable(self, platform: ixmp4.Platform):
run = platform.runs.create("Model", "Scenario")

# Test creation without indexset
variable = run.optimization.variables.create("Variable")
Expand All @@ -49,8 +47,8 @@ def test_create_variable(self, test_mp, request):

# Test creation with indexset
indexset, indexset_2 = tuple(
IndexSet(_backend=test_mp.backend, _model=model)
for model in create_indexsets_for_run(platform=test_mp, run_id=run.id)
IndexSet(_backend=platform.backend, _model=model)
for model in create_indexsets_for_run(platform=platform, run_id=run.id)
)
variable_2 = run.optimization.variables.create(
name="Variable 2",
Expand Down Expand Up @@ -105,11 +103,10 @@ def test_create_variable(self, test_mp, request):
assert variable_4.columns[0].dtype == "object"
assert variable_4.columns[1].dtype == "int64"

def test_get_variable(self, test_mp, request):
test_mp: Platform = request.getfixturevalue(test_mp) # type: ignore
run = test_mp.runs.create("Model", "Scenario")
def test_get_variable(self, platform: ixmp4.Platform):
run = platform.runs.create("Model", "Scenario")
(indexset,) = create_indexsets_for_run(
platform=test_mp, run_id=run.id, amount=1
platform=platform, run_id=run.id, amount=1
)
_ = run.optimization.variables.create(
name="Variable", constrained_to_indexsets=[indexset.name]
Expand All @@ -127,12 +124,11 @@ def test_get_variable(self, test_mp, request):
with pytest.raises(OptimizationVariable.NotFound):
_ = run.optimization.variables.get("Variable 2")

def test_variable_add_data(self, test_mp, request):
test_mp: Platform = request.getfixturevalue(test_mp) # type: ignore
run = test_mp.runs.create("Model", "Scenario")
def test_variable_add_data(self, platform: ixmp4.Platform):
run = platform.runs.create("Model", "Scenario")
indexset, indexset_2 = tuple(
IndexSet(_backend=test_mp.backend, _model=model)
for model in create_indexsets_for_run(platform=test_mp, run_id=run.id)
IndexSet(_backend=platform.backend, _model=model)
for model in create_indexsets_for_run(platform=platform, run_id=run.id)
)
indexset.add(elements=["foo", "bar", ""])
indexset_2.add(elements=[1, 2, 3])
Expand Down Expand Up @@ -248,25 +244,26 @@ def test_variable_add_data(self, test_mp, request):
variable_3.add(data=test_data_4)
test_data_5 = test_data_3.copy()
for key, value in test_data_4.items():
test_data_5[key].extend(value)
test_data_5[key].extend(value) # type: ignore
assert variable_3.data == test_data_5
assert variable_3.levels == test_data_5["levels"]
assert variable_3.marginals == test_data_5["marginals"]

def test_list_variable(self, test_mp, request):
test_mp: Platform = request.getfixturevalue(test_mp) # type: ignore
run = test_mp.runs.create("Model", "Scenario")
indexset, indexset_2 = create_indexsets_for_run(platform=test_mp, run_id=run.id)
def test_list_variable(self, platform: ixmp4.Platform):
run = platform.runs.create("Model", "Scenario")
indexset, indexset_2 = create_indexsets_for_run(
platform=platform, run_id=run.id
)
variable = run.optimization.variables.create(
"Variable", constrained_to_indexsets=[indexset.name]
)
variable_2 = run.optimization.variables.create(
"Variable 2", constrained_to_indexsets=[indexset_2.name]
)
# Create new run to test listing variables for specific run
run_2 = test_mp.runs.create("Model", "Scenario")
run_2 = platform.runs.create("Model", "Scenario")
(indexset,) = create_indexsets_for_run(
platform=test_mp, run_id=run_2.id, amount=1
platform=platform, run_id=run_2.id, amount=1
)
run_2.optimization.variables.create(
"Variable", constrained_to_indexsets=[indexset.name]
Expand All @@ -282,12 +279,11 @@ def test_list_variable(self, test_mp, request):
]
assert not (set(expected_id) ^ set(list_id))

def test_tabulate_variable(self, test_mp, request):
test_mp: Platform = request.getfixturevalue(test_mp) # type: ignore
run = test_mp.runs.create("Model", "Scenario")
def test_tabulate_variable(self, platform: ixmp4.Platform):
run = platform.runs.create("Model", "Scenario")
indexset, indexset_2 = tuple(
IndexSet(_backend=test_mp.backend, _model=model)
for model in create_indexsets_for_run(platform=test_mp, run_id=run.id)
IndexSet(_backend=platform.backend, _model=model)
for model in create_indexsets_for_run(platform=platform, run_id=run.id)
)
variable = run.optimization.variables.create(
name="Variable",
Expand All @@ -298,9 +294,9 @@ def test_tabulate_variable(self, test_mp, request):
constrained_to_indexsets=[indexset.name, indexset_2.name],
)
# Create new run to test tabulating variables for specific run
run_2 = test_mp.runs.create("Model", "Scenario")
run_2 = platform.runs.create("Model", "Scenario")
(indexset_3,) = create_indexsets_for_run(
platform=test_mp, run_id=run_2.id, amount=1
platform=platform, run_id=run_2.id, amount=1
)
run_2.optimization.variables.create(
"Variable", constrained_to_indexsets=[indexset_3.name]
Expand Down Expand Up @@ -332,11 +328,10 @@ def test_tabulate_variable(self, test_mp, request):
run.optimization.variables.tabulate(),
)

def test_variable_docs(self, test_mp, request):
test_mp: Platform = request.getfixturevalue(test_mp) # type: ignore
run = test_mp.runs.create("Model", "Scenario")
def test_variable_docs(self, platform: ixmp4.Platform):
run = platform.runs.create("Model", "Scenario")
(indexset,) = create_indexsets_for_run(
platform=test_mp, run_id=run.id, amount=1
platform=platform, run_id=run.id, amount=1
)
variable_1 = run.optimization.variables.create(
"Variable 1", constrained_to_indexsets=[indexset.name]
Expand Down
Loading

0 comments on commit dd143c3

Please sign in to comment.