Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure that dag_id, run_id and execution_date are non-null on DagRun #18804

Merged
merged 9 commits into from
Oct 8, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,17 @@
depends_on = None


def _mssql_datetime():
from sqlalchemy.dialects import mssql
def _datetime_type(dialect_name):
if dialect_name == "mssql":
from sqlalchemy.dialects import mssql

return mssql.DATETIME2(precision=6)
elif dialect_name == "mysql":
from sqlalchemy.dialects import mysql

return mssql.DATETIME2(precision=6)
return mysql.DATETIME(fsp=6)

return sa.TIMESTAMP(timezone=True)


# Just Enough Table to run the conditions for update.
Expand Down Expand Up @@ -101,21 +108,30 @@ def upgrade():
"""Apply TaskInstance keyed to DagRun"""
conn = op.get_bind()
dialect_name = conn.dialect.name
dt_type = _datetime_type(dialect_name)

run_id_col_type = sa.String(length=ID_LEN, **COLLATION_ARGS)
string_id_col_type = sa.String(length=ID_LEN, **COLLATION_ARGS)

if dialect_name == 'sqlite':
naming_convention = {
"uq": "%(table_name)s_%(column_0_N_name)s_key",
}
with op.batch_alter_table('dag_run', naming_convention=naming_convention, recreate="always"):
# The naming_convention force the previously un-named UNIQUE constraints to have the right name --
# but we still need to enter the context manager to trigger it
pass
# The naming_convention force the previously un-named UNIQUE constraints to have the right name
with op.batch_alter_table(
'dag_run', naming_convention=naming_convention, recreate="always"
) as batch_op:
batch_op.alter_column('dag_id', existing_type=string_id_col_type, nullable=False)
batch_op.alter_column('run_id', existing_type=string_id_col_type, nullable=False)
batch_op.alter_column('execution_date', existing_type=dt_type, nullable=False)
elif dialect_name == 'mysql':
with op.batch_alter_table('dag_run') as batch_op:
batch_op.alter_column('dag_id', existing_type=sa.String(length=ID_LEN), type_=run_id_col_type)
batch_op.alter_column('run_id', existing_type=sa.String(length=ID_LEN), type_=run_id_col_type)
batch_op.alter_column(
'dag_id', existing_type=sa.String(length=ID_LEN), type_=string_id_col_type, nullable=False
)
batch_op.alter_column(
'run_id', existing_type=sa.String(length=ID_LEN), type_=string_id_col_type, nullable=False
)
batch_op.alter_column('execution_date', existing_type=dt_type, nullable=False)
batch_op.drop_constraint('dag_id', 'unique')
batch_op.drop_constraint('dag_id_2', 'unique')
batch_op.create_unique_constraint(
Expand All @@ -124,16 +140,47 @@ def upgrade():
batch_op.create_unique_constraint('dag_run_dag_id_run_id_key', ['dag_id', 'run_id'])
elif dialect_name == 'mssql':

# _Somehow_ mssql was missing these constraints entirely!
with op.batch_alter_table('dag_run') as batch_op:
batch_op.drop_index('idx_not_null_dag_id_execution_date')
batch_op.drop_index('idx_not_null_dag_id_run_id')

batch_op.drop_index('dag_id_state')
batch_op.drop_index('idx_dag_run_dag_id')
batch_op.drop_index('idx_dag_run_running_dags')
batch_op.drop_index('idx_dag_run_queued_dags')

batch_op.alter_column('dag_id', existing_type=string_id_col_type, nullable=False)
batch_op.alter_column('execution_date', existing_type=dt_type, nullable=False)
batch_op.alter_column('run_id', existing_type=string_id_col_type, nullable=False)

# _Somehow_ mssql was missing these constraints entirely
batch_op.create_unique_constraint(
'dag_run_dag_id_execution_date_key', ['dag_id', 'execution_date']
)
batch_op.create_unique_constraint('dag_run_dag_id_run_id_key', ['dag_id', 'run_id'])

batch_op.create_index('dag_id_state', ['dag_id', 'state'], unique=False)
batch_op.create_index('idx_dag_run_dag_id', ['dag_id'])
batch_op.create_index(
'idx_dag_run_running_dags',
["state", "dag_id"],
mssql_where=sa.text("state='running'"),
)
batch_op.create_index(
'idx_dag_run_queued_dags',
["state", "dag_id"],
mssql_where=sa.text("state='queued'"),
)
else:
# Make sure DagRun id columns are non-nullable
with op.batch_alter_table('dag_run', schema=None) as batch_op:
batch_op.alter_column('dag_id', existing_type=string_id_col_type, nullable=False)
batch_op.alter_column('execution_date', existing_type=dt_type, nullable=False)
batch_op.alter_column('run_id', existing_type=string_id_col_type, nullable=False)

# First create column nullable
op.add_column('task_instance', sa.Column('run_id', type_=run_id_col_type, nullable=True))
op.add_column('task_reschedule', sa.Column('run_id', type_=run_id_col_type, nullable=True))
op.add_column('task_instance', sa.Column('run_id', type_=string_id_col_type, nullable=True))
op.add_column('task_reschedule', sa.Column('run_id', type_=string_id_col_type, nullable=True))

# Then update the new column by selecting the right value from DagRun
update_query = _multi_table_update(dialect_name, task_instance, task_instance.c.run_id)
Expand All @@ -147,7 +194,9 @@ def upgrade():
op.execute(update_query)

with op.batch_alter_table('task_reschedule', schema=None) as batch_op:
batch_op.alter_column('run_id', existing_type=run_id_col_type, existing_nullable=True, nullable=False)
batch_op.alter_column(
'run_id', existing_type=string_id_col_type, existing_nullable=True, nullable=False
)

batch_op.drop_constraint('task_reschedule_dag_task_date_fkey', 'foreignkey')
if dialect_name == "mysql":
Expand All @@ -157,7 +206,14 @@ def upgrade():

with op.batch_alter_table('task_instance', schema=None) as batch_op:
# Then make it non-nullable
batch_op.alter_column('run_id', existing_type=run_id_col_type, existing_nullable=True, nullable=False)
batch_op.alter_column(
'run_id', existing_type=string_id_col_type, existing_nullable=True, nullable=False
)

batch_op.alter_column(
'dag_id', existing_type=string_id_col_type, existing_nullable=True, nullable=False
)
batch_op.alter_column('execution_date', existing_type=dt_type, existing_nullable=True, nullable=False)

# TODO: Is this right for non-postgres?
if dialect_name == 'mssql':
Expand Down Expand Up @@ -212,14 +268,11 @@ def upgrade():
def downgrade():
"""Unapply TaskInstance keyed to DagRun"""
dialect_name = op.get_bind().dialect.name
dt_type = _datetime_type(dialect_name)
string_id_col_type = sa.String(length=ID_LEN, **COLLATION_ARGS)

if dialect_name == "mssql":
col_type = _mssql_datetime()
else:
col_type = sa.TIMESTAMP(timezone=True)

op.add_column('task_instance', sa.Column('execution_date', col_type, nullable=True))
op.add_column('task_reschedule', sa.Column('execution_date', col_type, nullable=True))
op.add_column('task_instance', sa.Column('execution_date', dt_type, nullable=True))
op.add_column('task_reschedule', sa.Column('execution_date', dt_type, nullable=True))

update_query = _multi_table_update(dialect_name, task_instance, task_instance.c.execution_date)
op.execute(update_query)
Expand All @@ -228,18 +281,17 @@ def downgrade():
op.execute(update_query)

with op.batch_alter_table('task_reschedule', schema=None) as batch_op:
batch_op.alter_column(
'execution_date', existing_type=col_type, existing_nullable=True, nullable=False
)
batch_op.alter_column('execution_date', existing_type=dt_type, existing_nullable=True, nullable=False)

# Can't drop PK index while there is a FK referencing it
batch_op.drop_constraint('task_reschedule_ti_fkey')
batch_op.drop_constraint('task_reschedule_dr_fkey')
batch_op.drop_index('idx_task_reschedule_dag_task_run')

with op.batch_alter_table('task_instance', schema=None) as batch_op:
batch_op.alter_column('execution_date', existing_type=dt_type, existing_nullable=True, nullable=False)
batch_op.alter_column(
'execution_date', existing_type=col_type, existing_nullable=True, nullable=False
'dag_id', existing_type=string_id_col_type, existing_nullable=True, nullable=True
)

batch_op.drop_constraint('task_instance_pkey', type_='primary')
Expand Down Expand Up @@ -269,6 +321,49 @@ def downgrade():
ondelete='CASCADE',
)

