mirror of
https://github.com/citra-emu/citra.git
synced 2024-12-18 12:30:04 +00:00
Service/UDS: Updated BeginHostingNetwork
This commit is contained in:
parent
f6d16c3f87
commit
ed9db735a2
@ -4,6 +4,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <cstring>
|
||||
#include <list>
|
||||
#include <mutex>
|
||||
@ -27,6 +28,12 @@
|
||||
namespace Service {
|
||||
namespace NWM {
|
||||
|
||||
namespace ErrCodes {
|
||||
enum {
|
||||
NotInitialized = 2,
|
||||
};
|
||||
} // namespace ErrCodes
|
||||
|
||||
// Event that is signaled every time the connection status changes.
|
||||
static Kernel::SharedPtr<Kernel::Event> connection_status_event;
|
||||
|
||||
@ -37,6 +44,8 @@ static Kernel::SharedPtr<Kernel::SharedMemory> recv_buffer_memory;
|
||||
// Connection status of this 3DS.
|
||||
static ConnectionStatus connection_status{};
|
||||
|
||||
static std::atomic<bool> initialized(false);
|
||||
|
||||
/* Node information about the current network.
|
||||
* The amount of elements in this vector is always the maximum number
|
||||
* of nodes specified in the network configuration.
|
||||
@ -155,7 +164,7 @@ void HandleAssociationResponseFrame(const Network::WifiPacket& packet) {
|
||||
"Could not join network");
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(connection_status_mutex);
|
||||
ASSERT(connection_status.status == static_cast<u32>(NetworkStatus::NotConnected));
|
||||
ASSERT(connection_status.status == static_cast<u32>(NetworkStatus::Connecting));
|
||||
}
|
||||
|
||||
// Send the EAPoL-Start packet to the server.
|
||||
@ -171,8 +180,9 @@ void HandleAssociationResponseFrame(const Network::WifiPacket& packet) {
|
||||
}
|
||||
|
||||
static void HandleEAPoLPacket(const Network::WifiPacket& packet) {
|
||||
std::lock_guard<std::recursive_mutex> hle_lock(HLE::g_hle_lock);
|
||||
std::lock_guard<std::mutex> lock(connection_status_mutex);
|
||||
std::unique_lock<std::recursive_mutex> hle_lock(HLE::g_hle_lock, std::defer_lock);
|
||||
std::unique_lock<std::mutex> lock(connection_status_mutex, std::defer_lock);
|
||||
std::lock(hle_lock, lock);
|
||||
|
||||
if (GetEAPoLFrameType(packet.data) == EAPoLStartMagic) {
|
||||
if (connection_status.status != static_cast<u32>(NetworkStatus::ConnectedAsHost)) {
|
||||
@ -220,7 +230,7 @@ static void HandleEAPoLPacket(const Network::WifiPacket& packet) {
|
||||
// The 3ds does this presumably to support spectators.
|
||||
connection_status_event->Signal();
|
||||
} else {
|
||||
if (connection_status.status != static_cast<u32>(NetworkStatus::NotConnected)) {
|
||||
if (connection_status.status != static_cast<u32>(NetworkStatus::Connecting)) {
|
||||
LOG_DEBUG(Service_NWM, "Connection sequence aborted, because connection status is %u",
|
||||
connection_status.status);
|
||||
return;
|
||||
@ -249,15 +259,15 @@ static void HandleEAPoLPacket(const Network::WifiPacket& packet) {
|
||||
// Some games require ConnectToNetwork to block, for now it doesn't
|
||||
// If blocking is implemented this lock needs to be changed,
|
||||
// otherwise it might cause deadlocks
|
||||
std::lock_guard<std::recursive_mutex> lock(HLE::g_hle_lock);
|
||||
connection_status_event->Signal();
|
||||
}
|
||||
}
|
||||
|
||||
static void HandleSecureDataPacket(const Network::WifiPacket& packet) {
|
||||
auto secure_data = ParseSecureDataHeader(packet.data);
|
||||
std::lock_guard<std::recursive_mutex> hle_lock(HLE::g_hle_lock);
|
||||
std::lock_guard<std::mutex> lock(connection_status_mutex);
|
||||
std::unique_lock<std::recursive_mutex> hle_lock(HLE::g_hle_lock, std::defer_lock);
|
||||
std::unique_lock<std::mutex> lock(connection_status_mutex, std::defer_lock);
|
||||
std::lock(hle_lock, lock);
|
||||
|
||||
if (secure_data.src_node_id == connection_status.network_node_id) {
|
||||
// Ignore packets that came from ourselves.
|
||||
@ -315,7 +325,7 @@ void StartConnectionSequence(const MacAddress& server) {
|
||||
WifiPacket auth_request;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(connection_status_mutex);
|
||||
ASSERT(connection_status.status == static_cast<u32>(NetworkStatus::NotConnected));
|
||||
connection_status.status = static_cast<u32>(NetworkStatus::Connecting);
|
||||
|
||||
// TODO(Subv): Handle timeout.
|
||||
|
||||
@ -546,6 +556,8 @@ static void InitializeWithVersion(Interface* self) {
|
||||
|
||||
recv_buffer_memory = Kernel::g_handle_table.Get<Kernel::SharedMemory>(sharedmem_handle);
|
||||
|
||||
initialized = true;
|
||||
|
||||
ASSERT_MSG(recv_buffer_memory->size == sharedmem_size, "Invalid shared memory size.");
|
||||
|
||||
{
|
||||
@ -614,8 +626,12 @@ static void GetNodeInformation(Interface* self) {
|
||||
IPC::RequestParser rp(Kernel::GetCommandBuffer(), 0xD, 1, 0);
|
||||
u16 network_node_id = rp.Pop<u16>();
|
||||
|
||||
IPC::RequestBuilder rb = rp.MakeBuilder(11, 0);
|
||||
rb.Push(RESULT_SUCCESS);
|
||||
if (!initialized) {
|
||||
IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
|
||||
rb.Push(ResultCode(ErrorDescription::NotInitialized, ErrorModule::UDS,
|
||||
ErrorSummary::StatusChanged, ErrorLevel::Status));
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(connection_status_mutex);
|
||||
@ -623,7 +639,15 @@ static void GetNodeInformation(Interface* self) {
|
||||
[network_node_id](const NodeInfo& node) {
|
||||
return node.network_node_id == network_node_id;
|
||||
});
|
||||
ASSERT(itr != node_info.end());
|
||||
if (itr == node_info.end()) {
|
||||
IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
|
||||
rb.Push(ResultCode(ErrorDescription::NotFound, ErrorModule::UDS,
|
||||
ErrorSummary::WrongArgument, ErrorLevel::Status));
|
||||
return;
|
||||
}
|
||||
|
||||
IPC::RequestBuilder rb = rp.MakeBuilder(11, 0);
|
||||
rb.Push(RESULT_SUCCESS);
|
||||
rb.PushRaw<NodeInfo>(*itr);
|
||||
}
|
||||
LOG_DEBUG(Service_NWM, "called");
|
||||
@ -653,13 +677,29 @@ static void Bind(Interface* self) {
|
||||
|
||||
LOG_DEBUG(Service_NWM, "called");
|
||||
|
||||
if (data_channel == 0) {
|
||||
if (data_channel == 0 || bind_node_id == 0) {
|
||||
IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
|
||||
rb.Push(ResultCode(ErrorDescription::NotAuthorized, ErrorModule::UDS,
|
||||
ErrorSummary::WrongArgument, ErrorLevel::Usage));
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr size_t MaxBindNodes = 16;
|
||||
if (channel_data.size() >= MaxBindNodes) {
|
||||
IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
|
||||
rb.Push(ResultCode(ErrorDescription::OutOfMemory, ErrorModule::UDS,
|
||||
ErrorSummary::OutOfResource, ErrorLevel::Status));
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr u32 MinRecvBufferSize = 0x5F4;
|
||||
if (recv_buffer_size < MinRecvBufferSize) {
|
||||
IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
|
||||
rb.Push(ResultCode(ErrorDescription::TooLarge, ErrorModule::UDS,
|
||||
ErrorSummary::WrongArgument, ErrorLevel::Usage));
|
||||
return;
|
||||
}
|
||||
|
||||
// Create a new event for this bind node.
|
||||
auto event = Kernel::Event::Create(Kernel::ResetType::OneShot,
|
||||
"NWM::BindNodeEvent" + std::to_string(bind_node_id));
|
||||
@ -687,6 +727,12 @@ static void Unbind(Interface* self) {
|
||||
IPC::RequestParser rp(Kernel::GetCommandBuffer(), 0x12, 1, 0);
|
||||
|
||||
u32 bind_node_id = rp.Pop<u32>();
|
||||
if (bind_node_id == 0) {
|
||||
IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
|
||||
rb.Push(ResultCode(ErrorDescription::NotAuthorized, ErrorModule::UDS,
|
||||
ErrorSummary::WrongArgument, ErrorLevel::Usage));
|
||||
return;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(connection_status_mutex);
|
||||
|
||||
@ -699,8 +745,13 @@ static void Unbind(Interface* self) {
|
||||
channel_data.erase(itr);
|
||||
}
|
||||
|
||||
IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
|
||||
IPC::RequestBuilder rb = rp.MakeBuilder(5, 0);
|
||||
rb.Push(RESULT_SUCCESS);
|
||||
rb.Push(bind_node_id);
|
||||
// TODO(B3N30): Find out what the other return values are
|
||||
rb.Push<u32>(0);
|
||||
rb.Push<u32>(0);
|
||||
rb.Push<u32>(0);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -729,13 +780,14 @@ static void BeginHostingNetwork(Interface* self) {
|
||||
|
||||
LOG_DEBUG(Service_NWM, "called");
|
||||
|
||||
Memory::ReadBlock(network_info_address, &network_info, sizeof(NetworkInfo));
|
||||
|
||||
// The real UDS module throws a fatal error if this assert fails.
|
||||
ASSERT_MSG(network_info.max_nodes > 1, "Trying to host a network of only one member.");
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(connection_status_mutex);
|
||||
|
||||
Memory::ReadBlock(network_info_address, &network_info, sizeof(NetworkInfo));
|
||||
|
||||
// The real UDS module throws a fatal error if this assert fails.
|
||||
ASSERT_MSG(network_info.max_nodes > 1, "Trying to host a network of only one member.");
|
||||
|
||||
connection_status.status = static_cast<u32>(NetworkStatus::ConnectedAsHost);
|
||||
|
||||
// Ensure the application data size is less than the maximum value.
|
||||
@ -749,11 +801,13 @@ static void BeginHostingNetwork(Interface* self) {
|
||||
connection_status.max_nodes = network_info.max_nodes;
|
||||
|
||||
// Resize the nodes list to hold max_nodes.
|
||||
node_info.clear();
|
||||
node_info.resize(network_info.max_nodes);
|
||||
|
||||
// There's currently only one node in the network (the host).
|
||||
connection_status.total_nodes = 1;
|
||||
network_info.total_nodes = 1;
|
||||
|
||||
// The host is always the first node
|
||||
connection_status.network_node_id = 1;
|
||||
current_node.network_node_id = 1;
|
||||
@ -762,12 +816,22 @@ static void BeginHostingNetwork(Interface* self) {
|
||||
connection_status.node_bitmask |= 1;
|
||||
// Notify the application that the first node was set.
|
||||
connection_status.changed_nodes |= 1;
|
||||
node_info[0] = current_node;
|
||||
}
|
||||
|
||||
// If the game has a preferred channel, use that instead.
|
||||
if (network_info.channel != 0)
|
||||
network_channel = network_info.channel;
|
||||
if (auto room_member = Network::GetRoomMember().lock()) {
|
||||
if (room_member->IsConnected()) {
|
||||
network_info.host_mac_address = room_member->GetMacAddress();
|
||||
} else {
|
||||
network_info.host_mac_address = {{0x0, 0x0, 0x0, 0x0, 0x0, 0x0}};
|
||||
}
|
||||
}
|
||||
node_info[0] = current_node;
|
||||
|
||||
// If the game has a preferred channel, use that instead.
|
||||
if (network_info.channel != 0)
|
||||
network_channel = network_info.channel;
|
||||
else
|
||||
network_info.channel = DefaultNetworkChannel;
|
||||
}
|
||||
|
||||
connection_status_event->Signal();
|
||||
|
||||
@ -775,8 +839,7 @@ static void BeginHostingNetwork(Interface* self) {
|
||||
CoreTiming::ScheduleEvent(msToCycles(DefaultBeaconInterval * MillisecondsPerTU),
|
||||
beacon_broadcast_event, 0);
|
||||
|
||||
LOG_WARNING(Service_NWM,
|
||||
"An UDS network has been created, but broadcasting it is unimplemented.");
|
||||
LOG_DEBUG(Service_NWM, "An UDS network has been created.");
|
||||
|
||||
IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
|
||||
rb.Push(RESULT_SUCCESS);
|
||||
@ -929,6 +992,14 @@ static void PullPacket(Interface* self) {
|
||||
ASSERT(desc_size == max_out_buff_size);
|
||||
|
||||
std::lock_guard<std::mutex> lock(connection_status_mutex);
|
||||
if (connection_status.status != static_cast<u32>(NetworkStatus::ConnectedAsHost) &&
|
||||
connection_status.status != static_cast<u32>(NetworkStatus::ConnectedAsClient) &&
|
||||
connection_status.status != static_cast<u32>(NetworkStatus::ConnectedAsSpectator)) {
|
||||
IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
|
||||
rb.Push(ResultCode(ErrorDescription::NotAuthorized, ErrorModule::UDS,
|
||||
ErrorSummary::InvalidState, ErrorLevel::Status));
|
||||
return;
|
||||
}
|
||||
|
||||
auto channel =
|
||||
std::find_if(channel_data.begin(), channel_data.end(), [bind_node_id](const auto& data) {
|
||||
@ -937,8 +1008,8 @@ static void PullPacket(Interface* self) {
|
||||
|
||||
if (channel == channel_data.end()) {
|
||||
IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
|
||||
// TODO(B3N30): Find the right error code
|
||||
rb.Push<u32>(-1);
|
||||
rb.Push(ResultCode(ErrorDescription::NotAuthorized, ErrorModule::UDS,
|
||||
ErrorSummary::WrongArgument, ErrorLevel::Usage));
|
||||
return;
|
||||
}
|
||||
|
||||
@ -959,7 +1030,8 @@ static void PullPacket(Interface* self) {
|
||||
|
||||
if (data_size > max_out_buff_size) {
|
||||
IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
|
||||
rb.Push<u32>(0xE10113E9);
|
||||
rb.Push(ResultCode(ErrorDescription::TooLarge, ErrorModule::UDS,
|
||||
ErrorSummary::WrongArgument, ErrorLevel::Usage));
|
||||
return;
|
||||
}
|
||||
|
||||
@ -1225,6 +1297,7 @@ NWM_UDS::~NWM_UDS() {
|
||||
channel_data.clear();
|
||||
connection_status_event = nullptr;
|
||||
recv_buffer_memory = nullptr;
|
||||
initialized = false;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(connection_status_mutex);
|
||||
|
@ -32,7 +32,7 @@ struct NodeInfo {
|
||||
std::array<u16_le, 10> username;
|
||||
INSERT_PADDING_BYTES(4);
|
||||
u16_le network_node_id;
|
||||
std::array<u8, 6> address;
|
||||
INSERT_PADDING_BYTES(6);
|
||||
};
|
||||
|
||||
static_assert(sizeof(NodeInfo) == 40, "NodeInfo has incorrect size.");
|
||||
@ -42,6 +42,7 @@ using NodeList = std::vector<NodeInfo>;
|
||||
enum class NetworkStatus {
|
||||
NotConnected = 3,
|
||||
ConnectedAsHost = 6,
|
||||
Connecting = 7,
|
||||
ConnectedAsClient = 9,
|
||||
ConnectedAsSpectator = 10,
|
||||
};
|
||||
|
@ -52,7 +52,7 @@ struct SecureDataHeader {
|
||||
u16_be dest_node_id;
|
||||
u16_be src_node_id;
|
||||
|
||||
u32 GetActualDataSize() {
|
||||
u32 GetActualDataSize() const {
|
||||
return protocol_size - sizeof(SecureDataHeader);
|
||||
}
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user