From 2bd590db7afbcc718a461e8f2c3dc043dd4fb1fa Mon Sep 17 00:00:00 2001 From: spadek <2984932259@qq.com> Date: Wed, 16 Jul 2025 18:43:43 +0800 Subject: [PATCH 1/2] Feature: Implement multi-user task management service with SSE support - add a stateful_json-based service for multi-user task/todo management - support CRUD operations for tasks, with user isolation via 'user' parameter - implement SSE (Server-Sent Events) for real-time task updates - provide a CLI client for adding, querying, completing, and deleting tasks - ensure concurrent-safe data management on the server side --- examples/task/client/task_client.go | 110 ++++++++++++ examples/task/server/task_server.go | 208 +++++++++++++++++++++++ examples/task/server/task_server_test.go | 58 +++++++ 3 files changed, 376 insertions(+) create mode 100644 examples/task/client/task_client.go create mode 100644 examples/task/server/task_server.go create mode 100644 examples/task/server/task_server_test.go diff --git a/examples/task/client/task_client.go b/examples/task/client/task_client.go new file mode 100644 index 0000000..e692a38 --- /dev/null +++ b/examples/task/client/task_client.go @@ -0,0 +1,110 @@ +// Copyright (c) 2024, Tencent Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package main + +import ( + "encoding/json" + "flag" + "fmt" + "io/ioutil" + "net/http" + "os" +) + +type Task struct { + ID int `json:"id"` + Content string `json:"content"` + Done bool `json:"done"` +} + +func main() { + server := flag.String("server", "http://localhost:8080", "server address") + user := flag.String("user", "", "user id") + cmd := flag.String("cmd", "list", "command: add/list/done/delete") + content := flag.String("content", "", "task content (for add)") + id := flag.Int("id", 0, "task id (for done/delete)") + flag.Parse() + + if *user == "" { + fmt.Println("user is required, use -user=xxx") + os.Exit(1) + } + + switch *cmd { + case "add": + if *content == "" { + fmt.Println("content is required for add") + os.Exit(1) + } + resp, err := http.PostForm(fmt.Sprintf("%s/task/add?user=%s", *server, *user), map[string][]string{"content": {*content}}) + if err != nil { + fmt.Println("add failed:", err) + os.Exit(1) + } + defer resp.Body.Close() + body, _ := ioutil.ReadAll(resp.Body) + fmt.Println(string(body)) + case "list": + resp, err := http.Get(fmt.Sprintf("%s/task/list?user=%s", *server, *user)) + if err != nil { + fmt.Println("list failed:", err) + os.Exit(1) + } + defer resp.Body.Close() + var tasks []Task + if err := json.NewDecoder(resp.Body).Decode(&tasks); err != nil { + fmt.Println("decode failed:", err) + os.Exit(1) + } + if len(tasks) == 0 { + fmt.Println("No tasks.") + return + } + for _, t := range tasks { + status := "[ ]" + if t.Done { + status = "[x]" + } + fmt.Printf("%s %d: %s\n", status, t.ID, t.Content) + } + case "done": + if *id == 0 { + fmt.Println("id is required for done") + os.Exit(1) + } + resp, err := http.PostForm(fmt.Sprintf("%s/task/done?user=%s", *server, *user), map[string][]string{"id": {fmt.Sprintf("%d", *id)}}) + if err != nil { + fmt.Println("done failed:", err) + os.Exit(1) + } + defer resp.Body.Close() + body, _ := ioutil.ReadAll(resp.Body) + fmt.Println(string(body)) + case "delete": + if *id == 0 { + fmt.Println("id is required for delete") + os.Exit(1) + } + resp, err := http.PostForm(fmt.Sprintf("%s/task/delete?user=%s", *server, *user), map[string][]string{"id": {fmt.Sprintf("%d", *id)}}) + if err != nil { + fmt.Println("delete failed:", err) + os.Exit(1) + } + defer resp.Body.Close() + body, _ := ioutil.ReadAll(resp.Body) + fmt.Println(string(body)) + default: + fmt.Println("unknown command, use add/list/done/delete") + os.Exit(1) + } +} diff --git a/examples/task/server/task_server.go b/examples/task/server/task_server.go new file mode 100644 index 0000000..745c05e --- /dev/null +++ b/examples/task/server/task_server.go @@ -0,0 +1,208 @@ +// Copyright (c) 2024, Tencent Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package main + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "strconv" + "sync" +) + +type Task struct { + ID int `json:"id"` + Content string `json:"content"` + Done bool `json:"done"` +} + +var ( + userTasks = make(map[string][]Task) + taskIDCounter = 0 + mu sync.Mutex + + // SSE 相关 + userStreams = make(map[string]map[chan []Task]struct{}) // user -> set of channels +) + +// SSE推送任务列表 +func pushTaskUpdate(user string) { + mu.Lock() + streams := userStreams[user] + tasks := append([]Task(nil), userTasks[user]...) // 拷贝,避免并发问题 + mu.Unlock() + for ch := range streams { + select { + case ch <- tasks: + default: + } + } +} + +// SSE handler +func streamHandler(w http.ResponseWriter, r *http.Request) { + user := getUserID(r) + if user == "" { + http.Error(w, "user required", http.StatusBadRequest) + return + } + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + ch := make(chan []Task, 8) + mu.Lock() + if userStreams[user] == nil { + userStreams[user] = make(map[chan []Task]struct{}) + } + userStreams[user][ch] = struct{}{} + mu.Unlock() + defer func() { + mu.Lock() + delete(userStreams[user], ch) + mu.Unlock() + close(ch) + }() + + // 初始推送一次 + mu.Lock() + tasks := append([]Task(nil), userTasks[user]...) + mu.Unlock() + fmt.Fprintf(w, "data: %s\n\n", toJSON(tasks)) + flusher.Flush() + + // 持续推送 + for { + select { + case tasks, ok := <-ch: + if !ok { + return + } + fmt.Fprintf(w, "data: %s\n\n", toJSON(tasks)) + flusher.Flush() + case <-r.Context().Done(): + return + } + } +} + +func toJSON(v interface{}) string { + b, _ := json.Marshal(v) + return string(b) +} + +func getUserID(r *http.Request) string { + return r.URL.Query().Get("user") +} + +// 修改增删改接口,操作后推送SSE +func addTaskHandler(w http.ResponseWriter, r *http.Request) { + user := getUserID(r) + if user == "" { + http.Error(w, "user required", http.StatusBadRequest) + return + } + content := r.FormValue("content") + if content == "" { + http.Error(w, "content required", http.StatusBadRequest) + return + } + mu.Lock() + taskIDCounter++ + task := Task{ID: taskIDCounter, Content: content, Done: false} + userTasks[user] = append(userTasks[user], task) + mu.Unlock() + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + pushTaskUpdate(user) +} + +func listTaskHandler(w http.ResponseWriter, r *http.Request) { + user := getUserID(r) + if user == "" { + http.Error(w, "user required", http.StatusBadRequest) + return + } + mu.Lock() + tasks := userTasks[user] + mu.Unlock() + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(tasks) +} + +func doneTaskHandler(w http.ResponseWriter, r *http.Request) { + user := getUserID(r) + if user == "" { + http.Error(w, "user required", http.StatusBadRequest) + return + } + idStr := r.FormValue("id") + id, err := strconv.Atoi(idStr) + if err != nil { + http.Error(w, "invalid id", http.StatusBadRequest) + return + } + mu.Lock() + tasks := userTasks[user] + for i, t := range tasks { + if t.ID == id { + tasks[i].Done = true + break + } + } + userTasks[user] = tasks + mu.Unlock() + w.Write([]byte("ok")) + pushTaskUpdate(user) +} + +func deleteTaskHandler(w http.ResponseWriter, r *http.Request) { + user := getUserID(r) + if user == "" { + http.Error(w, "user required", http.StatusBadRequest) + return + } + idStr := r.FormValue("id") + id, err := strconv.Atoi(idStr) + if err != nil { + http.Error(w, "invalid id", http.StatusBadRequest) + return + } + mu.Lock() + tasks := userTasks[user] + for i, t := range tasks { + if t.ID == id { + userTasks[user] = append(tasks[:i], tasks[i+1:]...) + break + } + } + mu.Unlock() + w.Write([]byte("ok")) + pushTaskUpdate(user) +} + +func main() { + http.HandleFunc("/task/add", addTaskHandler) + http.HandleFunc("/task/list", listTaskHandler) + http.HandleFunc("/task/done", doneTaskHandler) + http.HandleFunc("/task/delete", deleteTaskHandler) + http.HandleFunc("/task/stream", streamHandler) // 新增SSE接口 + fmt.Println("Task server running at http://localhost:8080") + log.Fatal(http.ListenAndServe(":8080", nil)) +} diff --git a/examples/task/server/task_server_test.go b/examples/task/server/task_server_test.go new file mode 100644 index 0000000..df9d3a1 --- /dev/null +++ b/examples/task/server/task_server_test.go @@ -0,0 +1,58 @@ +// Copyright (c) 2024, Tencent Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package main + +import ( + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +func TestAddListDoneDeleteTask(t *testing.T) { + // add + req := httptest.NewRequest("POST", "/task/add?user=test", strings.NewReader(url.Values{"content": {"hello"}}.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rw := httptest.NewRecorder() + addTaskHandler(rw, req) + if rw.Code != http.StatusOK { + t.Fatalf("addTaskHandler failed: %v", rw.Body.String()) + } + + // list + req = httptest.NewRequest("GET", "/task/list?user=test", nil) + rw = httptest.NewRecorder() + listTaskHandler(rw, req) + if !strings.Contains(rw.Body.String(), "hello") { + t.Fatalf("listTaskHandler missing task: %v", rw.Body.String()) + } + + // done + req = httptest.NewRequest("POST", "/task/done?user=test", strings.NewReader(url.Values{"id": {"1"}}.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rw = httptest.NewRecorder() + doneTaskHandler(rw, req) + if rw.Code != http.StatusOK { + t.Fatalf("doneTaskHandler failed: %v", rw.Body.String()) + } + + // delete + req = httptest.NewRequest("POST", "/task/delete?user=test", strings.NewReader(url.Values{"id": {"1"}}.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rw = httptest.NewRecorder() + deleteTaskHandler(rw, req) + if rw.Code != http.StatusOK { + t.Fatalf("deleteTaskHandler failed: %v", rw.Body.String()) + } +} From f2312600755406fa781e74e6143cd4a1a3b2a487 Mon Sep 17 00:00:00 2001 From: spadek <2984932259@qq.com> Date: Thu, 17 Jul 2025 12:15:19 +0800 Subject: [PATCH 2/2] Feature: Implement desktop file organizer tool demo with SSE progress and resource report - add organize_desktop_files tool supporting structured params (dir_path, mode) - implement SSE (Server-Sent Events) for real-time progress feedback during file organization - generate and register JSON summary report as MCP resource for client download - provide Go client example for tool invocation and report retrieval - support flexible directory input and multiple classification modes (type/ctime/project) --- examples/organize_desktop/client/main.go | 112 ++++++++++++ examples/organize_desktop/server/main.go | 195 +++++++++++++++++++++ examples/task/client/task_client.go | 110 ------------ examples/task/server/task_server.go | 208 ----------------------- examples/task/server/task_server_test.go | 58 ------- 5 files changed, 307 insertions(+), 376 deletions(-) create mode 100644 examples/organize_desktop/client/main.go create mode 100644 examples/organize_desktop/server/main.go delete mode 100644 examples/task/client/task_client.go delete mode 100644 examples/task/server/task_server.go delete mode 100644 examples/task/server/task_server_test.go diff --git a/examples/organize_desktop/client/main.go b/examples/organize_desktop/client/main.go new file mode 100644 index 0000000..1dd0249 --- /dev/null +++ b/examples/organize_desktop/client/main.go @@ -0,0 +1,112 @@ +// Tencent is pleased to support the open source community by making trpc-mcp-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-mcp-go is licensed under the Apache License Version 2.0. +package main + +import ( + "context" + "fmt" + "log" + + mcp "trpc.group/trpc-go/trpc-mcp-go" +) + +func main() { + log.Println("启动 organize_desktop_files 工具调用示例客户端...") + + ctx := context.Background() + serverURL := "http://localhost:3001/mcp" + client, err := mcp.NewClient( + serverURL, + mcp.Implementation{ + Name: "Organize-Desktop-Client", + Version: "1.0.0", + }, + mcp.WithClientLogger(mcp.GetDefaultLogger()), + ) + if err != nil { + log.Fatalf("创建 MCP 客户端失败: %v", err) + } + defer client.Close() + + _, err = client.Initialize(ctx, &mcp.InitializeRequest{}) + if err != nil { + log.Fatalf("初始化失败: %v", err) + } + client.RegisterNotificationHandler("notifications/progress", func(n *mcp.JSONRPCNotification) error { + progress, _ := n.Params.AdditionalFields["progress"].(float64) + message, _ := n.Params.AdditionalFields["message"].(string) + fmt.Printf("[进度] %.0f%% - %s\n", progress*100, message) + return nil + }) + + // 调用 organize_desktop_files 工具 + callReq := &mcp.CallToolRequest{} + callReq.Params.Name = "organize_desktop_files" + callReq.Params.Arguments = map[string]interface{}{ + // "dir_path": "C:\\Users\\你的用户名\\Desktop", // 可省略,默认桌面 + "dir_path": "D:\\Desktop", + "mode": "type", // 可选 type/ctime/project + } + log.Println("调用 organize_desktop_files 工具...") + resp, err := client.CallTool(ctx, callReq) + if err != nil { + log.Fatalf("工具调用失败: %v", err) + } + + log.Println("工具调用结果:") + var reportURI string + for _, item := range resp.Content { + if text, ok := item.(mcp.TextContent); ok { + fmt.Println(text.Text) + // 自动提取报告 URI + if idx := findReportURI(text.Text); idx != "" { + reportURI = idx + } + } else { + fmt.Printf("[其他类型内容] %+v\n", item) + } + } + + // 自动读取并打印报告内容 + if reportURI != "" { + log.Printf("\n读取报告资源: %s ...", reportURI) + readReq := &mcp.ReadResourceRequest{} + readReq.Params.URI = reportURI + resourceContent, err := client.ReadResource(ctx, readReq) + if err != nil { + log.Fatalf("读取资源失败: %v", err) + } + for _, content := range resourceContent.Contents { + if text, ok := content.(mcp.TextResourceContents); ok { + fmt.Println("\n报告内容:") + fmt.Println(text.Text) + } + } + } + log.Println("客户端示例结束.") +} + +// findReportURI 从文本中提取 resource://organize_desktop/report.json URI +func findReportURI(s string) string { + // 简单正则或字符串查找 + const prefix = "resource://organize_desktop/report.json" + if idx := len(s) - len(prefix); idx >= 0 && s[idx:] == prefix { + return prefix + } + if i := findIndex(s, prefix); i >= 0 { + return prefix + } + return "" +} + +func findIndex(s, sub string) int { + for i := 0; i+len(sub) <= len(s); i++ { + if s[i:i+len(sub)] == sub { + return i + } + } + return -1 +} diff --git a/examples/organize_desktop/server/main.go b/examples/organize_desktop/server/main.go new file mode 100644 index 0000000..97d28b3 --- /dev/null +++ b/examples/organize_desktop/server/main.go @@ -0,0 +1,195 @@ +// Tencent is pleased to support the open source community by making trpc-mcp-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-mcp-go is licensed under the Apache License Version 2.0. +package main + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "path/filepath" + "syscall" + "time" + + "encoding/json" + "sync" + + mcp "trpc.group/trpc-go/trpc-mcp-go" +) + +// 分类方式枚举 +var organizeModes = []string{"type", "ctime", "project"} + +var ( + resourceOnce sync.Once + resourceRegistered bool + resourceURI string +) + +func main() { + log.Printf("Starting organize_desktop_files MCP server...") + + mcpServer := mcp.NewServer( + "Organize-Desktop-Server", + "0.1.0", + mcp.WithServerAddress(":3001"), + mcp.WithServerPath("/mcp"), + mcp.WithServerLogger(mcp.GetDefaultLogger()), + ) + + organizeTool := mcp.NewTool("organize_desktop_files", + mcp.WithDescription("自动分析桌面文件并归类整理,支持 SSE 进度与资源报告下载。"), + mcp.WithString("dir_path", mcp.Description("要整理的桌面目录路径。")), + mcp.WithString("mode", mcp.Description("归类方式:type/ctime/project。"), mcp.Enum(organizeModes...)), + ) + + mcpServer.RegisterTool(organizeTool, handleOrganizeDesktopFiles) + log.Printf("Registered tool: organize_desktop_files") + + // 注册资源(首次注册时) + resourceOnce.Do(func() { + resource := &mcp.Resource{ + URI: "resource://organize_desktop/report.json", + Name: "desktop-organize-report", + Description: "桌面整理 JSON 总结报告", + MimeType: "application/json", + } + mcpServer.RegisterResource(resource, func(ctx context.Context, req *mcp.ReadResourceRequest) (mcp.ResourceContents, error) { + // 读取最新报告内容 + data, err := os.ReadFile("desktop_organize_report.json") + if err != nil { + return mcp.TextResourceContents{ + URI: resource.URI, + MIMEType: resource.MimeType, + Text: "报告文件不存在或读取失败。", + }, nil + } + return mcp.TextResourceContents{ + URI: resource.URI, + MIMEType: resource.MimeType, + Text: string(data), + }, nil + }) + resourceRegistered = true + resourceURI = resource.URI + }) + + stop := make(chan os.Signal, 1) + signal.Notify(stop, os.Interrupt, syscall.SIGTERM) + go func() { + log.Printf("MCP server started, listening on port 3001, path /mcp") + if err := mcpServer.Start(); err != nil { + log.Fatalf("Server failed to start: %v", err) + } + }() + <-stop + log.Printf("Shutting down server...") +} + +// handleOrganizeDesktopFiles 是 organize_desktop_files 工具的处理函数 +func handleOrganizeDesktopFiles(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + dirPath, _ := req.Params.Arguments["dir_path"].(string) + mode, _ := req.Params.Arguments["mode"].(string) + if dirPath == "" { + dirPath = os.Getenv("USERPROFILE") + "\\Desktop" // Windows 桌面默认路径 + } + if mode == "" { + mode = "type" + } + + // 获取 SSE 通知 sender + notificationSender, hasSender := mcp.GetNotificationSender(ctx) + if !hasSender { + return mcp.NewTextResult("Error: 无法获取 SSE 通知 sender,无法推送进度。"), fmt.Errorf("no notification sender") + } + + // 模拟扫描文件阶段 + notificationSender.SendProgress(0.05, "开始扫描桌面文件...") + time.Sleep(300 * time.Millisecond) + + files := []string{} + err := filepath.Walk(dirPath, func(path string, info os.FileInfo, err error) error { + if err != nil { + return nil // 跳过无法访问的文件 + } + if !info.IsDir() { + files = append(files, path) + } + return nil + }) + if err != nil { + notificationSender.SendLogMessage("error", "扫描文件失败: "+err.Error()) + return mcp.NewTextResult("扫描文件失败: " + err.Error()), err + } + notificationSender.SendProgress(0.15, fmt.Sprintf("共发现 %d 个文件,准备归类...", len(files))) + time.Sleep(300 * time.Millisecond) + + // 模拟归类阶段 + total := len(files) + if total == 0 { + notificationSender.SendProgress(1.0, "桌面无可整理文件。") + return mcp.NewTextResult("桌面无可整理文件。"), nil + } + + // 统计结果结构 + type FileInfo struct { + Path string `json:"path"` + Type string `json:"type"` + CTime string `json:"ctime"` + } + result := map[string][]FileInfo{} + + for i, f := range files { + info, err := os.Stat(f) + if err != nil { + continue + } + fileType := filepath.Ext(f) + ctime := info.ModTime().Format("2006-01-02 15:04:05") + item := FileInfo{Path: f, Type: fileType, CTime: ctime} + var key string + switch mode { + case "type": + key = fileType + case "ctime": + key = info.ModTime().Format("2006-01") + default: + key = "other" + } + result[key] = append(result[key], item) + if i%10 == 0 || i == total-1 { + progress := 0.2 + 0.7*float64(i+1)/float64(total) + msg := fmt.Sprintf("已归类 %d/%d 个文件...", i+1, total) + notificationSender.SendProgress(progress, msg) + time.Sleep(10 * time.Millisecond) + } + } + + notificationSender.SendProgress(1.0, "整理完成,生成报告...") + time.Sleep(200 * time.Millisecond) + + // 生成 JSON 报告 + reportBytes, err := json.MarshalIndent(result, "", " ") + if err != nil { + notificationSender.SendLogMessage("error", "生成 JSON 报告失败: "+err.Error()) + return mcp.NewTextResult("生成 JSON 报告失败: " + err.Error()), err + } + reportPath := "desktop_organize_report.json" + err = os.WriteFile(reportPath, reportBytes, 0644) + if err != nil { + notificationSender.SendLogMessage("error", "写入报告文件失败: "+err.Error()) + return mcp.NewTextResult("写入报告文件失败: " + err.Error()), err + } + + // 注册资源(如果未注册) + if !resourceRegistered { + resourceOnce.Do(func() {}) // 确保 main 中的注册已执行 + } + + notificationSender.SendProgress(1.0, "报告已生成,可下载。") + return mcp.NewTextResult(fmt.Sprintf("整理完成,共 %d 个文件。报告资源 URI: %s", total, resourceURI)), nil +} diff --git a/examples/task/client/task_client.go b/examples/task/client/task_client.go deleted file mode 100644 index e692a38..0000000 --- a/examples/task/client/task_client.go +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright (c) 2024, Tencent Inc. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package main - -import ( - "encoding/json" - "flag" - "fmt" - "io/ioutil" - "net/http" - "os" -) - -type Task struct { - ID int `json:"id"` - Content string `json:"content"` - Done bool `json:"done"` -} - -func main() { - server := flag.String("server", "http://localhost:8080", "server address") - user := flag.String("user", "", "user id") - cmd := flag.String("cmd", "list", "command: add/list/done/delete") - content := flag.String("content", "", "task content (for add)") - id := flag.Int("id", 0, "task id (for done/delete)") - flag.Parse() - - if *user == "" { - fmt.Println("user is required, use -user=xxx") - os.Exit(1) - } - - switch *cmd { - case "add": - if *content == "" { - fmt.Println("content is required for add") - os.Exit(1) - } - resp, err := http.PostForm(fmt.Sprintf("%s/task/add?user=%s", *server, *user), map[string][]string{"content": {*content}}) - if err != nil { - fmt.Println("add failed:", err) - os.Exit(1) - } - defer resp.Body.Close() - body, _ := ioutil.ReadAll(resp.Body) - fmt.Println(string(body)) - case "list": - resp, err := http.Get(fmt.Sprintf("%s/task/list?user=%s", *server, *user)) - if err != nil { - fmt.Println("list failed:", err) - os.Exit(1) - } - defer resp.Body.Close() - var tasks []Task - if err := json.NewDecoder(resp.Body).Decode(&tasks); err != nil { - fmt.Println("decode failed:", err) - os.Exit(1) - } - if len(tasks) == 0 { - fmt.Println("No tasks.") - return - } - for _, t := range tasks { - status := "[ ]" - if t.Done { - status = "[x]" - } - fmt.Printf("%s %d: %s\n", status, t.ID, t.Content) - } - case "done": - if *id == 0 { - fmt.Println("id is required for done") - os.Exit(1) - } - resp, err := http.PostForm(fmt.Sprintf("%s/task/done?user=%s", *server, *user), map[string][]string{"id": {fmt.Sprintf("%d", *id)}}) - if err != nil { - fmt.Println("done failed:", err) - os.Exit(1) - } - defer resp.Body.Close() - body, _ := ioutil.ReadAll(resp.Body) - fmt.Println(string(body)) - case "delete": - if *id == 0 { - fmt.Println("id is required for delete") - os.Exit(1) - } - resp, err := http.PostForm(fmt.Sprintf("%s/task/delete?user=%s", *server, *user), map[string][]string{"id": {fmt.Sprintf("%d", *id)}}) - if err != nil { - fmt.Println("delete failed:", err) - os.Exit(1) - } - defer resp.Body.Close() - body, _ := ioutil.ReadAll(resp.Body) - fmt.Println(string(body)) - default: - fmt.Println("unknown command, use add/list/done/delete") - os.Exit(1) - } -} diff --git a/examples/task/server/task_server.go b/examples/task/server/task_server.go deleted file mode 100644 index 745c05e..0000000 --- a/examples/task/server/task_server.go +++ /dev/null @@ -1,208 +0,0 @@ -// Copyright (c) 2024, Tencent Inc. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package main - -import ( - "encoding/json" - "fmt" - "log" - "net/http" - "strconv" - "sync" -) - -type Task struct { - ID int `json:"id"` - Content string `json:"content"` - Done bool `json:"done"` -} - -var ( - userTasks = make(map[string][]Task) - taskIDCounter = 0 - mu sync.Mutex - - // SSE 相关 - userStreams = make(map[string]map[chan []Task]struct{}) // user -> set of channels -) - -// SSE推送任务列表 -func pushTaskUpdate(user string) { - mu.Lock() - streams := userStreams[user] - tasks := append([]Task(nil), userTasks[user]...) // 拷贝,避免并发问题 - mu.Unlock() - for ch := range streams { - select { - case ch <- tasks: - default: - } - } -} - -// SSE handler -func streamHandler(w http.ResponseWriter, r *http.Request) { - user := getUserID(r) - if user == "" { - http.Error(w, "user required", http.StatusBadRequest) - return - } - flusher, ok := w.(http.Flusher) - if !ok { - http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - - ch := make(chan []Task, 8) - mu.Lock() - if userStreams[user] == nil { - userStreams[user] = make(map[chan []Task]struct{}) - } - userStreams[user][ch] = struct{}{} - mu.Unlock() - defer func() { - mu.Lock() - delete(userStreams[user], ch) - mu.Unlock() - close(ch) - }() - - // 初始推送一次 - mu.Lock() - tasks := append([]Task(nil), userTasks[user]...) - mu.Unlock() - fmt.Fprintf(w, "data: %s\n\n", toJSON(tasks)) - flusher.Flush() - - // 持续推送 - for { - select { - case tasks, ok := <-ch: - if !ok { - return - } - fmt.Fprintf(w, "data: %s\n\n", toJSON(tasks)) - flusher.Flush() - case <-r.Context().Done(): - return - } - } -} - -func toJSON(v interface{}) string { - b, _ := json.Marshal(v) - return string(b) -} - -func getUserID(r *http.Request) string { - return r.URL.Query().Get("user") -} - -// 修改增删改接口,操作后推送SSE -func addTaskHandler(w http.ResponseWriter, r *http.Request) { - user := getUserID(r) - if user == "" { - http.Error(w, "user required", http.StatusBadRequest) - return - } - content := r.FormValue("content") - if content == "" { - http.Error(w, "content required", http.StatusBadRequest) - return - } - mu.Lock() - taskIDCounter++ - task := Task{ID: taskIDCounter, Content: content, Done: false} - userTasks[user] = append(userTasks[user], task) - mu.Unlock() - w.WriteHeader(http.StatusOK) - w.Write([]byte("ok")) - pushTaskUpdate(user) -} - -func listTaskHandler(w http.ResponseWriter, r *http.Request) { - user := getUserID(r) - if user == "" { - http.Error(w, "user required", http.StatusBadRequest) - return - } - mu.Lock() - tasks := userTasks[user] - mu.Unlock() - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(tasks) -} - -func doneTaskHandler(w http.ResponseWriter, r *http.Request) { - user := getUserID(r) - if user == "" { - http.Error(w, "user required", http.StatusBadRequest) - return - } - idStr := r.FormValue("id") - id, err := strconv.Atoi(idStr) - if err != nil { - http.Error(w, "invalid id", http.StatusBadRequest) - return - } - mu.Lock() - tasks := userTasks[user] - for i, t := range tasks { - if t.ID == id { - tasks[i].Done = true - break - } - } - userTasks[user] = tasks - mu.Unlock() - w.Write([]byte("ok")) - pushTaskUpdate(user) -} - -func deleteTaskHandler(w http.ResponseWriter, r *http.Request) { - user := getUserID(r) - if user == "" { - http.Error(w, "user required", http.StatusBadRequest) - return - } - idStr := r.FormValue("id") - id, err := strconv.Atoi(idStr) - if err != nil { - http.Error(w, "invalid id", http.StatusBadRequest) - return - } - mu.Lock() - tasks := userTasks[user] - for i, t := range tasks { - if t.ID == id { - userTasks[user] = append(tasks[:i], tasks[i+1:]...) - break - } - } - mu.Unlock() - w.Write([]byte("ok")) - pushTaskUpdate(user) -} - -func main() { - http.HandleFunc("/task/add", addTaskHandler) - http.HandleFunc("/task/list", listTaskHandler) - http.HandleFunc("/task/done", doneTaskHandler) - http.HandleFunc("/task/delete", deleteTaskHandler) - http.HandleFunc("/task/stream", streamHandler) // 新增SSE接口 - fmt.Println("Task server running at http://localhost:8080") - log.Fatal(http.ListenAndServe(":8080", nil)) -} diff --git a/examples/task/server/task_server_test.go b/examples/task/server/task_server_test.go deleted file mode 100644 index df9d3a1..0000000 --- a/examples/task/server/task_server_test.go +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024, Tencent Inc. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package main - -import ( - "net/http" - "net/http/httptest" - "net/url" - "strings" - "testing" -) - -func TestAddListDoneDeleteTask(t *testing.T) { - // add - req := httptest.NewRequest("POST", "/task/add?user=test", strings.NewReader(url.Values{"content": {"hello"}}.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - rw := httptest.NewRecorder() - addTaskHandler(rw, req) - if rw.Code != http.StatusOK { - t.Fatalf("addTaskHandler failed: %v", rw.Body.String()) - } - - // list - req = httptest.NewRequest("GET", "/task/list?user=test", nil) - rw = httptest.NewRecorder() - listTaskHandler(rw, req) - if !strings.Contains(rw.Body.String(), "hello") { - t.Fatalf("listTaskHandler missing task: %v", rw.Body.String()) - } - - // done - req = httptest.NewRequest("POST", "/task/done?user=test", strings.NewReader(url.Values{"id": {"1"}}.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - rw = httptest.NewRecorder() - doneTaskHandler(rw, req) - if rw.Code != http.StatusOK { - t.Fatalf("doneTaskHandler failed: %v", rw.Body.String()) - } - - // delete - req = httptest.NewRequest("POST", "/task/delete?user=test", strings.NewReader(url.Values{"id": {"1"}}.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - rw = httptest.NewRecorder() - deleteTaskHandler(rw, req) - if rw.Code != http.StatusOK { - t.Fatalf("deleteTaskHandler failed: %v", rw.Body.String()) - } -}