if dialect_name == "mssql":

with op.batch_alter_table('dag_run', schema=None) as batch_op:
batch_op.drop_constraint('dag_run_dag_id_execution_date_key', 'unique')
batch_op.drop_constraint('dag_run_dag_id_run_id_key', 'unique')
batch_op.drop_index('dag_id_state')
batch_op.drop_index('idx_dag_run_running_dags')
batch_op.drop_index('idx_dag_run_queued_dags')

batch_op.alter_column('dag_id', existing_type=string_id_col_type, nullable=True)
batch_op.alter_column('execution_date', existing_type=dt_type, nullable=True)
batch_op.alter_column('run_id', existing_type=string_id_col_type, nullable=True)

batch_op.create_index('dag_id_state', ['dag_id', 'state'], unique=False)
batch_op.create_index('idx_dag_run_dag_id', ['dag_id'])
batch_op.create_index(
'idx_dag_run_running_dags',
["state", "dag_id"],
mssql_where=sa.text("state='running'"),
)
batch_op.create_index(
'idx_dag_run_queued_dags',
["state", "dag_id"],
mssql_where=sa.text("state='queued'"),
)
op.execute(
"""CREATE UNIQUE NONCLUSTERED INDEX idx_not_null_dag_id_execution_date
ON dag_run(dag_id,execution_date)
WHERE dag_id IS NOT NULL and execution_date is not null"""
)
op.execute(
"""CREATE UNIQUE NONCLUSTERED INDEX idx_not_null_dag_id_run_id
ON dag_run(dag_id,run_id)
WHERE dag_id IS NOT NULL and run_id is not null"""
)
else:
with op.batch_alter_table('dag_run', schema=None) as batch_op:
batch_op.drop_index('dag_id_state', table_name='dag_run')
batch_op.alter_column('run_id', existing_type=sa.VARCHAR(length=250), nullable=True)
batch_op.alter_column('execution_date', existing_type=dt_type, nullable=True)
batch_op.alter_column('dag_id', existing_type=sa.VARCHAR(length=250), nullable=True)
batch_op.create_index('dag_id_state', 'dag_run', ['dag_id', 'state'], unique=False)


