Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove select_column option in TaskInstance.get_task_instance #38571

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
)
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import reconstructor, relationship
from sqlalchemy.orm import lazyload, reconstructor, relationship
from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
from sqlalchemy.sql.expression import case, select

Expand Down Expand Up @@ -521,7 +521,6 @@ def _refresh_from_db(
task_id=task_instance.task_id,
run_id=task_instance.run_id,
map_index=task_instance.map_index,
select_columns=True,
lock_for_update=lock_for_update,
session=session,
)
Expand All @@ -532,8 +531,7 @@ def _refresh_from_db(
task_instance.end_date = ti.end_date
task_instance.duration = ti.duration
task_instance.state = ti.state
# Since we selected columns, not the object, this is the raw value
task_instance.try_number = ti.try_number
task_instance.try_number = ti._try_number # private attr to get value unaltered by accessor
task_instance.max_tries = ti.max_tries
task_instance.hostname = ti.hostname
task_instance.unixname = ti.unixname
Expand Down Expand Up @@ -911,7 +909,7 @@ def _get_try_number(*, task_instance: TaskInstance | TaskInstancePydantic):

:meta private:
"""
if task_instance.state == TaskInstanceState.RUNNING.RUNNING:
if task_instance.state == TaskInstanceState.RUNNING:
return task_instance._try_number
return task_instance._try_number + 1

Expand Down Expand Up @@ -1792,18 +1790,18 @@ def get_task_instance(
run_id: str,
task_id: str,
map_index: int,
select_columns: bool = False,
lock_for_update: bool = False,
session: Session = NEW_SESSION,
) -> TaskInstance | TaskInstancePydantic | None:
query = (
session.query(*TaskInstance.__table__.columns) if select_columns else session.query(TaskInstance)
)
query = query.filter_by(
dag_id=dag_id,
run_id=run_id,
task_id=task_id,
map_index=map_index,
session.query(TaskInstance)
.options(lazyload("dag_run")) # lazy load dag run to avoid locking it
.filter_by(
dag_id=dag_id,
run_id=run_id,
task_id=task_id,
map_index=map_index,
)
)

if lock_for_update:
Expand Down
13 changes: 13 additions & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4561,3 +4561,16 @@ def test_taskinstance_with_note(create_task_instance, session):

assert session.query(TaskInstance).filter_by(**filter_kwargs).one_or_none() is None
assert session.query(TaskInstanceNote).filter_by(**filter_kwargs).one_or_none() is None


def test__refresh_from_db_should_not_increment_try_number(dag_maker, session):
with dag_maker():
BashOperator(task_id="hello", bash_command="hi")
dag_maker.create_dagrun(state="success")
ti = session.scalar(select(TaskInstance))
assert ti.task_id == "hello" # just to confirm...
assert ti.try_number == 1 # starts out as 1
ti.refresh_from_db()
assert ti.try_number == 1 # stays 1
ti.refresh_from_db()
assert ti.try_number == 1 # stays 1