@
wunonglin #16 这是 json 的,如果请求协议里是 application/proto 序列化编码方式就不行了。
em....公司**需求要支持 json 和 proto 两种序列化方式(屎山),分享下实现(性能不算好,能用就行),在 encode 的地方动态构造 proto 对象,这样就不用所有 response 包一层了。
```
package http
import (
"bytes"
"fmt"
"io"
"log/slog"
stdhttp "net/http"
"reflect"
"strconv"
"strings"
"sync"
"time"
"
github.com/go-kratos/kratos/v2/encoding"
ejson "
github.com/go-kratos/kratos/v2/encoding/json"
eproto "
github.com/go-kratos/kratos/v2/encoding/proto"
"
github.com/go-kratos/kratos/v2/errors"
"
github.com/go-kratos/kratos/v2/transport/http"
pproto "
github.com/golang/protobuf/proto"
"
github.com/jhump/protoreflect/desc"
"
github.com/jhump/protoreflect/desc/builder"
"
github.com/jhump/protoreflect/dynamic"
"
google.golang.org/protobuf/encoding/protojson"
"
google.golang.org/protobuf/proto"
".../api/_gen/go/ecode"
)
var (
messagePool = &sync.Map{}
defaultErrorMessageDescriptor *desc.MessageDescriptor
// MarshalOptions is a configurable JSON format marshaller.
jsonMarshalOptions = protojson.MarshalOptions{
EmitUnpopulated: true,
}
// UnmarshalOptions is a configurable JSON format parser.
jsonUnmarshalOptions = protojson.UnmarshalOptions{
DiscardUnknown: true,
}
jsonCodecHeaders = []string{"application/json", "text/json"}
protoCodecHeaders = []string{"application/x-protobuf", "application/proto", "application/octet-stream"}
jsonCodec = encoding.GetCodec(
ejson.Name)
protoCodec = encoding.GetCodec(
eproto.Name)
registeredCodecs = make(map[string]encoding.Codec)
)
func init() {
for _, contentType := range jsonCodecHeaders {
registeredCodecs[contentType] = jsonCodec
}
for _, contentType := range protoCodecHeaders {
registeredCodecs[contentType] = protoCodec
}
defaultErrorMessageDescriptor, _ = builder.NewMessage("Response").
AddField(builder.NewField("code", builder.FieldTypeInt32()).SetNumber(1)).
AddField(builder.NewField("message", builder.FieldTypeString()).SetNumber(2)).
AddField(builder.NewField("ts", builder.FieldTypeInt64()).SetNumber(3)).Build()
}
func requestDecoder(r *http.Request, v any) error {
codec, _, ok := codecForRequest(r, "Content-Type")
if !ok {
return errors.BadRequest("CODEC", fmt.Sprintf("unregister Content-Type: %s", r.Header.Get("Content-Type")))
}
data, err := io.ReadAll(r.Body)
if err != nil {
return errors.BadRequest("CODEC", err.Error())
}
if len(data) == 0 {
return nil
}
if err = codec.Unmarshal(data, v); err != nil {
return errors.BadRequest("CODEC", fmt.Sprintf("body unmarshal err: %s, body: %s", err.Error(), string(data)))
}
r.Body = io.NopCloser(bytes.NewBuffer(data))
return nil
}
func ErrorEncoder(w http.ResponseWriter, r *http.Request, err error) {
er := errors.FromError(err)
// 获取业务错误码
code, ok := ecode.ServiceErrorReason_value[er.Reason]
if !ok || code == 0 { // 异常情况直接使用 errors.code
code = er.Code
}
codec, contentType, ok := codecForRequest(r, "Accept")
if !ok {
codec, contentType, _ = codecForRequest(r, "Content-Type")
}
switch
codec.Name() {
case
ejson.Name:
bt, err := encodeJSONResponse(code, er.Message, []byte("{}"))
if err != nil {
slog.Error("fail to encode json response: %v", err)
w.WriteHeader(stdhttp.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", contentType)
w.WriteHeader(stdhttp.StatusOK)
_, _ = w.Write(bt)
return
case
eproto.Name:
bt, err := encodeProtoResponse(code, er.Message, nil)
if err != nil {
slog.Error("fail to encode json response: %v", err)
w.WriteHeader(stdhttp.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", contentType)
w.WriteHeader(stdhttp.StatusOK)
_, _ = w.Write(bt)
return
}
return
}
func responseEncoder(w http.ResponseWriter, r *http.Request, i any) error {
codec, contentType, ok := codecForRequest(r, "Accept")
if !ok {
codec, contentType, _ = codecForRequest(r, "Content-Type")
}
m, ok := i.(proto.Message)
if !ok {
return errors.BadRequest("CODEC", fmt.Sprintf("response is not proto.Message: %s", reflect.TypeOf(i)))
}
switch
codec.Name() {
case
ejson.Name:
data, err := jsonMarshalOptions.Marshal(m)
if err != nil {
return err
}
bt, err := encodeJSONResponse(200, "success", data)
if err != nil {
return err
}
w.Header().Set("Content-Type", contentType)
w.WriteHeader(stdhttp.StatusOK)
_, _ = w.Write(bt)
return nil
case
eproto.Name:
bt, err := encodeProtoResponse(200, "success", m)
if err != nil {
return err
}
w.Header().Set("Content-Type", contentType)
w.WriteHeader(stdhttp.StatusOK)
_, _ = w.Write(bt)
return nil
}
return nil
}
// get codec for request
func codecForRequest(r *http.Request, name string) (encoding.Codec, string, bool) {
contentType := r.Header.Get(name)
right := strings.Index(contentType, ";")
if right == -1 {
right = len(contentType)
}
c := contentType[:right]
codec := registeredCodecs[c]
if codec != nil {
return codec, c, true
}
return jsonCodec, "application/json", false
}
func encodeJSONResponse(code int32, message string, data []byte) ([]byte, error) {
buf := new(bytes.Buffer)
buf.WriteString("{\"code\":")
buf.WriteString(strconv.FormatInt(int64(code), 10))
buf.WriteString(",\"message\":\"")
buf.WriteString(message)
buf.WriteString("\",\"ts\":" + strconv.FormatInt(time.Now().Unix(), 10) + ",")
buf.WriteString("\"data\":")
buf.Write(data)
buf.WriteString("}")
return buf.Bytes(), nil
}
func encodeProtoResponse(code int32, message string, data proto.Message) ([]byte, error) {
build, err := getProtoBuilder(data)
if err != nil {
return nil, err
}
response := dynamic.NewMessage(build)
response.SetFieldByNumber(1, code)
response.SetFieldByNumber(2, message)
response.SetFieldByNumber(3, int32(time.Now().Unix()))
if data != nil {
_ = response.TrySetFieldByNumber(4, data)
}
return response.Marshal()
}
func getProtoBuilder(message proto.Message) (*desc.MessageDescriptor, error) {
if message == nil {
return defaultErrorMessageDescriptor, nil
}
key := message.ProtoReflect().Type().Descriptor().Name()
v, ok := messagePool.Load(key)
if !ok || v == nil {
anyDesc, err := desc.LoadMessageDescriptorForMessage(pproto.MessageV1(message))
if err != nil {
return nil, fmt.Errorf("loadMessageDescriptorForMessage err: %w", err)
}
build, err := builder.NewMessage("Response").
AddField(builder.NewField("code", builder.FieldTypeInt32()).SetNumber(1)).
AddField(builder.NewField("message", builder.FieldTypeString()).SetNumber(2)).
AddField(builder.NewField("ts", builder.FieldTypeInt64()).SetNumber(3)).
AddField(builder.NewField("data", builder.FieldTypeImportedMessage(anyDesc)).SetNumber(4)).Build()
if err != nil {
return nil, fmt.Errorf("build new message err: %w", err)
}
messagePool.Store(key, build)
return build, nil
}
return v.(*desc.MessageDescriptor), nil
}
```