From b1c2f791af08b3eaba53c1ce1673fe0729fc5d26 Mon Sep 17 00:00:00 2001
From: Liam <byteslice@airmail.cc>
Date: Mon, 1 Jan 2024 21:58:56 -0500
Subject: [PATCH] am: rework IStorage for transfer storage

---
 src/core/hle/service/am/am_results.h          |  1 +
 .../hle/service/am/library_applet_creator.cpp | 28 +++----
 src/core/hle/service/am/storage.cpp           | 73 ++++++++--------
 src/core/hle/service/am/storage.h             | 25 ++----
 src/core/hle/service/am/storage_accessor.cpp  | 84 ++++++++++---------
 src/core/hle/service/am/storage_accessor.h    | 17 +++-
 6 files changed, 118 insertions(+), 110 deletions(-)

diff --git a/src/core/hle/service/am/am_results.h b/src/core/hle/service/am/am_results.h
index e82d391adc..a2afc9eec2 100644
--- a/src/core/hle/service/am/am_results.h
+++ b/src/core/hle/service/am/am_results.h
@@ -10,6 +10,7 @@ namespace Service::AM {
 constexpr Result ResultNoDataInChannel{ErrorModule::AM, 2};
 constexpr Result ResultNoMessages{ErrorModule::AM, 3};
 constexpr Result ResultInvalidOffset{ErrorModule::AM, 503};
+constexpr Result ResultInvalidStorageType{ErrorModule::AM, 511};
 constexpr Result ResultFatalSectionCountImbalance{ErrorModule::AM, 512};
 
 } // namespace Service::AM
diff --git a/src/core/hle/service/am/library_applet_creator.cpp b/src/core/hle/service/am/library_applet_creator.cpp
index e4332e244d..888b8b44bf 100644
--- a/src/core/hle/service/am/library_applet_creator.cpp
+++ b/src/core/hle/service/am/library_applet_creator.cpp
@@ -6,6 +6,7 @@
 #include "core/hle/service/am/frontend/applets.h"
 #include "core/hle/service/am/library_applet_accessor.h"
 #include "core/hle/service/am/library_applet_creator.h"
+#include "core/hle/service/am/library_applet_storage.h"
 #include "core/hle/service/am/storage.h"
 #include "core/hle/service/ipc_helpers.h"
 #include "core/hle/service/sm/sm.h"
@@ -164,28 +165,28 @@ void ILibraryAppletCreator::CreateStorage(HLERequestContext& ctx) {
         return;
     }
 
-    std::vector<u8> buffer(size);
+    std::vector<u8> data(size);
 
     IPC::ResponseBuilder rb{ctx, 2, 0, 1};
     rb.Push(ResultSuccess);
-    rb.PushIpcInterface<IStorage>(system, std::move(buffer));
+    rb.PushIpcInterface<IStorage>(system, AM::CreateStorage(std::move(data)));
 }
 
 void ILibraryAppletCreator::CreateTransferMemoryStorage(HLERequestContext& ctx) {
     IPC::RequestParser rp{ctx};
 
     struct Parameters {
-        u8 permissions;
+        bool is_writable;
         s64 size;
     };
 
-    const auto parameters{rp.PopRaw<Parameters>()};
+    const auto params{rp.PopRaw<Parameters>()};
     const auto handle{ctx.GetCopyHandle(0)};
 
-    LOG_DEBUG(Service_AM, "called, permissions={}, size={}, handle={:08X}", parameters.permissions,
-              parameters.size, handle);
+    LOG_DEBUG(Service_AM, "called, is_writable={}, size={}, handle={:08X}", params.is_writable,
+              params.size, handle);
 
-    if (parameters.size <= 0) {
+    if (params.size <= 0) {
         LOG_ERROR(Service_AM, "size is less than or equal to 0");
         IPC::ResponseBuilder rb{ctx, 2};
         rb.Push(ResultUnknown);
@@ -201,12 +202,11 @@ void ILibraryAppletCreator::CreateTransferMemoryStorage(HLERequestContext& ctx)
         return;
     }
 
-    std::vector<u8> memory(transfer_mem->GetSize());
-    ctx.GetMemory().ReadBlock(transfer_mem->GetSourceAddress(), memory.data(), memory.size());
-
     IPC::ResponseBuilder rb{ctx, 2, 0, 1};
     rb.Push(ResultSuccess);
-    rb.PushIpcInterface<IStorage>(system, std::move(memory));
+    rb.PushIpcInterface<IStorage>(
+        system, AM::CreateTransferMemoryStorage(ctx.GetMemory(), transfer_mem.GetPointerUnsafe(),
+                                                params.is_writable, params.size));
 }
 
 void ILibraryAppletCreator::CreateHandleStorage(HLERequestContext& ctx) {
@@ -233,12 +233,10 @@ void ILibraryAppletCreator::CreateHandleStorage(HLERequestContext& ctx) {
         return;
     }
 
-    std::vector<u8> memory(transfer_mem->GetSize());
-    ctx.GetMemory().ReadBlock(transfer_mem->GetSourceAddress(), memory.data(), memory.size());
-
     IPC::ResponseBuilder rb{ctx, 2, 0, 1};
     rb.Push(ResultSuccess);
-    rb.PushIpcInterface<IStorage>(system, std::move(memory));
+    rb.PushIpcInterface<IStorage>(
+        system, AM::CreateHandleStorage(ctx.GetMemory(), transfer_mem.GetPointerUnsafe(), size));
 }
 
 } // namespace Service::AM
