Skip to content

Commit

Permalink
初步实现了Websocket功能
Browse files Browse the repository at this point in the history
  • Loading branch information
steden committed Sep 2, 2024
1 parent 9c06e90 commit b95e298
Show file tree
Hide file tree
Showing 19 changed files with 375 additions and 70 deletions.
32 changes: 27 additions & 5 deletions applicationBuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/farseer-go/webapi/controller"
"github.com/farseer-go/webapi/middleware"
"github.com/farseer-go/webapi/minimal"
"github.com/farseer-go/webapi/websocket"
"net/http"
"net/http/pprof"
"strings"
Expand All @@ -29,8 +30,8 @@ type applicationBuilder struct {
MiddlewareList collections.List[context.IMiddleware] // 注册的中间件
printRoute bool // 打印所有路由信息到控制台
useApiResponse bool // 是否使用了ApiResponse中间件
addr string
hostAddress string
addr string // 地址
hostAddress string // 完整地址
}

func NewApplicationBuilder() *applicationBuilder {
Expand Down Expand Up @@ -65,10 +66,19 @@ func (r *applicationBuilder) RegisterDELETE(route string, actionFunc any, params
r.registerAction(Route{Url: route, Method: "DELETE", Action: actionFunc, Params: params})
}

// RegisterWS 注册WebSocket(支持占位符,例如:/{cateId}/{Id})
func (r *applicationBuilder) RegisterWS(route string, actionFunc any) {
r.registerAction(Route{Url: route, Method: "WS", Action: actionFunc})
}

// RegisterRoutes 批量注册路由
func (r *applicationBuilder) RegisterRoutes(routes ...Route) {
for i := 0; i < len(routes); i++ {
r.registerAction(routes[i])
if strings.ToUpper(routes[i].Method) == "WS" {
r.registerWS(routes[i])
} else {
r.registerAction(routes[i])
}
}
}

Expand All @@ -88,11 +98,24 @@ func (r *applicationBuilder) registerAction(route Route) {
route.Url = strings.Trim(route.Url, " ")
route.Url = strings.TrimLeft(route.Url, "/")
if route.Url == "" {
flog.Panicf("注册minimalApi失败:%s必须设置值", flog.Colors[eumLogLevel.Error]("routing"))
flog.Panicf("注册路由失败:%s必须设置值", flog.Colors[eumLogLevel.Error]("routing"))
}
r.mux.HandleRoute(minimal.Register(r.area, route.Method, route.Url, route.Action, route.Filters, route.Params...))
}

// registerWS 注册websocket
func (r *applicationBuilder) registerWS(route Route) {
// 需要先依赖模块
modules.ThrowIfNotLoad(Module{})

route.Url = strings.Trim(route.Url, " ")
route.Url = strings.TrimLeft(route.Url, "/")
if route.Url == "" {
flog.Panicf("注册websocket路由失败:%s必须设置值", flog.Colors[eumLogLevel.Error]("routing"))
}
r.mux.HandleRoute(websocket.Register(r.area, route.Method, route.Url, route.Action, route.Filters))
}

// Area 设置区域
func (r *applicationBuilder) Area(area string, f func()) {
if !strings.HasPrefix(area, "/") {
Expand Down Expand Up @@ -191,7 +214,6 @@ func (r *applicationBuilder) Run(params ...string) {
if apiDoc, exists := r.mux.m["/doc/api"]; exists {
flog.Infof("Api Document is:%s%s", r.hostAddress, apiDoc.RouteUrl)
}

if r.tls {
flog.Info(http.ListenAndServeTLS(r.addr, r.certFile, r.keyFile, r.mux))
} else {
Expand Down
15 changes: 15 additions & 0 deletions context/httpContext.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@ package context

import (
"github.com/farseer-go/collections"
"github.com/farseer-go/fs/asyncLocal"
"github.com/farseer-go/webapi/check"
"golang.org/x/net/websocket"
"net/http"
"reflect"
"strings"
)

var RoutineHttpContext = asyncLocal.New[*HttpContext]()

type HttpContext struct {
WebsocketConn *websocket.Conn // websocket
Request *HttpRequest // Request
Response *HttpResponse // Response
Header collections.ReadonlyDictionary[string, string] // 头部信息
Expand Down Expand Up @@ -86,13 +91,23 @@ func NewHttpContext(httpRoute *HttpRoute, w http.ResponseWriter, r *http.Request
return &httpContext
}

func (receiver *HttpContext) SetWebsocket(conn *websocket.Conn) {
receiver.WebsocketConn = conn
}

// ParseParams 转换成Handle函数需要的参数
func (receiver *HttpContext) ParseParams() []reflect.Value {
// 没有入参时,忽略request.body
if receiver.Route.RequestParamType.Count() == 0 {
return []reflect.Value{}
}

if receiver.Route.Schema == "ws" {
contextWebSocket := reflect.New(receiver.Route.RequestParamType.First().Elem())
contextWebSocket.MethodByName("SetContext").Call([]reflect.Value{reflect.ValueOf(receiver)})
return []reflect.Value{contextWebSocket}
}

if receiver.Method == "GET" {
return receiver.Route.FormToParams(receiver.Request.Query)
}
Expand Down
1 change: 1 addition & 0 deletions context/httpRoute.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type HttpRoute struct {
RouteRegexp *routeRegexp // 正则路由
Handler http.Handler // api处理函数
Filters []IFilter // 过滤器(对单个路由的执行单元)
Schema string // http or ws
}

// JsonToParams json入参转成param
Expand Down
1 change: 1 addition & 0 deletions controller/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ func registerAction(area string, actionMethod reflect.Method, actions map[string

// 添加到路由表
return &context.HttpRoute{
Schema: "http",
RouteUrl: area + strings.ToLower(controllerName) + "/" + strings.ToLower(actionName),
Action: methodType,
Method: collections.NewList(strings.Split(strings.ToUpper(actions[actionName].Method), "|")...),
Expand Down
37 changes: 37 additions & 0 deletions httpHandler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package webapi

import (
"github.com/farseer-go/fs/asyncLocal"
"github.com/farseer-go/fs/container"
"github.com/farseer-go/fs/flog"
"github.com/farseer-go/fs/trace"
"github.com/farseer-go/webapi/context"
"net/http"
)

func HttpHandler(route *context.HttpRoute) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// 解析报文、组装httpContext
httpContext := context.NewHttpContext(route, w, r)
// 创建链路追踪上下文
trackContext := container.Resolve[trace.IManager]().EntryWebApi(httpContext.URI.Host, httpContext.URI.Url, httpContext.Method, httpContext.ContentType, httpContext.Header.ToMap(), httpContext.URI.GetRealIp())
// 结束链路追踪
defer trackContext.End()
// 记录出入参
defer func() {
trackContext.SetBody(httpContext.Request.BodyString, httpContext.Response.GetHttpCode(), string(httpContext.Response.BodyBytes))
}()
httpContext.Data.Set("Trace", trackContext)

// 设置到routine,可用于任意子函数获取
context.RoutineHttpContext.Set(httpContext)
// 执行第一个中间件
route.HttpMiddleware.Invoke(httpContext)
// 记录异常
if httpContext.Exception != nil {
trackContext.Error(httpContext.Exception)
_ = flog.Errorf("[%s]%s 发生错误:%s", httpContext.Method, httpContext.URI.Url, httpContext.Exception.Error())
}
asyncLocal.Release()
}
}
7 changes: 4 additions & 3 deletions middleware/apiResponse.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ func (receiver *ApiResponse) Invoke(httpContext *context.HttpContext) {

var returnVal any
// 只有一个返回值
if len(httpContext.Response.Body) == 1 {
bodyLength := len(httpContext.Response.Body)
if bodyLength == 1 {
returnVal = httpContext.Response.Body[0]
} else {
} else if bodyLength > 1 {
// 多个返回值,则转成数组Json
lst := collections.NewListAny()
for i := 0; i < len(httpContext.Response.Body); i++ {
for i := 0; i < bodyLength; i++ {
lst.Add(httpContext.Response.Body[i])
}
returnVal = lst
Expand Down
9 changes: 7 additions & 2 deletions middleware/initMiddleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,16 @@ func InitMiddleware(lstRouteTable map[string]*context.HttpRoute, lstMiddleware c
func assembledPipeline(route *context.HttpRoute, lstMiddleware collections.List[context.IMiddleware]) collections.List[context.IMiddleware] {
// 添加系统中间件
lst := collections.NewList[context.IMiddleware](route.HttpMiddleware, &exceptionMiddleware{}, &routing{})

// 添加用户自定义中间件
for i := 0; i < lstMiddleware.Count(); i++ {
valIns := reflect.New(reflect.TypeOf(lstMiddleware.Index(i)).Elem()).Interface()
lst.Add(valIns.(context.IMiddleware))
middlewareType := reflect.TypeOf(lstMiddleware.Index(i)).Elem()
if route.Schema != "ws" && middlewareType.String() != "middleware.ApiResponse" {
valIns := reflect.New(middlewareType).Interface()
lst.Add(valIns.(context.IMiddleware))
}
}

// 添加Handle中间件
lst.Add(route.HandleMiddleware)

Expand Down
2 changes: 1 addition & 1 deletion middleware/routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ type routing struct {

func (receiver *routing) Invoke(httpContext *context.HttpContext) {
// 检查method
if httpContext.Method != "OPTIONS" && !httpContext.Route.Method.Contains(httpContext.Method) {
if httpContext.Route.Schema != "ws" && httpContext.Method != "OPTIONS" && !httpContext.Route.Method.Contains(httpContext.Method) {
// 响应码
httpContext.Response.Reject(405, "405 Method NotAllowed")
return
Expand Down
20 changes: 20 additions & 0 deletions middleware/websocket.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package middleware

import (
"github.com/farseer-go/webapi/context"
"net/http"
)

type Websocket struct {
context.IMiddleware
}

func (receiver *Websocket) Invoke(httpContext *context.HttpContext) {
httpContext.Response.SetHttpCode(http.StatusOK)
httpContext.Response.SetStatusCode(http.StatusOK)

// 下一步:exceptionMiddleware
receiver.IMiddleware.Invoke(httpContext)

_ = httpContext.WebsocketConn.WriteClose(httpContext.Response.GetHttpCode())
}
12 changes: 6 additions & 6 deletions minimal/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package minimal

import (
"github.com/farseer-go/collections"
"github.com/farseer-go/fs/core/eumLogLevel"
"github.com/farseer-go/fs/flog"
"github.com/farseer-go/fs/types"
"github.com/farseer-go/webapi/check"
Expand All @@ -15,18 +14,18 @@ import (
// Register 注册单个Api
func Register(area string, method string, route string, actionFunc any, filters []context.IFilter, paramNames ...string) *context.HttpRoute {
actionType := reflect.TypeOf(actionFunc)
param := types.GetInParam(actionType)
params := types.GetInParam(actionType)

// 如果设置了方法的入参(多参数),则需要全部设置
if len(paramNames) > 0 && len(paramNames) != len(param) {
flog.Panicf("路由注册失败:%s函数入参与%s不匹配,建议重新运行fsctl -r命令", flog.Colors[eumLogLevel.Error](actionType.String()), flog.Colors[eumLogLevel.Error](paramNames))
if len(paramNames) > 0 && len(paramNames) != len(params) {
flog.Panicf("注册路由%s%s失败:%s函数入参与%s不匹配,建议重新运行fsctl -r命令", area, route, flog.Red(actionType.String()), flog.Blue(paramNames))
}

lstRequestParamType := collections.NewList(param...)
lstRequestParamType := collections.NewList(params...)
lstResponseParamType := collections.NewList(types.GetOutParam(actionType)...)

// 入参是否为DTO模式
isDtoModel := types.IsDtoModelIgnoreInterface(lstRequestParamType.ToArray())
isDtoModel := types.IsDtoModelIgnoreInterface(params)
// 是否实现了ICheck
var requestParamIsImplCheck bool
if isDtoModel {
Expand All @@ -40,6 +39,7 @@ func Register(area string, method string, route string, actionFunc any, filters

// 添加到路由表
return &context.HttpRoute{
Schema: "http",
RouteUrl: area + route,
Action: actionFunc,
Method: collections.NewList(strings.Split(strings.ToUpper(method), "|")...),
Expand Down
8 changes: 0 additions & 8 deletions routineHttpContext.go

This file was deleted.

41 changes: 8 additions & 33 deletions serveMux.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@ package webapi

import (
"github.com/farseer-go/collections"
"github.com/farseer-go/fs/asyncLocal"
"github.com/farseer-go/fs/container"
"github.com/farseer-go/fs/flog"
"github.com/farseer-go/fs/trace"
"github.com/farseer-go/webapi/context"
"github.com/farseer-go/webapi/middleware"
"github.com/farseer-go/webapi/websocket"
"net"
"net/http"
"net/url"
Expand Down Expand Up @@ -57,30 +55,12 @@ func (mux *serveMux) HandleRoute(route *context.HttpRoute) {
UseEncodedPath: false,
})

route.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 解析报文、组装httpContext
httpContext := context.NewHttpContext(route, w, r)
// 创建链路追踪上下文
trackContext := container.Resolve[trace.IManager]().EntryWebApi(httpContext.URI.Host, httpContext.URI.Url, httpContext.Method, httpContext.ContentType, httpContext.Header.ToMap(), "", httpContext.URI.GetRealIp())
// 结束链路追踪
defer trackContext.End()
// 记录出入参
defer func() {
trackContext.SetBody(httpContext.Request.BodyString, httpContext.Response.GetHttpCode(), string(httpContext.Response.BodyBytes))
}()
httpContext.Data.Set("Trace", trackContext)

// 设置到routine,可用于任意子函数获取
SetHttpContext(httpContext)
// 执行第一个中间件
route.HttpMiddleware.Invoke(httpContext)
// 记录异常
if httpContext.Exception != nil {
trackContext.Error(httpContext.Exception)
_ = flog.Errorf("[%s]%s 发生错误:%s", httpContext.Method, httpContext.URI.Url, httpContext.Exception.Error())
}
asyncLocal.Release()
})
// websocket
if route.Schema == "ws" {
route.Handler = websocket.SocketHandler(route)
} else { // http
route.Handler = HttpHandler(route)
}

// 检查规则
mux.checkHandle(route.RouteUrl, route.Handler)
Expand Down Expand Up @@ -321,10 +301,5 @@ func appendSorted(es []*context.HttpRoute, e *context.HttpRoute) []*context.Http

// GetHttpContext 在minimalApi模式下也可以获取到上下文
func GetHttpContext() *context.HttpContext {
return routineHttpContext.Get()
}

// SetHttpContext 设置当前上下文
func SetHttpContext(httpCtx *context.HttpContext) {
routineHttpContext.Set(httpCtx)
return context.RoutineHttpContext.Get()
}
9 changes: 3 additions & 6 deletions test/benchmark_run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ import (
"time"
)

func init() {
// BenchmarkRun-12 4434 304151 ns/op 22731 B/op 202 allocs/op(优化前)
// BenchmarkRun-12 4432 239972 ns/op 22758 B/op 203 allocs/op(第一次优化:HttpResponse.Body改为[]any)
func BenchmarkRun(b *testing.B) {
fs.Initialize[webapi.Module]("demo")
configure.SetDefault("Log.Component.webapi", true)

Expand All @@ -24,11 +26,6 @@ func init() {
webapi.UseApiResponse()
go webapi.Run(":8094")
time.Sleep(10 * time.Millisecond)
}

// BenchmarkRun-12 4434 304151 ns/op 22731 B/op 202 allocs/op(优化前)
// BenchmarkRun-12 4432 239972 ns/op 22758 B/op 203 allocs/op(第一次优化:HttpResponse.Body改为[]any)
func BenchmarkRun(b *testing.B) {
b.ReportAllocs()
sizeRequest := pageSizeRequest{PageSize: 10, PageIndex: 2}
marshal, _ := json.Marshal(sizeRequest)
Expand Down
7 changes: 1 addition & 6 deletions test/init_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
package test

import (
"github.com/farseer-go/fs"
"github.com/farseer-go/webapi"
)

func init() {
fs.Initialize[webapi.Module]("demo")
//fs.Initialize[webapi.Module]("demo")
}
Loading

0 comments on commit b95e298

Please sign in to comment.