Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New User Journey (#388) #394

Merged
merged 1 commit into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""alter segment table add number_of_farmers column

Revision ID: a5f7676a8a23
Revises: 4bb17002ebca
Create Date: 2025-01-14 07:21:08.343305

"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "a5f7676a8a23"
down_revision: Union[str, None] = "4bb17002ebca"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.add_column(
"segment", sa.Column("number_of_farmers", sa.Integer(), nullable=True)
)


def downgrade() -> None:
op.drop_column("segment", "number_of_farmers")
33 changes: 33 additions & 0 deletions backend/db/crud_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ def add_case(session: Session, payload: CaseBase, user: User) -> CaseDict:
for tag_id in payload.tags:
tag = CaseTag(tag=tag_id)
case.case_tags.append(tag)
# store segments
if payload.segments:
for segment in payload.segments:
new_segment = Segment(
name=segment.name, number_of_farmers=segment.number_of_farmers
)
case.case_segments.append(new_segment)
session.add(case)
session.commit()
session.flush()
Expand Down Expand Up @@ -245,6 +252,32 @@ def update_case(session: Session, id: int, payload: CaseBase) -> CaseDict:
volume_measurement_unit=val.volume_measurement_unit,
)
case.case_commodities.append(case_commodity)
# handle update segments
if payload.segments:
for segment in payload.segments:
prev_segment = (
session.query(Segment)
.filter(
and_(
Segment.case == case.id,
Segment.id == segment.id,
)
)
.first()
)
if prev_segment:
# update prev segment
prev_segment.name = segment.name
prev_segment.number_of_farmers = segment.number_of_farmers
session.commit()
session.flush()
session.refresh(prev_segment)
else:
new_segment = Segment(
name=segment.name,
number_of_farmers=segment.number_of_farmers,
)
case.case_segments.append(new_segment)
session.commit()
session.flush()
session.refresh(case)
Expand Down
14 changes: 8 additions & 6 deletions backend/db/crud_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from fastapi import HTTPException, status

from models.segment import (
Segment, SegmentBase, SegmentDict, SegmentUpdateBase,
SegmentWithAnswersDict
Segment,
SegmentBase,
SegmentDict,
SegmentUpdateBase,
SegmentWithAnswersDict,
)
from models.segment_answer import SegmentAnswer

Expand Down Expand Up @@ -44,7 +47,7 @@ def get_segment_by_id(session: Session, id: int) -> SegmentDict:
if not segment:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Segment {id} not found"
detail=f"Segment {id} not found",
)
return segment

Expand Down Expand Up @@ -93,8 +96,7 @@ def delete_segment(session: Session, id: int):
segment = get_segment_by_id(session=session, id=id)
# delete segment answers
segment_answers = (
session.query(SegmentAnswer)
.filter(SegmentAnswer.segment == id).all()
session.query(SegmentAnswer).filter(SegmentAnswer.segment == id).all()
)
for sa in segment_answers:
session.delete(sa)
Expand All @@ -110,6 +112,6 @@ def get_segments_by_case_id(
if not segments:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Segments with case {case_id} not found"
detail=f"Segments with case {case_id} not found",
)
return segments
8 changes: 7 additions & 1 deletion backend/models/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
SimplifiedCaseCommodityDict,
CaseCommodityType,
)
from models.segment import Segment, SegmentDict, SegmentWithAnswersDict
from models.segment import (
Segment,
SegmentDict,
SegmentWithAnswersDict,
CaseSettingSegmentPayload,
)
from models.case_tag import CaseTag


Expand Down Expand Up @@ -336,6 +341,7 @@ class CaseBase(BaseModel):
other_commodities: Optional[List[OtherCommoditysBase]] = None
tags: Optional[List[int]] = None
company: Optional[int] = None
segments: Optional[List[CaseSettingSegmentPayload]] = None


class PaginatedCaseResponse(BaseModel):
Expand Down
31 changes: 23 additions & 8 deletions backend/models/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class SegmentDict(TypedDict):
target: Optional[float]
adult: Optional[float]
child: Optional[float]
number_of_farmers: Optional[int]


class SimplifiedSegmentDict(TypedDict):
Expand All @@ -25,6 +26,7 @@ class SimplifiedSegmentDict(TypedDict):
target: Optional[float]
adult: Optional[float]
child: Optional[float]
number_of_farmers: Optional[int]


class SegmentWithAnswersDict(TypedDict):
Expand All @@ -35,42 +37,45 @@ class SegmentWithAnswersDict(TypedDict):
target: Optional[float]
adult: Optional[float]
child: Optional[float]
number_of_farmers: Optional[int]
answers: Optional[dict]
benchmark: Optional[LivingIncomeBenchmarkDict]


class Segment(Base):
__tablename__ = 'segment'
__tablename__ = "segment"

id = Column(Integer, primary_key=True, nullable=False)
case = Column(Integer, ForeignKey('case.id'))
region = Column(Integer, ForeignKey('region.id'), nullable=True)
case = Column(Integer, ForeignKey("case.id"))
region = Column(Integer, ForeignKey("region.id"), nullable=True)
name = Column(String, nullable=False)
target = Column(Float, nullable=True)
adult = Column(Float, nullable=True)
child = Column(Float, nullable=True)
number_of_farmers = Column(Integer, nullable=True)

