316 lines
8.4 KiB
Go
316 lines
8.4 KiB
Go
// Package restyx this is http client wrapper.
|
||
package restyx
|
||
|
||
import (
|
||
"bufio"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"reflect"
|
||
"runtime"
|
||
"time"
|
||
|
||
"gitea.micah.wiki/pandora/starter/pkg/jsonx"
|
||
"gitea.micah.wiki/pandora/starter/pkg/logx"
|
||
"gitea.micah.wiki/pandora/starter/pkg/requestid"
|
||
"github.com/go-resty/resty/v2"
|
||
)
|
||
|
||
var (
|
||
headerKeys []string
|
||
)
|
||
|
||
const (
|
||
HeaderJWTTokenKey = "X-Jwt-Token"
|
||
HeaderLogIDKey = requestid.KLogIDKey
|
||
|
||
// PrintLogTypeDefault = 0
|
||
|
||
PrintLogTypeForce = 1
|
||
PrintLogTypeDisable = 2
|
||
|
||
MaxPrintLogLength = 100000
|
||
|
||
maxCapacity = 10 * 1024 * 1024
|
||
initCapacity = 1024 * 1024
|
||
)
|
||
|
||
func init() {
|
||
headerKeys = make([]string, 0)
|
||
headerKeys = append(headerKeys, HeaderJWTTokenKey)
|
||
}
|
||
|
||
type Auth struct {
|
||
User string
|
||
Token string
|
||
}
|
||
|
||
func (a *Auth) String() string {
|
||
if a == nil {
|
||
return "Auto: {nil}"
|
||
}
|
||
return fmt.Sprintf("Auth{User: %s, Token: %s}", a.User, a.Token)
|
||
}
|
||
|
||
type Request struct {
|
||
Auth *Auth
|
||
URL string
|
||
Body interface{}
|
||
FormData map[string]string
|
||
QueryData map[string]string
|
||
Header map[string]string
|
||
PathParam map[string]string
|
||
Timeout int
|
||
PrintLogType int // 相见头部var定义 0: 默认; 1: 打印;2: 不打印;
|
||
}
|
||
|
||
func (r *Request) String() string {
|
||
if r == nil {
|
||
return "Request: {nil}"
|
||
}
|
||
return fmt.Sprintf("Request: {Auth: %s, URL: %s, Body: %+v, FormData: %+v, QueryData: %+v, Header: %+v, PathParm: %+v, Timeout: %d, PrintLogType: %d}",
|
||
r.Auth, r.URL, r.Body, r.FormData, r.QueryData, r.Header, r.PathParam, r.Timeout, r.PrintLogType)
|
||
}
|
||
|
||
type Response struct {
|
||
LogID string
|
||
Header map[string]string
|
||
TargetPoint interface{}
|
||
}
|
||
|
||
func (r *Response) String() string {
|
||
if r == nil {
|
||
return "Response: {nil}"
|
||
}
|
||
return fmt.Sprintf("Response: {LogID: %s, Header: %+v, TargetPoint: %+v}", r.LogID, r.Header, r.TargetPoint)
|
||
}
|
||
|
||
type StreamResponse struct {
|
||
LogID string
|
||
Header map[string]string
|
||
Target interface{}
|
||
}
|
||
|
||
func PostStream(ctx context.Context, req *Request, streamFunc StreamFunc) (*StreamResponse, error) {
|
||
return DoStream(ctx, resty.MethodPost, req, streamFunc)
|
||
}
|
||
func DoStream(ctx context.Context, method string, req *Request, streamFunc StreamFunc) (*StreamResponse, error) {
|
||
start := time.Now()
|
||
request := getClient(ctx, req)
|
||
request = request.SetDoNotParseResponse(true)
|
||
var (
|
||
resp *resty.Response
|
||
err error
|
||
)
|
||
switch method {
|
||
case resty.MethodGet:
|
||
resp, err = request.Get(req.URL)
|
||
case resty.MethodPost:
|
||
resp, err = request.Post(req.URL)
|
||
case resty.MethodHead:
|
||
resp, err = request.Head(req.URL)
|
||
case resty.MethodDelete:
|
||
resp, err = request.Delete(req.URL)
|
||
default:
|
||
resp, err = request.Get(req.URL)
|
||
}
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if !resp.IsSuccess() {
|
||
logx.CtxError(ctx, "HttpResty Resp IsFail. req: %+v, respCode: %d, respStatus: %s, resp: %+v, cost: %+v.", req, resp.StatusCode(), resp.Status(), resp, time.Since(start))
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
_ = resp.RawResponse.Body.Close() // 记得关闭流,防止资源泄露
|
||
}()
|
||
scanner := bufio.NewScanner(resp.RawResponse.Body)
|
||
|
||
buf := make([]byte, initCapacity) // 初始缓冲区1MB
|
||
scanner.Buffer(buf, maxCapacity) // 设置最大缓冲区大小,比如 10MB
|
||
data, err := streamFunc(ctx, scanner)
|
||
if err != nil {
|
||
logx.CtxError(ctx, "HttpResty Resp streamFunc error. req: %+v, cost: %+v, err: %+v.", req, time.Since(start), err)
|
||
return nil, err
|
||
}
|
||
if err = scanner.Err(); err != nil {
|
||
logx.CtxError(ctx, "HttpResty Resp scanner error. req: %+v, respCode: %d, respStatus: %s, cost: %+v, err: %+v.", req, resp.StatusCode(), resp.Status(), time.Since(start), err)
|
||
return nil, err
|
||
}
|
||
|
||
return &StreamResponse{
|
||
LogID: resp.Header().Get(HeaderLogIDKey),
|
||
Target: data,
|
||
}, nil
|
||
}
|
||
|
||
func Post(ctx context.Context, req *Request, targetPoint interface{}) error {
|
||
return Do(ctx, resty.MethodPost, req, targetPoint)
|
||
}
|
||
|
||
func Get(ctx context.Context, req *Request, targetPoint interface{}) error {
|
||
return Do(ctx, resty.MethodGet, req, targetPoint)
|
||
}
|
||
|
||
func Do(ctx context.Context, method string, req *Request, targetPoint interface{}) error {
|
||
restyResp, err := doRemote(ctx, method, req)
|
||
response := fillResponse(restyResp)
|
||
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if targetPoint != nil {
|
||
if _, ok := targetPoint.(*SpecialResponse); ok {
|
||
v := reflect.ValueOf(targetPoint)
|
||
setDataMethod := v.MethodByName("SetData")
|
||
if setDataMethod.IsValid() {
|
||
args := []reflect.Value{reflect.ValueOf(string(restyResp.Body()))}
|
||
setDataMethod.Call(args)
|
||
}
|
||
setHeadersMethod := v.MethodByName("SetHeaders")
|
||
if setHeadersMethod.IsValid() {
|
||
args := []reflect.Value{reflect.ValueOf(jsonx.UnsafeMarshal(ctx, response.Header))}
|
||
setHeadersMethod.Call(args)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
err = json.Unmarshal(restyResp.Body(), targetPoint)
|
||
if err != nil {
|
||
logx.CtxInfo(ctx, "HttpResty Resp Unmarshal Error. req: %+v, targetType: %+v, restyResp: %+v, err: %+v.", req, reflect.TypeOf(targetPoint), string(restyResp.Body()), err)
|
||
return err
|
||
}
|
||
fillTargetBase(ctx, response, targetPoint)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func getClient(ctx context.Context, req *Request) *resty.Request {
|
||
cli := resty.New()
|
||
if req.Timeout > 0 {
|
||
cli = cli.SetTimeout(time.Second * time.Duration(req.Timeout))
|
||
}
|
||
request := cli.R().SetContext(ctx)
|
||
if req.Auth != nil && len(req.Auth.User) > 0 && len(req.Auth.Token) > 0 {
|
||
request = request.SetBasicAuth(req.Auth.User, req.Auth.Token)
|
||
}
|
||
headers := getHeaders(ctx, req.Header)
|
||
if len(headers) > 0 {
|
||
request = request.SetHeaders(headers)
|
||
}
|
||
if len(req.PathParam) > 0 {
|
||
request = request.SetPathParams(req.PathParam)
|
||
}
|
||
if req.Body != nil {
|
||
request = request.SetBody(req.Body)
|
||
}
|
||
if req.FormData != nil {
|
||
request = request.SetFormData(req.FormData)
|
||
}
|
||
if len(req.QueryData) > 0 {
|
||
request = request.SetQueryParams(req.QueryData)
|
||
}
|
||
return request
|
||
}
|
||
|
||
func doRemote(ctx context.Context, method string, req *Request) (*resty.Response, error) {
|
||
start := time.Now()
|
||
if req.PrintLogType != PrintLogTypeDisable {
|
||
logx.CtxInfo(ctx, "HttpResty start. req: %+v", req)
|
||
}
|
||
request := getClient(ctx, req)
|
||
var (
|
||
resp *resty.Response
|
||
err error
|
||
)
|
||
switch method {
|
||
case resty.MethodGet:
|
||
resp, err = request.Get(req.URL)
|
||
case resty.MethodPost:
|
||
resp, err = request.Post(req.URL)
|
||
case resty.MethodHead:
|
||
resp, err = request.Head(req.URL)
|
||
case resty.MethodDelete:
|
||
resp, err = request.Delete(req.URL)
|
||
default:
|
||
resp, err = request.Get(req.URL)
|
||
}
|
||
if err != nil {
|
||
logx.CtxError(ctx, "HttpResty Handler Error. req: %+v, err: %+v, cost: %+v.", req, err, time.Since(start))
|
||
return nil, err
|
||
}
|
||
|
||
if !resp.IsSuccess() {
|
||
logx.CtxError(ctx, "HttpResty Resp IsFail. req: %+v, respCode: %d, respStatus: %s, resp: %+v, cost: %+v.", req, resp.StatusCode(), resp.Status(), resp, time.Since(start))
|
||
return resp, err
|
||
}
|
||
// 如果这里返回结构中过多,不进行日志打印。
|
||
if req.PrintLogType != PrintLogTypeDisable {
|
||
if req.PrintLogType == PrintLogTypeForce || len(resp.Body()) < MaxPrintLogLength {
|
||
logx.CtxInfo(ctx, "HttpResty finish. req: %+v, resp: %+v, cost: %+v.", req, resp, time.Since(start))
|
||
} else {
|
||
logx.CtxInfo(ctx, "HttpResty finish. req: %+v, resp: too many resp(not print to log), cost: %+v.", req, time.Since(start))
|
||
}
|
||
}
|
||
|
||
return resp, nil
|
||
}
|
||
|
||
func fillResponse(resp *resty.Response) *Response {
|
||
response := &Response{}
|
||
if resp == nil {
|
||
return response
|
||
}
|
||
response.LogID = resp.Header().Get(HeaderLogIDKey)
|
||
response.Header = make(map[string]string)
|
||
if token := resp.Header().Get(HeaderJWTTokenKey); len(token) > 0 {
|
||
response.Header[HeaderJWTTokenKey] = token
|
||
}
|
||
return response
|
||
}
|
||
|
||
func fillTargetBase(ctx context.Context, response *Response, target interface{}) {
|
||
defer func() {
|
||
if err := recover(); err != nil {
|
||
buf := make([]byte, 1<<16)
|
||
runtime.Stack(buf, false)
|
||
logx.CtxError(ctx, "HttpResty fillTargetBase panic, response: %+v, err:%+v,\n%s", response, err, buf)
|
||
}
|
||
}()
|
||
if target == nil {
|
||
return
|
||
}
|
||
|
||
t := reflect.TypeOf(target)
|
||
switch t.Kind() {
|
||
case reflect.Ptr:
|
||
v := reflect.ValueOf(target)
|
||
m := v.MethodByName("SetLogID")
|
||
if m.IsValid() {
|
||
args := []reflect.Value{reflect.ValueOf(response.LogID)}
|
||
m.Call(args)
|
||
}
|
||
case reflect.Struct:
|
||
b, ok := target.(Basic)
|
||
if !ok {
|
||
return
|
||
}
|
||
b.LogID = response.LogID
|
||
default:
|
||
}
|
||
}
|
||
|
||
func getHeaders(ctx context.Context, headers map[string]string) map[string]string {
|
||
newHeaders := make(map[string]string)
|
||
for key, value := range headers {
|
||
newHeaders[key] = value
|
||
}
|
||
logID := requestid.GetLogID(ctx)
|
||
if len(logID) > 0 {
|
||
newHeaders[HeaderLogIDKey] = logID
|
||
}
|
||
return newHeaders
|
||
}
|