Common: Polish Fiber class, add comments, asserts and more tests.
This commit is contained in:
		| @@ -2,6 +2,7 @@ | ||||
| // Licensed under GPLv2 or any later version | ||||
| // Refer to the license.txt file included. | ||||
|  | ||||
| #include "common/assert.h" | ||||
| #include "common/fiber.h" | ||||
| #ifdef _MSC_VER | ||||
| #include <windows.h> | ||||
| @@ -18,11 +19,11 @@ struct Fiber::FiberImpl { | ||||
| }; | ||||
|  | ||||
| void Fiber::start() { | ||||
|     if (previous_fiber) { | ||||
|         previous_fiber->guard.unlock(); | ||||
|         previous_fiber = nullptr; | ||||
|     } | ||||
|     ASSERT(previous_fiber != nullptr); | ||||
|     previous_fiber->guard.unlock(); | ||||
|     previous_fiber.reset(); | ||||
|     entry_point(start_parameter); | ||||
|     UNREACHABLE(); | ||||
| } | ||||
|  | ||||
| void __stdcall Fiber::FiberStartFunc(void* fiber_parameter) | ||||
| @@ -43,12 +44,16 @@ Fiber::Fiber() : guard{}, entry_point{}, start_parameter{}, previous_fiber{} { | ||||
|  | ||||
| Fiber::~Fiber() { | ||||
|     // Make sure the Fiber is not being used | ||||
|     guard.lock(); | ||||
|     guard.unlock(); | ||||
|     bool locked = guard.try_lock(); | ||||
|     ASSERT_MSG(locked, "Destroying a fiber that's still running"); | ||||
|     if (locked) { | ||||
|         guard.unlock(); | ||||
|     } | ||||
|     DeleteFiber(impl->handle); | ||||
| } | ||||
|  | ||||
| void Fiber::Exit() { | ||||
|     ASSERT_MSG(is_thread_fiber, "Exitting non main thread fiber"); | ||||
|     if (!is_thread_fiber) { | ||||
|         return; | ||||
|     } | ||||
| @@ -57,14 +62,15 @@ void Fiber::Exit() { | ||||
| } | ||||
|  | ||||
| void Fiber::YieldTo(std::shared_ptr<Fiber> from, std::shared_ptr<Fiber> to) { | ||||
|     ASSERT_MSG(from != nullptr, "Yielding fiber is null!"); | ||||
|     ASSERT_MSG(to != nullptr, "Next fiber is null!"); | ||||
|     to->guard.lock(); | ||||
|     to->previous_fiber = from; | ||||
|     SwitchToFiber(to->impl->handle); | ||||
|     auto previous_fiber = from->previous_fiber; | ||||
|     if (previous_fiber) { | ||||
|         previous_fiber->guard.unlock(); | ||||
|         previous_fiber.reset(); | ||||
|     } | ||||
|     ASSERT(previous_fiber != nullptr); | ||||
|     previous_fiber->guard.unlock(); | ||||
|     previous_fiber.reset(); | ||||
| } | ||||
|  | ||||
| std::shared_ptr<Fiber> Fiber::ThreadToFiber() { | ||||
| @@ -85,12 +91,12 @@ struct alignas(64) Fiber::FiberImpl { | ||||
| }; | ||||
|  | ||||
| void Fiber::start(boost::context::detail::transfer_t& transfer) { | ||||
|     if (previous_fiber) { | ||||
|         previous_fiber->impl->context = transfer.fctx; | ||||
|         previous_fiber->guard.unlock(); | ||||
|         previous_fiber = nullptr; | ||||
|     } | ||||
|     ASSERT(previous_fiber != nullptr); | ||||
|     previous_fiber->impl->context = transfer.fctx; | ||||
|     previous_fiber->guard.unlock(); | ||||
|     previous_fiber.reset(); | ||||
|     entry_point(start_parameter); | ||||
|     UNREACHABLE(); | ||||
| } | ||||
|  | ||||
| void Fiber::FiberStartFunc(boost::context::detail::transfer_t transfer) | ||||
| @@ -113,11 +119,15 @@ Fiber::Fiber() : guard{}, entry_point{}, start_parameter{}, previous_fiber{} { | ||||
|  | ||||
| Fiber::~Fiber() { | ||||
|     // Make sure the Fiber is not being used | ||||
|     guard.lock(); | ||||
|     guard.unlock(); | ||||
|     bool locked = guard.try_lock(); | ||||
|     ASSERT_MSG(locked, "Destroying a fiber that's still running"); | ||||
|     if (locked) { | ||||
|         guard.unlock(); | ||||
|     } | ||||
| } | ||||
|  | ||||
| void Fiber::Exit() { | ||||
|     ASSERT_MSG(is_thread_fiber, "Exitting non main thread fiber"); | ||||
|     if (!is_thread_fiber) { | ||||
|         return; | ||||
|     } | ||||
| @@ -125,15 +135,16 @@ void Fiber::Exit() { | ||||
| } | ||||
|  | ||||
| void Fiber::YieldTo(std::shared_ptr<Fiber> from, std::shared_ptr<Fiber> to) { | ||||
|     ASSERT_MSG(from != nullptr, "Yielding fiber is null!"); | ||||
|     ASSERT_MSG(to != nullptr, "Next fiber is null!"); | ||||
|     to->guard.lock(); | ||||
|     to->previous_fiber = from; | ||||
|     auto transfer = boost::context::detail::jump_fcontext(to->impl.context, nullptr); | ||||
|     auto previous_fiber = from->previous_fiber; | ||||
|     if (previous_fiber) { | ||||
|         previous_fiber->impl->context = transfer.fctx; | ||||
|         previous_fiber->guard.unlock(); | ||||
|         previous_fiber.reset(); | ||||
|     } | ||||
|     ASSERT(previous_fiber != nullptr); | ||||
|     previous_fiber->impl->context = transfer.fctx; | ||||
|     previous_fiber->guard.unlock(); | ||||
|     previous_fiber.reset(); | ||||
| } | ||||
|  | ||||
| std::shared_ptr<Fiber> Fiber::ThreadToFiber() { | ||||
|   | ||||
| @@ -18,6 +18,18 @@ namespace boost::context::detail { | ||||
|  | ||||
| namespace Common { | ||||
|  | ||||
| /** | ||||
|  * Fiber class | ||||
|  * a fiber is a userspace thread with it's own context. They can be used to | ||||
|  * implement coroutines, emulated threading systems and certain asynchronous | ||||
|  * patterns. | ||||
|  * | ||||
|  * This class implements fibers at a low level, thus allowing greater freedom | ||||
|  * to implement such patterns. This fiber class is 'threadsafe' only one fiber | ||||
|  * can be running at a time and threads will be locked while trying to yield to | ||||
|  * a running fiber until it yields. WARNING exchanging two running fibers between | ||||
|  * threads will cause a deadlock. | ||||
|  */ | ||||
| class Fiber { | ||||
| public: | ||||
|     Fiber(std::function<void(void*)>&& entry_point_func, void* start_parameter); | ||||
| @@ -53,8 +65,6 @@ private: | ||||
|     static void FiberStartFunc(boost::context::detail::transfer_t transfer); | ||||
| #endif | ||||
|  | ||||
|  | ||||
|  | ||||
|     struct FiberImpl; | ||||
|  | ||||
|     SpinLock guard; | ||||
|   | ||||
| @@ -43,4 +43,11 @@ void SpinLock::unlock() { | ||||
|     lck.clear(std::memory_order_release); | ||||
| } | ||||
|  | ||||
| bool SpinLock::try_lock() { | ||||
|     if (lck.test_and_set(std::memory_order_acquire)) { | ||||
|         return false; | ||||
|     } | ||||
|     return true; | ||||
| } | ||||
|  | ||||
| } // namespace Common | ||||
|   | ||||
| @@ -12,6 +12,7 @@ class SpinLock { | ||||
| public: | ||||
|     void lock(); | ||||
|     void unlock(); | ||||
|     bool try_lock(); | ||||
|  | ||||
| private: | ||||
|     std::atomic_flag lck = ATOMIC_FLAG_INIT; | ||||
|   | ||||
| @@ -64,7 +64,9 @@ static void ThreadStart1(u32 id, TestControl1& test_control) { | ||||
|     test_control.ExecuteThread(id); | ||||
| } | ||||
|  | ||||
|  | ||||
| /** This test checks for fiber setup configuration and validates that fibers are | ||||
|  *  doing all the work required. | ||||
|  */ | ||||
| TEST_CASE("Fibers::Setup", "[common]") { | ||||
|     constexpr u32 num_threads = 7; | ||||
|     TestControl1 test_control{}; | ||||
| @@ -188,6 +190,10 @@ static void ThreadStart2_2(u32 id, TestControl2& test_control) { | ||||
|     test_control.Exit(); | ||||
| } | ||||
|  | ||||
| /** This test checks for fiber thread exchange configuration and validates that fibers are | ||||
|  *  that a fiber has been succesfully transfered from one thread to another and that the TLS | ||||
|  *  region of the thread is kept while changing fibers. | ||||
|  */ | ||||
| TEST_CASE("Fibers::InterExchange", "[common]") { | ||||
|     TestControl2 test_control{}; | ||||
|     test_control.thread_fibers.resize(2, nullptr); | ||||
| @@ -210,5 +216,92 @@ TEST_CASE("Fibers::InterExchange", "[common]") { | ||||
|     REQUIRE(test_control.value1 == cal_value); | ||||
| } | ||||
|  | ||||
| class TestControl3 { | ||||
| public: | ||||
|     TestControl3() = default; | ||||
|  | ||||
|     void DoWork1() { | ||||
|         value1 += 1; | ||||
|         Fiber::YieldTo(fiber1, fiber2); | ||||
|         std::thread::id this_id = std::this_thread::get_id(); | ||||
|         u32 id = ids[this_id]; | ||||
|         value3 += 1; | ||||
|         Fiber::YieldTo(fiber1, thread_fibers[id]); | ||||
|     } | ||||
|  | ||||
|     void DoWork2() { | ||||
|         value2 += 1; | ||||
|         std::thread::id this_id = std::this_thread::get_id(); | ||||
|         u32 id = ids[this_id]; | ||||
|         Fiber::YieldTo(fiber2, thread_fibers[id]); | ||||
|     } | ||||
|  | ||||
|     void ExecuteThread(u32 id); | ||||
|  | ||||
|     void CallFiber1() { | ||||
|         std::thread::id this_id = std::this_thread::get_id(); | ||||
|         u32 id = ids[this_id]; | ||||
|         Fiber::YieldTo(thread_fibers[id], fiber1); | ||||
|     } | ||||
|  | ||||
|     void Exit(); | ||||
|  | ||||
|     u32 value1{}; | ||||
|     u32 value2{}; | ||||
|     u32 value3{}; | ||||
|     std::unordered_map<std::thread::id, u32> ids; | ||||
|     std::vector<std::shared_ptr<Common::Fiber>> thread_fibers; | ||||
|     std::shared_ptr<Common::Fiber> fiber1; | ||||
|     std::shared_ptr<Common::Fiber> fiber2; | ||||
| }; | ||||
|  | ||||
| static void WorkControl3_1(void* control) { | ||||
|     TestControl3* test_control = static_cast<TestControl3*>(control); | ||||
|     test_control->DoWork1(); | ||||
| } | ||||
|  | ||||
| static void WorkControl3_2(void* control) { | ||||
|     TestControl3* test_control = static_cast<TestControl3*>(control); | ||||
|     test_control->DoWork2(); | ||||
| } | ||||
|  | ||||
| void TestControl3::ExecuteThread(u32 id) { | ||||
|     std::thread::id this_id = std::this_thread::get_id(); | ||||
|     ids[this_id] = id; | ||||
|     auto thread_fiber = Fiber::ThreadToFiber(); | ||||
|     thread_fibers[id] = thread_fiber; | ||||
| } | ||||
|  | ||||
| void TestControl3::Exit() { | ||||
|     std::thread::id this_id = std::this_thread::get_id(); | ||||
|     u32 id = ids[this_id]; | ||||
|     thread_fibers[id]->Exit(); | ||||
| } | ||||
|  | ||||
| static void ThreadStart3(u32 id, TestControl3& test_control) { | ||||
|     test_control.ExecuteThread(id); | ||||
|     test_control.CallFiber1(); | ||||
|     test_control.Exit(); | ||||
| } | ||||
|  | ||||
| /** This test checks for one two threads racing for starting the same fiber. | ||||
|  *  It checks execution occured in an ordered manner and by no time there were | ||||
|  *  two contexts at the same time. | ||||
|  */ | ||||
| TEST_CASE("Fibers::StartRace", "[common]") { | ||||
|     TestControl3 test_control{}; | ||||
|     test_control.thread_fibers.resize(2, nullptr); | ||||
|     test_control.fiber1 = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl3_1}, &test_control); | ||||
|     test_control.fiber2 = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl3_2}, &test_control); | ||||
|     std::thread thread1(ThreadStart3, 0, std::ref(test_control)); | ||||
|     std::thread thread2(ThreadStart3, 1, std::ref(test_control)); | ||||
|     thread1.join(); | ||||
|     thread2.join(); | ||||
|     REQUIRE(test_control.value1 == 1); | ||||
|     REQUIRE(test_control.value2 == 1); | ||||
|     REQUIRE(test_control.value3 == 1); | ||||
| } | ||||
|  | ||||
|  | ||||
|  | ||||
| } // namespace Common | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Fernando Sahmkow
					Fernando Sahmkow