more protocol/service refactor

- removed protocol.Event: CNPeers now send protocol.PacketEvents
- peer uData is held in CNPeer, use SetUserData() and UserData() to
set/read it
- Service.PacketHandler calback has changed, removed uData:
switched calls to peer.SetUserData() and peer.UserData() where appropriate
- service.Service lots of tidying up, removed dependence on old
protocol.Event.
- service.Service && protocol.CNPeer now accept a cancelable context.
hooray graceful shutdowns and unit tests!
- general cleanup
This commit is contained in:
2023-12-01 00:56:34 -06:00
parent c0ba365cf5
commit f4b17906ce
12 changed files with 292 additions and 261 deletions

View File

@@ -1,10 +1,11 @@
package protocol
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"io"
"log"
"net"
"sync/atomic"
"time"
@@ -15,9 +16,17 @@ const (
USE_FE
)
type PacketEvent struct {
Type int
Pkt *bytes.Buffer
PktID uint32
}
// CNPeer is a simple wrapper for net.Conn connections to send/recv packets over the Fusionfall packet protocol.
type CNPeer struct {
uData interface{}
conn net.Conn
ctx context.Context
whichKey int
alive *atomic.Bool
@@ -32,9 +41,10 @@ func GetTime() uint64 {
return uint64(time.Now().UnixMilli())
}
func NewCNPeer(conn net.Conn) *CNPeer {
func NewCNPeer(ctx context.Context, conn net.Conn) *CNPeer {
p := &CNPeer{
conn: conn,
ctx: ctx,
whichKey: USE_E,
alive: &atomic.Bool{},
@@ -45,6 +55,14 @@ func NewCNPeer(conn net.Conn) *CNPeer {
return p
}
func (peer *CNPeer) SetUserData(uData interface{}) {
peer.uData = uData
}
func (peer *CNPeer) UserData() interface{} {
return peer.uData
}
func (peer *CNPeer) Send(typeID uint32, data ...interface{}) error {
// grab buffer from pool
buf := GetBuffer()
@@ -82,7 +100,7 @@ func (peer *CNPeer) Send(typeID uint32, data ...interface{}) error {
EncryptData(buf.Bytes()[4:], key)
// send full packet
log.Printf("Sending %#v, sizeof: %d, buffer: %v", data, buf.Len(), buf.Bytes())
// log.Printf("Sending %#v, sizeof: %d, buffer: %v", data, buf.Len(), buf.Bytes())
if _, err := peer.conn.Write(buf.Bytes()); err != nil {
return fmt.Errorf("failed to write packet body! %v", err)
}
@@ -99,50 +117,52 @@ func (peer *CNPeer) Kill() {
return
}
log.Printf("Killing peer %p", peer)
peer.conn.Close()
}
// meant to be invoked as a goroutine
func (peer *CNPeer) Handler(eRecv chan<- *Event) error {
func (peer *CNPeer) Handler(eRecv chan<- *PacketEvent) error {
defer func() {
eRecv <- &Event{Type: EVENT_CLIENT_DISCONNECT, Peer: peer}
close(eRecv)
peer.Kill()
}()
peer.alive.Store(true)
eRecv <- &Event{Type: EVENT_CLIENT_CONNECT, Peer: peer}
for {
// read packet size, the goroutine spends most of it's time parked here
var sz uint32
if err := binary.Read(peer.conn, binary.LittleEndian, &sz); err != nil {
return err
select {
case <-peer.ctx.Done():
return nil
default:
// read packet size, the goroutine spends most of it's time parked here
var sz uint32
if err := binary.Read(peer.conn, binary.LittleEndian, &sz); err != nil {
return err
}
// client should never send a packet size outside of this range
if sz > CN_PACKET_BUFFER_SIZE || sz < 4 {
return fmt.Errorf("invalid packet size: %d", sz)
}
// grab buffer && read packet body
buf := GetBuffer()
if _, err := buf.ReadFrom(io.LimitReader(peer.conn, int64(sz))); err != nil {
return fmt.Errorf("failed to read packet body: %v", err)
}
// decrypt
DecryptData(buf.Bytes(), peer.E_key)
pkt := NewPacket(buf)
// create packet && read pktID
var pktID uint32
if err := pkt.Decode(&pktID); err != nil {
return fmt.Errorf("failed to read packet type! %v", err)
}
// dispatch packet
// log.Printf("Got packet ID: %x, with a sizeof: %d\n", pktID, sz)
eRecv <- &PacketEvent{Pkt: buf, PktID: pktID}
}
// client should never send a packet size outside of this range
if sz > CN_PACKET_BUFFER_SIZE || sz < 4 {
return fmt.Errorf("invalid packet size: %d", sz)
}
// grab buffer && read packet body
buf := GetBuffer()
if _, err := buf.ReadFrom(io.LimitReader(peer.conn, int64(sz))); err != nil {
return fmt.Errorf("failed to read packet body: %v", err)
}
// decrypt
DecryptData(buf.Bytes(), peer.E_key)
pkt := NewPacket(buf)
// create packet && read pktID
var pktID uint32
if err := pkt.Decode(&pktID); err != nil {
return fmt.Errorf("failed to read packet type! %v", err)
}
// dispatch packet
// log.Printf("Got packet ID: %x, with a sizeof: %d\n", pktID, sz)
eRecv <- &Event{Type: EVENT_CLIENT_PACKET, Peer: peer, Pkt: buf, PktID: pktID}
}
}

View File

@@ -1,16 +0,0 @@
package protocol
import "bytes"
const (
EVENT_CLIENT_DISCONNECT = iota
EVENT_CLIENT_CONNECT
EVENT_CLIENT_PACKET
)
type Event struct {
Type int
Peer *CNPeer
Pkt *bytes.Buffer
PktID uint32
}

View File

@@ -1,6 +1,7 @@
package service
import (
"context"
"errors"
"fmt"
"log"
@@ -13,9 +14,9 @@ import (
"github.com/CPunch/gopenfusion/internal/protocol"
)
type PacketHandler func(peer *protocol.CNPeer, uData interface{}, pkt protocol.Packet) error
type PacketHandler func(peer *protocol.CNPeer, pkt protocol.Packet) error
func StubbedPacket(_ *protocol.CNPeer, _ interface{}, _ protocol.Packet) error {
func StubbedPacket(_ *protocol.CNPeer, _ protocol.Packet) error {
return nil
}
@@ -23,22 +24,22 @@ type Service struct {
listener net.Listener
port int
Name string
stop chan struct{} // tell active handleEvents() to stop
stopped chan struct{}
ctx context.Context
started chan struct{}
stopped chan struct{}
packetHandlers map[uint32]PacketHandler
peers map[*protocol.CNPeer]interface{}
peers map[chan *protocol.PacketEvent]*protocol.CNPeer
stateLock sync.Mutex
// OnDisconnect is called when a peer disconnects from the service.
// uData is the stored value of the key/value pair in the peer map.
// It may not be set while the service is running. (eg. srvc.Start() has been called)
OnDisconnect func(peer *protocol.CNPeer, uData interface{})
OnDisconnect func(peer *protocol.CNPeer)
// OnConnect is called when a peer connects to the service.
// return value is used as the value in the peer map.
// It may not be set while the service is running. (eg. srvc.Start() has been called)
OnConnect func(peer *protocol.CNPeer) (uData interface{})
OnConnect func(peer *protocol.CNPeer)
}
func RandomPort() (int, error) {
@@ -55,44 +56,52 @@ func RandomPort() (int, error) {
return strconv.Atoi(port)
}
func NewService(name string, port int) *Service {
func NewService(ctx context.Context, name string, port int) *Service {
srvc := &Service{
port: port,
Name: name,
}
srvc.Reset()
srvc.Reset(ctx)
return srvc
}
func (service *Service) Reset() {
service.packetHandlers = make(map[uint32]PacketHandler)
service.peers = make(map[*protocol.CNPeer]interface{})
service.started = make(chan struct{})
func (srvc *Service) Reset(ctx context.Context) {
srvc.ctx = ctx
srvc.packetHandlers = make(map[uint32]PacketHandler)
srvc.peers = make(map[chan *protocol.PacketEvent]*protocol.CNPeer)
srvc.started = make(chan struct{})
srvc.stopped = make(chan struct{})
}
// may not be called while the service is running (eg. srvc.Start() has been called)
func (service *Service) AddPacketHandler(pktID uint32, handler PacketHandler) {
service.packetHandlers[pktID] = handler
func (srvc *Service) AddPacketHandler(pktID uint32, handler PacketHandler) {
srvc.packetHandlers[pktID] = handler
}
func (service *Service) Start() error {
service.stop = make(chan struct{})
service.stopped = make(chan struct{})
peerConnections := make(chan chan *protocol.Event)
go service.handleEvents(peerConnections)
type newPeerConnection struct {
peer *protocol.CNPeer
channel chan *protocol.PacketEvent
}
func (srvc *Service) Start() error {
peerConnections := make(chan newPeerConnection)
defer close(peerConnections)
go srvc.handleEvents(peerConnections)
// open listener socket
var err error
service.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", service.port))
srvc.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", srvc.port))
if err != nil {
return err
}
defer srvc.listener.Close()
close(service.started) // signal that the service has started
log.Printf("%s service hosted on %s:%d\n", service.Name, config.GetAnnounceIP(), service.port)
log.Printf("%s service hosted on %s:%d\n", srvc.Name, config.GetAnnounceIP(), srvc.port)
close(srvc.started) // signal that the service has started
for {
conn, err := service.listener.Accept()
conn, err := srvc.listener.Accept()
if err != nil {
fmt.Println(err)
// we expect this to happen when the service is stopped
@@ -103,143 +112,148 @@ func (service *Service) Start() error {
}
// create a new peer and pass it to the event loop
eRecv := make(chan *protocol.Event)
peer := protocol.NewCNPeer(conn)
log.Printf("New peer %p connected to %s\n", peer, service.Name)
peerConnections <- eRecv
peer := protocol.NewCNPeer(srvc.ctx, conn)
eRecv := make(chan *protocol.PacketEvent)
peerConnections <- newPeerConnection{channel: eRecv, peer: peer}
go peer.Handler(eRecv)
}
}
func (srvc *Service) getPeer(channel chan *protocol.PacketEvent) *protocol.CNPeer {
return srvc.peers[channel]
}
func (srvc *Service) setPeer(channel chan *protocol.PacketEvent, peer *protocol.CNPeer) {
srvc.peers[channel] = peer
}
func (srvc *Service) removePeer(channel chan *protocol.PacketEvent) {
delete(srvc.peers, channel)
}
// returns a channel that is closed when the service has started.
// this is useful if you need to do something after the service has started.
func (service *Service) Started() <-chan struct{} {
return service.started
// this is useful if you need to wait until after the service has started.
func (srvc *Service) Started() <-chan struct{} {
return srvc.started
}
// returns a channel that is closed when the service has stopped.
// this is useful if you need to wait until the service has completely stopped.
func (service *Service) Stopped() <-chan struct{} {
return service.stopped
}
// stops the service and disconnects all peers. OnDisconnect will be called
// for each peer.
func (service *Service) Stop() {
close(service.stop)
service.listener.Close()
}
// returns the stored uData for the peer.
// if the peer does not exist, nil is returned.
// NOTE: the peer map is not locked while accessing, if you're calling this
// outside of the service's event loop, you'll need to lock the peer map yourself.
func (service *Service) GetPeerData(peer *protocol.CNPeer) interface{} {
return service.peers[peer]
}
// sets the stored uData for the peer.
// NOTE: the peer map is not locked while accessing, if you're calling this
// outside of the service's event loop, you'll need to lock the peer map yourself.
func (service *Service) SetPeerData(peer *protocol.CNPeer, uData interface{}) {
service.peers[peer] = uData
// this is useful if you need wait until after the service has stopped.
func (srvc *Service) Stopped() <-chan struct{} {
return srvc.stopped
}
// calls f for each peer in the service passing the peer and the stored uData.
// if f returns false, the iteration is stopped.
// NOTE: the peer map is not locked while iterating, if you're calling this
// outside of the service's event loop, you'll need to lock the peer map yourself.
func (service *Service) RangePeers(f func(peer *protocol.CNPeer, uData interface{}) bool) {
for peer, uData := range service.peers {
if !f(peer, uData) {
func (srvc *Service) RangePeers(f func(peer *protocol.CNPeer) bool) {
for _, peer := range srvc.peers {
if !f(peer) {
break
}
}
}
// locks the peer map.
func (service *Service) Lock() {
service.stateLock.Lock()
func (srvc *Service) Lock() {
srvc.stateLock.Lock()
}
// unlocks the peer map.
func (service *Service) Unlock() {
service.stateLock.Unlock()
func (srvc *Service) Unlock() {
srvc.stateLock.Unlock()
}
func (srvc *Service) stop() {
// OnDisconnect handler might need to do something important
srvc.RangePeers(func(peer *protocol.CNPeer) bool {
peer.Kill()
if srvc.OnDisconnect != nil {
srvc.OnDisconnect(peer)
}
return true
})
log.Printf("%s service stopped\n", srvc.Name)
close(srvc.stopped)
}
// handleEvents is the main event loop for the service.
// it handles all events from the peers and calls the appropriate handlers.
func (service *Service) handleEvents(eRecv <-chan chan *protocol.Event) {
func (srvc *Service) handleEvents(peerPipe <-chan newPeerConnection) {
defer srvc.stop()
poll := make([]reflect.SelectCase, 0, 4)
// add the stop channel and the peer connection channel to our poll queue
poll = append(poll, reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(service.stop),
Chan: reflect.ValueOf(srvc.ctx.Done()),
})
poll = append(poll, reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(eRecv),
Chan: reflect.ValueOf(peerPipe),
})
addPoll := func(channel chan *protocol.PacketEvent) {
poll = append(poll, reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(channel),
})
}
removePoll := func(index int) {
poll = append(poll[:index], poll[index+1:]...)
}
for {
chosen, value, _ := reflect.Select(poll)
if chosen == 0 {
// stop
// OnDisconnect handler might need to do something important
service.Lock()
service.RangePeers(func(peer *protocol.CNPeer, uData interface{}) bool {
peer.Kill()
service.disconnect(peer)
return true
})
service.Unlock()
// signal we have stopped
close(service.stopped)
chosen, value, recvOK := reflect.Select(poll)
switch chosen {
case 0: // cancel signal received, stop the service
return
} else if chosen == 1 {
// new peer, add it to our poll queue
poll = append(poll, reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(value.Interface()),
})
} else {
// peer event
event, ok := value.Interface().(*protocol.Event)
if !ok {
panic("invalid event type")
case 1: // new peer, add it to our poll queue
if !recvOK {
return
}
service.Lock()
switch event.Type {
case protocol.EVENT_CLIENT_DISCONNECT:
// strip the peer from our poll queue
poll = append(poll[:chosen], poll[chosen+1:]...)
service.disconnect(value.Interface().(*protocol.Event).Peer)
case protocol.EVENT_CLIENT_CONNECT:
service.connect(event.Peer)
case protocol.EVENT_CLIENT_PACKET:
if err := service.handlePacket(event.Peer, event.PktID, protocol.NewPacket(event.Pkt)); err != nil {
log.Printf("Error handling packet: %v", err)
event.Peer.Kill()
}
// the packet buffer is given to us by the event, so we'll need to make sure to return it to the pool
protocol.PutBuffer(event.Pkt)
evnt := value.Interface().(newPeerConnection)
addPoll(evnt.channel)
srvc.connect(evnt.channel, evnt.peer)
default: // peer event
channel := poll[chosen].Chan.Interface().(chan *protocol.PacketEvent)
peer := srvc.getPeer(channel)
if peer == nil {
log.Printf("Unknown peer event: %v", value)
removePoll(chosen)
continue
}
service.Unlock()
evnt, ok := value.Interface().(*protocol.PacketEvent)
if !recvOK || !ok || evnt == nil {
// peer disconnected, remove it from our poll queue
removePoll(chosen)
srvc.disconnect(channel, peer)
continue
}
srvc.Lock()
if err := srvc.handlePacket(peer, evnt.PktID, protocol.NewPacket(evnt.Pkt)); err != nil {
log.Printf("Error handling packet: %v", err)
peer.Kill()
}
srvc.Unlock()
// the packet buffer is given to us by the event, so we'll need to make sure to return it to the pool
protocol.PutBuffer(evnt.Pkt)
}
}
}
func (service *Service) handlePacket(peer *protocol.CNPeer, typeID uint32, pkt protocol.Packet) error {
uData := service.peers[peer]
if hndlr, ok := service.packetHandlers[typeID]; ok {
func (srvc *Service) handlePacket(peer *protocol.CNPeer, typeID uint32, pkt protocol.Packet) error {
if hndlr, ok := srvc.packetHandlers[typeID]; ok {
// fmt.Printf("Handling packet %x\n", typeID)
if err := hndlr(peer, uData, pkt); err != nil {
if err := hndlr(peer, pkt); err != nil {
return err
}
} else {
@@ -249,24 +263,20 @@ func (service *Service) handlePacket(peer *protocol.CNPeer, typeID uint32, pkt p
return nil
}
func (service *Service) disconnect(peer *protocol.CNPeer) {
if service.OnDisconnect != nil {
uData := service.peers[peer]
service.OnDisconnect(peer, uData)
func (srvc *Service) disconnect(channel chan *protocol.PacketEvent, peer *protocol.CNPeer) {
log.Printf("Peer %p disconnected from %s\n", peer, srvc.Name)
if srvc.OnDisconnect != nil {
srvc.OnDisconnect(peer)
}
log.Printf("Peer %p disconnected from %s\n", peer, service.Name)
delete(service.peers, peer)
srvc.removePeer(channel)
}
func (service *Service) connect(peer *protocol.CNPeer) {
// default uData to nil, but if the service has an OnConnect
// handler, use the result from that
uData := interface{}(nil)
if service.OnConnect != nil {
uData = service.OnConnect(peer)
func (srvc *Service) connect(channel chan *protocol.PacketEvent, peer *protocol.CNPeer) {
log.Printf("New peer %p connected to %s\n", peer, srvc.Name)
if srvc.OnConnect != nil {
srvc.OnConnect(peer)
}
log.Printf("New peer %p connected to %s\n", peer, service.Name)
service.SetPeerData(peer, uData)
srvc.setPeer(channel, peer)
}

View File

@@ -1,7 +1,9 @@
package service_test
import (
"context"
"fmt"
"log"
"net"
"os"
"sync"
@@ -13,12 +15,11 @@ import (
)
var (
srvc *service.Service
srvcPort int
)
const (
timeout = 5
timeout = 2
maxDummyPeers = 5
)
@@ -44,15 +45,18 @@ func TestMain(m *testing.M) {
panic(err)
}
srvc = service.NewService("TEST", srvcPort)
os.Exit(m.Run())
}
func TestService(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
srvc := service.NewService(ctx, "TEST", srvcPort)
// waitgroup to wait for test packet handler to be called
wg := sync.WaitGroup{}
srvc.AddPacketHandler(0x1234, func(peer *protocol.CNPeer, uData interface{}, pkt protocol.Packet) error {
srvc.AddPacketHandler(0x1234, func(peer *protocol.CNPeer, pkt protocol.Packet) error {
log.Printf("Received packet %#v", pkt)
wg.Done()
return nil
})
@@ -65,7 +69,7 @@ func TestService(t *testing.T) {
// wait for service to start
<-srvc.Started()
wg.Add(maxDummyPeers)
wg.Add(maxDummyPeers * 3) // 2 wg.Done() calls per dummy peer
for i := 0; i < maxDummyPeers; i++ {
go func() {
// make dummy client
@@ -74,18 +78,28 @@ func TestService(t *testing.T) {
t.Error(err)
}
peer := protocol.NewCNPeer(conn)
defer peer.Kill()
// send dummy packet
if err := peer.Send(0x1234); err != nil {
t.Error(err)
}
peer := protocol.NewCNPeer(ctx, conn)
go func() {
defer peer.Kill()
// send dummy packets
for i := 0; i < 2; i++ {
if err := peer.Send(0x1234); err != nil {
t.Error(err)
}
}
}()
// we wait until Handler gracefully exits (peer was killed)
peer.Handler(make(chan *protocol.PacketEvent))
wg.Done()
}()
}
if !waitWithTimeout(&wg, timeout) {
t.Error("timeout waiting for packet handler to be called")
t.Error("failed to wait for packet handler to be called")
}
srvc.Stop()
cancel()
<-srvc.Stopped()
}