Compare commits

..

6 Commits

Author SHA1 Message Date
df8e9f7e83 refactoring: switched to match/case
wow! python actually added switch cases! too bad this is just syntax sugar...
2023-12-09 12:01:04 -06:00
a22aa808e0 lp: added support for OP_TEST 2022-08-26 01:18:24 -05:00
935844f274 more minor refactoring 2022-08-22 00:59:21 -05:00
c37e9a21d8 ld: minor refactoring 2022-08-22 00:54:59 -05:00
34b1ec7285 ld: added LuaDump()
- chunks can now be serialized back into lua bytecode dumps :D
2022-08-22 00:50:08 -05:00
f9f1d4af00 ld: minor refactoring 2022-08-19 15:46:47 -05:00
3 changed files with 316 additions and 194 deletions

View File

@@ -231,7 +231,8 @@ class LuaDecomp:
def __emitOperand(self, a: int, b: str, c: str, op: str) -> None:
self.__setReg(a, "(" + b + op + c + ")")
def __compJmp(self, op: str):
# handles conditional jumps
def __condJmp(self, op: str, rkBC: bool = True):
instr = self.__getCurrInstr()
jmpType = "if"
scopeStart = "then"
@@ -254,7 +255,13 @@ class LuaDecomp:
self.__addExpr("%s not " % jmpType)
else:
self.__addExpr("%s " % jmpType)
self.__addExpr(self.__readRK(instr.B) + op + self.__readRK(instr.C) + " ")
# write actual comparison
if rkBC:
self.__addExpr(self.__readRK(instr.B) + op + self.__readRK(instr.C) + " ")
else: # just testing rkB
self.__addExpr(op + self.__readRK(instr.B))
self.pc += 1 # skip next instr
if scopeStart:
self.__startScope("%s " % scopeStart, self.pc - 1, jmp)
@@ -315,120 +322,125 @@ class LuaDecomp:
def parseInstr(self):
instr = self.__getCurrInstr()
# python, add switch statements *please*
if instr.opcode == Opcodes.MOVE: # move is a fake ABC instr, C is ignored
# move registers
self.__setReg(instr.A, self.__getReg(instr.B))
elif instr.opcode == Opcodes.LOADK:
self.__setReg(instr.A, self.chunk.getConstant(instr.B).toCode())
elif instr.opcode == Opcodes.LOADBOOL:
if instr.B == 0:
self.__setReg(instr.A, "false")
else:
self.__setReg(instr.A, "true")
elif instr.opcode == Opcodes.GETGLOBAL:
self.__setReg(instr.A, self.chunk.getConstant(instr.B).data)
elif instr.opcode == Opcodes.GETTABLE:
self.__setReg(instr.A, self.__getReg(instr.B) + "[" + self.__readRK(instr.C) + "]")
elif instr.opcode == Opcodes.SETGLOBAL:
self.__addExpr(self.chunk.getConstant(instr.B).data + " = " + self.__getReg(instr.A))
self.__endStatement()
elif instr.opcode == Opcodes.SETTABLE:
self.__addExpr(self.__getReg(instr.A) + "[" + self.__readRK(instr.B) + "] = " + self.__readRK(instr.C))
self.__endStatement()
elif instr.opcode == Opcodes.NEWTABLE:
self.__parseNewTable(instr.A)
elif instr.opcode == Opcodes.ADD:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " + ")
elif instr.opcode == Opcodes.SUB:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " - ")
elif instr.opcode == Opcodes.MUL:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " * ")
elif instr.opcode == Opcodes.DIV:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " / ")
elif instr.opcode == Opcodes.MOD:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " % ")
elif instr.opcode == Opcodes.POW:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " ^ ")
elif instr.opcode == Opcodes.UNM:
self.__setReg(instr.A, "-" + self.__getReg(instr.B))
elif instr.opcode == Opcodes.NOT:
self.__setReg(instr.A, "not " + self.__getReg(instr.B))
elif instr.opcode == Opcodes.LEN:
self.__setReg(instr.A, "#" + self.__getReg(instr.B))
elif instr.opcode == Opcodes.CONCAT:
count = instr.C-instr.B+1
concatStr = ""
# concat all items on stack from RC to RB
for i in range(count):
concatStr += self.__getReg(instr.B + i) + (" .. " if not i == count - 1 else "")
self.__setReg(instr.A, concatStr)
elif instr.opcode == Opcodes.JMP:
pass
elif instr.opcode == Opcodes.EQ:
self.__compJmp(" == ")
elif instr.opcode == Opcodes.LT:
self.__compJmp(" < ")
elif instr.opcode == Opcodes.LE:
self.__compJmp(" <= ")
elif instr.opcode == Opcodes.CALL:
preStr = ""
callStr = ""
ident = ""
# parse arguments
callStr += self.__getReg(instr.A) + "("
for i in range(instr.A + 1, instr.A + instr.B):
callStr += self.__getReg(i) + (", " if not i + 1 == instr.A + instr.B else "")
callStr += ")"
# parse return values
if instr.C > 1:
preStr = "local "
for indx in range(instr.A, instr.A + instr.C - 1):
if indx in self.locals:
ident = self.locals[indx]
else:
ident = self.__makeLocalIdentifier(indx)
preStr += ident
# normally setReg() does this
self.top[indx] = ident
# just so we don't have a trailing ', '
preStr += ", " if not indx == instr.A + instr.C - 2 else ""
preStr += " = "
self.__addExpr(preStr + callStr)
self.__endStatement()
elif instr.opcode == Opcodes.RETURN:
self.__endStatement()
pass # no-op for now
elif instr.opcode == Opcodes.FORLOOP:
pass # no-op for now
elif instr.opcode == Opcodes.FORPREP:
self.__addExpr("for %s = %s, %s, %s " % (self.__getLocal(instr.A+3), self.__getReg(instr.A), self.__getReg(instr.A + 1), self.__getReg(instr.A + 2)))
self.__startScope("do", self.pc, instr.B)
elif instr.opcode == Opcodes.SETLIST:
# LFIELDS_PER_FLUSH (50) is the number of elements that *should* have been set in the list in the *last* SETLIST
# eg.
# [ 49] LOADK : R[49] K[1] ; load 0.0 into R[49]
# [ 50] LOADK : R[50] K[1] ; load 0.0 into R[50]
# [ 51] SETLIST : 0 50 1 ; sets list[1..50]
# [ 52] LOADK : R[1] K[1] ; load 0.0 into R[1]
# [ 53] SETLIST : 0 1 2 ; sets list[51..51]
numElems = instr.B
startAt = ((instr.C - 1) * 50)
ident = self.__getLocal(instr.A)
# set each index (TODO: make tables less verbose)
for i in range(numElems):
self.__addExpr("%s[%d] = %s" % (ident, (startAt + i + 1), self.__getReg(instr.A + i + 1)))
match instr.opcode:
case Opcodes.MOVE: # move is a fake ABC instr, C is ignored
# move registers
self.__setReg(instr.A, self.__getReg(instr.B))
case Opcodes.LOADK:
self.__setReg(instr.A, self.chunk.getConstant(instr.B).toCode())
case Opcodes.LOADBOOL:
if instr.B == 0:
self.__setReg(instr.A, "false")
else:
self.__setReg(instr.A, "true")
case Opcodes.GETGLOBAL:
self.__setReg(instr.A, self.chunk.getConstant(instr.B).data)
case Opcodes.GETTABLE:
self.__setReg(instr.A, self.__getReg(instr.B) + "[" + self.__readRK(instr.C) + "]")
case Opcodes.SETGLOBAL:
self.__addExpr(self.chunk.getConstant(instr.B).data + " = " + self.__getReg(instr.A))
self.__endStatement()
elif instr.opcode == Opcodes.CLOSURE:
proto = LuaDecomp(self.chunk.protos[instr.B], headChunk=False, scopeOffset=len(self.scope))
self.__setReg(instr.A, proto.getPseudoCode())
else:
raise Exception("unsupported instruction: %s" % instr.toString())
case Opcodes.SETTABLE:
self.__addExpr(self.__getReg(instr.A) + "[" + self.__readRK(instr.B) + "] = " + self.__readRK(instr.C))
self.__endStatement()
case Opcodes.NEWTABLE:
self.__parseNewTable(instr.A)
case Opcodes.ADD:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " + ")
case Opcodes.SUB:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " - ")
case Opcodes.MUL:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " * ")
case Opcodes.DIV:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " / ")
case Opcodes.MOD:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " % ")
case Opcodes.POW:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " ^ ")
case Opcodes.UNM:
self.__setReg(instr.A, "-" + self.__getReg(instr.B))
case Opcodes.NOT:
self.__setReg(instr.A, "not " + self.__getReg(instr.B))
case Opcodes.LEN:
self.__setReg(instr.A, "#" + self.__getReg(instr.B))
case Opcodes.CONCAT:
count = instr.C-instr.B+1
concatStr = ""
# concat all items on stack from RC to RB
for i in range(count):
concatStr += self.__getReg(instr.B + i) + (" .. " if not i == count - 1 else "")
self.__setReg(instr.A, concatStr)
case Opcodes.JMP:
pass
case Opcodes.EQ:
self.__condJmp(" == ")
case Opcodes.LT:
self.__condJmp(" < ")
case Opcodes.LE:
self.__condJmp(" <= ")
case Opcodes.TEST:
if instr.C == 0:
self.__condJmp("", False)
else:
self.__condJmp("not ", False)
case Opcodes.CALL:
preStr = ""
callStr = ""
ident = ""
# parse arguments
callStr += self.__getReg(instr.A) + "("
for i in range(instr.A + 1, instr.A + instr.B):
callStr += self.__getReg(i) + (", " if not i + 1 == instr.A + instr.B else "")
callStr += ")"
# parse return values
if instr.C > 1:
preStr = "local "
for indx in range(instr.A, instr.A + instr.C - 1):
if indx in self.locals:
ident = self.locals[indx]
else:
ident = self.__makeLocalIdentifier(indx)
preStr += ident
# normally setReg() does this
self.top[indx] = ident
# just so we don't have a trailing ', '
preStr += ", " if not indx == instr.A + instr.C - 2 else ""
preStr += " = "
self.__addExpr(preStr + callStr)
self.__endStatement()
case Opcodes.RETURN:
self.__endStatement()
pass # no-op for now
case Opcodes.FORLOOP:
pass # no-op for now
case Opcodes.FORPREP:
self.__addExpr("for %s = %s, %s, %s " % (self.__getLocal(instr.A+3), self.__getReg(instr.A), self.__getReg(instr.A + 1), self.__getReg(instr.A + 2)))
self.__startScope("do", self.pc, instr.B)
case Opcodes.SETLIST:
# LFIELDS_PER_FLUSH (50) is the number of elements that *should* have been set in the list in the *last* SETLIST
# eg.
# [ 49] LOADK : R[49] K[1] ; load 0.0 into R[49]
# [ 50] LOADK : R[50] K[1] ; load 0.0 into R[50]
# [ 51] SETLIST : 0 50 1 ; sets list[1..50]
# [ 52] LOADK : R[1] K[1] ; load 0.0 into R[1]
# [ 53] SETLIST : 0 1 2 ; sets list[51..51]
numElems = instr.B
startAt = ((instr.C - 1) * 50)
ident = self.__getLocal(instr.A)
# set each index (TODO: make tables less verbose)
for i in range(numElems):
self.__addExpr("%s[%d] = %s" % (ident, (startAt + i + 1), self.__getReg(instr.A + i + 1)))
self.__endStatement()
case Opcodes.CLOSURE:
proto = LuaDecomp(self.chunk.protos[instr.B], headChunk=False, scopeOffset=len(self.scope))
self.__setReg(instr.A, proto.getPseudoCode())
case _:
raise Exception("unsupported instruction: %s" % instr.toString())

