前面分析了thrift协议的总体的大致流程,这里分析一下thrift的序列化。说到序列化,对于http有json,对于grpc使用的protobuf,那么对于thrift来说,可选的序列化方式有
因为这里compact是默认的,这里也是主要分析的compact方式。同时对于这几种序列化方式来说,这里都是实现了TProtocol这个interface,也就是如下方法
type TProtocol interface {WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqid int32) errorWriteMessageEnd(ctx context.Context) errorWriteStructBegin(ctx context.Context, name string) errorWriteStructEnd(ctx context.Context) errorWriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) errorWriteFieldEnd(ctx context.Context) errorWriteFieldStop(ctx context.Context) errorWriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) errorWriteMapEnd(ctx context.Context) errorWriteListBegin(ctx context.Context, elemType TType, size int) errorWriteListEnd(ctx context.Context) errorWriteSetBegin(ctx context.Context, elemType TType, size int) errorWriteSetEnd(ctx context.Context) errorWriteBool(ctx context.Context, value bool) errorWriteByte(ctx context.Context, value int8) errorWriteI16(ctx context.Context, value int16) errorWriteI32(ctx context.Context, value int32) errorWriteI64(ctx context.Context, value int64) errorWriteDouble(ctx context.Context, value float64) errorWriteString(ctx context.Context, value string) errorWriteBinary(ctx context.Context, value []byte) errorWriteUUID(ctx context.Context, value Tuuid) errorReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqid int32, err error)ReadMessageEnd(ctx context.Context) errorReadStructBegin(ctx context.Context) (name string, err error)ReadStructEnd(ctx context.Context) errorReadFieldBegin(ctx context.Context) (name string, typeId TType, id int16, err error)ReadFieldEnd(ctx context.Context) errorReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, err error)ReadMapEnd(ctx context.Context) errorReadListBegin(ctx context.Context) (elemType TType, size int, err error)ReadListEnd(ctx context.Context) errorReadSetBegin(ctx context.Context) (elemType TType, size int, err error)ReadSetEnd(ctx context.Context) errorReadBool(ctx context.Context) (value bool, err error)ReadByte(ctx context.Context) (value int8, err error)ReadI16(ctx context.Context) (value int16, err error)ReadI32(ctx context.Context) (value int32, err error)ReadI64(ctx context.Context) (value int64, err error)ReadDouble(ctx context.Context) (value float64, err error)ReadString(ctx context.Context) (value string, err error)ReadBinary(ctx context.Context) (value []byte, err error)ReadUUID(ctx context.Context) (value Tuuid, err error)Skip(ctx context.Context, fieldType TType) (err error)Flush(ctx context.Context) (err error)Transport() TTransport
}
可以看出来这里的方法是比较多的,不过不要紧,这里核心的方法就几个。接下来结合具体的示例进行分析。看过前面都知道在tutorial.thrift中有定义如下方法。
service Calculator extends shared.SharedService {/*** A method definition looks like C code. It has a return type, arguments,* and optionally a list of exceptions that it may throw. Note that argument* lists and exception lists are specified using the exact same syntax as* field lists in struct or exception definitions.*/void ping(),i32 add(1:i32 num1, 2:i32 num2),i32 calculate(1:i32 logid, 2:Work w) throws (1:InvalidOperation ouch),/*** This method has a oneway modifier. That means the client only makes* a request and does not listen for any response at all. Oneway methods* must be void.*/oneway void zip()}
然后看一下调用add的方法的client请求到后端响应的整个流程,中间使用compact方法进行序列化。
这里的client是CalculatorClient。需要注意这个是thrift协议自己生成的,我们需要做的就是传入参数,然后看一下对应的实现
// Ahh, now onto the cool part, defining a service. Services just need a name
// and can optionally inherit from another service using the extends keyword.
type CalculatorClient struct {*shared.SharedServiceClient
}
sum, _ := client.Add(context.TODO(), 1, 1)
这里的client可以看一下前面的文章,这里简单说一下,因为thrift是建立在了tcp上,因此client是在conn上面的封装。
然后看一下ADD的实现,也就是CalculatorClient的方法。
func (p *CalculatorClient) Add(ctx context.Context, num1 int32, num2 int32) (_r int32, _err error) {// 将参数写入到CalculatorAddArgs中var _args3 CalculatorAddArgs_args3.Num1 = num1_args3.Num2 = num2var _result5 CalculatorAddResultvar _meta4 thrift.ResponseMeta// 调用Client的Call方法 这里传入了调用的方法名称_meta4, _err = p.Client_().Call(ctx, "add", &_args3, &_result5)p.SetLastResponseMeta_(_meta4)if _err != nil {return}return _result5.GetSuccess(), nil
}
func (p *TStandardClient) Call(ctx context.Context, method string, args, result TStruct) (ResponseMeta, error) {// 每次seqIdp.seqId++seqId := p.seqIdif err := p.Send(ctx, p.oprot, seqId, method, args); err != nil {return ResponseMeta{}, err}// method is onewayif result == nil {return ResponseMeta{}, nil}err := p.Recv(ctx, p.iprot, seqId, method, result)var headers THeaderMapif hp, ok := p.iprot.(*THeaderProtocol); ok {headers = hp.transport.readHeaders}return ResponseMeta{Headers: headers,}, err
}
然后这李主要就是Send和Recv方法。然后分别看一下对应的实现
func (p *TStandardClient) Send(ctx context.Context, oprot TProtocol, seqId int32, method string, args TStruct) error {// Set headers from context object on THeaderProtocolif headerProt, ok := oprot.(*THeaderProtocol); ok {headerProt.ClearWriteHeaders()for _, key := range GetWriteHeaderList(ctx) {if value, ok := GetHeader(ctx, key); ok {headerProt.SetWriteHeader(key, value)}}}if err := oprot.WriteMessageBegin(ctx, method, CALL, seqId); err != nil {return err}if err := args.Write(ctx, oprot); err != nil {return err}if err := oprot.WriteMessageEnd(ctx); err != nil {return err}return oprot.Flush(ctx)
}
这里Send主要就是
这里的oprot就是compact所实现的方法,然后args.Write就是传入的args的Write方法(这里好像有点废话)。
然后接下来看一下这几个方法
首先看一下具体的实现
const (COMPACT_PROTOCOL_ID = 0x082COMPACT_VERSION = 1COMPACT_VERSION_MASK = 0x1fCOMPACT_TYPE_MASK = 0x0E0COMPACT_TYPE_BITS = 0x07COMPACT_TYPE_SHIFT_AMOUNT = 5
)
const (INVALID_TMESSAGE_TYPE TMessageType = 0CALL TMessageType = 1REPLY TMessageType = 2EXCEPTION TMessageType = 3ONEWAY TMessageType = 4
)//
// Public Writing methods.
//// Write a message header to the wire. Compact Protocol messages contain the
// protocol version so we can migrate forwards in the future if need be.
func (p *TCompactProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqid int32) error {err := p.writeByteDirect(COMPACT_PROTOCOL_ID)if err != nil {return NewTProtocolException(err)}err = p.writeByteDirect((COMPACT_VERSION & COMPACT_VERSION_MASK) | ((byte(typeId) << COMPACT_TYPE_SHIFT_AMOUNT) & COMPACT_TYPE_MASK))if err != nil {return NewTProtocolException(err)}_, err = p.writeVarint32(seqid)if err != nil {return NewTProtocolException(err)}e := p.WriteString(ctx, name)return e}
这个可以理解成写入协议头,标明使用了哪种的序列化的方法。然后可以看出首先在conn中写入COMPACT_PROTOCOL_ID,对应的是0x082,说明使用了compact方法。然后是compact对应的版本和seqid,这里是(COMPACT_VERSION & COMPACT_VERSION_MASK) | ((byte(typeId) << COMPACT_TYPE_SHIFT_AMOUNT) & COMPACT_TYPE_MASK).
首先是COMPACT_VERSION & COMPACT_VERSION_MASK因为 COMPACT_VERSION=1,然后COMPACT_VERSION_MASK是0x1f也就是 前面的前五位都是1.所以这里值就是1
然后byte(typeId) << COMPACT_TYPE_SHIFT_AMOUNT) 。结合上文,这里的typeId是1,然后COMPACT_TYPE_SHIFT_AMOUNT是5,也就向左边移动5位。然后COMPACT_TYPE_MASK是0x0E0二进制是11100000。所以值就是0x10
所以这里的结果就是0x11。所以前五位表示压缩的版本,后三位表示的是typeId。
然后就是就是写入seqid,这个是每次都会递增,因此使用的int32去表示,然后看一下writeVarint32这个方法。
// Write an i32 as a varint. Results in 1-5 bytes on the wire.
// TODO(pomack): make a permanent buffer like writeVarint64?
func (p *TCompactProtocol) writeVarint32(n int32) (int, error) {i32buf := p.buffer[0:5]idx := 0for {if (n & ^0x7F) == 0 {i32buf[idx] = byte(n)idx++// p.writeByteDirect(byte(n));break// return;} else {i32buf[idx] = byte((n & 0x7F) | 0x80)idx++// p.writeByteDirect(byte(((n & 0x7F) | 0x80)));u := uint32(n)n = int32(u >> 7)}}return p.trans.Write(i32buf[0:idx])
}
这里首先是通过 & ^0x7F , 也就是将n的前面7位全部置零判断n是否前7位都是等于0,通俗来说就是小于127,然后是的话直接写入到buf中,然后这样是不是省下来24个bite也就是3个字节。
然后如果是大于的话,就是分段处理。
这里的trans.Write就可以理解为conn.Write。
// Write a string to the wire with a varint size preceding.
func (p *TCompactProtocol) WriteString(ctx context.Context, value string) error {_, e := p.writeVarint32(int32(len(value)))if e != nil {return NewTProtocolException(e)}if len(value) == 0 {return nil}_, e = p.trans.WriteString(value)return e
}
这里的writeVarint32上面已经说了,然后就是trans.WriteString其实就是conn.Write([]byte)是一样的。
这里是写入了方法名称,typeId,seqid。然后接下来就是写入结构体的参数。也就是CalculatorAddArgs的Write方法。
func (p *CalculatorAddArgs) Write(ctx context.Context, oprot thrift.TProtocol) error {if err := oprot.WriteStructBegin(ctx, "add_args"); err != nil {return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err)}if p != nil {if err := p.writeField1(ctx, oprot); err != nil {return err}if err := p.writeField2(ctx, oprot); err != nil {return err}}if err := oprot.WriteFieldStop(ctx); err != nil {return thrift.PrependError("write field stop error: ", err)}if err := oprot.WriteStructEnd(ctx); err != nil {return thrift.PrependError("write struct stop error: ", err)}return nil
}
从这个上面可以看出来主要调用的还是compact的这几个方法。
func (p *TCompactProtocol) WriteStructBegin(ctx context.Context, name string) error {p.lastField = append(p.lastField, p.lastFieldId)p.lastFieldId = 0return nil
}
这个就是把lastFieldId添加到lastField
writeField1 就是写入CalculatorAddArgs中的第一个num1字段,然后写入相应的值。
func (p *CalculatorAddArgs) writeField1(ctx context.Context, oprot thrift.TProtocol) (err error) {if err := oprot.WriteFieldBegin(ctx, "num1", thrift.I32, 1); err != nil {return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:num1: ", p), err)}if err := oprot.WriteI32(ctx, int32(p.Num1)); err != nil {return thrift.PrependError(fmt.Sprintf("%T.num1 (1) field write error: ", p), err)}if err := oprot.WriteFieldEnd(ctx); err != nil {return thrift.PrependError(fmt.Sprintf("%T write field end error 1:num1: ", p), err)}return err
}
func (p *TCompactProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error {if typeId == BOOL {// we want to possibly include the value, so we'll wait.p.booleanFieldName, p.booleanFieldId, p.booleanFieldPending = name, id, truereturn nil}_, err := p.writeFieldBeginInternal(ctx, name, typeId, id, 0xFF)return NewTProtocolException(err)
}// The workhorse of writeFieldBegin. It has the option of doing a
// 'type override' of the type header. This is used specifically in the
// boolean field case.
func (p *TCompactProtocol) writeFieldBeginInternal(ctx context.Context, name string, typeId TType, id int16, typeOverride byte) (int, error) {// short lastField = lastField_.pop();// if there's a type override, use that.var typeToWrite byteif typeOverride == 0xFF {typeToWrite = byte(p.getCompactType(typeId))} else {typeToWrite = typeOverride}// check if we can use delta encoding for the field idfieldId := int(id)written := 0if fieldId > p.lastFieldId && fieldId-p.lastFieldId <= 15 {// write them togethererr := p.writeByteDirect(byte((fieldId-p.lastFieldId)<<4) | typeToWrite)if err != nil {return 0, err}} else {// write them separateerr := p.writeByteDirect(typeToWrite)if err != nil {return 0, err}err = p.WriteI16(ctx, id)written = 1 + 2if err != nil {return 0, err}}p.lastFieldId = fieldIdreturn written, nil
}
这里主要是writeFieldBeginInternal这个方法,然后逻辑也是比较简单,根据长度小于四个比特也就是四个比特,那么就是和typeToWrite组装成一个byte合并发送,不然就是分开写入。
这个方法上面已经说过,就是把值写入过去,告诉对端。
func (p *TCompactProtocol) WriteFieldEnd(ctx context.Context) error { return nil }
func (p *TCompactProtocol) WriteFieldStop(ctx context.Context) error {err := p.writeByteDirect(STOP)return NewTProtocolException(err)
}
func (p *TCompactProtocol) WriteStructEnd(ctx context.Context) error {if len(p.lastField) <= 0 {return NewTProtocolExceptionWithType(INVALID_DATA, errors.New("WriteStructEnd called without matching WriteStructBegin call before"))}p.lastFieldId = p.lastField[len(p.lastField)-1]p.lastField = p.lastField[:len(p.lastField)-1]return nil
}
这个就没有操作什么了,就是把最近的 lastField写入到lastFieldId中去。
好的这里write就写完了,这里总结一下,主要就是写入compact的协议头,然后将结构成员写入,这里是通过field去判断是哪一个字段。
服务端的处理开始,首先就是process的处理。
func (p *SharedServiceProcessor) Process(ctx context.Context, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) {name, _, seqId, err2 := iprot.ReadMessageBegin(ctx)if err2 != nil {return false, thrift.WrapTException(err2)}if processor, ok := p.GetProcessorFunction(name); ok {return processor.Process(ctx, seqId, iprot, oprot)}iprot.Skip(ctx, thrift.STRUCT)iprot.ReadMessageEnd(ctx)x4 := thrift.NewTApplicationException(thrift.UNKNOWN_METHOD, "Unknown function "+name)oprot.WriteMessageBegin(ctx, name, thrift.EXCEPTION, seqId)x4.Write(ctx, oprot)oprot.WriteMessageEnd(ctx)oprot.Flush(ctx)return false, x4}
这个就是读取请求头的,看一下具体的实现
func (p *TCompactProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqId int32, err error) {var protocolId byte// 进行读取_, deadlineSet := ctx.Deadline()for {protocolId, err = p.readByteDirect()if deadlineSet && isTimeoutError(err) && ctx.Err() == nil {// keep retrying I/O timeout errors since we still have// time leftcontinue}// For anything else, don't retrybreak}if err != nil {return}// 判断是不是compactif protocolId != COMPACT_PROTOCOL_ID {e := fmt.Errorf("Expected protocol id %02x but got %02x", COMPACT_PROTOCOL_ID, protocolId)return "", typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, e)}// 读取版本versionAndType, err := p.readByteDirect()if err != nil {return}// 判断版本version := versionAndType & COMPACT_VERSION_MASK// 读取typeIdtypeId = TMessageType((versionAndType >> COMPACT_TYPE_SHIFT_AMOUNT) & COMPACT_TYPE_BITS)if version != COMPACT_VERSION {e := fmt.Errorf("Expected version %02x but got %02x", COMPACT_VERSION, version)err = NewTProtocolExceptionWithType(BAD_VERSION, e)return}// 读取seqIdseqId, e := p.readVarint32()if e != nil {err = NewTProtocolException(e)return}// 读取方法名称name, err = p.ReadString(ctx)return
}
这里读取到了方法名称之后就是,根据方法名然后执行相应的handler。结合上文这里传的Add也就是
func NewCalculatorProcessor(handler Calculator) *CalculatorProcessor {self10 := &CalculatorProcessor{shared.NewSharedServiceProcessor(handler)}self10.AddToProcessorMap("ping", &calculatorProcessorPing{handler: handler})self10.AddToProcessorMap("add", &calculatorProcessorAdd{handler: handler})self10.AddToProcessorMap("calculate", &calculatorProcessorCalculate{handler: handler})self10.AddToProcessorMap("zip", &calculatorProcessorZip{handler: handler})return self10
}
所以这里看一下calculatorProcessorAdd的process方法
func (p *calculatorProcessorAdd) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) {args := CalculatorAddArgs{}var err2 error// 读取参数if err2 = args.Read(ctx, iprot); err2 != nil {iprot.ReadMessageEnd(ctx)x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err2.Error())oprot.WriteMessageBegin(ctx, "add", thrift.EXCEPTION, seqId)x.Write(ctx, oprot)oprot.WriteMessageEnd(ctx)oprot.Flush(ctx)return false, thrift.WrapTException(err2)}iprot.ReadMessageEnd(ctx)tickerCancel := func() {}// Start a goroutine to do server side connectivity check.if thrift.ServerConnectivityCheckInterval > 0 {var cancel context.CancelFuncctx, cancel = context.WithCancel(ctx)defer cancel()var tickerCtx context.ContexttickerCtx, tickerCancel = context.WithCancel(context.Background())defer tickerCancel()go func(ctx context.Context, cancel context.CancelFunc) {ticker := time.NewTicker(thrift.ServerConnectivityCheckInterval)defer ticker.Stop()for {select {case <-ctx.Done():returncase <-ticker.C:if !iprot.Transport().IsOpen() {cancel()return}}}}(tickerCtx, cancel)}result := CalculatorAddResult{}var retval int32if retval, err2 = p.handler.Add(ctx, args.Num1, args.Num2); err2 != nil {tickerCancel()if err2 == thrift.ErrAbandonRequest {return false, thrift.WrapTException(err2)}x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing add: "+err2.Error())oprot.WriteMessageBegin(ctx, "add", thrift.EXCEPTION, seqId)x.Write(ctx, oprot)oprot.WriteMessageEnd(ctx)oprot.Flush(ctx)return true, thrift.WrapTException(err2)} else {result.Success = &retval}tickerCancel()if err2 = oprot.WriteMessageBegin(ctx, "add", thrift.REPLY, seqId); err2 != nil {err = thrift.WrapTException(err2)}if err2 = result.Write(ctx, oprot); err == nil && err2 != nil {err = thrift.WrapTException(err2)}if err2 = oprot.WriteMessageEnd(ctx); err == nil && err2 != nil {err = thrift.WrapTException(err2)}if err2 = oprot.Flush(ctx); err == nil && err2 != nil {err = thrift.WrapTException(err2)}if err != nil {return}return true, err
}
func (p *CalculatorAddArgs) Read(ctx context.Context, iprot thrift.TProtocol) error {if _, err := iprot.ReadStructBegin(ctx); err != nil {return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err)}for {_, fieldTypeId, fieldId, err := iprot.ReadFieldBegin(ctx)if err != nil {return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err)}if fieldTypeId == thrift.STOP {break}switch fieldId {case 1:if fieldTypeId == thrift.I32 {if err := p.ReadField1(ctx, iprot); err != nil {return err}} else {if err := iprot.Skip(ctx, fieldTypeId); err != nil {return err}}case 2:if fieldTypeId == thrift.I32 {if err := p.ReadField2(ctx, iprot); err != nil {return err}} else {if err := iprot.Skip(ctx, fieldTypeId); err != nil {return err}}default:if err := iprot.Skip(ctx, fieldTypeId); err != nil {return err}}if err := iprot.ReadFieldEnd(ctx); err != nil {return err}}if err := iprot.ReadStructEnd(ctx); err != nil {return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err)}return nil
}
这里也前面也一样,主要是ReadStructBegin,ReadFieldBegin,ReadField1,ReadField2,ReadFieldEnd,ReadStructEnd
// Read a struct begin. There's nothing on the wire for this, but it is our
// opportunity to push a new struct begin marker onto the field stack.
func (p *TCompactProtocol) ReadStructBegin(ctx context.Context) (name string, err error) {p.lastField = append(p.lastField, p.lastFieldId)p.lastFieldId = 0return
}
func (p *TCompactProtocol) ReadFieldBegin(ctx context.Context) (name string, typeId TType, id int16, err error) {// 读取字段t, err := p.readByteDirect()if err != nil {return}// if it's a stop, then we can return immediately, as the struct is over.if (t & 0x0f) == STOP {return "", STOP, 0, nil}// 判断读取的字段的id 也就是第几个字段// mask off the 4 MSB of the type header. it could contain a field id delta.modifier := int16((t & 0xf0) >> 4)if modifier == 0 {// not a delta. look ahead for the zigzag varint field id.id, err = p.ReadI16(ctx)if err != nil {return}} else {// has a delta. add the delta to the last read field id.id = int16(p.lastFieldId) + modifier}// 获取字段的typetypeId, e := p.getTType(tCompactType(t & 0x0f))if e != nil {err = NewTProtocolException(e)return}// if this happens to be a boolean field, the value is encoded in the typeif p.isBoolType(t) {// save the boolean value in a special instance variable.p.boolValue = (byte(t)&0x0f == COMPACT_BOOLEAN_TRUE)p.boolValueIsNotNull = true}// push the new field onto the field stack so we can keep the deltas going.p.lastFieldId = int(id)return
}
结合上文也容易看出,这里主要就是读取Field和typeId。Field是第几个字段,然后typeId判断是否写入结束。
然后接下来的ReadField1和ReadField2就是读取字段值。读取结束就是具体的逻辑。这个是在自己的文件中定义的, 这个就不细说了。然后看一下返回的代码
if err2 = oprot.WriteMessageBegin(ctx, "add", thrift.REPLY, seqId); err2 != nil {err = thrift.WrapTException(err2)}if err2 = result.Write(ctx, oprot); err == nil && err2 != nil {err = thrift.WrapTException(err2)}if err2 = oprot.WriteMessageEnd(ctx); err == nil && err2 != nil {err = thrift.WrapTException(err2)}if err2 = oprot.Flush(ctx); err == nil && err2 != nil {err = thrift.WrapTException(err2)}if err != nil {return}
这里的写入逻辑和上面一样。
好的到这里thrift协议的go版本协议大致就说完了。
下一篇:WPA-hashcat渗透