mirror of
https://github.com/CPunch/gopenfusion.git
synced 2024-12-04 22:46:32 +00:00
Compare commits
3 Commits
3abba0ca3c
...
c0ba365cf5
Author | SHA1 | Date | |
---|---|---|---|
c0ba365cf5 | |||
d0346b2382 | |||
18a6c5ab42 |
@ -1,7 +1,7 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"errors"
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
@ -44,8 +44,8 @@ func (db *DBHandler) NewAccount(Login, Password string) (*Account, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrLoginInvalidID = fmt.Errorf("invalid Login ID")
|
ErrLoginInvalidID = errors.New("invalid Login ID")
|
||||||
ErrLoginInvalidPassword = fmt.Errorf("invalid ID && Password combo")
|
ErrLoginInvalidPassword = errors.New("invalid ID && Password combo")
|
||||||
)
|
)
|
||||||
|
|
||||||
func (db *DBHandler) TryLogin(Login, Password string) (*Account, error) {
|
func (db *DBHandler) TryLogin(Login, Password string) (*Account, error) {
|
||||||
|
@ -2,6 +2,7 @@ package db_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -50,7 +51,7 @@ func TestDBAccount(t *testing.T) {
|
|||||||
t.Error("account username is not test")
|
t.Error("account username is not test")
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err = testDB.TryLogin("test", "wrongpassword"); err != db.ErrLoginInvalidPassword {
|
if _, err = testDB.TryLogin("test", "wrongpassword"); !errors.Is(err, db.ErrLoginInvalidPassword) {
|
||||||
t.Error("expected ErrLoginInvalidPassword")
|
t.Error("expected ErrLoginInvalidPassword")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -40,15 +40,7 @@ func (c *Chunk) SendPacket(typeID uint32, pkt ...interface{}) {
|
|||||||
// calls f for each entity in this chunk, if f returns true, stop iterating
|
// calls f for each entity in this chunk, if f returns true, stop iterating
|
||||||
// f can safely add/remove entities from the chunk
|
// f can safely add/remove entities from the chunk
|
||||||
func (c *Chunk) ForEachEntity(f func(entity Entity) bool) {
|
func (c *Chunk) ForEachEntity(f func(entity Entity) bool) {
|
||||||
// copy entities to avoid locking for the entire iteration
|
|
||||||
entities := make(map[Entity]struct{})
|
|
||||||
c.lock.Lock()
|
|
||||||
for entity := range c.entities {
|
for entity := range c.entities {
|
||||||
entities[entity] = struct{}{}
|
|
||||||
}
|
|
||||||
c.lock.Unlock()
|
|
||||||
|
|
||||||
for entity := range entities {
|
|
||||||
if f(entity) {
|
if f(entity) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -18,7 +18,6 @@ const (
|
|||||||
// CNPeer is a simple wrapper for net.Conn connections to send/recv packets over the Fusionfall packet protocol.
|
// CNPeer is a simple wrapper for net.Conn connections to send/recv packets over the Fusionfall packet protocol.
|
||||||
type CNPeer struct {
|
type CNPeer struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
eRecv chan *Event
|
|
||||||
whichKey int
|
whichKey int
|
||||||
alive *atomic.Bool
|
alive *atomic.Bool
|
||||||
|
|
||||||
@ -33,10 +32,9 @@ func GetTime() uint64 {
|
|||||||
return uint64(time.Now().UnixMilli())
|
return uint64(time.Now().UnixMilli())
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCNPeer(eRecv chan *Event, conn net.Conn) *CNPeer {
|
func NewCNPeer(conn net.Conn) *CNPeer {
|
||||||
p := &CNPeer{
|
p := &CNPeer{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
eRecv: eRecv,
|
|
||||||
whichKey: USE_E,
|
whichKey: USE_E,
|
||||||
alive: &atomic.Bool{},
|
alive: &atomic.Bool{},
|
||||||
|
|
||||||
@ -96,59 +94,55 @@ func (peer *CNPeer) SetActiveKey(whichKey int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (peer *CNPeer) Kill() {
|
func (peer *CNPeer) Kill() {
|
||||||
log.Printf("Killing peer %p", peer)
|
// de-bounce: only kill if alive
|
||||||
|
|
||||||
if !peer.alive.CompareAndSwap(true, false) {
|
if !peer.alive.CompareAndSwap(true, false) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Printf("Killing peer %p", peer)
|
||||||
peer.conn.Close()
|
peer.conn.Close()
|
||||||
peer.eRecv <- &Event{Type: EVENT_CLIENT_DISCONNECT, Peer: peer}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// meant to be invoked as a goroutine
|
// meant to be invoked as a goroutine
|
||||||
func (peer *CNPeer) Handler() {
|
func (peer *CNPeer) Handler(eRecv chan<- *Event) error {
|
||||||
defer peer.Kill()
|
defer func() {
|
||||||
|
eRecv <- &Event{Type: EVENT_CLIENT_DISCONNECT, Peer: peer}
|
||||||
|
close(eRecv)
|
||||||
|
peer.Kill()
|
||||||
|
}()
|
||||||
|
|
||||||
peer.alive.Store(true)
|
peer.alive.Store(true)
|
||||||
|
eRecv <- &Event{Type: EVENT_CLIENT_CONNECT, Peer: peer}
|
||||||
for {
|
for {
|
||||||
// read packet size, the goroutine spends most of it's time parked here
|
// read packet size, the goroutine spends most of it's time parked here
|
||||||
var sz uint32
|
var sz uint32
|
||||||
if err := binary.Read(peer.conn, binary.LittleEndian, &sz); err != nil {
|
if err := binary.Read(peer.conn, binary.LittleEndian, &sz); err != nil {
|
||||||
log.Printf("[FATAL] failed to read packet size! %v\n", err)
|
return err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// client should never send a packet size outside of this range
|
// client should never send a packet size outside of this range
|
||||||
if sz > CN_PACKET_BUFFER_SIZE || sz < 4 {
|
if sz > CN_PACKET_BUFFER_SIZE || sz < 4 {
|
||||||
log.Printf("[FATAL] malicious packet size received! %d", sz)
|
return fmt.Errorf("invalid packet size: %d", sz)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// grab buffer && read packet body
|
// grab buffer && read packet body
|
||||||
if err := func() error {
|
buf := GetBuffer()
|
||||||
buf := GetBuffer()
|
if _, err := buf.ReadFrom(io.LimitReader(peer.conn, int64(sz))); err != nil {
|
||||||
if _, err := buf.ReadFrom(io.LimitReader(peer.conn, int64(sz))); err != nil {
|
return fmt.Errorf("failed to read packet body: %v", err)
|
||||||
return fmt.Errorf("failed to read packet body! %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// decrypt
|
|
||||||
DecryptData(buf.Bytes(), peer.E_key)
|
|
||||||
pkt := NewPacket(buf)
|
|
||||||
|
|
||||||
// create packet && read pktID
|
|
||||||
var pktID uint32
|
|
||||||
if err := pkt.Decode(&pktID); err != nil {
|
|
||||||
return fmt.Errorf("failed to read packet type! %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// dispatch packet
|
|
||||||
log.Printf("Got packet ID: %x, with a sizeof: %d\n", pktID, sz)
|
|
||||||
peer.eRecv <- &Event{Type: EVENT_CLIENT_PACKET, Peer: peer, Pkt: buf, PktID: pktID}
|
|
||||||
return nil
|
|
||||||
}(); err != nil {
|
|
||||||
log.Printf("[FATAL] %v", err)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// decrypt
|
||||||
|
DecryptData(buf.Bytes(), peer.E_key)
|
||||||
|
pkt := NewPacket(buf)
|
||||||
|
|
||||||
|
// create packet && read pktID
|
||||||
|
var pktID uint32
|
||||||
|
if err := pkt.Decode(&pktID); err != nil {
|
||||||
|
return fmt.Errorf("failed to read packet type! %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// dispatch packet
|
||||||
|
// log.Printf("Got packet ID: %x, with a sizeof: %d\n", pktID, sz)
|
||||||
|
eRecv <- &Event{Type: EVENT_CLIENT_PACKET, Peer: peer, Pkt: buf, PktID: pktID}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import "bytes"
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
EVENT_CLIENT_DISCONNECT = iota
|
EVENT_CLIENT_DISCONNECT = iota
|
||||||
|
EVENT_CLIENT_CONNECT
|
||||||
EVENT_CLIENT_PACKET
|
EVENT_CLIENT_PACKET
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/CPunch/gopenfusion/config"
|
"github.com/CPunch/gopenfusion/config"
|
||||||
@ -20,9 +23,12 @@ type Service struct {
|
|||||||
listener net.Listener
|
listener net.Listener
|
||||||
port int
|
port int
|
||||||
Name string
|
Name string
|
||||||
eRecv chan *protocol.Event
|
stop chan struct{} // tell active handleEvents() to stop
|
||||||
|
stopped chan struct{}
|
||||||
|
started chan struct{}
|
||||||
packetHandlers map[uint32]PacketHandler
|
packetHandlers map[uint32]PacketHandler
|
||||||
peers *sync.Map
|
peers map[*protocol.CNPeer]interface{}
|
||||||
|
stateLock sync.Mutex
|
||||||
|
|
||||||
// OnDisconnect is called when a peer disconnects from the service.
|
// OnDisconnect is called when a peer disconnects from the service.
|
||||||
// uData is the stored value of the key/value pair in the peer map.
|
// uData is the stored value of the key/value pair in the peer map.
|
||||||
@ -35,22 +41,34 @@ type Service struct {
|
|||||||
OnConnect func(peer *protocol.CNPeer) (uData interface{})
|
OnConnect func(peer *protocol.CNPeer) (uData interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(name string, port int) (*Service, error) {
|
func RandomPort() (int, error) {
|
||||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return 0, err
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
|
||||||
|
_, port, err := net.SplitHostPort(l.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return strconv.Atoi(port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewService(name string, port int) *Service {
|
||||||
|
srvc := &Service{
|
||||||
|
port: port,
|
||||||
|
Name: name,
|
||||||
}
|
}
|
||||||
|
|
||||||
service := &Service{
|
srvc.Reset()
|
||||||
listener: listener,
|
return srvc
|
||||||
port: port,
|
}
|
||||||
Name: name,
|
|
||||||
eRecv: make(chan *protocol.Event),
|
|
||||||
packetHandlers: make(map[uint32]PacketHandler),
|
|
||||||
peers: &sync.Map{},
|
|
||||||
}
|
|
||||||
|
|
||||||
return service, nil
|
func (service *Service) Reset() {
|
||||||
|
service.packetHandlers = make(map[uint32]PacketHandler)
|
||||||
|
service.peers = make(map[*protocol.CNPeer]interface{})
|
||||||
|
service.started = make(chan struct{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// may not be called while the service is running (eg. srvc.Start() has been called)
|
// may not be called while the service is running (eg. srvc.Start() has been called)
|
||||||
@ -58,42 +76,169 @@ func (service *Service) AddPacketHandler(pktID uint32, handler PacketHandler) {
|
|||||||
service.packetHandlers[pktID] = handler
|
service.packetHandlers[pktID] = handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *Service) Start() {
|
func (service *Service) Start() error {
|
||||||
log.Printf("%s service hosted on %s:%d\n", service.Name, config.GetAnnounceIP(), service.port)
|
service.stop = make(chan struct{})
|
||||||
|
service.stopped = make(chan struct{})
|
||||||
|
peerConnections := make(chan chan *protocol.Event)
|
||||||
|
go service.handleEvents(peerConnections)
|
||||||
|
|
||||||
go service.handleEvents()
|
// open listener socket
|
||||||
|
var err error
|
||||||
|
service.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", service.port))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
close(service.started) // signal that the service has started
|
||||||
|
log.Printf("%s service hosted on %s:%d\n", service.Name, config.GetAnnounceIP(), service.port)
|
||||||
for {
|
for {
|
||||||
conn, err := service.listener.Accept()
|
conn, err := service.listener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Connection error: ", err)
|
fmt.Println(err)
|
||||||
return
|
// we expect this to happen when the service is stopped
|
||||||
|
if errors.Is(err, net.ErrClosed) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
peer := protocol.NewCNPeer(service.eRecv, conn)
|
// create a new peer and pass it to the event loop
|
||||||
service.connect(peer)
|
eRecv := make(chan *protocol.Event)
|
||||||
|
peer := protocol.NewCNPeer(conn)
|
||||||
|
log.Printf("New peer %p connected to %s\n", peer, service.Name)
|
||||||
|
peerConnections <- eRecv
|
||||||
|
go peer.Handler(eRecv)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *Service) handleEvents() {
|
// returns a channel that is closed when the service has started.
|
||||||
for event := range service.eRecv {
|
// this is useful if you need to do something after the service has started.
|
||||||
switch event.Type {
|
func (service *Service) Started() <-chan struct{} {
|
||||||
case protocol.EVENT_CLIENT_DISCONNECT:
|
return service.started
|
||||||
service.disconnect(event.Peer)
|
}
|
||||||
case protocol.EVENT_CLIENT_PACKET:
|
|
||||||
if err := service.handlePacket(event.Peer, event.PktID, protocol.NewPacket(event.Pkt)); err != nil {
|
// returns a channel that is closed when the service has stopped.
|
||||||
log.Printf("Error handling packet: %v", err)
|
// this is useful if you need to wait until the service has completely stopped.
|
||||||
event.Peer.Kill()
|
func (service *Service) Stopped() <-chan struct{} {
|
||||||
|
return service.stopped
|
||||||
|
}
|
||||||
|
|
||||||
|
// stops the service and disconnects all peers. OnDisconnect will be called
|
||||||
|
// for each peer.
|
||||||
|
func (service *Service) Stop() {
|
||||||
|
close(service.stop)
|
||||||
|
|
||||||
|
service.listener.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns the stored uData for the peer.
|
||||||
|
// if the peer does not exist, nil is returned.
|
||||||
|
// NOTE: the peer map is not locked while accessing, if you're calling this
|
||||||
|
// outside of the service's event loop, you'll need to lock the peer map yourself.
|
||||||
|
func (service *Service) GetPeerData(peer *protocol.CNPeer) interface{} {
|
||||||
|
return service.peers[peer]
|
||||||
|
}
|
||||||
|
|
||||||
|
// sets the stored uData for the peer.
|
||||||
|
// NOTE: the peer map is not locked while accessing, if you're calling this
|
||||||
|
// outside of the service's event loop, you'll need to lock the peer map yourself.
|
||||||
|
func (service *Service) SetPeerData(peer *protocol.CNPeer, uData interface{}) {
|
||||||
|
service.peers[peer] = uData
|
||||||
|
}
|
||||||
|
|
||||||
|
// calls f for each peer in the service passing the peer and the stored uData.
|
||||||
|
// if f returns false, the iteration is stopped.
|
||||||
|
// NOTE: the peer map is not locked while iterating, if you're calling this
|
||||||
|
// outside of the service's event loop, you'll need to lock the peer map yourself.
|
||||||
|
func (service *Service) RangePeers(f func(peer *protocol.CNPeer, uData interface{}) bool) {
|
||||||
|
for peer, uData := range service.peers {
|
||||||
|
if !f(peer, uData) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// locks the peer map.
|
||||||
|
func (service *Service) Lock() {
|
||||||
|
service.stateLock.Lock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// unlocks the peer map.
|
||||||
|
func (service *Service) Unlock() {
|
||||||
|
service.stateLock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleEvents is the main event loop for the service.
|
||||||
|
// it handles all events from the peers and calls the appropriate handlers.
|
||||||
|
func (service *Service) handleEvents(eRecv <-chan chan *protocol.Event) {
|
||||||
|
poll := make([]reflect.SelectCase, 0, 4)
|
||||||
|
|
||||||
|
// add the stop channel and the peer connection channel to our poll queue
|
||||||
|
poll = append(poll, reflect.SelectCase{
|
||||||
|
Dir: reflect.SelectRecv,
|
||||||
|
Chan: reflect.ValueOf(service.stop),
|
||||||
|
})
|
||||||
|
poll = append(poll, reflect.SelectCase{
|
||||||
|
Dir: reflect.SelectRecv,
|
||||||
|
Chan: reflect.ValueOf(eRecv),
|
||||||
|
})
|
||||||
|
|
||||||
|
for {
|
||||||
|
chosen, value, _ := reflect.Select(poll)
|
||||||
|
if chosen == 0 {
|
||||||
|
// stop
|
||||||
|
|
||||||
|
// OnDisconnect handler might need to do something important
|
||||||
|
service.Lock()
|
||||||
|
service.RangePeers(func(peer *protocol.CNPeer, uData interface{}) bool {
|
||||||
|
peer.Kill()
|
||||||
|
service.disconnect(peer)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
service.Unlock()
|
||||||
|
|
||||||
|
// signal we have stopped
|
||||||
|
close(service.stopped)
|
||||||
|
return
|
||||||
|
} else if chosen == 1 {
|
||||||
|
// new peer, add it to our poll queue
|
||||||
|
poll = append(poll, reflect.SelectCase{
|
||||||
|
Dir: reflect.SelectRecv,
|
||||||
|
Chan: reflect.ValueOf(value.Interface()),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// peer event
|
||||||
|
event, ok := value.Interface().(*protocol.Event)
|
||||||
|
if !ok {
|
||||||
|
panic("invalid event type")
|
||||||
}
|
}
|
||||||
|
|
||||||
// the packet buffer is given to us by the event, so we'll need to make sure to return it to the pool
|
service.Lock()
|
||||||
protocol.PutBuffer(event.Pkt)
|
switch event.Type {
|
||||||
|
case protocol.EVENT_CLIENT_DISCONNECT:
|
||||||
|
// strip the peer from our poll queue
|
||||||
|
poll = append(poll[:chosen], poll[chosen+1:]...)
|
||||||
|
service.disconnect(value.Interface().(*protocol.Event).Peer)
|
||||||
|
case protocol.EVENT_CLIENT_CONNECT:
|
||||||
|
service.connect(event.Peer)
|
||||||
|
case protocol.EVENT_CLIENT_PACKET:
|
||||||
|
if err := service.handlePacket(event.Peer, event.PktID, protocol.NewPacket(event.Pkt)); err != nil {
|
||||||
|
log.Printf("Error handling packet: %v", err)
|
||||||
|
event.Peer.Kill()
|
||||||
|
}
|
||||||
|
|
||||||
|
// the packet buffer is given to us by the event, so we'll need to make sure to return it to the pool
|
||||||
|
protocol.PutBuffer(event.Pkt)
|
||||||
|
}
|
||||||
|
service.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *Service) handlePacket(peer *protocol.CNPeer, typeID uint32, pkt protocol.Packet) error {
|
func (service *Service) handlePacket(peer *protocol.CNPeer, typeID uint32, pkt protocol.Packet) error {
|
||||||
uData, _ := service.peers.Load(peer)
|
uData := service.peers[peer]
|
||||||
if hndlr, ok := service.packetHandlers[typeID]; ok {
|
if hndlr, ok := service.packetHandlers[typeID]; ok {
|
||||||
|
// fmt.Printf("Handling packet %x\n", typeID)
|
||||||
if err := hndlr(peer, uData, pkt); err != nil {
|
if err := hndlr(peer, uData, pkt); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -106,12 +251,12 @@ func (service *Service) handlePacket(peer *protocol.CNPeer, typeID uint32, pkt p
|
|||||||
|
|
||||||
func (service *Service) disconnect(peer *protocol.CNPeer) {
|
func (service *Service) disconnect(peer *protocol.CNPeer) {
|
||||||
if service.OnDisconnect != nil {
|
if service.OnDisconnect != nil {
|
||||||
uData, _ := service.peers.Load(peer)
|
uData := service.peers[peer]
|
||||||
service.OnDisconnect(peer, uData)
|
service.OnDisconnect(peer, uData)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Peer %p disconnected from %s\n", peer, service.Name)
|
log.Printf("Peer %p disconnected from %s\n", peer, service.Name)
|
||||||
service.peers.Delete(peer)
|
delete(service.peers, peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *Service) connect(peer *protocol.CNPeer) {
|
func (service *Service) connect(peer *protocol.CNPeer) {
|
||||||
@ -123,16 +268,5 @@ func (service *Service) connect(peer *protocol.CNPeer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("New peer %p connected to %s\n", peer, service.Name)
|
log.Printf("New peer %p connected to %s\n", peer, service.Name)
|
||||||
service.peers.Store(peer, uData)
|
service.SetPeerData(peer, uData)
|
||||||
go peer.Handler()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (service *Service) SetPeerData(peer *protocol.CNPeer, uData interface{}) {
|
|
||||||
service.peers.Store(peer, uData)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (service *Service) RangePeers(f func(peer *protocol.CNPeer, uData interface{}) bool) {
|
|
||||||
service.peers.Range(func(key, value any) bool {
|
|
||||||
return f(key.(*protocol.CNPeer), value)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
91
internal/service/service_test.go
Normal file
91
internal/service/service_test.go
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
package service_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/CPunch/gopenfusion/internal/protocol"
|
||||||
|
"github.com/CPunch/gopenfusion/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
srvc *service.Service
|
||||||
|
srvcPort int
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
timeout = 5
|
||||||
|
maxDummyPeers = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
func waitWithTimeout(wg *sync.WaitGroup, seconds int) bool {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
wg.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return true
|
||||||
|
case <-time.After(time.Duration(seconds) * time.Second):
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
var err error
|
||||||
|
srvcPort, err = service.RandomPort()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srvc = service.NewService("TEST", srvcPort)
|
||||||
|
os.Exit(m.Run())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestService(t *testing.T) {
|
||||||
|
// waitgroup to wait for test packet handler to be called
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
|
||||||
|
srvc.AddPacketHandler(0x1234, func(peer *protocol.CNPeer, uData interface{}, pkt protocol.Packet) error {
|
||||||
|
wg.Done()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := srvc.Start(); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// wait for service to start
|
||||||
|
<-srvc.Started()
|
||||||
|
wg.Add(maxDummyPeers)
|
||||||
|
for i := 0; i < maxDummyPeers; i++ {
|
||||||
|
go func() {
|
||||||
|
// make dummy client
|
||||||
|
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", srvcPort))
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
peer := protocol.NewCNPeer(conn)
|
||||||
|
defer peer.Kill()
|
||||||
|
// send dummy packet
|
||||||
|
if err := peer.Send(0x1234); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !waitWithTimeout(&wg, timeout) {
|
||||||
|
t.Error("timeout waiting for packet handler to be called")
|
||||||
|
}
|
||||||
|
srvc.Stop()
|
||||||
|
<-srvc.Stopped()
|
||||||
|
}
|
@ -2,6 +2,7 @@ package login
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
@ -79,7 +80,7 @@ func (server *LoginServer) Login(peer *protocol.CNPeer, _account interface{}, pk
|
|||||||
|
|
||||||
// attempt login
|
// attempt login
|
||||||
account, err := server.dbHndlr.TryLogin(loginPkt.SzID, loginPkt.SzPassword)
|
account, err := server.dbHndlr.TryLogin(loginPkt.SzID, loginPkt.SzPassword)
|
||||||
if err == db.ErrLoginInvalidID {
|
if errors.Is(err, db.ErrLoginInvalidID) {
|
||||||
// this is the default behavior, auto create the account if the ID isn't in use
|
// this is the default behavior, auto create the account if the ID isn't in use
|
||||||
account, err = server.dbHndlr.NewAccount(loginPkt.SzID, loginPkt.SzPassword)
|
account, err = server.dbHndlr.NewAccount(loginPkt.SzID, loginPkt.SzPassword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -87,7 +88,7 @@ func (server *LoginServer) Login(peer *protocol.CNPeer, _account interface{}, pk
|
|||||||
SendError(LOGIN_DATABASE_ERROR)
|
SendError(LOGIN_DATABASE_ERROR)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else if err == db.ErrLoginInvalidPassword {
|
} else if errors.Is(err, db.ErrLoginInvalidPassword) {
|
||||||
// respond with invalid password
|
// respond with invalid password
|
||||||
SendError(LOGIN_ID_AND_PASSWORD_DO_NOT_MATCH)
|
SendError(LOGIN_ID_AND_PASSWORD_DO_NOT_MATCH)
|
||||||
return nil
|
return nil
|
||||||
|
@ -14,10 +14,7 @@ type LoginServer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewLoginServer(dbHndlr *db.DBHandler, redisHndlr *redis.RedisHandler, port int) (*LoginServer, error) {
|
func NewLoginServer(dbHndlr *db.DBHandler, redisHndlr *redis.RedisHandler, port int) (*LoginServer, error) {
|
||||||
srvc, err := service.NewService("LOGIN", port)
|
srvc := service.NewService("LOGIN", port)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
server := &LoginServer{
|
server := &LoginServer{
|
||||||
service: srvc,
|
service: srvc,
|
||||||
@ -47,6 +44,10 @@ func NewLoginServer(dbHndlr *db.DBHandler, redisHndlr *redis.RedisHandler, port
|
|||||||
return server, nil
|
return server, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *LoginServer) Start() {
|
func (server *LoginServer) Start() error {
|
||||||
server.service.Start()
|
return server.service.Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (server *LoginServer) Stop() {
|
||||||
|
server.service.Stop()
|
||||||
|
}
|
@ -19,10 +19,7 @@ type ShardServer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewShardServer(dbHndlr *db.DBHandler, redisHndlr *redis.RedisHandler, port int) (*ShardServer, error) {
|
func NewShardServer(dbHndlr *db.DBHandler, redisHndlr *redis.RedisHandler, port int) (*ShardServer, error) {
|
||||||
srvc, err := service.NewService("SHARD", port)
|
srvc := service.NewService("SHARD", port)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
server := &ShardServer{
|
server := &ShardServer{
|
||||||
service: srvc,
|
service: srvc,
|
||||||
|
Loading…
Reference in New Issue
Block a user