def _multi_table_update(dialect_name, target, column):
condition = dag_run.c.dag_id == target.c.dag_id
Expand Down
6 changes: 3 additions & 3 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ class DagRun(Base, LoggingMixin):
__NO_VALUE = object()

id = Column(Integer, primary_key=True)
dag_id = Column(String(ID_LEN, **COLLATION_ARGS))
dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
queued_at = Column(UtcDateTime)
execution_date = Column(UtcDateTime, default=timezone.utcnow)
execution_date = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
start_date = Column(UtcDateTime)
end_date = Column(UtcDateTime)
_state = Column('state', String(50), default=State.QUEUED)
run_id = Column(String(ID_LEN, **COLLATION_ARGS))
run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
creating_job_id = Column(Integer)
external_trigger = Column(Boolean, default=True)
run_type = Column(String(50), nullable=False)
Expand Down
6 changes: 3 additions & 3 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,9 @@ class TaskInstance(Base, LoggingMixin):

__tablename__ = "task_instance"

task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
run_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True, nullable=False)
dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True, nullable=False)
run_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True, nullable=False)
start_date = Column(UtcDateTime)
end_date = Column(UtcDateTime)
duration = Column(Float)
Expand Down
32 changes: 32 additions & 0 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,37 @@ def check_conn_type_null(session=None) -> Iterable[str]:
)


