diff --git a/go.mod b/go.mod index c3a01f2..25b0368 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/georgysavva/scany/v2 v2.0.0 github.com/google/subcommands v1.2.0 github.com/lib/pq v1.10.9 + github.com/matryer/is v1.4.1 github.com/redis/go-redis/v9 v9.0.5 golang.org/x/crypto v0.7.0 ) diff --git a/go.sum b/go.sum index 5525e7d..c590978 100644 --- a/go.sum +++ b/go.sum @@ -48,6 +48,8 @@ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+o github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/matryer/is v1.4.1 h1:55ehd8zaGABKLXQUe2awZ99BD/PTc2ls+KV/dXphgEQ= +github.com/matryer/is v1.4.1/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU= github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI= diff --git a/internal/db/db_test.go b/internal/db/db_test.go index 76b0a8c..a85b37c 100644 --- a/internal/db/db_test.go +++ b/internal/db/db_test.go @@ -6,6 +6,8 @@ import ( "os" "testing" + "github.com/matryer/is" + "github.com/CPunch/gopenfusion/internal/db" "github.com/CPunch/gopenfusion/internal/protocol" "github.com/bitcomplete/sqltestutil" @@ -37,23 +39,20 @@ func TestMain(m *testing.M) { } func TestDBAccount(t *testing.T) { - if _, err := testDB.NewAccount("test", "test"); err != nil { - t.Error(err) - } + is := is.New(t) + + // create new account + _, err := testDB.NewAccount("test", "test") + is.NoErr(err) // now try to retrieve account data acc, err := testDB.TryLogin("test", "test") - if err != nil { - t.Error(err) - } + is.NoErr(err) - if acc.Login != "test" { - t.Error("account username is not test") - } + _, err = testDB.TryLogin("test", "wrongpassword") - if _, err = testDB.TryLogin("test", "wrongpassword"); !errors.Is(err, db.ErrLoginInvalidPassword) { - t.Error("expected ErrLoginInvalidPassword") - } + is.True(acc.Login == "test") // login data should match created account + is.True(errors.Is(err, db.ErrLoginInvalidPassword)) // wrong password passed to TryLogin() should fail with ErrLoginInvalidPassword } /* @@ -81,22 +80,18 @@ gopenfusion=# SELECT * FROM Inventory; */ func TestDBPlayer(t *testing.T) { - if _, err := testDB.NewAccount("testplayer", "test"); err != nil { - t.Error(err) - } + is := is.New(t) + _, err := testDB.NewAccount("testplayer", "test") + is.NoErr(err) // now try to retrieve account data acc, err := testDB.TryLogin("testplayer", "test") - if err != nil { - t.Error(err) - } + is.NoErr(err) plrID, err := testDB.NewPlayer(acc.AccountID, "Neil", "Mcscout", 1) - if err != nil { - t.Error(err) - } + is.NoErr(err) - if err = testDB.FinishPlayer(&protocol.SP_CL2LS_REQ_CHAR_CREATE{ + err = testDB.FinishPlayer(&protocol.SP_CL2LS_REQ_CHAR_CREATE{ PCStyle: protocol.SPCStyle{ IPC_UID: int64(plrID), INameCheck: 1, @@ -116,11 +111,9 @@ func TestDBPlayer(t *testing.T) { IEquipLBID: 359, IEquipFootID: 194, }, - }, acc.AccountID); err != nil { - t.Error(err) - } + }, acc.AccountID) + is.NoErr(err) - if err = testDB.FinishTutorial(plrID, acc.AccountID); err != nil { - t.Error(err) - } + err = testDB.FinishTutorial(plrID, acc.AccountID) + is.NoErr(err) } diff --git a/internal/entity/entity_test.go b/internal/entity/entity_test.go index b043b11..05255b7 100644 --- a/internal/entity/entity_test.go +++ b/internal/entity/entity_test.go @@ -4,9 +4,11 @@ import ( "testing" "github.com/CPunch/gopenfusion/internal/entity" + "github.com/matryer/is" ) func TestChunkSliceDifference(t *testing.T) { + is := is.New(t) chunks := []*entity.Chunk{ entity.NewChunk(entity.MakeChunkPosition(0, 0)), entity.NewChunk(entity.MakeChunkPosition(0, 1)), @@ -28,13 +30,7 @@ func TestChunkSliceDifference(t *testing.T) { } diff := entity.ChunkSliceDifference(c1, c2) - if len(diff) != 1 { - t.Logf("%+v", diff) - t.Error("expected 1 chunk in difference") - } - if diff[0] != chunks[3] { - t.Logf("%+v", diff) - t.Error("wrong difference") - } + is.True(len(diff) == 1) // should be 1 chunk in difference + is.True(diff[0] == chunks[3]) // should be chunks[3] in difference } diff --git a/internal/protocol/protocol_test.go b/internal/protocol/protocol_test.go index 54c1ad8..92e8900 100644 --- a/internal/protocol/protocol_test.go +++ b/internal/protocol/protocol_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/CPunch/gopenfusion/internal/protocol" + "github.com/matryer/is" ) type TestPacketData struct { @@ -53,59 +54,51 @@ var ( ) func TestPacketEncode(t *testing.T) { + is := is.New(t) buf := bytes.NewBuffer(nil) pkt := protocol.NewPacket(buf) - if err := pkt.Encode(testStruct); err != nil { - t.Error(err) - } + err := pkt.Encode(testStruct) + is.NoErr(err) - if !bytes.Equal(buf.Bytes(), testData[:]) { - t.Error("packet data does not match!") - } + is.True(bytes.Equal(buf.Bytes(), testData[:])) // encoded data should match expected data } func TestPacketDecode(t *testing.T) { + is := is.New(t) buf := bytes.NewBuffer(nil) pkt := protocol.NewPacket(buf) buf.Write(testData[:]) var test TestPacketData - if err := pkt.Decode(&test); err != nil { - t.Error(err) - } - - if test != testStruct { - t.Error("packet data does not match!") - } + err := pkt.Decode(&test) + is.NoErr(err) + is.True(test == testStruct) // decoded data should match testStruct } func TestDataEncrypt(t *testing.T) { + is := is.New(t) buf := make([]byte, len(testData)) copy(buf, testData[:]) protocol.EncryptData(buf, []byte(protocol.DEFAULT_KEY)) - if !bytes.Equal(buf, encTestData) { - t.Error("encrypted data does not match!") - } + is.True(bytes.Equal(buf, encTestData)) // encrypted data should match expected data } func TestDataDecrypt(t *testing.T) { + is := is.New(t) buf := make([]byte, len(encTestData)) copy(buf, encTestData) protocol.DecryptData(buf, []byte(protocol.DEFAULT_KEY)) - if !bytes.Equal(buf, testData[:]) { - t.Error("decrypted data does not match!") - } + is.True(bytes.Equal(buf, testData[:])) // decrypted data should match expected data } func TestCreateNewKey(t *testing.T) { + is := is.New(t) key := protocol.CreateNewKey(123456789, 0x1234567890abcdef, 0x1234567890abcdef) - if !bytes.Equal(key, []byte{0x0, 0x31, 0xb8, 0xcd, 0xd, 0xc3, 0xad, 0x67}) { - t.Error("key does not match!") - } + is.True(bytes.Equal(key, []byte{0x0, 0x31, 0xb8, 0xcd, 0xd, 0xc3, 0xad, 0x67})) // key should match expected data } diff --git a/internal/service/service_test.go b/internal/service/service_test.go index c8c8128..b7e086d 100644 --- a/internal/service/service_test.go +++ b/internal/service/service_test.go @@ -12,6 +12,7 @@ import ( "github.com/CPunch/gopenfusion/internal/protocol" "github.com/CPunch/gopenfusion/internal/service" + "github.com/matryer/is" ) var ( @@ -23,6 +24,15 @@ const ( maxDummyPeers = 5 ) +func selectWithTimeout(ch <-chan struct{}, seconds int) bool { + select { + case <-ch: + return true + case <-time.After(time.Duration(seconds) * time.Second): + return false + } +} + func waitWithTimeout(wg *sync.WaitGroup, seconds int) bool { done := make(chan struct{}) go func() { @@ -30,12 +40,7 @@ func waitWithTimeout(wg *sync.WaitGroup, seconds int) bool { wg.Wait() }() - select { - case <-done: - return true - case <-time.After(time.Duration(seconds) * time.Second): - return false - } + return selectWithTimeout(done, seconds) } func TestMain(m *testing.M) { @@ -49,34 +54,38 @@ func TestMain(m *testing.M) { } func TestService(t *testing.T) { + is := is.New(t) ctx, cancel := context.WithCancel(context.Background()) srvc := service.NewService(ctx, "TEST", srvcPort) - - // waitgroup to wait for test packet handler to be called wg := sync.WaitGroup{} + // shutdown service when test is done + defer func() { + cancel() + is.True(selectWithTimeout(srvc.Stopped(), timeout)) // wait for service to stop with timeout + }() + + // our dummy packet handler srvc.AddPacketHandler(0x1234, func(peer *protocol.CNPeer, pkt protocol.Packet) error { log.Printf("Received packet %#v", pkt) wg.Done() return nil }) + // run service go func() { - if err := srvc.Start(); err != nil { - t.Error(err) - } + err := srvc.Start() + is.NoErr(err) // srvc.Start error }() - // wait for service to start - <-srvc.Started() + is.True(selectWithTimeout(srvc.Started(), timeout)) // wait for service to start with timeout + wg.Add(maxDummyPeers * 3) // 2 wg.Done() calls per dummy peer for i := 0; i < maxDummyPeers; i++ { go func() { // make dummy client conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", srvcPort)) - if err != nil { - t.Error(err) - } + is.NoErr(err) // net.Dial error peer := protocol.NewCNPeer(ctx, conn) go func() { @@ -84,9 +93,8 @@ func TestService(t *testing.T) { // send dummy packets for i := 0; i < 2; i++ { - if err := peer.Send(0x1234); err != nil { - t.Error(err) - } + err := peer.Send(0x1234) + is.NoErr(err) // peer.Send error } }() @@ -96,10 +104,5 @@ func TestService(t *testing.T) { }() } - if !waitWithTimeout(&wg, timeout) { - t.Error("failed to wait for packet handler to be called") - } - - cancel() - <-srvc.Stopped() + is.True(waitWithTimeout(&wg, timeout)) // wait for all dummy peers to be done with timeout }