From f4b17906cea21f62b8b7a19176de5980677ec088 Mon Sep 17 00:00:00 2001 From: CPunch Date: Fri, 1 Dec 2023 00:56:34 -0600 Subject: [PATCH] more protocol/service refactor - removed protocol.Event: CNPeers now send protocol.PacketEvents - peer uData is held in CNPeer, use SetUserData() and UserData() to set/read it - Service.PacketHandler calback has changed, removed uData: switched calls to peer.SetUserData() and peer.UserData() where appropriate - service.Service lots of tidying up, removed dependence on old protocol.Event. - service.Service && protocol.CNPeer now accept a cancelable context. hooray graceful shutdowns and unit tests! - general cleanup --- cmd/login.go | 2 +- cmd/shard.go | 2 +- internal/protocol/cnpeer.go | 92 ++++++----- internal/protocol/event.go | 16 -- internal/service/service.go | 266 ++++++++++++++++--------------- internal/service/service_test.go | 40 +++-- login/login.go | 45 +++--- login/loginserver.go | 14 +- shard/chat.go | 24 +-- shard/join.go | 13 +- shard/movement.go | 22 +-- shard/shardserver.go | 17 +- 12 files changed, 292 insertions(+), 261 deletions(-) delete mode 100644 internal/protocol/event.go diff --git a/cmd/login.go b/cmd/login.go index d5d9203..b564eed 100644 --- a/cmd/login.go +++ b/cmd/login.go @@ -31,7 +31,7 @@ func (s *loginCommand) SetFlags(f *flag.FlagSet) { } func (s *loginCommand) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { - loginServer, err := login.NewLoginServer(dbHndlr, redisHndlr, s.port) + loginServer, err := login.NewLoginServer(ctx, dbHndlr, redisHndlr, s.port) if err != nil { log.Panicf("failed to create shard server: %v", err) } diff --git a/cmd/shard.go b/cmd/shard.go index fef22f0..0ed1653 100644 --- a/cmd/shard.go +++ b/cmd/shard.go @@ -31,7 +31,7 @@ func (s *shardCommand) SetFlags(f *flag.FlagSet) { } func (s *shardCommand) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { - shardServer, err := shard.NewShardServer(dbHndlr, redisHndlr, s.port) + shardServer, err := shard.NewShardServer(ctx, dbHndlr, redisHndlr, s.port) if err != nil { log.Panicf("failed to create shard server: %v", err) } diff --git a/internal/protocol/cnpeer.go b/internal/protocol/cnpeer.go index 650d7f5..66a7eeb 100644 --- a/internal/protocol/cnpeer.go +++ b/internal/protocol/cnpeer.go @@ -1,10 +1,11 @@ package protocol import ( + "bytes" + "context" "encoding/binary" "fmt" "io" - "log" "net" "sync/atomic" "time" @@ -15,9 +16,17 @@ const ( USE_FE ) +type PacketEvent struct { + Type int + Pkt *bytes.Buffer + PktID uint32 +} + // CNPeer is a simple wrapper for net.Conn connections to send/recv packets over the Fusionfall packet protocol. type CNPeer struct { + uData interface{} conn net.Conn + ctx context.Context whichKey int alive *atomic.Bool @@ -32,9 +41,10 @@ func GetTime() uint64 { return uint64(time.Now().UnixMilli()) } -func NewCNPeer(conn net.Conn) *CNPeer { +func NewCNPeer(ctx context.Context, conn net.Conn) *CNPeer { p := &CNPeer{ conn: conn, + ctx: ctx, whichKey: USE_E, alive: &atomic.Bool{}, @@ -45,6 +55,14 @@ func NewCNPeer(conn net.Conn) *CNPeer { return p } +func (peer *CNPeer) SetUserData(uData interface{}) { + peer.uData = uData +} + +func (peer *CNPeer) UserData() interface{} { + return peer.uData +} + func (peer *CNPeer) Send(typeID uint32, data ...interface{}) error { // grab buffer from pool buf := GetBuffer() @@ -82,7 +100,7 @@ func (peer *CNPeer) Send(typeID uint32, data ...interface{}) error { EncryptData(buf.Bytes()[4:], key) // send full packet - log.Printf("Sending %#v, sizeof: %d, buffer: %v", data, buf.Len(), buf.Bytes()) + // log.Printf("Sending %#v, sizeof: %d, buffer: %v", data, buf.Len(), buf.Bytes()) if _, err := peer.conn.Write(buf.Bytes()); err != nil { return fmt.Errorf("failed to write packet body! %v", err) } @@ -99,50 +117,52 @@ func (peer *CNPeer) Kill() { return } - log.Printf("Killing peer %p", peer) peer.conn.Close() } // meant to be invoked as a goroutine -func (peer *CNPeer) Handler(eRecv chan<- *Event) error { +func (peer *CNPeer) Handler(eRecv chan<- *PacketEvent) error { defer func() { - eRecv <- &Event{Type: EVENT_CLIENT_DISCONNECT, Peer: peer} close(eRecv) peer.Kill() }() peer.alive.Store(true) - eRecv <- &Event{Type: EVENT_CLIENT_CONNECT, Peer: peer} for { - // read packet size, the goroutine spends most of it's time parked here - var sz uint32 - if err := binary.Read(peer.conn, binary.LittleEndian, &sz); err != nil { - return err + select { + case <-peer.ctx.Done(): + return nil + default: + // read packet size, the goroutine spends most of it's time parked here + var sz uint32 + if err := binary.Read(peer.conn, binary.LittleEndian, &sz); err != nil { + return err + } + + // client should never send a packet size outside of this range + if sz > CN_PACKET_BUFFER_SIZE || sz < 4 { + return fmt.Errorf("invalid packet size: %d", sz) + } + + // grab buffer && read packet body + buf := GetBuffer() + if _, err := buf.ReadFrom(io.LimitReader(peer.conn, int64(sz))); err != nil { + return fmt.Errorf("failed to read packet body: %v", err) + } + + // decrypt + DecryptData(buf.Bytes(), peer.E_key) + pkt := NewPacket(buf) + + // create packet && read pktID + var pktID uint32 + if err := pkt.Decode(&pktID); err != nil { + return fmt.Errorf("failed to read packet type! %v", err) + } + + // dispatch packet + // log.Printf("Got packet ID: %x, with a sizeof: %d\n", pktID, sz) + eRecv <- &PacketEvent{Pkt: buf, PktID: pktID} } - - // client should never send a packet size outside of this range - if sz > CN_PACKET_BUFFER_SIZE || sz < 4 { - return fmt.Errorf("invalid packet size: %d", sz) - } - - // grab buffer && read packet body - buf := GetBuffer() - if _, err := buf.ReadFrom(io.LimitReader(peer.conn, int64(sz))); err != nil { - return fmt.Errorf("failed to read packet body: %v", err) - } - - // decrypt - DecryptData(buf.Bytes(), peer.E_key) - pkt := NewPacket(buf) - - // create packet && read pktID - var pktID uint32 - if err := pkt.Decode(&pktID); err != nil { - return fmt.Errorf("failed to read packet type! %v", err) - } - - // dispatch packet - // log.Printf("Got packet ID: %x, with a sizeof: %d\n", pktID, sz) - eRecv <- &Event{Type: EVENT_CLIENT_PACKET, Peer: peer, Pkt: buf, PktID: pktID} } } diff --git a/internal/protocol/event.go b/internal/protocol/event.go deleted file mode 100644 index 56abd12..0000000 --- a/internal/protocol/event.go +++ /dev/null @@ -1,16 +0,0 @@ -package protocol - -import "bytes" - -const ( - EVENT_CLIENT_DISCONNECT = iota - EVENT_CLIENT_CONNECT - EVENT_CLIENT_PACKET -) - -type Event struct { - Type int - Peer *CNPeer - Pkt *bytes.Buffer - PktID uint32 -} diff --git a/internal/service/service.go b/internal/service/service.go index 1acb250..157f58b 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -1,6 +1,7 @@ package service import ( + "context" "errors" "fmt" "log" @@ -13,9 +14,9 @@ import ( "github.com/CPunch/gopenfusion/internal/protocol" ) -type PacketHandler func(peer *protocol.CNPeer, uData interface{}, pkt protocol.Packet) error +type PacketHandler func(peer *protocol.CNPeer, pkt protocol.Packet) error -func StubbedPacket(_ *protocol.CNPeer, _ interface{}, _ protocol.Packet) error { +func StubbedPacket(_ *protocol.CNPeer, _ protocol.Packet) error { return nil } @@ -23,22 +24,22 @@ type Service struct { listener net.Listener port int Name string - stop chan struct{} // tell active handleEvents() to stop - stopped chan struct{} + ctx context.Context started chan struct{} + stopped chan struct{} packetHandlers map[uint32]PacketHandler - peers map[*protocol.CNPeer]interface{} + peers map[chan *protocol.PacketEvent]*protocol.CNPeer stateLock sync.Mutex // OnDisconnect is called when a peer disconnects from the service. // uData is the stored value of the key/value pair in the peer map. // It may not be set while the service is running. (eg. srvc.Start() has been called) - OnDisconnect func(peer *protocol.CNPeer, uData interface{}) + OnDisconnect func(peer *protocol.CNPeer) // OnConnect is called when a peer connects to the service. // return value is used as the value in the peer map. // It may not be set while the service is running. (eg. srvc.Start() has been called) - OnConnect func(peer *protocol.CNPeer) (uData interface{}) + OnConnect func(peer *protocol.CNPeer) } func RandomPort() (int, error) { @@ -55,44 +56,52 @@ func RandomPort() (int, error) { return strconv.Atoi(port) } -func NewService(name string, port int) *Service { +func NewService(ctx context.Context, name string, port int) *Service { srvc := &Service{ port: port, Name: name, } - srvc.Reset() + srvc.Reset(ctx) return srvc } -func (service *Service) Reset() { - service.packetHandlers = make(map[uint32]PacketHandler) - service.peers = make(map[*protocol.CNPeer]interface{}) - service.started = make(chan struct{}) +func (srvc *Service) Reset(ctx context.Context) { + srvc.ctx = ctx + srvc.packetHandlers = make(map[uint32]PacketHandler) + srvc.peers = make(map[chan *protocol.PacketEvent]*protocol.CNPeer) + srvc.started = make(chan struct{}) + srvc.stopped = make(chan struct{}) } // may not be called while the service is running (eg. srvc.Start() has been called) -func (service *Service) AddPacketHandler(pktID uint32, handler PacketHandler) { - service.packetHandlers[pktID] = handler +func (srvc *Service) AddPacketHandler(pktID uint32, handler PacketHandler) { + srvc.packetHandlers[pktID] = handler } -func (service *Service) Start() error { - service.stop = make(chan struct{}) - service.stopped = make(chan struct{}) - peerConnections := make(chan chan *protocol.Event) - go service.handleEvents(peerConnections) +type newPeerConnection struct { + peer *protocol.CNPeer + channel chan *protocol.PacketEvent +} + +func (srvc *Service) Start() error { + peerConnections := make(chan newPeerConnection) + defer close(peerConnections) + go srvc.handleEvents(peerConnections) // open listener socket var err error - service.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", service.port)) + srvc.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", srvc.port)) if err != nil { return err } + defer srvc.listener.Close() - close(service.started) // signal that the service has started - log.Printf("%s service hosted on %s:%d\n", service.Name, config.GetAnnounceIP(), service.port) + log.Printf("%s service hosted on %s:%d\n", srvc.Name, config.GetAnnounceIP(), srvc.port) + + close(srvc.started) // signal that the service has started for { - conn, err := service.listener.Accept() + conn, err := srvc.listener.Accept() if err != nil { fmt.Println(err) // we expect this to happen when the service is stopped @@ -103,143 +112,148 @@ func (service *Service) Start() error { } // create a new peer and pass it to the event loop - eRecv := make(chan *protocol.Event) - peer := protocol.NewCNPeer(conn) - log.Printf("New peer %p connected to %s\n", peer, service.Name) - peerConnections <- eRecv + peer := protocol.NewCNPeer(srvc.ctx, conn) + eRecv := make(chan *protocol.PacketEvent) + peerConnections <- newPeerConnection{channel: eRecv, peer: peer} go peer.Handler(eRecv) } } +func (srvc *Service) getPeer(channel chan *protocol.PacketEvent) *protocol.CNPeer { + return srvc.peers[channel] +} + +func (srvc *Service) setPeer(channel chan *protocol.PacketEvent, peer *protocol.CNPeer) { + srvc.peers[channel] = peer +} + +func (srvc *Service) removePeer(channel chan *protocol.PacketEvent) { + delete(srvc.peers, channel) +} + // returns a channel that is closed when the service has started. -// this is useful if you need to do something after the service has started. -func (service *Service) Started() <-chan struct{} { - return service.started +// this is useful if you need to wait until after the service has started. +func (srvc *Service) Started() <-chan struct{} { + return srvc.started } // returns a channel that is closed when the service has stopped. -// this is useful if you need to wait until the service has completely stopped. -func (service *Service) Stopped() <-chan struct{} { - return service.stopped -} - -// stops the service and disconnects all peers. OnDisconnect will be called -// for each peer. -func (service *Service) Stop() { - close(service.stop) - - service.listener.Close() -} - -// returns the stored uData for the peer. -// if the peer does not exist, nil is returned. -// NOTE: the peer map is not locked while accessing, if you're calling this -// outside of the service's event loop, you'll need to lock the peer map yourself. -func (service *Service) GetPeerData(peer *protocol.CNPeer) interface{} { - return service.peers[peer] -} - -// sets the stored uData for the peer. -// NOTE: the peer map is not locked while accessing, if you're calling this -// outside of the service's event loop, you'll need to lock the peer map yourself. -func (service *Service) SetPeerData(peer *protocol.CNPeer, uData interface{}) { - service.peers[peer] = uData +// this is useful if you need wait until after the service has stopped. +func (srvc *Service) Stopped() <-chan struct{} { + return srvc.stopped } // calls f for each peer in the service passing the peer and the stored uData. // if f returns false, the iteration is stopped. // NOTE: the peer map is not locked while iterating, if you're calling this // outside of the service's event loop, you'll need to lock the peer map yourself. -func (service *Service) RangePeers(f func(peer *protocol.CNPeer, uData interface{}) bool) { - for peer, uData := range service.peers { - if !f(peer, uData) { +func (srvc *Service) RangePeers(f func(peer *protocol.CNPeer) bool) { + for _, peer := range srvc.peers { + if !f(peer) { break } } } // locks the peer map. -func (service *Service) Lock() { - service.stateLock.Lock() +func (srvc *Service) Lock() { + srvc.stateLock.Lock() } // unlocks the peer map. -func (service *Service) Unlock() { - service.stateLock.Unlock() +func (srvc *Service) Unlock() { + srvc.stateLock.Unlock() +} + +func (srvc *Service) stop() { + // OnDisconnect handler might need to do something important + srvc.RangePeers(func(peer *protocol.CNPeer) bool { + peer.Kill() + if srvc.OnDisconnect != nil { + srvc.OnDisconnect(peer) + } + return true + }) + + log.Printf("%s service stopped\n", srvc.Name) + close(srvc.stopped) } // handleEvents is the main event loop for the service. // it handles all events from the peers and calls the appropriate handlers. -func (service *Service) handleEvents(eRecv <-chan chan *protocol.Event) { +func (srvc *Service) handleEvents(peerPipe <-chan newPeerConnection) { + defer srvc.stop() + poll := make([]reflect.SelectCase, 0, 4) // add the stop channel and the peer connection channel to our poll queue poll = append(poll, reflect.SelectCase{ Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(service.stop), + Chan: reflect.ValueOf(srvc.ctx.Done()), }) poll = append(poll, reflect.SelectCase{ Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(eRecv), + Chan: reflect.ValueOf(peerPipe), }) + addPoll := func(channel chan *protocol.PacketEvent) { + poll = append(poll, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(channel), + }) + } + + removePoll := func(index int) { + poll = append(poll[:index], poll[index+1:]...) + } + for { - chosen, value, _ := reflect.Select(poll) - if chosen == 0 { - // stop - - // OnDisconnect handler might need to do something important - service.Lock() - service.RangePeers(func(peer *protocol.CNPeer, uData interface{}) bool { - peer.Kill() - service.disconnect(peer) - return true - }) - service.Unlock() - - // signal we have stopped - close(service.stopped) + chosen, value, recvOK := reflect.Select(poll) + switch chosen { + case 0: // cancel signal received, stop the service return - } else if chosen == 1 { - // new peer, add it to our poll queue - poll = append(poll, reflect.SelectCase{ - Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(value.Interface()), - }) - } else { - // peer event - event, ok := value.Interface().(*protocol.Event) - if !ok { - panic("invalid event type") + case 1: // new peer, add it to our poll queue + if !recvOK { + return } - service.Lock() - switch event.Type { - case protocol.EVENT_CLIENT_DISCONNECT: - // strip the peer from our poll queue - poll = append(poll[:chosen], poll[chosen+1:]...) - service.disconnect(value.Interface().(*protocol.Event).Peer) - case protocol.EVENT_CLIENT_CONNECT: - service.connect(event.Peer) - case protocol.EVENT_CLIENT_PACKET: - if err := service.handlePacket(event.Peer, event.PktID, protocol.NewPacket(event.Pkt)); err != nil { - log.Printf("Error handling packet: %v", err) - event.Peer.Kill() - } - - // the packet buffer is given to us by the event, so we'll need to make sure to return it to the pool - protocol.PutBuffer(event.Pkt) + evnt := value.Interface().(newPeerConnection) + addPoll(evnt.channel) + srvc.connect(evnt.channel, evnt.peer) + default: // peer event + channel := poll[chosen].Chan.Interface().(chan *protocol.PacketEvent) + peer := srvc.getPeer(channel) + if peer == nil { + log.Printf("Unknown peer event: %v", value) + removePoll(chosen) + continue } - service.Unlock() + + evnt, ok := value.Interface().(*protocol.PacketEvent) + if !recvOK || !ok || evnt == nil { + // peer disconnected, remove it from our poll queue + removePoll(chosen) + srvc.disconnect(channel, peer) + continue + } + + srvc.Lock() + if err := srvc.handlePacket(peer, evnt.PktID, protocol.NewPacket(evnt.Pkt)); err != nil { + log.Printf("Error handling packet: %v", err) + peer.Kill() + } + srvc.Unlock() + + // the packet buffer is given to us by the event, so we'll need to make sure to return it to the pool + protocol.PutBuffer(evnt.Pkt) } } } -func (service *Service) handlePacket(peer *protocol.CNPeer, typeID uint32, pkt protocol.Packet) error { - uData := service.peers[peer] - if hndlr, ok := service.packetHandlers[typeID]; ok { +func (srvc *Service) handlePacket(peer *protocol.CNPeer, typeID uint32, pkt protocol.Packet) error { + if hndlr, ok := srvc.packetHandlers[typeID]; ok { // fmt.Printf("Handling packet %x\n", typeID) - if err := hndlr(peer, uData, pkt); err != nil { + if err := hndlr(peer, pkt); err != nil { return err } } else { @@ -249,24 +263,20 @@ func (service *Service) handlePacket(peer *protocol.CNPeer, typeID uint32, pkt p return nil } -func (service *Service) disconnect(peer *protocol.CNPeer) { - if service.OnDisconnect != nil { - uData := service.peers[peer] - service.OnDisconnect(peer, uData) +func (srvc *Service) disconnect(channel chan *protocol.PacketEvent, peer *protocol.CNPeer) { + log.Printf("Peer %p disconnected from %s\n", peer, srvc.Name) + if srvc.OnDisconnect != nil { + srvc.OnDisconnect(peer) } - log.Printf("Peer %p disconnected from %s\n", peer, service.Name) - delete(service.peers, peer) + srvc.removePeer(channel) } -func (service *Service) connect(peer *protocol.CNPeer) { - // default uData to nil, but if the service has an OnConnect - // handler, use the result from that - uData := interface{}(nil) - if service.OnConnect != nil { - uData = service.OnConnect(peer) +func (srvc *Service) connect(channel chan *protocol.PacketEvent, peer *protocol.CNPeer) { + log.Printf("New peer %p connected to %s\n", peer, srvc.Name) + if srvc.OnConnect != nil { + srvc.OnConnect(peer) } - log.Printf("New peer %p connected to %s\n", peer, service.Name) - service.SetPeerData(peer, uData) + srvc.setPeer(channel, peer) } diff --git a/internal/service/service_test.go b/internal/service/service_test.go index 733b77b..c8c8128 100644 --- a/internal/service/service_test.go +++ b/internal/service/service_test.go @@ -1,7 +1,9 @@ package service_test import ( + "context" "fmt" + "log" "net" "os" "sync" @@ -13,12 +15,11 @@ import ( ) var ( - srvc *service.Service srvcPort int ) const ( - timeout = 5 + timeout = 2 maxDummyPeers = 5 ) @@ -44,15 +45,18 @@ func TestMain(m *testing.M) { panic(err) } - srvc = service.NewService("TEST", srvcPort) os.Exit(m.Run()) } func TestService(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + srvc := service.NewService(ctx, "TEST", srvcPort) + // waitgroup to wait for test packet handler to be called wg := sync.WaitGroup{} - srvc.AddPacketHandler(0x1234, func(peer *protocol.CNPeer, uData interface{}, pkt protocol.Packet) error { + srvc.AddPacketHandler(0x1234, func(peer *protocol.CNPeer, pkt protocol.Packet) error { + log.Printf("Received packet %#v", pkt) wg.Done() return nil }) @@ -65,7 +69,7 @@ func TestService(t *testing.T) { // wait for service to start <-srvc.Started() - wg.Add(maxDummyPeers) + wg.Add(maxDummyPeers * 3) // 2 wg.Done() calls per dummy peer for i := 0; i < maxDummyPeers; i++ { go func() { // make dummy client @@ -74,18 +78,28 @@ func TestService(t *testing.T) { t.Error(err) } - peer := protocol.NewCNPeer(conn) - defer peer.Kill() - // send dummy packet - if err := peer.Send(0x1234); err != nil { - t.Error(err) - } + peer := protocol.NewCNPeer(ctx, conn) + go func() { + defer peer.Kill() + + // send dummy packets + for i := 0; i < 2; i++ { + if err := peer.Send(0x1234); err != nil { + t.Error(err) + } + } + }() + + // we wait until Handler gracefully exits (peer was killed) + peer.Handler(make(chan *protocol.PacketEvent)) + wg.Done() }() } if !waitWithTimeout(&wg, timeout) { - t.Error("timeout waiting for packet handler to be called") + t.Error("failed to wait for packet handler to be called") } - srvc.Stop() + + cancel() <-srvc.Stopped() } diff --git a/login/login.go b/login/login.go index 3337d6c..25ce378 100644 --- a/login/login.go +++ b/login/login.go @@ -61,7 +61,7 @@ func (server *LoginServer) AcceptLogin(peer *protocol.CNPeer, SzID string, IClie return nil } -func (server *LoginServer) Login(peer *protocol.CNPeer, _account interface{}, pkt protocol.Packet) error { +func (server *LoginServer) Login(peer *protocol.CNPeer, pkt protocol.Packet) error { var loginPkt protocol.SP_CL2LS_REQ_LOGIN pkt.Decode(&loginPkt) @@ -73,9 +73,9 @@ func (server *LoginServer) Login(peer *protocol.CNPeer, _account interface{}, pk } // client is resending a login packet?? - if _account != nil { + if peer.UserData() != nil { SendError(LOGIN_ERROR) - return fmt.Errorf("out of order P_CL2LS_REQ_LOGIN: %v", _account) + return fmt.Errorf("out of order P_CL2LS_REQ_LOGIN: %v", peer.UserData()) } // attempt login @@ -98,7 +98,7 @@ func (server *LoginServer) Login(peer *protocol.CNPeer, _account interface{}, pk } // grab player data - server.service.SetPeerData(peer, account) + peer.SetUserData(account) plrs, err := server.dbHndlr.GetPlayers(account.AccountID) if err != nil { SendError(LOGIN_DATABASE_ERROR) @@ -137,7 +137,7 @@ func (server *LoginServer) Login(peer *protocol.CNPeer, _account interface{}, pk return server.AcceptLogin(peer, loginPkt.SzID, loginPkt.IClientVerC, 1, charInfo[:len(plrs)]) } -func (server *LoginServer) CheckCharacterName(peer *protocol.CNPeer, account interface{}, pkt protocol.Packet) error { +func (server *LoginServer) CheckCharacterName(peer *protocol.CNPeer, pkt protocol.Packet) error { var charPkt protocol.SP_CL2LS_REQ_CHECK_CHAR_NAME pkt.Decode(&charPkt) @@ -148,17 +148,18 @@ func (server *LoginServer) CheckCharacterName(peer *protocol.CNPeer, account int }) } -func (server *LoginServer) SaveCharacterName(peer *protocol.CNPeer, account interface{}, pkt protocol.Packet) error { +func (server *LoginServer) SaveCharacterName(peer *protocol.CNPeer, pkt protocol.Packet) error { var charPkt protocol.SP_CL2LS_REQ_SAVE_CHAR_NAME pkt.Decode(&charPkt) - if account == nil { + account, ok := peer.UserData().(*db.Account) + if !ok || account == nil { peer.Send(protocol.P_LS2CL_REP_SAVE_CHAR_NAME_FAIL, protocol.SP_LS2CL_REP_SAVE_CHAR_NAME_FAIL{}) return fmt.Errorf("out of order P_LS2CL_REP_SAVE_CHAR_NAME_FAIL") } // TODO: sanity check SzFirstName && SzLastName - PlayerID, err := server.dbHndlr.NewPlayer(account.(*db.Account).AccountID, charPkt.SzFirstName, charPkt.SzLastName, int(charPkt.ISlotNum)) + PlayerID, err := server.dbHndlr.NewPlayer(account.AccountID, charPkt.SzFirstName, charPkt.SzLastName, int(charPkt.ISlotNum)) if err != nil { peer.Send(protocol.P_LS2CL_REP_SAVE_CHAR_NAME_FAIL, protocol.SP_LS2CL_REP_SAVE_CHAR_NAME_FAIL{}) return err @@ -210,11 +211,12 @@ func SendFail(peer *protocol.CNPeer) error { return nil } -func (server *LoginServer) CharacterCreate(peer *protocol.CNPeer, account interface{}, pkt protocol.Packet) error { +func (server *LoginServer) CharacterCreate(peer *protocol.CNPeer, pkt protocol.Packet) error { var charPkt protocol.SP_CL2LS_REQ_CHAR_CREATE pkt.Decode(&charPkt) - if account == nil { + account, ok := peer.UserData().(*db.Account) + if !ok || account == nil { return SendFail(peer) } @@ -223,7 +225,7 @@ func (server *LoginServer) CharacterCreate(peer *protocol.CNPeer, account interf return SendFail(peer) } - if err := server.dbHndlr.FinishPlayer(&charPkt, account.(*db.Account).AccountID); err != nil { + if err := server.dbHndlr.FinishPlayer(&charPkt, account.AccountID); err != nil { log.Printf("Error finishing player: %v", err) return SendFail(peer) } @@ -242,15 +244,16 @@ func (server *LoginServer) CharacterCreate(peer *protocol.CNPeer, account interf }) } -func (server *LoginServer) CharacterDelete(peer *protocol.CNPeer, account interface{}, pkt protocol.Packet) error { +func (server *LoginServer) CharacterDelete(peer *protocol.CNPeer, pkt protocol.Packet) error { var charPkt protocol.SP_CL2LS_REQ_CHAR_DELETE pkt.Decode(&charPkt) - if account == nil { + account, ok := peer.UserData().(*db.Account) + if !ok || account == nil { return SendFail(peer) } - slot, err := server.dbHndlr.DeletePlayer(int(charPkt.IPC_UID), account.(*db.Account).AccountID) + slot, err := server.dbHndlr.DeletePlayer(int(charPkt.IPC_UID), account.AccountID) if err != nil { return SendFail(peer) } @@ -260,11 +263,12 @@ func (server *LoginServer) CharacterDelete(peer *protocol.CNPeer, account interf }) } -func (server *LoginServer) ShardSelect(peer *protocol.CNPeer, account interface{}, pkt protocol.Packet) error { +func (server *LoginServer) ShardSelect(peer *protocol.CNPeer, pkt protocol.Packet) error { var selection protocol.SP_CL2LS_REQ_CHAR_SELECT pkt.Decode(&selection) - if account == nil { + account, ok := peer.UserData().(*db.Account) + if !ok || account == nil { return SendFail(peer) } @@ -289,7 +293,7 @@ func (server *LoginServer) ShardSelect(peer *protocol.CNPeer, account interface{ log.Printf("Error getting player: %v", err) return SendFail(peer) } - accountID := account.(*db.Account).AccountID + accountID := account.AccountID if plr.AccountID != accountID { log.Printf("HACK: player %d tried to join shard as player %d", accountID, plr.AccountID) @@ -315,15 +319,16 @@ func (server *LoginServer) ShardSelect(peer *protocol.CNPeer, account interface{ return peer.Send(protocol.P_LS2CL_REP_SHARD_SELECT_SUCC, resp) } -func (server *LoginServer) FinishTutorial(peer *protocol.CNPeer, account interface{}, pkt protocol.Packet) error { +func (server *LoginServer) FinishTutorial(peer *protocol.CNPeer, pkt protocol.Packet) error { var charPkt protocol.SP_CL2LS_REQ_SAVE_CHAR_TUTOR pkt.Decode(&charPkt) - if account == nil { + account, ok := peer.UserData().(*db.Account) + if !ok || account == nil { return SendFail(peer) } - if err := server.dbHndlr.FinishTutorial(int(charPkt.IPC_UID), account.(*db.Account).AccountID); err != nil { + if err := server.dbHndlr.FinishTutorial(int(charPkt.IPC_UID), account.AccountID); err != nil { return SendFail(peer) } diff --git a/login/loginserver.go b/login/loginserver.go index deb488c..77a59e8 100644 --- a/login/loginserver.go +++ b/login/loginserver.go @@ -1,6 +1,8 @@ package login import ( + "context" + "github.com/CPunch/gopenfusion/internal/db" "github.com/CPunch/gopenfusion/internal/protocol" "github.com/CPunch/gopenfusion/internal/redis" @@ -13,8 +15,8 @@ type LoginServer struct { redisHndlr *redis.RedisHandler } -func NewLoginServer(dbHndlr *db.DBHandler, redisHndlr *redis.RedisHandler, port int) (*LoginServer, error) { - srvc := service.NewService("LOGIN", port) +func NewLoginServer(ctx context.Context, dbHndlr *db.DBHandler, redisHndlr *redis.RedisHandler, port int) (*LoginServer, error) { + srvc := service.NewService(ctx, "LOGIN", port) server := &LoginServer{ service: srvc, @@ -37,17 +39,9 @@ func NewLoginServer(dbHndlr *db.DBHandler, redisHndlr *redis.RedisHandler, port srvc.AddPacketHandler(protocol.P_CL2LS_REQ_CHANGE_CHAR_NAME, service.StubbedPacket) srvc.AddPacketHandler(protocol.P_CL2LS_REQ_SERVER_SELECT, service.StubbedPacket) - srvc.OnConnect = func(peer *protocol.CNPeer) interface{} { - return nil - } - return server, nil } func (server *LoginServer) Start() error { return server.service.Start() } - -func (server *LoginServer) Stop() { - server.service.Stop() -} \ No newline at end of file diff --git a/shard/chat.go b/shard/chat.go index 15fe52c..97159b0 100644 --- a/shard/chat.go +++ b/shard/chat.go @@ -7,14 +7,14 @@ import ( "github.com/CPunch/gopenfusion/internal/protocol" ) -func (server *ShardServer) freeChat(peer *protocol.CNPeer, _plr interface{}, pkt protocol.Packet) error { +func (server *ShardServer) freeChat(peer *protocol.CNPeer, pkt protocol.Packet) error { var chat protocol.SP_CL2FE_REQ_SEND_FREECHAT_MESSAGE pkt.Decode(&chat) - if _plr == nil { - return fmt.Errorf("freeChat: _plr is nil") + plr, ok := peer.UserData().(*entity.Player) + if !ok || plr == nil { + return fmt.Errorf("freeChat: plr is nil") } - plr := _plr.(*entity.Player) // spread message return server.sendAllPacket(plr, protocol.P_FE2CL_REP_SEND_FREECHAT_MESSAGE_SUCC, protocol.SP_FE2CL_REP_SEND_FREECHAT_MESSAGE_SUCC{ @@ -24,14 +24,14 @@ func (server *ShardServer) freeChat(peer *protocol.CNPeer, _plr interface{}, pkt }) } -func (server *ShardServer) menuChat(peer *protocol.CNPeer, _plr interface{}, pkt protocol.Packet) error { +func (server *ShardServer) menuChat(peer *protocol.CNPeer, pkt protocol.Packet) error { var chat protocol.SP_CL2FE_REQ_SEND_MENUCHAT_MESSAGE pkt.Decode(&chat) - if _plr == nil { - return fmt.Errorf("menuChat: _plr is nil") + plr, ok := peer.UserData().(*entity.Player) + if !ok || plr == nil { + return fmt.Errorf("menuChat: plr is nil") } - plr := _plr.(*entity.Player) // spread message return server.sendAllPacket(plr, protocol.P_FE2CL_REP_SEND_MENUCHAT_MESSAGE_SUCC, protocol.SP_FE2CL_REP_SEND_MENUCHAT_MESSAGE_SUCC{ @@ -41,14 +41,14 @@ func (server *ShardServer) menuChat(peer *protocol.CNPeer, _plr interface{}, pkt }) } -func (server *ShardServer) emoteChat(peer *protocol.CNPeer, _plr interface{}, pkt protocol.Packet) error { +func (server *ShardServer) emoteChat(peer *protocol.CNPeer, pkt protocol.Packet) error { var chat protocol.SP_CL2FE_REQ_PC_AVATAR_EMOTES_CHAT pkt.Decode(&chat) - if _plr == nil { - return fmt.Errorf("emoteChat: _plr is nil") + plr, ok := peer.UserData().(*entity.Player) + if !ok || plr == nil { + return fmt.Errorf("emoteChat: plr is nil") } - plr := _plr.(*entity.Player) // spread message return server.sendAllPacket(plr, protocol.P_FE2CL_REP_PC_AVATAR_EMOTES_CHAT, protocol.SP_FE2CL_REP_PC_AVATAR_EMOTES_CHAT{ diff --git a/shard/join.go b/shard/join.go index 75cf4f1..23f10b3 100644 --- a/shard/join.go +++ b/shard/join.go @@ -20,16 +20,17 @@ func (server *ShardServer) attachPlayer(peer *protocol.CNPeer, meta redis.LoginM // server.Start() goroutine. the only functions allowed to access // it are the packet handlers as no other goroutines will be // concurrently accessing it. - server.service.SetPeerData(peer, plr) + peer.SetUserData(plr) return plr, nil } -func (server *ShardServer) RequestEnter(peer *protocol.CNPeer, _plr interface{}, pkt protocol.Packet) error { +func (server *ShardServer) RequestEnter(peer *protocol.CNPeer, pkt protocol.Packet) error { var enter protocol.SP_CL2FE_REQ_PC_ENTER pkt.Decode(&enter) // resending a shard enter packet? - if _plr != nil { + _plr, ok := peer.UserData().(*entity.Player) + if ok && _plr != nil { return fmt.Errorf("resent enter packet") } @@ -64,15 +65,15 @@ func (server *ShardServer) RequestEnter(peer *protocol.CNPeer, _plr interface{}, return nil } -func (server *ShardServer) LoadingComplete(peer *protocol.CNPeer, _plr interface{}, pkt protocol.Packet) error { +func (server *ShardServer) LoadingComplete(peer *protocol.CNPeer, pkt protocol.Packet) error { var loadComplete protocol.SP_CL2FE_REQ_PC_LOADING_COMPLETE pkt.Decode(&loadComplete) // was the peer attached to a player? - if _plr == nil { + plr, ok := peer.UserData().(*entity.Player) + if !ok || plr == nil { return fmt.Errorf("loadingComplete: plr is nil") } - plr := _plr.(*entity.Player) err := peer.Send(protocol.P_FE2CL_REP_PC_LOADING_COMPLETE_SUCC, protocol.SP_FE2CL_REP_PC_LOADING_COMPLETE_SUCC{IPC_ID: int32(plr.PlayerID)}) if err != nil { diff --git a/shard/movement.go b/shard/movement.go index 898df47..625769e 100644 --- a/shard/movement.go +++ b/shard/movement.go @@ -15,14 +15,14 @@ func (server *ShardServer) updatePlayerPosition(plr *entity.Player, X, Y, Z, Ang server.updateEntityChunk(plr, plr.GetChunkPos(), entity.MakeChunkPosition(X, Y)) } -func (server *ShardServer) playerMove(peer *protocol.CNPeer, _plr interface{}, pkt protocol.Packet) error { +func (server *ShardServer) playerMove(peer *protocol.CNPeer, pkt protocol.Packet) error { var move protocol.SP_CL2FE_REQ_PC_MOVE pkt.Decode(&move) - if _plr == nil { - return fmt.Errorf("playerMove: _plr is nil") + plr, ok := peer.UserData().(*entity.Player) + if !ok || plr == nil { + return fmt.Errorf("playerMove: plr is nil") } - plr := _plr.(*entity.Player) // update chunking server.updatePlayerPosition(plr, int(move.IX), int(move.IY), int(move.IZ), int(move.IAngle)) @@ -43,14 +43,14 @@ func (server *ShardServer) playerMove(peer *protocol.CNPeer, _plr interface{}, p }) } -func (server *ShardServer) playerStop(peer *protocol.CNPeer, _plr interface{}, pkt protocol.Packet) error { +func (server *ShardServer) playerStop(peer *protocol.CNPeer, pkt protocol.Packet) error { var stop protocol.SP_CL2FE_REQ_PC_STOP pkt.Decode(&stop) - if _plr == nil { - return fmt.Errorf("playerStop: _plr is nil") + plr, ok := peer.UserData().(*entity.Player) + if !ok || plr == nil { + return fmt.Errorf("playerStop: plr is nil") } - plr := _plr.(*entity.Player) // update chunking server.updatePlayerPosition(plr, int(stop.IX), int(stop.IY), int(stop.IZ), plr.Angle) @@ -65,14 +65,14 @@ func (server *ShardServer) playerStop(peer *protocol.CNPeer, _plr interface{}, p }) } -func (server *ShardServer) playerJump(peer *protocol.CNPeer, _plr interface{}, pkt protocol.Packet) error { +func (server *ShardServer) playerJump(peer *protocol.CNPeer, pkt protocol.Packet) error { var jump protocol.SP_CL2FE_REQ_PC_JUMP pkt.Decode(&jump) - if _plr == nil { + plr, ok := peer.UserData().(*entity.Player) + if !ok || plr == nil { return fmt.Errorf("playerJump: _plr is nil") } - plr := _plr.(*entity.Player) // update chunking server.updatePlayerPosition(plr, int(jump.IX), int(jump.IY), int(jump.IZ), plr.Angle) diff --git a/shard/shardserver.go b/shard/shardserver.go index 716078a..5a156ee 100644 --- a/shard/shardserver.go +++ b/shard/shardserver.go @@ -1,6 +1,8 @@ package shard import ( + "context" + "github.com/CPunch/gopenfusion/config" "github.com/CPunch/gopenfusion/internal/db" "github.com/CPunch/gopenfusion/internal/entity" @@ -18,8 +20,8 @@ type ShardServer struct { chunks map[entity.ChunkPosition]*entity.Chunk } -func NewShardServer(dbHndlr *db.DBHandler, redisHndlr *redis.RedisHandler, port int) (*ShardServer, error) { - srvc := service.NewService("SHARD", port) +func NewShardServer(ctx context.Context, dbHndlr *db.DBHandler, redisHndlr *redis.RedisHandler, port int) (*ShardServer, error) { + srvc := service.NewService(ctx, "SHARD", port) server := &ShardServer{ service: srvc, @@ -53,13 +55,14 @@ func (server *ShardServer) Start() { server.service.Start() } -func (server *ShardServer) onDisconnect(peer *protocol.CNPeer, _plr interface{}) { +func (server *ShardServer) onDisconnect(peer *protocol.CNPeer) { // remove from chunks - if _plr != nil { - server.removeEntity(_plr.(*entity.Player)) + plr, ok := peer.UserData().(*entity.Player) + if ok && plr != nil { + server.removeEntity(plr) } } -func (server *ShardServer) onConnect(peer *protocol.CNPeer) interface{} { - return nil +func (server *ShardServer) onConnect(peer *protocol.CNPeer) { + }