LuaDecompy/lparser.py

412 lines
16 KiB
Python

'''
lparser.py
Depends on lundump.py for lua dump deserialization.
An experimental bytecode decompiler.
'''
from operator import concat
from subprocess import call
from lundump import Chunk, Constant, Instruction, Opcodes, whichRK, readRKasK
class _Scope:
def __init__(self, startPC: int, endPC: int):
self.startPC = startPC
self.endPC = endPC
class _Traceback:
def __init__(self):
self.sets = []
self.uses = []
self.isConst = False
class _Line:
def __init__(self, startPC: int, endPC: int, src: str, scope: int):
self.startPC = startPC
self.endPC = endPC
self.src = src
self.scope = scope
def isValidLocal(ident: str) -> bool:
# has to start with an alpha or _
if ident[0] not in "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ_":
return False
# then it can be alphanum or _
for c in ident[1:]:
if c not in "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890_":
return False
return True
class LuaDecomp:
def __init__(self, chunk: Chunk):
self.chunk = chunk
self.pc = 0
self.scope: list[_Scope] = []
self.lines: list[_Line] = []
self.top = {}
self.locals = {}
self.traceback = {}
self.unknownLocalCount = 0
self.src: str = ""
# configurations!
self.aggressiveLocals = False # should *EVERY* set register be considered a local?
self.annotateLines = False
self.indexWidth = 4 # how many spaces for indentions?
self.__loadLocals()
# parse instructions
while self.pc < len(self.chunk.instructions):
self.parseInstr()
self.pc += 1
# end the scope (if we're supposed too)
self.__checkScope()
print("\n==== [[" + str(self.chunk.name) + "'s pseudo-code]] ====\n")
for line in self.lines:
if self.annotateLines:
print("-- PC: %d to PC: %d" % (line.startPC, line.endPC))
print(((' ' * self.indexWidth) * line.scope) + line.src)
# =======================================[[ Helpers ]]=========================================
def __getInstrAtPC(self, pc: int) -> Instruction:
if pc < len(self.chunk.instructions):
return self.chunk.instructions[pc]
raise Exception("Decompilation failed!")
def __getNextInstr(self) -> Instruction:
return self.__getInstrAtPC(self.pc + 1)
def __getCurrInstr(self) -> Instruction:
return self.__getInstrAtPC(self.pc)
def __makeTracIfNotExist(self) -> None:
if not self.pc in self.traceback:
self.traceback[self.pc] = _Traceback()
# when we read from a register, call this
def __addUseTraceback(self, reg: int) -> None:
self.__makeTracIfNotExist()
self.traceback[self.pc].uses.append(reg)
# when we write from a register, call this
def __addSetTraceback(self, reg: int) -> None:
self.__makeTracIfNotExist()
self.traceback[self.pc].sets.append(reg)
def __addExpr(self, code: str) -> None:
self.src += code
def __endStatement(self):
startPC = self.lines[len(self.lines) - 1].endPC + 1 if len(self.lines) > 0 else 0
endPC = self.pc
# make sure we don't write an empty line
if not self.src == "":
self.lines.append(_Line(startPC, endPC, self.src, len(self.scope)))
self.src = ""
def __insertStatement(self, pc: int) -> None:
# insert current statement into lines at pc location
for i in range(len(self.lines)):
if self.lines[i].startPC <= pc and self.lines[i].endPC >= pc:
self.lines.insert(i, _Line(pc, pc, self.src, self.lines[i-1].scope if i > 0 else 0))
self.src = ""
return i
self.src = ""
# walks traceback, if local wasn't set before, the local needs to be defined
def __needsDefined(self, reg) -> bool:
for _, trace in self.traceback.items():
if reg in trace.sets:
return False
# wasn't set in traceback! needs defined!
return True
def __loadLocals(self):
for i in range(len(self.chunk.locals)):
name = self.chunk.locals[i].name
if isValidLocal(name):
self.locals[i] = name
elif "(for " not in name: # if it's a for loop register, ignore
self.__makeLocalIdentifier(i)
# when you *know* the register *has* to be a local (for loops, etc.)
def __getLocal(self, indx: int) -> str:
return self.locals[indx] if indx in self.locals else self.__makeLocalIdentifier(indx)
def __getReg(self, indx: int) -> str:
self.__addUseTraceback(indx)
# if the top indx is a local, get it
return self.locals[indx] if indx in self.locals else self.top[indx]
def __setReg(self, indx: int, code: str, forceLocal: bool = False) -> None:
# if the top indx is a local, set it
if indx in self.locals:
if self.__needsDefined(indx):
self.__newLocal(indx, code)
else:
self.__addExpr(self.locals[indx] + " = " + code)
self.__endStatement()
elif self.aggressiveLocals or forceLocal: # 'every register is a local!!'
self.__newLocal(indx, code)
self.__addSetTraceback(indx)
self.top[indx] = code
# ========================================[[ Locals ]]=========================================
def __makeLocalIdentifier(self, indx: int) -> str:
# first, check if we have a local name already determined
if indx in self.locals:
return self.locals[indx]
# otherwise, generate a local
self.locals[indx] = "__unknLocal%d" % self.unknownLocalCount
self.unknownLocalCount += 1
return self.locals[indx]
def __newLocal(self, indx: int, expr: str) -> None:
# TODO: grab identifier from chunk(?)
self.__makeLocalIdentifier(indx)
self.__addExpr("local " + self.locals[indx] + " = " + expr)
self.__endStatement()
# ========================================[[ Scopes ]]=========================================
def __startScope(self, scopeType: str, start: int, size: int) -> None:
self.__addExpr(scopeType)
self.__endStatement()
self.scope.append(_Scope(start, start + size))
# checks if we need to end a scope
def __checkScope(self) -> None:
if len(self.scope) == 0:
return
if self.pc > self.scope[len(self.scope) - 1].endPC:
self.__endScope()
def __endScope(self) -> None:
self.__endStatement()
self.__addExpr("end")
self.scope.pop()
self.__endStatement()
# =====================================[[ Instructions ]]======================================
def __emitOperand(self, a: int, b: str, c: str, op: str) -> None:
self.__setReg(a, "(" + b + op + c + ")")
def __compJmp(self, op: str):
instr = self.__getCurrInstr()
jmpType = "if"
scopeStart = "then"
# we need to check if the jmp location has a jump back (if so, it's a while loop)
jmp = self.__getNextInstr().B + 1
jmpToInstr = self.__getInstrAtPC(self.pc + jmp)
if jmpToInstr.opcode == Opcodes.JMP:
# if this jump jumps back to this compJmp, it's a loop!
if self.pc + jmp + jmpToInstr.B <= self.pc + 1:
jmpType = "while"
scopeStart = "do"
elif jmp < 0:
# 'repeat until' loop (probably)
jmpType = "until"
scopeStart = None
if instr.A > 0:
self.__addExpr("%s not " % jmpType)
else:
self.__addExpr("%s " % jmpType)
self.__addExpr(self.__readRK(instr.B) + op + self.__readRK(instr.C) + " ")
self.pc += 1 # skip next instr
if scopeStart:
self.__startScope("%s " % scopeStart, self.pc - 1, jmp)
# we end the statement *after* scopeStart
self.__endStatement()
else:
# end the statement prior to repeat
self.__endStatement()
# it's a repeat until loop, insert 'repeat' at the jumpTo location
self.__addExpr("repeat")
insertedLine = self.__insertStatement(self.pc + jmp)
# add scope to every line in-between
for i in range(insertedLine+1, len(self.lines)-1):
self.lines[i].scope += 1
# 'RK's are special in because can be a register or a konstant. a bitflag is read to determine which
def __readRK(self, rk: int) -> str:
if (whichRK(rk)) > 0:
return self.chunk.getConstant(readRKasK(rk)).toCode()
else:
return self.__getReg(rk)
# walk & peak ahead NEWTABLE
def __parseNewTable(self, indx: int):
# TODO: parse SETTABLE too?
tblOps = [Opcodes.LOADK, Opcodes.SETLIST]
instr = self.__getNextInstr()
cachedRegs = {}
tbl = "{"
while instr.opcode in tblOps:
if instr.opcode == Opcodes.LOADK: # operate on registers
cachedRegs[instr.A] = self.chunk.getConstant(instr.B).toCode()
elif instr.opcode == Opcodes.SETLIST:
numElems = instr.B
for i in range(numElems):
tbl += "%s, " % cachedRegs[instr.A + i + 1]
del cachedRegs[instr.A + i + 1]
self.pc += 1
instr = self.__getNextInstr()
tbl += "}"
# i use forceLocal here even though i don't know *for sure* that the register is a local.
# this does help later though if the table is reused (which is 99% of the time). the other 1%
# only affects syntax and may look a little weird but is fine and equivalent non-the-less
self.__setReg(indx, tbl, forceLocal=True)
self.__endStatement()
# if we have leftovers... oops, set those
for i, v in cachedRegs.items():
self.__setReg(i, v)
def parseInstr(self):
instr = self.__getCurrInstr()
# python, add switch statements *please*
if instr.opcode == Opcodes.MOVE: # move is a fake ABC instr, C is ignored
# move registers
self.__setReg(instr.A, self.__getReg(instr.B))
elif instr.opcode == Opcodes.LOADK:
self.__setReg(instr.A, self.chunk.getConstant(instr.B).toCode())
elif instr.opcode == Opcodes.LOADBOOL:
if instr.B == 0:
self.__setReg(instr.A, "false")
else:
self.__setReg(instr.A, "true")
elif instr.opcode == Opcodes.GETGLOBAL:
self.__setReg(instr.A, self.chunk.getConstant(instr.B).data)
elif instr.opcode == Opcodes.GETTABLE:
self.__setReg(instr.A, self.__getReg(instr.B) + "[" + self.__readRK(instr.C) + "]")
elif instr.opcode == Opcodes.SETGLOBAL:
self.__addExpr(self.chunk.getConstant(instr.B).data + " = " + self.__getReg(instr.A))
self.__endStatement()
elif instr.opcode == Opcodes.SETTABLE:
self.__addExpr(self.__getReg(instr.A) + "[" + self.__readRK(instr.B) + "] = " + self.__readRK(instr.C))
self.__endStatement()
elif instr.opcode == Opcodes.NEWTABLE:
self.__parseNewTable(instr.A)
elif instr.opcode == Opcodes.ADD:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " + ")
elif instr.opcode == Opcodes.SUB:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " - ")
elif instr.opcode == Opcodes.MUL:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " * ")
elif instr.opcode == Opcodes.DIV:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " / ")
elif instr.opcode == Opcodes.MOD:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " % ")
elif instr.opcode == Opcodes.POW:
self.__emitOperand(instr.A, self.__readRK(instr.B), self.__readRK(instr.C), " ^ ")
elif instr.opcode == Opcodes.UNM:
self.__setReg(instr.A, "-" + self.__getReg(instr.B))
elif instr.opcode == Opcodes.NOT:
self.__setReg(instr.A, "not " + self.__getReg(instr.B))
elif instr.opcode == Opcodes.LEN:
self.__setReg(instr.A, "#" + self.__getReg(instr.B))
elif instr.opcode == Opcodes.CONCAT:
count = instr.C-instr.B+1
concatStr = ""
# concat all items on stack from RC to RB
for i in range(count):
concatStr += self.__getReg(instr.B + i) + (" .. " if not i == count - 1 else "")
self.__setReg(instr.A, concatStr)
elif instr.opcode == Opcodes.JMP:
pass
elif instr.opcode == Opcodes.EQ:
self.__compJmp(" == ")
elif instr.opcode == Opcodes.LT:
self.__compJmp(" < ")
elif instr.opcode == Opcodes.LE:
self.__compJmp(" <= ")
elif instr.opcode == Opcodes.CALL:
preStr = ""
callStr = ""
ident = ""
# parse arguments
callStr += self.__getReg(instr.A) + "("
for i in range(instr.A + 1, instr.A + instr.B):
callStr += self.__getReg(i) + (", " if not i + 1 == instr.A + instr.B else "")
callStr += ")"
# parse return values
if instr.C > 1:
preStr = "local "
for indx in range(instr.A, instr.A + instr.C - 1):
if indx in self.locals:
ident = self.locals[indx]
else:
ident = self.__makeLocalIdentifier(indx)
preStr += ident
# normally setReg() does this
self.top[indx] = ident
# just so we don't have a trailing ', '
preStr += ", " if not indx == instr.A + instr.C - 2 else ""
preStr += " = "
self.__addExpr(preStr + callStr)
self.__endStatement()
elif instr.opcode == Opcodes.RETURN:
self.__endStatement()
pass # no-op for now
elif instr.opcode == Opcodes.FORLOOP:
pass # no-op for now
elif instr.opcode == Opcodes.FORPREP:
self.__addExpr("for %s = %s, %s, %s " % (self.__getLocal(instr.A+3), self.__getReg(instr.A), self.__getReg(instr.A + 1), self.__getReg(instr.A + 2)))
self.__startScope("do", self.pc, instr.B)
elif instr.opcode == Opcodes.SETLIST:
# LFIELDS_PER_FLUSH (50) is the number of elements that *should* have been set in the list in the *last* SETLIST
# eg.
# [ 49] LOADK : R[49] K[1] ; load 0.0 into R[49]
# [ 50] LOADK : R[50] K[1] ; load 0.0 into R[50]
# [ 51] SETLIST : 0 50 1 ; sets list[1..50]
# [ 52] LOADK : R[1] K[1] ; load 0.0 into R[1]
# [ 53] SETLIST : 0 1 2 ; sets list[51..51]
numElems = instr.B
startAt = ((instr.C - 1) * 50)
ident = self.__getLocal(instr.A)
# set each index (TODO: make tables less verbose)
for i in range(numElems):
self.__addExpr("%s[%d] = %s" % (ident, (startAt + i + 1), self.__getReg(instr.A + i + 1)))
self.__endStatement()
else:
raise Exception("unsupported instruction: %s" % instr.toString())