| @@ -63,6 +63,18 @@ option(YUZU_DOWNLOAD_TIME_ZONE_DATA "Always download time zone binaries" OFF) | ||||
|  | ||||
| CMAKE_DEPENDENT_OPTION(YUZU_USE_FASTER_LD "Check if a faster linker is available" ON "NOT WIN32" OFF) | ||||
|  | ||||
| set(DEFAULT_ENABLE_OPENSSL ON) | ||||
| if (ANDROID OR WIN32 OR APPLE) | ||||
|     # - Windows defaults to the Schannel backend. | ||||
|     # - macOS defaults to the SecureTransport backend. | ||||
|     # - Android currently has no SSL backend as the NDK doesn't include any SSL | ||||
|     #   library; a proper 'native' backend would have to go through Java. | ||||
|     # But you can force builds for those platforms to use OpenSSL if you have | ||||
|     # your own copy of it. | ||||
|     set(DEFAULT_ENABLE_OPENSSL OFF) | ||||
| endif() | ||||
| option(ENABLE_OPENSSL "Enable OpenSSL backend for ISslConnection" ${DEFAULT_ENABLE_OPENSSL}) | ||||
|  | ||||
| # On Android, fetch and compile libcxx before doing anything else | ||||
| if (ANDROID) | ||||
|     set(CMAKE_SKIP_INSTALL_RULES ON) | ||||
| @@ -322,6 +334,10 @@ if (MINGW) | ||||
|     find_library(MSWSOCK_LIBRARY mswsock REQUIRED) | ||||
| endif() | ||||
|  | ||||
| if(ENABLE_OPENSSL) | ||||
|     find_package(OpenSSL 1.1.1 REQUIRED) | ||||
| endif() | ||||
|  | ||||
| # Please consider this as a stub | ||||
| if(ENABLE_QT6 AND Qt6_LOCATION) | ||||
|     list(APPEND CMAKE_PREFIX_PATH "${Qt6_LOCATION}") | ||||
|   | ||||
| @@ -5,15 +5,19 @@ | ||||
|  | ||||
| #include "common/common_types.h" | ||||
|  | ||||
| #include <optional> | ||||
|  | ||||
| namespace Network { | ||||
|  | ||||
| /// Address families | ||||
| enum class Domain : u8 { | ||||
|     INET, ///< Address family for IPv4 | ||||
|     Unspecified, ///< Represents 0, used in getaddrinfo hints | ||||
|     INET,        ///< Address family for IPv4 | ||||
| }; | ||||
|  | ||||
| /// Socket types | ||||
| enum class Type { | ||||
|     Unspecified, ///< Represents 0, used in getaddrinfo hints | ||||
|     STREAM, | ||||
|     DGRAM, | ||||
|     RAW, | ||||
| @@ -22,6 +26,7 @@ enum class Type { | ||||
|  | ||||
| /// Protocol values for sockets | ||||
| enum class Protocol : u8 { | ||||
|     Unspecified, ///< Represents 0, usable in various places | ||||
|     ICMP, | ||||
|     TCP, | ||||
|     UDP, | ||||
| @@ -48,4 +53,13 @@ constexpr u32 FLAG_MSG_PEEK = 0x2; | ||||
| constexpr u32 FLAG_MSG_DONTWAIT = 0x80; | ||||
| constexpr u32 FLAG_O_NONBLOCK = 0x800; | ||||
|  | ||||
| /// Cross-platform addrinfo structure | ||||
| struct AddrInfo { | ||||
|     Domain family; | ||||
|     Type socket_type; | ||||
|     Protocol protocol; | ||||
|     SockAddrIn addr; | ||||
|     std::optional<std::string> canon_name; | ||||
| }; | ||||
|  | ||||
| } // namespace Network | ||||
|   | ||||
| @@ -723,6 +723,7 @@ add_library(core STATIC | ||||
|     hle/service/spl/spl_types.h | ||||
|     hle/service/ssl/ssl.cpp | ||||
|     hle/service/ssl/ssl.h | ||||
|     hle/service/ssl/ssl_backend.h | ||||
|     hle/service/time/clock_types.h | ||||
|     hle/service/time/ephemeral_network_system_clock_context_writer.h | ||||
|     hle/service/time/ephemeral_network_system_clock_core.h | ||||
| @@ -864,6 +865,23 @@ if (ARCHITECTURE_x86_64 OR ARCHITECTURE_arm64) | ||||
|     target_link_libraries(core PRIVATE dynarmic::dynarmic) | ||||
| endif() | ||||
|  | ||||
| if(ENABLE_OPENSSL) | ||||
|     target_sources(core PRIVATE | ||||
|         hle/service/ssl/ssl_backend_openssl.cpp) | ||||
|     target_link_libraries(core PRIVATE OpenSSL::SSL) | ||||
| elseif (APPLE) | ||||
|     target_sources(core PRIVATE | ||||
|         hle/service/ssl/ssl_backend_securetransport.cpp) | ||||
|     target_link_libraries(core PRIVATE "-framework Security") | ||||
| elseif (WIN32) | ||||
|     target_sources(core PRIVATE | ||||
|         hle/service/ssl/ssl_backend_schannel.cpp) | ||||
|     target_link_libraries(core PRIVATE secur32) | ||||
| else() | ||||
|     target_sources(core PRIVATE | ||||
|         hle/service/ssl/ssl_backend_none.cpp) | ||||
| endif() | ||||
|  | ||||
| if (YUZU_USE_PRECOMPILED_HEADERS) | ||||
|     target_precompile_headers(core PRIVATE precompiled_headers.h) | ||||
| endif() | ||||
|   | ||||
| @@ -20,6 +20,9 @@ | ||||
| #include "core/internal_network/sockets.h" | ||||
| #include "network/network.h" | ||||
|  | ||||
| using Common::Expected; | ||||
| using Common::Unexpected; | ||||
|  | ||||
| namespace Service::Sockets { | ||||
|  | ||||
| namespace { | ||||
| @@ -265,16 +268,19 @@ void BSD::GetSockOpt(HLERequestContext& ctx) { | ||||
|     const u32 level = rp.Pop<u32>(); | ||||
|     const auto optname = static_cast<OptName>(rp.Pop<u32>()); | ||||
|  | ||||
|     LOG_WARNING(Service, "(STUBBED) called. fd={} level={} optname=0x{:x}", fd, level, optname); | ||||
|  | ||||
|     std::vector<u8> optval(ctx.GetWriteBufferSize()); | ||||
|  | ||||
|     LOG_DEBUG(Service, "called. fd={} level={} optname=0x{:x} len=0x{:x}", fd, level, optname, | ||||
|               optval.size()); | ||||
|  | ||||
|     const Errno err = GetSockOptImpl(fd, level, optname, optval); | ||||
|  | ||||
|     ctx.WriteBuffer(optval); | ||||
|  | ||||
|     IPC::ResponseBuilder rb{ctx, 5}; | ||||
|     rb.Push(ResultSuccess); | ||||
|     rb.Push<s32>(-1); | ||||
|     rb.PushEnum(Errno::NOTCONN); | ||||
|     rb.Push<s32>(err == Errno::SUCCESS ? 0 : -1); | ||||
|     rb.PushEnum(err); | ||||
|     rb.Push<u32>(static_cast<u32>(optval.size())); | ||||
| } | ||||
|  | ||||
| @@ -436,6 +442,31 @@ void BSD::Close(HLERequestContext& ctx) { | ||||
|     BuildErrnoResponse(ctx, CloseImpl(fd)); | ||||
| } | ||||
|  | ||||
| void BSD::DuplicateSocket(HLERequestContext& ctx) { | ||||
|     struct InputParameters { | ||||
|         s32 fd; | ||||
|         u64 reserved; | ||||
|     }; | ||||
|     static_assert(sizeof(InputParameters) == 0x10); | ||||
|  | ||||
|     struct OutputParameters { | ||||
|         s32 ret; | ||||
|         Errno bsd_errno; | ||||
|     }; | ||||
|     static_assert(sizeof(OutputParameters) == 0x8); | ||||
|  | ||||
|     IPC::RequestParser rp{ctx}; | ||||
|     auto input = rp.PopRaw<InputParameters>(); | ||||
|  | ||||
|     Expected<s32, Errno> res = DuplicateSocketImpl(input.fd); | ||||
|     IPC::ResponseBuilder rb{ctx, 4}; | ||||
|     rb.Push(ResultSuccess); | ||||
|     rb.PushRaw(OutputParameters{ | ||||
|         .ret = res.value_or(0), | ||||
|         .bsd_errno = res ? Errno::SUCCESS : res.error(), | ||||
|     }); | ||||
| } | ||||
|  | ||||
| void BSD::EventFd(HLERequestContext& ctx) { | ||||
|     IPC::RequestParser rp{ctx}; | ||||
|     const u64 initval = rp.Pop<u64>(); | ||||
| @@ -477,12 +508,12 @@ std::pair<s32, Errno> BSD::SocketImpl(Domain domain, Type type, Protocol protoco | ||||
|  | ||||
|     auto room_member = room_network.GetRoomMember().lock(); | ||||
|     if (room_member && room_member->IsConnected()) { | ||||
|         descriptor.socket = std::make_unique<Network::ProxySocket>(room_network); | ||||
|         descriptor.socket = std::make_shared<Network::ProxySocket>(room_network); | ||||
|     } else { | ||||
|         descriptor.socket = std::make_unique<Network::Socket>(); | ||||
|         descriptor.socket = std::make_shared<Network::Socket>(); | ||||
|     } | ||||
|  | ||||
|     descriptor.socket->Initialize(Translate(domain), Translate(type), Translate(type, protocol)); | ||||
|     descriptor.socket->Initialize(Translate(domain), Translate(type), Translate(protocol)); | ||||
|     descriptor.is_connection_based = IsConnectionBased(type); | ||||
|  | ||||
|     return {fd, Errno::SUCCESS}; | ||||
| @@ -538,7 +569,7 @@ std::pair<s32, Errno> BSD::PollImpl(std::vector<u8>& write_buffer, std::span<con | ||||
|     std::transform(fds.begin(), fds.end(), host_pollfds.begin(), [this](PollFD pollfd) { | ||||
|         Network::PollFD result; | ||||
|         result.socket = file_descriptors[pollfd.fd]->socket.get(); | ||||
|         result.events = TranslatePollEventsToHost(pollfd.events); | ||||
|         result.events = Translate(pollfd.events); | ||||
|         result.revents = Network::PollEvents{}; | ||||
|         return result; | ||||
|     }); | ||||
| @@ -547,7 +578,7 @@ std::pair<s32, Errno> BSD::PollImpl(std::vector<u8>& write_buffer, std::span<con | ||||
|  | ||||
|     const size_t num = host_pollfds.size(); | ||||
|     for (size_t i = 0; i < num; ++i) { | ||||
|         fds[i].revents = TranslatePollEventsToGuest(host_pollfds[i].revents); | ||||
|         fds[i].revents = Translate(host_pollfds[i].revents); | ||||
|     } | ||||
|     std::memcpy(write_buffer.data(), fds.data(), length); | ||||
|  | ||||
| @@ -617,7 +648,8 @@ Errno BSD::GetPeerNameImpl(s32 fd, std::vector<u8>& write_buffer) { | ||||
|     } | ||||
|     const SockAddrIn guest_addrin = Translate(addr_in); | ||||
|  | ||||
|     ASSERT(write_buffer.size() == sizeof(guest_addrin)); | ||||
|     ASSERT(write_buffer.size() >= sizeof(guest_addrin)); | ||||
|     write_buffer.resize(sizeof(guest_addrin)); | ||||
|     std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin)); | ||||
|     return Translate(bsd_errno); | ||||
| } | ||||
| @@ -633,7 +665,8 @@ Errno BSD::GetSockNameImpl(s32 fd, std::vector<u8>& write_buffer) { | ||||
|     } | ||||
|     const SockAddrIn guest_addrin = Translate(addr_in); | ||||
|  | ||||
|     ASSERT(write_buffer.size() == sizeof(guest_addrin)); | ||||
|     ASSERT(write_buffer.size() >= sizeof(guest_addrin)); | ||||
|     write_buffer.resize(sizeof(guest_addrin)); | ||||
|     std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin)); | ||||
|     return Translate(bsd_errno); | ||||
| } | ||||
| @@ -671,13 +704,47 @@ std::pair<s32, Errno> BSD::FcntlImpl(s32 fd, FcntlCmd cmd, s32 arg) { | ||||
|     } | ||||
| } | ||||
|  | ||||
| Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval) { | ||||
|     UNIMPLEMENTED_IF(level != 0xffff); // SOL_SOCKET | ||||
|  | ||||
| Errno BSD::GetSockOptImpl(s32 fd, u32 level, OptName optname, std::vector<u8>& optval) { | ||||
|     if (!IsFileDescriptorValid(fd)) { | ||||
|         return Errno::BADF; | ||||
|     } | ||||
|  | ||||
|     if (level != static_cast<u32>(SocketLevel::SOCKET)) { | ||||
|         UNIMPLEMENTED_MSG("Unknown getsockopt level"); | ||||
|         return Errno::SUCCESS; | ||||
|     } | ||||
|  | ||||
|     Network::SocketBase* const socket = file_descriptors[fd]->socket.get(); | ||||
|  | ||||
|     switch (optname) { | ||||
|     case OptName::ERROR_: { | ||||
|         auto [pending_err, getsockopt_err] = socket->GetPendingError(); | ||||
|         if (getsockopt_err == Network::Errno::SUCCESS) { | ||||
|             Errno translated_pending_err = Translate(pending_err); | ||||
|             ASSERT_OR_EXECUTE_MSG( | ||||
|                 optval.size() == sizeof(Errno), { return Errno::INVAL; }, | ||||
|                 "Incorrect getsockopt option size"); | ||||
|             optval.resize(sizeof(Errno)); | ||||
|             memcpy(optval.data(), &translated_pending_err, sizeof(Errno)); | ||||
|         } | ||||
|         return Translate(getsockopt_err); | ||||
|     } | ||||
|     default: | ||||
|         UNIMPLEMENTED_MSG("Unimplemented optname={}", optname); | ||||
|         return Errno::SUCCESS; | ||||
|     } | ||||
| } | ||||
|  | ||||
| Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval) { | ||||
|     if (!IsFileDescriptorValid(fd)) { | ||||
|         return Errno::BADF; | ||||
|     } | ||||
|  | ||||
|     if (level != static_cast<u32>(SocketLevel::SOCKET)) { | ||||
|         UNIMPLEMENTED_MSG("Unknown setsockopt level"); | ||||
|         return Errno::SUCCESS; | ||||
|     } | ||||
|  | ||||
|     Network::SocketBase* const socket = file_descriptors[fd]->socket.get(); | ||||
|  | ||||
|     if (optname == OptName::LINGER) { | ||||
| @@ -711,6 +778,9 @@ Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, con | ||||
|         return Translate(socket->SetSndTimeo(value)); | ||||
|     case OptName::RCVTIMEO: | ||||
|         return Translate(socket->SetRcvTimeo(value)); | ||||
|     case OptName::NOSIGPIPE: | ||||
|         LOG_WARNING(Service, "(STUBBED) setting NOSIGPIPE to {}", value); | ||||
|         return Errno::SUCCESS; | ||||
|     default: | ||||
|         UNIMPLEMENTED_MSG("Unimplemented optname={}", optname); | ||||
|         return Errno::SUCCESS; | ||||
| @@ -841,6 +911,28 @@ Errno BSD::CloseImpl(s32 fd) { | ||||
|     return bsd_errno; | ||||
| } | ||||
|  | ||||
| Expected<s32, Errno> BSD::DuplicateSocketImpl(s32 fd) { | ||||
|     if (!IsFileDescriptorValid(fd)) { | ||||
|         return Unexpected(Errno::BADF); | ||||
|     } | ||||
|  | ||||
|     const s32 new_fd = FindFreeFileDescriptorHandle(); | ||||
|     if (new_fd < 0) { | ||||
|         LOG_ERROR(Service, "No more file descriptors available"); | ||||
|         return Unexpected(Errno::MFILE); | ||||
|     } | ||||
|  | ||||
|     file_descriptors[new_fd] = file_descriptors[fd]; | ||||
|     return new_fd; | ||||
| } | ||||
|  | ||||
| std::optional<std::shared_ptr<Network::SocketBase>> BSD::GetSocket(s32 fd) { | ||||
|     if (!IsFileDescriptorValid(fd)) { | ||||
|         return std::nullopt; | ||||
|     } | ||||
|     return file_descriptors[fd]->socket; | ||||
| } | ||||
|  | ||||
| s32 BSD::FindFreeFileDescriptorHandle() noexcept { | ||||
|     for (s32 fd = 0; fd < static_cast<s32>(file_descriptors.size()); ++fd) { | ||||
|         if (!file_descriptors[fd]) { | ||||
| @@ -911,7 +1003,7 @@ BSD::BSD(Core::System& system_, const char* name) | ||||
|         {24, &BSD::Write, "Write"}, | ||||
|         {25, &BSD::Read, "Read"}, | ||||
|         {26, &BSD::Close, "Close"}, | ||||
|         {27, nullptr, "DuplicateSocket"}, | ||||
|         {27, &BSD::DuplicateSocket, "DuplicateSocket"}, | ||||
|         {28, nullptr, "GetResourceStatistics"}, | ||||
|         {29, nullptr, "RecvMMsg"}, | ||||
|         {30, nullptr, "SendMMsg"}, | ||||
|   | ||||
| @@ -8,6 +8,7 @@ | ||||
| #include <vector> | ||||
|  | ||||
| #include "common/common_types.h" | ||||
| #include "common/expected.h" | ||||
| #include "common/socket_types.h" | ||||
| #include "core/hle/service/service.h" | ||||
| #include "core/hle/service/sockets/sockets.h" | ||||
| @@ -29,12 +30,19 @@ public: | ||||
|     explicit BSD(Core::System& system_, const char* name); | ||||
|     ~BSD() override; | ||||
|  | ||||
|     // These methods are called from SSL; the first two are also called from | ||||
|     // this class for the corresponding IPC methods. | ||||
|     // On the real device, the SSL service makes IPC calls to this service. | ||||
|     Common::Expected<s32, Errno> DuplicateSocketImpl(s32 fd); | ||||
|     Errno CloseImpl(s32 fd); | ||||
|     std::optional<std::shared_ptr<Network::SocketBase>> GetSocket(s32 fd); | ||||
|  | ||||
| private: | ||||
|     /// Maximum number of file descriptors | ||||
|     static constexpr size_t MAX_FD = 128; | ||||
|  | ||||
|     struct FileDescriptor { | ||||
|         std::unique_ptr<Network::SocketBase> socket; | ||||
|         std::shared_ptr<Network::SocketBase> socket; | ||||
|         s32 flags = 0; | ||||
|         bool is_connection_based = false; | ||||
|     }; | ||||
| @@ -138,6 +146,7 @@ private: | ||||
|     void Write(HLERequestContext& ctx); | ||||
|     void Read(HLERequestContext& ctx); | ||||
|     void Close(HLERequestContext& ctx); | ||||
|     void DuplicateSocket(HLERequestContext& ctx); | ||||
|     void EventFd(HLERequestContext& ctx); | ||||
|  | ||||
|     template <typename Work> | ||||
| @@ -153,6 +162,7 @@ private: | ||||
|     Errno GetSockNameImpl(s32 fd, std::vector<u8>& write_buffer); | ||||
|     Errno ListenImpl(s32 fd, s32 backlog); | ||||
|     std::pair<s32, Errno> FcntlImpl(s32 fd, FcntlCmd cmd, s32 arg); | ||||
|     Errno GetSockOptImpl(s32 fd, u32 level, OptName optname, std::vector<u8>& optval); | ||||
|     Errno SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval); | ||||
|     Errno ShutdownImpl(s32 fd, s32 how); | ||||
|     std::pair<s32, Errno> RecvImpl(s32 fd, u32 flags, std::vector<u8>& message); | ||||
| @@ -161,7 +171,6 @@ private: | ||||
|     std::pair<s32, Errno> SendImpl(s32 fd, u32 flags, std::span<const u8> message); | ||||
|     std::pair<s32, Errno> SendToImpl(s32 fd, u32 flags, std::span<const u8> message, | ||||
|                                      std::span<const u8> addr); | ||||
|     Errno CloseImpl(s32 fd); | ||||
|  | ||||
|     s32 FindFreeFileDescriptorHandle() noexcept; | ||||
|     bool IsFileDescriptorValid(s32 fd) const noexcept; | ||||
|   | ||||
| @@ -1,10 +1,15 @@ | ||||
| // SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project | ||||
| // SPDX-License-Identifier: GPL-2.0-or-later | ||||
|  | ||||
| #include "core/hle/service/ipc_helpers.h" | ||||
| #include "core/hle/service/sockets/nsd.h" | ||||
|  | ||||
| #include "common/string_util.h" | ||||
|  | ||||
| namespace Service::Sockets { | ||||
|  | ||||
| constexpr Result ResultOverflow{ErrorModule::NSD, 6}; | ||||
|  | ||||
| NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, name} { | ||||
|     // clang-format off | ||||
|     static const FunctionInfo functions[] = { | ||||
| @@ -15,8 +20,8 @@ NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, na | ||||
|         {13, nullptr, "DeleteSettings"}, | ||||
|         {14, nullptr, "ImportSettings"}, | ||||
|         {15, nullptr, "SetChangeEnvironmentIdentifierDisabled"}, | ||||
|         {20, nullptr, "Resolve"}, | ||||
|         {21, nullptr, "ResolveEx"}, | ||||
|         {20, &NSD::Resolve, "Resolve"}, | ||||
|         {21, &NSD::ResolveEx, "ResolveEx"}, | ||||
|         {30, nullptr, "GetNasServiceSetting"}, | ||||
|         {31, nullptr, "GetNasServiceSettingEx"}, | ||||
|         {40, nullptr, "GetNasRequestFqdn"}, | ||||
| @@ -40,6 +45,55 @@ NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, na | ||||
|     RegisterHandlers(functions); | ||||
| } | ||||
|  | ||||
| static ResultVal<std::string> ResolveImpl(const std::string& fqdn_in) { | ||||
|     // The real implementation makes various substitutions. | ||||
|     // For now we just return the string as-is, which is good enough when not | ||||
|     // connecting to real Nintendo servers. | ||||
|     LOG_WARNING(Service, "(STUBBED) called, fqdn_in={}", fqdn_in); | ||||
|     return fqdn_in; | ||||
| } | ||||
|  | ||||
| static Result ResolveCommon(const std::string& fqdn_in, std::array<char, 0x100>& fqdn_out) { | ||||
|     const auto res = ResolveImpl(fqdn_in); | ||||
|     if (res.Failed()) { | ||||
|         return res.Code(); | ||||
|     } | ||||
|     if (res->size() >= fqdn_out.size()) { | ||||
|         return ResultOverflow; | ||||
|     } | ||||
|     std::memcpy(fqdn_out.data(), res->c_str(), res->size() + 1); | ||||
|     return ResultSuccess; | ||||
| } | ||||
|  | ||||
| void NSD::Resolve(HLERequestContext& ctx) { | ||||
|     const std::string fqdn_in = Common::StringFromBuffer(ctx.ReadBuffer(0)); | ||||
|  | ||||
|     std::array<char, 0x100> fqdn_out{}; | ||||
|     const Result res = ResolveCommon(fqdn_in, fqdn_out); | ||||
|  | ||||
|     ctx.WriteBuffer(fqdn_out); | ||||
|     IPC::ResponseBuilder rb{ctx, 2}; | ||||
|     rb.Push(res); | ||||
| } | ||||
|  | ||||
| void NSD::ResolveEx(HLERequestContext& ctx) { | ||||
|     const std::string fqdn_in = Common::StringFromBuffer(ctx.ReadBuffer(0)); | ||||
|  | ||||
|     std::array<char, 0x100> fqdn_out; | ||||
|     const Result res = ResolveCommon(fqdn_in, fqdn_out); | ||||
|  | ||||
|     if (res.IsError()) { | ||||
|         IPC::ResponseBuilder rb{ctx, 2}; | ||||
|         rb.Push(res); | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     ctx.WriteBuffer(fqdn_out); | ||||
|     IPC::ResponseBuilder rb{ctx, 4}; | ||||
|     rb.Push(ResultSuccess); | ||||
|     rb.Push(ResultSuccess); | ||||
| } | ||||
|  | ||||
| NSD::~NSD() = default; | ||||
|  | ||||
| } // namespace Service::Sockets | ||||
|   | ||||
| @@ -15,6 +15,10 @@ class NSD final : public ServiceFramework<NSD> { | ||||
| public: | ||||
|     explicit NSD(Core::System& system_, const char* name); | ||||
|     ~NSD() override; | ||||
|  | ||||
| private: | ||||
|     void Resolve(HLERequestContext& ctx); | ||||
|     void ResolveEx(HLERequestContext& ctx); | ||||
| }; | ||||
|  | ||||
| } // namespace Service::Sockets | ||||
|   | ||||
| @@ -10,27 +10,18 @@ | ||||
| #include "core/core.h" | ||||
| #include "core/hle/service/ipc_helpers.h" | ||||
| #include "core/hle/service/sockets/sfdnsres.h" | ||||
| #include "core/hle/service/sockets/sockets.h" | ||||
| #include "core/hle/service/sockets/sockets_translate.h" | ||||
| #include "core/internal_network/network.h" | ||||
| #include "core/memory.h" | ||||
|  | ||||
| #ifdef _WIN32 | ||||
| #include <ws2tcpip.h> | ||||
| #elif YUZU_UNIX | ||||
| #include <arpa/inet.h> | ||||
| #include <netdb.h> | ||||
| #include <netinet/in.h> | ||||
| #include <sys/socket.h> | ||||
| #ifndef EAI_NODATA | ||||
| #define EAI_NODATA EAI_NONAME | ||||
| #endif | ||||
| #endif | ||||
|  | ||||
| namespace Service::Sockets { | ||||
|  | ||||
| SFDNSRES::SFDNSRES(Core::System& system_) : ServiceFramework{system_, "sfdnsres"} { | ||||
|     static const FunctionInfo functions[] = { | ||||
|         {0, nullptr, "SetDnsAddressesPrivateRequest"}, | ||||
|         {1, nullptr, "GetDnsAddressPrivateRequest"}, | ||||
|         {2, nullptr, "GetHostByNameRequest"}, | ||||
|         {2, &SFDNSRES::GetHostByNameRequest, "GetHostByNameRequest"}, | ||||
|         {3, nullptr, "GetHostByAddrRequest"}, | ||||
|         {4, nullptr, "GetHostStringErrorRequest"}, | ||||
|         {5, nullptr, "GetGaiStringErrorRequest"}, | ||||
| @@ -38,11 +29,11 @@ SFDNSRES::SFDNSRES(Core::System& system_) : ServiceFramework{system_, "sfdnsres" | ||||
|         {7, nullptr, "GetNameInfoRequest"}, | ||||
|         {8, nullptr, "RequestCancelHandleRequest"}, | ||||
|         {9, nullptr, "CancelRequest"}, | ||||
|         {10, nullptr, "GetHostByNameRequestWithOptions"}, | ||||
|         {10, &SFDNSRES::GetHostByNameRequestWithOptions, "GetHostByNameRequestWithOptions"}, | ||||
|         {11, nullptr, "GetHostByAddrRequestWithOptions"}, | ||||
|         {12, &SFDNSRES::GetAddrInfoRequestWithOptions, "GetAddrInfoRequestWithOptions"}, | ||||
|         {13, nullptr, "GetNameInfoRequestWithOptions"}, | ||||
|         {14, nullptr, "ResolverSetOptionRequest"}, | ||||
|         {14, &SFDNSRES::ResolverSetOptionRequest, "ResolverSetOptionRequest"}, | ||||
|         {15, nullptr, "ResolverGetOptionRequest"}, | ||||
|     }; | ||||
|     RegisterHandlers(functions); | ||||
| @@ -59,188 +50,285 @@ enum class NetDbError : s32 { | ||||
|     NoData = 4, | ||||
| }; | ||||
|  | ||||
| static NetDbError AddrInfoErrorToNetDbError(s32 result) { | ||||
|     // Best effort guess to map errors | ||||
| static NetDbError GetAddrInfoErrorToNetDbError(GetAddrInfoError result) { | ||||
|     // These combinations have been verified on console (but are not | ||||
|     // exhaustive). | ||||
|     switch (result) { | ||||
|     case 0: | ||||
|     case GetAddrInfoError::SUCCESS: | ||||
|         return NetDbError::Success; | ||||
|     case EAI_AGAIN: | ||||
|     case GetAddrInfoError::AGAIN: | ||||
|         return NetDbError::TryAgain; | ||||
|     case EAI_NODATA: | ||||
|         return NetDbError::NoData; | ||||
|     case GetAddrInfoError::NODATA: | ||||
|         return NetDbError::HostNotFound; | ||||
|     case GetAddrInfoError::SERVICE: | ||||
|         return NetDbError::Success; | ||||
|     default: | ||||
|         return NetDbError::HostNotFound; | ||||
|     } | ||||
| } | ||||
|  | ||||
| static std::vector<u8> SerializeAddrInfo(const addrinfo* addrinfo, s32 result_code, | ||||
| static Errno GetAddrInfoErrorToErrno(GetAddrInfoError result) { | ||||
|     // These combinations have been verified on console (but are not | ||||
|     // exhaustive). | ||||
|     switch (result) { | ||||
|     case GetAddrInfoError::SUCCESS: | ||||
|         // Note: Sometimes a successful lookup sets errno to EADDRNOTAVAIL for | ||||
|         // some reason, but that doesn't seem useful to implement. | ||||
|         return Errno::SUCCESS; | ||||
|     case GetAddrInfoError::AGAIN: | ||||
|         return Errno::SUCCESS; | ||||
|     case GetAddrInfoError::NODATA: | ||||
|         return Errno::SUCCESS; | ||||
|     case GetAddrInfoError::SERVICE: | ||||
|         return Errno::INVAL; | ||||
|     default: | ||||
|         return Errno::SUCCESS; | ||||
|     } | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| static void Append(std::vector<u8>& vec, T t) { | ||||
|     const size_t offset = vec.size(); | ||||
|     vec.resize(offset + sizeof(T)); | ||||
|     std::memcpy(vec.data() + offset, &t, sizeof(T)); | ||||
| } | ||||
|  | ||||
| static void AppendNulTerminated(std::vector<u8>& vec, std::string_view str) { | ||||
|     const size_t offset = vec.size(); | ||||
|     vec.resize(offset + str.size() + 1); | ||||
|     std::memmove(vec.data() + offset, str.data(), str.size()); | ||||
| } | ||||
|  | ||||
| // We implement gethostbyname using the host's getaddrinfo rather than the | ||||
| // host's gethostbyname, because it simplifies portability: e.g., getaddrinfo | ||||
| // behaves the same on Unix and Windows, unlike gethostbyname where Windows | ||||
| // doesn't implement h_errno. | ||||
| static std::vector<u8> SerializeAddrInfoAsHostEnt(const std::vector<Network::AddrInfo>& vec, | ||||
|                                                   std::string_view host) { | ||||
|  | ||||
|     std::vector<u8> data; | ||||
|     // h_name: use the input hostname (append nul-terminated) | ||||
|     AppendNulTerminated(data, host); | ||||
|     // h_aliases: leave empty | ||||
|  | ||||
|     Append<u32_be>(data, 0); // count of h_aliases | ||||
|     // (If the count were nonzero, the aliases would be appended as nul-terminated here.) | ||||
|     Append<u16_be>(data, static_cast<u16>(Domain::INET)); // h_addrtype | ||||
|     Append<u16_be>(data, sizeof(Network::IPv4Address));   // h_length | ||||
|     // h_addr_list: | ||||
|     size_t count = vec.size(); | ||||
|     ASSERT(count <= UINT32_MAX); | ||||
|     Append<u32_be>(data, static_cast<uint32_t>(count)); | ||||
|     for (const Network::AddrInfo& addrinfo : vec) { | ||||
|         // On the Switch, this is passed through htonl despite already being | ||||
|         // big-endian, so it ends up as little-endian. | ||||
|         Append<u32_le>(data, Network::IPv4AddressToInteger(addrinfo.addr.ip)); | ||||
|  | ||||
|         LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host, | ||||
|                  Network::IPv4AddressToString(addrinfo.addr.ip)); | ||||
|     } | ||||
|     return data; | ||||
| } | ||||
|  | ||||
| static std::pair<u32, GetAddrInfoError> GetHostByNameRequestImpl(HLERequestContext& ctx) { | ||||
|     struct InputParameters { | ||||
|         u8 use_nsd_resolve; | ||||
|         u32 cancel_handle; | ||||
|         u64 process_id; | ||||
|     }; | ||||
|     static_assert(sizeof(InputParameters) == 0x10); | ||||
|  | ||||
|     IPC::RequestParser rp{ctx}; | ||||
|     const auto parameters = rp.PopRaw<InputParameters>(); | ||||
|  | ||||
|     LOG_WARNING( | ||||
|         Service, | ||||
|         "called with ignored parameters: use_nsd_resolve={}, cancel_handle={}, process_id={}", | ||||
|         parameters.use_nsd_resolve, parameters.cancel_handle, parameters.process_id); | ||||
|  | ||||
|     const auto host_buffer = ctx.ReadBuffer(0); | ||||
|     const std::string host = Common::StringFromBuffer(host_buffer); | ||||
|     // For now, ignore options, which are in input buffer 1 for GetHostByNameRequestWithOptions. | ||||
|  | ||||
|     auto res = Network::GetAddressInfo(host, /*service*/ std::nullopt); | ||||
|     if (!res.has_value()) { | ||||
|         return {0, Translate(res.error())}; | ||||
|     } | ||||
|  | ||||
|     const std::vector<u8> data = SerializeAddrInfoAsHostEnt(res.value(), host); | ||||
|     const u32 data_size = static_cast<u32>(data.size()); | ||||
|     ctx.WriteBuffer(data, 0); | ||||
|  | ||||
|     return {data_size, GetAddrInfoError::SUCCESS}; | ||||
| } | ||||
|  | ||||
| void SFDNSRES::GetHostByNameRequest(HLERequestContext& ctx) { | ||||
|     auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx); | ||||
|  | ||||
|     struct OutputParameters { | ||||
|         NetDbError netdb_error; | ||||
|         Errno bsd_errno; | ||||
|         u32 data_size; | ||||
|     }; | ||||
|     static_assert(sizeof(OutputParameters) == 0xc); | ||||
|  | ||||
|     IPC::ResponseBuilder rb{ctx, 5}; | ||||
|     rb.Push(ResultSuccess); | ||||
|     rb.PushRaw(OutputParameters{ | ||||
|         .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err), | ||||
|         .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err), | ||||
|         .data_size = data_size, | ||||
|     }); | ||||
| } | ||||
|  | ||||
| void SFDNSRES::GetHostByNameRequestWithOptions(HLERequestContext& ctx) { | ||||
|     auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx); | ||||
|  | ||||
|     struct OutputParameters { | ||||
|         u32 data_size; | ||||
|         NetDbError netdb_error; | ||||
|         Errno bsd_errno; | ||||
|     }; | ||||
|     static_assert(sizeof(OutputParameters) == 0xc); | ||||
|  | ||||
|     IPC::ResponseBuilder rb{ctx, 5}; | ||||
|     rb.Push(ResultSuccess); | ||||
|     rb.PushRaw(OutputParameters{ | ||||
|         .data_size = data_size, | ||||
|         .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err), | ||||
|         .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err), | ||||
|     }); | ||||
| } | ||||
|  | ||||
| static std::vector<u8> SerializeAddrInfo(const std::vector<Network::AddrInfo>& vec, | ||||
|                                          std::string_view host) { | ||||
|     // Adapted from | ||||
|     // https://github.com/switchbrew/libnx/blob/c5a9a909a91657a9818a3b7e18c9b91ff0cbb6e3/nx/source/runtime/resolver.c#L190 | ||||
|     std::vector<u8> data; | ||||
|  | ||||
|     auto* current = addrinfo; | ||||
|     while (current != nullptr) { | ||||
|         struct SerializedResponseHeader { | ||||
|             u32 magic; | ||||
|             s32 flags; | ||||
|             s32 family; | ||||
|             s32 socket_type; | ||||
|             s32 protocol; | ||||
|             u32 address_length; | ||||
|         }; | ||||
|         static_assert(sizeof(SerializedResponseHeader) == 0x18, | ||||
|                       "Response header size must be 0x18 bytes"); | ||||
|     for (const Network::AddrInfo& addrinfo : vec) { | ||||
|         // serialized addrinfo: | ||||
|         Append<u32_be>(data, 0xBEEFCAFE);                                        // magic | ||||
|         Append<u32_be>(data, 0);                                                 // ai_flags | ||||
|         Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.family)));      // ai_family | ||||
|         Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.socket_type))); // ai_socktype | ||||
|         Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.protocol)));    // ai_protocol | ||||
|         Append<u32_be>(data, sizeof(SockAddrIn));                                // ai_addrlen | ||||
|         // ^ *not* sizeof(SerializedSockAddrIn), not that it matters since they're the same size | ||||
|  | ||||
|         constexpr auto header_size = sizeof(SerializedResponseHeader); | ||||
|         const auto addr_size = | ||||
|             current->ai_addr && current->ai_addrlen > 0 ? current->ai_addrlen : 4; | ||||
|         const auto canonname_size = current->ai_canonname ? strlen(current->ai_canonname) + 1 : 1; | ||||
|         // ai_addr: | ||||
|         Append<u16_be>(data, static_cast<u16>(Translate(addrinfo.addr.family))); // sin_family | ||||
|         // On the Switch, the following fields are passed through htonl despite | ||||
|         // already being big-endian, so they end up as little-endian. | ||||
|         Append<u16_le>(data, addrinfo.addr.portno);                            // sin_port | ||||
|         Append<u32_le>(data, Network::IPv4AddressToInteger(addrinfo.addr.ip)); // sin_addr | ||||
|         data.resize(data.size() + 8, 0);                                       // sin_zero | ||||
|  | ||||
|         const auto last_size = data.size(); | ||||
|         data.resize(last_size + header_size + addr_size + canonname_size); | ||||
|  | ||||
|         // Header in network byte order | ||||
|         SerializedResponseHeader header{}; | ||||
|  | ||||
|         constexpr auto HEADER_MAGIC = 0xBEEFCAFE; | ||||
|         header.magic = htonl(HEADER_MAGIC); | ||||
|         header.family = htonl(current->ai_family); | ||||
|         header.flags = htonl(current->ai_flags); | ||||
|         header.socket_type = htonl(current->ai_socktype); | ||||
|         header.protocol = htonl(current->ai_protocol); | ||||
|         header.address_length = current->ai_addr ? htonl((u32)current->ai_addrlen) : 0; | ||||
|  | ||||
|         auto* header_ptr = data.data() + last_size; | ||||
|         std::memcpy(header_ptr, &header, header_size); | ||||
|  | ||||
|         if (header.address_length == 0) { | ||||
|             std::memset(header_ptr + header_size, 0, 4); | ||||
|         if (addrinfo.canon_name.has_value()) { | ||||
|             AppendNulTerminated(data, *addrinfo.canon_name); | ||||
|         } else { | ||||
|             switch (current->ai_family) { | ||||
|             case AF_INET: { | ||||
|                 struct SockAddrIn { | ||||
|                     s16 sin_family; | ||||
|                     u16 sin_port; | ||||
|                     u32 sin_addr; | ||||
|                     u8 sin_zero[8]; | ||||
|                 }; | ||||
|  | ||||
|                 SockAddrIn serialized_addr{}; | ||||
|                 const auto addr = *reinterpret_cast<sockaddr_in*>(current->ai_addr); | ||||
|                 serialized_addr.sin_port = htons(addr.sin_port); | ||||
|                 serialized_addr.sin_family = htons(addr.sin_family); | ||||
|                 serialized_addr.sin_addr = htonl(addr.sin_addr.s_addr); | ||||
|                 std::memcpy(header_ptr + header_size, &serialized_addr, sizeof(SockAddrIn)); | ||||
|  | ||||
|                 char addr_string_buf[64]{}; | ||||
|                 inet_ntop(AF_INET, &addr.sin_addr, addr_string_buf, std::size(addr_string_buf)); | ||||
|                 LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host, addr_string_buf); | ||||
|                 break; | ||||
|             } | ||||
|             case AF_INET6: { | ||||
|                 struct SockAddrIn6 { | ||||
|                     s16 sin6_family; | ||||
|                     u16 sin6_port; | ||||
|                     u32 sin6_flowinfo; | ||||
|                     u8 sin6_addr[16]; | ||||
|                     u32 sin6_scope_id; | ||||
|                 }; | ||||
|  | ||||
|                 SockAddrIn6 serialized_addr{}; | ||||
|                 const auto addr = *reinterpret_cast<sockaddr_in6*>(current->ai_addr); | ||||
|                 serialized_addr.sin6_family = htons(addr.sin6_family); | ||||
|                 serialized_addr.sin6_port = htons(addr.sin6_port); | ||||
|                 serialized_addr.sin6_flowinfo = htonl(addr.sin6_flowinfo); | ||||
|                 serialized_addr.sin6_scope_id = htonl(addr.sin6_scope_id); | ||||
|                 std::memcpy(serialized_addr.sin6_addr, &addr.sin6_addr, | ||||
|                             sizeof(SockAddrIn6::sin6_addr)); | ||||
|                 std::memcpy(header_ptr + header_size, &serialized_addr, sizeof(SockAddrIn6)); | ||||
|  | ||||
|                 char addr_string_buf[64]{}; | ||||
|                 inet_ntop(AF_INET6, &addr.sin6_addr, addr_string_buf, std::size(addr_string_buf)); | ||||
|                 LOG_INFO(Service, "Resolved host '{}' to IPv6 address {}", host, addr_string_buf); | ||||
|                 break; | ||||
|             } | ||||
|             default: | ||||
|                 std::memcpy(header_ptr + header_size, current->ai_addr, addr_size); | ||||
|                 break; | ||||
|             } | ||||
|         } | ||||
|         if (current->ai_canonname) { | ||||
|             std::memcpy(header_ptr + addr_size, current->ai_canonname, canonname_size); | ||||
|         } else { | ||||
|             *(header_ptr + header_size + addr_size) = 0; | ||||
|             data.push_back(0); | ||||
|         } | ||||
|  | ||||
|         current = current->ai_next; | ||||
|         LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host, | ||||
|                  Network::IPv4AddressToString(addrinfo.addr.ip)); | ||||
|     } | ||||
|  | ||||
|     // 4-byte sentinel value | ||||
|     data.push_back(0); | ||||
|     data.push_back(0); | ||||
|     data.push_back(0); | ||||
|     data.push_back(0); | ||||
|     data.resize(data.size() + 4, 0); // 4-byte sentinel value | ||||
|  | ||||
|     return data; | ||||
| } | ||||
|  | ||||
| static std::pair<u32, s32> GetAddrInfoRequestImpl(HLERequestContext& ctx) { | ||||
|     struct Parameters { | ||||
| static std::pair<u32, GetAddrInfoError> GetAddrInfoRequestImpl(HLERequestContext& ctx) { | ||||
|     struct InputParameters { | ||||
|         u8 use_nsd_resolve; | ||||
|         u32 unknown; | ||||
|         u32 cancel_handle; | ||||
|         u64 process_id; | ||||
|     }; | ||||
|     static_assert(sizeof(InputParameters) == 0x10); | ||||
|  | ||||
|     IPC::RequestParser rp{ctx}; | ||||
|     const auto parameters = rp.PopRaw<Parameters>(); | ||||
|     const auto parameters = rp.PopRaw<InputParameters>(); | ||||
|  | ||||
|     LOG_WARNING(Service, | ||||
|                 "called with ignored parameters: use_nsd_resolve={}, unknown={}, process_id={}", | ||||
|                 parameters.use_nsd_resolve, parameters.unknown, parameters.process_id); | ||||
|     LOG_WARNING( | ||||
|         Service, | ||||
|         "called with ignored parameters: use_nsd_resolve={}, cancel_handle={}, process_id={}", | ||||
|         parameters.use_nsd_resolve, parameters.cancel_handle, parameters.process_id); | ||||
|  | ||||
|     // TODO: If use_nsd_resolve is true, pass the name through NSD::Resolve | ||||
|     // before looking up. | ||||
|  | ||||
|     const auto host_buffer = ctx.ReadBuffer(0); | ||||
|     const std::string host = Common::StringFromBuffer(host_buffer); | ||||
|  | ||||
|     const auto service_buffer = ctx.ReadBuffer(1); | ||||
|     const std::string service = Common::StringFromBuffer(service_buffer); | ||||
|  | ||||
|     addrinfo* addrinfo; | ||||
|     // Pass null for hints. Serialized hints are also passed in a buffer, but are ignored for now | ||||
|     s32 result_code = getaddrinfo(host.c_str(), service.c_str(), nullptr, &addrinfo); | ||||
|  | ||||
|     u32 data_size = 0; | ||||
|     if (result_code == 0 && addrinfo != nullptr) { | ||||
|         const std::vector<u8>& data = SerializeAddrInfo(addrinfo, result_code, host); | ||||
|         data_size = static_cast<u32>(data.size()); | ||||
|         freeaddrinfo(addrinfo); | ||||
|  | ||||
|         ctx.WriteBuffer(data, 0); | ||||
|     std::optional<std::string> service = std::nullopt; | ||||
|     if (ctx.CanReadBuffer(1)) { | ||||
|         const std::span<const u8> service_buffer = ctx.ReadBuffer(1); | ||||
|         service = Common::StringFromBuffer(service_buffer); | ||||
|     } | ||||
|  | ||||
|     return std::make_pair(data_size, result_code); | ||||
|     // Serialized hints are also passed in a buffer, but are ignored for now. | ||||
|  | ||||
|     auto res = Network::GetAddressInfo(host, service); | ||||
|     if (!res.has_value()) { | ||||
|         return {0, Translate(res.error())}; | ||||
|     } | ||||
|  | ||||
|     const std::vector<u8> data = SerializeAddrInfo(res.value(), host); | ||||
|     const u32 data_size = static_cast<u32>(data.size()); | ||||
|     ctx.WriteBuffer(data, 0); | ||||
|  | ||||
|     return {data_size, GetAddrInfoError::SUCCESS}; | ||||
| } | ||||
|  | ||||
| void SFDNSRES::GetAddrInfoRequest(HLERequestContext& ctx) { | ||||
|     auto [data_size, result_code] = GetAddrInfoRequestImpl(ctx); | ||||
|     auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx); | ||||
|  | ||||
|     IPC::ResponseBuilder rb{ctx, 4}; | ||||
|     struct OutputParameters { | ||||
|         Errno bsd_errno; | ||||
|         GetAddrInfoError gai_error; | ||||
|         u32 data_size; | ||||
|     }; | ||||
|     static_assert(sizeof(OutputParameters) == 0xc); | ||||
|  | ||||
|     IPC::ResponseBuilder rb{ctx, 5}; | ||||
|     rb.Push(ResultSuccess); | ||||
|     rb.Push(static_cast<s32>(AddrInfoErrorToNetDbError(result_code))); // NetDBErrorCode | ||||
|     rb.Push(result_code);                                              // errno | ||||
|     rb.Push(data_size);                                                // serialized size | ||||
|     rb.PushRaw(OutputParameters{ | ||||
|         .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err), | ||||
|         .gai_error = emu_gai_err, | ||||
|         .data_size = data_size, | ||||
|     }); | ||||
| } | ||||
|  | ||||
| void SFDNSRES::GetAddrInfoRequestWithOptions(HLERequestContext& ctx) { | ||||
|     // Additional options are ignored | ||||
|     auto [data_size, result_code] = GetAddrInfoRequestImpl(ctx); | ||||
|     auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx); | ||||
|  | ||||
|     IPC::ResponseBuilder rb{ctx, 5}; | ||||
|     struct OutputParameters { | ||||
|         u32 data_size; | ||||
|         GetAddrInfoError gai_error; | ||||
|         NetDbError netdb_error; | ||||
|         Errno bsd_errno; | ||||
|     }; | ||||
|     static_assert(sizeof(OutputParameters) == 0x10); | ||||
|  | ||||
|     IPC::ResponseBuilder rb{ctx, 6}; | ||||
|     rb.Push(ResultSuccess); | ||||
|     rb.Push(data_size);                                                // serialized size | ||||
|     rb.Push(result_code);                                              // errno | ||||
|     rb.Push(static_cast<s32>(AddrInfoErrorToNetDbError(result_code))); // NetDBErrorCode | ||||
|     rb.Push(0); | ||||
|     rb.PushRaw(OutputParameters{ | ||||
|         .data_size = data_size, | ||||
|         .gai_error = emu_gai_err, | ||||
|         .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err), | ||||
|         .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err), | ||||
|     }); | ||||
| } | ||||
|  | ||||
| void SFDNSRES::ResolverSetOptionRequest(HLERequestContext& ctx) { | ||||
|     LOG_WARNING(Service, "(STUBBED) called"); | ||||
|  | ||||
|     IPC::ResponseBuilder rb{ctx, 3}; | ||||
|  | ||||
|     rb.Push(ResultSuccess); | ||||
|     rb.Push<s32>(0); // bsd errno | ||||
| } | ||||
|  | ||||
| } // namespace Service::Sockets | ||||
|   | ||||
| @@ -17,8 +17,11 @@ public: | ||||
|     ~SFDNSRES() override; | ||||
|  | ||||
| private: | ||||
|     void GetHostByNameRequest(HLERequestContext& ctx); | ||||
|     void GetHostByNameRequestWithOptions(HLERequestContext& ctx); | ||||
|     void GetAddrInfoRequest(HLERequestContext& ctx); | ||||
|     void GetAddrInfoRequestWithOptions(HLERequestContext& ctx); | ||||
|     void ResolverSetOptionRequest(HLERequestContext& ctx); | ||||
| }; | ||||
|  | ||||
| } // namespace Service::Sockets | ||||
|   | ||||
| @@ -22,13 +22,35 @@ enum class Errno : u32 { | ||||
|     CONNRESET = 104, | ||||
|     NOTCONN = 107, | ||||
|     TIMEDOUT = 110, | ||||
|     INPROGRESS = 115, | ||||
| }; | ||||
|  | ||||
| enum class GetAddrInfoError : s32 { | ||||
|     SUCCESS = 0, | ||||
|     ADDRFAMILY = 1, | ||||
|     AGAIN = 2, | ||||
|     BADFLAGS = 3, | ||||
|     FAIL = 4, | ||||
|     FAMILY = 5, | ||||
|     MEMORY = 6, | ||||
|     NODATA = 7, | ||||
|     NONAME = 8, | ||||
|     SERVICE = 9, | ||||
|     SOCKTYPE = 10, | ||||
|     SYSTEM = 11, | ||||
|     BADHINTS = 12, | ||||
|     PROTOCOL = 13, | ||||
|     OVERFLOW_ = 14, // avoid name collision with Windows macro | ||||
|     OTHER = 15, | ||||
| }; | ||||
|  | ||||
| enum class Domain : u32 { | ||||
|     Unspecified = 0, | ||||
|     INET = 2, | ||||
| }; | ||||
|  | ||||
| enum class Type : u32 { | ||||
|     Unspecified = 0, | ||||
|     STREAM = 1, | ||||
|     DGRAM = 2, | ||||
|     RAW = 3, | ||||
| @@ -36,12 +58,16 @@ enum class Type : u32 { | ||||
| }; | ||||
|  | ||||
| enum class Protocol : u32 { | ||||
|     UNSPECIFIED = 0, | ||||
|     Unspecified = 0, | ||||
|     ICMP = 1, | ||||
|     TCP = 6, | ||||
|     UDP = 17, | ||||
| }; | ||||
|  | ||||
| enum class SocketLevel : u32 { | ||||
|     SOCKET = 0xffff, // i.e. SOL_SOCKET | ||||
| }; | ||||
|  | ||||
| enum class OptName : u32 { | ||||
|     REUSEADDR = 0x4, | ||||
|     KEEPALIVE = 0x8, | ||||
| @@ -51,6 +77,8 @@ enum class OptName : u32 { | ||||
|     RCVBUF = 0x1002, | ||||
|     SNDTIMEO = 0x1005, | ||||
|     RCVTIMEO = 0x1006, | ||||
|     ERROR_ = 0x1007,   // avoid name collision with Windows macro | ||||
|     NOSIGPIPE = 0x800, // at least according to libnx | ||||
| }; | ||||
|  | ||||
| enum class ShutdownHow : s32 { | ||||
| @@ -80,6 +108,9 @@ enum class PollEvents : u16 { | ||||
|     Err = 1 << 3, | ||||
|     Hup = 1 << 4, | ||||
|     Nval = 1 << 5, | ||||
|     RdNorm = 1 << 6, | ||||
|     RdBand = 1 << 7, | ||||
|     WrBand = 1 << 8, | ||||
| }; | ||||
|  | ||||
| DECLARE_ENUM_FLAG_OPERATORS(PollEvents); | ||||
|   | ||||
| @@ -29,6 +29,8 @@ Errno Translate(Network::Errno value) { | ||||
|         return Errno::TIMEDOUT; | ||||
|     case Network::Errno::CONNRESET: | ||||
|         return Errno::CONNRESET; | ||||
|     case Network::Errno::INPROGRESS: | ||||
|         return Errno::INPROGRESS; | ||||
|     default: | ||||
|         UNIMPLEMENTED_MSG("Unimplemented errno={}", value); | ||||
|         return Errno::SUCCESS; | ||||
| @@ -39,8 +41,50 @@ std::pair<s32, Errno> Translate(std::pair<s32, Network::Errno> value) { | ||||
|     return {value.first, Translate(value.second)}; | ||||
| } | ||||
|  | ||||
| GetAddrInfoError Translate(Network::GetAddrInfoError error) { | ||||
|     switch (error) { | ||||
|     case Network::GetAddrInfoError::SUCCESS: | ||||
|         return GetAddrInfoError::SUCCESS; | ||||
|     case Network::GetAddrInfoError::ADDRFAMILY: | ||||
|         return GetAddrInfoError::ADDRFAMILY; | ||||
|     case Network::GetAddrInfoError::AGAIN: | ||||
|         return GetAddrInfoError::AGAIN; | ||||
|     case Network::GetAddrInfoError::BADFLAGS: | ||||
|         return GetAddrInfoError::BADFLAGS; | ||||
|     case Network::GetAddrInfoError::FAIL: | ||||
|         return GetAddrInfoError::FAIL; | ||||
|     case Network::GetAddrInfoError::FAMILY: | ||||
|         return GetAddrInfoError::FAMILY; | ||||
|     case Network::GetAddrInfoError::MEMORY: | ||||
|         return GetAddrInfoError::MEMORY; | ||||
|     case Network::GetAddrInfoError::NODATA: | ||||
|         return GetAddrInfoError::NODATA; | ||||
|     case Network::GetAddrInfoError::NONAME: | ||||
|         return GetAddrInfoError::NONAME; | ||||
|     case Network::GetAddrInfoError::SERVICE: | ||||
|         return GetAddrInfoError::SERVICE; | ||||
|     case Network::GetAddrInfoError::SOCKTYPE: | ||||
|         return GetAddrInfoError::SOCKTYPE; | ||||
|     case Network::GetAddrInfoError::SYSTEM: | ||||
|         return GetAddrInfoError::SYSTEM; | ||||
|     case Network::GetAddrInfoError::BADHINTS: | ||||
|         return GetAddrInfoError::BADHINTS; | ||||
|     case Network::GetAddrInfoError::PROTOCOL: | ||||
|         return GetAddrInfoError::PROTOCOL; | ||||
|     case Network::GetAddrInfoError::OVERFLOW_: | ||||
|         return GetAddrInfoError::OVERFLOW_; | ||||
|     case Network::GetAddrInfoError::OTHER: | ||||
|         return GetAddrInfoError::OTHER; | ||||
|     default: | ||||
|         UNIMPLEMENTED_MSG("Unimplemented GetAddrInfoError={}", error); | ||||
|         return GetAddrInfoError::OTHER; | ||||
|     } | ||||
| } | ||||
|  | ||||
| Network::Domain Translate(Domain domain) { | ||||
|     switch (domain) { | ||||
|     case Domain::Unspecified: | ||||
|         return Network::Domain::Unspecified; | ||||
|     case Domain::INET: | ||||
|         return Network::Domain::INET; | ||||
|     default: | ||||
| @@ -51,6 +95,8 @@ Network::Domain Translate(Domain domain) { | ||||
|  | ||||
| Domain Translate(Network::Domain domain) { | ||||
|     switch (domain) { | ||||
|     case Network::Domain::Unspecified: | ||||
|         return Domain::Unspecified; | ||||
|     case Network::Domain::INET: | ||||
|         return Domain::INET; | ||||
|     default: | ||||
| @@ -61,39 +107,69 @@ Domain Translate(Network::Domain domain) { | ||||
|  | ||||
| Network::Type Translate(Type type) { | ||||
|     switch (type) { | ||||
|     case Type::Unspecified: | ||||
|         return Network::Type::Unspecified; | ||||
|     case Type::STREAM: | ||||
|         return Network::Type::STREAM; | ||||
|     case Type::DGRAM: | ||||
|         return Network::Type::DGRAM; | ||||
|     case Type::RAW: | ||||
|         return Network::Type::RAW; | ||||
|     case Type::SEQPACKET: | ||||
|         return Network::Type::SEQPACKET; | ||||
|     default: | ||||
|         UNIMPLEMENTED_MSG("Unimplemented type={}", type); | ||||
|         return Network::Type{}; | ||||
|     } | ||||
| } | ||||
|  | ||||
| Network::Protocol Translate(Type type, Protocol protocol) { | ||||
| Type Translate(Network::Type type) { | ||||
|     switch (type) { | ||||
|     case Network::Type::Unspecified: | ||||
|         return Type::Unspecified; | ||||
|     case Network::Type::STREAM: | ||||
|         return Type::STREAM; | ||||
|     case Network::Type::DGRAM: | ||||
|         return Type::DGRAM; | ||||
|     case Network::Type::RAW: | ||||
|         return Type::RAW; | ||||
|     case Network::Type::SEQPACKET: | ||||
|         return Type::SEQPACKET; | ||||
|     default: | ||||
|         UNIMPLEMENTED_MSG("Unimplemented type={}", type); | ||||
|         return Type{}; | ||||
|     } | ||||
| } | ||||
|  | ||||
| Network::Protocol Translate(Protocol protocol) { | ||||
|     switch (protocol) { | ||||
|     case Protocol::UNSPECIFIED: | ||||
|         LOG_WARNING(Service, "Unspecified protocol, assuming protocol from type"); | ||||
|         switch (type) { | ||||
|         case Type::DGRAM: | ||||
|             return Network::Protocol::UDP; | ||||
|         case Type::STREAM: | ||||
|             return Network::Protocol::TCP; | ||||
|         default: | ||||
|             return Network::Protocol::TCP; | ||||
|         } | ||||
|     case Protocol::Unspecified: | ||||
|         return Network::Protocol::Unspecified; | ||||
|     case Protocol::TCP: | ||||
|         return Network::Protocol::TCP; | ||||
|     case Protocol::UDP: | ||||
|         return Network::Protocol::UDP; | ||||
|     default: | ||||
|         UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol); | ||||
|         return Network::Protocol::TCP; | ||||
|         return Network::Protocol::Unspecified; | ||||
|     } | ||||
| } | ||||
|  | ||||
| Network::PollEvents TranslatePollEventsToHost(PollEvents flags) { | ||||
| Protocol Translate(Network::Protocol protocol) { | ||||
|     switch (protocol) { | ||||
|     case Network::Protocol::Unspecified: | ||||
|         return Protocol::Unspecified; | ||||
|     case Network::Protocol::TCP: | ||||
|         return Protocol::TCP; | ||||
|     case Network::Protocol::UDP: | ||||
|         return Protocol::UDP; | ||||
|     default: | ||||
|         UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol); | ||||
|         return Protocol::Unspecified; | ||||
|     } | ||||
| } | ||||
|  | ||||
| Network::PollEvents Translate(PollEvents flags) { | ||||
|     Network::PollEvents result{}; | ||||
|     const auto translate = [&result, &flags](PollEvents from, Network::PollEvents to) { | ||||
|         if (True(flags & from)) { | ||||
| @@ -107,12 +183,15 @@ Network::PollEvents TranslatePollEventsToHost(PollEvents flags) { | ||||
|     translate(PollEvents::Err, Network::PollEvents::Err); | ||||
|     translate(PollEvents::Hup, Network::PollEvents::Hup); | ||||
|     translate(PollEvents::Nval, Network::PollEvents::Nval); | ||||
|     translate(PollEvents::RdNorm, Network::PollEvents::RdNorm); | ||||
|     translate(PollEvents::RdBand, Network::PollEvents::RdBand); | ||||
|     translate(PollEvents::WrBand, Network::PollEvents::WrBand); | ||||
|  | ||||
|     UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags); | ||||
|     return result; | ||||
| } | ||||
|  | ||||
| PollEvents TranslatePollEventsToGuest(Network::PollEvents flags) { | ||||
| PollEvents Translate(Network::PollEvents flags) { | ||||
|     PollEvents result{}; | ||||
|     const auto translate = [&result, &flags](Network::PollEvents from, PollEvents to) { | ||||
|         if (True(flags & from)) { | ||||
| @@ -127,13 +206,18 @@ PollEvents TranslatePollEventsToGuest(Network::PollEvents flags) { | ||||
|     translate(Network::PollEvents::Err, PollEvents::Err); | ||||
|     translate(Network::PollEvents::Hup, PollEvents::Hup); | ||||
|     translate(Network::PollEvents::Nval, PollEvents::Nval); | ||||
|     translate(Network::PollEvents::RdNorm, PollEvents::RdNorm); | ||||
|     translate(Network::PollEvents::RdBand, PollEvents::RdBand); | ||||
|     translate(Network::PollEvents::WrBand, PollEvents::WrBand); | ||||
|  | ||||
|     UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags); | ||||
|     return result; | ||||
| } | ||||
|  | ||||
| Network::SockAddrIn Translate(SockAddrIn value) { | ||||
|     ASSERT(value.len == 0 || value.len == sizeof(value)); | ||||
|     // Note: 6 is incorrect, but can be passed by homebrew (because libnx sets | ||||
|     // sin_len to 6 when deserializing getaddrinfo results). | ||||
|     ASSERT(value.len == 0 || value.len == sizeof(value) || value.len == 6); | ||||
|  | ||||
|     return { | ||||
|         .family = Translate(static_cast<Domain>(value.family)), | ||||
|   | ||||
| @@ -17,6 +17,9 @@ Errno Translate(Network::Errno value); | ||||
| /// Translate abstract return value errno pair to guest return value errno pair | ||||
| std::pair<s32, Errno> Translate(std::pair<s32, Network::Errno> value); | ||||
|  | ||||
| /// Translate abstract getaddrinfo error to guest getaddrinfo error | ||||
| GetAddrInfoError Translate(Network::GetAddrInfoError value); | ||||
|  | ||||
| /// Translate guest domain to abstract domain | ||||
| Network::Domain Translate(Domain domain); | ||||
|  | ||||
| @@ -26,14 +29,20 @@ Domain Translate(Network::Domain domain); | ||||
| /// Translate guest type to abstract type | ||||
| Network::Type Translate(Type type); | ||||
|  | ||||
| /// Translate guest protocol to abstract protocol | ||||
| Network::Protocol Translate(Type type, Protocol protocol); | ||||
| /// Translate abstract type to guest type | ||||
| Type Translate(Network::Type type); | ||||
|  | ||||
| /// Translate abstract poll event flags to guest poll event flags | ||||
| Network::PollEvents TranslatePollEventsToHost(PollEvents flags); | ||||
| /// Translate guest protocol to abstract protocol | ||||
| Network::Protocol Translate(Protocol protocol); | ||||
|  | ||||
| /// Translate abstract protocol to guest protocol | ||||
| Protocol Translate(Network::Protocol protocol); | ||||
|  | ||||
| /// Translate guest poll event flags to abstract poll event flags | ||||
| PollEvents TranslatePollEventsToGuest(Network::PollEvents flags); | ||||
| Network::PollEvents Translate(PollEvents flags); | ||||
|  | ||||
| /// Translate abstract poll event flags to guest poll event flags | ||||
| PollEvents Translate(Network::PollEvents flags); | ||||
|  | ||||
| /// Translate guest socket address structure to abstract socket address structure | ||||
| Network::SockAddrIn Translate(SockAddrIn value); | ||||
|   | ||||
| @@ -1,10 +1,18 @@ | ||||
| // SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project | ||||
| // SPDX-License-Identifier: GPL-2.0-or-later | ||||
|  | ||||
| #include "common/string_util.h" | ||||
|  | ||||
| #include "core/core.h" | ||||
| #include "core/hle/service/ipc_helpers.h" | ||||
| #include "core/hle/service/server_manager.h" | ||||
| #include "core/hle/service/service.h" | ||||
| #include "core/hle/service/sm/sm.h" | ||||
| #include "core/hle/service/sockets/bsd.h" | ||||
| #include "core/hle/service/ssl/ssl.h" | ||||
| #include "core/hle/service/ssl/ssl_backend.h" | ||||
| #include "core/internal_network/network.h" | ||||
| #include "core/internal_network/sockets.h" | ||||
|  | ||||
| namespace Service::SSL { | ||||
|  | ||||
| @@ -20,6 +28,18 @@ enum class ContextOption : u32 { | ||||
|     CrlImportDateCheckEnable = 1, | ||||
| }; | ||||
|  | ||||
| // This is nn::ssl::Connection::IoMode | ||||
| enum class IoMode : u32 { | ||||
|     Blocking = 1, | ||||
|     NonBlocking = 2, | ||||
| }; | ||||
|  | ||||
| // This is nn::ssl::sf::OptionType | ||||
| enum class OptionType : u32 { | ||||
|     DoNotCloseSocket = 0, | ||||
|     GetServerCertChain = 1, | ||||
| }; | ||||
|  | ||||
| // This is nn::ssl::sf::SslVersion | ||||
| struct SslVersion { | ||||
|     union { | ||||
| @@ -34,35 +54,42 @@ struct SslVersion { | ||||
|     }; | ||||
| }; | ||||
|  | ||||
| struct SslContextSharedData { | ||||
|     u32 connection_count = 0; | ||||
| }; | ||||
|  | ||||
| class ISslConnection final : public ServiceFramework<ISslConnection> { | ||||
| public: | ||||
|     explicit ISslConnection(Core::System& system_, SslVersion version) | ||||
|         : ServiceFramework{system_, "ISslConnection"}, ssl_version{version} { | ||||
|     explicit ISslConnection(Core::System& system_in, SslVersion ssl_version_in, | ||||
|                             std::shared_ptr<SslContextSharedData>& shared_data_in, | ||||
|                             std::unique_ptr<SSLConnectionBackend>&& backend_in) | ||||
|         : ServiceFramework{system_in, "ISslConnection"}, ssl_version{ssl_version_in}, | ||||
|           shared_data{shared_data_in}, backend{std::move(backend_in)} { | ||||
|         // clang-format off | ||||
|         static const FunctionInfo functions[] = { | ||||
|             {0, nullptr, "SetSocketDescriptor"}, | ||||
|             {1, nullptr, "SetHostName"}, | ||||
|             {2, nullptr, "SetVerifyOption"}, | ||||
|             {3, nullptr, "SetIoMode"}, | ||||
|             {0, &ISslConnection::SetSocketDescriptor, "SetSocketDescriptor"}, | ||||
|             {1, &ISslConnection::SetHostName, "SetHostName"}, | ||||
|             {2, &ISslConnection::SetVerifyOption, "SetVerifyOption"}, | ||||
|             {3, &ISslConnection::SetIoMode, "SetIoMode"}, | ||||
|             {4, nullptr, "GetSocketDescriptor"}, | ||||
|             {5, nullptr, "GetHostName"}, | ||||
|             {6, nullptr, "GetVerifyOption"}, | ||||
|             {7, nullptr, "GetIoMode"}, | ||||
|             {8, nullptr, "DoHandshake"}, | ||||
|             {9, nullptr, "DoHandshakeGetServerCert"}, | ||||
|             {10, nullptr, "Read"}, | ||||
|             {11, nullptr, "Write"}, | ||||
|             {12, nullptr, "Pending"}, | ||||
|             {8, &ISslConnection::DoHandshake, "DoHandshake"}, | ||||
|             {9, &ISslConnection::DoHandshakeGetServerCert, "DoHandshakeGetServerCert"}, | ||||
|             {10, &ISslConnection::Read, "Read"}, | ||||
|             {11, &ISslConnection::Write, "Write"}, | ||||
|             {12, &ISslConnection::Pending, "Pending"}, | ||||
|             {13, nullptr, "Peek"}, | ||||
|             {14, nullptr, "Poll"}, | ||||
|             {15, nullptr, "GetVerifyCertError"}, | ||||
|             {16, nullptr, "GetNeededServerCertBufferSize"}, | ||||
|             {17, nullptr, "SetSessionCacheMode"}, | ||||
|             {17, &ISslConnection::SetSessionCacheMode, "SetSessionCacheMode"}, | ||||
|             {18, nullptr, "GetSessionCacheMode"}, | ||||
|             {19, nullptr, "FlushSessionCache"}, | ||||
|             {20, nullptr, "SetRenegotiationMode"}, | ||||
|             {21, nullptr, "GetRenegotiationMode"}, | ||||
|             {22, nullptr, "SetOption"}, | ||||
|             {22, &ISslConnection::SetOption, "SetOption"}, | ||||
|             {23, nullptr, "GetOption"}, | ||||
|             {24, nullptr, "GetVerifyCertErrors"}, | ||||
|             {25, nullptr, "GetCipherInfo"}, | ||||
| @@ -80,21 +107,299 @@ public: | ||||
|         // clang-format on | ||||
|  | ||||
|         RegisterHandlers(functions); | ||||
|  | ||||
|         shared_data->connection_count++; | ||||
|     } | ||||
|  | ||||
|     ~ISslConnection() { | ||||
|         shared_data->connection_count--; | ||||
|         if (fd_to_close.has_value()) { | ||||
|             const s32 fd = *fd_to_close; | ||||
|             if (!do_not_close_socket) { | ||||
|                 LOG_ERROR(Service_SSL, | ||||
|                           "do_not_close_socket was changed after setting socket; is this right?"); | ||||
|             } else { | ||||
|                 auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u"); | ||||
|                 if (bsd) { | ||||
|                     auto err = bsd->CloseImpl(fd); | ||||
|                     if (err != Service::Sockets::Errno::SUCCESS) { | ||||
|                         LOG_ERROR(Service_SSL, "Failed to close duplicated socket: {}", err); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
| private: | ||||
|     SslVersion ssl_version; | ||||
|     std::shared_ptr<SslContextSharedData> shared_data; | ||||
|     std::unique_ptr<SSLConnectionBackend> backend; | ||||
|     std::optional<int> fd_to_close; | ||||
|     bool do_not_close_socket = false; | ||||
|     bool get_server_cert_chain = false; | ||||
|     std::shared_ptr<Network::SocketBase> socket; | ||||
|     bool did_set_host_name = false; | ||||
|     bool did_handshake = false; | ||||
|  | ||||
|     ResultVal<s32> SetSocketDescriptorImpl(s32 fd) { | ||||
|         LOG_DEBUG(Service_SSL, "called, fd={}", fd); | ||||
|         ASSERT(!did_handshake); | ||||
|         auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u"); | ||||
|         ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; }); | ||||
|         s32 ret_fd; | ||||
|         // Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor | ||||
|         if (do_not_close_socket) { | ||||
|             auto res = bsd->DuplicateSocketImpl(fd); | ||||
|             if (!res.has_value()) { | ||||
|                 LOG_ERROR(Service_SSL, "Failed to duplicate socket with fd {}", fd); | ||||
|                 return ResultInvalidSocket; | ||||
|             } | ||||
|             fd = *res; | ||||
|             fd_to_close = fd; | ||||
|             ret_fd = fd; | ||||
|         } else { | ||||
|             ret_fd = -1; | ||||
|         } | ||||
|         std::optional<std::shared_ptr<Network::SocketBase>> sock = bsd->GetSocket(fd); | ||||
|         if (!sock.has_value()) { | ||||
|             LOG_ERROR(Service_SSL, "invalid socket fd {}", fd); | ||||
|             return ResultInvalidSocket; | ||||
|         } | ||||
|         socket = std::move(*sock); | ||||
|         backend->SetSocket(socket); | ||||
|         return ret_fd; | ||||
|     } | ||||
|  | ||||
|     Result SetHostNameImpl(const std::string& hostname) { | ||||
|         LOG_DEBUG(Service_SSL, "called. hostname={}", hostname); | ||||
|         ASSERT(!did_handshake); | ||||
|         Result res = backend->SetHostName(hostname); | ||||
|         if (res == ResultSuccess) { | ||||
|             did_set_host_name = true; | ||||
|         } | ||||
|         return res; | ||||
|     } | ||||
|  | ||||
|     Result SetVerifyOptionImpl(u32 option) { | ||||
|         ASSERT(!did_handshake); | ||||
|         LOG_WARNING(Service_SSL, "(STUBBED) called. option={}", option); | ||||
|         return ResultSuccess; | ||||
|     } | ||||
|  | ||||
|     Result SetIoModeImpl(u32 input_mode) { | ||||
|         auto mode = static_cast<IoMode>(input_mode); | ||||
|         ASSERT(mode == IoMode::Blocking || mode == IoMode::NonBlocking); | ||||
|         ASSERT_OR_EXECUTE(socket, { return ResultNoSocket; }); | ||||
|  | ||||
|         const bool non_block = mode == IoMode::NonBlocking; | ||||
|         const Network::Errno error = socket->SetNonBlock(non_block); | ||||
|         if (error != Network::Errno::SUCCESS) { | ||||
|             LOG_ERROR(Service_SSL, "Failed to set native socket non-block flag to {}", non_block); | ||||
|         } | ||||
|         return ResultSuccess; | ||||
|     } | ||||
|  | ||||
|     Result SetSessionCacheModeImpl(u32 mode) { | ||||
|         ASSERT(!did_handshake); | ||||
|         LOG_WARNING(Service_SSL, "(STUBBED) called. value={}", mode); | ||||
|         return ResultSuccess; | ||||
|     } | ||||
|  | ||||
|     Result DoHandshakeImpl() { | ||||
|         ASSERT_OR_EXECUTE(!did_handshake && socket, { return ResultNoSocket; }); | ||||
|         ASSERT_OR_EXECUTE_MSG( | ||||
|             did_set_host_name, { return ResultInternalError; }, | ||||
|             "Expected SetHostName before DoHandshake"); | ||||
|         Result res = backend->DoHandshake(); | ||||
|         did_handshake = res.IsSuccess(); | ||||
|         return res; | ||||
|     } | ||||
|  | ||||
|     std::vector<u8> SerializeServerCerts(const std::vector<std::vector<u8>>& certs) { | ||||
|         struct Header { | ||||
|             u64 magic; | ||||
|             u32 count; | ||||
|             u32 pad; | ||||
|         }; | ||||
|         struct EntryHeader { | ||||
|             u32 size; | ||||
|             u32 offset; | ||||
|         }; | ||||
|         if (!get_server_cert_chain) { | ||||
|             // Just return the first one, unencoded. | ||||
|             ASSERT_OR_EXECUTE_MSG( | ||||
|                 !certs.empty(), { return {}; }, "Should be at least one server cert"); | ||||
|             return certs[0]; | ||||
|         } | ||||
|         std::vector<u8> ret; | ||||
|         Header header{0x4E4D684374726543, static_cast<u32>(certs.size()), 0}; | ||||
|         ret.insert(ret.end(), reinterpret_cast<u8*>(&header), reinterpret_cast<u8*>(&header + 1)); | ||||
|         size_t data_offset = sizeof(Header) + certs.size() * sizeof(EntryHeader); | ||||
|         for (auto& cert : certs) { | ||||
|             EntryHeader entry_header{static_cast<u32>(cert.size()), static_cast<u32>(data_offset)}; | ||||
|             data_offset += cert.size(); | ||||
|             ret.insert(ret.end(), reinterpret_cast<u8*>(&entry_header), | ||||
|                        reinterpret_cast<u8*>(&entry_header + 1)); | ||||
|         } | ||||
|         for (auto& cert : certs) { | ||||
|             ret.insert(ret.end(), cert.begin(), cert.end()); | ||||
|         } | ||||
|         return ret; | ||||
|     } | ||||
|  | ||||
|     ResultVal<std::vector<u8>> ReadImpl(size_t size) { | ||||
|         ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; }); | ||||
|         std::vector<u8> res(size); | ||||
|         ResultVal<size_t> actual = backend->Read(res); | ||||
|         if (actual.Failed()) { | ||||
|             return actual.Code(); | ||||
|         } | ||||
|         res.resize(*actual); | ||||
|         return res; | ||||
|     } | ||||
|  | ||||
|     ResultVal<size_t> WriteImpl(std::span<const u8> data) { | ||||
|         ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; }); | ||||
|         return backend->Write(data); | ||||
|     } | ||||
|  | ||||
|     ResultVal<s32> PendingImpl() { | ||||
|         LOG_WARNING(Service_SSL, "(STUBBED) called."); | ||||
|         return 0; | ||||
|     } | ||||
|  | ||||
|     void SetSocketDescriptor(HLERequestContext& ctx) { | ||||
|         IPC::RequestParser rp{ctx}; | ||||
|         const s32 fd = rp.Pop<s32>(); | ||||
|         const ResultVal<s32> res = SetSocketDescriptorImpl(fd); | ||||
|         IPC::ResponseBuilder rb{ctx, 3}; | ||||
|         rb.Push(res.Code()); | ||||
|         rb.Push<s32>(res.ValueOr(-1)); | ||||
|     } | ||||
|  | ||||
|     void SetHostName(HLERequestContext& ctx) { | ||||
|         const std::string hostname = Common::StringFromBuffer(ctx.ReadBuffer()); | ||||
|         const Result res = SetHostNameImpl(hostname); | ||||
|         IPC::ResponseBuilder rb{ctx, 2}; | ||||
|         rb.Push(res); | ||||
|     } | ||||
|  | ||||
|     void SetVerifyOption(HLERequestContext& ctx) { | ||||
|         IPC::RequestParser rp{ctx}; | ||||
|         const u32 option = rp.Pop<u32>(); | ||||
|         const Result res = SetVerifyOptionImpl(option); | ||||
|         IPC::ResponseBuilder rb{ctx, 2}; | ||||
|         rb.Push(res); | ||||
|     } | ||||
|  | ||||
|     void SetIoMode(HLERequestContext& ctx) { | ||||
|         IPC::RequestParser rp{ctx}; | ||||
|         const u32 mode = rp.Pop<u32>(); | ||||
|         const Result res = SetIoModeImpl(mode); | ||||
|         IPC::ResponseBuilder rb{ctx, 2}; | ||||
|         rb.Push(res); | ||||
|     } | ||||
|  | ||||
|     void DoHandshake(HLERequestContext& ctx) { | ||||
|         const Result res = DoHandshakeImpl(); | ||||
|         IPC::ResponseBuilder rb{ctx, 2}; | ||||
|         rb.Push(res); | ||||
|     } | ||||
|  | ||||
|     void DoHandshakeGetServerCert(HLERequestContext& ctx) { | ||||
|         struct OutputParameters { | ||||
|             u32 certs_size; | ||||
|             u32 certs_count; | ||||
|         }; | ||||
|         static_assert(sizeof(OutputParameters) == 0x8); | ||||
|  | ||||
|         const Result res = DoHandshakeImpl(); | ||||
|         OutputParameters out{}; | ||||
|         if (res == ResultSuccess) { | ||||
|             auto certs = backend->GetServerCerts(); | ||||
|             if (certs.Succeeded()) { | ||||
|                 const std::vector<u8> certs_buf = SerializeServerCerts(*certs); | ||||
|                 ctx.WriteBuffer(certs_buf); | ||||
|                 out.certs_count = static_cast<u32>(certs->size()); | ||||
|                 out.certs_size = static_cast<u32>(certs_buf.size()); | ||||
|             } | ||||
|         } | ||||
|         IPC::ResponseBuilder rb{ctx, 4}; | ||||
|         rb.Push(res); | ||||
|         rb.PushRaw(out); | ||||
|     } | ||||
|  | ||||
|     void Read(HLERequestContext& ctx) { | ||||
|         const ResultVal<std::vector<u8>> res = ReadImpl(ctx.GetWriteBufferSize()); | ||||
|         IPC::ResponseBuilder rb{ctx, 3}; | ||||
|         rb.Push(res.Code()); | ||||
|         if (res.Succeeded()) { | ||||
|             rb.Push(static_cast<u32>(res->size())); | ||||
|             ctx.WriteBuffer(*res); | ||||
|         } else { | ||||
|             rb.Push(static_cast<u32>(0)); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     void Write(HLERequestContext& ctx) { | ||||
|         const ResultVal<size_t> res = WriteImpl(ctx.ReadBuffer()); | ||||
|         IPC::ResponseBuilder rb{ctx, 3}; | ||||
|         rb.Push(res.Code()); | ||||
|         rb.Push(static_cast<u32>(res.ValueOr(0))); | ||||
|     } | ||||
|  | ||||
|     void Pending(HLERequestContext& ctx) { | ||||
|         const ResultVal<s32> res = PendingImpl(); | ||||
|         IPC::ResponseBuilder rb{ctx, 3}; | ||||
|         rb.Push(res.Code()); | ||||
|         rb.Push<s32>(res.ValueOr(0)); | ||||
|     } | ||||
|  | ||||
|     void SetSessionCacheMode(HLERequestContext& ctx) { | ||||
|         IPC::RequestParser rp{ctx}; | ||||
|         const u32 mode = rp.Pop<u32>(); | ||||
|         const Result res = SetSessionCacheModeImpl(mode); | ||||
|         IPC::ResponseBuilder rb{ctx, 2}; | ||||
|         rb.Push(res); | ||||
|     } | ||||
|  | ||||
|     void SetOption(HLERequestContext& ctx) { | ||||
|         struct Parameters { | ||||
|             OptionType option; | ||||
|             s32 value; | ||||
|         }; | ||||
|         static_assert(sizeof(Parameters) == 0x8, "Parameters is an invalid size"); | ||||
|  | ||||
|         IPC::RequestParser rp{ctx}; | ||||
|         const auto parameters = rp.PopRaw<Parameters>(); | ||||
|  | ||||
|         switch (parameters.option) { | ||||
|         case OptionType::DoNotCloseSocket: | ||||
|             do_not_close_socket = static_cast<bool>(parameters.value); | ||||
|             break; | ||||
|         case OptionType::GetServerCertChain: | ||||
|             get_server_cert_chain = static_cast<bool>(parameters.value); | ||||
|             break; | ||||
|         default: | ||||
|             LOG_WARNING(Service_SSL, "Unknown option={}, value={}", parameters.option, | ||||
|                         parameters.value); | ||||
|         } | ||||
|  | ||||
|         IPC::ResponseBuilder rb{ctx, 2}; | ||||
|         rb.Push(ResultSuccess); | ||||
|     } | ||||
| }; | ||||
|  | ||||
| class ISslContext final : public ServiceFramework<ISslContext> { | ||||
| public: | ||||
|     explicit ISslContext(Core::System& system_, SslVersion version) | ||||
|         : ServiceFramework{system_, "ISslContext"}, ssl_version{version} { | ||||
|         : ServiceFramework{system_, "ISslContext"}, ssl_version{version}, | ||||
|           shared_data{std::make_shared<SslContextSharedData>()} { | ||||
|         static const FunctionInfo functions[] = { | ||||
|             {0, &ISslContext::SetOption, "SetOption"}, | ||||
|             {1, nullptr, "GetOption"}, | ||||
|             {2, &ISslContext::CreateConnection, "CreateConnection"}, | ||||
|             {3, nullptr, "GetConnectionCount"}, | ||||
|             {3, &ISslContext::GetConnectionCount, "GetConnectionCount"}, | ||||
|             {4, &ISslContext::ImportServerPki, "ImportServerPki"}, | ||||
|             {5, &ISslContext::ImportClientPki, "ImportClientPki"}, | ||||
|             {6, nullptr, "RemoveServerPki"}, | ||||
| @@ -111,6 +416,7 @@ public: | ||||
|  | ||||
| private: | ||||
|     SslVersion ssl_version; | ||||
|     std::shared_ptr<SslContextSharedData> shared_data; | ||||
|  | ||||
|     void SetOption(HLERequestContext& ctx) { | ||||
|         struct Parameters { | ||||
| @@ -130,11 +436,24 @@ private: | ||||
|     } | ||||
|  | ||||
|     void CreateConnection(HLERequestContext& ctx) { | ||||
|         LOG_WARNING(Service_SSL, "(STUBBED) called"); | ||||
|         LOG_WARNING(Service_SSL, "called"); | ||||
|  | ||||
|         auto backend_res = CreateSSLConnectionBackend(); | ||||
|  | ||||
|         IPC::ResponseBuilder rb{ctx, 2, 0, 1}; | ||||
|         rb.Push(backend_res.Code()); | ||||
|         if (backend_res.Succeeded()) { | ||||
|             rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data, | ||||
|                                                 std::move(*backend_res)); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     void GetConnectionCount(HLERequestContext& ctx) { | ||||
|         LOG_DEBUG(Service_SSL, "connection_count={}", shared_data->connection_count); | ||||
|  | ||||
|         IPC::ResponseBuilder rb{ctx, 3}; | ||||
|         rb.Push(ResultSuccess); | ||||
|         rb.PushIpcInterface<ISslConnection>(system, ssl_version); | ||||
|         rb.Push(shared_data->connection_count); | ||||
|     } | ||||
|  | ||||
|     void ImportServerPki(HLERequestContext& ctx) { | ||||
|   | ||||
							
								
								
									
										45
									
								
								src/core/hle/service/ssl/ssl_backend.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								src/core/hle/service/ssl/ssl_backend.h
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,45 @@ | ||||
| // SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project | ||||
| // SPDX-License-Identifier: GPL-2.0-or-later | ||||
|  | ||||
| #pragma once | ||||
|  | ||||
| #include "core/hle/result.h" | ||||
|  | ||||
| #include "common/common_types.h" | ||||
|  | ||||
| #include <memory> | ||||
| #include <span> | ||||
| #include <string> | ||||
| #include <vector> | ||||
|  | ||||
| namespace Network { | ||||
| class SocketBase; | ||||
| } | ||||
|  | ||||
| namespace Service::SSL { | ||||
|  | ||||
| constexpr Result ResultNoSocket{ErrorModule::SSLSrv, 103}; | ||||
| constexpr Result ResultInvalidSocket{ErrorModule::SSLSrv, 106}; | ||||
| constexpr Result ResultTimeout{ErrorModule::SSLSrv, 205}; | ||||
| constexpr Result ResultInternalError{ErrorModule::SSLSrv, 999}; // made up | ||||
|  | ||||
| // ResultWouldBlock is returned from Read and Write, and oddly, DoHandshake, | ||||
| // with no way in the latter case to distinguish whether the client should poll | ||||
| // for read or write.  The one official client I've seen handles this by always | ||||
| // polling for read (with a timeout). | ||||
| constexpr Result ResultWouldBlock{ErrorModule::SSLSrv, 204}; | ||||
|  | ||||
| class SSLConnectionBackend { | ||||
| public: | ||||
|     virtual ~SSLConnectionBackend() {} | ||||
|     virtual void SetSocket(std::shared_ptr<Network::SocketBase> socket) = 0; | ||||
|     virtual Result SetHostName(const std::string& hostname) = 0; | ||||
|     virtual Result DoHandshake() = 0; | ||||
|     virtual ResultVal<size_t> Read(std::span<u8> data) = 0; | ||||
|     virtual ResultVal<size_t> Write(std::span<const u8> data) = 0; | ||||
|     virtual ResultVal<std::vector<std::vector<u8>>> GetServerCerts() = 0; | ||||
| }; | ||||
|  | ||||
| ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend(); | ||||
|  | ||||
| } // namespace Service::SSL | ||||
							
								
								
									
										16
									
								
								src/core/hle/service/ssl/ssl_backend_none.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								src/core/hle/service/ssl/ssl_backend_none.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,16 @@ | ||||
| // SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project | ||||
| // SPDX-License-Identifier: GPL-2.0-or-later | ||||
|  | ||||
| #include "core/hle/service/ssl/ssl_backend.h" | ||||
|  | ||||
| #include "common/logging/log.h" | ||||
|  | ||||
| namespace Service::SSL { | ||||
|  | ||||
| ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { | ||||
|     LOG_ERROR(Service_SSL, | ||||
|               "Can't create SSL connection because no SSL backend is available on this platform"); | ||||
|     return ResultInternalError; | ||||
| } | ||||
|  | ||||
| } // namespace Service::SSL | ||||
							
								
								
									
										351
									
								
								src/core/hle/service/ssl/ssl_backend_openssl.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										351
									
								
								src/core/hle/service/ssl/ssl_backend_openssl.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,351 @@ | ||||
| // SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project | ||||
| // SPDX-License-Identifier: GPL-2.0-or-later | ||||
|  | ||||
| #include "core/hle/service/ssl/ssl_backend.h" | ||||
| #include "core/internal_network/network.h" | ||||
| #include "core/internal_network/sockets.h" | ||||
|  | ||||
| #include "common/fs/file.h" | ||||
| #include "common/hex_util.h" | ||||
| #include "common/string_util.h" | ||||
|  | ||||
| #include <mutex> | ||||
|  | ||||
| #include <openssl/bio.h> | ||||
| #include <openssl/err.h> | ||||
| #include <openssl/ssl.h> | ||||
| #include <openssl/x509.h> | ||||
|  | ||||
| using namespace Common::FS; | ||||
|  | ||||
| namespace Service::SSL { | ||||
|  | ||||
| // Import OpenSSL's `SSL` type into the namespace.  This is needed because the | ||||
| // namespace is also named `SSL`. | ||||
| using ::SSL; | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| std::once_flag one_time_init_flag; | ||||
| bool one_time_init_success = false; | ||||
|  | ||||
| SSL_CTX* ssl_ctx; | ||||
| IOFile key_log_file; // only open if SSLKEYLOGFILE set in environment | ||||
| BIO_METHOD* bio_meth; | ||||
|  | ||||
| Result CheckOpenSSLErrors(); | ||||
| void OneTimeInit(); | ||||
| void OneTimeInitLogFile(); | ||||
| bool OneTimeInitBIO(); | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| class SSLConnectionBackendOpenSSL final : public SSLConnectionBackend { | ||||
| public: | ||||
|     Result Init() { | ||||
|         std::call_once(one_time_init_flag, OneTimeInit); | ||||
|  | ||||
|         if (!one_time_init_success) { | ||||
|             LOG_ERROR(Service_SSL, | ||||
|                       "Can't create SSL connection because OpenSSL one-time initialization failed"); | ||||
|             return ResultInternalError; | ||||
|         } | ||||
|  | ||||
|         ssl = SSL_new(ssl_ctx); | ||||
|         if (!ssl) { | ||||
|             LOG_ERROR(Service_SSL, "SSL_new failed"); | ||||
|             return CheckOpenSSLErrors(); | ||||
|         } | ||||
|  | ||||
|         SSL_set_connect_state(ssl); | ||||
|  | ||||
|         bio = BIO_new(bio_meth); | ||||
|         if (!bio) { | ||||
|             LOG_ERROR(Service_SSL, "BIO_new failed"); | ||||
|             return CheckOpenSSLErrors(); | ||||
|         } | ||||
|  | ||||
|         BIO_set_data(bio, this); | ||||
|         BIO_set_init(bio, 1); | ||||
|         SSL_set_bio(ssl, bio, bio); | ||||
|  | ||||
|         return ResultSuccess; | ||||
|     } | ||||
|  | ||||
|     void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override { | ||||
|         socket = std::move(socket_in); | ||||
|     } | ||||
|  | ||||
|     Result SetHostName(const std::string& hostname) override { | ||||
|         if (!SSL_set1_host(ssl, hostname.c_str())) { // hostname for verification | ||||
|             LOG_ERROR(Service_SSL, "SSL_set1_host({}) failed", hostname); | ||||
|             return CheckOpenSSLErrors(); | ||||
|         } | ||||
|         if (!SSL_set_tlsext_host_name(ssl, hostname.c_str())) { // hostname for SNI | ||||
|             LOG_ERROR(Service_SSL, "SSL_set_tlsext_host_name({}) failed", hostname); | ||||
|             return CheckOpenSSLErrors(); | ||||
|         } | ||||
|         return ResultSuccess; | ||||
|     } | ||||
|  | ||||
|     Result DoHandshake() override { | ||||
|         SSL_set_verify_result(ssl, X509_V_OK); | ||||
|         const int ret = SSL_do_handshake(ssl); | ||||
|         const long verify_result = SSL_get_verify_result(ssl); | ||||
|         if (verify_result != X509_V_OK) { | ||||
|             LOG_ERROR(Service_SSL, "SSL cert verification failed because: {}", | ||||
|                       X509_verify_cert_error_string(verify_result)); | ||||
|             return CheckOpenSSLErrors(); | ||||
|         } | ||||
|         if (ret <= 0) { | ||||
|             const int ssl_err = SSL_get_error(ssl, ret); | ||||
|             if (ssl_err == SSL_ERROR_ZERO_RETURN || | ||||
|                 (ssl_err == SSL_ERROR_SYSCALL && got_read_eof)) { | ||||
|                 LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up"); | ||||
|                 return ResultInternalError; | ||||
|             } | ||||
|         } | ||||
|         return HandleReturn("SSL_do_handshake", 0, ret).Code(); | ||||
|     } | ||||
|  | ||||
|     ResultVal<size_t> Read(std::span<u8> data) override { | ||||
|         size_t actual; | ||||
|         const int ret = SSL_read_ex(ssl, data.data(), data.size(), &actual); | ||||
|         return HandleReturn("SSL_read_ex", actual, ret); | ||||
|     } | ||||
|  | ||||
|     ResultVal<size_t> Write(std::span<const u8> data) override { | ||||
|         size_t actual; | ||||
|         const int ret = SSL_write_ex(ssl, data.data(), data.size(), &actual); | ||||
|         return HandleReturn("SSL_write_ex", actual, ret); | ||||
|     } | ||||
|  | ||||
|     ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) { | ||||
|         const int ssl_err = SSL_get_error(ssl, ret); | ||||
|         CheckOpenSSLErrors(); | ||||
|         switch (ssl_err) { | ||||
|         case SSL_ERROR_NONE: | ||||
|             return actual; | ||||
|         case SSL_ERROR_ZERO_RETURN: | ||||
|             LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_ZERO_RETURN", what); | ||||
|             // DoHandshake special-cases this, but for Read and Write: | ||||
|             return size_t(0); | ||||
|         case SSL_ERROR_WANT_READ: | ||||
|             LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_READ", what); | ||||
|             return ResultWouldBlock; | ||||
|         case SSL_ERROR_WANT_WRITE: | ||||
|             LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_WRITE", what); | ||||
|             return ResultWouldBlock; | ||||
|         default: | ||||
|             if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof) { | ||||
|                 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what); | ||||
|                 return size_t(0); | ||||
|             } | ||||
|             LOG_ERROR(Service_SSL, "{} => other SSL_get_error return value {}", what, ssl_err); | ||||
|             return ResultInternalError; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { | ||||
|         STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl); | ||||
|         if (!chain) { | ||||
|             LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr"); | ||||
|             return ResultInternalError; | ||||
|         } | ||||
|         std::vector<std::vector<u8>> ret; | ||||
|         int count = sk_X509_num(chain); | ||||
|         ASSERT(count >= 0); | ||||
|         for (int i = 0; i < count; i++) { | ||||
|             X509* x509 = sk_X509_value(chain, i); | ||||
|             ASSERT_OR_EXECUTE(x509 != nullptr, { continue; }); | ||||
|             unsigned char* buf = nullptr; | ||||
|             int len = i2d_X509(x509, &buf); | ||||
|             ASSERT_OR_EXECUTE(len >= 0 && buf, { continue; }); | ||||
|             ret.emplace_back(buf, buf + len); | ||||
|             OPENSSL_free(buf); | ||||
|         } | ||||
|         return ret; | ||||
|     } | ||||
|  | ||||
|     ~SSLConnectionBackendOpenSSL() { | ||||
|         // these are null-tolerant: | ||||
|         SSL_free(ssl); | ||||
|         BIO_free(bio); | ||||
|     } | ||||
|  | ||||
|     static void KeyLogCallback(const SSL* ssl, const char* line) { | ||||
|         std::string str(line); | ||||
|         str.push_back('\n'); | ||||
|         // Do this in a single WriteString for atomicity if multiple instances | ||||
|         // are running on different threads (though that can't currently | ||||
|         // happen). | ||||
|         if (key_log_file.WriteString(str) != str.size() || !key_log_file.Flush()) { | ||||
|             LOG_CRITICAL(Service_SSL, "Failed to write to SSLKEYLOGFILE"); | ||||
|         } | ||||
|         LOG_DEBUG(Service_SSL, "Wrote to SSLKEYLOGFILE: {}", line); | ||||
|     } | ||||
|  | ||||
|     static int WriteCallback(BIO* bio, const char* buf, size_t len, size_t* actual_p) { | ||||
|         auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio)); | ||||
|         ASSERT_OR_EXECUTE_MSG( | ||||
|             self->socket, { return 0; }, "OpenSSL asked to send but we have no socket"); | ||||
|         BIO_clear_retry_flags(bio); | ||||
|         auto [actual, err] = self->socket->Send({reinterpret_cast<const u8*>(buf), len}, 0); | ||||
|         switch (err) { | ||||
|         case Network::Errno::SUCCESS: | ||||
|             *actual_p = actual; | ||||
|             return 1; | ||||
|         case Network::Errno::AGAIN: | ||||
|             BIO_set_flags(bio, BIO_FLAGS_WRITE | BIO_FLAGS_SHOULD_RETRY); | ||||
|             return 0; | ||||
|         default: | ||||
|             LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err); | ||||
|             return -1; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     static int ReadCallback(BIO* bio, char* buf, size_t len, size_t* actual_p) { | ||||
|         auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio)); | ||||
|         ASSERT_OR_EXECUTE_MSG( | ||||
|             self->socket, { return 0; }, "OpenSSL asked to recv but we have no socket"); | ||||
|         BIO_clear_retry_flags(bio); | ||||
|         auto [actual, err] = self->socket->Recv(0, {reinterpret_cast<u8*>(buf), len}); | ||||
|         switch (err) { | ||||
|         case Network::Errno::SUCCESS: | ||||
|             *actual_p = actual; | ||||
|             if (actual == 0) { | ||||
|                 self->got_read_eof = true; | ||||
|             } | ||||
|             return actual ? 1 : 0; | ||||
|         case Network::Errno::AGAIN: | ||||
|             BIO_set_flags(bio, BIO_FLAGS_READ | BIO_FLAGS_SHOULD_RETRY); | ||||
|             return 0; | ||||
|         default: | ||||
|             LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err); | ||||
|             return -1; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     static long CtrlCallback(BIO* bio, int cmd, long l_arg, void* p_arg) { | ||||
|         switch (cmd) { | ||||
|         case BIO_CTRL_FLUSH: | ||||
|             // Nothing to flush. | ||||
|             return 1; | ||||
|         case BIO_CTRL_PUSH: | ||||
|         case BIO_CTRL_POP: | ||||
| #ifdef BIO_CTRL_GET_KTLS_SEND | ||||
|         case BIO_CTRL_GET_KTLS_SEND: | ||||
|         case BIO_CTRL_GET_KTLS_RECV: | ||||
| #endif | ||||
|             // We don't support these operations, but don't bother logging them | ||||
|             // as they're nothing unusual. | ||||
|             return 0; | ||||
|         default: | ||||
|             LOG_DEBUG(Service_SSL, "OpenSSL BIO got ctrl({}, {}, {})", cmd, l_arg, p_arg); | ||||
|             return 0; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     SSL* ssl = nullptr; | ||||
|     BIO* bio = nullptr; | ||||
|     bool got_read_eof = false; | ||||
|  | ||||
|     std::shared_ptr<Network::SocketBase> socket; | ||||
| }; | ||||
|  | ||||
| ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { | ||||
|     auto conn = std::make_unique<SSLConnectionBackendOpenSSL>(); | ||||
|     const Result res = conn->Init(); | ||||
|     if (res.IsFailure()) { | ||||
|         return res; | ||||
|     } | ||||
|     return conn; | ||||
| } | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| Result CheckOpenSSLErrors() { | ||||
|     unsigned long rc; | ||||
|     const char* file; | ||||
|     int line; | ||||
|     const char* func; | ||||
|     const char* data; | ||||
|     int flags; | ||||
| #if OPENSSL_VERSION_NUMBER >= 0x30000000L | ||||
|     while ((rc = ERR_get_error_all(&file, &line, &func, &data, &flags))) | ||||
| #else | ||||
|     // Can't get function names from OpenSSL on this version, so use mine: | ||||
|     func = __func__; | ||||
|     while ((rc = ERR_get_error_line_data(&file, &line, &data, &flags))) | ||||
| #endif | ||||
|     { | ||||
|         std::string msg; | ||||
|         msg.resize(1024, '\0'); | ||||
|         ERR_error_string_n(rc, msg.data(), msg.size()); | ||||
|         msg.resize(strlen(msg.data()), '\0'); | ||||
|         if (flags & ERR_TXT_STRING) { | ||||
|             msg.append(" | "); | ||||
|             msg.append(data); | ||||
|         } | ||||
|         Common::Log::FmtLogMessage(Common::Log::Class::Service_SSL, Common::Log::Level::Error, | ||||
|                                    Common::Log::TrimSourcePath(file), line, func, "OpenSSL: {}", | ||||
|                                    msg); | ||||
|     } | ||||
|     return ResultInternalError; | ||||
| } | ||||
|  | ||||
| void OneTimeInit() { | ||||
|     ssl_ctx = SSL_CTX_new(TLS_client_method()); | ||||
|     if (!ssl_ctx) { | ||||
|         LOG_ERROR(Service_SSL, "SSL_CTX_new failed"); | ||||
|         CheckOpenSSLErrors(); | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, nullptr); | ||||
|  | ||||
|     if (!SSL_CTX_set_default_verify_paths(ssl_ctx)) { | ||||
|         LOG_ERROR(Service_SSL, "SSL_CTX_set_default_verify_paths failed"); | ||||
|         CheckOpenSSLErrors(); | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     OneTimeInitLogFile(); | ||||
|  | ||||
|     if (!OneTimeInitBIO()) { | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     one_time_init_success = true; | ||||
| } | ||||
|  | ||||
| void OneTimeInitLogFile() { | ||||
|     const char* logfile = getenv("SSLKEYLOGFILE"); | ||||
|     if (logfile) { | ||||
|         key_log_file.Open(logfile, FileAccessMode::Append, FileType::TextFile, | ||||
|                           FileShareFlag::ShareWriteOnly); | ||||
|         if (key_log_file.IsOpen()) { | ||||
|             SSL_CTX_set_keylog_callback(ssl_ctx, &SSLConnectionBackendOpenSSL::KeyLogCallback); | ||||
|         } else { | ||||
|             LOG_CRITICAL(Service_SSL, | ||||
|                          "SSLKEYLOGFILE was set but file could not be opened; not logging keys!"); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| bool OneTimeInitBIO() { | ||||
|     bio_meth = | ||||
|         BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, "SSLConnectionBackendOpenSSL"); | ||||
|     if (!bio_meth || | ||||
|         !BIO_meth_set_write_ex(bio_meth, &SSLConnectionBackendOpenSSL::WriteCallback) || | ||||
|         !BIO_meth_set_read_ex(bio_meth, &SSLConnectionBackendOpenSSL::ReadCallback) || | ||||
|         !BIO_meth_set_ctrl(bio_meth, &SSLConnectionBackendOpenSSL::CtrlCallback)) { | ||||
|         LOG_ERROR(Service_SSL, "Failed to create BIO_METHOD"); | ||||
|         return false; | ||||
|     } | ||||
|     return true; | ||||
| } | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| } // namespace Service::SSL | ||||
							
								
								
									
										543
									
								
								src/core/hle/service/ssl/ssl_backend_schannel.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										543
									
								
								src/core/hle/service/ssl/ssl_backend_schannel.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,543 @@ | ||||
| // SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project | ||||
| // SPDX-License-Identifier: GPL-2.0-or-later | ||||
|  | ||||
| #include "core/hle/service/ssl/ssl_backend.h" | ||||
| #include "core/internal_network/network.h" | ||||
| #include "core/internal_network/sockets.h" | ||||
|  | ||||
| #include "common/error.h" | ||||
| #include "common/fs/file.h" | ||||
| #include "common/hex_util.h" | ||||
| #include "common/string_util.h" | ||||
|  | ||||
| #include <mutex> | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| // These includes are inside the namespace to avoid a conflict on MinGW where | ||||
| // the headers define an enum containing Network and Service as enumerators | ||||
| // (which clash with the correspondingly named namespaces). | ||||
| #define SECURITY_WIN32 | ||||
| #include <schnlsp.h> | ||||
| #include <security.h> | ||||
|  | ||||
| std::once_flag one_time_init_flag; | ||||
| bool one_time_init_success = false; | ||||
|  | ||||
| SCHANNEL_CRED schannel_cred{}; | ||||
| CredHandle cred_handle; | ||||
|  | ||||
| static void OneTimeInit() { | ||||
|     schannel_cred.dwVersion = SCHANNEL_CRED_VERSION; | ||||
|     schannel_cred.dwFlags = | ||||
|         SCH_USE_STRONG_CRYPTO |         // don't allow insecure protocols | ||||
|         SCH_CRED_AUTO_CRED_VALIDATION | // validate certs | ||||
|         SCH_CRED_NO_DEFAULT_CREDS;      // don't automatically present a client certificate | ||||
|     // ^ I'm assuming that nobody would want to connect Yuzu to a | ||||
|     // service that requires some OS-provided corporate client | ||||
|     // certificate, and presenting one to some arbitrary server | ||||
|     // might be a privacy concern?  Who knows, though. | ||||
|  | ||||
|     const SECURITY_STATUS ret = | ||||
|         AcquireCredentialsHandle(nullptr, const_cast<LPTSTR>(UNISP_NAME), SECPKG_CRED_OUTBOUND, | ||||
|                                  nullptr, &schannel_cred, nullptr, nullptr, &cred_handle, nullptr); | ||||
|     if (ret != SEC_E_OK) { | ||||
|         // SECURITY_STATUS codes are a type of HRESULT and can be used with NativeErrorToString. | ||||
|         LOG_ERROR(Service_SSL, "AcquireCredentialsHandle failed: {}", | ||||
|                   Common::NativeErrorToString(ret)); | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     if (getenv("SSLKEYLOGFILE")) { | ||||
|         LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but Schannel does not support exporting " | ||||
|                                   "keys; not logging keys!"); | ||||
|         // Not fatal. | ||||
|     } | ||||
|  | ||||
|     one_time_init_success = true; | ||||
| } | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| namespace Service::SSL { | ||||
|  | ||||
| class SSLConnectionBackendSchannel final : public SSLConnectionBackend { | ||||
| public: | ||||
|     Result Init() { | ||||
|         std::call_once(one_time_init_flag, OneTimeInit); | ||||
|  | ||||
|         if (!one_time_init_success) { | ||||
|             LOG_ERROR( | ||||
|                 Service_SSL, | ||||
|                 "Can't create SSL connection because Schannel one-time initialization failed"); | ||||
|             return ResultInternalError; | ||||
|         } | ||||
|  | ||||
|         return ResultSuccess; | ||||
|     } | ||||
|  | ||||
|     void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override { | ||||
|         socket = std::move(socket_in); | ||||
|     } | ||||
|  | ||||
|     Result SetHostName(const std::string& hostname_in) override { | ||||
|         hostname = hostname_in; | ||||
|         return ResultSuccess; | ||||
|     } | ||||
|  | ||||
|     Result DoHandshake() override { | ||||
|         while (1) { | ||||
|             Result r; | ||||
|             switch (handshake_state) { | ||||
|             case HandshakeState::Initial: | ||||
|                 if ((r = FlushCiphertextWriteBuf()) != ResultSuccess || | ||||
|                     (r = CallInitializeSecurityContext()) != ResultSuccess) { | ||||
|                     return r; | ||||
|                 } | ||||
|                 // CallInitializeSecurityContext updated `handshake_state`. | ||||
|                 continue; | ||||
|             case HandshakeState::ContinueNeeded: | ||||
|             case HandshakeState::IncompleteMessage: | ||||
|                 if ((r = FlushCiphertextWriteBuf()) != ResultSuccess || | ||||
|                     (r = FillCiphertextReadBuf()) != ResultSuccess) { | ||||
|                     return r; | ||||
|                 } | ||||
|                 if (ciphertext_read_buf.empty()) { | ||||
|                     LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up"); | ||||
|                     return ResultInternalError; | ||||
|                 } | ||||
|                 if ((r = CallInitializeSecurityContext()) != ResultSuccess) { | ||||
|                     return r; | ||||
|                 } | ||||
|                 // CallInitializeSecurityContext updated `handshake_state`. | ||||
|                 continue; | ||||
|             case HandshakeState::DoneAfterFlush: | ||||
|                 if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) { | ||||
|                     return r; | ||||
|                 } | ||||
|                 handshake_state = HandshakeState::Connected; | ||||
|                 return ResultSuccess; | ||||
|             case HandshakeState::Connected: | ||||
|                 LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook"); | ||||
|                 return ResultInternalError; | ||||
|             case HandshakeState::Error: | ||||
|                 return ResultInternalError; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     Result FillCiphertextReadBuf() { | ||||
|         const size_t fill_size = read_buf_fill_size ? read_buf_fill_size : 4096; | ||||
|         read_buf_fill_size = 0; | ||||
|         // This unnecessarily zeroes the buffer; oh well. | ||||
|         const size_t offset = ciphertext_read_buf.size(); | ||||
|         ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; }); | ||||
|         ciphertext_read_buf.resize(offset + fill_size, 0); | ||||
|         const auto read_span = std::span(ciphertext_read_buf).subspan(offset, fill_size); | ||||
|         const auto [actual, err] = socket->Recv(0, read_span); | ||||
|         switch (err) { | ||||
|         case Network::Errno::SUCCESS: | ||||
|             ASSERT(static_cast<size_t>(actual) <= fill_size); | ||||
|             ciphertext_read_buf.resize(offset + actual); | ||||
|             return ResultSuccess; | ||||
|         case Network::Errno::AGAIN: | ||||
|             ciphertext_read_buf.resize(offset); | ||||
|             return ResultWouldBlock; | ||||
|         default: | ||||
|             ciphertext_read_buf.resize(offset); | ||||
|             LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err); | ||||
|             return ResultInternalError; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // Returns success if the write buffer has been completely emptied. | ||||
|     Result FlushCiphertextWriteBuf() { | ||||
|         while (!ciphertext_write_buf.empty()) { | ||||
|             const auto [actual, err] = socket->Send(ciphertext_write_buf, 0); | ||||
|             switch (err) { | ||||
|             case Network::Errno::SUCCESS: | ||||
|                 ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf.size()); | ||||
|                 ciphertext_write_buf.erase(ciphertext_write_buf.begin(), | ||||
|                                            ciphertext_write_buf.begin() + actual); | ||||
|                 break; | ||||
|             case Network::Errno::AGAIN: | ||||
|                 return ResultWouldBlock; | ||||
|             default: | ||||
|                 LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err); | ||||
|                 return ResultInternalError; | ||||
|             } | ||||
|         } | ||||
|         return ResultSuccess; | ||||
|     } | ||||
|  | ||||
|     Result CallInitializeSecurityContext() { | ||||
|         const unsigned long req = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | | ||||
|                                   ISC_REQ_INTEGRITY | ISC_REQ_REPLAY_DETECT | | ||||
|                                   ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM | | ||||
|                                   ISC_REQ_USE_SUPPLIED_CREDS; | ||||
|         unsigned long attr; | ||||
|         // https://learn.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--schannel | ||||
|         std::array<SecBuffer, 2> input_buffers{{ | ||||
|             // only used if `initial_call_done` | ||||
|             { | ||||
|                 // [0] | ||||
|                 .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()), | ||||
|                 .BufferType = SECBUFFER_TOKEN, | ||||
|                 .pvBuffer = ciphertext_read_buf.data(), | ||||
|             }, | ||||
|             { | ||||
|                 // [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is | ||||
|                 //     returned, or SECBUFFER_EXTRA when SEC_E_CONTINUE_NEEDED is returned if the | ||||
|                 //     whole buffer wasn't used) | ||||
|                 .cbBuffer = 0, | ||||
|                 .BufferType = SECBUFFER_EMPTY, | ||||
|                 .pvBuffer = nullptr, | ||||
|             }, | ||||
|         }}; | ||||
|         std::array<SecBuffer, 2> output_buffers{{ | ||||
|             { | ||||
|                 .cbBuffer = 0, | ||||
|                 .BufferType = SECBUFFER_TOKEN, | ||||
|                 .pvBuffer = nullptr, | ||||
|             }, // [0] | ||||
|             { | ||||
|                 .cbBuffer = 0, | ||||
|                 .BufferType = SECBUFFER_ALERT, | ||||
|                 .pvBuffer = nullptr, | ||||
|             }, // [1] | ||||
|         }}; | ||||
|         SecBufferDesc input_desc{ | ||||
|             .ulVersion = SECBUFFER_VERSION, | ||||
|             .cBuffers = static_cast<unsigned long>(input_buffers.size()), | ||||
|             .pBuffers = input_buffers.data(), | ||||
|         }; | ||||
|         SecBufferDesc output_desc{ | ||||
|             .ulVersion = SECBUFFER_VERSION, | ||||
|             .cBuffers = static_cast<unsigned long>(output_buffers.size()), | ||||
|             .pBuffers = output_buffers.data(), | ||||
|         }; | ||||
|         ASSERT_OR_EXECUTE_MSG( | ||||
|             input_buffers[0].cbBuffer == ciphertext_read_buf.size(), | ||||
|             { return ResultInternalError; }, "read buffer too large"); | ||||
|  | ||||
|         bool initial_call_done = handshake_state != HandshakeState::Initial; | ||||
|         if (initial_call_done) { | ||||
|             LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext", | ||||
|                       ciphertext_read_buf.size()); | ||||
|         } | ||||
|  | ||||
|         const SECURITY_STATUS ret = | ||||
|             InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt : nullptr, | ||||
|                                        // Caller ensured we have set a hostname: | ||||
|                                        const_cast<char*>(hostname.value().c_str()), req, | ||||
|                                        0, // Reserved1 | ||||
|                                        0, // TargetDataRep not used with Schannel | ||||
|                                        initial_call_done ? &input_desc : nullptr, | ||||
|                                        0, // Reserved2 | ||||
|                                        initial_call_done ? nullptr : &ctxt, &output_desc, &attr, | ||||
|                                        nullptr); // ptsExpiry | ||||
|  | ||||
|         if (output_buffers[0].pvBuffer) { | ||||
|             const std::span span(static_cast<u8*>(output_buffers[0].pvBuffer), | ||||
|                                  output_buffers[0].cbBuffer); | ||||
|             ciphertext_write_buf.insert(ciphertext_write_buf.end(), span.begin(), span.end()); | ||||
|             FreeContextBuffer(output_buffers[0].pvBuffer); | ||||
|         } | ||||
|  | ||||
|         if (output_buffers[1].pvBuffer) { | ||||
|             const std::span span(static_cast<u8*>(output_buffers[1].pvBuffer), | ||||
|                                  output_buffers[1].cbBuffer); | ||||
|             // The documentation doesn't explain what format this data is in. | ||||
|             LOG_DEBUG(Service_SSL, "Got a {}-byte alert buffer: {}", span.size(), | ||||
|                       Common::HexToString(span)); | ||||
|         } | ||||
|  | ||||
|         switch (ret) { | ||||
|         case SEC_I_CONTINUE_NEEDED: | ||||
|             LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED"); | ||||
|             if (input_buffers[1].BufferType == SECBUFFER_EXTRA) { | ||||
|                 LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer); | ||||
|                 ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf.size()); | ||||
|                 ciphertext_read_buf.erase(ciphertext_read_buf.begin(), | ||||
|                                           ciphertext_read_buf.end() - input_buffers[1].cbBuffer); | ||||
|             } else { | ||||
|                 ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY); | ||||
|                 ciphertext_read_buf.clear(); | ||||
|             } | ||||
|             handshake_state = HandshakeState::ContinueNeeded; | ||||
|             return ResultSuccess; | ||||
|         case SEC_E_INCOMPLETE_MESSAGE: | ||||
|             LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE"); | ||||
|             ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING); | ||||
|             read_buf_fill_size = input_buffers[1].cbBuffer; | ||||
|             handshake_state = HandshakeState::IncompleteMessage; | ||||
|             return ResultSuccess; | ||||
|         case SEC_E_OK: | ||||
|             LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK"); | ||||
|             ciphertext_read_buf.clear(); | ||||
|             handshake_state = HandshakeState::DoneAfterFlush; | ||||
|             return GrabStreamSizes(); | ||||
|         default: | ||||
|             LOG_ERROR(Service_SSL, | ||||
|                       "InitializeSecurityContext failed (probably certificate/protocol issue): {}", | ||||
|                       Common::NativeErrorToString(ret)); | ||||
|             handshake_state = HandshakeState::Error; | ||||
|             return ResultInternalError; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     Result GrabStreamSizes() { | ||||
|         const SECURITY_STATUS ret = | ||||
|             QueryContextAttributes(&ctxt, SECPKG_ATTR_STREAM_SIZES, &stream_sizes); | ||||
|         if (ret != SEC_E_OK) { | ||||
|             LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}", | ||||
|                       Common::NativeErrorToString(ret)); | ||||
|             handshake_state = HandshakeState::Error; | ||||
|             return ResultInternalError; | ||||
|         } | ||||
|         return ResultSuccess; | ||||
|     } | ||||
|  | ||||
|     ResultVal<size_t> Read(std::span<u8> data) override { | ||||
|         if (handshake_state != HandshakeState::Connected) { | ||||
|             LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake"); | ||||
|             return ResultInternalError; | ||||
|         } | ||||
|         if (data.size() == 0 || got_read_eof) { | ||||
|             return size_t(0); | ||||
|         } | ||||
|         while (1) { | ||||
|             if (!cleartext_read_buf.empty()) { | ||||
|                 const size_t read_size = std::min(cleartext_read_buf.size(), data.size()); | ||||
|                 std::memcpy(data.data(), cleartext_read_buf.data(), read_size); | ||||
|                 cleartext_read_buf.erase(cleartext_read_buf.begin(), | ||||
|                                          cleartext_read_buf.begin() + read_size); | ||||
|                 return read_size; | ||||
|             } | ||||
|             if (!ciphertext_read_buf.empty()) { | ||||
|                 SecBuffer empty{ | ||||
|                     .cbBuffer = 0, | ||||
|                     .BufferType = SECBUFFER_EMPTY, | ||||
|                     .pvBuffer = nullptr, | ||||
|                 }; | ||||
|                 std::array<SecBuffer, 5> buffers{{ | ||||
|                     { | ||||
|                         .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()), | ||||
|                         .BufferType = SECBUFFER_DATA, | ||||
|                         .pvBuffer = ciphertext_read_buf.data(), | ||||
|                     }, | ||||
|                     empty, | ||||
|                     empty, | ||||
|                     empty, | ||||
|                 }}; | ||||
|                 ASSERT_OR_EXECUTE_MSG( | ||||
|                     buffers[0].cbBuffer == ciphertext_read_buf.size(), | ||||
|                     { return ResultInternalError; }, "read buffer too large"); | ||||
|                 SecBufferDesc desc{ | ||||
|                     .ulVersion = SECBUFFER_VERSION, | ||||
|                     .cBuffers = static_cast<unsigned long>(buffers.size()), | ||||
|                     .pBuffers = buffers.data(), | ||||
|                 }; | ||||
|                 SECURITY_STATUS ret = | ||||
|                     DecryptMessage(&ctxt, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr); | ||||
|                 switch (ret) { | ||||
|                 case SEC_E_OK: | ||||
|                     ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER, | ||||
|                                       { return ResultInternalError; }); | ||||
|                     ASSERT_OR_EXECUTE(buffers[1].BufferType == SECBUFFER_DATA, | ||||
|                                       { return ResultInternalError; }); | ||||
|                     ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER, | ||||
|                                       { return ResultInternalError; }); | ||||
|                     cleartext_read_buf.assign(static_cast<u8*>(buffers[1].pvBuffer), | ||||
|                                               static_cast<u8*>(buffers[1].pvBuffer) + | ||||
|                                                   buffers[1].cbBuffer); | ||||
|                     if (buffers[3].BufferType == SECBUFFER_EXTRA) { | ||||
|                         ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf.size()); | ||||
|                         ciphertext_read_buf.erase(ciphertext_read_buf.begin(), | ||||
|                                                   ciphertext_read_buf.end() - buffers[3].cbBuffer); | ||||
|                     } else { | ||||
|                         ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY); | ||||
|                         ciphertext_read_buf.clear(); | ||||
|                     } | ||||
|                     continue; | ||||
|                 case SEC_E_INCOMPLETE_MESSAGE: | ||||
|                     break; | ||||
|                 case SEC_I_CONTEXT_EXPIRED: | ||||
|                     // Server hung up by sending close_notify. | ||||
|                     got_read_eof = true; | ||||
|                     return size_t(0); | ||||
|                 default: | ||||
|                     LOG_ERROR(Service_SSL, "DecryptMessage failed: {}", | ||||
|                               Common::NativeErrorToString(ret)); | ||||
|                     return ResultInternalError; | ||||
|                 } | ||||
|             } | ||||
|             const Result r = FillCiphertextReadBuf(); | ||||
|             if (r != ResultSuccess) { | ||||
|                 return r; | ||||
|             } | ||||
|             if (ciphertext_read_buf.empty()) { | ||||
|                 got_read_eof = true; | ||||
|                 return size_t(0); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     ResultVal<size_t> Write(std::span<const u8> data) override { | ||||
|         if (handshake_state != HandshakeState::Connected) { | ||||
|             LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake"); | ||||
|             return ResultInternalError; | ||||
|         } | ||||
|         if (data.size() == 0) { | ||||
|             return size_t(0); | ||||
|         } | ||||
|         data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes.cbMaximumMessage)); | ||||
|         if (!cleartext_write_buf.empty()) { | ||||
|             // Already in the middle of a write.  It wouldn't make sense to not | ||||
|             // finish sending the entire buffer since TLS has | ||||
|             // header/MAC/padding/etc. | ||||
|             if (data.size() != cleartext_write_buf.size() || | ||||
|                 std::memcmp(data.data(), cleartext_write_buf.data(), data.size())) { | ||||
|                 LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer"); | ||||
|                 return ResultInternalError; | ||||
|             } | ||||
|             return WriteAlreadyEncryptedData(); | ||||
|         } else { | ||||
|             cleartext_write_buf.assign(data.begin(), data.end()); | ||||
|         } | ||||
|  | ||||
|         std::vector<u8> header_buf(stream_sizes.cbHeader, 0); | ||||
|         std::vector<u8> tmp_data_buf = cleartext_write_buf; | ||||
|         std::vector<u8> trailer_buf(stream_sizes.cbTrailer, 0); | ||||
|  | ||||
|         std::array<SecBuffer, 3> buffers{{ | ||||
|             { | ||||
|                 .cbBuffer = stream_sizes.cbHeader, | ||||
|                 .BufferType = SECBUFFER_STREAM_HEADER, | ||||
|                 .pvBuffer = header_buf.data(), | ||||
|             }, | ||||
|             { | ||||
|                 .cbBuffer = static_cast<unsigned long>(tmp_data_buf.size()), | ||||
|                 .BufferType = SECBUFFER_DATA, | ||||
|                 .pvBuffer = tmp_data_buf.data(), | ||||
|             }, | ||||
|             { | ||||
|                 .cbBuffer = stream_sizes.cbTrailer, | ||||
|                 .BufferType = SECBUFFER_STREAM_TRAILER, | ||||
|                 .pvBuffer = trailer_buf.data(), | ||||
|             }, | ||||
|         }}; | ||||
|         ASSERT_OR_EXECUTE_MSG( | ||||
|             buffers[1].cbBuffer == tmp_data_buf.size(), { return ResultInternalError; }, | ||||
|             "temp buffer too large"); | ||||
|         SecBufferDesc desc{ | ||||
|             .ulVersion = SECBUFFER_VERSION, | ||||
|             .cBuffers = static_cast<unsigned long>(buffers.size()), | ||||
|             .pBuffers = buffers.data(), | ||||
|         }; | ||||
|  | ||||
|         const SECURITY_STATUS ret = EncryptMessage(&ctxt, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0); | ||||
|         if (ret != SEC_E_OK) { | ||||
|             LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret)); | ||||
|             return ResultInternalError; | ||||
|         } | ||||
|         ciphertext_write_buf.insert(ciphertext_write_buf.end(), header_buf.begin(), | ||||
|                                     header_buf.end()); | ||||
|         ciphertext_write_buf.insert(ciphertext_write_buf.end(), tmp_data_buf.begin(), | ||||
|                                     tmp_data_buf.end()); | ||||
|         ciphertext_write_buf.insert(ciphertext_write_buf.end(), trailer_buf.begin(), | ||||
|                                     trailer_buf.end()); | ||||
|         return WriteAlreadyEncryptedData(); | ||||
|     } | ||||
|  | ||||
|     ResultVal<size_t> WriteAlreadyEncryptedData() { | ||||
|         const Result r = FlushCiphertextWriteBuf(); | ||||
|         if (r != ResultSuccess) { | ||||
|             return r; | ||||
|         } | ||||
|         // write buf is empty | ||||
|         const size_t cleartext_bytes_written = cleartext_write_buf.size(); | ||||
|         cleartext_write_buf.clear(); | ||||
|         return cleartext_bytes_written; | ||||
|     } | ||||
|  | ||||
|     ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { | ||||
|         PCCERT_CONTEXT returned_cert = nullptr; | ||||
|         const SECURITY_STATUS ret = | ||||
|             QueryContextAttributes(&ctxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert); | ||||
|         if (ret != SEC_E_OK) { | ||||
|             LOG_ERROR(Service_SSL, | ||||
|                       "QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}", | ||||
|                       Common::NativeErrorToString(ret)); | ||||
|             return ResultInternalError; | ||||
|         } | ||||
|         PCCERT_CONTEXT some_cert = nullptr; | ||||
|         std::vector<std::vector<u8>> certs; | ||||
|         while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) { | ||||
|             certs.emplace_back(static_cast<u8*>(some_cert->pbCertEncoded), | ||||
|                                static_cast<u8*>(some_cert->pbCertEncoded) + | ||||
|                                    some_cert->cbCertEncoded); | ||||
|         } | ||||
|         std::reverse(certs.begin(), | ||||
|                      certs.end()); // Windows returns certs in reverse order from what we want | ||||
|         CertFreeCertificateContext(returned_cert); | ||||
|         return certs; | ||||
|     } | ||||
|  | ||||
|     ~SSLConnectionBackendSchannel() { | ||||
|         if (handshake_state != HandshakeState::Initial) { | ||||
|             DeleteSecurityContext(&ctxt); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     enum class HandshakeState { | ||||
|         // Haven't called anything yet. | ||||
|         Initial, | ||||
|         // `SEC_I_CONTINUE_NEEDED` was returned by | ||||
|         // `InitializeSecurityContext`; must finish sending data (if any) in | ||||
|         // the write buffer, then read at least one byte before calling | ||||
|         // `InitializeSecurityContext` again. | ||||
|         ContinueNeeded, | ||||
|         // `SEC_E_INCOMPLETE_MESSAGE` was returned by | ||||
|         // `InitializeSecurityContext`; hopefully the write buffer is empty; | ||||
|         // must read at least one byte before calling | ||||
|         // `InitializeSecurityContext` again. | ||||
|         IncompleteMessage, | ||||
|         // `SEC_E_OK` was returned by `InitializeSecurityContext`; must | ||||
|         // finish sending data in the write buffer before having `DoHandshake` | ||||
|         // report success. | ||||
|         DoneAfterFlush, | ||||
|         // We finished the above and are now connected.  At this point, writing | ||||
|         // and reading are separate 'state machines' represented by the | ||||
|         // nonemptiness of the ciphertext and cleartext read and write buffers. | ||||
|         Connected, | ||||
|         // Another error was returned and we shouldn't allow initialization | ||||
|         // to continue. | ||||
|         Error, | ||||
|     } handshake_state = HandshakeState::Initial; | ||||
|  | ||||
|     CtxtHandle ctxt; | ||||
|     SecPkgContext_StreamSizes stream_sizes; | ||||
|  | ||||
|     std::shared_ptr<Network::SocketBase> socket; | ||||
|     std::optional<std::string> hostname; | ||||
|  | ||||
|     std::vector<u8> ciphertext_read_buf; | ||||
|     std::vector<u8> ciphertext_write_buf; | ||||
|     std::vector<u8> cleartext_read_buf; | ||||
|     std::vector<u8> cleartext_write_buf; | ||||
|  | ||||
|     bool got_read_eof = false; | ||||
|     size_t read_buf_fill_size = 0; | ||||
| }; | ||||
|  | ||||
| ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { | ||||
|     auto conn = std::make_unique<SSLConnectionBackendSchannel>(); | ||||
|     const Result res = conn->Init(); | ||||
|     if (res.IsFailure()) { | ||||
|         return res; | ||||
|     } | ||||
|     return conn; | ||||
| } | ||||
|  | ||||
| } // namespace Service::SSL | ||||
							
								
								
									
										219
									
								
								src/core/hle/service/ssl/ssl_backend_securetransport.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										219
									
								
								src/core/hle/service/ssl/ssl_backend_securetransport.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,219 @@ | ||||
| // SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project | ||||
| // SPDX-License-Identifier: GPL-2.0-or-later | ||||
|  | ||||
| #include "core/hle/service/ssl/ssl_backend.h" | ||||
| #include "core/internal_network/network.h" | ||||
| #include "core/internal_network/sockets.h" | ||||
|  | ||||
| #include <mutex> | ||||
|  | ||||
| #include <Security/SecureTransport.h> | ||||
|  | ||||
| // SecureTransport has been deprecated in its entirety in favor of | ||||
| // Network.framework, but that does not allow layering TLS on top of an | ||||
| // arbitrary socket. | ||||
| #pragma GCC diagnostic ignored "-Wdeprecated-declarations" | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| template <typename T> | ||||
| struct CFReleaser { | ||||
|     T ptr; | ||||
|  | ||||
|     YUZU_NON_COPYABLE(CFReleaser); | ||||
|     constexpr CFReleaser() : ptr(nullptr) {} | ||||
|     constexpr CFReleaser(T ptr) : ptr(ptr) {} | ||||
|     constexpr operator T() { | ||||
|         return ptr; | ||||
|     } | ||||
|     ~CFReleaser() { | ||||
|         if (ptr) { | ||||
|             CFRelease(ptr); | ||||
|         } | ||||
|     } | ||||
| }; | ||||
|  | ||||
| std::string CFStringToString(CFStringRef cfstr) { | ||||
|     CFReleaser<CFDataRef> cfdata( | ||||
|         CFStringCreateExternalRepresentation(nullptr, cfstr, kCFStringEncodingUTF8, 0)); | ||||
|     ASSERT_OR_EXECUTE(cfdata, { return "???"; }); | ||||
|     return std::string(reinterpret_cast<const char*>(CFDataGetBytePtr(cfdata)), | ||||
|                        CFDataGetLength(cfdata)); | ||||
| } | ||||
|  | ||||
| std::string OSStatusToString(OSStatus status) { | ||||
|     CFReleaser<CFStringRef> cfstr(SecCopyErrorMessageString(status, nullptr)); | ||||
|     if (!cfstr) { | ||||
|         return "[unknown error]"; | ||||
|     } | ||||
|     return CFStringToString(cfstr); | ||||
| } | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| namespace Service::SSL { | ||||
|  | ||||
| class SSLConnectionBackendSecureTransport final : public SSLConnectionBackend { | ||||
| public: | ||||
|     Result Init() { | ||||
|         static std::once_flag once_flag; | ||||
|         std::call_once(once_flag, []() { | ||||
|             if (getenv("SSLKEYLOGFILE")) { | ||||
|                 LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but SecureTransport does not " | ||||
|                                           "support exporting keys; not logging keys!"); | ||||
|                 // Not fatal. | ||||
|             } | ||||
|         }); | ||||
|  | ||||
|         context.ptr = SSLCreateContext(nullptr, kSSLClientSide, kSSLStreamType); | ||||
|         if (!context) { | ||||
|             LOG_ERROR(Service_SSL, "SSLCreateContext failed"); | ||||
|             return ResultInternalError; | ||||
|         } | ||||
|  | ||||
|         OSStatus status; | ||||
|         if ((status = SSLSetIOFuncs(context, ReadCallback, WriteCallback)) || | ||||
|             (status = SSLSetConnection(context, this))) { | ||||
|             LOG_ERROR(Service_SSL, "SSLContext initialization failed: {}", | ||||
|                       OSStatusToString(status)); | ||||
|             return ResultInternalError; | ||||
|         } | ||||
|  | ||||
|         return ResultSuccess; | ||||
|     } | ||||
|  | ||||
|     void SetSocket(std::shared_ptr<Network::SocketBase> in_socket) override { | ||||
|         socket = std::move(in_socket); | ||||
|     } | ||||
|  | ||||
|     Result SetHostName(const std::string& hostname) override { | ||||
|         OSStatus status = SSLSetPeerDomainName(context, hostname.c_str(), hostname.size()); | ||||
|         if (status) { | ||||
|             LOG_ERROR(Service_SSL, "SSLSetPeerDomainName failed: {}", OSStatusToString(status)); | ||||
|             return ResultInternalError; | ||||
|         } | ||||
|         return ResultSuccess; | ||||
|     } | ||||
|  | ||||
|     Result DoHandshake() override { | ||||
|         OSStatus status = SSLHandshake(context); | ||||
|         return HandleReturn("SSLHandshake", 0, status).Code(); | ||||
|     } | ||||
|  | ||||
|     ResultVal<size_t> Read(std::span<u8> data) override { | ||||
|         size_t actual; | ||||
|         OSStatus status = SSLRead(context, data.data(), data.size(), &actual); | ||||
|         ; | ||||
|         return HandleReturn("SSLRead", actual, status); | ||||
|     } | ||||
|  | ||||
|     ResultVal<size_t> Write(std::span<const u8> data) override { | ||||
|         size_t actual; | ||||
|         OSStatus status = SSLWrite(context, data.data(), data.size(), &actual); | ||||
|         ; | ||||
|         return HandleReturn("SSLWrite", actual, status); | ||||
|     } | ||||
|  | ||||
|     ResultVal<size_t> HandleReturn(const char* what, size_t actual, OSStatus status) { | ||||
|         switch (status) { | ||||
|         case 0: | ||||
|             return actual; | ||||
|         case errSSLWouldBlock: | ||||
|             return ResultWouldBlock; | ||||
|         default: { | ||||
|             std::string reason; | ||||
|             if (got_read_eof) { | ||||
|                 reason = "server hung up"; | ||||
|             } else { | ||||
|                 reason = OSStatusToString(status); | ||||
|             } | ||||
|             LOG_ERROR(Service_SSL, "{} failed: {}", what, reason); | ||||
|             return ResultInternalError; | ||||
|         } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { | ||||
|         CFReleaser<SecTrustRef> trust; | ||||
|         OSStatus status = SSLCopyPeerTrust(context, &trust.ptr); | ||||
|         if (status) { | ||||
|             LOG_ERROR(Service_SSL, "SSLCopyPeerTrust failed: {}", OSStatusToString(status)); | ||||
|             return ResultInternalError; | ||||
|         } | ||||
|         std::vector<std::vector<u8>> ret; | ||||
|         for (CFIndex i = 0, count = SecTrustGetCertificateCount(trust); i < count; i++) { | ||||
|             SecCertificateRef cert = SecTrustGetCertificateAtIndex(trust, i); | ||||
|             CFReleaser<CFDataRef> data(SecCertificateCopyData(cert)); | ||||
|             ASSERT_OR_EXECUTE(data, { return ResultInternalError; }); | ||||
|             const u8* ptr = CFDataGetBytePtr(data); | ||||
|             ret.emplace_back(ptr, ptr + CFDataGetLength(data)); | ||||
|         } | ||||
|         return ret; | ||||
|     } | ||||
|  | ||||
|     static OSStatus ReadCallback(SSLConnectionRef connection, void* data, size_t* dataLength) { | ||||
|         return ReadOrWriteCallback(connection, data, dataLength, true); | ||||
|     } | ||||
|  | ||||
|     static OSStatus WriteCallback(SSLConnectionRef connection, const void* data, | ||||
|                                   size_t* dataLength) { | ||||
|         return ReadOrWriteCallback(connection, const_cast<void*>(data), dataLength, false); | ||||
|     } | ||||
|  | ||||
|     static OSStatus ReadOrWriteCallback(SSLConnectionRef connection, void* data, size_t* dataLength, | ||||
|                                         bool is_read) { | ||||
|         auto self = | ||||
|             static_cast<SSLConnectionBackendSecureTransport*>(const_cast<void*>(connection)); | ||||
|         ASSERT_OR_EXECUTE_MSG( | ||||
|             self->socket, { return 0; }, "SecureTransport asked to {} but we have no socket", | ||||
|             is_read ? "read" : "write"); | ||||
|  | ||||
|         // SecureTransport callbacks (unlike OpenSSL BIO callbacks) are | ||||
|         // expected to read/write the full requested dataLength or return an | ||||
|         // error, so we have to add a loop ourselves. | ||||
|         size_t requested_len = *dataLength; | ||||
|         size_t offset = 0; | ||||
|         while (offset < requested_len) { | ||||
|             std::span cur(reinterpret_cast<u8*>(data) + offset, requested_len - offset); | ||||
|             auto [actual, err] = is_read ? self->socket->Recv(0, cur) : self->socket->Send(cur, 0); | ||||
|             LOG_CRITICAL(Service_SSL, "op={}, offset={} actual={}/{} err={}", is_read, offset, | ||||
|                          actual, cur.size(), static_cast<s32>(err)); | ||||
|             switch (err) { | ||||
|             case Network::Errno::SUCCESS: | ||||
|                 offset += actual; | ||||
|                 if (actual == 0) { | ||||
|                     ASSERT(is_read); | ||||
|                     self->got_read_eof = true; | ||||
|                     return errSecEndOfData; | ||||
|                 } | ||||
|                 break; | ||||
|             case Network::Errno::AGAIN: | ||||
|                 *dataLength = offset; | ||||
|                 return errSSLWouldBlock; | ||||
|             default: | ||||
|                 LOG_ERROR(Service_SSL, "Socket {} returned Network::Errno {}", | ||||
|                           is_read ? "recv" : "send", err); | ||||
|                 return errSecIO; | ||||
|             } | ||||
|         } | ||||
|         ASSERT(offset == requested_len); | ||||
|         return 0; | ||||
|     } | ||||
|  | ||||
| private: | ||||
|     CFReleaser<SSLContextRef> context = nullptr; | ||||
|     bool got_read_eof = false; | ||||
|  | ||||
|     std::shared_ptr<Network::SocketBase> socket; | ||||
| }; | ||||
|  | ||||
| ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { | ||||
|     auto conn = std::make_unique<SSLConnectionBackendSecureTransport>(); | ||||
|     const Result res = conn->Init(); | ||||
|     if (res.IsFailure()) { | ||||
|         return res; | ||||
|     } | ||||
|     return conn; | ||||
| } | ||||
|  | ||||
| } // namespace Service::SSL | ||||
| @@ -27,6 +27,7 @@ | ||||
|  | ||||
| #include "common/assert.h" | ||||
| #include "common/common_types.h" | ||||
| #include "common/expected.h" | ||||
| #include "common/logging/log.h" | ||||
| #include "common/settings.h" | ||||
| #include "core/internal_network/network.h" | ||||
| @@ -97,6 +98,8 @@ bool EnableNonBlock(SOCKET fd, bool enable) { | ||||
|  | ||||
| Errno TranslateNativeError(int e) { | ||||
|     switch (e) { | ||||
|     case 0: | ||||
|         return Errno::SUCCESS; | ||||
|     case WSAEBADF: | ||||
|         return Errno::BADF; | ||||
|     case WSAEINVAL: | ||||
| @@ -121,6 +124,8 @@ Errno TranslateNativeError(int e) { | ||||
|         return Errno::MSGSIZE; | ||||
|     case WSAETIMEDOUT: | ||||
|         return Errno::TIMEDOUT; | ||||
|     case WSAEINPROGRESS: | ||||
|         return Errno::INPROGRESS; | ||||
|     default: | ||||
|         UNIMPLEMENTED_MSG("Unimplemented errno={}", e); | ||||
|         return Errno::OTHER; | ||||
| @@ -195,6 +200,8 @@ bool EnableNonBlock(int fd, bool enable) { | ||||
|  | ||||
| Errno TranslateNativeError(int e) { | ||||
|     switch (e) { | ||||
|     case 0: | ||||
|         return Errno::SUCCESS; | ||||
|     case EBADF: | ||||
|         return Errno::BADF; | ||||
|     case EINVAL: | ||||
| @@ -219,8 +226,10 @@ Errno TranslateNativeError(int e) { | ||||
|         return Errno::MSGSIZE; | ||||
|     case ETIMEDOUT: | ||||
|         return Errno::TIMEDOUT; | ||||
|     case EINPROGRESS: | ||||
|         return Errno::INPROGRESS; | ||||
|     default: | ||||
|         UNIMPLEMENTED_MSG("Unimplemented errno={}", e); | ||||
|         UNIMPLEMENTED_MSG("Unimplemented errno={} ({})", e, strerror(e)); | ||||
|         return Errno::OTHER; | ||||
|     } | ||||
| } | ||||
| @@ -234,15 +243,84 @@ Errno GetAndLogLastError() { | ||||
|     int e = errno; | ||||
| #endif | ||||
|     const Errno err = TranslateNativeError(e); | ||||
|     if (err == Errno::AGAIN || err == Errno::TIMEDOUT) { | ||||
|     if (err == Errno::AGAIN || err == Errno::TIMEDOUT || err == Errno::INPROGRESS) { | ||||
|         // These happen during normal operation, so only log them at debug level. | ||||
|         LOG_DEBUG(Network, "Socket operation error: {}", Common::NativeErrorToString(e)); | ||||
|         return err; | ||||
|     } | ||||
|     LOG_ERROR(Network, "Socket operation error: {}", Common::NativeErrorToString(e)); | ||||
|     return err; | ||||
| } | ||||
|  | ||||
| int TranslateDomain(Domain domain) { | ||||
| GetAddrInfoError TranslateGetAddrInfoErrorFromNative(int gai_err) { | ||||
|     switch (gai_err) { | ||||
|     case 0: | ||||
|         return GetAddrInfoError::SUCCESS; | ||||
| #ifdef EAI_ADDRFAMILY | ||||
|     case EAI_ADDRFAMILY: | ||||
|         return GetAddrInfoError::ADDRFAMILY; | ||||
| #endif | ||||
|     case EAI_AGAIN: | ||||
|         return GetAddrInfoError::AGAIN; | ||||
|     case EAI_BADFLAGS: | ||||
|         return GetAddrInfoError::BADFLAGS; | ||||
|     case EAI_FAIL: | ||||
|         return GetAddrInfoError::FAIL; | ||||
|     case EAI_FAMILY: | ||||
|         return GetAddrInfoError::FAMILY; | ||||
|     case EAI_MEMORY: | ||||
|         return GetAddrInfoError::MEMORY; | ||||
|     case EAI_NONAME: | ||||
|         return GetAddrInfoError::NONAME; | ||||
|     case EAI_SERVICE: | ||||
|         return GetAddrInfoError::SERVICE; | ||||
|     case EAI_SOCKTYPE: | ||||
|         return GetAddrInfoError::SOCKTYPE; | ||||
|         // These codes may not be defined on all systems: | ||||
| #ifdef EAI_SYSTEM | ||||
|     case EAI_SYSTEM: | ||||
|         return GetAddrInfoError::SYSTEM; | ||||
| #endif | ||||
| #ifdef EAI_BADHINTS | ||||
|     case EAI_BADHINTS: | ||||
|         return GetAddrInfoError::BADHINTS; | ||||
| #endif | ||||
| #ifdef EAI_PROTOCOL | ||||
|     case EAI_PROTOCOL: | ||||
|         return GetAddrInfoError::PROTOCOL; | ||||
| #endif | ||||
| #ifdef EAI_OVERFLOW | ||||
|     case EAI_OVERFLOW: | ||||
|         return GetAddrInfoError::OVERFLOW_; | ||||
| #endif | ||||
|     default: | ||||
| #ifdef EAI_NODATA | ||||
|         // This can't be a case statement because it would create a duplicate | ||||
|         // case on Windows where EAI_NODATA is an alias for EAI_NONAME. | ||||
|         if (gai_err == EAI_NODATA) { | ||||
|             return GetAddrInfoError::NODATA; | ||||
|         } | ||||
| #endif | ||||
|         return GetAddrInfoError::OTHER; | ||||
|     } | ||||
| } | ||||
|  | ||||
| Domain TranslateDomainFromNative(int domain) { | ||||
|     switch (domain) { | ||||
|     case 0: | ||||
|         return Domain::Unspecified; | ||||
|     case AF_INET: | ||||
|         return Domain::INET; | ||||
|     default: | ||||
|         UNIMPLEMENTED_MSG("Unhandled domain={}", domain); | ||||
|         return Domain::INET; | ||||
|     } | ||||
| } | ||||
|  | ||||
| int TranslateDomainToNative(Domain domain) { | ||||
|     switch (domain) { | ||||
|     case Domain::Unspecified: | ||||
|         return 0; | ||||
|     case Domain::INET: | ||||
|         return AF_INET; | ||||
|     default: | ||||
| @@ -251,20 +329,58 @@ int TranslateDomain(Domain domain) { | ||||
|     } | ||||
| } | ||||
|  | ||||
| int TranslateType(Type type) { | ||||
| Type TranslateTypeFromNative(int type) { | ||||
|     switch (type) { | ||||
|     case 0: | ||||
|         return Type::Unspecified; | ||||
|     case SOCK_STREAM: | ||||
|         return Type::STREAM; | ||||
|     case SOCK_DGRAM: | ||||
|         return Type::DGRAM; | ||||
|     case SOCK_RAW: | ||||
|         return Type::RAW; | ||||
|     case SOCK_SEQPACKET: | ||||
|         return Type::SEQPACKET; | ||||
|     default: | ||||
|         UNIMPLEMENTED_MSG("Unimplemented type={}", type); | ||||
|         return Type::STREAM; | ||||
|     } | ||||
| } | ||||
|  | ||||
| int TranslateTypeToNative(Type type) { | ||||
|     switch (type) { | ||||
|     case Type::Unspecified: | ||||
|         return 0; | ||||
|     case Type::STREAM: | ||||
|         return SOCK_STREAM; | ||||
|     case Type::DGRAM: | ||||
|         return SOCK_DGRAM; | ||||
|     case Type::RAW: | ||||
|         return SOCK_RAW; | ||||
|     default: | ||||
|         UNIMPLEMENTED_MSG("Unimplemented type={}", type); | ||||
|         return 0; | ||||
|     } | ||||
| } | ||||
|  | ||||
| int TranslateProtocol(Protocol protocol) { | ||||
| Protocol TranslateProtocolFromNative(int protocol) { | ||||
|     switch (protocol) { | ||||
|     case 0: | ||||
|         return Protocol::Unspecified; | ||||
|     case IPPROTO_TCP: | ||||
|         return Protocol::TCP; | ||||
|     case IPPROTO_UDP: | ||||
|         return Protocol::UDP; | ||||
|     default: | ||||
|         UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol); | ||||
|         return Protocol::Unspecified; | ||||
|     } | ||||
| } | ||||
|  | ||||
| int TranslateProtocolToNative(Protocol protocol) { | ||||
|     switch (protocol) { | ||||
|     case Protocol::Unspecified: | ||||
|         return 0; | ||||
|     case Protocol::TCP: | ||||
|         return IPPROTO_TCP; | ||||
|     case Protocol::UDP: | ||||
| @@ -275,21 +391,10 @@ int TranslateProtocol(Protocol protocol) { | ||||
|     } | ||||
| } | ||||
|  | ||||
| SockAddrIn TranslateToSockAddrIn(sockaddr input_) { | ||||
|     sockaddr_in input; | ||||
|     std::memcpy(&input, &input_, sizeof(input)); | ||||
|  | ||||
| SockAddrIn TranslateToSockAddrIn(sockaddr_in input, size_t input_len) { | ||||
|     SockAddrIn result; | ||||
|  | ||||
|     switch (input.sin_family) { | ||||
|     case AF_INET: | ||||
|         result.family = Domain::INET; | ||||
|         break; | ||||
|     default: | ||||
|         UNIMPLEMENTED_MSG("Unhandled sockaddr family={}", input.sin_family); | ||||
|         result.family = Domain::INET; | ||||
|         break; | ||||
|     } | ||||
|     result.family = TranslateDomainFromNative(input.sin_family); | ||||
|  | ||||
|     result.portno = ntohs(input.sin_port); | ||||
|  | ||||
| @@ -301,22 +406,33 @@ SockAddrIn TranslateToSockAddrIn(sockaddr input_) { | ||||
| short TranslatePollEvents(PollEvents events) { | ||||
|     short result = 0; | ||||
|  | ||||
|     if (True(events & PollEvents::In)) { | ||||
|         events &= ~PollEvents::In; | ||||
|         result |= POLLIN; | ||||
|     } | ||||
|     if (True(events & PollEvents::Pri)) { | ||||
|         events &= ~PollEvents::Pri; | ||||
|     const auto translate = [&result, &events](PollEvents guest, short host) { | ||||
|         if (True(events & guest)) { | ||||
|             events &= ~guest; | ||||
|             result |= host; | ||||
|         } | ||||
|     }; | ||||
|  | ||||
|     translate(PollEvents::In, POLLIN); | ||||
|     translate(PollEvents::Pri, POLLPRI); | ||||
|     translate(PollEvents::Out, POLLOUT); | ||||
|     translate(PollEvents::Err, POLLERR); | ||||
|     translate(PollEvents::Hup, POLLHUP); | ||||
|     translate(PollEvents::Nval, POLLNVAL); | ||||
|     translate(PollEvents::RdNorm, POLLRDNORM); | ||||
|     translate(PollEvents::RdBand, POLLRDBAND); | ||||
|     translate(PollEvents::WrBand, POLLWRBAND); | ||||
|  | ||||
| #ifdef _WIN32 | ||||
|         LOG_WARNING(Service, "Winsock doesn't support POLLPRI"); | ||||
| #else | ||||
|         result |= POLLPRI; | ||||
|     short allowed_events = POLLRDBAND | POLLRDNORM | POLLWRNORM; | ||||
|     // Unlike poll on other OSes, WSAPoll will complain if any other flags are set on input. | ||||
|     if (result & ~allowed_events) { | ||||
|         LOG_DEBUG(Network, | ||||
|                   "Removing WSAPoll input events 0x{:x} because Windows doesn't support them", | ||||
|                   result & ~allowed_events); | ||||
|     } | ||||
|     result &= allowed_events; | ||||
| #endif | ||||
|     } | ||||
|     if (True(events & PollEvents::Out)) { | ||||
|         events &= ~PollEvents::Out; | ||||
|         result |= POLLOUT; | ||||
|     } | ||||
|  | ||||
|     UNIMPLEMENTED_IF_MSG((u16)events != 0, "Unhandled guest events=0x{:x}", (u16)events); | ||||
|  | ||||
| @@ -337,6 +453,10 @@ PollEvents TranslatePollRevents(short revents) { | ||||
|     translate(POLLOUT, PollEvents::Out); | ||||
|     translate(POLLERR, PollEvents::Err); | ||||
|     translate(POLLHUP, PollEvents::Hup); | ||||
|     translate(POLLNVAL, PollEvents::Nval); | ||||
|     translate(POLLRDNORM, PollEvents::RdNorm); | ||||
|     translate(POLLRDBAND, PollEvents::RdBand); | ||||
|     translate(POLLWRBAND, PollEvents::WrBand); | ||||
|  | ||||
|     UNIMPLEMENTED_IF_MSG(revents != 0, "Unhandled host revents=0x{:x}", revents); | ||||
|  | ||||
| @@ -360,12 +480,51 @@ std::optional<IPv4Address> GetHostIPv4Address() { | ||||
|         return {}; | ||||
|     } | ||||
|  | ||||
|     std::array<char, 16> ip_addr = {}; | ||||
|     ASSERT(inet_ntop(AF_INET, &network_interface->ip_address, ip_addr.data(), sizeof(ip_addr)) != | ||||
|            nullptr); | ||||
|     return TranslateIPv4(network_interface->ip_address); | ||||
| } | ||||
|  | ||||
| std::string IPv4AddressToString(IPv4Address ip_addr) { | ||||
|     std::array<char, INET_ADDRSTRLEN> buf = {}; | ||||
|     ASSERT(inet_ntop(AF_INET, &ip_addr, buf.data(), sizeof(buf)) == buf.data()); | ||||
|     return std::string(buf.data()); | ||||
| } | ||||
|  | ||||
| u32 IPv4AddressToInteger(IPv4Address ip_addr) { | ||||
|     return static_cast<u32>(ip_addr[0]) << 24 | static_cast<u32>(ip_addr[1]) << 16 | | ||||
|            static_cast<u32>(ip_addr[2]) << 8 | static_cast<u32>(ip_addr[3]); | ||||
| } | ||||
|  | ||||
| Common::Expected<std::vector<AddrInfo>, GetAddrInfoError> GetAddressInfo( | ||||
|     const std::string& host, const std::optional<std::string>& service) { | ||||
|     addrinfo hints{}; | ||||
|     hints.ai_family = AF_INET; // Switch only supports IPv4. | ||||
|     addrinfo* addrinfo; | ||||
|     s32 gai_err = getaddrinfo(host.c_str(), service.has_value() ? service->c_str() : nullptr, | ||||
|                               &hints, &addrinfo); | ||||
|     if (gai_err != 0) { | ||||
|         return Common::Unexpected(TranslateGetAddrInfoErrorFromNative(gai_err)); | ||||
|     } | ||||
|     std::vector<AddrInfo> ret; | ||||
|     for (auto* current = addrinfo; current; current = current->ai_next) { | ||||
|         // We should only get AF_INET results due to the hints value. | ||||
|         ASSERT_OR_EXECUTE(addrinfo->ai_family == AF_INET && | ||||
|                               addrinfo->ai_addrlen == sizeof(sockaddr_in), | ||||
|                           continue;); | ||||
|  | ||||
|         AddrInfo& out = ret.emplace_back(); | ||||
|         out.family = TranslateDomainFromNative(current->ai_family); | ||||
|         out.socket_type = TranslateTypeFromNative(current->ai_socktype); | ||||
|         out.protocol = TranslateProtocolFromNative(current->ai_protocol); | ||||
|         out.addr = TranslateToSockAddrIn(*reinterpret_cast<sockaddr_in*>(current->ai_addr), | ||||
|                                          current->ai_addrlen); | ||||
|         if (current->ai_canonname != nullptr) { | ||||
|             out.canon_name = current->ai_canonname; | ||||
|         } | ||||
|     } | ||||
|     freeaddrinfo(addrinfo); | ||||
|     return ret; | ||||
| } | ||||
|  | ||||
| std::pair<s32, Errno> Poll(std::vector<PollFD>& pollfds, s32 timeout) { | ||||
|     const size_t num = pollfds.size(); | ||||
|  | ||||
| @@ -411,9 +570,21 @@ Socket::Socket(Socket&& rhs) noexcept { | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) { | ||||
| std::pair<T, Errno> Socket::GetSockOpt(SOCKET fd_so, int option) { | ||||
|     T value{}; | ||||
|     socklen_t len = sizeof(value); | ||||
|     const int result = getsockopt(fd_so, SOL_SOCKET, option, reinterpret_cast<char*>(&value), &len); | ||||
|     if (result != SOCKET_ERROR) { | ||||
|         ASSERT(len == sizeof(value)); | ||||
|         return {value, Errno::SUCCESS}; | ||||
|     } | ||||
|     return {value, GetAndLogLastError()}; | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| Errno Socket::SetSockOpt(SOCKET fd_so, int option, T value) { | ||||
|     const int result = | ||||
|         setsockopt(fd_, SOL_SOCKET, option, reinterpret_cast<const char*>(&value), sizeof(value)); | ||||
|         setsockopt(fd_so, SOL_SOCKET, option, reinterpret_cast<const char*>(&value), sizeof(value)); | ||||
|     if (result != SOCKET_ERROR) { | ||||
|         return Errno::SUCCESS; | ||||
|     } | ||||
| @@ -421,7 +592,8 @@ Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) { | ||||
| } | ||||
|  | ||||
| Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) { | ||||
|     fd = socket(TranslateDomain(domain), TranslateType(type), TranslateProtocol(protocol)); | ||||
|     fd = socket(TranslateDomainToNative(domain), TranslateTypeToNative(type), | ||||
|                 TranslateProtocolToNative(protocol)); | ||||
|     if (fd != INVALID_SOCKET) { | ||||
|         return Errno::SUCCESS; | ||||
|     } | ||||
| @@ -430,19 +602,17 @@ Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) { | ||||
| } | ||||
|  | ||||
| std::pair<SocketBase::AcceptResult, Errno> Socket::Accept() { | ||||
|     sockaddr addr; | ||||
|     sockaddr_in addr; | ||||
|     socklen_t addrlen = sizeof(addr); | ||||
|     const SOCKET new_socket = accept(fd, &addr, &addrlen); | ||||
|     const SOCKET new_socket = accept(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen); | ||||
|  | ||||
|     if (new_socket == INVALID_SOCKET) { | ||||
|         return {AcceptResult{}, GetAndLogLastError()}; | ||||
|     } | ||||
|  | ||||
|     ASSERT(addrlen == sizeof(sockaddr_in)); | ||||
|  | ||||
|     AcceptResult result{ | ||||
|         .socket = std::make_unique<Socket>(new_socket), | ||||
|         .sockaddr_in = TranslateToSockAddrIn(addr), | ||||
|         .sockaddr_in = TranslateToSockAddrIn(addr, addrlen), | ||||
|     }; | ||||
|  | ||||
|     return {std::move(result), Errno::SUCCESS}; | ||||
| @@ -458,25 +628,23 @@ Errno Socket::Connect(SockAddrIn addr_in) { | ||||
| } | ||||
|  | ||||
| std::pair<SockAddrIn, Errno> Socket::GetPeerName() { | ||||
|     sockaddr addr; | ||||
|     sockaddr_in addr; | ||||
|     socklen_t addrlen = sizeof(addr); | ||||
|     if (getpeername(fd, &addr, &addrlen) == SOCKET_ERROR) { | ||||
|     if (getpeername(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen) == SOCKET_ERROR) { | ||||
|         return {SockAddrIn{}, GetAndLogLastError()}; | ||||
|     } | ||||
|  | ||||
|     ASSERT(addrlen == sizeof(sockaddr_in)); | ||||
|     return {TranslateToSockAddrIn(addr), Errno::SUCCESS}; | ||||
|     return {TranslateToSockAddrIn(addr, addrlen), Errno::SUCCESS}; | ||||
| } | ||||
|  | ||||
| std::pair<SockAddrIn, Errno> Socket::GetSockName() { | ||||
|     sockaddr addr; | ||||
|     sockaddr_in addr; | ||||
|     socklen_t addrlen = sizeof(addr); | ||||
|     if (getsockname(fd, &addr, &addrlen) == SOCKET_ERROR) { | ||||
|     if (getsockname(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen) == SOCKET_ERROR) { | ||||
|         return {SockAddrIn{}, GetAndLogLastError()}; | ||||
|     } | ||||
|  | ||||
|     ASSERT(addrlen == sizeof(sockaddr_in)); | ||||
|     return {TranslateToSockAddrIn(addr), Errno::SUCCESS}; | ||||
|     return {TranslateToSockAddrIn(addr, addrlen), Errno::SUCCESS}; | ||||
| } | ||||
|  | ||||
| Errno Socket::Bind(SockAddrIn addr) { | ||||
| @@ -519,7 +687,7 @@ Errno Socket::Shutdown(ShutdownHow how) { | ||||
|     return GetAndLogLastError(); | ||||
| } | ||||
|  | ||||
| std::pair<s32, Errno> Socket::Recv(int flags, std::vector<u8>& message) { | ||||
| std::pair<s32, Errno> Socket::Recv(int flags, std::span<u8> message) { | ||||
|     ASSERT(flags == 0); | ||||
|     ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); | ||||
|  | ||||
| @@ -532,21 +700,20 @@ std::pair<s32, Errno> Socket::Recv(int flags, std::vector<u8>& message) { | ||||
|     return {-1, GetAndLogLastError()}; | ||||
| } | ||||
|  | ||||
| std::pair<s32, Errno> Socket::RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) { | ||||
| std::pair<s32, Errno> Socket::RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) { | ||||
|     ASSERT(flags == 0); | ||||
|     ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); | ||||
|  | ||||
|     sockaddr addr_in{}; | ||||
|     sockaddr_in addr_in{}; | ||||
|     socklen_t addrlen = sizeof(addr_in); | ||||
|     socklen_t* const p_addrlen = addr ? &addrlen : nullptr; | ||||
|     sockaddr* const p_addr_in = addr ? &addr_in : nullptr; | ||||
|     sockaddr* const p_addr_in = addr ? reinterpret_cast<sockaddr*>(&addr_in) : nullptr; | ||||
|  | ||||
|     const auto result = recvfrom(fd, reinterpret_cast<char*>(message.data()), | ||||
|                                  static_cast<int>(message.size()), 0, p_addr_in, p_addrlen); | ||||
|     if (result != SOCKET_ERROR) { | ||||
|         if (addr) { | ||||
|             ASSERT(addrlen == sizeof(addr_in)); | ||||
|             *addr = TranslateToSockAddrIn(addr_in); | ||||
|             *addr = TranslateToSockAddrIn(addr_in, addrlen); | ||||
|         } | ||||
|         return {static_cast<s32>(result), Errno::SUCCESS}; | ||||
|     } | ||||
| @@ -597,6 +764,11 @@ Errno Socket::Close() { | ||||
|     return Errno::SUCCESS; | ||||
| } | ||||
|  | ||||
| std::pair<Errno, Errno> Socket::GetPendingError() { | ||||
|     auto [pending_err, getsockopt_err] = GetSockOpt<int>(fd, SO_ERROR); | ||||
|     return {TranslateNativeError(pending_err), getsockopt_err}; | ||||
| } | ||||
|  | ||||
| Errno Socket::SetLinger(bool enable, u32 linger) { | ||||
|     return SetSockOpt(fd, SO_LINGER, MakeLinger(enable, linger)); | ||||
| } | ||||
|   | ||||
| @@ -5,6 +5,7 @@ | ||||
|  | ||||
| #include <array> | ||||
| #include <optional> | ||||
| #include <vector> | ||||
|  | ||||
| #include "common/common_funcs.h" | ||||
| #include "common/common_types.h" | ||||
| @@ -16,6 +17,11 @@ | ||||
| #include <netinet/in.h> | ||||
| #endif | ||||
|  | ||||
| namespace Common { | ||||
| template <typename T, typename E> | ||||
| class Expected; | ||||
| } | ||||
|  | ||||
| namespace Network { | ||||
|  | ||||
| class SocketBase; | ||||
| @@ -36,6 +42,26 @@ enum class Errno { | ||||
|     NETUNREACH, | ||||
|     TIMEDOUT, | ||||
|     MSGSIZE, | ||||
|     INPROGRESS, | ||||
|     OTHER, | ||||
| }; | ||||
|  | ||||
| enum class GetAddrInfoError { | ||||
|     SUCCESS, | ||||
|     ADDRFAMILY, | ||||
|     AGAIN, | ||||
|     BADFLAGS, | ||||
|     FAIL, | ||||
|     FAMILY, | ||||
|     MEMORY, | ||||
|     NODATA, | ||||
|     NONAME, | ||||
|     SERVICE, | ||||
|     SOCKTYPE, | ||||
|     SYSTEM, | ||||
|     BADHINTS, | ||||
|     PROTOCOL, | ||||
|     OVERFLOW_, | ||||
|     OTHER, | ||||
| }; | ||||
|  | ||||
| @@ -49,6 +75,9 @@ enum class PollEvents : u16 { | ||||
|     Err = 1 << 3, | ||||
|     Hup = 1 << 4, | ||||
|     Nval = 1 << 5, | ||||
|     RdNorm = 1 << 6, | ||||
|     RdBand = 1 << 7, | ||||
|     WrBand = 1 << 8, | ||||
| }; | ||||
|  | ||||
| DECLARE_ENUM_FLAG_OPERATORS(PollEvents); | ||||
| @@ -82,4 +111,11 @@ constexpr IPv4Address TranslateIPv4(in_addr addr) { | ||||
| /// @return human ordered IPv4 address (e.g. 192.168.0.1) as an array | ||||
| std::optional<IPv4Address> GetHostIPv4Address(); | ||||
|  | ||||
| std::string IPv4AddressToString(IPv4Address ip_addr); | ||||
| u32 IPv4AddressToInteger(IPv4Address ip_addr); | ||||
|  | ||||
| // named to avoid name collision with Windows macro | ||||
| Common::Expected<std::vector<AddrInfo>, GetAddrInfoError> GetAddressInfo( | ||||
|     const std::string& host, const std::optional<std::string>& service); | ||||
|  | ||||
| } // namespace Network | ||||
|   | ||||
| @@ -98,7 +98,7 @@ Errno ProxySocket::Shutdown(ShutdownHow how) { | ||||
|     return Errno::SUCCESS; | ||||
| } | ||||
|  | ||||
| std::pair<s32, Errno> ProxySocket::Recv(int flags, std::vector<u8>& message) { | ||||
| std::pair<s32, Errno> ProxySocket::Recv(int flags, std::span<u8> message) { | ||||
|     LOG_WARNING(Network, "(STUBBED) called"); | ||||
|     ASSERT(flags == 0); | ||||
|     ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); | ||||
| @@ -106,7 +106,7 @@ std::pair<s32, Errno> ProxySocket::Recv(int flags, std::vector<u8>& message) { | ||||
|     return {static_cast<s32>(0), Errno::SUCCESS}; | ||||
| } | ||||
|  | ||||
| std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) { | ||||
| std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) { | ||||
|     ASSERT(flags == 0); | ||||
|     ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); | ||||
|  | ||||
| @@ -140,8 +140,8 @@ std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::vector<u8>& message, | ||||
|     } | ||||
| } | ||||
|  | ||||
| std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& message, | ||||
|                                                  SockAddrIn* addr, std::size_t max_length) { | ||||
| std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::span<u8> message, SockAddrIn* addr, | ||||
|                                                  std::size_t max_length) { | ||||
|     ProxyPacket& packet = received_packets.front(); | ||||
|     if (addr) { | ||||
|         addr->family = Domain::INET; | ||||
| @@ -153,10 +153,7 @@ std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& mes | ||||
|     std::size_t read_bytes; | ||||
|     if (packet.data.size() > max_length) { | ||||
|         read_bytes = max_length; | ||||
|         message.clear(); | ||||
|         std::copy(packet.data.begin(), packet.data.begin() + read_bytes, | ||||
|                   std::back_inserter(message)); | ||||
|         message.resize(max_length); | ||||
|         memcpy(message.data(), packet.data.data(), max_length); | ||||
|  | ||||
|         if (protocol == Protocol::UDP) { | ||||
|             if (!peek) { | ||||
| @@ -171,9 +168,7 @@ std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& mes | ||||
|         } | ||||
|     } else { | ||||
|         read_bytes = packet.data.size(); | ||||
|         message.clear(); | ||||
|         std::copy(packet.data.begin(), packet.data.end(), std::back_inserter(message)); | ||||
|         message.resize(max_length); | ||||
|         memcpy(message.data(), packet.data.data(), read_bytes); | ||||
|         if (!peek) { | ||||
|             received_packets.pop(); | ||||
|         } | ||||
| @@ -293,6 +288,11 @@ Errno ProxySocket::SetNonBlock(bool enable) { | ||||
|     return Errno::SUCCESS; | ||||
| } | ||||
|  | ||||
| std::pair<Errno, Errno> ProxySocket::GetPendingError() { | ||||
|     LOG_DEBUG(Network, "(STUBBED) called"); | ||||
|     return {Errno::SUCCESS, Errno::SUCCESS}; | ||||
| } | ||||
|  | ||||
| bool ProxySocket::IsOpened() const { | ||||
|     return fd != INVALID_SOCKET; | ||||
| } | ||||
|   | ||||
| @@ -39,11 +39,11 @@ public: | ||||
|  | ||||
|     Errno Shutdown(ShutdownHow how) override; | ||||
|  | ||||
|     std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) override; | ||||
|     std::pair<s32, Errno> Recv(int flags, std::span<u8> message) override; | ||||
|  | ||||
|     std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) override; | ||||
|     std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) override; | ||||
|  | ||||
|     std::pair<s32, Errno> ReceivePacket(int flags, std::vector<u8>& message, SockAddrIn* addr, | ||||
|     std::pair<s32, Errno> ReceivePacket(int flags, std::span<u8> message, SockAddrIn* addr, | ||||
|                                         std::size_t max_length); | ||||
|  | ||||
|     std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override; | ||||
| @@ -74,6 +74,8 @@ public: | ||||
|     template <typename T> | ||||
|     Errno SetSockOpt(SOCKET fd, int option, T value); | ||||
|  | ||||
|     std::pair<Errno, Errno> GetPendingError() override; | ||||
|  | ||||
|     bool IsOpened() const override; | ||||
|  | ||||
| private: | ||||
|   | ||||
| @@ -59,10 +59,9 @@ public: | ||||
|  | ||||
|     virtual Errno Shutdown(ShutdownHow how) = 0; | ||||
|  | ||||
|     virtual std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) = 0; | ||||
|     virtual std::pair<s32, Errno> Recv(int flags, std::span<u8> message) = 0; | ||||
|  | ||||
|     virtual std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, | ||||
|                                            SockAddrIn* addr) = 0; | ||||
|     virtual std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) = 0; | ||||
|  | ||||
|     virtual std::pair<s32, Errno> Send(std::span<const u8> message, int flags) = 0; | ||||
|  | ||||
| @@ -87,6 +86,8 @@ public: | ||||
|  | ||||
|     virtual Errno SetNonBlock(bool enable) = 0; | ||||
|  | ||||
|     virtual std::pair<Errno, Errno> GetPendingError() = 0; | ||||
|  | ||||
|     virtual bool IsOpened() const = 0; | ||||
|  | ||||
|     virtual void HandleProxyPacket(const ProxyPacket& packet) = 0; | ||||
| @@ -126,9 +127,9 @@ public: | ||||
|  | ||||
|     Errno Shutdown(ShutdownHow how) override; | ||||
|  | ||||
|     std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) override; | ||||
|     std::pair<s32, Errno> Recv(int flags, std::span<u8> message) override; | ||||
|  | ||||
|     std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) override; | ||||
|     std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) override; | ||||
|  | ||||
|     std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override; | ||||
|  | ||||
| @@ -156,6 +157,11 @@ public: | ||||
|     template <typename T> | ||||
|     Errno SetSockOpt(SOCKET fd, int option, T value); | ||||
|  | ||||
|     std::pair<Errno, Errno> GetPendingError() override; | ||||
|  | ||||
|     template <typename T> | ||||
|     std::pair<T, Errno> GetSockOpt(SOCKET fd, int option); | ||||
|  | ||||
|     bool IsOpened() const override; | ||||
|  | ||||
|     void HandleProxyPacket(const ProxyPacket& packet) override; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 liamwhite
					liamwhite