diff --git a/src/core/hle/service/am/storage.cpp b/src/core/hle/service/am/storage.cpp
index 9a86c867a8..4e82afd1ca 100644
--- a/src/core/hle/service/am/storage.cpp
+++ b/src/core/hle/service/am/storage.cpp
@@ -1,60 +1,59 @@
 // SPDX-FileCopyrightText: Copyright 2024 yuzu Emulator Project
 // SPDX-License-Identifier: GPL-2.0-or-later
 
+#include "core/hle/service/am/am_results.h"
+#include "core/hle/service/am/library_applet_storage.h"
 #include "core/hle/service/am/storage.h"
 #include "core/hle/service/am/storage_accessor.h"
 #include "core/hle/service/ipc_helpers.h"
 
 namespace Service::AM {
 
-IStorageImpl::~IStorageImpl() = default;
-
-class StorageDataImpl final : public IStorageImpl {
-public:
-    explicit StorageDataImpl(std::vector<u8>&& buffer_) : buffer{std::move(buffer_)} {}
-
-    std::vector<u8>& GetData() override {
-        return buffer;
-    }
-
-    const std::vector<u8>& GetData() const override {
-        return buffer;
-    }
-
-    std::size_t GetSize() const override {
-        return buffer.size();
-    }
-
-private:
-    std::vector<u8> buffer;
-};
-
-IStorage::IStorage(Core::System& system_, std::vector<u8>&& buffer)
-    : ServiceFramework{system_, "IStorage"},
-      impl{std::make_shared<StorageDataImpl>(std::move(buffer))} {
-    Register();
-}
-
-void IStorage::Register() {
-    // clang-format off
-        static const FunctionInfo functions[] = {
-            {0, &IStorage::Open, "Open"},
-            {1, nullptr, "OpenTransferStorage"},
-        };
-    // clang-format on
+IStorage::IStorage(Core::System& system_, std::shared_ptr<LibraryAppletStorage> impl_)
+    : ServiceFramework{system_, "IStorage"}, impl{std::move(impl_)} {
+    static const FunctionInfo functions[] = {
+        {0, &IStorage::Open, "Open"},
+        {1, &IStorage::OpenTransferStorage, "OpenTransferStorage"},
+    };
 
     RegisterHandlers(functions);
 }
 
+IStorage::IStorage(Core::System& system_, std::vector<u8>&& data)
+    : IStorage(system_, CreateStorage(std::move(data))) {}
+
 IStorage::~IStorage() = default;
 
 void IStorage::Open(HLERequestContext& ctx) {
     LOG_DEBUG(Service_AM, "called");
 
-    IPC::ResponseBuilder rb{ctx, 2, 0, 1};
+    if (impl->GetHandle() != nullptr) {
+        IPC::ResponseBuilder rb{ctx, 2};
+        rb.Push(AM::ResultInvalidStorageType);
+        return;
+    }
 
+    IPC::ResponseBuilder rb{ctx, 2, 0, 1};
     rb.Push(ResultSuccess);
-    rb.PushIpcInterface<IStorageAccessor>(system, *this);
+    rb.PushIpcInterface<IStorageAccessor>(system, impl);
+}
+
+void IStorage::OpenTransferStorage(HLERequestContext& ctx) {
+    LOG_DEBUG(Service_AM, "called");
+
+    if (impl->GetHandle() == nullptr) {
+        IPC::ResponseBuilder rb{ctx, 2};
+        rb.Push(AM::ResultInvalidStorageType);
+        return;
+    }
+
+    IPC::ResponseBuilder rb{ctx, 2, 0, 1};
+    rb.Push(ResultSuccess);
+    rb.PushIpcInterface<ITransferStorageAccessor>(system, impl);
+}
+
+std::vector<u8> IStorage::GetData() const {
+    return impl->GetData();
 }
 
 } // namespace Service::AM
diff --git a/src/core/hle/service/am/storage.h b/src/core/hle/service/am/storage.h
index d47a8d89f7..10d00b1419 100644
--- a/src/core/hle/service/am/storage.h
+++ b/src/core/hle/service/am/storage.h
@@ -7,36 +7,25 @@
 
 namespace Service::AM {
 
-class IStorageImpl {
-public:
-    virtual ~IStorageImpl();
-    virtual std::vector<u8>& GetData() = 0;
-    virtual const std::vector<u8>& GetData() const = 0;
-    virtual std::size_t GetSize() const = 0;
-};
+class LibraryAppletStorage;
 
 class IStorage final : public ServiceFramework<IStorage> {
 public:
+    explicit IStorage(Core::System& system_, std::shared_ptr<LibraryAppletStorage> impl_);
     explicit IStorage(Core::System& system_, std::vector<u8>&& buffer);
     ~IStorage() override;
 
-    std::vector<u8>& GetData() {
-        return impl->GetData();
+    std::shared_ptr<LibraryAppletStorage> GetImpl() const {
+        return impl;
     }
 
-    const std::vector<u8>& GetData() const {
-        return impl->GetData();
-    }
-
-    std::size_t GetSize() const {
-        return impl->GetSize();
-    }
+    std::vector<u8> GetData() const;
 
 private:
-    void Register();
     void Open(HLERequestContext& ctx);
+    void OpenTransferStorage(HLERequestContext& ctx);
 
-    std::shared_ptr<IStorageImpl> impl;
+    const std::shared_ptr<LibraryAppletStorage> impl;
 };
 
 } // namespace Service::AM
diff --git a/src/core/hle/service/am/storage_accessor.cpp b/src/core/hle/service/am/storage_accessor.cpp
index 7d8c82de33..a1184b0652 100644
--- a/src/core/hle/service/am/storage_accessor.cpp
+++ b/src/core/hle/service/am/storage_accessor.cpp
@@ -1,21 +1,22 @@
 // SPDX-FileCopyrightText: Copyright 2024 yuzu Emulator Project
 // SPDX-License-Identifier: GPL-2.0-or-later
 
+#include "core/hle/kernel/k_transfer_memory.h"
 #include "core/hle/service/am/am_results.h"
+#include "core/hle/service/am/library_applet_storage.h"
 #include "core/hle/service/am/storage_accessor.h"
 #include "core/hle/service/ipc_helpers.h"
 
 namespace Service::AM {
 
-IStorageAccessor::IStorageAccessor(Core::System& system_, IStorage& backing_)
-    : ServiceFramework{system_, "IStorageAccessor"}, backing{backing_} {
-    // clang-format off
-        static const FunctionInfo functions[] = {
-            {0, &IStorageAccessor::GetSize, "GetSize"},
-            {10, &IStorageAccessor::Write, "Write"},
-            {11, &IStorageAccessor::Read, "Read"},
-        };
-    // clang-format on
+IStorageAccessor::IStorageAccessor(Core::System& system_,
+                                   std::shared_ptr<LibraryAppletStorage> impl_)
+    : ServiceFramework{system_, "IStorageAccessor"}, impl{std::move(impl_)} {
+    static const FunctionInfo functions[] = {
+        {0, &IStorageAccessor::GetSize, "GetSize"},
+        {10, &IStorageAccessor::Write, "Write"},
+        {11, &IStorageAccessor::Read, "Read"},
+    };
 
     RegisterHandlers(functions);
 }
@@ -28,55 +29,62 @@ void IStorageAccessor::GetSize(HLERequestContext& ctx) {
     IPC::ResponseBuilder rb{ctx, 4};
 
     rb.Push(ResultSuccess);
-    rb.Push(static_cast<u64>(backing.GetSize()));
+    rb.Push(impl->GetSize());
 }
 
 void IStorageAccessor::Write(HLERequestContext& ctx) {
     IPC::RequestParser rp{ctx};
 
-    const u64 offset{rp.Pop<u64>()};
+    const s64 offset{rp.Pop<s64>()};
     const auto data{ctx.ReadBuffer()};
-    const std::size_t size{std::min<u64>(data.size(), backing.GetSize() - offset)};
+    LOG_DEBUG(Service_AM, "called, offset={}, size={}", offset, data.size());
 
-    LOG_DEBUG(Service_AM, "called, offset={}, size={}", offset, size);
-
-    if (offset > backing.GetSize()) {
-        LOG_ERROR(Service_AM,
-                  "offset is out of bounds, backing_buffer_sz={}, data_size={}, offset={}",
-                  backing.GetSize(), size, offset);
-
-        IPC::ResponseBuilder rb{ctx, 2};
-        rb.Push(AM::ResultInvalidOffset);
-        return;
-    }
-
-    std::memcpy(backing.GetData().data() + offset, data.data(), size);
+    const auto res{impl->Write(offset, data.data(), data.size())};
 
     IPC::ResponseBuilder rb{ctx, 2};
-    rb.Push(ResultSuccess);
+    rb.Push(res);
 }
 
 void IStorageAccessor::Read(HLERequestContext& ctx) {
     IPC::RequestParser rp{ctx};
 
-    const u64 offset{rp.Pop<u64>()};
-    const std::size_t size{std::min<u64>(ctx.GetWriteBufferSize(), backing.GetSize() - offset)};
+    const s64 offset{rp.Pop<s64>()};
+    std::vector<u8> data(ctx.GetWriteBufferSize());
 
-    LOG_DEBUG(Service_AM, "called, offset={}, size={}", offset, size);
+    LOG_DEBUG(Service_AM, "called, offset={}, size={}", offset, data.size());
 
-    if (offset > backing.GetSize()) {
-        LOG_ERROR(Service_AM, "offset is out of bounds, backing_buffer_sz={}, size={}, offset={}",
-                  backing.GetSize(), size, offset);
+    const auto res{impl->Read(offset, data.data(), data.size())};
 
-        IPC::ResponseBuilder rb{ctx, 2};
-        rb.Push(AM::ResultInvalidOffset);
-        return;
-    }
-
-    ctx.WriteBuffer(backing.GetData().data() + offset, size);
+    ctx.WriteBuffer(data);
 
     IPC::ResponseBuilder rb{ctx, 2};
+    rb.Push(res);
+}
+
+ITransferStorageAccessor::ITransferStorageAccessor(Core::System& system_,
+                                                   std::shared_ptr<LibraryAppletStorage> impl_)
+    : ServiceFramework{system_, "ITransferStorageAccessor"}, impl{std::move(impl_)} {
+    static const FunctionInfo functions[] = {
+        {0, &ITransferStorageAccessor::GetSize, "GetSize"},
+        {1, &ITransferStorageAccessor::GetHandle, "GetHandle"},
+    };
+
+    RegisterHandlers(functions);
+}
+
+ITransferStorageAccessor::~ITransferStorageAccessor() = default;
+
+void ITransferStorageAccessor::GetSize(HLERequestContext& ctx) {
+    IPC::ResponseBuilder rb{ctx, 4};
     rb.Push(ResultSuccess);
+    rb.Push(impl->GetSize());
+}
+
+void ITransferStorageAccessor::GetHandle(HLERequestContext& ctx) {
+    IPC::ResponseBuilder rb{ctx, 4, 1};
+    rb.Push(ResultSuccess);
+    rb.Push(impl->GetSize());
+    rb.PushCopyObjects(impl->GetHandle());
 }
 
 } // namespace Service::AM
diff --git a/src/core/hle/service/am/storage_accessor.h b/src/core/hle/service/am/storage_accessor.h
index 8648bfc138..b9aa85a666 100644
--- a/src/core/hle/service/am/storage_accessor.h
+++ b/src/core/hle/service/am/storage_accessor.h
@@ -10,7 +10,7 @@ namespace Service::AM {
 
 class IStorageAccessor final : public ServiceFramework<IStorageAccessor> {
 public:
-    explicit IStorageAccessor(Core::System& system_, IStorage& backing_);
+    explicit IStorageAccessor(Core::System& system_, std::shared_ptr<LibraryAppletStorage> impl_);
     ~IStorageAccessor() override;
 
 private:
@@ -18,7 +18,20 @@ private:
     void Write(HLERequestContext& ctx);
     void Read(HLERequestContext& ctx);
 
-    IStorage& backing;
+    const std::shared_ptr<LibraryAppletStorage> impl;
+};
+
+class ITransferStorageAccessor final : public ServiceFramework<ITransferStorageAccessor> {
+public:
+    explicit ITransferStorageAccessor(Core::System& system_,
+                                      std::shared_ptr<LibraryAppletStorage> impl_);
+    ~ITransferStorageAccessor() override;
+
+private:
+    void GetSize(HLERequestContext& ctx);
+    void GetHandle(HLERequestContext& ctx);
+
+    const std::shared_ptr<LibraryAppletStorage> impl;
 };
 
 } // namespace Service::AM