CNPeer/Service refactor

- each CNPeer is given a unique chan *protocol.Event to pass events to
the service.handleEvents() loop. this is now passed to CNPeer.Handler()
as opposed to NewCNPeer().
- service has basically been rewritten. handleEvents() main loop uses
reflect.SelectCase() now to handle all of the eRecv channels for each
peer
- new protocol Event type: EVENT_CLIENT_CONNECT
- Added service_test.go; blackbox-styled testing like the others.
TestService() starts a service and spins up  a bunch of dummy peers
and verifies that each packet sent causes the corresponding packet
handler to be called.
This commit is contained in:
CPunch 2023-11-29 19:57:45 -06:00
parent d0346b2382
commit c0ba365cf5
6 changed files with 310 additions and 92 deletions

View File

@ -18,7 +18,6 @@ const (
// CNPeer is a simple wrapper for net.Conn connections to send/recv packets over the Fusionfall packet protocol. // CNPeer is a simple wrapper for net.Conn connections to send/recv packets over the Fusionfall packet protocol.
type CNPeer struct { type CNPeer struct {
conn net.Conn conn net.Conn
eRecv chan *Event
whichKey int whichKey int
alive *atomic.Bool alive *atomic.Bool
@ -33,10 +32,9 @@ func GetTime() uint64 {
return uint64(time.Now().UnixMilli()) return uint64(time.Now().UnixMilli())
} }
func NewCNPeer(eRecv chan *Event, conn net.Conn) *CNPeer { func NewCNPeer(conn net.Conn) *CNPeer {
p := &CNPeer{ p := &CNPeer{
conn: conn, conn: conn,
eRecv: eRecv,
whichKey: USE_E, whichKey: USE_E,
alive: &atomic.Bool{}, alive: &atomic.Bool{},
@ -96,59 +94,55 @@ func (peer *CNPeer) SetActiveKey(whichKey int) {
} }
func (peer *CNPeer) Kill() { func (peer *CNPeer) Kill() {
log.Printf("Killing peer %p", peer) // de-bounce: only kill if alive
if !peer.alive.CompareAndSwap(true, false) { if !peer.alive.CompareAndSwap(true, false) {
return return
} }
log.Printf("Killing peer %p", peer)
peer.conn.Close() peer.conn.Close()
peer.eRecv <- &Event{Type: EVENT_CLIENT_DISCONNECT, Peer: peer}
} }
// meant to be invoked as a goroutine // meant to be invoked as a goroutine
func (peer *CNPeer) Handler() { func (peer *CNPeer) Handler(eRecv chan<- *Event) error {
defer peer.Kill() defer func() {
eRecv <- &Event{Type: EVENT_CLIENT_DISCONNECT, Peer: peer}
close(eRecv)
peer.Kill()
}()
peer.alive.Store(true) peer.alive.Store(true)
eRecv <- &Event{Type: EVENT_CLIENT_CONNECT, Peer: peer}
for { for {
// read packet size, the goroutine spends most of it's time parked here // read packet size, the goroutine spends most of it's time parked here
var sz uint32 var sz uint32
if err := binary.Read(peer.conn, binary.LittleEndian, &sz); err != nil { if err := binary.Read(peer.conn, binary.LittleEndian, &sz); err != nil {
log.Printf("[FATAL] failed to read packet size! %v\n", err) return err
return
} }
// client should never send a packet size outside of this range // client should never send a packet size outside of this range
if sz > CN_PACKET_BUFFER_SIZE || sz < 4 { if sz > CN_PACKET_BUFFER_SIZE || sz < 4 {
log.Printf("[FATAL] malicious packet size received! %d", sz) return fmt.Errorf("invalid packet size: %d", sz)
return
} }
// grab buffer && read packet body // grab buffer && read packet body
if err := func() error { buf := GetBuffer()
buf := GetBuffer() if _, err := buf.ReadFrom(io.LimitReader(peer.conn, int64(sz))); err != nil {
if _, err := buf.ReadFrom(io.LimitReader(peer.conn, int64(sz))); err != nil { return fmt.Errorf("failed to read packet body: %v", err)
return fmt.Errorf("failed to read packet body! %v", err)
}
// decrypt
DecryptData(buf.Bytes(), peer.E_key)
pkt := NewPacket(buf)
// create packet && read pktID
var pktID uint32
if err := pkt.Decode(&pktID); err != nil {
return fmt.Errorf("failed to read packet type! %v", err)
}
// dispatch packet
log.Printf("Got packet ID: %x, with a sizeof: %d\n", pktID, sz)
peer.eRecv <- &Event{Type: EVENT_CLIENT_PACKET, Peer: peer, Pkt: buf, PktID: pktID}
return nil
}(); err != nil {
log.Printf("[FATAL] %v", err)
return
} }
// decrypt
DecryptData(buf.Bytes(), peer.E_key)
pkt := NewPacket(buf)
// create packet && read pktID
var pktID uint32
if err := pkt.Decode(&pktID); err != nil {
return fmt.Errorf("failed to read packet type! %v", err)
}
// dispatch packet
// log.Printf("Got packet ID: %x, with a sizeof: %d\n", pktID, sz)
eRecv <- &Event{Type: EVENT_CLIENT_PACKET, Peer: peer, Pkt: buf, PktID: pktID}
} }
} }

View File

@ -4,6 +4,7 @@ import "bytes"
const ( const (
EVENT_CLIENT_DISCONNECT = iota EVENT_CLIENT_DISCONNECT = iota
EVENT_CLIENT_CONNECT
EVENT_CLIENT_PACKET EVENT_CLIENT_PACKET
) )

View File

@ -1,9 +1,12 @@
package service package service
import ( import (
"errors"
"fmt" "fmt"
"log" "log"
"net" "net"
"reflect"
"strconv"
"sync" "sync"
"github.com/CPunch/gopenfusion/config" "github.com/CPunch/gopenfusion/config"
@ -20,9 +23,12 @@ type Service struct {
listener net.Listener listener net.Listener
port int port int
Name string Name string
eRecv chan *protocol.Event stop chan struct{} // tell active handleEvents() to stop
stopped chan struct{}
started chan struct{}
packetHandlers map[uint32]PacketHandler packetHandlers map[uint32]PacketHandler
peers *sync.Map peers map[*protocol.CNPeer]interface{}
stateLock sync.Mutex
// OnDisconnect is called when a peer disconnects from the service. // OnDisconnect is called when a peer disconnects from the service.
// uData is the stored value of the key/value pair in the peer map. // uData is the stored value of the key/value pair in the peer map.
@ -35,22 +41,34 @@ type Service struct {
OnConnect func(peer *protocol.CNPeer) (uData interface{}) OnConnect func(peer *protocol.CNPeer) (uData interface{})
} }
func NewService(name string, port int) (*Service, error) { func RandomPort() (int, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
return nil, err return 0, err
}
defer l.Close()
_, port, err := net.SplitHostPort(l.Addr().String())
if err != nil {
return 0, err
}
return strconv.Atoi(port)
}
func NewService(name string, port int) *Service {
srvc := &Service{
port: port,
Name: name,
} }
service := &Service{ srvc.Reset()
listener: listener, return srvc
port: port, }
Name: name,
eRecv: make(chan *protocol.Event),
packetHandlers: make(map[uint32]PacketHandler),
peers: &sync.Map{},
}
return service, nil func (service *Service) Reset() {
service.packetHandlers = make(map[uint32]PacketHandler)
service.peers = make(map[*protocol.CNPeer]interface{})
service.started = make(chan struct{})
} }
// may not be called while the service is running (eg. srvc.Start() has been called) // may not be called while the service is running (eg. srvc.Start() has been called)
@ -58,42 +76,169 @@ func (service *Service) AddPacketHandler(pktID uint32, handler PacketHandler) {
service.packetHandlers[pktID] = handler service.packetHandlers[pktID] = handler
} }
func (service *Service) Start() { func (service *Service) Start() error {
log.Printf("%s service hosted on %s:%d\n", service.Name, config.GetAnnounceIP(), service.port) service.stop = make(chan struct{})
service.stopped = make(chan struct{})
peerConnections := make(chan chan *protocol.Event)
go service.handleEvents(peerConnections)
go service.handleEvents() // open listener socket
var err error
service.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", service.port))
if err != nil {
return err
}
close(service.started) // signal that the service has started
log.Printf("%s service hosted on %s:%d\n", service.Name, config.GetAnnounceIP(), service.port)
for { for {
conn, err := service.listener.Accept() conn, err := service.listener.Accept()
if err != nil { if err != nil {
log.Println("Connection error: ", err) fmt.Println(err)
return // we expect this to happen when the service is stopped
if errors.Is(err, net.ErrClosed) {
return nil
}
return err
} }
peer := protocol.NewCNPeer(service.eRecv, conn) // create a new peer and pass it to the event loop
service.connect(peer) eRecv := make(chan *protocol.Event)
peer := protocol.NewCNPeer(conn)
log.Printf("New peer %p connected to %s\n", peer, service.Name)
peerConnections <- eRecv
go peer.Handler(eRecv)
} }
} }
func (service *Service) handleEvents() { // returns a channel that is closed when the service has started.
for event := range service.eRecv { // this is useful if you need to do something after the service has started.
switch event.Type { func (service *Service) Started() <-chan struct{} {
case protocol.EVENT_CLIENT_DISCONNECT: return service.started
service.disconnect(event.Peer) }
case protocol.EVENT_CLIENT_PACKET:
if err := service.handlePacket(event.Peer, event.PktID, protocol.NewPacket(event.Pkt)); err != nil { // returns a channel that is closed when the service has stopped.
log.Printf("Error handling packet: %v", err) // this is useful if you need to wait until the service has completely stopped.
event.Peer.Kill() func (service *Service) Stopped() <-chan struct{} {
return service.stopped
}
// stops the service and disconnects all peers. OnDisconnect will be called
// for each peer.
func (service *Service) Stop() {
close(service.stop)
service.listener.Close()
}
// returns the stored uData for the peer.
// if the peer does not exist, nil is returned.
// NOTE: the peer map is not locked while accessing, if you're calling this
// outside of the service's event loop, you'll need to lock the peer map yourself.
func (service *Service) GetPeerData(peer *protocol.CNPeer) interface{} {
return service.peers[peer]
}
// sets the stored uData for the peer.
// NOTE: the peer map is not locked while accessing, if you're calling this
// outside of the service's event loop, you'll need to lock the peer map yourself.
func (service *Service) SetPeerData(peer *protocol.CNPeer, uData interface{}) {
service.peers[peer] = uData
}
// calls f for each peer in the service passing the peer and the stored uData.
// if f returns false, the iteration is stopped.
// NOTE: the peer map is not locked while iterating, if you're calling this
// outside of the service's event loop, you'll need to lock the peer map yourself.
func (service *Service) RangePeers(f func(peer *protocol.CNPeer, uData interface{}) bool) {
for peer, uData := range service.peers {
if !f(peer, uData) {
break
}
}
}
// locks the peer map.
func (service *Service) Lock() {
service.stateLock.Lock()
}
// unlocks the peer map.
func (service *Service) Unlock() {
service.stateLock.Unlock()
}
// handleEvents is the main event loop for the service.
// it handles all events from the peers and calls the appropriate handlers.
func (service *Service) handleEvents(eRecv <-chan chan *protocol.Event) {
poll := make([]reflect.SelectCase, 0, 4)
// add the stop channel and the peer connection channel to our poll queue
poll = append(poll, reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(service.stop),
})
poll = append(poll, reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(eRecv),
})
for {
chosen, value, _ := reflect.Select(poll)
if chosen == 0 {
// stop
// OnDisconnect handler might need to do something important
service.Lock()
service.RangePeers(func(peer *protocol.CNPeer, uData interface{}) bool {
peer.Kill()
service.disconnect(peer)
return true
})
service.Unlock()
// signal we have stopped
close(service.stopped)
return
} else if chosen == 1 {
// new peer, add it to our poll queue
poll = append(poll, reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(value.Interface()),
})
} else {
// peer event
event, ok := value.Interface().(*protocol.Event)
if !ok {
panic("invalid event type")
} }
// the packet buffer is given to us by the event, so we'll need to make sure to return it to the pool service.Lock()
protocol.PutBuffer(event.Pkt) switch event.Type {
case protocol.EVENT_CLIENT_DISCONNECT:
// strip the peer from our poll queue
poll = append(poll[:chosen], poll[chosen+1:]...)
service.disconnect(value.Interface().(*protocol.Event).Peer)
case protocol.EVENT_CLIENT_CONNECT:
service.connect(event.Peer)
case protocol.EVENT_CLIENT_PACKET:
if err := service.handlePacket(event.Peer, event.PktID, protocol.NewPacket(event.Pkt)); err != nil {
log.Printf("Error handling packet: %v", err)
event.Peer.Kill()
}
// the packet buffer is given to us by the event, so we'll need to make sure to return it to the pool
protocol.PutBuffer(event.Pkt)
}
service.Unlock()
} }
} }
} }
func (service *Service) handlePacket(peer *protocol.CNPeer, typeID uint32, pkt protocol.Packet) error { func (service *Service) handlePacket(peer *protocol.CNPeer, typeID uint32, pkt protocol.Packet) error {
uData, _ := service.peers.Load(peer) uData := service.peers[peer]
if hndlr, ok := service.packetHandlers[typeID]; ok { if hndlr, ok := service.packetHandlers[typeID]; ok {
// fmt.Printf("Handling packet %x\n", typeID)
if err := hndlr(peer, uData, pkt); err != nil { if err := hndlr(peer, uData, pkt); err != nil {
return err return err
} }
@ -106,12 +251,12 @@ func (service *Service) handlePacket(peer *protocol.CNPeer, typeID uint32, pkt p
func (service *Service) disconnect(peer *protocol.CNPeer) { func (service *Service) disconnect(peer *protocol.CNPeer) {
if service.OnDisconnect != nil { if service.OnDisconnect != nil {
uData, _ := service.peers.Load(peer) uData := service.peers[peer]
service.OnDisconnect(peer, uData) service.OnDisconnect(peer, uData)
} }
log.Printf("Peer %p disconnected from %s\n", peer, service.Name) log.Printf("Peer %p disconnected from %s\n", peer, service.Name)
service.peers.Delete(peer) delete(service.peers, peer)
} }
func (service *Service) connect(peer *protocol.CNPeer) { func (service *Service) connect(peer *protocol.CNPeer) {
@ -123,16 +268,5 @@ func (service *Service) connect(peer *protocol.CNPeer) {
} }
log.Printf("New peer %p connected to %s\n", peer, service.Name) log.Printf("New peer %p connected to %s\n", peer, service.Name)
service.peers.Store(peer, uData) service.SetPeerData(peer, uData)
go peer.Handler()
}
func (service *Service) SetPeerData(peer *protocol.CNPeer, uData interface{}) {
service.peers.Store(peer, uData)
}
func (service *Service) RangePeers(f func(peer *protocol.CNPeer, uData interface{}) bool) {
service.peers.Range(func(key, value any) bool {
return f(key.(*protocol.CNPeer), value)
})
} }

View File

@ -0,0 +1,91 @@
package service_test
import (
"fmt"
"net"
"os"
"sync"
"testing"
"time"
"github.com/CPunch/gopenfusion/internal/protocol"
"github.com/CPunch/gopenfusion/internal/service"
)
var (
srvc *service.Service
srvcPort int
)
const (
timeout = 5
maxDummyPeers = 5
)
func waitWithTimeout(wg *sync.WaitGroup, seconds int) bool {
done := make(chan struct{})
go func() {
defer close(done)
wg.Wait()
}()
select {
case <-done:
return true
case <-time.After(time.Duration(seconds) * time.Second):
return false
}
}
func TestMain(m *testing.M) {
var err error
srvcPort, err = service.RandomPort()
if err != nil {
panic(err)
}
srvc = service.NewService("TEST", srvcPort)
os.Exit(m.Run())
}
func TestService(t *testing.T) {
// waitgroup to wait for test packet handler to be called
wg := sync.WaitGroup{}
srvc.AddPacketHandler(0x1234, func(peer *protocol.CNPeer, uData interface{}, pkt protocol.Packet) error {
wg.Done()
return nil
})
go func() {
if err := srvc.Start(); err != nil {
t.Error(err)
}
}()
// wait for service to start
<-srvc.Started()
wg.Add(maxDummyPeers)
for i := 0; i < maxDummyPeers; i++ {
go func() {
// make dummy client
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", srvcPort))
if err != nil {
t.Error(err)
}
peer := protocol.NewCNPeer(conn)
defer peer.Kill()
// send dummy packet
if err := peer.Send(0x1234); err != nil {
t.Error(err)
}
}()
}
if !waitWithTimeout(&wg, timeout) {
t.Error("timeout waiting for packet handler to be called")
}
srvc.Stop()
<-srvc.Stopped()
}

View File

@ -14,10 +14,7 @@ type LoginServer struct {
} }
func NewLoginServer(dbHndlr *db.DBHandler, redisHndlr *redis.RedisHandler, port int) (*LoginServer, error) { func NewLoginServer(dbHndlr *db.DBHandler, redisHndlr *redis.RedisHandler, port int) (*LoginServer, error) {
srvc, err := service.NewService("LOGIN", port) srvc := service.NewService("LOGIN", port)
if err != nil {
return nil, err
}
server := &LoginServer{ server := &LoginServer{
service: srvc, service: srvc,
@ -47,6 +44,10 @@ func NewLoginServer(dbHndlr *db.DBHandler, redisHndlr *redis.RedisHandler, port
return server, nil return server, nil
} }
func (server *LoginServer) Start() { func (server *LoginServer) Start() error {
server.service.Start() return server.service.Start()
}
func (server *LoginServer) Stop() {
server.service.Stop()
} }

View File

@ -19,10 +19,7 @@ type ShardServer struct {
} }
func NewShardServer(dbHndlr *db.DBHandler, redisHndlr *redis.RedisHandler, port int) (*ShardServer, error) { func NewShardServer(dbHndlr *db.DBHandler, redisHndlr *redis.RedisHandler, port int) (*ShardServer, error) {
srvc, err := service.NewService("SHARD", port) srvc := service.NewService("SHARD", port)
if err != nil {
return nil, err
}
server := &ShardServer{ server := &ShardServer{
service: srvc, service: srvc,