diff --git a/controller/stories/stories.go b/controller/stories/stories.go index 214c66a..29f61bb 100644 --- a/controller/stories/stories.go +++ b/controller/stories/stories.go @@ -13,6 +13,7 @@ import ( "github.com/source-academy/stories-backend/controller" "github.com/source-academy/stories-backend/internal/database" apierrors "github.com/source-academy/stories-backend/internal/errors" + "github.com/source-academy/stories-backend/internal/usergroups" "github.com/source-academy/stories-backend/model" storyparams "github.com/source-academy/stories-backend/params/stories" storyviews "github.com/source-academy/stories-backend/view/stories" @@ -26,7 +27,14 @@ func HandleList(w http.ResponseWriter, r *http.Request) error { return err } - stories, err := model.GetAllStories(db) + // Get group id from context + groupID, err := usergroups.GetGroupIDFrom(r) + if err != nil { + logrus.Error(err) + return err + } + + stories, err := model.GetAllStoriesInGroup(db, groupID) if err != nil { logrus.Error(err) return err diff --git a/internal/usergroups/middleware.go b/internal/usergroups/middleware.go index 84a7ad3..4c07422 100644 --- a/internal/usergroups/middleware.go +++ b/internal/usergroups/middleware.go @@ -72,8 +72,8 @@ func InjectUserGroupIntoContext(next http.Handler) http.Handler { }) } -func GetGroupIDFrom(r *http.Request) (*int, error) { - groupID, ok := r.Context().Value(groupKey).(*int) +func GetGroupIDFrom(r *http.Request) (*uint, error) { + groupID, ok := r.Context().Value(groupKey).(*uint) if !ok { return nil, errors.New("Could not get groupID from request context") } diff --git a/migrations/20230810152313-add_course_groups.sql b/migrations/20230810152313-add_course_groups.sql index 5c975b6..0b7a06e 100644 --- a/migrations/20230810152313-add_course_groups.sql +++ b/migrations/20230810152313-add_course_groups.sql @@ -16,4 +16,3 @@ CREATE UNIQUE INDEX idx_unique_course ON course_groups (course_id); DROP INDEX idx_unique_course; DROP TABLE IF EXISTS course_groups; - diff --git a/model/stories.go b/model/stories.go index 8721f67..a3ccc07 100644 --- a/model/stories.go +++ b/model/stories.go @@ -17,9 +17,13 @@ type Story struct { PinOrder *int // nil if not pinned } -func GetAllStories(db *gorm.DB) ([]Story, error) { +// Passing nil to omit the filtering and get all stories +// TODO: Use nullable types instead +func GetAllStoriesInGroup(db *gorm.DB, groupID *uint) ([]Story, error) { var stories []Story err := db. + // FIXME: Handle nil case properly + Where(Story{GroupID: groupID}). Preload(clause.Associations). // TODO: Abstract out the sorting logic Order("pin_order ASC NULLS LAST, title ASC, content ASC"). diff --git a/model/stories_test.go b/model/stories_test.go index 9b10244..3847fe7 100644 --- a/model/stories_test.go +++ b/model/stories_test.go @@ -45,9 +45,6 @@ func TestCreateStory(t *testing.T) { db, cleanUp := testutils.SetupDBConnection(t, dbConfig, migrationPath) defer cleanUp(t) - initialStories, err := GetAllStories(db) - assert.Nil(t, err, "Expected no error when getting all stories") - // We need to first create a user and a group due to the foreign key constraint user := User{ Username: "testStoryAuthor", @@ -60,6 +57,9 @@ func TestCreateStory(t *testing.T) { } _ = CreateGroup(db, &group) + initialStories, err := GetAllStoriesInGroup(db, &group.ID) + assert.Nil(t, err, "Expected no error when getting all stories") + story := Story{ AuthorID: user.ID, Group: group, @@ -68,7 +68,7 @@ func TestCreateStory(t *testing.T) { err = CreateStory(db, &story) assert.Nil(t, err, "Expected no error when creating story") - newStories, err := GetAllStories(db) + newStories, err := GetAllStoriesInGroup(db, &group.ID) assert.Nil(t, err, "Expected no error when getting all stories") assert.Len(t, newStories, len(initialStories)+1, "Expected number of stories to increase by 1") @@ -85,7 +85,7 @@ func TestCreateStory(t *testing.T) { db, cleanUp := testutils.SetupDBConnection(t, dbConfig, migrationPath) defer cleanUp(t) - initialStories, err := GetAllStories(db) + initialStories, err := GetAllStoriesInGroup(db, nil) assert.Nil(t, err, "Expected no error when getting all stories") // We need to first create a user and a group due to the foreign key constraint @@ -102,7 +102,7 @@ func TestCreateStory(t *testing.T) { err = CreateStory(db, &story) assert.Nil(t, err, "Expected no error when creating story") - newStories, err := GetAllStories(db) + newStories, err := GetAllStoriesInGroup(db, nil) assert.Nil(t, err, "Expected no error when getting all stories") assert.Len(t, newStories, len(initialStories)+1, "Expected number of stories to increase by 1") @@ -131,14 +131,48 @@ func TestCreateStory(t *testing.T) { }) } -func TestGetStoryByID(t *testing.T) { - t.Run("should get the correct story", func(t *testing.T) { +func TestGetAllStoriesInGroup(t *testing.T) { + t.Run("Should get all stories in a group", func(t *testing.T) { + db, cleanUp := testutils.SetupDBConnection(t, dbConfig, migrationPath) + defer cleanUp(t) + + // We need to first create a user and a group due to the foreign key constraint + user := User{ + Username: "testStoryAuthor", + LoginProvider: userenums.LoginProvider(rand.Int31()), + } + _ = CreateUser(db, &user) + + groups := []*Group{ + {Name: "testGroup"}, {Name: "testGroup2"}, + } + for i, group := range groups { + _ = CreateGroup(db, group) + for j := 0; j < i+1; j++ { + story := Story{ + AuthorID: user.ID, + Group: *group, + Content: fmt.Sprintf("testStoies %d", j), + } + err := CreateStory(db, &story) + assert.Nil(t, err, "Expected no error when creating story") + } + } + + for i, group := range groups { + newStories, err := GetAllStoriesInGroup(db, &group.ID) + assert.Nil(t, err, "Expected no error when getting all stories") + assert.Len(t, newStories, i+1, "Expected number of stories to be correct") + } + }) + + t.Run("Should get stories with null group", func(t *testing.T) { db, cleanUp := testutils.SetupDBConnection(t, dbConfig, migrationPath) defer cleanUp(t) - // We need to first create a user due to the foreign key constraint + // We need to first create a user and a group due to the foreign key constraint user := User{ - Username: "testMultipleStoriesAuthor", + Username: "testStoryAuthor", LoginProvider: userenums.LoginProvider(rand.Int31()), } _ = CreateUser(db, &user) @@ -148,25 +182,32 @@ func TestGetStoryByID(t *testing.T) { } _ = CreateGroup(db, &group) - stories := []*Story{ - {AuthorID: user.ID, Group: group, Content: "The quick"}, - {AuthorID: user.ID, Group: group, Content: "brown fox"}, - {AuthorID: user.ID, Group: group, Content: "jumps over"}, + // Create 3 stories without group + for i := 0; i < 3; i++ { + story := Story{ + AuthorID: user.ID, + Content: fmt.Sprintf("testStoies %d", i), + } + err := CreateStory(db, &story) + assert.Nil(t, err, "Expected no error when creating story") } - for _, storyToAdd := range stories { - _ = CreateStory(db, storyToAdd) + // Create 1 story with group + story := Story{ + AuthorID: user.ID, + GroupID: &group.ID, + Content: "testStoies", } + err := CreateStory(db, &story) + assert.Nil(t, err, "Expected no error when creating story") - for _, story := range stories { - // FIXME: Don't use typecast - dbStory, err := GetStoryByID(db, int(story.ID)) - assert.Nil(t, err, "Expected no error when getting story with valid ID") - assert.Equal(t, story.ID, dbStory.ID, fmt.Sprintf(expectReadEqualMessage, "story")) - assert.Equal(t, story.AuthorID, dbStory.AuthorID, fmt.Sprintf(expectReadEqualMessage, "story")) - assert.Equal(t, story.GroupID, dbStory.GroupID, fmt.Sprintf(expectReadEqualMessage, "story")) - assert.Equal(t, story.Content, dbStory.Content, fmt.Sprintf(expectReadEqualMessage, "story")) - } + allStories, err := GetAllStoriesInGroup(db, nil) + assert.Nil(t, err, "Expected no error when getting all stories without group") + assert.Len(t, allStories, 4, "Expected number of stories to be correct") + + groupStories, err := GetAllStoriesInGroup(db, &group.ID) + assert.Nil(t, err, "Expected no error when getting all stories without group") + assert.Len(t, groupStories, 1, "Expected number of stories to be correct") }) }