diff --git a/src/binary_translation/BlockColors.cpp b/src/binary_translation/BlockColors.cpp new file mode 100644 index 000000000..64d58fcc2 --- /dev/null +++ b/src/binary_translation/BlockColors.cpp @@ -0,0 +1,91 @@ +#include "BlockColors.h" +#include +#include "InstructionBlock.h" +#include +#include "common/logging/log.h" + +using namespace llvm; + +BlockColors::BlockColors(ModuleGen* module) : module(module) +{ + auto ir_builder = module->IrBuilder(); + function_type = FunctionType::get(ir_builder->getVoidTy(), ir_builder->getInt32Ty(), false); +} + +BlockColors::~BlockColors() +{ +} + +void BlockColors::AddBlock(InstructionBlock* block) +{ + if (block->HasColor()) return; + + std::stack current_color_stack; + current_color_stack.push(block); + auto color = colors.size(); + colors.push_back({ color }); + + while (current_color_stack.size()) + { + auto item = current_color_stack.top(); + current_color_stack.pop(); + + item->SetColor(color); + colors[color].instructions.push_back(item); + for (auto next : item->GetNexts()) + { + if (next->HasColor()) assert(next->GetColor() == color); + else current_color_stack.push(next); + } + for (auto prev : item->GetPrevs()) + { + if (prev->HasColor()) assert(prev->GetColor() == color); + else current_color_stack.push(prev); + } + } +} + +void BlockColors::GenerateFunctions() +{ + auto ir_builder = module->IrBuilder(); + + LOG_INFO(BinaryTranslator, "%x block colors", colors.size()); + + for (auto &color : colors) + { + auto function = Function::Create(function_type, GlobalValue::PrivateLinkage, + "ColorFunction", module->Module()); + color.function = function; + auto index = &function->getArgumentList().front(); + + auto entry_basic_block = BasicBlock::Create(getGlobalContext(), "Entry", function); + auto default_case_basic_block = BasicBlock::Create(getGlobalContext(), "Default", function); + + ir_builder->SetInsertPoint(default_case_basic_block); + ir_builder->CreateUnreachable(); + + ir_builder->SetInsertPoint(entry_basic_block); + auto switch_instruction = ir_builder->CreateSwitch(index, default_case_basic_block, color.instructions.size()); + for (size_t i = 0; i < color.instructions.size(); ++i) + { + switch_instruction->addCase(ir_builder->getInt32(i), color.instructions[i]->GetEntryBasicBlock()); + AddBasicBlocksToFunction(function, color.instructions[i]->GetEntryBasicBlock()); + } + } +} + +void BlockColors::AddBasicBlocksToFunction(Function* function, BasicBlock* basic_block) +{ + if (basic_block->getParent()) + { + assert(basic_block->getParent() == function); + return; + } + + basic_block->insertInto(function); + auto terminator = basic_block->getTerminator(); + for (auto i = 0; i < terminator->getNumSuccessors(); ++i) + { + AddBasicBlocksToFunction(function, terminator->getSuccessor(i)); + } +} \ No newline at end of file diff --git a/src/binary_translation/BlockColors.h b/src/binary_translation/BlockColors.h new file mode 100644 index 000000000..12f7991ba --- /dev/null +++ b/src/binary_translation/BlockColors.h @@ -0,0 +1,50 @@ +#include + +namespace llvm +{ + class BasicBlock; + class Function; + class FunctionType; +} +class InstructionBlock; +class ModuleGen; + +/* + +Responsible to partition the blocks by connectivity, each disjoined graph gets a color +And to generate a function for each color + +*/ + +class BlockColors +{ +public: + BlockColors(ModuleGen *module); + ~BlockColors(); + + void AddBlock(InstructionBlock *block); + // Generates a function for each color + void GenerateFunctions(); + + llvm::FunctionType *GetFunctionType() { return function_type; } + size_t GetColorCount() { return colors.size(); } + size_t GetColorInstructionCount(size_t color) { return colors[color].instructions.size(); } + InstructionBlock *GetColorInstruction(size_t color, size_t index) { return colors[color].instructions[index]; } + llvm::Function *GetColorFunction(size_t color) { return colors[color].function; } +private: + ModuleGen *module; + + // void ColorFunction(int i) + // Runs the code for color->instructions[i] + llvm::FunctionType *function_type; + + void AddBasicBlocksToFunction(llvm::Function *function, llvm::BasicBlock *basic_block); + + struct Color + { + size_t color; + std::vector instructions; + llvm::Function *function; + }; + std::vector colors; +}; \ No newline at end of file diff --git a/src/binary_translation/CMakeLists.txt b/src/binary_translation/CMakeLists.txt index 465910a0a..a057311d6 100644 --- a/src/binary_translation/CMakeLists.txt +++ b/src/binary_translation/CMakeLists.txt @@ -7,6 +7,7 @@ set(SRCS MachineState.cpp TBAA.cpp ARMFuncs.cpp + BlockColors.cpp Instructions/Instruction.cpp Instructions/MovShift.cpp @@ -21,6 +22,7 @@ set(HEADERS TBAA.h BinarySearch.h ARMFuncs.h + BlockColors.h Instructions/Types.h Instructions/Instruction.h diff --git a/src/binary_translation/InstructionBlock.cpp b/src/binary_translation/InstructionBlock.cpp index 025aca6fa..3fec2169f 100644 --- a/src/binary_translation/InstructionBlock.cpp +++ b/src/binary_translation/InstructionBlock.cpp @@ -46,6 +46,12 @@ llvm::BasicBlock *InstructionBlock::CreateBasicBlock(const char *name) return llvm::BasicBlock::Create(llvm::getGlobalContext(), address_string + name); } +void InstructionBlock::Link(InstructionBlock* prev, InstructionBlock* next) +{ + prev->nexts.push_back(next); + next->prevs.push_back(prev); +} + u32 InstructionBlock::Address() { return instruction->Address(); diff --git a/src/binary_translation/InstructionBlock.h b/src/binary_translation/InstructionBlock.h index 0f312ef69..53038a5e7 100644 --- a/src/binary_translation/InstructionBlock.h +++ b/src/binary_translation/InstructionBlock.h @@ -50,11 +50,23 @@ public: */ llvm::BasicBlock *CreateBasicBlock(const char *name); + /* + * Links two instructions, adding to prev and next lists + */ + static void Link(InstructionBlock *prev, InstructionBlock *next); + u32 Address(); ModuleGen *Module() { return module; } llvm::IRBuilder<> *IrBuilder() { return module->IrBuilder(); } llvm::BasicBlock *GetEntryBasicBlock() { return entry_basic_block; } + + bool HasColor() { return has_color; } + void SetColor(size_t color) { this->color = color; has_color = true; } + size_t GetColor() { return color; } + + std::list GetNexts() { return nexts; } + std::list GetPrevs() { return prevs; } private: // Textual representation of the address // Used to generate names @@ -65,4 +77,10 @@ private: // The block at the entry to instruction llvm::BasicBlock *entry_basic_block; + + bool has_color = false; + size_t color; + + std::list nexts; + std::list prevs; }; \ No newline at end of file diff --git a/src/binary_translation/Instructions/Branch.cpp b/src/binary_translation/Instructions/Branch.cpp index 5e9003399..1680b822a 100644 --- a/src/binary_translation/Instructions/Branch.cpp +++ b/src/binary_translation/Instructions/Branch.cpp @@ -37,7 +37,7 @@ void Branch::GenerateInstructionCode(InstructionBlock* instruction_block) auto pc = static_cast(imm24 << 2); pc = pc << 6 >> 6; // Sign extend pc += instruction_block->Address() + 8; - instruction_block->Module()->BranchWritePCConst(pc); + instruction_block->Module()->BranchWritePCConst(instruction_block, pc); } else { diff --git a/src/binary_translation/Instructions/Instruction.cpp b/src/binary_translation/Instructions/Instruction.cpp index e4a0a16c8..00a39b9ff 100644 --- a/src/binary_translation/Instructions/Instruction.cpp +++ b/src/binary_translation/Instructions/Instruction.cpp @@ -56,7 +56,7 @@ void Instruction::GenerateCode(InstructionBlock *instruction_block) // If not, jump to the next instruction if (!ir_builder->GetInsertBlock()->getTerminator()) { - instruction_block->Module()->BranchWritePCConst(Address() + 4); + instruction_block->Module()->BranchWritePCConst(instruction_block, Address() + 4); } } diff --git a/src/binary_translation/ModuleGen.cpp b/src/binary_translation/ModuleGen.cpp index b3a0db027..f2bbd0a1a 100644 --- a/src/binary_translation/ModuleGen.cpp +++ b/src/binary_translation/ModuleGen.cpp @@ -10,6 +10,7 @@ #include #include "MachineState.h" #include "TBAA.h" +#include "BlockColors.h" using namespace llvm; @@ -19,6 +20,7 @@ ModuleGen::ModuleGen(llvm::Module* module) ir_builder = make_unique>(getGlobalContext()); machine = make_unique(this); tbaa = make_unique(); + block_colors = make_unique(this); } ModuleGen::~ModuleGen() @@ -38,23 +40,26 @@ void ModuleGen::Run() GenerateGetBlockAddressFunction(); GenerateInstructionsCode(); - AddInstructionsToRunFunction(); + ColorBlocks(); GenerateBlockAddressArray(); } void ModuleGen::BranchReadPC() { - ir_builder->CreateBr(run_function_re_entry); + auto call = ir_builder->CreateCall(run_function); + call->setTailCall(); + ir_builder->CreateRetVoid(); } -void ModuleGen::BranchWritePCConst(u32 pc) +void ModuleGen::BranchWritePCConst(InstructionBlock *current, u32 pc) { auto i = instruction_blocks_by_pc.find(pc); if (i != instruction_blocks_by_pc.end()) { // Found instruction, jump to it ir_builder->CreateBr(i->second->GetEntryBasicBlock()); + InstructionBlock::Link(i->second, current); } else { @@ -68,7 +73,11 @@ void ModuleGen::GenerateGlobals() { machine->GenerateGlobals(); - auto get_block_address_function_type = FunctionType::get(ir_builder->getInt8PtrTy(), ir_builder->getInt32Ty(), false); + auto function_pointer = PointerType::get(block_colors->GetFunctionType(), 0); + block_address_type = StructType::get(function_pointer, ir_builder->getInt32Ty(), nullptr); + block_address_not_present = ConstantStruct::get(block_address_type, ConstantPointerNull::get(function_pointer), ir_builder->getInt32(0), nullptr); + + auto get_block_address_function_type = FunctionType::get(block_address_type, ir_builder->getInt32Ty(), false); get_block_address_function = Function::Create(get_block_address_function_type, GlobalValue::PrivateLinkage, "GetBlockAddress", module); auto can_run_function_type = FunctionType::get(ir_builder->getInt1Ty(), false); @@ -80,7 +89,7 @@ void ModuleGen::GenerateGlobals() block_address_array_base = Loader::ROMCodeStart / 4; block_address_array_size = Loader::ROMCodeSize / 4; - block_address_array_type = ArrayType::get(ir_builder->getInt8PtrTy(), block_address_array_size); + block_address_array_type = ArrayType::get(block_address_type, block_address_array_size); block_address_array = new GlobalVariable(*module, block_address_array_type, true, GlobalValue::ExternalLinkage, nullptr, "BlockAddressArray"); } @@ -91,15 +100,27 @@ void ModuleGen::GenerateBlockAddressArray() std::fill( local_block_address_array_values.get(), local_block_address_array_values.get() + block_address_array_size, - ConstantPointerNull::get(ir_builder->getInt8PtrTy())); + block_address_not_present); - for (auto i = 0; i < instruction_blocks.size(); ++i) + /*for (auto i = 0; i < instruction_blocks.size(); ++i) { auto &block = instruction_blocks[i]; auto entry_basic_block = block->GetEntryBasicBlock(); auto index = block->Address() / 4 - block_address_array_base; - local_block_address_array_values[index] = BlockAddress::get(entry_basic_block->getParent(), entry_basic_block); - } + auto color_index = 0; + local_block_address_array_values[index] = BConst + }*/ + for (auto color = 0; color < block_colors->GetColorCount(); ++color) + { + auto function = block_colors->GetColorFunction(color); + for (auto i = 0; i < block_colors->GetColorInstructionCount(color); ++i) + { + auto block = block_colors->GetColorInstruction(color, i); + auto index = block->Address() / 4 - block_address_array_base; + auto value = ConstantStruct::get(block_address_type, function, ir_builder->getInt32(i), nullptr); + local_block_address_array_values[index] = value; + } + } auto local_block_address_array_values_ref = ArrayRef(local_block_address_array_values.get(), block_address_array_size); auto local_blocks_address_array = ConstantArray::get(block_address_array_type, local_block_address_array_values_ref); @@ -140,29 +161,30 @@ void ModuleGen::GenerateGetBlockAddressFunction() ir_builder->CreateRet(block_address); ir_builder->SetInsertPoint(index_out_of_bounds_basic_block); - ir_builder->CreateRet(ConstantPointerNull::get(ir_builder->getInt8PtrTy())); + ir_builder->CreateRet(block_address_not_present); } void ModuleGen::GenerateCanRunFunction() { - // return GetBlockAddress(Read(PC)) != nullptr; + // return GetBlockAddress(Read(PC)).function != nullptr; auto basic_block = BasicBlock::Create(getGlobalContext(), "Entry", can_run_function); ir_builder->SetInsertPoint(basic_block); auto block_address = ir_builder->CreateCall(get_block_address_function, machine->ReadRegiser(Register::PC)); - ir_builder->CreateRet(ir_builder->CreateICmpNE(block_address, ConstantPointerNull::get(ir_builder->getInt8PtrTy()))); + auto function = ir_builder->CreateExtractValue(block_address, 0); + ir_builder->CreateRet(ir_builder->CreateICmpNE(function, + ConstantPointerNull::get(cast(function->getType())))); } void ModuleGen::GenerateRunFunction() { /* run_function_entry: - run_function_re_entry: - auto block_address = GetBlockAddress(Read(PC)) - if(index != nullptr) + auto block = GetBlockAddress(Read(PC)) + if(block_address != nullptr) { block_present_basic_block: - goto block_address; + block.function(block.index); return; } else @@ -172,25 +194,21 @@ void ModuleGen::GenerateRunFunction() } */ run_function_entry = BasicBlock::Create(getGlobalContext(), "Entry", run_function); - // run_function_re_entry is needed because it isn't possible to jump to the first block of a function - run_function_re_entry = BasicBlock::Create(getGlobalContext(), "ReEntry", run_function); auto block_present_basic_block = BasicBlock::Create(getGlobalContext(), "BlockPresent", run_function); auto block_not_present_basic_block = BasicBlock::Create(getGlobalContext(), "BlockNotPresent", run_function); - ir_builder->SetInsertPoint(run_function_entry); - ir_builder->CreateBr(run_function_re_entry); - - ir_builder->SetInsertPoint(run_function_re_entry); - auto block_address = ir_builder->CreateCall(get_block_address_function, Machine()->ReadRegiser(Register::PC)); - auto block_present_pred = ir_builder->CreateICmpNE(block_address, ConstantPointerNull::get(ir_builder->getInt8PtrTy())); + ir_builder->SetInsertPoint(run_function_entry); + auto block_address = ir_builder->CreateCall(get_block_address_function, Machine()->ReadRegiser(Register::PC)); + auto function = ir_builder->CreateExtractValue(block_address, 0); + auto block_present_pred = ir_builder->CreateICmpNE(function, + ConstantPointerNull::get(cast(function->getType()))); ir_builder->CreateCondBr(block_present_pred, block_present_basic_block, block_not_present_basic_block); ir_builder->SetInsertPoint(block_present_basic_block); - auto indirect_br = ir_builder->CreateIndirectBr(block_address, instruction_blocks.size()); - for (auto &block : instruction_blocks) - { - indirect_br->addDestination(block->GetEntryBasicBlock()); - } + auto index = ir_builder->CreateExtractValue(block_address, 1); + auto call = ir_builder->CreateCall(function, index); + call->setTailCall(); + ir_builder->CreateRetVoid(); ir_builder->SetInsertPoint(block_not_present_basic_block); ir_builder->CreateRetVoid(); @@ -230,27 +248,11 @@ void ModuleGen::GenerateInstructionsCode() } } -void ModuleGen::AddInstructionsToRunFunction() +void ModuleGen::ColorBlocks() { - std::stack basic_blocks_stack; - - for (auto &block : instruction_blocks) - { - basic_blocks_stack.push(block->GetEntryBasicBlock()); - - while (basic_blocks_stack.size()) - { - auto basic_block = basic_blocks_stack.top(); - basic_blocks_stack.pop(); - if (basic_block->getParent()) continue; // Already added to run - basic_block->insertInto(run_function); - auto terminator = basic_block->getTerminator(); - for (auto i = 0; i < terminator->getNumSuccessors(); ++i) - { - auto new_basic_block = terminator->getSuccessor(i); - if (new_basic_block->getParent()) continue; // Already added to run - basic_blocks_stack.push(new_basic_block); - } - } - } + for (auto &instruction : instruction_blocks) + { + block_colors->AddBlock(instruction.get()); + } + block_colors->GenerateFunctions(); } \ No newline at end of file diff --git a/src/binary_translation/ModuleGen.h b/src/binary_translation/ModuleGen.h index befa17cbf..486e2d763 100644 --- a/src/binary_translation/ModuleGen.h +++ b/src/binary_translation/ModuleGen.h @@ -8,6 +8,7 @@ enum class Register; class InstructionBlock; class MachineState; class TBAA; +class BlockColors; namespace llvm { @@ -25,7 +26,7 @@ public: // Generate code to read pc and run all following instructions, used in cases of indirect branch void BranchReadPC(); // Generate code to write to pc and run all following instructions, used in cases of direct branch - void BranchWritePCConst(u32 pc); + void BranchWritePCConst(InstructionBlock *current, u32 pc); llvm::IRBuilder<> *IrBuilder() { return ir_builder.get(); } llvm::Module *Module() { return module; } @@ -44,9 +45,10 @@ private: // Generates the entry basic blocks for each instruction void GenerateInstructionsEntry(); // Generates the code of each instruction - void GenerateInstructionsCode(); - // Adds all the basic blocks of an instruction to the run function - void AddInstructionsToRunFunction(); + void GenerateInstructionsCode(); + // Must be run after the instruction code is generated since it depends on the + // inter block jumps + void ColorBlocks(); std::unique_ptr machine; std::unique_ptr tbaa; @@ -55,7 +57,16 @@ private: llvm::Module *module; size_t block_address_array_base; - size_t block_address_array_size; + size_t block_address_array_size; + /* + * struct BlockAddress + * { + * void (*function)(u32 index); + * u32 index; + * } + */ + llvm::StructType *block_address_type; + llvm::Constant *block_address_not_present; /* * i8 **BlockAddressArray; * The array at [i/4 - block_address_array_base] contains the block address for the instruction at i @@ -79,11 +90,12 @@ private: */ llvm::Function *run_function; llvm::BasicBlock *run_function_entry; - llvm::BasicBlock *run_function_re_entry; /* * All the instruction blocks */ std::vector> instruction_blocks; std::unordered_map instruction_blocks_by_pc; + + std::unique_ptr block_colors; }; \ No newline at end of file