diff --git a/protocol/packet.go b/protocol/packet.go index 6c764d3..d2b0829 100644 --- a/protocol/packet.go +++ b/protocol/packet.go @@ -16,61 +16,16 @@ import ( */ type Packet struct { - ByteOrder binary.ByteOrder - Buf []byte + readWriter io.ReadWriter } -func NewPacket(buf []byte) *Packet { +func NewPacket(readWriter io.ReadWriter) *Packet { pkt := &Packet{ - ByteOrder: binary.LittleEndian, - Buf: buf, + readWriter: readWriter, } return pkt } -func (pkt *Packet) writeRaw(data []byte) { - pkt.Buf = append(pkt.Buf, data...) -} - -func (pkt *Packet) Write(data []byte) (int, error) { - pkt.writeRaw(data) - - if len(pkt.Buf) > CN_PACKET_BUFFER_SIZE { - return 0, fmt.Errorf("Failed to write to packet, invalid size!") - } - - return len(data), nil -} - -func (pkt *Packet) writeByte(data byte) { - pkt.Write([]byte{data}) -} - -func (pkt *Packet) readRaw(data []byte) (int, error) { - sz := copy(data, pkt.Buf) - pkt.Buf = pkt.Buf[sz:] - - if sz != len(data) { - return sz, io.EOF - } - - return sz, nil -} - -func (pkt *Packet) Read(data []byte) (int, error) { - if len(data) > len(pkt.Buf) { - return 0, fmt.Errorf("Failed to read from packet, invalid size!") - } - - return pkt.readRaw(data) -} - -func (pkt *Packet) readByte() byte { - data := pkt.Buf[0] - pkt.Buf = pkt.Buf[1:] - return data -} - func (pkt *Packet) encodeStructField(field reflect.StructField, value reflect.Value) { log.Printf("Encoding '%s'", field.Name) @@ -89,13 +44,14 @@ func (pkt *Packet) encodeStructField(field reflect.StructField, value reflect.Va buf16 = buf16[:sz] } else { // grow + // TODO: probably a better way to do this? for len(buf16) < sz { buf16 = append(buf16, 0) } } // write - binary.Write(pkt, pkt.ByteOrder, buf16) + binary.Write(pkt.readWriter, binary.LittleEndian, buf16) default: pkt.Encode(value.Addr().Interface()) } @@ -104,7 +60,7 @@ func (pkt *Packet) encodeStructField(field reflect.StructField, value reflect.Va pad, err := strconv.Atoi(field.Tag.Get("pad")) if err == nil { for i := 0; i < pad; i++ { - pkt.writeByte(0) + pkt.readWriter.Write([]byte{0}) } } } @@ -121,7 +77,7 @@ func (pkt *Packet) Encode(data interface{}) { } default: // we pass everything else to go's binary package - binary.Write(pkt, pkt.ByteOrder, data) + binary.Write(pkt.readWriter, binary.LittleEndian, data) } } @@ -136,7 +92,7 @@ func (pkt *Packet) decodeStructField(field reflect.StructField, value reflect.Va } buf16 := make([]uint16, sz) - binary.Read(pkt, pkt.ByteOrder, buf16) + binary.Read(pkt.readWriter, binary.LittleEndian, buf16) // find null terminator var realSize int @@ -151,11 +107,11 @@ func (pkt *Packet) decodeStructField(field reflect.StructField, value reflect.Va pkt.Decode(value.Addr().Interface()) } - // read padding bytes + // consume padding bytes pad, err := strconv.Atoi(field.Tag.Get("pad")) if err == nil { for i := 0; i < pad; i++ { - pkt.readByte() + pkt.readWriter.Read([]byte{0}) } } } @@ -171,6 +127,6 @@ func (pkt *Packet) Decode(data interface{}) { pkt.decodeStructField(rv.Type().Field(i), rv.Field(i)) } default: - binary.Read(pkt, pkt.ByteOrder, data) + binary.Read(pkt.readWriter, binary.LittleEndian, data) } } diff --git a/protocol/pool/pool.go b/protocol/pool/pool.go new file mode 100644 index 0000000..4d5893d --- /dev/null +++ b/protocol/pool/pool.go @@ -0,0 +1,19 @@ +package pool + +import ( + "bytes" + "sync" +) + +var allocator = sync.Pool{ + New: func() any { return new(bytes.Buffer) }, +} + +func Get() *bytes.Buffer { + return allocator.Get().(*bytes.Buffer) +} + +func Put(buf *bytes.Buffer) { + buf.Reset() + allocator.Put(buf) +} diff --git a/server/peer.go b/server/peer.go index 130fb1e..0862269 100644 --- a/server/peer.go +++ b/server/peer.go @@ -3,11 +3,13 @@ package server import ( "encoding/binary" "fmt" + "io" "log" "net" "github.com/CPunch/gopenfusion/db" "github.com/CPunch/gopenfusion/protocol" + "github.com/CPunch/gopenfusion/protocol/pool" ) const ( @@ -48,32 +50,32 @@ func NewPeer(handler PeerHandler, conn net.Conn) *Peer { } func (client *Peer) Send(data interface{}, typeID uint32) { + buf := pool.Get() + defer func() { // always return the buffer to the pool + pool.Put(buf) + }() + // encode - pkt := protocol.NewPacket(make([]byte, 0)) + pkt := protocol.NewPacket(buf) + + // write the typeID and packet body + pkt.Encode(uint32(typeID)) pkt.Encode(data) - log.Printf("Sending %#v, sizeof: %d", data, len(pkt.Buf)) - // write packet size - tmp := make([]byte, 4) - binary.LittleEndian.PutUint32(tmp, uint32(len(pkt.Buf)+4)) - if _, err := client.conn.Write(tmp); err != nil { - panic(fmt.Errorf("[FATAL] failed to write packet size! %v", err)) - } - - // prepend the typeID to the packet body - binary.LittleEndian.PutUint32(tmp, uint32(typeID)) - tmp = append(tmp, pkt.Buf...) + // write the packet size + binary.Write(client.conn, binary.LittleEndian, uint32(buf.Len())) // encrypt typeID & body switch client.whichKey { case USE_E: - protocol.EncryptData(tmp, client.E_key) + protocol.EncryptData(buf.Bytes(), client.E_key) case USE_FE: - protocol.EncryptData(tmp, client.FE_key) + protocol.EncryptData(buf.Bytes(), client.FE_key) } - // write packet body - if _, err := client.conn.Write(tmp); err != nil { + // write packet type && packet body + log.Printf("Sending %#v, sizeof: %d", data, buf.Len()) + if _, err := client.conn.Write(buf.Bytes()); err != nil { panic(fmt.Errorf("[FATAL] failed to write packet body! %v", err)) } } @@ -96,13 +98,12 @@ func (client *Peer) ClientHandler() { client.Kill() }() - tmp := make([]byte, 4, protocol.CN_PACKET_BUFFER_SIZE) for { // read packet size - if _, err := client.conn.Read(tmp); err != nil { + var sz uint32 + if err := binary.Read(client.conn, binary.LittleEndian, &sz); err != nil { panic(fmt.Errorf("[FATAL] failed to read packet size! %v", err)) } - sz := int(binary.LittleEndian.Uint32(tmp)) // client should never send a packet size outside of this range if sz > protocol.CN_PACKET_BUFFER_SIZE || sz < 4 { @@ -110,20 +111,26 @@ func (client *Peer) ClientHandler() { } // read packet body - if _, err := client.conn.Read(tmp[:sz]); err != nil { + buf := pool.Get() + if _, err := buf.ReadFrom(io.LimitReader(client.conn, int64(sz))); err != nil { panic(fmt.Errorf("[FATAL] failed to read packet body! %v", err)) } - // decrypt && grab typeID - protocol.DecryptData(tmp[:sz], client.E_key) - typeID := uint32(binary.LittleEndian.Uint32(tmp[:4])) + fmt.Printf("%#v", buf) + + // decrypt + protocol.DecryptData(buf.Bytes(), client.E_key) + + // create packet && read typeID + var typeID uint32 + pkt := protocol.NewPacket(buf) + pkt.Decode(&typeID) // dispatch packet log.Printf("Got packet ID: %x, with a sizeof: %d\n", typeID, sz) - pkt := protocol.NewPacket(tmp[4:sz]) client.handler.HandlePacket(client, typeID, pkt) - // reset tmp - tmp = tmp[:4] + // restore buffer to pool + pool.Put(buf) } }