diff --git a/internal/protocol/cnpeer.go b/internal/protocol/cnpeer.go index 8186993..650d7f5 100644 --- a/internal/protocol/cnpeer.go +++ b/internal/protocol/cnpeer.go @@ -18,7 +18,6 @@ const ( // CNPeer is a simple wrapper for net.Conn connections to send/recv packets over the Fusionfall packet protocol. type CNPeer struct { conn net.Conn - eRecv chan *Event whichKey int alive *atomic.Bool @@ -33,10 +32,9 @@ func GetTime() uint64 { return uint64(time.Now().UnixMilli()) } -func NewCNPeer(eRecv chan *Event, conn net.Conn) *CNPeer { +func NewCNPeer(conn net.Conn) *CNPeer { p := &CNPeer{ conn: conn, - eRecv: eRecv, whichKey: USE_E, alive: &atomic.Bool{}, @@ -96,59 +94,55 @@ func (peer *CNPeer) SetActiveKey(whichKey int) { } func (peer *CNPeer) Kill() { - log.Printf("Killing peer %p", peer) - + // de-bounce: only kill if alive if !peer.alive.CompareAndSwap(true, false) { return } + log.Printf("Killing peer %p", peer) peer.conn.Close() - peer.eRecv <- &Event{Type: EVENT_CLIENT_DISCONNECT, Peer: peer} } // meant to be invoked as a goroutine -func (peer *CNPeer) Handler() { - defer peer.Kill() +func (peer *CNPeer) Handler(eRecv chan<- *Event) error { + defer func() { + eRecv <- &Event{Type: EVENT_CLIENT_DISCONNECT, Peer: peer} + close(eRecv) + peer.Kill() + }() peer.alive.Store(true) + eRecv <- &Event{Type: EVENT_CLIENT_CONNECT, Peer: peer} for { // read packet size, the goroutine spends most of it's time parked here var sz uint32 if err := binary.Read(peer.conn, binary.LittleEndian, &sz); err != nil { - log.Printf("[FATAL] failed to read packet size! %v\n", err) - return + return err } // client should never send a packet size outside of this range if sz > CN_PACKET_BUFFER_SIZE || sz < 4 { - log.Printf("[FATAL] malicious packet size received! %d", sz) - return + return fmt.Errorf("invalid packet size: %d", sz) } // grab buffer && read packet body - if err := func() error { - buf := GetBuffer() - if _, err := buf.ReadFrom(io.LimitReader(peer.conn, int64(sz))); err != nil { - return fmt.Errorf("failed to read packet body! %v", err) - } - - // decrypt - DecryptData(buf.Bytes(), peer.E_key) - pkt := NewPacket(buf) - - // create packet && read pktID - var pktID uint32 - if err := pkt.Decode(&pktID); err != nil { - return fmt.Errorf("failed to read packet type! %v", err) - } - - // dispatch packet - log.Printf("Got packet ID: %x, with a sizeof: %d\n", pktID, sz) - peer.eRecv <- &Event{Type: EVENT_CLIENT_PACKET, Peer: peer, Pkt: buf, PktID: pktID} - return nil - }(); err != nil { - log.Printf("[FATAL] %v", err) - return + buf := GetBuffer() + if _, err := buf.ReadFrom(io.LimitReader(peer.conn, int64(sz))); err != nil { + return fmt.Errorf("failed to read packet body: %v", err) } + + // decrypt + DecryptData(buf.Bytes(), peer.E_key) + pkt := NewPacket(buf) + + // create packet && read pktID + var pktID uint32 + if err := pkt.Decode(&pktID); err != nil { + return fmt.Errorf("failed to read packet type! %v", err) + } + + // dispatch packet + // log.Printf("Got packet ID: %x, with a sizeof: %d\n", pktID, sz) + eRecv <- &Event{Type: EVENT_CLIENT_PACKET, Peer: peer, Pkt: buf, PktID: pktID} } } diff --git a/internal/protocol/event.go b/internal/protocol/event.go index bfb0613..56abd12 100644 --- a/internal/protocol/event.go +++ b/internal/protocol/event.go @@ -4,6 +4,7 @@ import "bytes" const ( EVENT_CLIENT_DISCONNECT = iota + EVENT_CLIENT_CONNECT EVENT_CLIENT_PACKET ) diff --git a/internal/service/service.go b/internal/service/service.go index 6f3a074..1acb250 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -1,9 +1,12 @@ package service import ( + "errors" "fmt" "log" "net" + "reflect" + "strconv" "sync" "github.com/CPunch/gopenfusion/config" @@ -20,9 +23,12 @@ type Service struct { listener net.Listener port int Name string - eRecv chan *protocol.Event + stop chan struct{} // tell active handleEvents() to stop + stopped chan struct{} + started chan struct{} packetHandlers map[uint32]PacketHandler - peers *sync.Map + peers map[*protocol.CNPeer]interface{} + stateLock sync.Mutex // OnDisconnect is called when a peer disconnects from the service. // uData is the stored value of the key/value pair in the peer map. @@ -35,22 +41,34 @@ type Service struct { OnConnect func(peer *protocol.CNPeer) (uData interface{}) } -func NewService(name string, port int) (*Service, error) { - listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) +func RandomPort() (int, error) { + l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { - return nil, err + return 0, err + } + defer l.Close() + + _, port, err := net.SplitHostPort(l.Addr().String()) + if err != nil { + return 0, err + } + return strconv.Atoi(port) +} + +func NewService(name string, port int) *Service { + srvc := &Service{ + port: port, + Name: name, } - service := &Service{ - listener: listener, - port: port, - Name: name, - eRecv: make(chan *protocol.Event), - packetHandlers: make(map[uint32]PacketHandler), - peers: &sync.Map{}, - } + srvc.Reset() + return srvc +} - return service, nil +func (service *Service) Reset() { + service.packetHandlers = make(map[uint32]PacketHandler) + service.peers = make(map[*protocol.CNPeer]interface{}) + service.started = make(chan struct{}) } // may not be called while the service is running (eg. srvc.Start() has been called) @@ -58,42 +76,169 @@ func (service *Service) AddPacketHandler(pktID uint32, handler PacketHandler) { service.packetHandlers[pktID] = handler } -func (service *Service) Start() { - log.Printf("%s service hosted on %s:%d\n", service.Name, config.GetAnnounceIP(), service.port) +func (service *Service) Start() error { + service.stop = make(chan struct{}) + service.stopped = make(chan struct{}) + peerConnections := make(chan chan *protocol.Event) + go service.handleEvents(peerConnections) - go service.handleEvents() + // open listener socket + var err error + service.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", service.port)) + if err != nil { + return err + } + + close(service.started) // signal that the service has started + log.Printf("%s service hosted on %s:%d\n", service.Name, config.GetAnnounceIP(), service.port) for { conn, err := service.listener.Accept() if err != nil { - log.Println("Connection error: ", err) - return + fmt.Println(err) + // we expect this to happen when the service is stopped + if errors.Is(err, net.ErrClosed) { + return nil + } + return err } - peer := protocol.NewCNPeer(service.eRecv, conn) - service.connect(peer) + // create a new peer and pass it to the event loop + eRecv := make(chan *protocol.Event) + peer := protocol.NewCNPeer(conn) + log.Printf("New peer %p connected to %s\n", peer, service.Name) + peerConnections <- eRecv + go peer.Handler(eRecv) } } -func (service *Service) handleEvents() { - for event := range service.eRecv { - switch event.Type { - case protocol.EVENT_CLIENT_DISCONNECT: - service.disconnect(event.Peer) - case protocol.EVENT_CLIENT_PACKET: - if err := service.handlePacket(event.Peer, event.PktID, protocol.NewPacket(event.Pkt)); err != nil { - log.Printf("Error handling packet: %v", err) - event.Peer.Kill() +// returns a channel that is closed when the service has started. +// this is useful if you need to do something after the service has started. +func (service *Service) Started() <-chan struct{} { + return service.started +} + +// returns a channel that is closed when the service has stopped. +// this is useful if you need to wait until the service has completely stopped. +func (service *Service) Stopped() <-chan struct{} { + return service.stopped +} + +// stops the service and disconnects all peers. OnDisconnect will be called +// for each peer. +func (service *Service) Stop() { + close(service.stop) + + service.listener.Close() +} + +// returns the stored uData for the peer. +// if the peer does not exist, nil is returned. +// NOTE: the peer map is not locked while accessing, if you're calling this +// outside of the service's event loop, you'll need to lock the peer map yourself. +func (service *Service) GetPeerData(peer *protocol.CNPeer) interface{} { + return service.peers[peer] +} + +// sets the stored uData for the peer. +// NOTE: the peer map is not locked while accessing, if you're calling this +// outside of the service's event loop, you'll need to lock the peer map yourself. +func (service *Service) SetPeerData(peer *protocol.CNPeer, uData interface{}) { + service.peers[peer] = uData +} + +// calls f for each peer in the service passing the peer and the stored uData. +// if f returns false, the iteration is stopped. +// NOTE: the peer map is not locked while iterating, if you're calling this +// outside of the service's event loop, you'll need to lock the peer map yourself. +func (service *Service) RangePeers(f func(peer *protocol.CNPeer, uData interface{}) bool) { + for peer, uData := range service.peers { + if !f(peer, uData) { + break + } + } +} + +// locks the peer map. +func (service *Service) Lock() { + service.stateLock.Lock() +} + +// unlocks the peer map. +func (service *Service) Unlock() { + service.stateLock.Unlock() +} + +// handleEvents is the main event loop for the service. +// it handles all events from the peers and calls the appropriate handlers. +func (service *Service) handleEvents(eRecv <-chan chan *protocol.Event) { + poll := make([]reflect.SelectCase, 0, 4) + + // add the stop channel and the peer connection channel to our poll queue + poll = append(poll, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(service.stop), + }) + poll = append(poll, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(eRecv), + }) + + for { + chosen, value, _ := reflect.Select(poll) + if chosen == 0 { + // stop + + // OnDisconnect handler might need to do something important + service.Lock() + service.RangePeers(func(peer *protocol.CNPeer, uData interface{}) bool { + peer.Kill() + service.disconnect(peer) + return true + }) + service.Unlock() + + // signal we have stopped + close(service.stopped) + return + } else if chosen == 1 { + // new peer, add it to our poll queue + poll = append(poll, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(value.Interface()), + }) + } else { + // peer event + event, ok := value.Interface().(*protocol.Event) + if !ok { + panic("invalid event type") } - // the packet buffer is given to us by the event, so we'll need to make sure to return it to the pool - protocol.PutBuffer(event.Pkt) + service.Lock() + switch event.Type { + case protocol.EVENT_CLIENT_DISCONNECT: + // strip the peer from our poll queue + poll = append(poll[:chosen], poll[chosen+1:]...) + service.disconnect(value.Interface().(*protocol.Event).Peer) + case protocol.EVENT_CLIENT_CONNECT: + service.connect(event.Peer) + case protocol.EVENT_CLIENT_PACKET: + if err := service.handlePacket(event.Peer, event.PktID, protocol.NewPacket(event.Pkt)); err != nil { + log.Printf("Error handling packet: %v", err) + event.Peer.Kill() + } + + // the packet buffer is given to us by the event, so we'll need to make sure to return it to the pool + protocol.PutBuffer(event.Pkt) + } + service.Unlock() } } } func (service *Service) handlePacket(peer *protocol.CNPeer, typeID uint32, pkt protocol.Packet) error { - uData, _ := service.peers.Load(peer) + uData := service.peers[peer] if hndlr, ok := service.packetHandlers[typeID]; ok { + // fmt.Printf("Handling packet %x\n", typeID) if err := hndlr(peer, uData, pkt); err != nil { return err } @@ -106,12 +251,12 @@ func (service *Service) handlePacket(peer *protocol.CNPeer, typeID uint32, pkt p func (service *Service) disconnect(peer *protocol.CNPeer) { if service.OnDisconnect != nil { - uData, _ := service.peers.Load(peer) + uData := service.peers[peer] service.OnDisconnect(peer, uData) } log.Printf("Peer %p disconnected from %s\n", peer, service.Name) - service.peers.Delete(peer) + delete(service.peers, peer) } func (service *Service) connect(peer *protocol.CNPeer) { @@ -123,16 +268,5 @@ func (service *Service) connect(peer *protocol.CNPeer) { } log.Printf("New peer %p connected to %s\n", peer, service.Name) - service.peers.Store(peer, uData) - go peer.Handler() -} - -func (service *Service) SetPeerData(peer *protocol.CNPeer, uData interface{}) { - service.peers.Store(peer, uData) -} - -func (service *Service) RangePeers(f func(peer *protocol.CNPeer, uData interface{}) bool) { - service.peers.Range(func(key, value any) bool { - return f(key.(*protocol.CNPeer), value) - }) + service.SetPeerData(peer, uData) } diff --git a/internal/service/service_test.go b/internal/service/service_test.go new file mode 100644 index 0000000..733b77b --- /dev/null +++ b/internal/service/service_test.go @@ -0,0 +1,91 @@ +package service_test + +import ( + "fmt" + "net" + "os" + "sync" + "testing" + "time" + + "github.com/CPunch/gopenfusion/internal/protocol" + "github.com/CPunch/gopenfusion/internal/service" +) + +var ( + srvc *service.Service + srvcPort int +) + +const ( + timeout = 5 + maxDummyPeers = 5 +) + +func waitWithTimeout(wg *sync.WaitGroup, seconds int) bool { + done := make(chan struct{}) + go func() { + defer close(done) + wg.Wait() + }() + + select { + case <-done: + return true + case <-time.After(time.Duration(seconds) * time.Second): + return false + } +} + +func TestMain(m *testing.M) { + var err error + srvcPort, err = service.RandomPort() + if err != nil { + panic(err) + } + + srvc = service.NewService("TEST", srvcPort) + os.Exit(m.Run()) +} + +func TestService(t *testing.T) { + // waitgroup to wait for test packet handler to be called + wg := sync.WaitGroup{} + + srvc.AddPacketHandler(0x1234, func(peer *protocol.CNPeer, uData interface{}, pkt protocol.Packet) error { + wg.Done() + return nil + }) + + go func() { + if err := srvc.Start(); err != nil { + t.Error(err) + } + }() + + // wait for service to start + <-srvc.Started() + wg.Add(maxDummyPeers) + for i := 0; i < maxDummyPeers; i++ { + go func() { + // make dummy client + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", srvcPort)) + if err != nil { + t.Error(err) + } + + peer := protocol.NewCNPeer(conn) + defer peer.Kill() + // send dummy packet + if err := peer.Send(0x1234); err != nil { + t.Error(err) + } + }() + } + + if !waitWithTimeout(&wg, timeout) { + t.Error("timeout waiting for packet handler to be called") + } + srvc.Stop() + <-srvc.Stopped() +} diff --git a/login/loginserver.go b/login/loginserver.go index 5f6a374..deb488c 100644 --- a/login/loginserver.go +++ b/login/loginserver.go @@ -14,10 +14,7 @@ type LoginServer struct { } func NewLoginServer(dbHndlr *db.DBHandler, redisHndlr *redis.RedisHandler, port int) (*LoginServer, error) { - srvc, err := service.NewService("LOGIN", port) - if err != nil { - return nil, err - } + srvc := service.NewService("LOGIN", port) server := &LoginServer{ service: srvc, @@ -47,6 +44,10 @@ func NewLoginServer(dbHndlr *db.DBHandler, redisHndlr *redis.RedisHandler, port return server, nil } -func (server *LoginServer) Start() { - server.service.Start() +func (server *LoginServer) Start() error { + return server.service.Start() } + +func (server *LoginServer) Stop() { + server.service.Stop() +} \ No newline at end of file diff --git a/shard/shardserver.go b/shard/shardserver.go index 6744881..716078a 100644 --- a/shard/shardserver.go +++ b/shard/shardserver.go @@ -19,10 +19,7 @@ type ShardServer struct { } func NewShardServer(dbHndlr *db.DBHandler, redisHndlr *redis.RedisHandler, port int) (*ShardServer, error) { - srvc, err := service.NewService("SHARD", port) - if err != nil { - return nil, err - } + srvc := service.NewService("SHARD", port) server := &ShardServer{ service: srvc,