starter/pkg/restyx/resty.go
2026-03-28 19:29:40 +08:00

316 lines
8.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package restyx this is http client wrapper.
package restyx
import (
"bufio"
"context"
"encoding/json"
"fmt"
"reflect"
"runtime"
"time"
"gitea.micah.wiki/pandora/starter/pkg/jsonx"
"gitea.micah.wiki/pandora/starter/pkg/logx"
"gitea.micah.wiki/pandora/starter/pkg/requestid"
"github.com/go-resty/resty/v2"
)
var (
headerKeys []string
)
const (
HeaderJWTTokenKey = "X-Jwt-Token"
HeaderLogIDKey = requestid.KLogIDKey
// PrintLogTypeDefault = 0
PrintLogTypeForce = 1
PrintLogTypeDisable = 2
MaxPrintLogLength = 100000
maxCapacity = 10 * 1024 * 1024
initCapacity = 1024 * 1024
)
func init() {
headerKeys = make([]string, 0)
headerKeys = append(headerKeys, HeaderJWTTokenKey)
}
type Auth struct {
User string
Token string
}
func (a *Auth) String() string {
if a == nil {
return "Auto: {nil}"
}
return fmt.Sprintf("Auth{User: %s, Token: %s}", a.User, a.Token)
}
type Request struct {
Auth *Auth
URL string
Body interface{}
FormData map[string]string
QueryData map[string]string
Header map[string]string
PathParam map[string]string
Timeout int
PrintLogType int // 相见头部var定义 0: 默认; 1: 打印2: 不打印;
}
func (r *Request) String() string {
if r == nil {
return "Request: {nil}"
}
return fmt.Sprintf("Request: {Auth: %s, URL: %s, Body: %+v, FormData: %+v, QueryData: %+v, Header: %+v, PathParm: %+v, Timeout: %d, PrintLogType: %d}",
r.Auth, r.URL, r.Body, r.FormData, r.QueryData, r.Header, r.PathParam, r.Timeout, r.PrintLogType)
}
type Response struct {
LogID string
Header map[string]string
TargetPoint interface{}
}
func (r *Response) String() string {
if r == nil {
return "Response: {nil}"
}
return fmt.Sprintf("Response: {LogID: %s, Header: %+v, TargetPoint: %+v}", r.LogID, r.Header, r.TargetPoint)
}
type StreamResponse struct {
LogID string
Header map[string]string
Target interface{}
}
func PostStream(ctx context.Context, req *Request, streamFunc StreamFunc) (*StreamResponse, error) {
return DoStream(ctx, resty.MethodPost, req, streamFunc)
}
func DoStream(ctx context.Context, method string, req *Request, streamFunc StreamFunc) (*StreamResponse, error) {
start := time.Now()
request := getClient(ctx, req)
request = request.SetDoNotParseResponse(true)
var (
resp *resty.Response
err error
)
switch method {
case resty.MethodGet:
resp, err = request.Get(req.URL)
case resty.MethodPost:
resp, err = request.Post(req.URL)
case resty.MethodHead:
resp, err = request.Head(req.URL)
case resty.MethodDelete:
resp, err = request.Delete(req.URL)
default:
resp, err = request.Get(req.URL)
}
if err != nil {
return nil, err
}
if !resp.IsSuccess() {
logx.CtxError(ctx, "HttpResty Resp IsFail. req: %+v, respCode: %d, respStatus: %s, resp: %+v, cost: %+v.", req, resp.StatusCode(), resp.Status(), resp, time.Since(start))
return nil, err
}
defer func() {
_ = resp.RawResponse.Body.Close() // 记得关闭流,防止资源泄露
}()
scanner := bufio.NewScanner(resp.RawResponse.Body)
buf := make([]byte, initCapacity) // 初始缓冲区1MB
scanner.Buffer(buf, maxCapacity) // 设置最大缓冲区大小,比如 10MB
data, err := streamFunc(ctx, scanner)
if err != nil {
logx.CtxError(ctx, "HttpResty Resp streamFunc error. req: %+v, cost: %+v, err: %+v.", req, time.Since(start), err)
return nil, err
}
if err = scanner.Err(); err != nil {
logx.CtxError(ctx, "HttpResty Resp scanner error. req: %+v, respCode: %d, respStatus: %s, cost: %+v, err: %+v.", req, resp.StatusCode(), resp.Status(), time.Since(start), err)
return nil, err
}
return &StreamResponse{
LogID: resp.Header().Get(HeaderLogIDKey),
Target: data,
}, nil
}
func Post(ctx context.Context, req *Request, targetPoint interface{}) error {
return Do(ctx, resty.MethodPost, req, targetPoint)
}
func Get(ctx context.Context, req *Request, targetPoint interface{}) error {
return Do(ctx, resty.MethodGet, req, targetPoint)
}
func Do(ctx context.Context, method string, req *Request, targetPoint interface{}) error {
restyResp, err := doRemote(ctx, method, req)
response := fillResponse(restyResp)
if err != nil {
return err
}
if targetPoint != nil {
if _, ok := targetPoint.(*SpecialResponse); ok {
v := reflect.ValueOf(targetPoint)
setDataMethod := v.MethodByName("SetData")
if setDataMethod.IsValid() {
args := []reflect.Value{reflect.ValueOf(string(restyResp.Body()))}
setDataMethod.Call(args)
}
setHeadersMethod := v.MethodByName("SetHeaders")
if setHeadersMethod.IsValid() {
args := []reflect.Value{reflect.ValueOf(jsonx.UnsafeMarshal(ctx, response.Header))}
setHeadersMethod.Call(args)
}
return nil
}
err = json.Unmarshal(restyResp.Body(), targetPoint)
if err != nil {
logx.CtxInfo(ctx, "HttpResty Resp Unmarshal Error. req: %+v, targetType: %+v, restyResp: %+v, err: %+v.", req, reflect.TypeOf(targetPoint), string(restyResp.Body()), err)
return err
}
fillTargetBase(ctx, response, targetPoint)
}
return nil
}
func getClient(ctx context.Context, req *Request) *resty.Request {
cli := resty.New()
if req.Timeout > 0 {
cli = cli.SetTimeout(time.Second * time.Duration(req.Timeout))
}
request := cli.R().SetContext(ctx)
if req.Auth != nil && len(req.Auth.User) > 0 && len(req.Auth.Token) > 0 {
request = request.SetBasicAuth(req.Auth.User, req.Auth.Token)
}
headers := getHeaders(ctx, req.Header)
if len(headers) > 0 {
request = request.SetHeaders(headers)
}
if len(req.PathParam) > 0 {
request = request.SetPathParams(req.PathParam)
}
if req.Body != nil {
request = request.SetBody(req.Body)
}
if req.FormData != nil {
request = request.SetFormData(req.FormData)
}
if len(req.QueryData) > 0 {
request = request.SetQueryParams(req.QueryData)
}
return request
}
func doRemote(ctx context.Context, method string, req *Request) (*resty.Response, error) {
start := time.Now()
if req.PrintLogType != PrintLogTypeDisable {
logx.CtxInfo(ctx, "HttpResty start. req: %+v", req)
}
request := getClient(ctx, req)
var (
resp *resty.Response
err error
)
switch method {
case resty.MethodGet:
resp, err = request.Get(req.URL)
case resty.MethodPost:
resp, err = request.Post(req.URL)
case resty.MethodHead:
resp, err = request.Head(req.URL)
case resty.MethodDelete:
resp, err = request.Delete(req.URL)
default:
resp, err = request.Get(req.URL)
}
if err != nil {
logx.CtxError(ctx, "HttpResty Handler Error. req: %+v, err: %+v, cost: %+v.", req, err, time.Since(start))
return nil, err
}
if !resp.IsSuccess() {
logx.CtxError(ctx, "HttpResty Resp IsFail. req: %+v, respCode: %d, respStatus: %s, resp: %+v, cost: %+v.", req, resp.StatusCode(), resp.Status(), resp, time.Since(start))
return resp, err
}
// 如果这里返回结构中过多,不进行日志打印。
if req.PrintLogType != PrintLogTypeDisable {
if req.PrintLogType == PrintLogTypeForce || len(resp.Body()) < MaxPrintLogLength {
logx.CtxInfo(ctx, "HttpResty finish. req: %+v, resp: %+v, cost: %+v.", req, resp, time.Since(start))
} else {
logx.CtxInfo(ctx, "HttpResty finish. req: %+v, resp: too many resp(not print to log), cost: %+v.", req, time.Since(start))
}
}
return resp, nil
}
func fillResponse(resp *resty.Response) *Response {
response := &Response{}
if resp == nil {
return response
}
response.LogID = resp.Header().Get(HeaderLogIDKey)
response.Header = make(map[string]string)
if token := resp.Header().Get(HeaderJWTTokenKey); len(token) > 0 {
response.Header[HeaderJWTTokenKey] = token
}
return response
}
func fillTargetBase(ctx context.Context, response *Response, target interface{}) {
defer func() {
if err := recover(); err != nil {
buf := make([]byte, 1<<16)
runtime.Stack(buf, false)
logx.CtxError(ctx, "HttpResty fillTargetBase panic, response: %+v, err:%+v,\n%s", response, err, buf)
}
}()
if target == nil {
return
}
t := reflect.TypeOf(target)
switch t.Kind() {
case reflect.Ptr:
v := reflect.ValueOf(target)
m := v.MethodByName("SetLogID")
if m.IsValid() {
args := []reflect.Value{reflect.ValueOf(response.LogID)}
m.Call(args)
}
case reflect.Struct:
b, ok := target.(Basic)
if !ok {
return
}
b.LogID = response.LogID
default:
}
}
func getHeaders(ctx context.Context, headers map[string]string) map[string]string {
newHeaders := make(map[string]string)
for key, value := range headers {
newHeaders[key] = value
}
logID := requestid.GetLogID(ctx)
if len(logID) > 0 {
newHeaders[HeaderLogIDKey] = logID
}
return newHeaders
}