hle: kernel: Implement CloneCurrentObject and improve session management.
This commit is contained in:
		| @@ -86,10 +86,8 @@ public: | |||||||
|  |  | ||||||
|         // The entire size of the raw data section in u32 units, including the 16 bytes of mandatory |         // The entire size of the raw data section in u32 units, including the 16 bytes of mandatory | ||||||
|         // padding. |         // padding. | ||||||
|         u32 raw_data_size = ctx.IsTipc() |         u32 raw_data_size = ctx.write_size = | ||||||
|                                 ? normal_params_size - 1 |             ctx.IsTipc() ? normal_params_size - 1 : normal_params_size; | ||||||
|                                 : sizeof(IPC::DataPayloadHeader) / 4 + 4 + normal_params_size; |  | ||||||
|  |  | ||||||
|         u32 num_handles_to_move{}; |         u32 num_handles_to_move{}; | ||||||
|         u32 num_domain_objects{}; |         u32 num_domain_objects{}; | ||||||
|         const bool always_move_handles{ |         const bool always_move_handles{ | ||||||
| @@ -101,16 +99,20 @@ public: | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (ctx.Session()->IsDomain()) { |         if (ctx.Session()->IsDomain()) { | ||||||
|             raw_data_size += static_cast<u32>(sizeof(DomainMessageHeader) / 4 + num_domain_objects); |             raw_data_size += | ||||||
|  |                 static_cast<u32>(sizeof(DomainMessageHeader) / sizeof(u32) + num_domain_objects); | ||||||
|  |             ctx.write_size += num_domain_objects; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (ctx.IsTipc()) { |         if (ctx.IsTipc()) { | ||||||
|             header.type.Assign(ctx.GetCommandType()); |             header.type.Assign(ctx.GetCommandType()); | ||||||
|  |         } else { | ||||||
|  |             raw_data_size += static_cast<u32>(sizeof(IPC::DataPayloadHeader) / sizeof(u32) + 4 + | ||||||
|  |                                               normal_params_size); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         ctx.data_size = static_cast<u32>(raw_data_size); |         header.data_size.Assign(raw_data_size); | ||||||
|         header.data_size.Assign(static_cast<u32>(raw_data_size)); |         if (num_handles_to_copy || num_handles_to_move) { | ||||||
|         if (num_handles_to_copy != 0 || num_handles_to_move != 0) { |  | ||||||
|             header.enable_handle_descriptor.Assign(1); |             header.enable_handle_descriptor.Assign(1); | ||||||
|         } |         } | ||||||
|         PushRaw(header); |         PushRaw(header); | ||||||
| @@ -143,7 +145,8 @@ public: | |||||||
|         data_payload_index = index; |         data_payload_index = index; | ||||||
|  |  | ||||||
|         ctx.data_payload_offset = index; |         ctx.data_payload_offset = index; | ||||||
|         ctx.domain_offset = index + raw_data_size / 4; |         ctx.write_size += index; | ||||||
|  |         ctx.domain_offset = static_cast<u32>(index + raw_data_size / sizeof(u32)); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     template <class T> |     template <class T> | ||||||
| @@ -404,7 +407,7 @@ public: | |||||||
|     std::shared_ptr<T> PopIpcInterface() { |     std::shared_ptr<T> PopIpcInterface() { | ||||||
|         ASSERT(context->Session()->IsDomain()); |         ASSERT(context->Session()->IsDomain()); | ||||||
|         ASSERT(context->GetDomainMessageHeader().input_object_count > 0); |         ASSERT(context->GetDomainMessageHeader().input_object_count > 0); | ||||||
|         return context->GetDomainRequestHandler<T>(Pop<u32>() - 1); |         return context->GetDomainHandler<T>(Pop<u32>() - 1); | ||||||
|     } |     } | ||||||
| }; | }; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -35,11 +35,11 @@ SessionRequestHandler::SessionRequestHandler() = default; | |||||||
| SessionRequestHandler::~SessionRequestHandler() = default; | SessionRequestHandler::~SessionRequestHandler() = default; | ||||||
|  |  | ||||||
| void SessionRequestHandler::ClientConnected(KServerSession* session) { | void SessionRequestHandler::ClientConnected(KServerSession* session) { | ||||||
|     session->SetHleHandler(shared_from_this()); |     session->SetSessionHandler(shared_from_this()); | ||||||
| } | } | ||||||
|  |  | ||||||
| void SessionRequestHandler::ClientDisconnected(KServerSession* session) { | void SessionRequestHandler::ClientDisconnected(KServerSession* session) { | ||||||
|     session->SetHleHandler(nullptr); |     session->SetSessionHandler(nullptr); | ||||||
| } | } | ||||||
|  |  | ||||||
| HLERequestContext::HLERequestContext(KernelCore& kernel_, Core::Memory::Memory& memory_, | HLERequestContext::HLERequestContext(KernelCore& kernel_, Core::Memory::Memory& memory_, | ||||||
| @@ -186,18 +186,6 @@ ResultCode HLERequestContext::WriteToOutgoingCommandBuffer(KThread& requesting_t | |||||||
|     auto& owner_process = *requesting_thread.GetOwnerProcess(); |     auto& owner_process = *requesting_thread.GetOwnerProcess(); | ||||||
|     auto& handle_table = owner_process.GetHandleTable(); |     auto& handle_table = owner_process.GetHandleTable(); | ||||||
|  |  | ||||||
|     // The data_size already includes the payload header, the padding and the domain header. |  | ||||||
|     std::size_t size{}; |  | ||||||
|  |  | ||||||
|     if (IsTipc()) { |  | ||||||
|         size = cmd_buf.size(); |  | ||||||
|     } else { |  | ||||||
|         size = data_payload_offset + data_size - sizeof(IPC::DataPayloadHeader) / sizeof(u32) - 4; |  | ||||||
|         if (Session()->IsDomain()) { |  | ||||||
|             size -= sizeof(IPC::DomainMessageHeader) / sizeof(u32); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     for (auto& object : copy_objects) { |     for (auto& object : copy_objects) { | ||||||
|         Handle handle{}; |         Handle handle{}; | ||||||
|         if (object) { |         if (object) { | ||||||
| @@ -222,7 +210,7 @@ ResultCode HLERequestContext::WriteToOutgoingCommandBuffer(KThread& requesting_t | |||||||
|     if (Session()->IsDomain()) { |     if (Session()->IsDomain()) { | ||||||
|         current_offset = domain_offset - static_cast<u32>(domain_objects.size()); |         current_offset = domain_offset - static_cast<u32>(domain_objects.size()); | ||||||
|         for (const auto& object : domain_objects) { |         for (const auto& object : domain_objects) { | ||||||
|             server_session->AppendDomainRequestHandler(object); |             server_session->AppendDomainHandler(object); | ||||||
|             cmd_buf[current_offset++] = |             cmd_buf[current_offset++] = | ||||||
|                 static_cast<u32_le>(server_session->NumDomainRequestHandlers()); |                 static_cast<u32_le>(server_session->NumDomainRequestHandlers()); | ||||||
|         } |         } | ||||||
| @@ -230,7 +218,7 @@ ResultCode HLERequestContext::WriteToOutgoingCommandBuffer(KThread& requesting_t | |||||||
|  |  | ||||||
|     // Copy the translated command buffer back into the thread's command buffer area. |     // Copy the translated command buffer back into the thread's command buffer area. | ||||||
|     memory.WriteBlock(owner_process, requesting_thread.GetTLSAddress(), cmd_buf.data(), |     memory.WriteBlock(owner_process, requesting_thread.GetTLSAddress(), cmd_buf.data(), | ||||||
|                       size * sizeof(u32)); |                       write_size * sizeof(u32)); | ||||||
|  |  | ||||||
|     return RESULT_SUCCESS; |     return RESULT_SUCCESS; | ||||||
| } | } | ||||||
|   | |||||||
| @@ -12,6 +12,8 @@ | |||||||
| #include <type_traits> | #include <type_traits> | ||||||
| #include <vector> | #include <vector> | ||||||
| #include <boost/container/small_vector.hpp> | #include <boost/container/small_vector.hpp> | ||||||
|  |  | ||||||
|  | #include "common/assert.h" | ||||||
| #include "common/common_types.h" | #include "common/common_types.h" | ||||||
| #include "common/concepts.h" | #include "common/concepts.h" | ||||||
| #include "common/swap.h" | #include "common/swap.h" | ||||||
| @@ -84,6 +86,69 @@ public: | |||||||
|     void ClientDisconnected(KServerSession* session); |     void ClientDisconnected(KServerSession* session); | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | using SessionRequestHandlerPtr = std::shared_ptr<SessionRequestHandler>; | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Manages the underlying HLE requests for a session, and whether (or not) the session should be | ||||||
|  |  * treated as a domain. This is managed separately from server sessions, as this state is shared | ||||||
|  |  * when objects are cloned. | ||||||
|  |  */ | ||||||
|  | class SessionRequestManager final { | ||||||
|  | public: | ||||||
|  |     SessionRequestManager() = default; | ||||||
|  |  | ||||||
|  |     bool IsDomain() const { | ||||||
|  |         return is_domain; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     void ConvertToDomain() { | ||||||
|  |         domain_handlers = {session_handler}; | ||||||
|  |         is_domain = true; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     std::size_t DomainHandlerCount() const { | ||||||
|  |         return domain_handlers.size(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     bool HasSessionHandler() const { | ||||||
|  |         return session_handler != nullptr; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     SessionRequestHandler& SessionHandler() { | ||||||
|  |         return *session_handler; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     const SessionRequestHandler& SessionHandler() const { | ||||||
|  |         return *session_handler; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     void CloseDomainHandler(std::size_t index) { | ||||||
|  |         if (index < DomainHandlerCount()) { | ||||||
|  |             domain_handlers[index] = nullptr; | ||||||
|  |         } else { | ||||||
|  |             UNREACHABLE_MSG("Unexpected handler index {}", index); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     SessionRequestHandlerPtr DomainHandler(std::size_t index) const { | ||||||
|  |         ASSERT_MSG(index < DomainHandlerCount(), "Unexpected handler index {}", index); | ||||||
|  |         return domain_handlers.at(index); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     void AppendDomainHandler(SessionRequestHandlerPtr&& handler) { | ||||||
|  |         domain_handlers.emplace_back(std::move(handler)); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     void SetSessionHandler(SessionRequestHandlerPtr&& handler) { | ||||||
|  |         session_handler = std::move(handler); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  | private: | ||||||
|  |     bool is_domain{}; | ||||||
|  |     SessionRequestHandlerPtr session_handler; | ||||||
|  |     std::vector<SessionRequestHandlerPtr> domain_handlers; | ||||||
|  | }; | ||||||
|  |  | ||||||
| /** | /** | ||||||
|  * Class containing information about an in-flight IPC request being handled by an HLE service |  * Class containing information about an in-flight IPC request being handled by an HLE service | ||||||
|  * implementation. Services should avoid using old global APIs (e.g. Kernel::GetCommandBuffer()) and |  * implementation. Services should avoid using old global APIs (e.g. Kernel::GetCommandBuffer()) and | ||||||
| @@ -239,18 +304,17 @@ public: | |||||||
|         copy_objects.emplace_back(object); |         copy_objects.emplace_back(object); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     void AddDomainObject(std::shared_ptr<SessionRequestHandler> object) { |     void AddDomainObject(SessionRequestHandlerPtr object) { | ||||||
|         domain_objects.emplace_back(std::move(object)); |         domain_objects.emplace_back(std::move(object)); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     template <typename T> |     template <typename T> | ||||||
|     std::shared_ptr<T> GetDomainRequestHandler(std::size_t index) const { |     std::shared_ptr<T> GetDomainHandler(std::size_t index) const { | ||||||
|         return std::static_pointer_cast<T>(domain_request_handlers.at(index)); |         return std::static_pointer_cast<T>(manager->DomainHandler(index)); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     void SetDomainRequestHandlers( |     void SetSessionRequestManager(std::shared_ptr<SessionRequestManager> manager_) { | ||||||
|         const std::vector<std::shared_ptr<SessionRequestHandler>>& handlers) { |         manager = std::move(manager_); | ||||||
|         domain_request_handlers = handlers; |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     /// Clears the list of objects so that no lingering objects are written accidentally to the |     /// Clears the list of objects so that no lingering objects are written accidentally to the | ||||||
| @@ -297,7 +361,7 @@ private: | |||||||
|     boost::container::small_vector<Handle, 8> copy_handles; |     boost::container::small_vector<Handle, 8> copy_handles; | ||||||
|     boost::container::small_vector<KAutoObject*, 8> move_objects; |     boost::container::small_vector<KAutoObject*, 8> move_objects; | ||||||
|     boost::container::small_vector<KAutoObject*, 8> copy_objects; |     boost::container::small_vector<KAutoObject*, 8> copy_objects; | ||||||
|     boost::container::small_vector<std::shared_ptr<SessionRequestHandler>, 8> domain_objects; |     boost::container::small_vector<SessionRequestHandlerPtr, 8> domain_objects; | ||||||
|  |  | ||||||
|     std::optional<IPC::CommandHeader> command_header; |     std::optional<IPC::CommandHeader> command_header; | ||||||
|     std::optional<IPC::HandleDescriptorHeader> handle_descriptor_header; |     std::optional<IPC::HandleDescriptorHeader> handle_descriptor_header; | ||||||
| @@ -311,12 +375,12 @@ private: | |||||||
|  |  | ||||||
|     u32_le command{}; |     u32_le command{}; | ||||||
|     u64 pid{}; |     u64 pid{}; | ||||||
|  |     u32 write_size{}; | ||||||
|     u32 data_payload_offset{}; |     u32 data_payload_offset{}; | ||||||
|     u32 handles_offset{}; |     u32 handles_offset{}; | ||||||
|     u32 domain_offset{}; |     u32 domain_offset{}; | ||||||
|     u32 data_size{}; |  | ||||||
|  |  | ||||||
|     std::vector<std::shared_ptr<SessionRequestHandler>> domain_request_handlers; |     std::shared_ptr<SessionRequestManager> manager; | ||||||
|     bool is_thread_waiting{}; |     bool is_thread_waiting{}; | ||||||
|  |  | ||||||
|     KernelCore& kernel; |     KernelCore& kernel; | ||||||
|   | |||||||
| @@ -31,6 +31,9 @@ public: | |||||||
|     const KPort* GetParent() const { |     const KPort* GetParent() const { | ||||||
|         return parent; |         return parent; | ||||||
|     } |     } | ||||||
|  |     KPort* GetParent() { | ||||||
|  |         return parent; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     s32 GetNumSessions() const { |     s32 GetNumSessions() const { | ||||||
|         return num_sessions; |         return num_sessions; | ||||||
|   | |||||||
| @@ -56,11 +56,8 @@ ResultCode KPort::EnqueueSession(KServerSession* session) { | |||||||
|  |  | ||||||
|     R_UNLESS(state == State::Normal, ResultPortClosed); |     R_UNLESS(state == State::Normal, ResultPortClosed); | ||||||
|  |  | ||||||
|     if (server.HasHLEHandler()) { |     server.GetSessionRequestHandler()->ClientConnected(session); | ||||||
|         server.GetHLEHandler()->ClientConnected(session); |     server.EnqueueSession(session); | ||||||
|     } else { |  | ||||||
|         server.EnqueueSession(session); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     return RESULT_SUCCESS; |     return RESULT_SUCCESS; | ||||||
| } | } | ||||||
|   | |||||||
| @@ -32,26 +32,24 @@ public: | |||||||
|     explicit KServerPort(KernelCore& kernel_); |     explicit KServerPort(KernelCore& kernel_); | ||||||
|     virtual ~KServerPort() override; |     virtual ~KServerPort() override; | ||||||
|  |  | ||||||
|     using HLEHandler = std::shared_ptr<SessionRequestHandler>; |  | ||||||
|  |  | ||||||
|     void Initialize(KPort* parent_, std::string&& name_); |     void Initialize(KPort* parent_, std::string&& name_); | ||||||
|  |  | ||||||
|     /// Whether or not this server port has an HLE handler available. |     /// Whether or not this server port has an HLE handler available. | ||||||
|     bool HasHLEHandler() const { |     bool HasSessionRequestHandler() const { | ||||||
|         return hle_handler != nullptr; |         return session_handler != nullptr; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     /// Gets the HLE handler for this port. |     /// Gets the HLE handler for this port. | ||||||
|     HLEHandler GetHLEHandler() const { |     SessionRequestHandlerPtr GetSessionRequestHandler() const { | ||||||
|         return hle_handler; |         return session_handler; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     /** |     /** | ||||||
|      * Sets the HLE handler template for the port. ServerSessions crated by connecting to this port |      * Sets the HLE handler template for the port. ServerSessions crated by connecting to this port | ||||||
|      * will inherit a reference to this handler. |      * will inherit a reference to this handler. | ||||||
|      */ |      */ | ||||||
|     void SetHleHandler(HLEHandler hle_handler_) { |     void SetSessionHandler(SessionRequestHandlerPtr&& handler) { | ||||||
|         hle_handler = std::move(hle_handler_); |         session_handler = std::move(handler); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     void EnqueueSession(KServerSession* pending_session); |     void EnqueueSession(KServerSession* pending_session); | ||||||
| @@ -73,7 +71,7 @@ private: | |||||||
|  |  | ||||||
| private: | private: | ||||||
|     SessionList session_list; |     SessionList session_list; | ||||||
|     HLEHandler hle_handler; |     SessionRequestHandlerPtr session_handler; | ||||||
|     KPort* parent{}; |     KPort* parent{}; | ||||||
| }; | }; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -23,7 +23,8 @@ | |||||||
|  |  | ||||||
| namespace Kernel { | namespace Kernel { | ||||||
|  |  | ||||||
| KServerSession::KServerSession(KernelCore& kernel_) : KSynchronizationObject{kernel_} {} | KServerSession::KServerSession(KernelCore& kernel_) | ||||||
|  |     : KSynchronizationObject{kernel_}, manager{std::make_shared<SessionRequestManager>()} {} | ||||||
|  |  | ||||||
| KServerSession::~KServerSession() { | KServerSession::~KServerSession() { | ||||||
|     kernel.ReleaseServiceThread(service_thread); |     kernel.ReleaseServiceThread(service_thread); | ||||||
| @@ -43,14 +44,8 @@ void KServerSession::Destroy() { | |||||||
| } | } | ||||||
|  |  | ||||||
| void KServerSession::OnClientClosed() { | void KServerSession::OnClientClosed() { | ||||||
|     // We keep a shared pointer to the hle handler to keep it alive throughout |     if (manager->HasSessionHandler()) { | ||||||
|     // the call to ClientDisconnected, as ClientDisconnected invalidates the |         manager->SessionHandler().ClientDisconnected(this); | ||||||
|     // hle_handler member itself during the course of the function executing. |  | ||||||
|     std::shared_ptr<SessionRequestHandler> handler = hle_handler; |  | ||||||
|     if (handler) { |  | ||||||
|         // Note that after this returns, this server session's hle_handler is |  | ||||||
|         // invalidated (set to null). |  | ||||||
|         handler->ClientDisconnected(this); |  | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -66,12 +61,12 @@ bool KServerSession::IsSignaled() const { | |||||||
|     return false; |     return false; | ||||||
| } | } | ||||||
|  |  | ||||||
| void KServerSession::AppendDomainRequestHandler(std::shared_ptr<SessionRequestHandler> handler) { | void KServerSession::AppendDomainHandler(SessionRequestHandlerPtr handler) { | ||||||
|     domain_request_handlers.push_back(std::move(handler)); |     manager->AppendDomainHandler(std::move(handler)); | ||||||
| } | } | ||||||
|  |  | ||||||
| std::size_t KServerSession::NumDomainRequestHandlers() const { | std::size_t KServerSession::NumDomainRequestHandlers() const { | ||||||
|     return domain_request_handlers.size(); |     return manager->DomainHandlerCount(); | ||||||
| } | } | ||||||
|  |  | ||||||
| ResultCode KServerSession::HandleDomainSyncRequest(Kernel::HLERequestContext& context) { | ResultCode KServerSession::HandleDomainSyncRequest(Kernel::HLERequestContext& context) { | ||||||
| @@ -80,14 +75,14 @@ ResultCode KServerSession::HandleDomainSyncRequest(Kernel::HLERequestContext& co | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     // Set domain handlers in HLE context, used for domain objects (IPC interfaces) as inputs |     // Set domain handlers in HLE context, used for domain objects (IPC interfaces) as inputs | ||||||
|     context.SetDomainRequestHandlers(domain_request_handlers); |     context.SetSessionRequestManager(manager); | ||||||
|  |  | ||||||
|     // If there is a DomainMessageHeader, then this is CommandType "Request" |     // If there is a DomainMessageHeader, then this is CommandType "Request" | ||||||
|     const auto& domain_message_header = context.GetDomainMessageHeader(); |     const auto& domain_message_header = context.GetDomainMessageHeader(); | ||||||
|     const u32 object_id{domain_message_header.object_id}; |     const u32 object_id{domain_message_header.object_id}; | ||||||
|     switch (domain_message_header.command) { |     switch (domain_message_header.command) { | ||||||
|     case IPC::DomainMessageHeader::CommandType::SendMessage: |     case IPC::DomainMessageHeader::CommandType::SendMessage: | ||||||
|         if (object_id > domain_request_handlers.size()) { |         if (object_id > manager->DomainHandlerCount()) { | ||||||
|             LOG_CRITICAL(IPC, |             LOG_CRITICAL(IPC, | ||||||
|                          "object_id {} is too big! This probably means a recent service call " |                          "object_id {} is too big! This probably means a recent service call " | ||||||
|                          "to {} needed to return a new interface!", |                          "to {} needed to return a new interface!", | ||||||
| @@ -95,12 +90,12 @@ ResultCode KServerSession::HandleDomainSyncRequest(Kernel::HLERequestContext& co | |||||||
|             UNREACHABLE(); |             UNREACHABLE(); | ||||||
|             return RESULT_SUCCESS; // Ignore error if asserts are off |             return RESULT_SUCCESS; // Ignore error if asserts are off | ||||||
|         } |         } | ||||||
|         return domain_request_handlers[object_id - 1]->HandleSyncRequest(*this, context); |         return manager->DomainHandler(object_id - 1)->HandleSyncRequest(*this, context); | ||||||
|  |  | ||||||
|     case IPC::DomainMessageHeader::CommandType::CloseVirtualHandle: { |     case IPC::DomainMessageHeader::CommandType::CloseVirtualHandle: { | ||||||
|         LOG_DEBUG(IPC, "CloseVirtualHandle, object_id=0x{:08X}", object_id); |         LOG_DEBUG(IPC, "CloseVirtualHandle, object_id=0x{:08X}", object_id); | ||||||
|  |  | ||||||
|         domain_request_handlers[object_id - 1] = nullptr; |         manager->CloseDomainHandler(object_id - 1); | ||||||
|  |  | ||||||
|         IPC::ResponseBuilder rb{context, 2}; |         IPC::ResponseBuilder rb{context, 2}; | ||||||
|         rb.Push(RESULT_SUCCESS); |         rb.Push(RESULT_SUCCESS); | ||||||
| @@ -133,14 +128,14 @@ ResultCode KServerSession::CompleteSyncRequest(HLERequestContext& context) { | |||||||
|     if (IsDomain() && context.HasDomainMessageHeader()) { |     if (IsDomain() && context.HasDomainMessageHeader()) { | ||||||
|         result = HandleDomainSyncRequest(context); |         result = HandleDomainSyncRequest(context); | ||||||
|         // If there is no domain header, the regular session handler is used |         // If there is no domain header, the regular session handler is used | ||||||
|     } else if (hle_handler != nullptr) { |     } else if (manager->HasSessionHandler()) { | ||||||
|         // If this ServerSession has an associated HLE handler, forward the request to it. |         // If this ServerSession has an associated HLE handler, forward the request to it. | ||||||
|         result = hle_handler->HandleSyncRequest(*this, context); |         result = manager->SessionHandler().HandleSyncRequest(*this, context); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (convert_to_domain) { |     if (convert_to_domain) { | ||||||
|         ASSERT_MSG(IsSession(), "ServerSession is already a domain instance."); |         ASSERT_MSG(!IsDomain(), "ServerSession is already a domain instance."); | ||||||
|         domain_request_handlers = {hle_handler}; |         manager->ConvertToDomain(); | ||||||
|         convert_to_domain = false; |         convert_to_domain = false; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -12,6 +12,7 @@ | |||||||
| #include <boost/intrusive/list.hpp> | #include <boost/intrusive/list.hpp> | ||||||
|  |  | ||||||
| #include "common/threadsafe_queue.h" | #include "common/threadsafe_queue.h" | ||||||
|  | #include "core/hle/kernel/hle_ipc.h" | ||||||
| #include "core/hle/kernel/k_synchronization_object.h" | #include "core/hle/kernel/k_synchronization_object.h" | ||||||
| #include "core/hle/kernel/service_thread.h" | #include "core/hle/kernel/service_thread.h" | ||||||
| #include "core/hle/result.h" | #include "core/hle/result.h" | ||||||
| @@ -64,8 +65,8 @@ public: | |||||||
|      * instead of the regular IPC machinery. (The regular IPC machinery is currently not |      * instead of the regular IPC machinery. (The regular IPC machinery is currently not | ||||||
|      * implemented.) |      * implemented.) | ||||||
|      */ |      */ | ||||||
|     void SetHleHandler(std::shared_ptr<SessionRequestHandler> hle_handler_) { |     void SetSessionHandler(SessionRequestHandlerPtr handler) { | ||||||
|         hle_handler = std::move(hle_handler_); |         manager->SetSessionHandler(std::move(handler)); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     /** |     /** | ||||||
| @@ -82,7 +83,7 @@ public: | |||||||
|  |  | ||||||
|     /// Adds a new domain request handler to the collection of request handlers within |     /// Adds a new domain request handler to the collection of request handlers within | ||||||
|     /// this ServerSession instance. |     /// this ServerSession instance. | ||||||
|     void AppendDomainRequestHandler(std::shared_ptr<SessionRequestHandler> handler); |     void AppendDomainHandler(SessionRequestHandlerPtr handler); | ||||||
|  |  | ||||||
|     /// Retrieves the total number of domain request handlers that have been |     /// Retrieves the total number of domain request handlers that have been | ||||||
|     /// appended to this ServerSession instance. |     /// appended to this ServerSession instance. | ||||||
| @@ -90,12 +91,7 @@ public: | |||||||
|  |  | ||||||
|     /// Returns true if the session has been converted to a domain, otherwise False |     /// Returns true if the session has been converted to a domain, otherwise False | ||||||
|     bool IsDomain() const { |     bool IsDomain() const { | ||||||
|         return !IsSession(); |         return manager->IsDomain(); | ||||||
|     } |  | ||||||
|  |  | ||||||
|     /// Returns true if this session has not been converted to a domain, otherwise false. |  | ||||||
|     bool IsSession() const { |  | ||||||
|         return domain_request_handlers.empty(); |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     /// Converts the session to a domain at the end of the current command |     /// Converts the session to a domain at the end of the current command | ||||||
| @@ -103,6 +99,21 @@ public: | |||||||
|         convert_to_domain = true; |         convert_to_domain = true; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     /// Gets the session request manager, which forwards requests to the underlying service | ||||||
|  |     std::shared_ptr<SessionRequestManager>& GetSessionRequestManager() { | ||||||
|  |         return manager; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     /// Gets the session request manager, which forwards requests to the underlying service | ||||||
|  |     const std::shared_ptr<SessionRequestManager>& GetSessionRequestManager() const { | ||||||
|  |         return manager; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     /// Sets the session request manager, which forwards requests to the underlying service | ||||||
|  |     void SetSessionRequestManager(std::shared_ptr<SessionRequestManager> manager_) { | ||||||
|  |         manager = std::move(manager_); | ||||||
|  |     } | ||||||
|  |  | ||||||
| private: | private: | ||||||
|     /// Queues a sync request from the emulated application. |     /// Queues a sync request from the emulated application. | ||||||
|     ResultCode QueueSyncRequest(KThread* thread, Core::Memory::Memory& memory); |     ResultCode QueueSyncRequest(KThread* thread, Core::Memory::Memory& memory); | ||||||
| @@ -114,11 +125,8 @@ private: | |||||||
|     /// object handle. |     /// object handle. | ||||||
|     ResultCode HandleDomainSyncRequest(Kernel::HLERequestContext& context); |     ResultCode HandleDomainSyncRequest(Kernel::HLERequestContext& context); | ||||||
|  |  | ||||||
|     /// This session's HLE request handler (applicable when not a domain) |     /// This session's HLE request handlers | ||||||
|     std::shared_ptr<SessionRequestHandler> hle_handler; |     std::shared_ptr<SessionRequestManager> manager; | ||||||
|  |  | ||||||
|     /// This is the list of domain request handlers (after conversion to a domain) |  | ||||||
|     std::vector<std::shared_ptr<SessionRequestHandler>> domain_request_handlers; |  | ||||||
|  |  | ||||||
|     /// When set to True, converts the session to a domain at the end of the command |     /// When set to True, converts the session to a domain at the end of the command | ||||||
|     bool convert_to_domain{}; |     bool convert_to_domain{}; | ||||||
|   | |||||||
| @@ -66,6 +66,10 @@ public: | |||||||
|         return port; |         return port; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     KClientPort* GetParent() { | ||||||
|  |         return port; | ||||||
|  |     } | ||||||
|  |  | ||||||
| private: | private: | ||||||
|     enum class State : u8 { |     enum class State : u8 { | ||||||
|         Invalid = 0, |         Invalid = 0, | ||||||
|   | |||||||
| @@ -107,7 +107,7 @@ void ServiceFrameworkBase::InstallAsService(SM::ServiceManager& service_manager) | |||||||
|     ASSERT(!port_installed); |     ASSERT(!port_installed); | ||||||
|  |  | ||||||
|     auto port = service_manager.RegisterService(service_name, max_sessions).Unwrap(); |     auto port = service_manager.RegisterService(service_name, max_sessions).Unwrap(); | ||||||
|     port->SetHleHandler(shared_from_this()); |     port->SetSessionHandler(shared_from_this()); | ||||||
|     port_installed = true; |     port_installed = true; | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -118,7 +118,7 @@ Kernel::KClientPort& ServiceFrameworkBase::CreatePort(Kernel::KernelCore& kernel | |||||||
|  |  | ||||||
|     auto* port = Kernel::KPort::Create(kernel); |     auto* port = Kernel::KPort::Create(kernel); | ||||||
|     port->Initialize(max_sessions, false, service_name); |     port->Initialize(max_sessions, false, service_name); | ||||||
|     port->GetServerPort().SetHleHandler(shared_from_this()); |     port->GetServerPort().SetSessionHandler(shared_from_this()); | ||||||
|  |  | ||||||
|     port_installed = true; |     port_installed = true; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -4,8 +4,13 @@ | |||||||
|  |  | ||||||
| #include "common/assert.h" | #include "common/assert.h" | ||||||
| #include "common/logging/log.h" | #include "common/logging/log.h" | ||||||
|  | #include "core/core.h" | ||||||
| #include "core/hle/ipc_helpers.h" | #include "core/hle/ipc_helpers.h" | ||||||
|  | #include "core/hle/kernel/k_client_port.h" | ||||||
| #include "core/hle/kernel/k_client_session.h" | #include "core/hle/kernel/k_client_session.h" | ||||||
|  | #include "core/hle/kernel/k_port.h" | ||||||
|  | #include "core/hle/kernel/k_scoped_resource_reservation.h" | ||||||
|  | #include "core/hle/kernel/k_server_port.h" | ||||||
| #include "core/hle/kernel/k_server_session.h" | #include "core/hle/kernel/k_server_session.h" | ||||||
| #include "core/hle/kernel/k_session.h" | #include "core/hle/kernel/k_session.h" | ||||||
| #include "core/hle/service/sm/controller.h" | #include "core/hle/service/sm/controller.h" | ||||||
| @@ -13,7 +18,7 @@ | |||||||
| namespace Service::SM { | namespace Service::SM { | ||||||
|  |  | ||||||
| void Controller::ConvertCurrentObjectToDomain(Kernel::HLERequestContext& ctx) { | void Controller::ConvertCurrentObjectToDomain(Kernel::HLERequestContext& ctx) { | ||||||
|     ASSERT_MSG(ctx.Session()->IsSession(), "Session is already a domain"); |     ASSERT_MSG(!ctx.Session()->IsDomain(), "Session is already a domain"); | ||||||
|     LOG_DEBUG(Service, "called, server_session={}", ctx.Session()->GetId()); |     LOG_DEBUG(Service, "called, server_session={}", ctx.Session()->GetId()); | ||||||
|     ctx.Session()->ConvertToDomain(); |     ctx.Session()->ConvertToDomain(); | ||||||
|  |  | ||||||
| @@ -29,16 +34,36 @@ void Controller::CloneCurrentObject(Kernel::HLERequestContext& ctx) { | |||||||
|  |  | ||||||
|     LOG_DEBUG(Service, "called"); |     LOG_DEBUG(Service, "called"); | ||||||
|  |  | ||||||
|     auto session = ctx.Session()->GetParent(); |     auto& kernel = system.Kernel(); | ||||||
|  |     auto* session = ctx.Session()->GetParent(); | ||||||
|  |     auto* port = session->GetParent()->GetParent(); | ||||||
|  |  | ||||||
|     // Open a reference to the session to simulate a new one being created. |     // Reserve a new session from the process resource limit. | ||||||
|     session->Open(); |     Kernel::KScopedResourceReservation session_reservation( | ||||||
|     session->GetClientSession().Open(); |         kernel.CurrentProcess()->GetResourceLimit(), Kernel::LimitableResource::Sessions); | ||||||
|     session->GetServerSession().Open(); |     if (!session_reservation.Succeeded()) { | ||||||
|  |         IPC::ResponseBuilder rb{ctx, 2}; | ||||||
|  |         rb.Push(Kernel::ResultLimitReached); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // Create a new session. | ||||||
|  |     auto* clone = Kernel::KSession::Create(kernel); | ||||||
|  |     clone->Initialize(&port->GetClientPort(), session->GetName()); | ||||||
|  |  | ||||||
|  |     // Commit the session reservation. | ||||||
|  |     session_reservation.Commit(); | ||||||
|  |  | ||||||
|  |     // Enqueue the session with the named port. | ||||||
|  |     port->EnqueueSession(&clone->GetServerSession()); | ||||||
|  |  | ||||||
|  |     // Set the session request manager. | ||||||
|  |     clone->GetServerSession().SetSessionRequestManager( | ||||||
|  |         session->GetServerSession().GetSessionRequestManager()); | ||||||
|  |  | ||||||
|  |     // We succeeded. | ||||||
|     IPC::ResponseBuilder rb{ctx, 2, 0, 1, IPC::ResponseBuilder::Flags::AlwaysMoveHandles}; |     IPC::ResponseBuilder rb{ctx, 2, 0, 1, IPC::ResponseBuilder::Flags::AlwaysMoveHandles}; | ||||||
|     rb.Push(RESULT_SUCCESS); |     rb.Push(RESULT_SUCCESS); | ||||||
|     rb.PushMoveObjects(session->GetClientSession()); |     rb.PushMoveObjects(clone->GetClientSession()); | ||||||
| } | } | ||||||
|  |  | ||||||
| void Controller::CloneCurrentObjectEx(Kernel::HLERequestContext& ctx) { | void Controller::CloneCurrentObjectEx(Kernel::HLERequestContext& ctx) { | ||||||
|   | |||||||
| @@ -150,31 +150,31 @@ ResultVal<Kernel::KClientSession*> SM::GetServiceImpl(Kernel::HLERequestContext& | |||||||
|     IPC::RequestParser rp{ctx}; |     IPC::RequestParser rp{ctx}; | ||||||
|     std::string name(PopServiceName(rp)); |     std::string name(PopServiceName(rp)); | ||||||
|  |  | ||||||
|  |     // Find the named port. | ||||||
|     auto result = service_manager.GetServicePort(name); |     auto result = service_manager.GetServicePort(name); | ||||||
|     if (result.Failed()) { |     if (result.Failed()) { | ||||||
|         LOG_ERROR(Service_SM, "called service={} -> error 0x{:08X}", name, result.Code().raw); |         LOG_ERROR(Service_SM, "called service={} -> error 0x{:08X}", name, result.Code().raw); | ||||||
|         return result.Code(); |         return result.Code(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     auto* port = result.Unwrap(); |     auto* port = result.Unwrap(); | ||||||
|  |  | ||||||
|  |     // Reserve a new session from the process resource limit. | ||||||
|     Kernel::KScopedResourceReservation session_reservation( |     Kernel::KScopedResourceReservation session_reservation( | ||||||
|         kernel.CurrentProcess()->GetResourceLimit(), Kernel::LimitableResource::Sessions); |         kernel.CurrentProcess()->GetResourceLimit(), Kernel::LimitableResource::Sessions); | ||||||
|     R_UNLESS(session_reservation.Succeeded(), Kernel::ResultLimitReached); |     R_UNLESS(session_reservation.Succeeded(), Kernel::ResultLimitReached); | ||||||
|  |  | ||||||
|  |     // Create a new session. | ||||||
|     auto* session = Kernel::KSession::Create(kernel); |     auto* session = Kernel::KSession::Create(kernel); | ||||||
|     session->Initialize(&port->GetClientPort(), std::move(name)); |     session->Initialize(&port->GetClientPort(), std::move(name)); | ||||||
|  |  | ||||||
|     // Commit the session reservation. |     // Commit the session reservation. | ||||||
|     session_reservation.Commit(); |     session_reservation.Commit(); | ||||||
|  |  | ||||||
|     if (port->GetServerPort().GetHLEHandler()) { |     // Enqueue the session with the named port. | ||||||
|         port->GetServerPort().GetHLEHandler()->ClientConnected(&session->GetServerSession()); |     port->EnqueueSession(&session->GetServerSession()); | ||||||
|     } else { |  | ||||||
|         port->EnqueueSession(&session->GetServerSession()); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     LOG_DEBUG(Service_SM, "called service={} -> session={}", name, session->GetId()); |     LOG_DEBUG(Service_SM, "called service={} -> session={}", name, session->GetId()); | ||||||
|  |  | ||||||
|     return MakeResult(&session->GetClientSession()); |     return MakeResult(&session->GetClientSession()); | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -73,7 +73,7 @@ public: | |||||||
|         if (port == nullptr) { |         if (port == nullptr) { | ||||||
|             return nullptr; |             return nullptr; | ||||||
|         } |         } | ||||||
|         return std::static_pointer_cast<T>(port->GetServerPort().GetHLEHandler()); |         return std::static_pointer_cast<T>(port->GetServerPort().GetSessionRequestHandler()); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     void InvokeControlRequest(Kernel::HLERequestContext& context); |     void InvokeControlRequest(Kernel::HLERequestContext& context); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 bunnei
					bunnei