diff --git a/cnet/service_test.go b/cnet/service_test.go index ee85705..705e555 100644 --- a/cnet/service_test.go +++ b/cnet/service_test.go @@ -8,10 +8,10 @@ import ( "os" "sync" "testing" - "time" "github.com/CPunch/gopenfusion/cnet" "github.com/CPunch/gopenfusion/cnet/protocol" + "github.com/CPunch/gopenfusion/util" "github.com/matryer/is" ) @@ -24,25 +24,6 @@ const ( maxDummyPeers = 5 ) -func selectWithTimeout(ch <-chan struct{}, seconds int) bool { - select { - case <-ch: - return true - case <-time.After(time.Duration(seconds) * time.Second): - return false - } -} - -func waitWithTimeout(wg *sync.WaitGroup, seconds int) bool { - done := make(chan struct{}) - go func() { - defer close(done) - wg.Wait() - }() - - return selectWithTimeout(done, seconds) -} - func TestMain(m *testing.M) { var err error srvcPort, err = cnet.RandomPort() @@ -62,7 +43,7 @@ func TestService(t *testing.T) { // shutdown service when test is done defer func() { cancel() - is.True(selectWithTimeout(srvc.Stopped(), timeout)) // wait for service to stop with timeout + is.True(util.SelectWithTimeout(srvc.Stopped(), timeout)) // wait for service to stop with timeout }() // our dummy packet handler @@ -85,8 +66,8 @@ func TestService(t *testing.T) { } // run service - go func() { is.NoErr(srvc.Start()) }() // srvc.Start error - is.True(selectWithTimeout(srvc.Started(), timeout)) // wait for service to start with timeout + go func() { is.NoErr(srvc.Start()) }() // srvc.Start error + is.True(util.SelectWithTimeout(srvc.Started(), timeout)) // wait for service to start with timeout wg.Add(maxDummyPeers * 2) // 2 wg.Done() per peer for receiving packets for i := 0; i < maxDummyPeers; i++ { @@ -111,5 +92,5 @@ func TestService(t *testing.T) { }() } - is.True(waitWithTimeout(&wg, timeout)) // wait for all dummy peers to be done with timeout + is.True(util.WaitWithTimeout(&wg, timeout)) // wait for all dummy peers to be done with timeout } diff --git a/util/util.go b/util/util.go index 24d000f..4208723 100644 --- a/util/util.go +++ b/util/util.go @@ -1,7 +1,29 @@ package util -import "time" +import ( + "sync" + "time" +) func GetTime() uint64 { return uint64(time.Now().UnixMilli()) } + +func SelectWithTimeout(ch <-chan struct{}, seconds int) bool { + select { + case <-ch: + return true + case <-time.After(time.Duration(seconds) * time.Second): + return false + } +} + +func WaitWithTimeout(wg *sync.WaitGroup, seconds int) bool { + done := make(chan struct{}) + go func() { + defer close(done) + wg.Wait() + }() + + return SelectWithTimeout(done, seconds) +}