case_detail = relationship(
'Case',
"Case",
cascade="all, delete",
passive_deletes=True,
back_populates='case_segments'
back_populates="case_segments",
)
segment_answers = relationship(
'SegmentAnswer',
"SegmentAnswer",
cascade="all, delete",
passive_deletes=True,
backref='segment_detail'
backref="segment_detail",
)

def __init__(
self,
name: str,
case: int,
case: Optional[int] = None,
region: Optional[int] = None,
target: Optional[float] = None,
adult: Optional[float] = None,
child: Optional[float] = None,
number_of_farmers: Optional[int] = None,
id: Optional[int] = None,
):
self.id = id
Expand All @@ -80,6 +85,7 @@ def __init__(
self.target = target
self.adult = adult
self.child = child
self.number_of_farmers = number_of_farmers

def __repr__(self) -> int:
return f"<Segment {self.id}>"
Expand All @@ -94,6 +100,7 @@ def serialize(self) -> SegmentDict:
"target": self.target,
"adult": self.adult,
"child": self.child,
"number_of_farmers": self.number_of_farmers,
}

@property
Expand All @@ -105,6 +112,7 @@ def simplify(self) -> SimplifiedSegmentDict:
"target": self.target,
"adult": self.adult,
"child": self.child,
"number_of_farmers": self.number_of_farmers,
}

@property
Expand All @@ -126,6 +134,7 @@ def serialize_with_answers(self) -> SegmentWithAnswersDict:
"child": self.child,
"answers": answers,
"benchmark": None,
"number_of_farmers": self.number_of_farmers,
}


Expand All @@ -148,3 +157,9 @@ class SegmentUpdateBase(BaseModel):
adult: Optional[float] = None
child: Optional[float] = None
answers: Optional[List[SegmentAnswerBase]] = []


class CaseSettingSegmentPayload(BaseModel):
name: str
number_of_farmers: Optional[int] = None
id: Optional[int] = None
9 changes: 4 additions & 5 deletions backend/routes/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from db.connection import get_session
from models.segment import (
SegmentBase,
SegmentDict,
SegmentUpdateBase,
SegmentWithAnswersDict,
)
Expand All @@ -24,7 +23,7 @@

@segment_route.post(
"/segment",
response_model=List[SegmentDict],
response_model=List[SegmentWithAnswersDict],
summary="create segment",
name="segment:create",
tags=["Segment"],
Expand All @@ -45,12 +44,12 @@ def create_segment(
session=session, case_id=case_id, user_id=user.id
)
segments = crud_segment.add_segment(session=session, payloads=payload)
return [s.serialize for s in segments]
return [s.serialize_with_answers for s in segments]


@segment_route.put(
"/segment",
response_model=List[SegmentDict],
response_model=List[SegmentWithAnswersDict],
summary="update segment",
name="segment:update",
tags=["Segment"],
Expand All @@ -71,7 +70,7 @@ def update_segment(
crud_case.case_updated_by(
session=session, case_id=case_id, user_id=user.id
)
return [s.serialize for s in segments]
return [s.serialize_with_answers for s in segments]


@segment_route.delete(
Expand Down
33 changes: 33 additions & 0 deletions backend/tests/test_040_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ async def test_create_segment(
"target": 1000.0,
"adult": 2.0,
"child": 3.0,
"number_of_farmers": None,
"answers": {},
"benchmark": None,
}
]
# with admin user cred
Expand Down Expand Up @@ -102,6 +105,9 @@ async def test_create_segment(
"target": 2000.0,
"adult": 3.0,
"child": 2.0,
"number_of_farmers": None,
"answers": {},
"benchmark": None,
},
{
"id": res[1]["id"],
Expand All @@ -111,6 +117,14 @@ async def test_create_segment(
"target": 3000.0,
"adult": 4.0,
"child": 2.0,
"number_of_farmers": None,
"answers": {
"current-1-1": 10000.0,
"current-1-2": None,
"feasible-1-1": None,
"feasible-1-2": None,
},
"benchmark": None,
},
]

Expand Down Expand Up @@ -152,6 +166,9 @@ async def test_update_segment(
"target": 2000.0,
"adult": 4.0,
"child": 2.0,
"number_of_farmers": None,
"answers": {},
"benchmark": None,
}
]
# with admin user cred
Expand Down Expand Up @@ -222,6 +239,9 @@ async def test_update_segment(
"target": 2000.0,
"adult": 5.0,
"child": 0.0,
"number_of_farmers": None,
"answers": {},
"benchmark": None,
},
{
"id": 2,
Expand All @@ -231,6 +251,9 @@ async def test_update_segment(
"target": 2000.0,
"adult": 6.0,
"child": 0.0,
"number_of_farmers": None,
"answers": {},
"benchmark": None,
},
{
"id": 3,
Expand All @@ -240,5 +263,15 @@ async def test_update_segment(
"target": 3000.0,
"adult": 4.0,
"child": 2.0,
"number_of_farmers": None,
"answers": {
"current-1-1": 10000.0,
"current-1-2": None,
"current-1-3": None,
"feasible-1-1": None,
"feasible-1-2": None,
"feasible-1-3": 500.0,
},
"benchmark": None,
},
]
Loading
Loading