def check_run_id_null(session) -> Iterable[str]:
import sqlalchemy.schema

metadata = sqlalchemy.schema.MetaData(session.bind)
try:
metadata.reflect(only=["dag_run"])
except exc.InvalidRequestError:
# Table doesn't exist -- empty db
return

dag_run = metadata.tables["dag_run"]

for colname in ('run_id', 'dag_id', 'execution_date'):

col = dag_run.columns.get(colname)
if col is None:
continue

if not col.nullable:
continue

num = session.query(dag_run).filter(col.is_(None)).count()
if num > 0:
yield (
f'The {dag_run.name} table has {num} row{"s" if num != 1 else ""} with a NULL value in '
f'{col.name!r}. You must manually correct this problem (possibly by deleting the problem '
'rows).'
)
session.rollback()


def check_task_tables_without_matching_dagruns(session) -> Iterable[str]:
from itertools import chain

Expand Down Expand Up @@ -762,6 +793,7 @@ def _check_migration_errors(session=None) -> Iterable[str]:
for check_fn in (
check_conn_id_duplicates,
check_conn_type_null,
check_run_id_null,
check_task_tables_without_matching_dagruns,
):
yield from check_fn(session)
Expand Down
9 changes: 6 additions & 3 deletions tests/api_connexion/schemas/test_dag_run_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class TestDAGRunSchema(TestDAGRunBase):
@provide_session
def test_serialize(self, session):
dagrun_model = DagRun(
dag_id="my-dag-run",
run_id="my-dag-run",
state='running',
run_type=DagRunType.MANUAL.value,
Expand All @@ -64,7 +65,7 @@ def test_serialize(self, session):
deserialized_dagrun = dagrun_schema.dump(dagrun_model)

assert deserialized_dagrun == {
"dag_id": None,
"dag_id": "my-dag-run",
"dag_run_id": "my-dag-run",
"end_date": None,
"state": "running",
Expand Down Expand Up @@ -128,6 +129,7 @@ class TestDagRunCollection(TestDAGRunBase):
@provide_session
def test_serialize(self, session):
dagrun_model_1 = DagRun(
dag_id="my-dag-run",
run_id="my-dag-run",
state='running',
execution_date=timezone.parse(self.default_time),
Expand All @@ -136,6 +138,7 @@ def test_serialize(self, session):
conf='{"start": "stop"}',
)
dagrun_model_2 = DagRun(
dag_id="my-dag-run",
run_id="my-dag-run-2",
state='running',
execution_date=timezone.parse(self.second_time),
Expand All @@ -150,7 +153,7 @@ def test_serialize(self, session):
assert deserialized_dagruns == {
"dag_runs": [
{
"dag_id": None,
"dag_id": "my-dag-run",
"dag_run_id": "my-dag-run",
"end_date": None,
"execution_date": self.default_time,
Expand All @@ -161,7 +164,7 @@ def test_serialize(self, session):
"conf": {"start": "stop"},
},
{
"dag_id": None,
"dag_id": "my-dag-run",
"dag_run_id": "my-dag-run-2",
"end_date": None,
"state": "running",
Expand Down
3 changes: 3 additions & 0 deletions tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def test_dagrun_find(self):
dag_id1 = "test_dagrun_find_externally_triggered"
dag_run = models.DagRun(
dag_id=dag_id1,
run_id=dag_id1,
run_type=DagRunType.MANUAL,
execution_date=now,
start_date=now,
Expand All @@ -125,6 +126,7 @@ def test_dagrun_find(self):
dag_id2 = "test_dagrun_find_not_externally_triggered"
dag_run = models.DagRun(
dag_id=dag_id2,
run_id=dag_id2,
run_type=DagRunType.MANUAL,
execution_date=now,
start_date=now,
Expand Down Expand Up @@ -532,6 +534,7 @@ def test_get_task_instance_on_empty_dagrun(self):
# don't want
dag_run = models.DagRun(
dag_id=dag.dag_id,
run_id="test_get_task_instance_on_empty_dagrun",
run_type=DagRunType.MANUAL,
execution_date=now,
start_date=now,
Expand Down