mirror of
https://github.com/CPunch/gopenfusion.git
synced 2025-11-10 09:30:07 +00:00
more protocol/service refactor
- removed protocol.Event: CNPeers now send protocol.PacketEvents - peer uData is held in CNPeer, use SetUserData() and UserData() to set/read it - Service.PacketHandler calback has changed, removed uData: switched calls to peer.SetUserData() and peer.UserData() where appropriate - service.Service lots of tidying up, removed dependence on old protocol.Event. - service.Service && protocol.CNPeer now accept a cancelable context. hooray graceful shutdowns and unit tests! - general cleanup
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -13,9 +14,9 @@ import (
|
||||
"github.com/CPunch/gopenfusion/internal/protocol"
|
||||
)
|
||||
|
||||
type PacketHandler func(peer *protocol.CNPeer, uData interface{}, pkt protocol.Packet) error
|
||||
type PacketHandler func(peer *protocol.CNPeer, pkt protocol.Packet) error
|
||||
|
||||
func StubbedPacket(_ *protocol.CNPeer, _ interface{}, _ protocol.Packet) error {
|
||||
func StubbedPacket(_ *protocol.CNPeer, _ protocol.Packet) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -23,22 +24,22 @@ type Service struct {
|
||||
listener net.Listener
|
||||
port int
|
||||
Name string
|
||||
stop chan struct{} // tell active handleEvents() to stop
|
||||
stopped chan struct{}
|
||||
ctx context.Context
|
||||
started chan struct{}
|
||||
stopped chan struct{}
|
||||
packetHandlers map[uint32]PacketHandler
|
||||
peers map[*protocol.CNPeer]interface{}
|
||||
peers map[chan *protocol.PacketEvent]*protocol.CNPeer
|
||||
stateLock sync.Mutex
|
||||
|
||||
// OnDisconnect is called when a peer disconnects from the service.
|
||||
// uData is the stored value of the key/value pair in the peer map.
|
||||
// It may not be set while the service is running. (eg. srvc.Start() has been called)
|
||||
OnDisconnect func(peer *protocol.CNPeer, uData interface{})
|
||||
OnDisconnect func(peer *protocol.CNPeer)
|
||||
|
||||
// OnConnect is called when a peer connects to the service.
|
||||
// return value is used as the value in the peer map.
|
||||
// It may not be set while the service is running. (eg. srvc.Start() has been called)
|
||||
OnConnect func(peer *protocol.CNPeer) (uData interface{})
|
||||
OnConnect func(peer *protocol.CNPeer)
|
||||
}
|
||||
|
||||
func RandomPort() (int, error) {
|
||||
@@ -55,44 +56,52 @@ func RandomPort() (int, error) {
|
||||
return strconv.Atoi(port)
|
||||
}
|
||||
|
||||
func NewService(name string, port int) *Service {
|
||||
func NewService(ctx context.Context, name string, port int) *Service {
|
||||
srvc := &Service{
|
||||
port: port,
|
||||
Name: name,
|
||||
}
|
||||
|
||||
srvc.Reset()
|
||||
srvc.Reset(ctx)
|
||||
return srvc
|
||||
}
|
||||
|
||||
func (service *Service) Reset() {
|
||||
service.packetHandlers = make(map[uint32]PacketHandler)
|
||||
service.peers = make(map[*protocol.CNPeer]interface{})
|
||||
service.started = make(chan struct{})
|
||||
func (srvc *Service) Reset(ctx context.Context) {
|
||||
srvc.ctx = ctx
|
||||
srvc.packetHandlers = make(map[uint32]PacketHandler)
|
||||
srvc.peers = make(map[chan *protocol.PacketEvent]*protocol.CNPeer)
|
||||
srvc.started = make(chan struct{})
|
||||
srvc.stopped = make(chan struct{})
|
||||
}
|
||||
|
||||
// may not be called while the service is running (eg. srvc.Start() has been called)
|
||||
func (service *Service) AddPacketHandler(pktID uint32, handler PacketHandler) {
|
||||
service.packetHandlers[pktID] = handler
|
||||
func (srvc *Service) AddPacketHandler(pktID uint32, handler PacketHandler) {
|
||||
srvc.packetHandlers[pktID] = handler
|
||||
}
|
||||
|
||||
func (service *Service) Start() error {
|
||||
service.stop = make(chan struct{})
|
||||
service.stopped = make(chan struct{})
|
||||
peerConnections := make(chan chan *protocol.Event)
|
||||
go service.handleEvents(peerConnections)
|
||||
type newPeerConnection struct {
|
||||
peer *protocol.CNPeer
|
||||
channel chan *protocol.PacketEvent
|
||||
}
|
||||
|
||||
func (srvc *Service) Start() error {
|
||||
peerConnections := make(chan newPeerConnection)
|
||||
defer close(peerConnections)
|
||||
go srvc.handleEvents(peerConnections)
|
||||
|
||||
// open listener socket
|
||||
var err error
|
||||
service.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", service.port))
|
||||
srvc.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", srvc.port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer srvc.listener.Close()
|
||||
|
||||
close(service.started) // signal that the service has started
|
||||
log.Printf("%s service hosted on %s:%d\n", service.Name, config.GetAnnounceIP(), service.port)
|
||||
log.Printf("%s service hosted on %s:%d\n", srvc.Name, config.GetAnnounceIP(), srvc.port)
|
||||
|
||||
close(srvc.started) // signal that the service has started
|
||||
for {
|
||||
conn, err := service.listener.Accept()
|
||||
conn, err := srvc.listener.Accept()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
// we expect this to happen when the service is stopped
|
||||
@@ -103,143 +112,148 @@ func (service *Service) Start() error {
|
||||
}
|
||||
|
||||
// create a new peer and pass it to the event loop
|
||||
eRecv := make(chan *protocol.Event)
|
||||
peer := protocol.NewCNPeer(conn)
|
||||
log.Printf("New peer %p connected to %s\n", peer, service.Name)
|
||||
peerConnections <- eRecv
|
||||
peer := protocol.NewCNPeer(srvc.ctx, conn)
|
||||
eRecv := make(chan *protocol.PacketEvent)
|
||||
peerConnections <- newPeerConnection{channel: eRecv, peer: peer}
|
||||
go peer.Handler(eRecv)
|
||||
}
|
||||
}
|
||||
|
||||
func (srvc *Service) getPeer(channel chan *protocol.PacketEvent) *protocol.CNPeer {
|
||||
return srvc.peers[channel]
|
||||
}
|
||||
|
||||
func (srvc *Service) setPeer(channel chan *protocol.PacketEvent, peer *protocol.CNPeer) {
|
||||
srvc.peers[channel] = peer
|
||||
}
|
||||
|
||||
func (srvc *Service) removePeer(channel chan *protocol.PacketEvent) {
|
||||
delete(srvc.peers, channel)
|
||||
}
|
||||
|
||||
// returns a channel that is closed when the service has started.
|
||||
// this is useful if you need to do something after the service has started.
|
||||
func (service *Service) Started() <-chan struct{} {
|
||||
return service.started
|
||||
// this is useful if you need to wait until after the service has started.
|
||||
func (srvc *Service) Started() <-chan struct{} {
|
||||
return srvc.started
|
||||
}
|
||||
|
||||
// returns a channel that is closed when the service has stopped.
|
||||
// this is useful if you need to wait until the service has completely stopped.
|
||||
func (service *Service) Stopped() <-chan struct{} {
|
||||
return service.stopped
|
||||
}
|
||||
|
||||
// stops the service and disconnects all peers. OnDisconnect will be called
|
||||
// for each peer.
|
||||
func (service *Service) Stop() {
|
||||
close(service.stop)
|
||||
|
||||
service.listener.Close()
|
||||
}
|
||||
|
||||
// returns the stored uData for the peer.
|
||||
// if the peer does not exist, nil is returned.
|
||||
// NOTE: the peer map is not locked while accessing, if you're calling this
|
||||
// outside of the service's event loop, you'll need to lock the peer map yourself.
|
||||
func (service *Service) GetPeerData(peer *protocol.CNPeer) interface{} {
|
||||
return service.peers[peer]
|
||||
}
|
||||
|
||||
// sets the stored uData for the peer.
|
||||
// NOTE: the peer map is not locked while accessing, if you're calling this
|
||||
// outside of the service's event loop, you'll need to lock the peer map yourself.
|
||||
func (service *Service) SetPeerData(peer *protocol.CNPeer, uData interface{}) {
|
||||
service.peers[peer] = uData
|
||||
// this is useful if you need wait until after the service has stopped.
|
||||
func (srvc *Service) Stopped() <-chan struct{} {
|
||||
return srvc.stopped
|
||||
}
|
||||
|
||||
// calls f for each peer in the service passing the peer and the stored uData.
|
||||
// if f returns false, the iteration is stopped.
|
||||
// NOTE: the peer map is not locked while iterating, if you're calling this
|
||||
// outside of the service's event loop, you'll need to lock the peer map yourself.
|
||||
func (service *Service) RangePeers(f func(peer *protocol.CNPeer, uData interface{}) bool) {
|
||||
for peer, uData := range service.peers {
|
||||
if !f(peer, uData) {
|
||||
func (srvc *Service) RangePeers(f func(peer *protocol.CNPeer) bool) {
|
||||
for _, peer := range srvc.peers {
|
||||
if !f(peer) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// locks the peer map.
|
||||
func (service *Service) Lock() {
|
||||
service.stateLock.Lock()
|
||||
func (srvc *Service) Lock() {
|
||||
srvc.stateLock.Lock()
|
||||
}
|
||||
|
||||
// unlocks the peer map.
|
||||
func (service *Service) Unlock() {
|
||||
service.stateLock.Unlock()
|
||||
func (srvc *Service) Unlock() {
|
||||
srvc.stateLock.Unlock()
|
||||
}
|
||||
|
||||
func (srvc *Service) stop() {
|
||||
// OnDisconnect handler might need to do something important
|
||||
srvc.RangePeers(func(peer *protocol.CNPeer) bool {
|
||||
peer.Kill()
|
||||
if srvc.OnDisconnect != nil {
|
||||
srvc.OnDisconnect(peer)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
log.Printf("%s service stopped\n", srvc.Name)
|
||||
close(srvc.stopped)
|
||||
}
|
||||
|
||||
// handleEvents is the main event loop for the service.
|
||||
// it handles all events from the peers and calls the appropriate handlers.
|
||||
func (service *Service) handleEvents(eRecv <-chan chan *protocol.Event) {
|
||||
func (srvc *Service) handleEvents(peerPipe <-chan newPeerConnection) {
|
||||
defer srvc.stop()
|
||||
|
||||
poll := make([]reflect.SelectCase, 0, 4)
|
||||
|
||||
// add the stop channel and the peer connection channel to our poll queue
|
||||
poll = append(poll, reflect.SelectCase{
|
||||
Dir: reflect.SelectRecv,
|
||||
Chan: reflect.ValueOf(service.stop),
|
||||
Chan: reflect.ValueOf(srvc.ctx.Done()),
|
||||
})
|
||||
poll = append(poll, reflect.SelectCase{
|
||||
Dir: reflect.SelectRecv,
|
||||
Chan: reflect.ValueOf(eRecv),
|
||||
Chan: reflect.ValueOf(peerPipe),
|
||||
})
|
||||
|
||||
addPoll := func(channel chan *protocol.PacketEvent) {
|
||||
poll = append(poll, reflect.SelectCase{
|
||||
Dir: reflect.SelectRecv,
|
||||
Chan: reflect.ValueOf(channel),
|
||||
})
|
||||
}
|
||||
|
||||
removePoll := func(index int) {
|
||||
poll = append(poll[:index], poll[index+1:]...)
|
||||
}
|
||||
|
||||
for {
|
||||
chosen, value, _ := reflect.Select(poll)
|
||||
if chosen == 0 {
|
||||
// stop
|
||||
|
||||
// OnDisconnect handler might need to do something important
|
||||
service.Lock()
|
||||
service.RangePeers(func(peer *protocol.CNPeer, uData interface{}) bool {
|
||||
peer.Kill()
|
||||
service.disconnect(peer)
|
||||
return true
|
||||
})
|
||||
service.Unlock()
|
||||
|
||||
// signal we have stopped
|
||||
close(service.stopped)
|
||||
chosen, value, recvOK := reflect.Select(poll)
|
||||
switch chosen {
|
||||
case 0: // cancel signal received, stop the service
|
||||
return
|
||||
} else if chosen == 1 {
|
||||
// new peer, add it to our poll queue
|
||||
poll = append(poll, reflect.SelectCase{
|
||||
Dir: reflect.SelectRecv,
|
||||
Chan: reflect.ValueOf(value.Interface()),
|
||||
})
|
||||
} else {
|
||||
// peer event
|
||||
event, ok := value.Interface().(*protocol.Event)
|
||||
if !ok {
|
||||
panic("invalid event type")
|
||||
case 1: // new peer, add it to our poll queue
|
||||
if !recvOK {
|
||||
return
|
||||
}
|
||||
|
||||
service.Lock()
|
||||
switch event.Type {
|
||||
case protocol.EVENT_CLIENT_DISCONNECT:
|
||||
// strip the peer from our poll queue
|
||||
poll = append(poll[:chosen], poll[chosen+1:]...)
|
||||
service.disconnect(value.Interface().(*protocol.Event).Peer)
|
||||
case protocol.EVENT_CLIENT_CONNECT:
|
||||
service.connect(event.Peer)
|
||||
case protocol.EVENT_CLIENT_PACKET:
|
||||
if err := service.handlePacket(event.Peer, event.PktID, protocol.NewPacket(event.Pkt)); err != nil {
|
||||
log.Printf("Error handling packet: %v", err)
|
||||
event.Peer.Kill()
|
||||
}
|
||||
|
||||
// the packet buffer is given to us by the event, so we'll need to make sure to return it to the pool
|
||||
protocol.PutBuffer(event.Pkt)
|
||||
evnt := value.Interface().(newPeerConnection)
|
||||
addPoll(evnt.channel)
|
||||
srvc.connect(evnt.channel, evnt.peer)
|
||||
default: // peer event
|
||||
channel := poll[chosen].Chan.Interface().(chan *protocol.PacketEvent)
|
||||
peer := srvc.getPeer(channel)
|
||||
if peer == nil {
|
||||
log.Printf("Unknown peer event: %v", value)
|
||||
removePoll(chosen)
|
||||
continue
|
||||
}
|
||||
service.Unlock()
|
||||
|
||||
evnt, ok := value.Interface().(*protocol.PacketEvent)
|
||||
if !recvOK || !ok || evnt == nil {
|
||||
// peer disconnected, remove it from our poll queue
|
||||
removePoll(chosen)
|
||||
srvc.disconnect(channel, peer)
|
||||
continue
|
||||
}
|
||||
|
||||
srvc.Lock()
|
||||
if err := srvc.handlePacket(peer, evnt.PktID, protocol.NewPacket(evnt.Pkt)); err != nil {
|
||||
log.Printf("Error handling packet: %v", err)
|
||||
peer.Kill()
|
||||
}
|
||||
srvc.Unlock()
|
||||
|
||||
// the packet buffer is given to us by the event, so we'll need to make sure to return it to the pool
|
||||
protocol.PutBuffer(evnt.Pkt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (service *Service) handlePacket(peer *protocol.CNPeer, typeID uint32, pkt protocol.Packet) error {
|
||||
uData := service.peers[peer]
|
||||
if hndlr, ok := service.packetHandlers[typeID]; ok {
|
||||
func (srvc *Service) handlePacket(peer *protocol.CNPeer, typeID uint32, pkt protocol.Packet) error {
|
||||
if hndlr, ok := srvc.packetHandlers[typeID]; ok {
|
||||
// fmt.Printf("Handling packet %x\n", typeID)
|
||||
if err := hndlr(peer, uData, pkt); err != nil {
|
||||
if err := hndlr(peer, pkt); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
@@ -249,24 +263,20 @@ func (service *Service) handlePacket(peer *protocol.CNPeer, typeID uint32, pkt p
|
||||
return nil
|
||||
}
|
||||
|
||||
func (service *Service) disconnect(peer *protocol.CNPeer) {
|
||||
if service.OnDisconnect != nil {
|
||||
uData := service.peers[peer]
|
||||
service.OnDisconnect(peer, uData)
|
||||
func (srvc *Service) disconnect(channel chan *protocol.PacketEvent, peer *protocol.CNPeer) {
|
||||
log.Printf("Peer %p disconnected from %s\n", peer, srvc.Name)
|
||||
if srvc.OnDisconnect != nil {
|
||||
srvc.OnDisconnect(peer)
|
||||
}
|
||||
|
||||
log.Printf("Peer %p disconnected from %s\n", peer, service.Name)
|
||||
delete(service.peers, peer)
|
||||
srvc.removePeer(channel)
|
||||
}
|
||||
|
||||
func (service *Service) connect(peer *protocol.CNPeer) {
|
||||
// default uData to nil, but if the service has an OnConnect
|
||||
// handler, use the result from that
|
||||
uData := interface{}(nil)
|
||||
if service.OnConnect != nil {
|
||||
uData = service.OnConnect(peer)
|
||||
func (srvc *Service) connect(channel chan *protocol.PacketEvent, peer *protocol.CNPeer) {
|
||||
log.Printf("New peer %p connected to %s\n", peer, srvc.Name)
|
||||
if srvc.OnConnect != nil {
|
||||
srvc.OnConnect(peer)
|
||||
}
|
||||
|
||||
log.Printf("New peer %p connected to %s\n", peer, service.Name)
|
||||
service.SetPeerData(peer, uData)
|
||||
srvc.setPeer(channel, peer)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package service_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
@@ -13,12 +15,11 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
srvc *service.Service
|
||||
srvcPort int
|
||||
)
|
||||
|
||||
const (
|
||||
timeout = 5
|
||||
timeout = 2
|
||||
maxDummyPeers = 5
|
||||
)
|
||||
|
||||
@@ -44,15 +45,18 @@ func TestMain(m *testing.M) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
srvc = service.NewService("TEST", srvcPort)
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestService(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
srvc := service.NewService(ctx, "TEST", srvcPort)
|
||||
|
||||
// waitgroup to wait for test packet handler to be called
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
srvc.AddPacketHandler(0x1234, func(peer *protocol.CNPeer, uData interface{}, pkt protocol.Packet) error {
|
||||
srvc.AddPacketHandler(0x1234, func(peer *protocol.CNPeer, pkt protocol.Packet) error {
|
||||
log.Printf("Received packet %#v", pkt)
|
||||
wg.Done()
|
||||
return nil
|
||||
})
|
||||
@@ -65,7 +69,7 @@ func TestService(t *testing.T) {
|
||||
|
||||
// wait for service to start
|
||||
<-srvc.Started()
|
||||
wg.Add(maxDummyPeers)
|
||||
wg.Add(maxDummyPeers * 3) // 2 wg.Done() calls per dummy peer
|
||||
for i := 0; i < maxDummyPeers; i++ {
|
||||
go func() {
|
||||
// make dummy client
|
||||
@@ -74,18 +78,28 @@ func TestService(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
peer := protocol.NewCNPeer(conn)
|
||||
defer peer.Kill()
|
||||
// send dummy packet
|
||||
if err := peer.Send(0x1234); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
peer := protocol.NewCNPeer(ctx, conn)
|
||||
go func() {
|
||||
defer peer.Kill()
|
||||
|
||||
// send dummy packets
|
||||
for i := 0; i < 2; i++ {
|
||||
if err := peer.Send(0x1234); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// we wait until Handler gracefully exits (peer was killed)
|
||||
peer.Handler(make(chan *protocol.PacketEvent))
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
if !waitWithTimeout(&wg, timeout) {
|
||||
t.Error("timeout waiting for packet handler to be called")
|
||||
t.Error("failed to wait for packet handler to be called")
|
||||
}
|
||||
srvc.Stop()
|
||||
|
||||
cancel()
|
||||
<-srvc.Stopped()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user