View File

@@ -1,7 +1,7 @@
'''
l(un)dump.py
A Lua5.1 cross-platform bytecode deserializer. This module pulls int and size_t sizes from the
A Lua5.1 cross-platform bytecode deserializer && serializer. This module pulls int and size_t sizes from the
chunk header, meaning it should be able to deserialize lua bytecode dumps from most platforms,
regardless of the host machine.
@@ -9,11 +9,9 @@
as well as read the lundump.c source file from the Lua5.1 source.
'''
from multiprocessing.spawn import get_executable
import struct
import array
from enum import IntEnum, Enum, auto
from typing_extensions import Self
class InstructionType(Enum):
ABC = auto(),
@@ -70,6 +68,8 @@ _RKBCInstr = [Opcodes.SETTABLE, Opcodes.ADD, Opcodes.SUB, Opcodes.MUL, Opcodes.D
_RKCInstr = [Opcodes.GETTABLE, Opcodes.SELF]
_KBx = [Opcodes.LOADK, Opcodes.GETGLOBAL, Opcodes.SETGLOBAL]
_LUAMAGIC = b'\x1bLua'
# is an 'RK' value a K? (result is true for K, false for R)
def whichRK(rk: int):
return (rk & (1 << 8)) > 0
@@ -189,6 +189,7 @@ class Chunk:
self.maxStack: int = 0
self.upvalues: list[str] = []
self.lineNums: list[int] = []
self.locals: list[Local] = []
def appendInstruction(self, instr: Instruction):
@@ -200,9 +201,15 @@ class Chunk:
def appendProto(self, proto):
self.protos.append(proto)
def appendLine(self, line: int):
self.lineNums.append(line)
def appendLocal(self, local: Local):
self.locals.append(local)
def appendUpval(self, upval: str):
self.upvalues.append(upval)
def findLocal(self, pc: int) -> Local:
for l in self.locals:
if l.start <= pc and l.end >= pc:
@@ -298,11 +305,7 @@ class LuaUndump:
self.rootChunk: Chunk = None
self.index = 0
@staticmethod
def dis_chunk(chunk: Chunk):
chunk.print()
def loadBlock(self, sz) -> bytearray:
def _loadBlock(self, sz) -> bytearray:
if self.index + sz > len(self.bytecode):
raise Exception("Malformed bytecode!")
@@ -310,82 +313,71 @@ class LuaUndump:
self.index = self.index + sz
return temp
def get_byte(self) -> int:
return self.loadBlock(1)[0]
def _get_byte(self) -> int:
return self._loadBlock(1)[0]
def get_int32(self) -> int:
if (self.big_endian):
return int.from_bytes(self.loadBlock(4), byteorder='big', signed=False)
else:
return int.from_bytes(self.loadBlock(4), byteorder='little', signed=False)
def _get_uint32(self) -> int:
order = 'big' if self.big_endian else 'little'
return int.from_bytes(self._loadBlock(4), byteorder=order, signed=False)
def get_int(self) -> int:
if (self.big_endian):
return int.from_bytes(self.loadBlock(self.int_size), byteorder='big', signed=False)
else:
return int.from_bytes(self.loadBlock(self.int_size), byteorder='little', signed=False)
def _get_uint(self) -> int:
order = 'big' if self.big_endian else 'little'
return int.from_bytes(self._loadBlock(self.int_size), byteorder=order, signed=False)
def get_size_t(self) -> int:
if (self.big_endian):
return int.from_bytes(self.loadBlock(self.size_t), byteorder='big', signed=False)
else:
return int.from_bytes(self.loadBlock(self.size_t), byteorder='little', signed=False)
def _get_size_t(self) -> int:
order = 'big' if self.big_endian else 'little'
return int.from_bytes(self._loadBlock(self.size_t), byteorder=order, signed=False)
def get_double(self) -> int:
if self.big_endian:
return struct.unpack('>d', self.loadBlock(8))[0]
else:
return struct.unpack('<d', self.loadBlock(8))[0]
def _get_double(self) -> int:
order = '>d' if self.big_endian else '<d'
return struct.unpack(order, self._loadBlock(self.l_number_size))[0]
def get_string(self, size) -> str:
if (size == None):
size = self.get_size_t()
if (size == 0):
return ""
def _get_string(self) -> str:
size = self._get_size_t()
if (size == 0):
return ""
return "".join(chr(x) for x in self.loadBlock(size))
# [:-1] to remove the NULL terminator
return ("".join(chr(x) for x in self._loadBlock(size)))[:-1]
def decode_chunk(self) -> Chunk:
chunk = Chunk()
chunk.name = self.get_string(None)
chunk.frst_line = self.get_int()
chunk.last_line = self.get_int()
chunk.numUpvals = self.get_byte()
chunk.numParams = self.get_byte()
chunk.isVarg = (self.get_byte() != 0)
chunk.maxStack = self.get_byte()
if (not chunk.name == ""):
chunk.name = chunk.name[1:-1]
# chunk meta info
chunk.name = self._get_string()
chunk.frst_line = self._get_uint()
chunk.last_line = self._get_uint()
chunk.numUpvals = self._get_byte()
chunk.numParams = self._get_byte()
chunk.isVarg = (self._get_byte() != 0)
chunk.maxStack = self._get_byte()
# parse instructions
num = self.get_int()
num = self._get_uint()
for i in range(num):
chunk.appendInstruction(_decode_instr(self.get_int32()))
chunk.appendInstruction(_decode_instr(self._get_uint32()))
# get constants
num = self.get_int()
num = self._get_uint()
for i in range(num):
constant: Constant = None
type = self.get_byte()
type = self._get_byte()
if type == 0: #nil
if type == 0: # nil
constant = Constant(ConstType.NIL, None)
elif type == 1: # bool
constant = Constant(ConstType.BOOL, (self.get_byte() != 0))
constant = Constant(ConstType.BOOL, (self._get_byte() != 0))
elif type == 3: # number
constant = Constant(ConstType.NUMBER, self.get_double())
constant = Constant(ConstType.NUMBER, self._get_double())
elif type == 4: # string
constant = Constant(ConstType.STRING, self.get_string(None)[:-1])
constant = Constant(ConstType.STRING, self._get_string())
else:
raise Exception("Unknown Datatype! [%d]" % type)
chunk.appendConstant(constant)
# parse protos
num = self.get_int()
num = self._get_uint()
for i in range(num):
chunk.appendProto(self.decode_chunk())
@@ -393,47 +385,47 @@ class LuaUndump:
# eh, for now just consume the bytes.
# line numbers
num = self.get_int()
num = self._get_uint()
for i in range(num):
self.get_int()
self._get_uint()
# locals
num = self.get_int()
num = self._get_uint()
for i in range(num):
name = self.get_string(None)[:-1] # local name ([:-1] to remove the NULL terminator)
start = self.get_int() # local start PC
end = self.get_int() # local end PC
name = self._get_string() # local name
start = self._get_uint() # local start PC
end = self._get_uint() # local end PC
chunk.appendLocal(Local(name, start, end))
# upvalues
num = self.get_int()
num = self._get_uint()
for i in range(num):
self.get_string(None) # upvalue name
chunk.appendUpval(self._get_string()) # upvalue name
return chunk
def decode_rawbytecode(self, rawbytecode):
# bytecode sanity checks
if not rawbytecode[0:4] == b'\x1bLua':
if not rawbytecode[0:4] == _LUAMAGIC:
raise Exception("Lua Bytecode expected!")
bytecode = array.array('b', rawbytecode)
return self.decode_bytecode(bytecode)
def decode_bytecode(self, bytecode):
self.bytecode = bytecode
self.bytecode = bytecode
# aligns index, skips header
self.index = 4
self.vm_version = self.get_byte()
self.bytecode_format = self.get_byte()
self.big_endian = (self.get_byte() == 0)
self.int_size = self.get_byte()
self.size_t = self.get_byte()
self.instr_size = self.get_byte() # gets size of instructions
self.l_number_size = self.get_byte() # size of lua_Number
self.integral_flag = self.get_byte()
self.vm_version = self._get_byte()
self.bytecode_format = self._get_byte()
self.big_endian = (self._get_byte() == 0)
self.int_size = self._get_byte()
self.size_t = self._get_byte()
self.instr_size = self._get_byte() # gets size of instructions
self.l_number_size = self._get_byte() # size of lua_Number
self.integral_flag = self._get_byte() # is lua_Number defined as an int? false = float/double, true = int/long/short/etc.
self.rootChunk = self.decode_chunk()
return self.rootChunk
@@ -444,5 +436,122 @@ class LuaUndump:
return self.decode_rawbytecode(bytecode)
def print_dissassembly(self):
LuaUndump.dis_chunk(self.rootChunk)
self.rootChunk.print()
class LuaDump:
def __init__(self, rootChunk: Chunk):
self.rootChunk = rootChunk
self.bytecode = bytearray()
# header info
self.vm_version = 0x51
self.bytecode_format = 0x00
self.big_endian = False
# data sizes
self.int_size = 4
self.size_t = 8
self.instr_size = 4
self.l_number_size = 8
self.integral_flag = False # lua_Number is a double
def _writeBlock(self, data: bytes):
self.bytecode += bytearray(data)
def _set_byte(self, b: int):
self.bytecode.append(b)
def _set_uint32(self, i: int):
order = 'big' if self.big_endian else 'little'
self._writeBlock(i.to_bytes(4, order, signed=False))
def _set_uint(self, i: int):
order = 'big' if self.big_endian else 'little'
self._writeBlock(i.to_bytes(self.int_size, order, signed=False))
def _set_size_t(self, i: int):
order = 'big' if self.big_endian else 'little'
self._writeBlock(i.to_bytes(self.size_t, order, signed=False))
def _set_double(self, f: float):
order = '>d' if self.big_endian else '<d'
self._writeBlock(struct.pack(order, f))
def _set_string(self, string: str):
self._set_size_t(len(string)+1)
self._writeBlock(string.encode('utf-8'))
self._set_byte(0x00) # write null terminator
def _dumpChunk(self, chunk: Chunk):
# write meta info
self._set_string(chunk.name)
self._set_uint(chunk.frst_line)
self._set_uint(chunk.last_line)
self._set_byte(chunk.numUpvals)
self._set_byte(chunk.numParams)
self._set_byte(1 if chunk.isVarg else 1)
self._set_byte(chunk.maxStack)
# write instructions
self._set_uint(len(chunk.instructions))
for l in chunk.instructions:
self._set_uint32(_encode_instr(l))
# write constants
self._set_uint(len(chunk.constants))
for constant in chunk.constants:
# write constant data
if constant.type == ConstType.NIL:
self._set_byte(0)
elif constant.type == ConstType.BOOL:
self._set_byte(1)
self._set_byte(1 if constant.data else 0)
elif constant.type == ConstType.NUMBER: # number
self._set_byte(3)
self._set_double(constant.data)
elif constant.type == ConstType.STRING: # string
self._set_byte(4)
self._set_string(constant.data)
else:
raise Exception("Unknown Datatype! [%s]" % str(constant.type))
# write child protos
self._set_uint(len(chunk.protos))
for p in chunk.protos:
self._dumpChunk(p)
# write line numbers
self._set_uint(len(chunk.lineNums))
for l in chunk.lineNums:
self._set_uint(l)
# write locals
self._set_uint(len(chunk.locals))
for l in chunk.locals:
self._set_string(l.name)
self._set_uint(l.start)
self._set_uint(l.end)
# write upvals
self._set_uint(len(chunk.upvalues))
for u in chunk.upvalues:
self._set_string(u)
def _dumpHeader(self):
self._writeBlock(_LUAMAGIC)
# write header info
self._set_byte(self.vm_version)
self._set_byte(self.bytecode_format)
self._set_byte(0 if self.big_endian else 1)
self._set_byte(self.int_size)
self._set_byte(self.size_t)
self._set_byte(self.instr_size)
self._set_byte(self.l_number_size)
self._set_byte(self.integral_flag)
def dump(self) -> bytearray:
self._dumpHeader()
self._dumpChunk(self.rootChunk)
return self.bytecode

1
main.py Normal file → Executable file
View File

@@ -1,3 +1,4 @@
#!/usr/bin/env python3
import sys
import lundump
import lparser